This repo contains code for the paper Track2Act: Predicting Point Tracks from Internet Videos enables Generalizable Robot Manipulation
Follow the environment.yml
file for creating conda environment and installing dependencies.
For training the point track prediction model, run the following after changing the number of nodes / GPUs per node, batch size as needed
torchrun --nnodes=1 --nproc_per_node=8 train_track_pred.py --global-batch-size=480 --data-path=<folder with data files>
Specify path to initial image, goal image, and checkpoint (trained model is in this link). The visualization will be saved in the folder save_tracK_pred
.
python inference_track_pred.py --ckpt=<path to model> --init=<path to initial image> --goal=<path to goal image>
For any questions about the project, feel free to email Homanga Bharadhwaj [email protected]
The code is licensed under CC-BY-NC License.md
The code in this repo is based on Diffusion Transformers https://github.com/facebookresearch/DiT
and uses open-source packages like diffusers
, scipy
, opencv
, numpy
, pytorch
If you find the repository helpful, please consider citing our paper
@misc{bharadhwaj2024track2act,
title={Track2Act: Predicting Point Tracks from Internet Videos enables Diverse Zero-shot Robot Manipulation},
author={Homanga Bharadhwaj and Roozbeh Mottaghi and Abhinav Gupta and Shubham Tulsiani},
year={2024},
eprint={2405.01527},
archivePrefix={arXiv},
primaryClass={cs.RO}
}