Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

einsum problem with linear grad sample #242

Merged
merged 6 commits into from
Nov 10, 2021
Merged

Conversation

pierrestock
Copy link
Contributor

Make linear einsum compatible in the case of 3D activations (transformers use case). Maybe there is a way of getting rid of the if/else statement.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 3, 2021
# With activations with 3 dimensions (transformers) the original formula does not aggregate correctly
if activations.ndim == 2:
gs = torch.einsum("n...i,n...j->nij", backprops, activations)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob question: I thought ellipses in einsum handles all dimensions, and by extension even 3D. Do you know why it didn't work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After careful investigation today, it appears that its linked to a previous version fo PyTorch that we used in our environment (specifically 1.7.1). The bug appears to be solved for 1.9 since we could not reproduce for this version.

My suggestion would then be either (1) not to merge the PR but update the requirements for PyTorch or (2) merge the PR so that we have backward compatibility with older PyTorch versions.

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating PyTorch requirements past 1.8 is problematic, because we depend on csprng, which doesn't support 1.9 and beyond.
Is the bug still reproducable in 1.8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good news, bug disappears for 1.8.0. Let's update the requirements then?

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's update the requirements then?

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated requirements.txt with torch>=1.8. Sorry for the multiple commits to revert grad_sample/linear.py.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ffuuugor Are you sure csprng doesn't support torch-1.9 (in it's current state). Did you see my comment? Just checking. I didn't test it deeply.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're probably right - but I haven't looked too deeply either.
I was only relying on the fact that PyPI release depend on 1.8 and their documentation explicitly lists all pyTorch versions they support, notably lacking 1.9 and beyond.
Is the fix as simple as installing from main branch instead of the latest release?

@pierrestock pierrestock merged commit 6712524 into experimental_v1.0 Nov 10, 2021
@karthikprasad karthikprasad deleted the linear_einsum branch November 25, 2021 22:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants