mach/libs/dusk/src/Ast.zig
2023-04-11 09:02:14 -07:00

700 lines
15 KiB
Zig

const std = @import("std");
const Parser = @import("Parser.zig");
const Token = @import("Token.zig");
const Tokenizer = @import("Tokenizer.zig");
const ErrorList = @import("ErrorList.zig");
const Extension = @import("main.zig").Extension;
const Ast = @This();
pub const NodeList = std.MultiArrayList(Node);
pub const TokenList = std.MultiArrayList(Token);
source: [:0]const u8,
tokens: TokenList.Slice,
nodes: NodeList.Slice,
extra: []const Index,
errors: ErrorList,
pub fn deinit(tree: *Ast, allocator: std.mem.Allocator) void {
tree.tokens.deinit(allocator);
tree.nodes.deinit(allocator);
allocator.free(tree.extra);
tree.errors.deinit();
tree.* = undefined;
}
/// parses a TranslationUnit (WGSL Program)
pub fn parse(allocator: std.mem.Allocator, source: [:0]const u8) error{OutOfMemory}!Ast {
var p = Parser{
.allocator = allocator,
.source = source,
.tok_i = 0,
.tokens = blk: {
const estimated_tokens = source.len / 8;
var tokens = std.MultiArrayList(Token){};
errdefer tokens.deinit(allocator);
try tokens.ensureTotalCapacity(allocator, estimated_tokens);
var tokenizer = Tokenizer.init(source);
while (true) {
const tok = tokenizer.next();
try tokens.append(allocator, tok);
if (tok.tag == .eof) break;
}
break :blk tokens;
},
.nodes = .{},
.extra = .{},
.scratch = .{},
.errors = try ErrorList.init(allocator),
.extensions = Extension.Array.initFill(false),
};
defer p.scratch.deinit(allocator);
errdefer {
p.tokens.deinit(allocator);
p.nodes.deinit(allocator);
p.extra.deinit(allocator);
p.errors.deinit();
}
// TODO: make sure tokens:nodes ratio is right
const estimated_node_count = (p.tokens.len + 2) / 2;
try p.nodes.ensureTotalCapacity(allocator, estimated_node_count);
try p.translationUnit();
return .{
.source = source,
.tokens = p.tokens.toOwnedSlice(),
.nodes = p.nodes.toOwnedSlice(),
.extra = try p.extra.toOwnedSlice(allocator),
.errors = p.errors,
};
}
pub fn spanToList(tree: Ast, span: Ast.Index) []const Ast.Index {
std.debug.assert(tree.nodeTag(span) == .span);
return tree.extra[tree.nodeLHS(span)..tree.nodeRHS(span)];
}
pub fn extraData(tree: Ast, comptime T: type, index: Ast.Index) T {
const fields = std.meta.fields(T);
var result: T = undefined;
inline for (fields, 0..) |field, i| {
comptime std.debug.assert(field.type == Ast.Index);
@field(result, field.name) = tree.extra[index + i];
}
return result;
}
pub fn tokenTag(tree: Ast, i: Index) Token.Tag {
return tree.tokens.items(.tag)[i];
}
pub fn tokenLoc(tree: Ast, i: Index) Token.Loc {
return tree.tokens.items(.loc)[i];
}
pub fn nodeTag(tree: Ast, i: Index) Node.Tag {
return tree.nodes.items(.tag)[i];
}
pub fn nodeToken(tree: Ast, i: Index) Index {
return tree.nodes.items(.main_token)[i];
}
pub fn nodeLHS(tree: Ast, i: Index) Index {
return tree.nodes.items(.lhs)[i];
}
pub fn nodeRHS(tree: Ast, i: Index) Index {
return tree.nodes.items(.rhs)[i];
}
pub const Index = u32;
pub const null_index: Index = 0;
pub const Node = struct {
tag: Tag,
main_token: Index,
lhs: Index = null_index,
rhs: Index = null_index,
pub const Tag = enum {
/// an slice to extra field [LHS..RHS]
/// TOK : undefined
/// LHS : Index
/// RHS : Index
span,
// ####### GlobalDecl #######
/// TOK : k_var
/// LHS : GlobalVarDecl
/// RHS : Expr?
global_variable,
/// TOK : k_const
/// LHS : Type
/// RHS : Expr
global_constant,
/// TOK : k_override
/// LHS : OverrideDecl
/// RHS : Expr
override,
/// TOK : k_type
/// LHS : Type
/// RHS : --
type_alias,
/// TOK : k_const_assert
/// LHS : Expr
/// RHS : --
const_assert,
/// TOK : k_struct
/// LHS : span(struct_member)
/// RHS : --
struct_decl,
/// TOK : ident
/// LHS : span(Attribute)
/// RHS : Type
struct_member,
/// TOK : k_fn
/// LHS : FnProto
/// RHS : block
fn_decl,
/// TOK : ident
/// LHS : ?span(Attribute)
/// RHS : type
fn_param,
// ####### Statement #######
// block = span(Statement)
/// TOK : k_return
/// LHS : Expr?
/// RHS : --
@"return",
/// TOK : k_discard
/// LHS : --
/// RHS : --
discard,
/// TOK : k_loop
/// LHS : block
/// RHS : --
loop,
/// TOK : k_continuing
/// LHS : block
/// RHS : --
continuing,
/// TOK : k_break
/// LHS : Expr
/// RHS : --
break_if,
/// TOK : k_break
/// LHS : --
/// RHS : --
@"break",
/// TOK : k_continue
/// LHS : --
/// RHS : --
@"continue",
/// TOK : k_if
/// LHS : Expr
/// RHS : blcok
@"if",
/// RHS is else body
/// TOK : k_if
/// LHS : if
/// RHS : blcok
if_else,
/// TOK : k_if
/// LHS : if
/// RHS : if, if_else, if_else_if
if_else_if,
/// TOK : k_switch
/// LHS : Expr
/// RHS : span(switch_case, switch_default, switch_case_default)
@"switch",
/// TOK : k_case
/// LHS : span(Expr)
/// RHS : block
switch_case,
/// TOK : k_default
/// LHS : block
/// RHS : --
switch_default,
/// switch_case with default (`case 1, 2, default {}`)
/// TOK : k_case
/// LHS : span(Expr)
/// RHS : block
switch_case_default,
/// TOK : k_var
/// LHS : VarDecl
/// RHS : Expr?
var_decl,
/// TOK : k_const
/// LHS : Type?
/// RHS : Expr
const_decl,
/// TOK : k_let
/// LHS : Type?
/// RHS : Expr
let_decl,
/// TOK : k_while
/// LHS : Expr
/// RHS : block
@"while",
/// TOK : k_for
/// LHS : ForHeader
/// RHS : block
@"for",
/// TOK : plus_plus, minus_minus
/// LHS : Expr
increase_decrement,
/// TOK : plus_equal, minus_equal,
/// times_equal, division_equal,
/// modulo_equal, and_equal,
/// or_equal, xor_equal,
/// shift_right_equal, shift_left_equal
/// LHS : Expr
/// RHS : Expr
compound_assign,
/// TOK : equal
/// LHS : Expr
/// RHS : --
phony_assign,
// ####### Type #######
/// TOK : k_i32, k_u32, k_f32, k_f16, k_bool
/// LHS : --
/// RHS : --
number_type,
/// TOK : k_bool
/// LHS : --
/// RHS : --
bool_type,
/// TOK : k_sampler, k_comparison_sampler
/// LHS : --
/// RHS : --
sampler_type,
/// TOK : k_vec2, k_vec3, k_vec4
/// LHS : Type
/// RHS : --
vector_type,
/// TOK : k_mat2x2, k_mat2x3, k_mat2x4,
/// k_mat3x2, k_mat3x3, k_mat3x4,
/// k_mat4x2, k_mat4x3, k_mat4x4
/// LHS : Type
/// RHS : --
matrix_type,
/// TOK : k_atomic
/// LHS : Type
/// RHS : --
atomic_type,
/// TOK : k_array
/// LHS : Type
/// RHS : Expr?
array_type,
/// TOK : k_ptr
/// LHS : Type
/// RHS : PtrType
ptr_type,
/// TOK : k_texture_1d, k_texture_2d, k_texture_2d_array,
/// k_texture_3d, k_texture_cube, k_texture_cube_array
/// LHS : Type
/// RHS : --
sampled_texture_type,
/// TOK : k_texture_multisampled_2d
/// LHS : Type
/// RHS : --
multisampled_texture_type,
/// TOK : k_texture_external
/// LHS : Type
/// RHS : --
external_texture_type,
/// TOK : k_texture_storage_1d, k_texture_storage_2d,
/// k_texture_storage_2d_array, k_texture_storage_3d
/// LHS : Index(Token(TexelFormat))
/// RHS : Index(Token(AccessMode))
storage_texture_type,
/// TOK : k_texture_depth_2d, k_texture_depth_2d_array
/// k_texture_depth_cube, k_texture_depth_cube_array
/// k_texture_depth_multisampled_2d
/// LHS : --
/// RHS : --
depth_texture_type,
// ####### Attribute #######
// TOK : attr
attr,
/// TOK : attr
/// LHS : Expr
/// RHS : --
attr_one_arg,
/// TOK : attr
/// LHS : Index(Token(BuiltinValue))
/// RHS : --
attr_builtin,
/// TOK : attr
/// LHS : WorkgroupSize
/// RHS : --
attr_workgroup_size,
/// TOK : attr
/// LHS : Index(Token(InterpolationType))
/// RHS : Index(Token(InterpolationSample))
attr_interpolate,
// ####### Expr #######
// see both Parser.zig and https://gpuweb.github.io/gpuweb/wgsl/#expression-grammar
/// TOK : *
/// LHS : Expr
/// RHS : Expr
mul,
/// TOK : /
/// LHS : Expr
/// RHS : Expr
div,
/// TOK : %
/// LHS : Expr
/// RHS : Expr
mod,
/// TOK : +
/// LHS : Expr
/// RHS : Expr
add,
/// TOK : -
/// LHS : Expr
/// RHS : Expr
sub,
/// TOK : <<
/// LHS : Expr
/// RHS : Expr
shift_left,
/// TOK : >>
/// LHS : Expr
/// RHS : Expr
shift_right,
/// TOK : &
/// LHS : Expr
/// RHS : Expr
binary_and,
/// TOK : |
/// LHS : Expr
/// RHS : Expr
binary_or,
/// TOK : ^
/// LHS : Expr
/// RHS : Expr
binary_xor,
/// TOK : &&
/// LHS : Expr
/// RHS : Expr
circuit_and,
/// TOK : ||
/// LHS : Expr
/// RHS : Expr
circuit_or,
/// TOK : !
/// LHS : Expr
/// RHS : --
not,
/// TOK : -
/// LHS : Expr
/// RHS : --
negate,
/// TOK : *
/// LHS : Expr
/// RHS : --
deref,
/// TOK : &
/// LHS : Expr
/// RHS : --
addr_of,
/// TOK : ==
/// LHS : Expr
/// RHS : Expr
equal,
/// TOK : !=
/// LHS : Expr
/// RHS : Expr
not_equal,
/// TOK : <
/// LHS : Expr
/// RHS : Expr
less,
/// TOK : <=
/// LHS : Expr
/// RHS : Expr
less_equal,
/// TOK : >
/// LHS : Expr
/// RHS : Expr
greater,
/// TOK : >=
/// LHS : Expr
/// RHS : Expr
greater_equal,
/// for identifier, array without element type specified,
/// vector prefix (e.g. vec2) and matrix prefix (e.g. mat2x2) LHS is null
/// see callExpr in Parser.zig if you don't understand this
///
/// TOK : ident, k_array, k_bool, 'number type keywords', 'vector keywords', 'matrix keywords'
/// LHS : (number_type, bool_type, vector_type, matrix_type, array_type)?
/// RHS : arguments (Expr span)
call,
/// TOK : k_bitcast
/// LHS : Type
/// RHS : Expr
bitcast,
/// TOK : ident
/// LHS : --
/// RHS : --
ident_expr,
/// LHS is prefix expression
/// TOK : ident
/// LHS : Expr
component_access,
/// LHS is prefix expression
/// TOK : bracket_left
/// LHS : Expr
/// RHS : Expr
index_access,
// ####### Literals #######
/// TOK : k_true
/// LHS : --
/// RHS : --
bool_true,
/// TOK : k_false
/// LHS : --
/// RHS : --
bool_false,
/// TOK : number
/// LHS : --
/// RHS : --
number_literal,
};
pub const GlobalVarDecl = struct {
/// span(Attr)?
attrs: Index = null_index,
/// Token(ident)
name: Index,
/// Token(AddrSpace)?
addr_space: Index = null_index,
/// Token(AccessMode)?
access_mode: Index = null_index,
/// Type?
type: Index = null_index,
};
pub const VarDecl = struct {
/// Token(ident)
name: Index,
/// Token(AddrSpace)?
addr_space: Index = null_index,
/// Token(AccessMode)?
access_mode: Index = null_index,
/// Type?
type: Index = null_index,
};
pub const OverrideDecl = struct {
/// span(Attr)?
attrs: Index = null_index,
/// Type?
type: Index = null_index,
};
pub const PtrType = struct {
/// Token(AddrSpace)
addr_space: Index,
/// Token(AccessMode)?
access_mode: Index = null_index,
};
pub const WorkgroupSize = struct {
/// Expr
x: Index,
/// Expr?
y: Index = null_index,
/// Expr?
z: Index = null_index,
};
pub const FnProto = struct {
/// span(Attr)?
attrs: Index = null_index,
/// span(fn_param)?
params: Index = null_index,
/// span(Attr)?
result_attrs: Index = null_index,
/// Type?
result_type: Index = null_index,
};
pub const IfStatement = struct {
/// Expr
cond: Index,
/// block
body: Index,
};
pub const ForHeader = struct {
/// var_decl, const_decl, let_decl, phony_assign, compound_assign
init: Index = null_index,
/// Expr
cond: Index = null_index,
/// call, phony_assign, compound_assign
update: Index = null_index,
};
};
pub const BuiltinValue = enum {
vertex_index,
instance_index,
position,
front_facing,
frag_depth,
local_invocation_id,
local_invocation_index,
global_invocation_id,
workgroup_id,
num_workgroups,
sample_index,
sample_mask,
};
pub const InterpolationType = enum {
perspective,
linear,
flat,
};
pub const InterpolationSample = enum {
center,
centroid,
sample,
};
pub const AddressSpace = enum {
function,
private,
workgroup,
uniform,
storage,
};
pub const AccessMode = enum {
read,
write,
read_write,
};
pub const Attribute = enum {
invariant,
@"const",
vertex,
fragment,
compute,
@"align",
binding,
group,
id,
location,
size,
builtin,
workgroup_size,
interpolate,
};
pub const TexelFormat = enum {
rgba8unorm,
rgba8snorm,
rgba8uint,
rgba8sint,
rgba16uint,
rgba16sint,
rgba16float,
r32uint,
r32sint,
r32float,
rg32uint,
rg32sint,
rg32float,
rgba32uint,
rgba32sint,
rgba32float,
bgra8unorm,
};