From 55a1c73dd1af44ee91ac7136015af7ec06641b43 Mon Sep 17 00:00:00 2001 From: je1lee Date: Tue, 9 Apr 2024 05:43:47 +0000 Subject: [PATCH 1/3] fix: early stop when all sequence reach EOS --- gemma/model.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/gemma/model.py b/gemma/model.py index 87280d2..bf57ef0 100644 --- a/gemma/model.py +++ b/gemma/model.py @@ -469,6 +469,7 @@ def generate( batch_size = len(prompts) prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts] + prompt_length = [len(p) for p in prompt_tokens] min_prompt_len = min(len(p) for p in prompt_tokens) max_prompt_len = max(len(p) for p in prompt_tokens) max_seq_len = max_prompt_len + output_len @@ -511,6 +512,7 @@ def generate( top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to( device) + eos_flags_tensor = torch.tensor([False] * batch_size).to(device) # Prefill up to min_prompt_len tokens, then treat other prefill as # decode and ignore output. @@ -543,6 +545,16 @@ def generate( device) output_index = output_index + 1 + # Check if all sequences have reached EOS. + batch_eos_idx = (next_token_ids == self.tokenizer.eos_id).nonzero( + as_tuple=True)[0] + for eos_idx in batch_eos_idx: + if output_index >= prompt_length[eos_idx]: + eos_flags_tensor[eos_idx] = True + + if eos_flags_tensor.all(): + break + # Detokenization. token_ids = token_ids_tensor.tolist() results = [] From 488e5f24175c4fdaa87b8ba10419833d8fd6a431 Mon Sep 17 00:00:00 2001 From: je1lee Date: Tue, 9 Apr 2024 06:57:41 +0000 Subject: [PATCH 2/3] style: tab in line --- gemma/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gemma/model.py b/gemma/model.py index bf57ef0..6edbd52 100644 --- a/gemma/model.py +++ b/gemma/model.py @@ -551,7 +551,7 @@ def generate( for eos_idx in batch_eos_idx: if output_index >= prompt_length[eos_idx]: eos_flags_tensor[eos_idx] = True - + if eos_flags_tensor.all(): break From 815a0c987ebe615da69a290e862ae11696a1b180 Mon Sep 17 00:00:00 2001 From: je1lee Date: Tue, 9 Apr 2024 06:58:29 +0000 Subject: [PATCH 3/3] fix: (xla) early stop when all sequence reach EOS --- scripts/run_xla.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/scripts/run_xla.py b/scripts/run_xla.py index 4240fa7..d3beeb6 100644 --- a/scripts/run_xla.py +++ b/scripts/run_xla.py @@ -134,6 +134,7 @@ def generate( input_token_ids_tensor = torch.full((batch_size, min_prompt_len), tokenizer.pad_id, dtype=torch.int64) + prompt_length = [len(p) for p in prompt_tokens] for i, p in enumerate(prompt_tokens): token_ids_tensor[i, :len(p)] = torch.tensor(p) input_token_ids_tensor[i, :min_prompt_len] = torch.tensor( @@ -152,9 +153,10 @@ def generate( top_ps_tensor = torch.FloatTensor(top_ps).to(device) top_ks_tensor = torch.LongTensor(top_ks).to(device) output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device) + eos_flags_tensor = torch.tensor([False] * batch_size).to(device) + if not USE_CUDA: xm.mark_step() - # Prefill up to min_prompt_len tokens, then treat other prefill as decode and ignore output. for i in range(max_seq_len - min_prompt_len): next_token_ids = model( @@ -184,6 +186,16 @@ def generate( if not USE_CUDA: xm.mark_step() + # Check if all sequences have reached EOS. + batch_eos_idx = (next_token_ids == tokenizer.eos_id).nonzero( + as_tuple=True)[0] + for eos_idx in batch_eos_idx: + if output_index >= prompt_length[eos_idx]: + eos_flags_tensor[eos_idx] = True + + if eos_flags_tensor.all(): + break + # Detokenization. token_ids = token_ids_tensor.tolist() results = []