Skip to content

Commit

Permalink
Cleaned up the code a little bit
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel Kant committed Dec 5, 2024
1 parent b00afb9 commit 73736c0
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 96 deletions.
49 changes: 2 additions & 47 deletions PyITA/ITA.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,37 +601,20 @@ def apply_mask(self, index):
else:
raise ValueError("Mask not supported")


def step4_QK(self, no_partial_softmax, index):
self.A = np.array(
[np.matmul(self.Qp_requant[i], np.transpose(self.Kp_requant[i]), dtype = np.int32) for i in range(self.H)])
self.A = np.clip(self.A, -2**(self.WO - 1), 2**(self.WO - 1) - 1)
self.A_requant = requantize(self.A, self.requant_eps_mult[3], self.requant_right_shift[3], self.requant_add[3])

self.apply_mask(index)

print(self.Mask)

matrix = np.squeeze(self.A_requant)
plt.imshow(matrix, cmap='viridis')
plt.colorbar()
plt.title("A_requant/A_stream_soft_in")
plt.show()

print(f"A_requant row 0: {self.A_requant[0, 0, :]}")

if (self.S_ITA - self.S) > 0:
self.A_requant[:, -(self.S_ITA - self.S):, :] = 0
self.A_requant[:, :, -(self.S_ITA - self.S):] = 0

self.soft(no_partial_softmax)

matrix = np.squeeze(self.A_partial_softmax)
plt.imshow(matrix, cmap='viridis')
plt.colorbar()
plt.title("A_partial_softmax")
plt.show()

self.tiler_AV(self.Qp_requant, self.Kp_requant, self.A_requant, "Qp_in", "Kp_in", "A")

def soft(self, no_partial_softmax = False):
Expand All @@ -645,8 +628,6 @@ def soft(self, no_partial_softmax = False):
else:
self.A_partial_softmax = streamingPartialSoftmax(self.A_requant[:, :self.S, :self.S], self.Mask)
self.A_partial_softmax[self.Mask] = 0
print(f"inp_stream_soft_o: {self.A_partial_softmax[0,:,:]}")
print(f"Normalization Sum: {np.sum(self.A_partial_softmax[0,:,:], axis=1)}")
self.A_partial_softmax = np.pad(self.A_partial_softmax,
((0, 0), (0, self.S_ITA - self.S), (0, self.S_ITA - self.S)))

Expand All @@ -657,44 +638,24 @@ def soft(self, no_partial_softmax = False):
A_save = self.A_partial_softmax[h]
write_matrix(A_save, f"A_soft_{h}", self.paths["standalone"])

def step5_AV(self):
print(f"A_partial_softmax: {self.A_partial_softmax.shape}")
print(f"Vp_requant: {self.Vp_requant.shape}")

def step5_AV(self):
self.O_soft = np.array([
np.matmul(self.A_partial_softmax[i].astype(np.uint8), self.Vp_requant[i], dtype = np.int32)
for i in range(self.H)
])
print(f"O_soft without requant row 0: {self.O_soft[0, 62, :]}")
print(f"O_soft without requant row 0: {self.O_soft[0, 63, :]}")
print(f"O_soft without requant row 0: {self.O_soft[0, 0, :]}")
print(f"O_soft without requant row 0: {self.O_soft[0, 1, :]}")


self.O_soft = np.clip(self.O_soft, -2**(self.WO - 1), 2**(self.WO - 1) - 1)
self.O_soft_requant = requantize(self.O_soft, self.requant_eps_mult[4], self.requant_right_shift[4],
self.requant_add[4])

print(f"O_soft_requant: {self.O_soft_requant[0, 62, :]}")
print(f"O_soft_requant: {self.O_soft_requant[0, 63, :]}")
print(f"O_soft_requant: {self.O_soft_requant[0, 0, :]}")
print(f"O_soft_requant: {self.O_soft_requant[0, 1, :]}")

if (self.S_ITA - self.S) > 0:
self.O_soft_requant[:, -(self.S_ITA - self.S):, :] = 0
if (self.P_ITA - self.P) > 0:
self.O_soft_requant[:, :, -(self.P_ITA - self.P):] = 0

matrix = np.squeeze(self.O_soft_requant)
plt.imshow(matrix, cmap='viridis')
plt.colorbar()
plt.title("O_soft_requant/O_soft")
plt.show()

self.tiler_AV(self.A_requant, np.transpose(self.Vp_requant, (0, 2, 1)), self.O_soft_requant, "A_stream_soft_in",
"Vp_in", "O_soft")



def apply_activation(self, preactivation, activation):
if activation not in ["gelu", "relu", "identity"]:
raise ValueError("Activation function not supported")
Expand All @@ -719,12 +680,6 @@ def step6_O(self):
self.Out_soft_requant = requantize(self.Out_soft, self.requant_eps_mult[5], self.requant_right_shift[5],
self.requant_add[5])

matrix = np.squeeze(self.Out_soft_requant)
plt.imshow(matrix, cmap='viridis')
plt.colorbar()
plt.title("Out_soft_requant")
plt.show()

if (self.S_ITA - self.S) > 0:
self.Out_soft_requant[:, -(self.S_ITA - self.S):, :] = 0
if (self.E_ITA - self.E) > 0:
Expand Down
30 changes: 0 additions & 30 deletions PyITA/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,6 @@ def streamingPartialSoftmax(x, mask, integerize = True):

mask_slice = mask[... ,i*PE:(i*PE)+width]
x_slice = x[..., 0 + i * PE:width + i * PE]
print(f"Mask Slice Shape: {mask_slice.shape}")
print(f"Mask Slice: {mask_slice}")
print(f"X Slice Shape: {x_slice.shape}")
print(f"X Slice: {x_slice}")

# Find the maximum for each row in the current column block (consisting of 16 columns)
if integerize:
Expand All @@ -129,16 +125,9 @@ def streamingPartialSoftmax(x, mask, integerize = True):
else:
max_shift = (current_max - global_max) * eps_max

print(f"Global Max: {global_max.shape}")
print(global_max)
print(f"Current Max: {current_max.shape}")
print(current_max)

# Update all shift values where new maximum is larger
shift_sum[current_max > global_max] = max_shift[current_max > global_max]

print(f"Shift sum: {shift_sum}")

# Updated all maximums where they changed
global_max[current_max > global_max] = current_max[current_max > global_max]

Expand All @@ -157,45 +146,29 @@ def streamingPartialSoftmax(x, mask, integerize = True):
else:
shift = diff * eps_max

print(f"Shift Shape: {shift.shape}")
print(f"Shift without mask: {shift}")

# Set shift value so high that 2**8 >> shift gets zero for all masked values
shift[mask_slice] = 32
print(f"Shift with mask: {shift}")
# # matrix = np.squeeze(shift)
# # import matplotlib.pyplot as plt
# # plt.imshow(matrix, cmap='viridis')
# # plt.colorbar()
# # plt.title("Shift Matrix")
# # plt.show()

# Calculate exponential sum over the current part of the row and scale it by 2**10 to prevent underflow
if integerize:
exp_sum = np.sum(2**8 >> shift, -1) # or
# exp_sum = np.floor(np.sum(2**8 / 2**shift, axis = -1))
else:
exp_sum = np.sum(1 / 2**shift, axis = -1)

print(f"Exp sum: {exp_sum}")

# Update the accumulated sum and add the accumulation over the current part of the row
if integerize:
exp_partial_sum = np.floor((exp_partial_sum / 2**shift_sum)) + exp_sum
else:
exp_partial_sum = (exp_partial_sum / 2**(shift_sum.astype(np.float32))) + exp_sum

print(f"Exp parital sum: {exp_partial_sum}")

## STAGE 2: Calculate the softmax activation
# Invert the partial sum
if integerize:
exp_partial_sum_inverse = np.floor((2**8 - 1) * 2**8 / exp_partial_sum).astype(np.int32)
else:
exp_partial_sum_inverse = 1 / exp_partial_sum

print(f"Exp parital sum inverse: {exp_partial_sum_inverse}")

# Find the difference between the maximum and x
diff = np.repeat(global_max, seq_length).reshape(n_heads, seq_length, seq_length) - x.astype(np.int32)

Expand All @@ -209,9 +182,6 @@ def streamingPartialSoftmax(x, mask, integerize = True):
shift = np.floor(diff * eps_max + 0.5 + np.finfo(np.float32).eps).astype(np.int32)
else:
shift = diff * eps_max

print(f"shift value before return shape: {shift.shape}")
print(f"shift value before return: {shift}")

# Calculate the activation value
if integerize:
Expand Down
27 changes: 8 additions & 19 deletions src/ita_controller.sv
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,6 @@ module ita_controller
end
inp_bias_padded = inp_bias;


mask_d = '0;
case (ctrl_i.mask_type)
None: begin
Expand Down Expand Up @@ -433,18 +432,14 @@ module ita_controller
end
end
end else if ((count_q & (M-1)) < (mask_pos_q & (M-1))) begin
for (int i = 0; i < N; i++) begin
mask_d[i] = 1'b1;
end
mask_d = '1;
end
end else if (mask_tile_x_pos_q <= tile_x_q && mask_tile_y_pos_q != tile_y_q && last_inner_tile_o == 1'b1) begin
for (int i = 0; i < N; i++) begin
mask_d[i] = 1'b1;
end
mask_d = '1;
end else if (mask_tile_x_pos_q != tile_x_q && mask_tile_y_pos_q == tile_y_q && last_inner_tile_o == 1'b1) begin
for (int i = 0; i < N; i++) begin
mask_d[i] = 1'b0;
end
mask_d = '0;
end else begin
mask_d = '0;
end
end
end
Expand Down Expand Up @@ -483,18 +478,12 @@ module ita_controller
end
end
end else if ((count_q & (M-1)) >= (mask_pos_q & (M-1))) begin
for (int i = 0; i < N; i++) begin
mask_d[i] = 1'b1;
end
mask_d = '1;
end
end else if (mask_tile_x_pos_q > tile_x_q && mask_tile_y_pos_q == tile_y_q && last_inner_tile_o == 1'b1) begin
for (int i = 0; i < N; i++) begin
mask_d[i] = 1'b1;
end
mask_d = '1;
end else if (mask_tile_x_pos_q >= tile_x_q && mask_tile_y_pos_q != tile_y_q && last_inner_tile_o == 1'b1) begin
for (int i = 0; i < N; i++) begin
mask_d[i] = 1'b0;
end
mask_d = '0;
end
end
end
Expand Down

0 comments on commit 73736c0

Please sign in to comment.