diff --git a/examples/grammar_constraint.py b/examples/grammar_constraint.py index d3f28ba..965d3b3 100644 --- a/examples/grammar_constraint.py +++ b/examples/grammar_constraint.py @@ -9,6 +9,7 @@ Requires synchromesh (github.com/kanishkg/synchromesh) """ + import asyncio import os from typing import List diff --git a/examples/haiku.py b/examples/haiku.py index c9665cb..91016be 100644 --- a/examples/haiku.py +++ b/examples/haiku.py @@ -4,26 +4,37 @@ import os # download the CMU pronunciation dictionary (if we haven't already) -nltk.download('cmudict') +nltk.download("cmudict") # Load the CMU pronunciation dictionary and use it for syllable counting from nltk.corpus import cmudict + CMUDICT = cmudict.dict() + def count_syllables(word, unknown_word_syllables=100): - + # Use the dictionary to get the list of possible phonetic representations for the word phonetic_transcriptions = CMUDICT.get(word.strip().lower(), []) - + # Count the number of syllables based on the number of phonetic transcriptions - syllable_count = min([len([ph for ph in transcription if ph[-1].isdigit()]) for transcription in phonetic_transcriptions], default=unknown_word_syllables) + syllable_count = min( + [ + len([ph for ph in transcription if ph[-1].isdigit()]) + for transcription in phonetic_transcriptions + ], + default=unknown_word_syllables, + ) return syllable_count + # Load the language model (llama2 if authorized, else mistral-7b). -if 'HF_AUTH_TOKEN' in os.environ: - HF_AUTH_TOKEN = os.environ['HF_AUTH_TOKEN'] - LLM = CachedCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", auth_token=HF_AUTH_TOKEN) +if "HF_AUTH_TOKEN" in os.environ: + HF_AUTH_TOKEN = os.environ["HF_AUTH_TOKEN"] + LLM = CachedCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", auth_token=HF_AUTH_TOKEN + ) else: LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") @@ -74,21 +85,22 @@ def count_syllables(word, unknown_word_syllables=100): # Useful constants NEWLINE_TOKEN, EOS_TOKEN = 13, LLM.tokenizer.eos_token_id + # LLaMPPL model class Haiku(Model): - + def __init__(self, prompt, syllable_pattern=[5, 7, 5]): super().__init__() self.context = LMContext(LLM, prompt, 0.7) self.syllable_pattern = syllable_pattern - + async def step(self): # Get the number of syllables required in the next line syllables_remaining = self.syllable_pattern.pop(0) - + # Loop to sample words until this line is over while syllables_remaining > 0: - + # Sample a word word, punctuation = await self.call(sample_word(self.context)) @@ -103,18 +115,19 @@ async def step(self): await self.observe(self.context.next_token(), EOS_TOKEN) self.finish() return - + # Otherwise, observe a line break await self.observe(self.context.next_token(), NEWLINE_TOKEN) # Print current result print(str(self.context)) + # Run inference -SYLLABLES_PER_LINE = [5, 7, 5] # [5, 3, 5] for a Lune +SYLLABLES_PER_LINE = [5, 7, 5] # [5, 3, 5] for a Lune particles = asyncio.run(smc_standard(Haiku(poem_prompt, SYLLABLES_PER_LINE), 120)) print("--------") -for (i,particle) in enumerate(particles): +for i, particle in enumerate(particles): print(f"Poem {i} (weight {particle.weight}):") - print(f"{particle.context}") \ No newline at end of file + print(f"{particle.context}") diff --git a/examples/hard_constraints.py b/examples/hard_constraints.py index 1feb848..75fce14 100644 --- a/examples/hard_constraints.py +++ b/examples/hard_constraints.py @@ -4,22 +4,30 @@ import os -if 'HF_AUTH_TOKEN' in os.environ: - HF_AUTH_TOKEN = os.environ['HF_AUTH_TOKEN'] +if "HF_AUTH_TOKEN" in os.environ: + HF_AUTH_TOKEN = os.environ["HF_AUTH_TOKEN"] -# Load the language model. +# Load the language model. # Mistral and Vicuna are open models; to use a model with restricted access, like LLaMA 2, # pass your HuggingFace API key as the optional `auth_token` argument: # LLM = CachedCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", auth_token=HF_AUTH_TOKEN) -# LLM = CachedCausalLM.from_pretrained("lmsys/vicuna-7b-v1.5") -LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") +LLM = CachedCausalLM.from_pretrained("lmsys/vicuna-7b-v1.5") +# LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") LLM.batch_size = 40 -MASKS = {i : set(j for (j,v) in enumerate(LLM.vocab) - if j != LLM.tokenizer.eos_token_id and '\n' not in v and - any(c.isalpha() or c in string.punctuation for c in v) and - len(v.strip()) <= 5 and (not v[0].isalpha() or i+len(v) <= 5)) - for i in range(6)} +MASKS = { + i: set( + j + for (j, v) in enumerate(LLM.vocab) + if j != LLM.tokenizer.eos_token_id + and "\n" not in v + and any(c.isalpha() or c in string.punctuation for c in v) + and len(v.strip()) <= 5 + and (not v[0].isalpha() or i + len(v) <= 5) + ) + for i in range(6) +} + class ConstraintModel(Model): def __init__(self, prompt, max_tokens): @@ -33,26 +41,27 @@ async def step(self): # Condition on next token being from mask await self.observe(self.context.mask_dist(mask), True) - + # Generate proposed token. token = await self.sample(self.context.next_token()) - + # Reduce number of max tokens remaining self.max_tokens -= 1 - + print(f"{self.context}") # Check if done if token == LLM.tokenizer.eos_token_id or self.max_tokens == 0: self.finish() return - + def active_constraint_mask(self): string_so_far = str(self.context) words = string_so_far.split() last_word = words[-1] if len(words) > 0 else "" return MASKS[min(5, len(last_word))] - + + # From Politico.com prompt = """3 things to watch … @@ -64,10 +73,12 @@ def active_constraint_mask(self): LLM.cache_kv(LLM.tokenizer.encode(prompt)) + async def main(): constraint_model = ConstraintModel(prompt, 50) particles = await smc_standard(constraint_model, 40) for p in particles: print(f"{p.context}") -asyncio.run(main()) \ No newline at end of file + +asyncio.run(main()) diff --git a/hfppl/__init__.py b/hfppl/__init__.py index aa8da01..ec1651a 100644 --- a/hfppl/__init__.py +++ b/hfppl/__init__.py @@ -6,4 +6,4 @@ from .distributions import * from .modeling import * from .inference import * -from .chunks import * \ No newline at end of file +from .chunks import * diff --git a/hfppl/chunks.py b/hfppl/chunks.py index 1635702..663152c 100644 --- a/hfppl/chunks.py +++ b/hfppl/chunks.py @@ -1,38 +1,51 @@ import string from .modeling import submodel + @submodel async def sample_word(self, context, max_tokens=5, allow_punctuation=True): """Sample a word from the `LMContext` object `context`.""" last_token = context.lm.vocab[context.tokens[-1]] if len(context.tokens) > 0 else "" last_character = last_token[-1] if len(last_token) > 0 else "" - needs_space = last_character not in string.whitespace and last_character not in ['-', "'", '"'] + needs_space = last_character not in string.whitespace and last_character not in [ + "-", + "'", + '"', + ] if needs_space: starts_word_mask = context.lm.masks.STARTS_NEW_WORD else: starts_word_mask = context.lm.masks.CONTINUES_CURRENT_WORD - + # Force model to start a new word await self.observe(context.mask_dist(starts_word_mask), True) word = "" num_tokens = 0 while True: - token = await self.sample(context.next_token()) - word += context.lm.vocab[token.token_id] + token = await self.sample(context.next_token()) + word += context.lm.vocab[token.token_id] num_tokens += 1 if num_tokens == max_tokens: - await self.observe(context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD), False) + await self.observe( + context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD), False + ) break - if not (await self.sample(context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD))): + if not ( + await self.sample( + context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD) + ) + ): break - + # Sample punctuation, if desired punctuation = "" - if allow_punctuation and await self.sample(context.mask_dist(context.lm.masks.PUNCTUATION)): + if allow_punctuation and await self.sample( + context.mask_dist(context.lm.masks.PUNCTUATION) + ): punctuation_token = await self.sample(context.next_token()) punctuation = context.lm.vocab[punctuation_token.token_id] - return word, punctuation \ No newline at end of file + return word, punctuation diff --git a/hfppl/distributions/__init__.py b/hfppl/distributions/__init__.py index d80078d..f87cdf0 100644 --- a/hfppl/distributions/__init__.py +++ b/hfppl/distributions/__init__.py @@ -17,4 +17,4 @@ from .tokencategorical import TokenCategorical from .transformer import Transformer from .lmcontext import LMContext -from .bernoulli import Bernoulli \ No newline at end of file +from .bernoulli import Bernoulli diff --git a/hfppl/distributions/bernoulli.py b/hfppl/distributions/bernoulli.py index 0d9028a..fed945b 100644 --- a/hfppl/distributions/bernoulli.py +++ b/hfppl/distributions/bernoulli.py @@ -2,13 +2,13 @@ import numpy as np + class Bernoulli(Distribution): - """A Bernoulli distribution. - """ - + """A Bernoulli distribution.""" + def __init__(self, p): """Create a Bernoulli distribution. - + Args: p: the probability-of-True for the Bernoulli distribution. """ @@ -20,6 +20,6 @@ async def sample(self): async def log_prob(self, value): return np.log(self.p) if value else np.log1p(-self.p) - + async def argmax(self, idx): - return ((self.p > 0.5) if idx == 0 else (self.p < 0.5)) \ No newline at end of file + return (self.p > 0.5) if idx == 0 else (self.p < 0.5) diff --git a/hfppl/distributions/distribution.py b/hfppl/distributions/distribution.py index 5296da7..e063b96 100644 --- a/hfppl/distributions/distribution.py +++ b/hfppl/distributions/distribution.py @@ -1,29 +1,28 @@ class Distribution: """Abstract base class for a distribution.""" - async def sample(self): """Generate a random sample from the distribution. - + Returns: x: a value randomly sampled from the distribution.""" raise NotImplementedError() - + async def log_prob(self, x): """Compute the log probability of a value under this distribution, or the log probability density if the distribution is continuous. - + Args: x: the point at which to evaluate the log probability. Returns: - logprob (float): the log probability of `x`.""" + logprob (float): the log probability of `x`.""" raise NotImplementedError() - + async def argmax(self, n): """Return the nth most probable outcome under this distribution (assuming this is a discrete distribution). - + Args: n (int): which value to return to, indexed from most probable (n=0) to least probable (n=|support|). Returns: x: the nth most probable outcome from this distribution.""" - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() diff --git a/hfppl/distributions/geometric.py b/hfppl/distributions/geometric.py index aa2be13..6ebf5e8 100644 --- a/hfppl/distributions/geometric.py +++ b/hfppl/distributions/geometric.py @@ -1,12 +1,13 @@ from .distribution import Distribution +import numpy as np + class Geometric(Distribution): - """A Geometric distribution. - """ - + """A Geometric distribution.""" + def __init__(self, p): """Create a Geometric distribution. - + Args: p: the rate of the Geometric distribution. """ @@ -17,7 +18,7 @@ async def sample(self): return n, await self.log_prob(n) async def log_prob(self, value): - return np.log(self.p) + np.log(1 - self.p)*(value - 1) - + return np.log(self.p) + np.log(1 - self.p) * (value - 1) + async def argmax(self, idx): - return idx - 1 # Most likely outcome is 0, then 1, etc. \ No newline at end of file + return idx - 1 # Most likely outcome is 0, then 1, etc. diff --git a/hfppl/distributions/lmcontext.py b/hfppl/distributions/lmcontext.py index 8400066..b2971fe 100644 --- a/hfppl/distributions/lmcontext.py +++ b/hfppl/distributions/lmcontext.py @@ -4,23 +4,24 @@ import numpy as np import copy + class LMNextToken(Distribution): - + def __init__(self, ctx): self.ctx = ctx - + async def log_prob(self, x): if isinstance(x, Token): x = x.token_id - + lp = self.ctx.next_token_logprobs[x] self.ctx.tokens.append(x) updated_logprobs = await self.ctx.lm.next_token_logprobs(self.ctx.tokens) self.ctx.next_token_logprobs = log_softmax(updated_logprobs / self.ctx.temp) self.ctx.model_mask = self.ctx.lm.masks.ALL_TOKENS - + return lp - + async def sample(self): probs = np.exp(self.ctx.next_token_logprobs) token_id = np.random.choice(len(probs), p=(probs)) @@ -28,71 +29,78 @@ async def sample(self): logprob = self.ctx.next_token_logprobs[token_id] # Reset mask and update logprobs - self.ctx.model_mask = self.ctx.lm.masks.ALL_TOKENS - updated_logprobs = await self.ctx.lm.next_token_logprobs(self.ctx.tokens) + self.ctx.model_mask = self.ctx.lm.masks.ALL_TOKENS + updated_logprobs = await self.ctx.lm.next_token_logprobs(self.ctx.tokens) self.ctx.next_token_logprobs = log_softmax(updated_logprobs / self.ctx.temp) - t = Token(self.ctx.lm, token_id, self.ctx.lm.tokenizer.convert_ids_to_tokens(token_id)) + t = Token( + self.ctx.lm, token_id, self.ctx.lm.tokenizer.convert_ids_to_tokens(token_id) + ) return t, logprob - + + class LMTokenMask(Distribution): def __init__(self, ctx, mask): - self.ctx = ctx + self.ctx = ctx self.mask = mask - + async def sample(self): - newly_bad_tokens = [i for i in self.ctx.model_mask if i not in self.mask] - good_tokens = [i for i in self.ctx.model_mask if i in self.mask] - logprob_no_mask = logsumexp(self.ctx.next_token_logprobs[newly_bad_tokens]) + newly_bad_tokens = [i for i in self.ctx.model_mask if i not in self.mask] + good_tokens = [i for i in self.ctx.model_mask if i in self.mask] + logprob_no_mask = logsumexp(self.ctx.next_token_logprobs[newly_bad_tokens]) if logprob_no_mask > 0: - logprob_yes_mask = float('-inf') + logprob_yes_mask = float("-inf") else: # When logprob_no_mask is very close to 0.0, np.log1p can raise a "divide by zero" # warning before returning -inf. We suppress this warning, because returning -inf # is the desired behavior (the LLM places no mass on 'yes'). - with np.errstate(divide='ignore'): - logprob_yes_mask = np.log1p(-np.exp(logprob_no_mask)) - decide_no_mask = np.random.rand() < np.exp(logprob_no_mask) + with np.errstate(divide="ignore"): + logprob_yes_mask = np.log1p(-np.exp(logprob_no_mask)) + decide_no_mask = np.random.rand() < np.exp(logprob_no_mask) if decide_no_mask: self.ctx.model_mask = self.ctx.model_mask - self.mask - self.ctx.next_token_logprobs[good_tokens] = float('-inf') + self.ctx.next_token_logprobs[good_tokens] = float("-inf") self.ctx.next_token_logprobs -= logprob_no_mask return False, logprob_no_mask else: self.ctx.model_mask = self.ctx.model_mask.intersection(self.mask) - self.ctx.next_token_logprobs[newly_bad_tokens] = float('-inf') + self.ctx.next_token_logprobs[newly_bad_tokens] = float("-inf") self.ctx.next_token_logprobs -= logprob_yes_mask return True, logprob_yes_mask - + async def log_prob(self, v): - good_tokens = self.ctx.model_mask.intersection(self.mask) if v else self.ctx.model_mask - self.mask - bad_tokens = [i for i in self.ctx.model_mask if i not in good_tokens] + good_tokens = ( + self.ctx.model_mask.intersection(self.mask) + if v + else self.ctx.model_mask - self.mask + ) + bad_tokens = [i for i in self.ctx.model_mask if i not in good_tokens] logprob_good = logsumexp(self.ctx.next_token_logprobs[list(good_tokens)]) - self.ctx.next_token_logprobs[bad_tokens] = float('-inf') + self.ctx.next_token_logprobs[bad_tokens] = float("-inf") self.ctx.next_token_logprobs -= logprob_good self.ctx.model_mask = good_tokens return logprob_good - - + + class LMContext: """Represents a generation-in-progress from a language model. - + The state tracks two pieces of information: - + * A sequence of tokens — the ever-growing context for the language model. * A *current mask* — a set of tokens that have not yet been ruled out as the next token. - + Storing a mask enables _sub-token_ generation: models can use `LMContext` to sample the next token in _stages_, first deciding, e.g., whether to use an upper-case or lower-case first letter, and only later deciding which upper-case or lower-case token to generate. - + The state of a `LMContext` can be advanced in two ways: - + 1. Sampling, observing, or intervening the `next_token()` distribution. This causes a token to be added to the growing sequence of tokens. Supports auto-batching. 2. Sampling, observing, or intervening the `mask_dist(mask)` distribution for a given mask (set of token ids). This changes the current mask. - + Attributes: lm (hfppl.llms.CachedCausalLM): the language model for which this is a context tokens (list[int]): the underlying sequence of tokens, including prompt, in this context @@ -101,52 +109,57 @@ class LMContext: model_mask (set[int]): set of tokens that have not been ruled out as the next token. This mask is managed by the `LMContext` object internally; do not mutate. show_prompt (bool): controls whether the string representation of this `LMContext` includes the initial prompt or not. Defaults to `False`. """ - + def __init__(self, lm, prompt, temp=1.0): """Create a new `LMContext` with a given prompt and temperature. - + Args: lm (hfppl.llms.CachedCausalLM): the language model for which this is a context. prompt (str): a string with which to initialize the context. Will be tokenized using `lm.tokenizer`. - temp (float): temeprature for next-token distribution (0 < temp < float('inf'))""" - self.lm = lm - self.tokens = lm.tokenizer.encode(prompt) - self.next_token_logprobs = log_softmax(lm.next_token_logprobs_unbatched(self.tokens) / temp) - self.temp = temp - self.model_mask = lm.masks.ALL_TOKENS + temp (float): temeprature for next-token distribution (0 < temp < float('inf')) + """ + self.lm = lm + self.tokens = lm.tokenizer.encode(prompt) + self.next_token_logprobs = log_softmax( + lm.next_token_logprobs_unbatched(self.tokens) / temp + ) + self.temp = temp + self.model_mask = lm.masks.ALL_TOKENS self.prompt_string_length = len(lm.tokenizer.decode(self.tokens)) - self.show_prompt = False - + self.show_prompt = False + def next_token(self): """Distribution over the next token. - - Sampling or observing from this distribution advances the state of this `LMContext` instance.""" + + Sampling or observing from this distribution advances the state of this `LMContext` instance. + """ return LMNextToken(self) - + def mask_dist(self, mask): """Bernoulli distribution, with probability of True equal to the probability that the next token of this `LMContext` belongs to the given mask. - + Sampling or observing from this distribution modifies the state of this `LMContext` instance, so that the `next_token()` distribution either *will* (if True) or *will not* (if False) generate a token from the given mask. - + Args: - mask: a `set(int)` specifying which token ids are included within the mask.""" - return LMTokenMask(self, mask) - + mask: a `set(int)` specifying which token ids are included within the mask. + """ + return LMTokenMask(self, mask) + def __str__(self): base = 0 if self.show_prompt else self.prompt_string_length full_string = self.lm.tokenizer.decode(self.tokens) return full_string[base:] - - def __deepcopy__(self, memo): + + def __deepcopy__(self, memo): cpy = type(self).__new__(type(self)) - + for k, v in self.__dict__.items(): - if k in set(['lm']): + if k in set(["lm"]): setattr(cpy, k, v) else: setattr(cpy, k, copy.deepcopy(v, memo)) - - return cpy \ No newline at end of file + + return cpy diff --git a/hfppl/distributions/logcategorical.py b/hfppl/distributions/logcategorical.py index 5dcd16c..3756eec 100644 --- a/hfppl/distributions/logcategorical.py +++ b/hfppl/distributions/logcategorical.py @@ -1,13 +1,14 @@ from .distribution import Distribution + class LogCategorical(Distribution): """A Geometric distribution.""" def __init__(self, logits): - """Create a Categorical distribution from unnormalized log probabilities (logits). + """Create a Categorical distribution from unnormalized log probabilities (logits). Given an array of logits, takes their `softmax` and samples an integer in `range(len(logits))` from the resulting categorical. - + Args: logits (np.array): a numpy array of unnormalized log probabilities. """ @@ -19,6 +20,6 @@ async def sample(self): async def log_prob(self, value): return self.log_probs[value] - + async def argmax(self, idx): - return np.argsort(self.log_probs)[-idx] \ No newline at end of file + return np.argsort(self.log_probs)[-idx] diff --git a/hfppl/distributions/tokencategorical.py b/hfppl/distributions/tokencategorical.py index e3fe275..7f109f8 100644 --- a/hfppl/distributions/tokencategorical.py +++ b/hfppl/distributions/tokencategorical.py @@ -3,29 +3,38 @@ from ..llms import Token import numpy as np + class TokenCategorical(Distribution): - def __init__(self, lm, logits): - """Create a Categorical distribution whose values are Tokens, not integers. - Given a language model `lm` and an array of unnormalized log probabilities (of length `len(lm.vocab)`), + def __init__(self, lm, logits): + """Create a Categorical distribution whose values are Tokens, not integers. + Given a language model `lm` and an array of unnormalized log probabilities (of length `len(lm.vocab)`), uses softmax to normalize them and samples a Token from the resulting categorical. - + Args: lm (hfppl.llms.CachedCausalLM): the language model whose vocabulary is to be generated from. logits (np.array): a numpy array of unnormalized log probabilities. """ - self.lm = lm + self.lm = lm self.log_probs = log_softmax(logits) if self.lm.tokenizer.vocab_size != len(logits): - raise RuntimeError(f"TokenCategorical: vocab size is {self.lm.tokenizer.vocab_size} but provided {len(logits)} logits.") + raise RuntimeError( + f"TokenCategorical: vocab size is {self.lm.tokenizer.vocab_size} but provided {len(logits)} logits." + ) async def sample(self): n = np.random.choice(len(self.log_probs), p=(np.exp(self.log_probs))) - return Token(self.lm, n, self.lm.tokenizer.convert_ids_to_tokens(n)), self.log_probs[n] + return ( + Token(self.lm, n, self.lm.tokenizer.convert_ids_to_tokens(n)), + self.log_probs[n], + ) async def log_prob(self, value): return self.log_probs[value.token_id] - + async def argmax(self, idx): tok = torch.argsort(self.log_probs)[-idx] - return Token(self.lm, tok, self.lm.tokenizer.convert_ids_to_tokens(tok)), self.log_probs[tok] \ No newline at end of file + return ( + Token(self.lm, tok, self.lm.tokenizer.convert_ids_to_tokens(tok)), + self.log_probs[tok], + ) diff --git a/hfppl/distributions/transformer.py b/hfppl/distributions/transformer.py index 32d05db..9fd834f 100644 --- a/hfppl/distributions/transformer.py +++ b/hfppl/distributions/transformer.py @@ -2,13 +2,14 @@ from ..llms import TokenSequence, Token import numpy as np + # Transformer(lm, prompt) -- where prompt can either be a string or a list of Tokens. class Transformer(Distribution): def __init__(self, lm, prompt, temp=1.0): """Create a Categorical distribution whose values are Tokens, with probabilities given by a language model. Supports auto-batching. - + Args: lm (hfppl.llms.CachedCausalLM): the language model. prompt (str | hfppl.llms.TokenSequence): the sequence of tokens to use as the prompt. If a string, `lm.tokenizer` is used to encode it. @@ -16,33 +17,36 @@ def __init__(self, lm, prompt, temp=1.0): """ self.lm = lm self.temp = temp - + # prompt will be a list of ints if isinstance(prompt, str): prompt = self.lm.tokenizer.encode(prompt) elif isinstance(prompt, TokenSequence): prompt = prompt.seq - + self.prompt = prompt - - + async def log_prob(self, x): log_probs = await self.lm.next_token_logprobs(self.prompt) log_probs = log_probs / self.temp - + if isinstance(x, Token): x = x.token_id - + return log_probs[x] - + async def sample(self): log_probs = await self.lm.next_token_logprobs(self.prompt) log_probs = log_probs / self.temp probs = np.exp(log_probs) token_id = np.random.choice(len(probs), p=(probs)) logprob = log_probs[token_id] - return Token(self.lm, token_id, self.lm.tokenizer.convert_ids_to_tokens(token_id)), logprob - + return ( + Token(self.lm, token_id, self.lm.tokenizer.convert_ids_to_tokens(token_id)), + logprob, + ) + + # def argmax(self, idx): # token_id = np.argsort(self.log_probs)[-idx] -# return Token(self.lm, token_id, self.lm.tokenizer.convert_ids_to_tokens(token_id)), log_probs[token_id] \ No newline at end of file +# return Token(self.lm, token_id, self.lm.tokenizer.convert_ids_to_tokens(token_id)), log_probs[token_id] diff --git a/hfppl/inference/__init__.py b/hfppl/inference/__init__.py index 432e969..0dbda60 100644 --- a/hfppl/inference/__init__.py +++ b/hfppl/inference/__init__.py @@ -7,5 +7,5 @@ * `smc_steer(model, num_beams, num_expansions)`: a without-replacement SMC algorithm that resembles beam search. """ -from .smc_standard import smc_standard +from .smc_standard import smc_standard from .smc_steer import smc_steer diff --git a/hfppl/inference/smc_standard.py b/hfppl/inference/smc_standard.py index 941fb2f..10500be 100644 --- a/hfppl/inference/smc_standard.py +++ b/hfppl/inference/smc_standard.py @@ -3,39 +3,47 @@ import numpy as np import asyncio + async def smc_standard(model, n_particles, ess_threshold=0.5): """ Standard sequential Monte Carlo algorithm with multinomial resampling. - + Args: model (hfppl.modeling.Model): The model to perform inference on. n_particles (int): Number of particles to execute concurrently. ess_threshold (float): Effective sample size below which resampling is triggered, given as a fraction of `n_particles`. - + Returns: particles (list[hfppl.modeling.Model]): The completed particles after inference. """ particles = [copy.deepcopy(model) for _ in range(n_particles)] weights = [0.0 for _ in range(n_particles)] - - while (any(map(lambda p: not p.done_stepping(), particles))): + + while any(map(lambda p: not p.done_stepping(), particles)): # Step each particle for p in particles: p.untwist() await asyncio.gather(*[p.step() for p in particles if not p.done_stepping()]) - + # Normalize weights W = np.array([p.weight for p in particles]) w_sum = logsumexp(W) normalized_weights = W - w_sum - + # Resample if necessary - if -logsumexp(normalized_weights * 2) < np.log(ess_threshold) + np.log(n_particles): + if -logsumexp(normalized_weights * 2) < np.log(ess_threshold) + np.log( + n_particles + ): # Alternative implementation uses a multinomial distribution and only makes n-1 copies, reusing existing one, but fine for now probs = np.exp(normalized_weights) - particles = [copy.deepcopy(particles[np.random.choice(range(len(particles)), p=probs)]) for _ in range(n_particles)] + particles = [ + copy.deepcopy( + particles[np.random.choice(range(len(particles)), p=probs)] + ) + for _ in range(n_particles) + ] avg_weight = w_sum - np.log(n_particles) for p in particles: p.weight = avg_weight - - return particles \ No newline at end of file + + return particles diff --git a/hfppl/inference/smc_steer.py b/hfppl/inference/smc_steer.py index cf0d88f..4a559b0 100644 --- a/hfppl/inference/smc_steer.py +++ b/hfppl/inference/smc_steer.py @@ -3,6 +3,7 @@ import asyncio from ..util import logsumexp, softmax + def find_c(weights, N): # Sort the weights sorted_weights = np.sort(weights) @@ -19,6 +20,7 @@ def find_c(weights, N): return (N - A_val) / B_val return N + def resample_optimal(weights, N): c = find_c(weights, N) # Weights for which c * w >= 1 are deterministically resampled @@ -46,20 +48,21 @@ def resample_optimal(weights, N): else: i += 1 # Concatenate the deterministic and stochastic resampled indices - #resampled = np.concatenate((deterministic, stoch_resampled)) - #return resampled + # resampled = np.concatenate((deterministic, stoch_resampled)) + # return resampled return deterministic, stoch_resampled, c + async def smc_steer(model, n_particles, n_beam): """ Modified sequential Monte Carlo algorithm that uses without-replacement resampling, as described in [our workshop abstract](https://arxiv.org/abs/2306.03081). - + Args: model (hfppl.modeling.Model): The model to perform inference on. n_particles (int): Number of particles to maintain. n_beam (int): Number of continuations to consider for each particle. - + Returns: particles (list[hfppl.modeling.Model]): The completed particles after inference. """ @@ -67,7 +70,7 @@ async def smc_steer(model, n_particles, n_beam): particles = [copy.deepcopy(model) for _ in range(n_particles)] for particle in particles: - particle.start() # TODO: allow to be async? + particle.start() # TODO: allow to be async? while any(map(lambda p: not p.done_stepping(), particles)): # Count the number of finished particles @@ -83,23 +86,29 @@ async def smc_steer(model, n_particles, n_beam): p.weight += np.log(n_total) - np.log(n_particles) else: p.weight += np.log(n_total) - np.log(n_particles) - np.log(n_beam) - super_particles.extend([copy.deepcopy(p) for _ in range(n_beam-1)]) - + super_particles.extend([copy.deepcopy(p) for _ in range(n_beam - 1)]) + # Step each super-particle - await asyncio.gather(*[p.step() for p in super_particles if not p.done_stepping()]) + await asyncio.gather( + *[p.step() for p in super_particles if not p.done_stepping()] + ) # Use optimal resampling to resample W = np.array([p.weight for p in super_particles]) W_tot = logsumexp(W) W_normalized = softmax(W) det_indices, stoch_indices, c = resample_optimal(W_normalized, n_particles) - particles = [super_particles[i] for i in np.concatenate((det_indices, stoch_indices))] + particles = [ + super_particles[i] for i in np.concatenate((det_indices, stoch_indices)) + ] # For deterministic particles: w = w * N/N' for i in det_indices: super_particles[i].weight += np.log(n_particles) - np.log(n_total) # For stochastic particles: w = 1/c * total sum(stoch weights) / num_stoch = sum(stoch weights / total) / num_stoch * total * N/M for i in stoch_indices: - super_particles[i].weight = W_tot - np.log(c) + np.log(n_particles) - np.log(n_total) + super_particles[i].weight = ( + W_tot - np.log(c) + np.log(n_particles) - np.log(n_total) + ) # Return the particles return particles diff --git a/hfppl/llms.py b/hfppl/llms.py index b3bec8a..1e554ec 100644 --- a/hfppl/llms.py +++ b/hfppl/llms.py @@ -5,34 +5,47 @@ import asyncio import string + class Masks: def __init__(self, lm): self.ALL_TOKENS = set(range(len(lm.vocab))) - self.STARTS_NEW_WORD = set(i for (i,v) in enumerate(lm.vocab) if v[0]==' ' and len(v) > 1 and v[1] not in string.whitespace and v[1] not in string.punctuation) - self.CONTINUES_CURRENT_WORD = set(i for (i,v) in enumerate(lm.vocab) if all(c in '\'' or c.isalpha() for c in v)) - self.PUNCTUATION = set(i for (i,v) in enumerate(lm.vocab) if v in ',:;.!?"-') - self.END_SENTENCE_PUNCT = set(i for (i, v) in enumerate(lm.vocab) if v in '.!?') + self.STARTS_NEW_WORD = set( + i + for (i, v) in enumerate(lm.vocab) + if v[0] == " " + and len(v) > 1 + and v[1] not in string.whitespace + and v[1] not in string.punctuation + ) + self.CONTINUES_CURRENT_WORD = set( + i + for (i, v) in enumerate(lm.vocab) + if all(c in "'" or c.isalpha() for c in v) + ) + self.PUNCTUATION = set(i for (i, v) in enumerate(lm.vocab) if v in ',:;.!?"-') + self.END_SENTENCE_PUNCT = set(i for (i, v) in enumerate(lm.vocab) if v in ".!?") class TokenSequence: """A sequence of tokens. - + Supports addition (via `+` or mutating `+=`) with: - + * other `TokenSequence` instances (concatenation) * individual tokens, represented as integers or `Token` instances * strings, which are tokenized by `lm.tokenizer` - + Attributes: lm (hfppl.llms.CachedCausalLM): the language model whose vocabulary the tokens come from. seq (list[hfppl.llms.Token]): the sequence of tokens.""" - + def __init__(self, lm, seq=None): """Create a `TokenSequence` from a language model and a sequence. - + Args: lm (hfppl.llms.CachedCausalLM): the language model whose vocabulary the tokens come from. - seq (str | list[int]): the sequence of token ids, or a string which will be automatically tokenized. Defaults to the singleton sequence containing a bos token.""" + seq (str | list[int]): the sequence of token ids, or a string which will be automatically tokenized. Defaults to the singleton sequence containing a bos token. + """ self.lm = lm if seq is None: self.seq = [lm.tokenizer.bos_token_id] @@ -40,10 +53,10 @@ def __init__(self, lm, seq=None): self.seq = self.lm.tokenizer.encode(seq) else: self.seq = seq - + def __str__(self): return self.lm.tokenizer.decode(self.seq) - + def __iadd__(self, other): if isinstance(other, Token): assert other.lm is self.lm @@ -58,7 +71,7 @@ def __iadd__(self, other): else: raise RuntimeError(f"Addition not supported on {type(other)}") return self - + def __radd__(self, other): if isinstance(other, Token): assert other.lm is self.lm @@ -67,30 +80,35 @@ def __radd__(self, other): assert other.lm is self.lm return TokenSequence(self.lm, other.seq + self.seq) elif isinstance(other, str): - return TokenSequence(self.lm, self.lm.tokenizer.encode(other, add_special_tokens=False) + self.seq) + return TokenSequence( + self.lm, + self.lm.tokenizer.encode(other, add_special_tokens=False) + self.seq, + ) elif isinstance(other, int): return TokenSequence(self.lm, [other, *self.seq]) else: raise RuntimeError(f"Addition not supported on {type(other)}") - + def __add__(self, other): s = TokenSequence(self.lm, self.seq) s += other return s + class Token: """Class representing a token. - + Attributes: lm (hfppl.llms.CachedCausalLM): the language model for which this is a Token. token_id (int): the integer token id (an index into the vocabulary). - token_str (str): a string, which the token represents—equal to `lm.vocab[token_id]`.""" - + token_str (str): a string, which the token represents—equal to `lm.vocab[token_id]`. + """ + def __init__(self, lm, token_id, token_str): - self.lm = lm - self.token_id = token_id + self.lm = lm + self.token_id = token_id self.token_str = token_str - + # Adding tokens def __add__(self, other): s = TokenSequence(self.lm, [self.token_id]) @@ -100,7 +118,7 @@ def __add__(self, other): def __radd__(self, other): s = TokenSequence(self.lm, [self.token_id]) return other + s - + # Support checking for EOS def __eq__(self, other): if isinstance(other, Token): @@ -112,90 +130,120 @@ def __eq__(self, other): def __str__(self): return self.token_str - + def __repr__(self): return f"<{self.token_str}|{self.token_id}>" + class TokenTrie: """Class used internally to cache language model results.""" + # Trie of tokens. - def __init__(self, parent=None, logprobs=None): - self.children = {} # maps token ID to child + def __init__(self, parent=None, logprobs=None): + self.children = {} # maps token ID to child self.logprobs = logprobs # for next token self.past_key_values = None - + def __repr__(self): - return f"{'*' if self.past_key_values is not None else ''}[" + ", ".join([f"{node_id}: {node.__repr__()}" for (node_id, node) in self.children.items()]) + "]" - + return ( + f"{'*' if self.past_key_values is not None else ''}[" + + ", ".join( + [ + f"{node_id}: {node.__repr__()}" + for (node_id, node) in self.children.items() + ] + ) + + "]" + ) + def clear_kv_cache(self): self.past_key_values = None - for (child, node) in self.children.items(): + for child, node in self.children.items(): node.clear_kv_cache() - + def has_token(self, token_id): return token_id in self.children - + def get_token(self, token_id): return self.children[token_id] - + def add_token(self, token_id, logprobs=None): self.children[token_id] = TokenTrie(self, logprobs) return self.children[token_id] - def extend_cache(self, next_token_index, token_ids, logits, base): node = self - + for j in range(next_token_index, len(token_ids)): - token_id = token_ids[j] - token_logits = logits[j-base] - token_logprobs = torch.log_softmax(token_logits, 0) - + token_id = token_ids[j] + token_logits = logits[j - base] + token_logprobs = torch.log_softmax(token_logits, 0) + node = node.add_token(token_id, token_logprobs.cpu().numpy()) - + return node + class Query: """A query to a language model, waiting to be batched.""" - + def __init__(self, prompt, future, past=None): self.prompt = prompt self.future = future self.past = past - + if self.past is not None: - self.past_len = past[0][0].shape[2] # layers, key or value, batch size, num heads, num tokens, head repr length + self.past_len = past[0][0].shape[ + 2 + ] # layers, key or value, batch size, num heads, num tokens, head repr length else: self.past_len = 0 - + @torch.no_grad() def past_padded(self, layer, j, to_length, dtype, device, past_shape): - + if self.past is not None: - return torch.cat((self.past[layer][j], torch.zeros(1, past_shape[1], to_length-self.past_len, past_shape[3], dtype=dtype, device=device)), - dim=2) + return torch.cat( + ( + self.past[layer][j], + torch.zeros( + 1, + past_shape[1], + to_length - self.past_len, + past_shape[3], + dtype=dtype, + device=device, + ), + ), + dim=2, + ) else: - return torch.zeros(1, past_shape[1], to_length, past_shape[3], dtype=dtype, device=device) - + return torch.zeros( + 1, past_shape[1], to_length, past_shape[3], dtype=dtype, device=device + ) + def prompt_padded(self, pad_token, to_length): - return [*self.prompt, *[pad_token for _ in range(to_length-len(self.prompt))]] - - + return [*self.prompt, *[pad_token for _ in range(to_length - len(self.prompt))]] + def attention_mask(self, total_past_length, total_seq_length): - return [*[1 for _ in range(self.past_len)], - *[0 for _ in range(total_past_length-self.past_len)], - *[1 for _ in range(len(self.prompt))], - *[0 for _ in range(total_seq_length-len(self.prompt))]] - + return [ + *[1 for _ in range(self.past_len)], + *[0 for _ in range(total_past_length - self.past_len)], + *[1 for _ in range(len(self.prompt))], + *[0 for _ in range(total_seq_length - len(self.prompt))], + ] + def position_ids(self, total_past_length, total_seq_length): - return [*range(self.past_len, self.past_len + len(self.prompt)), - *[0 for _ in range(total_seq_length-len(self.prompt))]] - + return [ + *range(self.past_len, self.past_len + len(self.prompt)), + *[0 for _ in range(total_seq_length - len(self.prompt))], + ] + class CachedCausalLM: """Wrapper around a HuggingFace causal language model, with support for caching. - + Attributes: model: the underlying HuggingFace model. tokenizer: the underlying HuggingFace tokenizer. @@ -205,33 +253,40 @@ class CachedCausalLM: batch_size (int): when auto-batching, maximum number of queries to process in one batch. timeout (float): number of seconds to wait since last query before processing the current batch of queries, even if not full. """ - + @classmethod def from_pretrained(cls, model_id, auth_token=False, load_in_8bit=True): """Create a [`CachedCausalLM`][hfppl.llms.CachedCausalLM] from a pretrained HuggingFace model. - + Args: model_id (str): the string identifier of the model in HuggingFace's model library. auth_token (str): a HuggingFace API key. Only necessary if using private models, e.g. Meta's Llama models, which require authorization. load_in_8bit (bool): whether to use the `bitsandbytes` library to load the model in 8-bit quantized form. - + Returns: model (hfppl.llms.CachedCausalLM): the LLaMPPL-compatible interface to the HuggingFace model. """ if not auth_token: tok = AutoTokenizer.from_pretrained(model_id) - mod = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=load_in_8bit) + mod = AutoModelForCausalLM.from_pretrained( + model_id, device_map="auto", load_in_8bit=load_in_8bit + ) else: tok = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token) - mod = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=auth_token, device_map="auto", load_in_8bit=load_in_8bit) - + mod = AutoModelForCausalLM.from_pretrained( + model_id, + use_auth_token=auth_token, + device_map="auto", + load_in_8bit=load_in_8bit, + ) + return CachedCausalLM(mod, tok) - + @torch.no_grad() def __init__(self, hf_model, hf_tokenizer, batch_size=20): """ Create a `CachedCausalLM` from a loaded HuggingFace model and tokenizer. - + Args: hf_model: a HuggingFace `CausalLM`. hf_tokenizer: a HuggingFace `Tokenizer`. @@ -240,21 +295,28 @@ def __init__(self, hf_model, hf_tokenizer, batch_size=20): self.model = hf_model self.tokenizer = hf_tokenizer self.device = hf_model.device - + # TODO: remove required BOS token if self.tokenizer.bos_token_id is None: - raise RuntimeError("Causal LM has no BOS token, distribution of first word unclear") - + raise RuntimeError( + "Causal LM has no BOS token, distribution of first word unclear" + ) + # Evaluate BOS token - logits = self.model(torch.tensor([[self.tokenizer.bos_token_id]]).to(self.model.device)).logits[0][0] + logits = self.model( + torch.tensor([[self.tokenizer.bos_token_id]]).to(self.model.device) + ).logits[0][0] logprobs = torch.log_softmax(logits, 0) - + self.cache = TokenTrie(None, logprobs.cpu().numpy()) - + # Cache vocabulary - bos_len = len(self.tokenizer.decode([self.tokenizer.bos_token_id])) - self.vocab = [self.tokenizer.decode([self.tokenizer.bos_token_id,i])[bos_len:] for i in range(len(hf_tokenizer.vocab))] - + bos_len = len(self.tokenizer.decode([self.tokenizer.bos_token_id])) + self.vocab = [ + self.tokenizer.decode([self.tokenizer.bos_token_id, i])[bos_len:] + for i in range(len(hf_tokenizer.vocab)) + ] + # Precompute useful masks self.masks = Masks(self) @@ -264,64 +326,98 @@ def __init__(self, hf_model, hf_tokenizer, batch_size=20): self.batch_size = batch_size self.timeout = 0.02 self.timer = None - + def __deepcopy__(self, memo): return self - + def clear_cache(self): """Clear the cache of log probabilities and key/value pairs.""" self.cache = TokenTrie(None, self.cache.logprobs) - + def clear_kv_cache(self): """Clear any key and value vectors from the cache.""" self.cache.clear_kv_cache() - + def reset_async_queries(self): - """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing + """Clear any pending language model queries from the queue. Use this method when an exception prevented an inference algorithm from executing to completion.""" self.queries = [] - + @torch.no_grad() def cache_kv(self, prompt_tokens): """Cache the key and value vectors for a prompt. Future queries that have this prompt as a prefix will only run the LLM on new tokens. - + Args: prompt_tokens (list[int]): token ids for the prompt to cache. """ result = self.model(torch.tensor([prompt_tokens]).to(self.device)) - + node = self.cache.extend_cache(1, prompt_tokens, result.logits[0], 0) node.past_key_values = result.past_key_values - + @torch.no_grad() def batch_evaluate_queries(self): - + queries, self.queries = self.queries, [] if len(queries) == 0: return - + past_example = next((q.past for q in queries if q.past), False) max_past_length = max(q.past_len for q in queries) max_query_length = max(len(q.prompt) for q in queries) - - padding_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 - - input_ids = torch.tensor([q.prompt_padded(padding_token_id, max_query_length) for q in queries]).to(self.device) - attn_masks = torch.tensor([q.attention_mask(max_past_length, max_query_length) for q in queries]).to(self.device) - posn_ids = torch.tensor([q.position_ids(max_past_length, max_query_length) for q in queries]).to(self.device) + + padding_token_id = ( + self.tokenizer.pad_token_id + if self.tokenizer.pad_token_id is not None + else 0 + ) + + input_ids = torch.tensor( + [q.prompt_padded(padding_token_id, max_query_length) for q in queries] + ).to(self.device) + attn_masks = torch.tensor( + [q.attention_mask(max_past_length, max_query_length) for q in queries] + ).to(self.device) + posn_ids = torch.tensor( + [q.position_ids(max_past_length, max_query_length) for q in queries] + ).to(self.device) if past_example: - pasts = [[torch.cat((*(q.past_padded(layer, j, max_past_length, past_example[0][0].dtype, self.device, past_example[0][0].shape) for q in queries),), dim=0) - for j in range(2)] for layer in range(len(past_example))] + pasts = [ + [ + torch.cat( + ( + *( + q.past_padded( + layer, + j, + max_past_length, + past_example[0][0].dtype, + self.device, + past_example[0][0].shape, + ) + for q in queries + ), + ), + dim=0, + ) + for j in range(2) + ] + for layer in range(len(past_example)) + ] else: pasts = None - - results = self.model(input_ids, attention_mask=attn_masks, - position_ids=posn_ids, past_key_values=pasts, - use_cache=pasts is not None) - - for (i, q) in enumerate(queries): + + results = self.model( + input_ids, + attention_mask=attn_masks, + position_ids=posn_ids, + past_key_values=pasts, + use_cache=pasts is not None, + ) + + for i, q in enumerate(queries): q.future.set_result(results.logits[i]) - + @torch.no_grad() def add_query(self, query, future, past): self.queries.append(Query(query, future, past)) @@ -332,11 +428,13 @@ def add_query(self, query, future, past): if len(self.queries) >= self.batch_size: self.batch_evaluate_queries() else: - self.timer = asyncio.get_running_loop().call_later(self.timeout, lambda: self.batch_evaluate_queries()) - + self.timer = asyncio.get_running_loop().call_later( + self.timeout, lambda: self.batch_evaluate_queries() + ) + def walk_cache(self, token_ids): # Walk while tokens can be found - node = self.cache + node = self.cache next_token_index = 1 past = None @@ -350,60 +448,65 @@ def walk_cache(self, token_ids): next_token_index += 1 else: break - + return node, next_token_index, past, base - + @torch.no_grad() async def next_token_logprobs(self, token_ids): - """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`. - + """Request log probabilities of next token. This version is asynchronous because it automatically batches concurrent requests; use with `await`. + Args: token_ids (list[int]): a list of token ids starting with `tokenizer.bos_token_id`, representing a prompt to the language model. - + Returns: logprobs (numpy.array): a numpy array of `len(vocab)`, with the language model's log (normalized) probabilities for the next token following the prompt. """ - + # Ensure that token list begins with BOS assert token_ids[0] == self.tokenizer.bos_token_id - + node, next_token_index, past, base = self.walk_cache(token_ids) - + # If we processed all tokens, then we're done. if next_token_index == len(token_ids): return node.logprobs - + # Create a future with the prompt future = asyncio.get_running_loop().create_future() self.add_query(token_ids[base:], future, past) logits = await future - + # Create new nodes node = node.extend_cache(next_token_index, token_ids, logits, base) - + return node.logprobs - + @torch.no_grad() def next_token_logprobs_unbatched(self, token_ids): """Request log probabilities of next token. Not asynchronous, and does not support auto-batching. - + Args: token_ids (list[int]): a list of token ids starting with `tokenizer.bos_token_id`, representing a prompt to the language model. - + Returns: - logprobs (numpy.array): a numpy array of `len(vocab)`, with the language model's log (normalized) probabilities for the next token following the prompt.""" - + logprobs (numpy.array): a numpy array of `len(vocab)`, with the language model's log (normalized) probabilities for the next token following the prompt. + """ + # Ensure that token list begins with BOS assert token_ids[0] == self.tokenizer.bos_token_id - + # Walk while tokens can be found node, next_token_index, past, base = self.walk_cache(token_ids) - + if next_token_index == len(token_ids): return node.logprobs - - logits = self.model(torch.tensor([token_ids[base:]]).to(self.device), past_key_values=node.past_key_values, use_cache=node.past_key_values is not None).logits[0] - + + logits = self.model( + torch.tensor([token_ids[base:]]).to(self.device), + past_key_values=node.past_key_values, + use_cache=node.past_key_values is not None, + ).logits[0] + node = node.extend_cache(next_token_index, token_ids, logits, base) - + return node.logprobs diff --git a/hfppl/modeling.py b/hfppl/modeling.py index 3d7be22..b090c02 100644 --- a/hfppl/modeling.py +++ b/hfppl/modeling.py @@ -1,10 +1,11 @@ import copy + class SubModel: def __init__(self): self.parent = None - + async def run_with_parent(self, parent): old_parent = self.parent self.parent = parent @@ -13,37 +14,41 @@ async def run_with_parent(self, parent): return val async def forward(self): - raise NotImplementedError("SubModel.forward() must be implemented by subclasses") + raise NotImplementedError( + "SubModel.forward() must be implemented by subclasses" + ) async def sample(self, dist, proposal=None): return await self.parent.sample(dist, proposal) - + async def observe(self, dist, x): return await self.parent.observe(dist, x) - + async def intervene(self, dist, x): return await self.parent.intervene(dist, x) - + def condition(self, b): return self.parent.condition(b) - + def score(self, score): return self.parent.score(score) - + def twist(self, amt): return self.parent.twist(amt) - + async def call(self, submodel): - return (await submodel.run_with_parent(self.parent)) + return await submodel.run_with_parent(self.parent) + # For use as a decorator import functools + def submodel(f): """Decorator to create a SubModel implementation from an async function. - + For example: - + ```python @submodel async def sample_two_tokens(self, context): @@ -54,25 +59,27 @@ async def sample_two_tokens(self, context): This SubModel can then be used from another model or submodel, using the syntax `await self.call(sample_two_tokens(context))`. """ - @functools.wraps(f, updated=()) # unclear if this is the best way to do it + + @functools.wraps(f, updated=()) # unclear if this is the best way to do it class SubModelImpl(SubModel): def __init__(self, *args, **kwargs): super().__init__() self.args = args self.kwargs = kwargs - + async def forward(self): - return (await f(self, *self.args, **self.kwargs)) - + return await f(self, *self.args, **self.kwargs) + return SubModelImpl + class Model: """Base class for all LLaMPPL models. - + Your models should subclass this class. Minimally, you should provide an `__init__` method that calls `super().__init__(self)`, and a `step` method. """ - + def __init__(self): self.weight = 0.0 self.finished = False @@ -92,68 +99,67 @@ def reset(self): def immutable_properties(self): """Return a `set[str]` of properties that LLaMPPL may assume do not change during execution of `step`. This set is empty by default but can be overridden by subclasses to speed up inference. - + Returns: properties (set[str]): a set of immutable property names""" return set() - - def __deepcopy__(self, memo): + + def __deepcopy__(self, memo): cpy = type(self).__new__(type(self)) immutable = self.immutable_properties() - + for k, v in self.__dict__.items(): if k in immutable: setattr(cpy, k, v) else: setattr(cpy, k, copy.deepcopy(v, memo)) - + return cpy - def twist(self, amt): """Multiply this particle's weight by `exp(amt)`, but divide it back out before the next `step`. - + Use this method to provide heuristic guidance about whether a particle is "on the right track" without changing the ultimate target distribution. - + Args: amt: the logarithm of the amount by which to (temporarily) multiply this particle's weight. """ self.twist_amount += amt self.score(amt) - + def untwist(self): self.score(-self.twist_amount) self.twist_amount = 0.0 - + def finish(self): self.untwist() self.finished = True - + def done_stepping(self): return self.finished async def step(self): """Defines the computation performed in each step of the model. - + All subclasses should override this method.""" - + if not self.done_stepping(): raise NotImplementedError("Model.step() must be implemented by subclasses") - + def __str__(self): return "Particle" - + def start(self): pass - + def score(self, score): """Multiply this particle's weight by `exp(score)`. - + The `score` method is a low-level way to change the target distribution. For many use cases, it is sufficient to use `sample`, `observe`, `condition`, and `twist`, all of which are implemented in terms of `score`. - + Args: score: logarithm of the amount by which the particle's weight should be multiplied. """ @@ -161,54 +167,54 @@ def score(self, score): def condition(self, b): """Constrain a given Boolean expression to be `True`. - + If the condition is False, the particle's weight is set to zero and `self.finish()` is called, so that no further `step` calls are made. - + Args: b: the Boolean expression whose value is constrained to be True. """ if not b: - self.score(float('-inf')) + self.score(float("-inf")) self.finish() - + async def intervene(self, dist, x): """Force the distribution to take on the value `x`, but do not _condition_ on this result. - + This is useful primarily with distributions that have side effects (e.g., modifying some state). For example, a model with the code - + ```python token_1 = await self.sample(self.stateful_lm.next_token()) await self.observe(self.stateful_lm.next_token(), token_2) ``` - + encodes a posterior inference problem, to find `token_1` values that *likely preceded* `token_2`. By contrast, - + ```python token_1 = await self.sample(stateful_lm.next_token()) await self.intervene(self.stateful_lm.next_token(), token_2) ``` - + encodes a much easier task: freely generate `token_1` and then force-feed `token_2` as the following token. - + Args: dist (hfppl.distributions.distribution.Distribution): the distribution on which to intervene. x: the value to intervene with. """ await dist.log_prob(x) return x - + async def observe(self, dist, x): """Condition the model on the value `x` being sampled from the distribution `dist`. - + For discrete distributions `dist`, `await self.observe(dist, x)` specifies the same constraint as ``` val = await self.sample(dist) self.condition(val == x) ``` but can be much more efficient. - + Args: dist: a `Distribution` object from which to observe x: the value observed from `dist` @@ -216,16 +222,16 @@ async def observe(self, dist, x): p = await dist.log_prob(x) self.score(p) return x - + async def sample(self, dist, proposal=None): - """Extend the model with a sample from a given `Distribution`, with support for autobatching. + """Extend the model with a sample from a given `Distribution`, with support for autobatching. If specified, the Distribution `proposal` is used during inference to generate informed hypotheses. - + Args: dist: the `Distribution` object from which to sample proposal: if provided, inference algorithms will use this `Distribution` object to generate proposed samples, rather than `dist`. However, importance weights will be adjusted so that the target posterior is independent of the proposal. - + Returns: value: the value sampled from the distribution. """ @@ -238,7 +244,7 @@ async def sample(self, dist, proposal=None): # else: # self.score(w) # return x - + if proposal is None: x, _ = await dist.sample() return x @@ -247,6 +253,6 @@ async def sample(self, dist, proposal=None): p = await dist.log_prob(x) self.score(p - q) return x - + async def call(self, submodel): - return await submodel.run_with_parent(self) \ No newline at end of file + return await submodel.run_with_parent(self) diff --git a/hfppl/util.py b/hfppl/util.py index ee01d7e..2831945 100644 --- a/hfppl/util.py +++ b/hfppl/util.py @@ -2,20 +2,23 @@ import numpy as np + def logsumexp(nums): m = np.max(nums) return np.log(np.sum(np.exp(nums - m))) + m - + + def log_softmax(nums): """Compute log(softmax(nums)). - + Args: nums: a vector or numpy array of unnormalized log probabilities. - + Returns: np.array: an array of log (normalized) probabilities. """ return nums - logsumexp(nums) + def softmax(nums): - return np.exp(log_softmax(nums)) \ No newline at end of file + return np.exp(log_softmax(nums))