summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlejandro Soto <alejandro@34project.org>2024-03-08 01:11:21 -0600
committerAlejandro Soto <alejandro@34project.org>2024-03-08 01:11:21 -0600
commitc8b633207f42b85480635573fd2a271b842c1260 (patch)
tree10c0a491bdc27ebed723f25dc8d3be790c6194c1
parent9ecb54925bbbc218c183c0ee9cb3b27bf79e6e80 (diff)
platform/wavelet3d: implement int->float conversion
Diffstat (limited to '')
-rw-r--r--platform/wavelet3d/gfx_fpint_lane.sv81
-rw-r--r--platform/wavelet3d/main.cpp67
2 files changed, 103 insertions, 45 deletions
diff --git a/platform/wavelet3d/gfx_fpint_lane.sv b/platform/wavelet3d/gfx_fpint_lane.sv
index 0010f06..63d56e2 100644
--- a/platform/wavelet3d/gfx_fpint_lane.sv
+++ b/platform/wavelet3d/gfx_fpint_lane.sv
@@ -10,10 +10,12 @@ module gfx_fpint_lane
float_a_1,
int_hi_a_1,
int_lo_a_1,
- zero_flags_a_1,
+ zero_flags_1,
zero_b_1,
copy_flags_2,
+ int_signed_4,
copy_flags_5,
+ int_operand_5,
enable_norm_6,
copy_flags_10,
copy_flags_11,
@@ -25,9 +27,10 @@ module gfx_fpint_lane
import gfx::*;
- /* Notas de implementación para floating-point.
+ /* Notas de implementación para floating-point
*
* === PRODUCTO ===
+ *
* Queremos calcular q = a * b.
*
* Donde a = (-1)^s * 1.m * 2^f,
@@ -55,10 +58,19 @@ module gfx_fpint_lane
* En el caso de una resta, el exponente normalizado puede ser mucho más
* pequeño que cualquiera de los exponentes de entrada. Necesitamos
* entonces de lǵoica CLZ (count leading zeros) para renormalizar.
+ *
+ *
+ * === CONVERSIÓN INTEGER->FP ===
+ *
+ * Esto simplemente usa el mismo datapath de fadd, con el abs del entero
+ * como entrada como entrada de clz. El exponente de referencia se fija
+ * en 30 (aludiendo al segundo msb de un entero de 32 bits). A partir de
+ * ese punto es idéntico a un fadd, las etapas de clz se encargan de ajustar
+ * el exponente.
*/
logic exp_step, guard_0, guard_1, guard_2, guard_3, guard_4, guard_5, guard_10,
- lo_msb, lo_reduce, overflow_0, overflow_1, overflow_10, overflow_12,
+ int_sign, lo_msb, lo_reduce, overflow_0, overflow_1, overflow_10, overflow_12,
round_0, round_1, round_2, round_3, round_4, round_5, round_10, sign_0,
sign_10, sign_11, sign_12, slow_1, slow_2, slow_3, slow_4, slow_5, slow_10,
slow_11, slow_12, slow_in_1, slow_in_next, slow_out, sticky_1, sticky_2,
@@ -71,7 +83,7 @@ module gfx_fpint_lane
float_class a_class_0, a_class_1, b_class_0, b_class_1,
max_class_2, max_class_3, min_class_2, min_class_3, min_class_4;
- word clz_in, product_hi, product_lo;
+ word add_sub, clz_in, normalized, product_hi, product_lo;
dword product;
float_exp exp, exp_11, exp_10, exp_12, exp_delta;
float_mant mant_10, mant_11, mant_12;
@@ -81,15 +93,13 @@ module gfx_fpint_lane
typedef logic[$bits(float_mant_full) + 1:0] extended_mant;
localparam bit[$clog2($bits(extended_mant)):0] MAX_SHIFT = 1 << $clog2($bits(extended_mant));
- localparam int SHIFT_WIDTH = {{($bits(int) - $bits(MAX_SHIFT)){1'b0}}, MAX_SHIFT};
- localparam int CLZ_EXTEND_BITS = $bits(float_exp) - $bits(clz_shift) + 1;
-
- typedef logic[$bits(float_mant_full) + 2:0] mant_sum;
-
- mant_sum add_sub, normalized;
extended_mant max_mant, min_mant, sticky_mask;
logic[$clog2(MAX_SHIFT):0] clz_shift, exp_shift;
+ localparam int INT_SHIFT_REF = $bits(word) - 2;
+ localparam int SHIFT_WIDTH = {{($bits(int) - $bits(MAX_SHIFT)){1'b0}}, MAX_SHIFT};
+ localparam int CLZ_EXTEND_BITS = $bits(float_exp) - $bits(clz_shift) + 1;
+
struct packed
{
float max;
@@ -98,7 +108,7 @@ module gfx_fpint_lane
slow,
sticky,
zero;
- mant_sum add_sub;
+ word add_sub;
} clz_hold[FADD_CLZ_STAGES], clz_hold_out;
gfx_clz #($bits(word)) clz
@@ -112,18 +122,22 @@ module gfx_fpint_lane
extend_min_max = {~in_class.exp_min, in.mant, 2'b00};
endfunction
+ function word fp_add_sub_arg(extended_mant arg);
+ fp_add_sub_arg = {1'b0, arg, {($bits(fp_add_sub_arg) - $bits(arg) - 1){1'b0}}};
+ endfunction
+
assign lo_msb = lo[$bits(lo) - 1];
assign slow_out = &exp_12 || slow_12 || overflow_12;
assign exp_delta = max_2.exp - min_2.exp;
assign lo_reduce = |lo[$bits(lo) - 2:0];
- assign normalized = add_sub << clz_shift;
+ assign normalized = clz_hold_out.add_sub << clz_shift;
assign clz_hold_out = clz_hold[FADD_CLZ_STAGES - 1];
assign slow_in_next = is_float_special(a_class_0) | is_float_special(b_class_0);
assign {product_hi, product_lo} = product;
assign {hi, guard_0, round_0, lo} = product[2 * $bits(float_mant_full) - 1:0];
always_comb begin
- clz_in = {add_sub, {($bits(clz_in) - $bits(add_sub)){1'b0}}};
+ clz_in = add_sub;
if (~enable_norm_6)
clz_in[$bits(clz_in) - 1:$bits(clz_in) - 2] = 2'b01;
end
@@ -136,8 +150,8 @@ module gfx_fpint_lane
a_mul <= a;
b_mul <= b;
- /* Nótese que el orden es sign-exp-mant. Esto coloca el 1. implícito
- * en la posición correcta para multiplicar mantisas.
+ /* Nótese que el orden es sign-exp-mant. Esto coloca el '1.' implícito
+ * en la posición correcta para multiplicar las mantisas.
*/
if (mul_float_m1) begin
a_mul.exp <= 1;
@@ -146,10 +160,10 @@ module gfx_fpint_lane
b_mul.sign <= 0;
end
- // Genera un nop junto a lo anterior
if (unit_b_m1) begin
b_mul.exp <= 0;
b_mul.mant <= 1;
+ b_mul.sign <= 0;
end
// Stage 0: multiplicación de fp o enteros
@@ -163,9 +177,6 @@ module gfx_fpint_lane
// Stage 1: normalización
- slow_in_1 <= slow_in_next;
- overflow_1 <= 0;
-
if (float_a_1) begin
slow_1 <= slow_in_next | (overflow_0 & ~a_class_0.exp_min & ~a_class_1.exp_min);
zero_1 <= a_class_0.exp_min | b_class_0.exp_min;
@@ -174,7 +185,9 @@ module gfx_fpint_lane
zero_1 <= 0;
end
+ overflow_1 <= 0;
a_add.sign <= sign_0;
+
if (hi[$bits(hi) - 1]) begin
guard_1 <= guard_0;
round_1 <= round_0;
@@ -203,8 +216,12 @@ module gfx_fpint_lane
endcase
a_class_1 <= a_class_0;
- if (zero_flags_a_1)
+ slow_in_1 <= slow_in_next;
+
+ if (zero_flags_1) begin
a_class_1 <= classify_float(0);
+ slow_in_1 <= 0;
+ end
if (zero_b_1) begin
b_add <= 0;
@@ -264,9 +281,8 @@ module gfx_fpint_lane
if (exp_delta > {{($bits(exp_delta) - $bits(MAX_SHIFT)){1'b0}}, MAX_SHIFT})
exp_shift <= MAX_SHIFT;
- // Stage 4: shifts
+ // Stage 4: shifts y abs(max) para enteros con signo
- max_4 <= max_3;
min_4 <= min_3;
slow_4 <= slow_3;
zero_4 <= zero_3;
@@ -279,7 +295,13 @@ module gfx_fpint_lane
min_mant <= extend_min_max(min_3, min_class_3) >> exp_shift;
sticky_mask <= {($bits(min_mant)){1'b1}} << exp_shift;
- // Stage 5: suma/resta y sticky
+ max_4 <= max_3;
+ int_sign <= max_3[$bits(max_3) - 1];
+
+ if (int_signed_4 & max_3[$bits(max_3) - 1])
+ max_4 <= -max_3;
+
+ // Stage 5: suma de mantisas
max_5 <= max_4;
slow_5 <= slow_4;
@@ -287,15 +309,22 @@ module gfx_fpint_lane
guard_5 <= guard_4;
round_5 <= round_4;
+ if (int_operand_5) begin
+ max_5.exp <= FLOAT_EXP_BIAS + INT_SHIFT_REF[$bits(float_exp) - 1:0];
+ max_5.sign <= int_sign;
+ end
+
if (copy_flags_5)
sticky_5 <= sticky_4;
else
sticky_5 <= |(extend_min_max(min_4, min_class_4) & ~sticky_mask);
- if (max_4.sign ^ min_4.sign)
- add_sub <= {1'b0, max_mant - min_mant};
+ if (int_operand_5)
+ add_sub <= max_4;
+ else if (max_4.sign ^ min_4.sign)
+ add_sub <= fp_add_sub_arg(max_mant) - fp_add_sub_arg(min_mant);
else
- add_sub <= {1'b0, max_mant} + {1'b0, min_mant};
+ add_sub <= fp_add_sub_arg(max_mant) + fp_add_sub_arg(min_mant);
// Stages 6-9: clz
diff --git a/platform/wavelet3d/main.cpp b/platform/wavelet3d/main.cpp
index 1bffb68..037aee4 100644
--- a/platform/wavelet3d/main.cpp
+++ b/platform/wavelet3d/main.cpp
@@ -23,18 +23,41 @@ int main(int argc, char **argv)
Py_Initialize();
- float a, b;
+ float q;
+ int a, b;
+ const char *op = "int->fp";
+
std::cin >> a >> b;
+ // int->fp
+ top.mul_float_m1 = 0;
+ top.unit_b_m1 = 1;
+ top.float_a_1 = 0;
+ top.int_hi_a_1 = 0;
+ top.int_lo_a_1 = 1;
+ top.zero_flags_1 = 1;
+ top.zero_b_1 = 1;
+ top.copy_flags_2 = 0;
+ top.int_signed_4 = 1;
+ top.int_operand_5 = 1;
+ top.copy_flags_5 = 1;
+ top.enable_norm_6 = 1;
+ top.copy_flags_10 = 0;
+ top.copy_flags_11 = 0;
+ top.enable_round_11 = 1;
+ top.encode_special_13 = 1;
+
// mul int
//top.mul_float_m1 = 0;
//top.unit_b_m1 = 0;
//top.float_a_1 = 0;
//top.int_hi_a_1 = 0;
//top.int_lo_a_1 = 1;
- //top.zero_flags_a_1 = 1;
+ //top.zero_flags_1 = 1;
//top.zero_b_1 = 1;
//top.copy_flags_2 = 1;
+ //top.int_signed_4 = 0;
+ //top.int_operand_5 = 0;
//top.copy_flags_5 = 1;
//top.enable_norm_6 = 0;
//top.copy_flags_10 = 1;
@@ -48,10 +71,12 @@ int main(int argc, char **argv)
//top.float_a_1 = 1;
//top.int_hi_a_1 = 0;
//top.int_lo_a_1 = 0;
- //top.zero_flags_a_1 = 0;
+ //top.zero_flags_1 = 0;
//top.zero_b_1 = 1;
//top.copy_flags_2 = 1;
//top.copy_flags_5 = 1;
+ //top.int_signed_4 = 0;
+ //top.int_operand_5 = 0;
//top.enable_norm_6 = 1;
//top.copy_flags_10 = 1;
//top.copy_flags_11 = 1;
@@ -59,20 +84,22 @@ int main(int argc, char **argv)
//top.encode_special_13 = 1;
// suma/resta
- top.mul_float_m1 = 0;
- top.unit_b_m1 = 1;
- top.float_a_1 = 0;
- top.int_hi_a_1 = 0;
- top.int_lo_a_1 = 1;
- top.zero_flags_a_1 = 0;
- top.zero_b_1 = 0;
- top.copy_flags_2 = 0;
- top.copy_flags_5 = 0;
- top.enable_norm_6 = 1;
- top.copy_flags_10 = 0;
- top.copy_flags_11 = 0;
- top.enable_round_11 = 1;
- top.encode_special_13 = 1;
+ //top.mul_float_m1 = 0;
+ //top.unit_b_m1 = 1;
+ //top.float_a_1 = 0;
+ //top.int_hi_a_1 = 0;
+ //top.int_lo_a_1 = 1;
+ //top.zero_flags_1 = 0;
+ //top.zero_b_1 = 0;
+ //top.copy_flags_2 = 0;
+ //top.copy_flags_5 = 0;
+ //top.int_signed_4 = 0;
+ //top.int_operand_5 = 0;
+ //top.enable_norm_6 = 1;
+ //top.copy_flags_10 = 0;
+ //top.copy_flags_11 = 0;
+ //top.enable_round_11 = 1;
+ //top.encode_special_13 = 1;
top.a = *reinterpret_cast<unsigned*>(&a);
top.b = *reinterpret_cast<unsigned*>(&b);
@@ -85,8 +112,10 @@ int main(int argc, char **argv)
top.eval();
}
- unsigned q = top.q;
- std::cout << a << " * " << b << " = " << *reinterpret_cast<decltype(a)*>(&q) << '\n';
+ unsigned q_bits = top.q;
+ q = *reinterpret_cast<decltype(q)*>(&q_bits);
+
+ std::cout << a << ' ' << op << ' ' << b << " = " << q << '\n';
bool failed = Py_FinalizeEx() < 0;