Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Make PPO compatible with composite actions and log-probs #2665

Open
wants to merge 6 commits into
base: gh/vmoens/58/base
Choose a base branch
from

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Dec 18, 2024

[ghstack-poisoned]
vmoens added a commit that referenced this pull request Dec 18, 2024
ghstack-source-id: cbdaf533a39aeea41e3fbcda4e9d95a116eabfe1
Pull Request resolved: #2665
Copy link

pytorch-bot bot commented Dec 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2665

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 17 New Failures, 1 Unrelated Failure

As of commit 14e639d with merge base ed656a1 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 18, 2024
@vmoens
Copy link
Contributor Author

vmoens commented Dec 18, 2024

In this PR, I propose to let PPO have series of actions defined in the in-keys (rather than a single one) to accomodate CompositeDistributions better.

This PR requires pytorch/tensordict#1146 and pytorch/tensordict#1145 to be merged or checked out.

Here is a demo:
https://gist.github.com/vmoens/46175764240dcbaf311af562b9e53294

cc @matteobettini

@matteobettini
Copy link
Contributor

Cool!

Just to understand a bit, how is this related to multiagent?

I see in the example that you are using different agent groups, but the feature seems to be more suited for composite single-agent actions.

In multiagent, the suggested way to do things was to create a different loss for each group. This is to avoid losses taking a list of dones, rewards, and actions and have to match them.

I think this feature for me makes sense for composite actions within a single-agent or a single marl group (avoiding taking a list of rewards and dones).

@matteobettini
Copy link
Contributor

Also in the example you are using a single module to output actions for multiple groups.
I think also here the way we suggest to do things is to process different groups in different modules, so that each module can go to its loss.

@vmoens
Copy link
Contributor Author

vmoens commented Dec 19, 2024

I don't have a strong feeling RE multiagent or not, the use case that was suggested to me here had a composite action space where each leaf was labelled "agent_x"
I guess that long term there's a version of this where you could have one loss for all, since tensordict now supports arithmetic ops you could perfectly do

log_prob = make_some_tensordict(...)
prev_log_prob = make_some_tensordict(...)
advantage = make_a_tensordict_or_a_tensor(...)
loss = (log_prob - prev_log_prob).exp().clamp(...) * advantage

and your loss will be a tensordict itself.
This would probably break now but I do think we could actually get this to work and simplify the code at the same time (that will require deprecating some default behaviours in CompositeDistribution in v0.9 as announed in tensordict)

@matteobettini
Copy link
Contributor

matteobettini commented Dec 19, 2024

This makes sense for a composite action space yes. But in your PR i see you are also allowing lists of dones and rewards.

This is a bit less trivial as it opens up to a bunch of compatibility usecases if you want to use this in MARL.
The done and reward keys might not be a one to one mapping to actions:

  • groups can have composite actions
  • reward and done could be partially or totally shared across groups (each in a different way possibly)

Supporting all these usecases might become a big headacke which is why I preferred to stick with one reward and done key per loss class.

[ghstack-poisoned]
vmoens added a commit that referenced this pull request Dec 20, 2024
ghstack-source-id: f465f2017843904a510aa06768ced457df987e94
Pull Request resolved: #2665
@vmoens vmoens added the enhancement New feature or request label Dec 20, 2024
[ghstack-poisoned]
vmoens added a commit that referenced this pull request Jan 9, 2025
ghstack-source-id: 3bcf7ebf9619f62d68979f85021a769796da0539
Pull Request resolved: #2665
Copy link

github-actions bot commented Jan 9, 2025

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 149. Improved: $\large\color{#35bf28}35$. Worsened: $\large\color{#d91a1a}23$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_simple 0.5689s 0.4674s 2.1394 Ops/s 2.1416 Ops/s $\color{#d91a1a}-0.10\%$
test_transformed 0.7436s 0.6492s 1.5403 Ops/s 1.5206 Ops/s $\color{#35bf28}+1.30\%$
test_serial 1.4874s 1.3864s 0.7213 Ops/s 0.7046 Ops/s $\color{#35bf28}+2.36\%$
test_parallel 1.3401s 1.2280s 0.8143 Ops/s 0.7603 Ops/s $\textbf{\color{#35bf28}+7.11\%}$
test_step_mdp_speed[True-True-True-True-True] 0.1648ms 30.4931μs 32.7943 KOps/s 33.5589 KOps/s $\color{#d91a1a}-2.28\%$
test_step_mdp_speed[True-True-True-True-False] 76.6840μs 18.0617μs 55.3657 KOps/s 55.9135 KOps/s $\color{#d91a1a}-0.98\%$
test_step_mdp_speed[True-True-True-False-True] 80.8810μs 17.2989μs 57.8072 KOps/s 58.8932 KOps/s $\color{#d91a1a}-1.84\%$
test_step_mdp_speed[True-True-True-False-False] 38.0020μs 10.1379μs 98.6396 KOps/s 98.9940 KOps/s $\color{#d91a1a}-0.36\%$
test_step_mdp_speed[True-True-False-True-True] 86.1020μs 32.7417μs 30.5421 KOps/s 31.1040 KOps/s $\color{#d91a1a}-1.81\%$
test_step_mdp_speed[True-True-False-True-False] 81.5130μs 19.9516μs 50.1212 KOps/s 50.4088 KOps/s $\color{#d91a1a}-0.57\%$
test_step_mdp_speed[True-True-False-False-True] 57.1070μs 19.2024μs 52.0768 KOps/s 52.4725 KOps/s $\color{#d91a1a}-0.75\%$
test_step_mdp_speed[True-True-False-False-False] 51.9570μs 12.1072μs 82.5957 KOps/s 83.9385 KOps/s $\color{#d91a1a}-1.60\%$
test_step_mdp_speed[True-False-True-True-True] 0.1309ms 35.1307μs 28.4652 KOps/s 29.4069 KOps/s $\color{#d91a1a}-3.20\%$
test_step_mdp_speed[True-False-True-True-False] 53.6100μs 22.3654μs 44.7120 KOps/s 45.9667 KOps/s $\color{#d91a1a}-2.73\%$
test_step_mdp_speed[True-False-True-False-True] 46.7180μs 19.2691μs 51.8966 KOps/s 53.6333 KOps/s $\color{#d91a1a}-3.24\%$
test_step_mdp_speed[True-False-True-False-False] 54.4420μs 12.1603μs 82.2348 KOps/s 83.6048 KOps/s $\color{#d91a1a}-1.64\%$
test_step_mdp_speed[True-False-False-True-True] 68.7190μs 36.5234μs 27.3797 KOps/s 28.1960 KOps/s $\color{#d91a1a}-2.89\%$
test_step_mdp_speed[True-False-False-True-False] 70.9320μs 24.1283μs 41.4452 KOps/s 43.0549 KOps/s $\color{#d91a1a}-3.74\%$
test_step_mdp_speed[True-False-False-False-True] 76.1430μs 22.0093μs 45.4354 KOps/s 48.3307 KOps/s $\textbf{\color{#d91a1a}-5.99\%}$
test_step_mdp_speed[True-False-False-False-False] 0.1418ms 13.8298μs 72.3077 KOps/s 72.8512 KOps/s $\color{#d91a1a}-0.75\%$
test_step_mdp_speed[False-True-True-True-True] 76.3030μs 34.6117μs 28.8920 KOps/s 29.4453 KOps/s $\color{#d91a1a}-1.88\%$
test_step_mdp_speed[False-True-True-True-False] 54.6230μs 22.0928μs 45.2636 KOps/s 46.0348 KOps/s $\color{#d91a1a}-1.68\%$
test_step_mdp_speed[False-True-True-False-True] 74.7500μs 21.9712μs 45.5141 KOps/s 44.9626 KOps/s $\color{#35bf28}+1.23\%$
test_step_mdp_speed[False-True-True-False-False] 67.9770μs 13.4551μs 74.3211 KOps/s 75.6971 KOps/s $\color{#d91a1a}-1.82\%$
test_step_mdp_speed[False-True-False-True-True] 88.8070μs 36.1414μs 27.6691 KOps/s 28.4653 KOps/s $\color{#d91a1a}-2.80\%$
test_step_mdp_speed[False-True-False-True-False] 61.0250μs 23.9141μs 41.8164 KOps/s 42.6559 KOps/s $\color{#d91a1a}-1.97\%$
test_step_mdp_speed[False-True-False-False-True] 2.6106ms 23.8353μs 41.9546 KOps/s 43.1588 KOps/s $\color{#d91a1a}-2.79\%$
test_step_mdp_speed[False-True-False-False-False] 74.5810μs 15.2323μs 65.6500 KOps/s 67.2401 KOps/s $\color{#d91a1a}-2.36\%$
test_step_mdp_speed[False-False-True-True-True] 76.3430μs 38.2271μs 26.1595 KOps/s 26.9248 KOps/s $\color{#d91a1a}-2.84\%$
test_step_mdp_speed[False-False-True-True-False] 59.1900μs 25.7107μs 38.8943 KOps/s 39.1115 KOps/s $\color{#d91a1a}-0.56\%$
test_step_mdp_speed[False-False-True-False-True] 77.7560μs 23.8203μs 41.9811 KOps/s 43.6152 KOps/s $\color{#d91a1a}-3.75\%$
test_step_mdp_speed[False-False-True-False-False] 56.8970μs 15.4390μs 64.7712 KOps/s 66.9105 KOps/s $\color{#d91a1a}-3.20\%$
test_step_mdp_speed[False-False-False-True-True] 75.7320μs 40.0003μs 24.9998 KOps/s 25.6725 KOps/s $\color{#d91a1a}-2.62\%$
test_step_mdp_speed[False-False-False-True-False] 98.1130μs 27.3908μs 36.5086 KOps/s 37.1493 KOps/s $\color{#d91a1a}-1.72\%$
test_step_mdp_speed[False-False-False-False-True] 56.6260μs 25.2534μs 39.5986 KOps/s 40.6499 KOps/s $\color{#d91a1a}-2.59\%$
test_step_mdp_speed[False-False-False-False-False] 55.2730μs 16.8121μs 59.4808 KOps/s 60.3791 KOps/s $\color{#d91a1a}-1.49\%$
test_values[generalized_advantage_estimate-True-True] 10.2948ms 10.0375ms 99.6267 Ops/s 99.6708 Ops/s $\color{#d91a1a}-0.04\%$
test_values[vec_generalized_advantage_estimate-True-True] 38.0918ms 33.7460ms 29.6332 Ops/s 29.4798 Ops/s $\color{#35bf28}+0.52\%$
test_values[td0_return_estimate-False-False] 0.2815ms 0.2124ms 4.7079 KOps/s 4.7050 KOps/s $\color{#35bf28}+0.06\%$
test_values[td1_return_estimate-False-False] 28.3563ms 24.6104ms 40.6332 Ops/s 39.9479 Ops/s $\color{#35bf28}+1.72\%$
test_values[vec_td1_return_estimate-False-False] 37.1327ms 33.6920ms 29.6807 Ops/s 29.4069 Ops/s $\color{#35bf28}+0.93\%$
test_values[td_lambda_return_estimate-True-False] 37.4898ms 35.0156ms 28.5587 Ops/s 28.0091 Ops/s $\color{#35bf28}+1.96\%$
test_values[vec_td_lambda_return_estimate-True-False] 46.5158ms 34.1923ms 29.2463 Ops/s 29.3585 Ops/s $\color{#d91a1a}-0.38\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 11.8942ms 8.5755ms 116.6109 Ops/s 115.9356 Ops/s $\color{#35bf28}+0.58\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 2.3519ms 1.8553ms 539.0096 Ops/s 539.3791 Ops/s $\color{#d91a1a}-0.07\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.6300ms 0.3612ms 2.7687 KOps/s 2.7729 KOps/s $\color{#d91a1a}-0.15\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 39.1611ms 38.4559ms 26.0038 Ops/s 26.1163 Ops/s $\color{#d91a1a}-0.43\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 4.3696ms 3.0513ms 327.7333 Ops/s 305.2308 Ops/s $\textbf{\color{#35bf28}+7.37\%}$
test_dqn_speed[False-None] 2.0935ms 1.4144ms 707.0208 Ops/s 693.1214 Ops/s $\color{#35bf28}+2.01\%$
test_dqn_speed[False-backward] 1.9754ms 1.8882ms 529.5979 Ops/s 501.3073 Ops/s $\textbf{\color{#35bf28}+5.64\%}$
test_dqn_speed[True-None] 0.7426ms 0.4784ms 2.0902 KOps/s 2.0383 KOps/s $\color{#35bf28}+2.55\%$
test_dqn_speed[True-backward] 0.9443ms 0.8949ms 1.1175 KOps/s 854.5569 Ops/s $\textbf{\color{#35bf28}+30.76\%}$
test_dqn_speed[reduce-overhead-None] 0.5922ms 0.4773ms 2.0952 KOps/s 2.0046 KOps/s $\color{#35bf28}+4.52\%$
test_dqn_speed[reduce-overhead-backward] 0.9285ms 0.8886ms 1.1254 KOps/s 1.0264 KOps/s $\textbf{\color{#35bf28}+9.65\%}$
test_ddpg_speed[False-None] 3.3212ms 2.9198ms 342.4846 Ops/s 316.1030 Ops/s $\textbf{\color{#35bf28}+8.35\%}$
test_ddpg_speed[False-backward] 4.2055ms 4.0502ms 246.9003 Ops/s 224.5300 Ops/s $\textbf{\color{#35bf28}+9.96\%}$
test_ddpg_speed[True-None] 1.4579ms 1.0210ms 979.4420 Ops/s 555.9376 Ops/s $\textbf{\color{#35bf28}+76.18\%}$
test_ddpg_speed[True-backward] 2.4231ms 1.9716ms 507.1938 Ops/s 415.0689 Ops/s $\textbf{\color{#35bf28}+22.20\%}$
test_ddpg_speed[reduce-overhead-None] 1.6882ms 1.0098ms 990.3370 Ops/s 927.0692 Ops/s $\textbf{\color{#35bf28}+6.82\%}$
test_ddpg_speed[reduce-overhead-backward] 1.9580ms 1.9069ms 524.4114 Ops/s 455.5887 Ops/s $\textbf{\color{#35bf28}+15.11\%}$
test_sac_speed[False-None] 9.3644ms 8.1158ms 123.2165 Ops/s 105.0327 Ops/s $\textbf{\color{#35bf28}+17.31\%}$
test_sac_speed[False-backward] 11.3785ms 10.9020ms 91.7265 Ops/s 78.6914 Ops/s $\textbf{\color{#35bf28}+16.56\%}$
test_sac_speed[True-None] 2.7484ms 1.8626ms 536.8895 Ops/s 470.5976 Ops/s $\textbf{\color{#35bf28}+14.09\%}$
test_sac_speed[True-backward] 3.6861ms 3.5862ms 278.8469 Ops/s 229.0318 Ops/s $\textbf{\color{#35bf28}+21.75\%}$
test_sac_speed[reduce-overhead-None] 2.1804ms 1.8414ms 543.0652 Ops/s 443.9634 Ops/s $\textbf{\color{#35bf28}+22.32\%}$
test_sac_speed[reduce-overhead-backward] 3.7551ms 3.6122ms 276.8364 Ops/s 216.0022 Ops/s $\textbf{\color{#35bf28}+28.16\%}$
test_redq_speed[False-None] 15.8075ms 13.8531ms 72.1861 Ops/s 69.3331 Ops/s $\color{#35bf28}+4.11\%$
test_redq_speed[False-backward] 0.2707s 28.3996ms 35.2117 Ops/s 41.1332 Ops/s $\textbf{\color{#d91a1a}-14.40\%}$
test_redq_speed[True-None] 6.3448ms 5.3953ms 185.3456 Ops/s 164.0274 Ops/s $\textbf{\color{#35bf28}+13.00\%}$
test_redq_speed[True-backward] 13.4297ms 12.8381ms 77.8929 Ops/s 72.1986 Ops/s $\textbf{\color{#35bf28}+7.89\%}$
test_redq_speed[reduce-overhead-None] 6.5822ms 5.8316ms 171.4789 Ops/s 164.9629 Ops/s $\color{#35bf28}+3.95\%$
test_redq_speed[reduce-overhead-backward] 14.1782ms 13.1963ms 75.7786 Ops/s 73.5108 Ops/s $\color{#35bf28}+3.09\%$
test_redq_deprec_speed[False-None] 15.4132ms 13.7509ms 72.7227 Ops/s 65.0288 Ops/s $\textbf{\color{#35bf28}+11.83\%}$
test_redq_deprec_speed[False-backward] 23.4082ms 20.6279ms 48.4780 Ops/s 44.7847 Ops/s $\textbf{\color{#35bf28}+8.25\%}$
test_redq_deprec_speed[True-None] 4.9417ms 4.2255ms 236.6579 Ops/s 203.6377 Ops/s $\textbf{\color{#35bf28}+16.22\%}$
test_redq_deprec_speed[True-backward] 10.7496ms 9.3165ms 107.3368 Ops/s 104.4682 Ops/s $\color{#35bf28}+2.75\%$
test_redq_deprec_speed[reduce-overhead-None] 4.6193ms 4.0806ms 245.0602 Ops/s 218.4479 Ops/s $\textbf{\color{#35bf28}+12.18\%}$
test_redq_deprec_speed[reduce-overhead-backward] 9.1956ms 8.9940ms 111.1857 Ops/s 102.4827 Ops/s $\textbf{\color{#35bf28}+8.49\%}$
test_td3_speed[False-None] 10.6821ms 8.4376ms 118.5172 Ops/s 104.8598 Ops/s $\textbf{\color{#35bf28}+13.02\%}$
test_td3_speed[False-backward] 12.0236ms 11.0022ms 90.8912 Ops/s 40.3497 Ops/s $\textbf{\color{#35bf28}+125.26\%}$
test_td3_speed[True-None] 2.0335ms 1.8174ms 550.2453 Ops/s 461.5162 Ops/s $\textbf{\color{#35bf28}+19.23\%}$
test_td3_speed[True-backward] 3.8084ms 3.6422ms 274.5560 Ops/s 234.1285 Ops/s $\textbf{\color{#35bf28}+17.27\%}$
test_td3_speed[reduce-overhead-None] 2.6153ms 1.8112ms 552.1273 Ops/s 526.2834 Ops/s $\color{#35bf28}+4.91\%$
test_td3_speed[reduce-overhead-backward] 3.9475ms 3.5723ms 279.9339 Ops/s 242.0293 Ops/s $\textbf{\color{#35bf28}+15.66\%}$
test_cql_speed[False-None] 39.9291ms 37.9786ms 26.3306 Ops/s 25.0564 Ops/s $\textbf{\color{#35bf28}+5.09\%}$
test_cql_speed[False-backward] 51.8552ms 48.7045ms 20.5320 Ops/s 19.7065 Ops/s $\color{#35bf28}+4.19\%$
test_cql_speed[True-None] 18.3823ms 16.3643ms 61.1087 Ops/s 61.5889 Ops/s $\color{#d91a1a}-0.78\%$
test_cql_speed[True-backward] 24.6161ms 23.3575ms 42.8128 Ops/s 42.7166 Ops/s $\color{#35bf28}+0.23\%$
test_cql_speed[reduce-overhead-None] 18.0415ms 16.5805ms 60.3117 Ops/s 61.8871 Ops/s $\color{#d91a1a}-2.55\%$
test_cql_speed[reduce-overhead-backward] 25.1492ms 24.0684ms 41.5482 Ops/s 42.2150 Ops/s $\color{#d91a1a}-1.58\%$
test_a2c_speed[False-None] 9.8296ms 8.4773ms 117.9617 Ops/s 127.4906 Ops/s $\textbf{\color{#d91a1a}-7.47\%}$
test_a2c_speed[False-backward] 16.5946ms 16.1380ms 61.9656 Ops/s 64.1320 Ops/s $\color{#d91a1a}-3.38\%$
test_a2c_speed[True-None] 5.0977ms 4.6624ms 214.4839 Ops/s 233.6190 Ops/s $\textbf{\color{#d91a1a}-8.19\%}$
test_a2c_speed[True-backward] 12.3148ms 11.8640ms 84.2884 Ops/s 91.5602 Ops/s $\textbf{\color{#d91a1a}-7.94\%}$
test_a2c_speed[reduce-overhead-None] 5.7119ms 4.8792ms 204.9506 Ops/s 235.8139 Ops/s $\textbf{\color{#d91a1a}-13.09\%}$
test_a2c_speed[reduce-overhead-backward] 12.0445ms 11.8313ms 84.5214 Ops/s 94.0394 Ops/s $\textbf{\color{#d91a1a}-10.12\%}$
test_ppo_speed[False-None] 9.6444ms 8.6675ms 115.3741 Ops/s 133.1203 Ops/s $\textbf{\color{#d91a1a}-13.33\%}$
test_ppo_speed[False-backward] 18.3843ms 16.7214ms 59.8036 Ops/s 65.0513 Ops/s $\textbf{\color{#d91a1a}-8.07\%}$
test_ppo_speed[True-None] 4.8369ms 4.3712ms 228.7712 Ops/s 264.2910 Ops/s $\textbf{\color{#d91a1a}-13.44\%}$
test_ppo_speed[True-backward] 14.9920ms 10.7936ms 92.6477 Ops/s 101.4494 Ops/s $\textbf{\color{#d91a1a}-8.68\%}$
test_ppo_speed[reduce-overhead-None] 4.7394ms 4.2130ms 237.3602 Ops/s 266.9062 Ops/s $\textbf{\color{#d91a1a}-11.07\%}$
test_ppo_speed[reduce-overhead-backward] 10.8793ms 10.6079ms 94.2696 Ops/s 101.9902 Ops/s $\textbf{\color{#d91a1a}-7.57\%}$
test_reinforce_speed[False-None] 8.8304ms 7.4774ms 133.7363 Ops/s 152.3433 Ops/s $\textbf{\color{#d91a1a}-12.21\%}$
test_reinforce_speed[False-backward] 12.3544ms 11.1352ms 89.8055 Ops/s 99.5649 Ops/s $\textbf{\color{#d91a1a}-9.80\%}$
test_reinforce_speed[True-None] 3.4598ms 3.2329ms 309.3178 Ops/s 368.2565 Ops/s $\textbf{\color{#d91a1a}-16.00\%}$
test_reinforce_speed[True-backward] 9.8995ms 9.6132ms 104.0238 Ops/s 106.9073 Ops/s $\color{#d91a1a}-2.70\%$
test_reinforce_speed[reduce-overhead-None] 3.6624ms 3.1864ms 313.8315 Ops/s 372.0995 Ops/s $\textbf{\color{#d91a1a}-15.66\%}$
test_reinforce_speed[reduce-overhead-backward] 10.5593ms 9.6947ms 103.1496 Ops/s 116.4056 Ops/s $\textbf{\color{#d91a1a}-11.39\%}$
test_iql_speed[False-None] 49.3593ms 35.4278ms 28.2264 Ops/s 29.4787 Ops/s $\color{#d91a1a}-4.25\%$
test_iql_speed[False-backward] 53.2848ms 48.5551ms 20.5952 Ops/s 15.1996 Ops/s $\textbf{\color{#35bf28}+35.50\%}$
test_iql_speed[True-None] 12.5246ms 11.4373ms 87.4335 Ops/s 87.6043 Ops/s $\color{#d91a1a}-0.19\%$
test_iql_speed[True-backward] 24.1073ms 23.0146ms 43.4508 Ops/s 43.2112 Ops/s $\color{#35bf28}+0.55\%$
test_iql_speed[reduce-overhead-None] 12.3481ms 11.4941ms 87.0011 Ops/s 87.3294 Ops/s $\color{#d91a1a}-0.38\%$
test_iql_speed[reduce-overhead-backward] 24.9293ms 23.0904ms 43.3081 Ops/s 42.5897 Ops/s $\color{#35bf28}+1.69\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 8.1250ms 5.8221ms 171.7589 Ops/s 177.6457 Ops/s $\color{#d91a1a}-3.31\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 1.0202ms 0.5559ms 1.7987 KOps/s 1.7642 KOps/s $\color{#35bf28}+1.96\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.8099ms 0.5320ms 1.8798 KOps/s 1.8771 KOps/s $\color{#35bf28}+0.15\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 8.2625ms 5.2444ms 190.6792 Ops/s 190.6186 Ops/s $\color{#35bf28}+0.03\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 1.6004ms 0.5456ms 1.8330 KOps/s 1.8592 KOps/s $\color{#d91a1a}-1.41\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.7647ms 0.5113ms 1.9559 KOps/s 1.9465 KOps/s $\color{#35bf28}+0.48\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] 2.7493ms 1.7383ms 575.2806 Ops/s 596.5794 Ops/s $\color{#d91a1a}-3.57\%$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] 2.2752ms 1.6616ms 601.8458 Ops/s 631.2298 Ops/s $\color{#d91a1a}-4.66\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 8.8024ms 5.7953ms 172.5531 Ops/s 192.3785 Ops/s $\textbf{\color{#d91a1a}-10.31\%}$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 1.8401ms 0.7004ms 1.4277 KOps/s 1.4541 KOps/s $\color{#d91a1a}-1.81\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 2.0275ms 0.6694ms 1.4939 KOps/s 1.5070 KOps/s $\color{#d91a1a}-0.87\%$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 8.3305ms 5.4801ms 182.4769 Ops/s 184.3692 Ops/s $\color{#d91a1a}-1.03\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 1.5401ms 0.5593ms 1.7880 KOps/s 1.7634 KOps/s $\color{#35bf28}+1.39\%$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.5242s 1.2548ms 796.9199 Ops/s 1.8808 KOps/s $\textbf{\color{#d91a1a}-57.63\%}$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 7.1808ms 5.3063ms 188.4567 Ops/s 189.8579 Ops/s $\color{#d91a1a}-0.74\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 1.5310ms 0.5471ms 1.8277 KOps/s 423.6051 Ops/s $\textbf{\color{#35bf28}+331.47\%}$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.8288ms 0.5307ms 1.8844 KOps/s 1.8713 KOps/s $\color{#35bf28}+0.70\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 6.6119ms 5.9415ms 168.3064 Ops/s 162.8823 Ops/s $\color{#35bf28}+3.33\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 1.0020ms 0.6981ms 1.4324 KOps/s 1.4272 KOps/s $\color{#35bf28}+0.36\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 9.6943ms 0.6972ms 1.4344 KOps/s 1.4538 KOps/s $\color{#d91a1a}-1.33\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 8.3869ms 4.9208ms 203.2194 Ops/s 174.6613 Ops/s $\textbf{\color{#35bf28}+16.35\%}$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 4.0959ms 2.3769ms 420.7117 Ops/s 385.2650 Ops/s $\textbf{\color{#35bf28}+9.20\%}$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 7.9122ms 1.5877ms 629.8268 Ops/s 644.9104 Ops/s $\color{#d91a1a}-2.34\%$
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 0.5115s 15.3170ms 65.2869 Ops/s 185.6160 Ops/s $\textbf{\color{#d91a1a}-64.83\%}$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 8.1117ms 2.5402ms 393.6704 Ops/s 384.5297 Ops/s $\color{#35bf28}+2.38\%$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 1.9181ms 1.3244ms 755.0434 Ops/s 665.2166 Ops/s $\textbf{\color{#35bf28}+13.50\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 8.3161ms 5.6149ms 178.0974 Ops/s 200.1906 Ops/s $\textbf{\color{#d91a1a}-11.04\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 9.8979ms 2.8863ms 346.4630 Ops/s 372.9954 Ops/s $\textbf{\color{#d91a1a}-7.11\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 3.3890ms 1.6743ms 597.2653 Ops/s 569.9659 Ops/s $\color{#35bf28}+4.79\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] 16.2573ms 13.7136ms 72.9201 Ops/s 70.0920 Ops/s $\color{#35bf28}+4.03\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] 17.5449ms 15.3295ms 65.2338 Ops/s 63.9357 Ops/s $\color{#35bf28}+2.03\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] 23.5649ms 22.6039ms 44.2401 Ops/s 44.4248 Ops/s $\color{#d91a1a}-0.42\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] 17.1245ms 15.5890ms 64.1477 Ops/s 63.8144 Ops/s $\color{#35bf28}+0.52\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] 23.5630ms 22.5386ms 44.3684 Ops/s 44.6329 Ops/s $\color{#d91a1a}-0.59\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] 18.4069ms 17.0650ms 58.5995 Ops/s 59.7489 Ops/s $\color{#d91a1a}-1.92\%$

Copy link

github-actions bot commented Jan 9, 2025

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of GPU Benchmark Tests

Total Benchmarks: 149. Improved: $\large\color{#35bf28}23$. Worsened: $\large\color{#d91a1a}9$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_simple 0.8060s 0.7235s 1.3822 Ops/s 1.3658 Ops/s $\color{#35bf28}+1.20\%$
test_transformed 0.9492s 0.9484s 1.0544 Ops/s 1.0216 Ops/s $\color{#35bf28}+3.21\%$
test_serial 2.0861s 2.0792s 0.4810 Ops/s 0.4728 Ops/s $\color{#35bf28}+1.73\%$
test_parallel 1.8853s 1.8162s 0.5506 Ops/s 0.5381 Ops/s $\color{#35bf28}+2.33\%$
test_step_mdp_speed[True-True-True-True-True] 0.2001ms 39.9129μs 25.0546 KOps/s 24.8750 KOps/s $\color{#35bf28}+0.72\%$
test_step_mdp_speed[True-True-True-True-False] 59.6220μs 23.2860μs 42.9443 KOps/s 42.1577 KOps/s $\color{#35bf28}+1.87\%$
test_step_mdp_speed[True-True-True-False-True] 57.0430μs 22.4992μs 44.4460 KOps/s 44.6167 KOps/s $\color{#d91a1a}-0.38\%$
test_step_mdp_speed[True-True-True-False-False] 43.6020μs 12.9599μs 77.1612 KOps/s 76.0541 KOps/s $\color{#35bf28}+1.46\%$
test_step_mdp_speed[True-True-False-True-True] 81.5540μs 42.8597μs 23.3320 KOps/s 23.1057 KOps/s $\color{#35bf28}+0.98\%$
test_step_mdp_speed[True-True-False-True-False] 67.3330μs 25.8639μs 38.6640 KOps/s 37.9710 KOps/s $\color{#35bf28}+1.82\%$
test_step_mdp_speed[True-True-False-False-True] 63.9030μs 24.8776μs 40.1968 KOps/s 39.9172 KOps/s $\color{#35bf28}+0.70\%$
test_step_mdp_speed[True-True-False-False-False] 49.7520μs 15.5129μs 64.4624 KOps/s 63.7846 KOps/s $\color{#35bf28}+1.06\%$
test_step_mdp_speed[True-False-True-True-True] 99.7250μs 45.8356μs 21.8171 KOps/s 21.7927 KOps/s $\color{#35bf28}+0.11\%$
test_step_mdp_speed[True-False-True-True-False] 59.8030μs 28.5072μs 35.0788 KOps/s 34.7303 KOps/s $\color{#35bf28}+1.00\%$
test_step_mdp_speed[True-False-True-False-True] 71.7440μs 24.7285μs 40.4391 KOps/s 39.8918 KOps/s $\color{#35bf28}+1.37\%$
test_step_mdp_speed[True-False-True-False-False] 46.4220μs 15.5818μs 64.1772 KOps/s 65.4093 KOps/s $\color{#d91a1a}-1.88\%$
test_step_mdp_speed[True-False-False-True-True] 80.3540μs 47.6746μs 20.9755 KOps/s 20.7085 KOps/s $\color{#35bf28}+1.29\%$
test_step_mdp_speed[True-False-False-True-False] 79.0240μs 30.0940μs 33.2292 KOps/s 32.8183 KOps/s $\color{#35bf28}+1.25\%$
test_step_mdp_speed[True-False-False-False-True] 59.6430μs 26.9922μs 37.0478 KOps/s 36.6503 KOps/s $\color{#35bf28}+1.08\%$
test_step_mdp_speed[True-False-False-False-False] 47.6020μs 17.4773μs 57.2172 KOps/s 55.7687 KOps/s $\color{#35bf28}+2.60\%$
test_step_mdp_speed[False-True-True-True-True] 92.1540μs 45.1355μs 22.1555 KOps/s 21.6677 KOps/s $\color{#35bf28}+2.25\%$
test_step_mdp_speed[False-True-True-True-False] 61.9720μs 27.9291μs 35.8050 KOps/s 35.2658 KOps/s $\color{#35bf28}+1.53\%$
test_step_mdp_speed[False-True-True-False-True] 65.0230μs 28.4115μs 35.1970 KOps/s 34.0991 KOps/s $\color{#35bf28}+3.22\%$
test_step_mdp_speed[False-True-True-False-False] 57.7730μs 17.1499μs 58.3093 KOps/s 57.9377 KOps/s $\color{#35bf28}+0.64\%$
test_step_mdp_speed[False-True-False-True-True] 85.6540μs 47.5199μs 21.0438 KOps/s 20.7287 KOps/s $\color{#35bf28}+1.52\%$
test_step_mdp_speed[False-True-False-True-False] 66.6030μs 30.3307μs 32.9699 KOps/s 32.7234 KOps/s $\color{#35bf28}+0.75\%$
test_step_mdp_speed[False-True-False-False-True] 3.1631ms 31.5930μs 31.6526 KOps/s 31.7965 KOps/s $\color{#d91a1a}-0.45\%$
test_step_mdp_speed[False-True-False-False-False] 55.2530μs 19.7668μs 50.5899 KOps/s 51.0143 KOps/s $\color{#d91a1a}-0.83\%$
test_step_mdp_speed[False-False-True-True-True] 99.4550μs 49.4296μs 20.2308 KOps/s 19.7393 KOps/s $\color{#35bf28}+2.49\%$
test_step_mdp_speed[False-False-True-True-False] 69.2630μs 33.0527μs 30.2547 KOps/s 29.7979 KOps/s $\color{#35bf28}+1.53\%$
test_step_mdp_speed[False-False-True-False-True] 69.2530μs 30.4899μs 32.7978 KOps/s 32.3723 KOps/s $\color{#35bf28}+1.31\%$
test_step_mdp_speed[False-False-True-False-False] 53.4020μs 19.6334μs 50.9335 KOps/s 51.1258 KOps/s $\color{#d91a1a}-0.38\%$
test_step_mdp_speed[False-False-False-True-True] 93.2440μs 51.8988μs 19.2683 KOps/s 19.1939 KOps/s $\color{#35bf28}+0.39\%$
test_step_mdp_speed[False-False-False-True-False] 70.0640μs 35.1579μs 28.4431 KOps/s 28.6358 KOps/s $\color{#d91a1a}-0.67\%$
test_step_mdp_speed[False-False-False-False-True] 73.5240μs 32.1815μs 31.0738 KOps/s 30.1508 KOps/s $\color{#35bf28}+3.06\%$
test_step_mdp_speed[False-False-False-False-False] 53.8020μs 21.6186μs 46.2564 KOps/s 45.9323 KOps/s $\color{#35bf28}+0.71\%$
test_values[generalized_advantage_estimate-True-True] 23.8026ms 23.3221ms 42.8778 Ops/s 42.3151 Ops/s $\color{#35bf28}+1.33\%$
test_values[vec_generalized_advantage_estimate-True-True] 97.8128ms 2.8316ms 353.1618 Ops/s 337.2235 Ops/s $\color{#35bf28}+4.73\%$
test_values[td0_return_estimate-False-False] 0.1040ms 75.1743μs 13.3024 KOps/s 12.9857 KOps/s $\color{#35bf28}+2.44\%$
test_values[td1_return_estimate-False-False] 53.1469ms 52.3855ms 19.0893 Ops/s 18.8518 Ops/s $\color{#35bf28}+1.26\%$
test_values[vec_td1_return_estimate-False-False] 1.2912ms 1.0501ms 952.2839 Ops/s 944.6772 Ops/s $\color{#35bf28}+0.81\%$
test_values[td_lambda_return_estimate-True-False] 87.5671ms 83.5382ms 11.9706 Ops/s 11.8459 Ops/s $\color{#35bf28}+1.05\%$
test_values[vec_td_lambda_return_estimate-True-False] 1.1854ms 1.0442ms 957.6430 Ops/s 948.1505 Ops/s $\color{#35bf28}+1.00\%$
test_gae_speed[generalized_advantage_estimate-False-1-512] 23.5033ms 23.0379ms 43.4068 Ops/s 42.3430 Ops/s $\color{#35bf28}+2.51\%$
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] 1.0020ms 0.7169ms 1.3949 KOps/s 1.3738 KOps/s $\color{#35bf28}+1.54\%$
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] 0.7260ms 0.6407ms 1.5608 KOps/s 1.5502 KOps/s $\color{#35bf28}+0.68\%$
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] 1.5103ms 1.4495ms 689.8972 Ops/s 686.5300 Ops/s $\color{#35bf28}+0.49\%$
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] 0.7135ms 0.6563ms 1.5237 KOps/s 1.4732 KOps/s $\color{#35bf28}+3.43\%$
test_dqn_speed[False-None] 6.8786ms 1.4900ms 671.1429 Ops/s 678.8408 Ops/s $\color{#d91a1a}-1.13\%$
test_dqn_speed[False-backward] 2.1245ms 2.0711ms 482.8300 Ops/s 486.2776 Ops/s $\color{#d91a1a}-0.71\%$
test_dqn_speed[True-None] 0.6308ms 0.5410ms 1.8485 KOps/s 1.8196 KOps/s $\color{#35bf28}+1.59\%$
test_dqn_speed[True-backward] 1.2760ms 1.1795ms 847.8456 Ops/s 830.8954 Ops/s $\color{#35bf28}+2.04\%$
test_dqn_speed[reduce-overhead-None] 0.6275ms 0.5534ms 1.8070 KOps/s 1.7506 KOps/s $\color{#35bf28}+3.22\%$
test_dqn_speed[reduce-overhead-backward] 1.1230ms 1.0437ms 958.1394 Ops/s 940.6862 Ops/s $\color{#35bf28}+1.86\%$
test_ddpg_speed[False-None] 3.0807ms 2.8075ms 356.1941 Ops/s 350.0253 Ops/s $\color{#35bf28}+1.76\%$
test_ddpg_speed[False-backward] 4.4811ms 4.0936ms 244.2820 Ops/s 241.4107 Ops/s $\color{#35bf28}+1.19\%$
test_ddpg_speed[True-None] 1.1323ms 1.0598ms 943.6076 Ops/s 905.0687 Ops/s $\color{#35bf28}+4.26\%$
test_ddpg_speed[True-backward] 2.2911ms 2.2442ms 445.5896 Ops/s 465.7315 Ops/s $\color{#d91a1a}-4.32\%$
test_ddpg_speed[reduce-overhead-None] 1.2213ms 1.1042ms 905.6713 Ops/s 884.8497 Ops/s $\color{#35bf28}+2.35\%$
test_ddpg_speed[reduce-overhead-backward] 1.7908ms 1.7412ms 574.3205 Ops/s 612.7275 Ops/s $\textbf{\color{#d91a1a}-6.27\%}$
test_sac_speed[False-None] 8.2835ms 7.8495ms 127.3962 Ops/s 125.1734 Ops/s $\color{#35bf28}+1.78\%$
test_sac_speed[False-backward] 11.3340ms 10.8214ms 92.4092 Ops/s 93.1603 Ops/s $\color{#d91a1a}-0.81\%$
test_sac_speed[True-None] 1.6665ms 1.5058ms 664.0928 Ops/s 652.9916 Ops/s $\color{#35bf28}+1.70\%$
test_sac_speed[True-backward] 3.3975ms 3.3231ms 300.9256 Ops/s 298.7448 Ops/s $\color{#35bf28}+0.73\%$
test_sac_speed[reduce-overhead-None] 23.0791ms 12.7645ms 78.3425 Ops/s 78.2932 Ops/s $\color{#35bf28}+0.06\%$
test_sac_speed[reduce-overhead-backward] 1.6053ms 1.5020ms 665.7887 Ops/s 744.9279 Ops/s $\textbf{\color{#d91a1a}-10.62\%}$
test_redq_speed[False-None] 8.0654ms 7.3334ms 136.3615 Ops/s 134.7176 Ops/s $\color{#35bf28}+1.22\%$
test_redq_speed[False-backward] 12.0885ms 11.2712ms 88.7220 Ops/s 91.0152 Ops/s $\color{#d91a1a}-2.52\%$
test_redq_speed[True-None] 2.1897ms 1.9584ms 510.6150 Ops/s 509.7029 Ops/s $\color{#35bf28}+0.18\%$
test_redq_speed[True-backward] 3.9969ms 3.7333ms 267.8589 Ops/s 262.6352 Ops/s $\color{#35bf28}+1.99\%$
test_redq_speed[reduce-overhead-None] 2.0475ms 1.9404ms 515.3534 Ops/s 510.2938 Ops/s $\color{#35bf28}+0.99\%$
test_redq_speed[reduce-overhead-backward] 3.9809ms 3.5454ms 282.0523 Ops/s 263.9370 Ops/s $\textbf{\color{#35bf28}+6.86\%}$
test_redq_deprec_speed[False-None] 9.2782ms 8.7705ms 114.0180 Ops/s 111.8246 Ops/s $\color{#35bf28}+1.96\%$
test_redq_deprec_speed[False-backward] 11.8799ms 11.5526ms 86.5608 Ops/s 83.5564 Ops/s $\color{#35bf28}+3.60\%$
test_redq_deprec_speed[True-None] 2.4192ms 2.2825ms 438.1080 Ops/s 433.1120 Ops/s $\color{#35bf28}+1.15\%$
test_redq_deprec_speed[True-backward] 4.0384ms 3.8984ms 256.5155 Ops/s 243.0773 Ops/s $\textbf{\color{#35bf28}+5.53\%}$
test_redq_deprec_speed[reduce-overhead-None] 2.3894ms 2.2801ms 438.5730 Ops/s 431.1762 Ops/s $\color{#35bf28}+1.72\%$
test_redq_deprec_speed[reduce-overhead-backward] 4.3451ms 3.9209ms 255.0412 Ops/s 244.3168 Ops/s $\color{#35bf28}+4.39\%$
test_td3_speed[False-None] 7.8777ms 7.7449ms 129.1178 Ops/s 127.5278 Ops/s $\color{#35bf28}+1.25\%$
test_td3_speed[False-backward] 10.2820ms 9.9762ms 100.2382 Ops/s 97.3108 Ops/s $\color{#35bf28}+3.01\%$
test_td3_speed[True-None] 1.6199ms 1.5709ms 636.5803 Ops/s 620.6055 Ops/s $\color{#35bf28}+2.57\%$
test_td3_speed[True-backward] 3.0846ms 3.0399ms 328.9535 Ops/s 306.9307 Ops/s $\textbf{\color{#35bf28}+7.18\%}$
test_td3_speed[reduce-overhead-None] 56.3258ms 26.2147ms 38.1466 Ops/s 38.2017 Ops/s $\color{#d91a1a}-0.14\%$
test_td3_speed[reduce-overhead-backward] 1.3381ms 1.2810ms 780.6125 Ops/s 687.5121 Ops/s $\textbf{\color{#35bf28}+13.54\%}$
test_cql_speed[False-None] 16.9255ms 16.3828ms 61.0396 Ops/s 60.5445 Ops/s $\color{#35bf28}+0.82\%$
test_cql_speed[False-backward] 22.1089ms 21.2517ms 47.0551 Ops/s 45.9246 Ops/s $\color{#35bf28}+2.46\%$
test_cql_speed[True-None] 2.9575ms 2.8674ms 348.7453 Ops/s 344.8747 Ops/s $\color{#35bf28}+1.12\%$
test_cql_speed[True-backward] 5.2995ms 4.9854ms 200.5877 Ops/s 198.6451 Ops/s $\color{#35bf28}+0.98\%$
test_cql_speed[reduce-overhead-None] 0.3663s 15.1243ms 66.1186 Ops/s 74.7428 Ops/s $\textbf{\color{#d91a1a}-11.54\%}$
test_cql_speed[reduce-overhead-backward] 1.5521ms 1.5112ms 661.7442 Ops/s 587.3954 Ops/s $\textbf{\color{#35bf28}+12.66\%}$
test_a2c_speed[False-None] 3.2165ms 3.1301ms 319.4795 Ops/s 315.9132 Ops/s $\color{#35bf28}+1.13\%$
test_a2c_speed[False-backward] 6.9728ms 5.8948ms 169.6410 Ops/s 160.8495 Ops/s $\textbf{\color{#35bf28}+5.47\%}$
test_a2c_speed[True-None] 1.1220ms 0.9958ms 1.0042 KOps/s 990.1832 Ops/s $\color{#35bf28}+1.41\%$
test_a2c_speed[True-backward] 2.5861ms 2.5184ms 397.0720 Ops/s 364.4988 Ops/s $\textbf{\color{#35bf28}+8.94\%}$
test_a2c_speed[reduce-overhead-None] 21.5077ms 11.6753ms 85.6511 Ops/s 86.1301 Ops/s $\color{#d91a1a}-0.56\%$
test_a2c_speed[reduce-overhead-backward] 0.9993ms 0.9594ms 1.0423 KOps/s 869.3047 Ops/s $\textbf{\color{#35bf28}+19.90\%}$
test_ppo_speed[False-None] 3.7306ms 3.5881ms 278.7006 Ops/s 277.3476 Ops/s $\color{#35bf28}+0.49\%$
test_ppo_speed[False-backward] 7.2988ms 6.6101ms 151.2834 Ops/s 147.3737 Ops/s $\color{#35bf28}+2.65\%$
test_ppo_speed[True-None] 1.0065ms 0.9365ms 1.0678 KOps/s 1.0396 KOps/s $\color{#35bf28}+2.71\%$
test_ppo_speed[True-backward] 2.7110ms 2.6477ms 377.6927 Ops/s 398.1029 Ops/s $\textbf{\color{#d91a1a}-5.13\%}$
test_ppo_speed[reduce-overhead-None] 0.5865ms 0.5270ms 1.8976 KOps/s 68.0171 Ops/s $\textbf{\color{#35bf28}+2689.90\%}$
test_ppo_speed[reduce-overhead-backward] 1.1553ms 1.1038ms 905.9800 Ops/s 978.7588 Ops/s $\textbf{\color{#d91a1a}-7.44\%}$
test_reinforce_speed[False-None] 2.3093ms 2.2144ms 451.5948 Ops/s 447.1864 Ops/s $\color{#35bf28}+0.99\%$
test_reinforce_speed[False-backward] 3.5669ms 3.2616ms 306.6002 Ops/s 314.1609 Ops/s $\color{#d91a1a}-2.41\%$
test_reinforce_speed[True-None] 0.8661ms 0.8198ms 1.2198 KOps/s 1.1488 KOps/s $\textbf{\color{#35bf28}+6.18\%}$
test_reinforce_speed[True-backward] 2.5952ms 2.5008ms 399.8685 Ops/s 412.7820 Ops/s $\color{#d91a1a}-3.13\%$
test_reinforce_speed[reduce-overhead-None] 0.2927s 12.1792ms 82.1070 Ops/s 86.3953 Ops/s $\color{#d91a1a}-4.96\%$
test_reinforce_speed[reduce-overhead-backward] 1.1972ms 1.1529ms 867.3974 Ops/s 969.9776 Ops/s $\textbf{\color{#d91a1a}-10.58\%}$
test_iql_speed[False-None] 9.6923ms 9.1325ms 109.4995 Ops/s 109.6957 Ops/s $\color{#d91a1a}-0.18\%$
test_iql_speed[False-backward] 13.5345ms 12.9064ms 77.4806 Ops/s 78.7920 Ops/s $\color{#d91a1a}-1.66\%$
test_iql_speed[True-None] 1.8118ms 1.7386ms 575.1608 Ops/s 570.9952 Ops/s $\color{#35bf28}+0.73\%$
test_iql_speed[True-backward] 4.7257ms 4.3091ms 232.0676 Ops/s 229.6168 Ops/s $\color{#35bf28}+1.07\%$
test_iql_speed[reduce-overhead-None] 19.9821ms 11.4565ms 87.2864 Ops/s 85.7898 Ops/s $\color{#35bf28}+1.74\%$
test_iql_speed[reduce-overhead-backward] 1.6808ms 1.5806ms 632.6706 Ops/s 615.2495 Ops/s $\color{#35bf28}+2.83\%$
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 7.8940ms 6.4110ms 155.9808 Ops/s 153.7272 Ops/s $\color{#35bf28}+1.47\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 0.5683ms 0.2798ms 3.5734 KOps/s 2.8758 KOps/s $\textbf{\color{#35bf28}+24.26\%}$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.4630ms 0.2523ms 3.9642 KOps/s 2.7948 KOps/s $\textbf{\color{#35bf28}+41.84\%}$
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 6.4399ms 6.1291ms 163.1561 Ops/s 161.6373 Ops/s $\color{#35bf28}+0.94\%$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 2.1949ms 0.2824ms 3.5411 KOps/s 2.9784 KOps/s $\textbf{\color{#35bf28}+18.89\%}$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.5655ms 0.2776ms 3.6024 KOps/s 3.3943 KOps/s $\textbf{\color{#35bf28}+6.13\%}$
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] 1.5070ms 1.2979ms 770.4740 Ops/s 711.0246 Ops/s $\textbf{\color{#35bf28}+8.36\%}$
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] 1.4072ms 1.1738ms 851.9129 Ops/s 767.4042 Ops/s $\textbf{\color{#35bf28}+11.01\%}$
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 6.4279ms 6.3068ms 158.5587 Ops/s 156.9563 Ops/s $\color{#35bf28}+1.02\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 1.1668ms 0.4348ms 2.2997 KOps/s 2.4190 KOps/s $\color{#d91a1a}-4.93\%$
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.6736ms 0.3940ms 2.5382 KOps/s 2.3513 KOps/s $\textbf{\color{#35bf28}+7.95\%}$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] 6.3361ms 6.1220ms 163.3463 Ops/s 160.1816 Ops/s $\color{#35bf28}+1.98\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 1.4670ms 0.3269ms 3.0590 KOps/s 3.2842 KOps/s $\textbf{\color{#d91a1a}-6.86\%}$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] 0.4982ms 0.3140ms 3.1844 KOps/s 3.8511 KOps/s $\textbf{\color{#d91a1a}-17.31\%}$
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] 6.3936ms 6.0443ms 165.4462 Ops/s 161.7273 Ops/s $\color{#35bf28}+2.30\%$
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] 0.6185ms 0.3058ms 3.2703 KOps/s 2.9990 KOps/s $\textbf{\color{#35bf28}+9.05\%}$
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] 0.5320ms 0.2943ms 3.3977 KOps/s 3.4791 KOps/s $\color{#d91a1a}-2.34\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] 6.4132ms 6.2486ms 160.0355 Ops/s 157.7189 Ops/s $\color{#35bf28}+1.47\%$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] 1.1169ms 0.4290ms 2.3311 KOps/s 2.1495 KOps/s $\textbf{\color{#35bf28}+8.44\%}$
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] 0.5902ms 0.3860ms 2.5910 KOps/s 2.0737 KOps/s $\textbf{\color{#35bf28}+24.95\%}$
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 7.0423ms 5.4056ms 184.9939 Ops/s 183.1324 Ops/s $\color{#35bf28}+1.02\%$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] 4.0088ms 1.9372ms 516.1989 Ops/s 418.7960 Ops/s $\textbf{\color{#35bf28}+23.26\%}$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] 8.7585ms 1.2364ms 808.7999 Ops/s 907.6455 Ops/s $\textbf{\color{#d91a1a}-10.89\%}$
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] 7.5432ms 5.3224ms 187.8865 Ops/s 186.7481 Ops/s $\color{#35bf28}+0.61\%$
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] 7.9769ms 2.0505ms 487.6868 Ops/s 432.7574 Ops/s $\textbf{\color{#35bf28}+12.69\%}$
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] 7.5499ms 1.2305ms 812.6564 Ops/s 797.4705 Ops/s $\color{#35bf28}+1.90\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] 0.4977s 15.4759ms 64.6165 Ops/s 33.3746 Ops/s $\textbf{\color{#35bf28}+93.61\%}$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] 10.2648ms 2.2468ms 445.0823 Ops/s 441.9400 Ops/s $\color{#35bf28}+0.71\%$
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] 7.0725ms 1.3465ms 742.6422 Ops/s 738.2613 Ops/s $\color{#35bf28}+0.59\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] 15.8027ms 15.3267ms 65.2454 Ops/s 63.9676 Ops/s $\color{#35bf28}+2.00\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] 19.7087ms 17.7246ms 56.4189 Ops/s 55.7401 Ops/s $\color{#35bf28}+1.22\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] 21.5425ms 19.4900ms 51.3083 Ops/s 50.1443 Ops/s $\color{#35bf28}+2.32\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] 20.1289ms 17.5026ms 57.1343 Ops/s 55.6011 Ops/s $\color{#35bf28}+2.76\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] 19.8626ms 19.3887ms 51.5764 Ops/s 49.8091 Ops/s $\color{#35bf28}+3.55\%$
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] 20.6067ms 18.8721ms 52.9883 Ops/s 51.0235 Ops/s $\color{#35bf28}+3.85\%$

[ghstack-poisoned]
[ghstack-poisoned]
@vmoens
Copy link
Contributor Author

vmoens commented Jan 10, 2025

Supporting all these usecases might become a big headacke which is why I preferred to stick with one reward and done key per loss class.

In its current version, this PR assumes that you can have a multihead action but reward / done are going to be tensors.

What we do is that everytime we need to re-compute anything from the dist, we look at it and if it's a composite dist AND if you didn't explicitly ask to aggregate the log-probs, we get a tensordict of log-probs.

From there, for PPOLoss and KL version, the change is quite trivial.

For ClipPPOLoss, there's a bit of change in the logic: before, we were summing all the log-probs (or weights), then clamping, then multiplying by the advantage.

Now, we first clamp each weight leaf, then sum and multiply. Hopefully that should be more mathematically accurate since but I'm happy to revert this if people yell at me!

There is still something I would like to implement but it could be a bit bc-breaking so I'd rather get people's opinion on it: currently, we kind of assume that users have set the return_log_prob=True in the distribution, but we could spare that. All we need is the params of the dist and the sample to compute the log-prob, and the params are presumably part of your tensordict. So we could rebuild the original dist at low cost during the loss computation.

This is what we have now:

with self.actor_network_params.to_module(
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict)
try:
prev_log_prob = _maybe_get_or_select(
tensordict, self.tensor_keys.sample_log_prob
)
except KeyError as err:
raise _make_lp_get_error(self.tensor_keys, tensordict, err)

The way I'm thinking about this is to append a PR to this stack where:

  • if the parameters (say, loc and scale) are present, we always recompute the dist. If the log-prob is there too, we tell the user that this is not necessary in a warning;
  • If the parameters are not there, we fall back on a stored log-prob if there is.

cc @louisfaury

@matteobettini
Copy link
Contributor

Why can't you (1) read the logprob if there is and if not (2) recompute it if there are the original dist prams and if not (3) throw error?

@vmoens
Copy link
Contributor Author

vmoens commented Jan 10, 2025

We could, it's mainly a matter of what is the "default" to me.
IMO the default should be not asking users to compute anything we can do ourselves. But I agree that both options are fine (and in eager mode the new behaviour may be slightly slower - although considering the big picture it'll probably be roughly the same if you include the time it takes to compute the lp during inference).

@vmoens
Copy link
Contributor Author

vmoens commented Jan 10, 2025

nm, this is a bit more ambitious than I thought since VTrace requires the log-prob to be present and we want the advantage to be callable outside of the loss

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants