Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply 2D blocking to all kernels #156

Merged
merged 5 commits into from
Jan 3, 2024
Merged

Conversation

ahgamut
Copy link
Contributor

@ahgamut ahgamut commented Dec 30, 2023

extracts a bit more speed in the prompt eval time. also fixes some typo errors.

@ahgamut
Copy link
Contributor Author

ahgamut commented Dec 30, 2023

somehow GemmStridedBatchedEx produces garbage output when I try to use 2D blocks there.

@ahgamut ahgamut marked this pull request as ready for review December 30, 2023 19:25
@ahgamut
Copy link
Contributor Author

ahgamut commented Dec 30, 2023

Alright, I'm unable to find the error that stops GemmStridedBatchedEx. It's a bounds check of some kind, I think.

@ahgamut
Copy link
Contributor Author

ahgamut commented Jan 2, 2024

I applied the 2D blocking for every kernel used. This gives a boost in both the CLIP and the prompt_eval parts when using llava, but causes a slowdown in the eval part.

The slowdown is because the blocking for the GemmStridedBatchedEx kernel is not optimal, a lot of threads do no work with the current values of BM, BN, BK. GemmStridedBatchedEx does much better when BM = 32, BN = 4, BK = 32 because less threads are wasted.

@ahgamut
Copy link
Contributor Author

ahgamut commented Jan 2, 2024

A balancing act will be have BM/BK/BN as template parameters for all the GPU functions, or redefine those macros again before GemmStridedBatchedEx and copy-paste the body of matmul_block2d.

@ahgamut ahgamut changed the title apply 2D blocks to GemmBatchedEx apply 2D blocking to all kernels Jan 2, 2024
Copy link
Collaborator

@jart jart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! It looks like this gives us a 13% performance boost for GPU inference (both eval and batch eval) for Windows users. Looks good to me.

@jart jart merged commit c0589f0 into Mozilla-Ocho:main Jan 3, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants