Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Renormalize LMNextToken.sample() probs to fix floating point errors #20

Merged
merged 1 commit into from
Jan 4, 2025

Conversation

gabegrand
Copy link
Collaborator

Due to numerical imprecision, the exponentiated next_token_logprobs occasionally do not sum to 1. When the discrepancy is too large, this causes np.random.choice to throw the following error:

    token_id = np.random.choice(len(probs), p=(probs))
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "numpy/random/mtrand.pyx", line 975, in numpy.random.mtrand.RandomState.choice
ValueError: probabilities do not sum to 1

This is observed empirically when using token masking and appears to occur more frequently for smaller models (e.g., llama-3.2-1B).

This PR fixes this issue by renormalizing next_token_logprobs after exponentiation. This is intended as a stopgap fix pending future work to look into numerical stability in the token masking code.

@gabegrand gabegrand requested a review from alex-lew January 3, 2025 21:37
@alex-lew
Copy link
Contributor

alex-lew commented Jan 4, 2025

Looks good, as a stopgap! We probably want to have some way of emitting a warning when this happens, so that we catch future bugs relating to logprob computation.

@alex-lew alex-lew merged commit a191fca into main Jan 4, 2025
1 check passed
@gabegrand gabegrand deleted the gg/renormalize branch January 6, 2025 18:19
@gabegrand
Copy link
Collaborator Author

Thanks! I agree about the warning. I didn't one for now because I didn't want to introduce a ton of spam in the logs. In the future we could consider making better use of the logging / warnings modules, which allow suppression of repeated warnings through filters.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants