feat: code generation for conditional and switch statements

This commit is contained in:
Brett Broadhurst 2026-03-03 16:22:35 -07:00
parent 889f678dd8
commit d6ff3a40bd
Failed to generate hash of commit
7 changed files with 476 additions and 59 deletions

View file

@ -104,7 +104,7 @@ pub const Node = struct {
}, },
switch_stmt: struct { switch_stmt: struct {
condition_expr: ?*Node, condition_expr: ?*Node,
cases: ?[]*Node, cases: []*Node,
}, },
knot_decl: struct { knot_decl: struct {
prototype: *Node, prototype: *Node,
@ -175,7 +175,7 @@ pub const Node = struct {
tag: Tag, tag: Tag,
loc: Span, loc: Span,
condition_expr: ?*Node, condition_expr: ?*Node,
cases_list: ?[]*Node, cases_list: []*Node,
) !*Node { ) !*Node {
const node = try Node.create(gpa, tag, loc); const node = try Node.create(gpa, tag, loc);
node.data = .{ node.data = .{
@ -227,6 +227,10 @@ pub const Error = struct {
invalid_lvalue, invalid_lvalue,
too_many_arguments, too_many_arguments,
too_many_parameters, too_many_parameters,
invalid_else_stmt,
unexpected_else_stmt,
invalid_switch_case,
}; };
}; };

View file

@ -172,6 +172,9 @@ fn renderError(r: *Render, writer: *std.Io.Writer, err: Ast.Error) !void {
.invalid_lvalue => try renderErrorf(r, writer, err, "invalid lvalue for assignment"), .invalid_lvalue => try renderErrorf(r, writer, err, "invalid lvalue for assignment"),
.too_many_arguments => try renderErrorf(r, writer, err, "too many arguments to '{s}'"), .too_many_arguments => try renderErrorf(r, writer, err, "too many arguments to '{s}'"),
.too_many_parameters => try renderErrorf(r, writer, err, "too many parameters defined for '{s}'"), .too_many_parameters => try renderErrorf(r, writer, err, "too many parameters defined for '{s}'"),
.unexpected_else_stmt => try renderErrorf(r, writer, err, "unexpected else stmt"),
.invalid_else_stmt => try renderErrorf(r, writer, err, "invalid else stmt"),
.invalid_switch_case => try renderErrorf(r, writer, err, "invalid switch case expression"),
} }
} }
@ -491,9 +494,9 @@ fn renderAstWalk(
if (expr) |n| try children.append(r.gpa, n); if (expr) |n| try children.append(r.gpa, n);
const list = node.data.switch_stmt.cases; const list = node.data.switch_stmt.cases;
if (list) |items| for (items) |n| { for (list) |case_stmt| {
try children.append(r.gpa, n); try children.append(r.gpa, case_stmt);
}; }
}, },
.inline_logic_expr => { .inline_logic_expr => {
const lhs = node.data.bin.lhs; const lhs = node.data.bin.lhs;

View file

@ -25,7 +25,7 @@ pub const CheckError = error{
TooManyConstants, TooManyConstants,
InvalidCharacter, InvalidCharacter,
NotImplemented, NotImplemented,
}; } || anyerror;
const Scope = struct { const Scope = struct {
parent: ?*Scope, parent: ?*Scope,
@ -34,7 +34,7 @@ const Scope = struct {
symbol_table: std.StringHashMapUnmanaged(Symbol), symbol_table: std.StringHashMapUnmanaged(Symbol),
jump_stack_top: usize, jump_stack_top: usize,
label_stack_top: usize, label_stack_top: usize,
exit_label: ?usize, exit_label: usize,
pub fn deinit(scope: *Scope) void { pub fn deinit(scope: *Scope) void {
const gpa = scope.global.gpa; const gpa = scope.global.gpa;
@ -92,7 +92,7 @@ const Scope = struct {
.symbol_table = .empty, .symbol_table = .empty,
.jump_stack_top = global.jump_stack.items.len, .jump_stack_top = global.jump_stack.items.len,
.label_stack_top = global.label_stack.items.len, .label_stack_top = global.label_stack.items.len,
.exit_label = null, .exit_label = parent_scope.exit_label,
}; };
} }
@ -181,13 +181,13 @@ const Scope = struct {
} }
// Sets the code offset pointed to by a label. // Sets the code offset pointed to by a label.
pub fn setLabel(scope: *Scope, label_id: usize) void { pub fn setLabel(scope: *Scope, label_index: usize) void {
const chunk = scope.chunk; const chunk = scope.chunk;
const code_offset = chunk.bytes.items.len; const code_offset = chunk.bytes.items.len;
const label_stack = &scope.global.label_stack; const label_stack = &scope.global.label_stack;
assert(label_id <= label_stack.items.len); assert(label_index <= label_stack.items.len);
const label_data = &label_stack.items[label_id]; const label_data = &label_stack.items[label_index];
label_data.code_offset = code_offset; label_data.code_offset = code_offset;
} }
@ -199,6 +199,10 @@ const Scope = struct {
return jump_index; return jump_index;
} }
pub fn setExit(scope: *Scope, label_index: usize) void {
scope.exit_label = label_index;
}
pub fn resolveLabels(scope: *Scope, start_index: usize, end_index: usize) !void { pub fn resolveLabels(scope: *Scope, start_index: usize, end_index: usize) !void {
assert(start_index <= end_index); assert(start_index <= end_index);
const jump_stack = &scope.global.jump_stack; const jump_stack = &scope.global.jump_stack;
@ -213,7 +217,7 @@ const Scope = struct {
.absolute => label.code_offset, .absolute => label.code_offset,
}; };
if (jump_offset >= std.math.maxInt(u16)) { if (jump_offset >= std.math.maxInt(u16)) {
std.debug.print("Too much code to jump over!\n", .{}); std.debug.print("Too much code to jump over! {d}\n", .{jump_offset});
return error.CompilerBug; return error.CompilerBug;
} }
@ -310,22 +314,34 @@ fn checkUnaryOp(scope: *Scope, node: *const Ast.Node, op: Story.Opcode) CheckErr
} }
fn checkBinaryOp(scope: *Scope, node: *const Ast.Node, op: Story.Opcode) CheckError!void { fn checkBinaryOp(scope: *Scope, node: *const Ast.Node, op: Story.Opcode) CheckError!void {
const bin_data = node.data.bin;
assert(bin_data.lhs != null and bin_data.rhs != null);
try checkExpr(scope, bin_data.lhs);
try checkExpr(scope, bin_data.rhs);
try scope.emitSimpleInst(op);
}
fn checkLogicalOp(scope: *Scope, node: *const Ast.Node, binary_or: bool) CheckError!void {
const data = node.data.bin; const data = node.data.bin;
assert(data.lhs != null and data.rhs != null); assert(data.lhs != null and data.rhs != null);
try checkExpr(scope, data.lhs); try checkExpr(scope, data.lhs);
const else_branch = try scope.emitJumpInst(if (binary_or) .jmp_t else .jmp_f);
try scope.emitSimpleInst(.pop);
try checkExpr(scope, data.rhs); try checkExpr(scope, data.rhs);
try scope.patchJump(else_branch); try scope.emitSimpleInst(op);
}
fn checkLogicalOp(scope: *Scope, node: *const Ast.Node, op: Story.Opcode) CheckError!void {
const data = node.data.bin;
assert(data.lhs != null and data.rhs != null);
try checkExpr(scope, data.lhs);
const else_label = try scope.makeLabel();
const jump_offset = try scope.emitJumpInst(op);
_ = try scope.makeJump(.{
.mode = .relative,
.label_index = else_label,
.code_offset = jump_offset,
});
try scope.emitSimpleInst(.pop);
const rhs_label = try scope.makeLabel();
scope.setLabel(rhs_label);
try checkExpr(scope, data.rhs);
scope.setLabel(else_label);
} }
fn checkTrueLiteral(scope: *Scope, _: *const Ast.Node) CheckError!void { fn checkTrueLiteral(scope: *Scope, _: *const Ast.Node) CheckError!void {
@ -338,10 +354,10 @@ fn checkFalseLiteral(scope: *Scope, _: *const Ast.Node) CheckError!void {
fn checkNumberLiteral(scope: *Scope, node: *const Ast.Node) CheckError!void { fn checkNumberLiteral(scope: *Scope, node: *const Ast.Node) CheckError!void {
const lexeme = getLexemeFromNode(scope.global, node); const lexeme = getLexemeFromNode(scope.global, node);
const number_value = try std.fmt.parseFloat(f64, lexeme); const number_value = try std.fmt.parseUnsigned(i64, lexeme, 10);
const number_object = try Story.Object.Number.create( const number_object = try Story.Object.Number.create(
scope.global.story, scope.global.story,
.{ .floating = number_value }, .{ .integer = number_value },
); );
const constant_id = try scope.makeConstant(@ptrCast(number_object)); const constant_id = try scope.makeConstant(@ptrCast(number_object));
@ -396,6 +412,18 @@ fn checkExpr(scope: *Scope, expr: ?*const Ast.Node) CheckError!void {
.divide_expr => try checkBinaryOp(scope, expr_node, .div), .divide_expr => try checkBinaryOp(scope, expr_node, .div),
.mod_expr => try checkBinaryOp(scope, expr_node, .mod), .mod_expr => try checkBinaryOp(scope, expr_node, .mod),
.negate_expr => try checkUnaryOp(scope, expr_node, .neg), .negate_expr => try checkUnaryOp(scope, expr_node, .neg),
.logical_and_expr => try checkLogicalOp(scope, expr_node, .jmp_f),
.logical_or_expr => try checkLogicalOp(scope, expr_node, .jmp_t),
.logical_not_expr => try checkUnaryOp(scope, expr_node, .not),
.logical_equality_expr => try checkBinaryOp(scope, expr_node, .cmp_eq),
.logical_inequality_expr => {
try scope.emitSimpleInst(.not);
try checkBinaryOp(scope, expr_node, .cmp_eq);
},
.logical_greater_expr => try checkBinaryOp(scope, expr_node, .cmp_gt),
.logical_greater_or_equal_expr => try checkBinaryOp(scope, expr_node, .cmp_gte),
.logical_lesser_expr => try checkBinaryOp(scope, expr_node, .cmp_lt),
.logical_lesser_or_equal_expr => try checkBinaryOp(scope, expr_node, .cmp_lte),
else => return error.NotImplemented, else => return error.NotImplemented,
} }
} }
@ -406,6 +434,239 @@ fn checkExprStmt(scope: *Scope, stmt: *const Ast.Node) CheckError!void {
try scope.emitSimpleInst(.pop); try scope.emitSimpleInst(.pop);
} }
fn validateSwitchProngs(scope: *Scope, stmt: *const Ast.Node) CheckError!void {
var stmt_has_block: bool = false;
var stmt_has_else: bool = false;
const case_list = stmt.data.switch_stmt.cases;
const last_prong = case_list[case_list.len - 1];
for (case_list) |case_stmt| {
switch (case_stmt.tag) {
.block_stmt => stmt_has_block = true,
.switch_case, .if_branch => {
if (stmt_has_block) {
//return scope.fail(.expected_else, case_stmt);
}
},
.else_branch => {
if (case_stmt != last_prong) {
return scope.fail(.invalid_else_stmt, case_stmt);
}
if (stmt_has_else) {
return scope.fail(.unexpected_else_stmt, case_stmt);
}
stmt_has_else = true;
},
else => unreachable,
}
}
}
fn checkIfStmt(parent_scope: *Scope, stmt: *const Ast.Node) CheckError!void {
var child_scope = parent_scope.makeSubBlock();
defer child_scope.deinit();
const case_list = stmt.data.switch_stmt.cases;
const eval_expr = stmt.data.switch_stmt.condition_expr;
if (eval_expr) |expr_node| {
try validateSwitchProngs(&child_scope, stmt);
const first_prong = case_list[0];
const last_prong = case_list[case_list.len - 1];
const then_stmt: *const Ast.Node = first_prong;
const else_stmt: ?*const Ast.Node = if (first_prong == last_prong)
null
else
last_prong;
try checkExpr(&child_scope, expr_node);
const else_label = try child_scope.makeLabel();
const end_label = try child_scope.makeLabel();
const then_br = try child_scope.emitJumpInst(.jmp_f);
_ = try child_scope.makeJump(.{
.mode = .relative,
.label_index = else_label,
.code_offset = then_br,
});
try child_scope.emitSimpleInst(.pop);
try checkBlockStmt(&child_scope, then_stmt);
const else_br = try child_scope.emitJumpInst(.jmp);
_ = try child_scope.makeJump(.{
.mode = .relative,
.label_index = end_label,
.code_offset = else_br,
});
child_scope.setLabel(else_label);
try child_scope.emitSimpleInst(.pop);
if (else_stmt) |else_node| {
const block_stmt = else_node.data.bin.rhs;
try checkBlockStmt(&child_scope, block_stmt);
}
child_scope.setLabel(end_label);
} else {
return child_scope.fail(.expected_expression, stmt);
}
}
fn checkMultiIfStmt(
parent_scope: *Scope,
stmt: *const Ast.Node,
) CheckError!void {
const gpa = parent_scope.global.gpa;
var child_scope = parent_scope.makeSubBlock();
defer child_scope.deinit();
const exit_label = try child_scope.makeLabel();
child_scope.setExit(exit_label);
try validateSwitchProngs(&child_scope, stmt);
const case_list = stmt.data.switch_stmt.cases;
// NOTE: We're going to create an array of label indexes here, since we
// may create additional labels while traversing nested expressions.
var label_list: std.ArrayList(usize) = .empty;
defer label_list.deinit(gpa);
try label_list.ensureUnusedCapacity(gpa, case_list.len);
for (case_list) |case_stmt| {
const label_index = try child_scope.makeLabel();
switch (case_stmt.tag) {
.if_branch => {
const lhs = case_stmt.data.bin.lhs orelse unreachable;
try checkExpr(&child_scope, lhs);
const jump_offset = try child_scope.emitJumpInst(.jmp_t);
_ = try child_scope.makeJump(.{
.mode = .relative,
.label_index = label_index,
.code_offset = jump_offset,
});
try child_scope.emitSimpleInst(.pop);
},
.else_branch => {
const jump_offset = try child_scope.emitJumpInst(.jmp);
_ = try child_scope.makeJump(.{
.mode = .relative,
.label_index = label_index,
.code_offset = jump_offset,
});
},
else => unreachable,
}
}
for (case_list, label_list.items) |case_stmt, label_index| {
const body_stmt = case_stmt.data.bin.rhs;
switch (case_stmt.tag) {
.if_branch => {
child_scope.setLabel(label_index);
try child_scope.emitSimpleInst(.pop);
},
.else_branch => {
child_scope.setLabel(label_index);
},
else => unreachable,
}
try checkBlockStmt(&child_scope, body_stmt);
const jump_inst = try child_scope.emitJumpInst(.jmp);
_ = try child_scope.makeJump(.{
.mode = .relative,
.label_index = child_scope.exit_label,
.code_offset = jump_inst,
});
}
child_scope.setLabel(child_scope.exit_label);
}
fn checkSwitchStmt(parent_scope: *Scope, switch_stmt: *const Ast.Node) CheckError!void {
const gpa = parent_scope.global.gpa;
const current_chunk = parent_scope.chunk;
var child_scope = parent_scope.makeSubBlock();
defer child_scope.deinit();
const label_index = try child_scope.makeLabel();
child_scope.setExit(label_index);
const eval_expr = switch_stmt.data.switch_stmt.condition_expr;
const case_list = switch_stmt.data.switch_stmt.cases;
// NOTE: We're going to create an array of label indexes here, since we
// may create additional labels while traversing nested expressions.
var label_list: std.ArrayList(usize) = .empty;
defer label_list.deinit(gpa);
try label_list.ensureUnusedCapacity(gpa, case_list.len);
const stack_slot = current_chunk.arity + current_chunk.locals_count;
current_chunk.locals_count += 1;
try checkExpr(&child_scope, eval_expr);
try child_scope.emitConstInst(.store, stack_slot);
try child_scope.emitSimpleInst(.pop);
for (case_list) |case_stmt| {
const case_label_index = try child_scope.makeLabel();
label_list.appendAssumeCapacity(case_label_index);
switch (case_stmt.tag) {
.switch_case => {
const case_eval_expr = case_stmt.data.bin.lhs orelse unreachable;
switch (case_eval_expr.tag) {
.number_literal, .true_literal, .false_literal => {},
else => {
return child_scope.fail(.invalid_switch_case, case_stmt);
},
}
try child_scope.emitConstInst(.load, stack_slot);
try checkExpr(&child_scope, case_eval_expr);
try child_scope.emitSimpleInst(.cmp_eq);
const jump_offset = try child_scope.emitJumpInst(.jmp_t);
_ = try child_scope.makeJump(.{
.mode = .relative,
.label_index = case_label_index,
.code_offset = jump_offset,
});
try child_scope.emitSimpleInst(.pop);
},
.else_branch => {
const jump_offset = try child_scope.emitJumpInst(.jmp);
_ = try child_scope.makeJump(.{
.mode = .relative,
.label_index = case_label_index,
.code_offset = jump_offset,
});
},
else => unreachable,
}
}
for (case_list, label_list.items) |case_stmt, case_label_index| {
child_scope.setLabel(case_label_index);
switch (case_stmt.tag) {
.switch_case => {
try child_scope.emitSimpleInst(.pop);
},
.else_branch => {},
else => unreachable,
}
const block_stmt = case_stmt.data.bin.rhs;
try checkBlockStmt(&child_scope, block_stmt);
const jump_offset = try child_scope.emitJumpInst(.jmp);
_ = try child_scope.makeJump(.{
.mode = .relative,
.label_index = child_scope.exit_label,
.code_offset = jump_offset,
});
}
child_scope.setLabel(child_scope.exit_label);
}
fn checkInlineLogicExpr(scope: *Scope, expr: *const Ast.Node) CheckError!void { fn checkInlineLogicExpr(scope: *Scope, expr: *const Ast.Node) CheckError!void {
assert(expr.data.bin.lhs != null); assert(expr.data.bin.lhs != null);
return checkExpr(scope, expr.data.bin.lhs); return checkExpr(scope, expr.data.bin.lhs);
@ -423,6 +684,9 @@ fn checkContentExpr(scope: *Scope, expr: *const Ast.Node) CheckError!void {
try checkInlineLogicExpr(scope, child_node); try checkInlineLogicExpr(scope, child_node);
try scope.emitSimpleInst(.stream_push); try scope.emitSimpleInst(.stream_push);
}, },
.if_stmt => try checkIfStmt(scope, child_node),
.multi_if_stmt => try checkMultiIfStmt(scope, child_node),
.switch_stmt => try checkSwitchStmt(scope, child_node),
else => return error.NotImplemented, else => return error.NotImplemented,
} }
} }
@ -430,7 +694,8 @@ fn checkContentExpr(scope: *Scope, expr: *const Ast.Node) CheckError!void {
fn checkContentStmt(scope: *Scope, stmt: *const Ast.Node) CheckError!void { fn checkContentStmt(scope: *Scope, stmt: *const Ast.Node) CheckError!void {
const expr_node = stmt.data.bin.lhs orelse return; const expr_node = stmt.data.bin.lhs orelse return;
return checkContentExpr(scope, expr_node); try checkContentExpr(scope, expr_node);
try scope.emitSimpleInst(.stream_flush);
} }
fn checkAssignStmt(scope: *Scope, stmt: *const Ast.Node) CheckError!void { fn checkAssignStmt(scope: *Scope, stmt: *const Ast.Node) CheckError!void {
@ -460,13 +725,6 @@ fn checkAssignStmt(scope: *Scope, stmt: *const Ast.Node) CheckError!void {
return scope.fail(.unknown_identifier, lhs); return scope.fail(.unknown_identifier, lhs);
} }
fn checkBlockStmt(parent_scope: *Scope, block_stmt: *const Ast.Node) CheckError!void {
var block_scope = parent_scope.makeSubBlock();
defer block_scope.deinit();
const children = block_stmt.data.list.items orelse return;
for (children) |child_stmt| try checkStmt(&block_scope, child_stmt);
}
fn checkChoiceStmt(scope: *Scope, stmt_node: *const Ast.Node) CheckError!void { fn checkChoiceStmt(scope: *Scope, stmt_node: *const Ast.Node) CheckError!void {
const Choice = struct { const Choice = struct {
label_index: usize, label_index: usize,
@ -596,6 +854,23 @@ fn checkStmt(scope: *Scope, stmt: *const Ast.Node) CheckError!void {
} }
} }
fn checkBlockStmt(parent_scope: *Scope, block_stmt: ?*const Ast.Node) CheckError!void {
if (block_stmt) |stmt| {
var block_scope = parent_scope.makeSubBlock();
defer block_scope.deinit();
const children = stmt.data.list.items orelse return;
for (children) |child_stmt| try checkStmt(&block_scope, child_stmt);
} else {
const jump_offset = try parent_scope.emitJumpInst(.jmp);
_ = try parent_scope.makeJump(.{
.mode = .relative,
.label_index = parent_scope.exit_label,
.code_offset = jump_offset,
});
}
}
fn checkAnonymousKnot(parent_scope: *Scope, body: *const Ast.Node) CheckError!void { fn checkAnonymousKnot(parent_scope: *Scope, body: *const Ast.Node) CheckError!void {
const exit_label = try parent_scope.makeLabel(); const exit_label = try parent_scope.makeLabel();
parent_scope.exit_label = exit_label; parent_scope.exit_label = exit_label;
@ -623,7 +898,7 @@ fn checkFile(astgen: *AstGen, file: *const Ast.Node) CheckError!void {
.symbol_table = .empty, .symbol_table = .empty,
.jump_stack_top = astgen.jump_stack.items.len, .jump_stack_top = astgen.jump_stack.items.len,
.label_stack_top = astgen.label_stack.items.len, .label_stack_top = astgen.label_stack.items.len,
.exit_label = null, .exit_label = 0,
}; };
defer file_scope.deinit(); defer file_scope.deinit();
@ -665,6 +940,8 @@ pub fn generate(gpa: std.mem.Allocator, tree: *const Ast) !Story {
.is_exited = false, .is_exited = false,
.can_advance = false, .can_advance = false,
}; };
errdefer story.deinit();
var astgen: AstGen = .{ var astgen: AstGen = .{
.gpa = gpa, .gpa = gpa,
.tree = tree, .tree = tree,

View file

@ -245,9 +245,7 @@ fn popScratch(p: *Parse, context: *const StmtContext) *Ast.Node {
@panic("BUG: Scratch buffer popped when empty!"); @panic("BUG: Scratch buffer popped when empty!");
} }
fn nodeListFromScratch(p: *Parse, start_offset: usize, end_offset: usize) Error!?[]*Ast.Node { fn nodeListFromScratch(p: *Parse, start_offset: usize, end_offset: usize) Error![]*Ast.Node {
if (start_offset >= end_offset) return null;
const span = end_offset - start_offset; const span = end_offset - start_offset;
assert(span > 0); assert(span > 0);

View file

@ -38,7 +38,9 @@ pub const Opcode = enum(u8) {
ret, ret,
/// Pop a value off the stack, discarding it. /// Pop a value off the stack, discarding it.
pop, pop,
/// Push an object representing the boolean value of "true" to the stack.
true, true,
/// Push an object representing the boolean value of "false" to the stack.
false, false,
/// Pop two values off the stack and calculate their sum. /// Pop two values off the stack and calculate their sum.
/// The result will be pushed to the stack. /// The result will be pushed to the stack.
@ -60,8 +62,13 @@ pub const Opcode = enum(u8) {
cmp_gt, cmp_gt,
cmp_lte, cmp_lte,
cmp_gte, cmp_gte,
/// Jump unconditionally to the target address.
jmp, jmp,
/// Jump conditionally to the target address if the boolean value at the
/// top of the stack is true.
jmp_t, jmp_t,
/// Jump conditionally to the target address if the boolean value at the
/// top of the stack is false.
jmp_f, jmp_f,
call, call,
divert, divert,
@ -118,14 +125,14 @@ pub fn trace(story: *Story, writer: *std.Io.Writer, frame: *CallFrame) !void {
if (slot) |object| { if (slot) |object| {
try story_dumper.dumpObject(object); try story_dumper.dumpObject(object);
} else { } else {
try writer.writeAll("NULL"); try writer.writeAll("null");
} }
try writer.writeAll(", "); try writer.writeAll(", ");
} }
if (last_slot) |object| { if (last_slot) |object| {
try story_dumper.dumpObject(object); try story_dumper.dumpObject(object);
} else { } else {
try writer.writeAll("NULL"); try writer.writeAll("null");
} }
} }
@ -198,7 +205,7 @@ fn setGlobal(vm: *Story, key: *const Object.String, value: *Object) !void {
fn execute(vm: *Story) !std.ArrayListUnmanaged(u8) { fn execute(vm: *Story) !std.ArrayListUnmanaged(u8) {
const gpa = vm.allocator; const gpa = vm.allocator;
defer { errdefer {
vm.can_advance = false; vm.can_advance = false;
} }
if (vm.isCallStackEmpty()) return .empty; if (vm.isCallStackEmpty()) return .empty;
@ -215,11 +222,27 @@ fn execute(vm: *Story) !std.ArrayListUnmanaged(u8) {
switch (code[frame.ip]) { switch (code[frame.ip]) {
.exit => { .exit => {
vm.is_exited = true; vm.is_exited = true;
vm.can_advance = false;
return .empty; return .empty;
}, },
.true => {
const true_object = try Object.Number.create(vm, .{
.boolean = true,
});
try vm.pushStack(@ptrCast(true_object));
frame.ip += 1;
},
.false => {
const false_object = try Object.Number.create(vm, .{
.boolean = false,
});
try vm.pushStack(@ptrCast(false_object));
frame.ip += 1;
},
.pop => { .pop => {
const object_top = vm.popStack(); if (vm.popStack()) |_| {} else {
if (object_top == null) return error.InvalidArgument; return error.InvalidArgument;
}
frame.ip += 1; frame.ip += 1;
}, },
.add => { .add => {
@ -243,12 +266,74 @@ fn execute(vm: *Story) !std.ArrayListUnmanaged(u8) {
frame.ip += 1; frame.ip += 1;
}, },
.neg => { .neg => {
const arg_object = vm.peekStack(0); if (vm.peekStack(0)) |arg| {
if (arg_object) |arg| {
_ = Object.Number.negate(@ptrCast(arg)); _ = Object.Number.negate(@ptrCast(arg));
} else {
return error.InvalidArgument;
} }
frame.ip += 1; frame.ip += 1;
}, },
.not => {
if (vm.peekStack(0)) |arg| {
const value = try Object.Number.create(vm, .{
.boolean = arg.isFalsey(),
});
_ = vm.popStack();
try vm.pushStack(@ptrCast(value));
} else {
return error.StackOverflow;
}
frame.ip += 1;
},
.cmp_eq => {
const lhs = vm.peekStack(1) orelse return error.Bugged;
const rhs = vm.peekStack(0) orelse return error.Bugged;
const value = try Object.cmpEql(vm, @ptrCast(lhs), @ptrCast(rhs));
_ = vm.popStack();
_ = vm.popStack();
try vm.pushStack(@ptrCast(value));
frame.ip += 1;
},
.cmp_lt, .cmp_gt, .cmp_lte, .cmp_gte => |op| {
const lhs = vm.peekStack(1) orelse return error.Bugged;
const rhs = vm.peekStack(0) orelse return error.Bugged;
const value = try Object.Number.performLogic(vm, op, @ptrCast(lhs), @ptrCast(rhs));
_ = vm.popStack();
_ = vm.popStack();
try vm.pushStack(@ptrCast(value));
frame.ip += 1;
},
.jmp => {
const arg_offset = readAddress(code, frame.ip);
frame.ip += 3 + arg_offset;
},
.jmp_t => {
const arg_offset = readAddress(code, frame.ip);
frame.ip += 3;
if (vm.peekStack(0)) |condition| {
if (!condition.isFalsey()) {
frame.ip += arg_offset;
}
} else {
return error.InvalidArgument;
}
},
.jmp_f => {
const arg_offset = readAddress(code, frame.ip);
frame.ip += 3;
if (vm.peekStack(0)) |condition| {
if (condition.isFalsey()) {
frame.ip += arg_offset;
}
} else {
return error.InvalidArgument;
}
},
.load_const => { .load_const => {
const index: u8 = @intFromEnum(code[frame.ip + 1]); const index: u8 = @intFromEnum(code[frame.ip + 1]);
const value = try vm.getConstant(frame, index); const value = try vm.getConstant(frame, index);
@ -262,10 +347,11 @@ fn execute(vm: *Story) !std.ArrayListUnmanaged(u8) {
frame.ip += 2; frame.ip += 2;
}, },
.store => { .store => {
const value = vm.peekStack(0); if (vm.peekStack(0)) |arg| {
if (value) |arg| {
const arg_offset: u8 = @intFromEnum(code[frame.ip + 1]); const arg_offset: u8 = @intFromEnum(code[frame.ip + 1]);
vm.setLocal(frame, arg_offset, arg); vm.setLocal(frame, arg_offset, arg);
} else {
return error.InvalidArgument;
} }
frame.ip += 2; frame.ip += 2;
}, },
@ -283,11 +369,12 @@ fn execute(vm: *Story) !std.ArrayListUnmanaged(u8) {
const global_name = try vm.getConstant(frame, arg_offset); const global_name = try vm.getConstant(frame, arg_offset);
assert(global_name.tag == .string); assert(global_name.tag == .string);
const value = vm.peekStack(0); if (vm.peekStack(0)) |arg| {
if (value) |arg| {
try vm.setGlobal(@ptrCast(global_name), arg); try vm.setGlobal(@ptrCast(global_name), arg);
_ = vm.popStack(); _ = vm.popStack();
try vm.pushStack(arg); try vm.pushStack(arg);
} else {
return error.InvalidArgument;
} }
frame.ip += 2; frame.ip += 2;
}, },
@ -297,20 +384,26 @@ fn execute(vm: *Story) !std.ArrayListUnmanaged(u8) {
const string_object = try Object.String.fromObject(vm, object); const string_object = try Object.String.fromObject(vm, object);
const string_bytes = string_object.bytes[0..string_object.length]; const string_bytes = string_object.bytes[0..string_object.length];
try stream_writer.writer.writeAll(string_bytes); try stream_writer.writer.writeAll(string_bytes);
} else {
return error.InvalidArgument;
} }
frame.ip += 1; frame.ip += 1;
}, },
.stream_flush => { .stream_flush => {
frame.ip += 1; frame.ip += 1;
// FIXME: This is a bit of a hack, but we have to deal with this right now.
const buffered = stream_writer.writer.buffered();
if (buffered.len == 0) continue;
return stream_writer.toArrayList(); return stream_writer.toArrayList();
}, },
.br_push => { .br_push => {
const arg_offset = std.mem.bytesToValue(u16, code[frame.ip + 1 ..][0..2]); const arg_offset = readAddress(code, frame.ip);
try vm.current_choices.append(gpa, .{ try vm.current_choices.append(gpa, .{
.text = stream_writer.toArrayList(), .text = stream_writer.toArrayList(),
.dest_offset = std.mem.bigToNative(u16, arg_offset), .dest_offset = arg_offset,
}); });
frame.ip += 3; frame.ip += 3;
}, },
@ -335,6 +428,11 @@ fn execute(vm: *Story) !std.ArrayListUnmanaged(u8) {
} }
} }
fn readAddress(code: []const Story.Opcode, offset: usize) u16 {
const arg_offset = std.mem.bytesToValue(u16, code[offset + 1 ..][0..2]);
return std.mem.bigToNative(u16, arg_offset);
}
pub fn advance(story: *Story, gpa: std.mem.Allocator) ![]const u8 { pub fn advance(story: *Story, gpa: std.mem.Allocator) ![]const u8 {
var content = try story.execute(); var content = try story.execute();
return content.toOwnedSlice(gpa); return content.toOwnedSlice(gpa);

View file

@ -50,6 +50,26 @@ pub const Object = struct {
} }
} }
pub fn cmpEql(story: *Story, lhs: *Object, rhs: *Object) !*Object {
// TODO: This is temporary
if (lhs.tag != .number or lhs.tag != rhs.tag) return error.InvalidComparison;
const result = try Object.Number.performLogic(story, .cmp_eq, @ptrCast(lhs), @ptrCast(rhs));
return @ptrCast(result);
}
pub fn isFalsey(obj: *Object) bool {
switch (obj.tag) {
.number => {
const number: *Object.Number = @ptrCast(obj);
switch (number.data) {
.boolean => |value| return !value,
else => return false,
}
},
else => return false,
}
}
pub const Number = struct { pub const Number = struct {
base: Object, base: Object,
data: Data, data: Data,
@ -96,7 +116,7 @@ pub const Object = struct {
.floating => |value| break :v .{ .floating = value }, .floating => |value| break :v .{ .floating = value },
} }
}, },
else => break :v .{ .integer = 1 }, else => break :v .{ .boolean = true },
}; };
const obj = Object.Number.create(story, data); const obj = Object.Number.create(story, data);
@ -104,6 +124,17 @@ pub const Object = struct {
return obj; return obj;
} }
fn logicalOp(comptime T: type, op: Story.Opcode, lhs: T, rhs: T) bool {
switch (op) {
.cmp_eq => return lhs == rhs,
.cmp_lt => return lhs < rhs,
.cmp_gt => return lhs > rhs,
.cmp_lte => return lhs <= rhs,
.cmp_gte => return lhs >= rhs,
else => unreachable,
}
}
fn arithmeticOp(comptime T: type, op: Story.Opcode, lhs: T, rhs: T) T { fn arithmeticOp(comptime T: type, op: Story.Opcode, lhs: T, rhs: T) T {
switch (op) { switch (op) {
.add => return lhs + rhs, .add => return lhs + rhs,
@ -124,6 +155,17 @@ pub const Object = struct {
return object; return object;
} }
pub fn performLogic(
story: *Story,
op: Story.Opcode,
lhs: *Object.Number,
rhs: *Object.Number,
) !*Object.Number {
return .create(story, .{
.boolean = logicalOp(i64, op, lhs.data.integer, rhs.data.integer),
});
}
pub fn performArithmetic( pub fn performArithmetic(
story: *Story, story: *Story,
op: Story.Opcode, op: Story.Opcode,
@ -131,17 +173,12 @@ pub const Object = struct {
rhs: *Object.Number, rhs: *Object.Number,
) !*Object.Number { ) !*Object.Number {
if (lhs.data == .floating or rhs.data == .floating) { if (lhs.data == .floating or rhs.data == .floating) {
const _lhs = try Object.Number.fromObject(story, @ptrCast(lhs));
const _rhs = try Object.Number.fromObject(story, @ptrCast(rhs));
return .create(story, .{ return .create(story, .{
.floating = arithmeticOp(f64, op, _lhs.data.floating, _rhs.data.floating), .floating = arithmeticOp(f64, op, lhs.data.floating, rhs.data.floating),
}); });
} }
const _lhs = try Object.Number.fromObject(story, @ptrCast(lhs));
const _rhs = try Object.Number.fromObject(story, @ptrCast(rhs));
return .create(story, .{ return .create(story, .{
.integer = arithmeticOp(i64, op, _lhs.data.integer, _rhs.data.integer), .integer = arithmeticOp(i64, op, lhs.data.integer, rhs.data.integer),
}); });
} }
}; };

View file

@ -305,7 +305,7 @@ pub const Tokenizer = struct {
.bang => switch (self.buffer[self.index]) { .bang => switch (self.buffer[self.index]) {
'=' => { '=' => {
self.index += 1; self.index += 1;
result.tag = .equal; result.tag = .not_equal;
}, },
else => result.tag = .exclaimation_mark, else => result.tag = .exclaimation_mark,
}, },