diff --git a/PyITA/ITA.py b/PyITA/ITA.py index 0281be2..6d35fa2 100644 --- a/PyITA/ITA.py +++ b/PyITA/ITA.py @@ -598,6 +598,8 @@ def step4_QK(self, no_partial_softmax, mask, index): else: raise ValueError("Mask not supported") + print(self.Mask) + matrix = np.squeeze(self.A_requant) plt.imshow(matrix, cmap='viridis') plt.colorbar() @@ -689,10 +691,10 @@ def step6_O(self): self.Out_soft_requant = requantize(self.Out_soft, self.requant_eps_mult[5], self.requant_right_shift[5], self.requant_add[5]) - matrix = np.squeeze(self.Out_soft) + matrix = np.squeeze(self.Out_soft_requant) plt.imshow(matrix, cmap='viridis') plt.colorbar() - plt.title("Out_soft") + plt.title("Out_soft_requant") plt.show() if (self.S_ITA - self.S) > 0: diff --git a/PyITA/softmax.py b/PyITA/softmax.py index c4e5475..996c14a 100644 --- a/PyITA/softmax.py +++ b/PyITA/softmax.py @@ -129,7 +129,7 @@ def streamingPartialSoftmax(x, mask, integerize = True): print(f"Global Max: {global_max.shape}") print(global_max) - print(f"Global Max: {current_max.shape}") + print(f"Current Max: {current_max.shape}") print(current_max) # Update all shift values where new maximum is larger @@ -189,6 +189,10 @@ def streamingPartialSoftmax(x, mask, integerize = True): # Find the difference between the maximum and x diff = np.repeat(global_max, seq_length).reshape(n_heads, seq_length, seq_length) - x.astype(np.int32) + # The global_max can be smaller than a few positions in x because not all values in x were considered for the global_max due to the mask. + # So diff should normally not be smaller than 0 + diff[mask] = 0 + # Shift the values by B-log2B -> multiply by B/2**B = log2e*eps_x # Make sure to do use round-half-up instead of round-half-to-even if integerize: diff --git a/modelsim/sim_ita_tb_wave.tcl b/modelsim/sim_ita_tb_wave.tcl index 6b4ff01..cf72515 100644 --- a/modelsim/sim_ita_tb_wave.tcl +++ b/modelsim/sim_ita_tb_wave.tcl @@ -19,13 +19,10 @@ add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/ add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_controller/mask_count_q1 add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_controller/mask_count_q2 add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_controller/mask_count_q3 -add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_controller/mask_col_offset_q add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_softmax_top/i_softmax/max_o add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_softmax_top/i_softmax/exp_sum_d add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_softmax_top/i_softmax/exp_sum_q add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_softmax_top/i_softmax/disable_row -add wave -noupdate -expand -group {Masking Signals} -radix binary /ita_tb/dut/i_softmax_top/i_softmax/disable_col -add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_controller/step_q add wave -noupdate -expand -group {Masking Signals} -group {Mask Tile Pos} -radix unsigned /ita_tb/dut/i_controller/mask_tile_x_pos_d add wave -noupdate -expand -group {Masking Signals} -group {Mask Tile Pos} -radix unsigned /ita_tb/dut/i_controller/mask_tile_x_pos_q add wave -noupdate -expand -group {Masking Signals} -group {Mask Tile Pos} -radix unsigned /ita_tb/dut/i_controller/mask_tile_y_pos_d @@ -36,16 +33,7 @@ add wave -noupdate -expand -group {Masking Signals} -group {Mask Tile Pos} -radi add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_controller/count_q add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_controller/mask_d add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_controller/mask_pos_q -add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_inp2_mux/clk_i -add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/inp_stream_soft_o -add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/count_soft_q1 -add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/count_soft_q2 -add wave -noupdate -expand -group {Masking Signals} -radix hexadecimal /ita_tb/dut/i_softmax_top/i_softmax/calc_stream_soft_en_q -add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en -add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q1 -add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q2 -add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q3 -add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q4 +add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_controller/mask_col_offset_q add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q5 add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q6 add wave -noupdate /ita_tb/dut/calc_en_q7 @@ -58,8 +46,49 @@ add wave -noupdate -group Bias /ita_tb/dut/inp_bias add wave -noupdate -group Bias /ita_tb/dut/inp_bias_padded add wave -noupdate -group Bias /ita_tb/dut/inp_bias_q1 add wave -noupdate -group Bias /ita_tb/dut/inp_bias_q2 +add wave -noupdate /ita_tb/dut/calc_en_q4 +add wave -noupdate -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/count_soft_q1 +add wave -noupdate -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/count_soft_mask_q +add wave -noupdate -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/count_soft_q2 +add wave -noupdate -radix binary /ita_tb/dut/i_softmax_top/i_softmax/disable_col +add wave -noupdate /ita_tb/dut/i_inp2_mux/clk_i +add wave -noupdate /ita_tb/dut/i_controller/step_q +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/step_i +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_i +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_d +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_q1 +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_q2 +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_q3 +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/mask_i +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/max_i +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/max_o +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/count_d +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/count_q1 +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/count_q2 +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/count_q3 +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/count_q4 +add wave -noupdate /ita_tb/dut/calc_en +add wave -noupdate /ita_tb/dut/calc_en_q1 +add wave -noupdate /ita_tb/dut/calc_en_q2 +add wave -noupdate /ita_tb/dut/calc_en_q3 +add wave -noupdate /ita_tb/dut/calc_en_q4 +add wave -noupdate /ita_tb/dut/calc_en_q5 +add wave -noupdate /ita_tb/dut/calc_en_q6 +add wave -noupdate /ita_tb/dut/calc_en_q7 +add wave -noupdate /ita_tb/dut/calc_en_q8 +add wave -noupdate /ita_tb/dut/calc_en_q9 +add wave -noupdate /ita_tb/dut/calc_en_q10 +add wave -noupdate /ita_tb/dut/i_softmax_top/i_softmax/calc_stream_soft_en_i +add wave -noupdate -radix hexadecimal /ita_tb/dut/i_softmax_top/i_softmax/calc_stream_soft_en_q +add wave -noupdate -radix binary /ita_tb/dut/i_softmax_top/i_softmax/disable_col +add wave -noupdate /ita_tb/dut/i_inp1_mux/inp_i +add wave -noupdate /ita_tb/dut/inp +add wave -noupdate -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/inp_stream_soft_o +add wave -noupdate /ita_tb/dut/inp1 +add wave -noupdate /ita_tb/dut/inp1_q add wave -noupdate /ita_tb/dut/i_accumulator/oup_i add wave -noupdate /ita_tb/dut/i_accumulator/result_d +add wave -noupdate /ita_tb/dut/i_accumulator/result_o add wave -noupdate /ita_tb/dut/i_activation/data_i add wave -noupdate /ita_tb/dut/i_activation/data_q1 add wave -noupdate /ita_tb/dut/i_activation/data_q2 @@ -72,4 +101,4 @@ add wave -noupdate /ita_tb/dut/oup_o add wave -noupdate -group Requantizer /ita_tb/dut/i_requantizer/* add wave -expand -group Controller /ita_tb/dut/i_controller/* add wave -group {Softmax Controller} ita_tb/dut/i_softmax_top/i_softmax/* -add wave -group {Accumulator} ita_tb/dut/i_accumulator/* +add wave -group {Accumulator} ita_tb/dut/i_accumulator/* \ No newline at end of file diff --git a/modelsim/sim_ita_tb_wave_important.tcl b/modelsim/sim_ita_tb_wave_important.tcl index ddd5579..509e694 100644 --- a/modelsim/sim_ita_tb_wave_important.tcl +++ b/modelsim/sim_ita_tb_wave_important.tcl @@ -17,13 +17,10 @@ add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/ add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_controller/mask_count_q1 add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_controller/mask_count_q2 add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_controller/mask_count_q3 -add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_controller/mask_col_offset_q add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_softmax_top/i_softmax/max_o add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_softmax_top/i_softmax/exp_sum_d add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_softmax_top/i_softmax/exp_sum_q add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_softmax_top/i_softmax/disable_row -add wave -noupdate -expand -group {Masking Signals} -radix binary /ita_tb/dut/i_softmax_top/i_softmax/disable_col -add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_controller/step_q add wave -noupdate -expand -group {Masking Signals} -group {Mask Tile Pos} -radix unsigned /ita_tb/dut/i_controller/mask_tile_x_pos_d add wave -noupdate -expand -group {Masking Signals} -group {Mask Tile Pos} -radix unsigned /ita_tb/dut/i_controller/mask_tile_x_pos_q add wave -noupdate -expand -group {Masking Signals} -group {Mask Tile Pos} -radix unsigned /ita_tb/dut/i_controller/mask_tile_y_pos_d @@ -34,16 +31,7 @@ add wave -noupdate -expand -group {Masking Signals} -group {Mask Tile Pos} -radi add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_controller/count_q add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_controller/mask_d add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_controller/mask_pos_q -add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/i_inp2_mux/clk_i -add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/inp_stream_soft_o -add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/count_soft_q1 -add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/count_soft_q2 -add wave -noupdate -expand -group {Masking Signals} -radix hexadecimal /ita_tb/dut/i_softmax_top/i_softmax/calc_stream_soft_en_q -add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en -add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q1 -add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q2 -add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q3 -add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q4 +add wave -noupdate -expand -group {Masking Signals} -radix unsigned /ita_tb/dut/i_controller/mask_col_offset_q add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q5 add wave -noupdate -expand -group {Masking Signals} /ita_tb/dut/calc_en_q6 add wave -noupdate /ita_tb/dut/calc_en_q7 @@ -56,8 +44,49 @@ add wave -noupdate -group Bias /ita_tb/dut/inp_bias add wave -noupdate -group Bias /ita_tb/dut/inp_bias_padded add wave -noupdate -group Bias /ita_tb/dut/inp_bias_q1 add wave -noupdate -group Bias /ita_tb/dut/inp_bias_q2 +add wave -noupdate /ita_tb/dut/calc_en_q4 +add wave -noupdate -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/count_soft_q1 +add wave -noupdate -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/count_soft_mask_q +add wave -noupdate -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/count_soft_q2 +add wave -noupdate -radix binary /ita_tb/dut/i_softmax_top/i_softmax/disable_col +add wave -noupdate /ita_tb/dut/i_inp2_mux/clk_i +add wave -noupdate /ita_tb/dut/i_controller/step_q +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/step_i +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_i +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_d +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_q1 +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_q2 +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_q3 +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/mask_i +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/max_i +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/max_o +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/count_d +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/count_q1 +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/count_q2 +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/count_q3 +add wave -noupdate -expand -group {In Softmax} /ita_tb/dut/i_softmax_top/i_softmax/count_q4 +add wave -noupdate /ita_tb/dut/calc_en +add wave -noupdate /ita_tb/dut/calc_en_q1 +add wave -noupdate /ita_tb/dut/calc_en_q2 +add wave -noupdate /ita_tb/dut/calc_en_q3 +add wave -noupdate /ita_tb/dut/calc_en_q4 +add wave -noupdate /ita_tb/dut/calc_en_q5 +add wave -noupdate /ita_tb/dut/calc_en_q6 +add wave -noupdate /ita_tb/dut/calc_en_q7 +add wave -noupdate /ita_tb/dut/calc_en_q8 +add wave -noupdate /ita_tb/dut/calc_en_q9 +add wave -noupdate /ita_tb/dut/calc_en_q10 +add wave -noupdate /ita_tb/dut/i_softmax_top/i_softmax/calc_stream_soft_en_i +add wave -noupdate -radix hexadecimal /ita_tb/dut/i_softmax_top/i_softmax/calc_stream_soft_en_q +add wave -noupdate -radix binary /ita_tb/dut/i_softmax_top/i_softmax/disable_col +add wave -noupdate /ita_tb/dut/i_inp1_mux/inp_i +add wave -noupdate /ita_tb/dut/inp +add wave -noupdate -radix unsigned /ita_tb/dut/i_softmax_top/i_softmax/inp_stream_soft_o +add wave -noupdate /ita_tb/dut/inp1 +add wave -noupdate /ita_tb/dut/inp1_q add wave -noupdate /ita_tb/dut/i_accumulator/oup_i add wave -noupdate /ita_tb/dut/i_accumulator/result_d +add wave -noupdate /ita_tb/dut/i_accumulator/result_o add wave -noupdate /ita_tb/dut/i_activation/data_i add wave -noupdate /ita_tb/dut/i_activation/data_q1 add wave -noupdate /ita_tb/dut/i_activation/data_q2 @@ -227,6 +256,8 @@ add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/exp_sum_d add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/exp_sum_q add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_soft_d +add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_soft_q1 +add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_soft_q2 add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_div_d add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/count_div_q add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/addr_div_d @@ -246,6 +277,7 @@ add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/max_diff add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/shift_inp add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/shift_inp_diff +add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_stream_soft_en_q add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_d add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_q1 add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/calc_en_q2 @@ -259,7 +291,6 @@ add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/fifo_usage add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/disable_shift add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/disable_row -add wave -noupdate -expand -group {Softmax Controller} /ita_tb/dut/i_softmax_top/i_softmax/disable_col add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/clk_i add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/rst_ni add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/calc_en_i @@ -285,10 +316,10 @@ add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/write_addr_q add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/result_d add wave -noupdate -group Accumulator /ita_tb/dut/i_accumulator/result_q TreeUpdate [SetDefaultTree] -WaveRestoreCursors {{Cursor 1} {5371458 ps} 0} {{Cursor 2} {4813035 ps} 1} {{Cursor 3} {4817000 ps} 1} {{Cursor 4} {4829000 ps} 0} -quietly wave cursor active 4 -configure wave -namecolwidth 195 -configure wave -valuecolwidth 135 +WaveRestoreCursors {{Cursor 1} {4842600 ps} 1} {{Cursor 2} {4823000 ps} 1} {{Cursor 4} {4816946 ps} 1} {{Cursor 12} {3614999 ps} 1} {{Cursor 13} {3617010 ps} 1} {{Cursor 14} {3645847 ps} 1} {{Cursor 15} {3624942 ps} 1} {{Cursor 16} {5124600 ps} 1} +quietly wave cursor active 8 +configure wave -namecolwidth 167 +configure wave -valuecolwidth 100 configure wave -justifyvalue left configure wave -signalnamewidth 1 configure wave -snapdistance 10 @@ -301,4 +332,4 @@ configure wave -griddelta 40 configure wave -timeline 0 configure wave -timelineunits ns update -WaveRestoreZoom {4812915 ps} {4830108 ps} +WaveRestoreZoom {5079893 ps} {5130074 ps} diff --git a/src/ita_controller.sv b/src/ita_controller.sv index ccd32ba..199782d 100644 --- a/src/ita_controller.sv +++ b/src/ita_controller.sv @@ -97,8 +97,8 @@ module ita_controller requant_add_d = {N {requant_add_i}}; last_time = 1'b0; inp_bias = inp_bias_i; - mask_col_offset_d = (step_q == QK) ? mask_col_offset_q : ((ctrl_i.mask_start_index-1) & (N-1)); - mask_pos_d = (step_q == QK) ? mask_pos_q : (((ctrl_i.mask_start_index-1)/N)*M); + mask_col_offset_d = (step_q == QK) ? mask_col_offset_q : ((ctrl_i.mask_start_index) & (N-1)); + mask_pos_d = (step_q == QK) ? mask_pos_q : (((ctrl_i.mask_start_index)/N)*M); mask_tile_x_pos_d = mask_tile_x_pos_q; mask_tile_y_pos_d = mask_tile_y_pos_q; mask_d = mask_q; @@ -421,7 +421,7 @@ module ita_controller mask_pos_d = (mask_pos_q + (N - ((mask_pos_q + mask_col_offset_q) & (N-1))) + M) & ((M*M/N)-1); end for (int i = 0; i < N; i++) begin - if (((count_q + mask_col_offset_q) & (N-1)) >= i) begin + if (((count_q + mask_col_offset_q) & (N-1)) <= i) begin mask_d[i] = 1'b1; end else begin mask_d[i] = 1'b0; diff --git a/src/ita_softmax.sv b/src/ita_softmax.sv index 528b38e..26c9611 100644 --- a/src/ita_softmax.sv +++ b/src/ita_softmax.sv @@ -52,7 +52,7 @@ module ita_softmax counter_t tile_y_q; logic unsigned [SoftmaxAccDataWidth-1:0] exp_sum_d, exp_sum_q; - counter_t count_soft_d, count_soft_q1, count_soft_q2; + counter_t count_soft_d, count_soft_q1, count_soft_q2, count_soft_mask_q; counter_t count_div_d, count_div_q, addr_div_d, addr_div_q; logic [NumDiv-1:0] div_read_d, div_read_q, div_write_d, div_write_q; @@ -248,7 +248,8 @@ module ita_softmax disable_col[i] = ((inner_tile_q*M + i) >= ctrl_i.seq_length); if ((inner_tile_q*M + i) >= ctrl_i.seq_length) begin disable_col[i] = 1'b1; - end else if ((i >= (count_soft_q1 & (M-1))) && (ctrl_i.mask_type == UpperTriangular)) begin + // This logic needs to be replaced + end else if ((i >= ((count_soft_mask_q & (M-1)) + (ctrl_i.mask_start_index & (M-1)))) && (ctrl_i.mask_type == UpperTriangular)) begin disable_col[i] = 1'b1; end else begin disable_col[i] = 1'b0; @@ -282,6 +283,7 @@ module ita_softmax count_q1 <= M*M/N; count_soft_q1 <= '0; count_soft_q2 <= '0; + count_soft_mask_q <= '0; count_div_q <= '0; div_read_q <= '0; div_write_q <= '0; @@ -303,8 +305,10 @@ module ita_softmax count_q3 <= count_q2; count_q2 <= count_q1; count_q1 <= count_d; - count_soft_q1 <= count_soft_d; - count_soft_q2 <= count_soft_q1; + count_soft_q1 <= count_soft_d; + count_soft_q2 <= count_soft_q1; + if (calc_stream_soft_en_i) + count_soft_mask_q <= count_soft_q1; count_div_q <= count_div_d; div_read_q <= div_read_d; div_write_q <= div_write_d;