Skip to content

Commit

Permalink
Added upper and lower strided masks
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel Kant committed Dec 13, 2024
1 parent 27a0e37 commit 343b6a9
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 22 deletions.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ else ifeq ($(mask), upper_strided)
mask_int = 4
else ifeq ($(mask), lower_strided)
mask_int = 5
else ifeq ($(mask), sliding_window_attention)
mask_int = 6
else ifeq ($(mask), strided_sliding_window_attention)
mask_int = 7
else
mask_int = 0
endif
Expand Down
57 changes: 42 additions & 15 deletions PyITA/ITA.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,32 +614,59 @@ def apply_mask(self, index):
elif (self.mask == 'upper_strided'):
self.Mask = np.full((self.H, self.S, self.S), fill_value=True, dtype='bool')
if (0 < index and index < self.S):
for h in range(self.Mask.shape[0]):
for i in range(self.Mask.shape[1]):
for j in range(i, self.Mask.shape[2], index):
self.Mask[h][i][j] = False
if (index % 2 == 0):
for h in range(self.Mask.shape[0]):
for i in range(self.Mask.shape[1]):
for j in range(i, self.Mask.shape[2], index):
self.Mask[h][i][j] = False
else:
raise ValueError(f"Index has to be a power of two for {self.mask} mask")
else:
raise ValueError(f"Index is out of bounds for {self.mask} mask")
elif (self.mask == 'lower_strided'):
self.Mask = np.full((self.H, self.S, self.S), fill_value=True, dtype='bool')
if (0 < index and index < self.S):
for h in range(self.Mask.shape[0]):
for i in range(self.Mask.shape[1]):
for j in range(i, self.Mask.shape[2], index):
self.Mask[h][j][i] = False
if (index % 2 == 0):
for h in range(self.Mask.shape[0]):
for i in range(self.Mask.shape[1]):
for j in range(i, self.Mask.shape[2], index):
self.Mask[h][j][i] = False
else:
raise ValueError(f"Index has to be a power of two for {self.mask} mask")
else:
raise ValueError(f"Index is out of bounds for {self.mask} mask")
elif (self.mask == 'sliding_window_attention'):
elif (self.mask == 'sliding_window'):
self.Mask = np.full((self.H, self.S, self.S), fill_value=True, dtype='bool')
if (0 < index and index < self.S):
for h in range(self.Mask.shape[0]):
for i in range(self.Mask.shape[1]):
for j in range(i, (index + i)):
self.Mask[h][i][j] = False
self.Mask[h][j][i] = False
if (index % 2 == 0):
for h in range(self.Mask.shape[0]):
for i in range(self.Mask.shape[1]):
for j in range(i, (index + i)):
self.Mask[h][i][j] = False
self.Mask[h][j][i] = False
else:
raise ValueError(f"Index has to be a power of two for {self.mask} mask")
else:
raise ValueError(f"Index is out of bounds for {self.mask} mask")
elif (self.mask == 'strided_sliding_window'):
self.Mask = np.full((self.H, self.S, self.S), fill_value=True, dtype='bool')
if (0 < index and index < self.S):
if (index % 2 == 0):
for h in range(self.Mask.shape[0]):
for i in range(self.Mask.shape[1]):
for j in range(i, self.Mask.shape[2]):
if (j > (index + i)):
if (j % index == 0):
self.Mask[h][i][j] = False
self.Mask[h][j][i] = False
else:
self.Mask[h][i][j] = False
self.Mask[h][j][i] = False
else:
raise ValueError(f"Index has to be a power of two for {self.mask} mask")
else:
raise ValueError(f"Index is out of bounds for {self.mask} mask")
elif(self.mask == 'none'):
elif (self.mask == 'none'):
pass
else:
raise ValueError("Mask not supported")
Expand Down
61 changes: 56 additions & 5 deletions src/ita_masking.sv
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ module ita_masking
assign mask_o = mask_q;

always_comb begin



case (ctrl_i.mask_type)
None: begin
mask_col_offset_d = '0;
Expand All @@ -41,9 +44,9 @@ module ita_masking
UpperTriangular: begin
mask_col_offset_d = (step_i == QK || step_i == AV) ? mask_col_offset_q : ((ctrl_i.mask_start_index) & (N-1));
mask_tile_x_pos_d = (step_i == QK || step_i == AV) ? mask_tile_x_pos_q : ((ctrl_i.mask_start_index) / M);
mask_tile_y_pos_d = mask_tile_y_pos_q;
mask_tile_y_pos_d = mask_tile_y_pos_q;
mask_pos_d = (step_i == QK || step_i == AV) ? mask_pos_q : ((((ctrl_i.mask_start_index)/N)*M) & ((M*M/N)-1));
mask_d = '0;
mask_d = '0;

if (step_i == QK) begin
if (mask_tile_x_pos_q == tile_x_i && mask_tile_y_pos_q == tile_y_i && last_inner_tile_i == 1'b1) begin
Expand Down Expand Up @@ -92,11 +95,11 @@ module ita_masking
end
end
LowerTriangular: begin
mask_col_offset_d = '0;
mask_tile_x_pos_d = '0;
mask_col_offset_d = '0;
mask_tile_x_pos_d = mask_tile_x_pos_q;
mask_tile_y_pos_d = (step_i == QK || step_i == AV) ? mask_tile_y_pos_q : ((ctrl_i.mask_start_index) / M);
mask_pos_d = (step_i == QK || step_i == AV) ? mask_pos_q : (ctrl_i.mask_start_index & (M-1));
mask_d = '0;
mask_d = '0;

if (step_i == QK) begin
if (mask_tile_x_pos_q == tile_x_i && mask_tile_y_pos_q == tile_y_i && last_inner_tile_i == 1'b1) begin
Expand Down Expand Up @@ -166,6 +169,54 @@ module ita_masking
end
end
end
UpperStrided: begin
mask_col_offset_d = '0;
mask_tile_x_pos_d = '0;
mask_tile_y_pos_d = '0;
mask_pos_d = '0;
mask_d = '0;

if (step_i == QK) begin
if (last_inner_tile_i == 1'b1) begin
for (int i = 0; i < N; i++) begin
//Marcel Kant: Does only work if ctrl_i.mask_start_index is a power of two
if ((((((count_i / M) * N) + i + (tile_x_i * M)) - ((count_i & (M-1)) + (tile_y_i * M))) & (ctrl_i.mask_start_index-1)) == 0 &&
((((count_i / M) * N) + i + (tile_x_i * M)) >= ((count_i & (M-1)) + (tile_y_i * M)))) begin
mask_d[i] = 1'b0;
end else begin
mask_d[i] = 1'b1;
end
end
end
end
end
LowerStrided: begin
mask_col_offset_d = '0;
mask_tile_x_pos_d = '0;
mask_tile_y_pos_d = '0;
mask_pos_d = '0;
mask_d = '0;

if (step_i == QK) begin
if (last_inner_tile_i == 1'b1) begin
for (int i = 0; i < N; i++) begin
//Marcel Kant: Does only work if ctrl_i.mask_start_index is a power of two
if ((((((count_i / M) * N) + i + (tile_x_i * M)) - ((count_i & (M-1)) + (tile_y_i * M))) & (ctrl_i.mask_start_index-1)) == 0 &&
((((count_i / M) * N) + i + (tile_x_i * M)) <= ((count_i & (M-1)) + (tile_y_i * M)))) begin
mask_d[i] = 1'b0;
end else begin
mask_d[i] = 1'b1;
end
end
end
end
end
SlidingWindow: begin

end
StridedSlidingWindow: begin

end
endcase
end

Expand Down
9 changes: 8 additions & 1 deletion src/ita_package.sv
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,14 @@ package ita_package;
typedef logic signed [GELU_OUT_WIDTH-1:0] gelu_out_t;

// Masking
typedef enum {None=0, UpperTriangular=1, LowerTriangular=2, Strided=3, UpperStrided=4, LowerStrided=5} mask_e;
typedef enum {None=0,
UpperTriangular=1,
LowerTriangular=2,
Strided=3,
UpperStrided=4,
LowerStrided=5,
SlidingWindow=6,
StridedSlidingWindow=7} mask_e;
typedef logic [WO-WI*2-2:0] mask_index_t;

// IO
Expand Down
22 changes: 22 additions & 0 deletions src/ita_softmax.sv
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,28 @@ module ita_softmax
disable_col[i] = 1'b1;
end
end
UpperStrided: begin
if ((((i + (mask_tile_x_q * M)) - ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M))) & (ctrl_i.mask_start_index-1)) == 0 &&
((i + (mask_tile_x_q * M)) >= ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M)))) begin
disable_col[i] = 1'b0;
end else begin
disable_col[i] = 1'b1;
end
end
LowerStrided: begin
if ((((i + (mask_tile_x_q * M)) - ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M))) & (ctrl_i.mask_start_index-1)) == 0 &&
((i + (mask_tile_x_q * M)) <= ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M)))) begin
disable_col[i] = 1'b0;
end else begin
disable_col[i] = 1'b1;
end
end
SlidingWindow: begin

end
StridedSlidingWindow: begin

end
endcase
end

Expand Down
9 changes: 8 additions & 1 deletion testGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,14 @@ class ArgumentDefaultMetavarTypeFormatter(argparse.ArgumentDefaultsHelpFormatter
default = 'none',
type = str,
help = 'Attention-Mask',
choices = ['none', 'upper_triangular', 'lower_triangular', 'strided', 'upper_strided', 'lower_strided'])
choices = ['none',
'upper_triangular',
'lower_triangular',
'strided',
'upper_strided',
'lower_strided',
'sliding_window',
'strided_sliding_window'])
self.group1.add_argument('-I', default = 1, type = int, help = 'Masking starting index')
self.group1.add_argument('--no-partial-softmax',
action = 'store_true',
Expand Down

0 comments on commit 343b6a9

Please sign in to comment.