From d6ff3a40bde184da79df451ad261460f2002df42 Mon Sep 17 00:00:00 2001 From: Brett Broadhurst Date: Tue, 3 Mar 2026 16:22:35 -0700 Subject: [PATCH] feat: code generation for conditional and switch statements --- src/Ast.zig | 8 +- src/Ast/Render.zig | 9 +- src/AstGen.zig | 335 +++++++++++++++++++++++++++++++++++++++---- src/Parse.zig | 4 +- src/Story.zig | 124 ++++++++++++++-- src/Story/object.zig | 53 +++++-- src/tokenizer.zig | 2 +- 7 files changed, 476 insertions(+), 59 deletions(-) diff --git a/src/Ast.zig b/src/Ast.zig index 078609e..f711479 100644 --- a/src/Ast.zig +++ b/src/Ast.zig @@ -104,7 +104,7 @@ pub const Node = struct { }, switch_stmt: struct { condition_expr: ?*Node, - cases: ?[]*Node, + cases: []*Node, }, knot_decl: struct { prototype: *Node, @@ -175,7 +175,7 @@ pub const Node = struct { tag: Tag, loc: Span, condition_expr: ?*Node, - cases_list: ?[]*Node, + cases_list: []*Node, ) !*Node { const node = try Node.create(gpa, tag, loc); node.data = .{ @@ -227,6 +227,10 @@ pub const Error = struct { invalid_lvalue, too_many_arguments, too_many_parameters, + + invalid_else_stmt, + unexpected_else_stmt, + invalid_switch_case, }; }; diff --git a/src/Ast/Render.zig b/src/Ast/Render.zig index 7bdbb48..6e337a2 100644 --- a/src/Ast/Render.zig +++ b/src/Ast/Render.zig @@ -172,6 +172,9 @@ fn renderError(r: *Render, writer: *std.Io.Writer, err: Ast.Error) !void { .invalid_lvalue => try renderErrorf(r, writer, err, "invalid lvalue for assignment"), .too_many_arguments => try renderErrorf(r, writer, err, "too many arguments to '{s}'"), .too_many_parameters => try renderErrorf(r, writer, err, "too many parameters defined for '{s}'"), + .unexpected_else_stmt => try renderErrorf(r, writer, err, "unexpected else stmt"), + .invalid_else_stmt => try renderErrorf(r, writer, err, "invalid else stmt"), + .invalid_switch_case => try renderErrorf(r, writer, err, "invalid switch case expression"), } } @@ -491,9 +494,9 @@ fn renderAstWalk( if (expr) |n| try children.append(r.gpa, n); const list = node.data.switch_stmt.cases; - if (list) |items| for (items) |n| { - try children.append(r.gpa, n); - }; + for (list) |case_stmt| { + try children.append(r.gpa, case_stmt); + } }, .inline_logic_expr => { const lhs = node.data.bin.lhs; diff --git a/src/AstGen.zig b/src/AstGen.zig index e84df84..3231576 100644 --- a/src/AstGen.zig +++ b/src/AstGen.zig @@ -25,7 +25,7 @@ pub const CheckError = error{ TooManyConstants, InvalidCharacter, NotImplemented, -}; +} || anyerror; const Scope = struct { parent: ?*Scope, @@ -34,7 +34,7 @@ const Scope = struct { symbol_table: std.StringHashMapUnmanaged(Symbol), jump_stack_top: usize, label_stack_top: usize, - exit_label: ?usize, + exit_label: usize, pub fn deinit(scope: *Scope) void { const gpa = scope.global.gpa; @@ -92,7 +92,7 @@ const Scope = struct { .symbol_table = .empty, .jump_stack_top = global.jump_stack.items.len, .label_stack_top = global.label_stack.items.len, - .exit_label = null, + .exit_label = parent_scope.exit_label, }; } @@ -181,13 +181,13 @@ const Scope = struct { } // Sets the code offset pointed to by a label. - pub fn setLabel(scope: *Scope, label_id: usize) void { + pub fn setLabel(scope: *Scope, label_index: usize) void { const chunk = scope.chunk; const code_offset = chunk.bytes.items.len; const label_stack = &scope.global.label_stack; - assert(label_id <= label_stack.items.len); + assert(label_index <= label_stack.items.len); - const label_data = &label_stack.items[label_id]; + const label_data = &label_stack.items[label_index]; label_data.code_offset = code_offset; } @@ -199,6 +199,10 @@ const Scope = struct { return jump_index; } + pub fn setExit(scope: *Scope, label_index: usize) void { + scope.exit_label = label_index; + } + pub fn resolveLabels(scope: *Scope, start_index: usize, end_index: usize) !void { assert(start_index <= end_index); const jump_stack = &scope.global.jump_stack; @@ -213,7 +217,7 @@ const Scope = struct { .absolute => label.code_offset, }; if (jump_offset >= std.math.maxInt(u16)) { - std.debug.print("Too much code to jump over!\n", .{}); + std.debug.print("Too much code to jump over! {d}\n", .{jump_offset}); return error.CompilerBug; } @@ -310,22 +314,34 @@ fn checkUnaryOp(scope: *Scope, node: *const Ast.Node, op: Story.Opcode) CheckErr } fn checkBinaryOp(scope: *Scope, node: *const Ast.Node, op: Story.Opcode) CheckError!void { - const bin_data = node.data.bin; - assert(bin_data.lhs != null and bin_data.rhs != null); - try checkExpr(scope, bin_data.lhs); - try checkExpr(scope, bin_data.rhs); - try scope.emitSimpleInst(op); -} - -fn checkLogicalOp(scope: *Scope, node: *const Ast.Node, binary_or: bool) CheckError!void { const data = node.data.bin; assert(data.lhs != null and data.rhs != null); try checkExpr(scope, data.lhs); - const else_branch = try scope.emitJumpInst(if (binary_or) .jmp_t else .jmp_f); - try scope.emitSimpleInst(.pop); try checkExpr(scope, data.rhs); - try scope.patchJump(else_branch); + try scope.emitSimpleInst(op); +} + +fn checkLogicalOp(scope: *Scope, node: *const Ast.Node, op: Story.Opcode) CheckError!void { + const data = node.data.bin; + assert(data.lhs != null and data.rhs != null); + + try checkExpr(scope, data.lhs); + + const else_label = try scope.makeLabel(); + const jump_offset = try scope.emitJumpInst(op); + _ = try scope.makeJump(.{ + .mode = .relative, + .label_index = else_label, + .code_offset = jump_offset, + }); + + try scope.emitSimpleInst(.pop); + const rhs_label = try scope.makeLabel(); + scope.setLabel(rhs_label); + + try checkExpr(scope, data.rhs); + scope.setLabel(else_label); } fn checkTrueLiteral(scope: *Scope, _: *const Ast.Node) CheckError!void { @@ -338,10 +354,10 @@ fn checkFalseLiteral(scope: *Scope, _: *const Ast.Node) CheckError!void { fn checkNumberLiteral(scope: *Scope, node: *const Ast.Node) CheckError!void { const lexeme = getLexemeFromNode(scope.global, node); - const number_value = try std.fmt.parseFloat(f64, lexeme); + const number_value = try std.fmt.parseUnsigned(i64, lexeme, 10); const number_object = try Story.Object.Number.create( scope.global.story, - .{ .floating = number_value }, + .{ .integer = number_value }, ); const constant_id = try scope.makeConstant(@ptrCast(number_object)); @@ -396,6 +412,18 @@ fn checkExpr(scope: *Scope, expr: ?*const Ast.Node) CheckError!void { .divide_expr => try checkBinaryOp(scope, expr_node, .div), .mod_expr => try checkBinaryOp(scope, expr_node, .mod), .negate_expr => try checkUnaryOp(scope, expr_node, .neg), + .logical_and_expr => try checkLogicalOp(scope, expr_node, .jmp_f), + .logical_or_expr => try checkLogicalOp(scope, expr_node, .jmp_t), + .logical_not_expr => try checkUnaryOp(scope, expr_node, .not), + .logical_equality_expr => try checkBinaryOp(scope, expr_node, .cmp_eq), + .logical_inequality_expr => { + try scope.emitSimpleInst(.not); + try checkBinaryOp(scope, expr_node, .cmp_eq); + }, + .logical_greater_expr => try checkBinaryOp(scope, expr_node, .cmp_gt), + .logical_greater_or_equal_expr => try checkBinaryOp(scope, expr_node, .cmp_gte), + .logical_lesser_expr => try checkBinaryOp(scope, expr_node, .cmp_lt), + .logical_lesser_or_equal_expr => try checkBinaryOp(scope, expr_node, .cmp_lte), else => return error.NotImplemented, } } @@ -406,6 +434,239 @@ fn checkExprStmt(scope: *Scope, stmt: *const Ast.Node) CheckError!void { try scope.emitSimpleInst(.pop); } +fn validateSwitchProngs(scope: *Scope, stmt: *const Ast.Node) CheckError!void { + var stmt_has_block: bool = false; + var stmt_has_else: bool = false; + const case_list = stmt.data.switch_stmt.cases; + const last_prong = case_list[case_list.len - 1]; + for (case_list) |case_stmt| { + switch (case_stmt.tag) { + .block_stmt => stmt_has_block = true, + .switch_case, .if_branch => { + if (stmt_has_block) { + //return scope.fail(.expected_else, case_stmt); + } + }, + .else_branch => { + if (case_stmt != last_prong) { + return scope.fail(.invalid_else_stmt, case_stmt); + } + if (stmt_has_else) { + return scope.fail(.unexpected_else_stmt, case_stmt); + } + stmt_has_else = true; + }, + else => unreachable, + } + } +} + +fn checkIfStmt(parent_scope: *Scope, stmt: *const Ast.Node) CheckError!void { + var child_scope = parent_scope.makeSubBlock(); + defer child_scope.deinit(); + + const case_list = stmt.data.switch_stmt.cases; + const eval_expr = stmt.data.switch_stmt.condition_expr; + if (eval_expr) |expr_node| { + try validateSwitchProngs(&child_scope, stmt); + + const first_prong = case_list[0]; + const last_prong = case_list[case_list.len - 1]; + const then_stmt: *const Ast.Node = first_prong; + const else_stmt: ?*const Ast.Node = if (first_prong == last_prong) + null + else + last_prong; + + try checkExpr(&child_scope, expr_node); + + const else_label = try child_scope.makeLabel(); + const end_label = try child_scope.makeLabel(); + const then_br = try child_scope.emitJumpInst(.jmp_f); + _ = try child_scope.makeJump(.{ + .mode = .relative, + .label_index = else_label, + .code_offset = then_br, + }); + try child_scope.emitSimpleInst(.pop); + try checkBlockStmt(&child_scope, then_stmt); + + const else_br = try child_scope.emitJumpInst(.jmp); + _ = try child_scope.makeJump(.{ + .mode = .relative, + .label_index = end_label, + .code_offset = else_br, + }); + child_scope.setLabel(else_label); + try child_scope.emitSimpleInst(.pop); + + if (else_stmt) |else_node| { + const block_stmt = else_node.data.bin.rhs; + try checkBlockStmt(&child_scope, block_stmt); + } + child_scope.setLabel(end_label); + } else { + return child_scope.fail(.expected_expression, stmt); + } +} + +fn checkMultiIfStmt( + parent_scope: *Scope, + stmt: *const Ast.Node, +) CheckError!void { + const gpa = parent_scope.global.gpa; + var child_scope = parent_scope.makeSubBlock(); + defer child_scope.deinit(); + + const exit_label = try child_scope.makeLabel(); + child_scope.setExit(exit_label); + + try validateSwitchProngs(&child_scope, stmt); + const case_list = stmt.data.switch_stmt.cases; + + // NOTE: We're going to create an array of label indexes here, since we + // may create additional labels while traversing nested expressions. + var label_list: std.ArrayList(usize) = .empty; + defer label_list.deinit(gpa); + try label_list.ensureUnusedCapacity(gpa, case_list.len); + + for (case_list) |case_stmt| { + const label_index = try child_scope.makeLabel(); + switch (case_stmt.tag) { + .if_branch => { + const lhs = case_stmt.data.bin.lhs orelse unreachable; + try checkExpr(&child_scope, lhs); + + const jump_offset = try child_scope.emitJumpInst(.jmp_t); + _ = try child_scope.makeJump(.{ + .mode = .relative, + .label_index = label_index, + .code_offset = jump_offset, + }); + try child_scope.emitSimpleInst(.pop); + }, + .else_branch => { + const jump_offset = try child_scope.emitJumpInst(.jmp); + _ = try child_scope.makeJump(.{ + .mode = .relative, + .label_index = label_index, + .code_offset = jump_offset, + }); + }, + else => unreachable, + } + } + for (case_list, label_list.items) |case_stmt, label_index| { + const body_stmt = case_stmt.data.bin.rhs; + switch (case_stmt.tag) { + .if_branch => { + child_scope.setLabel(label_index); + try child_scope.emitSimpleInst(.pop); + }, + .else_branch => { + child_scope.setLabel(label_index); + }, + else => unreachable, + } + try checkBlockStmt(&child_scope, body_stmt); + + const jump_inst = try child_scope.emitJumpInst(.jmp); + _ = try child_scope.makeJump(.{ + .mode = .relative, + .label_index = child_scope.exit_label, + .code_offset = jump_inst, + }); + } + + child_scope.setLabel(child_scope.exit_label); +} + +fn checkSwitchStmt(parent_scope: *Scope, switch_stmt: *const Ast.Node) CheckError!void { + const gpa = parent_scope.global.gpa; + const current_chunk = parent_scope.chunk; + var child_scope = parent_scope.makeSubBlock(); + defer child_scope.deinit(); + + const label_index = try child_scope.makeLabel(); + child_scope.setExit(label_index); + + const eval_expr = switch_stmt.data.switch_stmt.condition_expr; + const case_list = switch_stmt.data.switch_stmt.cases; + + // NOTE: We're going to create an array of label indexes here, since we + // may create additional labels while traversing nested expressions. + var label_list: std.ArrayList(usize) = .empty; + defer label_list.deinit(gpa); + try label_list.ensureUnusedCapacity(gpa, case_list.len); + + const stack_slot = current_chunk.arity + current_chunk.locals_count; + current_chunk.locals_count += 1; + + try checkExpr(&child_scope, eval_expr); + try child_scope.emitConstInst(.store, stack_slot); + try child_scope.emitSimpleInst(.pop); + + for (case_list) |case_stmt| { + const case_label_index = try child_scope.makeLabel(); + label_list.appendAssumeCapacity(case_label_index); + + switch (case_stmt.tag) { + .switch_case => { + const case_eval_expr = case_stmt.data.bin.lhs orelse unreachable; + switch (case_eval_expr.tag) { + .number_literal, .true_literal, .false_literal => {}, + else => { + return child_scope.fail(.invalid_switch_case, case_stmt); + }, + } + + try child_scope.emitConstInst(.load, stack_slot); + try checkExpr(&child_scope, case_eval_expr); + try child_scope.emitSimpleInst(.cmp_eq); + + const jump_offset = try child_scope.emitJumpInst(.jmp_t); + _ = try child_scope.makeJump(.{ + .mode = .relative, + .label_index = case_label_index, + .code_offset = jump_offset, + }); + try child_scope.emitSimpleInst(.pop); + }, + .else_branch => { + const jump_offset = try child_scope.emitJumpInst(.jmp); + _ = try child_scope.makeJump(.{ + .mode = .relative, + .label_index = case_label_index, + .code_offset = jump_offset, + }); + }, + else => unreachable, + } + } + for (case_list, label_list.items) |case_stmt, case_label_index| { + child_scope.setLabel(case_label_index); + + switch (case_stmt.tag) { + .switch_case => { + try child_scope.emitSimpleInst(.pop); + }, + .else_branch => {}, + else => unreachable, + } + + const block_stmt = case_stmt.data.bin.rhs; + try checkBlockStmt(&child_scope, block_stmt); + const jump_offset = try child_scope.emitJumpInst(.jmp); + _ = try child_scope.makeJump(.{ + .mode = .relative, + .label_index = child_scope.exit_label, + .code_offset = jump_offset, + }); + } + + child_scope.setLabel(child_scope.exit_label); +} + fn checkInlineLogicExpr(scope: *Scope, expr: *const Ast.Node) CheckError!void { assert(expr.data.bin.lhs != null); return checkExpr(scope, expr.data.bin.lhs); @@ -423,6 +684,9 @@ fn checkContentExpr(scope: *Scope, expr: *const Ast.Node) CheckError!void { try checkInlineLogicExpr(scope, child_node); try scope.emitSimpleInst(.stream_push); }, + .if_stmt => try checkIfStmt(scope, child_node), + .multi_if_stmt => try checkMultiIfStmt(scope, child_node), + .switch_stmt => try checkSwitchStmt(scope, child_node), else => return error.NotImplemented, } } @@ -430,7 +694,8 @@ fn checkContentExpr(scope: *Scope, expr: *const Ast.Node) CheckError!void { fn checkContentStmt(scope: *Scope, stmt: *const Ast.Node) CheckError!void { const expr_node = stmt.data.bin.lhs orelse return; - return checkContentExpr(scope, expr_node); + try checkContentExpr(scope, expr_node); + try scope.emitSimpleInst(.stream_flush); } fn checkAssignStmt(scope: *Scope, stmt: *const Ast.Node) CheckError!void { @@ -460,13 +725,6 @@ fn checkAssignStmt(scope: *Scope, stmt: *const Ast.Node) CheckError!void { return scope.fail(.unknown_identifier, lhs); } -fn checkBlockStmt(parent_scope: *Scope, block_stmt: *const Ast.Node) CheckError!void { - var block_scope = parent_scope.makeSubBlock(); - defer block_scope.deinit(); - const children = block_stmt.data.list.items orelse return; - for (children) |child_stmt| try checkStmt(&block_scope, child_stmt); -} - fn checkChoiceStmt(scope: *Scope, stmt_node: *const Ast.Node) CheckError!void { const Choice = struct { label_index: usize, @@ -596,6 +854,23 @@ fn checkStmt(scope: *Scope, stmt: *const Ast.Node) CheckError!void { } } +fn checkBlockStmt(parent_scope: *Scope, block_stmt: ?*const Ast.Node) CheckError!void { + if (block_stmt) |stmt| { + var block_scope = parent_scope.makeSubBlock(); + defer block_scope.deinit(); + + const children = stmt.data.list.items orelse return; + for (children) |child_stmt| try checkStmt(&block_scope, child_stmt); + } else { + const jump_offset = try parent_scope.emitJumpInst(.jmp); + _ = try parent_scope.makeJump(.{ + .mode = .relative, + .label_index = parent_scope.exit_label, + .code_offset = jump_offset, + }); + } +} + fn checkAnonymousKnot(parent_scope: *Scope, body: *const Ast.Node) CheckError!void { const exit_label = try parent_scope.makeLabel(); parent_scope.exit_label = exit_label; @@ -623,7 +898,7 @@ fn checkFile(astgen: *AstGen, file: *const Ast.Node) CheckError!void { .symbol_table = .empty, .jump_stack_top = astgen.jump_stack.items.len, .label_stack_top = astgen.label_stack.items.len, - .exit_label = null, + .exit_label = 0, }; defer file_scope.deinit(); @@ -665,6 +940,8 @@ pub fn generate(gpa: std.mem.Allocator, tree: *const Ast) !Story { .is_exited = false, .can_advance = false, }; + errdefer story.deinit(); + var astgen: AstGen = .{ .gpa = gpa, .tree = tree, diff --git a/src/Parse.zig b/src/Parse.zig index c78da5b..d5d7fae 100644 --- a/src/Parse.zig +++ b/src/Parse.zig @@ -245,9 +245,7 @@ fn popScratch(p: *Parse, context: *const StmtContext) *Ast.Node { @panic("BUG: Scratch buffer popped when empty!"); } -fn nodeListFromScratch(p: *Parse, start_offset: usize, end_offset: usize) Error!?[]*Ast.Node { - if (start_offset >= end_offset) return null; - +fn nodeListFromScratch(p: *Parse, start_offset: usize, end_offset: usize) Error![]*Ast.Node { const span = end_offset - start_offset; assert(span > 0); diff --git a/src/Story.zig b/src/Story.zig index 8dbe0d1..74cf674 100644 --- a/src/Story.zig +++ b/src/Story.zig @@ -38,7 +38,9 @@ pub const Opcode = enum(u8) { ret, /// Pop a value off the stack, discarding it. pop, + /// Push an object representing the boolean value of "true" to the stack. true, + /// Push an object representing the boolean value of "false" to the stack. false, /// Pop two values off the stack and calculate their sum. /// The result will be pushed to the stack. @@ -60,8 +62,13 @@ pub const Opcode = enum(u8) { cmp_gt, cmp_lte, cmp_gte, + /// Jump unconditionally to the target address. jmp, + /// Jump conditionally to the target address if the boolean value at the + /// top of the stack is true. jmp_t, + /// Jump conditionally to the target address if the boolean value at the + /// top of the stack is false. jmp_f, call, divert, @@ -118,14 +125,14 @@ pub fn trace(story: *Story, writer: *std.Io.Writer, frame: *CallFrame) !void { if (slot) |object| { try story_dumper.dumpObject(object); } else { - try writer.writeAll("NULL"); + try writer.writeAll("null"); } try writer.writeAll(", "); } if (last_slot) |object| { try story_dumper.dumpObject(object); } else { - try writer.writeAll("NULL"); + try writer.writeAll("null"); } } @@ -198,7 +205,7 @@ fn setGlobal(vm: *Story, key: *const Object.String, value: *Object) !void { fn execute(vm: *Story) !std.ArrayListUnmanaged(u8) { const gpa = vm.allocator; - defer { + errdefer { vm.can_advance = false; } if (vm.isCallStackEmpty()) return .empty; @@ -215,11 +222,27 @@ fn execute(vm: *Story) !std.ArrayListUnmanaged(u8) { switch (code[frame.ip]) { .exit => { vm.is_exited = true; + vm.can_advance = false; return .empty; }, + .true => { + const true_object = try Object.Number.create(vm, .{ + .boolean = true, + }); + try vm.pushStack(@ptrCast(true_object)); + frame.ip += 1; + }, + .false => { + const false_object = try Object.Number.create(vm, .{ + .boolean = false, + }); + try vm.pushStack(@ptrCast(false_object)); + frame.ip += 1; + }, .pop => { - const object_top = vm.popStack(); - if (object_top == null) return error.InvalidArgument; + if (vm.popStack()) |_| {} else { + return error.InvalidArgument; + } frame.ip += 1; }, .add => { @@ -243,12 +266,74 @@ fn execute(vm: *Story) !std.ArrayListUnmanaged(u8) { frame.ip += 1; }, .neg => { - const arg_object = vm.peekStack(0); - if (arg_object) |arg| { + if (vm.peekStack(0)) |arg| { _ = Object.Number.negate(@ptrCast(arg)); + } else { + return error.InvalidArgument; } frame.ip += 1; }, + .not => { + if (vm.peekStack(0)) |arg| { + const value = try Object.Number.create(vm, .{ + .boolean = arg.isFalsey(), + }); + + _ = vm.popStack(); + try vm.pushStack(@ptrCast(value)); + } else { + return error.StackOverflow; + } + frame.ip += 1; + }, + .cmp_eq => { + const lhs = vm.peekStack(1) orelse return error.Bugged; + const rhs = vm.peekStack(0) orelse return error.Bugged; + const value = try Object.cmpEql(vm, @ptrCast(lhs), @ptrCast(rhs)); + _ = vm.popStack(); + _ = vm.popStack(); + try vm.pushStack(@ptrCast(value)); + + frame.ip += 1; + }, + .cmp_lt, .cmp_gt, .cmp_lte, .cmp_gte => |op| { + const lhs = vm.peekStack(1) orelse return error.Bugged; + const rhs = vm.peekStack(0) orelse return error.Bugged; + const value = try Object.Number.performLogic(vm, op, @ptrCast(lhs), @ptrCast(rhs)); + _ = vm.popStack(); + _ = vm.popStack(); + try vm.pushStack(@ptrCast(value)); + + frame.ip += 1; + }, + .jmp => { + const arg_offset = readAddress(code, frame.ip); + frame.ip += 3 + arg_offset; + }, + .jmp_t => { + const arg_offset = readAddress(code, frame.ip); + frame.ip += 3; + + if (vm.peekStack(0)) |condition| { + if (!condition.isFalsey()) { + frame.ip += arg_offset; + } + } else { + return error.InvalidArgument; + } + }, + .jmp_f => { + const arg_offset = readAddress(code, frame.ip); + frame.ip += 3; + + if (vm.peekStack(0)) |condition| { + if (condition.isFalsey()) { + frame.ip += arg_offset; + } + } else { + return error.InvalidArgument; + } + }, .load_const => { const index: u8 = @intFromEnum(code[frame.ip + 1]); const value = try vm.getConstant(frame, index); @@ -262,10 +347,11 @@ fn execute(vm: *Story) !std.ArrayListUnmanaged(u8) { frame.ip += 2; }, .store => { - const value = vm.peekStack(0); - if (value) |arg| { + if (vm.peekStack(0)) |arg| { const arg_offset: u8 = @intFromEnum(code[frame.ip + 1]); vm.setLocal(frame, arg_offset, arg); + } else { + return error.InvalidArgument; } frame.ip += 2; }, @@ -283,11 +369,12 @@ fn execute(vm: *Story) !std.ArrayListUnmanaged(u8) { const global_name = try vm.getConstant(frame, arg_offset); assert(global_name.tag == .string); - const value = vm.peekStack(0); - if (value) |arg| { + if (vm.peekStack(0)) |arg| { try vm.setGlobal(@ptrCast(global_name), arg); _ = vm.popStack(); try vm.pushStack(arg); + } else { + return error.InvalidArgument; } frame.ip += 2; }, @@ -297,20 +384,26 @@ fn execute(vm: *Story) !std.ArrayListUnmanaged(u8) { const string_object = try Object.String.fromObject(vm, object); const string_bytes = string_object.bytes[0..string_object.length]; try stream_writer.writer.writeAll(string_bytes); + } else { + return error.InvalidArgument; } frame.ip += 1; }, .stream_flush => { frame.ip += 1; + + // FIXME: This is a bit of a hack, but we have to deal with this right now. + const buffered = stream_writer.writer.buffered(); + if (buffered.len == 0) continue; return stream_writer.toArrayList(); }, .br_push => { - const arg_offset = std.mem.bytesToValue(u16, code[frame.ip + 1 ..][0..2]); + const arg_offset = readAddress(code, frame.ip); try vm.current_choices.append(gpa, .{ .text = stream_writer.toArrayList(), - .dest_offset = std.mem.bigToNative(u16, arg_offset), + .dest_offset = arg_offset, }); frame.ip += 3; }, @@ -335,6 +428,11 @@ fn execute(vm: *Story) !std.ArrayListUnmanaged(u8) { } } +fn readAddress(code: []const Story.Opcode, offset: usize) u16 { + const arg_offset = std.mem.bytesToValue(u16, code[offset + 1 ..][0..2]); + return std.mem.bigToNative(u16, arg_offset); +} + pub fn advance(story: *Story, gpa: std.mem.Allocator) ![]const u8 { var content = try story.execute(); return content.toOwnedSlice(gpa); diff --git a/src/Story/object.zig b/src/Story/object.zig index 2824bc8..b7c787e 100644 --- a/src/Story/object.zig +++ b/src/Story/object.zig @@ -50,6 +50,26 @@ pub const Object = struct { } } + pub fn cmpEql(story: *Story, lhs: *Object, rhs: *Object) !*Object { + // TODO: This is temporary + if (lhs.tag != .number or lhs.tag != rhs.tag) return error.InvalidComparison; + const result = try Object.Number.performLogic(story, .cmp_eq, @ptrCast(lhs), @ptrCast(rhs)); + return @ptrCast(result); + } + + pub fn isFalsey(obj: *Object) bool { + switch (obj.tag) { + .number => { + const number: *Object.Number = @ptrCast(obj); + switch (number.data) { + .boolean => |value| return !value, + else => return false, + } + }, + else => return false, + } + } + pub const Number = struct { base: Object, data: Data, @@ -96,7 +116,7 @@ pub const Object = struct { .floating => |value| break :v .{ .floating = value }, } }, - else => break :v .{ .integer = 1 }, + else => break :v .{ .boolean = true }, }; const obj = Object.Number.create(story, data); @@ -104,6 +124,17 @@ pub const Object = struct { return obj; } + fn logicalOp(comptime T: type, op: Story.Opcode, lhs: T, rhs: T) bool { + switch (op) { + .cmp_eq => return lhs == rhs, + .cmp_lt => return lhs < rhs, + .cmp_gt => return lhs > rhs, + .cmp_lte => return lhs <= rhs, + .cmp_gte => return lhs >= rhs, + else => unreachable, + } + } + fn arithmeticOp(comptime T: type, op: Story.Opcode, lhs: T, rhs: T) T { switch (op) { .add => return lhs + rhs, @@ -124,6 +155,17 @@ pub const Object = struct { return object; } + pub fn performLogic( + story: *Story, + op: Story.Opcode, + lhs: *Object.Number, + rhs: *Object.Number, + ) !*Object.Number { + return .create(story, .{ + .boolean = logicalOp(i64, op, lhs.data.integer, rhs.data.integer), + }); + } + pub fn performArithmetic( story: *Story, op: Story.Opcode, @@ -131,17 +173,12 @@ pub const Object = struct { rhs: *Object.Number, ) !*Object.Number { if (lhs.data == .floating or rhs.data == .floating) { - const _lhs = try Object.Number.fromObject(story, @ptrCast(lhs)); - const _rhs = try Object.Number.fromObject(story, @ptrCast(rhs)); return .create(story, .{ - .floating = arithmeticOp(f64, op, _lhs.data.floating, _rhs.data.floating), + .floating = arithmeticOp(f64, op, lhs.data.floating, rhs.data.floating), }); } - - const _lhs = try Object.Number.fromObject(story, @ptrCast(lhs)); - const _rhs = try Object.Number.fromObject(story, @ptrCast(rhs)); return .create(story, .{ - .integer = arithmeticOp(i64, op, _lhs.data.integer, _rhs.data.integer), + .integer = arithmeticOp(i64, op, lhs.data.integer, rhs.data.integer), }); } }; diff --git a/src/tokenizer.zig b/src/tokenizer.zig index c90018b..e09612d 100644 --- a/src/tokenizer.zig +++ b/src/tokenizer.zig @@ -305,7 +305,7 @@ pub const Tokenizer = struct { .bang => switch (self.buffer[self.index]) { '=' => { self.index += 1; - result.tag = .equal; + result.tag = .not_equal; }, else => result.tag = .exclaimation_mark, },