feat: code generation for conditional and switch statements

This commit is contained in:
Brett Broadhurst 2026-03-03 16:22:35 -07:00
parent 889f678dd8
commit d6ff3a40bd
Failed to generate hash of commit
7 changed files with 476 additions and 59 deletions

View file

@ -25,7 +25,7 @@ pub const CheckError = error{
TooManyConstants,
InvalidCharacter,
NotImplemented,
};
} || anyerror;
const Scope = struct {
parent: ?*Scope,
@ -34,7 +34,7 @@ const Scope = struct {
symbol_table: std.StringHashMapUnmanaged(Symbol),
jump_stack_top: usize,
label_stack_top: usize,
exit_label: ?usize,
exit_label: usize,
pub fn deinit(scope: *Scope) void {
const gpa = scope.global.gpa;
@ -92,7 +92,7 @@ const Scope = struct {
.symbol_table = .empty,
.jump_stack_top = global.jump_stack.items.len,
.label_stack_top = global.label_stack.items.len,
.exit_label = null,
.exit_label = parent_scope.exit_label,
};
}
@ -181,13 +181,13 @@ const Scope = struct {
}
// Sets the code offset pointed to by a label.
pub fn setLabel(scope: *Scope, label_id: usize) void {
pub fn setLabel(scope: *Scope, label_index: usize) void {
const chunk = scope.chunk;
const code_offset = chunk.bytes.items.len;
const label_stack = &scope.global.label_stack;
assert(label_id <= label_stack.items.len);
assert(label_index <= label_stack.items.len);
const label_data = &label_stack.items[label_id];
const label_data = &label_stack.items[label_index];
label_data.code_offset = code_offset;
}
@ -199,6 +199,10 @@ const Scope = struct {
return jump_index;
}
pub fn setExit(scope: *Scope, label_index: usize) void {
scope.exit_label = label_index;
}
pub fn resolveLabels(scope: *Scope, start_index: usize, end_index: usize) !void {
assert(start_index <= end_index);
const jump_stack = &scope.global.jump_stack;
@ -213,7 +217,7 @@ const Scope = struct {
.absolute => label.code_offset,
};
if (jump_offset >= std.math.maxInt(u16)) {
std.debug.print("Too much code to jump over!\n", .{});
std.debug.print("Too much code to jump over! {d}\n", .{jump_offset});
return error.CompilerBug;
}
@ -310,22 +314,34 @@ fn checkUnaryOp(scope: *Scope, node: *const Ast.Node, op: Story.Opcode) CheckErr
}
fn checkBinaryOp(scope: *Scope, node: *const Ast.Node, op: Story.Opcode) CheckError!void {
const bin_data = node.data.bin;
assert(bin_data.lhs != null and bin_data.rhs != null);
try checkExpr(scope, bin_data.lhs);
try checkExpr(scope, bin_data.rhs);
try scope.emitSimpleInst(op);
}
fn checkLogicalOp(scope: *Scope, node: *const Ast.Node, binary_or: bool) CheckError!void {
const data = node.data.bin;
assert(data.lhs != null and data.rhs != null);
try checkExpr(scope, data.lhs);
const else_branch = try scope.emitJumpInst(if (binary_or) .jmp_t else .jmp_f);
try scope.emitSimpleInst(.pop);
try checkExpr(scope, data.rhs);
try scope.patchJump(else_branch);
try scope.emitSimpleInst(op);
}
fn checkLogicalOp(scope: *Scope, node: *const Ast.Node, op: Story.Opcode) CheckError!void {
const data = node.data.bin;
assert(data.lhs != null and data.rhs != null);
try checkExpr(scope, data.lhs);
const else_label = try scope.makeLabel();
const jump_offset = try scope.emitJumpInst(op);
_ = try scope.makeJump(.{
.mode = .relative,
.label_index = else_label,
.code_offset = jump_offset,
});
try scope.emitSimpleInst(.pop);
const rhs_label = try scope.makeLabel();
scope.setLabel(rhs_label);
try checkExpr(scope, data.rhs);
scope.setLabel(else_label);
}
fn checkTrueLiteral(scope: *Scope, _: *const Ast.Node) CheckError!void {
@ -338,10 +354,10 @@ fn checkFalseLiteral(scope: *Scope, _: *const Ast.Node) CheckError!void {
fn checkNumberLiteral(scope: *Scope, node: *const Ast.Node) CheckError!void {
const lexeme = getLexemeFromNode(scope.global, node);
const number_value = try std.fmt.parseFloat(f64, lexeme);
const number_value = try std.fmt.parseUnsigned(i64, lexeme, 10);
const number_object = try Story.Object.Number.create(
scope.global.story,
.{ .floating = number_value },
.{ .integer = number_value },
);
const constant_id = try scope.makeConstant(@ptrCast(number_object));
@ -396,6 +412,18 @@ fn checkExpr(scope: *Scope, expr: ?*const Ast.Node) CheckError!void {
.divide_expr => try checkBinaryOp(scope, expr_node, .div),
.mod_expr => try checkBinaryOp(scope, expr_node, .mod),
.negate_expr => try checkUnaryOp(scope, expr_node, .neg),
.logical_and_expr => try checkLogicalOp(scope, expr_node, .jmp_f),
.logical_or_expr => try checkLogicalOp(scope, expr_node, .jmp_t),
.logical_not_expr => try checkUnaryOp(scope, expr_node, .not),
.logical_equality_expr => try checkBinaryOp(scope, expr_node, .cmp_eq),
.logical_inequality_expr => {
try scope.emitSimpleInst(.not);
try checkBinaryOp(scope, expr_node, .cmp_eq);
},
.logical_greater_expr => try checkBinaryOp(scope, expr_node, .cmp_gt),
.logical_greater_or_equal_expr => try checkBinaryOp(scope, expr_node, .cmp_gte),
.logical_lesser_expr => try checkBinaryOp(scope, expr_node, .cmp_lt),
.logical_lesser_or_equal_expr => try checkBinaryOp(scope, expr_node, .cmp_lte),
else => return error.NotImplemented,
}
}
@ -406,6 +434,239 @@ fn checkExprStmt(scope: *Scope, stmt: *const Ast.Node) CheckError!void {
try scope.emitSimpleInst(.pop);
}
fn validateSwitchProngs(scope: *Scope, stmt: *const Ast.Node) CheckError!void {
var stmt_has_block: bool = false;
var stmt_has_else: bool = false;
const case_list = stmt.data.switch_stmt.cases;
const last_prong = case_list[case_list.len - 1];
for (case_list) |case_stmt| {
switch (case_stmt.tag) {
.block_stmt => stmt_has_block = true,
.switch_case, .if_branch => {
if (stmt_has_block) {
//return scope.fail(.expected_else, case_stmt);
}
},
.else_branch => {
if (case_stmt != last_prong) {
return scope.fail(.invalid_else_stmt, case_stmt);
}
if (stmt_has_else) {
return scope.fail(.unexpected_else_stmt, case_stmt);
}
stmt_has_else = true;
},
else => unreachable,
}
}
}
fn checkIfStmt(parent_scope: *Scope, stmt: *const Ast.Node) CheckError!void {
var child_scope = parent_scope.makeSubBlock();
defer child_scope.deinit();
const case_list = stmt.data.switch_stmt.cases;
const eval_expr = stmt.data.switch_stmt.condition_expr;
if (eval_expr) |expr_node| {
try validateSwitchProngs(&child_scope, stmt);
const first_prong = case_list[0];
const last_prong = case_list[case_list.len - 1];
const then_stmt: *const Ast.Node = first_prong;
const else_stmt: ?*const Ast.Node = if (first_prong == last_prong)
null
else
last_prong;
try checkExpr(&child_scope, expr_node);
const else_label = try child_scope.makeLabel();
const end_label = try child_scope.makeLabel();
const then_br = try child_scope.emitJumpInst(.jmp_f);
_ = try child_scope.makeJump(.{
.mode = .relative,
.label_index = else_label,
.code_offset = then_br,
});
try child_scope.emitSimpleInst(.pop);
try checkBlockStmt(&child_scope, then_stmt);
const else_br = try child_scope.emitJumpInst(.jmp);
_ = try child_scope.makeJump(.{
.mode = .relative,
.label_index = end_label,
.code_offset = else_br,
});
child_scope.setLabel(else_label);
try child_scope.emitSimpleInst(.pop);
if (else_stmt) |else_node| {
const block_stmt = else_node.data.bin.rhs;
try checkBlockStmt(&child_scope, block_stmt);
}
child_scope.setLabel(end_label);
} else {
return child_scope.fail(.expected_expression, stmt);
}
}
fn checkMultiIfStmt(
parent_scope: *Scope,
stmt: *const Ast.Node,
) CheckError!void {
const gpa = parent_scope.global.gpa;
var child_scope = parent_scope.makeSubBlock();
defer child_scope.deinit();
const exit_label = try child_scope.makeLabel();
child_scope.setExit(exit_label);
try validateSwitchProngs(&child_scope, stmt);
const case_list = stmt.data.switch_stmt.cases;
// NOTE: We're going to create an array of label indexes here, since we
// may create additional labels while traversing nested expressions.
var label_list: std.ArrayList(usize) = .empty;
defer label_list.deinit(gpa);
try label_list.ensureUnusedCapacity(gpa, case_list.len);
for (case_list) |case_stmt| {
const label_index = try child_scope.makeLabel();
switch (case_stmt.tag) {
.if_branch => {
const lhs = case_stmt.data.bin.lhs orelse unreachable;
try checkExpr(&child_scope, lhs);
const jump_offset = try child_scope.emitJumpInst(.jmp_t);
_ = try child_scope.makeJump(.{
.mode = .relative,
.label_index = label_index,
.code_offset = jump_offset,
});
try child_scope.emitSimpleInst(.pop);
},
.else_branch => {
const jump_offset = try child_scope.emitJumpInst(.jmp);
_ = try child_scope.makeJump(.{
.mode = .relative,
.label_index = label_index,
.code_offset = jump_offset,
});
},
else => unreachable,
}
}
for (case_list, label_list.items) |case_stmt, label_index| {
const body_stmt = case_stmt.data.bin.rhs;
switch (case_stmt.tag) {
.if_branch => {
child_scope.setLabel(label_index);
try child_scope.emitSimpleInst(.pop);
},
.else_branch => {
child_scope.setLabel(label_index);
},
else => unreachable,
}
try checkBlockStmt(&child_scope, body_stmt);
const jump_inst = try child_scope.emitJumpInst(.jmp);
_ = try child_scope.makeJump(.{
.mode = .relative,
.label_index = child_scope.exit_label,
.code_offset = jump_inst,
});
}
child_scope.setLabel(child_scope.exit_label);
}
fn checkSwitchStmt(parent_scope: *Scope, switch_stmt: *const Ast.Node) CheckError!void {
const gpa = parent_scope.global.gpa;
const current_chunk = parent_scope.chunk;
var child_scope = parent_scope.makeSubBlock();
defer child_scope.deinit();
const label_index = try child_scope.makeLabel();
child_scope.setExit(label_index);
const eval_expr = switch_stmt.data.switch_stmt.condition_expr;
const case_list = switch_stmt.data.switch_stmt.cases;
// NOTE: We're going to create an array of label indexes here, since we
// may create additional labels while traversing nested expressions.
var label_list: std.ArrayList(usize) = .empty;
defer label_list.deinit(gpa);
try label_list.ensureUnusedCapacity(gpa, case_list.len);
const stack_slot = current_chunk.arity + current_chunk.locals_count;
current_chunk.locals_count += 1;
try checkExpr(&child_scope, eval_expr);
try child_scope.emitConstInst(.store, stack_slot);
try child_scope.emitSimpleInst(.pop);
for (case_list) |case_stmt| {
const case_label_index = try child_scope.makeLabel();
label_list.appendAssumeCapacity(case_label_index);
switch (case_stmt.tag) {
.switch_case => {
const case_eval_expr = case_stmt.data.bin.lhs orelse unreachable;
switch (case_eval_expr.tag) {
.number_literal, .true_literal, .false_literal => {},
else => {
return child_scope.fail(.invalid_switch_case, case_stmt);
},
}
try child_scope.emitConstInst(.load, stack_slot);
try checkExpr(&child_scope, case_eval_expr);
try child_scope.emitSimpleInst(.cmp_eq);
const jump_offset = try child_scope.emitJumpInst(.jmp_t);
_ = try child_scope.makeJump(.{
.mode = .relative,
.label_index = case_label_index,
.code_offset = jump_offset,
});
try child_scope.emitSimpleInst(.pop);
},
.else_branch => {
const jump_offset = try child_scope.emitJumpInst(.jmp);
_ = try child_scope.makeJump(.{
.mode = .relative,
.label_index = case_label_index,
.code_offset = jump_offset,
});
},
else => unreachable,
}
}
for (case_list, label_list.items) |case_stmt, case_label_index| {
child_scope.setLabel(case_label_index);
switch (case_stmt.tag) {
.switch_case => {
try child_scope.emitSimpleInst(.pop);
},
.else_branch => {},
else => unreachable,
}
const block_stmt = case_stmt.data.bin.rhs;
try checkBlockStmt(&child_scope, block_stmt);
const jump_offset = try child_scope.emitJumpInst(.jmp);
_ = try child_scope.makeJump(.{
.mode = .relative,
.label_index = child_scope.exit_label,
.code_offset = jump_offset,
});
}
child_scope.setLabel(child_scope.exit_label);
}
fn checkInlineLogicExpr(scope: *Scope, expr: *const Ast.Node) CheckError!void {
assert(expr.data.bin.lhs != null);
return checkExpr(scope, expr.data.bin.lhs);
@ -423,6 +684,9 @@ fn checkContentExpr(scope: *Scope, expr: *const Ast.Node) CheckError!void {
try checkInlineLogicExpr(scope, child_node);
try scope.emitSimpleInst(.stream_push);
},
.if_stmt => try checkIfStmt(scope, child_node),
.multi_if_stmt => try checkMultiIfStmt(scope, child_node),
.switch_stmt => try checkSwitchStmt(scope, child_node),
else => return error.NotImplemented,
}
}
@ -430,7 +694,8 @@ fn checkContentExpr(scope: *Scope, expr: *const Ast.Node) CheckError!void {
fn checkContentStmt(scope: *Scope, stmt: *const Ast.Node) CheckError!void {
const expr_node = stmt.data.bin.lhs orelse return;
return checkContentExpr(scope, expr_node);
try checkContentExpr(scope, expr_node);
try scope.emitSimpleInst(.stream_flush);
}
fn checkAssignStmt(scope: *Scope, stmt: *const Ast.Node) CheckError!void {
@ -460,13 +725,6 @@ fn checkAssignStmt(scope: *Scope, stmt: *const Ast.Node) CheckError!void {
return scope.fail(.unknown_identifier, lhs);
}
fn checkBlockStmt(parent_scope: *Scope, block_stmt: *const Ast.Node) CheckError!void {
var block_scope = parent_scope.makeSubBlock();
defer block_scope.deinit();
const children = block_stmt.data.list.items orelse return;
for (children) |child_stmt| try checkStmt(&block_scope, child_stmt);
}
fn checkChoiceStmt(scope: *Scope, stmt_node: *const Ast.Node) CheckError!void {
const Choice = struct {
label_index: usize,
@ -596,6 +854,23 @@ fn checkStmt(scope: *Scope, stmt: *const Ast.Node) CheckError!void {
}
}
fn checkBlockStmt(parent_scope: *Scope, block_stmt: ?*const Ast.Node) CheckError!void {
if (block_stmt) |stmt| {
var block_scope = parent_scope.makeSubBlock();
defer block_scope.deinit();
const children = stmt.data.list.items orelse return;
for (children) |child_stmt| try checkStmt(&block_scope, child_stmt);
} else {
const jump_offset = try parent_scope.emitJumpInst(.jmp);
_ = try parent_scope.makeJump(.{
.mode = .relative,
.label_index = parent_scope.exit_label,
.code_offset = jump_offset,
});
}
}
fn checkAnonymousKnot(parent_scope: *Scope, body: *const Ast.Node) CheckError!void {
const exit_label = try parent_scope.makeLabel();
parent_scope.exit_label = exit_label;
@ -623,7 +898,7 @@ fn checkFile(astgen: *AstGen, file: *const Ast.Node) CheckError!void {
.symbol_table = .empty,
.jump_stack_top = astgen.jump_stack.items.len,
.label_stack_top = astgen.label_stack.items.len,
.exit_label = null,
.exit_label = 0,
};
defer file_scope.deinit();
@ -665,6 +940,8 @@ pub fn generate(gpa: std.mem.Allocator, tree: *const Ast) !Story {
.is_exited = false,
.can_advance = false,
};
errdefer story.deinit();
var astgen: AstGen = .{
.gpa = gpa,
.tree = tree,