Skip to content

fix: for loops no longer execute once when condition is already met #1248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 27, 2024
199 changes: 96 additions & 103 deletions src/codegen/generators/statement_generator.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
// Copyright (c) 2020 Ghaith Hachem and Mathias Rieder
use super::{
expression_generator::{to_i1, ExpressionCodeGenerator},
expression_generator::{to_i1, ExpressionCodeGenerator, ExpressionValue},
llvm::Llvm,
};
use crate::{
codegen::debug::Debug,
codegen::{debug::DebugBuilderEnum, LlvmTypedIndex},
codegen::{
debug::{Debug, DebugBuilderEnum},
llvm_typesystem::cast_if_needed,
LlvmTypedIndex,
},
index::{ImplementationIndexEntry, Index},
resolver::{AnnotationMap, AstAnnotations, StatementAnnotation},
typesystem::DataTypeInformation,
typesystem::{get_bigger_type, DataTypeInformation, DINT_TYPE},
};
use inkwell::{
basic_block::BasicBlock,
builder::Builder,
context::Context,
values::{BasicValueEnum, FunctionValue, PointerValue},
values::{FunctionValue, PointerValue},
};
use plc_ast::{
ast::{
Expand Down Expand Up @@ -325,117 +328,107 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> {
body: &[AstNode],
) -> Result<(), Diagnostic> {
let (builder, current_function, context) = self.get_llvm_deps();
self.generate_assignment_statement(counter, start)?;
let condition_check = context.append_basic_block(current_function, "condition_check");
let for_body = context.append_basic_block(current_function, "for_body");
let increment_block = context.append_basic_block(current_function, "increment");
let continue_block = context.append_basic_block(current_function, "continue");

//Generate an initial jump to the for condition
builder.build_unconditional_branch(condition_check);

//Check loop condition
builder.position_at_end(condition_check);
let exp_gen = self.create_expr_generator();
let counter_statement = exp_gen.generate_expression(counter)?;

//. / and_2 \
//. / and 1 \
//. (counter_end_le && counter_start_ge) || (counter_end_ge && counter_start_le)
let or_eval = self.generate_compare_expression(counter, end, start, &exp_gen)?;
let end_ty = self.annotations.get_type_or_void(end, self.index);
let counter_ty = self.annotations.get_type_or_void(counter, self.index);
let cast_target_ty = get_bigger_type(self.index.get_type_or_panic(DINT_TYPE), counter_ty, self.index);
let cast_target_llty = self.llvm_index.find_associated_type(cast_target_ty.get_name()).unwrap();

let step_ty = by_step.as_ref().map(|it| {
self.register_debug_location(it);
self.annotations.get_type_or_void(it, self.index)
});

let eval_step = || {
step_ty.map_or_else(
|| self.llvm.create_const_numeric(&cast_target_llty, "1", SourceLocation::undefined()),
|step_ty| {
let step = exp_gen.generate_expression(by_step.as_ref().unwrap())?;
Ok(cast_if_needed!(exp_gen, cast_target_ty, step_ty, step, None))
},
)
};

builder.build_conditional_branch(to_i1(or_eval.into_int_value(), builder), for_body, continue_block);
let predicate_incrementing = context.append_basic_block(current_function, "predicate_sle");
let predicate_decrementing = context.append_basic_block(current_function, "predicate_sge");
let loop_body = context.append_basic_block(current_function, "loop");
let increment = context.append_basic_block(current_function, "increment");
let afterloop = context.append_basic_block(current_function, "continue");

//Enter the for loop
builder.position_at_end(for_body);
let body_generator = StatementCodeGenerator {
current_loop_exit: Some(continue_block),
current_loop_continue: Some(increment_block),
self.generate_assignment_statement(counter, start)?;
let counter = exp_gen.generate_lvalue(counter)?;

// generate loop predicate selector. since `STEP` can be a reference, this needs to be a runtime eval
// XXX(mhasel): IR could possibly be improved by generating phi instructions.
// Candidate for frontend optimization for builds without optimization when `STEP`
// is a compile-time constant
let is_incrementing = builder.build_int_compare(
inkwell::IntPredicate::SGT,
eval_step()?.into_int_value(),
self.llvm
.create_const_numeric(&cast_target_llty, "0", SourceLocation::undefined())?
.into_int_value(),
"is_incrementing",
);
builder.build_conditional_branch(is_incrementing, predicate_incrementing, predicate_decrementing);
// generate predicates for incrementing and decrementing counters
let generate_predicate = |predicate| {
builder.position_at_end(match predicate {
inkwell::IntPredicate::SLE => predicate_incrementing,
inkwell::IntPredicate::SGE => predicate_decrementing,
_ => unreachable!(),
});

let end = exp_gen.generate_expression_value(end).unwrap();
let end_value = match end {
ExpressionValue::LValue(ptr) => builder.build_load(ptr, ""),
ExpressionValue::RValue(val) => val,
};
let counter_value = builder.build_load(counter, "");
let cmp = builder.build_int_compare(
predicate,
cast_if_needed!(exp_gen, cast_target_ty, counter_ty, counter_value, None).into_int_value(),
cast_if_needed!(exp_gen, cast_target_ty, end_ty, end_value, None).into_int_value(),
"condition",
);
builder.build_conditional_branch(cmp, loop_body, afterloop);
};
generate_predicate(inkwell::IntPredicate::SLE);
generate_predicate(inkwell::IntPredicate::SGE);

// generate loop body
builder.position_at_end(loop_body);
let body_builder = StatementCodeGenerator {
current_loop_continue: Some(increment),
current_loop_exit: Some(afterloop),
load_prefix: self.load_prefix.clone(),
load_suffix: self.load_suffix.clone(),
..*self
};
body_generator.generate_body(body)?;
builder.build_unconditional_branch(increment_block);

//Increment
builder.position_at_end(increment_block);
let expression_generator = self.create_expr_generator();
let step_by_value = by_step.as_ref().map_or_else(
|| {
self.llvm.create_const_numeric(
&counter_statement.get_type(),
"1",
SourceLocation::undefined(),
)
},
|step| {
self.register_debug_location(step);
expression_generator.generate_expression(step)
},
)?;

let next = builder.build_int_add(
counter_statement.into_int_value(),
step_by_value.into_int_value(),
"tmpVar",
body_builder.generate_body(body)?;

// increment counter
builder.build_unconditional_branch(increment);
builder.position_at_end(increment);
let counter_value = builder.build_load(counter, "");
let inc = inkwell::values::BasicValue::as_basic_value_enum(&builder.build_int_add(
eval_step()?.into_int_value(),
cast_if_needed!(exp_gen, cast_target_ty, counter_ty, counter_value, None).into_int_value(),
"next",
));
builder.build_store(
counter,
cast_if_needed!(exp_gen, counter_ty, cast_target_ty, inc, None).into_int_value(),
);

let ptr = expression_generator.generate_lvalue(counter)?;
builder.build_store(ptr, next);

//Loop back
builder.build_unconditional_branch(condition_check);

//Continue
builder.position_at_end(continue_block);

// check condition
builder.build_conditional_branch(is_incrementing, predicate_incrementing, predicate_decrementing);
// continue
builder.position_at_end(afterloop);
Ok(())
}

fn generate_compare_expression(
&'a self,
counter: &AstNode,
end: &AstNode,
start: &AstNode,
exp_gen: &'a ExpressionCodeGenerator,
) -> Result<BasicValueEnum<'a>, Diagnostic> {
let bool_id = self.annotations.get_bool_id();
let counter_end_ge = AstFactory::create_binary_expression(
counter.clone(),
Operator::GreaterOrEqual,
end.clone(),
bool_id,
);
let counter_start_ge = AstFactory::create_binary_expression(
counter.clone(),
Operator::GreaterOrEqual,
start.clone(),
bool_id,
);
let counter_end_le = AstFactory::create_binary_expression(
counter.clone(),
Operator::LessOrEqual,
end.clone(),
bool_id,
);
let counter_start_le = AstFactory::create_binary_expression(
counter.clone(),
Operator::LessOrEqual,
start.clone(),
bool_id,
);
let and_1 =
AstFactory::create_binary_expression(counter_end_le, Operator::And, counter_start_ge, bool_id);
let and_2 =
AstFactory::create_binary_expression(counter_end_ge, Operator::And, counter_start_le, bool_id);
let or = AstFactory::create_binary_expression(and_1, Operator::Or, and_2, bool_id);

self.register_debug_location(&or);
let or_eval = exp_gen.generate_expression(&or)?;
Ok(or_eval)
}

/// genertes a case statement
///
/// CASE selector OF
Expand Down
139 changes: 139 additions & 0 deletions src/codegen/tests/code_gen_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,145 @@ fn for_statement_with_references_steps_test() {
insta::assert_snapshot!(result);
}

#[test]
fn for_statement_with_binary_expressions() {
let result = codegen(
"
PROGRAM prg
VAR
step: DINT;
x : DINT;
y : DINT;
z : DINT;
END_VAR
FOR x := y + 1 TO z - 2 BY step * 3 DO
x;
END_FOR
END_PROGRAM
",
);

insta::assert_snapshot!(result, @r###"
; ModuleID = 'main'
source_filename = "main"

%prg = type { i32, i32, i32, i32 }

@prg_instance = global %prg zeroinitializer, section "var-$RUSTY$prg_instance:r4i32i32i32i32"

define void @prg(%prg* %0) section "fn-$RUSTY$prg:v" {
entry:
%step = getelementptr inbounds %prg, %prg* %0, i32 0, i32 0
%x = getelementptr inbounds %prg, %prg* %0, i32 0, i32 1
%y = getelementptr inbounds %prg, %prg* %0, i32 0, i32 2
%z = getelementptr inbounds %prg, %prg* %0, i32 0, i32 3
%load_y = load i32, i32* %y, align 4
%tmpVar = add i32 %load_y, 1
store i32 %tmpVar, i32* %x, align 4
%load_step = load i32, i32* %step, align 4
%tmpVar1 = mul i32 %load_step, 3
%is_incrementing = icmp sgt i32 %tmpVar1, 0
br i1 %is_incrementing, label %predicate_sle, label %predicate_sge

predicate_sle: ; preds = %increment, %entry
%load_z = load i32, i32* %z, align 4
%tmpVar2 = sub i32 %load_z, 2
%1 = load i32, i32* %x, align 4
%condition = icmp sle i32 %1, %tmpVar2
br i1 %condition, label %loop, label %continue

predicate_sge: ; preds = %increment, %entry
%load_z3 = load i32, i32* %z, align 4
%tmpVar4 = sub i32 %load_z3, 2
%2 = load i32, i32* %x, align 4
%condition5 = icmp sge i32 %2, %tmpVar4
br i1 %condition5, label %loop, label %continue

loop: ; preds = %predicate_sge, %predicate_sle
%load_x = load i32, i32* %x, align 4
br label %increment

increment: ; preds = %loop
%3 = load i32, i32* %x, align 4
%load_step6 = load i32, i32* %step, align 4
%tmpVar7 = mul i32 %load_step6, 3
%next = add i32 %tmpVar7, %3
store i32 %next, i32* %x, align 4
br i1 %is_incrementing, label %predicate_sle, label %predicate_sge

continue: ; preds = %predicate_sge, %predicate_sle
ret void
}
"###);
}

#[test]
fn for_statement_type_casting() {
let result = codegen(
"FUNCTION main
VAR
a: USINT;
b: INT := 1;
END_VAR
FOR a := 0 TO 10 BY b DO
b := b * 3;
END_FOR
END_FUNCTION",
);
insta::assert_snapshot!(result, @r###"
; ModuleID = 'main'
source_filename = "main"

define void @main() section "fn-$RUSTY$main:v" {
entry:
%a = alloca i8, align 1
%b = alloca i16, align 2
store i8 0, i8* %a, align 1
store i16 1, i16* %b, align 2
store i8 0, i8* %a, align 1
%load_b = load i16, i16* %b, align 2
%0 = trunc i16 %load_b to i8
%1 = sext i8 %0 to i32
%is_incrementing = icmp sgt i32 %1, 0
br i1 %is_incrementing, label %predicate_sle, label %predicate_sge

predicate_sle: ; preds = %increment, %entry
%2 = load i8, i8* %a, align 1
%3 = zext i8 %2 to i32
%condition = icmp sle i32 %3, 10
br i1 %condition, label %loop, label %continue

predicate_sge: ; preds = %increment, %entry
%4 = load i8, i8* %a, align 1
%5 = zext i8 %4 to i32
%condition1 = icmp sge i32 %5, 10
br i1 %condition1, label %loop, label %continue

loop: ; preds = %predicate_sge, %predicate_sle
%load_b2 = load i16, i16* %b, align 2
%6 = sext i16 %load_b2 to i32
%tmpVar = mul i32 %6, 3
%7 = trunc i32 %tmpVar to i16
store i16 %7, i16* %b, align 2
br label %increment

increment: ; preds = %loop
%8 = load i8, i8* %a, align 1
%load_b3 = load i16, i16* %b, align 2
%9 = trunc i16 %load_b3 to i8
%10 = sext i8 %9 to i32
%11 = zext i8 %8 to i32
%next = add i32 %10, %11
%12 = trunc i32 %next to i8
store i8 %12, i8* %a, align 1
br i1 %is_incrementing, label %predicate_sle, label %predicate_sge

continue: ; preds = %predicate_sge, %predicate_sle
ret void
}
"###);
}

#[test]
fn while_statement() {
let result = codegen(
Expand Down
Loading
Loading