-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
97 changed files
with
19,469 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/*/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
|
||
# custom | ||
checkpoints/ | ||
checkpoints | ||
data/ | ||
data | ||
.vscode | ||
.idea | ||
.DS_Store | ||
*.pkl | ||
*.pkl.json | ||
*.log.json | ||
work_dirs/ | ||
pretrained/ | ||
temp_*/ | ||
|
||
# Pytorch | ||
*.pth | ||
*.py~ | ||
*.sh~ | ||
|
||
# srun | ||
*.out | ||
batchscript-* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos | ||
|
||
[\[🏠 Sa2VA\]](https://lxtgh.github.io/project/sa2va) [\[📜 arXiv\]]() [\[🤗 Model\]](https://huggingface.co/ByteDance/Sa2VA) [\[🎥 Introduction\]]() [\[🧑💻 GitHub\]](https://github.com/magic-research/Sa2VA) | ||
|
||
![Teaser](assets/images/teaser.jpg) | ||
|
||
## Overiew | ||
This repository contains the code for the paper "Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos". | ||
|
||
Sa2VA is the the first unified model for dense grounded understanding of both images and videos. Unlike existing multi-modal large language models, which are often limited to specific modalities and tasks, Sa2VA supports a wide range of image and video tasks, including referring segmentation and conversation, with minimal one-shot instruction tuning. Sa2VA combines SAM-2, a foundation video segmentation model, with LLaVA, an advanced vision-language model, and unifies text, image, and video into a shared LLM token space. | ||
|
||
## Model Zoo | ||
We provide the following models: | ||
| Model Name | Base MLLM | Language Part | HF Link | | ||
|:----------:|:-----------------------------------------------------------------:|:-----------------------------------------------------------------------------:|:----------------------------------------------------:| | ||
| Sa2VA-1B | [InternVL2.5-1B](https://huggingface.co/OpenGVLab/InternVL2_5-1B) | [Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) | [🤗 link](https://huggingface.co/ByteDance/Sa2VA-1B) | | ||
| Sa2VA-4B | [InternVL2.5-4B](https://huggingface.co/OpenGVLab/InternVL2_5-4B) | [Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct) | [🤗 link](https://huggingface.co/ByteDance/Sa2VA-4B) | | ||
| Sa2VA-8B | [InternVL2.5-8B](https://huggingface.co/OpenGVLab/InternVL2_5-8B) | [internlm2_5-7b-chat](https://huggingface.co/internlm/internlm2_5-7b-chat) | [🤗 link](https://huggingface.co/ByteDance/Sa2VA-8B) | | ||
|
||
## Training | ||
<details> | ||
<summary>Installation</summary> | ||
|
||
1. Please install the python and pytorch first: | ||
```bash | ||
> conda create -n vlm python=3.10 | ||
> conda activate vlm | ||
> conda install pytorch==2.3.1 torchvision==0.18.1 pytorch-cuda=12.1 cuda -c pytorch -c "nvidia/label/cuda-12.1.0" -c "nvidia/label/cuda-12.1.1" | ||
``` | ||
|
||
2. Install mmcv: | ||
```bash | ||
> pip install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.3/index.html | ||
``` | ||
|
||
3. Install other dependencies: | ||
```bash | ||
> pip install -r requirements.txt | ||
``` | ||
</details> | ||
|
||
<details> | ||
<summary>Pretrained Model Preparation</summary> | ||
|
||
You are expected to download the following pretrained models and place them in the `./pretrained` directory: | ||
- [sam2_hiera_large.pt](https://huggingface.co/facebook/sam2-hiera-large) | ||
- [InternVL2_5-4B](https://huggingface.co/OpenGVLab/InternVL2_5-4B) | ||
|
||
</details> | ||
|
||
<details> | ||
<summary>Data Preparation</summary> | ||
|
||
(TODO) Please download the training datasets and place them in the `data` directory. The download link is [here](https://huggingface.co/datasets/Dense-World/Sa2VA-Training). | ||
|
||
</details> | ||
|
||
|
||
<details> | ||
<summary>Training Script</summary> | ||
|
||
Please run the following script to train: | ||
```bash | ||
> bash tools/dist.sh train projects/llava_sam2/configs/sa2va_4b.py 8 | ||
``` | ||
</details> | ||
|
||
|
||
## References | ||
If you find this repository useful, please consider referring the following paper: | ||
``` | ||
@article{sa2va, | ||
title={Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos}, | ||
author={Yuan, Haobo and Li, Xiangtai and Zhang, Tao and Huang, Zilong and Xu, Shilin and Ji, Shunping and Tong, Yunhai and Qi, Lu and Feng, Jiashi and Yang, Ming-Hsuan}, | ||
journal={arXiv}, | ||
year={2025} | ||
} | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .semantic_seg_dataset import SemanticSegDataset, ADE20kSemanticSegDataset, \ | ||
COCOStuffSemanticSegDataset, PascalPartSemanticSegDataset, PacoSemanticSegDataset | ||
from .gcg_dataset import GCGDataset, GranDfGCGDataset, RefCOCOgGCGDataset, OpenPsgGCGDataset, Flickr30kGCGDataset | ||
from .region_level_dataset import RefCocoGRegionDataset, VisualGenomeRegionDataset | ||
from .refcoco_segm_dataset import ReferSegmDataset | ||
from .utils.utils import * | ||
from .collate_fns.glamm_collate_fn import glamm_collate_fn |
136 changes: 136 additions & 0 deletions
136
projects/glamm/datasets/collate_fns/glamm_collate_fn.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
from typing import Dict, Sequence | ||
|
||
import torch | ||
from torch.nn.utils.rnn import pad_sequence | ||
|
||
from xtuner.parallel.sequence import (get_sequence_parallel_world_size, | ||
pad_for_sequence_parallel) | ||
from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX | ||
|
||
|
||
def glamm_collate_fn(instances: Sequence[Dict], | ||
pad_index: int = DEFAULT_PAD_TOKEN_INDEX, | ||
return_hf_format: bool = False, | ||
use_varlen_attn: bool = False): | ||
seq_parallel_world_size = get_sequence_parallel_world_size() | ||
|
||
input_ids, labels = [], [] | ||
has_image = any(inst.get('pixel_values') is not None for inst in instances) | ||
has_grounding_image = any(inst.get('g_pixel_values') is not None for inst in instances) | ||
has_mask = any(inst.get('masks') is not None for inst in instances) | ||
has_bboxes = any(inst.get('bboxes') is not None for inst in instances) | ||
has_points = any(inst.get('points') is not None for inst in instances) | ||
|
||
if use_varlen_attn: | ||
position_ids, cumulative_len = [], [] | ||
assert len(instances) == 1, ( | ||
f'If utilizing varlen attention, the batch size should be' | ||
f' set to 1, but got {len(instances)}') | ||
assert not has_image, 'Currently, it is not configured to ' | ||
'accommodate the use of varlen Attention in multimodal training' | ||
|
||
if has_image: | ||
pixel_values = [] | ||
if has_grounding_image: | ||
grounding_pixel_values = [] | ||
if has_mask: | ||
object_masks = [] | ||
if has_bboxes: | ||
object_bboxes = [] | ||
if has_points: | ||
prompt_points = [] | ||
|
||
for example in instances: | ||
input_ids.append(torch.LongTensor(example['input_ids'])) | ||
labels.append(torch.LongTensor(example['labels'])) | ||
if use_varlen_attn: | ||
cumulative_len.append(torch.IntTensor(example['cumulative_len'])) | ||
position_ids.append(torch.LongTensor(example['position_ids'])) | ||
|
||
if has_image: | ||
pixel_values.append(example['pixel_values']) | ||
if has_grounding_image: | ||
grounding_pixel_values.append(example['g_pixel_values']) | ||
if has_mask: | ||
if 'masks' in example.keys() and example['masks'] is not None: | ||
object_masks.append(example['masks']) | ||
if has_bboxes: | ||
if 'bboxes' in example.keys() and example['bboxes'] is not None: | ||
object_bboxes.append(example['bboxes']) | ||
if has_points: | ||
if 'points' in example.keys() and example['points'] is not None: | ||
prompt_points.append(example['points']) | ||
|
||
ori_length = [len(ids) for ids in input_ids] | ||
if len(instances) > 1: | ||
input_ids = pad_sequence( | ||
input_ids, batch_first=True, padding_value=pad_index) | ||
labels = pad_sequence( | ||
labels, batch_first=True, padding_value=IGNORE_INDEX) | ||
else: | ||
input_ids = torch.stack(input_ids) | ||
labels = torch.stack(labels) | ||
|
||
if use_varlen_attn: | ||
assert input_ids.size(1) % seq_parallel_world_size == 0 | ||
attention_mask = None | ||
position_ids = torch.stack(position_ids, dim=0) | ||
else: | ||
# Some tokenizers have the same eos token and pad token, so input_ids | ||
# cannot be masked directly based on the pad token id. | ||
attention_mask = torch.zeros_like(input_ids).bool() | ||
for i, length in enumerate(ori_length): | ||
attention_mask[i, :length] = True | ||
|
||
bs, seq_len = input_ids.shape | ||
position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1) | ||
|
||
if seq_parallel_world_size > 1: | ||
input_ids = pad_for_sequence_parallel(input_ids, pad_index) | ||
labels = pad_for_sequence_parallel(labels, IGNORE_INDEX) | ||
position_ids = pad_for_sequence_parallel(position_ids, 0) | ||
if attention_mask is not None: | ||
attention_mask = pad_for_sequence_parallel(attention_mask, 0) | ||
|
||
if use_varlen_attn: | ||
max_seqlen = ( | ||
cumulative_len[0][1:] - # noqa: W504 | ||
cumulative_len[0][:-1]).max().item() | ||
data_dict = { | ||
'input_ids': input_ids, | ||
'cumulative_len': cumulative_len, | ||
'position_ids': position_ids, | ||
'labels': labels, | ||
'max_seqlen': max_seqlen | ||
} | ||
else: | ||
data_dict = { | ||
'input_ids': input_ids, | ||
'attention_mask': attention_mask, | ||
'position_ids': position_ids, | ||
'labels': labels | ||
} | ||
|
||
if has_image: | ||
if all(x.shape == pixel_values[0].shape for x in pixel_values): | ||
pixel_values = torch.stack(pixel_values, dim=0) | ||
data_dict['pixel_values'] = pixel_values | ||
|
||
if has_grounding_image: | ||
# if all(x.shape == grounding_pixel_values[0].shape for x in grounding_pixel_values): | ||
# grounding_pixel_values = torch.stack(grounding_pixel_values, dim=0) | ||
data_dict['g_pixel_values'] = grounding_pixel_values | ||
|
||
if has_mask: | ||
data_dict['masks'] = object_masks | ||
|
||
if has_bboxes: | ||
data_dict['bboxes'] = object_bboxes | ||
|
||
if has_points: | ||
data_dict['points'] = prompt_points | ||
|
||
if return_hf_format: | ||
return data_dict | ||
else: | ||
return {'data': data_dict, 'data_samples': None} |
Oops, something went wrong.