Skip to content

Commit

Permalink
Use black code formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-lew committed Jul 16, 2024
1 parent d54a59f commit 7d71e06
Show file tree
Hide file tree
Showing 19 changed files with 545 additions and 351 deletions.
1 change: 1 addition & 0 deletions examples/grammar_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Requires synchromesh (github.com/kanishkg/synchromesh)
"""

import asyncio
import os
from typing import List
Expand Down
43 changes: 28 additions & 15 deletions examples/haiku.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,37 @@
import os

# download the CMU pronunciation dictionary (if we haven't already)
nltk.download('cmudict')
nltk.download("cmudict")

# Load the CMU pronunciation dictionary and use it for syllable counting
from nltk.corpus import cmudict

CMUDICT = cmudict.dict()


def count_syllables(word, unknown_word_syllables=100):

# Use the dictionary to get the list of possible phonetic representations for the word
phonetic_transcriptions = CMUDICT.get(word.strip().lower(), [])

# Count the number of syllables based on the number of phonetic transcriptions
syllable_count = min([len([ph for ph in transcription if ph[-1].isdigit()]) for transcription in phonetic_transcriptions], default=unknown_word_syllables)
syllable_count = min(
[
len([ph for ph in transcription if ph[-1].isdigit()])
for transcription in phonetic_transcriptions
],
default=unknown_word_syllables,
)

return syllable_count


# Load the language model (llama2 if authorized, else mistral-7b).
if 'HF_AUTH_TOKEN' in os.environ:
HF_AUTH_TOKEN = os.environ['HF_AUTH_TOKEN']
LLM = CachedCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", auth_token=HF_AUTH_TOKEN)
if "HF_AUTH_TOKEN" in os.environ:
HF_AUTH_TOKEN = os.environ["HF_AUTH_TOKEN"]
LLM = CachedCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", auth_token=HF_AUTH_TOKEN
)
else:
LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")

Expand Down Expand Up @@ -74,21 +85,22 @@ def count_syllables(word, unknown_word_syllables=100):
# Useful constants
NEWLINE_TOKEN, EOS_TOKEN = 13, LLM.tokenizer.eos_token_id


# LLaMPPL model
class Haiku(Model):

def __init__(self, prompt, syllable_pattern=[5, 7, 5]):
super().__init__()
self.context = LMContext(LLM, prompt, 0.7)
self.syllable_pattern = syllable_pattern

async def step(self):
# Get the number of syllables required in the next line
syllables_remaining = self.syllable_pattern.pop(0)

# Loop to sample words until this line is over
while syllables_remaining > 0:

# Sample a word
word, punctuation = await self.call(sample_word(self.context))

Expand All @@ -103,18 +115,19 @@ async def step(self):
await self.observe(self.context.next_token(), EOS_TOKEN)
self.finish()
return

# Otherwise, observe a line break
await self.observe(self.context.next_token(), NEWLINE_TOKEN)

# Print current result
print(str(self.context))


# Run inference
SYLLABLES_PER_LINE = [5, 7, 5] # [5, 3, 5] for a Lune
SYLLABLES_PER_LINE = [5, 7, 5] # [5, 3, 5] for a Lune
particles = asyncio.run(smc_standard(Haiku(poem_prompt, SYLLABLES_PER_LINE), 120))

print("--------")
for (i,particle) in enumerate(particles):
for i, particle in enumerate(particles):
print(f"Poem {i} (weight {particle.weight}):")
print(f"{particle.context}")
print(f"{particle.context}")
43 changes: 27 additions & 16 deletions examples/hard_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,30 @@

import os

if 'HF_AUTH_TOKEN' in os.environ:
HF_AUTH_TOKEN = os.environ['HF_AUTH_TOKEN']
if "HF_AUTH_TOKEN" in os.environ:
HF_AUTH_TOKEN = os.environ["HF_AUTH_TOKEN"]

# Load the language model.
# Load the language model.
# Mistral and Vicuna are open models; to use a model with restricted access, like LLaMA 2,
# pass your HuggingFace API key as the optional `auth_token` argument:
# LLM = CachedCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", auth_token=HF_AUTH_TOKEN)
# LLM = CachedCausalLM.from_pretrained("lmsys/vicuna-7b-v1.5")
LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
LLM = CachedCausalLM.from_pretrained("lmsys/vicuna-7b-v1.5")
# LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
LLM.batch_size = 40

MASKS = {i : set(j for (j,v) in enumerate(LLM.vocab)
if j != LLM.tokenizer.eos_token_id and '\n' not in v and
any(c.isalpha() or c in string.punctuation for c in v) and
len(v.strip()) <= 5 and (not v[0].isalpha() or i+len(v) <= 5))
for i in range(6)}
MASKS = {
i: set(
j
for (j, v) in enumerate(LLM.vocab)
if j != LLM.tokenizer.eos_token_id
and "\n" not in v
and any(c.isalpha() or c in string.punctuation for c in v)
and len(v.strip()) <= 5
and (not v[0].isalpha() or i + len(v) <= 5)
)
for i in range(6)
}


class ConstraintModel(Model):
def __init__(self, prompt, max_tokens):
Expand All @@ -33,26 +41,27 @@ async def step(self):

# Condition on next token being from mask
await self.observe(self.context.mask_dist(mask), True)

# Generate proposed token.
token = await self.sample(self.context.next_token())

# Reduce number of max tokens remaining
self.max_tokens -= 1

print(f"{self.context}")

# Check if done
if token == LLM.tokenizer.eos_token_id or self.max_tokens == 0:
self.finish()
return

def active_constraint_mask(self):
string_so_far = str(self.context)
words = string_so_far.split()
last_word = words[-1] if len(words) > 0 else ""
return MASKS[min(5, len(last_word))]



# From Politico.com
prompt = """3 things to watch …
Expand All @@ -64,10 +73,12 @@ def active_constraint_mask(self):

LLM.cache_kv(LLM.tokenizer.encode(prompt))


async def main():
constraint_model = ConstraintModel(prompt, 50)
particles = await smc_standard(constraint_model, 40)
for p in particles:
print(f"{p.context}")

asyncio.run(main())

asyncio.run(main())
2 changes: 1 addition & 1 deletion hfppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
from .distributions import *
from .modeling import *
from .inference import *
from .chunks import *
from .chunks import *
31 changes: 22 additions & 9 deletions hfppl/chunks.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,51 @@
import string
from .modeling import submodel


@submodel
async def sample_word(self, context, max_tokens=5, allow_punctuation=True):
"""Sample a word from the `LMContext` object `context`."""
last_token = context.lm.vocab[context.tokens[-1]] if len(context.tokens) > 0 else ""
last_character = last_token[-1] if len(last_token) > 0 else ""
needs_space = last_character not in string.whitespace and last_character not in ['-', "'", '"']
needs_space = last_character not in string.whitespace and last_character not in [
"-",
"'",
'"',
]
if needs_space:
starts_word_mask = context.lm.masks.STARTS_NEW_WORD
else:
starts_word_mask = context.lm.masks.CONTINUES_CURRENT_WORD

# Force model to start a new word
await self.observe(context.mask_dist(starts_word_mask), True)

word = ""
num_tokens = 0
while True:
token = await self.sample(context.next_token())
word += context.lm.vocab[token.token_id]
token = await self.sample(context.next_token())
word += context.lm.vocab[token.token_id]
num_tokens += 1

if num_tokens == max_tokens:
await self.observe(context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD), False)
await self.observe(
context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD), False
)
break

if not (await self.sample(context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD))):
if not (
await self.sample(
context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD)
)
):
break

# Sample punctuation, if desired
punctuation = ""
if allow_punctuation and await self.sample(context.mask_dist(context.lm.masks.PUNCTUATION)):
if allow_punctuation and await self.sample(
context.mask_dist(context.lm.masks.PUNCTUATION)
):
punctuation_token = await self.sample(context.next_token())
punctuation = context.lm.vocab[punctuation_token.token_id]

return word, punctuation
return word, punctuation
2 changes: 1 addition & 1 deletion hfppl/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
from .tokencategorical import TokenCategorical
from .transformer import Transformer
from .lmcontext import LMContext
from .bernoulli import Bernoulli
from .bernoulli import Bernoulli
12 changes: 6 additions & 6 deletions hfppl/distributions/bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import numpy as np


class Bernoulli(Distribution):
"""A Bernoulli distribution.
"""

"""A Bernoulli distribution."""

def __init__(self, p):
"""Create a Bernoulli distribution.
Args:
p: the probability-of-True for the Bernoulli distribution.
"""
Expand All @@ -20,6 +20,6 @@ async def sample(self):

async def log_prob(self, value):
return np.log(self.p) if value else np.log1p(-self.p)

async def argmax(self, idx):
return ((self.p > 0.5) if idx == 0 else (self.p < 0.5))
return (self.p > 0.5) if idx == 0 else (self.p < 0.5)
15 changes: 7 additions & 8 deletions hfppl/distributions/distribution.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
class Distribution:
"""Abstract base class for a distribution."""


async def sample(self):
"""Generate a random sample from the distribution.
Returns:
x: a value randomly sampled from the distribution."""
raise NotImplementedError()

async def log_prob(self, x):
"""Compute the log probability of a value under this distribution,
or the log probability density if the distribution is continuous.
Args:
x: the point at which to evaluate the log probability.
Returns:
logprob (float): the log probability of `x`."""
logprob (float): the log probability of `x`."""
raise NotImplementedError()

async def argmax(self, n):
"""Return the nth most probable outcome under this distribution (assuming this is a discrete distribution).
Args:
n (int): which value to return to, indexed from most probable (n=0) to least probable (n=|support|).
Returns:
x: the nth most probable outcome from this distribution."""
raise NotImplementedError()
raise NotImplementedError()
15 changes: 8 additions & 7 deletions hfppl/distributions/geometric.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from .distribution import Distribution
import numpy as np


class Geometric(Distribution):
"""A Geometric distribution.
"""

"""A Geometric distribution."""

def __init__(self, p):
"""Create a Geometric distribution.
Args:
p: the rate of the Geometric distribution.
"""
Expand All @@ -17,7 +18,7 @@ async def sample(self):
return n, await self.log_prob(n)

async def log_prob(self, value):
return np.log(self.p) + np.log(1 - self.p)*(value - 1)
return np.log(self.p) + np.log(1 - self.p) * (value - 1)

async def argmax(self, idx):
return idx - 1 # Most likely outcome is 0, then 1, etc.
return idx - 1 # Most likely outcome is 0, then 1, etc.
Loading

1 comment on commit 7d71e06

@gabegrand
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good, I was going to suggest that we apply this to the whole repo.

Please sign in to comment.