Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lxtGH committed Jan 8, 2025
1 parent f5ce17e commit 2e7edca
Show file tree
Hide file tree
Showing 97 changed files with 19,469 additions and 1 deletion.
128 changes: 128 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/*/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# custom
checkpoints/
checkpoints
data/
data
.vscode
.idea
.DS_Store
*.pkl
*.pkl.json
*.log.json
work_dirs/
pretrained/
temp_*/

# Pytorch
*.pth
*.py~
*.sh~

# srun
*.out
batchscript-*
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,4 @@
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
78 changes: 78 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos

[\[🏠 Sa2VA\]](https://lxtgh.github.io/project/sa2va) [\[📜 arXiv\]]() [\[🤗 Model\]](https://huggingface.co/ByteDance/Sa2VA) [\[🎥 Introduction\]]() [\[🧑‍💻 GitHub\]](https://github.com/magic-research/Sa2VA)

![Teaser](assets/images/teaser.jpg)

## Overiew
This repository contains the code for the paper "Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos".

Sa2VA is the the first unified model for dense grounded understanding of both images and videos. Unlike existing multi-modal large language models, which are often limited to specific modalities and tasks, Sa2VA supports a wide range of image and video tasks, including referring segmentation and conversation, with minimal one-shot instruction tuning. Sa2VA combines SAM-2, a foundation video segmentation model, with LLaVA, an advanced vision-language model, and unifies text, image, and video into a shared LLM token space.

## Model Zoo
We provide the following models:
| Model Name | Base MLLM | Language Part | HF Link |
|:----------:|:-----------------------------------------------------------------:|:-----------------------------------------------------------------------------:|:----------------------------------------------------:|
| Sa2VA-1B | [InternVL2.5-1B](https://huggingface.co/OpenGVLab/InternVL2_5-1B) | [Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) | [🤗 link](https://huggingface.co/ByteDance/Sa2VA-1B) |
| Sa2VA-4B | [InternVL2.5-4B](https://huggingface.co/OpenGVLab/InternVL2_5-4B) | [Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct) | [🤗 link](https://huggingface.co/ByteDance/Sa2VA-4B) |
| Sa2VA-8B | [InternVL2.5-8B](https://huggingface.co/OpenGVLab/InternVL2_5-8B) | [internlm2_5-7b-chat](https://huggingface.co/internlm/internlm2_5-7b-chat) | [🤗 link](https://huggingface.co/ByteDance/Sa2VA-8B) |

## Training
<details>
<summary>Installation</summary>

1. Please install the python and pytorch first:
```bash
> conda create -n vlm python=3.10
> conda activate vlm
> conda install pytorch==2.3.1 torchvision==0.18.1 pytorch-cuda=12.1 cuda -c pytorch -c "nvidia/label/cuda-12.1.0" -c "nvidia/label/cuda-12.1.1"
```

2. Install mmcv:
```bash
> pip install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.3/index.html
```

3. Install other dependencies:
```bash
> pip install -r requirements.txt
```
</details>

<details>
<summary>Pretrained Model Preparation</summary>

You are expected to download the following pretrained models and place them in the `./pretrained` directory:
- [sam2_hiera_large.pt](https://huggingface.co/facebook/sam2-hiera-large)
- [InternVL2_5-4B](https://huggingface.co/OpenGVLab/InternVL2_5-4B)

</details>

<details>
<summary>Data Preparation</summary>

(TODO) Please download the training datasets and place them in the `data` directory. The download link is [here](https://huggingface.co/datasets/Dense-World/Sa2VA-Training).

</details>


<details>
<summary>Training Script</summary>

Please run the following script to train:
```bash
> bash tools/dist.sh train projects/llava_sam2/configs/sa2va_4b.py 8
```
</details>


## References
If you find this repository useful, please consider referring the following paper:
```
@article{sa2va,
title={Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos},
author={Yuan, Haobo and Li, Xiangtai and Zhang, Tao and Huang, Zilong and Xu, Shilin and Ji, Shunping and Tong, Yunhai and Qi, Lu and Feng, Jiashi and Yang, Ming-Hsuan},
journal={arXiv},
year={2025}
}
```
Binary file added assets/images/teaser.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions projects/glamm/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .semantic_seg_dataset import SemanticSegDataset, ADE20kSemanticSegDataset, \
COCOStuffSemanticSegDataset, PascalPartSemanticSegDataset, PacoSemanticSegDataset
from .gcg_dataset import GCGDataset, GranDfGCGDataset, RefCOCOgGCGDataset, OpenPsgGCGDataset, Flickr30kGCGDataset
from .region_level_dataset import RefCocoGRegionDataset, VisualGenomeRegionDataset
from .refcoco_segm_dataset import ReferSegmDataset
from .utils.utils import *
from .collate_fns.glamm_collate_fn import glamm_collate_fn
136 changes: 136 additions & 0 deletions projects/glamm/datasets/collate_fns/glamm_collate_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import Dict, Sequence

import torch
from torch.nn.utils.rnn import pad_sequence

from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
pad_for_sequence_parallel)
from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX


def glamm_collate_fn(instances: Sequence[Dict],
pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
return_hf_format: bool = False,
use_varlen_attn: bool = False):
seq_parallel_world_size = get_sequence_parallel_world_size()

input_ids, labels = [], []
has_image = any(inst.get('pixel_values') is not None for inst in instances)
has_grounding_image = any(inst.get('g_pixel_values') is not None for inst in instances)
has_mask = any(inst.get('masks') is not None for inst in instances)
has_bboxes = any(inst.get('bboxes') is not None for inst in instances)
has_points = any(inst.get('points') is not None for inst in instances)

if use_varlen_attn:
position_ids, cumulative_len = [], []
assert len(instances) == 1, (
f'If utilizing varlen attention, the batch size should be'
f' set to 1, but got {len(instances)}')
assert not has_image, 'Currently, it is not configured to '
'accommodate the use of varlen Attention in multimodal training'

if has_image:
pixel_values = []
if has_grounding_image:
grounding_pixel_values = []
if has_mask:
object_masks = []
if has_bboxes:
object_bboxes = []
if has_points:
prompt_points = []

for example in instances:
input_ids.append(torch.LongTensor(example['input_ids']))
labels.append(torch.LongTensor(example['labels']))
if use_varlen_attn:
cumulative_len.append(torch.IntTensor(example['cumulative_len']))
position_ids.append(torch.LongTensor(example['position_ids']))

if has_image:
pixel_values.append(example['pixel_values'])
if has_grounding_image:
grounding_pixel_values.append(example['g_pixel_values'])
if has_mask:
if 'masks' in example.keys() and example['masks'] is not None:
object_masks.append(example['masks'])
if has_bboxes:
if 'bboxes' in example.keys() and example['bboxes'] is not None:
object_bboxes.append(example['bboxes'])
if has_points:
if 'points' in example.keys() and example['points'] is not None:
prompt_points.append(example['points'])

ori_length = [len(ids) for ids in input_ids]
if len(instances) > 1:
input_ids = pad_sequence(
input_ids, batch_first=True, padding_value=pad_index)
labels = pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX)
else:
input_ids = torch.stack(input_ids)
labels = torch.stack(labels)

if use_varlen_attn:
assert input_ids.size(1) % seq_parallel_world_size == 0
attention_mask = None
position_ids = torch.stack(position_ids, dim=0)
else:
# Some tokenizers have the same eos token and pad token, so input_ids
# cannot be masked directly based on the pad token id.
attention_mask = torch.zeros_like(input_ids).bool()
for i, length in enumerate(ori_length):
attention_mask[i, :length] = True

bs, seq_len = input_ids.shape
position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)

if seq_parallel_world_size > 1:
input_ids = pad_for_sequence_parallel(input_ids, pad_index)
labels = pad_for_sequence_parallel(labels, IGNORE_INDEX)
position_ids = pad_for_sequence_parallel(position_ids, 0)
if attention_mask is not None:
attention_mask = pad_for_sequence_parallel(attention_mask, 0)

if use_varlen_attn:
max_seqlen = (
cumulative_len[0][1:] - # noqa: W504
cumulative_len[0][:-1]).max().item()
data_dict = {
'input_ids': input_ids,
'cumulative_len': cumulative_len,
'position_ids': position_ids,
'labels': labels,
'max_seqlen': max_seqlen
}
else:
data_dict = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'position_ids': position_ids,
'labels': labels
}

if has_image:
if all(x.shape == pixel_values[0].shape for x in pixel_values):
pixel_values = torch.stack(pixel_values, dim=0)
data_dict['pixel_values'] = pixel_values

if has_grounding_image:
# if all(x.shape == grounding_pixel_values[0].shape for x in grounding_pixel_values):
# grounding_pixel_values = torch.stack(grounding_pixel_values, dim=0)
data_dict['g_pixel_values'] = grounding_pixel_values

if has_mask:
data_dict['masks'] = object_masks

if has_bboxes:
data_dict['bboxes'] = object_bboxes

if has_points:
data_dict['points'] = prompt_points

if return_hf_format:
return data_dict
else:
return {'data': data_dict, 'data_samples': None}
Loading

0 comments on commit 2e7edca

Please sign in to comment.