Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbaijin authored Jan 10, 2025
1 parent 38f4ee4 commit 21120d5
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 29 deletions.
133 changes: 109 additions & 24 deletions minigpt4/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def forward(self, samples):

return {"loss": loss}

@torch.no_grad()
# @torch.no_grad()
def generate(
self,
samples,
Expand All @@ -147,14 +147,13 @@ def generate(
num_captions=1,
temperature=1,
output_attentions=False,
return_dict_in_generate=False,
# ours
opera_decoding=False,
key_position=None,
scale_factor=1.0,
threshold=1,
num_attn_candidates=5,
penalty_weights=1.0,
# opera_decoding=False,
# key_position=None,
# scale_factor=1.0,
# threshold=1,
# num_attn_candidates=5,
# penalty_weights=1.0,
):
self.llama_tokenizer.padding_side = "left"

Expand Down Expand Up @@ -202,12 +201,12 @@ def generate(
with self.maybe_autocast():
input_ids = torch.cat([bos, tokens_before, image_token, tokens_after], dim=1)

if key_position is None:
key_position = {
"image_start": tokens_before.shape[1]+1,
"image_end": tokens_before.shape[1]+NUM_IMAGE_TOKENS,
"response_start": input_ids.shape[1]+NUM_IMAGE_TOKENS-1,
}
# if key_position is None:
# key_position = {
# "image_start": tokens_before.shape[1]+1,
# "image_end": tokens_before.shape[1]+NUM_IMAGE_TOKENS,
# "response_start": input_ids.shape[1]+NUM_IMAGE_TOKENS-1,
# }

output_ids = self.llama_model.generate(
input_ids=input_ids,
Expand All @@ -226,24 +225,110 @@ def generate(
# num_return_sequences=num_captions,
images=image,
output_attentions=output_attentions,
return_dict_in_generate=return_dict_in_generate,
# opera
opera_decoding=opera_decoding,
key_position=key_position,
scale_factor=scale_factor,
threshold=threshold,
num_attn_candidates=num_attn_candidates,
penalty_weights=penalty_weights,
# opera_decoding=opera_decoding,
# key_position=key_position,
# scale_factor=scale_factor,
# threshold=threshold,
# num_attn_candidates=num_attn_candidates,
# penalty_weights=penalty_weights,
)

input_token_len = input_ids.shape[1]

n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()

if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
output_text = self.llama_tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
output_text = [text.split('###')[0].strip() for text in output_text]
return output_text
return (output_text, input_token_len, output_ids)

#return output_ids

def generate_output(
self,
samples,
use_nucleus_sampling=False,
num_beams=5,
max_length=256,
min_length=1,
max_new_tokens=300,
top_p=0.9,
repetition_penalty=1.0,
length_penalty=1,
num_captions=1,
temperature=1,
output_attentions=False,
# ours
opera_decoding=False,
key_position=None,
scale_factor=1.0,
threshold=1,
num_attn_candidates=5,
penalty_weights=1.0,
):
self.llama_tokenizer.padding_side = "left"

image = samples["image"]

instruction = samples["prompt"] if "prompt" in samples else None

bs = image.size(0)

if isinstance(instruction, str):
instruction = [instruction] * bs
else:
assert len(instruction) == bs, "The number of prompts must be equal to the batch size."

instruction = [self.system_message + p for p in instruction]

chunks_before, chunks_after = [], []
for p in instruction:
chunk_before, chunk_after = p.split('<ImageHere>')
chunks_before.append(chunk_before)
chunks_after.append(chunk_after)

tokens_before = self.llama_tokenizer(
chunks_before,
return_tensors="pt",
padding="longest",
add_special_tokens=False
).to(image.device).input_ids

tokens_after = self.llama_tokenizer(
chunks_after,
return_tensors="pt",
padding="longest",
add_special_tokens=False
).to(image.device).input_ids

bos = torch.ones([bs, 1],
dtype=torch.int64,
device=image.device) * self.llama_tokenizer.bos_token_id

image_token = torch.ones([bs, 1],
dtype=torch.int64,
device=image.device) * IMAGE_TOKEN_INDEX

with self.maybe_autocast():
input_ids = torch.cat([bos, tokens_before, image_token, tokens_after], dim=1)

if key_position is None:
key_position = {
"image_start": tokens_before.shape[1]+1,
"image_end": tokens_before.shape[1]+NUM_IMAGE_TOKENS,
"response_start": input_ids.shape[1]+NUM_IMAGE_TOKENS-1,
}

output_ids = self.llama_model(
input_ids=input_ids,
use_cache=True,
images=image,
output_attentions=output_attentions,
)

return output_ids

def embed_tokens(self, token_ids):
if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model
Expand Down Expand Up @@ -294,4 +379,4 @@ def from_config(cls, cfg):
system_message=system_message,
)

return model
return model
12 changes: 9 additions & 3 deletions minigpt4/models/llava_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, vision_tower, args, delay_load=False):
def load_model(self):
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
self.vision_tower.requires_grad_(False)
# self.vision_tower.requires_grad_(False)

self.is_loaded = True

Expand All @@ -68,7 +68,7 @@ def feature_select(self, image_forward_outs):
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features

@torch.no_grad()
# @torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
Expand Down Expand Up @@ -238,6 +238,7 @@ def get_vision_tower(self):
def encode_images(self, images):
image_features = self.get_model().get_vision_tower()(images)
image_features = self.get_model().mm_projector(image_features)
image_features.retain_grad()
return image_features

def prepare_inputs_labels_for_multimodal(
Expand All @@ -257,7 +258,7 @@ def prepare_inputs_labels_for_multimodal(
image_features = [x.flatten(0, 1) for x in image_features]
else:
image_features = self.encode_images(images)

image_features.retain_grad()
new_input_embeds = []
new_labels = [] if labels is not None else None
cur_image_idx = 0
Expand All @@ -267,6 +268,7 @@ def prepare_inputs_labels_for_multimodal(
# FIXME: this is a hacky fix, for deepspeed zero3 to work
half_len = cur_input_ids.shape[0] // 2
cur_image_features = image_features[cur_image_idx]
cur_image_features.retain_grad()
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
Expand All @@ -283,6 +285,7 @@ def prepare_inputs_labels_for_multimodal(
assert cur_labels.shape == cur_input_ids.shape
while image_token_indices.numel() > 0:
cur_image_features = image_features[cur_image_idx]
cur_image_features.retain_grad()
image_token_start = image_token_indices[0]
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())
Expand All @@ -297,6 +300,7 @@ def prepare_inputs_labels_for_multimodal(
else:
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
cur_new_input_embeds.append(cur_image_features)
cur_new_input_embeds[0].retain_grad()
if labels is not None:
cur_new_labels.append(cur_labels[:image_token_start])
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
Expand Down Expand Up @@ -329,6 +333,7 @@ def prepare_inputs_labels_for_multimodal(
cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
new_input_embeds_align.append(cur_new_embed)
new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
new_input_embeds.retain_grad()

if labels is not None:
new_labels_align = []
Expand All @@ -349,6 +354,7 @@ def prepare_inputs_labels_for_multimodal(
assert attention_mask.shape == new_labels.shape
else:
new_input_embeds = torch.stack(new_input_embeds, dim=0)
new_input_embeds.retain_grad()
if labels is not None:
new_labels = torch.stack(new_labels, dim=0)

Expand Down
6 changes: 4 additions & 2 deletions minigpt4/models/llava_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def forward(

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
hidden_states.retain_grad()
logits.retain_grad()

loss = None
if labels is not None:
Expand All @@ -96,8 +98,8 @@ def forward(
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_logits = shift_logits.reshape(-1, self.config.vocab_size)
shift_labels = shift_labels.reshape(-1)
# Enable model/pipeline parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
Expand Down

0 comments on commit 21120d5

Please sign in to comment.