Skip to content

Commit

Permalink
Started to add the logic for more than one tile
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel Kant committed Dec 1, 2024
1 parent 4386fce commit 2447e2a
Show file tree
Hide file tree
Showing 5 changed files with 791 additions and 74 deletions.
8 changes: 5 additions & 3 deletions PyITA/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def streamingPartialSoftmax(x, mask, integerize = True):
print(f"Shift without mask: {shift}")

# Set shift value so high that 2**8 >> shift gets zero for all masked values
shift[mask_slice] = 16
shift[mask_slice] = 32
print(f"Shift with mask: {shift}")
# # matrix = np.squeeze(shift)
# # import matplotlib.pyplot as plt
Expand Down Expand Up @@ -194,8 +194,7 @@ def streamingPartialSoftmax(x, mask, integerize = True):
else:
exp_partial_sum_inverse = 1 / exp_partial_sum

# print(exp_partial_sum_inverse.shape)
# print(exp_partial_sum_inverse[0])
print(f"Exp parital sum inverse: {exp_partial_sum_inverse}")

# Find the difference between the maximum and x
diff = np.repeat(global_max, seq_length).reshape(n_heads, seq_length, seq_length) - x.astype(np.int32)
Expand All @@ -210,6 +209,9 @@ def streamingPartialSoftmax(x, mask, integerize = True):
shift = np.floor(diff * eps_max + 0.5 + np.finfo(np.float32).eps).astype(np.int32)
else:
shift = diff * eps_max

print(f"shift value before return shape: {shift.shape}")
print(f"shift value before return: {shift}")

# Calculate the activation value
if integerize:
Expand Down
Loading

0 comments on commit 2447e2a

Please sign in to comment.