diff --git a/tensorflow_federated/python/core/impl/compiler/transformations_test.py b/tensorflow_federated/python/core/impl/compiler/transformations_test.py index 6370122b0d..f3ae2cbfe8 100644 --- a/tensorflow_federated/python/core/impl/compiler/transformations_test.py +++ b/tensorflow_federated/python/core/impl/compiler/transformations_test.py @@ -67,23 +67,37 @@ def test_inlines_selections(self): int_type = computation_types.TensorType(np.int32) structed = computation_types.StructType([int_type]) double = computation_types.StructType([structed]) - bb = building_blocks - before = bb.Lambda( + before = building_blocks.Lambda( 'x', double, - bb.Block( + building_blocks.Block( [ - ('y', bb.Selection(bb.Reference('x', double), index=0)), - ('z', bb.Selection(bb.Reference('y', structed), index=0)), + ( + 'y', + building_blocks.Selection( + building_blocks.Reference('x', double), index=0 + ), + ), + ( + 'z', + building_blocks.Selection( + building_blocks.Reference('y', structed), index=0 + ), + ), ], - bb.Reference('z', int_type), + building_blocks.Reference('z', int_type), ), ) after = transformations.to_call_dominant(before) - expected = bb.Lambda( + expected = building_blocks.Lambda( 'x', double, - bb.Selection(bb.Selection(bb.Reference('x', double), index=0), index=0), + building_blocks.Selection( + building_blocks.Selection( + building_blocks.Reference('x', double), index=0 + ), + index=0, + ), ) self.assert_compact_representations_equal(after, expected) @@ -91,64 +105,91 @@ def test_inlines_structs(self): int_type = computation_types.TensorType(np.int32) structed = computation_types.StructType([int_type]) double = computation_types.StructType([structed]) - bb = building_blocks - before = bb.Lambda( + before = building_blocks.Lambda( 'x', int_type, - bb.Block( + building_blocks.Block( [ - ('y', bb.Struct([building_blocks.Reference('x', int_type)])), - ('z', bb.Struct([building_blocks.Reference('y', structed)])), + ( + 'y', + building_blocks.Struct( + [building_blocks.Reference('x', int_type)] + ), + ), + ( + 'z', + building_blocks.Struct( + [building_blocks.Reference('y', structed)] + ), + ), ], - bb.Reference('z', double), + building_blocks.Reference('z', double), ), ) after = transformations.to_call_dominant(before) - expected = bb.Lambda( - 'x', int_type, bb.Struct([bb.Struct([bb.Reference('x', int_type)])]) + expected = building_blocks.Lambda( + 'x', + int_type, + building_blocks.Struct( + [building_blocks.Struct([building_blocks.Reference('x', int_type)])] + ), ) self.assert_compact_representations_equal(after, expected) def test_inlines_selection_from_struct(self): int_type = computation_types.TensorType(np.int32) - bb = building_blocks - before = bb.Lambda( + before = building_blocks.Lambda( 'x', int_type, - bb.Selection(bb.Struct([bb.Reference('x', int_type)]), index=0), + building_blocks.Selection( + building_blocks.Struct([building_blocks.Reference('x', int_type)]), + index=0, + ), ) after = transformations.to_call_dominant(before) - expected = bb.Lambda('x', int_type, bb.Reference('x', int_type)) + expected = building_blocks.Lambda( + 'x', int_type, building_blocks.Reference('x', int_type) + ) self.assert_compact_representations_equal(after, expected) def test_creates_binding_for_each_call(self): int_type = computation_types.TensorType(np.int32) int_to_int_type = computation_types.FunctionType(int_type, int_type) - bb = building_blocks any_proto = building_block_test_utils.create_any_proto_from_array( np.array([1, 2, 3]) ) - int_to_int_fn = bb.Data(any_proto, int_to_int_type) - before = bb.Lambda( + int_to_int_fn = building_blocks.Data(any_proto, int_to_int_type) + before = building_blocks.Lambda( 'x', int_type, - bb.Call( - int_to_int_fn, bb.Call(int_to_int_fn, bb.Reference('x', int_type)) + building_blocks.Call( + int_to_int_fn, + building_blocks.Call( + int_to_int_fn, building_blocks.Reference('x', int_type) + ), ), ) after = transformations.to_call_dominant(before) - expected = bb.Lambda( + expected = building_blocks.Lambda( 'x', int_type, - bb.Block( + building_blocks.Block( [ - ('_var1', bb.Call(int_to_int_fn, bb.Reference('x', int_type))), + ( + '_var1', + building_blocks.Call( + int_to_int_fn, building_blocks.Reference('x', int_type) + ), + ), ( '_var2', - bb.Call(int_to_int_fn, bb.Reference('_var1', int_type)), + building_blocks.Call( + int_to_int_fn, + building_blocks.Reference('_var1', int_type), + ), ), ], - bb.Reference('_var2', int_type), + building_blocks.Reference('_var2', int_type), ), ) self.assert_compact_representations_equal(after, expected) @@ -157,72 +198,112 @@ def test_evaluates_called_lambdas(self): int_type = computation_types.TensorType(np.int32) int_to_int_type = computation_types.FunctionType(int_type, int_type) int_thunk_type = computation_types.FunctionType(None, int_type) - bb = building_blocks any_proto = building_block_test_utils.create_any_proto_from_array( np.array([1, 2, 3]) ) - int_to_int_fn = bb.Data(any_proto, int_to_int_type) + int_to_int_fn = building_blocks.Data(any_proto, int_to_int_type) # -> (let result = ext(x) in (-> result)) # Each call of the outer lambda should create a single binding, with # calls to the inner lambda repeatedly returning references to the binding. - higher_fn = bb.Lambda( + higher_fn = building_blocks.Lambda( None, None, - bb.Block( - [('result', bb.Call(int_to_int_fn, bb.Reference('x', int_type)))], - bb.Lambda(None, None, bb.Reference('result', int_type)), + building_blocks.Block( + [( + 'result', + building_blocks.Call( + int_to_int_fn, building_blocks.Reference('x', int_type) + ), + )], + building_blocks.Lambda( + None, None, building_blocks.Reference('result', int_type) + ), ), ) block_locals = [ ('fn', higher_fn), # fn = -> (let result = ext(x) in (-> result)) - ('get_val1', bb.Call(bb.Reference('fn', higher_fn.type_signature))), + ( + 'get_val1', + building_blocks.Call( + building_blocks.Reference('fn', higher_fn.type_signature) + ), + ), # _var2 = ext(x) # get_val1 = -> _var2 - ('get_val2', bb.Call(bb.Reference('fn', higher_fn.type_signature))), + ( + 'get_val2', + building_blocks.Call( + building_blocks.Reference('fn', higher_fn.type_signature) + ), + ), # _var3 = ext(x) # get_val2 = -> _var3 - ('val11', bb.Call(bb.Reference('get_val1', int_thunk_type))), + ( + 'val11', + building_blocks.Call( + building_blocks.Reference('get_val1', int_thunk_type) + ), + ), # val11 = _var2 - ('val12', bb.Call(bb.Reference('get_val1', int_thunk_type))), + ( + 'val12', + building_blocks.Call( + building_blocks.Reference('get_val1', int_thunk_type) + ), + ), # val12 = _var2 - ('val2', bb.Call(bb.Reference('get_val2', int_thunk_type))), + ( + 'val2', + building_blocks.Call( + building_blocks.Reference('get_val2', int_thunk_type) + ), + ), # val2 = _var3 ] - before = bb.Lambda( + before = building_blocks.Lambda( 'x', int_type, - bb.Block( + building_blocks.Block( block_locals, # <_var2, _var2, _var3> - bb.Struct([ - bb.Reference('val11', int_type), - bb.Reference('val12', int_type), - bb.Reference('val2', int_type), + building_blocks.Struct([ + building_blocks.Reference('val11', int_type), + building_blocks.Reference('val12', int_type), + building_blocks.Reference('val2', int_type), ]), ), ) after = transformations.to_call_dominant(before) - expected = bb.Lambda( + expected = building_blocks.Lambda( 'x', int_type, - bb.Block( + building_blocks.Block( [ - ('_var2', bb.Call(int_to_int_fn, bb.Reference('x', int_type))), - ('_var3', bb.Call(int_to_int_fn, bb.Reference('x', int_type))), + ( + '_var2', + building_blocks.Call( + int_to_int_fn, building_blocks.Reference('x', int_type) + ), + ), + ( + '_var3', + building_blocks.Call( + int_to_int_fn, building_blocks.Reference('x', int_type) + ), + ), ], - bb.Struct([ - bb.Reference('_var2', int_type), - bb.Reference('_var2', int_type), - bb.Reference('_var3', int_type), + building_blocks.Struct([ + building_blocks.Reference('_var2', int_type), + building_blocks.Reference('_var2', int_type), + building_blocks.Reference('_var3', int_type), ]), ), ) self.assert_compact_representations_equal(after, expected) def test_creates_block_for_non_lambda(self): - bb = building_blocks int_type = computation_types.TensorType(np.int32) two_int_type = computation_types.StructType( [(None, int_type), (None, int_type)] @@ -231,14 +312,18 @@ def test_creates_block_for_non_lambda(self): any_proto = building_block_test_utils.create_any_proto_from_array( np.array([1, 2, 3]) ) - call_ext = bb.Call(bb.Data(any_proto, get_two_int_type)) - before = bb.Selection(call_ext, index=0) + call_ext = building_blocks.Call( + building_blocks.Data(any_proto, get_two_int_type) + ) + before = building_blocks.Selection(call_ext, index=0) after = transformations.to_call_dominant(before) - expected = bb.Block( + expected = building_blocks.Block( [ ('_var1', call_ext), ], - bb.Selection(bb.Reference('_var1', two_int_type), index=0), + building_blocks.Selection( + building_blocks.Reference('_var1', two_int_type), index=0 + ), ) self.assert_compact_representations_equal(after, expected)