You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I get the following error when running python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/dialog/gpt2_ppo.yml. I have double-checked that transformers==4.18.0.
Traceback (most recent call last):
File "/Users/stephanehatgiskessell/Desktop/RL4LMs/scripts/training/train_text_generation.py", line 84, in <module>
main(
File "/Users/stephanehatgiskessell/Desktop/RL4LMs/scripts/training/train_text_generation.py", line 55, in main
trainer.train_and_eval()
File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/training_utils.py", line 232, in train_and_eval
self._alg.learn(self._n_steps_per_iter)
File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/algorithms/ppo/ppo.py", line 342, in learn
return super().learn(
File "/opt/anaconda3/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 247, in learn
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/alg_wrappers.py", line 384, in collect_rollouts
rollout_info = self.generate_batch(
File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/alg_wrappers.py", line 159, in generate_batch
gen_output = self.policy.generate(
File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/policy/base_policy.py", line 230, in generate
inputs=input_ids.to(self.get_policy_first_device()),
File "/Users/stephanehatgiskessell/Desktop/RL4LMs/rl4lms/envs/text_generation/policy/causal_policy.py", line 259, in get_policy_first_device
self._policy_model.transformer.first_device
File "/opt/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1185, in __getattr__
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'GPT2Model' object has no attribute 'first_device'
The text was updated successfully, but these errors were encountered:
I get the following error when running
python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/dialog/gpt2_ppo.yml
. I have double-checked that transformers==4.18.0.The text was updated successfully, but these errors were encountered: