-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Make PPO compatible with composite actions and log-probs #2665
base: gh/vmoens/58/base
Are you sure you want to change the base?
Conversation
ghstack-source-id: cbdaf533a39aeea41e3fbcda4e9d95a116eabfe1 Pull Request resolved: #2665
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2665
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 17 New Failures, 1 Unrelated FailureAs of commit 14e639d with merge base ed656a1 (): NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
In this PR, I propose to let PPO have series of actions defined in the in-keys (rather than a single one) to accomodate CompositeDistributions better. This PR requires pytorch/tensordict#1146 and pytorch/tensordict#1145 to be merged or checked out. Here is a demo: |
Cool! Just to understand a bit, how is this related to multiagent? I see in the example that you are using different agent groups, but the feature seems to be more suited for composite single-agent actions. In multiagent, the suggested way to do things was to create a different loss for each group. This is to avoid losses taking a list of dones, rewards, and actions and have to match them. I think this feature for me makes sense for composite actions within a single-agent or a single marl group (avoiding taking a list of rewards and dones). |
Also in the example you are using a single module to output actions for multiple groups. |
I don't have a strong feeling RE multiagent or not, the use case that was suggested to me here had a composite action space where each leaf was labelled "agent_x" log_prob = make_some_tensordict(...)
prev_log_prob = make_some_tensordict(...)
advantage = make_a_tensordict_or_a_tensor(...)
loss = (log_prob - prev_log_prob).exp().clamp(...) * advantage and your loss will be a tensordict itself. |
This makes sense for a composite action space yes. But in your PR i see you are also allowing lists of dones and rewards. This is a bit less trivial as it opens up to a bunch of compatibility usecases if you want to use this in MARL.
Supporting all these usecases might become a big headacke which is why I preferred to stick with one reward and done key per loss class. |
ghstack-source-id: f465f2017843904a510aa06768ced457df987e94 Pull Request resolved: #2665
ghstack-source-id: 3bcf7ebf9619f62d68979f85021a769796da0539 Pull Request resolved: #2665
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.5689s | 0.4674s | 2.1394 Ops/s | 2.1416 Ops/s | |
test_transformed | 0.7436s | 0.6492s | 1.5403 Ops/s | 1.5206 Ops/s | |
test_serial | 1.4874s | 1.3864s | 0.7213 Ops/s | 0.7046 Ops/s | |
test_parallel | 1.3401s | 1.2280s | 0.8143 Ops/s | 0.7603 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1648ms | 30.4931μs | 32.7943 KOps/s | 33.5589 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 76.6840μs | 18.0617μs | 55.3657 KOps/s | 55.9135 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 80.8810μs | 17.2989μs | 57.8072 KOps/s | 58.8932 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 38.0020μs | 10.1379μs | 98.6396 KOps/s | 98.9940 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 86.1020μs | 32.7417μs | 30.5421 KOps/s | 31.1040 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 81.5130μs | 19.9516μs | 50.1212 KOps/s | 50.4088 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 57.1070μs | 19.2024μs | 52.0768 KOps/s | 52.4725 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 51.9570μs | 12.1072μs | 82.5957 KOps/s | 83.9385 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 0.1309ms | 35.1307μs | 28.4652 KOps/s | 29.4069 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 53.6100μs | 22.3654μs | 44.7120 KOps/s | 45.9667 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 46.7180μs | 19.2691μs | 51.8966 KOps/s | 53.6333 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 54.4420μs | 12.1603μs | 82.2348 KOps/s | 83.6048 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 68.7190μs | 36.5234μs | 27.3797 KOps/s | 28.1960 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 70.9320μs | 24.1283μs | 41.4452 KOps/s | 43.0549 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 76.1430μs | 22.0093μs | 45.4354 KOps/s | 48.3307 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 0.1418ms | 13.8298μs | 72.3077 KOps/s | 72.8512 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 76.3030μs | 34.6117μs | 28.8920 KOps/s | 29.4453 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 54.6230μs | 22.0928μs | 45.2636 KOps/s | 46.0348 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 74.7500μs | 21.9712μs | 45.5141 KOps/s | 44.9626 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 67.9770μs | 13.4551μs | 74.3211 KOps/s | 75.6971 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 88.8070μs | 36.1414μs | 27.6691 KOps/s | 28.4653 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 61.0250μs | 23.9141μs | 41.8164 KOps/s | 42.6559 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 2.6106ms | 23.8353μs | 41.9546 KOps/s | 43.1588 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 74.5810μs | 15.2323μs | 65.6500 KOps/s | 67.2401 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 76.3430μs | 38.2271μs | 26.1595 KOps/s | 26.9248 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 59.1900μs | 25.7107μs | 38.8943 KOps/s | 39.1115 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 77.7560μs | 23.8203μs | 41.9811 KOps/s | 43.6152 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 56.8970μs | 15.4390μs | 64.7712 KOps/s | 66.9105 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 75.7320μs | 40.0003μs | 24.9998 KOps/s | 25.6725 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 98.1130μs | 27.3908μs | 36.5086 KOps/s | 37.1493 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 56.6260μs | 25.2534μs | 39.5986 KOps/s | 40.6499 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 55.2730μs | 16.8121μs | 59.4808 KOps/s | 60.3791 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 10.2948ms | 10.0375ms | 99.6267 Ops/s | 99.6708 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 38.0918ms | 33.7460ms | 29.6332 Ops/s | 29.4798 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.2815ms | 0.2124ms | 4.7079 KOps/s | 4.7050 KOps/s | |
test_values[td1_return_estimate-False-False] | 28.3563ms | 24.6104ms | 40.6332 Ops/s | 39.9479 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 37.1327ms | 33.6920ms | 29.6807 Ops/s | 29.4069 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 37.4898ms | 35.0156ms | 28.5587 Ops/s | 28.0091 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 46.5158ms | 34.1923ms | 29.2463 Ops/s | 29.3585 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 11.8942ms | 8.5755ms | 116.6109 Ops/s | 115.9356 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 2.3519ms | 1.8553ms | 539.0096 Ops/s | 539.3791 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.6300ms | 0.3612ms | 2.7687 KOps/s | 2.7729 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 39.1611ms | 38.4559ms | 26.0038 Ops/s | 26.1163 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 4.3696ms | 3.0513ms | 327.7333 Ops/s | 305.2308 Ops/s | |
test_dqn_speed[False-None] | 2.0935ms | 1.4144ms | 707.0208 Ops/s | 693.1214 Ops/s | |
test_dqn_speed[False-backward] | 1.9754ms | 1.8882ms | 529.5979 Ops/s | 501.3073 Ops/s | |
test_dqn_speed[True-None] | 0.7426ms | 0.4784ms | 2.0902 KOps/s | 2.0383 KOps/s | |
test_dqn_speed[True-backward] | 0.9443ms | 0.8949ms | 1.1175 KOps/s | 854.5569 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.5922ms | 0.4773ms | 2.0952 KOps/s | 2.0046 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 0.9285ms | 0.8886ms | 1.1254 KOps/s | 1.0264 KOps/s | |
test_ddpg_speed[False-None] | 3.3212ms | 2.9198ms | 342.4846 Ops/s | 316.1030 Ops/s | |
test_ddpg_speed[False-backward] | 4.2055ms | 4.0502ms | 246.9003 Ops/s | 224.5300 Ops/s | |
test_ddpg_speed[True-None] | 1.4579ms | 1.0210ms | 979.4420 Ops/s | 555.9376 Ops/s | |
test_ddpg_speed[True-backward] | 2.4231ms | 1.9716ms | 507.1938 Ops/s | 415.0689 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.6882ms | 1.0098ms | 990.3370 Ops/s | 927.0692 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 1.9580ms | 1.9069ms | 524.4114 Ops/s | 455.5887 Ops/s | |
test_sac_speed[False-None] | 9.3644ms | 8.1158ms | 123.2165 Ops/s | 105.0327 Ops/s | |
test_sac_speed[False-backward] | 11.3785ms | 10.9020ms | 91.7265 Ops/s | 78.6914 Ops/s | |
test_sac_speed[True-None] | 2.7484ms | 1.8626ms | 536.8895 Ops/s | 470.5976 Ops/s | |
test_sac_speed[True-backward] | 3.6861ms | 3.5862ms | 278.8469 Ops/s | 229.0318 Ops/s | |
test_sac_speed[reduce-overhead-None] | 2.1804ms | 1.8414ms | 543.0652 Ops/s | 443.9634 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 3.7551ms | 3.6122ms | 276.8364 Ops/s | 216.0022 Ops/s | |
test_redq_speed[False-None] | 15.8075ms | 13.8531ms | 72.1861 Ops/s | 69.3331 Ops/s | |
test_redq_speed[False-backward] | 0.2707s | 28.3996ms | 35.2117 Ops/s | 41.1332 Ops/s | |
test_redq_speed[True-None] | 6.3448ms | 5.3953ms | 185.3456 Ops/s | 164.0274 Ops/s | |
test_redq_speed[True-backward] | 13.4297ms | 12.8381ms | 77.8929 Ops/s | 72.1986 Ops/s | |
test_redq_speed[reduce-overhead-None] | 6.5822ms | 5.8316ms | 171.4789 Ops/s | 164.9629 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 14.1782ms | 13.1963ms | 75.7786 Ops/s | 73.5108 Ops/s | |
test_redq_deprec_speed[False-None] | 15.4132ms | 13.7509ms | 72.7227 Ops/s | 65.0288 Ops/s | |
test_redq_deprec_speed[False-backward] | 23.4082ms | 20.6279ms | 48.4780 Ops/s | 44.7847 Ops/s | |
test_redq_deprec_speed[True-None] | 4.9417ms | 4.2255ms | 236.6579 Ops/s | 203.6377 Ops/s | |
test_redq_deprec_speed[True-backward] | 10.7496ms | 9.3165ms | 107.3368 Ops/s | 104.4682 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 4.6193ms | 4.0806ms | 245.0602 Ops/s | 218.4479 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 9.1956ms | 8.9940ms | 111.1857 Ops/s | 102.4827 Ops/s | |
test_td3_speed[False-None] | 10.6821ms | 8.4376ms | 118.5172 Ops/s | 104.8598 Ops/s | |
test_td3_speed[False-backward] | 12.0236ms | 11.0022ms | 90.8912 Ops/s | 40.3497 Ops/s | |
test_td3_speed[True-None] | 2.0335ms | 1.8174ms | 550.2453 Ops/s | 461.5162 Ops/s | |
test_td3_speed[True-backward] | 3.8084ms | 3.6422ms | 274.5560 Ops/s | 234.1285 Ops/s | |
test_td3_speed[reduce-overhead-None] | 2.6153ms | 1.8112ms | 552.1273 Ops/s | 526.2834 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 3.9475ms | 3.5723ms | 279.9339 Ops/s | 242.0293 Ops/s | |
test_cql_speed[False-None] | 39.9291ms | 37.9786ms | 26.3306 Ops/s | 25.0564 Ops/s | |
test_cql_speed[False-backward] | 51.8552ms | 48.7045ms | 20.5320 Ops/s | 19.7065 Ops/s | |
test_cql_speed[True-None] | 18.3823ms | 16.3643ms | 61.1087 Ops/s | 61.5889 Ops/s | |
test_cql_speed[True-backward] | 24.6161ms | 23.3575ms | 42.8128 Ops/s | 42.7166 Ops/s | |
test_cql_speed[reduce-overhead-None] | 18.0415ms | 16.5805ms | 60.3117 Ops/s | 61.8871 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 25.1492ms | 24.0684ms | 41.5482 Ops/s | 42.2150 Ops/s | |
test_a2c_speed[False-None] | 9.8296ms | 8.4773ms | 117.9617 Ops/s | 127.4906 Ops/s | |
test_a2c_speed[False-backward] | 16.5946ms | 16.1380ms | 61.9656 Ops/s | 64.1320 Ops/s | |
test_a2c_speed[True-None] | 5.0977ms | 4.6624ms | 214.4839 Ops/s | 233.6190 Ops/s | |
test_a2c_speed[True-backward] | 12.3148ms | 11.8640ms | 84.2884 Ops/s | 91.5602 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 5.7119ms | 4.8792ms | 204.9506 Ops/s | 235.8139 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 12.0445ms | 11.8313ms | 84.5214 Ops/s | 94.0394 Ops/s | |
test_ppo_speed[False-None] | 9.6444ms | 8.6675ms | 115.3741 Ops/s | 133.1203 Ops/s | |
test_ppo_speed[False-backward] | 18.3843ms | 16.7214ms | 59.8036 Ops/s | 65.0513 Ops/s | |
test_ppo_speed[True-None] | 4.8369ms | 4.3712ms | 228.7712 Ops/s | 264.2910 Ops/s | |
test_ppo_speed[True-backward] | 14.9920ms | 10.7936ms | 92.6477 Ops/s | 101.4494 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 4.7394ms | 4.2130ms | 237.3602 Ops/s | 266.9062 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 10.8793ms | 10.6079ms | 94.2696 Ops/s | 101.9902 Ops/s | |
test_reinforce_speed[False-None] | 8.8304ms | 7.4774ms | 133.7363 Ops/s | 152.3433 Ops/s | |
test_reinforce_speed[False-backward] | 12.3544ms | 11.1352ms | 89.8055 Ops/s | 99.5649 Ops/s | |
test_reinforce_speed[True-None] | 3.4598ms | 3.2329ms | 309.3178 Ops/s | 368.2565 Ops/s | |
test_reinforce_speed[True-backward] | 9.8995ms | 9.6132ms | 104.0238 Ops/s | 106.9073 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 3.6624ms | 3.1864ms | 313.8315 Ops/s | 372.0995 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 10.5593ms | 9.6947ms | 103.1496 Ops/s | 116.4056 Ops/s | |
test_iql_speed[False-None] | 49.3593ms | 35.4278ms | 28.2264 Ops/s | 29.4787 Ops/s | |
test_iql_speed[False-backward] | 53.2848ms | 48.5551ms | 20.5952 Ops/s | 15.1996 Ops/s | |
test_iql_speed[True-None] | 12.5246ms | 11.4373ms | 87.4335 Ops/s | 87.6043 Ops/s | |
test_iql_speed[True-backward] | 24.1073ms | 23.0146ms | 43.4508 Ops/s | 43.2112 Ops/s | |
test_iql_speed[reduce-overhead-None] | 12.3481ms | 11.4941ms | 87.0011 Ops/s | 87.3294 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 24.9293ms | 23.0904ms | 43.3081 Ops/s | 42.5897 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 8.1250ms | 5.8221ms | 171.7589 Ops/s | 177.6457 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.0202ms | 0.5559ms | 1.7987 KOps/s | 1.7642 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.8099ms | 0.5320ms | 1.8798 KOps/s | 1.8771 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 8.2625ms | 5.2444ms | 190.6792 Ops/s | 190.6186 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.6004ms | 0.5456ms | 1.8330 KOps/s | 1.8592 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7647ms | 0.5113ms | 1.9559 KOps/s | 1.9465 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 2.7493ms | 1.7383ms | 575.2806 Ops/s | 596.5794 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 2.2752ms | 1.6616ms | 601.8458 Ops/s | 631.2298 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 8.8024ms | 5.7953ms | 172.5531 Ops/s | 192.3785 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.8401ms | 0.7004ms | 1.4277 KOps/s | 1.4541 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 2.0275ms | 0.6694ms | 1.4939 KOps/s | 1.5070 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 8.3305ms | 5.4801ms | 182.4769 Ops/s | 184.3692 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.5401ms | 0.5593ms | 1.7880 KOps/s | 1.7634 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.5242s | 1.2548ms | 796.9199 Ops/s | 1.8808 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 7.1808ms | 5.3063ms | 188.4567 Ops/s | 189.8579 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.5310ms | 0.5471ms | 1.8277 KOps/s | 423.6051 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.8288ms | 0.5307ms | 1.8844 KOps/s | 1.8713 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.6119ms | 5.9415ms | 168.3064 Ops/s | 162.8823 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.0020ms | 0.6981ms | 1.4324 KOps/s | 1.4272 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 9.6943ms | 0.6972ms | 1.4344 KOps/s | 1.4538 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 8.3869ms | 4.9208ms | 203.2194 Ops/s | 174.6613 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 4.0959ms | 2.3769ms | 420.7117 Ops/s | 385.2650 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 7.9122ms | 1.5877ms | 629.8268 Ops/s | 644.9104 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.5115s | 15.3170ms | 65.2869 Ops/s | 185.6160 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 8.1117ms | 2.5402ms | 393.6704 Ops/s | 384.5297 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 1.9181ms | 1.3244ms | 755.0434 Ops/s | 665.2166 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 8.3161ms | 5.6149ms | 178.0974 Ops/s | 200.1906 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 9.8979ms | 2.8863ms | 346.4630 Ops/s | 372.9954 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 3.3890ms | 1.6743ms | 597.2653 Ops/s | 569.9659 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 16.2573ms | 13.7136ms | 72.9201 Ops/s | 70.0920 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 17.5449ms | 15.3295ms | 65.2338 Ops/s | 63.9357 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 23.5649ms | 22.6039ms | 44.2401 Ops/s | 44.4248 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 17.1245ms | 15.5890ms | 64.1477 Ops/s | 63.8144 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 23.5630ms | 22.5386ms | 44.3684 Ops/s | 44.6329 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 18.4069ms | 17.0650ms | 58.5995 Ops/s | 59.7489 Ops/s |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.8060s | 0.7235s | 1.3822 Ops/s | 1.3658 Ops/s | |
test_transformed | 0.9492s | 0.9484s | 1.0544 Ops/s | 1.0216 Ops/s | |
test_serial | 2.0861s | 2.0792s | 0.4810 Ops/s | 0.4728 Ops/s | |
test_parallel | 1.8853s | 1.8162s | 0.5506 Ops/s | 0.5381 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.2001ms | 39.9129μs | 25.0546 KOps/s | 24.8750 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 59.6220μs | 23.2860μs | 42.9443 KOps/s | 42.1577 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 57.0430μs | 22.4992μs | 44.4460 KOps/s | 44.6167 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 43.6020μs | 12.9599μs | 77.1612 KOps/s | 76.0541 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 81.5540μs | 42.8597μs | 23.3320 KOps/s | 23.1057 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 67.3330μs | 25.8639μs | 38.6640 KOps/s | 37.9710 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 63.9030μs | 24.8776μs | 40.1968 KOps/s | 39.9172 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 49.7520μs | 15.5129μs | 64.4624 KOps/s | 63.7846 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 99.7250μs | 45.8356μs | 21.8171 KOps/s | 21.7927 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 59.8030μs | 28.5072μs | 35.0788 KOps/s | 34.7303 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 71.7440μs | 24.7285μs | 40.4391 KOps/s | 39.8918 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 46.4220μs | 15.5818μs | 64.1772 KOps/s | 65.4093 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 80.3540μs | 47.6746μs | 20.9755 KOps/s | 20.7085 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 79.0240μs | 30.0940μs | 33.2292 KOps/s | 32.8183 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 59.6430μs | 26.9922μs | 37.0478 KOps/s | 36.6503 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 47.6020μs | 17.4773μs | 57.2172 KOps/s | 55.7687 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 92.1540μs | 45.1355μs | 22.1555 KOps/s | 21.6677 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 61.9720μs | 27.9291μs | 35.8050 KOps/s | 35.2658 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 65.0230μs | 28.4115μs | 35.1970 KOps/s | 34.0991 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 57.7730μs | 17.1499μs | 58.3093 KOps/s | 57.9377 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 85.6540μs | 47.5199μs | 21.0438 KOps/s | 20.7287 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 66.6030μs | 30.3307μs | 32.9699 KOps/s | 32.7234 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 3.1631ms | 31.5930μs | 31.6526 KOps/s | 31.7965 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 55.2530μs | 19.7668μs | 50.5899 KOps/s | 51.0143 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 99.4550μs | 49.4296μs | 20.2308 KOps/s | 19.7393 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 69.2630μs | 33.0527μs | 30.2547 KOps/s | 29.7979 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 69.2530μs | 30.4899μs | 32.7978 KOps/s | 32.3723 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 53.4020μs | 19.6334μs | 50.9335 KOps/s | 51.1258 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 93.2440μs | 51.8988μs | 19.2683 KOps/s | 19.1939 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 70.0640μs | 35.1579μs | 28.4431 KOps/s | 28.6358 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 73.5240μs | 32.1815μs | 31.0738 KOps/s | 30.1508 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 53.8020μs | 21.6186μs | 46.2564 KOps/s | 45.9323 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 23.8026ms | 23.3221ms | 42.8778 Ops/s | 42.3151 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 97.8128ms | 2.8316ms | 353.1618 Ops/s | 337.2235 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.1040ms | 75.1743μs | 13.3024 KOps/s | 12.9857 KOps/s | |
test_values[td1_return_estimate-False-False] | 53.1469ms | 52.3855ms | 19.0893 Ops/s | 18.8518 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 1.2912ms | 1.0501ms | 952.2839 Ops/s | 944.6772 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 87.5671ms | 83.5382ms | 11.9706 Ops/s | 11.8459 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 1.1854ms | 1.0442ms | 957.6430 Ops/s | 948.1505 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 23.5033ms | 23.0379ms | 43.4068 Ops/s | 42.3430 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.0020ms | 0.7169ms | 1.3949 KOps/s | 1.3738 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.7260ms | 0.6407ms | 1.5608 KOps/s | 1.5502 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.5103ms | 1.4495ms | 689.8972 Ops/s | 686.5300 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.7135ms | 0.6563ms | 1.5237 KOps/s | 1.4732 KOps/s | |
test_dqn_speed[False-None] | 6.8786ms | 1.4900ms | 671.1429 Ops/s | 678.8408 Ops/s | |
test_dqn_speed[False-backward] | 2.1245ms | 2.0711ms | 482.8300 Ops/s | 486.2776 Ops/s | |
test_dqn_speed[True-None] | 0.6308ms | 0.5410ms | 1.8485 KOps/s | 1.8196 KOps/s | |
test_dqn_speed[True-backward] | 1.2760ms | 1.1795ms | 847.8456 Ops/s | 830.8954 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.6275ms | 0.5534ms | 1.8070 KOps/s | 1.7506 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.1230ms | 1.0437ms | 958.1394 Ops/s | 940.6862 Ops/s | |
test_ddpg_speed[False-None] | 3.0807ms | 2.8075ms | 356.1941 Ops/s | 350.0253 Ops/s | |
test_ddpg_speed[False-backward] | 4.4811ms | 4.0936ms | 244.2820 Ops/s | 241.4107 Ops/s | |
test_ddpg_speed[True-None] | 1.1323ms | 1.0598ms | 943.6076 Ops/s | 905.0687 Ops/s | |
test_ddpg_speed[True-backward] | 2.2911ms | 2.2442ms | 445.5896 Ops/s | 465.7315 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.2213ms | 1.1042ms | 905.6713 Ops/s | 884.8497 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 1.7908ms | 1.7412ms | 574.3205 Ops/s | 612.7275 Ops/s | |
test_sac_speed[False-None] | 8.2835ms | 7.8495ms | 127.3962 Ops/s | 125.1734 Ops/s | |
test_sac_speed[False-backward] | 11.3340ms | 10.8214ms | 92.4092 Ops/s | 93.1603 Ops/s | |
test_sac_speed[True-None] | 1.6665ms | 1.5058ms | 664.0928 Ops/s | 652.9916 Ops/s | |
test_sac_speed[True-backward] | 3.3975ms | 3.3231ms | 300.9256 Ops/s | 298.7448 Ops/s | |
test_sac_speed[reduce-overhead-None] | 23.0791ms | 12.7645ms | 78.3425 Ops/s | 78.2932 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 1.6053ms | 1.5020ms | 665.7887 Ops/s | 744.9279 Ops/s | |
test_redq_speed[False-None] | 8.0654ms | 7.3334ms | 136.3615 Ops/s | 134.7176 Ops/s | |
test_redq_speed[False-backward] | 12.0885ms | 11.2712ms | 88.7220 Ops/s | 91.0152 Ops/s | |
test_redq_speed[True-None] | 2.1897ms | 1.9584ms | 510.6150 Ops/s | 509.7029 Ops/s | |
test_redq_speed[True-backward] | 3.9969ms | 3.7333ms | 267.8589 Ops/s | 262.6352 Ops/s | |
test_redq_speed[reduce-overhead-None] | 2.0475ms | 1.9404ms | 515.3534 Ops/s | 510.2938 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 3.9809ms | 3.5454ms | 282.0523 Ops/s | 263.9370 Ops/s | |
test_redq_deprec_speed[False-None] | 9.2782ms | 8.7705ms | 114.0180 Ops/s | 111.8246 Ops/s | |
test_redq_deprec_speed[False-backward] | 11.8799ms | 11.5526ms | 86.5608 Ops/s | 83.5564 Ops/s | |
test_redq_deprec_speed[True-None] | 2.4192ms | 2.2825ms | 438.1080 Ops/s | 433.1120 Ops/s | |
test_redq_deprec_speed[True-backward] | 4.0384ms | 3.8984ms | 256.5155 Ops/s | 243.0773 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 2.3894ms | 2.2801ms | 438.5730 Ops/s | 431.1762 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 4.3451ms | 3.9209ms | 255.0412 Ops/s | 244.3168 Ops/s | |
test_td3_speed[False-None] | 7.8777ms | 7.7449ms | 129.1178 Ops/s | 127.5278 Ops/s | |
test_td3_speed[False-backward] | 10.2820ms | 9.9762ms | 100.2382 Ops/s | 97.3108 Ops/s | |
test_td3_speed[True-None] | 1.6199ms | 1.5709ms | 636.5803 Ops/s | 620.6055 Ops/s | |
test_td3_speed[True-backward] | 3.0846ms | 3.0399ms | 328.9535 Ops/s | 306.9307 Ops/s | |
test_td3_speed[reduce-overhead-None] | 56.3258ms | 26.2147ms | 38.1466 Ops/s | 38.2017 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 1.3381ms | 1.2810ms | 780.6125 Ops/s | 687.5121 Ops/s | |
test_cql_speed[False-None] | 16.9255ms | 16.3828ms | 61.0396 Ops/s | 60.5445 Ops/s | |
test_cql_speed[False-backward] | 22.1089ms | 21.2517ms | 47.0551 Ops/s | 45.9246 Ops/s | |
test_cql_speed[True-None] | 2.9575ms | 2.8674ms | 348.7453 Ops/s | 344.8747 Ops/s | |
test_cql_speed[True-backward] | 5.2995ms | 4.9854ms | 200.5877 Ops/s | 198.6451 Ops/s | |
test_cql_speed[reduce-overhead-None] | 0.3663s | 15.1243ms | 66.1186 Ops/s | 74.7428 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 1.5521ms | 1.5112ms | 661.7442 Ops/s | 587.3954 Ops/s | |
test_a2c_speed[False-None] | 3.2165ms | 3.1301ms | 319.4795 Ops/s | 315.9132 Ops/s | |
test_a2c_speed[False-backward] | 6.9728ms | 5.8948ms | 169.6410 Ops/s | 160.8495 Ops/s | |
test_a2c_speed[True-None] | 1.1220ms | 0.9958ms | 1.0042 KOps/s | 990.1832 Ops/s | |
test_a2c_speed[True-backward] | 2.5861ms | 2.5184ms | 397.0720 Ops/s | 364.4988 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 21.5077ms | 11.6753ms | 85.6511 Ops/s | 86.1301 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 0.9993ms | 0.9594ms | 1.0423 KOps/s | 869.3047 Ops/s | |
test_ppo_speed[False-None] | 3.7306ms | 3.5881ms | 278.7006 Ops/s | 277.3476 Ops/s | |
test_ppo_speed[False-backward] | 7.2988ms | 6.6101ms | 151.2834 Ops/s | 147.3737 Ops/s | |
test_ppo_speed[True-None] | 1.0065ms | 0.9365ms | 1.0678 KOps/s | 1.0396 KOps/s | |
test_ppo_speed[True-backward] | 2.7110ms | 2.6477ms | 377.6927 Ops/s | 398.1029 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 0.5865ms | 0.5270ms | 1.8976 KOps/s | 68.0171 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 1.1553ms | 1.1038ms | 905.9800 Ops/s | 978.7588 Ops/s | |
test_reinforce_speed[False-None] | 2.3093ms | 2.2144ms | 451.5948 Ops/s | 447.1864 Ops/s | |
test_reinforce_speed[False-backward] | 3.5669ms | 3.2616ms | 306.6002 Ops/s | 314.1609 Ops/s | |
test_reinforce_speed[True-None] | 0.8661ms | 0.8198ms | 1.2198 KOps/s | 1.1488 KOps/s | |
test_reinforce_speed[True-backward] | 2.5952ms | 2.5008ms | 399.8685 Ops/s | 412.7820 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 0.2927s | 12.1792ms | 82.1070 Ops/s | 86.3953 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 1.1972ms | 1.1529ms | 867.3974 Ops/s | 969.9776 Ops/s | |
test_iql_speed[False-None] | 9.6923ms | 9.1325ms | 109.4995 Ops/s | 109.6957 Ops/s | |
test_iql_speed[False-backward] | 13.5345ms | 12.9064ms | 77.4806 Ops/s | 78.7920 Ops/s | |
test_iql_speed[True-None] | 1.8118ms | 1.7386ms | 575.1608 Ops/s | 570.9952 Ops/s | |
test_iql_speed[True-backward] | 4.7257ms | 4.3091ms | 232.0676 Ops/s | 229.6168 Ops/s | |
test_iql_speed[reduce-overhead-None] | 19.9821ms | 11.4565ms | 87.2864 Ops/s | 85.7898 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 1.6808ms | 1.5806ms | 632.6706 Ops/s | 615.2495 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.8940ms | 6.4110ms | 155.9808 Ops/s | 153.7272 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.5683ms | 0.2798ms | 3.5734 KOps/s | 2.8758 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.4630ms | 0.2523ms | 3.9642 KOps/s | 2.7948 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.4399ms | 6.1291ms | 163.1561 Ops/s | 161.6373 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.1949ms | 0.2824ms | 3.5411 KOps/s | 2.9784 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.5655ms | 0.2776ms | 3.6024 KOps/s | 3.3943 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.5070ms | 1.2979ms | 770.4740 Ops/s | 711.0246 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.4072ms | 1.1738ms | 851.9129 Ops/s | 767.4042 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.4279ms | 6.3068ms | 158.5587 Ops/s | 156.9563 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.1668ms | 0.4348ms | 2.2997 KOps/s | 2.4190 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.6736ms | 0.3940ms | 2.5382 KOps/s | 2.3513 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.3361ms | 6.1220ms | 163.3463 Ops/s | 160.1816 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.4670ms | 0.3269ms | 3.0590 KOps/s | 3.2842 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.4982ms | 0.3140ms | 3.1844 KOps/s | 3.8511 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.3936ms | 6.0443ms | 165.4462 Ops/s | 161.7273 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.6185ms | 0.3058ms | 3.2703 KOps/s | 2.9990 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.5320ms | 0.2943ms | 3.3977 KOps/s | 3.4791 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.4132ms | 6.2486ms | 160.0355 Ops/s | 157.7189 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.1169ms | 0.4290ms | 2.3311 KOps/s | 2.1495 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.5902ms | 0.3860ms | 2.5910 KOps/s | 2.0737 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 7.0423ms | 5.4056ms | 184.9939 Ops/s | 183.1324 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 4.0088ms | 1.9372ms | 516.1989 Ops/s | 418.7960 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 8.7585ms | 1.2364ms | 808.7999 Ops/s | 907.6455 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 7.5432ms | 5.3224ms | 187.8865 Ops/s | 186.7481 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 7.9769ms | 2.0505ms | 487.6868 Ops/s | 432.7574 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 7.5499ms | 1.2305ms | 812.6564 Ops/s | 797.4705 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.4977s | 15.4759ms | 64.6165 Ops/s | 33.3746 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 10.2648ms | 2.2468ms | 445.0823 Ops/s | 441.9400 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 7.0725ms | 1.3465ms | 742.6422 Ops/s | 738.2613 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 15.8027ms | 15.3267ms | 65.2454 Ops/s | 63.9676 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 19.7087ms | 17.7246ms | 56.4189 Ops/s | 55.7401 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 21.5425ms | 19.4900ms | 51.3083 Ops/s | 50.1443 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 20.1289ms | 17.5026ms | 57.1343 Ops/s | 55.6011 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 19.8626ms | 19.3887ms | 51.5764 Ops/s | 49.8091 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 20.6067ms | 18.8721ms | 52.9883 Ops/s | 51.0235 Ops/s |
In its current version, this PR assumes that you can have a multihead action but reward / done are going to be tensors. What we do is that everytime we need to re-compute anything from the dist, we look at it and if it's a composite dist AND if you didn't explicitly ask to aggregate the log-probs, we get a tensordict of log-probs. From there, for PPOLoss and KL version, the change is quite trivial. For ClipPPOLoss, there's a bit of change in the logic: before, we were summing all the log-probs (or weights), then clamping, then multiplying by the advantage. Now, we first clamp each weight leaf, then sum and multiply. Hopefully that should be more mathematically accurate since but I'm happy to revert this if people yell at me! There is still something I would like to implement but it could be a bit bc-breaking so I'd rather get people's opinion on it: currently, we kind of assume that users have set the This is what we have now: Lines 513 to 523 in 86ab9b7
The way I'm thinking about this is to append a PR to this stack where:
cc @louisfaury |
Why can't you (1) read the logprob if there is and if not (2) recompute it if there are the original dist prams and if not (3) throw error? |
We could, it's mainly a matter of what is the "default" to me. |
nm, this is a bit more ambitious than I thought since VTrace requires the log-prob to be present and we want the advantage to be callable outside of the loss |
Stack from ghstack (oldest at bottom):