spirv: update spec generator

For module parsing and assembling, we will also need to know
all of the SPIR-V extensions and their instructions. This commit
updates the generator to generate those. Because there are
multiple instruction sets that each have a separate list of Opcodes,
no separate enum is generated for these opcodes. Additionally, the
previous mechanism for runtime instruction information, `Opcode`'s
`fn operands()`, has been removed in favor for
`InstructionSet.core.instructions()`.

Any mapping from operand to instruction is to be done at runtime.
Using a runtime populated hashmap should also be more efficient
than the previous mechanism using `stringToEnum`.
This commit is contained in:
Robin Voetter 2024-03-09 12:00:34 +01:00
parent 3bffa58012
commit 3d5721da23
No known key found for this signature in database
2 changed files with 251 additions and 81 deletions

View File

@ -1,45 +1,110 @@
const std = @import("std");
const g = @import("spirv/grammar.zig");
const Allocator = std.mem.Allocator;
const g = @import("spirv/grammar.zig");
const CoreRegistry = g.CoreRegistry;
const ExtensionRegistry = g.ExtensionRegistry;
const Instruction = g.Instruction;
const OperandKind = g.OperandKind;
const Enumerant = g.Enumerant;
const Operand = g.Operand;
const ExtendedStructSet = std.StringHashMap(void);
const Extension = struct {
name: []const u8,
spec: ExtensionRegistry,
};
const CmpInst = struct {
fn lt(_: CmpInst, a: Instruction, b: Instruction) bool {
return a.opcode < b.opcode;
}
};
const StringPair = struct { []const u8, []const u8 };
const StringPairContext = struct {
pub fn hash(_: @This(), a: StringPair) u32 {
var hasher = std.hash.Wyhash.init(0);
const x, const y = a;
hasher.update(x);
hasher.update(y);
return @truncate(hasher.final());
}
pub fn eql(_: @This(), a: StringPair, b: StringPair, b_index: usize) bool {
_ = b_index;
const a_x, const a_y = a;
const b_x, const b_y = b;
return std.mem.eql(u8, a_x, b_x) and std.mem.eql(u8, a_y, b_y);
}
};
const OperandKindMap = std.ArrayHashMap(StringPair, OperandKind, StringPairContext, true);
pub fn main() !void {
var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
defer arena.deinit();
const allocator = arena.allocator();
const a = arena.allocator();
const args = try std.process.argsAlloc(allocator);
const args = try std.process.argsAlloc(a);
if (args.len != 2) {
usageAndExit(std.io.getStdErr(), args[0], 1);
usageAndExit(args[0], 1);
}
const spec_path = args[1];
const spec = try std.fs.cwd().readFileAlloc(allocator, spec_path, std.math.maxInt(usize));
const json_path = try std.fs.path.join(a, &.{ args[1], "include/spirv/unified1/" });
const dir = try std.fs.cwd().openDir(json_path, .{ .iterate = true });
// const spec_path = try std.fs.path.join(a, &.{spirv_headers_dir_path, "spirv.core.grammar.json"});
// const core_spec = try std.fs.cwd().readFileAlloc(a, spec_path, std.math.maxInt(usize));
const core_spec = try readRegistry(CoreRegistry, a, dir, "spirv.core.grammar.json");
std.sort.block(Instruction, core_spec.instructions, CmpInst{}, CmpInst.lt);
var exts = std.ArrayList(Extension).init(a);
var it = dir.iterate();
while (try it.next()) |entry| {
if (entry.kind != .file or !std.mem.startsWith(u8, entry.name, "extinst.")) {
continue;
}
std.debug.assert(std.mem.endsWith(u8, entry.name, ".grammar.json"));
const name = entry.name["extinst.".len .. entry.name.len - ".grammar.json".len];
const spec = try readRegistry(ExtensionRegistry, a, dir, entry.name);
std.sort.block(Instruction, spec.instructions, CmpInst{}, CmpInst.lt);
try exts.append(.{ .name = try a.dupe(u8, name), .spec = spec });
}
var bw = std.io.bufferedWriter(std.io.getStdOut().writer());
try render(bw.writer(), a, core_spec, exts.items);
try bw.flush();
}
fn readRegistry(comptime RegistryType: type, a: Allocator, dir: std.fs.Dir, path: []const u8) !RegistryType {
const spec = try dir.readFileAlloc(a, path, std.math.maxInt(usize));
// Required for json parsing.
@setEvalBranchQuota(10000);
var scanner = std.json.Scanner.initCompleteInput(allocator, spec);
var scanner = std.json.Scanner.initCompleteInput(a, spec);
var diagnostics = std.json.Diagnostics{};
scanner.enableDiagnostics(&diagnostics);
const parsed = std.json.parseFromTokenSource(g.CoreRegistry, allocator, &scanner, .{}) catch |err| {
std.debug.print("line,col: {},{}\n", .{ diagnostics.getLine(), diagnostics.getColumn() });
const parsed = std.json.parseFromTokenSource(RegistryType, a, &scanner, .{}) catch |err| {
std.debug.print("{s}:{}:{}:\n", .{ path, diagnostics.getLine(), diagnostics.getColumn() });
return err;
};
var bw = std.io.bufferedWriter(std.io.getStdOut().writer());
try render(bw.writer(), allocator, parsed.value);
try bw.flush();
return parsed.value;
}
/// Returns a set with types that require an extra struct for the `Instruction` interface
/// to the spir-v spec, or whether the original type can be used.
fn extendedStructs(
arena: Allocator,
kinds: []const g.OperandKind,
a: Allocator,
kinds: []const OperandKind,
) !ExtendedStructSet {
var map = ExtendedStructSet.init(arena);
var map = ExtendedStructSet.init(a);
try map.ensureTotalCapacity(@as(u32, @intCast(kinds.len)));
for (kinds) |kind| {
@ -73,7 +138,7 @@ fn tagPriorityScore(tag: []const u8) usize {
}
}
fn render(writer: anytype, allocator: Allocator, registry: g.CoreRegistry) !void {
fn render(writer: anytype, a: Allocator, registry: CoreRegistry, extensions: []const Extension) !void {
try writer.writeAll(
\\//! This file is auto-generated by tools/gen_spirv_spec.zig.
\\
@ -99,6 +164,7 @@ fn render(writer: anytype, allocator: Allocator, registry: g.CoreRegistry) !void
\\pub const IdScope = IdRef;
\\
\\pub const LiteralInteger = Word;
\\pub const LiteralFloat = Word;
\\pub const LiteralString = []const u8;
\\pub const LiteralContextDependentNumber = union(enum) {
\\ int32: i32,
@ -139,6 +205,12 @@ fn render(writer: anytype, allocator: Allocator, registry: g.CoreRegistry) !void
\\ parameters: []const OperandKind,
\\};
\\
\\pub const Instruction = struct {
\\ name: []const u8,
\\ opcode: Word,
\\ operands: []const Operand,
\\};
\\
\\
);
@ -151,15 +223,123 @@ fn render(writer: anytype, allocator: Allocator, registry: g.CoreRegistry) !void
.{ registry.major_version, registry.minor_version, registry.revision, registry.magic_number },
);
const extended_structs = try extendedStructs(allocator, registry.operand_kinds);
try renderClass(writer, allocator, registry.instructions);
try renderOperandKind(writer, registry.operand_kinds);
try renderOpcodes(writer, allocator, registry.instructions, extended_structs);
try renderOperandKinds(writer, allocator, registry.operand_kinds, extended_structs);
// Merge the operand kinds from all extensions together.
// var all_operand_kinds = std.ArrayList(OperandKind).init(a);
// try all_operand_kinds.appendSlice(registry.operand_kinds);
var all_operand_kinds = OperandKindMap.init(a);
for (registry.operand_kinds) |kind| {
try all_operand_kinds.putNoClobber(.{ "core", kind.kind }, kind);
}
for (extensions) |ext| {
// Note: extensions may define the same operand kind, with different
// parameters. Instead of trying to merge them, just discriminate them
// using the name of the extension. This is similar to what
// the official headers do.
try all_operand_kinds.ensureUnusedCapacity(ext.spec.operand_kinds.len);
for (ext.spec.operand_kinds) |kind| {
var new_kind = kind;
new_kind.kind = try std.mem.join(a, ".", &.{ ext.name, kind.kind });
try all_operand_kinds.putNoClobber(.{ ext.name, kind.kind }, new_kind);
}
}
const extended_structs = try extendedStructs(a, all_operand_kinds.values());
// Note: extensions don't seem to have class.
try renderClass(writer, a, registry.instructions);
try renderOperandKind(writer, all_operand_kinds.values());
try renderOpcodes(writer, a, registry.instructions, extended_structs);
try renderOperandKinds(writer, a, all_operand_kinds.values(), extended_structs);
try renderInstructionSet(writer, a, registry, extensions, all_operand_kinds);
}
fn renderClass(writer: anytype, allocator: Allocator, instructions: []const g.Instruction) !void {
var class_map = std.StringArrayHashMap(void).init(allocator);
fn renderInstructionSet(
writer: anytype,
a: Allocator,
core: CoreRegistry,
extensions: []const Extension,
all_operand_kinds: OperandKindMap,
) !void {
_ = a;
try writer.writeAll(
\\pub const InstructionSet = enum {
\\ core,
);
for (extensions) |ext| {
try writer.print("{},\n", .{std.zig.fmtId(ext.name)});
}
try writer.writeAll(
\\
\\ pub fn instructions(self: InstructionSet) []const Instruction {
\\ return switch (self) {
\\
);
try renderInstructionsCase(writer, "core", core.instructions, all_operand_kinds);
for (extensions) |ext| {
try renderInstructionsCase(writer, ext.name, ext.spec.instructions, all_operand_kinds);
}
try writer.writeAll(
\\ };
\\ }
\\};
\\
);
}
fn renderInstructionsCase(
writer: anytype,
set_name: []const u8,
instructions: []const Instruction,
all_operand_kinds: OperandKindMap,
) !void {
// Note: theoretically we could dedup from tags and give every instruction a list of aliases,
// but there aren't so many total aliases and that would add more overhead in total. We will
// just filter those out when needed.
try writer.print(".{} => &[_]Instruction{{\n", .{std.zig.fmtId(set_name)});
for (instructions) |inst| {
try writer.print(
\\.{{
\\ .name = "{s}",
\\ .opcode = {},
\\ .operands = &[_]Operand{{
\\
, .{ inst.opname, inst.opcode });
for (inst.operands) |operand| {
const quantifier = if (operand.quantifier) |q|
switch (q) {
.@"?" => "optional",
.@"*" => "variadic",
}
else
"required";
const kind = all_operand_kinds.get(.{ set_name, operand.kind }) orelse
all_operand_kinds.get(.{ "core", operand.kind }).?;
try writer.print(".{{.kind = .{}, .quantifier = .{s}}},\n", .{ std.zig.fmtId(kind.kind), quantifier });
}
try writer.writeAll(
\\ },
\\},
\\
);
}
try writer.writeAll(
\\},
\\
);
}
fn renderClass(writer: anytype, a: Allocator, instructions: []const Instruction) !void {
var class_map = std.StringArrayHashMap(void).init(a);
for (instructions) |inst| {
if (std.mem.eql(u8, inst.class.?, "@exclude")) {
@ -173,7 +353,7 @@ fn renderClass(writer: anytype, allocator: Allocator, instructions: []const g.In
try renderInstructionClass(writer, class);
try writer.writeAll(",\n");
}
try writer.writeAll("};\n");
try writer.writeAll("};\n\n");
}
fn renderInstructionClass(writer: anytype, class: []const u8) !void {
@ -192,7 +372,7 @@ fn renderInstructionClass(writer: anytype, class: []const u8) !void {
}
}
fn renderOperandKind(writer: anytype, operands: []const g.OperandKind) !void {
fn renderOperandKind(writer: anytype, operands: []const OperandKind) !void {
try writer.writeAll("pub const OperandKind = enum {\n");
for (operands) |operand| {
try writer.print("{},\n", .{std.zig.fmtId(operand.kind)});
@ -242,7 +422,7 @@ fn renderOperandKind(writer: anytype, operands: []const g.OperandKind) !void {
try writer.writeAll("};\n}\n};\n");
}
fn renderEnumerant(writer: anytype, enumerant: g.Enumerant) !void {
fn renderEnumerant(writer: anytype, enumerant: Enumerant) !void {
try writer.print(".{{.name = \"{s}\", .value = ", .{enumerant.enumerant});
switch (enumerant.value) {
.bitflag => |flag| try writer.writeAll(flag),
@ -260,14 +440,14 @@ fn renderEnumerant(writer: anytype, enumerant: g.Enumerant) !void {
fn renderOpcodes(
writer: anytype,
allocator: Allocator,
instructions: []const g.Instruction,
a: Allocator,
instructions: []const Instruction,
extended_structs: ExtendedStructSet,
) !void {
var inst_map = std.AutoArrayHashMap(u32, usize).init(allocator);
var inst_map = std.AutoArrayHashMap(u32, usize).init(a);
try inst_map.ensureTotalCapacity(instructions.len);
var aliases = std.ArrayList(struct { inst: usize, alias: usize }).init(allocator);
var aliases = std.ArrayList(struct { inst: usize, alias: usize }).init(a);
try aliases.ensureTotalCapacity(instructions.len);
for (instructions, 0..) |inst, i| {
@ -323,31 +503,6 @@ fn renderOpcodes(
try renderOperand(writer, .instruction, inst.opname, inst.operands, extended_structs);
}
try writer.writeAll(
\\};
\\}
\\pub fn operands(self: Opcode) []const Operand {
\\return switch (self) {
\\
);
for (instructions_indices) |i| {
const inst = instructions[i];
try writer.print(".{} => &[_]Operand{{", .{std.zig.fmtId(inst.opname)});
for (inst.operands) |operand| {
const quantifier = if (operand.quantifier) |q|
switch (q) {
.@"?" => "optional",
.@"*" => "variadic",
}
else
"required";
try writer.print(".{{.kind = .{s}, .quantifier = .{s}}},", .{ operand.kind, quantifier });
}
try writer.writeAll("},\n");
}
try writer.writeAll(
\\};
\\}
@ -368,14 +523,14 @@ fn renderOpcodes(
fn renderOperandKinds(
writer: anytype,
allocator: Allocator,
kinds: []const g.OperandKind,
a: Allocator,
kinds: []const OperandKind,
extended_structs: ExtendedStructSet,
) !void {
for (kinds) |kind| {
switch (kind.category) {
.ValueEnum => try renderValueEnum(writer, allocator, kind, extended_structs),
.BitEnum => try renderBitEnum(writer, allocator, kind, extended_structs),
.ValueEnum => try renderValueEnum(writer, a, kind, extended_structs),
.BitEnum => try renderBitEnum(writer, a, kind, extended_structs),
else => {},
}
}
@ -383,20 +538,26 @@ fn renderOperandKinds(
fn renderValueEnum(
writer: anytype,
allocator: Allocator,
enumeration: g.OperandKind,
a: Allocator,
enumeration: OperandKind,
extended_structs: ExtendedStructSet,
) !void {
const enumerants = enumeration.enumerants orelse return error.InvalidRegistry;
var enum_map = std.AutoArrayHashMap(u32, usize).init(allocator);
var enum_map = std.AutoArrayHashMap(u32, usize).init(a);
try enum_map.ensureTotalCapacity(enumerants.len);
var aliases = std.ArrayList(struct { enumerant: usize, alias: usize }).init(allocator);
var aliases = std.ArrayList(struct { enumerant: usize, alias: usize }).init(a);
try aliases.ensureTotalCapacity(enumerants.len);
for (enumerants, 0..) |enumerant, i| {
const result = enum_map.getOrPutAssumeCapacity(enumerant.value.int);
try writer.context.flush();
const value: u31 = switch (enumerant.value) {
.int => |value| value,
// Some extensions declare ints as string
.bitflag => |value| try std.fmt.parseInt(u31, value, 10),
};
const result = enum_map.getOrPutAssumeCapacity(value);
if (!result.found_existing) {
result.value_ptr.* = i;
continue;
@ -422,9 +583,12 @@ fn renderValueEnum(
for (enum_indices) |i| {
const enumerant = enumerants[i];
if (enumerant.value != .int) return error.InvalidRegistry;
// if (enumerant.value != .int) return error.InvalidRegistry;
try writer.print("{} = {},\n", .{ std.zig.fmtId(enumerant.enumerant), enumerant.value.int });
switch (enumerant.value) {
.int => |value| try writer.print("{} = {},\n", .{ std.zig.fmtId(enumerant.enumerant), value }),
.bitflag => |value| try writer.print("{} = {s},\n", .{ std.zig.fmtId(enumerant.enumerant), value }),
}
}
try writer.writeByte('\n');
@ -454,8 +618,8 @@ fn renderValueEnum(
fn renderBitEnum(
writer: anytype,
allocator: Allocator,
enumeration: g.OperandKind,
a: Allocator,
enumeration: OperandKind,
extended_structs: ExtendedStructSet,
) !void {
try writer.print("pub const {s} = packed struct {{\n", .{std.zig.fmtId(enumeration.kind)});
@ -463,7 +627,7 @@ fn renderBitEnum(
var flags_by_bitpos = [_]?usize{null} ** 32;
const enumerants = enumeration.enumerants orelse return error.InvalidRegistry;
var aliases = std.ArrayList(struct { flag: usize, alias: u5 }).init(allocator);
var aliases = std.ArrayList(struct { flag: usize, alias: u5 }).init(a);
try aliases.ensureTotalCapacity(enumerants.len);
for (enumerants, 0..) |enumerant, i| {
@ -471,6 +635,10 @@ fn renderBitEnum(
const value = try parseHexInt(enumerant.value.bitflag);
if (value == 0) {
continue; // Skip 'none' items
} else if (std.mem.eql(u8, enumerant.enumerant, "FlagIsPublic")) {
// This flag is special and poorly defined in the json files.
// Just skip it for now
continue;
}
std.debug.assert(@popCount(value) == 1);
@ -540,7 +708,7 @@ fn renderOperand(
mask,
},
field_name: []const u8,
parameters: []const g.Operand,
parameters: []const Operand,
extended_structs: ExtendedStructSet,
) !void {
if (kind == .instruction) {
@ -606,7 +774,7 @@ fn renderOperand(
try writer.writeAll(",\n");
}
fn renderFieldName(writer: anytype, operands: []const g.Operand, field_index: usize) !void {
fn renderFieldName(writer: anytype, operands: []const Operand, field_index: usize) !void {
const operand = operands[field_index];
// Should be enough for all names - adjust as needed.
@ -673,16 +841,16 @@ fn parseHexInt(text: []const u8) !u31 {
return try std.fmt.parseInt(u31, text[prefix.len..], 16);
}
fn usageAndExit(file: std.fs.File, arg0: []const u8, code: u8) noreturn {
file.writer().print(
\\Usage: {s} <spirv json spec>
fn usageAndExit(arg0: []const u8, code: u8) noreturn {
std.io.getStdErr().writer().print(
\\Usage: {s} <SPIRV-Headers repository path>
\\
\\Generates Zig bindings for a SPIR-V specification .json (either core or
\\extinst versions). The result, printed to stdout, should be used to update
\\Generates Zig bindings for SPIR-V specifications found in the SPIRV-Headers
\\repository. The result, printed to stdout, should be used to update
\\files in src/codegen/spirv. Don't forget to format the output.
\\
\\The relevant specifications can be obtained from the SPIR-V registry:
\\https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/unified1/
\\<SPIRV-Headers repository path> should point to a clone of
\\https://github.com/KhronosGroup/SPIRV-Headers/
\\
, .{arg0}) catch std.process.exit(1);
std.process.exit(code);

View File

@ -22,8 +22,8 @@ pub const CoreRegistry = struct {
};
pub const ExtensionRegistry = struct {
copyright: [][]const u8,
version: u32,
copyright: ?[][]const u8 = null,
version: ?u32 = null,
revision: u32,
instructions: []Instruction,
operand_kinds: []OperandKind = &[_]OperandKind{},
@ -40,6 +40,8 @@ pub const Instruction = struct {
opcode: u32,
operands: []Operand = &[_]Operand{},
capabilities: [][]const u8 = &[_][]const u8{},
// DebugModuleINTEL has this...
capability: ?[]const u8 = null,
extensions: [][]const u8 = &[_][]const u8{},
version: ?[]const u8 = null,