Skip to content

Commit

Permalink
Remove unnecessary alias for building_blocks in `transformations_te…
Browse files Browse the repository at this point in the history
…st`.

PiperOrigin-RevId: 690775356
  • Loading branch information
michaelreneer authored and copybara-github committed Oct 28, 2024
1 parent f4ae37f commit fdd409e
Showing 1 changed file with 145 additions and 60 deletions.
205 changes: 145 additions & 60 deletions tensorflow_federated/python/core/impl/compiler/transformations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,88 +67,129 @@ def test_inlines_selections(self):
int_type = computation_types.TensorType(np.int32)
structed = computation_types.StructType([int_type])
double = computation_types.StructType([structed])
bb = building_blocks
before = bb.Lambda(
before = building_blocks.Lambda(
'x',
double,
bb.Block(
building_blocks.Block(
[
('y', bb.Selection(bb.Reference('x', double), index=0)),
('z', bb.Selection(bb.Reference('y', structed), index=0)),
(
'y',
building_blocks.Selection(
building_blocks.Reference('x', double), index=0
),
),
(
'z',
building_blocks.Selection(
building_blocks.Reference('y', structed), index=0
),
),
],
bb.Reference('z', int_type),
building_blocks.Reference('z', int_type),
),
)
after = transformations.to_call_dominant(before)
expected = bb.Lambda(
expected = building_blocks.Lambda(
'x',
double,
bb.Selection(bb.Selection(bb.Reference('x', double), index=0), index=0),
building_blocks.Selection(
building_blocks.Selection(
building_blocks.Reference('x', double), index=0
),
index=0,
),
)
self.assert_compact_representations_equal(after, expected)

def test_inlines_structs(self):
int_type = computation_types.TensorType(np.int32)
structed = computation_types.StructType([int_type])
double = computation_types.StructType([structed])
bb = building_blocks
before = bb.Lambda(
before = building_blocks.Lambda(
'x',
int_type,
bb.Block(
building_blocks.Block(
[
('y', bb.Struct([building_blocks.Reference('x', int_type)])),
('z', bb.Struct([building_blocks.Reference('y', structed)])),
(
'y',
building_blocks.Struct(
[building_blocks.Reference('x', int_type)]
),
),
(
'z',
building_blocks.Struct(
[building_blocks.Reference('y', structed)]
),
),
],
bb.Reference('z', double),
building_blocks.Reference('z', double),
),
)
after = transformations.to_call_dominant(before)
expected = bb.Lambda(
'x', int_type, bb.Struct([bb.Struct([bb.Reference('x', int_type)])])
expected = building_blocks.Lambda(
'x',
int_type,
building_blocks.Struct(
[building_blocks.Struct([building_blocks.Reference('x', int_type)])]
),
)
self.assert_compact_representations_equal(after, expected)

def test_inlines_selection_from_struct(self):
int_type = computation_types.TensorType(np.int32)
bb = building_blocks
before = bb.Lambda(
before = building_blocks.Lambda(
'x',
int_type,
bb.Selection(bb.Struct([bb.Reference('x', int_type)]), index=0),
building_blocks.Selection(
building_blocks.Struct([building_blocks.Reference('x', int_type)]),
index=0,
),
)
after = transformations.to_call_dominant(before)
expected = bb.Lambda('x', int_type, bb.Reference('x', int_type))
expected = building_blocks.Lambda(
'x', int_type, building_blocks.Reference('x', int_type)
)
self.assert_compact_representations_equal(after, expected)

def test_creates_binding_for_each_call(self):
int_type = computation_types.TensorType(np.int32)
int_to_int_type = computation_types.FunctionType(int_type, int_type)
bb = building_blocks
any_proto = building_block_test_utils.create_any_proto_from_array(
np.array([1, 2, 3])
)
int_to_int_fn = bb.Data(any_proto, int_to_int_type)
before = bb.Lambda(
int_to_int_fn = building_blocks.Data(any_proto, int_to_int_type)
before = building_blocks.Lambda(
'x',
int_type,
bb.Call(
int_to_int_fn, bb.Call(int_to_int_fn, bb.Reference('x', int_type))
building_blocks.Call(
int_to_int_fn,
building_blocks.Call(
int_to_int_fn, building_blocks.Reference('x', int_type)
),
),
)
after = transformations.to_call_dominant(before)
expected = bb.Lambda(
expected = building_blocks.Lambda(
'x',
int_type,
bb.Block(
building_blocks.Block(
[
('_var1', bb.Call(int_to_int_fn, bb.Reference('x', int_type))),
(
'_var1',
building_blocks.Call(
int_to_int_fn, building_blocks.Reference('x', int_type)
),
),
(
'_var2',
bb.Call(int_to_int_fn, bb.Reference('_var1', int_type)),
building_blocks.Call(
int_to_int_fn,
building_blocks.Reference('_var1', int_type),
),
),
],
bb.Reference('_var2', int_type),
building_blocks.Reference('_var2', int_type),
),
)
self.assert_compact_representations_equal(after, expected)
Expand All @@ -157,72 +198,112 @@ def test_evaluates_called_lambdas(self):
int_type = computation_types.TensorType(np.int32)
int_to_int_type = computation_types.FunctionType(int_type, int_type)
int_thunk_type = computation_types.FunctionType(None, int_type)
bb = building_blocks
any_proto = building_block_test_utils.create_any_proto_from_array(
np.array([1, 2, 3])
)
int_to_int_fn = bb.Data(any_proto, int_to_int_type)
int_to_int_fn = building_blocks.Data(any_proto, int_to_int_type)

# -> (let result = ext(x) in (-> result))
# Each call of the outer lambda should create a single binding, with
# calls to the inner lambda repeatedly returning references to the binding.
higher_fn = bb.Lambda(
higher_fn = building_blocks.Lambda(
None,
None,
bb.Block(
[('result', bb.Call(int_to_int_fn, bb.Reference('x', int_type)))],
bb.Lambda(None, None, bb.Reference('result', int_type)),
building_blocks.Block(
[(
'result',
building_blocks.Call(
int_to_int_fn, building_blocks.Reference('x', int_type)
),
)],
building_blocks.Lambda(
None, None, building_blocks.Reference('result', int_type)
),
),
)
block_locals = [
('fn', higher_fn),
# fn = -> (let result = ext(x) in (-> result))
('get_val1', bb.Call(bb.Reference('fn', higher_fn.type_signature))),
(
'get_val1',
building_blocks.Call(
building_blocks.Reference('fn', higher_fn.type_signature)
),
),
# _var2 = ext(x)
# get_val1 = -> _var2
('get_val2', bb.Call(bb.Reference('fn', higher_fn.type_signature))),
(
'get_val2',
building_blocks.Call(
building_blocks.Reference('fn', higher_fn.type_signature)
),
),
# _var3 = ext(x)
# get_val2 = -> _var3
('val11', bb.Call(bb.Reference('get_val1', int_thunk_type))),
(
'val11',
building_blocks.Call(
building_blocks.Reference('get_val1', int_thunk_type)
),
),
# val11 = _var2
('val12', bb.Call(bb.Reference('get_val1', int_thunk_type))),
(
'val12',
building_blocks.Call(
building_blocks.Reference('get_val1', int_thunk_type)
),
),
# val12 = _var2
('val2', bb.Call(bb.Reference('get_val2', int_thunk_type))),
(
'val2',
building_blocks.Call(
building_blocks.Reference('get_val2', int_thunk_type)
),
),
# val2 = _var3
]
before = bb.Lambda(
before = building_blocks.Lambda(
'x',
int_type,
bb.Block(
building_blocks.Block(
block_locals,
# <_var2, _var2, _var3>
bb.Struct([
bb.Reference('val11', int_type),
bb.Reference('val12', int_type),
bb.Reference('val2', int_type),
building_blocks.Struct([
building_blocks.Reference('val11', int_type),
building_blocks.Reference('val12', int_type),
building_blocks.Reference('val2', int_type),
]),
),
)
after = transformations.to_call_dominant(before)
expected = bb.Lambda(
expected = building_blocks.Lambda(
'x',
int_type,
bb.Block(
building_blocks.Block(
[
('_var2', bb.Call(int_to_int_fn, bb.Reference('x', int_type))),
('_var3', bb.Call(int_to_int_fn, bb.Reference('x', int_type))),
(
'_var2',
building_blocks.Call(
int_to_int_fn, building_blocks.Reference('x', int_type)
),
),
(
'_var3',
building_blocks.Call(
int_to_int_fn, building_blocks.Reference('x', int_type)
),
),
],
bb.Struct([
bb.Reference('_var2', int_type),
bb.Reference('_var2', int_type),
bb.Reference('_var3', int_type),
building_blocks.Struct([
building_blocks.Reference('_var2', int_type),
building_blocks.Reference('_var2', int_type),
building_blocks.Reference('_var3', int_type),
]),
),
)
self.assert_compact_representations_equal(after, expected)

def test_creates_block_for_non_lambda(self):
bb = building_blocks
int_type = computation_types.TensorType(np.int32)
two_int_type = computation_types.StructType(
[(None, int_type), (None, int_type)]
Expand All @@ -231,14 +312,18 @@ def test_creates_block_for_non_lambda(self):
any_proto = building_block_test_utils.create_any_proto_from_array(
np.array([1, 2, 3])
)
call_ext = bb.Call(bb.Data(any_proto, get_two_int_type))
before = bb.Selection(call_ext, index=0)
call_ext = building_blocks.Call(
building_blocks.Data(any_proto, get_two_int_type)
)
before = building_blocks.Selection(call_ext, index=0)
after = transformations.to_call_dominant(before)
expected = bb.Block(
expected = building_blocks.Block(
[
('_var1', call_ext),
],
bb.Selection(bb.Reference('_var1', two_int_type), index=0),
building_blocks.Selection(
building_blocks.Reference('_var1', two_int_type), index=0
),
)
self.assert_compact_representations_equal(after, expected)

Expand Down

0 comments on commit fdd409e

Please sign in to comment.