diff --git a/RELEASE.md b/RELEASE.md index 3b8fc77b78..364f7a0e72 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -8,6 +8,24 @@ and this project adheres to ## Unreleased +### Added + +* The `dp_noise_mechanisms` header and source files: contains functions that + generate `differential_privacy::LaplaceMechanism` or + `differential_privacy::GaussianMechanism`, based upon privacy parameters and + norm bounds. Each of these functions return a `DPHistogramBundle` struct, + which contains the mechanism, the threshold needed for DP open-domain + histograms, and a boolean indicating whether Laplace noise was used. + +### Changed + +* `DPClosedDomainHistogram::Report` and `DPOpenDomainHistogram::Report`: they + both use the `DPHistogramBundles` produced by the `CreateDPHistogramBundle` + function in `dp_noise_mechanisms`. +* `DPGroupByFactory::CreateInternal`: when `delta` is not provided, check if + the right norm bounds are provided to compute L1 sensitivity (for the + Laplace mech). + ## Release 0.84.0 ### Added diff --git a/tensorflow_federated/cc/core/impl/aggregation/core/BUILD b/tensorflow_federated/cc/core/impl/aggregation/core/BUILD index 887a20a032..704001b424 100644 --- a/tensorflow_federated/cc/core/impl/aggregation/core/BUILD +++ b/tensorflow_federated/cc/core/impl/aggregation/core/BUILD @@ -119,6 +119,33 @@ cc_library( deps = [":tensor"], ) +cc_library( + name = "dp_noise_mechanisms", + srcs = ["dp_noise_mechanisms.cc"], + hdrs = [ + "dp_noise_mechanisms.h", + ], + deps = [ + ":dp_fedsql_constants", + "//tensorflow_federated/cc/core/impl/aggregation/base", + "@com_google_absl//absl/status:statusor", + "@com_google_cc_differential_privacy//algorithms:numerical-mechanisms", + "@com_google_cc_differential_privacy//algorithms:partition-selection", + ], +) + +cc_test( + name = "dp_noise_mechanisms_test", + srcs = ["dp_noise_mechanisms_test.cc"], + deps = [ + ":dp_fedsql_constants", + ":dp_noise_mechanisms", + "//tensorflow_federated/cc/core/impl/aggregation/base", + "//tensorflow_federated/cc/testing:oss_test_main", + "//tensorflow_federated/cc/testing:status_matchers", + ], +) + # TODO: b/352020454 - Create one library per cc & hh pair. Make them aggregation_cores deps. cc_library( name = "aggregation_cores", @@ -148,6 +175,7 @@ cc_library( ":agg_core_cc_proto", ":aggregator", ":dp_fedsql_constants", + ":dp_noise_mechanisms", ":fedsql_constants", ":intrinsic", ":tensor", @@ -159,11 +187,9 @@ cc_library( "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/log:check", "@com_google_absl//absl/random", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_cc_differential_privacy//algorithms:numerical-mechanisms", - "@com_google_cc_differential_privacy//algorithms:partition-selection", ], alwayslink = 1, ) diff --git a/tensorflow_federated/cc/core/impl/aggregation/core/dp_closed_domain_histogram.cc b/tensorflow_federated/cc/core/impl/aggregation/core/dp_closed_domain_histogram.cc index e61e8e342e..71bbcab541 100644 --- a/tensorflow_federated/cc/core/impl/aggregation/core/dp_closed_domain_histogram.cc +++ b/tensorflow_federated/cc/core/impl/aggregation/core/dp_closed_domain_histogram.cc @@ -22,11 +22,14 @@ #include #include "absl/container/fixed_array.h" +#include "algorithms/numerical-mechanisms.h" #include "tensorflow_federated/cc/core/impl/aggregation/base/monitoring.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/agg_core.pb.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/composite_key_combiner.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/datatype.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/dp_composite_key_combiner.h" +#include "tensorflow_federated/cc/core/impl/aggregation/core/dp_fedsql_constants.h" +#include "tensorflow_federated/cc/core/impl/aggregation/core/dp_noise_mechanisms.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/group_by_aggregator.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/input_tensor_list.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/intrinsic.h" @@ -41,23 +44,26 @@ namespace tensorflow_federated { namespace aggregation { +using differential_privacy::NumericalMechanism; + namespace { // Given a tensor containing a column of aggregates and an ordinal, push the // aggregate associated with that ordinal to the back of a MutableVectorData // container. If the ordinal is kNoOrdinal, push 0 instead. +// Adds noise if a mechanism is provided. template void CopyAggregateFromColumn(const Tensor& column_of_aggregates, - int64_t ordinal, MutableVectorData& container) { - // Get the aggregate we will be copying over if it exists. - // Future CL will initialize value to a random number instead of 0. - T value = 0; + int64_t ordinal, MutableVectorData& container, + NumericalMechanism* mechanism) { + T value = (mechanism == nullptr) ? 0 : mechanism->AddNoise(/*result=*/0); if (ordinal != kNoOrdinal) { value += column_of_aggregates.AsSpan()[ordinal]; } - // Add to the container. + // Add (possibly noisy) value to the container. container.push_back(value); } + } // namespace DPClosedDomainHistogram::DPClosedDomainHistogram( @@ -135,6 +141,35 @@ StatusOr DPClosedDomainHistogram::Report() && { domain_size *= domain_tensor.num_elements(); } + // Make a noise mechanism for each aggregation. + std::vector> mechanisms; + for (int i = 0; i < intrinsics().size(); ++i) { + const Intrinsic& intrinsic = intrinsics()[i]; + // Do not bother making mechanism if epsilon is too large. + if (epsilon_per_agg_ >= kEpsilonThreshold) { + mechanisms.push_back(nullptr); + laplace_was_used_.push_back(false); + continue; + } + + // Get norm bounds for the ith aggregation. + double linfinity_bound = + intrinsic.parameters[kLinfinityIndex].CastToScalar(); + double l1_bound = intrinsic.parameters[kL1Index].CastToScalar(); + double l2_bound = intrinsic.parameters[kL2Index].CastToScalar(); + + // Create a noise mechanism out of those norm bounds and privacy params. + TFF_ASSIGN_OR_RETURN( + DPHistogramBundle noise_mechanism, + CreateDPHistogramBundle(epsilon_per_agg_, delta_per_agg_, l0_bound_, + linfinity_bound, l1_bound, l2_bound, + /*open_domain=*/false)); + mechanisms.push_back(std::move(noise_mechanism.mechanism)); + + // Record whether Laplace will be used. + laplace_was_used_.push_back(noise_mechanism.use_laplace); + } + // Create MutableVectorData containers, one for each output tensor, that are // each big enough to hold domain_size elements. // If all output tensors had the same type like int64_t we could create an @@ -154,7 +189,13 @@ StatusOr DPClosedDomainHistogram::Report() && { do { // Each composite key is associated with a row of the output. i-th entry of // that row will be written to i-th entry of noisy_aggregate_data. + + // Maintain the index of the next key to output. int64_t key_to_output = 0; + // Maintain the index of the next mechanism to use. + int64_t mech_to_use = 0; + + // Loop to populate the row of the output for the current composite key. for (int64_t i = 0; i < noisy_aggregate_data.size(); i++) { // Get the TensorData container we will be writing to (a column). TensorData& container = *(noisy_aggregate_data[i]); @@ -198,7 +239,11 @@ StatusOr DPClosedDomainHistogram::Report() && { column_of_aggregates.dtype(), T, CopyAggregateFromColumn( column_of_aggregates, ordinal, - dynamic_cast&>(container))); + dynamic_cast&>(container), + mechanisms[mech_to_use].get())); + + // Move on to the next mechanism. + mech_to_use++; } } } while (IncrementDomainIndices(domain_indices)); diff --git a/tensorflow_federated/cc/core/impl/aggregation/core/dp_closed_domain_histogram_test.cc b/tensorflow_federated/cc/core/impl/aggregation/core/dp_closed_domain_histogram_test.cc index 5e232fc404..1f386287e1 100644 --- a/tensorflow_federated/cc/core/impl/aggregation/core/dp_closed_domain_histogram_test.cc +++ b/tensorflow_federated/cc/core/impl/aggregation/core/dp_closed_domain_histogram_test.cc @@ -102,9 +102,9 @@ std::vector CreateTopLevelParameters(EpsilonType epsilon, // First tensor contains the names of keys int64_t num_keys = key_types.size(); - MutableVectorData key_names(num_keys); + MutableVectorData key_names; for (int i = 0; i < num_keys; i++) { - key_names[i] = "key" + std::to_string(i); + key_names.push_back(absl::StrCat("key", i)); } parameters.push_back( Tensor::Create(DT_STRING, {num_keys}, @@ -171,8 +171,8 @@ Intrinsic CreateInnerIntrinsic(InputType linfinity_bound, double l1_bound, } template -Intrinsic CreateIntrinsic(double epsilon = 1000.0, double delta = 0.001, - int64_t l0_bound = 100, +Intrinsic CreateIntrinsic(double epsilon = kEpsilonThreshold, + double delta = 0.001, int64_t l0_bound = 100, InputType linfinity_bound = 100, double l1_bound = -1, double l2_bound = -1, std::vector key_types = {DT_STRING}) { @@ -327,12 +327,15 @@ TEST(DPClosedDomainHistogramTest, CatchInnerParameters_WrongTypes) { {CreateTensorSpec("key_out", DT_STRING)}, {CreateTopLevelParameters()}, {}}; - intrinsic1.nested_intrinsics.push_back(Intrinsic{ - kDPSumUri, - {CreateTensorSpec("value", DT_INT64)}, - {CreateTensorSpec("value", DT_INT64)}, - {CreateGenericDPGFSParameters("x", -1, -1)}, - {}}); + intrinsic1.nested_intrinsics.push_back( + Intrinsic{kDPSumUri, + {CreateTensorSpec("value", DT_INT64)}, + {CreateTensorSpec("value", DT_INT64)}, + {CreateGenericDPGFSParameters( + /*linfinity_bound=*/"x", + /*l1_bound=*/-1, + /*l2_bound=*/-1)}, + {}}); auto aggregator_status1 = CreateTensorAggregator(intrinsic1); EXPECT_THAT(aggregator_status1, StatusIs(INVALID_ARGUMENT)); EXPECT_THAT(aggregator_status1.status().message(), @@ -372,39 +375,46 @@ TEST(DPClosedDomainHistogramTest, CatchInnerParameters_WrongTypes) { HasSubstr("numerical Tensors")); } -TEST(DPOpenDomainHistogramTest, CatchInvalidParameterValues) { +TEST(DPClosedDomainHistogramTest, CatchInvalidParameterValues) { // Negative epsilon - Intrinsic intrinsic0 = CreateIntrinsic(-1, 0.001, 10); + Intrinsic intrinsic0 = CreateIntrinsic(/*epsilon=*/-1, + /*delta=*/0.001, + /*l0_bound=*/10); auto bad_epsilon = CreateTensorAggregator(intrinsic0).status(); EXPECT_THAT(bad_epsilon, StatusIs(INVALID_ARGUMENT)); EXPECT_THAT(bad_epsilon.message(), HasSubstr("Epsilon must be positive")); // Delta too large Intrinsic intrinsic2 = - CreateIntrinsic(1, 2, 10, 10, -1, -1); + CreateIntrinsic(1, /*delta=*/2, 10, 10, -1, -1); auto bad_delta1 = CreateTensorAggregator(intrinsic2).status(); EXPECT_THAT(bad_delta1, StatusIs(INVALID_ARGUMENT)); EXPECT_THAT(bad_delta1.message(), HasSubstr("delta must be less than 1")); // Missing norm bounds Intrinsic intrinsic1 = - CreateIntrinsic(1, 0.001, 3, -1, -1, -1); + CreateIntrinsic(1, 0.001, 3, /*linfinity_bound=*/-1, + /*l1_bound=*/-1, + /*l2_bound=*/-1); auto bad_bounds = CreateTensorAggregator(intrinsic1).status(); EXPECT_THAT(bad_bounds, StatusIs(INVALID_ARGUMENT)); EXPECT_THAT(bad_bounds.message(), HasSubstr("either an L1 bound, an L2 bound," " or both Linfinity and L0 bounds")); - // Delta not provided (implied 0) and L1 bound not provided + // Delta not provided and only an L2 bound was provided Intrinsic intrinsic3 = - CreateIntrinsic(1, -1, 10, 10, -1, -1); + CreateIntrinsic(1, /*delta=*/-1, -1, -1, -1, + /*l2_bound=*/3); auto bad_delta2 = CreateTensorAggregator(intrinsic3).status(); EXPECT_THAT(bad_delta2, StatusIs(INVALID_ARGUMENT)); - EXPECT_THAT(bad_delta2.message(), - HasSubstr("either a positive delta or an L1 bound")); + EXPECT_THAT( + bad_delta2.message(), + HasSubstr("either a positive delta or one of the following: " + "(a) an L1 bound (b) an Linfinity bound and an L0 bound")); } -// Second batch of tests validate the aggregator itself. +// Second batch of tests validate the aggregator itself, without DP noise. // Make sure we can successfully create a DPClosedDomainHistogram object. TEST(DPClosedDomainHistogramTest, CreateAggregator_Success) { @@ -426,12 +436,13 @@ TEST(DPClosedDomainHistogramTest, CreateAggregator_Success) { EXPECT_EQ(domain_tensors[0].AsSpan()[2], "c"); } -// Make sure the Report is what we expect. +// Make sure the Report without DP noise contains all composite keys and their +// aggregations. // One key taking values in the set {"a", "b", "c"} TEST(DPClosedDomainHistogramTest, NoiselessReport_OneKey) { // Create intrinsic with one string key ({"a", "b", "c"} is default domain) - Intrinsic intrinsic = - CreateIntrinsic(1, 0.001, 10, 10, -1, -1, {DT_STRING}); + Intrinsic intrinsic = CreateIntrinsic( + kEpsilonThreshold, 0.001, 10, 10, -1, -1, {DT_STRING}); // Create a DPClosedDomainHistogram object auto status = CreateTensorAggregator(intrinsic); TFF_EXPECT_OK(status); @@ -458,8 +469,7 @@ TEST(DPClosedDomainHistogramTest, NoiselessReport_OneKey) { auto report_status = std::move(*agg).Report(); TFF_EXPECT_OK(report_status); auto& report = report_status.value(); - EXPECT_EQ(report.size(), 2); - EXPECT_EQ(report[0].shape(), TensorShape({3})); + ASSERT_EQ(report.size(), 2); EXPECT_THAT(report[0], IsTensor({3}, {"a", "b", "c"})); EXPECT_THAT(report[1], IsTensor({3}, {5, 0, 1})); } @@ -468,7 +478,7 @@ TEST(DPClosedDomainHistogramTest, NoiselessReport_OneKey) { // Number of possible composite keys is 9 = 3 * 3. TEST(DPClosedDomainHistogramTest, NoiselessReport_TwoKeys) { Intrinsic intrinsic = CreateIntrinsic( - 1, 0.001, 10, 10, -1, -1, {DT_STRING, DT_INT64}); + kEpsilonThreshold, 0.001, 10, 10, -1, -1, {DT_STRING, DT_INT64}); // Create a DPClosedDomainHistogram object auto status = CreateTensorAggregator(intrinsic); TFF_EXPECT_OK(status); @@ -499,7 +509,7 @@ TEST(DPClosedDomainHistogramTest, NoiselessReport_TwoKeys) { TFF_EXPECT_OK(report_status); auto& report = report_status.value(); // three tensors (columns): first key, second key, aggregation - EXPECT_EQ(report.size(), 3); + ASSERT_EQ(report.size(), 3); // first key: letters cycle as a, b, c, a, b, c, a, b, c EXPECT_THAT(report[0], IsTensor({9}, {"a", "b", "c", "a", "b", @@ -516,7 +526,9 @@ TEST(DPClosedDomainHistogramTest, NoiselessReport_TwoKeys) { // Same as above except we do not output the key that takes numerical values. TEST(DPClosedDomainHistogramTest, NoiselessReport_TwoKeys_DropSecondKey) { Intrinsic intrinsic = CreateIntrinsic( - 1, 0.001, 10, 10, -1, -1, {DT_STRING, DT_INT64}); + /*epsilon=*/kEpsilonThreshold, /*delta=*/0.001, /*l0_bound=*/10, + /*linfinity_bound=*/10, /*l1_bound=*/-1, /*l2_bound=*/-1, + /*key_types=*/{DT_STRING, DT_INT64}); intrinsic.outputs[1] = CreateTensorSpec("", DT_INT64); // Create a DPClosedDomainHistogram object @@ -549,7 +561,7 @@ TEST(DPClosedDomainHistogramTest, NoiselessReport_TwoKeys_DropSecondKey) { TFF_EXPECT_OK(report_status); auto& report = report_status.value(); // two tensors (columns): first key of letters, then aggregation - EXPECT_EQ(report.size(), 2); + ASSERT_EQ(report.size(), 2); // first key: letters cycle as a, b, c, a, b, c, a, b, c EXPECT_THAT(report[0], IsTensor({9}, {"a", "b", "c", "a", "b", @@ -559,6 +571,47 @@ TEST(DPClosedDomainHistogramTest, NoiselessReport_TwoKeys_DropSecondKey) { // (a0 is the composite key at index 0, c1 is at index 5, a2 is at index 6) EXPECT_THAT(report[1], IsTensor({9}, {-3, 0, 0, 0, 0, 1, 5, 0, 0})); } + +// Third: Check that noise is added. the noised sum should not be the same as +// the unnoised sum. The chance of a false negative shrinks with epsilon. +TEST(DPClosedDomainHistogramTest, NoiseAddedForSmallEpsilons) { + Intrinsic intrinsic = + CreateIntrinsic(/*epsilon=*/0.05, + /*delta=*/1e-8, + /*l0_bound=*/2, + /*linfinity_bound=*/1); + auto aggregator = CreateTensorAggregator(intrinsic).value(); + int num_inputs = 4000; + for (int i = 0; i < num_inputs; i++) { + Tensor keys = + Tensor::Create(DT_STRING, {2}, CreateTestData({"a", "b"})) + .value(); + Tensor values = + Tensor::Create(DT_INT32, {2}, CreateTestData({1, 1})).value(); + auto acc_status = aggregator->Accumulate({&keys, &values}); + EXPECT_THAT(acc_status, IsOk()); + } + EXPECT_EQ(aggregator->GetNumInputs(), num_inputs); + EXPECT_TRUE(aggregator->CanReport()); + + auto report = std::move(*aggregator).Report(); + EXPECT_THAT(report, IsOk()); + + // There must be 2 columns, one for keys and one for aggregated values. + ASSERT_EQ(report->size(), 2); + + const auto& values = report.value()[1].AsSpan(); + + // There must be 3 rows, one per key (a, b, c) + ASSERT_EQ(values.size(), 3); + + // We expect that there is some perturbation in the output. + // The values for a and b should be num_inputs +/- noise, while the value for + // c should be 0 +/- noise. + EXPECT_TRUE(values[0] != num_inputs && values[1] != num_inputs && + values[2] != 0); +} + } // namespace } // namespace aggregation } // namespace tensorflow_federated diff --git a/tensorflow_federated/cc/core/impl/aggregation/core/dp_group_by_factory.cc b/tensorflow_federated/cc/core/impl/aggregation/core/dp_group_by_factory.cc index cb232c61f8..2d71f3ccc3 100644 --- a/tensorflow_federated/cc/core/impl/aggregation/core/dp_group_by_factory.cc +++ b/tensorflow_federated/cc/core/impl/aggregation/core/dp_group_by_factory.cc @@ -17,6 +17,7 @@ #include "tensorflow_federated/cc/core/impl/aggregation/core/dp_group_by_factory.h" #include +#include #include #include #include @@ -218,10 +219,12 @@ StatusOr> DPGroupByFactory::CreateInternal( "intrinsic must provide a positive Linfinity bound."; } } else { - // Either L1 is positive, or L2 is positive, or both Linfinity is positive - // and L0 is positive. - bool has_l1_bound = l1 > 0; - bool has_l2_bound = l2 > 0; + // Closed-domain histogram requires either an L1 bound, an L2 bound, or + // both Linfinity and L0 bounds. + bool has_l1_bound = + l1 > 0 && l1 != std::numeric_limits::infinity(); + bool has_l2_bound = + l2 > 0 && l2 != std::numeric_limits::infinity(); if ((!has_linfinity_bound || l0_bound <= 0) && !has_l1_bound && !has_l2_bound) { return TFF_STATUS(INVALID_ARGUMENT) @@ -230,14 +233,15 @@ StatusOr> DPGroupByFactory::CreateInternal( "bounds."; } // If delta is 0, we will employ the Laplace mechanism (Gaussian requires - // a positive delta). But the Laplace mechanism requires a positive L1 - // bound. - // If query author did not provide any delta (which will be indicated with - // -1), we again have to use the Laplace mechanism. - if (delta <= 0 && !has_l1_bound) { + // a positive delta). But the Laplace mechanism requires a positive and + // finite L1 sensitivity. + bool has_l1_sensitivity = + has_l1_bound || (has_linfinity_bound && l0_bound > 0); + if (delta <= 0 && !has_l1_sensitivity) { return TFF_STATUS(INVALID_ARGUMENT) << "DPGroupByFactory: Closed-domain DP histograms require " - "either a positive delta or an L1 bound."; + "either a positive delta or one of the following: " + << "(a) an L1 bound (b) an Linfinity bound and an L0 bound"; } } } diff --git a/tensorflow_federated/cc/core/impl/aggregation/core/dp_noise_mechanisms.cc b/tensorflow_federated/cc/core/impl/aggregation/core/dp_noise_mechanisms.cc new file mode 100644 index 0000000000..9d77193f12 --- /dev/null +++ b/tensorflow_federated/cc/core/impl/aggregation/core/dp_noise_mechanisms.cc @@ -0,0 +1,277 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tensorflow_federated/cc/core/impl/aggregation/core/dp_noise_mechanisms.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "algorithms/numerical-mechanisms.h" +#include "algorithms/partition-selection.h" +#include "tensorflow_federated/cc/core/impl/aggregation/base/monitoring.h" +#include "tensorflow_federated/cc/core/impl/aggregation/core/dp_fedsql_constants.h" + +namespace tensorflow_federated { +namespace aggregation { +namespace internal { +using differential_privacy::GaussianPartitionSelection; +using differential_privacy::SafeAdd; + +constexpr double kMaxSensitivity = std::numeric_limits::infinity(); + +// Calculate L1 sensitivity from norm bounds, under replacement DP. +// A non-positive norm bound means it was not specified. +double CalculateL1Sensitivity(int64_t l0_bound, double linfinity_bound, + double l1_bound) { + double l1_sensitivity = kMaxSensitivity; + if (l0_bound > 0 && linfinity_bound > 0) { + l1_sensitivity = fmin(l1_sensitivity, 2.0 * l0_bound * linfinity_bound); + } + if (l1_bound > 0) { + l1_sensitivity = fmin(l1_sensitivity, 2.0 * l1_bound); + } + return l1_sensitivity; +} + +// Calculate L2 sensitivity from norm bounds, under replacement DP. +// A non-positive norm bound means it was not specified. +double CalculateL2Sensitivity(int64_t l0_bound, double linfinity_bound, + double l2_bound) { + double l2_sensitivity = kMaxSensitivity; + if (l0_bound > 0 && linfinity_bound > 0) { + l2_sensitivity = + fmin(l2_sensitivity, sqrt(2.0 * l0_bound) * linfinity_bound); + } + if (l2_bound > 0) { + l2_sensitivity = fmin(l2_sensitivity, 2.0 * l2_bound); + } + return l2_sensitivity; +} + +// Computes threshold needed when Gaussian noise is used to ensure DP of open- +// domain histograms. +absl::StatusOr CalculateGaussianThreshold( + double epsilon, double delta_for_noising, double delta_for_thresholding, + int64_t l0_sensitivity, double linfinity_bound, double l2_sensitivity) { + TFF_CHECK(epsilon > 0 && delta_for_noising > 0 && + delta_for_thresholding > 0 && l0_sensitivity > 0 && + linfinity_bound > 0 && l2_sensitivity > 0) + << "CalculateGaussianThreshold: All inputs must be positive"; + TFF_CHECK(delta_for_noising < 1) << "CalculateGaussianThreshold: " + << "delta_for_noising must be less than 1."; + TFF_CHECK(delta_for_thresholding < 1) + << "CalculateGaussianThreshold: " + << "delta_for_thresholding must be less than 1."; + + double stdev = differential_privacy::GaussianMechanism::CalculateStddev( + epsilon, delta_for_noising, l2_sensitivity); + TFF_ASSIGN_OR_RETURN(double library_threshold, + GaussianPartitionSelection::CalculateThresholdFromStddev( + stdev, delta_for_thresholding, l0_sensitivity)); + + return SafeAdd(linfinity_bound - 1, library_threshold).value; +} + +// Computes threshold needed when Laplace noise is used to ensure DP of open- +// domain histograms. +absl::StatusOr CalculateLaplaceThreshold(double epsilon, double delta, + int64_t l0_sensitivity, + double linfinity_bound, + double l1_sensitivity) { + TFF_CHECK(epsilon > 0 && delta > 0 && l0_sensitivity > 0 && + linfinity_bound > 0 && l1_sensitivity > 0) + << "CalculateLaplaceThreshold: All inputs must be positive"; + TFF_CHECK(delta < 1) << "CalculateLaplaceThreshold: delta must be less " + << "than 1"; + + // If probability of failing to drop a small value is + // 1- pow(1 - delta, 1 / l0_sensitivity) + // then the overall privacy failure probability is delta + // Below: numerically stable version of 1- pow(1 - delta, 1 / l0_sensitivity) + // Adapted from PartitionSelectionStrategy::CalculateAdjustedDelta. + double adjusted_delta = -std::expm1(log1p(-delta) / l0_sensitivity); + + double laplace_tail_bound; + if (adjusted_delta > 0.5) { + laplace_tail_bound = + (l1_sensitivity / epsilon) * std::log(2 * (1 - adjusted_delta)); + } else { + laplace_tail_bound = + -(l1_sensitivity / epsilon) * (std::log(2 * adjusted_delta)); + } + + return linfinity_bound + laplace_tail_bound; +} +} // namespace internal + +// Given parameters for an DP aggregation, create a Gaussian mechanism for that +// aggregation (or return error status). If open_domain is true, then split +// delta and compute a post-aggregation threshold. +absl::StatusOr CreateGaussianMechanism( + double epsilon, double delta, int64_t l0_bound, double linfinity_bound, + double l2_bound, bool open_domain) { + if (epsilon <= 0 || epsilon >= kEpsilonThreshold) { + return TFF_STATUS(INVALID_ARGUMENT) + << "CreateGaussianMechanism: Epsilon must be positive " + "and smaller than " + << kEpsilonThreshold; + } + if (delta <= 0 || delta >= 1) { + return TFF_STATUS(INVALID_ARGUMENT) + << "CreateGaussianMechanism: Delta must lie within (0, 1)."; + } + + // The following parameter determines how much of delta is consumed for + // thresholding (when open_domain is true). Currently set to 0.5, but this + // could be optimized down the line. + double fractionForThresholding = open_domain ? 0.5 : 0.0; + double delta_for_thresholding = delta * fractionForThresholding; + double delta_for_noising = delta - delta_for_thresholding; + + double l2_sensitivity = + internal::CalculateL2Sensitivity(l0_bound, linfinity_bound, l2_bound); + + differential_privacy::GaussianMechanism::Builder gaussian_builder; + gaussian_builder.SetL2Sensitivity(l2_sensitivity) + .SetEpsilon(epsilon) + .SetDelta(delta_for_noising); + + DPHistogramBundle dp_histogram; + TFF_ASSIGN_OR_RETURN(dp_histogram.mechanism, gaussian_builder.Build()); + dp_histogram.use_laplace = false; + + if (open_domain) { + if (l0_bound <= 0 || linfinity_bound <= 0) { + return TFF_STATUS(INVALID_ARGUMENT) + << "CreateGaussianMechanism: Open-domain DP " + "histogram algorithm requires valid l0_bound " + "and linfinity_bound."; + } + + // Calculate the threshold which we will impose on noisy sums. + // Note that l0_sensitivity = 2 * l0_bound because we target replacement DP. + TFF_ASSIGN_OR_RETURN( + dp_histogram.threshold, + internal::CalculateGaussianThreshold( + epsilon, delta_for_noising, delta_for_thresholding, + /*l0_sensitivity=*/2 * l0_bound, linfinity_bound, l2_sensitivity)); + } + return dp_histogram; +} + +// Given parameters for an DP aggregation, create a Laplace mechanism for that +// aggregation (or return error status). +absl::StatusOr CreateLaplaceMechanism( + double epsilon, double delta, int64_t l0_bound, double linfinity_bound, + double l1_bound, bool open_domain) { + if (epsilon <= 0 || epsilon >= kEpsilonThreshold) { + return TFF_STATUS(INVALID_ARGUMENT) + << "CreateLaplaceMechanism: Epsilon must be positive " + "and smaller than " + << kEpsilonThreshold; + } + + double l1_sensitivity = + internal::CalculateL1Sensitivity(l0_bound, linfinity_bound, l1_bound); + + differential_privacy::LaplaceMechanism::Builder laplace_builder; + laplace_builder.SetL1Sensitivity(l1_sensitivity).SetEpsilon(epsilon); + + DPHistogramBundle dp_histogram; + TFF_ASSIGN_OR_RETURN(dp_histogram.mechanism, laplace_builder.Build()); + dp_histogram.use_laplace = true; + + if (open_domain) { + if (delta <= 0 || delta >= 1 || l0_bound <= 0 || linfinity_bound <= 0) { + return TFF_STATUS(INVALID_ARGUMENT) + << "CreateLaplaceMechanism: Open-domain DP " + "histogram algorithm requires valid delta, " + "l0_bound, and linfinity_bound."; + } + + // Calculate the threshold which we will impose on noisy sums. + // Note that l0_sensitivity = 2 * l0_bound because we target replacement DP. + TFF_ASSIGN_OR_RETURN( + dp_histogram.threshold, + internal::CalculateLaplaceThreshold(epsilon, delta, 2 * l0_bound, + linfinity_bound, l1_sensitivity)); + } + + return dp_histogram; +} + +// Given parameters for an DP histogram aggregation, create a mechanism for that +// aggregation (or return error status). The mechanism will be either Laplace +// or Gaussian, whichever has less variance for the same DP parameters. +absl::StatusOr CreateDPHistogramBundle( + double epsilon, double delta, int64_t l0_bound, double linfinity_bound, + double l1_bound, double l2_bound, bool open_domain) { + // First we determine if we are able to make Gaussian or Laplace mechanisms + // from the given parameters. + double l1_sensitivity = + internal::CalculateL1Sensitivity(l0_bound, linfinity_bound, l1_bound); + double l2_sensitivity = + internal::CalculateL2Sensitivity(l0_bound, linfinity_bound, l2_bound); + bool laplace_is_possible = (epsilon > 0 && epsilon < kEpsilonThreshold && + l1_sensitivity != internal::kMaxSensitivity); + bool gaussian_is_possible = + (epsilon > 0 && epsilon < kEpsilonThreshold && delta > 0 && delta < 1 && + l2_sensitivity != internal::kMaxSensitivity); + + if (!laplace_is_possible && !gaussian_is_possible) { + return TFF_STATUS(INVALID_ARGUMENT) + << "CreateDPHistogramBundle: Unable to make either a Laplace or a" + " Gaussian DP mechanism. Relevant parameters:" + << "\n l0_bound: " << l0_bound + << "\n linfinity_bound: " << linfinity_bound + << "\n l1_bound: " << l1_bound << "\n l2_bound: " << l2_bound + << "\n epsilon: " << epsilon << "\n delta: " << delta; + } + + // When only one mechanism can be made, make it. + if (!laplace_is_possible && gaussian_is_possible) { + return CreateGaussianMechanism(epsilon, delta, l0_bound, linfinity_bound, + l2_bound, open_domain); + } + if (laplace_is_possible && !gaussian_is_possible) { + return CreateLaplaceMechanism(epsilon, delta, l0_bound, linfinity_bound, + l1_bound, open_domain); + } + + // When both mechanisms can be made, use the one with smaller variance. + // This is a simple heuristic that will minimize average error across the + // domain of composite keys. An alternative would be to minimize the + // maximum error using tail bounds. + + TFF_ASSIGN_OR_RETURN( + auto laplace_mechanism, + CreateLaplaceMechanism(epsilon, delta, l0_bound, linfinity_bound, + l1_bound, open_domain)); + TFF_ASSIGN_OR_RETURN( + auto gaussian_mechanism, + CreateGaussianMechanism(epsilon, delta, l0_bound, linfinity_bound, + l2_bound, open_domain)); + + if (gaussian_mechanism.mechanism->GetVariance() < + laplace_mechanism.mechanism->GetVariance()) { + return gaussian_mechanism; + } + return laplace_mechanism; +} + +} // namespace aggregation +} // namespace tensorflow_federated diff --git a/tensorflow_federated/cc/core/impl/aggregation/core/dp_noise_mechanisms.h b/tensorflow_federated/cc/core/impl/aggregation/core/dp_noise_mechanisms.h new file mode 100644 index 0000000000..083d18746e --- /dev/null +++ b/tensorflow_federated/cc/core/impl/aggregation/core/dp_noise_mechanisms.h @@ -0,0 +1,117 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_DP_NOISE_MECHANISMS_H_ +#define TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_DP_NOISE_MECHANISMS_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "algorithms/numerical-mechanisms.h" + +// This file specifies DP noise mechanisms to be used in DPOpenDomainHistogram +// and DPClosedDomainHistogram. The functions build upon the +// differential_privacy::LaplaceMechanism and ::GaussianMechanism classes. +// They calculate replacement DP sensitivity from contribution bounds, which are +// expressed as l0, linfinity, l1, and l2 norm bounds. + +namespace tensorflow_federated { +namespace aggregation { +namespace internal { + +// Calculate L1 sensitivity from norm bounds, under replacement DP. +// A non-positive norm bound means it was not specified. +double CalculateL1Sensitivity(int64_t l0_bound, double linfinity_bound, + double l1_bound); + +// Calculate L2 sensitivity from norm bounds, under replacement DP. +// A non-positive norm bound means it was not specified. +double CalculateL2Sensitivity(int64_t l0_bound, double linfinity_bound, + double l2_bound); + +// Computes threshold needed when Laplace noise is used to ensure DP of open- +// domain histograms. +// Generalizes GaussianPartitionSelection::CalculateThresholdFromStddev, which +// assumes that linfinity_bound = 1 (we do not). The only role linfinity_bound +// plays is as an additive offset, so we simply shift the number it produces to +// compute the threshold. +absl::StatusOr CalculateGaussianThreshold( + double epsilon, double delta_for_noising, double delta_for_thresholding, + int64_t l0_sensitivity, double linfinity_bound, double l2_sensitivity); + +// Computes threshold needed when Laplace noise is used to ensure DP of open- +// domain histograms. +// Generalizes LaplacePartitionSelection from partition-selection.h, since it +// permits setting norm bounds beyond l0 (max_groups_contributed). +// l0_sensitivity and l1_sensitivity measure how much one user changes the l0 +// and l1 norms, respectively, while linfinity_bound caps the magnitude of one +// user's contributions. This distinction is important for replacement DP. +absl::StatusOr CalculateLaplaceThreshold(double epsilon, double delta, + int64_t l0_sensitivity, + double linfinity_bound, + double l1_sensitivity); +} // namespace internal + +// Because of substantial overlap in the logic for closed-domain and open-domain +// histogram algorithms, the following struct is used in both places. +struct DPHistogramBundle { + // A pointer to a NumericalMechanism object which introduces noise for one + // summation that satisfies replacement DP. The distribution will either be + // Laplace or Gaussian, whichever has less variance. + std::unique_ptr mechanism = nullptr; + + // A threshold below which noisy sums will be erased. The thresholding step + // consumes some or all of the delta that a customer provides. Only used in + // the open-domain case. + std::optional threshold; + + // A boolean to indicate which noise is used. + bool use_laplace = false; +}; + +// Given parameters for a DP aggregation, create a Gaussian mechanism for that +// aggregation (or return error status). If open_domain is true, then split +// delta and compute a post-aggregation threshold. +absl::StatusOr CreateGaussianMechanism( + double epsilon, double delta, int64_t l0_bound, double linfinity_bound, + double l2_bound, bool open_domain); + +// Given parameters for a DP aggregation, create a Laplace mechanism for that +// aggregation (or return error status). +absl::StatusOr CreateLaplaceMechanism( + double epsilon, double delta, int64_t l0_bound, double linfinity_bound, + double l1_bound, bool open_domain); + +// Given parameters for a DP aggregation, create a mechanism for it (or return +// an error status). The mechanism will be either Laplace or Gaussian, whichever +// has less variance for the same DP parameters. +// If it is not possible to make a mechanism, return an error status whose +// message includes the parameters of the aggregation and the provided index of +// the aggregation. +// If open_domain is true, then also compute a post-aggregation threshold. +// +// This function can be interpreted as an version of MinVarianceMechanismBuilder +// that takes L1 and L2 sensitivities. +absl::StatusOr CreateDPHistogramBundle( + double epsilon, double delta, int64_t l0_bound, double linfinity_bound, + double l1_bound, double l2_bound, bool open_domain); + +} // namespace aggregation +} // namespace tensorflow_federated + +#endif // TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_DP_NOISE_MECHANISMS_H_ diff --git a/tensorflow_federated/cc/core/impl/aggregation/core/dp_noise_mechanisms_test.cc b/tensorflow_federated/cc/core/impl/aggregation/core/dp_noise_mechanisms_test.cc new file mode 100644 index 0000000000..83455a556d --- /dev/null +++ b/tensorflow_federated/cc/core/impl/aggregation/core/dp_noise_mechanisms_test.cc @@ -0,0 +1,286 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorflow_federated/cc/core/impl/aggregation/core/dp_noise_mechanisms.h" + +#include +#include + +#include "googlemock/include/gmock/gmock.h" +#include "googletest/include/gtest/gtest.h" +#include "tensorflow_federated/cc/core/impl/aggregation/base/monitoring.h" +#include "tensorflow_federated/cc/core/impl/aggregation/core/dp_fedsql_constants.h" +#include "tensorflow_federated/cc/testing/status_matchers.h" + +namespace tensorflow_federated { +namespace aggregation { +namespace { + +using ::testing::DoubleEq; +using ::testing::HasSubstr; +using ::testing::Ne; +constexpr double kSmallEpsilon = 0.01; + +TEST(DPNoiseMechanismsTest, CreateLaplaceMechanismMissingEpsilon) { + // The function needs epsilon. + auto missing_epsilon = CreateLaplaceMechanism(-1, 1e-8, 10, 10, 10, false); + EXPECT_THAT(missing_epsilon, StatusIs(INVALID_ARGUMENT)); + EXPECT_THAT(missing_epsilon.status().message(), + HasSubstr("Epsilon must be positive")); +} + +TEST(DPNoiseMechanismsTest, CreateLaplaceMechanismMissingL1) { + // The function returns an error if it is unable to bound the L1 sensitivity. + auto missing_l1 = CreateLaplaceMechanism(1.0, 1e-8, -1, -1, -1, false); + EXPECT_THAT(missing_l1, StatusIs(INVALID_ARGUMENT)); + EXPECT_THAT(missing_l1.status().message(), + HasSubstr("must be finite and positive")); +} + +TEST(DPNoiseMechanismsTest, CreateLaplaceMechanismMissingParamsForOpenDomain) { + // If the goal is to support open-domain histograms but there is no valid + // delta, CreateLaplaceMechanism returns an error status. Same is true if we + // are missing one of L0 and Linf. + std::string kErrorMsg = + "CreateLaplaceMechanism: Open-domain DP " + "histogram algorithm requires valid delta, " + "l0_bound, and linfinity_bound."; + + auto missing_delta = CreateLaplaceMechanism(1.0, -1, 10, 10, 10, true); + EXPECT_THAT(missing_delta, StatusIs(INVALID_ARGUMENT)); + EXPECT_THAT(missing_delta.status().message(), HasSubstr(kErrorMsg)); + + auto missing_L0 = CreateLaplaceMechanism(1.0, 1e-8, -1, 10, 10, true); + EXPECT_THAT(missing_L0, StatusIs(INVALID_ARGUMENT)); + EXPECT_THAT(missing_L0.status().message(), HasSubstr(kErrorMsg)); + + auto missing_Linf = CreateLaplaceMechanism(1.0, 1e-8, 10, -1, 10, true); + EXPECT_THAT(missing_Linf, StatusIs(INVALID_ARGUMENT)); + EXPECT_THAT(missing_Linf.status().message(), HasSubstr(kErrorMsg)); +} + +TEST(DPNoiseMechanismsTest, CreateGaussianMechanismMissingEpsilon) { + // The function needs epsilon and delta. + auto missing_epsilon = CreateGaussianMechanism(-1, 1e-8, 10, 10, 10, false); + EXPECT_THAT(missing_epsilon, StatusIs(INVALID_ARGUMENT)); + EXPECT_THAT(missing_epsilon.status().message(), + HasSubstr("Epsilon must be positive")); +} + +TEST(DPNoiseMechanismsTest, CreateGaussianMechanismMissingDelta) { + auto missing_delta = CreateGaussianMechanism(1.0, -1, 10, 10, 10, false); + EXPECT_THAT(missing_delta, StatusIs(INVALID_ARGUMENT)); + EXPECT_THAT(missing_delta.status().message(), + HasSubstr("Delta must lie within (0, 1)")); +} + +TEST(DPNoiseMechanismsTest, CreateGaussianMechanismMissingL2) { + // The function returns an error if it is unable to bound the L2 sensitivity. + auto missing_l2 = CreateGaussianMechanism(1.0, 1e-8, -1, -1, -1, false); + EXPECT_THAT(missing_l2, StatusIs(INVALID_ARGUMENT)); + EXPECT_THAT(missing_l2.status().message(), + HasSubstr("must be finite and positive")); +} + +TEST(DPNoiseMechanismsTest, CreateGaussianMechanismMissingParamsForOpenDomain) { + // If the goal is to support open-domain histograms but there is no valid + // l0_bound or linfinity_bound, CreateGaussianMechanism returns an error + // status. + std::string kErrorMsg = + "CreateGaussianMechanism: Open-domain DP " + "histogram algorithm requires valid l0_bound " + "and linfinity_bound."; + + auto missing_L0 = CreateGaussianMechanism(1.0, 1e-8, -1, 10, 10, true); + EXPECT_THAT(missing_L0, StatusIs(INVALID_ARGUMENT)); + EXPECT_THAT(missing_L0.status().message(), HasSubstr(kErrorMsg)); + + auto missing_Linf = CreateGaussianMechanism(1.0, 1e-8, 10, -1, 10, true); + EXPECT_THAT(missing_Linf, StatusIs(INVALID_ARGUMENT)); + EXPECT_THAT(missing_Linf.status().message(), HasSubstr(kErrorMsg)); +} + +// If neither L1 nor L2 sensitivity can be computed, CreateDPHistogramBundle +// returns a bad status. Same is true if epsilon is invalid (< 0 or too big). +TEST(DPNoiseMechanismsTest, CreateDPHistogramBundleCatchesBadParameters) { + std::string kErrorMsg = + "CreateDPHistogramBundle: Unable to make either a " + "Laplace or a Gaussian DP mechanism"; + // No norm bounds -> no sensitivity bound + auto missing_bounds = + CreateDPHistogramBundle(1.0, 1e-8, -1, -1, -1, -1, false); + EXPECT_THAT(missing_bounds, StatusIs(INVALID_ARGUMENT)); + EXPECT_THAT(missing_bounds.status().message(), HasSubstr(kErrorMsg)); + + // Negative epsilon + auto negative_epsilon = + CreateDPHistogramBundle(-1, 1e-8, 10, 10, 10, 10, false); + EXPECT_THAT(negative_epsilon, StatusIs(INVALID_ARGUMENT)); + EXPECT_THAT(negative_epsilon.status().message(), HasSubstr(kErrorMsg)); + + // Too large epsilon + auto too_large_epsilon = + CreateDPHistogramBundle(kEpsilonThreshold, 1e-8, 10, 10, 10, 10, false); + EXPECT_THAT(too_large_epsilon, StatusIs(INVALID_ARGUMENT)); + EXPECT_THAT(too_large_epsilon.status().message(), HasSubstr(kErrorMsg)); +} + +// If sensitivities and epsilon are valid, check that the function correctly +// switches between distributions. + +// If only the L1 norm bound is given, Laplace should be used. +TEST(DPNoiseMechanismsTest, CreateDPHistogramBundleUsesLaplaceOnlyL1) { + auto laplace_mechanism = CreateDPHistogramBundle(/*epsilon=*/1.0, + /*delta=*/1e-8, + /*l0_bound=*/-1, + /*linfinity_bound=*/-1, + /*l1_bound=*/10, + /*l2_bound=*/-1, false); + EXPECT_THAT(laplace_mechanism, IsOk()); + EXPECT_TRUE(laplace_mechanism.value().use_laplace); + EXPECT_THAT(laplace_mechanism.value().mechanism->GetVariance(), + DoubleEq(800)); +} + +// If both Laplace and Gaussian can be used, Laplace should be used if its +// variance is smaller. +TEST(DPNoiseMechanismsTest, CreateDPHistogramBundleUsesLaplaceWhenAppropriate) { + // L0 and Linf bounds are given, with Laplace variance smaller than Gaussian. + auto agg1 = CreateDPHistogramBundle(1.0, 1e-10, 2, 10, -1, -1, false); + EXPECT_THAT(agg1, IsOk()); + EXPECT_TRUE(agg1.value().use_laplace); + + // L1 and L2 bounds are given, with Laplace variance smaller than Gaussian. + auto agg2 = CreateDPHistogramBundle(1.0, 1e-8, -1, -1, 10, 10, false); + EXPECT_THAT(agg2, IsOk()); + EXPECT_TRUE(agg2.value().use_laplace); +} + +// If only the L2 norm bound is given, Gaussian should be used. +TEST(DPNoiseMechanismsTest, CreateDPHistogramBundleUsesGaussianOnlyL2) { + auto gaussian_mechanism = CreateDPHistogramBundle(/*epsilon=*/1.0, + /*delta=*/1e-8, + /*l0_bound=*/-1, + /*linfinity_bound=*/-1, + /*l1_bound=*/-1, + /*l2_bound=*/10, false); + EXPECT_THAT(gaussian_mechanism, IsOk()); + EXPECT_FALSE(gaussian_mechanism.value().use_laplace); +} + +TEST(DPNoiseMechanismsTest, + CreateDPHistogramBundleUsesGaussianWhenAppropriate) { + // L0 and Linf plus L2 bound are given; Laplace variance larger than Gaussian. + auto agg1 = CreateDPHistogramBundle(1.0, 1e-10, /*l0_bound=*/2, + /*linfinity_bound=*/10, -1, 2, false); + EXPECT_THAT(agg1, IsOk()); + EXPECT_FALSE(agg1.value().use_laplace); + + // L1 and L2 bounds are given, with Gaussian variance smaller than Laplace. + auto agg2 = CreateDPHistogramBundle(1.0, 1e-8, -1, -1, 46, 10, false); + EXPECT_THAT(agg2, IsOk()); + EXPECT_FALSE(agg2.value().use_laplace); +} + +TEST(DPNoiseMechanismsTest, CreateDPHistogramBundleUsesGaussianForLargeL0) { + // If a user can contribute to L0 = x groups and there is only an L_inf bound, + // Laplace noise is linear in x while Gaussian noise scales with sqrt(x). + // Hence, we should use Gaussian when we loosen x (from 2 to 20) + auto agg3 = CreateDPHistogramBundle(1.0, /*delta=*/1e-10, /*l0_bound=*/20, + /*linfinity_bound=*/10, -1, -1, false); + EXPECT_THAT(agg3, IsOk()); + EXPECT_FALSE(agg3.value().use_laplace); +} + +TEST(DPNoiseMechanismsTest, CreateDPHistogramBundleUsesGaussianForLargeDelta) { + // Gaussian noise should also be used if delta was loosened enough + auto agg4 = CreateDPHistogramBundle(1.0, /*delta=*/1e-3, /*l0_bound=*/2, + /*linfinity_bound=*/10, -1, -1, false); + EXPECT_THAT(agg4, IsOk()); + EXPECT_FALSE(agg4.value().use_laplace); +} + +// Check that noise is added at all: the noised sum should not be the same as +// the unnoised sum. The chance of a false negative shrinks with epsilon. +TEST(DPNoiseMechanismsTest, LaplaceNoiseAddedForSmallEpsilons) { + // Laplace + auto bundle = CreateDPHistogramBundle(kSmallEpsilon, 1e-8, -1, -1, + /*l1_bound=*/1, -1, false); + EXPECT_THAT(bundle, IsOk()); + int val = 1000; + auto noisy_val = bundle.value().mechanism->AddNoise(val); + EXPECT_THAT(noisy_val, Ne(val)); +} + +TEST(DPNoiseMechanismsTest, GaussianNoiseAddedForSmallEpsilons) { + // Gaussian + auto bundle = CreateDPHistogramBundle(kSmallEpsilon, 1e-8, -1, -1, -1, + /*l2_bound=*/1, false); + EXPECT_THAT(bundle, IsOk()); + int val = 1000; + auto noisy_val = bundle.value().mechanism->AddNoise(val); + EXPECT_THAT(noisy_val, Ne(val)); +} + +// Check that CalculateLaplaceThreshold computes the right threshold +// Case 1: adjusted delta less than 1/2 +TEST(DPNoiseMechanismsTest, CalculateLaplaceThresholdSucceedsSmallDelta) { + double delta = 0.468559; // = 1-(9/10)^6 + double linfinity_bound = 1; + int64_t l0_bound = 1; + + // under replacement DP: + int64_t l0_sensitivity = 2 * l0_bound; + double l1_sensitivity = 2; // = min(2 * l0_bound * linf_bound, 2 * l1_bound) + + // We'll work with eps = 1 for simplicity + auto threshold_wrapper = internal::CalculateLaplaceThreshold( + /*epsilon=*/1.0, delta, l0_sensitivity, linfinity_bound, l1_sensitivity); + TFF_ASSERT_OK(threshold_wrapper.status()); + + double laplace_tail_bound = 1.22497855; + // = -(l1_sensitivity / 1.0) * std::log(2.0 * adjusted_delta), + // where adjusted_delta = 1 - sqrt(1-delta) = 1 - (9/10)^3 = 1 - 0.729 = 0.271 + + EXPECT_NEAR(threshold_wrapper.value(), linfinity_bound + laplace_tail_bound, + 1e-5); +} + +// Case 2: adjusted delta greater than 1/2 +TEST(DPNoiseMechanismsTest, CalculateLaplaceThresholdSucceedsLargeDelta) { + double delta = 0.77123207545039; // 1-(9/10)^14 + + double linfinity_bound = 1; + int64_t l0_bound = 1; + + // under replacement DP: + int64_t l0_sensitivity = 2 * l0_bound; + double l1_sensitivity = 2; // = min(2 * l0_bound * linf_bound, 2 * l1_bound) + + auto threshold_wrapper = internal::CalculateLaplaceThreshold( + /*epsilon=*/1.0, delta, l0_sensitivity, linfinity_bound, l1_sensitivity); + TFF_ASSERT_OK(threshold_wrapper.status()); + + double laplace_tail_bound = -0.0887529; + // = (l1_sensitivity / 1.0) * std::log(2.0 - 2.0 * adjusted_delta), + // where adjusted_delta = 1 - sqrt(1-delta) = 1 - (9/10)^7 = 0.5217031 + EXPECT_NEAR(threshold_wrapper.value(), linfinity_bound + laplace_tail_bound, + 1e-5); +} + +} // namespace +} // namespace aggregation +} // namespace tensorflow_federated diff --git a/tensorflow_federated/cc/core/impl/aggregation/core/dp_open_domain_histogram.cc b/tensorflow_federated/cc/core/impl/aggregation/core/dp_open_domain_histogram.cc index 1b2ac91e92..64c9c13fd6 100644 --- a/tensorflow_federated/cc/core/impl/aggregation/core/dp_open_domain_histogram.cc +++ b/tensorflow_federated/cc/core/impl/aggregation/core/dp_open_domain_histogram.cc @@ -16,26 +16,22 @@ #include "tensorflow_federated/cc/core/impl/aggregation/core/dp_open_domain_histogram.h" -#include #include #include #include #include -#include #include #include #include "absl/container/flat_hash_set.h" -#include "absl/log/check.h" #include "absl/types/span.h" -#include "algorithms/numerical-mechanisms.h" -#include "algorithms/partition-selection.h" #include "tensorflow_federated/cc/core/impl/aggregation/base/monitoring.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/agg_core.pb.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/composite_key_combiner.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/datatype.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/dp_composite_key_combiner.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/dp_fedsql_constants.h" +#include "tensorflow_federated/cc/core/impl/aggregation/core/dp_noise_mechanisms.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/group_by_aggregator.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/input_tensor_list.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/intrinsic.h" @@ -50,135 +46,9 @@ namespace tensorflow_federated { namespace aggregation { -using ::differential_privacy::NumericalMechanism; using ::differential_privacy::sign; namespace internal { -using ::differential_privacy::GaussianMechanism; -using ::differential_privacy::GaussianPartitionSelection; -using ::differential_privacy::LaplaceMechanism; -using ::differential_privacy::SafeAdd; - -// Struct to contain the components of the noise and threshold algorithm: -// - A pointer to a NumericalMechanism object which introduces DP noise for one -// summation that satisfies replacement DP. The distribution will either be -// Laplace or Gaussian, whichever has less variance for the same DP parameters -// - A threshold below which noisy sums will be erased. The thresholding step -// consumes some or all of the delta that a customer provides. -// - Also holds a boolean to indicate which noise is used. -template -struct NoiseAndThresholdBundle { - std::unique_ptr mechanism; - OutputType threshold; - bool use_laplace; -}; - -// Derive NoiseAndThresholdBundle from privacy parameters and clipping norms. -template -StatusOr> SetupNoiseAndThreshold( - double epsilon, double delta, int64_t l0_bound, OutputType linfinity_bound, - double l1_bound, double l2_bound) { - // The following constraints on DP parameters should be caught beforehand in - // factory code. - TFF_CHECK(epsilon > 0 && delta > 0 && l0_bound > 0 && linfinity_bound > 0) - << "epsilon, delta, l0_bound, and linfinity_bound must be greater than 0"; - TFF_CHECK(delta < 1) << "delta must be less than 1"; - - // For Gaussian noise, the following parameter determines how much of delta is - // consumed for thresholding. Currently set to 1/2 of delta, but this could be - // optimized down the line. - constexpr double kFractionForThresholding = 0.5; - double delta_for_thresholding = delta * kFractionForThresholding; - double delta_for_noising = delta - delta_for_thresholding; - - // Compute L1 sensitivity from the L0 and Linfinity bounds. - // We target replacement DP, which means L1 sensitivity is twice the maximum - // L1 norm of any contribution. The maximum L1 norm of any contribution can - // be derived from l0_bound and linfinity_bound (or l1_bound if provided). - double l1_sensitivity = 2.0 * l0_bound * linfinity_bound; - // If an L1 bound was given and it is tighter than the above, use it. - if (l1_bound > 0 && 2.0 * l1_bound < l1_sensitivity) { - l1_sensitivity = 2.0 * l1_bound; - } - - // Repeat for L2 sensitivity. To derive the expression, consider two - // neighboring user inputs (1, 1, 1, 0, 0, 0, 0) and (0, 0, 0, 0, 1, 1, 1) - // and fix linfinity_bound = 1 & l0_bound = 3. The L2 distance between these - // vectors---and therefore the L2 sensitivity of the sum of vectors---is - // sqrt(6 = 2 * l0_bound * linfinity_bound) - double l2_sensitivity = sqrt(2.0 * l0_bound) * linfinity_bound; - if (l2_bound > 0 && 2.0 * l2_bound < l2_sensitivity) { - l2_sensitivity = 2.0 * l2_bound; - } - - NoiseAndThresholdBundle output; - - // Pick the mechanism that will add noise with smaller standard deviation. - TFF_CHECK(epsilon > 0) << "epsilon must be greater than 0"; - double laplace_scale = (1.0 / epsilon) * l1_sensitivity; - double laplace_stdev = sqrt(2) * laplace_scale; - double gaussian_stdev = GaussianMechanism::CalculateStddev( - epsilon, delta_for_noising, l2_sensitivity); - - if (laplace_stdev < gaussian_stdev) { - // If we are going to use Laplace noise, - // 1. record that fact - output.use_laplace = true; - - // 2. use our parameters to create an object that will add that noise. - LaplaceMechanism::Builder laplace_builder; - laplace_builder.SetL1Sensitivity(l1_sensitivity).SetEpsilon(epsilon); - TFF_ASSIGN_OR_RETURN(output.mechanism, laplace_builder.Build()); - - // 3. Calculate the threshold which we will impose on noisy sums. - // Note that l0_sensitivity = 2 * l0_bound because we target replacement DP. - TFF_ASSIGN_OR_RETURN( - double library_threshold, - CalculateLaplaceThreshold(epsilon, delta, 2 * l0_bound, - linfinity_bound, l1_sensitivity)); - // Use ceil to err on the side of caution: - // if noisy_val is an integer less than (double) library_threshold, - // a cast of library_threshold may make them appear equal - if (std::is_integral::value) { - library_threshold = ceil(library_threshold); - } - output.threshold = static_cast(library_threshold); - - return output; - } - - // If we are going to use Gaussian noise, - // 1. record that fact - output.use_laplace = false; - - // 2. use our parameters to create an object that will add that noise. - GaussianMechanism::Builder gaussian_builder; - gaussian_builder.SetStandardDeviation(gaussian_stdev); - TFF_ASSIGN_OR_RETURN(output.mechanism, gaussian_builder.Build()); - - // 3. Calculate the threshold which we will impose on noisy sums. We use - // GaussianPartitionSelection::CalculateThresholdFromStddev. It assumes that - // linfinity_bound = 1 but the only role linfinity_bound plays is as an - // additive offset. So we can simply shift the number it produces to compute - // the threshold. - TFF_ASSIGN_OR_RETURN( - double library_threshold, - GaussianPartitionSelection::CalculateThresholdFromStddev( - gaussian_stdev, delta_for_thresholding, 2 * l0_bound)); - // Use ceil to err on the side of caution: - // if noisy_val is an integer less than (double) library_threshold, - // a cast of library_threshold may make them appear equal - if (std::is_integral::value) { - library_threshold = ceil(library_threshold); - } - - output.threshold = - SafeAdd(linfinity_bound - 1, - static_cast(library_threshold)) - .value; - - return output; -} // Noise is added to each value stored in a column tensor. If the noised value // falls below a given threshold, then the index of that value is removed from a @@ -201,11 +71,15 @@ StatusOr NoiseAndThreshold( absl::flat_hash_set& survivor_indices, std::vector& laplace_was_used) { TFF_ASSIGN_OR_RETURN( - auto bundle, SetupNoiseAndThreshold(epsilon, delta, l0_bound, - linfinity_bound, l1_bound, l2_bound)); + auto bundle, + CreateDPHistogramBundle(epsilon, delta, l0_bound, linfinity_bound, + l1_bound, l2_bound, true)); + laplace_was_used.push_back(bundle.use_laplace); - OutputType threshold = bundle.threshold; + TFF_CHECK(bundle.threshold.has_value()) + << "NoiseAndThreshold: threshold was not set."; + OutputType threshold = static_cast(bundle.threshold.value()); auto column_span = column_tensor.AsSpan(); auto noisy_values = std::make_unique>(); diff --git a/tensorflow_federated/cc/core/impl/aggregation/core/dp_open_domain_histogram.h b/tensorflow_federated/cc/core/impl/aggregation/core/dp_open_domain_histogram.h index 826761cba2..4451456a47 100644 --- a/tensorflow_federated/cc/core/impl/aggregation/core/dp_open_domain_histogram.h +++ b/tensorflow_federated/cc/core/impl/aggregation/core/dp_open_domain_histogram.h @@ -17,12 +17,10 @@ #ifndef THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_DP_OPEN_DOMAIN_HISTOGRAM_H_ #define THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_DP_OPEN_DOMAIN_HISTOGRAM_H_ -#include #include #include #include -#include "absl/status/statusor.h" #include "tensorflow_federated/cc/core/impl/aggregation/base/monitoring.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/agg_core.pb.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/composite_key_combiner.h" @@ -39,42 +37,6 @@ namespace tensorflow_federated { namespace aggregation { -namespace internal { -// Computes threshold needed when Laplace noise is used to ensure DP. -// Generalizes LaplacePartitionSelection from partition-selection.h, since it -// permits setting norm bounds beyond l0 (max_groups_contributed). -// l0_sensitivity and l1_sensitivity measure how much one user changes the l0 -// and l1 norms, respectively, while linfinity_bound caps the magnitude of one -// user's contributions. This distinction is important for replacement DP. -template -static absl::StatusOr CalculateLaplaceThreshold( - double epsilon, double delta, int64_t l0_sensitivity, - OutputType linfinity_bound, double l1_sensitivity) { - TFF_CHECK(epsilon > 0 && delta > 0 && l0_sensitivity > 0 && - linfinity_bound > 0 && l1_sensitivity > 0) - << "CalculateThreshold: All inputs must be positive"; - TFF_CHECK(delta < 1) << "CalculateThreshold: delta must be less than 1"; - - // If probability of failing to drop a small value is - // 1- pow(1 - delta, 1 / l0_sensitivity) - // then the overall privacy failure probability is delta - // Below: numerically stable version of 1- pow(1 - delta, 1 / l0_sensitivity) - // Adapted from PartitionSelectionStrategy::CalculateAdjustedDelta. - double adjusted_delta = -std::expm1(log1p(-delta) / l0_sensitivity); - - OutputType laplace_tail_bound; - if (adjusted_delta > 0.5) { - laplace_tail_bound = static_cast( - (l1_sensitivity / epsilon) * std::log(2 * (1 - adjusted_delta))); - } else { - laplace_tail_bound = static_cast( - -(l1_sensitivity / epsilon) * (std::log(2 * adjusted_delta))); - } - - return linfinity_bound + laplace_tail_bound; -} -} // namespace internal - // DPOpenDomainHistogram is a child class of GroupByAggregator. // ::AggregateTensorsInternal enforces a bound on the number of composite keys // (ordinals) that any one aggregation can contribute to. diff --git a/tensorflow_federated/cc/core/impl/aggregation/core/dp_open_domain_histogram_test.cc b/tensorflow_federated/cc/core/impl/aggregation/core/dp_open_domain_histogram_test.cc index d238f12ffa..5fc6542e24 100644 --- a/tensorflow_federated/cc/core/impl/aggregation/core/dp_open_domain_histogram_test.cc +++ b/tensorflow_federated/cc/core/impl/aggregation/core/dp_open_domain_histogram_test.cc @@ -14,8 +14,6 @@ * limitations under the License. */ -#include "tensorflow_federated/cc/core/impl/aggregation/core/dp_open_domain_histogram.h" - #include #include #include @@ -654,10 +652,8 @@ TEST_P(DPOpenDomainHistogramTest, NoKeyTripleAggWithAllBounds) { EXPECT_THAT(result.value()[2], IsTensor({1}, {11})); } -// Sixth: check for proper noise addition. - -// Check that noise is added at all: the noised sum should not be the same as -// the unnoised sum. The chance of a false negative shrinks with epsilon. +// Sixth: Check that noise is added at all. The noised sum should not be the +// same as the unnoised sum. The odds of a false negative shrinks with epsilon. TEST_P(DPOpenDomainHistogramTest, NoiseAddedForSmallEpsilons) { Intrinsic intrinsic = CreateIntrinsic(0.05, 1e-8, 2, 1); auto dpgba = CreateTensorAggregator(intrinsic).value(); @@ -688,90 +684,6 @@ TEST_P(DPOpenDomainHistogramTest, NoiseAddedForSmallEpsilons) { EXPECT_TRUE(values[0] != num_inputs || values[1] != num_inputs); } -// Check that SetupNoiseAndThreshold is capable of switching between -// distributions -TEST_P(DPOpenDomainHistogramTest, SetupNoiseAndThreshold_CorrectDistribution) { - Intrinsic intrinsic1{kDPGroupByUri, - {CreateTensorSpec("key", DT_STRING)}, - {CreateTensorSpec("key_out", DT_STRING)}, - {CreateTopLevelParameters(1.0, 1e-10, 2)}, - {}}; - // "Baseline" aggregation where Laplace was chosen - intrinsic1.nested_intrinsics.push_back( - CreateInnerIntrinsic(10, -1, -1)); - - // Aggregation where a given L2 norm bound is sufficiently smaller than L_0 * - // L_inf, which means Gaussian is preferred. - intrinsic1.nested_intrinsics.push_back( - CreateInnerIntrinsic(10, -1, 2)); - - auto agg1 = CreateTensorAggregator(intrinsic1).value(); - auto report = std::move(*agg1).Report(); - auto laplace_was_used = - dynamic_cast(*agg1).laplace_was_used(); - ASSERT_EQ(laplace_was_used.size(), 2); - EXPECT_TRUE(laplace_was_used[0]); - EXPECT_FALSE(laplace_was_used[1]); - - // If a user can contribute to L0 = x groups and there is only an L_inf bound, - // Laplace noise is linear in x while Gaussian noise scales with sqrt(x). - // Hence, we should use Gaussian when we loosen x (from 2 to 20) - Intrinsic intrinsic2 = - CreateIntrinsic(1.0, 1e-10, 20, 10, -1, -1); - auto agg2 = CreateTensorAggregator(intrinsic2).value(); - auto report2 = std::move(*agg2).Report(); - laplace_was_used = - dynamic_cast(*agg2).laplace_was_used(); - ASSERT_EQ(laplace_was_used.size(), 1); - EXPECT_FALSE(laplace_was_used[0]); - - // Gaussian noise should also be used if delta was loosened enough - Intrinsic intrinsic3 = - CreateIntrinsic(1.0, 1e-3, 2, 10, -1, -1); - auto agg3 = CreateTensorAggregator(intrinsic3).value(); - auto report3 = std::move(*agg3).Report(); - laplace_was_used = - dynamic_cast(*agg3).laplace_was_used(); - ASSERT_EQ(laplace_was_used.size(), 1); - EXPECT_FALSE(laplace_was_used[0]); -} - -// Check that CalculateLaplaceThreshold computes the right threshold -TEST(DPOpenDomainHistogramTest, CalculateLaplaceThreshold_Succeeds) { - // Case 1: adjusted delta less than 1/2 - double delta = 0.468559; // = 1-(9/10)^6 - double linfinity_bound = 1; - int64_t l0_bound = 1; - - // under replacement DP: - int64_t l0_sensitivity = 2 * l0_bound; - double l1_sensitivity = 2; // = min(2 * l0_bound * linf_bound, 2 * l1_bound) - - // We'll work with eps = 1 for simplicity - auto threshold_wrapper = internal::CalculateLaplaceThreshold( - 1.0, delta, l0_sensitivity, linfinity_bound, l1_sensitivity); - TFF_ASSERT_OK(threshold_wrapper.status()); - - double laplace_tail_bound = 1.22497855; - // = -(l1_sensitivity / 1.0) * std::log(2.0 * adjusted_delta), - // where adjusted_delta = 1 - sqrt(1-delta) = 1 - (9/10)^3 = 1 - 0.729 = 0.271 - - EXPECT_NEAR(threshold_wrapper.value(), linfinity_bound + laplace_tail_bound, - 1e-5); - - // Case 2: adjusted delta greater than 1/2 - delta = 0.77123207545039; // 1-(9/10)^14 - threshold_wrapper = internal::CalculateLaplaceThreshold( - 1.0, delta, l0_sensitivity, linfinity_bound, l1_sensitivity); - TFF_ASSERT_OK(threshold_wrapper.status()); - - laplace_tail_bound = -0.0887529; - // = (l1_sensitivity / 1.0) * std::log(2.0 - 2.0 * adjusted_delta), - // where adjusted_delta = 1 - sqrt(1-delta) = 1 - (9/10)^7 = 0.5217031 - EXPECT_NEAR(threshold_wrapper.value(), linfinity_bound + laplace_tail_bound, - 1e-5); -} - // Seventh: check that the right groups get dropped // Test that we will drop groups with any small aggregate and keep groups with