diff --git a/Makefile b/Makefile index 0158e78..3a5a0d7 100644 --- a/Makefile +++ b/Makefile @@ -47,6 +47,10 @@ else ifeq ($(mask), upper_strided) mask_int = 4 else ifeq ($(mask), lower_strided) mask_int = 5 +else ifeq ($(mask), sliding_window_attention) + mask_int = 6 +else ifeq ($(mask), strided_sliding_window_attention) + mask_int = 7 else mask_int = 0 endif diff --git a/PyITA/ITA.py b/PyITA/ITA.py index 0b1dac3..d53e7cd 100644 --- a/PyITA/ITA.py +++ b/PyITA/ITA.py @@ -614,32 +614,59 @@ def apply_mask(self, index): elif (self.mask == 'upper_strided'): self.Mask = np.full((self.H, self.S, self.S), fill_value=True, dtype='bool') if (0 < index and index < self.S): - for h in range(self.Mask.shape[0]): - for i in range(self.Mask.shape[1]): - for j in range(i, self.Mask.shape[2], index): - self.Mask[h][i][j] = False + if (index % 2 == 0): + for h in range(self.Mask.shape[0]): + for i in range(self.Mask.shape[1]): + for j in range(i, self.Mask.shape[2], index): + self.Mask[h][i][j] = False + else: + raise ValueError(f"Index has to be a power of two for {self.mask} mask") else: raise ValueError(f"Index is out of bounds for {self.mask} mask") elif (self.mask == 'lower_strided'): self.Mask = np.full((self.H, self.S, self.S), fill_value=True, dtype='bool') if (0 < index and index < self.S): - for h in range(self.Mask.shape[0]): - for i in range(self.Mask.shape[1]): - for j in range(i, self.Mask.shape[2], index): - self.Mask[h][j][i] = False + if (index % 2 == 0): + for h in range(self.Mask.shape[0]): + for i in range(self.Mask.shape[1]): + for j in range(i, self.Mask.shape[2], index): + self.Mask[h][j][i] = False + else: + raise ValueError(f"Index has to be a power of two for {self.mask} mask") else: raise ValueError(f"Index is out of bounds for {self.mask} mask") - elif (self.mask == 'sliding_window_attention'): + elif (self.mask == 'sliding_window'): self.Mask = np.full((self.H, self.S, self.S), fill_value=True, dtype='bool') if (0 < index and index < self.S): - for h in range(self.Mask.shape[0]): - for i in range(self.Mask.shape[1]): - for j in range(i, (index + i)): - self.Mask[h][i][j] = False - self.Mask[h][j][i] = False + if (index % 2 == 0): + for h in range(self.Mask.shape[0]): + for i in range(self.Mask.shape[1]): + for j in range(i, (index + i)): + self.Mask[h][i][j] = False + self.Mask[h][j][i] = False + else: + raise ValueError(f"Index has to be a power of two for {self.mask} mask") + else: + raise ValueError(f"Index is out of bounds for {self.mask} mask") + elif (self.mask == 'strided_sliding_window'): + self.Mask = np.full((self.H, self.S, self.S), fill_value=True, dtype='bool') + if (0 < index and index < self.S): + if (index % 2 == 0): + for h in range(self.Mask.shape[0]): + for i in range(self.Mask.shape[1]): + for j in range(i, self.Mask.shape[2]): + if (j > (index + i)): + if (j % index == 0): + self.Mask[h][i][j] = False + self.Mask[h][j][i] = False + else: + self.Mask[h][i][j] = False + self.Mask[h][j][i] = False + else: + raise ValueError(f"Index has to be a power of two for {self.mask} mask") else: raise ValueError(f"Index is out of bounds for {self.mask} mask") - elif(self.mask == 'none'): + elif (self.mask == 'none'): pass else: raise ValueError("Mask not supported") diff --git a/src/ita_masking.sv b/src/ita_masking.sv index 7273309..f65cc5d 100644 --- a/src/ita_masking.sv +++ b/src/ita_masking.sv @@ -30,6 +30,9 @@ module ita_masking assign mask_o = mask_q; always_comb begin + + + case (ctrl_i.mask_type) None: begin mask_col_offset_d = '0; @@ -41,9 +44,9 @@ module ita_masking UpperTriangular: begin mask_col_offset_d = (step_i == QK || step_i == AV) ? mask_col_offset_q : ((ctrl_i.mask_start_index) & (N-1)); mask_tile_x_pos_d = (step_i == QK || step_i == AV) ? mask_tile_x_pos_q : ((ctrl_i.mask_start_index) / M); - mask_tile_y_pos_d = mask_tile_y_pos_q; + mask_tile_y_pos_d = mask_tile_y_pos_q; mask_pos_d = (step_i == QK || step_i == AV) ? mask_pos_q : ((((ctrl_i.mask_start_index)/N)*M) & ((M*M/N)-1)); - mask_d = '0; + mask_d = '0; if (step_i == QK) begin if (mask_tile_x_pos_q == tile_x_i && mask_tile_y_pos_q == tile_y_i && last_inner_tile_i == 1'b1) begin @@ -92,11 +95,11 @@ module ita_masking end end LowerTriangular: begin - mask_col_offset_d = '0; - mask_tile_x_pos_d = '0; + mask_col_offset_d = '0; + mask_tile_x_pos_d = mask_tile_x_pos_q; mask_tile_y_pos_d = (step_i == QK || step_i == AV) ? mask_tile_y_pos_q : ((ctrl_i.mask_start_index) / M); mask_pos_d = (step_i == QK || step_i == AV) ? mask_pos_q : (ctrl_i.mask_start_index & (M-1)); - mask_d = '0; + mask_d = '0; if (step_i == QK) begin if (mask_tile_x_pos_q == tile_x_i && mask_tile_y_pos_q == tile_y_i && last_inner_tile_i == 1'b1) begin @@ -166,6 +169,54 @@ module ita_masking end end end + UpperStrided: begin + mask_col_offset_d = '0; + mask_tile_x_pos_d = '0; + mask_tile_y_pos_d = '0; + mask_pos_d = '0; + mask_d = '0; + + if (step_i == QK) begin + if (last_inner_tile_i == 1'b1) begin + for (int i = 0; i < N; i++) begin + //Marcel Kant: Does only work if ctrl_i.mask_start_index is a power of two + if ((((((count_i / M) * N) + i + (tile_x_i * M)) - ((count_i & (M-1)) + (tile_y_i * M))) & (ctrl_i.mask_start_index-1)) == 0 && + ((((count_i / M) * N) + i + (tile_x_i * M)) >= ((count_i & (M-1)) + (tile_y_i * M)))) begin + mask_d[i] = 1'b0; + end else begin + mask_d[i] = 1'b1; + end + end + end + end + end + LowerStrided: begin + mask_col_offset_d = '0; + mask_tile_x_pos_d = '0; + mask_tile_y_pos_d = '0; + mask_pos_d = '0; + mask_d = '0; + + if (step_i == QK) begin + if (last_inner_tile_i == 1'b1) begin + for (int i = 0; i < N; i++) begin + //Marcel Kant: Does only work if ctrl_i.mask_start_index is a power of two + if ((((((count_i / M) * N) + i + (tile_x_i * M)) - ((count_i & (M-1)) + (tile_y_i * M))) & (ctrl_i.mask_start_index-1)) == 0 && + ((((count_i / M) * N) + i + (tile_x_i * M)) <= ((count_i & (M-1)) + (tile_y_i * M)))) begin + mask_d[i] = 1'b0; + end else begin + mask_d[i] = 1'b1; + end + end + end + end + end + SlidingWindow: begin + + end + StridedSlidingWindow: begin + + end endcase end diff --git a/src/ita_package.sv b/src/ita_package.sv index 184e0f9..fee10a8 100644 --- a/src/ita_package.sv +++ b/src/ita_package.sv @@ -41,7 +41,14 @@ package ita_package; typedef logic signed [GELU_OUT_WIDTH-1:0] gelu_out_t; // Masking - typedef enum {None=0, UpperTriangular=1, LowerTriangular=2, Strided=3, UpperStrided=4, LowerStrided=5} mask_e; + typedef enum {None=0, + UpperTriangular=1, + LowerTriangular=2, + Strided=3, + UpperStrided=4, + LowerStrided=5, + SlidingWindow=6, + StridedSlidingWindow=7} mask_e; typedef logic [WO-WI*2-2:0] mask_index_t; // IO diff --git a/src/ita_softmax.sv b/src/ita_softmax.sv index 70bdfe8..3e38f47 100644 --- a/src/ita_softmax.sv +++ b/src/ita_softmax.sv @@ -325,6 +325,28 @@ module ita_softmax disable_col[i] = 1'b1; end end + UpperStrided: begin + if ((((i + (mask_tile_x_q * M)) - ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M))) & (ctrl_i.mask_start_index-1)) == 0 && + ((i + (mask_tile_x_q * M)) >= ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M)))) begin + disable_col[i] = 1'b0; + end else begin + disable_col[i] = 1'b1; + end + end + LowerStrided: begin + if ((((i + (mask_tile_x_q * M)) - ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M))) & (ctrl_i.mask_start_index-1)) == 0 && + ((i + (mask_tile_x_q * M)) <= ((count_soft_mask_q & (M-1)) + (mask_tile_y_q * M)))) begin + disable_col[i] = 1'b0; + end else begin + disable_col[i] = 1'b1; + end + end + SlidingWindow: begin + + end + StridedSlidingWindow: begin + + end endcase end diff --git a/testGenerator.py b/testGenerator.py index 97465d9..9079056 100644 --- a/testGenerator.py +++ b/testGenerator.py @@ -108,7 +108,14 @@ class ArgumentDefaultMetavarTypeFormatter(argparse.ArgumentDefaultsHelpFormatter default = 'none', type = str, help = 'Attention-Mask', - choices = ['none', 'upper_triangular', 'lower_triangular', 'strided', 'upper_strided', 'lower_strided']) + choices = ['none', + 'upper_triangular', + 'lower_triangular', + 'strided', + 'upper_strided', + 'lower_strided', + 'sliding_window', + 'strided_sliding_window']) self.group1.add_argument('-I', default = 1, type = int, help = 'Masking starting index') self.group1.add_argument('--no-partial-softmax', action = 'store_true',