Skip to content

Commit

Permalink
use update(T, arg1, arg2, ...) for points-to effect
Browse files Browse the repository at this point in the history
also update(T, *arg) for copy points to

Signed-off-by: Elazar Gershuni <[email protected]>
  • Loading branch information
elazarg committed Dec 2, 2024
1 parent d4d3c05 commit 2ce3c8f
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 76 deletions.
45 changes: 35 additions & 10 deletions pythia/dom_typed_pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,9 +682,10 @@ def expr(
func_type = predefined(var)
else:
assert False, f"Expected Var or PredefinedFunction, got {var}"
if isinstance(
func_type, ts.Instantiation
) and func_type.generic == ts.Ref("builtins.type"):
if (
isinstance(func_type, ts.Instantiation)
and func_type.generic == ts.TYPE
):
func_type = ts.get_init_func(func_type)
assert isinstance(
func_type, ts.Overloaded
Expand All @@ -703,7 +704,7 @@ def expr(

side_effect = ts.get_side_effect(applied)
dirty = make_dirty()
if side_effect.update is not None:
if side_effect.update[0] is not None:
func_obj = pythia.dom_concrete.Set[Object].squeeze(func_objects)
if isinstance(func_obj, pythia.dom_concrete.Set):
raise RuntimeError(
Expand Down Expand Up @@ -732,15 +733,39 @@ def expr(
]
# Expected two objects: self argument and locals

if new_tp.types[self_obj] != side_effect.update:
if True or new_tp.types[self_obj] != side_effect.update[0]:
if monomorophized:
raise RuntimeError(
f"Update with aliased objects: {aliasing_pointers} (not: {func_obj, LOCALS})"
)
new_tp.types[self_obj] = side_effect.update
if side_effect.name == "append":
x = arg_objects[0]
new_tp.pointers.update(self_obj, tac.Var("*"), x)
new_tp.types[self_obj] = side_effect.update[0]
arg_indices_to_point = side_effect.update[1]
if arg_indices_to_point:
for i in arg_indices_to_point:
starred = False
if isinstance(i, ts.Star):
assert len(i.items) == 1
i = i.items[0]
starred = True

if isinstance(i, ts.Literal) and isinstance(
i.value, int
):
# TODO: minus one only for self. Should be fixed on binding
v = i.value - 1
assert v < len(
arg_objects
), f"{v} >= {len(arg_objects)}"
targets = arg_objects[v]
if starred:
targets = prev_tp.pointers[
targets, tac.Var("*")
]
new_tp.pointers.update(
self_obj, tac.Var("*"), targets
)
else:
assert False, i

t = ts.get_return(applied)
assert t != ts.BOTTOM, f"Expected non-bottom return type for {locals()}"
Expand Down Expand Up @@ -792,7 +817,7 @@ def expr(
assert isinstance(applied, ts.Overloaded)
side_effect = ts.get_side_effect(applied)
dirty = make_dirty()
if side_effect.update is not None:
if side_effect.update[0] is not None:
dirty = make_dirty_from_keys(
value_objects, pythia.dom_concrete.Set[tac.Var].top()
)
Expand Down
102 changes: 56 additions & 46 deletions pythia/type_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@ def __repr__(self) -> str:
return self.name


ANY = Ref("typing.Any")
LIST = Ref("builtins.list")
TUPLE = Ref("builtins.tuple")
SET = Ref("builtins.set")
TYPE = Ref("builtins.type")
NONE_TYPE = Ref("builtins.NoneType")
BOOL = Ref("builtins.bool")
INT = Ref("builtins.int")
FLOAT = Ref("builtins.float")
STR = Ref("builtins.str")


@dataclass(frozen=True, slots=True)
class TypeVar:
name: str
Expand Down Expand Up @@ -132,25 +144,28 @@ def literal(value: int | str | bool | float | tuple | list | None) -> Literal:
case value if value is NULL:
ref = Ref("builtins.ellipsis")
case int():
ref = Ref("builtins.int")
ref = INT
case float():
ref = Ref("builtins.float")
ref = FLOAT
case str():
ref = Ref("builtins.str")
ref = STR
case bool():
ref = Ref("builtins.bool")
ref = BOOL
case None:
ref = Ref("builtins.NoneType")
ref = NONE_TYPE
case tuple():
ref = Ref("builtins.tuple")
ref = TUPLE
case list():
value = tuple(value)
ref = Ref("builtins.list")
ref = LIST
case _:
assert False, f"Unknown literal type {value!r}"
return Literal(value, ref)


NONE = literal(None)


@dataclass(frozen=True, slots=True)
class TypedDict:
items: frozenset[Row]
Expand Down Expand Up @@ -227,9 +242,8 @@ def __repr__(self) -> str:
class SideEffect:
new: bool
bound_method: bool = False
update: typing.Optional[TypeExpr] = None
update: tuple[typing.Optional[TypeExpr], tuple[int, ...]] = (None, ())
points_to_args: bool = False
name: typing.Optional[str] = None # ad hoc effects


@dataclass(frozen=True, slots=True)
Expand All @@ -248,7 +262,7 @@ def __repr__(self) -> str:
new = "new " if self.new() else ""
update = (
"{update " + str(self.side_effect.update) + "}@"
if self.side_effect.update
if self.side_effect.update[0]
else ""
)
return f"[{type_params}]({self.params} -> {update}{new}{self.return_type})"
Expand Down Expand Up @@ -340,9 +354,9 @@ def bind_typevars(t: TypeExpr, context: dict[TypeVar, TypeExpr]) -> TypeExpr:
case Row() as row:
return replace(row, type=bind_typevars(row.type, context))
case SideEffect() as s:
if s.update is None:
if s.update[0] is None:
return s
return replace(s, update=bind_typevars(s.update, context))
return replace(s, update=(bind_typevars(s.update[0], context), s.update[1]))
case Class(
type_params=type_params, class_dict=class_dict, inherits=inherits
) as klass:
Expand Down Expand Up @@ -372,9 +386,9 @@ def bind_typevars(t: TypeExpr, context: dict[TypeVar, TypeExpr]) -> TypeExpr:
return choices.items[actual_arg.value]
return Access(choices, actual_arg)
case SideEffect() as s:
if s.update is None:
if s.update[0] is None:
return s
return replace(s, update=bind_typevars(s.update, context))
return replace(s, update=(bind_typevars(s.update[0], context), s.update[1]))
raise NotImplementedError(f"{t!r}, {type(t)}")


Expand Down Expand Up @@ -414,7 +428,6 @@ def union(items: typing.Iterable[TypeExpr], squeeze=True) -> TypeExpr:

TOP = typed_dict([])
BOTTOM = Union(frozenset())
ANY = Ref("typing.Any")


def join(t1: TypeExpr, t2: TypeExpr) -> TypeExpr:
Expand Down Expand Up @@ -461,7 +474,7 @@ def join(t1: TypeExpr, t2: TypeExpr) -> TypeExpr:
tuple(join(t1, t2) for t1, t2 in zip(l1.value, l2.value)),
l1.ref,
)
if l1.ref == Ref("builtins.list"):
if l1.ref == LIST:
return Instantiation(
l1.ref, (join_all([*l1.value, *l2.value]),)
)
Expand All @@ -477,7 +490,7 @@ def join(t1: TypeExpr, t2: TypeExpr) -> TypeExpr:
Instantiation() as inst,
Literal(tuple() as value, ref=ref),
) if ref == inst.generic:
if ref.name == "builtins.list":
if ref == LIST:
value = (join_all(value),)
return join(inst, Instantiation(ref, value))
case (TypedDict(items1), TypedDict(items2)): # type: ignore
Expand All @@ -497,21 +510,17 @@ def join(t1: TypeExpr, t2: TypeExpr) -> TypeExpr:
if index1 == index2:
return Row(index1, join(t1, t2))
return BOTTOM
case (Row(_, t1), (Instantiation(Ref("builtins.list"), type_args))) | (
Instantiation(Ref("builtins.list"), type_args),
case (Row(_, t1), (Instantiation(Ref("builtings.list"), type_args))) | (
Instantiation(Ref("builtings.list"), type_args),
Row(_, t1),
):
return Instantiation(
Ref("builtins.list"), tuple(join(t1, t) for t in type_args)
)
return Instantiation(LIST, tuple(join(t1, t) for t in type_args))
case (Row(_, t1), (Instantiation(Ref("builtins.tuple"), type_args))) | (
Instantiation(Ref("builtins.tuple"), type_args),
Row(_, t1),
):
# not exact; should only join at the index of the row
return Instantiation(
Ref("builtins.tuple"), tuple(join(t1, t) for t in type_args)
)
return Instantiation(TUPLE, tuple(join(t1, t) for t in type_args))
case Class(), Class():
return TOP
case (Class(name="int") | Ref("builtins.int") as c, Literal(int())) | (
Expand All @@ -520,10 +529,11 @@ def join(t1: TypeExpr, t2: TypeExpr) -> TypeExpr:
):
return c
case (SideEffect() as s1, SideEffect() as s2):
assert s1.update[1] == s2.update[1]
return SideEffect(
new=s1.new | s2.new,
bound_method=s1.bound_method | s2.bound_method,
update=join(s1.update, s2.update),
update=(join(s1.update[0], s2.update[0]), s1.update[1]),
points_to_args=s1.points_to_args | s2.points_to_args,
)
case x, y:
Expand Down Expand Up @@ -556,7 +566,7 @@ def meet(t1: TypeExpr, t2: TypeExpr) -> TypeExpr:
tuple(meet(t1, t2) for t1, t2 in zip(l1.value, l2.value)),
l1.ref,
)
if l1.ref == Ref("builtins.list"):
if l1.ref == LIST:
return Instantiation(
l1.ref, (meet_all([*l1.value, *l2.value]),)
)
Expand Down Expand Up @@ -1054,13 +1064,13 @@ def get_init_func(callable: TypeExpr) -> TypeExpr:
]
)
side_effect = SideEffect(
new=True, bound_method=True, update=None, points_to_args=True
new=True, bound_method=True, update=(None, ()), points_to_args=True
)
res = bind_self(
overload(
replace(
f,
return_type=f.side_effect.update or selftype,
return_type=f.side_effect.update[0] or selftype,
side_effect=side_effect,
)
for f in init.items
Expand Down Expand Up @@ -1268,7 +1278,7 @@ def make_list_constructor() -> Overloaded:
FunctionType(
params=typed_dict([make_row(0, "args", args)]),
return_type=return_type,
side_effect=SideEffect(new=True, points_to_args=True, name="[]"),
side_effect=SideEffect(new=True, points_to_args=True),
is_property=False,
type_params=(args,),
)
Expand All @@ -1277,13 +1287,13 @@ def make_list_constructor() -> Overloaded:


def make_set_constructor() -> Overloaded:
return_type = Instantiation(Ref("builtins.set"), (union([]),))
return_type = Instantiation(SET, (union([]),))
return overload(
[
FunctionType(
params=typed_dict([]),
return_type=return_type,
side_effect=SideEffect(new=True, points_to_args=True, name="{}"),
side_effect=SideEffect(new=True, points_to_args=True),
is_property=False,
type_params=(),
)
Expand All @@ -1293,13 +1303,13 @@ def make_set_constructor() -> Overloaded:

def make_tuple_constructor() -> Overloaded:
args = TypeVar("Args", is_args=True)
return_type = Instantiation(Ref("builtins.tuple"), (args,))
return_type = Instantiation(TUPLE, (args,))
return overload(
[
FunctionType(
params=typed_dict([make_row(0, "args", args)]),
return_type=return_type,
side_effect=SideEffect(new=True, points_to_args=True, name="()"),
side_effect=SideEffect(new=True, points_to_args=True),
is_property=False,
type_params=(args,),
)
Expand All @@ -1309,8 +1319,6 @@ def make_tuple_constructor() -> Overloaded:

def make_slice_constructor() -> Overloaded:
return_type = Ref("builtins.slice")
NONE = literal(None)
INT = Ref("builtins.int")
both = union([NONE, INT])
return overload(
[
Expand All @@ -1319,7 +1327,7 @@ def make_slice_constructor() -> Overloaded:
[make_row(0, "start", both), make_row(1, "end", both)]
),
return_type=return_type,
side_effect=SideEffect(new=True, name="[:]"),
side_effect=SideEffect(new=True),
is_property=False,
type_params=(),
)
Expand Down Expand Up @@ -1619,8 +1627,9 @@ def visit_Name(self, name) -> TypeExpr:
return self.symtable.lookup(name.id)

def visit_Starred(self, starred: ast.Starred) -> TypeExpr:
assert isinstance(starred.value, ast.Name), f"{starred!r}"
return TypeVar(starred.value.id, is_args=True)
if isinstance(starred.value, ast.Name):
return TypeVar(starred.value.id, is_args=True)
return Star((self.to_type(starred.value),))

def visit_Subscript(self, subscr: ast.Subscript) -> TypeExpr:
generic = self.to_type(subscr.value)
Expand Down Expand Up @@ -1677,7 +1686,7 @@ def is_immutable(value: TypeExpr) -> bool:
case Row(type=value):
return is_immutable(value)
case FunctionType() as f:
if f.side_effect.update is not None:
if f.side_effect.update[0] is not None:
return False
return True
case Ref(name):
Expand Down Expand Up @@ -1805,21 +1814,21 @@ def visit_FunctionDef(self, fdef: ast.FunctionDef) -> FunctionType:
update = call_decorators.get("update")
if update is not None:
assert isinstance(update, ast.Call)
assert len(update.args) == 1
update_arg = update.args[0]
if isinstance(update_arg, ast.Constant) and isinstance(
update_arg.value, str
):
update_arg = ast.parse(update_arg.s).body[0].value
update_type = self.expr_to_type(update_arg)
update_args = tuple(self.expr_to_type(x) for x in update.args[1:])
else:
update_type = None
update_args = ()
# side_effect = parse_side_effect(fdef.body)
side_effect = SideEffect(
new="new" in name_decorators and not is_immutable(returns),
update=update_type,
update=(update_type, update_args),
points_to_args="points_to_args" in name_decorators,
name=fdef.name,
)
is_property = "property" in name_decorators

Expand Down Expand Up @@ -1946,13 +1955,14 @@ def is_bound_method(t: TypeExpr) -> bool:


def get_side_effect(applied: Overloaded) -> SideEffect:
[name] = {x.side_effect.name for x in applied.items}
return SideEffect(
new=any(x.side_effect.new for x in applied.items),
update=join_all(x.side_effect.update for x in applied.items),
update=(
join_all(x.side_effect.update[0] for x in applied.items),
applied.items[0].side_effect.update[1],
),
bound_method=any(is_bound_method(x) for x in applied.items),
points_to_args=any(x.side_effect.points_to_args for x in applied.items),
name=name,
)


Expand Down
Loading

0 comments on commit 2ce3c8f

Please sign in to comment.