-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
19 changed files
with
545 additions
and
351 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,51 @@ | ||
import string | ||
from .modeling import submodel | ||
|
||
|
||
@submodel | ||
async def sample_word(self, context, max_tokens=5, allow_punctuation=True): | ||
"""Sample a word from the `LMContext` object `context`.""" | ||
last_token = context.lm.vocab[context.tokens[-1]] if len(context.tokens) > 0 else "" | ||
last_character = last_token[-1] if len(last_token) > 0 else "" | ||
needs_space = last_character not in string.whitespace and last_character not in ['-', "'", '"'] | ||
needs_space = last_character not in string.whitespace and last_character not in [ | ||
"-", | ||
"'", | ||
'"', | ||
] | ||
if needs_space: | ||
starts_word_mask = context.lm.masks.STARTS_NEW_WORD | ||
else: | ||
starts_word_mask = context.lm.masks.CONTINUES_CURRENT_WORD | ||
|
||
# Force model to start a new word | ||
await self.observe(context.mask_dist(starts_word_mask), True) | ||
|
||
word = "" | ||
num_tokens = 0 | ||
while True: | ||
token = await self.sample(context.next_token()) | ||
word += context.lm.vocab[token.token_id] | ||
token = await self.sample(context.next_token()) | ||
word += context.lm.vocab[token.token_id] | ||
num_tokens += 1 | ||
|
||
if num_tokens == max_tokens: | ||
await self.observe(context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD), False) | ||
await self.observe( | ||
context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD), False | ||
) | ||
break | ||
|
||
if not (await self.sample(context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD))): | ||
if not ( | ||
await self.sample( | ||
context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD) | ||
) | ||
): | ||
break | ||
|
||
# Sample punctuation, if desired | ||
punctuation = "" | ||
if allow_punctuation and await self.sample(context.mask_dist(context.lm.masks.PUNCTUATION)): | ||
if allow_punctuation and await self.sample( | ||
context.mask_dist(context.lm.masks.PUNCTUATION) | ||
): | ||
punctuation_token = await self.sample(context.next_token()) | ||
punctuation = context.lm.vocab[punctuation_token.token_id] | ||
|
||
return word, punctuation | ||
return word, punctuation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,28 @@ | ||
class Distribution: | ||
"""Abstract base class for a distribution.""" | ||
|
||
|
||
async def sample(self): | ||
"""Generate a random sample from the distribution. | ||
Returns: | ||
x: a value randomly sampled from the distribution.""" | ||
raise NotImplementedError() | ||
|
||
async def log_prob(self, x): | ||
"""Compute the log probability of a value under this distribution, | ||
or the log probability density if the distribution is continuous. | ||
Args: | ||
x: the point at which to evaluate the log probability. | ||
Returns: | ||
logprob (float): the log probability of `x`.""" | ||
logprob (float): the log probability of `x`.""" | ||
raise NotImplementedError() | ||
|
||
async def argmax(self, n): | ||
"""Return the nth most probable outcome under this distribution (assuming this is a discrete distribution). | ||
Args: | ||
n (int): which value to return to, indexed from most probable (n=0) to least probable (n=|support|). | ||
Returns: | ||
x: the nth most probable outcome from this distribution.""" | ||
raise NotImplementedError() | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
7d71e06
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good, I was going to suggest that we apply this to the whole repo.