Skip to content

Commit

Permalink
annotation and types
Browse files Browse the repository at this point in the history
  • Loading branch information
Hudson Cooper committed Feb 14, 2024
1 parent 4a18fe8 commit 2dc12e7
Showing 1 changed file with 49 additions and 38 deletions.
87 changes: 49 additions & 38 deletions src/minml/types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from types import UnionType, NoneType, GenericAlias
from typing import get_origin, get_args
from typing_extensions import _AnnotatedAlias
from annotated_types import GroupedMetadata
from collections.abc import Collection
from pydantic import StringConstraints, BaseModel
import guidance
from guidance import gen, select
from guidance._grammar import Select
from guidance._grammar import Null, Byte, GrammarFunction, Join, Select, string

__all__ = [
"gen_bool",
Expand All @@ -16,48 +17,62 @@
"gen_type",
]

_QUOTE = Byte(b'"')
_OPEN_BRACE = Byte(b"{")
_CLOSE_BRACE = Byte(b"}")
_OPEN_BRACKET = Byte(b"[")
_CLOSE_BRACKET = Byte(b"]")
_COMMA = Byte(b",")
_COLON = Byte(b":")

def _gen_None():
return "null"
Type = type | NoneType | UnionType | GenericAlias | _AnnotatedAlias | BaseModel


def gen_bool():
def gen_None() -> GrammarFunction:
return string("null")


def gen_bool() -> GrammarFunction:
return select(["true", "false"])


def gen_int():
def gen_int() -> GrammarFunction:
return gen(regex=r"(\+|\-)?\d+")


def gen_float():
def gen_float() -> GrammarFunction:
return gen(regex=r"(\+|\-)?(\d*\.)?\d+")


def gen_str(**kwds):
delim = '"'
return delim + gen(**kwds, stop=delim) + delim
def gen_str(**kwds) -> GrammarFunction:
return Join([_QUOTE, gen(**kwds, stop='"'), _QUOTE])


def gen_list(type):
return _gen_sequence(type, "[", "]")
def gen_list(type: Type) -> GrammarFunction:
s = Select([], capture_name=None, recursive=True)
s.values = [gen_type(type), Join([s, _COMMA, gen_type(type)])]
return _OPEN_BRACKET + select([_CLOSE_BRACKET, Join([s, _CLOSE_BRACKET])])


def gen_schema(schema: BaseModel):
template = "{"
items = schema.model_fields.items()
n = len(items)
for i, (field, field_info) in enumerate(items):
def gen_pydantic(schema: BaseModel) -> GrammarFunction:
grammar = _OPEN_BRACE
model_fields = schema.model_fields.items()
for i, (field, field_info) in enumerate(model_fields):
annotation = field_info.rebuild_annotation()
template += f'"{field}": ' + gen_type(annotation)
if i < n - 1:
template += ","
template += "}"
return template


def gen_type(type):
field_grammar = Join(
[_QUOTE, string(field), _QUOTE, _COLON, gen_type(annotation)]
)
if i == 0:
grammar = Join([grammar, field_grammar])
else:
grammar = Join([grammar, _COMMA, field_grammar])
grammar = Join([grammar, _CLOSE_BRACE])
return grammar


def gen_type(type: Type | None) -> GrammarFunction:
if (type is None) or (type is NoneType):
return _gen_None()
return gen_None()
if type is bool:
return gen_bool()
if type is int:
Expand All @@ -77,24 +92,20 @@ def gen_type(type):
types = get_args(type)
return _gen_union_type(*types)
if issubclass(type, BaseModel):
return gen_schema(type)
raise NotImplementedError("Can't gen type {type!r}")


def _gen_sequence(type, opener, closer):
s = Select([], capture_name=None, recursive=True)
s.values = [gen_type(type), s + ", " + gen_type(type)]
return opener + select([closer, s + closer])
return gen_pydantic(type)
raise NotImplementedError(f"Can't gen type {type!r}")


def _gen_generic_alias_type(origin, args):
def _gen_generic_alias_type(origin: Type, args: Collection[Type]) -> GrammarFunction:
if origin is list and len(args) == 1:
type = args[0]
return gen_list(type)
raise NotImplementedError


def _gen_annotated_type(type, annotations):
def _gen_annotated_type(
type: Type, annotations: Collection[GroupedMetadata]
) -> GrammarFunction:
if type is str:
if len(annotations) == 1 and isinstance(annotations[0], StringConstraints):
kmap = {"pattern": "regex", "max_length": "max_tokens"}
Expand All @@ -110,8 +121,8 @@ def _gen_annotated_type(type, annotations):
) from e
return gen_str(**kwds)

raise NotImplementedError("Can't gen type {type!r}")
raise NotImplementedError(f"Can't gen type {type!r}")


def _gen_union_type(*types):
def _gen_union_type(*types: Type) -> GrammarFunction:
return select([gen_type(type) for type in types])

0 comments on commit 2dc12e7

Please sign in to comment.