Skip to content

Commit

Permalink
remove unneeded torch dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
slundberg committed Nov 20, 2023
1 parent 39b944c commit fe07b17
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 63 deletions.
10 changes: 0 additions & 10 deletions guidance/models/_llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,6 @@
from ._local import Local
from .._utils import normalize_notebook_stdout_stderr

try:
# TODO: can we eliminate the torch requirement for llama.cpp by using numpy in the caller instead?
import torch
is_torch = True
except ImportError:
is_torch = False

try:
import llama_cpp
is_llama_cpp = True
Expand All @@ -24,8 +17,6 @@

class LlamaCpp(Local):
def __init__(self, model=None, tokenizer=None, echo=True, caching=True, temperature=0.0, **kwargs):
if not is_torch:
raise Exception("Please install PyTorch in order to use guidance.models.LlamaCpp!")

if not is_llama_cpp:
raise Exception("Please install llama-cpp-python with `pip install llama-cpp-python` in order to use guidance.models.LlamaCpp!")
Expand Down Expand Up @@ -139,7 +130,6 @@ def _get_logits(self, token_ids, forced_bytes):
finally:
llama_cpp.llama_batch_free(batch)

logits = torch.from_numpy(logits)
self._cache_state["logits"] = logits

return logits
Expand Down
33 changes: 8 additions & 25 deletions guidance/models/_local.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
try:
import torch
except ImportError:
pass
import scipy.special
import scipy.stats
import numpy as np
from .._utils import ByteTrie
from ._model import Model
# from ..library._string import string
from .._parser import EarleyCommitParser
from .._grammar import Terminal, select
# import numba
from .._grammar import Terminal

class Local(Model):
def __init__(self, tokens, bos_token_id, eos_token_id=None, echo=True):
Expand All @@ -33,20 +29,7 @@ def _get_logits(self, token_ids, forced_bytes):
'''A fake method designed to be overriden by subclasses.'''

# pretend to extend the KV cache and update the log probs
return torch.randn(len(self.tokens))

# def _longest_token_match(self, bytes):
# '''Greedy token matching.'''
# trie_pos = self._token_trie
# for i,c in enumerate(bytes):
# if c in trie_pos.children:
# trie_pos = trie_pos.children[c]
# else:
# return bytes[:i], trie_pos.value # note that if there are redudant tokens we choose the one stored in the trie
# if len(trie_pos.children) == 0:
# return bytes[:i+1], trie_pos.value
# else:
# return None,None # more than one token can match these bytes
return np.randn(len(self.tokens))

def _tokenize_prefix(self, byte_string):
'''This is used to speed up the tokenization of long prompts without using the parser.'''
Expand Down Expand Up @@ -237,17 +220,17 @@ def __call__(self, grammar, max_tokens=100, n=1, top_p=1, temperature=0.0, ensur
# if requested we compute the log probabilities so we can track the probabilities of each node
# TODO: we should lower this step to C++ with pybind11
if log_probs:
_compute_log_probs(trie, torch.nn.functional.log_softmax(logits, dim=-1).cpu().numpy())
_compute_log_probs(trie, scipy.special.log_softmax(logits, dim=-1))

# get the sampling order
grammar_temp = parser.next_byte_temperature()
current_temp = grammar_temp if grammar_temp >= 0 else temperature # we prefer to use the grammar temp when it is specified
if current_temp == 0:
sampling_order = torch.argsort(-logits, descending=False).cpu().numpy() # we need numpy so the enumerate below does not get really slow...
sampling_order = np.argsort(-logits) # we need numpy so the enumerate below does not get really slow...
else:
assert top_p == 1, "Still need to add support for top_p!"
probs = torch.nn.functional.softmax(logits / current_temp, dim=-1)
sampling_order = torch.multinomial(probs, len(probs)).cpu().numpy()
probs = scipy.special.softmax(logits / current_temp, axis=-1)
sampling_order = np.random.choice(len(probs), size=len(probs), p=probs)

# loop over the tokens looking for a valid one
for i,sampled_token_ind in enumerate(sampling_order):
Expand Down
12 changes: 1 addition & 11 deletions guidance/models/_local_mock.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
import re
import functools
import time
import collections
import os
import numpy as np

try:
import torch
except ImportError:
pass

from ._model import Chat
from ._local import Local

Expand Down Expand Up @@ -69,7 +59,7 @@ def _get_logits(self, token_ids, forced_bytes):
logits[i] += bias
bias /= 2 # if we have multiple matches then they apply with decreasing bias

return torch.tensor(logits)
return logits

def _get_next_tokens(self, byte_string):
for i,t in enumerate(self.tokens):
Expand Down
8 changes: 0 additions & 8 deletions guidance/models/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,6 @@
from ._model import Chat, Instruct
from ._remote import Remote


try:
# TODO: can we eliminate the torch requirement for llama.cpp by using numpy in the caller instead?
import torch
is_torch = True
except ImportError:
is_torch = False

try:
# TODO: can we eliminate the torch requirement for llama.cpp by using numpy in the caller instead?
import openai as openai_package
Expand Down
9 changes: 1 addition & 8 deletions guidance/models/_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,6 @@
from ._local import Local


try:
# TODO: can we eliminate the torch requirement for llama.cpp by using numpy in the caller instead?
import torch
is_torch = True
except ImportError:
is_torch = False

# try:
# # TODO: can we eliminate the torch requirement for llama.cpp by using numpy in the caller instead?
# import vertexai
Expand Down Expand Up @@ -190,7 +183,7 @@ def _get_logits(self, token_ids, forced_bytes):
logits = np.ones(len(self.tokens)) * np.nan
logits[token_id] = 100

return torch.tensor(logits) # TODO: the caller need to know to NOT use the 0 entries, but to fail out if that happens
return logits

def _get_next_token(self, pos, allow_early_stop=False):
data = self._shared_state["data"]
Expand Down
2 changes: 1 addition & 1 deletion guidance/models/transformers/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _get_logits(self, token_ids, forced_bytes):
# save the results
self._cache_state["past_key_values"] = model_out.past_key_values
cache_token_ids.extend(new_token_ids)
self._cache_state["logits"] = model_out.logits[0, -1, :]
self._cache_state["logits"] = model_out.logits[0, -1, :].cpu().numpy()

return self._cache_state["logits"]

Expand Down

0 comments on commit fe07b17

Please sign in to comment.