diff options
Diffstat (limited to 'platform')
| -rw-r--r-- | platform/wavelet3d/gfx_float_lane.sv | 53 | ||||
| -rw-r--r-- | platform/wavelet3d/gfx_fmul_lane.sv | 82 | ||||
| -rw-r--r-- | platform/wavelet3d/gfx_pkg.sv | 49 | ||||
| -rw-r--r-- | platform/wavelet3d/gfx_round_lane.sv | 62 | ||||
| -rw-r--r-- | platform/wavelet3d/main.cpp | 19 | ||||
| -rw-r--r-- | platform/wavelet3d/mod.mk | 7 |
6 files changed, 270 insertions, 2 deletions
diff --git a/platform/wavelet3d/gfx_float_lane.sv b/platform/wavelet3d/gfx_float_lane.sv new file mode 100644 index 0000000..4d214f6 --- /dev/null +++ b/platform/wavelet3d/gfx_float_lane.sv @@ -0,0 +1,53 @@ +module gfx_float_lane +( + input logic clk, + + input gfx::float a, + b, + + output gfx::float q +); + + import gfx::*; + + logic slow_fmul; + float_round q_fmul; + float_special a_special, b_special; + + function float_special front_flags(float in); + front_flags.val = in; + front_flags.exp_max = &in.exp; + front_flags.exp_min = ~|in.exp; + front_flags.mant_zero = ~|in.mant; + endfunction + + function logic is_special(float_special in); + is_special = in.exp_max | (in.exp_min & ~in.mant_zero); + endfunction + + gfx_fmul_lane fmul + ( + .clk(clk), + .a(a_special), + .b(b_special), + .q(q_fmul), + .slow_in(slow_fmul) + ); + + gfx_round_lane round + ( + .clk(clk), + .in(q_fmul), + .out(q) + ); + + always_comb begin + slow_fmul = is_special(a_special) | is_special(b_special); + end + + always_ff @(posedge clk) begin + a_special <= front_flags(a); + b_special <= front_flags(b); + end + +endmodule diff --git a/platform/wavelet3d/gfx_fmul_lane.sv b/platform/wavelet3d/gfx_fmul_lane.sv new file mode 100644 index 0000000..17a988e --- /dev/null +++ b/platform/wavelet3d/gfx_fmul_lane.sv @@ -0,0 +1,82 @@ +module gfx_fmul_lane +( + input logic clk, + + input gfx::float_special a, + b, + input logic slow_in, + + output gfx::float_round q +); + + import gfx::*; + + /* Queremos calcular q = a * b. + * + * Donde a = (-1)^s * 1.m * 2^f, + * b = (-1)^t * 1.n * 2^g + * + * Entonces q = (-1)^(s + t) (1.m * 1.n) 2^(f + g) + * + * El producto es entre números >= 1.0 y < 2.0. En el peor caso: + * Mejor caso: 1.000... * 1.000... ~ 1.000... + * Peor caso: 1.999... * 1.999... ~ 3.999... = 2^1 * 1.999 + * + * Así que, si el producto es >= 2, hay que hacerle >> 1 a la mantisa + * y sumarle 1 al exponente para normalizar. + */ + + logic guard, lo_msb, lo_reduce, overflow_0, overflow_1, + round, sign, slow_0, slow_1, zero; + + float_exp exp; + float_round out; + float_mant_full hi; + logic[$bits(float_mant_full) - 3:0] lo; + + assign lo_msb = lo[$bits(lo) - 1]; + assign lo_reduce = |lo[$bits(lo) - 2:0]; + + always_comb begin + q = out; + q.slow = slow_1 | overflow_1; + end + + always_ff @(posedge clk) begin + // Stage 0: producto + + sign <= a.val.sign ^ b.val.sign; + zero <= a.exp_min | b.exp_min; + slow_0 <= slow_in; + + {overflow_0, exp} <= {1'b0, a.val.exp} + {1'b0, b.val.exp} - {1'b0, FLOAT_EXP_BIAS}; + {hi, guard, round, lo} <= full_mant(a.val.mant) * full_mant(b.val.mant); + + // Stage 1: normalización + + slow_1 <= slow_0 | overflow_0; + overflow_1 <= 0; + + out.slow <= 1'bx; // Ver 'q' + out.zero <= zero; + out.normal.sign <= sign; + + if (hi[$bits(hi) - 1]) begin + out.guard <= guard; + out.round <= round; + out.sticky <= lo_msb | lo_reduce; + out.normal.mant <= implicit_mant(hi); + {overflow_1, out.normal.exp} <= {1'b0, exp} + 1; + end else begin + /* Bit antes de msb es necesariamente 1, ya que los msb de + * ambos multiplicandos son 1. Ver assert en implicit_mant(). + */ + out.guard <= round; + out.round <= lo[$bits(lo) - 1]; + out.sticky <= lo_reduce; + out.normal.exp <= exp; + out.normal.mant <= implicit_mant({hi[$bits(hi) - 2:0], guard}); + end + end + +endmodule diff --git a/platform/wavelet3d/gfx_pkg.sv b/platform/wavelet3d/gfx_pkg.sv new file mode 100644 index 0000000..27c1117 --- /dev/null +++ b/platform/wavelet3d/gfx_pkg.sv @@ -0,0 +1,49 @@ +package gfx; + + typedef logic[31:0] float_word; + typedef logic[7:0] float_exp; + + typedef logic[$bits(float_word) - $bits(float_exp) - 2:0] float_mant; + typedef logic[$bits(float_mant):0] float_mant_full; // Incluye '1.' explícito + + localparam float_exp FLOAT_EXP_BIAS = (1 << ($bits(float_exp) - 1)) - 1; + localparam float_exp FLOAT_EXP_MAX = {($bits(float_exp)){1'b1}}; + + function float_mant_full full_mant(float_mant in); + full_mant = {1'b1, in}; + endfunction + + function float_mant implicit_mant(float_mant_full in); + assert (in[$bits(in) - 1]); + implicit_mant = in[$bits(in) - 2:0]; + endfunction + + typedef struct packed + { + logic sign; + float_exp exp; + float_mant mant; + } float; + + /* Explicación de guard, round, sticky: + * https://drilian.com/2023/01/10/floating-point-numbers-and-rounding/ + */ + typedef struct packed + { + float normal; + logic slow, + zero, + guard, + round, + sticky; + } float_round; + + typedef struct packed + { + float val; + logic exp_max, + exp_min, + mant_zero; + } float_special; + +endpackage diff --git a/platform/wavelet3d/gfx_round_lane.sv b/platform/wavelet3d/gfx_round_lane.sv new file mode 100644 index 0000000..d0b0b03 --- /dev/null +++ b/platform/wavelet3d/gfx_round_lane.sv @@ -0,0 +1,62 @@ +module gfx_round_lane +( + input logic clk, + + input gfx::float_round in, + + output gfx::float out +); + + import gfx::*; + + logic exp_step, overflow, sign_0, sign_1, slow_0, slow_1, + slow_out, zero_0, zero_1; + + float_exp exp_0, exp_1; + float_mant mant_0, mant_1; + + assign slow_out = slow_1 || overflow || &exp_1; + + always_ff @(posedge clk) begin + // Stage 0: redondeo + + exp_0 <= in.normal.exp; + sign_0 <= in.normal.sign; + slow_0 <= in.slow; + zero_0 <= in.zero; + exp_step <= 0; + + // Este es el modo más común: round to nearest, ties to even + if (in.guard & (in.round | in.sticky | in.normal.mant[0])) + {exp_step, mant_0} <= {1'b0, in.normal.mant} + 1; + else + mant_0 <= in.normal.mant; + + sign_1 <= sign_0; + slow_1 <= slow_0; + zero_1 <= zero_0; + mant_1 <= mant_0; + overflow <= 0; + + if (exp_step) + {overflow, exp_1} <= {1'b0, exp_0} + 1; + else + exp_1 <= exp_0; + + // Stage 1: ceros y slow path + + out.sign <= sign_1; + + if (slow_out) begin + out.exp <= FLOAT_EXP_MAX; + out.mant <= 1; + end else if (zero_1) begin + out.exp <= 0; + out.mant <= 0; + end else begin + out.exp <= exp_1; + out.mant <= mant_1; + end + end + +endmodule diff --git a/platform/wavelet3d/main.cpp b/platform/wavelet3d/main.cpp index ce632b6..1243dba 100644 --- a/platform/wavelet3d/main.cpp +++ b/platform/wavelet3d/main.cpp @@ -1,3 +1,4 @@ +#include <iostream> #include <cstddef> #include <cstdio> #include <cstdlib> @@ -21,6 +22,24 @@ int main(int argc, char **argv) #endif Py_Initialize(); + + float a, b; + std::cin >> a >> b; + + top.a = *reinterpret_cast<unsigned*>(&a); + top.b = *reinterpret_cast<unsigned*>(&b); + + for (int i = 0; i < 1000; ++i) { + top.clk = 0; + top.eval(); + + top.clk = 1; + top.eval(); + } + + unsigned q = top.q; + std::cout << a << " * " << b << " = " << *reinterpret_cast<float*>(&q) << '\n'; + bool failed = Py_FinalizeEx() < 0; #if VM_TRACE diff --git a/platform/wavelet3d/mod.mk b/platform/wavelet3d/mod.mk index ab54441..a12392e 100644 --- a/platform/wavelet3d/mod.mk +++ b/platform/wavelet3d/mod.mk @@ -1,6 +1,9 @@ define core - $(this)/deps := picorv32 - $(this)/rtl_top := picorv32 + $(this)/deps := dma_axi32 picorv32 + + $(this)/rtl_top := gfx_float_lane + $(this)/rtl_files := gfx_pkg.sv gfx_float_lane.sv gfx_fmul_lane.sv gfx_round_lane.sv + $(this)/vl_main := main.cpp $(this)/vl_pkgconfig := python3-embed endef |
