From 476ea6ff064e75b4ad96ff2fc57ecffde2bc1fb0 Mon Sep 17 00:00:00 2001 From: Gabe Grand Date: Thu, 18 Jul 2024 15:37:24 -0400 Subject: [PATCH 1/4] Add LMContext.token_count --- hfppl/distributions/lmcontext.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/hfppl/distributions/lmcontext.py b/hfppl/distributions/lmcontext.py index b2971fe..30f4d0b 100644 --- a/hfppl/distributions/lmcontext.py +++ b/hfppl/distributions/lmcontext.py @@ -126,6 +126,7 @@ def __init__(self, lm, prompt, temp=1.0): self.temp = temp self.model_mask = lm.masks.ALL_TOKENS self.prompt_string_length = len(lm.tokenizer.decode(self.tokens)) + self.prompt_token_count = len(self.tokens) self.show_prompt = False def next_token(self): @@ -148,6 +149,10 @@ def mask_dist(self, mask): """ return LMTokenMask(self, mask) + @property + def token_count(self): + return len(self.tokens) - self.prompt_token_count + def __str__(self): base = 0 if self.show_prompt else self.prompt_string_length full_string = self.lm.tokenizer.decode(self.tokens) From e3b97d65452d3e68d4c3a9fb99c51fb1c132624a Mon Sep 17 00:00:00 2001 From: Gabe Grand Date: Thu, 18 Jul 2024 15:39:32 -0400 Subject: [PATCH 2/4] Update Masks --- hfppl/llms.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/hfppl/llms.py b/hfppl/llms.py index 1dcb334..5b56064 100644 --- a/hfppl/llms.py +++ b/hfppl/llms.py @@ -22,8 +22,37 @@ def __init__(self, lm): for (i, v) in enumerate(lm.vocab) if all(c in "'" or c.isalpha() for c in v) ) - self.PUNCTUATION = set(i for (i, v) in enumerate(lm.vocab) if v in ',:;.!?"-') - self.END_SENTENCE_PUNCT = set(i for (i, v) in enumerate(lm.vocab) if v in ".!?") + self.MID_PUNCTUATION = set( + i for (i, v) in enumerate(lm.vocab) if v in (",", ":", ";", "-", '"') + ) + self.END_PUNCTUATION = set( + i for (i, v) in enumerate(lm.vocab) if v in (".", "!", "?") + ) + self.PUNCTUATION = self.MID_PUNCTUATION | self.END_PUNCTUATION + self.CONTAINS_WHITESPACE = set( + i + for (i, v) in enumerate(lm.vocab) + if any(c in string.whitespace for c in v) + ) + + self.MAX_TOKEN_LENGTH = self.precompute_token_length_masks(lm) + + def precompute_token_length_masks(self, lm) -> Dict[int, Set[int]]: + """Precompute masks for tokens of different lengths. + + Each mask is a set of token ids that are of the given length or shorter.""" + max_token_length = max([len(t) for t in lm.vocab]) + + masks = defaultdict(lambda: self.ALL_TOKENS) + masks[0] = set([lm.tokenizer.eos_token_id]) + for token_length in range(1, max_token_length + 1): + masks[token_length] = set( + i + for (i, v) in enumerate(lm.vocab) + if len(v) <= token_length and i != lm.tokenizer.eos_token_id + ) + + return masks class TokenSequence: From 3195a80dee579f0315fcdd8d664be1b5f72d2c11 Mon Sep 17 00:00:00 2001 From: Gabe Grand Date: Thu, 18 Jul 2024 15:40:28 -0400 Subject: [PATCH 3/4] Add sample_word_2() --- hfppl/chunks.py | 86 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/hfppl/chunks.py b/hfppl/chunks.py index 663152c..88ee656 100644 --- a/hfppl/chunks.py +++ b/hfppl/chunks.py @@ -49,3 +49,89 @@ async def sample_word(self, context, max_tokens=5, allow_punctuation=True): punctuation = context.lm.vocab[punctuation_token.token_id] return word, punctuation + + +@submodel +async def sample_word_2( + self, + context, + max_chars: int = None, + allow_mid_punctuation: bool = True, + allow_end_punctuation: bool = True, +): + """Sample a word from the `LMContext` object `context`. + + Unlike sample_word() above, this method allows for character-level control over the length of the word. + It also allows for control over the presence of punctuation in the middle and at the end of the word. + + Args: + max_chars (int): Maximum number of characters in the word. If None, the model will sample a word of any length. + allow_mid_punctuation (bool): If True, the model may sample punctuation in the middle of the word. + allow_end_punctuation (bool): If True, the model may sample punctuation at the end of the word. + + Returns: + Tuple[str, str]: The sampled word and punctuation + """ + + # This approach sometimes breaks with max_chars = 1 + if max_chars is not None: + assert max_chars > 1 + + last_token = context.lm.vocab[context.tokens[-1]] if len(context.tokens) > 0 else "" + last_character = last_token[-1] if len(last_token) > 0 else "" + needs_space = last_character not in string.whitespace and last_character not in [ + "-", + "'", + '"', + ] + if needs_space: + starts_word_mask = context.lm.masks.STARTS_NEW_WORD + else: + starts_word_mask = context.lm.masks.CONTINUES_CURRENT_WORD + + # Force model to start a new word + await self.observe(context.mask_dist(starts_word_mask), True) + + word = "" + while True: + # Force model to sample a token with an appropriate number of characters + if max_chars is not None: + await self.observe( + context.mask_dist( + context.lm.masks.MAX_TOKEN_LENGTH[max_chars - len(word.strip())] + ), + True, + ) + + token = await self.sample(context.next_token()) + word += context.lm.vocab[token.token_id] + + # If we ran out of chars, break + if max_chars is not None and len(word.strip()) >= max_chars: + await self.observe( + context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD), False + ) + break + + # If the model wants to end the word, break + if not ( + await self.sample( + context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD) + ) + ): + break + + # Sample punctuation, if desired + punctuation = "" + + mask = set() + if allow_mid_punctuation: + mask = mask | context.lm.masks.MID_PUNCTUATION + if allow_end_punctuation: + mask = mask | context.lm.masks.END_PUNCTUATION + + if mask and await self.sample(context.mask_dist(mask)): + punctuation_token = await self.sample(context.next_token()) + punctuation = context.lm.vocab[punctuation_token.token_id] + + return word, punctuation From ebaed8d4b214d41ac394dbe917f816b21d1d63d0 Mon Sep 17 00:00:00 2001 From: Gabe Grand Date: Thu, 18 Jul 2024 15:42:15 -0400 Subject: [PATCH 4/4] Update LMContext --- hfppl/distributions/lmcontext.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/hfppl/distributions/lmcontext.py b/hfppl/distributions/lmcontext.py index 30f4d0b..16b8247 100644 --- a/hfppl/distributions/lmcontext.py +++ b/hfppl/distributions/lmcontext.py @@ -110,7 +110,7 @@ class LMContext: show_prompt (bool): controls whether the string representation of this `LMContext` includes the initial prompt or not. Defaults to `False`. """ - def __init__(self, lm, prompt, temp=1.0): + def __init__(self, lm, prompt, temp=1.0, show_prompt=False, show_eos=True): """Create a new `LMContext` with a given prompt and temperature. Args: @@ -127,7 +127,8 @@ def __init__(self, lm, prompt, temp=1.0): self.model_mask = lm.masks.ALL_TOKENS self.prompt_string_length = len(lm.tokenizer.decode(self.tokens)) self.prompt_token_count = len(self.tokens) - self.show_prompt = False + self.show_prompt = show_prompt + self.show_eos = show_eos def next_token(self): """Distribution over the next token. @@ -154,9 +155,12 @@ def token_count(self): return len(self.tokens) - self.prompt_token_count def __str__(self): - base = 0 if self.show_prompt else self.prompt_string_length full_string = self.lm.tokenizer.decode(self.tokens) - return full_string[base:] + if not self.show_prompt: + full_string = full_string[self.prompt_string_length:] + if not self.show_eos and full_string.endswith(self.lm.tokenizer.eos_token): + full_string = full_string[:-len(self.lm.tokenizer.eos_token)] + return full_string def __deepcopy__(self, memo): cpy = type(self).__new__(type(self))