Skip to content

Commit

Permalink
Made new masking module
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel Kant committed Dec 12, 2024
1 parent b3068b1 commit 9dbbd8e
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 2,064 deletions.
1 change: 1 addition & 0 deletions Bender.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ sources:
# Individual source files are simple string entries:
- src/ita_package.sv
- src/ita_accumulator.sv
- src/ita_masking.sv
- src/ita_controller.sv
- src/ita_dotp.sv
- src/ita_fifo_controller.sv
Expand Down
1,444 changes: 0 additions & 1,444 deletions PyITA/ITA.py

Large diffs are not rendered by default.

449 changes: 1 addition & 448 deletions modelsim/sim_ita_tb_wave.tcl

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions src/ita.sv
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,7 @@ module ita
.inp_bias_i (inp_bias ),
.inp_bias_pad_o (inp_bias_padded ),
.mask_o (mask ),
.busy_o (busy_o ),
.calc_en_q1_i (calc_en_q1 )
.busy_o (busy_o )
);

ita_input_sampler i_input_sampler (
Expand Down
188 changes: 23 additions & 165 deletions src/ita_controller.sv
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,14 @@ module ita_controller
input bias_t inp_bias_i ,
output bias_t inp_bias_pad_o ,
output logic [N-1:0] mask_o ,
output logic busy_o ,
input logic calc_en_q1_i
output logic busy_o
);

step_e step_d, step_q;
counter_t count_d, count_q, bias_count;
counter_t mask_pos_d, mask_pos_q;
logic [3:0] mask_col_offset_d, mask_col_offset_q;

counter_t tile_d, tile_q;
counter_t inner_tile_d, inner_tile_q;
counter_t mask_tile_x_pos_d, mask_tile_x_pos_q;
counter_t mask_tile_y_pos_d, mask_tile_y_pos_q;
counter_t tile_x_d, tile_x_q, bias_tile_x_d, bias_tile_x_q;
counter_t tile_y_d, tile_y_q, bias_tile_y_d, bias_tile_y_q;
counter_t softmax_tile_d, softmax_tile_q;
Expand All @@ -56,14 +52,12 @@ module ita_controller

bias_t inp_bias, inp_bias_padded;
logic last_time;
logic [N-1:0] mask_d, mask_q;

tile_t inner_tile_dim;
logic [WO-WI*2-2:0] first_outer_dim, second_outer_dim;
logic [WO-WI*2-2:0] first_outer_dim_d, first_outer_dim_q;
logic [WO-WI*2-2:0] second_outer_dim_d, second_outer_dim_q;
input_dim_t first_outer_dim, second_outer_dim;
input_dim_t first_outer_dim_d, first_outer_dim_q;
input_dim_t second_outer_dim_d, second_outer_dim_q;


logic softmax_fifo, softmax_div, softmax_div_done_d, softmax_div_done_q, busy_d, busy_q;
requant_oup_t requant_add, requant_add_d, requant_add_q;

Expand All @@ -74,7 +68,7 @@ module ita_controller
assign inner_tile_o = inner_tile_q;
assign requant_add_o = requant_add_q;
assign inp_bias_pad_o = inp_bias_padded;
assign mask_o = mask_q;


always_comb begin
count_d = count_q;
Expand All @@ -96,10 +90,9 @@ module ita_controller
last_time = 1'b0;
requant_add = {N {requant_add_i}};
inp_bias = inp_bias_i;

busy_d = busy_q;
softmax_fifo = 1'b0;
softmax_div = 1'b0;
busy_d = busy_q;
softmax_fifo = 1'b0;
softmax_div = 1'b0;

if (step_q != AV) begin
softmax_div_done_d = 1'b0;
Expand Down Expand Up @@ -390,143 +383,6 @@ module ita_controller
end
inp_bias_padded = inp_bias;

case (ctrl_i.mask_type)
None: begin
mask_col_offset_d = '0;
mask_tile_x_pos_d = '0;
mask_tile_y_pos_d = '0;
mask_pos_d = '0;
mask_d = '0;
end
UpperTriangular: begin
mask_col_offset_d = (step_q == QK || step_q == AV) ? mask_col_offset_q : ((ctrl_i.mask_start_index) & (N-1));
mask_tile_x_pos_d = (step_q == QK || step_q == AV) ? mask_tile_x_pos_q : ((ctrl_i.mask_start_index) / M);
mask_tile_y_pos_d = mask_tile_y_pos_q;
mask_pos_d = (step_q == QK || step_q == AV) ? mask_pos_q : ((((ctrl_i.mask_start_index)/N)*M) & ((M*M/N)-1));
mask_d = '0;

if (step_q == QK) begin
if (mask_tile_x_pos_q == tile_x_q && mask_tile_y_pos_q == tile_y_q && last_inner_tile_o == 1'b1) begin
if (count_q == ((M*M/N)-1)) begin
mask_tile_x_pos_d = mask_tile_x_pos_q + 1'b1;
end
if ((count_q >= mask_pos_q) && (count_q < (mask_pos_q + N))) begin
if ((count_q & (M-1)) == (M-1) && !(((count_q + mask_col_offset_q) & (N-1)) == (N-1))) begin
mask_tile_y_pos_d = tile_y_q + 1'b1;
mask_tile_x_pos_d = tile_x_q;
mask_pos_d = ((count_q + (((ctrl_i.tile_s * (M*M/N)) - M) + 1)) & ((M*M/N)-1));
end else if ((count_q & (M-1)) == (M-1) && (((count_q + mask_col_offset_q) & (N-1)) == (N-1))) begin
if ((count_q / M) == ((M/N)-1)) begin
mask_tile_y_pos_d = tile_y_q + 1'b1;
mask_tile_x_pos_d = tile_x_q + 1'b1;
mask_pos_d = ((count_q + ((ctrl_i.tile_s * (M*M/N)) + 1)) & ((M*M/N)-1));
end else begin
mask_tile_y_pos_d = tile_y_q + 1'b1;
mask_tile_x_pos_d = tile_x_q;
mask_pos_d = ((count_q + ((ctrl_i.tile_s * (M*M/N)) + 1)) & ((M*M/N)-1));
end
end else if (((count_q + mask_col_offset_q) & (N-1)) == (N-1)) begin
mask_pos_d = (mask_pos_q + (N - ((mask_pos_q + mask_col_offset_q) & (N-1))) + M) & ((M*M/N)-1);
end
for (int i = 0; i < N; i++) begin
if (((count_q + mask_col_offset_q) & (N-1)) <= i) begin
mask_d[i] = 1'b1;
end else begin
mask_d[i] = 1'b0;
end
end
end else if ((count_q & (M-1)) < (mask_pos_q & (M-1))) begin
for (int i = 0; i < N; i++) begin
mask_d[i] = 1'b1;
end
end
end else if (mask_tile_x_pos_q <= tile_x_q && mask_tile_y_pos_q != tile_y_q && last_inner_tile_o == 1'b1) begin
for (int i = 0; i < N; i++) begin
mask_d[i] = 1'b1;
end
end else if (mask_tile_x_pos_q != tile_x_q && mask_tile_y_pos_q == tile_y_q && last_inner_tile_o == 1'b1) begin
for (int i = 0; i < N; i++) begin
mask_d[i] = 1'b0;
end
end
end
end
LowerTriangular: begin
mask_col_offset_d = '0;
mask_tile_x_pos_d = '0;
mask_tile_y_pos_d = (step_q == QK || step_q == AV) ? mask_tile_y_pos_q : ((ctrl_i.mask_start_index) / M);
mask_pos_d = (step_q == QK || step_q == AV) ? mask_pos_q : (ctrl_i.mask_start_index & (M-1));
mask_d = '0;

if (step_q == QK) begin
if (mask_tile_x_pos_q == tile_x_q && mask_tile_y_pos_q == tile_y_q && last_inner_tile_o == 1'b1) begin
if (count_q == ((M*M/N)-1)) begin
mask_tile_x_pos_d = mask_tile_x_pos_q + 1'b1;
end
if ((count_q >= mask_pos_q) && (count_q < (mask_pos_q + N))) begin
if (((count_q & (M-1)) == (M-1)) && !(((count_q + (N - (ctrl_i.mask_start_index & (N-1)))) & (N-1)) == (N-1))) begin
mask_tile_y_pos_d = tile_y_q + 1'b1;
mask_tile_x_pos_d = tile_x_q;
mask_pos_d = ((count_q + (((ctrl_i.tile_s * (M*M/N)) - M) + 1)) & ((M*M/N)-1));
end else if (((count_q & (M-1)) == (M-1)) && (((count_q + (N - (ctrl_i.mask_start_index & (N-1)))) & (N-1)) == (N-1))) begin
if ((count_q / M) == ((M/N)-1)) begin
mask_tile_y_pos_d = tile_y_q + 1'b1;
mask_tile_x_pos_d = tile_x_q + 1'b1;
mask_pos_d = ((count_q + ((ctrl_i.tile_s * (M*M/N)) + 1)) & ((M*M/N)-1));
end else begin
mask_tile_y_pos_d = tile_y_q + 1'b1;
mask_tile_x_pos_d = tile_x_q;
mask_pos_d = ((count_q + ((ctrl_i.tile_s * (M*M/N)) + 1)) & ((M*M/N)-1));
end
end else if (((count_q + (N - (ctrl_i.mask_start_index & (N-1)))) & (N-1)) == (N-1)) begin
mask_pos_d = (mask_pos_q + (count_q - mask_pos_q + 1) + M) & ((M*M/N)-1);
end
for (int i = 0; i < N; i++) begin
if (((count_q + (N - (ctrl_i.mask_start_index & (N-1)))) & (N-1)) >= i) begin
mask_d[i] = 1'b1;
end else begin
mask_d[i] = 1'b0;
end
end
end else if ((count_q & (M-1)) >= (mask_pos_q & (M-1))) begin
for (int i = 0; i < N; i++) begin
mask_d[i] = 1'b1;
end
end
end else if (mask_tile_x_pos_q > tile_x_q && mask_tile_y_pos_q == tile_y_q && last_inner_tile_o == 1'b1) begin
for (int i = 0; i < N; i++) begin
mask_d[i] = 1'b1;
end
end else if (mask_tile_x_pos_q >= tile_x_q && mask_tile_y_pos_q != tile_y_q && last_inner_tile_o == 1'b1) begin
for (int i = 0; i < N; i++) begin
mask_d[i] = 1'b0;
end
end
end
end
Strided: begin
mask_col_offset_d = '0;
mask_tile_x_pos_d = '0;
mask_tile_y_pos_d = '0;
mask_pos_d = '0;
mask_d = '0;

if (step_q == QK) begin
if (last_inner_tile_o == 1'b1) begin
for (int i = 0; i < N; i++) begin
//col_pos = count_q/M + i + mask_tile_x_pos_q * M
//row_pos = count_q & (M-1) + mask_tile_y_pos_q * M
if ((((((count_q / M) * N) + i + (tile_x_q * M)) - ((count_q & (M-1)) + (tile_y_q * M))) & (ctrl_i.mask_start_index-1)) == 0) begin
mask_d[i] = 1'b0;
end else begin
mask_d[i] = 1'b1;
end
end
end
end
end
endcase

if (inp_valid_i && inp_ready_o && oup_valid_i && oup_ready_i && last_inner_tile_o) begin
ongoing_d = ongoing_q;
end else if (inp_valid_i && inp_ready_o && last_inner_tile_o) begin
Expand Down Expand Up @@ -561,11 +417,6 @@ module ita_controller
bias_tile_y_q <= '0;
first_outer_dim_q <= '0;
second_outer_dim_q <= '0;
mask_pos_q <= '0;
mask_col_offset_q <= '0;
mask_tile_x_pos_q <= '0;
mask_tile_y_pos_q <= '0;
mask_q <= '0;
end else begin
step_q <= step_d;
count_q <= count_d;
Expand All @@ -583,13 +434,20 @@ module ita_controller
bias_tile_y_q <= bias_tile_y_d;
first_outer_dim_q <= first_outer_dim_d;
second_outer_dim_q <= second_outer_dim_d;
if (calc_en_o) begin
mask_pos_q <= mask_pos_d;
mask_tile_x_pos_q <= mask_tile_x_pos_d;
mask_tile_y_pos_q <= mask_tile_y_pos_d;
end
mask_q <= mask_d;
mask_col_offset_q <= mask_col_offset_d;
end
end

ita_masking i_masking (
.clk_i (clk_i),
.rst_ni (rst_ni),
.ctrl_i (ctrl_i),
.step_i (step_o),
.calc_en_i (calc_en_o),
.last_inner_tile_i (last_inner_tile_o),
.count_i (count_q),
.tile_x_i (tile_x_o),
.tile_y_i (tile_y_o),
.mask_o (mask_o)
);

endmodule
Loading

0 comments on commit 9dbbd8e

Please sign in to comment.