Skip to content

Commit

Permalink
Added functionality for the strided mask but not tested yet
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel Kant committed Dec 8, 2024
1 parent 5612e8d commit b3068b1
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 18 deletions.
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ ifeq ($(mask), upper_triangular)
mask_int = 1
else ifeq ($(mask), lower_triangular)
mask_int = 2
else ifeq ($(mask), strided)
mask_int = 3
else ifeq ($(mask), upper_strided)
mask_int = 4
else ifeq ($(mask), lower_strided)
mask_int = 5
else
mask_int = 0
endif
Expand Down
27 changes: 23 additions & 4 deletions PyITA/ITA.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,23 +579,42 @@ def step3_Vp(self):
self.tiler_V(self.V, self.Wv, self.Bv, self.Vp_requant, "V", "Wv", "Bv", "Vp")

def apply_mask(self, index):
self.Mask = np.full((self.H, self.S, self.S), fill_value=False, dtype='bool')

if (self.mask == 'upper_triangular'):
self.Mask = np.full((self.H, self.S, self.S), fill_value=False, dtype='bool')
if (0 < index and index < self.S):
for h in range(self.Mask.shape[0]):
for i in range(self.Mask.shape[1]):
for j in range((i + index), self.Mask.shape[2]):
self.Mask[h][i][j] = True
else:
raise ValueError("Index is out of bounds")
elif(self.mask == 'lower_triangular'):
raise ValueError(f"Index is out of bounds for {self.mask} mask")
elif (self.mask == 'lower_triangular'):
self.Mask = np.full((self.H, self.S, self.S), fill_value=False, dtype='bool')
if (0 < index and index < self.S):
for h in range(self.Mask.shape[0]):
for i in range(index, self.Mask.shape[1]):
for j in range((i-(index-1))):
self.Mask[h][i][j] = True
else:
raise ValueError("Index is out of bounds")
raise ValueError(f"Index is out of bounds for {self.mask} mask")
elif (self.mask == 'strided'):
self.Mask = np.full((self.H, self.S, self.S), fill_value=True, dtype='bool')
if (0 < index and index < self.S):
for h in range(self.Mask.shape[0]):
for i in range(self.Mask.shape[1]):
self.Mask[h][i][i] = False
for j in range(i, self.Mask.shape[2], index):
self.Mask[h][i][j] = False
self.Mask[h][j][i] = False
else:
raise ValueError(f"Index is out of bounds for {self.mask} mask")
elif (self.mask == 'upper_strided'):
pass
elif (self.mask == 'lower_strided'):
pass
elif (self.mask == 'lower_local'):
pass
elif(self.mask == 'none'):
pass
else:
Expand Down
37 changes: 28 additions & 9 deletions src/ita_controller.sv
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,7 @@ module ita_controller
softmax_div_done_d = softmax_div_done_q;
last_time = 1'b0;
requant_add = {N {requant_add_i}};
mask_col_offset_d = (step_q == QK || step_q == AV) ? mask_col_offset_q : ((ctrl_i.mask_start_index) & (N-1));
mask_pos_d = (step_q == QK || step_q == AV) ? mask_pos_q : ((((ctrl_i.mask_start_index)/N)*M) & ((M*M/N)-1));
mask_tile_x_pos_d = (step_q == QK || step_q == AV) ? mask_tile_x_pos_q : ((ctrl_i.mask_start_index) / M);
inp_bias = inp_bias_i;
mask_tile_y_pos_d = mask_tile_y_pos_q;
mask_d = mask_q;

busy_d = busy_q;
softmax_fifo = 1'b0;
Expand Down Expand Up @@ -395,8 +390,6 @@ module ita_controller
end
inp_bias_padded = inp_bias;


mask_d = '0;
case (ctrl_i.mask_type)
None: begin
mask_col_offset_d = '0;
Expand All @@ -407,8 +400,10 @@ module ita_controller
end
UpperTriangular: begin
mask_col_offset_d = (step_q == QK || step_q == AV) ? mask_col_offset_q : ((ctrl_i.mask_start_index) & (N-1));
mask_pos_d = (step_q == QK || step_q == AV) ? mask_pos_q : ((((ctrl_i.mask_start_index)/N)*M) & ((M*M/N)-1));
mask_tile_x_pos_d = (step_q == QK || step_q == AV) ? mask_tile_x_pos_q : ((ctrl_i.mask_start_index) / M);
mask_tile_y_pos_d = mask_tile_y_pos_q;
mask_pos_d = (step_q == QK || step_q == AV) ? mask_pos_q : ((((ctrl_i.mask_start_index)/N)*M) & ((M*M/N)-1));
mask_d = '0;

if (step_q == QK) begin
if (mask_tile_x_pos_q == tile_x_q && mask_tile_y_pos_q == tile_y_q && last_inner_tile_o == 1'b1) begin
Expand Down Expand Up @@ -457,8 +452,11 @@ module ita_controller
end
end
LowerTriangular: begin
mask_pos_d = (step_q == QK || step_q == AV) ? mask_pos_q : (ctrl_i.mask_start_index & (M-1));
mask_col_offset_d = '0;
mask_tile_x_pos_d = '0;
mask_tile_y_pos_d = (step_q == QK || step_q == AV) ? mask_tile_y_pos_q : ((ctrl_i.mask_start_index) / M);
mask_pos_d = (step_q == QK || step_q == AV) ? mask_pos_q : (ctrl_i.mask_start_index & (M-1));
mask_d = '0;

if (step_q == QK) begin
if (mask_tile_x_pos_q == tile_x_q && mask_tile_y_pos_q == tile_y_q && last_inner_tile_o == 1'b1) begin
Expand Down Expand Up @@ -506,6 +504,27 @@ module ita_controller
end
end
end
Strided: begin
mask_col_offset_d = '0;
mask_tile_x_pos_d = '0;
mask_tile_y_pos_d = '0;
mask_pos_d = '0;
mask_d = '0;

if (step_q == QK) begin
if (last_inner_tile_o == 1'b1) begin
for (int i = 0; i < N; i++) begin
//col_pos = count_q/M + i + mask_tile_x_pos_q * M
//row_pos = count_q & (M-1) + mask_tile_y_pos_q * M
if ((((((count_q / M) * N) + i + (tile_x_q * M)) - ((count_q & (M-1)) + (tile_y_q * M))) & (ctrl_i.mask_start_index-1)) == 0) begin
mask_d[i] = 1'b0;
end else begin
mask_d[i] = 1'b1;
end
end
end
end
end
endcase

if (inp_valid_i && inp_ready_o && oup_valid_i && oup_ready_i && last_inner_tile_o) begin
Expand Down
2 changes: 1 addition & 1 deletion src/ita_package.sv
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ package ita_package;
typedef logic signed [GELU_OUT_WIDTH-1:0] gelu_out_t;

// Masking
typedef enum {None=0, UpperTriangular=1, LowerTriangular=2} mask_e;
typedef enum {None=0, UpperTriangular=1, LowerTriangular=2, Strided=3, UpperStrided=4, LowerStrided=5} mask_e;
typedef logic [WO-WI*2-2:0] mask_index_t;

// IO
Expand Down
15 changes: 12 additions & 3 deletions src/ita_softmax.sv
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@ module ita_softmax
disable_col[i] = 1'b1;
end else begin
case (ctrl_i.mask_type)
None: begin
disable_col[i] = 1'b0;
end
UpperTriangular: begin
// (ctrl_i.mask_start_index / M) -> tile where the masking starts
if (mask_tile_x_q == mask_tile_y_q + (ctrl_i.mask_start_index / M)) begin
Expand Down Expand Up @@ -312,9 +315,15 @@ module ita_softmax
disable_col[i] = 1'b0;
end
end
None: begin
disable_col[i] = 1'b0;
end
Strided: begin
//col_pos = i + mask_tile_x_q * M
//row_pos = count_soft_mask_q & (M-1) + mask_tile_y_pos_q * M
if ((((i + (mask_tile_x_q * M)) - ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M))) & (ctrl_i.mask_start_index-1)) == 0) begin
disable_col[i] = 1'b0;
end else begin
disable_col[i] = 1'b1;
end
end
endcase
end

Expand Down
2 changes: 1 addition & 1 deletion testGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class ArgumentDefaultMetavarTypeFormatter(argparse.ArgumentDefaultsHelpFormatter
default = 'none',
type = str,
help = 'Attention-Mask',
choices = ['none', 'upper_triangular', 'lower_triangular'])
choices = ['none', 'upper_triangular', 'lower_triangular', 'strided', 'upper_strided', 'lower_strided'])
self.group1.add_argument('-I', default = 1, type = int, help = 'Masking starting index')
self.group1.add_argument('--no-partial-softmax',
action = 'store_true',
Expand Down

0 comments on commit b3068b1

Please sign in to comment.