diff --git a/open_flamingo/eval/eval_models/eval_model.py b/open_flamingo/eval/eval_models/eval_model.py index ac248a0a..43126326 100644 --- a/open_flamingo/eval/eval_models/eval_model.py +++ b/open_flamingo/eval/eval_models/eval_model.py @@ -195,7 +195,7 @@ def supported_tasks(self): Parsed by checking whether the model has a method called `get_{task}_prompt`. """ return [ - task.split("_")[1] + "_".join(task.split("_")[1:-1]) for task in dir(self) if task.startswith("get_") and task.endswith("_prompt") ] @@ -207,4 +207,4 @@ def _validate_text(self, batch_text): if any([x.endswith(" ") for x in batch_text]): print( "Warning: trailing whitespace detected in text. This can cause unexpected behavior." - ) \ No newline at end of file + ) diff --git a/open_flamingo/eval/evaluate.py b/open_flamingo/eval/evaluate.py index 8f3632ec..8b420f72 100644 --- a/open_flamingo/eval/evaluate.py +++ b/open_flamingo/eval/evaluate.py @@ -534,7 +534,7 @@ def main(): if args.eval_coco: eval_dataset( args, - dataset_name="flickr30", + dataset_name="coco", eval_model=eval_model, results=results, eval_fn=evaluate_captioning, @@ -1302,4 +1302,4 @@ def evaluate_classification( if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/open_flamingo/src/vlm.py b/open_flamingo/src/vlm.py index dfaa7716..5176d4d7 100644 --- a/open_flamingo/src/vlm.py +++ b/open_flamingo/src/vlm.py @@ -415,13 +415,14 @@ def _prepare_inputs_for_forward( past_vision_tokens=past_vision_tokens, num_beams=num_beams, ) - past_key_values = [ - ( - k.repeat_interleave(num_beams, dim=0), - v.repeat_interleave(num_beams, dim=0) - ) - for k, v in past_key_values - ] if past_key_values is not None else None + if past_key_values is not None: + past_key_values = [ + ( + k.repeat_interleave(num_beams, dim=0), + v.repeat_interleave(num_beams, dim=0) + ) + for k, v in past_key_values + return { "input_ids": lang_x, "attention_mask": attention_mask,