summaryrefslogtreecommitdiff
path: root/platform
diff options
context:
space:
mode:
authorAlejandro Soto <alejandro@34project.org>2024-03-03 20:39:39 -0600
committerAlejandro Soto <alejandro@34project.org>2024-03-03 20:42:36 -0600
commit2b7b2185d381f2c5fd4aee19bd3a3508b4c9557f (patch)
treeacd92b7a9de4678a08cf7e55f99b27fc23ff8938 /platform
parentcce507d21c86f20a83eec1b09fe3231399ffb10c (diff)
platform/wavelet3d: implement rounded fmul
Diffstat (limited to 'platform')
-rw-r--r--platform/wavelet3d/gfx_float_lane.sv53
-rw-r--r--platform/wavelet3d/gfx_fmul_lane.sv82
-rw-r--r--platform/wavelet3d/gfx_pkg.sv49
-rw-r--r--platform/wavelet3d/gfx_round_lane.sv62
-rw-r--r--platform/wavelet3d/main.cpp19
-rw-r--r--platform/wavelet3d/mod.mk7
6 files changed, 270 insertions, 2 deletions
diff --git a/platform/wavelet3d/gfx_float_lane.sv b/platform/wavelet3d/gfx_float_lane.sv
new file mode 100644
index 0000000..4d214f6
--- /dev/null
+++ b/platform/wavelet3d/gfx_float_lane.sv
@@ -0,0 +1,53 @@
+module gfx_float_lane
+(
+ input logic clk,
+
+ input gfx::float a,
+ b,
+
+ output gfx::float q
+);
+
+ import gfx::*;
+
+ logic slow_fmul;
+ float_round q_fmul;
+ float_special a_special, b_special;
+
+ function float_special front_flags(float in);
+ front_flags.val = in;
+ front_flags.exp_max = &in.exp;
+ front_flags.exp_min = ~|in.exp;
+ front_flags.mant_zero = ~|in.mant;
+ endfunction
+
+ function logic is_special(float_special in);
+ is_special = in.exp_max | (in.exp_min & ~in.mant_zero);
+ endfunction
+
+ gfx_fmul_lane fmul
+ (
+ .clk(clk),
+ .a(a_special),
+ .b(b_special),
+ .q(q_fmul),
+ .slow_in(slow_fmul)
+ );
+
+ gfx_round_lane round
+ (
+ .clk(clk),
+ .in(q_fmul),
+ .out(q)
+ );
+
+ always_comb begin
+ slow_fmul = is_special(a_special) | is_special(b_special);
+ end
+
+ always_ff @(posedge clk) begin
+ a_special <= front_flags(a);
+ b_special <= front_flags(b);
+ end
+
+endmodule
diff --git a/platform/wavelet3d/gfx_fmul_lane.sv b/platform/wavelet3d/gfx_fmul_lane.sv
new file mode 100644
index 0000000..17a988e
--- /dev/null
+++ b/platform/wavelet3d/gfx_fmul_lane.sv
@@ -0,0 +1,82 @@
+module gfx_fmul_lane
+(
+ input logic clk,
+
+ input gfx::float_special a,
+ b,
+ input logic slow_in,
+
+ output gfx::float_round q
+);
+
+ import gfx::*;
+
+ /* Queremos calcular q = a * b.
+ *
+ * Donde a = (-1)^s * 1.m * 2^f,
+ * b = (-1)^t * 1.n * 2^g
+ *
+ * Entonces q = (-1)^(s + t) (1.m * 1.n) 2^(f + g)
+ *
+ * El producto es entre números >= 1.0 y < 2.0. En el peor caso:
+ * Mejor caso: 1.000... * 1.000... ~ 1.000...
+ * Peor caso: 1.999... * 1.999... ~ 3.999... = 2^1 * 1.999
+ *
+ * Así que, si el producto es >= 2, hay que hacerle >> 1 a la mantisa
+ * y sumarle 1 al exponente para normalizar.
+ */
+
+ logic guard, lo_msb, lo_reduce, overflow_0, overflow_1,
+ round, sign, slow_0, slow_1, zero;
+
+ float_exp exp;
+ float_round out;
+ float_mant_full hi;
+ logic[$bits(float_mant_full) - 3:0] lo;
+
+ assign lo_msb = lo[$bits(lo) - 1];
+ assign lo_reduce = |lo[$bits(lo) - 2:0];
+
+ always_comb begin
+ q = out;
+ q.slow = slow_1 | overflow_1;
+ end
+
+ always_ff @(posedge clk) begin
+ // Stage 0: producto
+
+ sign <= a.val.sign ^ b.val.sign;
+ zero <= a.exp_min | b.exp_min;
+ slow_0 <= slow_in;
+
+ {overflow_0, exp} <= {1'b0, a.val.exp} + {1'b0, b.val.exp} - {1'b0, FLOAT_EXP_BIAS};
+ {hi, guard, round, lo} <= full_mant(a.val.mant) * full_mant(b.val.mant);
+
+ // Stage 1: normalización
+
+ slow_1 <= slow_0 | overflow_0;
+ overflow_1 <= 0;
+
+ out.slow <= 1'bx; // Ver 'q'
+ out.zero <= zero;
+ out.normal.sign <= sign;
+
+ if (hi[$bits(hi) - 1]) begin
+ out.guard <= guard;
+ out.round <= round;
+ out.sticky <= lo_msb | lo_reduce;
+ out.normal.mant <= implicit_mant(hi);
+ {overflow_1, out.normal.exp} <= {1'b0, exp} + 1;
+ end else begin
+ /* Bit antes de msb es necesariamente 1, ya que los msb de
+ * ambos multiplicandos son 1. Ver assert en implicit_mant().
+ */
+ out.guard <= round;
+ out.round <= lo[$bits(lo) - 1];
+ out.sticky <= lo_reduce;
+ out.normal.exp <= exp;
+ out.normal.mant <= implicit_mant({hi[$bits(hi) - 2:0], guard});
+ end
+ end
+
+endmodule
diff --git a/platform/wavelet3d/gfx_pkg.sv b/platform/wavelet3d/gfx_pkg.sv
new file mode 100644
index 0000000..27c1117
--- /dev/null
+++ b/platform/wavelet3d/gfx_pkg.sv
@@ -0,0 +1,49 @@
+package gfx;
+
+ typedef logic[31:0] float_word;
+ typedef logic[7:0] float_exp;
+
+ typedef logic[$bits(float_word) - $bits(float_exp) - 2:0] float_mant;
+ typedef logic[$bits(float_mant):0] float_mant_full; // Incluye '1.' explícito
+
+ localparam float_exp FLOAT_EXP_BIAS = (1 << ($bits(float_exp) - 1)) - 1;
+ localparam float_exp FLOAT_EXP_MAX = {($bits(float_exp)){1'b1}};
+
+ function float_mant_full full_mant(float_mant in);
+ full_mant = {1'b1, in};
+ endfunction
+
+ function float_mant implicit_mant(float_mant_full in);
+ assert (in[$bits(in) - 1]);
+ implicit_mant = in[$bits(in) - 2:0];
+ endfunction
+
+ typedef struct packed
+ {
+ logic sign;
+ float_exp exp;
+ float_mant mant;
+ } float;
+
+ /* Explicación de guard, round, sticky:
+ * https://drilian.com/2023/01/10/floating-point-numbers-and-rounding/
+ */
+ typedef struct packed
+ {
+ float normal;
+ logic slow,
+ zero,
+ guard,
+ round,
+ sticky;
+ } float_round;
+
+ typedef struct packed
+ {
+ float val;
+ logic exp_max,
+ exp_min,
+ mant_zero;
+ } float_special;
+
+endpackage
diff --git a/platform/wavelet3d/gfx_round_lane.sv b/platform/wavelet3d/gfx_round_lane.sv
new file mode 100644
index 0000000..d0b0b03
--- /dev/null
+++ b/platform/wavelet3d/gfx_round_lane.sv
@@ -0,0 +1,62 @@
+module gfx_round_lane
+(
+ input logic clk,
+
+ input gfx::float_round in,
+
+ output gfx::float out
+);
+
+ import gfx::*;
+
+ logic exp_step, overflow, sign_0, sign_1, slow_0, slow_1,
+ slow_out, zero_0, zero_1;
+
+ float_exp exp_0, exp_1;
+ float_mant mant_0, mant_1;
+
+ assign slow_out = slow_1 || overflow || &exp_1;
+
+ always_ff @(posedge clk) begin
+ // Stage 0: redondeo
+
+ exp_0 <= in.normal.exp;
+ sign_0 <= in.normal.sign;
+ slow_0 <= in.slow;
+ zero_0 <= in.zero;
+ exp_step <= 0;
+
+ // Este es el modo más común: round to nearest, ties to even
+ if (in.guard & (in.round | in.sticky | in.normal.mant[0]))
+ {exp_step, mant_0} <= {1'b0, in.normal.mant} + 1;
+ else
+ mant_0 <= in.normal.mant;
+
+ sign_1 <= sign_0;
+ slow_1 <= slow_0;
+ zero_1 <= zero_0;
+ mant_1 <= mant_0;
+ overflow <= 0;
+
+ if (exp_step)
+ {overflow, exp_1} <= {1'b0, exp_0} + 1;
+ else
+ exp_1 <= exp_0;
+
+ // Stage 1: ceros y slow path
+
+ out.sign <= sign_1;
+
+ if (slow_out) begin
+ out.exp <= FLOAT_EXP_MAX;
+ out.mant <= 1;
+ end else if (zero_1) begin
+ out.exp <= 0;
+ out.mant <= 0;
+ end else begin
+ out.exp <= exp_1;
+ out.mant <= mant_1;
+ end
+ end
+
+endmodule
diff --git a/platform/wavelet3d/main.cpp b/platform/wavelet3d/main.cpp
index ce632b6..1243dba 100644
--- a/platform/wavelet3d/main.cpp
+++ b/platform/wavelet3d/main.cpp
@@ -1,3 +1,4 @@
+#include <iostream>
#include <cstddef>
#include <cstdio>
#include <cstdlib>
@@ -21,6 +22,24 @@ int main(int argc, char **argv)
#endif
Py_Initialize();
+
+ float a, b;
+ std::cin >> a >> b;
+
+ top.a = *reinterpret_cast<unsigned*>(&a);
+ top.b = *reinterpret_cast<unsigned*>(&b);
+
+ for (int i = 0; i < 1000; ++i) {
+ top.clk = 0;
+ top.eval();
+
+ top.clk = 1;
+ top.eval();
+ }
+
+ unsigned q = top.q;
+ std::cout << a << " * " << b << " = " << *reinterpret_cast<float*>(&q) << '\n';
+
bool failed = Py_FinalizeEx() < 0;
#if VM_TRACE
diff --git a/platform/wavelet3d/mod.mk b/platform/wavelet3d/mod.mk
index ab54441..a12392e 100644
--- a/platform/wavelet3d/mod.mk
+++ b/platform/wavelet3d/mod.mk
@@ -1,6 +1,9 @@
define core
- $(this)/deps := picorv32
- $(this)/rtl_top := picorv32
+ $(this)/deps := dma_axi32 picorv32
+
+ $(this)/rtl_top := gfx_float_lane
+ $(this)/rtl_files := gfx_pkg.sv gfx_float_lane.sv gfx_fmul_lane.sv gfx_round_lane.sv
+
$(this)/vl_main := main.cpp
$(this)/vl_pkgconfig := python3-embed
endef