Skip to content

Commit

Permalink
move tuple-specific check from guard_tracker to vs.tuplevar (apache#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
superDong1998 authored Sep 25, 2023
2 parents db8f18b + 880de91 commit 7e1a676
Show file tree
Hide file tree
Showing 11 changed files with 168 additions and 66 deletions.
51 changes: 15 additions & 36 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .c_api import get_value_stack_from_top, get_value_stack_size, set_eval_frame, stack_effect
from .instruction import Instruction, ci
from .cache import CachedGraph, get_frame_cache
from .store_pos import StorePos, StoreInStack, StoreInLocal, StoreInGlobal, StoreInAttr, StoreInTuple
from .store_pos import StorePos, StoreInStack, StoreInLocal, StoreInGlobal, StoreInAttr, StoreInIndex
from . import variables as vs
from .utils import is_scalar, new_random_key, has_force_graph_break, NullObject, is_call_bytecode, fx_graph_functions, is_user_defined_func, UnknownTypeError, get_all_objects_in_stack
from .object_table import ObjectTable
Expand Down Expand Up @@ -142,7 +142,9 @@ def from_frame(cls, frame: FrameType, read_stack: bool,
state.start_stack_size = get_value_stack_size(frame)
for i in range(state.start_stack_size):
value = get_value_stack_from_top(frame, i)
var = vs.make_var_from_value(value, True, state.fx_graph,
var = vs.make_var_from_value(value, True,
state.objects.read_only,
state.fx_graph,
[StoreInLocal(f"__stack__{i}")])
state.objects.add(var, value)
# state.written may be assigned inside make_var_from_value
Expand Down Expand Up @@ -320,30 +322,6 @@ def init_state(self, read_stack: bool = True) -> None:
self.state = State.from_frame(self.frame, read_stack, self.frame_root)
self.have_error = False

def variable_check(self, var: TupleVar,
extract_code_at_start: StorePos) -> None:
for i, sub_obj in enumerate(var.value):
sub_var = vs.make_var_from_value(
sub_obj, True, self.state.fx_graph,
[StoreInTuple(extract_code_at_start, i)])
self.state.add_object(sub_var, sub_obj)
if isinstance(sub_var, TupleVar):
self.variable_check(sub_var,
StoreInTuple(extract_code_at_start, i))

def variable_output(self, var: Variable, name_in_graph_fn: str,
store_pos: StorePos, codegen: "GraphFnCodegen") -> None:
if isinstance(var, TupleVar):
self.tuple_output(var)
var.make_output(name_in_graph_fn, store_pos, codegen)

def tuple_output(self, var: TupleVar) -> None:
for sub_val in var.value:
sub_obj = self.state.objects.get(sub_val, allow_unexist_const=True)
var.objs.append(sub_obj)
if isinstance(sub_obj, TupleVar):
self.tuple_output(sub_obj)

def record(
self, frame: FrameType, frame_id: int
) -> None: # pass frame and frame_id only for assertion
Expand Down Expand Up @@ -446,8 +424,7 @@ def commit(self, break_before_cur_inst: bool) -> None:

for i, value in enumerate(stack_objs):
var = self.state.objects.get(value, allow_unexist_const=True)
self.variable_output(var, f"__stack__{i}", StoreInStack(i),
graph_codegen)
var.make_output(f"__stack__{i}", StoreInStack(i), graph_codegen)

self.state.fx_graph.set_output_nodes(graph_codegen.get_graph_outputs())

Expand Down Expand Up @@ -625,11 +602,11 @@ def LOAD_CONST(self, _inst: Instruction) -> None:
def LOAD_FAST(self, inst: Instruction) -> None:
if inst.argval not in self.state.stored_locals:
obj = self.frame.f_locals[inst.argval]
var = vs.make_var_from_value(obj, True, self.state.fx_graph,
var = vs.make_var_from_value(obj, True,
self.state.objects.read_only,
self.state.fx_graph,
[StoreInLocal(inst.argval)])
self.state.add_object(var, obj)
if isinstance(var, TupleVar):
self.variable_check(var, StoreInLocal(inst.argval))

def LOAD_GLOBAL(self, inst: Instruction) -> None:
if inst.argval not in self.state.stored_globals:
Expand All @@ -642,19 +619,20 @@ def LOAD_GLOBAL(self, inst: Instruction) -> None:
except Exception as e:
raise UnknownTypeError(inst.argval)

var = vs.make_var_from_value(obj, True, self.state.fx_graph,
var = vs.make_var_from_value(obj, True,
self.state.objects.read_only,
self.state.fx_graph,
[StoreInGlobal(inst.argval)])
self.state.add_object(var, obj)
if isinstance(var, TupleVar):
self.variable_check(var, StoreInGlobal(inst.argval))

# heheda: we need to make sure that no unbound LOAD_METHOD is called by python runtime to avoid NULL in stack
def LOAD_METHOD(self, inst: Instruction) -> None:
self_obj = get_value_stack_from_top(self.frame, 0)
method = getattr(self_obj, inst.argval)
self_var = self.state.objects.get(self_obj)
method_var = vs.make_var_from_value(
method, self_var.need_guard_check, self.state.fx_graph, [
method, self_var.need_guard_check, self.state.objects.read_only,
self.state.fx_graph, [
StoreInAttr(self_var.extract_code_at_start[0], self_obj,
inst.argval)
] if self_var.need_guard_check else [])
Expand All @@ -668,7 +646,8 @@ def LOAD_ATTR(self, inst: Instruction) -> None:
attr = getattr(obj, inst.argval)
obj_var = self.state.objects.get(obj)
attr_var = vs.make_var_from_value(
attr, obj_var.need_guard_check, self.state.fx_graph,
attr, obj_var.need_guard_check, self.state.objects.read_only,
self.state.fx_graph,
[StoreInAttr(obj_var.extract_code_at_start[0], obj, inst.argval)]
if obj_var.need_guard_check else [])
if isinstance(obj_var, vs.ModuleVar):
Expand Down
60 changes: 56 additions & 4 deletions frontend/object_table.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from typing import Any, get_args, Optional, Tuple
from typing import Any, get_args, Optional, Tuple, Generic
from .variables.base import Variable
from .variables import CONST_TYPES, ScalarVar, make_var_from_value
from .variables.tuple_ import TupleVar
from .utils import NullObject
from .utils import NullObject, ReadOnlyObject
from .store_pos import StorePos
from .fx_graph import FxGraph


class ObjectTable:
objs: dict[int, Variable] # id -> object
# Python caches small integers, so int variables don't have unique ids
objs_no_id: list[Variable]
read_only: 'ReadOnlyObjectTable'

def __init__(self) -> None:
self.objs = {}
self.objs_no_id = []
self.read_only = ReadOnlyObjectTable(self)

def add(self, var: Variable, value: Any) -> None:
if isinstance(value, bool):
Expand All @@ -23,9 +27,11 @@ def add(self, var: Variable, value: Any) -> None:
old_var.need_guard_check |= var.need_guard_check
else:
self.objs[id(value)] = var
var.add_subvars_to_table(self)

def add_by_id(self, var: Variable, idx: int) -> None:
self.objs[idx] = var
var.add_subvars_to_table(self)

def get_all(self) -> list[Variable]:
return list(self.objs.values()) + self.objs_no_id
Expand All @@ -36,9 +42,9 @@ def get(self, value: Any, allow_unexist_const: bool = False) -> Variable:
elif id(value) in self.objs:
return self.objs[id(value)]
elif allow_unexist_const and isinstance(value, get_args(CONST_TYPES)):
return make_var_from_value(value, False)
return make_var_from_value(value, False, self.read_only)
elif isinstance(value, tuple):
return TupleVar(value, False)
return TupleVar(value, False, self.read_only)
raise RuntimeError(f"Object {value} not found in object table")

def get_or_none(self, value: Any) -> Optional[Variable]:
Expand All @@ -47,6 +53,19 @@ def get_or_none(self, value: Any) -> Optional[Variable]:
else:
return None

def get_or_make_var(self,
value: Any,
need_guard_check: bool,
fx_graph: Optional[FxGraph] = None,
extract_code_at_start: list[StorePos] = []) -> Variable:
if isinstance(value, bool):
return ScalarVar(value, need_guard_check, extract_code_at_start)
elif id(value) in self.objs:
return self.objs[id(value)]
else:
return make_var_from_value(value, need_guard_check, self.read_only,
fx_graph, extract_code_at_start)

def get_by_id(self, idx: int) -> Variable:
return self.objs[idx]

Expand All @@ -55,3 +74,36 @@ def contains(self, value: Any) -> bool:

def contains_by_id(self, idx: int) -> bool:
return idx in self.objs


class ReadOnlyObjectTable:
table: ObjectTable

def __init__(self, table: ObjectTable) -> None:
self.table = table

def get_all(self) -> list[Variable]:
return self.table.get_all()

def get(self, value: Any, allow_unexist_const: bool = False) -> Variable:
return self.table.get(value, allow_unexist_const)

def get_or_none(self, value: Any) -> Optional[Variable]:
return self.table.get_or_none(value)

def get_or_make_var(self,
value: Any,
need_guard_check: bool,
fx_graph: Optional[FxGraph] = None,
extract_code_at_start: list[StorePos] = []) -> Variable:
return self.table.get_or_make_var(value, need_guard_check, fx_graph,
extract_code_at_start)

def get_by_id(self, idx: int) -> Variable:
return self.table.get_by_id(idx)

def contains(self, value: Any) -> bool:
return self.table.contains(value)

def contains_by_id(self, idx: int) -> bool:
return self.table.contains_by_id(idx)
6 changes: 3 additions & 3 deletions frontend/store_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def __str__(self) -> str:
return f"{self.self_pos}.{self.attr_name}"


class StoreInTuple(StorePos):
class StoreInIndex(StorePos):
self_pos: StorePos
self_idx: int
self_idx: Any

def __init__(self, self_pos: StorePos, self_idx: int) -> None:
def __init__(self, self_pos: StorePos, self_idx: Any) -> None:
self.self_pos = self_pos
self.self_idx = self_idx

Expand Down
22 changes: 21 additions & 1 deletion frontend/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import dis
from typing import Any, TYPE_CHECKING, Callable
from typing import Any, TYPE_CHECKING, Callable, TypeVar, Generic
from types import FrameType
import random
import operator
Expand Down Expand Up @@ -184,3 +184,23 @@ def __enter__(self) -> None:
def __exit__(self, *args: Any) -> None:
if self.old_ld_preload:
os.environ['LD_PRELOAD'] = self.old_ld_preload


T = TypeVar('T')


class ReadOnlyObject(Generic[T]):
obj: T
const_attrs: tuple[str, ...]

def __init__(self, obj: T, const_attrs: tuple[str, ...] = ()) -> None:
self.obj = obj
self.const_attrs = const_attrs

def __getattr__(self, attr: str) -> Any:
if attr in self.const_attrs:
return getattr(self.obj, attr)
else:
raise AttributeError(
f"Attribute {attr} should not be called in reader of {self.obj}"
)
23 changes: 12 additions & 11 deletions frontend/variables/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Union, Optional, Tuple
from typing import Any, Union, Optional, Tuple, TYPE_CHECKING
from types import ModuleType
import torch
from .base import Variable
Expand All @@ -10,6 +10,8 @@
from ..fx_graph import FxGraph
from ..utils import NullObject, UnknownTypeError
from ..store_pos import StorePos
if TYPE_CHECKING:
from ..object_table import ReadOnlyObjectTable

ty2var: dict[type[Any], type[Variable]] = {
float: ScalarVar,
Expand All @@ -27,23 +29,22 @@

def make_var_from_value(value: Any,
need_guard_check: bool,
object_table: 'ReadOnlyObjectTable',
fx_graph: Optional[FxGraph] = None,
extract_code_at_start: list[StorePos] = []) -> Variable:
if type(value) in ty2var:
return ty2var[type(value)].from_value(value, need_guard_check, fx_graph,
return ty2var[type(value)].from_value(value, need_guard_check,
object_table, fx_graph,
extract_code_at_start)
elif isinstance(value, torch.nn.Module):
return TorchModuleVar.from_value(value, need_guard_check, fx_graph,
extract_code_at_start)
return TorchModuleVar.from_value(value, need_guard_check, object_table,
fx_graph, extract_code_at_start)
elif isinstance(value, ModuleType):
return ModuleVar.from_value(value, need_guard_check, fx_graph,
extract_code_at_start)
return ModuleVar.from_value(value, need_guard_check, object_table,
fx_graph, extract_code_at_start)
elif callable(value):
return FunctionVar.from_value(value, need_guard_check, fx_graph,
extract_code_at_start)
elif isinstance(value, tuple):
return TupleVar.from_value(value, need_guard_check, fx_graph,
extract_code_at_start)
return FunctionVar.from_value(value, need_guard_check, object_table,
fx_graph, extract_code_at_start)
raise UnknownTypeError(type(value))


Expand Down
5 changes: 5 additions & 0 deletions frontend/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.fx
from ..pycode_generator import GraphFnCodegen, GuardFnCodegen
from ..fx_graph import FxGraph, NodeArgs
from ..object_table import ReadOnlyObjectTable, ObjectTable


@dataclass
Expand All @@ -27,6 +28,7 @@ def __init__(self,
def from_value(self,
value: Any,
need_guard_check: bool,
object_table: 'ReadOnlyObjectTable',
fx_graph: Optional[FxGraph] = None,
extract_code_at_start: list[StorePos] = []) -> 'Variable':
raise NotImplementedError
Expand Down Expand Up @@ -55,3 +57,6 @@ def make_temp(self, name_in_graph_fn: str, store_pos: StorePos,
@abstractmethod
def as_fx_node(self) -> "NodeArgs":
raise NotImplementedError

def add_subvars_to_table(self, table: 'ObjectTable') -> None:
pass
6 changes: 6 additions & 0 deletions frontend/variables/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..store_pos import StorePos
if TYPE_CHECKING:
from ..pycode_generator import GraphFnCodegen, GuardFnCodegen
from ..object_table import ReadOnlyObjectTable


class NoneVar(Variable):
Expand All @@ -37,6 +38,7 @@ def make_temp(self, name_in_graph_fn: str, store_pos: StorePos,
def from_value(cls,
value: None,
need_guard_check: bool,
_object_table: 'ReadOnlyObjectTable',
_fx_graph: Optional[FxGraph] = None,
extract_code_at_start: list[StorePos] = []) -> "NoneVar":
return cls(need_guard_check, extract_code_at_start)
Expand Down Expand Up @@ -69,6 +71,7 @@ def make_temp(self, name_in_graph_fn: str, store_pos: StorePos,
def from_value(cls,
value: NullObject,
need_guard_check: bool,
_object_table: 'ReadOnlyObjectTable',
_fx_graph: Optional[FxGraph] = None,
extract_code_at_start: list[StorePos] = []) -> "NullVar":
return cls(need_guard_check, extract_code_at_start)
Expand Down Expand Up @@ -110,6 +113,7 @@ def make_temp(self, name_in_graph_fn: str, store_pos: StorePos,
def from_value(cls,
value: slice,
need_guard_check: bool,
_object_table: 'ReadOnlyObjectTable',
_fx_graph: Optional[FxGraph] = None,
extract_code_at_start: list[StorePos] = []) -> "SliceVar":
return cls(value.start, value.stop, value.step, need_guard_check,
Expand Down Expand Up @@ -154,6 +158,7 @@ def make_output(self, name_in_graph_fn: str, store_pos: StorePos,
def from_value(cls,
value: ModuleType,
need_guard_check: bool,
_object_table: 'ReadOnlyObjectTable',
_fx_graph: Optional[FxGraph] = None,
extract_code_at_start: list[StorePos] = []) -> "ModuleVar":
if value in torch_modules:
Expand Down Expand Up @@ -189,6 +194,7 @@ def make_output(self, name_in_graph_fn: str, store_pos: StorePos,
def from_value(cls,
value: Callable[..., Any],
need_guard_check: bool,
_object_table: 'ReadOnlyObjectTable',
_fx_graph: Optional[FxGraph] = None,
extract_code_at_start: list[StorePos] = []) -> "FunctionVar":
return cls(value, ObjectSrc.USER_DEFINED, need_guard_check,
Expand Down
2 changes: 2 additions & 0 deletions frontend/variables/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ..store_pos import StorePos
if TYPE_CHECKING:
from ..pycode_generator import GraphFnCodegen, GuardFnCodegen
from ..object_table import ReadOnlyObjectTable

ScalarType = Union[int, float, bool, str]

Expand Down Expand Up @@ -46,6 +47,7 @@ def make_temp(self, name_in_graph_fn: str, store_pos: StorePos,
def from_value(cls,
value: ScalarType,
need_guard_check: bool,
_object_table: 'ReadOnlyObjectTable',
_fx_graph: Optional[FxGraph] = None,
extract_code_at_start: list[StorePos] = []) -> "ScalarVar":
return cls(value, need_guard_check, extract_code_at_start)
Expand Down
Loading

0 comments on commit 7e1a676

Please sign in to comment.