From 9e891aa5da91cadd049c9d464fb31fad92ae346e Mon Sep 17 00:00:00 2001 From: Zachary Garrett Date: Fri, 18 Oct 2024 09:50:01 -0700 Subject: [PATCH] Remove unused aggregators from tff.learning. PiperOrigin-RevId: 687336045 --- RELEASE.md | 4 + tensorflow_federated/python/learning/BUILD | 4 - .../python/learning/__init__.py | 4 - .../python/learning/algorithms/BUILD | 9 - .../learning/algorithms/fed_avg_test.py | 66 +--- .../fed_avg_with_optimizer_schedule_test.py | 32 +- .../learning/algorithms/fed_prox_test.py | 48 +-- .../learning/algorithms/fed_sgd_test.py | 28 -- .../python/learning/algorithms/mime_test.py | 28 -- .../learning/model_update_aggregator.py | 251 --------------- .../learning/model_update_aggregator_test.py | 296 ++---------------- 11 files changed, 47 insertions(+), 723 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index dad2afcf49..06b9066761 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -23,6 +23,10 @@ and this project adheres to ### Removed * `tff.types.tensorflow_to_type`, this function is no longer used. +* `tff.learning.dp_aggregator` removed. Prefer using the class methods on + `tff.aggregators.DifferentiallyPrivateFactory`. +* `tff.learning.ddp_secure_aggregator` and `tff.learning.secure_aggregator` + removed. ## Release 0.88.0 diff --git a/tensorflow_federated/python/learning/BUILD b/tensorflow_federated/python/learning/BUILD index 92498e5d1f..5e7675a54e 100644 --- a/tensorflow_federated/python/learning/BUILD +++ b/tensorflow_federated/python/learning/BUILD @@ -61,9 +61,6 @@ py_library( name = "model_update_aggregator", srcs = ["model_update_aggregator.py"], deps = [ - "//tensorflow_federated/python/aggregators:differential_privacy", - "//tensorflow_federated/python/aggregators:distributed_dp", - "//tensorflow_federated/python/aggregators:encoded", "//tensorflow_federated/python/aggregators:factory", "//tensorflow_federated/python/aggregators:mean", "//tensorflow_federated/python/aggregators:quantile_estimation", @@ -88,7 +85,6 @@ py_test( "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:iterative_process", - "//tensorflow_federated/python/core/test:static_assert", ], ) diff --git a/tensorflow_federated/python/learning/__init__.py b/tensorflow_federated/python/learning/__init__.py index d9bd571388..a1508c8a11 100644 --- a/tensorflow_federated/python/learning/__init__.py +++ b/tensorflow_federated/python/learning/__init__.py @@ -46,8 +46,4 @@ from tensorflow_federated.python.learning.debug_measurements import add_debug_measurements from tensorflow_federated.python.learning.debug_measurements import add_debug_measurements_with_mixed_dtype from tensorflow_federated.python.learning.loop_builder import LoopImplementation -from tensorflow_federated.python.learning.model_update_aggregator import compression_aggregator -from tensorflow_federated.python.learning.model_update_aggregator import ddp_secure_aggregator -from tensorflow_federated.python.learning.model_update_aggregator import dp_aggregator from tensorflow_federated.python.learning.model_update_aggregator import robust_aggregator -from tensorflow_federated.python.learning.model_update_aggregator import secure_aggregator diff --git a/tensorflow_federated/python/learning/algorithms/BUILD b/tensorflow_federated/python/learning/algorithms/BUILD index 66933fdfc3..947c129173 100644 --- a/tensorflow_federated/python/learning/algorithms/BUILD +++ b/tensorflow_federated/python/learning/algorithms/BUILD @@ -69,12 +69,9 @@ py_cpu_gpu_test( deps = [ ":fed_avg", "//tensorflow_federated/python/aggregators:factory_utils", - "//tensorflow_federated/python/core/test:static_assert", "//tensorflow_federated/python/learning:loop_builder", "//tensorflow_federated/python/learning:model_update_aggregator", - "//tensorflow_federated/python/learning/metrics:aggregator", "//tensorflow_federated/python/learning/models:model_examples", - "//tensorflow_federated/python/learning/models:test_models", "//tensorflow_federated/python/learning/optimizers:sgdm", ], ) @@ -118,7 +115,6 @@ py_cpu_gpu_test( shard_count = 10, deps = [ ":fed_avg_with_optimizer_schedule", - "//tensorflow_federated/python/core/test:static_assert", "//tensorflow_federated/python/learning:loop_builder", "//tensorflow_federated/python/learning:model_update_aggregator", "//tensorflow_federated/python/learning/metrics:aggregator", @@ -164,10 +160,8 @@ py_cpu_gpu_test( ":fed_prox", "//tensorflow_federated/python/aggregators:factory_utils", "//tensorflow_federated/python/core/templates:iterative_process", - "//tensorflow_federated/python/core/test:static_assert", "//tensorflow_federated/python/learning:loop_builder", "//tensorflow_federated/python/learning:model_update_aggregator", - "//tensorflow_federated/python/learning/metrics:aggregator", "//tensorflow_federated/python/learning/models:model_examples", "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/models:test_models", @@ -330,7 +324,6 @@ py_cpu_gpu_test( deps = [ ":fed_sgd", "//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_test_utils", - "//tensorflow_federated/python/core/test:static_assert", "//tensorflow_federated/python/learning:loop_builder", "//tensorflow_federated/python/learning:model_update_aggregator", "//tensorflow_federated/python/learning/metrics:aggregator", @@ -429,11 +422,9 @@ py_cpu_gpu_test( "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:iterative_process", "//tensorflow_federated/python/core/templates:measured_process", - "//tensorflow_federated/python/core/test:static_assert", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:loop_builder", "//tensorflow_federated/python/learning:model_update_aggregator", - "//tensorflow_federated/python/learning/metrics:aggregator", "//tensorflow_federated/python/learning/metrics:counters", "//tensorflow_federated/python/learning/models:functional", "//tensorflow_federated/python/learning/models:keras_utils", diff --git a/tensorflow_federated/python/learning/algorithms/fed_avg_test.py b/tensorflow_federated/python/learning/algorithms/fed_avg_test.py index bdfe8acd09..b18a01b89b 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_avg_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_avg_test.py @@ -18,38 +18,25 @@ from absl.testing import parameterized from tensorflow_federated.python.aggregators import factory_utils -from tensorflow_federated.python.core.test import static_assert from tensorflow_federated.python.learning import loop_builder from tensorflow_federated.python.learning import model_update_aggregator from tensorflow_federated.python.learning.algorithms import fed_avg -from tensorflow_federated.python.learning.metrics import aggregator from tensorflow_federated.python.learning.models import model_examples -from tensorflow_federated.python.learning.models import test_models from tensorflow_federated.python.learning.optimizers import sgdm class FedAvgTest(parameterized.TestCase): """Tests construction of the FedAvg training process.""" - @parameterized.product( - optimizer_fn=[ - sgdm.build_sgdm(learning_rate=0.1), - ], - aggregation_factory=[ - model_update_aggregator.robust_aggregator, - model_update_aggregator.compression_aggregator, - model_update_aggregator.secure_aggregator, - ], - ) - def test_construction_calls_model_fn(self, optimizer_fn, aggregation_factory): + def test_construction_calls_model_fn(self): # Assert that the process building does not call `model_fn` too many times. # `model_fn` can potentially be expensive (loading weights, processing, etc # ). mock_model_fn = mock.Mock(side_effect=model_examples.LinearRegression) fed_avg.build_weighted_fed_avg( model_fn=mock_model_fn, - client_optimizer_fn=optimizer_fn, - model_aggregator=aggregation_factory(), + client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1), + model_aggregator=model_update_aggregator.robust_aggregator(), ) self.assertEqual(mock_model_fn.call_count, 3) @@ -125,34 +112,6 @@ def test_unweighted_fed_avg_raises_on_weighted_aggregator(self): model_aggregator=model_aggregator, ) - def test_weighted_fed_avg_with_only_secure_aggregation(self): - model_fn = model_examples.LinearRegression - learning_process = fed_avg.build_weighted_fed_avg( - model_fn, - client_optimizer_fn=sgdm.build_sgdm(), - model_aggregator=model_update_aggregator.secure_aggregator( - weighted=True - ), - metrics_aggregator=aggregator.secure_sum_then_finalize, - ) - static_assert.assert_not_contains_unsecure_aggregation( - learning_process.next - ) - - def test_unweighted_fed_avg_with_only_secure_aggregation(self): - model_fn = model_examples.LinearRegression - learning_process = fed_avg.build_unweighted_fed_avg( - model_fn, - client_optimizer_fn=sgdm.build_sgdm(), - model_aggregator=model_update_aggregator.secure_aggregator( - weighted=False - ), - metrics_aggregator=aggregator.secure_sum_then_finalize, - ) - static_assert.assert_not_contains_unsecure_aggregation( - learning_process.next - ) - class FunctionalFedAvgTest(parameterized.TestCase): """Tests construction of the FedAvg training process.""" @@ -167,25 +126,6 @@ def test_raises_on_non_callable_or_functional_model(self, constructor): model_fn=0, client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1) ) - @parameterized.named_parameters( - ('weighted', fed_avg.build_weighted_fed_avg), - ('unweighted', fed_avg.build_unweighted_fed_avg), - ) - def test_weighted_fed_avg_with_only_secure_aggregation(self, constructor): - model = test_models.build_functional_linear_regression() - learning_process = constructor( - model_fn=model, - client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1), - server_optimizer_fn=sgdm.build_sgdm(), - model_aggregator=model_update_aggregator.secure_aggregator( - weighted=constructor is fed_avg.build_weighted_fed_avg - ), - metrics_aggregator=aggregator.secure_sum_then_finalize, - ) - static_assert.assert_not_contains_unsecure_aggregation( - learning_process.next - ) - if __name__ == '__main__': absltest.main() diff --git a/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule_test.py b/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule_test.py index d881551901..54ebe25f9d 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule_test.py @@ -18,7 +18,6 @@ from absl.testing import parameterized import tensorflow as tf -from tensorflow_federated.python.core.test import static_assert from tensorflow_federated.python.learning import loop_builder from tensorflow_federated.python.learning import model_update_aggregator from tensorflow_federated.python.learning.algorithms import fed_avg_with_optimizer_schedule @@ -30,17 +29,7 @@ class ClientScheduledFedAvgTest(parameterized.TestCase): - @parameterized.product( - optimizer_fn=[ - lambda x: sgdm.build_sgdm(learning_rate=x), - ], - aggregation_factory=[ - model_update_aggregator.robust_aggregator, - model_update_aggregator.compression_aggregator, - model_update_aggregator.secure_aggregator, - ], - ) - def test_construction_calls_model_fn(self, optimizer_fn, aggregation_factory): + def test_construction_calls_model_fn(self): # Assert that the process building does not call `model_fn` too many times. # `model_fn` can potentially be expensive (loading weights, processing, etc # ). @@ -49,8 +38,8 @@ def test_construction_calls_model_fn(self, optimizer_fn, aggregation_factory): fed_avg_with_optimizer_schedule.build_weighted_fed_avg_with_optimizer_schedule( model_fn=mock_model_fn, client_learning_rate_fn=learning_rate_fn, - client_optimizer_fn=optimizer_fn, - model_aggregator=aggregation_factory(), + client_optimizer_fn=lambda lr: sgdm.build_sgdm(learning_rate=lr), + model_aggregator=model_update_aggregator.robust_aggregator(), ) self.assertEqual(mock_model_fn.call_count, 3) @@ -143,21 +132,6 @@ def test_raises_on_non_callable_model_fn(self): client_optimizer_fn=lambda _: sgdm.build_sgdm(), ) - def test_construction_with_only_secure_aggregation(self): - model_fn = model_examples.LinearRegression - learning_process = fed_avg_with_optimizer_schedule.build_weighted_fed_avg_with_optimizer_schedule( - model_fn, - client_learning_rate_fn=lambda x: 0.5, - client_optimizer_fn=lambda x: sgdm.build_sgdm(), - model_aggregator=model_update_aggregator.secure_aggregator( - weighted=True - ), - metrics_aggregator=aggregator.secure_sum_then_finalize, - ) - static_assert.assert_not_contains_unsecure_aggregation( - learning_process.next - ) - def test_measurements_include_client_learning_rate(self): client_work = fed_avg_with_optimizer_schedule.build_scheduled_client_work( model_fn=model_examples.LinearRegression, diff --git a/tensorflow_federated/python/learning/algorithms/fed_prox_test.py b/tensorflow_federated/python/learning/algorithms/fed_prox_test.py index 0ada68f128..3481dd3002 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_prox_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_prox_test.py @@ -19,11 +19,9 @@ from tensorflow_federated.python.aggregators import factory_utils from tensorflow_federated.python.core.templates import iterative_process -from tensorflow_federated.python.core.test import static_assert from tensorflow_federated.python.learning import loop_builder from tensorflow_federated.python.learning import model_update_aggregator from tensorflow_federated.python.learning.algorithms import fed_prox -from tensorflow_federated.python.learning.metrics import aggregator from tensorflow_federated.python.learning.models import model_examples from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.models import test_models @@ -34,17 +32,7 @@ class FedProxConstructionTest(parameterized.TestCase): """Tests construction of the FedProx training process.""" - @parameterized.product( - optimizer_fn=[ - sgdm.build_sgdm(learning_rate=0.1), - ], - aggregation_factory=[ - model_update_aggregator.robust_aggregator, - model_update_aggregator.compression_aggregator, - model_update_aggregator.secure_aggregator, - ], - ) - def test_construction_calls_model_fn(self, optimizer_fn, aggregation_factory): + def test_construction_calls_model_fn(self): # Assert that the process building does not call `model_fn` too many times. # `model_fn` can potentially be expensive (loading weights, processing, etc # ). @@ -52,8 +40,8 @@ def test_construction_calls_model_fn(self, optimizer_fn, aggregation_factory): fed_prox.build_weighted_fed_prox( model_fn=mock_model_fn, proximal_strength=1.0, - client_optimizer_fn=optimizer_fn, - model_aggregator=aggregation_factory(), + client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1), + model_aggregator=model_update_aggregator.robust_aggregator(), ) self.assertEqual(mock_model_fn.call_count, 3) @@ -160,36 +148,6 @@ def test_unweighted_fed_avg_raises_on_weighted_aggregator(self): model_aggregator=model_aggregator, ) - def test_weighted_fed_prox_with_only_secure_aggregation(self): - model_fn = model_examples.LinearRegression - learning_process = fed_prox.build_weighted_fed_prox( - model_fn, - proximal_strength=1.0, - client_optimizer_fn=sgdm.build_sgdm(), - model_aggregator=model_update_aggregator.secure_aggregator( - weighted=True - ), - metrics_aggregator=aggregator.secure_sum_then_finalize, - ) - static_assert.assert_not_contains_unsecure_aggregation( - learning_process.next - ) - - def test_unweighted_fed_prox_with_only_secure_aggregation(self): - model_fn = model_examples.LinearRegression - learning_process = fed_prox.build_unweighted_fed_prox( - model_fn, - proximal_strength=1.0, - client_optimizer_fn=sgdm.build_sgdm(), - model_aggregator=model_update_aggregator.secure_aggregator( - weighted=False - ), - metrics_aggregator=aggregator.secure_sum_then_finalize, - ) - static_assert.assert_not_contains_unsecure_aggregation( - learning_process.next - ) - if __name__ == '__main__': absltest.main() diff --git a/tensorflow_federated/python/learning/algorithms/fed_sgd_test.py b/tensorflow_federated/python/learning/algorithms/fed_sgd_test.py index e2ad6b4ee1..445ea600dd 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_sgd_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_sgd_test.py @@ -20,7 +20,6 @@ import tensorflow as tf from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_test_utils -from tensorflow_federated.python.core.test import static_assert from tensorflow_federated.python.learning import loop_builder from tensorflow_federated.python.learning import model_update_aggregator from tensorflow_federated.python.learning.algorithms import fed_sgd @@ -160,11 +159,6 @@ def test_client_tf_dataset_reduce_fn(self, loop_implementation, mock_method): @parameterized.named_parameters( ('robust_aggregator', model_update_aggregator.robust_aggregator), - ( - 'compression_aggregator', - model_update_aggregator.compression_aggregator, - ), - ('secure_aggreagtor', model_update_aggregator.secure_aggregator), ) def test_construction_calls_model_fn(self, aggregation_factory): # Assert that the process building does not call `model_fn` too many times. @@ -177,17 +171,6 @@ def test_construction_calls_model_fn(self, aggregation_factory): # TODO: b/186451541 - reduce the number of calls to model_fn. self.assertEqual(mock_model_fn.call_count, 3) - def test_no_unsecure_aggregation_with_secure_aggregator(self): - model_fn = model_examples.LinearRegression - learning_process = fed_sgd.build_fed_sgd( - model_fn, - model_aggregator=model_update_aggregator.secure_aggregator(), - metrics_aggregator=aggregator.secure_sum_then_finalize, - ) - static_assert.assert_not_contains_unsecure_aggregation( - learning_process.next - ) - class FunctionalFederatedSgdTest(tf.test.TestCase, parameterized.TestCase): @@ -276,17 +259,6 @@ def test_build_functional_fed_sgd_succeeds(self): model = _build_functional_model() fed_sgd.build_fed_sgd(model_fn=model) - def test_no_unsecure_aggregation_with_secure_aggregator(self): - model = _build_functional_model() - learning_process = fed_sgd.build_fed_sgd( - model, - model_aggregator=model_update_aggregator.secure_aggregator(), - metrics_aggregator=aggregator.secure_sum_then_finalize, - ) - static_assert.assert_not_contains_unsecure_aggregation( - learning_process.next - ) - if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_federated/python/learning/algorithms/mime_test.py b/tensorflow_federated/python/learning/algorithms/mime_test.py index fe699abee3..8dc60f7e8c 100644 --- a/tensorflow_federated/python/learning/algorithms/mime_test.py +++ b/tensorflow_federated/python/learning/algorithms/mime_test.py @@ -31,13 +31,11 @@ from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import iterative_process from tensorflow_federated.python.core.templates import measured_process -from tensorflow_federated.python.core.test import static_assert from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import loop_builder from tensorflow_federated.python.learning import model_update_aggregator from tensorflow_federated.python.learning.algorithms import fed_avg from tensorflow_federated.python.learning.algorithms import mime -from tensorflow_federated.python.learning.metrics import aggregator as metrics_aggregator from tensorflow_federated.python.learning.metrics import counters from tensorflow_federated.python.learning.models import functional from tensorflow_federated.python.learning.models import keras_utils @@ -470,32 +468,6 @@ def test_unweighted_mime_lite_raises_on_weighted_aggregator(self): full_gradient_aggregator=aggregator, ) - def test_weighted_mime_lite_with_only_secure_aggregation(self): - aggregator = model_update_aggregator.secure_aggregator(weighted=True) - learning_process = mime.build_weighted_mime_lite( - model_examples.LinearRegression, - base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9), - model_aggregator=aggregator, - full_gradient_aggregator=aggregator, - metrics_aggregator=metrics_aggregator.secure_sum_then_finalize, - ) - static_assert.assert_not_contains_unsecure_aggregation( - learning_process.next - ) - - def test_unweighted_mime_lite_with_only_secure_aggregation(self): - aggregator = model_update_aggregator.secure_aggregator(weighted=False) - learning_process = mime.build_unweighted_mime_lite( - model_examples.LinearRegression, - base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9), - model_aggregator=aggregator, - full_gradient_aggregator=aggregator, - metrics_aggregator=metrics_aggregator.secure_sum_then_finalize, - ) - static_assert.assert_not_contains_unsecure_aggregation( - learning_process.next - ) - @tensorflow_test_utils.skip_test_for_multi_gpu def test_equivalent_to_vanilla_fed_avg(self): # Mime Lite with no-momentum SGD should reduce to FedAvg. diff --git a/tensorflow_federated/python/learning/model_update_aggregator.py b/tensorflow_federated/python/learning/model_update_aggregator.py index 35b6ee4500..f707e939c2 100644 --- a/tensorflow_federated/python/learning/model_update_aggregator.py +++ b/tensorflow_federated/python/learning/model_update_aggregator.py @@ -17,9 +17,6 @@ import math from typing import Optional, TypeVar -from tensorflow_federated.python.aggregators import differential_privacy -from tensorflow_federated.python.aggregators import distributed_dp -from tensorflow_federated.python.aggregators import encoded from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.aggregators import mean from tensorflow_federated.python.aggregators import quantile_estimation @@ -152,251 +149,3 @@ def robust_aggregator( aggregation_factory = _default_zeroing(aggregation_factory) return aggregation_factory - - -def dp_aggregator( - noise_multiplier: float, clients_per_round: float, zeroing: bool = True -) -> factory.UnweightedAggregationFactory: - """Creates aggregator with adaptive zeroing and differential privacy. - - Zeroes out extremely large values for robustness to data corruption on - clients, and performs adaptive clipping and addition of Gaussian noise for - differentially private learning. For details of the DP algorithm see McMahan - et. al (2017) https://arxiv.org/abs/1710.06963. The adaptive clipping uses the - geometric method described in Andrew, Thakkar et al. (2021) - https://arxiv.org/abs/1905.03871. - - Args: - noise_multiplier: A float specifying the noise multiplier for the Gaussian - mechanism for model updates. A value of 1.0 or higher may be needed for - meaningful privacy. See above mentioned papers to compute (epsilon, delta) - privacy guarantee. - clients_per_round: A float specifying the expected number of clients per - round. Must be positive. - zeroing: Whether to enable adaptive zeroing for data corruption mitigation. - - Returns: - A `tff.aggregators.UnweightedAggregationFactory`. - """ - - aggregation_factory = ( - differential_privacy.DifferentiallyPrivateFactory.gaussian_adaptive( - noise_multiplier, clients_per_round - ) - ) - - if zeroing: - aggregation_factory = _default_zeroing(aggregation_factory) - - return aggregation_factory - - -def compression_aggregator( - *, - zeroing: bool = True, - clipping: bool = True, - weighted: bool = True, - debug_measurements_fn: Optional[ - Callable[[factory.AggregationFactory], factory.AggregationFactory] - ] = None, - **kwargs, -) -> factory.AggregationFactory: - """Creates aggregator with compression and adaptive zeroing and clipping. - - Zeroes out extremely large values for robustness to data corruption on - clients and clips in the L2 norm to moderately high norm for robustness to - outliers. After weighting in mean, the weighted values are uniformly quantized - to reduce the size of the model update communicated from clients to the - server. For details, see Suresh et al. (2017) - http://proceedings.mlr.press/v70/suresh17a/suresh17a.pdf. The default - configuration is chosen such that compression does not have adverse effect on - trained model quality in typical tasks. - - Args: - zeroing: Whether to enable adaptive zeroing for data corruption mitigation. - clipping: Whether to enable adaptive clipping in the L2 norm for robustness. - Note this clipping is performed prior to the per-coordinate clipping - required for quantization. - weighted: Whether the mean is weighted (vs. unweighted). - debug_measurements_fn: A callable to add measurements suitable for debugging - learning algorithms, with possible values as None, - `tff.learning.add_debug_measurements` or - `tff.learning.add_debug_measurements_with_mixed_dtype`. - **kwargs: Keyword arguments. - - Returns: - A `tff.aggregators.AggregationFactory`. - - Raises: - TypeError: if debug_measurement_fn yields an aggregation factory whose - weight type does not match `weighted`. - """ - aggregation_factory = encoded.EncodedSumFactory.quantize_above_threshold( - quantization_bits=8, threshold=20000, **kwargs - ) - - aggregation_factory = ( - mean.MeanFactory(aggregation_factory) - if weighted - else mean.UnweightedMeanFactory(aggregation_factory) - ) - - if debug_measurements_fn is not None: - aggregation_factory = debug_measurements_fn(aggregation_factory) - if ( - weighted - and not isinstance( - aggregation_factory, factory.WeightedAggregationFactory - ) - ) or ( - (not weighted) - and ( - not isinstance( - aggregation_factory, factory.UnweightedAggregationFactory - ) - ) - ): - raise TypeError('debug_measurements_fn should return the same type.') - - if clipping: - aggregation_factory = _default_clipping(aggregation_factory) - - if zeroing: - aggregation_factory = _default_zeroing(aggregation_factory) - - return aggregation_factory - - -def secure_aggregator( - *, - zeroing: bool = True, - clipping: bool = True, - weighted: bool = True, -) -> factory.AggregationFactory: - """Creates secure aggregator with adaptive zeroing and clipping. - - Zeroes out extremely large values for robustness to data corruption on - clients, clips to moderately high norm for robustness to outliers. After - weighting in mean, the weighted values are summed using cryptographic protocol - ensuring that the server cannot see individual updates until sufficient number - of updates have been added together. For details, see Bonawitz et al. (2017) - https://dl.acm.org/doi/abs/10.1145/3133956.3133982. In TFF, this is realized - using the `tff.federated_secure_sum_bitwidth` operator. - - Args: - zeroing: Whether to enable adaptive zeroing for data corruption mitigation. - clipping: Whether to enable adaptive clipping in the L2 norm for robustness. - Note this clipping is performed prior to the per-coordinate clipping - required for secure aggregation. - weighted: Whether the mean is weighted (vs. unweighted). - - Returns: - A `tff.aggregators.AggregationFactory`. - """ - secure_clip_bound = ( - quantile_estimation.PrivateQuantileEstimationProcess.no_noise( - initial_estimate=50.0, - target_quantile=0.95, - learning_rate=1.0, - multiplier=2.0, - secure_estimation=True, - ) - ) - - aggregation_factory = secure.SecureSumFactory(secure_clip_bound) - - if weighted: - aggregation_factory = mean.MeanFactory( - value_sum_factory=aggregation_factory, - # Use a power of 2 minus one to more accurately encode floating dtypes - # that actually contain integer values. 2 ^ 20 gives us approximately a - # range of [0, 1 million]. Existing use cases have the weights either - # all ones, or a variant of number of examples processed locally. - weight_sum_factory=secure.SecureSumFactory( - upper_bound_threshold=float(2**20 - 1), lower_bound_threshold=0.0 - ), - ) - else: - aggregation_factory = mean.UnweightedMeanFactory( - value_sum_factory=aggregation_factory, - count_sum_factory=secure.SecureSumFactory( - upper_bound_threshold=1, lower_bound_threshold=0 - ), - ) - - if clipping: - aggregation_factory = _default_clipping( - aggregation_factory, secure_estimation=True - ) - - if zeroing: - aggregation_factory = _default_zeroing( - aggregation_factory, secure_estimation=True - ) - - return aggregation_factory - - -def ddp_secure_aggregator( - noise_multiplier: float, - expected_clients_per_round: int, - bits: int = 20, - zeroing: bool = True, - rotation_type: str = 'hd', -) -> factory.UnweightedAggregationFactory: - """Creates aggregator with adaptive zeroing and distributed DP. - - Zeroes out extremely large values for robustness to data corruption on - clients, and performs distributed DP (compression, discrete noising, and - SecAgg) with adaptive clipping for differentially private learning. For - details of the two main distributed DP algorithms see - https://arxiv.org/pdf/2102.06387 - or https://arxiv.org/pdf/2110.04995.pdf. The adaptive clipping uses the - geometric method described in https://arxiv.org/abs/1905.03871. - - Args: - noise_multiplier: A float specifying the noise multiplier (with respect to - the initial L2 cipping) for the distributed DP mechanism for model - updates. A value of 1.0 or higher may be needed for meaningful privacy. - expected_clients_per_round: An integer specifying the expected number of - clients per round. Must be positive. - bits: An integer specifying the bit-width for the aggregation. Note that - this is for the noisy, quantized aggregate at the server and thus should - account for the `expected_clients_per_round`. Must be in the inclusive - range of [1, 22]. This is set to 20 bits by default, and it dictates the - computational and communication efficiency of Secure Aggregation. Setting - it to less than 20 bits should work fine for most cases. For instance, for - an expected number of securely aggregated client updates of 100, 12 bits - should be enough, and for an expected number of securely aggregated client - updates of 1000, 16 bits should be enough. - zeroing: A bool indicating whether to enable adaptive zeroing for data - corruption mitigation. Defaults to `True`. - rotation_type: A string indicating what rotation to use for distributed DP. - Valid options are 'hd' (Hadamard transform) and 'dft' (discrete Fourier - transform). Defaults to `hd`. - - Returns: - A `tff.aggregators.UnweightedAggregationFactory`. - """ - aggregation_factory = distributed_dp.DistributedDpSumFactory( - noise_multiplier=noise_multiplier, - expected_clients_per_round=expected_clients_per_round, - bits=bits, - l2_clip=0.1, - mechanism='distributed_skellam', - rotation_type=rotation_type, - auto_l2_clip=True, - ) - aggregation_factory = mean.UnweightedMeanFactory( - value_sum_factory=aggregation_factory, - count_sum_factory=secure.SecureSumFactory( - upper_bound_threshold=1, lower_bound_threshold=0 - ), - ) - - if zeroing: - aggregation_factory = _default_zeroing( - aggregation_factory, secure_estimation=True - ) - - return aggregation_factory diff --git a/tensorflow_federated/python/learning/model_update_aggregator_test.py b/tensorflow_federated/python/learning/model_update_aggregator_test.py index 4e6641f519..9b2755b7b4 100644 --- a/tensorflow_federated/python/learning/model_update_aggregator_test.py +++ b/tensorflow_federated/python/learning/model_update_aggregator_test.py @@ -26,7 +26,6 @@ from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import iterative_process -from tensorflow_federated.python.core.test import static_assert from tensorflow_federated.python.learning import debug_measurements from tensorflow_federated.python.learning import model_update_aggregator @@ -34,201 +33,42 @@ _FLOAT_MATRIX_TYPE = computation_types.TensorType(np.float32, [200, 300]) -class ModelUpdateAggregatorTest(parameterized.TestCase): - - @parameterized.named_parameters( - ('simple', False, False, None), - ('zeroing', True, False, None), - ('clipping', False, True, None), - ('zeroing_and_clipping', True, True, None), - ( - 'debug_measurements', - False, - False, - debug_measurements.add_debug_measurements, - ), - ( - 'zeroing_clipping_debug_measurements', - True, - True, - debug_measurements.add_debug_measurements, - ), - ) - def test_robust_aggregator_weighted( - self, zeroing, clipping, debug_measurements_fn - ): - factory_ = model_update_aggregator.robust_aggregator( - zeroing=zeroing, - clipping=clipping, - debug_measurements_fn=debug_measurements_fn, - ) - - self.assertIsInstance(factory_, factory.WeightedAggregationFactory) - process = factory_.create(_FLOAT_TYPE, _FLOAT_TYPE) - self.assertIsInstance(process, aggregation_process.AggregationProcess) - self.assertTrue(process.is_weighted) - - @parameterized.named_parameters( - ('simple', False, False, None), - ('zeroing', True, False, None), - ('clipping', False, True, None), - ('zeroing_and_clipping', True, True, None), - ( - 'debug_measurements', - False, - False, - debug_measurements.add_debug_measurements, - ), - ( - 'zeroing_clipping_debug_measurements', - True, - True, - debug_measurements.add_debug_measurements, - ), - ) - def test_robust_aggregator_unweighted( - self, zeroing, clipping, debug_measurements_fn - ): - factory_ = model_update_aggregator.robust_aggregator( - zeroing=zeroing, - clipping=clipping, - weighted=False, - debug_measurements_fn=debug_measurements_fn, - ) +def _mrfify_aggregator(aggregator): + """Makes aggregator compatible with MapReduceForm.""" - self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) - process = factory_.create(_FLOAT_TYPE) - self.assertIsInstance(process, aggregation_process.AggregationProcess) - self.assertFalse(process.is_weighted) + if aggregator.is_weighted: - @parameterized.named_parameters( - ( - 'debug_measurements', - False, - False, - debug_measurements.add_debug_measurements_with_mixed_dtype, - ), - ( - 'zeroing_clipping_debug_measurements', - True, - True, - debug_measurements.add_debug_measurements_with_mixed_dtype, - ), - ) - def test_robust_aggregator_weighted_mixed_dtype( - self, zeroing, clipping, debug_measurements_fn - ): - factory_ = model_update_aggregator.robust_aggregator( - zeroing=zeroing, - clipping=clipping, - debug_measurements_fn=debug_measurements_fn, + @federated_computation.federated_computation( + aggregator.next.type_signature.parameter[0], + computation_types.FederatedType( + ( + aggregator.next.type_signature.parameter[1].member, + aggregator.next.type_signature.parameter[2].member, + ), + placements.CLIENTS, + ), ) - - self.assertIsInstance(factory_, factory.WeightedAggregationFactory) - process = factory_.create(_FLOAT_TYPE, _FLOAT_TYPE) - self.assertIsInstance(process, aggregation_process.AggregationProcess) - self.assertTrue(process.is_weighted) - - def test_wrong_debug_measurements_fn_robust_aggregator(self): - """Expect error if debug_measurements_fn is wrong.""" - with self.assertRaises(TypeError): - - def wrong_debug_measurements_fn( - aggregation_factory: factory.AggregationFactory, - ) -> ...: - del aggregation_factory - return ( - debug_measurements._calculate_client_update_statistics_mixed_dtype( - [1.0], [1.0] - ) - ) - - model_update_aggregator.robust_aggregator( - debug_measurements_fn=wrong_debug_measurements_fn + def next_fn(state, value): + output = aggregator.next(state, value[0], value[1]) + return output.state, intrinsics.federated_zip( + (output.result, output.measurements) ) - @parameterized.named_parameters( - ('simple', False), - ('zeroing', True), - ) - def test_dp_aggregator(self, zeroing): - factory_ = model_update_aggregator.dp_aggregator( - noise_multiplier=1e-2, clients_per_round=10, zeroing=zeroing - ) - - self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) - process = factory_.create(_FLOAT_TYPE) - self.assertIsInstance(process, aggregation_process.AggregationProcess) - self.assertFalse(process.is_weighted) - - @parameterized.named_parameters( - ('simple', False, False), - ('zeroing', True, False), - ('clipping', False, True), - ('zeroing_and_clipping', True, True), - ) - def test_secure_aggregator_weighted(self, zeroing, clipping): - factory_ = model_update_aggregator.secure_aggregator( - zeroing=zeroing, clipping=clipping - ) - - self.assertIsInstance(factory_, factory.WeightedAggregationFactory) - process = factory_.create(_FLOAT_TYPE, _FLOAT_TYPE) - self.assertIsInstance(process, aggregation_process.AggregationProcess) - self.assertTrue(process.is_weighted) + else: - @parameterized.named_parameters( - ('simple', False, False), - ('zeroing', True, False), - ('clipping', False, True), - ('zeroing_and_clipping', True, True), - ) - def test_secure_aggregator_unweighted(self, zeroing, clipping): - factory_ = model_update_aggregator.secure_aggregator( - zeroing=zeroing, clipping=clipping, weighted=False + @federated_computation.federated_computation( + aggregator.next.type_signature.parameter ) + def next_fn(state, value): + output = aggregator.next(state, value) + return output.state, intrinsics.federated_zip( + (output.result, output.measurements) + ) - self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) - process = factory_.create(_FLOAT_TYPE) - self.assertIsInstance(process, aggregation_process.AggregationProcess) - self.assertFalse(process.is_weighted) - - def test_weighted_secure_aggregator_only_contains_secure_aggregation(self): - aggregator = model_update_aggregator.secure_aggregator( - weighted=True - ).create(_FLOAT_MATRIX_TYPE, _FLOAT_TYPE) - static_assert.assert_not_contains_unsecure_aggregation(aggregator.next) - - def test_unweighted_secure_aggregator_only_contains_secure_aggregation(self): - aggregator = model_update_aggregator.secure_aggregator( - weighted=False - ).create(_FLOAT_MATRIX_TYPE) - static_assert.assert_not_contains_unsecure_aggregation(aggregator.next) - - def test_ddp_secure_aggregator_only_contains_secure_aggregation(self): - aggregator = model_update_aggregator.ddp_secure_aggregator( - noise_multiplier=1e-2, expected_clients_per_round=10 - ).create(_FLOAT_MATRIX_TYPE) - static_assert.assert_not_contains_unsecure_aggregation(aggregator.next) + return iterative_process.IterativeProcess(aggregator.initialize, next_fn) - @parameterized.named_parameters( - ('zeroing_float', True, _FLOAT_TYPE), - ('zeroing_float_matrix', True, _FLOAT_MATRIX_TYPE), - ('no_zeroing_float', False, _FLOAT_TYPE), - ('no_zeroing_float_matrix', False, _FLOAT_MATRIX_TYPE), - ) - def test_ddp_secure_aggregator_unweighted(self, zeroing, dtype): - aggregator = model_update_aggregator.ddp_secure_aggregator( - noise_multiplier=1e-2, - expected_clients_per_round=10, - bits=16, - zeroing=zeroing, - ) - self.assertIsInstance(aggregator, factory.UnweightedAggregationFactory) - process = aggregator.create(dtype) - self.assertIsInstance(process, aggregation_process.AggregationProcess) - self.assertFalse(process.is_weighted) +class ModelUpdateAggregatorTest(parameterized.TestCase): @parameterized.named_parameters( ('simple', False, False, None), @@ -248,10 +88,10 @@ def test_ddp_secure_aggregator_unweighted(self, zeroing, dtype): debug_measurements.add_debug_measurements, ), ) - def test_compression_aggregator_weighted( + def test_robust_aggregator_weighted( self, zeroing, clipping, debug_measurements_fn ): - factory_ = model_update_aggregator.compression_aggregator( + factory_ = model_update_aggregator.robust_aggregator( zeroing=zeroing, clipping=clipping, debug_measurements_fn=debug_measurements_fn, @@ -280,10 +120,10 @@ def test_compression_aggregator_weighted( debug_measurements.add_debug_measurements, ), ) - def test_compression_aggregator_unweighted( + def test_robust_aggregator_unweighted( self, zeroing, clipping, debug_measurements_fn ): - factory_ = model_update_aggregator.compression_aggregator( + factory_ = model_update_aggregator.robust_aggregator( zeroing=zeroing, clipping=clipping, weighted=False, @@ -309,10 +149,10 @@ def test_compression_aggregator_unweighted( debug_measurements.add_debug_measurements_with_mixed_dtype, ), ) - def test_compression_aggregator_weighted_mixed_dtype( + def test_robust_aggregator_weighted_mixed_dtype( self, zeroing, clipping, debug_measurements_fn ): - factory_ = model_update_aggregator.compression_aggregator( + factory_ = model_update_aggregator.robust_aggregator( zeroing=zeroing, clipping=clipping, debug_measurements_fn=debug_measurements_fn, @@ -323,7 +163,7 @@ def test_compression_aggregator_weighted_mixed_dtype( self.assertIsInstance(process, aggregation_process.AggregationProcess) self.assertTrue(process.is_weighted) - def test_wrong_debug_measurements_fn_compression_aggregator(self): + def test_wrong_debug_measurements_fn_robust_aggregator(self): """Expect error if debug_measurements_fn is wrong.""" with self.assertRaises(TypeError): @@ -337,7 +177,7 @@ def wrong_debug_measurements_fn( ) ) - model_update_aggregator.compression_aggregator( + model_update_aggregator.robust_aggregator( debug_measurements_fn=wrong_debug_measurements_fn ) @@ -368,74 +208,6 @@ def test_robust_aggregator(self): ) self._check_aggregated_scalar_count(aggregator, 60000 * 1.01, 60000) - def test_dp_aggregator(self): - aggregator = model_update_aggregator.dp_aggregator(0.01, 10).create( - _FLOAT_MATRIX_TYPE - ) - self._check_aggregated_scalar_count(aggregator, 60000 * 1.01, 60000) - - def test_secure_aggregator(self): - aggregator = model_update_aggregator.secure_aggregator().create( - _FLOAT_MATRIX_TYPE, _FLOAT_TYPE - ) - mrf = self._check_aggregated_scalar_count(aggregator, 60000 * 1.01, 60000) - - # The MapReduceForm should be using secure aggregation. - self.assertTrue(mrf.securely_aggregates_tensors) - - def test_compression_aggregator(self): - aggregator = model_update_aggregator.compression_aggregator().create( - _FLOAT_MATRIX_TYPE, _FLOAT_TYPE - ) - # Default compression should reduce the size aggregated by more than 60%. - self._check_aggregated_scalar_count(aggregator, 60000 * 0.4) - - def test_ddp_secure_aggregator(self): - self.skipTest('b/305747127') - aggregator = model_update_aggregator.ddp_secure_aggregator( - noise_multiplier=1e-2, expected_clients_per_round=10 - ).create(_FLOAT_MATRIX_TYPE) - # The Hadmard transform requires padding to next power of 2 - mrf = self._check_aggregated_scalar_count(aggregator, 2**16 * 1.01, 60000) - - # The MapReduceForm should be using secure aggregation. - self.assertTrue(mrf.securely_aggregates_tensors) - - -def _mrfify_aggregator(aggregator): - """Makes aggregator compatible with MapReduceForm.""" - - if aggregator.is_weighted: - - @federated_computation.federated_computation( - aggregator.next.type_signature.parameter[0], - computation_types.FederatedType( - ( - aggregator.next.type_signature.parameter[1].member, - aggregator.next.type_signature.parameter[2].member, - ), - placements.CLIENTS, - ), - ) - def next_fn(state, value): - output = aggregator.next(state, value[0], value[1]) - return output.state, intrinsics.federated_zip( - (output.result, output.measurements) - ) - - else: - - @federated_computation.federated_computation( - aggregator.next.type_signature.parameter - ) - def next_fn(state, value): - output = aggregator.next(state, value) - return output.state, intrinsics.federated_zip( - (output.result, output.measurements) - ) - - return iterative_process.IterativeProcess(aggregator.initialize, next_fn) - if __name__ == '__main__': absltest.main()