Skip to content

Commit

Permalink
Merge pull request #18 from probcomp/gg/pre-commit
Browse files Browse the repository at this point in the history
Run pre-commit on all files
  • Loading branch information
alex-lew authored Jul 23, 2024
2 parents bd3bcb5 + 7351d5a commit f6efcd5
Show file tree
Hide file tree
Showing 18 changed files with 75 additions and 47 deletions.
7 changes: 4 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
repos:
- repo: https://github.com/asottile/reorder-python-imports
rev: v3.13.0
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: reorder-python-imports
- id: isort
args: [--profile, black, --force-single-line-imports]
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.4.2
hooks:
Expand Down
10 changes: 5 additions & 5 deletions docs/gen_reference_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
nav = mkdocs_gen_files.Nav()

for path in sorted(Path("hfppl").rglob("*.py")):
if any(part.startswith('.') for part in path.parts):
if any(part.startswith(".") for part in path.parts):
continue

module_path = path.relative_to(".").with_suffix("")
doc_path = path.relative_to(".").with_suffix(".md")
full_doc_path = Path("reference", doc_path)
Expand All @@ -22,13 +22,13 @@
elif parts[-1] == "__main__":
continue

nav[parts] = doc_path.as_posix() #
nav[parts] = doc_path.as_posix() #

with mkdocs_gen_files.open(full_doc_path, "w") as fd:
ident = ".".join(parts)
fd.write(f"::: {ident}")

mkdocs_gen_files.set_edit_path(full_doc_path, path)

with mkdocs_gen_files.open("reference/SUMMARY.md", "w") as nav_file: #
nav_file.writelines(nav.build_literate_nav()) #
with mkdocs_gen_files.open("reference/SUMMARY.md", "w") as nav_file: #
nav_file.writelines(nav.build_literate_nav()) #
8 changes: 4 additions & 4 deletions examples/grammar_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
import os
from typing import List

from synchromesh.completion_engine import LarkCompletionEngine
from synchromesh.synchromesh import StreamingCSD

from hfppl.distributions import LMContext
from hfppl.inference import smc_standard
from hfppl.llms import CachedCausalLM
from hfppl.modeling import Model
from hfppl.inference import smc_standard

from synchromesh.completion_engine import LarkCompletionEngine
from synchromesh.synchromesh import StreamingCSD


class GrammarConstrainedSMC(Model):
Expand Down
10 changes: 8 additions & 2 deletions examples/haiku.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from hfppl import Model, CachedCausalLM, LMContext, smc_standard, sample_word
import asyncio
import nltk
import os

import nltk

from hfppl import CachedCausalLM
from hfppl import LMContext
from hfppl import Model
from hfppl import sample_word
from hfppl import smc_standard

# download the CMU pronunciation dictionary (if we haven't already)
nltk.download("cmudict")

Expand Down
9 changes: 6 additions & 3 deletions examples/hard_constraints.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import string
import asyncio
from hfppl import Model, CachedCausalLM, LMContext, smc_standard

import os
import string

from hfppl import CachedCausalLM
from hfppl import LMContext
from hfppl import Model
from hfppl import smc_standard

if "HF_AUTH_TOKEN" in os.environ:
HF_AUTH_TOKEN = os.environ["HF_AUTH_TOKEN"]
Expand Down
8 changes: 4 additions & 4 deletions hfppl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Probabilistic programming with HuggingFace Transformer models.
"""

from .util import *
from .llms import *
from .chunks import *
from .distributions import *
from .modeling import *
from .inference import *
from .chunks import *
from .llms import *
from .modeling import *
from .util import *
1 change: 1 addition & 0 deletions hfppl/chunks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import string

from .modeling import submodel


Expand Down
4 changes: 2 additions & 2 deletions hfppl/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
* `LMContext(lm: hfppl.llms.CachedCausalLM, prompt: list[int]).mask_dist(mask: set[int]) -> bool`
"""

from .bernoulli import Bernoulli
from .distribution import Distribution
from .geometric import Geometric
from .lmcontext import LMContext
from .logcategorical import LogCategorical
from .tokencategorical import TokenCategorical
from .transformer import Transformer
from .lmcontext import LMContext
from .bernoulli import Bernoulli
4 changes: 2 additions & 2 deletions hfppl/distributions/bernoulli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .distribution import Distribution

import numpy as np

from .distribution import Distribution


class Bernoulli(Distribution):
"""A Bernoulli distribution."""
Expand Down
3 changes: 2 additions & 1 deletion hfppl/distributions/geometric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .distribution import Distribution
import numpy as np

from .distribution import Distribution


class Geometric(Distribution):
"""A Geometric distribution."""
Expand Down
15 changes: 9 additions & 6 deletions hfppl/distributions/lmcontext.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from ..util import log_softmax, logsumexp
from .distribution import Distribution
from ..llms import Token
import numpy as np
import copy

import numpy as np

from ..llms import Token
from ..util import log_softmax
from ..util import logsumexp
from .distribution import Distribution


class LMNextToken(Distribution):

Expand Down Expand Up @@ -157,9 +160,9 @@ def token_count(self):
def __str__(self):
full_string = self.lm.tokenizer.decode(self.tokens)
if not self.show_prompt:
full_string = full_string[self.prompt_string_length:]
full_string = full_string[self.prompt_string_length :]
if not self.show_eos and full_string.endswith(self.lm.tokenizer.eos_token):
full_string = full_string[:-len(self.lm.tokenizer.eos_token)]
full_string = full_string[: -len(self.lm.tokenizer.eos_token)]
return full_string

def __deepcopy__(self, memo):
Expand Down
5 changes: 3 additions & 2 deletions hfppl/distributions/logcategorical.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .distribution import Distribution
from ..util import log_softmax
import numpy as np

from ..util import log_softmax
from .distribution import Distribution


class LogCategorical(Distribution):
"""A Geometric distribution."""
Expand Down
7 changes: 4 additions & 3 deletions hfppl/distributions/tokencategorical.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .distribution import Distribution
from ..util import log_softmax
from ..llms import Token
import numpy as np
import torch

from ..llms import Token
from ..util import log_softmax
from .distribution import Distribution


class TokenCategorical(Distribution):

Expand Down
6 changes: 4 additions & 2 deletions hfppl/distributions/transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .distribution import Distribution
from ..llms import TokenSequence, Token
import numpy as np

from ..llms import Token
from ..llms import TokenSequence
from .distribution import Distribution


# Transformer(lm, prompt) -- where prompt can either be a string or a list of Tokens.
class Transformer(Distribution):
Expand Down
1 change: 1 addition & 0 deletions hfppl/inference/smc_record.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json

import numpy as np


Expand Down
8 changes: 5 additions & 3 deletions hfppl/inference/smc_standard.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import copy
from ..util import logsumexp
from datetime import datetime

import numpy as np
import asyncio

from ..util import logsumexp
from .smc_record import SMCRecord
from datetime import datetime


async def smc_standard(
Expand Down
9 changes: 6 additions & 3 deletions hfppl/inference/smc_steer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import numpy as np
import copy
import asyncio
from ..util import logsumexp, softmax
import copy

import numpy as np

from ..util import logsumexp
from ..util import softmax


def find_c(weights, N):
Expand Down
7 changes: 5 additions & 2 deletions hfppl/llms.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Utilities for working with HuggingFace language models, including caching and auto-batching."""

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import asyncio
import string

import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import BitsAndBytesConfig


class Masks:
def __init__(self, lm):
Expand Down

0 comments on commit f6efcd5

Please sign in to comment.