summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlejandro Soto <alejandro@34project.org>2024-05-04 23:58:08 -0600
committerAlejandro Soto <alejandro@34project.org>2024-05-05 14:19:15 -0600
commit4fdcb079663eccc71ed2c120f8279d6c364de9fd (patch)
tree1df076513ef031fa2a2f55d280e2edd09748cdd5
parenta7d92072c0bdc3a3e1c99de64f353e932846bc2a (diff)
platform/wavelet3d: implement shader writeback
-rw-r--r--platform/wavelet3d/gfx_front_back.sv10
-rw-r--r--platform/wavelet3d/gfx_isa.sv2
-rw-r--r--platform/wavelet3d/gfx_pkg.sv8
-rw-r--r--platform/wavelet3d/gfx_regfile_io.sv38
-rw-r--r--platform/wavelet3d/gfx_shader.sv19
-rw-r--r--platform/wavelet3d/gfx_shader_back.sv163
-rw-r--r--platform/wavelet3d/gfx_shader_fpint.sv69
-rw-r--r--platform/wavelet3d/gfx_shader_front.sv82
-rw-r--r--platform/wavelet3d/gfx_shader_group.sv1
-rw-r--r--platform/wavelet3d/gfx_shader_mem.sv1
-rw-r--r--platform/wavelet3d/gfx_shader_regs.sv77
-rw-r--r--platform/wavelet3d/gfx_shader_schedif.rdl27
-rw-r--r--platform/wavelet3d/gfx_shader_setup.sv37
-rw-r--r--platform/wavelet3d/gfx_shader_sfu.sv1
-rw-r--r--platform/wavelet3d/gfx_top.sv1
-rw-r--r--platform/wavelet3d/gfx_wb.sv20
16 files changed, 461 insertions, 95 deletions
diff --git a/platform/wavelet3d/gfx_front_back.sv b/platform/wavelet3d/gfx_front_back.sv
index 890b734..b768532 100644
--- a/platform/wavelet3d/gfx_front_back.sv
+++ b/platform/wavelet3d/gfx_front_back.sv
@@ -3,11 +3,11 @@ import gfx::*;;
struct
{
- group_id group;
- fpint_op p0;
- mem_op p1;
- sfu_op p2;
- group_op p3;
+ wave_exec wave;
+ fpint_op p0;
+ mem_op p1;
+ sfu_op p2;
+ group_op p3;
} execute;
struct
diff --git a/platform/wavelet3d/gfx_isa.sv b/platform/wavelet3d/gfx_isa.sv
index 873e6ec..7239478 100644
--- a/platform/wavelet3d/gfx_isa.sv
+++ b/platform/wavelet3d/gfx_isa.sv
@@ -3,6 +3,8 @@ package gfx_isa;
typedef logic[3:0] sgpr_num;
typedef logic[2:0] vgpr_num;
+ typedef logic signed[7:0] pc_offset;
+
typedef union packed
{
sgpr_num sgpr;
diff --git a/platform/wavelet3d/gfx_pkg.sv b/platform/wavelet3d/gfx_pkg.sv
index 5beb399..7072967 100644
--- a/platform/wavelet3d/gfx_pkg.sv
+++ b/platform/wavelet3d/gfx_pkg.sv
@@ -241,6 +241,7 @@ package gfx;
typedef gfx_isa::sgpr_num sgpr_num;
typedef gfx_isa::vgpr_num vgpr_num;
typedef gfx_isa::xgpr_num xgpr_num;
+ typedef gfx_isa::pc_offset pc_offset;
typedef struct packed
{
@@ -251,6 +252,13 @@ package gfx;
valid;
} shader_dispatch;
+ typedef struct
+ {
+ group_id group;
+ xgpr_num dest;
+ logic dest_scalar;
+ } wave_exec;
+
localparam int FIXED_MULADD_DEPTH = 5;
localparam int FIXED_DOTADD_DEPTH = 2 * FIXED_MULADD_DEPTH;
diff --git a/platform/wavelet3d/gfx_regfile_io.sv b/platform/wavelet3d/gfx_regfile_io.sv
index 49dcd5c..2459049 100644
--- a/platform/wavelet3d/gfx_regfile_io.sv
+++ b/platform/wavelet3d/gfx_regfile_io.sv
@@ -34,8 +34,10 @@ interface gfx_regfile_io;
} vgpr_write;
word a[SHADER_LANES], b[SHADER_LANES], sgpr_write_data, vgpr_write_data[SHADER_LANES];
- word_ptr pc_front;
- group_id pc_front_group;
+ logic mask_wb_write, pc_wb_write;
+ word_ptr pc_back, pc_front, pc_wb;
+ group_id mask_back_group, mask_wb_group, pc_back_group, pc_front_group, pc_wb_group;
+ lane_mask mask_back, mask_wb;
modport ab
(
@@ -57,8 +59,22 @@ interface gfx_regfile_io;
modport wb
(
+ input pc_back,
+ mask_back,
+
output sgpr_write,
- vgpr_write
+ vgpr_write,
+
+ pc_back_group,
+ mask_back_group,
+
+ pc_wb,
+ pc_wb_group,
+ pc_wb_write,
+
+ mask_wb,
+ mask_wb_group,
+ mask_wb_write
);
modport regs
@@ -66,11 +82,25 @@ interface gfx_regfile_io;
input op,
sgpr_write,
vgpr_write,
+
+ pc_back_group,
pc_front_group,
+ mask_back_group,
+
+ pc_wb,
+ pc_wb_group,
+ pc_wb_write,
+
+ mask_wb,
+ mask_wb_group,
+ mask_wb_write,
output a,
b,
- pc_front
+
+ pc_back,
+ pc_front,
+ mask_back
);
endinterface
diff --git a/platform/wavelet3d/gfx_shader.sv b/platform/wavelet3d/gfx_shader.sv
index f8432c9..322ffb5 100644
--- a/platform/wavelet3d/gfx_shader.sv
+++ b/platform/wavelet3d/gfx_shader.sv
@@ -10,7 +10,7 @@ import gfx_shader_schedif_pkg::*;
gfx_axil.s sched
);
- axi4lite_intf #(.ADDR_WIDTH(4)) regblock();
+ axi4lite_intf #(.ADDR_WIDTH(GFX_SHADER_SCHEDIF_MIN_ADDR_WIDTH)) regblock();
gfx_axil2regblock axil2regblock
(
@@ -23,6 +23,20 @@ import gfx_shader_schedif_pkg::*;
gfx_front_back front_back();
gfx_regfile_io regfile();
+ gfx_shader_setup setup();
+
+ assign schedif_in.SETUP_CTRL.GPR_DONE.hwset = setup.sched.set_done.gpr;
+ assign schedif_in.SETUP_CTRL.MASK_DONE.hwset = setup.sched.set_done.mask;
+ assign schedif_in.SETUP_CTRL.SUBMIT_DONE.hwset = setup.sched.set_done.submit;
+
+ assign setup.sched.write.pc = schedif_out.SETUP_SUBMIT.PC.value;
+ assign setup.sched.write.gpr = schedif_out.SETUP_CTRL.XGPR.value;
+ assign setup.sched.write.mask = schedif_out.SETUP_MASK.MASK.value;
+ assign setup.sched.write.group = schedif_out.SETUP_CTRL.GROUP.value;
+ assign setup.sched.write.pc_set = schedif_out.SETUP_SUBMIT.PC.swmod;
+ assign setup.sched.write.gpr_set = schedif_out.SETUP_GPR.VALUE.swmod;
+ assign setup.sched.write.mask_set = schedif_out.SETUP_MASK.MASK.swmod;
+ assign setup.sched.write.gpr_value = schedif_out.SETUP_GPR.VALUE.value;
gfx_shader_front frontend
(
@@ -40,6 +54,7 @@ import gfx_shader_schedif_pkg::*;
.clk,
.rst_n,
.back(front_back.back),
+ .setup(setup.core),
.reg_wb(regfile.wb),
.read_data(regfile.ab)
);
@@ -47,7 +62,7 @@ import gfx_shader_schedif_pkg::*;
gfx_shader_regs regs
(
.clk,
- .io(regfile)
+ .io(regfile.regs)
);
gfx_shader_schedif schedif
diff --git a/platform/wavelet3d/gfx_shader_back.sv b/platform/wavelet3d/gfx_shader_back.sv
index c7b2855..4929192 100644
--- a/platform/wavelet3d/gfx_shader_back.sv
+++ b/platform/wavelet3d/gfx_shader_back.sv
@@ -1,13 +1,15 @@
module gfx_shader_back
import gfx::*;
(
- input logic clk,
- rst_n,
+ input logic clk,
+ rst_n,
- gfx_front_back.back back,
+ gfx_front_back.back back,
- gfx_regfile_io.ab read_data,
- gfx_regfile_io.wb reg_wb
+ gfx_regfile_io.ab read_data,
+ gfx_regfile_io.wb reg_wb,
+
+ gfx_shader_setup.core setup
);
logic abort;
@@ -30,6 +32,7 @@ import gfx::*;
.rst_n,
.op(back.execute.p0),
.wb(p0_wb.tx),
+ .wave(back.execute.wave),
.abort,
.read_data,
.in_valid(back.dispatch.valid)
@@ -41,6 +44,7 @@ import gfx::*;
.rst_n,
.op(back.execute.p1),
.wb(p1_wb.tx),
+ .wave(back.execute.wave),
.in_shake(p1_shake.rx),
.read_data
);
@@ -51,6 +55,7 @@ import gfx::*;
.rst_n,
.op(back.execute.p2),
.wb(p2_wb.tx),
+ .wave(back.execute.wave),
.in_shake(p2_shake.rx),
.read_data
);
@@ -61,6 +66,7 @@ import gfx::*;
.rst_n,
.op(back.execute.p3),
.wb(p3_wb.tx),
+ .wave(back.execute.wave),
.in_shake(p3_shake.rx),
.read_data
);
@@ -81,7 +87,10 @@ import gfx::*;
.clk,
.rst_n,
.wb(out_wb.rx),
- .regs(reg_wb)
+ .regs(reg_wb),
+ .setup,
+ .loop_group(back.loop.group),
+ .loop_valid(back.loop.valid)
);
endmodule
@@ -170,6 +179,7 @@ module gfx_shader_writeback_arbiter2_prio
//TODO
assign a.ready = out.ready;
assign b.ready = 0;
+
assign out.dest = a.dest;
assign out.lanes = a.lanes;
assign out.group = a.group;
@@ -177,18 +187,149 @@ module gfx_shader_writeback_arbiter2_prio
assign out.scalar = a.scalar;
assign out.writeback = a.writeback;
+ assign out.mask = a.mask;
+ assign out.mask_update = a.mask_update;
+
+ assign out.pc_add = a.pc_add;
+ assign out.pc_inc = a.pc_inc;
+ assign out.pc_update = a.pc_update;
+
endmodule
module gfx_shader_writeback
+import gfx::*;
(
- input logic clk,
- rst_n,
+ input logic clk,
+ rst_n,
+
+ gfx_wb.rx wb,
+
+ gfx_regfile_io.wb regs,
- gfx_wb.rx wb,
+ output logic loop_valid,
+ output group_id loop_group,
- gfx_regfile_io.wb regs
+ gfx_shader_setup.core setup
);
-
+ struct
+ {
+ group_id group;
+ word lanes[SHADER_LANES];
+ pc_offset pc_add;
+ lane_mask mask;
+ vgpr_num vgpr;
+ logic pc_update,
+ mask_update,
+ vgpr_update;
+ } loop_hold[REGFILE_STAGES], loop_out;
+
+ logic loop_valid_hold[REGFILE_STAGES], loop_out_valid, mask_wb, scalar_wb,
+ setup_gpr, setup_mask, setup_submit;
+
+ assign wb.ready = 1;
+
+ assign loop_out = loop_hold[REGFILE_STAGES - 1];
+ assign loop_out_valid = loop_valid_hold[REGFILE_STAGES - 1];
+
+ assign loop_valid = loop_out_valid | setup_submit;
+
+ assign regs.pc_back_group = wb.group;
+ assign regs.mask_back_group = wb.group;
+
+ assign regs.pc_wb_write = (loop_out_valid & loop_out.pc_update) | setup_submit;
+ assign regs.mask_wb_write = mask_wb | setup_mask;
+ assign regs.sgpr_write.write = scalar_wb | setup_gpr;
+
+ assign regs.vgpr_write.vgpr = loop_out.vgpr;
+ assign regs.vgpr_write.group = loop_out.group;
+
+ assign mask_wb = loop_out_valid & loop_out.mask_update;
+ assign scalar_wb = wb.valid & wb.writeback & wb.scalar;
+
+ always_comb begin
+ loop_group = setup.write.group;
+ regs.pc_wb = setup.write.pc;
+ regs.pc_wb_group = setup.write.group;
+
+ if (loop_out_valid) begin
+ loop_group = loop_out.group;
+ regs.pc_wb = regs.pc_back + word_ptr'(loop_out.pc_add);
+ regs.pc_wb_group = loop_out.group;
+ end
+
+ regs.mask_wb = setup.write.mask;
+ regs.mask_wb_group = setup.write.group;
+
+ if (mask_wb) begin
+ regs.mask_wb = loop_out.mask;
+ regs.mask_wb_group = loop_out.group;
+ end
+
+ regs.sgpr_write.data = setup.write.gpr_value;
+ regs.sgpr_write.sgpr = setup.write.gpr.sgpr;
+ regs.sgpr_write.group = setup.write.group;
+
+ if (scalar_wb) begin
+ regs.sgpr_write.data = wb.lanes[0];
+ regs.sgpr_write.sgpr = wb.dest.sgpr;
+ regs.sgpr_write.group = wb.group;
+ end
+
+ for (int i = 0; i < SHADER_LANES; ++i)
+ regs.vgpr_write.data[i] = loop_out.lanes[i];
+
+ regs.vgpr_write.mask = regs.mask_back;
+ if (~loop_out_valid | ~loop_out.vgpr_update)
+ regs.vgpr_write.mask = '0;
+ end
+
+ always_ff @(posedge clk) begin
+ // Blocking assignments por bug de verilator (ver for de lanes abajo)
+
+ for (int i = REGFILE_STAGES - 1; i > 0; --i)
+ loop_hold[i] = loop_hold[i - 1];
+
+ loop_hold[0].mask = wb.mask;
+ loop_hold[0].vgpr = wb.dest.vgpr.num;
+ loop_hold[0].group = wb.group;
+ loop_hold[0].pc_add = wb.pc_add;
+ loop_hold[0].pc_update = wb.pc_update;
+ loop_hold[0].mask_update = wb.mask_update;
+ loop_hold[0].vgpr_update = wb.writeback & ~wb.scalar;
+
+ // https://github.com/verilator/verilator/issues/4804
+ for (int i = 0; i < SHADER_LANES; ++i)
+ loop_hold[0].lanes[i] = wb.lanes[i];
+
+ if (wb.pc_inc)
+ loop_hold[0].pc_add = pc_offset'(1);
+ end
+
+ always_ff @(posedge clk or negedge rst_n)
+ if (~rst_n) begin
+ setup_gpr <= 0;
+ setup_mask <= 0;
+ setup_submit <= 0;
+
+ setup.set_done.gpr <= 0;
+ setup.set_done.mask <= 0;
+ setup.set_done.submit <= 0;
+
+ for (int i = 0; i < $size(loop_valid_hold); ++i)
+ loop_valid_hold[i] <= 0;
+ end else begin
+ setup_gpr <= (setup_gpr & scalar_wb) | setup.write.gpr_set;
+ setup_mask <= (setup_mask & mask_wb) | setup.write.mask_set;
+ setup_submit <= (setup_submit & loop_out_valid) | setup.write.pc_set;
+
+ setup.set_done.gpr <= setup_gpr & ~scalar_wb;
+ setup.set_done.mask <= setup_mask & ~mask_wb;
+ setup.set_done.submit <= setup_submit & ~loop_out_valid;
+
+ loop_valid_hold[0] <= wb.valid;
+ for (int i = 1; i < REGFILE_STAGES; ++i)
+ loop_valid_hold[i] <= loop_valid_hold[i - 1];
+ end
endmodule
diff --git a/platform/wavelet3d/gfx_shader_fpint.sv b/platform/wavelet3d/gfx_shader_fpint.sv
index 392a8b5..a418dcc 100644
--- a/platform/wavelet3d/gfx_shader_fpint.sv
+++ b/platform/wavelet3d/gfx_shader_fpint.sv
@@ -154,6 +154,7 @@ import gfx::*;
rst_n,
input fpint_op op,
+ input wave_exec wave,
input logic abort,
in_valid,
@@ -164,10 +165,25 @@ import gfx::*;
localparam int FPINT_STAGES = 7 + FPINT_CLZ_STAGES + 4;
+ struct
+ {
+ fpint_op op;
+ wave_exec wave;
+ } stage[FPINT_STAGES];
+
logic stage_valid[FPINT_STAGES];
- fpint_op stage_op[FPINT_STAGES];
- assign stage_op[0] = op;
+ assign wb.dest = stage[FPINT_STAGES - 1].wave.dest;
+ assign wb.mask = 'x;
+ assign wb.group = stage[FPINT_STAGES - 1].wave.group;
+ assign wb.pc_add = 'x;
+ assign wb.pc_inc = 1;
+ assign wb.scalar = stage[FPINT_STAGES - 1].wave.dest_scalar;
+ assign wb.pc_update = wb.writeback;
+ assign wb.writeback = stage[FPINT_STAGES - 1].op.writeback;
+ assign wb.mask_update = 0;
+
+ // Ojo: stage_valid[0], pero stage[0] no
assign stage_valid[0] = in_valid;
genvar lane;
@@ -179,32 +195,36 @@ import gfx::*;
.a(read_data.a[lane]),
.b(read_data.b[lane]),
.q(wb.lanes[lane]),
- .mul_float_0(stage_op[0].setup_mul_float),
- .unit_b_0(stage_op[0].setup_unit_b),
- .put_hi_2(stage_op[2].mnorm_put_hi),
- .put_lo_2(stage_op[2].mnorm_put_lo),
- .put_mul_2(stage_op[2].mnorm_put_mul),
- .zero_b_2(stage_op[2].mnorm_zero_b),
- .zero_flags_2(stage_op[2].mnorm_zero_flags),
- .abs_3(stage_op[3].minmax_abs),
- .swap_3(stage_op[3].minmax_swap),
- .zero_min_3(stage_op[3].minmax_zero_min),
- .copy_flags_3(stage_op[3].minmax_copy_flags),
- .int_signed_5(stage_op[5].shiftr_int_signed),
- .copy_flags_6(stage_op[6].addsub_copy_flags),
- .int_operand_6(stage_op[6].addsub_int_operand),
- .force_nop_7(stage_op[7].clz_force_nop),
- .copy_flags_11(stage_op[11].shiftl_copy_flags),
- .copy_flags_12(stage_op[12].round_copy_flags),
- .enable_12(stage_op[12].round_enable),
- .enable_14(stage_op[14].encode_enable)
+ .mul_float_0(op.setup_mul_float),
+ .unit_b_0(op.setup_unit_b),
+ .put_hi_2(stage[2 - 1].op.mnorm_put_hi),
+ .put_lo_2(stage[2 - 1].op.mnorm_put_lo),
+ .put_mul_2(stage[2 - 1].op.mnorm_put_mul),
+ .zero_b_2(stage[2 - 1].op.mnorm_zero_b),
+ .zero_flags_2(stage[2 - 1].op.mnorm_zero_flags),
+ .abs_3(stage[3 - 1].op.minmax_abs),
+ .swap_3(stage[3 - 1].op.minmax_swap),
+ .zero_min_3(stage[3 - 1].op.minmax_zero_min),
+ .copy_flags_3(stage[3 - 1].op.minmax_copy_flags),
+ .int_signed_5(stage[5 - 1].op.shiftr_int_signed),
+ .copy_flags_6(stage[6 - 1].op.addsub_copy_flags),
+ .int_operand_6(stage[6 - 1].op.addsub_int_operand),
+ .force_nop_7(stage[7 - 1].op.clz_force_nop),
+ .copy_flags_11(stage[11 - 1].op.shiftl_copy_flags),
+ .copy_flags_12(stage[12 - 1].op.round_copy_flags),
+ .enable_12(stage[12 - 1].op.round_enable),
+ .enable_14(stage[14 - 1].op.encode_enable)
);
end
endgenerate
- always_ff @(posedge clk)
+ always_ff @(posedge clk) begin
+ stage[0].op <= op;
+ stage[0].wave <= wave;
+
for (int i = 1; i < FPINT_STAGES; ++i)
- stage_op[i] <= stage_op[i - 1];
+ stage[i] <= stage[i - 1];
+ end
always_ff @(posedge clk or negedge rst_n)
if (~rst_n) begin
@@ -217,8 +237,7 @@ import gfx::*;
stage_valid[i] <= stage_valid[i - 1];
// Se levanta 1 ciclo luego que in_valid
- if (abort)
- stage_valid[2] <= 0;
+ stage_valid[2] <= stage_valid[1] & ~abort;
wb.valid <= stage_valid[FPINT_STAGES - 1];
end
diff --git a/platform/wavelet3d/gfx_shader_front.sv b/platform/wavelet3d/gfx_shader_front.sv
index 5ad0203..52074fd 100644
--- a/platform/wavelet3d/gfx_shader_front.sv
+++ b/platform/wavelet3d/gfx_shader_front.sv
@@ -4,7 +4,13 @@ typedef struct
retry;
gfx::group_id group;
gfx_isa::insn_word insn;
-} shader_front_wave;
+} front_wave;
+
+typedef struct
+{
+ gfx::xgpr_num dest;
+ logic dest_scalar;
+} front_reg_passthru;
typedef logic[4:0] icache_line_num;
@@ -40,7 +46,11 @@ import gfx::*;
word fetch_insn, port_insn;
logic fetch_hit, p0_writeback;
- shader_front_wave bind_wave, dec_wave, port_dec_wave;
+ front_wave bind_wave, dec_wave, port_dec_wave;
+ front_reg_passthru reg_passthru;
+
+ assign front.execute.wave.dest = reg_passthru.dest;
+ assign front.execute.wave.dest_scalar = reg_passthru.dest_scalar;
gfx_shader_bind bind_
(
@@ -60,7 +70,8 @@ import gfx::*;
.rst_n,
.in(bind_wave),
.out(dec_wave),
- .read(reg_read)
+ .read(reg_read),
+ .passthru(reg_passthru)
);
gfx_shader_decode_class class_dec
@@ -68,7 +79,7 @@ import gfx::*;
.clk,
.rst_n,
.wave(dec_wave),
- .out_group(front.execute.group),
+ .out_group(front.execute.wave.group),
.port_wave(port_dec_wave),
.dispatch(front.dispatch),
.p0_writeback
@@ -99,7 +110,7 @@ import gfx::*;
gfx_regfile_io.bind_ regs,
- output shader_front_wave wave
+ output front_wave wave
);
localparam int ICACHE_STAGES = 6;
@@ -419,31 +430,48 @@ import gfx_isa::*;
input logic clk,
rst_n,
- input shader_front_wave in,
+ input front_wave in,
gfx_regfile_io.read read,
- output shader_front_wave out
+ output front_wave out,
+ output front_reg_passthru passthru
);
- localparam int HOLD_DEPTH = REG_READ_STAGES + 1 - 2;
+ // + 1 por next-cycle de read.op
+ localparam int PASSTHRU_DEPTH = REG_READ_STAGES + 1 - 2;
+ localparam int HOLD_DEPTH = PASSTHRU_DEPTH - 2;
logic reg_rev;
- logic hold_valid[HOLD_DEPTH];
- shader_front_wave hold[HOLD_DEPTH];
+ logic valid[HOLD_DEPTH];
+ front_wave out_hold[HOLD_DEPTH];
+ front_reg_passthru passthru_hold[PASSTHRU_DEPTH];
+
+ assign passthru = passthru_hold[$size(passthru_hold) - 1];
assign reg_rev = in.insn.reg_rev;
always_comb begin
- out = hold[$size(hold) - 1];
- out.valid = hold_valid[$size(hold_valid) - 1];
+ out = out_hold[$size(out_hold) - 1];
+ out.valid = valid[$size(valid) - 1];
end
always_ff @(posedge clk) begin
- hold[0] <= in;
+ out_hold[0] <= in;
+ for (int i = 1; i < $size(out_hold); ++i)
+ out_hold[i] <= out_hold[i - 1];
+
+ passthru_hold[0].dest <= in.insn.dst_src.rr.rd;
+ unique case (in.insn.reg_mode)
+ REGS_SVS, REGS_SSS:
+ passthru_hold[0].dest_scalar <= 1;
+
+ REGS_VVS, REGS_VVV:
+ passthru_hold[0].dest_scalar <= 0;
+ endcase
- for (int i = 1; i < HOLD_DEPTH; ++i)
- hold[i] <= hold[i - 1];
+ for (int i = 1; i < $size(passthru_hold); ++i)
+ passthru_hold[i] <= passthru_hold[i - 1];
read.op.group <= in.group;
@@ -476,13 +504,13 @@ import gfx_isa::*;
always_ff @(posedge clk or negedge rst_n)
if (~rst_n)
- for (int i = 1; i < HOLD_DEPTH; ++i)
- hold_valid[i] <= 0;
+ for (int i = 0; i < HOLD_DEPTH; ++i)
+ valid[i] <= 0;
else begin
- hold_valid[0] <= in.valid;
+ valid[0] <= in.valid;
for (int i = 1; i < HOLD_DEPTH; ++i)
- hold_valid[i] <= hold_valid[i - 1];
+ valid[i] <= valid[i - 1];
end
endmodule
@@ -491,19 +519,19 @@ module gfx_shader_decode_class
import gfx::*;
import gfx_isa::*;
(
- input logic clk,
- rst_n,
+ input logic clk,
+ rst_n,
- input shader_front_wave wave,
- output shader_front_wave port_wave,
- output group_id out_group,
+ input front_wave wave,
+ output front_wave port_wave,
+ output group_id out_group,
- output shader_dispatch dispatch,
- output logic p0_writeback
+ output shader_dispatch dispatch,
+ output logic p0_writeback
);
logic is_fsu, is_mem, is_group, hold_valid, retry;
- shader_front_wave hold_wave;
+ front_wave hold_wave;
assign p0_writeback = ~(is_mem | is_fsu | is_group | retry);
diff --git a/platform/wavelet3d/gfx_shader_group.sv b/platform/wavelet3d/gfx_shader_group.sv
index 7659bb9..e668877 100644
--- a/platform/wavelet3d/gfx_shader_group.sv
+++ b/platform/wavelet3d/gfx_shader_group.sv
@@ -5,6 +5,7 @@ import gfx::*;
rst_n,
input group_op op,
+ input wave_exec wave,
gfx_regfile_io.ab read_data,
diff --git a/platform/wavelet3d/gfx_shader_mem.sv b/platform/wavelet3d/gfx_shader_mem.sv
index 97561fb..403c9e4 100644
--- a/platform/wavelet3d/gfx_shader_mem.sv
+++ b/platform/wavelet3d/gfx_shader_mem.sv
@@ -5,6 +5,7 @@ import gfx::*;
rst_n,
input mem_op op,
+ input wave_exec wave,
gfx_regfile_io.ab read_data,
diff --git a/platform/wavelet3d/gfx_shader_regs.sv b/platform/wavelet3d/gfx_shader_regs.sv
index 7ae2e14..ef3a129 100644
--- a/platform/wavelet3d/gfx_shader_regs.sv
+++ b/platform/wavelet3d/gfx_shader_regs.sv
@@ -8,10 +8,17 @@ import gfx::*;
// verilator tracing_off
+ localparam PC_TABLE_PORTS = 2;
+ localparam MASK_TABLE_PORTS = 1;
+
word hold_imm[REGFILE_STAGES], imm_out, read_a_data_sgpr, read_b_data_scalar,
read_b_data_sgpr, read_const, read_a_data_vgpr[SHADER_LANES],
read_b_data_vgpr[SHADER_LANES], sgpr_out_a, sgpr_out_b;
+ group_id mask_read_groups[MASK_TABLE_PORTS], pc_read_groups[PC_TABLE_PORTS];
+ word_ptr pc_read[PC_TABLE_PORTS];
+ lane_mask mask_read[MASK_TABLE_PORTS];
+
logic a_scalar_out, b_is_const_out, b_is_imm_out, b_scalar_out, scalar_rev_out;
group_id hold_read_group_1, hold_read_group_2;
sgpr_num hold_read_a_sgpr;
@@ -20,6 +27,14 @@ import gfx::*;
logic[REGFILE_STAGES + 1 - 1:0] hold_scalar_rev;
logic[REGFILE_STAGES + 2 - 1:0] hold_a_scalar, hold_b_scalar;
+ assign io.pc_back = pc_read[0];
+ assign io.pc_front = pc_read[1];
+ assign pc_read_groups[0] = io.pc_back_group;
+ assign pc_read_groups[1] = io.pc_front_group;
+
+ assign io.mask_back = mask_read[0];
+ assign pc_read_groups[0] = io.mask_back_group;
+
assign imm_out = hold_imm[$size(hold_imm) - 1];
assign a_scalar_out = hold_a_scalar[$bits(hold_a_scalar) - 1];
assign b_scalar_out = hold_b_scalar[$bits(hold_b_scalar) - 1];
@@ -27,11 +42,24 @@ import gfx::*;
assign b_is_const_out = hold_b_is_const[$bits(hold_b_is_const) - 1];
assign scalar_rev_out = hold_scalar_rev[$bits(hold_scalar_rev) - 1];
- gfx_shader_pc_table pcs
+ gfx_shader_table #(.DATA_WIDTH($bits(word_ptr)), .READ_PORTS(PC_TABLE_PORTS)) pc_table
+ (
+ .clk,
+ .read(pc_read),
+ .write(io.pc_wb),
+ .read_groups(pc_read_groups),
+ .write_group(io.pc_wb_group),
+ .write_enable(io.pc_wb_write)
+ );
+
+ gfx_shader_table #(.DATA_WIDTH($bits(lane_mask)), .READ_PORTS(MASK_TABLE_PORTS)) mask_table
(
.clk,
- .read(io.pc_front),
- .read_group(io.pc_front_group)
+ .read(mask_read),
+ .write(io.mask_wb),
+ .read_groups(mask_read_groups),
+ .write_group(io.mask_wb_group),
+ .write_enable(io.mask_wb_write)
);
gfx_shader_consts consts
@@ -231,23 +259,44 @@ import gfx::*;
endmodule
-module gfx_shader_pc_table
+module gfx_shader_table
import gfx::*;
+#(int DATA_WIDTH = 0,
+ int READ_PORTS = 0)
(
- input logic clk,
+ input logic clk,
+
+ input group_id write_group,
+ read_groups[READ_PORTS],
- input group_id read_group,
+ input logic[DATA_WIDTH - 1:0] write,
+ input logic write_enable,
- output word_ptr read
+ output logic[DATA_WIDTH - 1:0] read[READ_PORTS]
);
- group_id read_group_hold;
- word_ptr pcs[1 << $bits(group_id)], read_hold;
+ genvar i;
- always_ff @(posedge clk) begin
- read <= read_hold;
- read_hold <= pcs[read_group_hold];
- read_group_hold <= read_group;
- end
+ generate
+ for (i = 0; i < READ_PORTS; ++i) begin: ports
+ logic write_enable_hold;
+ group_id read_group_hold, write_group_hold;
+ logic[DATA_WIDTH - 1:0] data[1 << $bits(group_id)], read_hold, write_hold;
+
+ always_ff @(posedge clk) begin
+ write_hold <= write;
+ read_group_hold <= read_groups[i];
+ write_group_hold <= write_group;
+ write_enable_hold <= write_enable;
+
+ read_hold <= data[read_group_hold];
+
+ if (write_enable_hold)
+ data[write_group_hold] <= write_hold;
+
+ read[i] <= read_hold;
+ end
+ end
+ endgenerate
endmodule
diff --git a/platform/wavelet3d/gfx_shader_schedif.rdl b/platform/wavelet3d/gfx_shader_schedif.rdl
index 2ab31ac..c846da9 100644
--- a/platform/wavelet3d/gfx_shader_schedif.rdl
+++ b/platform/wavelet3d/gfx_shader_schedif.rdl
@@ -13,12 +13,12 @@ addrmap gfx_shader_schedif {
singlepulse;
} IFLUSH[0:0] = 0;
- } CORE @ 0x0;
+ } CORE @ 0x00;
reg {
name = "Wavefront setup control register";
- default hw = w;
+ default hw = na;
default sw = r;
default precedence = hw;
@@ -49,7 +49,14 @@ addrmap gfx_shader_schedif {
rclr;
hwset;
} GPR_DONE[17:17] = 0;
- } SETUP_CTRL @ 0x4;
+
+ field {
+ desc = "Lane mask update done";
+
+ rclr;
+ hwset;
+ } MASK_DONE[18:18] = 0;
+ } SETUP_CTRL @ 0x04;
reg {
name = "SGPR/VGPR write register";
@@ -59,7 +66,17 @@ addrmap gfx_shader_schedif {
swmod;
} VALUE[31:0];
- } SETUP_GPR @ 0x8;
+ } SETUP_GPR @ 0x08;
+
+ reg {
+ name = "Lane mask write register";
+
+ field {
+ desc = "Mask value to write";
+
+ swmod;
+ } MASK[15:0];
+ } SETUP_MASK @ 0x0c;
reg {
name = "Group submit register";
@@ -69,6 +86,6 @@ addrmap gfx_shader_schedif {
swmod;
} PC[31:2];
- } SETUP_SUBMIT @ 0xc;
+ } SETUP_SUBMIT @ 0x10;
};
diff --git a/platform/wavelet3d/gfx_shader_setup.sv b/platform/wavelet3d/gfx_shader_setup.sv
new file mode 100644
index 0000000..f46fb66
--- /dev/null
+++ b/platform/wavelet3d/gfx_shader_setup.sv
@@ -0,0 +1,37 @@
+interface gfx_shader_setup
+import gfx::*;;
+
+ struct
+ {
+ group_id group;
+ word_ptr pc;
+ xgpr_num gpr;
+ word gpr_value;
+ lane_mask mask;
+ logic pc_set,
+ gpr_set,
+ mask_set;
+ } write;
+
+ struct
+ {
+ logic gpr,
+ mask,
+ submit;
+ } set_done;
+
+ modport core
+ (
+ input write,
+
+ output set_done
+ );
+
+ modport sched
+ (
+ input set_done,
+
+ output write
+ );
+
+endinterface
diff --git a/platform/wavelet3d/gfx_shader_sfu.sv b/platform/wavelet3d/gfx_shader_sfu.sv
index 614d5a1..d65e522 100644
--- a/platform/wavelet3d/gfx_shader_sfu.sv
+++ b/platform/wavelet3d/gfx_shader_sfu.sv
@@ -5,6 +5,7 @@ import gfx::*;
rst_n,
input sfu_op op,
+ input wave_exec wave,
gfx_regfile_io.ab read_data,
diff --git a/platform/wavelet3d/gfx_top.sv b/platform/wavelet3d/gfx_top.sv
index 5b8a0ce..41ff7f4 100644
--- a/platform/wavelet3d/gfx_top.sv
+++ b/platform/wavelet3d/gfx_top.sv
@@ -100,6 +100,7 @@ import gfx::*;
.rst_n,
.op,
.wb(fpint_wb.tx),
+ .wave(),
.abort(0),
.in_valid,
.read_data(fpint_io.ab)
diff --git a/platform/wavelet3d/gfx_wb.sv b/platform/wavelet3d/gfx_wb.sv
index cc25944..20c7c64 100644
--- a/platform/wavelet3d/gfx_wb.sv
+++ b/platform/wavelet3d/gfx_wb.sv
@@ -3,9 +3,11 @@ interface gfx_wb;
import gfx::*;
word lanes[SHADER_LANES];
- logic ready, scalar, valid, writeback;
+ logic mask_update, pc_inc, pc_update, ready, scalar, valid, writeback;
group_id group;
xgpr_num dest;
+ lane_mask mask;
+ pc_offset pc_add;
modport tx
(
@@ -16,7 +18,14 @@ interface gfx_wb;
lanes,
valid,
scalar,
- writeback
+ writeback,
+
+ mask,
+ mask_update,
+
+ pc_add,
+ pc_inc,
+ pc_update
);
modport rx
@@ -28,6 +37,13 @@ interface gfx_wb;
scalar,
writeback,
+ mask,
+ mask_update,
+
+ pc_add,
+ pc_inc,
+ pc_update,
+
output ready
);