Remove Distributed Point Functions library
This third-party library is no longer integrated with any Chromium code following https://crrev.com/c/6505925, so we can remove the library. Note that this also removes the associated fuzzer. We can't remove the third_party/highway library as well as it is now used in //third_party/blink/renderer/core/html/parser/html_document_parser_fastpath.cc. We update the OWNERS to align with that new usage. Bug: 40178420 Change-Id: I879d718fbd83d6069adef18801de0d1c13ab3037 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6507840 Reviewed-by: Anton Bikineev <bikineev@chromium.org> Reviewed-by: Andrew Grieve <agrieve@chromium.org> Reviewed-by: Nan Lin <linnan@chromium.org> Commit-Queue: Alex Turner <alexmt@chromium.org> Cr-Commit-Position: refs/heads/main@{#1456337}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
23f70fe18f
commit
a1a7802b3b
BUILD.gn
content
infra
third_party
distributed_point_functions
BUILD.gnDEPSDIR_METADATALICENSEOWNERSREADME.chromium
code
.bazelrc.clang-format.gitattributes.gitignoreBUILDCODE_OF_CONDUCT.mdCONTRIBUTING.mdLICENSEREADME.mdSECURITY.mdWORKSPACE.bazel
features.gnidpf
BUILDaes_128_fixed_key_hash.ccaes_128_fixed_key_hash.haes_128_fixed_key_hash_test.ccdistributed_point_function.ccdistributed_point_function.hdistributed_point_function.protodistributed_point_function_benchmark.ccdistributed_point_function_test.ccint_mod_n.ccint_mod_n.hint_mod_n_benchmark.ccint_mod_n_test.cc
internal
BUILDaes_128_fixed_key_hash_hwy.haes_128_fixed_key_hash_hwy_test.ccevaluate_prg_hwy.ccevaluate_prg_hwy.hevaluate_prg_hwy_test.ccget_hwy_mode.ccget_hwy_mode.hmaybe_deref_span.hmaybe_deref_span_test.ccproto_validator.ccproto_validator.hproto_validator_test.ccproto_validator_test.textprotostatus_matchers.ccstatus_matchers.hvalue_type_helpers.ccvalue_type_helpers.hvalue_type_helpers_test.cc
status_macros.htuple.htuple_test.ccxor_wrapper.hxor_wrapper_test.ccfuzz
shim
highway
1
BUILD.gn
1
BUILD.gn
@ -114,7 +114,6 @@ group("gn_all") {
|
|||||||
"//third_party/angle/src/tests:angle_end2end_tests",
|
"//third_party/angle/src/tests:angle_end2end_tests",
|
||||||
"//third_party/angle/src/tests:angle_unittests",
|
"//third_party/angle/src/tests:angle_unittests",
|
||||||
"//third_party/angle/src/tests:angle_white_box_tests",
|
"//third_party/angle/src/tests:angle_white_box_tests",
|
||||||
"//third_party/distributed_point_functions/shim:distributed_point_functions_shim_unittests",
|
|
||||||
"//third_party/flatbuffers:flatbuffers_unittests",
|
"//third_party/flatbuffers:flatbuffers_unittests",
|
||||||
"//third_party/highway:highway_tests",
|
"//third_party/highway:highway_tests",
|
||||||
"//third_party/liburlpattern:liburlpattern_unittests",
|
"//third_party/liburlpattern:liburlpattern_unittests",
|
||||||
|
@ -289,7 +289,6 @@ source_set("browser") {
|
|||||||
"//third_party/blink/public/strings",
|
"//third_party/blink/public/strings",
|
||||||
"//third_party/boringssl",
|
"//third_party/boringssl",
|
||||||
"//third_party/brotli:dec",
|
"//third_party/brotli:dec",
|
||||||
"//third_party/distributed_point_functions",
|
|
||||||
"//third_party/icu",
|
"//third_party/icu",
|
||||||
"//third_party/inspector_protocol:crdtp",
|
"//third_party/inspector_protocol:crdtp",
|
||||||
"//third_party/libyuv",
|
"//third_party/libyuv",
|
||||||
|
@ -3252,7 +3252,6 @@ test("content_unittests") {
|
|||||||
"//third_party/blink/public:test_support",
|
"//third_party/blink/public:test_support",
|
||||||
"//third_party/blink/public/common:font_enumeration_table_proto",
|
"//third_party/blink/public/common:font_enumeration_table_proto",
|
||||||
"//third_party/blink/public/common:headers",
|
"//third_party/blink/public/common:headers",
|
||||||
"//third_party/distributed_point_functions/shim:buildflags",
|
|
||||||
"//third_party/icu",
|
"//third_party/icu",
|
||||||
"//third_party/inspector_protocol:crdtp",
|
"//third_party/inspector_protocol:crdtp",
|
||||||
"//third_party/inspector_protocol:crdtp_test",
|
"//third_party/inspector_protocol:crdtp_test",
|
||||||
|
@ -481,7 +481,6 @@ third_party/crashpad/crashpad/third_party/linux 1 1
|
|||||||
third_party/crashpad/crashpad/third_party/ninja 1 1
|
third_party/crashpad/crashpad/third_party/ninja 1 1
|
||||||
third_party/crashpad/crashpad/util/misc 1 1
|
third_party/crashpad/crashpad/util/misc 1 1
|
||||||
third_party/dav1d 2 2
|
third_party/dav1d 2 2
|
||||||
third_party/distributed_point_functions/code 2 1
|
|
||||||
third_party/expat 2 2
|
third_party/expat 2 2
|
||||||
third_party/fdlibm 1 1
|
third_party/fdlibm 1 1
|
||||||
third_party/fusejs/dist 3 1
|
third_party/fusejs/dist 3 1
|
||||||
|
81
third_party/distributed_point_functions/BUILD.gn
vendored
81
third_party/distributed_point_functions/BUILD.gn
vendored
@ -1,81 +0,0 @@
|
|||||||
# Copyright 2021 The Chromium Authors
|
|
||||||
# Use of this source code is governed by a BSD-style license that can be
|
|
||||||
# found in the LICENSE file.
|
|
||||||
|
|
||||||
import("//testing/libfuzzer/fuzzer_test.gni")
|
|
||||||
import("//third_party/distributed_point_functions/features.gni")
|
|
||||||
import("//third_party/protobuf/proto_library.gni")
|
|
||||||
|
|
||||||
# This is Chromium's interface with the third-party distributed_point_functions
|
|
||||||
# library. Targets outside of //third_party/distributed_point_functions should
|
|
||||||
# depend on this target rather than using the source directly. This extra layer
|
|
||||||
# prevents macros from leaking into Chromium code via header includes.
|
|
||||||
source_set("distributed_point_functions") {
|
|
||||||
public_deps = [ "//third_party/distributed_point_functions/shim" ]
|
|
||||||
}
|
|
||||||
|
|
||||||
proto_library("proto") {
|
|
||||||
sources = [ "code/dpf/distributed_point_function.proto" ]
|
|
||||||
proto_out_dir = "third_party/distributed_point_functions/dpf"
|
|
||||||
cc_generator_options = "lite"
|
|
||||||
}
|
|
||||||
|
|
||||||
fuzzer_test("dpf_fuzzer") {
|
|
||||||
sources = [ "fuzz/dpf_fuzzer.cc" ]
|
|
||||||
deps = [ ":internal" ]
|
|
||||||
|
|
||||||
# Do not apply Chromium code rules to this third-party code.
|
|
||||||
suppressed_configs = [ "//build/config/compiler:chromium_code" ]
|
|
||||||
additional_configs = [ "//build/config/compiler:no_chromium_code" ]
|
|
||||||
|
|
||||||
additional_configs += [ ":includes" ]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Targets below this line are only visible within this file and shim/.
|
|
||||||
visibility = [
|
|
||||||
":*",
|
|
||||||
"//third_party/distributed_point_functions/shim:*",
|
|
||||||
]
|
|
||||||
|
|
||||||
config("includes") {
|
|
||||||
include_dirs = [
|
|
||||||
"code",
|
|
||||||
"$target_gen_dir",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
source_set("internal") {
|
|
||||||
sources = [
|
|
||||||
"code/dpf/aes_128_fixed_key_hash.cc",
|
|
||||||
"code/dpf/aes_128_fixed_key_hash.h",
|
|
||||||
"code/dpf/distributed_point_function.cc",
|
|
||||||
"code/dpf/distributed_point_function.h",
|
|
||||||
"code/dpf/int_mod_n.cc",
|
|
||||||
"code/dpf/int_mod_n.h",
|
|
||||||
"code/dpf/internal/evaluate_prg_hwy.cc",
|
|
||||||
"code/dpf/internal/evaluate_prg_hwy.h",
|
|
||||||
"code/dpf/internal/get_hwy_mode.cc",
|
|
||||||
"code/dpf/internal/get_hwy_mode.h",
|
|
||||||
"code/dpf/internal/proto_validator.cc",
|
|
||||||
"code/dpf/internal/proto_validator.h",
|
|
||||||
"code/dpf/internal/value_type_helpers.cc",
|
|
||||||
"code/dpf/internal/value_type_helpers.h",
|
|
||||||
"code/dpf/status_macros.h",
|
|
||||||
"code/dpf/tuple.h",
|
|
||||||
"code/dpf/xor_wrapper.h",
|
|
||||||
]
|
|
||||||
|
|
||||||
public_deps = [
|
|
||||||
":proto",
|
|
||||||
"$dpf_abseil_cpp_dir:absl",
|
|
||||||
"$dpf_highway_cpp_dir:libhwy",
|
|
||||||
"//third_party/boringssl",
|
|
||||||
"//third_party/protobuf:protobuf_lite",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Do not apply Chromium code rules to this third-party code.
|
|
||||||
configs -= [ "//build/config/compiler:chromium_code" ]
|
|
||||||
configs += [ "//build/config/compiler:no_chromium_code" ]
|
|
||||||
|
|
||||||
configs += [ ":includes" ]
|
|
||||||
}
|
|
11
third_party/distributed_point_functions/DEPS
vendored
11
third_party/distributed_point_functions/DEPS
vendored
@ -1,11 +0,0 @@
|
|||||||
include_rules = [
|
|
||||||
"+absl",
|
|
||||||
"+benchmark",
|
|
||||||
"+dpf",
|
|
||||||
"+gmock",
|
|
||||||
"+google/protobuf",
|
|
||||||
"+gtest",
|
|
||||||
"+testing",
|
|
||||||
"+hwy",
|
|
||||||
"+openssl",
|
|
||||||
]
|
|
@ -1,6 +0,0 @@
|
|||||||
monorail: {
|
|
||||||
component: "Internals>AttributionReporting"
|
|
||||||
}
|
|
||||||
buganizer_public: {
|
|
||||||
component_id: 1456103
|
|
||||||
}
|
|
202
third_party/distributed_point_functions/LICENSE
vendored
202
third_party/distributed_point_functions/LICENSE
vendored
@ -1,202 +0,0 @@
|
|||||||
|
|
||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
@ -1,3 +0,0 @@
|
|||||||
alexmt@chromium.org
|
|
||||||
csharrison@chromium.org
|
|
||||||
linnan@chromium.org
|
|
@ -1,25 +0,0 @@
|
|||||||
Name: The Incremental Distributed Point Functions library
|
|
||||||
Short Name: distributed_point_functions
|
|
||||||
URL: https://github.com/google/distributed_point_functions
|
|
||||||
Version: N/A
|
|
||||||
Revision: 2db593b64a99f178f682ef0db222d417c23e5bb5
|
|
||||||
Date: 2023-11-16
|
|
||||||
License: Apache-2.0
|
|
||||||
License File: LICENSE
|
|
||||||
Security Critical: Yes
|
|
||||||
Shipped: yes
|
|
||||||
CPEPrefix: unknown
|
|
||||||
|
|
||||||
Description:
|
|
||||||
This library contains an implementation of incremental distributed point
|
|
||||||
functions, based on the paper by Boneh et al.
|
|
||||||
|
|
||||||
Local Modifications:
|
|
||||||
The directory code/ is a copy of the source code, modified in two ways. First,
|
|
||||||
all top-level directories other than dpf/ have been removed as they are unused.
|
|
||||||
Second, a .clang-format file has been added to disable automatic code
|
|
||||||
formatting. Parts of code/dpf/distributed_point_function_test.cc are also
|
|
||||||
adapted for fuzzing in fuzz/dpf_fuzzer.cc.
|
|
||||||
Third, a missing absl/strings/str_cat.h include backported from revision
|
|
||||||
c662ca975068bfa884cc4a96f3a1db40a7611e5e to fix build error when compiled with
|
|
||||||
latest version of abseil.
|
|
@ -1 +0,0 @@
|
|||||||
build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17
|
|
@ -1 +0,0 @@
|
|||||||
DisableFormat: true
|
|
@ -1 +0,0 @@
|
|||||||
experiments/data/* filter=lfs diff=lfs merge=lfs -text
|
|
@ -1,2 +0,0 @@
|
|||||||
# Bazel generated symlinks
|
|
||||||
bazel-*
|
|
@ -1,9 +0,0 @@
|
|||||||
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
|
|
||||||
|
|
||||||
package(
|
|
||||||
default_visibility = [":allowlist"],
|
|
||||||
)
|
|
||||||
|
|
||||||
licenses(["notice"])
|
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
|
@ -1,93 +0,0 @@
|
|||||||
# Code of Conduct
|
|
||||||
|
|
||||||
## Our Pledge
|
|
||||||
|
|
||||||
In the interest of fostering an open and welcoming environment, we as
|
|
||||||
contributors and maintainers pledge to making participation in our project and
|
|
||||||
our community a harassment-free experience for everyone, regardless of age, body
|
|
||||||
size, disability, ethnicity, gender identity and expression, level of
|
|
||||||
experience, education, socio-economic status, nationality, personal appearance,
|
|
||||||
race, religion, or sexual identity and orientation.
|
|
||||||
|
|
||||||
## Our Standards
|
|
||||||
|
|
||||||
Examples of behavior that contributes to creating a positive environment
|
|
||||||
include:
|
|
||||||
|
|
||||||
* Using welcoming and inclusive language
|
|
||||||
* Being respectful of differing viewpoints and experiences
|
|
||||||
* Gracefully accepting constructive criticism
|
|
||||||
* Focusing on what is best for the community
|
|
||||||
* Showing empathy towards other community members
|
|
||||||
|
|
||||||
Examples of unacceptable behavior by participants include:
|
|
||||||
|
|
||||||
* The use of sexualized language or imagery and unwelcome sexual attention or
|
|
||||||
advances
|
|
||||||
* Trolling, insulting/derogatory comments, and personal or political attacks
|
|
||||||
* Public or private harassment
|
|
||||||
* Publishing others' private information, such as a physical or electronic
|
|
||||||
address, without explicit permission
|
|
||||||
* Other conduct which could reasonably be considered inappropriate in a
|
|
||||||
professional setting
|
|
||||||
|
|
||||||
## Our Responsibilities
|
|
||||||
|
|
||||||
Project maintainers are responsible for clarifying the standards of acceptable
|
|
||||||
behavior and are expected to take appropriate and fair corrective action in
|
|
||||||
response to any instances of unacceptable behavior.
|
|
||||||
|
|
||||||
Project maintainers have the right and responsibility to remove, edit, or reject
|
|
||||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
|
||||||
not aligned to this Code of Conduct, or to ban temporarily or permanently any
|
|
||||||
contributor for other behaviors that they deem inappropriate, threatening,
|
|
||||||
offensive, or harmful.
|
|
||||||
|
|
||||||
## Scope
|
|
||||||
|
|
||||||
This Code of Conduct applies both within project spaces and in public spaces
|
|
||||||
when an individual is representing the project or its community. Examples of
|
|
||||||
representing a project or community include using an official project e-mail
|
|
||||||
address, posting via an official social media account, or acting as an appointed
|
|
||||||
representative at an online or offline event. Representation of a project may be
|
|
||||||
further defined and clarified by project maintainers.
|
|
||||||
|
|
||||||
This Code of Conduct also applies outside the project spaces when the Project
|
|
||||||
Steward has a reasonable belief that an individual's behavior may have a
|
|
||||||
negative impact on the project or its community.
|
|
||||||
|
|
||||||
## Conflict Resolution
|
|
||||||
|
|
||||||
We do not believe that all conflict is bad; healthy debate and disagreement
|
|
||||||
often yield positive results. However, it is never okay to be disrespectful or
|
|
||||||
to engage in behavior that violates the project’s code of conduct.
|
|
||||||
|
|
||||||
If you see someone violating the code of conduct, you are encouraged to address
|
|
||||||
the behavior directly with those involved. Many issues can be resolved quickly
|
|
||||||
and easily, and this gives people more control over the outcome of their
|
|
||||||
dispute. If you are unable to resolve the matter for any reason, or if the
|
|
||||||
behavior is threatening or harassing, report it. We are dedicated to providing
|
|
||||||
an environment where participants feel welcome and safe.
|
|
||||||
|
|
||||||
Reports should be directed to *[PROJECT STEWARD NAME(s) AND EMAIL(s)]*, the
|
|
||||||
Project Steward(s) for *[PROJECT NAME]*. It is the Project Steward’s duty to
|
|
||||||
receive and address reported violations of the code of conduct. They will then
|
|
||||||
work with a committee consisting of representatives from the Open Source
|
|
||||||
Programs Office and the Google Open Source Strategy team. If for any reason you
|
|
||||||
are uncomfortable reaching out to the Project Steward, please email
|
|
||||||
opensource@google.com.
|
|
||||||
|
|
||||||
We will investigate every complaint, but you may not receive a direct response.
|
|
||||||
We will use our discretion in determining when and how to follow up on reported
|
|
||||||
incidents, which may range from not taking action to permanent expulsion from
|
|
||||||
the project and project-sponsored spaces. We will notify the accused of the
|
|
||||||
report and provide them an opportunity to discuss it before any action is taken.
|
|
||||||
The identity of the reporter will be omitted from the details of the report
|
|
||||||
supplied to the accused. In potentially harmful situations, such as ongoing
|
|
||||||
harassment or threats to anyone's safety, we may take action without notice.
|
|
||||||
|
|
||||||
## Attribution
|
|
||||||
|
|
||||||
This Code of Conduct is adapted from the Contributor Covenant, version 1.4,
|
|
||||||
available at
|
|
||||||
https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
|
@ -1,29 +0,0 @@
|
|||||||
# How to Contribute
|
|
||||||
|
|
||||||
We'd love to accept your patches and contributions to this project. There are
|
|
||||||
just a few small guidelines you need to follow.
|
|
||||||
|
|
||||||
## Contributor License Agreement
|
|
||||||
|
|
||||||
Contributions to this project must be accompanied by a Contributor License
|
|
||||||
Agreement (CLA). You (or your employer) retain the copyright to your
|
|
||||||
contribution; this simply gives us permission to use and redistribute your
|
|
||||||
contributions as part of the project. Head over to
|
|
||||||
<https://cla.developers.google.com/> to see your current agreements on file or
|
|
||||||
to sign a new one.
|
|
||||||
|
|
||||||
You generally only need to submit a CLA once, so if you've already submitted one
|
|
||||||
(even if it was for a different project), you probably don't need to do it
|
|
||||||
again.
|
|
||||||
|
|
||||||
## Code reviews
|
|
||||||
|
|
||||||
All submissions, including submissions by project members, require review. We
|
|
||||||
use GitHub pull requests for this purpose. Consult
|
|
||||||
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
|
|
||||||
information on using pull requests.
|
|
||||||
|
|
||||||
## Community Guidelines
|
|
||||||
|
|
||||||
This project follows
|
|
||||||
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
|
|
202
third_party/distributed_point_functions/code/LICENSE
vendored
202
third_party/distributed_point_functions/code/LICENSE
vendored
@ -1,202 +0,0 @@
|
|||||||
|
|
||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
@ -1,48 +0,0 @@
|
|||||||
# An Implementation of Incremental Distributed Point Functions in C++ [](https://buildkite.com/bazel/google-distributed-point-functions)
|
|
||||||
|
|
||||||
This library contains an implementation of incremental distributed point
|
|
||||||
functions, based on the following paper:
|
|
||||||
> Boneh, D., Boyle, E., Corrigan-Gibbs, H., Gilboa, N., & Ishai, Y. (2020).
|
|
||||||
Lightweight Techniques for Private Heavy Hitters. arXiv preprint
|
|
||||||
> arXiv:2012.14884. https://arxiv.org/abs/2012.14884
|
|
||||||
|
|
||||||
## About Incremental Distributed Point Functions
|
|
||||||
|
|
||||||
A distributed point function (DPF) is parameterized by an index `alpha` and a
|
|
||||||
value `beta`. It consists of two algorithms: key generation and evaluation.
|
|
||||||
The key generation procedure produces two keys `k_a` and `k_b`, given `alpha`
|
|
||||||
and `beta`. Evaluating each key on any point `x` in the DPF domain results in an
|
|
||||||
additive secret share of `beta`, if `x == alpha`, and a share of 0 otherwise.
|
|
||||||
|
|
||||||
Incremental DPFs additionally can be evaluated on prefixes of the index domain.
|
|
||||||
More precisely, an incremental DPF is parameterized by a hierarchy of index
|
|
||||||
domains, each a power of two larger than the previous. Key generation now takes
|
|
||||||
a vector `beta`, one value `beta[i]` for each hierarchy level.
|
|
||||||
When evaluated on a `b`-bit prefix of `alpha`, where b is the log domain size of
|
|
||||||
the `i`-th hierarchy level, the incremental DPF returns a secret share of
|
|
||||||
`beta[i]`, otherwise a share of 0.
|
|
||||||
|
|
||||||
For more details, see the above paper, as well as the
|
|
||||||
[`DistributedPointFunction` class documentation](dpf/distributed_point_function.h).
|
|
||||||
|
|
||||||
|
|
||||||
## Building/Running Tests
|
|
||||||
|
|
||||||
This repository requires Bazel. You can install Bazel by
|
|
||||||
following the instructions for your platform on the
|
|
||||||
[Bazel website](https://docs.bazel.build/versions/master/install.html).
|
|
||||||
|
|
||||||
Once you have installed Bazel you can clone this repository and run all tests
|
|
||||||
that are included by navigating into the root folder and running:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
bazel test //...
|
|
||||||
```
|
|
||||||
|
|
||||||
## Security
|
|
||||||
To report a security issue, please read [SECURITY.md](SECURITY.md).
|
|
||||||
|
|
||||||
## Disclaimer
|
|
||||||
|
|
||||||
This is not an officially supported Google product. The code is provided as-is,
|
|
||||||
with no guarantees of correctness or security.
|
|
@ -1,5 +0,0 @@
|
|||||||
# Security
|
|
||||||
To report a security issue, please use http://g.co/vulnz. We use
|
|
||||||
http://g.co/vulnz for our intake, and do coordination and disclosure here on
|
|
||||||
GitHub (including using GitHub Security Advisory). The Google Security Team will
|
|
||||||
respond within 5 working days of your report on g.co/vulnz.
|
|
@ -1,170 +0,0 @@
|
|||||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
|
||||||
|
|
||||||
# rules_proto defines abstract rules for building Protocol Buffers.
|
|
||||||
# https://github.com/bazelbuild/rules_proto
|
|
||||||
http_archive(
|
|
||||||
name = "rules_proto",
|
|
||||||
sha256 = "0daa4fc5b2b820705fcbf239557515f9ab809be45a1e7c6dfaa1d465d5c615d4",
|
|
||||||
strip_prefix = "rules_proto-3f1ab99b718e3e7dd86ebdc49c580aa6a126b1cd",
|
|
||||||
urls = [
|
|
||||||
"https://github.com/bazelbuild/rules_proto/archive/3f1ab99b718e3e7dd86ebdc49c580aa6a126b1cd.zip",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains")
|
|
||||||
|
|
||||||
rules_proto_dependencies()
|
|
||||||
|
|
||||||
rules_proto_toolchains()
|
|
||||||
|
|
||||||
# rules_cc defines rules for generating C++ code from Protocol Buffers.
|
|
||||||
# https://github.com/bazelbuild/rules_cc
|
|
||||||
http_archive(
|
|
||||||
name = "rules_cc",
|
|
||||||
sha256 = "e17cca44563e0918a36a8ea2a50acb99ea9ad726bbd3cad8ba95a643a40121ab",
|
|
||||||
strip_prefix = "rules_cc-d7c11265cb157c9b962d87d9ab67b8c24e3a875f",
|
|
||||||
urls = [
|
|
||||||
"https://github.com/bazelbuild/rules_cc/archive/d7c11265cb157c9b962d87d9ab67b8c24e3a875f.zip",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
load("@rules_cc//cc:repositories.bzl", "rules_cc_dependencies")
|
|
||||||
|
|
||||||
rules_cc_dependencies()
|
|
||||||
|
|
||||||
# io_bazel_rules_go defines rules for generating C++ code from Protocol Buffers.
|
|
||||||
# https://github.com/bazelbuild/rules_go
|
|
||||||
http_archive(
|
|
||||||
name = "io_bazel_rules_go",
|
|
||||||
sha256 = "7c35e8515012279ef7bcbc39c4ef4b54a86756d853848cb621b7da49f156c82f",
|
|
||||||
strip_prefix = "rules_go-b397ab7ace3c4131f48b5f4d4d7e7e9e6809e0d2",
|
|
||||||
urls = [
|
|
||||||
"https://github.com/bazelbuild/rules_go/archive/b397ab7ace3c4131f48b5f4d4d7e7e9e6809e0d2.zip",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies")
|
|
||||||
|
|
||||||
go_rules_dependencies()
|
|
||||||
|
|
||||||
go_register_toolchains(version = "1.19.3")
|
|
||||||
|
|
||||||
# Install gtest.
|
|
||||||
# https://github.com/google/googletest
|
|
||||||
http_archive(
|
|
||||||
name = "com_github_google_googletest",
|
|
||||||
sha256 = "3e91944af2d909a79f18ee9760765624810146ccfae8f1a8f990037a1677d44b",
|
|
||||||
strip_prefix = "googletest-ac7a126f39d5bcd909b78c9e69900c76659b1bbb",
|
|
||||||
urls = [
|
|
||||||
"https://github.com/google/googletest/archive/ac7a126f39d5bcd909b78c9e69900c76659b1bbb.zip",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# abseil-cpp
|
|
||||||
# https://github.com/abseil/abseil-cpp
|
|
||||||
http_archive(
|
|
||||||
name = "com_google_absl",
|
|
||||||
sha256 = "431c0c47217c36106f90e2ca4fcdf45af618ea21adde880804661b1ecb240056",
|
|
||||||
strip_prefix = "abseil-cpp-1fb3830b1cf685999bb2bbd0294be0a53c9440a6",
|
|
||||||
urls = [
|
|
||||||
"https://github.com/abseil/abseil-cpp/archive/1fb3830b1cf685999bb2bbd0294be0a53c9440a6.zip",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# BoringSSL
|
|
||||||
# https://github.com/google/boringssl
|
|
||||||
http_archive(
|
|
||||||
name = "boringssl",
|
|
||||||
sha256 = "88e4330f4f65ebfdf24847e4807c25f3eacfd5bf1a93f6629d3941196ff9b0b3",
|
|
||||||
strip_prefix = "boringssl-6347808f2a480a3792148bf7732232229db9b909",
|
|
||||||
urls = [
|
|
||||||
"https://github.com/google/boringssl/archive/6347808f2a480a3792148bf7732232229db9b909.zip",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Benchmarks
|
|
||||||
# https://github.com/google/benchmark
|
|
||||||
http_archive(
|
|
||||||
name = "com_github_google_benchmark",
|
|
||||||
sha256 = "5f98b44165f3250f1d749b728018318d654f763ea0f4d7ea156e10e6e0cc678a",
|
|
||||||
strip_prefix = "benchmark-5e78bedfb07c615edb2b646d1e354980268c1728",
|
|
||||||
urls = [
|
|
||||||
"https://github.com/google/benchmark/archive/5e78bedfb07c615edb2b646d1e354980268c1728.zip",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# gflags needed for glog.
|
|
||||||
# https://github.com/gflags/gflags
|
|
||||||
http_archive(
|
|
||||||
name = "com_github_gflags_gflags",
|
|
||||||
sha256 = "017e0a91531bfc45be9eaf07e4d8fed33c488b90b58509dbd2e33a33b2648ae6",
|
|
||||||
strip_prefix = "gflags-a738fdf9338412f83ab3f26f31ac11ed3f3ec4bd",
|
|
||||||
urls = [
|
|
||||||
"https://github.com/gflags/gflags/archive/a738fdf9338412f83ab3f26f31ac11ed3f3ec4bd.zip",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# glog for logging
|
|
||||||
# https://github.com/google/glog
|
|
||||||
http_archive(
|
|
||||||
name = "com_github_google_glog",
|
|
||||||
sha256 = "0f91ee6cc1edc3b1c53a286382e69a37e5d172ce208b7e5b305be8770d8c21b1",
|
|
||||||
strip_prefix = "glog-f545ff5e7d7f3df95f6e86c8cb987d9d9d4bd481",
|
|
||||||
urls = [
|
|
||||||
"https://github.com/google/glog/archive/f545ff5e7d7f3df95f6e86c8cb987d9d9d4bd481.zip",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# IREE for cc_embed_data.
|
|
||||||
# https://github.com/google/iree
|
|
||||||
http_archive(
|
|
||||||
name = "com_github_google_iree",
|
|
||||||
sha256 = "aa369b29a5c45ae9d7aa8bf49ea1308221d1711277222f0755df6e0a575f6879",
|
|
||||||
strip_prefix = "iree-7e6012468cbaafaaf30302748a2943771b40e2c3",
|
|
||||||
urls = [
|
|
||||||
"https://github.com/google/iree/archive/7e6012468cbaafaaf30302748a2943771b40e2c3.zip",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# riegeli for file IO
|
|
||||||
# https://github.com/google/riegeli
|
|
||||||
http_archive(
|
|
||||||
name = "com_github_google_riegeli",
|
|
||||||
sha256 = "3de21a222271a1e2c5d728e7f46b63ab4520da829c09ef9727a322e693c9ac18",
|
|
||||||
strip_prefix = "riegeli-43b7ef9f995469609b6ab07f6becc82186314bfb",
|
|
||||||
urls = [
|
|
||||||
"https://github.com/google/riegeli/archive/43b7ef9f995469609b6ab07f6becc82186314bfb.zip",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# rules_license needed for Highway
|
|
||||||
# https://github.com/bazelbuild/rules_license
|
|
||||||
http_archive(
|
|
||||||
name = "rules_license",
|
|
||||||
sha256 = "6157e1e68378532d0241ecd15d3c45f6e5cfd98fc10846045509fb2a7cc9e381",
|
|
||||||
urls = [
|
|
||||||
"https://github.com/bazelbuild/rules_license/releases/download/0.0.4/rules_license-0.0.4.tar.gz",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Highway for SIMD operations.
|
|
||||||
# https://github.com/google/highway
|
|
||||||
http_archive(
|
|
||||||
name = "com_github_google_highway",
|
|
||||||
sha256 = "cdba0eb21796598dd50fa0a4aa3651fa466c0d37c39d149ee383f725434e4314",
|
|
||||||
strip_prefix = "highway-45c98184ab7f81cf592c07633070b75fced14a52",
|
|
||||||
urls = [
|
|
||||||
"https://github.com/google/highway/archive/45c98184ab7f81cf592c07633070b75fced14a52.zip",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# cppitertools for logging
|
|
||||||
# https://github.com/ryanhaining/cppitertools
|
|
||||||
http_archive(
|
|
||||||
name = "com_github_ryanhaining_cppitertools",
|
|
||||||
sha256 = "1608ddbe3c12b0c6e653b992ff63b5dceab9af5347ad93be8714d05e5dc17afb",
|
|
||||||
strip_prefix = "cppitertools-add5acc932dea2c78acd80747bab71ec0b5bce27",
|
|
||||||
urls = [
|
|
||||||
"https://github.com/ryanhaining/cppitertools/archive/add5acc932dea2c78acd80747bab71ec0b5bce27.zip",
|
|
||||||
],
|
|
||||||
)
|
|
@ -1,235 +0,0 @@
|
|||||||
# Copyright 2023 Google LLC
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
|
|
||||||
load("@rules_cc//cc:defs.bzl", "cc_library")
|
|
||||||
load("@rules_proto//proto:defs.bzl", "proto_library")
|
|
||||||
|
|
||||||
package(
|
|
||||||
default_visibility = ["//visibility:public"],
|
|
||||||
)
|
|
||||||
|
|
||||||
licenses(["notice"])
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "int_mod_n",
|
|
||||||
srcs = ["int_mod_n.cc"],
|
|
||||||
hdrs = ["int_mod_n.h"],
|
|
||||||
deps = [
|
|
||||||
"@com_google_absl//absl/base:config",
|
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
|
||||||
"@com_google_absl//absl/log:absl_check",
|
|
||||||
"@com_google_absl//absl/numeric:int128",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_absl//absl/strings:str_format",
|
|
||||||
"@com_google_absl//absl/types:span",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "int_mod_n_test",
|
|
||||||
srcs = ["int_mod_n_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":int_mod_n",
|
|
||||||
"//dpf/internal:status_matchers",
|
|
||||||
"@com_github_google_googletest//:gtest_main",
|
|
||||||
"@com_google_absl//absl/base:config",
|
|
||||||
"@com_google_absl//absl/numeric:int128",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings:str_format",
|
|
||||||
"@com_google_absl//absl/types:span",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "int_mod_n_benchmark",
|
|
||||||
srcs = ["int_mod_n_benchmark.cc"],
|
|
||||||
deps = [
|
|
||||||
":int_mod_n",
|
|
||||||
"@boringssl//:crypto",
|
|
||||||
"@com_github_google_benchmark//:benchmark",
|
|
||||||
"@com_github_google_googletest//:gtest_main",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_absl//absl/types:span",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "distributed_point_function",
|
|
||||||
srcs = ["distributed_point_function.cc"],
|
|
||||||
hdrs = ["distributed_point_function.h"],
|
|
||||||
deps = [
|
|
||||||
":aes_128_fixed_key_hash",
|
|
||||||
":distributed_point_function_cc_proto",
|
|
||||||
":status_macros",
|
|
||||||
"//dpf/internal:evaluate_prg_hwy",
|
|
||||||
"//dpf/internal:get_hwy_mode",
|
|
||||||
"//dpf/internal:maybe_deref_span",
|
|
||||||
"//dpf/internal:proto_validator",
|
|
||||||
"//dpf/internal:value_type_helpers",
|
|
||||||
"@boringssl//:crypto",
|
|
||||||
"@com_github_google_highway//:hwy",
|
|
||||||
"@com_google_absl//absl/container:btree",
|
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
|
||||||
"@com_google_absl//absl/log:absl_check",
|
|
||||||
"@com_google_absl//absl/log:absl_log",
|
|
||||||
"@com_google_absl//absl/memory",
|
|
||||||
"@com_google_absl//absl/meta:type_traits",
|
|
||||||
"@com_google_absl//absl/numeric:int128",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_absl//absl/strings:str_format",
|
|
||||||
"@com_google_absl//absl/types:span",
|
|
||||||
"@com_google_protobuf//:protobuf",
|
|
||||||
"@com_google_protobuf//:protobuf_lite",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "distributed_point_function_test",
|
|
||||||
size = "medium",
|
|
||||||
srcs = ["distributed_point_function_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":distributed_point_function",
|
|
||||||
":distributed_point_function_cc_proto",
|
|
||||||
":xor_wrapper",
|
|
||||||
"//dpf/internal:proto_validator",
|
|
||||||
"//dpf/internal:status_matchers",
|
|
||||||
"@com_github_google_googletest//:gtest_main",
|
|
||||||
"@com_google_absl//absl/algorithm:container",
|
|
||||||
"@com_google_absl//absl/base:config",
|
|
||||||
"@com_google_absl//absl/numeric:int128",
|
|
||||||
"@com_google_absl//absl/random",
|
|
||||||
"@com_google_absl//absl/random:distributions",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_absl//absl/strings:str_format",
|
|
||||||
"@com_google_absl//absl/types:span",
|
|
||||||
"@com_google_absl//absl/utility",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
proto_library(
|
|
||||||
name = "distributed_point_function_proto",
|
|
||||||
srcs = ["distributed_point_function.proto"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_proto_library(
|
|
||||||
name = "distributed_point_function_cc_proto",
|
|
||||||
deps = [":distributed_point_function_proto"],
|
|
||||||
)
|
|
||||||
|
|
||||||
go_proto_library(
|
|
||||||
name = "distributed_point_function_go_proto",
|
|
||||||
importpath = "github.com/google/distributed_point_functions/dpf/distributed_point_function_go_proto",
|
|
||||||
protos = [":distributed_point_function_proto"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "distributed_point_function_benchmark",
|
|
||||||
srcs = [
|
|
||||||
"distributed_point_function_benchmark.cc",
|
|
||||||
],
|
|
||||||
tags = ["benchmark"],
|
|
||||||
deps = [
|
|
||||||
":distributed_point_function",
|
|
||||||
"@com_github_google_benchmark//:benchmark",
|
|
||||||
"@com_github_google_googletest//:gtest_main",
|
|
||||||
"@com_github_google_highway//:hwy",
|
|
||||||
"@com_google_absl//absl/container:btree",
|
|
||||||
"@com_google_absl//absl/log:absl_check",
|
|
||||||
"@com_google_absl//absl/numeric:int128",
|
|
||||||
"@com_google_absl//absl/random",
|
|
||||||
"@com_google_absl//absl/random:distributions",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_absl//absl/types:span",
|
|
||||||
"@com_google_protobuf//:protobuf",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "status_macros",
|
|
||||||
hdrs = ["status_macros.h"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "aes_128_fixed_key_hash",
|
|
||||||
srcs = ["aes_128_fixed_key_hash.cc"],
|
|
||||||
hdrs = ["aes_128_fixed_key_hash.h"],
|
|
||||||
deps = [
|
|
||||||
"@boringssl//:crypto",
|
|
||||||
"@com_google_absl//absl/numeric:int128",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_absl//absl/types:span",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "aes_128_fixed_key_hash_test",
|
|
||||||
srcs = ["aes_128_fixed_key_hash_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":aes_128_fixed_key_hash",
|
|
||||||
"//dpf/internal:status_matchers",
|
|
||||||
"@com_github_google_googletest//:gtest_main",
|
|
||||||
"@com_google_absl//absl/numeric:int128",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/types:span",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "tuple",
|
|
||||||
hdrs = ["tuple.h"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "tuple_test",
|
|
||||||
srcs = [
|
|
||||||
"tuple_test.cc",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":tuple",
|
|
||||||
"@com_github_google_googletest//:gtest_main",
|
|
||||||
"@com_google_absl//absl/numeric:int128",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "xor_wrapper",
|
|
||||||
hdrs = ["xor_wrapper.h"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "xor_wrapper_test",
|
|
||||||
srcs = [
|
|
||||||
"xor_wrapper_test.cc",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":xor_wrapper",
|
|
||||||
"@com_github_google_googletest//:gtest_main",
|
|
||||||
"@com_google_absl//absl/numeric:int128",
|
|
||||||
],
|
|
||||||
)
|
|
@ -1,102 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2021 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "dpf/aes_128_fixed_key_hash.h"
|
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <array>
|
|
||||||
#include <string>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "absl/types/span.h"
|
|
||||||
#include "openssl/err.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
|
|
||||||
Aes128FixedKeyHash::Aes128FixedKeyHash(
|
|
||||||
bssl::UniquePtr<EVP_CIPHER_CTX> cipher_ctx, absl::uint128 key)
|
|
||||||
: cipher_ctx_(std::move(cipher_ctx)), key_(key) {}
|
|
||||||
|
|
||||||
absl::StatusOr<Aes128FixedKeyHash> Aes128FixedKeyHash::Create(
|
|
||||||
absl::uint128 key) {
|
|
||||||
bssl::UniquePtr<EVP_CIPHER_CTX> cipher_ctx(EVP_CIPHER_CTX_new());
|
|
||||||
if (!cipher_ctx) {
|
|
||||||
return absl::InternalError("Failed to allocate AES context");
|
|
||||||
}
|
|
||||||
// Set up the OpenSSL encryption context. We want to evaluate the PRG in
|
|
||||||
// parallel on many seeds (see class comment in pseudorandom_generator.h), so
|
|
||||||
// we're using ECB mode here to achieve that. This batched evaluation is not
|
|
||||||
// to be confused with encryption of an array, for which ECB would be
|
|
||||||
// insecure.
|
|
||||||
int openssl_status =
|
|
||||||
EVP_EncryptInit_ex(cipher_ctx.get(), EVP_aes_128_ecb(), nullptr,
|
|
||||||
reinterpret_cast<const uint8_t*>(&key), nullptr);
|
|
||||||
if (openssl_status != 1) {
|
|
||||||
return absl::InternalError("Failed to set up AES context");
|
|
||||||
}
|
|
||||||
return Aes128FixedKeyHash(std::move(cipher_ctx), key);
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status Aes128FixedKeyHash::Evaluate(absl::Span<const absl::uint128> in,
|
|
||||||
absl::Span<absl::uint128> out) const {
|
|
||||||
if (in.size() != out.size()) {
|
|
||||||
return absl::InvalidArgumentError("Input and output sizes don't match");
|
|
||||||
}
|
|
||||||
if (in.empty()) {
|
|
||||||
// Nothing to do.
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute orthomorphism sigma for each element in `in`, `kBatchSize` elements
|
|
||||||
// at a time.
|
|
||||||
auto in_size = static_cast<int64_t>(in.size());
|
|
||||||
std::array<absl::uint128, kBatchSize> sigma_in;
|
|
||||||
for (int64_t start_block = 0; start_block < in_size;
|
|
||||||
start_block += kBatchSize) {
|
|
||||||
int64_t batch_size = std::min<int64_t>(in_size - start_block, kBatchSize);
|
|
||||||
for (int i = 0; i < batch_size; ++i) {
|
|
||||||
sigma_in[i] =
|
|
||||||
absl::MakeUint128(absl::Uint128High64(in[start_block + i]) ^
|
|
||||||
absl::Uint128Low64(in[start_block + i]),
|
|
||||||
absl::Uint128High64(in[start_block + i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
// We use EVP_Cipher here instead of EVP_EncryptUpdate, since it doesn't
|
|
||||||
// mutate the context in ECB mode, and so this call is thread-safe.
|
|
||||||
int openssl_status = EVP_Cipher(
|
|
||||||
cipher_ctx_.get(), reinterpret_cast<uint8_t*>(out.data() + start_block),
|
|
||||||
reinterpret_cast<const uint8_t*>(sigma_in.data()),
|
|
||||||
static_cast<int>(batch_size * sizeof(absl::uint128)));
|
|
||||||
if (openssl_status != 1) {
|
|
||||||
char buf[256];
|
|
||||||
ERR_error_string_n(ERR_get_error(), buf, sizeof(buf));
|
|
||||||
return absl::InternalError(
|
|
||||||
absl::StrCat("AES encryption failed: ", std::string(buf)));
|
|
||||||
}
|
|
||||||
for (int64_t i = 0; i < batch_size; ++i) {
|
|
||||||
out[start_block + i] ^= sigma_in[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace distributed_point_functions
|
|
@ -1,86 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2021 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_H_
|
|
||||||
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_H_
|
|
||||||
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/types/span.h"
|
|
||||||
#include "openssl/cipher.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
|
|
||||||
// Aes128FixedKeyHash is a circular correlation-robust hash function based on
|
|
||||||
// AES. For key `key`, input `in` and output `out`, the hash function is defined
|
|
||||||
// as
|
|
||||||
//
|
|
||||||
// out[i] = AES.Encrypt(key, sigma(in[i])) ^ sigma(in[i]),
|
|
||||||
//
|
|
||||||
// where sigma(x) = (x.high64 ^ x.low64, x.high64). This is the
|
|
||||||
// circular correlation-robust MMO construction from
|
|
||||||
// https://eprint.iacr.org/2019/074.pdf (pp. 18-19). Note that unlike
|
|
||||||
// cryptographic hash functions such as SHA-256, this hash function is *not*
|
|
||||||
// compressing and is not designed to provide any security guarantees beyond
|
|
||||||
// circular correlation-robustness. Use with appropriate caution.
|
|
||||||
class Aes128FixedKeyHash {
|
|
||||||
public:
|
|
||||||
// Creates a new Aes128FixedKeyHash with the given `key`.
|
|
||||||
//
|
|
||||||
// Returns INTERNAL in case of allocation failures or OpenSSL errors.
|
|
||||||
static absl::StatusOr<Aes128FixedKeyHash> Create(absl::uint128 key);
|
|
||||||
|
|
||||||
// Computes hash values of each block in `in`, writing the output to `out`.
|
|
||||||
// It is safe to call this method if `in` and `out` overlap.
|
|
||||||
//
|
|
||||||
// Returns INVALID_ARGUMENT if sizes of `in` and `out` don't match or their
|
|
||||||
// sizes in bytes exceed an `int`, or INTERNAL in case of OpenSSL errors.
|
|
||||||
absl::Status Evaluate(absl::Span<const absl::uint128> in,
|
|
||||||
absl::Span<absl::uint128> out) const;
|
|
||||||
|
|
||||||
// Aes128FixedKeyHash is not copyable.
|
|
||||||
Aes128FixedKeyHash(const Aes128FixedKeyHash&) = delete;
|
|
||||||
Aes128FixedKeyHash& operator=(const Aes128FixedKeyHash&) = delete;
|
|
||||||
|
|
||||||
// Aes128FixedKeyHash is movable (it just wraps a bssl::UniquePtr).
|
|
||||||
Aes128FixedKeyHash(Aes128FixedKeyHash&&) = default;
|
|
||||||
Aes128FixedKeyHash& operator=(Aes128FixedKeyHash&&) = default;
|
|
||||||
|
|
||||||
// Returns the key used to construct this hash function.
|
|
||||||
// DO NOT SEND THIS TO ANY OTHER PARTY!
|
|
||||||
const absl::uint128& key() const { return key_; }
|
|
||||||
|
|
||||||
// The maximum number of AES blocks encrypted at once. Chosen to pipeline AES
|
|
||||||
// as much as possible, while still allowing both source and destination to
|
|
||||||
// comfortably fit in the L1 CPU cache.
|
|
||||||
static constexpr int kBatchSize = 64;
|
|
||||||
|
|
||||||
private:
|
|
||||||
// Called by `Create`.
|
|
||||||
Aes128FixedKeyHash(bssl::UniquePtr<EVP_CIPHER_CTX> cipher_ctx,
|
|
||||||
absl::uint128 key);
|
|
||||||
|
|
||||||
// The OpenSSL encryption context used by `Evaluate`.
|
|
||||||
bssl::UniquePtr<EVP_CIPHER_CTX> cipher_ctx_;
|
|
||||||
|
|
||||||
// The key used to construct this hash function.
|
|
||||||
absl::uint128 key_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
|
|
||||||
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_H_
|
|
@ -1,178 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2021 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "dpf/aes_128_fixed_key_hash.h"
|
|
||||||
|
|
||||||
#include <thread> // NOLINT(build/c++11)
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/types/span.h"
|
|
||||||
#include "dpf/internal/status_matchers.h"
|
|
||||||
#include "gmock/gmock.h"
|
|
||||||
#include "gtest/gtest.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
using dpf_internal::StatusIs;
|
|
||||||
|
|
||||||
// Test blocks for keys, inputs, and outputs.
|
|
||||||
constexpr absl::uint128 kKey0 =
|
|
||||||
absl::MakeUint128(0x0000000000000000, 0x0000000000000000);
|
|
||||||
constexpr absl::uint128 kKey1 =
|
|
||||||
absl::MakeUint128(0x1111111111111111, 0x1111111111111111);
|
|
||||||
constexpr absl::uint128 kSeed0 =
|
|
||||||
absl::MakeUint128(0x0123012301230123, 0x0123012301230123);
|
|
||||||
constexpr absl::uint128 kSeed1 =
|
|
||||||
absl::MakeUint128(0x4567456745674567, 0x4567456745674567);
|
|
||||||
constexpr absl::uint128 kSeed2 =
|
|
||||||
absl::MakeUint128(0x89ab89ab89ab89ab, 0x89ab89ab89ab89ab);
|
|
||||||
constexpr absl::uint128 kSeed3 =
|
|
||||||
absl::MakeUint128(0xcdefcdefcdefcdef, 0xcdefcdefcdefcdef);
|
|
||||||
|
|
||||||
TEST(Aes128FixedKeyHashTest, CreateSucceeds) {
|
|
||||||
DPF_EXPECT_OK(Aes128FixedKeyHash::Create(kKey0));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(Aes128FixedKeyHashTest, SameKeysAndSeedsGenerateSameOutput) {
|
|
||||||
std::vector<absl::uint128> in;
|
|
||||||
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_0,
|
|
||||||
Aes128FixedKeyHash::Create(kKey0));
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_1,
|
|
||||||
Aes128FixedKeyHash::Create(kKey0));
|
|
||||||
in = {kSeed0};
|
|
||||||
// Initialize output arrays with different values, to make sure they are the
|
|
||||||
// same afterwards.
|
|
||||||
std::vector<absl::uint128> out_0(in.size(), kSeed2), out_1(in.size(), kSeed3);
|
|
||||||
|
|
||||||
DPF_EXPECT_OK(prg_0.Evaluate(in, absl::MakeSpan(out_0)));
|
|
||||||
DPF_EXPECT_OK(prg_1.Evaluate(in, absl::MakeSpan(out_1)));
|
|
||||||
EXPECT_THAT(out_0, testing::ElementsAreArray(out_1));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(Aes128FixedKeyHashTest, DifferentKeysGenerateDifferentOutput) {
|
|
||||||
std::vector<absl::uint128> in{kSeed0};
|
|
||||||
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_0,
|
|
||||||
Aes128FixedKeyHash::Create(kKey0));
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_1,
|
|
||||||
Aes128FixedKeyHash::Create(kKey1));
|
|
||||||
// Initialize output arrays with the same values, to make sure they are
|
|
||||||
// different afterwards.
|
|
||||||
std::vector<absl::uint128> out_0(in.size(), kSeed2), out_1(in.size(), kSeed2);
|
|
||||||
|
|
||||||
DPF_EXPECT_OK(prg_0.Evaluate(in, absl::MakeSpan(out_0)));
|
|
||||||
DPF_EXPECT_OK(prg_1.Evaluate(in, absl::MakeSpan(out_1)));
|
|
||||||
EXPECT_THAT(out_0, testing::Not(testing::ElementsAreArray(out_1)));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(Aes128FixedKeyHashTest, DifferentSeedsGenerateDifferentOutput) {
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg,
|
|
||||||
Aes128FixedKeyHash::Create(kKey0));
|
|
||||||
std::vector<absl::uint128> in_0, in_1;
|
|
||||||
|
|
||||||
in_0 = {kSeed0};
|
|
||||||
in_1 = {kSeed1};
|
|
||||||
// Initialize output arrays with the same values, to make sure they are
|
|
||||||
// different afterwards.
|
|
||||||
std::vector<absl::uint128> out_0(in_0.size(), kSeed2),
|
|
||||||
out_1(in_1.size(), kSeed2);
|
|
||||||
|
|
||||||
DPF_EXPECT_OK(prg.Evaluate(in_0, absl::MakeSpan(out_0)));
|
|
||||||
DPF_EXPECT_OK(prg.Evaluate(in_1, absl::MakeSpan(out_1)));
|
|
||||||
EXPECT_THAT(out_0, testing::Not(testing::ElementsAreArray(out_1)));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(Aes128FixedKeyHashTest, BatchedEvaluationEqualsBlockWiseEvaluation) {
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg,
|
|
||||||
Aes128FixedKeyHash::Create(kKey0));
|
|
||||||
std::vector<absl::uint128> in_0, in_1, in_2;
|
|
||||||
|
|
||||||
in_0 = {kSeed0};
|
|
||||||
in_1 = {kSeed1};
|
|
||||||
in_2 = {kSeed0, kSeed1};
|
|
||||||
std::vector<absl::uint128> out_0(in_0.size()), out_1(in_1.size()),
|
|
||||||
out_2(in_2.size());
|
|
||||||
|
|
||||||
DPF_EXPECT_OK(prg.Evaluate(in_0, absl::MakeSpan(out_0)));
|
|
||||||
DPF_EXPECT_OK(prg.Evaluate(in_1, absl::MakeSpan(out_1)));
|
|
||||||
DPF_EXPECT_OK(prg.Evaluate(in_2, absl::MakeSpan(out_2)));
|
|
||||||
EXPECT_THAT(out_2, testing::ElementsAre(out_0[0], out_1[0]));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(Aes128FixedKeyHashTest, TestSpecificOutputValues) {
|
|
||||||
std::vector<absl::uint128> in, out_0, out_1;
|
|
||||||
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_0,
|
|
||||||
Aes128FixedKeyHash::Create(kKey0));
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_1,
|
|
||||||
Aes128FixedKeyHash::Create(kKey1));
|
|
||||||
in = {kSeed0, kSeed1};
|
|
||||||
out_0.resize(in.size());
|
|
||||||
out_1.resize(in.size());
|
|
||||||
|
|
||||||
DPF_EXPECT_OK(prg_0.Evaluate(in, absl::MakeSpan(out_0)));
|
|
||||||
DPF_EXPECT_OK(prg_1.Evaluate(in, absl::MakeSpan(out_1)));
|
|
||||||
EXPECT_THAT(out_0,
|
|
||||||
testing::ElementsAre(
|
|
||||||
absl::MakeUint128(0x73c2dc14812be4ef, 0xeac64d09c8adf8ed),
|
|
||||||
absl::MakeUint128(0xb8f33653a53a8436, 0xaedf39b62de91d95)));
|
|
||||||
EXPECT_THAT(out_1,
|
|
||||||
testing::ElementsAre(
|
|
||||||
absl::MakeUint128(0x934704aff58fa233, 0xd3c20d1b9cc18d8f),
|
|
||||||
absl::MakeUint128(0x530098817046d284, 0x43e61d3273a04f7c)));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(Aes128FixedKeyHashTest, EvaluateFailsWhenSizesDontMatch) {
|
|
||||||
std::vector<absl::uint128> in{kSeed0};
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg,
|
|
||||||
Aes128FixedKeyHash::Create(kKey0));
|
|
||||||
|
|
||||||
std::vector<absl::uint128> out(in.size() + 1);
|
|
||||||
|
|
||||||
EXPECT_THAT(prg.Evaluate(in, absl::MakeSpan(out)),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"Input and output sizes don't match"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(Aes128FixedKeyHashTest, TestThreadSafety) {
|
|
||||||
std::vector<absl::uint128> in{kSeed0};
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg,
|
|
||||||
Aes128FixedKeyHash::Create(kKey0));
|
|
||||||
constexpr int kNumThreads = 1024;
|
|
||||||
|
|
||||||
auto do_evaluation = [&prg, &in]() {
|
|
||||||
absl::uint128 out;
|
|
||||||
DPF_ASSERT_OK(prg.Evaluate(in, absl::MakeSpan(&out, 1)));
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<std::thread> threads;
|
|
||||||
threads.reserve(kNumThreads);
|
|
||||||
for (int i = 0; i < kNumThreads; ++i) {
|
|
||||||
threads.emplace_back(do_evaluation);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto& thread : threads) {
|
|
||||||
thread.join();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace distributed_point_functions
|
|
@ -1,732 +0,0 @@
|
|||||||
// Copyright 2021 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "dpf/distributed_point_function.h"
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <array>
|
|
||||||
#include <cstddef>
|
|
||||||
#include <limits>
|
|
||||||
#include <memory>
|
|
||||||
#include <numeric>
|
|
||||||
#include <string>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/container/btree_map.h"
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
|
||||||
#include "absl/container/inlined_vector.h"
|
|
||||||
#include "absl/log/absl_check.h"
|
|
||||||
#include "absl/log/absl_log.h"
|
|
||||||
#include "absl/memory/memory.h"
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "absl/strings/string_view.h"
|
|
||||||
#include "absl/types/span.h"
|
|
||||||
#include "dpf/internal/evaluate_prg_hwy.h"
|
|
||||||
#include "dpf/internal/get_hwy_mode.h"
|
|
||||||
#include "dpf/internal/proto_validator.h"
|
|
||||||
#include "dpf/internal/value_type_helpers.h"
|
|
||||||
#include "dpf/status_macros.h"
|
|
||||||
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
|
|
||||||
#include "hwy/aligned_allocator.h"
|
|
||||||
#include "openssl/rand.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
// PRG keys used to expand seeds using AES. The first two are used to compute
|
|
||||||
// correction words of seeds, while the last is used to compute correction
|
|
||||||
// words of the incremental DPF values. Values were computed by taking the
|
|
||||||
// first half of the SHA256 sum of the constant name, e.g., `echo
|
|
||||||
// "DistributedPointFunction::kPrgKeyLeft" | sha256sum`
|
|
||||||
constexpr absl::uint128 kPrgKeyLeft =
|
|
||||||
absl::MakeUint128(0x5be037ccf6a03de5ULL, 0x935f08d0a5b6a2fdULL);
|
|
||||||
constexpr absl::uint128 kPrgKeyRight =
|
|
||||||
absl::MakeUint128(0xef94b6aedebb026cULL, 0xe2ea1fe0f66f4d0bULL);
|
|
||||||
constexpr absl::uint128 kPrgKeyValue =
|
|
||||||
absl::MakeUint128(0x05a5d1588c5423e3ULL, 0x46a31101b21d1c98ULL);
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
DistributedPointFunction::DistributedPointFunction(
|
|
||||||
std::unique_ptr<dpf_internal::ProtoValidator> proto_validator,
|
|
||||||
std::vector<int> blocks_needed, Aes128FixedKeyHash prg_left,
|
|
||||||
Aes128FixedKeyHash prg_right, Aes128FixedKeyHash prg_value,
|
|
||||||
absl::flat_hash_map<std::string, ValueCorrectionFunction>
|
|
||||||
value_correction_functions)
|
|
||||||
: proto_validator_(std::move(proto_validator)),
|
|
||||||
parameters_(proto_validator_->parameters()),
|
|
||||||
tree_levels_needed_(proto_validator_->tree_levels_needed()),
|
|
||||||
tree_to_hierarchy_(proto_validator_->tree_to_hierarchy()),
|
|
||||||
hierarchy_to_tree_(proto_validator_->hierarchy_to_tree()),
|
|
||||||
blocks_needed_(std::move(blocks_needed)),
|
|
||||||
prg_left_(std::move(prg_left)),
|
|
||||||
prg_right_(std::move(prg_right)),
|
|
||||||
prg_value_(std::move(prg_value)),
|
|
||||||
value_correction_functions_(value_correction_functions) {}
|
|
||||||
|
|
||||||
absl::StatusOr<std::vector<Value>>
|
|
||||||
DistributedPointFunction::ComputeValueCorrection(
|
|
||||||
int hierarchy_level, absl::Span<const absl::uint128> seeds,
|
|
||||||
absl::uint128 alpha, const Value& beta, bool invert) const {
|
|
||||||
// Compute value output component of the PRG on current seeds. To that end, we
|
|
||||||
// Compute x_0+0, ..., x_0+k-1, and x_1+0, ..., x_1+k-1, where x_i is the seed
|
|
||||||
// for helper i, and k is the number of blocks needed at the current hierarchy
|
|
||||||
// level. We use a single contiguous vector for both helpers, which allows us
|
|
||||||
// to use a single call to prg_value_.Evaluate.
|
|
||||||
int blocks_needed = blocks_needed_[hierarchy_level];
|
|
||||||
std::vector<absl::uint128> expanded_seeds(2 * blocks_needed);
|
|
||||||
absl::Span<absl::uint128> expanded_seed_a(&expanded_seeds[0], blocks_needed);
|
|
||||||
absl::Span<absl::uint128> expanded_seed_b(&expanded_seeds[blocks_needed],
|
|
||||||
blocks_needed);
|
|
||||||
ABSL_DCHECK(seeds.size() == 2);
|
|
||||||
std::iota(expanded_seed_a.begin(), expanded_seed_a.end(), seeds[0]);
|
|
||||||
std::iota(expanded_seed_b.begin(), expanded_seed_b.end(), seeds[1]);
|
|
||||||
|
|
||||||
// Evaluate PRG in place (this is safe as `Evaluate` creates a copy of the
|
|
||||||
// input).
|
|
||||||
DPF_RETURN_IF_ERROR(
|
|
||||||
prg_value_.Evaluate(expanded_seeds, absl::MakeSpan(expanded_seeds)));
|
|
||||||
|
|
||||||
// Compute index in block for alpha at the current hierarchy level.
|
|
||||||
int index_in_block = DomainToBlockIndex(alpha, hierarchy_level);
|
|
||||||
|
|
||||||
// Choose implementation depending on element_bitsize.
|
|
||||||
DPF_ASSIGN_OR_RETURN(
|
|
||||||
ValueCorrectionFunction func,
|
|
||||||
GetValueCorrectionFunction(parameters_[hierarchy_level]));
|
|
||||||
return func(
|
|
||||||
absl::string_view(reinterpret_cast<const char*>(expanded_seed_a.data()),
|
|
||||||
blocks_needed * sizeof(absl::uint128)),
|
|
||||||
absl::string_view(reinterpret_cast<const char*>(expanded_seed_b.data()),
|
|
||||||
blocks_needed * sizeof(absl::uint128)),
|
|
||||||
index_in_block, beta, invert);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expands the PRG seeds at the next `tree_level`, updates `seeds` and
|
|
||||||
// `control_bits`, and writes the next correction word to `keys`.
|
|
||||||
absl::Status DistributedPointFunction::GenerateNext(
|
|
||||||
int tree_level, absl::uint128 alpha, absl::Span<const Value> beta,
|
|
||||||
absl::Span<absl::uint128> seeds, absl::Span<bool> control_bits,
|
|
||||||
absl::Span<DpfKey> keys) const {
|
|
||||||
// As in `GenerateKeysIncremental`, we annotate code with the corresponding
|
|
||||||
// lines from https://arxiv.org/pdf/2012.14884.pdf#figure.caption.12.
|
|
||||||
//
|
|
||||||
// Lines 13 & 14: Compute value correction word if there is a value on the
|
|
||||||
// current level. This is done here already, since we use the "PRG evaluation
|
|
||||||
// optimization" described in Appendix C.2 of the paper. Since we are using
|
|
||||||
// fixed-key AES as PRG, which can have arbitrary stretch, this optimization
|
|
||||||
// works even for large output groups.
|
|
||||||
CorrectionWord* correction_word = keys[0].add_correction_words();
|
|
||||||
if (tree_to_hierarchy_.contains(tree_level - 1)) {
|
|
||||||
int hierarchy_level = tree_to_hierarchy_.at(tree_level - 1);
|
|
||||||
absl::uint128 alpha_prefix = 0;
|
|
||||||
int shift_amount = parameters_.back().log_domain_size() -
|
|
||||||
parameters_[hierarchy_level].log_domain_size();
|
|
||||||
if (shift_amount < 128) {
|
|
||||||
alpha_prefix = alpha >> shift_amount;
|
|
||||||
}
|
|
||||||
DPF_ASSIGN_OR_RETURN(
|
|
||||||
std::vector<Value> value_correction,
|
|
||||||
ComputeValueCorrection(hierarchy_level, seeds, alpha_prefix,
|
|
||||||
beta[hierarchy_level], control_bits[1]));
|
|
||||||
for (const Value& value : value_correction) {
|
|
||||||
*(correction_word->add_value_correction()) = value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Line 5: Expand seeds from previous level.
|
|
||||||
std::array<std::array<absl::uint128, 2>, 2> expanded_seeds;
|
|
||||||
DPF_RETURN_IF_ERROR(
|
|
||||||
prg_left_.Evaluate(seeds, absl::MakeSpan(expanded_seeds[0])));
|
|
||||||
DPF_RETURN_IF_ERROR(
|
|
||||||
prg_right_.Evaluate(seeds, absl::MakeSpan(expanded_seeds[1])));
|
|
||||||
std::array<std::array<bool, 2>, 2> expanded_control_bits;
|
|
||||||
expanded_control_bits[0][0] =
|
|
||||||
dpf_internal::ExtractAndClearLowestBit(expanded_seeds[0][0]);
|
|
||||||
expanded_control_bits[0][1] =
|
|
||||||
dpf_internal::ExtractAndClearLowestBit(expanded_seeds[0][1]);
|
|
||||||
expanded_control_bits[1][0] =
|
|
||||||
dpf_internal::ExtractAndClearLowestBit(expanded_seeds[1][0]);
|
|
||||||
expanded_control_bits[1][1] =
|
|
||||||
dpf_internal::ExtractAndClearLowestBit(expanded_seeds[1][1]);
|
|
||||||
|
|
||||||
// Lines 6-8: Assign keep/lose branch depending on current bit of `alpha`.
|
|
||||||
bool current_bit = 0;
|
|
||||||
if (parameters_.back().log_domain_size() - tree_level < 128) {
|
|
||||||
current_bit =
|
|
||||||
(alpha & (absl::uint128{1}
|
|
||||||
<< (parameters_.back().log_domain_size() - tree_level))) != 0;
|
|
||||||
}
|
|
||||||
bool keep = current_bit, lose = !current_bit;
|
|
||||||
|
|
||||||
// Line 9: Compute seed correction word.
|
|
||||||
absl::uint128 seed_correction =
|
|
||||||
expanded_seeds[lose][0] ^ expanded_seeds[lose][1];
|
|
||||||
|
|
||||||
// Line 10: Compute control bit correction words.
|
|
||||||
std::array<bool, 2> control_bit_correction;
|
|
||||||
control_bit_correction[0] = expanded_control_bits[0][0] ^
|
|
||||||
expanded_control_bits[0][1] ^ current_bit ^ 1;
|
|
||||||
control_bit_correction[1] =
|
|
||||||
expanded_control_bits[1][0] ^ expanded_control_bits[1][1] ^ current_bit;
|
|
||||||
|
|
||||||
// We swap lines 11 and 12, since we first need to use the previous level's
|
|
||||||
// control bits before updating them.
|
|
||||||
|
|
||||||
// Line 12: Update seeds. Note that there is a typo in the paper: The
|
|
||||||
// multiplication / AND needs to be done with the control bit of iteration
|
|
||||||
// l-1, not l. Note that unlike the original algorithm, we are using the
|
|
||||||
// corrected seed directly for the next iteration. This is secure as we're
|
|
||||||
// using AES with a different key (kPrgKeyValue) to compute the value
|
|
||||||
// correction word below.
|
|
||||||
seeds[0] = expanded_seeds[keep][0];
|
|
||||||
seeds[1] = expanded_seeds[keep][1];
|
|
||||||
if (control_bits[0]) {
|
|
||||||
seeds[0] ^= seed_correction;
|
|
||||||
}
|
|
||||||
if (control_bits[1]) {
|
|
||||||
seeds[1] ^= seed_correction;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Line 11: Update control bits. Again, same typo as in Line 12.
|
|
||||||
control_bits[0] = expanded_control_bits[keep][0] ^
|
|
||||||
(control_bits[0] && control_bit_correction[keep]);
|
|
||||||
control_bits[1] = expanded_control_bits[keep][1] ^
|
|
||||||
(control_bits[1] && control_bit_correction[keep]);
|
|
||||||
|
|
||||||
// Line 15: Assemble correction word and add it to keys[0].
|
|
||||||
correction_word->mutable_seed()->set_high(
|
|
||||||
absl::Uint128High64(seed_correction));
|
|
||||||
correction_word->mutable_seed()->set_low(absl::Uint128Low64(seed_correction));
|
|
||||||
correction_word->set_control_left(control_bit_correction[0]);
|
|
||||||
correction_word->set_control_right(control_bit_correction[1]);
|
|
||||||
|
|
||||||
// Copy correction word to second key.
|
|
||||||
*(keys[1].add_correction_words()) = *correction_word;
|
|
||||||
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::uint128 DistributedPointFunction::DomainToTreeIndex(
|
|
||||||
absl::uint128 domain_index, int hierarchy_level) const {
|
|
||||||
int block_index_bits = parameters_[hierarchy_level].log_domain_size() -
|
|
||||||
hierarchy_to_tree_[hierarchy_level];
|
|
||||||
ABSL_DCHECK_LT(block_index_bits, 128);
|
|
||||||
return domain_index >> block_index_bits;
|
|
||||||
}
|
|
||||||
|
|
||||||
int DistributedPointFunction::DomainToBlockIndex(absl::uint128 domain_index,
|
|
||||||
int hierarchy_level) const {
|
|
||||||
int block_index_bits = parameters_[hierarchy_level].log_domain_size() -
|
|
||||||
hierarchy_to_tree_[hierarchy_level];
|
|
||||||
ABSL_DCHECK_LT(block_index_bits, 128);
|
|
||||||
return static_cast<int>(domain_index &
|
|
||||||
((absl::uint128{1} << block_index_bits) - 1));
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status DistributedPointFunction::EvaluateSeeds(
|
|
||||||
absl::Span<const absl::uint128> seeds, absl::Span<const bool> control_bits,
|
|
||||||
absl::Span<const absl::uint128> paths,
|
|
||||||
absl::Span<const CorrectionWord* const> correction_words,
|
|
||||||
absl::Span<absl::uint128> seeds_out,
|
|
||||||
absl::Span<bool> control_bits_out) const {
|
|
||||||
if (seeds.size() != control_bits.size() || seeds.size() != paths.size() ||
|
|
||||||
seeds.size() != seeds_out.size() ||
|
|
||||||
seeds.size() != control_bits_out.size()) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"`seeds`, `control_bits`, `paths`, `seeds_out`, and `control_bits_out` "
|
|
||||||
"must all have equal sizes");
|
|
||||||
}
|
|
||||||
auto num_seeds = static_cast<int64_t>(seeds.size());
|
|
||||||
auto num_levels = static_cast<int>(correction_words.size());
|
|
||||||
if (num_seeds == 0 || num_levels == 0) {
|
|
||||||
return absl::OkStatus(); // Nothing to do.
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse correction words for each level.
|
|
||||||
auto correction_seeds = hwy::AllocateAligned<absl::uint128>(num_levels);
|
|
||||||
if (correction_seeds == nullptr) {
|
|
||||||
return absl::ResourceExhaustedError("Memory allocation error");
|
|
||||||
}
|
|
||||||
BitVector correction_controls_left(num_levels),
|
|
||||||
correction_controls_right(num_levels);
|
|
||||||
for (int level = 0; level < num_levels; ++level) {
|
|
||||||
const CorrectionWord& correction = *(correction_words[level]);
|
|
||||||
correction_seeds[level] =
|
|
||||||
absl::MakeUint128(correction.seed().high(), correction.seed().low());
|
|
||||||
correction_controls_left[level] = correction.control_left();
|
|
||||||
correction_controls_right[level] = correction.control_right();
|
|
||||||
}
|
|
||||||
|
|
||||||
ABSL_DCHECK(seeds.size() == num_seeds);
|
|
||||||
ABSL_DCHECK(control_bits.size() == num_seeds);
|
|
||||||
ABSL_DCHECK(correction_controls_left.size() == num_levels);
|
|
||||||
ABSL_DCHECK(correction_controls_right.size() == num_levels);
|
|
||||||
ABSL_DCHECK(seeds_out.size() == num_seeds);
|
|
||||||
ABSL_DCHECK(control_bits_out.size() == num_seeds);
|
|
||||||
DPF_RETURN_IF_ERROR(dpf_internal::EvaluateSeeds(
|
|
||||||
num_seeds, num_levels, num_levels, seeds.data(), control_bits.data(),
|
|
||||||
paths.data(), 0, correction_seeds.get(), correction_controls_left.data(),
|
|
||||||
correction_controls_right.data(), prg_left_, prg_right_, seeds_out.data(),
|
|
||||||
control_bits_out.data()));
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<DistributedPointFunction::DpfExpansion>
|
|
||||||
DistributedPointFunction::ExpandSeeds(
|
|
||||||
const DpfExpansion& partial_evaluations,
|
|
||||||
absl::Span<const CorrectionWord* const> correction_words) const {
|
|
||||||
int num_expansions = static_cast<int>(correction_words.size());
|
|
||||||
|
|
||||||
// Check that the output size fits in a size_t. This should already be checked
|
|
||||||
// by the caller, so using ABSL_DCHECK here is enough.
|
|
||||||
ABSL_DCHECK_LT(num_expansions, 63);
|
|
||||||
auto current_level_size =
|
|
||||||
static_cast<int64_t>(partial_evaluations.control_bits.size());
|
|
||||||
absl::uint128 output_size_128 = absl::uint128{current_level_size}
|
|
||||||
<< num_expansions;
|
|
||||||
ABSL_DCHECK_LE(output_size_128, std::numeric_limits<size_t>::max() / 2);
|
|
||||||
size_t output_size = static_cast<size_t>(output_size_128);
|
|
||||||
|
|
||||||
// Allocate buffers with the correct size to avoid reallocations.
|
|
||||||
int64_t max_batch_size = Aes128FixedKeyHash::kBatchSize;
|
|
||||||
std::vector<absl::uint128> prg_buffer_left(max_batch_size),
|
|
||||||
prg_buffer_right(max_batch_size);
|
|
||||||
|
|
||||||
// Copy seeds and control bits. We will swap these after every expansion.
|
|
||||||
DpfExpansion expansion;
|
|
||||||
expansion.seeds = hwy::AllocateAligned<absl::uint128>(output_size);
|
|
||||||
if (expansion.seeds == nullptr) {
|
|
||||||
return absl::ResourceExhaustedError("Memory allocation error");
|
|
||||||
}
|
|
||||||
std::copy_n(partial_evaluations.seeds.get(), current_level_size,
|
|
||||||
expansion.seeds.get());
|
|
||||||
expansion.control_bits = partial_evaluations.control_bits;
|
|
||||||
expansion.control_bits.reserve(output_size);
|
|
||||||
DpfExpansion next_level_expansion;
|
|
||||||
next_level_expansion.seeds = hwy::AllocateAligned<absl::uint128>(output_size);
|
|
||||||
if (next_level_expansion.seeds == nullptr) {
|
|
||||||
return absl::ResourceExhaustedError("Memory allocation error");
|
|
||||||
}
|
|
||||||
next_level_expansion.control_bits.reserve(output_size);
|
|
||||||
|
|
||||||
// We use an iterative expansion here to pipeline AES as much as possible.
|
|
||||||
for (int i = 0; i < num_expansions; ++i) {
|
|
||||||
next_level_expansion.control_bits.resize(0);
|
|
||||||
absl::uint128 correction_seed = absl::MakeUint128(
|
|
||||||
correction_words[i]->seed().high(), correction_words[i]->seed().low());
|
|
||||||
bool correction_control_left = correction_words[i]->control_left();
|
|
||||||
bool correction_control_right = correction_words[i]->control_right();
|
|
||||||
// Expand PRG.
|
|
||||||
for (int64_t start_block = 0; start_block < current_level_size;
|
|
||||||
start_block += max_batch_size) {
|
|
||||||
int64_t batch_size =
|
|
||||||
std::min<int64_t>(current_level_size - start_block, max_batch_size);
|
|
||||||
DPF_RETURN_IF_ERROR(prg_left_.Evaluate(
|
|
||||||
absl::MakeConstSpan(expansion.seeds.get() + start_block, batch_size),
|
|
||||||
absl::MakeSpan(prg_buffer_left).subspan(0, batch_size)));
|
|
||||||
DPF_RETURN_IF_ERROR(prg_right_.Evaluate(
|
|
||||||
absl::MakeConstSpan(expansion.seeds.get() + start_block, batch_size),
|
|
||||||
absl::MakeSpan(prg_buffer_right).subspan(0, batch_size)));
|
|
||||||
|
|
||||||
// Merge results into next level of seeds and perform correction.
|
|
||||||
for (int64_t j = 0; j < batch_size; ++j) {
|
|
||||||
const int64_t index_expanded = 2 * (start_block + j);
|
|
||||||
if (expansion.control_bits[start_block + j]) {
|
|
||||||
prg_buffer_left[j] ^= correction_seed;
|
|
||||||
prg_buffer_right[j] ^= correction_seed;
|
|
||||||
}
|
|
||||||
next_level_expansion.seeds[index_expanded] = prg_buffer_left[j];
|
|
||||||
next_level_expansion.seeds[index_expanded + 1] = prg_buffer_right[j];
|
|
||||||
next_level_expansion.control_bits.push_back(
|
|
||||||
dpf_internal::ExtractAndClearLowestBit(
|
|
||||||
next_level_expansion.seeds[index_expanded]));
|
|
||||||
next_level_expansion.control_bits.push_back(
|
|
||||||
dpf_internal::ExtractAndClearLowestBit(
|
|
||||||
next_level_expansion.seeds[index_expanded + 1]));
|
|
||||||
if (expansion.control_bits[start_block + j]) {
|
|
||||||
next_level_expansion.control_bits[index_expanded] ^=
|
|
||||||
correction_control_left;
|
|
||||||
next_level_expansion.control_bits[index_expanded + 1] ^=
|
|
||||||
correction_control_right;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::swap(expansion, next_level_expansion);
|
|
||||||
current_level_size *= 2;
|
|
||||||
}
|
|
||||||
return expansion;
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<DistributedPointFunction::DpfExpansion>
|
|
||||||
DistributedPointFunction::ComputePartialEvaluations(
|
|
||||||
absl::Span<const absl::uint128> prefixes, int hierarchy_level,
|
|
||||||
bool update_ctx, EvaluationContext& ctx) const {
|
|
||||||
int64_t num_prefixes = static_cast<int64_t>(prefixes.size());
|
|
||||||
|
|
||||||
DpfExpansion partial_evaluations;
|
|
||||||
int start_level = hierarchy_to_tree_[ctx.partial_evaluations_level()];
|
|
||||||
int stop_level = hierarchy_to_tree_[hierarchy_level];
|
|
||||||
if (ctx.partial_evaluations_size() > 0 && start_level <= stop_level) {
|
|
||||||
// We have partial evaluations from a tree level before the current one.
|
|
||||||
// Parse `ctx.partial_evaluations` into a btree_map for quick lookups up by
|
|
||||||
// prefix. We use a btree_map because `ctx.partial_evaluations()` will
|
|
||||||
// usually be sorted.
|
|
||||||
absl::btree_map<absl::uint128, std::pair<absl::uint128, bool>>
|
|
||||||
previous_partial_evaluations;
|
|
||||||
for (const PartialEvaluation& element : ctx.partial_evaluations()) {
|
|
||||||
absl::uint128 prefix =
|
|
||||||
absl::MakeUint128(element.prefix().high(), element.prefix().low());
|
|
||||||
// Try inserting `(seed, control_bit)` at `prefix` into
|
|
||||||
// partial_evaluations. Return an error if `prefix` is already present
|
|
||||||
// with a different seed or control bit.
|
|
||||||
auto value = std::make_pair(
|
|
||||||
absl::MakeUint128(element.seed().high(), element.seed().low()),
|
|
||||||
element.control_bit());
|
|
||||||
auto it = previous_partial_evaluations.try_emplace(
|
|
||||||
previous_partial_evaluations.end(), prefix, value);
|
|
||||||
if (it->second != value) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"Duplicate prefix in `ctx.partial_evaluations()` with mismatching "
|
|
||||||
"seed or control bit");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Now select all partial evaluations from the map that correspond to
|
|
||||||
// `prefixes`.
|
|
||||||
partial_evaluations.seeds =
|
|
||||||
hwy::AllocateAligned<absl::uint128>(num_prefixes);
|
|
||||||
if (partial_evaluations.seeds == nullptr) {
|
|
||||||
return absl::ResourceExhaustedError("Memory allocation error");
|
|
||||||
}
|
|
||||||
partial_evaluations.control_bits.reserve(num_prefixes);
|
|
||||||
for (int64_t i = 0; i < num_prefixes; ++i) {
|
|
||||||
absl::uint128 previous_prefix = 0;
|
|
||||||
if (stop_level - start_level < 128) {
|
|
||||||
previous_prefix = prefixes[i] >> (stop_level - start_level);
|
|
||||||
}
|
|
||||||
auto it = previous_partial_evaluations.find(previous_prefix);
|
|
||||||
if (it == previous_partial_evaluations.end()) {
|
|
||||||
return absl::InvalidArgumentError(absl::StrCat(
|
|
||||||
"Prefix not present in ctx.partial_evaluations at hierarchy level ",
|
|
||||||
hierarchy_level));
|
|
||||||
}
|
|
||||||
const std::pair<absl::uint128, bool>& partial_evaluation = it->second;
|
|
||||||
partial_evaluations.seeds[partial_evaluations.control_bits.size()] =
|
|
||||||
partial_evaluation.first;
|
|
||||||
partial_evaluations.control_bits.push_back(partial_evaluation.second);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// No partial evaluations in `ctx` -> Start from the beginning.
|
|
||||||
partial_evaluations.seeds =
|
|
||||||
hwy::AllocateAligned<absl::uint128>(num_prefixes);
|
|
||||||
if (partial_evaluations.seeds == nullptr) {
|
|
||||||
return absl::ResourceExhaustedError("Memory allocation error");
|
|
||||||
}
|
|
||||||
auto seeds = absl::MakeSpan(partial_evaluations.seeds.get(), num_prefixes);
|
|
||||||
std::fill(
|
|
||||||
seeds.begin(), seeds.end(),
|
|
||||||
absl::MakeUint128(ctx.key().seed().high(), ctx.key().seed().low()));
|
|
||||||
partial_evaluations.control_bits.resize(
|
|
||||||
num_prefixes, static_cast<bool>(ctx.key().party()));
|
|
||||||
start_level = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Evaluate the DPF up to current_tree_level.
|
|
||||||
auto seeds = absl::MakeSpan(partial_evaluations.seeds.get(),
|
|
||||||
partial_evaluations.control_bits.size());
|
|
||||||
DPF_RETURN_IF_ERROR(
|
|
||||||
EvaluateSeeds(seeds, partial_evaluations.control_bits, prefixes,
|
|
||||||
absl::MakeConstSpan(ctx.key().correction_words())
|
|
||||||
.subspan(start_level, stop_level - start_level),
|
|
||||||
seeds, absl::MakeSpan(partial_evaluations.control_bits)));
|
|
||||||
|
|
||||||
// Update `partial_evaluations` in `ctx` if there are more evaluations to
|
|
||||||
// come.
|
|
||||||
ctx.clear_partial_evaluations();
|
|
||||||
ctx.mutable_partial_evaluations()->Reserve(num_prefixes);
|
|
||||||
if (update_ctx) {
|
|
||||||
for (int64_t i = 0; i < num_prefixes; ++i) {
|
|
||||||
PartialEvaluation* current_element = ctx.add_partial_evaluations();
|
|
||||||
current_element->mutable_prefix()->set_high(
|
|
||||||
absl::Uint128High64(prefixes[i]));
|
|
||||||
current_element->mutable_prefix()->set_low(
|
|
||||||
absl::Uint128Low64(prefixes[i]));
|
|
||||||
current_element->mutable_seed()->set_high(
|
|
||||||
absl::Uint128High64(partial_evaluations.seeds[i]));
|
|
||||||
current_element->mutable_seed()->set_low(
|
|
||||||
absl::Uint128Low64(partial_evaluations.seeds[i]));
|
|
||||||
current_element->set_control_bit(partial_evaluations.control_bits[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ctx.set_partial_evaluations_level(hierarchy_level);
|
|
||||||
return partial_evaluations;
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<DistributedPointFunction::DpfExpansion>
|
|
||||||
DistributedPointFunction::ExpandAndUpdateContext(
|
|
||||||
int hierarchy_level, absl::Span<const absl::uint128> prefixes,
|
|
||||||
EvaluationContext& ctx) const {
|
|
||||||
// Expand seeds by expanding either the DPF key seed, or
|
|
||||||
// `ctx.partial_evaluations` for the given `prefixes`.
|
|
||||||
DpfExpansion selected_partial_evaluations;
|
|
||||||
int start_level = 0;
|
|
||||||
if (prefixes.empty()) {
|
|
||||||
// First expansion -> Expand seed of the DPF key.
|
|
||||||
selected_partial_evaluations.seeds = hwy::AllocateAligned<absl::uint128>(1);
|
|
||||||
if (selected_partial_evaluations.seeds == nullptr) {
|
|
||||||
return absl::ResourceExhaustedError("Memory allocation error");
|
|
||||||
}
|
|
||||||
selected_partial_evaluations.seeds[0] =
|
|
||||||
absl::MakeUint128(ctx.key().seed().high(), ctx.key().seed().low());
|
|
||||||
selected_partial_evaluations.control_bits = {
|
|
||||||
static_cast<bool>(ctx.key().party())};
|
|
||||||
} else {
|
|
||||||
// Second or later expansion -> Extract all seeds for `prefixes` from
|
|
||||||
// `ctx.partial_evaluations`. Update `ctx` if this is not the last
|
|
||||||
// evaluation.
|
|
||||||
bool update_ctx =
|
|
||||||
(hierarchy_level < static_cast<int>(parameters_.size()) - 1);
|
|
||||||
ABSL_DCHECK(ctx.previous_hierarchy_level() >= 0);
|
|
||||||
DPF_ASSIGN_OR_RETURN(
|
|
||||||
selected_partial_evaluations,
|
|
||||||
ComputePartialEvaluations(prefixes, ctx.previous_hierarchy_level(),
|
|
||||||
update_ctx, ctx));
|
|
||||||
start_level = hierarchy_to_tree_[ctx.previous_hierarchy_level()];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expand up to the next hierarchy level.
|
|
||||||
int stop_level = hierarchy_to_tree_[hierarchy_level];
|
|
||||||
DPF_ASSIGN_OR_RETURN(
|
|
||||||
DpfExpansion expansion,
|
|
||||||
ExpandSeeds(selected_partial_evaluations,
|
|
||||||
absl::MakeConstSpan(ctx.key().correction_words())
|
|
||||||
.subspan(start_level, stop_level - start_level)));
|
|
||||||
|
|
||||||
// Update hierarchy level in ctx.
|
|
||||||
ctx.set_previous_hierarchy_level(hierarchy_level);
|
|
||||||
return expansion;
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<hwy::AlignedFreeUniquePtr<absl::uint128[]>>
|
|
||||||
DistributedPointFunction::HashExpandedSeeds(
|
|
||||||
int hierarchy_level, absl::Span<const absl::uint128> expansion) const {
|
|
||||||
const auto expansion_size = static_cast<int64_t>(expansion.size());
|
|
||||||
const int blocks_needed = blocks_needed_[hierarchy_level];
|
|
||||||
auto hashed_expansion =
|
|
||||||
hwy::AllocateAligned<absl::uint128>(expansion_size * blocks_needed);
|
|
||||||
if (hashed_expansion == nullptr) {
|
|
||||||
return absl::ResourceExhaustedError("Memory allocation error");
|
|
||||||
}
|
|
||||||
for (int64_t i = 0; i < expansion_size; ++i) {
|
|
||||||
for (int j = 0; j < blocks_needed; ++j) {
|
|
||||||
hashed_expansion[i * blocks_needed + j] = expansion[i] + j;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Evaluate PRG in place (this is safe as `Evaluate` creates a copy of the
|
|
||||||
// input).
|
|
||||||
absl::Span<absl::uint128> hashed_expansion_span(
|
|
||||||
hashed_expansion.get(), expansion_size * blocks_needed);
|
|
||||||
DPF_RETURN_IF_ERROR(
|
|
||||||
prg_value_.Evaluate(hashed_expansion_span, hashed_expansion_span));
|
|
||||||
|
|
||||||
return hashed_expansion;
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<std::string>
|
|
||||||
DistributedPointFunction::SerializeValueTypeDeterministically(
|
|
||||||
const ValueType& value_type) {
|
|
||||||
// We need to do serialization to a string by hand, in order to use
|
|
||||||
// deterministic serialization.
|
|
||||||
std::string serialized_value_type;
|
|
||||||
{ // Start new block so that stream destructors are run before returning.
|
|
||||||
::google::protobuf::io::StringOutputStream string_stream(
|
|
||||||
&serialized_value_type);
|
|
||||||
::google::protobuf::io::CodedOutputStream coded_stream(&string_stream);
|
|
||||||
coded_stream.SetSerializationDeterministic(true);
|
|
||||||
if (!value_type.SerializeToCodedStream(&coded_stream)) {
|
|
||||||
return absl::InternalError("Serializing value_type to string failed");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return serialized_value_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<DistributedPointFunction::ValueCorrectionFunction>
|
|
||||||
DistributedPointFunction::GetValueCorrectionFunction(
|
|
||||||
const DpfParameters& parameters) const {
|
|
||||||
std::string serialized_value_type;
|
|
||||||
DPF_ASSIGN_OR_RETURN(
|
|
||||||
serialized_value_type,
|
|
||||||
SerializeValueTypeDeterministically(parameters.value_type()));
|
|
||||||
auto it = value_correction_functions_.find(serialized_value_type);
|
|
||||||
if (it == value_correction_functions_.end()) {
|
|
||||||
return absl::FailedPreconditionError(absl::StrCat(
|
|
||||||
"No value correction function known for the following parameters:\n",
|
|
||||||
parameters.DebugString(),
|
|
||||||
"Did you call RegisterValueType<T>() with your value type?"));
|
|
||||||
}
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<std::unique_ptr<DistributedPointFunction>>
|
|
||||||
DistributedPointFunction::Create(const DpfParameters& parameters) {
|
|
||||||
return CreateIncremental(absl::MakeConstSpan(¶meters, 1));
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<std::unique_ptr<DistributedPointFunction>>
|
|
||||||
DistributedPointFunction::CreateIncremental(
|
|
||||||
absl::Span<const DpfParameters> parameters) {
|
|
||||||
// Log Highway mode for debugging.
|
|
||||||
ABSL_LOG_FIRST_N(INFO, 1)
|
|
||||||
<< "Highway is in " << dpf_internal::GetHwyModeAsString() << " mode";
|
|
||||||
|
|
||||||
// Validate `parameters` and store validator for later.
|
|
||||||
DPF_ASSIGN_OR_RETURN(
|
|
||||||
std::unique_ptr<dpf_internal::ProtoValidator> proto_validator,
|
|
||||||
dpf_internal::ProtoValidator::Create(parameters));
|
|
||||||
|
|
||||||
// Compute the number of value correction blocks needed for each hierarchy
|
|
||||||
// level.
|
|
||||||
std::vector<int> blocks_needed(parameters.size());
|
|
||||||
for (int i = 0; i < static_cast<int>(parameters.size()); ++i) {
|
|
||||||
DPF_ASSIGN_OR_RETURN(
|
|
||||||
int bits_needed,
|
|
||||||
dpf_internal::BitsNeeded(parameters[i].value_type(),
|
|
||||||
parameters[i].security_parameter()));
|
|
||||||
blocks_needed[i] = (bits_needed + 127) / 128;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up hash functions for PRG.
|
|
||||||
DPF_ASSIGN_OR_RETURN(Aes128FixedKeyHash prg_left,
|
|
||||||
Aes128FixedKeyHash::Create(kPrgKeyLeft));
|
|
||||||
DPF_ASSIGN_OR_RETURN(Aes128FixedKeyHash prg_right,
|
|
||||||
Aes128FixedKeyHash::Create(kPrgKeyRight));
|
|
||||||
DPF_ASSIGN_OR_RETURN(Aes128FixedKeyHash prg_value,
|
|
||||||
Aes128FixedKeyHash::Create(kPrgKeyValue));
|
|
||||||
|
|
||||||
// For backwards compatibility, register all single unsigned integers as value
|
|
||||||
// types.
|
|
||||||
absl::flat_hash_map<std::string, ValueCorrectionFunction>
|
|
||||||
value_correction_functions;
|
|
||||||
DPF_RETURN_IF_ERROR(
|
|
||||||
RegisterValueTypeImpl<uint8_t>(value_correction_functions));
|
|
||||||
DPF_RETURN_IF_ERROR(
|
|
||||||
RegisterValueTypeImpl<uint16_t>(value_correction_functions));
|
|
||||||
DPF_RETURN_IF_ERROR(
|
|
||||||
RegisterValueTypeImpl<uint32_t>(value_correction_functions));
|
|
||||||
DPF_RETURN_IF_ERROR(
|
|
||||||
RegisterValueTypeImpl<uint64_t>(value_correction_functions));
|
|
||||||
DPF_RETURN_IF_ERROR(
|
|
||||||
RegisterValueTypeImpl<absl::uint128>(value_correction_functions));
|
|
||||||
|
|
||||||
// Copy parameters and return new DPF.
|
|
||||||
return absl::WrapUnique(new DistributedPointFunction(
|
|
||||||
std::move(proto_validator), std::move(blocks_needed), std::move(prg_left),
|
|
||||||
std::move(prg_right), std::move(prg_value),
|
|
||||||
std::move(value_correction_functions)));
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<std::pair<DpfKey, DpfKey>>
|
|
||||||
DistributedPointFunction::GenerateKeysIncremental(
|
|
||||||
absl::uint128 alpha, absl::Span<const Value> beta) {
|
|
||||||
// Check validity of beta.
|
|
||||||
if (beta.size() != parameters_.size()) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"`beta` has to have the same size as `parameters` passed at "
|
|
||||||
"construction");
|
|
||||||
}
|
|
||||||
for (int i = 0; i < static_cast<int>(parameters_.size()); ++i) {
|
|
||||||
absl::Status status = proto_validator_->ValidateValue(beta[i], i);
|
|
||||||
if (!status.ok()) {
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check validity of alpha.
|
|
||||||
int last_level_log_domain_size = parameters_.back().log_domain_size();
|
|
||||||
if (last_level_log_domain_size < 128 &&
|
|
||||||
alpha >= (absl::uint128{1} << last_level_log_domain_size)) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"`alpha` must be smaller than the output domain size");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::array<DpfKey, 2> keys;
|
|
||||||
keys[0].set_party(0);
|
|
||||||
keys[1].set_party(1);
|
|
||||||
|
|
||||||
// We will annotate the following code with the corresponding lines from the
|
|
||||||
// pseudocode in the Incremental DPF paper
|
|
||||||
// (https://arxiv.org/pdf/2012.14884.pdf, Figure 11).
|
|
||||||
//
|
|
||||||
// There are two possible dimensions for each variable at each level: Parties
|
|
||||||
// (0 or 1) and branches (left or right). For two-dimensional arrays, we use
|
|
||||||
// the outer dimension for the branch, and the inner dimension for the party.
|
|
||||||
//
|
|
||||||
// Line 2: Sample random seeds for each party.
|
|
||||||
std::array<absl::uint128, 2> seeds;
|
|
||||||
RAND_bytes(reinterpret_cast<uint8_t*>(&seeds[0]), sizeof(absl::uint128));
|
|
||||||
RAND_bytes(reinterpret_cast<uint8_t*>(&seeds[1]), sizeof(absl::uint128));
|
|
||||||
keys[0].mutable_seed()->set_high(absl::Uint128High64(seeds[0]));
|
|
||||||
keys[0].mutable_seed()->set_low(absl::Uint128Low64(seeds[0]));
|
|
||||||
keys[1].mutable_seed()->set_high(absl::Uint128High64(seeds[1]));
|
|
||||||
keys[1].mutable_seed()->set_low(absl::Uint128Low64(seeds[1]));
|
|
||||||
|
|
||||||
// Line 3: Initialize control bits.
|
|
||||||
std::array<bool, 2> control_bits{0, 1};
|
|
||||||
|
|
||||||
// Line 4: Compute correction words for each level after the first one.
|
|
||||||
keys[0].mutable_correction_words()->Reserve(tree_levels_needed_ - 1);
|
|
||||||
keys[1].mutable_correction_words()->Reserve(tree_levels_needed_ - 1);
|
|
||||||
for (int i = 1; i < tree_levels_needed_; i++) {
|
|
||||||
DPF_RETURN_IF_ERROR(GenerateNext(i, alpha, beta, absl::MakeSpan(seeds),
|
|
||||||
absl::MakeSpan(control_bits),
|
|
||||||
absl::MakeSpan(keys)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute output correction word for last layer.
|
|
||||||
DPF_ASSIGN_OR_RETURN(
|
|
||||||
std::vector<Value> last_level_value_correction,
|
|
||||||
ComputeValueCorrection(parameters_.size() - 1, seeds, alpha, beta.back(),
|
|
||||||
control_bits[1]));
|
|
||||||
for (const Value& value : last_level_value_correction) {
|
|
||||||
*(keys[0].add_last_level_value_correction()) = value;
|
|
||||||
*(keys[1].add_last_level_value_correction()) = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::make_pair(std::move(keys[0]), std::move(keys[1]));
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<EvaluationContext>
|
|
||||||
DistributedPointFunction::CreateEvaluationContext(DpfKey key) const {
|
|
||||||
// Check that `key` is valid.
|
|
||||||
DPF_RETURN_IF_ERROR(proto_validator_->ValidateDpfKey(key));
|
|
||||||
|
|
||||||
// Create new EvaluationContext with `parameters_` and `key`.
|
|
||||||
EvaluationContext result;
|
|
||||||
for (int i = 0; i < static_cast<int>(parameters_.size()); ++i) {
|
|
||||||
*(result.add_parameters()) = parameters_[i];
|
|
||||||
}
|
|
||||||
*(result.mutable_key()) = std::move(key);
|
|
||||||
// previous_hierarchy_level = -1 means that this context has not been
|
|
||||||
// evaluated at all.
|
|
||||||
result.set_previous_hierarchy_level(-1);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace distributed_point_functions
|
|
File diff suppressed because it is too large
Load Diff
@ -1,171 +0,0 @@
|
|||||||
// Copyright 2021 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
syntax = "proto3";
|
|
||||||
|
|
||||||
package distributed_point_functions;
|
|
||||||
|
|
||||||
// For faster allocations of sub-messages.
|
|
||||||
option cc_enable_arenas = true;
|
|
||||||
|
|
||||||
// Describes the type of a single DPF output value. Any additional types added
|
|
||||||
// here should also be supported in internal/value_type_helpers.h.
|
|
||||||
// LINT.IfChange
|
|
||||||
message ValueType {
|
|
||||||
// Describes an integer modulo 2^l. Maps to the C++ types `uint8_t`,
|
|
||||||
// `uint16_t`, `uint32_t`, `uint64_t`, and `absl::uint128`.
|
|
||||||
message Integer {
|
|
||||||
// Number of bits per integer. Must be a power of 2 and at most 128.
|
|
||||||
int32 bitsize = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Describes a tuple of value types.
|
|
||||||
message Tuple {
|
|
||||||
repeated ValueType elements = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Describes an integer ring modulo `modulus`.
|
|
||||||
message IntModN {
|
|
||||||
// The underlying integer type used to represent elements in the ring.
|
|
||||||
Integer base_integer = 1;
|
|
||||||
// The modulus.
|
|
||||||
Value.Integer modulus = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
oneof type {
|
|
||||||
// A single integer modulo 2^l.
|
|
||||||
Integer integer = 1;
|
|
||||||
// A tuple of values.
|
|
||||||
Tuple tuple = 2;
|
|
||||||
// A integer with custom modulus.
|
|
||||||
IntModN int_mod_n = 3;
|
|
||||||
// An XOR-wrapped integer. Corresponds to the XorWrapper C++ class.
|
|
||||||
Integer xor_wrapper = 4;
|
|
||||||
}
|
|
||||||
// Do not add fields outside of the `oneof` above, to ensure that messages
|
|
||||||
// with known ValueTypes are serialized deterministically. See the
|
|
||||||
// documentation of `value_correction_functions_` in
|
|
||||||
// distributed_point_function.h for details.
|
|
||||||
}
|
|
||||||
|
|
||||||
// Used to correct output values to the desired DPF magnitude. Holds the values
|
|
||||||
// corresponding to the types defined in `ValueType`.
|
|
||||||
message Value {
|
|
||||||
message Integer {
|
|
||||||
oneof value {
|
|
||||||
// Any value up to 64 bits.
|
|
||||||
uint64 value_uint64 = 1;
|
|
||||||
// 128-bit values.
|
|
||||||
Block value_uint128 = 2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
message Tuple {
|
|
||||||
repeated Value elements = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
oneof value {
|
|
||||||
Integer integer = 1;
|
|
||||||
Tuple tuple = 2;
|
|
||||||
Integer int_mod_n =
|
|
||||||
3; // The value of an IntModN is represented by its base_integer type.
|
|
||||||
Integer xor_wrapper = 4;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// LINT.ThenChange(
|
|
||||||
// internal/value_type_helpers.h,
|
|
||||||
// internal/value_type_helpers.cc
|
|
||||||
// )
|
|
||||||
|
|
||||||
// Parameters of a single hierarchy level of a distributed point function (DPF).
|
|
||||||
message DpfParameters {
|
|
||||||
reserved 2;
|
|
||||||
// Base-2 logarithm of the number of elements.
|
|
||||||
int32 log_domain_size = 1;
|
|
||||||
// Describes the type of output values at this hierarchy level.
|
|
||||||
ValueType value_type = 3;
|
|
||||||
// The negative logarithm of the total variation distance from uniform that an
|
|
||||||
// evaluation at a *single point* at this hierarchy level is allowed to have.
|
|
||||||
// The correct value for this parameter depends on the maximum number of
|
|
||||||
// points at which this hierarchy level is evaluated. It should be at least 40
|
|
||||||
// + log2(number_of_evaluation_points). Defaults to
|
|
||||||
// ProtoValidator::kDefaultSecurityParameter + log_domain_size.
|
|
||||||
double security_parameter = 4;
|
|
||||||
}
|
|
||||||
|
|
||||||
// A single 128-bit AES block.
|
|
||||||
message Block {
|
|
||||||
uint64 high = 1;
|
|
||||||
uint64 low = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// A correction word used to evaluate a single layer in the DPF evaluation tree.
|
|
||||||
message CorrectionWord {
|
|
||||||
// Block used to correct the new seeds after PRG evaluation.
|
|
||||||
Block seed = 1;
|
|
||||||
// Correction bits for the left and right control bits.
|
|
||||||
bool control_left = 2;
|
|
||||||
bool control_right = 3;
|
|
||||||
// Reserved for deprecated value correction field.
|
|
||||||
reserved 4;
|
|
||||||
// Used to correct the output value at the previous tree layer. Only included
|
|
||||||
// if the previous tree layer is an output layer. Repeated to capture the case
|
|
||||||
// where multiple correction values are needed due to packing.
|
|
||||||
repeated Value value_correction = 5;
|
|
||||||
}
|
|
||||||
|
|
||||||
// A key of a distributed point function (DPF).
|
|
||||||
message DpfKey {
|
|
||||||
// Initial seed at the first level.
|
|
||||||
Block seed = 1;
|
|
||||||
// Correction words for each level after expansion.
|
|
||||||
repeated CorrectionWord correction_words = 2;
|
|
||||||
// Party this DpfKey belongs to (0 or 1).
|
|
||||||
int32 party = 3;
|
|
||||||
// Deprecated last level value correction.
|
|
||||||
reserved 4;
|
|
||||||
// Output correction for the last level of the evaluation tree.
|
|
||||||
repeated Value last_level_value_correction = 5;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Maps a single prefix of a DPF index to a PRG seed. Used to store partial
|
|
||||||
// evaluation state between hierarchy levels in `EvaluationContext`
|
|
||||||
message PartialEvaluation {
|
|
||||||
// Prefix in the FSS evaluation tree. Does not necessarily coincide with the
|
|
||||||
// corresponding prefix of the output domain at this hierarchy level.
|
|
||||||
Block prefix = 1;
|
|
||||||
// Seed for the next evaluation.
|
|
||||||
Block seed = 2;
|
|
||||||
// Control bit for the correction in the next evaluation.
|
|
||||||
bool control_bit = 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
// An EvaluationContext holds the state of a partially evaluated incremental
|
|
||||||
// DPF.
|
|
||||||
message EvaluationContext {
|
|
||||||
// The parameters of the DPF being evaluated. One set of parameters for each
|
|
||||||
// hierarchy level of the incremental DPF.
|
|
||||||
repeated DpfParameters parameters = 1;
|
|
||||||
// The DPF key being evaluated.
|
|
||||||
DpfKey key = 2;
|
|
||||||
// The hierarchy level that this EvaluationContext was last evaluated on.
|
|
||||||
int32 previous_hierarchy_level = 3;
|
|
||||||
// Maps prefixes from an earlier hierarchy level to PRG seeds, which are used
|
|
||||||
// to continue the evaluation under each prefix. Uses a repeated message field
|
|
||||||
// since Protobuf doesn't allow messages (such as `Block`) as map keys.
|
|
||||||
repeated PartialEvaluation partial_evaluations = 4;
|
|
||||||
// The hierarchy level `partial_evaluations` corresponds to. Ignored when
|
|
||||||
// `partial_evaluations` is empty.
|
|
||||||
int32 partial_evaluations_level = 5;
|
|
||||||
}
|
|
418
third_party/distributed_point_functions/code/dpf/distributed_point_function_benchmark.cc
vendored
418
third_party/distributed_point_functions/code/dpf/distributed_point_function_benchmark.cc
vendored
@ -1,418 +0,0 @@
|
|||||||
// Copyright 2021 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cmath>
|
|
||||||
#include <memory>
|
|
||||||
#include <numeric>
|
|
||||||
#include <string>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/container/btree_set.h"
|
|
||||||
#include "absl/log/absl_check.h"
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "absl/random/random.h"
|
|
||||||
#include "absl/random/uniform_int_distribution.h"
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "absl/types/span.h"
|
|
||||||
#include "benchmark/benchmark.h"
|
|
||||||
#include "dpf/distributed_point_function.h"
|
|
||||||
#include "google/protobuf/arena.h"
|
|
||||||
#include "hwy/aligned_allocator.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
// Benchmarks a regular DPF evaluation. Expects the first range argument to
|
|
||||||
// specify the output log domain size.
|
|
||||||
template <typename T>
|
|
||||||
void BM_EvaluateRegularDpf(benchmark::State& state) {
|
|
||||||
DpfParameters parameters;
|
|
||||||
parameters.set_log_domain_size(state.range(0));
|
|
||||||
*(parameters.mutable_value_type()) = ToValueType<T>();
|
|
||||||
std::unique_ptr<DistributedPointFunction> dpf =
|
|
||||||
DistributedPointFunction::Create(parameters).value();
|
|
||||||
absl::uint128 alpha = 0;
|
|
||||||
T beta{};
|
|
||||||
ABSL_CHECK(dpf->RegisterValueType<T>().ok());
|
|
||||||
std::pair<DpfKey, DpfKey> keys = dpf->GenerateKeys(alpha, beta).value();
|
|
||||||
EvaluationContext ctx_0 = dpf->CreateEvaluationContext(keys.first).value();
|
|
||||||
for (auto s : state) {
|
|
||||||
google::protobuf::Arena arena;
|
|
||||||
EvaluationContext* ctx =
|
|
||||||
google::protobuf::Arena::CreateMessage<EvaluationContext>(&arena);
|
|
||||||
*ctx = ctx_0;
|
|
||||||
std::vector<T> result = dpf->EvaluateNext<T>({}, *ctx).value();
|
|
||||||
benchmark::DoNotOptimize(result);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, uint8_t)->DenseRange(12, 24, 2);
|
|
||||||
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, uint16_t)->DenseRange(12, 24, 2);
|
|
||||||
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, uint32_t)->DenseRange(12, 24, 2);
|
|
||||||
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, uint64_t)->DenseRange(12, 24, 2);
|
|
||||||
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, absl::uint128)->DenseRange(12, 24, 2);
|
|
||||||
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, Tuple<uint32_t, uint32_t>)
|
|
||||||
->DenseRange(12, 24, 2);
|
|
||||||
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, Tuple<uint32_t, uint64_t>)
|
|
||||||
->DenseRange(12, 24, 2);
|
|
||||||
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, Tuple<uint64_t, uint64_t>)
|
|
||||||
->DenseRange(12, 24, 2);
|
|
||||||
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf,
|
|
||||||
Tuple<uint32_t, uint32_t, uint32_t, uint32_t>)
|
|
||||||
->DenseRange(12, 24, 2);
|
|
||||||
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf,
|
|
||||||
Tuple<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t>)
|
|
||||||
->DenseRange(12, 24, 2);
|
|
||||||
BENCHMARK_TEMPLATE(
|
|
||||||
BM_EvaluateRegularDpf,
|
|
||||||
Tuple<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t>)
|
|
||||||
->DenseRange(12, 24, 2);
|
|
||||||
|
|
||||||
using MyIntModN = IntModN<uint32_t, 4294967291u>; // 2**32 - 5.
|
|
||||||
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf,
|
|
||||||
Tuple<MyIntModN, MyIntModN, MyIntModN, MyIntModN, MyIntModN>)
|
|
||||||
->DenseRange(12, 24, 2);
|
|
||||||
using MyIntModN64 = IntModN<uint64_t, 18446744073709551557ull>; // 2**64 - 59.
|
|
||||||
BENCHMARK_TEMPLATE(
|
|
||||||
BM_EvaluateRegularDpf,
|
|
||||||
Tuple<MyIntModN64, MyIntModN64, MyIntModN64, MyIntModN64, MyIntModN64>)
|
|
||||||
->DenseRange(12, 22, 2);
|
|
||||||
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, XorWrapper<absl::uint128>)
|
|
||||||
->DenseRange(1, 24, 1);
|
|
||||||
|
|
||||||
// Benchmarks full evaluation of all hierarchy levels. Expects the first range
|
|
||||||
// argument to specify the number of iterations. The output domain size is fixed
|
|
||||||
// to 2**20.
|
|
||||||
template <typename T>
|
|
||||||
void BM_EvaluateHierarchicalFull(benchmark::State& state) {
|
|
||||||
// Set up DPF with the given parameters.
|
|
||||||
const int kMaxLogDomainSize = 20;
|
|
||||||
int num_hierarchy_levels = state.range(0);
|
|
||||||
std::vector<DpfParameters> parameters(num_hierarchy_levels);
|
|
||||||
for (int i = 0; i < num_hierarchy_levels; ++i) {
|
|
||||||
parameters[i].set_log_domain_size(static_cast<int>(
|
|
||||||
static_cast<double>(i + 1) / num_hierarchy_levels * kMaxLogDomainSize));
|
|
||||||
parameters[i].mutable_value_type()->mutable_integer()->set_bitsize(
|
|
||||||
sizeof(T) * 8);
|
|
||||||
}
|
|
||||||
std::unique_ptr<DistributedPointFunction> dpf =
|
|
||||||
DistributedPointFunction::CreateIncremental(parameters).value();
|
|
||||||
|
|
||||||
// Generate keys.
|
|
||||||
absl::uint128 alpha = 12345;
|
|
||||||
std::vector<absl::uint128> beta(num_hierarchy_levels);
|
|
||||||
for (int i = 0; i < num_hierarchy_levels; ++i) {
|
|
||||||
beta[i] = i;
|
|
||||||
}
|
|
||||||
std::pair<DpfKey, DpfKey> keys =
|
|
||||||
dpf->GenerateKeysIncremental(alpha, beta).value();
|
|
||||||
|
|
||||||
// Set up evaluation context and evaluation prefixes for each level.
|
|
||||||
EvaluationContext ctx_0 = dpf->CreateEvaluationContext(keys.first).value();
|
|
||||||
std::vector<std::vector<absl::uint128>> prefixes(num_hierarchy_levels);
|
|
||||||
for (int i = 1; i < num_hierarchy_levels; ++i) {
|
|
||||||
prefixes[i].resize(1 << parameters[i - 1].log_domain_size());
|
|
||||||
std::iota(prefixes[i].begin(), prefixes[i].end(), absl::uint128{0});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run hierarchical evaluation.
|
|
||||||
for (auto s : state) {
|
|
||||||
google::protobuf::Arena arena;
|
|
||||||
EvaluationContext* ctx =
|
|
||||||
google::protobuf::Arena::CreateMessage<EvaluationContext>(&arena);
|
|
||||||
*ctx = ctx_0;
|
|
||||||
for (int i = 0; i < num_hierarchy_levels; ++i) {
|
|
||||||
std::vector<T> result = dpf->EvaluateNext<T>(prefixes[i], *ctx).value();
|
|
||||||
benchmark::DoNotOptimize(result);
|
|
||||||
}
|
|
||||||
benchmark::DoNotOptimize(*ctx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
BENCHMARK_TEMPLATE(BM_EvaluateHierarchicalFull, uint8_t)->DenseRange(1, 16, 2);
|
|
||||||
BENCHMARK_TEMPLATE(BM_EvaluateHierarchicalFull, uint16_t)->DenseRange(1, 16, 2);
|
|
||||||
BENCHMARK_TEMPLATE(BM_EvaluateHierarchicalFull, uint32_t)->DenseRange(1, 16, 2);
|
|
||||||
BENCHMARK_TEMPLATE(BM_EvaluateHierarchicalFull, uint64_t)->DenseRange(1, 16, 2);
|
|
||||||
BENCHMARK_TEMPLATE(BM_EvaluateHierarchicalFull, absl::uint128)
|
|
||||||
->DenseRange(1, 16, 2);
|
|
||||||
|
|
||||||
// Generates random prefixes for the given set of `parameters`. Generates
|
|
||||||
// `num_nonzeros[i]` prefixes at hierarchy level `i`.
|
|
||||||
std::vector<std::vector<absl::uint128>> GenerateRandomPrefixes(
|
|
||||||
absl::Span<const DpfParameters> parameters,
|
|
||||||
absl::Span<const int> num_nonzeros) {
|
|
||||||
auto num_hierarchy_levels = static_cast<int>(parameters.size());
|
|
||||||
std::vector<std::vector<absl::uint128>> prefixes(parameters.size());
|
|
||||||
|
|
||||||
absl::BitGen rng;
|
|
||||||
absl::uniform_int_distribution<uint32_t> dist_index, dist_value;
|
|
||||||
for (int i = 0; i < num_hierarchy_levels; ++i) {
|
|
||||||
if (i > 0) { // prefixes must be empty for the first level.
|
|
||||||
prefixes[i] = std::vector<absl::uint128>(num_nonzeros[i - 1]);
|
|
||||||
absl::uint128 prefix = 0;
|
|
||||||
// Difference between the previous domain size and the one before that.
|
|
||||||
// This is the amount of bits we have to shift prefixes from the previous
|
|
||||||
// level to append the current level.
|
|
||||||
int previous_domain_size_difference = parameters[i - 1].log_domain_size();
|
|
||||||
if (i > 1) {
|
|
||||||
previous_domain_size_difference -= parameters[i - 2].log_domain_size();
|
|
||||||
}
|
|
||||||
dist_value = absl::uniform_int_distribution<uint32_t>(
|
|
||||||
0, (1 << previous_domain_size_difference) - 1);
|
|
||||||
if (i > 1) {
|
|
||||||
dist_index = absl::uniform_int_distribution<uint32_t>(
|
|
||||||
0, prefixes[i - 1].size() - 1);
|
|
||||||
}
|
|
||||||
for (int j = 0; i > 0 && j < num_nonzeros[i - 1]; ++j) {
|
|
||||||
if (i > 1) {
|
|
||||||
// Choose a random prefix from the previous level to extend.
|
|
||||||
prefix = prefixes[i - 1][dist_index(rng)]
|
|
||||||
<< previous_domain_size_difference;
|
|
||||||
}
|
|
||||||
prefixes[i][j] = prefix | dist_value(rng);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::sort(prefixes[i].begin(), prefixes[i].end());
|
|
||||||
}
|
|
||||||
return prefixes;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Benchmark the example used here:
|
|
||||||
// https://github.com/abetterinternet/prio-documents/issues/18#issuecomment-801248636
|
|
||||||
void BM_IsrgExampleHierarchy(benchmark::State& state) {
|
|
||||||
const int kNumHierarchyLevels = 2;
|
|
||||||
std::vector<DpfParameters> parameters(kNumHierarchyLevels);
|
|
||||||
std::vector<int> num_nonzeros(kNumHierarchyLevels - 1);
|
|
||||||
|
|
||||||
parameters[0].set_log_domain_size(12);
|
|
||||||
parameters[0].mutable_value_type()->mutable_integer()->set_bitsize(32);
|
|
||||||
num_nonzeros[0] = 32;
|
|
||||||
|
|
||||||
parameters[1].set_log_domain_size(25);
|
|
||||||
parameters[1].mutable_value_type()->mutable_integer()->set_bitsize(32);
|
|
||||||
|
|
||||||
std::unique_ptr<DistributedPointFunction> dpf =
|
|
||||||
DistributedPointFunction::CreateIncremental(parameters).value();
|
|
||||||
|
|
||||||
// Create DPF keys.
|
|
||||||
absl::uint128 alpha = 1234567;
|
|
||||||
std::vector<absl::uint128> beta(kNumHierarchyLevels, 1);
|
|
||||||
std::pair<DpfKey, DpfKey> keys =
|
|
||||||
dpf->GenerateKeysIncremental(alpha, beta).value();
|
|
||||||
|
|
||||||
// Generate prefixes for evaluation with the appropriate number of nonzeros.
|
|
||||||
std::vector<std::vector<absl::uint128>> prefixes =
|
|
||||||
GenerateRandomPrefixes(parameters, num_nonzeros);
|
|
||||||
|
|
||||||
// Run hierarchical evaluation.
|
|
||||||
EvaluationContext ctx_0 = dpf->CreateEvaluationContext(keys.first).value();
|
|
||||||
for (auto s : state) {
|
|
||||||
google::protobuf::Arena arena;
|
|
||||||
EvaluationContext* ctx =
|
|
||||||
google::protobuf::Arena::CreateMessage<EvaluationContext>(&arena);
|
|
||||||
*ctx = ctx_0;
|
|
||||||
for (int i = 0; i < kNumHierarchyLevels; ++i) {
|
|
||||||
std::vector<uint32_t> result =
|
|
||||||
dpf->EvaluateNext<uint32_t>(prefixes[i], *ctx).value();
|
|
||||||
benchmark::DoNotOptimize(result);
|
|
||||||
}
|
|
||||||
benchmark::DoNotOptimize(*ctx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
BENCHMARK(BM_IsrgExampleHierarchy);
|
|
||||||
|
|
||||||
// Benchmarks the time needed to generate keys. The log domain size is read from
|
|
||||||
// the first range argument. If `direct_evaluation` is true, a single hierarchy
|
|
||||||
// level will be used. Otherwise, the number of hierarchy levels is eqaual to
|
|
||||||
// the log domain size (i.e., one level per bit in the domain).
|
|
||||||
template <bool direct_evaluation>
|
|
||||||
void BM_KeyGeneration(benchmark::State& state) {
|
|
||||||
int last_level_log_domain_size = state.range(0);
|
|
||||||
std::vector<DpfParameters> parameters(1);
|
|
||||||
if (direct_evaluation) {
|
|
||||||
parameters[0].set_log_domain_size(last_level_log_domain_size);
|
|
||||||
parameters[0].mutable_value_type()->mutable_integer()->set_bitsize(32);
|
|
||||||
} else {
|
|
||||||
parameters.resize(last_level_log_domain_size);
|
|
||||||
for (int i = 0; i < last_level_log_domain_size; ++i) {
|
|
||||||
parameters[i].set_log_domain_size(i + 1);
|
|
||||||
parameters[i].mutable_value_type()->mutable_integer()->set_bitsize(32);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::unique_ptr<DistributedPointFunction> dpf =
|
|
||||||
*(DistributedPointFunction::CreateIncremental(parameters));
|
|
||||||
|
|
||||||
std::vector<absl::uint128> beta(parameters.size(), 23);
|
|
||||||
absl::BitGen rng;
|
|
||||||
absl::uniform_int_distribution<uint64_t> dist;
|
|
||||||
absl::uint128 alpha_mask =
|
|
||||||
(absl::uint128{1} << parameters.back().log_domain_size()) - 1;
|
|
||||||
std::pair<DpfKey, DpfKey> result;
|
|
||||||
for (auto s : state) {
|
|
||||||
// Sample alpha randomly, so we don't rely on any structure here.
|
|
||||||
absl::uint128 alpha = absl::MakeUint128(dist(rng), dist(rng)) & alpha_mask;
|
|
||||||
result = dpf->GenerateKeysIncremental(alpha, beta).value();
|
|
||||||
benchmark::DoNotOptimize(result);
|
|
||||||
}
|
|
||||||
state.SetLabel(absl::StrCat("key_size: ", result.first.ByteSizeLong()));
|
|
||||||
}
|
|
||||||
BENCHMARK_TEMPLATE(BM_KeyGeneration, true)->RangeMultiplier(2)->Range(1, 128);
|
|
||||||
BENCHMARK_TEMPLATE(BM_KeyGeneration, false)->RangeMultiplier(2)->Range(1, 128);
|
|
||||||
|
|
||||||
// Generates `num_nonzeros` uniform indices, and computes their prefixes for
|
|
||||||
// each hierarchy level in `parameters`.
|
|
||||||
absl::StatusOr<std::vector<std::vector<absl::uint128>>> GenerateUniformPrefixes(
|
|
||||||
absl::Span<const DpfParameters> parameters, int num_nonzeros) {
|
|
||||||
int num_parameters = static_cast<int>(parameters.size());
|
|
||||||
std::vector<std::vector<absl::uint128>> result(num_parameters);
|
|
||||||
if (num_parameters <= 1) {
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
if (std::log2(num_nonzeros) >
|
|
||||||
parameters[num_parameters - 2].log_domain_size()) {
|
|
||||||
return absl::InvalidArgumentError("num_nonzeros out of range");
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::BitGen rng;
|
|
||||||
absl::uniform_int_distribution<uint64_t> dist;
|
|
||||||
|
|
||||||
// Generate prefixes for last level.
|
|
||||||
absl::btree_set<absl::uint128> last_level_prefixes;
|
|
||||||
while (static_cast<int>(last_level_prefixes.size()) < num_nonzeros) {
|
|
||||||
absl::uint128 mask = (absl::uint128{1} << parameters[parameters.size() - 2]
|
|
||||||
.log_domain_size()) -
|
|
||||||
1;
|
|
||||||
last_level_prefixes.insert(absl::MakeUint128(dist(rng), dist(rng)) & mask);
|
|
||||||
}
|
|
||||||
result.back() = std::vector<absl::uint128>(last_level_prefixes.begin(),
|
|
||||||
last_level_prefixes.end());
|
|
||||||
|
|
||||||
// Iterate backwards through previous levels, computing prefixes by
|
|
||||||
// appropriately shifting the ones from higher levels.
|
|
||||||
for (int i = static_cast<int>(result.size()) - 1; i > 1; --i) {
|
|
||||||
absl::btree_set<absl::uint128> current_level_prefixes;
|
|
||||||
for (const auto& x : result[i]) {
|
|
||||||
absl::uint128 prefix = x >> (parameters[i - 1].log_domain_size() -
|
|
||||||
parameters[i - 2].log_domain_size());
|
|
||||||
current_level_prefixes.insert(prefix);
|
|
||||||
}
|
|
||||||
result[i - 1] = std::vector<absl::uint128>(current_level_prefixes.begin(),
|
|
||||||
current_level_prefixes.end());
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Benchmark a bit-wise hierarchy as in https://github.com/henrycg/heavyhitters.
|
|
||||||
// Uses a variable domain size with 10000 uniform non-zeros at the last
|
|
||||||
// hierarchy level, and evaluate at every bit.
|
|
||||||
void BM_HeavyHitters(benchmark::State& state) {
|
|
||||||
int num_parameters = state.range(0);
|
|
||||||
const int kNumNonzeros = 10000;
|
|
||||||
std::vector<DpfParameters> parameters(num_parameters);
|
|
||||||
for (int i = 0; i < num_parameters; ++i) {
|
|
||||||
parameters[i].set_log_domain_size(i + 1);
|
|
||||||
parameters[i].mutable_value_type()->mutable_integer()->set_bitsize(64);
|
|
||||||
}
|
|
||||||
std::unique_ptr<DistributedPointFunction> dpf =
|
|
||||||
*(DistributedPointFunction::CreateIncremental(parameters));
|
|
||||||
|
|
||||||
std::vector<absl::uint128> beta(num_parameters, 23);
|
|
||||||
absl::uint128 alpha = 42;
|
|
||||||
DpfKey key = dpf->GenerateKeysIncremental(alpha, beta).value().first;
|
|
||||||
std::vector<std::vector<absl::uint128>> prefixes =
|
|
||||||
GenerateUniformPrefixes(parameters, kNumNonzeros).value();
|
|
||||||
|
|
||||||
// Run hierarchical evaluation.
|
|
||||||
EvaluationContext ctx_0 = dpf->CreateEvaluationContext(key).value();
|
|
||||||
for (auto s : state) {
|
|
||||||
google::protobuf::Arena arena;
|
|
||||||
EvaluationContext* ctx =
|
|
||||||
google::protobuf::Arena::CreateMessage<EvaluationContext>(&arena);
|
|
||||||
*ctx = ctx_0;
|
|
||||||
for (int i = 0; i < num_parameters; ++i) {
|
|
||||||
std::vector<uint64_t> result =
|
|
||||||
dpf->EvaluateNext<uint64_t>(prefixes[i], *ctx).value();
|
|
||||||
benchmark::DoNotOptimize(result);
|
|
||||||
}
|
|
||||||
benchmark::DoNotOptimize(*ctx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
BENCHMARK(BM_HeavyHitters)->RangeMultiplier(2)->Range(16, 128);
|
|
||||||
|
|
||||||
// Benchmark batch evaluation of multiple DPF keys at a single point each.
|
|
||||||
// The first argument specifies the number of keys, the second the domain size,
|
|
||||||
// and the last the number of evaluation points per key.
|
|
||||||
template <typename T>
|
|
||||||
void BM_BatchEvaluation(benchmark::State& state) {
|
|
||||||
const int num_keys = state.range(0);
|
|
||||||
const int evaluation_points_per_key = state.range(1);
|
|
||||||
constexpr int kLogDomainSize = 63 - 7;
|
|
||||||
|
|
||||||
absl::uint128 domain_mask = absl::Uint128Max();
|
|
||||||
if (kLogDomainSize < 128) {
|
|
||||||
domain_mask = (absl::uint128{1} << kLogDomainSize) - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
DpfParameters parameters;
|
|
||||||
parameters.set_log_domain_size(kLogDomainSize);
|
|
||||||
*(parameters.mutable_value_type()) = ToValueType<T>();
|
|
||||||
|
|
||||||
std::unique_ptr<DistributedPointFunction> dpf =
|
|
||||||
DistributedPointFunction::Create(parameters).value();
|
|
||||||
|
|
||||||
absl::BitGen rng;
|
|
||||||
google::protobuf::Arena arena;
|
|
||||||
std::vector<const DpfKey*> key_pointers(num_keys * evaluation_points_per_key);
|
|
||||||
auto evaluation_points =
|
|
||||||
hwy::AllocateAligned<absl::uint128>(num_keys * evaluation_points_per_key);
|
|
||||||
ABSL_CHECK(evaluation_points != nullptr);
|
|
||||||
for (int i = 0; i < num_keys; ++i) {
|
|
||||||
absl::uint128 alpha = absl::MakeUint128(absl::Uniform<uint64_t>(rng),
|
|
||||||
absl::Uniform<uint64_t>(rng)) &
|
|
||||||
domain_mask;
|
|
||||||
T beta{};
|
|
||||||
DpfKey* key = google::protobuf::Arena::CreateMessage<DpfKey>(&arena);
|
|
||||||
*key = dpf->GenerateKeys(alpha, beta).value().first;
|
|
||||||
|
|
||||||
for (int j = 0; j < evaluation_points_per_key; ++j) {
|
|
||||||
key_pointers[i * evaluation_points_per_key + j] = key;
|
|
||||||
evaluation_points[i * evaluation_points_per_key + j] =
|
|
||||||
absl::MakeUint128(absl::Uniform<uint64_t>(rng),
|
|
||||||
absl::Uniform<uint64_t>(rng)) &
|
|
||||||
domain_mask;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto s : state) {
|
|
||||||
for (int i = 0; i < num_keys; ++i) {
|
|
||||||
std::vector<T> result =
|
|
||||||
dpf->EvaluateAt<T>(
|
|
||||||
*(key_pointers[i]), 0,
|
|
||||||
absl::MakeConstSpan(
|
|
||||||
evaluation_points.get() + i * evaluation_points_per_key,
|
|
||||||
evaluation_points_per_key))
|
|
||||||
.value();
|
|
||||||
benchmark::DoNotOptimize(result);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
BENCHMARK_TEMPLATE(BM_BatchEvaluation, XorWrapper<absl::uint128>)
|
|
||||||
->ArgPair(1, 400000)
|
|
||||||
->ArgPair(10, 40000)
|
|
||||||
->ArgPair(100, 4000);
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace distributed_point_functions
|
|
File diff suppressed because it is too large
Load Diff
@ -1,88 +0,0 @@
|
|||||||
// Copyright 2021 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "dpf/int_mod_n.h"
|
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/strings/str_format.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
|
|
||||||
namespace dpf_internal {
|
|
||||||
|
|
||||||
double IntModNBase::GetSecurityLevel(int num_samples, absl::uint128 modulus) {
|
|
||||||
return 128 + 3 -
|
|
||||||
(std::log2(static_cast<double>(modulus)) +
|
|
||||||
std::log2(static_cast<double>(num_samples)) +
|
|
||||||
std::log2(static_cast<double>(num_samples + 1)));
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status IntModNBase::CheckParameters(int num_samples,
|
|
||||||
int base_integer_bitsize,
|
|
||||||
absl::uint128 modulus,
|
|
||||||
double security_parameter) {
|
|
||||||
if (num_samples <= 0) {
|
|
||||||
return absl::InvalidArgumentError("num_samples must be positive");
|
|
||||||
}
|
|
||||||
if (base_integer_bitsize <= 0) {
|
|
||||||
return absl::InvalidArgumentError("base_integer_bitsize must be positive");
|
|
||||||
}
|
|
||||||
if (base_integer_bitsize > 128) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"base_integer_bitsize must be at most 128");
|
|
||||||
}
|
|
||||||
if (base_integer_bitsize < 128 &&
|
|
||||||
(absl::uint128{1} << base_integer_bitsize) < modulus) {
|
|
||||||
return absl::InvalidArgumentError(absl::StrFormat(
|
|
||||||
"kModulus %d out of range for base_integer_bitsize = %d", modulus,
|
|
||||||
base_integer_bitsize));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the level of security that we will get, and fail if it is
|
|
||||||
// insufficient.
|
|
||||||
const double sigma = GetSecurityLevel(num_samples, modulus);
|
|
||||||
if (security_parameter > sigma) {
|
|
||||||
return absl::InvalidArgumentError(absl::StrFormat(
|
|
||||||
"For num_samples = %d and kModulus = %d this approach can only "
|
|
||||||
"provide "
|
|
||||||
"%f bits of statistical security. You can try calling this function "
|
|
||||||
"several times with smaller values of num_samples.",
|
|
||||||
num_samples, modulus, sigma));
|
|
||||||
}
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<int> IntModNBase::GetNumBytesRequired(
|
|
||||||
int num_samples, int base_integer_bitsize, absl::uint128 modulus,
|
|
||||||
double security_parameter) {
|
|
||||||
absl::Status status = CheckParameters(num_samples, base_integer_bitsize,
|
|
||||||
modulus, security_parameter);
|
|
||||||
if (!status.ok()) {
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int base_integer_bytes = ((base_integer_bitsize + 7) / 8);
|
|
||||||
// We start the sampling by requiring a 128-bit (16 bytes) block, see
|
|
||||||
// function `SampleFromBytes`.
|
|
||||||
return 16 + base_integer_bytes * (num_samples - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace dpf_internal
|
|
||||||
|
|
||||||
} // namespace distributed_point_functions
|
|
@ -1,282 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2021 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_INT_MOD_N_H_
|
|
||||||
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_INT_MOD_N_H_
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <string>
|
|
||||||
#include <type_traits>
|
|
||||||
|
|
||||||
#include "absl/base/config.h"
|
|
||||||
#include "absl/container/inlined_vector.h"
|
|
||||||
#include "absl/log/absl_check.h"
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "absl/strings/string_view.h"
|
|
||||||
#include "absl/types/span.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
|
|
||||||
namespace dpf_internal {
|
|
||||||
|
|
||||||
// Base class holding common functions of IntModN that are independent of the
|
|
||||||
// template parameter.
|
|
||||||
class IntModNBase {
|
|
||||||
public:
|
|
||||||
// Computes the security level achievable when sampling `num_samples` elements
|
|
||||||
// with the given `kModulus`.
|
|
||||||
//
|
|
||||||
static double GetSecurityLevel(int num_samples, absl::uint128 modulus);
|
|
||||||
|
|
||||||
// Checks if the given parameters are consistent and valid for an IntModN.
|
|
||||||
//
|
|
||||||
// Returns OK for valid parameters, and INVALID_ARGUMENT otherwise.
|
|
||||||
static absl::Status CheckParameters(int num_samples, int base_integer_bitsize,
|
|
||||||
absl::uint128 modulus,
|
|
||||||
double security_parameter);
|
|
||||||
|
|
||||||
// Computes the number of bytes required to sample `num_samples` integers
|
|
||||||
// modulo `kModulus` with an underlying integer type of
|
|
||||||
// `base_integer_bitsize`.
|
|
||||||
//
|
|
||||||
// Returns INVALID_ARGUMENT if the achievable security level with the given
|
|
||||||
// parameters is less than `security_parameter`, or if the parameters are
|
|
||||||
// invalid.
|
|
||||||
static absl::StatusOr<int> GetNumBytesRequired(int num_samples,
|
|
||||||
int base_integer_bitsize,
|
|
||||||
absl::uint128 modulus,
|
|
||||||
double security_parameter);
|
|
||||||
|
|
||||||
// Creates a value of type T from the given `bytes`, using little-endian
|
|
||||||
// encoding. Called by SampleFromBytes. Crashes if bytes.size() != sizeof(T).
|
|
||||||
//
|
|
||||||
// This is a reimplementation of dpf_internal::ConvertBytesTo for integers,
|
|
||||||
// to avoid depending on value_type_helpers here.
|
|
||||||
template <typename T>
|
|
||||||
static T ConvertBytesTo(absl::string_view bytes) {
|
|
||||||
ABSL_CHECK(bytes.size() == sizeof(T));
|
|
||||||
T out{0};
|
|
||||||
#ifdef ABSL_IS_LITTLE_ENDIAN
|
|
||||||
std::copy_n(bytes.begin(), sizeof(T), reinterpret_cast<char*>(&out));
|
|
||||||
#else
|
|
||||||
for (int i = sizeof(T) - 1; i >= 0; --i) {
|
|
||||||
out |= absl::bit_cast<uint8_t>(bytes[i]);
|
|
||||||
out <<= 8;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename BaseInteger, typename ModulusType, ModulusType kModulus>
|
|
||||||
class IntModNImpl : public IntModNBase {
|
|
||||||
static_assert(sizeof(BaseInteger) <= sizeof(absl::uint128),
|
|
||||||
"BaseInteger may be at most 128 bits large");
|
|
||||||
static_assert(
|
|
||||||
std::is_same<BaseInteger, absl::uint128>::value ||
|
|
||||||
#ifdef ABSL_HAVE_INTRINSIC_INT128
|
|
||||||
// std::is_unsigned_v<unsigned __int128> is not true everywhere:
|
|
||||||
// https://quuxplusone.github.io/blog/2019/02/28/is-int128-integral/#signedness
|
|
||||||
std::is_same<BaseInteger, unsigned __int128>::value ||
|
|
||||||
#endif
|
|
||||||
std::is_unsigned<BaseInteger>::value,
|
|
||||||
"BaseInteger must be unsigned");
|
|
||||||
static_assert(kModulus <= ModulusType(BaseInteger(-1)),
|
|
||||||
"kModulus must fit in BaseInteger");
|
|
||||||
|
|
||||||
public:
|
|
||||||
using Base = BaseInteger;
|
|
||||||
|
|
||||||
constexpr IntModNImpl() : value_(0) {}
|
|
||||||
explicit constexpr IntModNImpl(BaseInteger value)
|
|
||||||
: value_(value % kModulus) {}
|
|
||||||
|
|
||||||
// Copyable.
|
|
||||||
constexpr IntModNImpl(const IntModNImpl& a) = default;
|
|
||||||
|
|
||||||
constexpr IntModNImpl& operator=(const IntModNImpl& a) = default;
|
|
||||||
|
|
||||||
// Assignment operators.
|
|
||||||
constexpr IntModNImpl& operator=(const BaseInteger& a) {
|
|
||||||
value_ = a % kModulus;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr IntModNImpl& operator+=(const IntModNImpl& a) {
|
|
||||||
AddBaseInteger(a.value_);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr IntModNImpl& operator-=(const IntModNImpl& a) {
|
|
||||||
SubtractBaseInteger(a.value_);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the underlying representation as a BaseInteger.
|
|
||||||
constexpr BaseInteger value() const { return value_; }
|
|
||||||
|
|
||||||
// Returns the modulus of this IntModNImpl type.
|
|
||||||
static constexpr BaseInteger modulus() { return kModulus; }
|
|
||||||
|
|
||||||
// Returns the number of (pseudo)random bytes required to extract
|
|
||||||
// `num_samples` samples r1, ..., rn
|
|
||||||
// so that the stream r1, ..., rn is close to a truly (pseudo) random
|
|
||||||
// sequence up to total variation distance < 2^(-`security_parameter`)
|
|
||||||
static absl::StatusOr<int> GetNumBytesRequired(int num_samples,
|
|
||||||
double security_parameter) {
|
|
||||||
return IntModNBase::GetNumBytesRequired(
|
|
||||||
num_samples, 8 * sizeof(BaseInteger), kModulus, security_parameter);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extracts `samples.size()` samples r1, ..., rn so that the stream r1, ...,
|
|
||||||
// rn is close to a truly (pseudo) random sequence up to total variation
|
|
||||||
// distance < 2^(-`security_parameter`). Returns r1, ..., rn in `samples`.
|
|
||||||
//
|
|
||||||
// The optional template argument allows users to specify the number of
|
|
||||||
// samples at compile time, which can save heap allocations.
|
|
||||||
//
|
|
||||||
// Caution: For performance reasons, this function does not check whether
|
|
||||||
// `bytes` is long enough for the required number of samples and security
|
|
||||||
// parameter. Use `GetNumBytesRequired` or `SampleFromBytes` if such checks
|
|
||||||
// are needed.
|
|
||||||
//
|
|
||||||
template <int kCompiledNumSamples = 1>
|
|
||||||
static void UnsafeSampleFromBytes(absl::string_view bytes,
|
|
||||||
double security_parameter,
|
|
||||||
absl::Span<IntModNImpl> samples) {
|
|
||||||
static_assert(kCompiledNumSamples >= 1,
|
|
||||||
"kCompiledNumSamples must be positive");
|
|
||||||
absl::uint128 r = ConvertBytesTo<absl::uint128>(bytes.substr(0, 16));
|
|
||||||
absl::InlinedVector<BaseInteger, std::max(1, kCompiledNumSamples - 1)>
|
|
||||||
randomness(samples.size() - 1);
|
|
||||||
for (int i = 0; i < static_cast<int>(randomness.size()); ++i) {
|
|
||||||
randomness[i] = ConvertBytesTo<BaseInteger>(
|
|
||||||
bytes.substr(16 + i * sizeof(BaseInteger), sizeof(BaseInteger)));
|
|
||||||
}
|
|
||||||
for (int i = 0; i < static_cast<int>(samples.size()); ++i) {
|
|
||||||
samples[i] = IntModNImpl(static_cast<BaseInteger>(r % kModulus));
|
|
||||||
if (i < static_cast<int>(randomness.size())) {
|
|
||||||
r /= kModulus;
|
|
||||||
if (sizeof(BaseInteger) < sizeof(absl::uint128)) {
|
|
||||||
r <<= (sizeof(BaseInteger) * 8);
|
|
||||||
}
|
|
||||||
r |= randomness[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checks that length(`bytes`) is enough to extract
|
|
||||||
// `samples.size()` samples r1, ..., rn
|
|
||||||
// so that the stream r1, ..., rn is close to a truly (pseudo) random
|
|
||||||
// sequence up to total variation distance < 2^(-`security_parameter`) and
|
|
||||||
// fails if that is not the case.
|
|
||||||
// Otherwise returns r1, ..., rn in `samples`.
|
|
||||||
static absl::Status SampleFromBytes(absl::string_view bytes,
|
|
||||||
double security_parameter,
|
|
||||||
absl::Span<IntModNImpl> samples) {
|
|
||||||
if (samples.empty()) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"The number of samples required must be > 0");
|
|
||||||
}
|
|
||||||
absl::StatusOr<int> num_bytes_lower_bound =
|
|
||||||
GetNumBytesRequired(samples.size(), security_parameter);
|
|
||||||
if (!num_bytes_lower_bound.ok()) {
|
|
||||||
return num_bytes_lower_bound.status();
|
|
||||||
}
|
|
||||||
if (*num_bytes_lower_bound > bytes.size()) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
absl::StrCat("The number of bytes provided (", bytes.size(),
|
|
||||||
") is insufficient for the required "
|
|
||||||
"statistical security and number of samples."));
|
|
||||||
}
|
|
||||||
UnsafeSampleFromBytes(bytes, security_parameter, samples);
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
constexpr void SubtractBaseInteger(const BaseInteger& a) {
|
|
||||||
if (value_ >= a) {
|
|
||||||
value_ -= a;
|
|
||||||
} else {
|
|
||||||
value_ = kModulus - a + value_;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr void AddBaseInteger(const BaseInteger& a) {
|
|
||||||
SubtractBaseInteger(kModulus - a);
|
|
||||||
}
|
|
||||||
|
|
||||||
BaseInteger value_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename BaseInteger, typename ModulusType, ModulusType kModulus>
|
|
||||||
constexpr IntModNImpl<BaseInteger, ModulusType, kModulus> operator+(
|
|
||||||
IntModNImpl<BaseInteger, ModulusType, kModulus> a,
|
|
||||||
const IntModNImpl<BaseInteger, ModulusType, kModulus>& b) {
|
|
||||||
a += b;
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename BaseInteger, typename ModulusType, ModulusType kModulus>
|
|
||||||
constexpr IntModNImpl<BaseInteger, ModulusType, kModulus> operator-(
|
|
||||||
IntModNImpl<BaseInteger, ModulusType, kModulus> a,
|
|
||||||
const IntModNImpl<BaseInteger, ModulusType, kModulus>& b) {
|
|
||||||
a -= b;
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename BaseInteger, typename ModulusType, ModulusType kModulus>
|
|
||||||
constexpr IntModNImpl<BaseInteger, ModulusType, kModulus> operator-(
|
|
||||||
IntModNImpl<BaseInteger, ModulusType, kModulus> a) {
|
|
||||||
IntModNImpl<BaseInteger, ModulusType, kModulus> result(BaseInteger{0});
|
|
||||||
result -= a;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename BaseInteger, typename ModulusType, ModulusType kModulus>
|
|
||||||
constexpr bool operator==(
|
|
||||||
const IntModNImpl<BaseInteger, ModulusType, kModulus>& a,
|
|
||||||
const IntModNImpl<BaseInteger, ModulusType, kModulus>& b) {
|
|
||||||
return a.value() == b.value();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename BaseInteger, typename ModulusType, ModulusType kModulus>
|
|
||||||
constexpr bool operator!=(
|
|
||||||
const IntModNImpl<BaseInteger, ModulusType, kModulus>& a,
|
|
||||||
const IntModNImpl<BaseInteger, ModulusType, kModulus>& b) {
|
|
||||||
return !(a == b);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace dpf_internal
|
|
||||||
|
|
||||||
// Since `absl::uint128` is not an alias to `unsigned __int128`, but a struct,
|
|
||||||
// we cannot use it as a template parameter type. So if we have an intrinsic
|
|
||||||
// int128, we always use that as the modulus type. Otherwise, the modulus type
|
|
||||||
// is the same as BaseInteger.
|
|
||||||
#ifdef ABSL_HAVE_INTRINSIC_INT128
|
|
||||||
template <typename BaseInteger, unsigned __int128 kModulus>
|
|
||||||
using IntModN =
|
|
||||||
dpf_internal::IntModNImpl<BaseInteger, unsigned __int128, kModulus>;
|
|
||||||
#else
|
|
||||||
template <typename BaseInteger, BaseInteger kModulus>
|
|
||||||
using IntModN = dpf_internal::IntModNImpl<BaseInteger, BaseInteger, kModulus>;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_INT_MOD_N_H_
|
|
@ -1,54 +0,0 @@
|
|||||||
// Copyright 2021 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/strings/string_view.h"
|
|
||||||
#include "absl/types/span.h"
|
|
||||||
#include "benchmark/benchmark.h"
|
|
||||||
#include "dpf/int_mod_n.h"
|
|
||||||
#include "openssl/rand.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
using MyInt = IntModN<uint32_t, 4294967291u>; // 2**32 - 5.
|
|
||||||
constexpr int kNumSamples = 5;
|
|
||||||
|
|
||||||
void BM_Sample(benchmark::State& state) {
|
|
||||||
int num_iterations = state.range(0);
|
|
||||||
double security_parameter = 40 + std::log2(num_iterations);
|
|
||||||
std::vector<uint8_t> bytes(
|
|
||||||
MyInt::GetNumBytesRequired(kNumSamples, security_parameter).value());
|
|
||||||
RAND_bytes(bytes.data(), bytes.size());
|
|
||||||
std::vector<MyInt> output(num_iterations * kNumSamples);
|
|
||||||
for (auto s : state) {
|
|
||||||
for (int i = 0; i < num_iterations; ++i) {
|
|
||||||
MyInt::UnsafeSampleFromBytes<kNumSamples>(
|
|
||||||
absl::string_view(reinterpret_cast<const char*>(bytes.data()),
|
|
||||||
bytes.size()),
|
|
||||||
security_parameter,
|
|
||||||
absl::MakeSpan(&output[i * kNumSamples], kNumSamples));
|
|
||||||
}
|
|
||||||
benchmark::DoNotOptimize(output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
BENCHMARK(BM_Sample)->Range(1, 1 << 20);
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace distributed_point_functions
|
|
@ -1,258 +0,0 @@
|
|||||||
// Copyright 2021 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "dpf/int_mod_n.h"
|
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/base/config.h"
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/strings/str_format.h"
|
|
||||||
#include "absl/types/span.h"
|
|
||||||
#include "dpf/internal/status_matchers.h"
|
|
||||||
#include "gmock/gmock.h"
|
|
||||||
#include "gtest/gtest.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
constexpr double kFeasibleSecurityParameter = 40;
|
|
||||||
constexpr double kUnfeasibleSecurityParameter = 95;
|
|
||||||
constexpr int kNumSamples = 5;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class IntModNTest : public testing::Test {};
|
|
||||||
using IntModNTypes = ::testing::Types<
|
|
||||||
IntModN<uint32_t, 4294967291u>, // 2**32-5
|
|
||||||
IntModN<uint64_t, 18446744073709551557ull> // 2**64-59
|
|
||||||
#ifdef ABSL_HAVE_INTRINSIC_INT128
|
|
||||||
,
|
|
||||||
IntModN<absl::uint128, (unsigned __int128)(absl::MakeUint128(
|
|
||||||
65535u, 18446744073709551551ull))> // 2**80-65
|
|
||||||
#endif
|
|
||||||
>;
|
|
||||||
TYPED_TEST_SUITE(IntModNTest, IntModNTypes);
|
|
||||||
|
|
||||||
TYPED_TEST(IntModNTest, DefaultValueIsZero) {
|
|
||||||
TypeParam a;
|
|
||||||
EXPECT_EQ(a.value(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(IntModNTest, SetValueWorks) {
|
|
||||||
TypeParam a;
|
|
||||||
EXPECT_EQ(a.value(), 0);
|
|
||||||
a = 23;
|
|
||||||
EXPECT_EQ(a.value(), 23);
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(IntModNTest, AdditionWithoutWrapAroundWorks) {
|
|
||||||
TypeParam a;
|
|
||||||
TypeParam b;
|
|
||||||
a += b;
|
|
||||||
EXPECT_EQ(a.value(), 0);
|
|
||||||
b = 23;
|
|
||||||
a += b;
|
|
||||||
EXPECT_EQ(a.value(), 23);
|
|
||||||
b = 4294967200;
|
|
||||||
a += b;
|
|
||||||
EXPECT_EQ(a.value(), 4294967223);
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(IntModNTest, AdditionWithWrapAroundWorks) {
|
|
||||||
TypeParam a;
|
|
||||||
TypeParam b;
|
|
||||||
a += b;
|
|
||||||
EXPECT_EQ(a.value(), 0);
|
|
||||||
b = 23;
|
|
||||||
a += b;
|
|
||||||
EXPECT_EQ(a.value(), 23);
|
|
||||||
b = TypeParam::modulus() - 10;
|
|
||||||
a += b;
|
|
||||||
EXPECT_EQ(a.value(), 13);
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(IntModNTest, NegationWorks) {
|
|
||||||
TypeParam a(10);
|
|
||||||
TypeParam b = -a;
|
|
||||||
EXPECT_EQ(a + b, TypeParam(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(IntModNTest, GetNumBytesRequiredFailsIfUnfeasible) {
|
|
||||||
absl::StatusOr<int> result =
|
|
||||||
TypeParam::GetNumBytesRequired(kNumSamples, kUnfeasibleSecurityParameter);
|
|
||||||
EXPECT_THAT(result, dpf_internal::StatusIs(
|
|
||||||
absl::StatusCode::kInvalidArgument,
|
|
||||||
testing::StartsWith(absl::StrFormat(
|
|
||||||
"For num_samples = 5 and kModulus = %d",
|
|
||||||
absl::uint128(TypeParam::modulus())))));
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(IntModNTest, GetNumBytesRequiredSucceedsIfFeasible) {
|
|
||||||
absl::StatusOr<int> result =
|
|
||||||
TypeParam::GetNumBytesRequired(5, kFeasibleSecurityParameter);
|
|
||||||
EXPECT_EQ(result.ok(), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(IntModNTest, SampleFailsIfUnfeasible) {
|
|
||||||
absl::StatusOr<int> r_getnum =
|
|
||||||
TypeParam::GetNumBytesRequired(5, kFeasibleSecurityParameter);
|
|
||||||
EXPECT_EQ(r_getnum.ok(), true);
|
|
||||||
|
|
||||||
std::string bytes = std::string(16, '#');
|
|
||||||
EXPECT_GT(r_getnum.value(), bytes.size());
|
|
||||||
std::vector<TypeParam> samples(5);
|
|
||||||
absl::Status r_sample = TypeParam::SampleFromBytes(
|
|
||||||
bytes, kFeasibleSecurityParameter, absl::MakeSpan(samples));
|
|
||||||
EXPECT_EQ(r_sample.ok(), false);
|
|
||||||
EXPECT_THAT(
|
|
||||||
r_sample,
|
|
||||||
dpf_internal::StatusIs(
|
|
||||||
absl::StatusCode::kInvalidArgument,
|
|
||||||
"The number of bytes provided (16) is insufficient for the required "
|
|
||||||
"statistical security and number of samples."));
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(IntModNTest, SampleSucceedsIfFeasible) {
|
|
||||||
absl::StatusOr<int> r_getnum =
|
|
||||||
TypeParam::GetNumBytesRequired(5, kFeasibleSecurityParameter);
|
|
||||||
EXPECT_EQ(r_getnum.ok(), true);
|
|
||||||
|
|
||||||
std::string bytes = std::string(r_getnum.value(), '#');
|
|
||||||
std::vector<TypeParam> samples(5);
|
|
||||||
absl::Status r_sample = TypeParam::SampleFromBytes(
|
|
||||||
bytes, kFeasibleSecurityParameter, absl::MakeSpan(samples));
|
|
||||||
EXPECT_EQ(r_sample.ok(), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(IntModNTest, FirstEntryOfSamplesIsAsExpected) {
|
|
||||||
absl::StatusOr<int> r_getnum =
|
|
||||||
TypeParam::GetNumBytesRequired(5, kFeasibleSecurityParameter);
|
|
||||||
EXPECT_EQ(r_getnum.ok(), true);
|
|
||||||
|
|
||||||
std::string bytes = std::string(r_getnum.value(), '#');
|
|
||||||
std::vector<TypeParam> samples(5);
|
|
||||||
absl::Status r_sample = TypeParam::SampleFromBytes(
|
|
||||||
bytes, kFeasibleSecurityParameter, absl::MakeSpan(samples));
|
|
||||||
EXPECT_EQ(r_sample.ok(), true);
|
|
||||||
EXPECT_EQ(
|
|
||||||
samples[0].value(),
|
|
||||||
TypeParam::template ConvertBytesTo<absl::uint128>(bytes.substr(0, 16)) %
|
|
||||||
TypeParam::modulus());
|
|
||||||
}
|
|
||||||
|
|
||||||
using BaseInteger = uint32_t;
|
|
||||||
constexpr BaseInteger kModulus32 = 4294967291u; // 2**32 - 5
|
|
||||||
using MyIntModN = IntModN<BaseInteger, kModulus32>;
|
|
||||||
|
|
||||||
TEST(IntModNTest, SampleFromBytesWorksInConcreteExample) {
|
|
||||||
absl::StatusOr<int> r_getnum =
|
|
||||||
MyIntModN::GetNumBytesRequired(5, kFeasibleSecurityParameter);
|
|
||||||
EXPECT_EQ(r_getnum.ok(), true);
|
|
||||||
EXPECT_EQ(*r_getnum, 32);
|
|
||||||
std::string bytes = "this is a length 32 test string.";
|
|
||||||
EXPECT_EQ(bytes.size(), 32);
|
|
||||||
|
|
||||||
std::vector<MyIntModN> samples(5);
|
|
||||||
absl::Status r_sample = MyIntModN::SampleFromBytes(
|
|
||||||
bytes, kFeasibleSecurityParameter, absl::MakeSpan(samples));
|
|
||||||
EXPECT_EQ(r_sample.ok(), true);
|
|
||||||
absl::uint128 r =
|
|
||||||
MyIntModN::ConvertBytesTo<absl::uint128>("this is a length");
|
|
||||||
EXPECT_EQ(samples[0].value(), r % MyIntModN::modulus());
|
|
||||||
r /= MyIntModN::modulus();
|
|
||||||
r <<= (sizeof(MyIntModN::Base) * 8);
|
|
||||||
r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>(" 32 ");
|
|
||||||
EXPECT_EQ(samples[1].value(), r % MyIntModN::modulus());
|
|
||||||
r /= MyIntModN::modulus();
|
|
||||||
r <<= (sizeof(MyIntModN::Base) * 8);
|
|
||||||
r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>("test");
|
|
||||||
EXPECT_EQ(samples[2].value(), r % MyIntModN::modulus());
|
|
||||||
r /= MyIntModN::modulus();
|
|
||||||
r <<= (sizeof(MyIntModN::Base) * 8);
|
|
||||||
r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>(" str");
|
|
||||||
EXPECT_EQ(samples[3].value(), r % MyIntModN::modulus());
|
|
||||||
r /= MyIntModN::modulus();
|
|
||||||
r <<= (sizeof(MyIntModN::Base) * 8);
|
|
||||||
r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>("ing.");
|
|
||||||
EXPECT_EQ(samples[4].value(), r % MyIntModN::modulus());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(IntModNTest, SampleFromBytesFailsAsExpectedInConcreteExample) {
|
|
||||||
absl::StatusOr<int> r_getnum =
|
|
||||||
MyIntModN::GetNumBytesRequired(5, kFeasibleSecurityParameter);
|
|
||||||
EXPECT_EQ(r_getnum.ok(), true);
|
|
||||||
EXPECT_EQ(*r_getnum, 32);
|
|
||||||
std::string bytes = "this is a length 32 test string.";
|
|
||||||
EXPECT_EQ(bytes.size(), 32);
|
|
||||||
|
|
||||||
std::vector<MyIntModN> samples(5);
|
|
||||||
absl::Status r_sample = MyIntModN::SampleFromBytes(
|
|
||||||
bytes, kFeasibleSecurityParameter, absl::MakeSpan(samples));
|
|
||||||
EXPECT_EQ(r_sample.ok(), true);
|
|
||||||
absl::uint128 r =
|
|
||||||
MyIntModN::ConvertBytesTo<absl::uint128>("this is a length");
|
|
||||||
EXPECT_EQ(samples[0].value(), r % MyIntModN::modulus());
|
|
||||||
r /= MyIntModN::modulus();
|
|
||||||
r <<= (sizeof(MyIntModN::Base) * 8);
|
|
||||||
r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>(" 32 ");
|
|
||||||
EXPECT_EQ(samples[1].value(), r % MyIntModN::modulus());
|
|
||||||
r /= MyIntModN::modulus();
|
|
||||||
r <<= (sizeof(MyIntModN::Base) * 8);
|
|
||||||
r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>("test");
|
|
||||||
EXPECT_EQ(samples[2].value(), r % MyIntModN::modulus());
|
|
||||||
r /= MyIntModN::modulus();
|
|
||||||
r <<= (sizeof(MyIntModN::Base) * 8);
|
|
||||||
r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>(" str");
|
|
||||||
EXPECT_EQ(samples[3].value(), r % MyIntModN::modulus());
|
|
||||||
r /= MyIntModN::modulus();
|
|
||||||
r <<= (sizeof(MyIntModN::Base) * 8);
|
|
||||||
r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>("ing#"); // # instead of .
|
|
||||||
EXPECT_NE(samples[4].value(), r % MyIntModN::modulus());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test if IntModN operators are in fact constexpr. This will fail to compile
|
|
||||||
// otherwise.
|
|
||||||
constexpr MyIntModN TestAddition() { return MyIntModN(2) + MyIntModN(5); }
|
|
||||||
static_assert(TestAddition().value() == 7,
|
|
||||||
"constexpr addition of IntModNs incorrect");
|
|
||||||
|
|
||||||
constexpr MyIntModN TestSubtraction() { return MyIntModN(5) - MyIntModN(2); }
|
|
||||||
static_assert(TestSubtraction().value() == 3,
|
|
||||||
"constexpr subtraction of IntModNs incorrect");
|
|
||||||
|
|
||||||
constexpr MyIntModN TestAssignment() {
|
|
||||||
MyIntModN x(0);
|
|
||||||
x = 5;
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
static_assert(TestAssignment().value() == 5,
|
|
||||||
"constexpr assignment to IntModN incorrect");
|
|
||||||
|
|
||||||
#ifdef ABSL_HAVE_INTRINSIC_INT128
|
|
||||||
constexpr unsigned __int128 kModulus128 =
|
|
||||||
(unsigned __int128)(-1); // 2**128 - 159
|
|
||||||
using MyIntModN128 = IntModN<unsigned __int128, kModulus128>;
|
|
||||||
constexpr MyIntModN128 TestAddition128() {
|
|
||||||
return MyIntModN128(2) + MyIntModN128(5);
|
|
||||||
}
|
|
||||||
static_assert(TestAddition128().value() == 7,
|
|
||||||
"constexpr addition of IntModNs incorrect");
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace distributed_point_functions
|
|
@ -1,238 +0,0 @@
|
|||||||
# Copyright 2023 Google LLC
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
|
|
||||||
load("@com_github_google_iree//build_tools/embed_data:build_defs.bzl", "cc_embed_data")
|
|
||||||
|
|
||||||
package(
|
|
||||||
default_visibility = ["//:__subpackages__"],
|
|
||||||
)
|
|
||||||
|
|
||||||
licenses(["notice"])
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "value_type_helpers",
|
|
||||||
srcs = ["value_type_helpers.cc"],
|
|
||||||
hdrs = ["value_type_helpers.h"],
|
|
||||||
deps = [
|
|
||||||
"//dpf:distributed_point_function_cc_proto",
|
|
||||||
"//dpf:int_mod_n",
|
|
||||||
"//dpf:status_macros",
|
|
||||||
"//dpf:tuple",
|
|
||||||
"//dpf:xor_wrapper",
|
|
||||||
"@com_google_absl//absl/base:config",
|
|
||||||
"@com_google_absl//absl/log:absl_check",
|
|
||||||
"@com_google_absl//absl/meta:type_traits",
|
|
||||||
"@com_google_absl//absl/numeric:int128",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_absl//absl/strings:str_format",
|
|
||||||
"@com_google_absl//absl/utility",
|
|
||||||
"@com_google_protobuf//:protobuf_lite",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "value_type_helpers_test",
|
|
||||||
srcs = ["value_type_helpers_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":status_matchers",
|
|
||||||
":value_type_helpers",
|
|
||||||
"//dpf:distributed_point_function_cc_proto",
|
|
||||||
"//dpf:int_mod_n",
|
|
||||||
"//dpf:tuple",
|
|
||||||
"@com_github_google_googletest//:gtest_main",
|
|
||||||
"@com_google_absl//absl/base:config",
|
|
||||||
"@com_google_absl//absl/numeric:int128",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "status_matchers",
|
|
||||||
testonly = 1,
|
|
||||||
srcs = [
|
|
||||||
"status_matchers.cc",
|
|
||||||
],
|
|
||||||
hdrs = ["status_matchers.h"],
|
|
||||||
deps = [
|
|
||||||
"//dpf:status_macros",
|
|
||||||
"@com_github_google_googletest//:gtest",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "proto_validator",
|
|
||||||
srcs = [
|
|
||||||
"proto_validator.cc",
|
|
||||||
],
|
|
||||||
hdrs = [
|
|
||||||
"proto_validator.h",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":value_type_helpers",
|
|
||||||
"//dpf:distributed_point_function_cc_proto",
|
|
||||||
"//dpf:status_macros",
|
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
|
||||||
"@com_google_absl//absl/log:absl_check",
|
|
||||||
"@com_google_absl//absl/memory",
|
|
||||||
"@com_google_absl//absl/numeric:int128",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_absl//absl/strings:str_format",
|
|
||||||
"@com_google_absl//absl/types:span",
|
|
||||||
"@com_google_protobuf//:protobuf_lite",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_embed_data(
|
|
||||||
name = "proto_validator_test_textproto_embed",
|
|
||||||
srcs = [
|
|
||||||
"proto_validator_test.textproto",
|
|
||||||
],
|
|
||||||
cc_file_output = "proto_validator_test_textproto_embed.cc",
|
|
||||||
cpp_namespace = "distributed_point_functions::dpf_internal",
|
|
||||||
h_file_output = "proto_validator_test_textproto_embed.h",
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "proto_validator_test",
|
|
||||||
srcs = [
|
|
||||||
"proto_validator_test.cc",
|
|
||||||
],
|
|
||||||
data = [
|
|
||||||
"proto_validator_test.textproto",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":proto_validator",
|
|
||||||
":proto_validator_test_textproto_embed",
|
|
||||||
":status_matchers",
|
|
||||||
"//dpf:distributed_point_function_cc_proto",
|
|
||||||
"//dpf:tuple",
|
|
||||||
"@com_github_google_googletest//:gtest_main",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_absl//absl/strings:str_format",
|
|
||||||
"@com_google_protobuf//:protobuf",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "evaluate_prg_hwy",
|
|
||||||
srcs = ["evaluate_prg_hwy.cc"],
|
|
||||||
hdrs = ["evaluate_prg_hwy.h"],
|
|
||||||
deps = [
|
|
||||||
":aes_128_fixed_key_hash_hwy",
|
|
||||||
"//dpf:aes_128_fixed_key_hash",
|
|
||||||
"//dpf:status_macros",
|
|
||||||
"@boringssl//:crypto",
|
|
||||||
"@com_github_google_highway//:hwy",
|
|
||||||
"@com_google_absl//absl/base:config",
|
|
||||||
"@com_google_absl//absl/base:core_headers",
|
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
|
||||||
"@com_google_absl//absl/log:absl_check",
|
|
||||||
"@com_google_absl//absl/numeric:int128",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/types:span",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "evaluate_prg_hwy_test",
|
|
||||||
srcs = [
|
|
||||||
"evaluate_prg_hwy_test.cc",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":evaluate_prg_hwy",
|
|
||||||
":status_matchers",
|
|
||||||
"//dpf:aes_128_fixed_key_hash",
|
|
||||||
"@com_github_google_googletest//:gtest_main",
|
|
||||||
"@com_github_google_highway//:hwy",
|
|
||||||
"@com_github_google_highway//:hwy_test_util",
|
|
||||||
"@com_google_absl//absl/numeric:int128",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "get_hwy_mode",
|
|
||||||
srcs = ["get_hwy_mode.cc"],
|
|
||||||
hdrs = ["get_hwy_mode.h"],
|
|
||||||
deps = [
|
|
||||||
"@com_github_google_highway//:hwy",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "aes_128_fixed_key_hash_hwy",
|
|
||||||
hdrs = [
|
|
||||||
"aes_128_fixed_key_hash_hwy.h",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
"@com_github_google_highway//:hwy",
|
|
||||||
"@com_google_absl//absl/numeric:int128",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "maybe_deref_span",
|
|
||||||
hdrs = ["maybe_deref_span.h"],
|
|
||||||
deps = [
|
|
||||||
"@com_google_absl//absl/meta:type_traits",
|
|
||||||
"@com_google_absl//absl/types:span",
|
|
||||||
"@com_google_absl//absl/types:variant",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "aes_128_fixed_key_hash_hwy_test",
|
|
||||||
srcs = [
|
|
||||||
"aes_128_fixed_key_hash_hwy_test.cc",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":aes_128_fixed_key_hash_hwy",
|
|
||||||
":get_hwy_mode",
|
|
||||||
":status_matchers",
|
|
||||||
"//dpf:aes_128_fixed_key_hash",
|
|
||||||
"@boringssl//:crypto",
|
|
||||||
"@com_github_google_googletest//:gtest_main",
|
|
||||||
"@com_github_google_highway//:hwy",
|
|
||||||
"@com_github_google_highway//:hwy_test_util",
|
|
||||||
"@com_google_absl//absl/flags:parse",
|
|
||||||
"@com_google_absl//absl/log:absl_log",
|
|
||||||
"@com_google_absl//absl/numeric:int128",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/types:span",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "maybe_deref_span_test",
|
|
||||||
srcs = ["maybe_deref_span_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":maybe_deref_span",
|
|
||||||
"@com_github_google_googletest//:gtest_main",
|
|
||||||
],
|
|
||||||
)
|
|
@ -1,237 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2021 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Highway-specific include guard, ensuring the header can get included once per
|
|
||||||
// target architecture.
|
|
||||||
#if defined( \
|
|
||||||
DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_HWY_H_) == \
|
|
||||||
defined(HWY_TARGET_TOGGLE)
|
|
||||||
#ifdef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_HWY_H_
|
|
||||||
#undef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_HWY_H_
|
|
||||||
#else
|
|
||||||
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_HWY_H_
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <limits>
|
|
||||||
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "hwy/highway.h"
|
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
namespace HWY_NAMESPACE {
|
|
||||||
|
|
||||||
// There is no AES support on HWY_SCALAR, but we still want to be able to
|
|
||||||
// include this header when compiling for HWY_SCALAR. The caller has to make
|
|
||||||
// sure to only call the functions defined here when not on HWY_SCALAR.
|
|
||||||
#if HWY_TARGET != HWY_SCALAR
|
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
|
|
||||||
constexpr int kAesBlockSize = 16;
|
|
||||||
|
|
||||||
// Helper to convert a Highway tag to a tag for vectors of the same bit size,
|
|
||||||
// but with 64-bit lanes.
|
|
||||||
template <typename D>
|
|
||||||
constexpr auto To64(D d) {
|
|
||||||
return hn::Repartition<uint64_t, D>();
|
|
||||||
}
|
|
||||||
|
|
||||||
// The following macros define parts of the fixed-key AES hash function
|
|
||||||
// implementation. We use macros here since Highway doesn't allow creating
|
|
||||||
// arrays of vectors/SIMD registers. That way, we can access each register by a
|
|
||||||
// unique variable name. All inputs and outputs are assumed to be of type
|
|
||||||
// hn::ScalableTag<uint8_t>.
|
|
||||||
|
|
||||||
// Loads the AES round key for the given round and key_index.
|
|
||||||
#define DPF_AES_LOAD_ROUND_KEY(key_index, round) \
|
|
||||||
const auto round_##round##_key_##key_index = \
|
|
||||||
hn::LoadDup128(d, round_keys_##key_index + kAesBlockSize * round);
|
|
||||||
|
|
||||||
// Selects key_0 or key_1 for the given block_index and round, depending on the
|
|
||||||
// bits in `mask`. Keys are first converted to 64-bit vectors to apply the more
|
|
||||||
// efficient 64 bit masks.
|
|
||||||
#define DPF_AES_SELECT_KEY(block_index, round) \
|
|
||||||
const auto selected_round_##round##_key_##block_index = hn::BitCast( \
|
|
||||||
d, hn::IfThenElse(mask_##block_index, \
|
|
||||||
hn::BitCast(To64(d), round_##round##_key_1), \
|
|
||||||
hn::BitCast(To64(d), round_##round##_key_0)));
|
|
||||||
|
|
||||||
// Load mask for computing {0, x.high64}, for computing sigma(x) below.
|
|
||||||
HWY_ALIGN constexpr absl::uint128 kSigmaMask =
|
|
||||||
absl::MakeUint128(std::numeric_limits<uint64_t>::max(), 0);
|
|
||||||
#define DPF_AES_LOAD_SIGMA_MASK() \
|
|
||||||
const auto sigma_mask = \
|
|
||||||
hn::LoadDup128(To64(d), reinterpret_cast<const uint64_t*>(&kSigmaMask));
|
|
||||||
|
|
||||||
// Compute sigma(x) = {x.high64, x.high64^x.low64} (in little-endian notation).
|
|
||||||
#define DPF_AES_COMPUTE_SIGMA(block_index) \
|
|
||||||
const auto in_##block_index##_64 = hn::BitCast(To64(d), in_##block_index); \
|
|
||||||
const auto sigma_##block_index = \
|
|
||||||
hn::BitCast(d, hn::Xor(hn::Shuffle01(in_##block_index##_64), \
|
|
||||||
hn::And(sigma_mask, in_##block_index##_64)));
|
|
||||||
|
|
||||||
// Performs the first round of AES for the given block_index, using sigma as the
|
|
||||||
// input.
|
|
||||||
#define DPF_AES_FIRST_ROUND(block_index) \
|
|
||||||
out_##block_index = \
|
|
||||||
hn::Xor(sigma_##block_index, selected_round_0_key_##block_index)
|
|
||||||
|
|
||||||
// Performs a middle round of AES for the given block_index.
|
|
||||||
#define DPF_AES_MIDDLE_ROUND(block_index, round) \
|
|
||||||
out_##block_index = hn::AESRound( \
|
|
||||||
out_##block_index, selected_round_##round##_key_##block_index);
|
|
||||||
|
|
||||||
// Performs the last round of AES for the given block_index.
|
|
||||||
#define DPF_AES_LAST_ROUND(block_index) \
|
|
||||||
out_##block_index = hn::AESLastRound(out_##block_index, \
|
|
||||||
selected_round_10_key_##block_index);
|
|
||||||
|
|
||||||
// Finalize the hash by XORing with sigma.
|
|
||||||
#define DPF_AES_FINALIZE_HASH(block_index) \
|
|
||||||
out_##block_index = hn::Xor(out_##block_index, sigma_##block_index);
|
|
||||||
|
|
||||||
// Helper macro for hashing a single vector.
|
|
||||||
#define DPF_AES_MIDDLE_ROUND_1(round) \
|
|
||||||
DPF_AES_LOAD_ROUND_KEY(0, round); \
|
|
||||||
DPF_AES_LOAD_ROUND_KEY(1, round); \
|
|
||||||
DPF_AES_SELECT_KEY(0, round); \
|
|
||||||
DPF_AES_MIDDLE_ROUND(0, round);
|
|
||||||
|
|
||||||
// Hashes a vector `in_0`, writing the output to `out_0`. Each block is hashed
|
|
||||||
// using either `round_keys_0` or `round_keys_1`, which both must point to a
|
|
||||||
// byte array containing two expanded AES keys. Which key is used for each block
|
|
||||||
// depends on `mask_0`: If the mask 0, then `round_keys_0` is used, otherwise
|
|
||||||
// `round_keys_1`. Note that the masks are masks on 64 bit integers, so there
|
|
||||||
// are two mask bits per AES block. The caller is responsible for making sure
|
|
||||||
// that the masks for the two halves of any given block have the same value.
|
|
||||||
template <typename V, typename D, typename M>
|
|
||||||
void HashOneWithKeyMask(D d, V in_0, M mask_0,
|
|
||||||
const uint8_t* HWY_RESTRICT round_keys_0,
|
|
||||||
const uint8_t* HWY_RESTRICT round_keys_1, V& out_0) {
|
|
||||||
// Compute sigma(in_0)
|
|
||||||
DPF_AES_LOAD_SIGMA_MASK();
|
|
||||||
DPF_AES_COMPUTE_SIGMA(0);
|
|
||||||
|
|
||||||
// First AES round.
|
|
||||||
DPF_AES_LOAD_ROUND_KEY(0, 0);
|
|
||||||
DPF_AES_LOAD_ROUND_KEY(1, 0);
|
|
||||||
DPF_AES_SELECT_KEY(0, 0);
|
|
||||||
DPF_AES_FIRST_ROUND(0);
|
|
||||||
|
|
||||||
// Middle AES rounds.
|
|
||||||
DPF_AES_MIDDLE_ROUND_1(1);
|
|
||||||
DPF_AES_MIDDLE_ROUND_1(2);
|
|
||||||
DPF_AES_MIDDLE_ROUND_1(3);
|
|
||||||
DPF_AES_MIDDLE_ROUND_1(4);
|
|
||||||
DPF_AES_MIDDLE_ROUND_1(5);
|
|
||||||
DPF_AES_MIDDLE_ROUND_1(6);
|
|
||||||
DPF_AES_MIDDLE_ROUND_1(7);
|
|
||||||
DPF_AES_MIDDLE_ROUND_1(8);
|
|
||||||
DPF_AES_MIDDLE_ROUND_1(9);
|
|
||||||
|
|
||||||
// Last AES round.
|
|
||||||
DPF_AES_LOAD_ROUND_KEY(0, 10);
|
|
||||||
DPF_AES_LOAD_ROUND_KEY(1, 10);
|
|
||||||
DPF_AES_SELECT_KEY(0, 10)
|
|
||||||
DPF_AES_LAST_ROUND(0);
|
|
||||||
|
|
||||||
// Finalize hash.
|
|
||||||
DPF_AES_FINALIZE_HASH(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper macros for hashing four vectors in parallel.
|
|
||||||
#define DPF_AES_SELECT_KEY_4(round) \
|
|
||||||
DPF_AES_SELECT_KEY(0, round); \
|
|
||||||
DPF_AES_SELECT_KEY(1, round); \
|
|
||||||
DPF_AES_SELECT_KEY(2, round); \
|
|
||||||
DPF_AES_SELECT_KEY(3, round);
|
|
||||||
#define DPF_AES_MIDDLE_ROUND_4(round) \
|
|
||||||
DPF_AES_LOAD_ROUND_KEY(0, round); \
|
|
||||||
DPF_AES_LOAD_ROUND_KEY(1, round); \
|
|
||||||
DPF_AES_SELECT_KEY_4(round); \
|
|
||||||
DPF_AES_MIDDLE_ROUND(0, round); \
|
|
||||||
DPF_AES_MIDDLE_ROUND(1, round); \
|
|
||||||
DPF_AES_MIDDLE_ROUND(2, round); \
|
|
||||||
DPF_AES_MIDDLE_ROUND(3, round);
|
|
||||||
|
|
||||||
// Hashes four vectors `in_0, ..., in_3`, writing the results to `out_0, ...,
|
|
||||||
// out_3`. This improves pipelining of AES instructions, and improves
|
|
||||||
// performance by about 10%. Each block is hashed using either `round_keys_0` or
|
|
||||||
// `round_keys_1`, which both must point to a byte array containing two expanded
|
|
||||||
// AES keys. Which key is used for each block depends on `mask_0, ... mask_3`:
|
|
||||||
// If the mask 0, then `round_keys_0` is used, otherwise `round_keys_1`. Note
|
|
||||||
// that the masks are masks on 64 bit integers, so there are two mask bits per
|
|
||||||
// AES block. The caller is responsible for making sure that the masks for the
|
|
||||||
// two halves of any given block have the same value.
|
|
||||||
template <typename V, typename D, typename M>
|
|
||||||
void HashFourWithKeyMask(D d, V in_0, V in_1, V in_2, V in_3, M mask_0,
|
|
||||||
M mask_1, M mask_2, M mask_3,
|
|
||||||
const uint8_t* HWY_RESTRICT round_keys_0,
|
|
||||||
const uint8_t* HWY_RESTRICT round_keys_1, V& out_0,
|
|
||||||
V& out_1, V& out_2, V& out_3) {
|
|
||||||
// Compute sigma(in_0), ..., sigma(in_3)
|
|
||||||
DPF_AES_LOAD_SIGMA_MASK();
|
|
||||||
DPF_AES_COMPUTE_SIGMA(0);
|
|
||||||
DPF_AES_COMPUTE_SIGMA(1);
|
|
||||||
DPF_AES_COMPUTE_SIGMA(2);
|
|
||||||
DPF_AES_COMPUTE_SIGMA(3);
|
|
||||||
|
|
||||||
// First AES round.
|
|
||||||
DPF_AES_LOAD_ROUND_KEY(0, 0);
|
|
||||||
DPF_AES_LOAD_ROUND_KEY(1, 0);
|
|
||||||
DPF_AES_SELECT_KEY_4(0)
|
|
||||||
DPF_AES_FIRST_ROUND(0);
|
|
||||||
DPF_AES_FIRST_ROUND(1);
|
|
||||||
DPF_AES_FIRST_ROUND(2);
|
|
||||||
DPF_AES_FIRST_ROUND(3);
|
|
||||||
|
|
||||||
// Middle AES rounds.
|
|
||||||
DPF_AES_MIDDLE_ROUND_4(1);
|
|
||||||
DPF_AES_MIDDLE_ROUND_4(2);
|
|
||||||
DPF_AES_MIDDLE_ROUND_4(3);
|
|
||||||
DPF_AES_MIDDLE_ROUND_4(4);
|
|
||||||
DPF_AES_MIDDLE_ROUND_4(5);
|
|
||||||
DPF_AES_MIDDLE_ROUND_4(6);
|
|
||||||
DPF_AES_MIDDLE_ROUND_4(7);
|
|
||||||
DPF_AES_MIDDLE_ROUND_4(8);
|
|
||||||
DPF_AES_MIDDLE_ROUND_4(9);
|
|
||||||
|
|
||||||
// Last AES round.
|
|
||||||
DPF_AES_LOAD_ROUND_KEY(0, 10);
|
|
||||||
DPF_AES_LOAD_ROUND_KEY(1, 10);
|
|
||||||
DPF_AES_SELECT_KEY_4(10)
|
|
||||||
DPF_AES_LAST_ROUND(0);
|
|
||||||
DPF_AES_LAST_ROUND(1);
|
|
||||||
DPF_AES_LAST_ROUND(2);
|
|
||||||
DPF_AES_LAST_ROUND(3);
|
|
||||||
|
|
||||||
// Finalize hash.
|
|
||||||
DPF_AES_FINALIZE_HASH(0);
|
|
||||||
DPF_AES_FINALIZE_HASH(1);
|
|
||||||
DPF_AES_FINALIZE_HASH(2);
|
|
||||||
DPF_AES_FINALIZE_HASH(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // HWY_TARGET != HWY_SCALAR
|
|
||||||
|
|
||||||
} // namespace HWY_NAMESPACE
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
HWY_AFTER_NAMESPACE();
|
|
||||||
|
|
||||||
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_HWY_H_
|
|
232
third_party/distributed_point_functions/code/dpf/internal/aes_128_fixed_key_hash_hwy_test.cc
vendored
232
third_party/distributed_point_functions/code/dpf/internal/aes_128_fixed_key_hash_hwy_test.cc
vendored
@ -1,232 +0,0 @@
|
|||||||
// Copyright 2021 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include <limits>
|
|
||||||
#include <memory>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/flags/parse.h"
|
|
||||||
#include "absl/log/absl_log.h"
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/types/span.h"
|
|
||||||
#include "dpf/aes_128_fixed_key_hash.h"
|
|
||||||
#include "dpf/internal/get_hwy_mode.h"
|
|
||||||
#include "dpf/internal/status_matchers.h"
|
|
||||||
#include "gtest/gtest.h"
|
|
||||||
#include "hwy/aligned_allocator.h"
|
|
||||||
#include "hwy/detect_targets.h"
|
|
||||||
#include "openssl/aes.h"
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define HWY_IS_TEST 1
|
|
||||||
#undef HWY_TARGET_INCLUDE
|
|
||||||
#define HWY_TARGET_INCLUDE "dpf/internal/aes_128_fixed_key_hash_hwy_test.cc" // NOLINT
|
|
||||||
#include "hwy/foreach_target.h"
|
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
#include "dpf/internal/aes_128_fixed_key_hash_hwy.h"
|
|
||||||
#include "hwy/highway.h"
|
|
||||||
#include "hwy/tests/test_util-inl.h"
|
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
namespace HWY_NAMESPACE {
|
|
||||||
|
|
||||||
#if HWY_TARGET == HWY_SCALAR
|
|
||||||
|
|
||||||
void TestAllAes() {
|
|
||||||
return; // HWY_SCALAR doesn't support AES instructions, so nothing to test.
|
|
||||||
}
|
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
|
|
||||||
constexpr absl::uint128 kKey0 =
|
|
||||||
absl::MakeUint128(0x0000000000000000, 0x0000000000000000);
|
|
||||||
constexpr absl::uint128 kKey1 =
|
|
||||||
absl::MakeUint128(0x1111111111111111, 0x1111111111111111);
|
|
||||||
constexpr int kNumBlocks = 128; // Must be divisible by (4 * hn::Lanes(d)).
|
|
||||||
constexpr int kNumBytes = kNumBlocks * sizeof(absl::uint128);
|
|
||||||
|
|
||||||
class TestOutputMatchesOpenSSL {
|
|
||||||
public:
|
|
||||||
template <typename T, typename D>
|
|
||||||
HWY_NOINLINE void operator()(T /*unused*/, D d) {
|
|
||||||
Reset();
|
|
||||||
EvaluateOne(d);
|
|
||||||
CheckResult();
|
|
||||||
Reset();
|
|
||||||
EvaluateFour(d);
|
|
||||||
CheckResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
void Reset() {
|
|
||||||
inputs_ = hwy::AllocateAligned<absl::uint128>(kNumBlocks);
|
|
||||||
ASSERT_NE(inputs_, nullptr);
|
|
||||||
masks_ = hwy::AllocateAligned<uint64_t>(2 * kNumBlocks);
|
|
||||||
ASSERT_NE(masks_, nullptr);
|
|
||||||
for (int i = 0; i < kNumBlocks; ++i) {
|
|
||||||
inputs_[i] = absl::MakeUint128(i, i + 1);
|
|
||||||
masks_[2 * i] = masks_[2 * i + 1] =
|
|
||||||
(i % 3 == 0) ? std::numeric_limits<uint64_t>::max() : 0;
|
|
||||||
}
|
|
||||||
outputs_ = hwy::AllocateAligned<absl::uint128>(kNumBlocks);
|
|
||||||
ASSERT_NE(outputs_, nullptr);
|
|
||||||
ASSERT_EQ(0, AES_set_encrypt_key(reinterpret_cast<const uint8_t*>(&kKey0),
|
|
||||||
128, &expanded_key_0_));
|
|
||||||
ASSERT_EQ(0, AES_set_encrypt_key(reinterpret_cast<const uint8_t*>(&kKey1),
|
|
||||||
128, &expanded_key_1_));
|
|
||||||
input_ptr_ = reinterpret_cast<const uint8_t*>(inputs_.get());
|
|
||||||
output_ptr_ = reinterpret_cast<uint8_t*>(outputs_.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
void CheckResult() {
|
|
||||||
// Check the result by comparing with OpenSSL-based AES hash.
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(
|
|
||||||
distributed_point_functions::Aes128FixedKeyHash hash_0,
|
|
||||||
distributed_point_functions::Aes128FixedKeyHash::Create(kKey0));
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(
|
|
||||||
distributed_point_functions::Aes128FixedKeyHash hash_1,
|
|
||||||
distributed_point_functions::Aes128FixedKeyHash::Create(kKey1));
|
|
||||||
std::vector<absl::uint128> wanted_0(kNumBlocks), wanted_1(kNumBlocks);
|
|
||||||
DPF_ASSERT_OK(
|
|
||||||
hash_0.Evaluate(absl::MakeConstSpan(inputs_.get(), kNumBlocks),
|
|
||||||
absl::MakeSpan(wanted_0)));
|
|
||||||
DPF_ASSERT_OK(
|
|
||||||
hash_1.Evaluate(absl::MakeConstSpan(inputs_.get(), kNumBlocks),
|
|
||||||
absl::MakeSpan(wanted_1)));
|
|
||||||
for (int i = 0; i < kNumBlocks; ++i) {
|
|
||||||
if (i % 3 == 0) {
|
|
||||||
EXPECT_EQ(wanted_1[i], outputs_.get()[i]) << "i=" << i;
|
|
||||||
} else {
|
|
||||||
EXPECT_EQ(wanted_0[i], outputs_.get()[i]) << "i=" << i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename D>
|
|
||||||
void EvaluateOne(D d) {
|
|
||||||
hn::Repartition<uint64_t, D> d64;
|
|
||||||
for (int i = 0; i + hn::Lanes(d) <= kNumBytes; i += hn::Lanes(d)) {
|
|
||||||
const auto in = hn::Load(d, input_ptr_ + i);
|
|
||||||
const auto mask =
|
|
||||||
hn::MaskFromVec(hn::Load(d64, masks_.get() + i / sizeof(uint64_t)));
|
|
||||||
auto out = hn::Undefined(d);
|
|
||||||
HashOneWithKeyMask(
|
|
||||||
d, in, mask, reinterpret_cast<const uint8_t*>(expanded_key_0_.rd_key),
|
|
||||||
reinterpret_cast<const uint8_t*>(expanded_key_1_.rd_key), out);
|
|
||||||
hn::Store(out, d, output_ptr_ + i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename D>
|
|
||||||
void EvaluateFour(D d) {
|
|
||||||
hn::Repartition<uint64_t, D> d64;
|
|
||||||
// Evaluate four vectors at once. Assumes kNumBytes is divisible by (4 *
|
|
||||||
// hn::Lanes(d)).
|
|
||||||
for (int i = 0; i < kNumBytes; i += 4 * hn::Lanes(d)) {
|
|
||||||
const auto in_0 = hn::Load(d, input_ptr_ + i);
|
|
||||||
const auto in_1 = hn::Load(d, input_ptr_ + i + 1 * hn::Lanes(d));
|
|
||||||
const auto in_2 = hn::Load(d, input_ptr_ + i + 2 * hn::Lanes(d));
|
|
||||||
const auto in_3 = hn::Load(d, input_ptr_ + i + 3 * hn::Lanes(d));
|
|
||||||
const auto mask_0 =
|
|
||||||
hn::MaskFromVec(hn::Load(d64, masks_.get() + i / sizeof(uint64_t)));
|
|
||||||
const auto mask_1 = hn::MaskFromVec(hn::Load(
|
|
||||||
d64, masks_.get() + (i + 1 * hn::Lanes(d)) / sizeof(uint64_t)));
|
|
||||||
const auto mask_2 = hn::MaskFromVec(hn::Load(
|
|
||||||
d64, masks_.get() + (i + 2 * hn::Lanes(d)) / sizeof(uint64_t)));
|
|
||||||
const auto mask_3 = hn::MaskFromVec(hn::Load(
|
|
||||||
d64, masks_.get() + (i + 3 * hn::Lanes(d)) / sizeof(uint64_t)));
|
|
||||||
auto out_0 = hn::Undefined(d);
|
|
||||||
auto out_1 = hn::Undefined(d);
|
|
||||||
auto out_2 = hn::Undefined(d);
|
|
||||||
auto out_3 = hn::Undefined(d);
|
|
||||||
HashFourWithKeyMask(
|
|
||||||
d, in_0, in_1, in_2, in_3, mask_0, mask_1, mask_2, mask_3,
|
|
||||||
reinterpret_cast<const uint8_t*>(expanded_key_0_.rd_key),
|
|
||||||
reinterpret_cast<const uint8_t*>(expanded_key_1_.rd_key), out_0,
|
|
||||||
out_1, out_2, out_3);
|
|
||||||
hn::Store(out_0, d, output_ptr_ + i);
|
|
||||||
hn::Store(out_1, d, output_ptr_ + i + 1 * hn::Lanes(d));
|
|
||||||
hn::Store(out_2, d, output_ptr_ + i + 2 * hn::Lanes(d));
|
|
||||||
hn::Store(out_3, d, output_ptr_ + i + 3 * hn::Lanes(d));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check the result by comparing with OpenSSL-based AES hash.
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(
|
|
||||||
distributed_point_functions::Aes128FixedKeyHash hash_0,
|
|
||||||
distributed_point_functions::Aes128FixedKeyHash::Create(kKey0));
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(
|
|
||||||
distributed_point_functions::Aes128FixedKeyHash hash_1,
|
|
||||||
distributed_point_functions::Aes128FixedKeyHash::Create(kKey1));
|
|
||||||
std::vector<absl::uint128> wanted_0(kNumBlocks), wanted_1(kNumBlocks);
|
|
||||||
DPF_ASSERT_OK(
|
|
||||||
hash_0.Evaluate(absl::MakeConstSpan(inputs_.get(), kNumBlocks),
|
|
||||||
absl::MakeSpan(wanted_0)));
|
|
||||||
DPF_ASSERT_OK(
|
|
||||||
hash_1.Evaluate(absl::MakeConstSpan(inputs_.get(), kNumBlocks),
|
|
||||||
absl::MakeSpan(wanted_1)));
|
|
||||||
for (int i = 0; i < kNumBlocks; ++i) {
|
|
||||||
if (i % 3 == 0) {
|
|
||||||
EXPECT_EQ(wanted_1[i], outputs_.get()[i]) << "i=" << i;
|
|
||||||
} else {
|
|
||||||
EXPECT_EQ(wanted_0[i], outputs_.get()[i]) << "i=" << i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<absl::uint128[]> inputs_, outputs_;
|
|
||||||
hwy::AlignedFreeUniquePtr<uint64_t[]> masks_;
|
|
||||||
const uint8_t* input_ptr_;
|
|
||||||
uint8_t* output_ptr_;
|
|
||||||
HWY_ALIGN AES_KEY expanded_key_0_, expanded_key_1_;
|
|
||||||
};
|
|
||||||
|
|
||||||
void TestAllAes() {
|
|
||||||
hn::ForGE128Vectors<TestOutputMatchesOpenSSL>()(uint8_t{0});
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // HWY_TARGET == HWY_SCALAR
|
|
||||||
|
|
||||||
} // namespace HWY_NAMESPACE
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
HWY_AFTER_NAMESPACE();
|
|
||||||
|
|
||||||
#if HWY_ONCE
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
HWY_BEFORE_TEST(Aes128FixedKeyHashHwyTest);
|
|
||||||
HWY_EXPORT_AND_TEST_P(Aes128FixedKeyHashHwyTest, TestAllAes);
|
|
||||||
|
|
||||||
TEST(Aes128FixedKeyHashHwy, LogHwyMode) {
|
|
||||||
ABSL_LOG(INFO) << "Highway is in " << GetHwyModeAsString() << " mode";
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
::testing::InitGoogleTest(&argc, argv);
|
|
||||||
|
|
||||||
return RUN_ALL_TESTS();
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
@ -1,662 +0,0 @@
|
|||||||
// Copyright 2021 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "dpf/internal/evaluate_prg_hwy.h"
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cstdint>
|
|
||||||
#include <memory>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/base/config.h"
|
|
||||||
#include "absl/base/optimization.h"
|
|
||||||
#include "absl/container/inlined_vector.h"
|
|
||||||
#include "absl/log/absl_check.h"
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/types/span.h"
|
|
||||||
#include "dpf/aes_128_fixed_key_hash.h"
|
|
||||||
#include "dpf/status_macros.h"
|
|
||||||
#include "hwy/aligned_allocator.h"
|
|
||||||
#include "openssl/aes.h"
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#undef HWY_TARGET_INCLUDE
|
|
||||||
#define HWY_TARGET_INCLUDE "dpf/internal/evaluate_prg_hwy.cc"
|
|
||||||
#include "hwy/foreach_target.h"
|
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
#include "dpf/internal/aes_128_fixed_key_hash_hwy.h"
|
|
||||||
#include "hwy/highway.h"
|
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
namespace HWY_NAMESPACE {
|
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
|
|
||||||
#if HWY_TARGET == HWY_SCALAR
|
|
||||||
|
|
||||||
absl::Status EvaluateSeedsHwy(
|
|
||||||
int64_t num_seeds, int num_levels, const absl::uint128* seeds_in,
|
|
||||||
const bool* control_bits_in, const absl::uint128* paths,
|
|
||||||
const absl::uint128* correction_seeds, const bool* correction_controls_left,
|
|
||||||
const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left,
|
|
||||||
const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out,
|
|
||||||
bool* control_bits_out) {
|
|
||||||
return EvaluateSeedsNoHwy(num_seeds, num_levels, seeds_in, control_bits_in,
|
|
||||||
paths, correction_seeds, correction_controls_left,
|
|
||||||
correction_controls_right, prg_left, prg_right,
|
|
||||||
seeds_out, control_bits_out);
|
|
||||||
}
|
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
// Converts a bool array to a block-level mask suitable for vectors described by
|
|
||||||
// `d`. The mask value for each integer in the i-th block is set to input[i].
|
|
||||||
// If `max_blocks > 0`, returns after reading `max_blocks` bools from `input`.
|
|
||||||
template <typename D>
|
|
||||||
auto MaskFromBools(D d, const bool* input, int max_blocks = 0) {
|
|
||||||
using T = hn::TFromD<D>;
|
|
||||||
constexpr size_t ints_per_block = sizeof(absl::uint128) / sizeof(T);
|
|
||||||
constexpr int buffer_size = std::max(HWY_MAX_BYTES / 8, 64);
|
|
||||||
uint8_t mask_bits[buffer_size] = {0};
|
|
||||||
for (int i = 0; i < hn::Lanes(d); ++i) {
|
|
||||||
int block_idx = i / ints_per_block;
|
|
||||||
if (max_blocks > 0 && block_idx >= max_blocks) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (input[block_idx]) {
|
|
||||||
mask_bits[i / 8] |= uint8_t{1} << (i % 8);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return hn::LoadMaskBits(d, mask_bits);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Converts a mask for types `d` to a bool array. Assumes that the mask value
|
|
||||||
// for all integers in the i-th block is equal, and writes that value to
|
|
||||||
// output[i]. If `max_blocks > 0`, returns after writing `max_blocks` bools to
|
|
||||||
// `output`.
|
|
||||||
template <typename D, typename M>
|
|
||||||
void BoolsFromMask(D d, M mask, bool* output, int max_blocks = 0) {
|
|
||||||
using T = hn::TFromD<D>;
|
|
||||||
constexpr size_t ints_per_block = sizeof(absl::uint128) / sizeof(T);
|
|
||||||
int num_outputs = hn::Lanes(d) / ints_per_block;
|
|
||||||
if (max_blocks > 0) {
|
|
||||||
num_outputs = max_blocks;
|
|
||||||
}
|
|
||||||
constexpr int buffer_size = std::max(HWY_MAX_BYTES / 8, 64);
|
|
||||||
uint8_t mask_bits[buffer_size] = {0};
|
|
||||||
hn::StoreMaskBits(d, mask, mask_bits);
|
|
||||||
for (int i = 0; i < num_outputs; ++i) {
|
|
||||||
int mask_idx = i * ints_per_block;
|
|
||||||
output[i] = (mask_bits[mask_idx / 8] & (uint8_t{1} << (mask_idx % 8))) != 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename M>
|
|
||||||
M IfThenElseMask(M condition, M true_value, M false_value) {
|
|
||||||
return hn::Or(hn::And(condition, true_value),
|
|
||||||
hn::And(hn::Not(condition), false_value));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a mask that is `true` on all blocks where `input[i] & (1 << index)`
|
|
||||||
// is nonzero. The mask is a 64-bit-level mask, suitable for AES hashing.
|
|
||||||
template <typename V, typename D>
|
|
||||||
auto IsBitSet(D d, const V input, int index) {
|
|
||||||
// First create a 128-bit block with the `index`-th bit set.
|
|
||||||
HWY_ALIGN absl::uint128 mask = 0;
|
|
||||||
if (index < 128) {
|
|
||||||
mask = absl::uint128{1} << index;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now load it into a vector of 64-bit integers. Note that every second
|
|
||||||
// element of that vector will be 0.
|
|
||||||
const hn::Repartition<uint64_t, D> d64;
|
|
||||||
static_assert(ABSL_IS_LITTLE_ENDIAN);
|
|
||||||
const auto mask_64 =
|
|
||||||
hn::LoadDup128(d64, reinterpret_cast<const uint64_t*>(&mask));
|
|
||||||
|
|
||||||
// Compute input AND mask_64 on 64-bit integers.
|
|
||||||
auto input_64 = hn::BitCast(d64, input);
|
|
||||||
input_64 = hn::And(input_64, mask_64);
|
|
||||||
|
|
||||||
// Take the OR of every two adjacent 64-bit integers. This ensures that each
|
|
||||||
// half of an 128-bit block is nonzero iff at least one half was nonzero.
|
|
||||||
input_64 = hn::Or(input_64, hn::Shuffle01(input_64));
|
|
||||||
|
|
||||||
// Compute a 64-bit mask that checks which integers are nonzero.
|
|
||||||
return hn::Ne(input_64, hn::Zero(d64));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dummy struct to get HWY_ALIGN as a number, for testing if an array of
|
|
||||||
// absl::uint128 is aligned.
|
|
||||||
struct HWY_ALIGN Aligned128 {
|
|
||||||
absl::uint128 _;
|
|
||||||
};
|
|
||||||
|
|
||||||
absl::Status EvaluateSeedsHwy(
|
|
||||||
int64_t num_seeds, int num_levels, int num_correction_words,
|
|
||||||
const absl::uint128* seeds_in, const bool* control_bits_in,
|
|
||||||
const absl::uint128* paths, int paths_rightshift,
|
|
||||||
const absl::uint128* correction_seeds, const bool* correction_controls_left,
|
|
||||||
const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left,
|
|
||||||
const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out,
|
|
||||||
bool* control_bits_out) {
|
|
||||||
// Exit early if inputs are empty.
|
|
||||||
if (num_seeds == 0 || num_levels == 0) {
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if inputs and outputs are aligned.
|
|
||||||
constexpr size_t kHwyAlignment = alignof(Aligned128);
|
|
||||||
const bool is_aligned =
|
|
||||||
(reinterpret_cast<uintptr_t>(seeds_in) % kHwyAlignment == 0) &&
|
|
||||||
(reinterpret_cast<uintptr_t>(paths) % kHwyAlignment == 0) &&
|
|
||||||
(reinterpret_cast<uintptr_t>(correction_seeds) % kHwyAlignment == 0) &&
|
|
||||||
(reinterpret_cast<uintptr_t>(seeds_out) % kHwyAlignment == 0);
|
|
||||||
// Vector type used throughout this function: Largest byte vector available.
|
|
||||||
const hn::ScalableTag<uint8_t> d8;
|
|
||||||
// Only run the highway version if
|
|
||||||
// - the inputs are aligned,
|
|
||||||
// - the number of bytes in a vector is at least 16, and
|
|
||||||
// - the number of bytes in a vector is a multiple of 16.
|
|
||||||
if (ABSL_PREDICT_FALSE(!is_aligned || hn::Lanes(d8) < 16 ||
|
|
||||||
hn::Lanes(d8) % 16 != 0)) {
|
|
||||||
return EvaluateSeedsNoHwy(
|
|
||||||
num_seeds, num_levels, num_correction_words, seeds_in, control_bits_in,
|
|
||||||
paths, paths_rightshift, correction_seeds, correction_controls_left,
|
|
||||||
correction_controls_right, prg_left, prg_right, seeds_out,
|
|
||||||
control_bits_out);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Do AES key schedule.
|
|
||||||
HWY_ALIGN AES_KEY expanded_key_0;
|
|
||||||
HWY_ALIGN AES_KEY expanded_key_1;
|
|
||||||
int openssl_status = AES_set_encrypt_key(
|
|
||||||
reinterpret_cast<const uint8_t*>(&prg_left.key()), 128, &expanded_key_0);
|
|
||||||
if (openssl_status != 0) {
|
|
||||||
return absl::InternalError("Failed to set up AES key");
|
|
||||||
}
|
|
||||||
openssl_status = AES_set_encrypt_key(
|
|
||||||
reinterpret_cast<const uint8_t*>(&prg_right.key()), 128, &expanded_key_1);
|
|
||||||
if (openssl_status != 0) {
|
|
||||||
return absl::InternalError("Failed to set up AES key");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper variables.
|
|
||||||
const hn::Repartition<uint64_t, decltype(d8)> d64;
|
|
||||||
HWY_ALIGN absl::uint128 clear_lowest_bit_128 = ~absl::uint128{1};
|
|
||||||
const auto clear_lowest_bit = hn::LoadDup128(
|
|
||||||
d8, reinterpret_cast<const uint8_t*>(&clear_lowest_bit_128));
|
|
||||||
const auto mask_all_zero = hn::FirstN(d64, 0);
|
|
||||||
const auto mask_all_one = hn::Not(mask_all_zero);
|
|
||||||
const int64_t num_bytes = num_seeds * sizeof(absl::uint128);
|
|
||||||
const int bytes_per_vec = hn::Lanes(d8);
|
|
||||||
const int blocks_per_vec = bytes_per_vec / sizeof(absl::uint128);
|
|
||||||
const int64_t correction_words_per_level = num_correction_words / num_levels;
|
|
||||||
|
|
||||||
// Pointer aliases for reading and writing data.
|
|
||||||
const uint8_t* seeds_in_ptr = reinterpret_cast<const uint8_t*>(seeds_in);
|
|
||||||
const uint8_t* paths_ptr = reinterpret_cast<const uint8_t*>(paths);
|
|
||||||
uint8_t* seeds_out_ptr = reinterpret_cast<uint8_t*>(seeds_out);
|
|
||||||
// Four vectors at a time.
|
|
||||||
int64_t i = 0;
|
|
||||||
for (; i + 4 * bytes_per_vec <= num_bytes; i += 4 * bytes_per_vec) {
|
|
||||||
const int64_t start_block = i / sizeof(absl::uint128);
|
|
||||||
// Load initial seeds and paths into vectors.
|
|
||||||
auto vec_0 = hn::Load(d8, seeds_in_ptr + i);
|
|
||||||
auto vec_1 = hn::Load(d8, seeds_in_ptr + i + 1 * bytes_per_vec);
|
|
||||||
auto vec_2 = hn::Load(d8, seeds_in_ptr + i + 2 * bytes_per_vec);
|
|
||||||
auto vec_3 = hn::Load(d8, seeds_in_ptr + i + 3 * bytes_per_vec);
|
|
||||||
const auto path_0 = hn::Load(d8, paths_ptr + i);
|
|
||||||
const auto path_1 = hn::Load(d8, paths_ptr + i + 1 * bytes_per_vec);
|
|
||||||
const auto path_2 = hn::Load(d8, paths_ptr + i + 2 * bytes_per_vec);
|
|
||||||
const auto path_3 = hn::Load(d8, paths_ptr + i + 3 * bytes_per_vec);
|
|
||||||
auto control_mask_0 = MaskFromBools(d64, control_bits_in + start_block);
|
|
||||||
auto control_mask_1 =
|
|
||||||
MaskFromBools(d64, control_bits_in + start_block + 1 * blocks_per_vec);
|
|
||||||
auto control_mask_2 =
|
|
||||||
MaskFromBools(d64, control_bits_in + start_block + 2 * blocks_per_vec);
|
|
||||||
auto control_mask_3 =
|
|
||||||
MaskFromBools(d64, control_bits_in + start_block + 3 * blocks_per_vec);
|
|
||||||
for (int j = 0; j < num_levels; ++j) {
|
|
||||||
// Convert path bits to masks and evaluate PRG.
|
|
||||||
const int bit_index = num_levels - j - 1 + paths_rightshift;
|
|
||||||
const auto path_mask_0 = IsBitSet(d8, path_0, bit_index);
|
|
||||||
const auto path_mask_1 = IsBitSet(d8, path_1, bit_index);
|
|
||||||
const auto path_mask_2 = IsBitSet(d8, path_2, bit_index);
|
|
||||||
const auto path_mask_3 = IsBitSet(d8, path_3, bit_index);
|
|
||||||
HashFourWithKeyMask(
|
|
||||||
d8, vec_0, vec_1, vec_2, vec_3, path_mask_0, path_mask_1, path_mask_2,
|
|
||||||
path_mask_3, reinterpret_cast<const uint8_t*>(expanded_key_0.rd_key),
|
|
||||||
reinterpret_cast<const uint8_t*>(expanded_key_1.rd_key), vec_0, vec_1,
|
|
||||||
vec_2, vec_3);
|
|
||||||
|
|
||||||
// Apply correction.
|
|
||||||
if (correction_words_per_level == 1) {
|
|
||||||
const auto correction_seed = hn::LoadDup128(
|
|
||||||
d64, reinterpret_cast<const uint64_t*>(correction_seeds + j));
|
|
||||||
vec_0 = hn::Xor(vec_0,
|
|
||||||
hn::BitCast(d8, hn::IfThenElseZero(control_mask_0,
|
|
||||||
correction_seed)));
|
|
||||||
vec_1 = hn::Xor(vec_1,
|
|
||||||
hn::BitCast(d8, hn::IfThenElseZero(control_mask_1,
|
|
||||||
correction_seed)));
|
|
||||||
vec_2 = hn::Xor(vec_2,
|
|
||||||
hn::BitCast(d8, hn::IfThenElseZero(control_mask_2,
|
|
||||||
correction_seed)));
|
|
||||||
vec_3 = hn::Xor(vec_3,
|
|
||||||
hn::BitCast(d8, hn::IfThenElseZero(control_mask_3,
|
|
||||||
correction_seed)));
|
|
||||||
} else { // correction_words_per_level == num_seeds.
|
|
||||||
const uint8_t* correction_seeds_ptr = reinterpret_cast<const uint8_t*>(
|
|
||||||
correction_seeds + j * correction_words_per_level);
|
|
||||||
hn::Vec<decltype(d64)> correction_seed_0, correction_seed_1,
|
|
||||||
correction_seed_2, correction_seed_3;
|
|
||||||
if (ABSL_PREDICT_TRUE(
|
|
||||||
correction_words_per_level % blocks_per_vec == 0 || j == 0)) {
|
|
||||||
correction_seed_0 =
|
|
||||||
hn::BitCast(d64, hn::Load(d8, correction_seeds_ptr + i));
|
|
||||||
correction_seed_1 = hn::BitCast(
|
|
||||||
d64, hn::Load(d8, correction_seeds_ptr + i + 1 * bytes_per_vec));
|
|
||||||
correction_seed_2 = hn::BitCast(
|
|
||||||
d64, hn::Load(d8, correction_seeds_ptr + i + 2 * bytes_per_vec));
|
|
||||||
correction_seed_3 = hn::BitCast(
|
|
||||||
d64, hn::Load(d8, correction_seeds_ptr + i + 3 * bytes_per_vec));
|
|
||||||
} else {
|
|
||||||
correction_seed_0 =
|
|
||||||
hn::BitCast(d64, hn::LoadU(d8, correction_seeds_ptr + i));
|
|
||||||
correction_seed_1 = hn::BitCast(
|
|
||||||
d64, hn::LoadU(d8, correction_seeds_ptr + i + 1 * bytes_per_vec));
|
|
||||||
correction_seed_2 = hn::BitCast(
|
|
||||||
d64, hn::LoadU(d8, correction_seeds_ptr + i + 2 * bytes_per_vec));
|
|
||||||
correction_seed_3 = hn::BitCast(
|
|
||||||
d64, hn::LoadU(d8, correction_seeds_ptr + i + 3 * bytes_per_vec));
|
|
||||||
}
|
|
||||||
vec_0 = hn::Xor(vec_0,
|
|
||||||
hn::BitCast(d8, hn::IfThenElseZero(control_mask_0,
|
|
||||||
correction_seed_0)));
|
|
||||||
vec_1 = hn::Xor(vec_1,
|
|
||||||
hn::BitCast(d8, hn::IfThenElseZero(control_mask_1,
|
|
||||||
correction_seed_1)));
|
|
||||||
vec_2 = hn::Xor(vec_2,
|
|
||||||
hn::BitCast(d8, hn::IfThenElseZero(control_mask_2,
|
|
||||||
correction_seed_2)));
|
|
||||||
vec_3 = hn::Xor(vec_3,
|
|
||||||
hn::BitCast(d8, hn::IfThenElseZero(control_mask_3,
|
|
||||||
correction_seed_3)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract control bit for next level.
|
|
||||||
const auto next_control_mask_0 = IsBitSet(d8, vec_0, 0);
|
|
||||||
const auto next_control_mask_1 = IsBitSet(d8, vec_1, 0);
|
|
||||||
const auto next_control_mask_2 = IsBitSet(d8, vec_2, 0);
|
|
||||||
const auto next_control_mask_3 = IsBitSet(d8, vec_3, 0);
|
|
||||||
vec_0 = hn::And(vec_0, clear_lowest_bit);
|
|
||||||
vec_1 = hn::And(vec_1, clear_lowest_bit);
|
|
||||||
vec_2 = hn::And(vec_2, clear_lowest_bit);
|
|
||||||
vec_3 = hn::And(vec_3, clear_lowest_bit);
|
|
||||||
|
|
||||||
// Perform control bit correction.
|
|
||||||
auto correction_control_mask_0 = mask_all_zero,
|
|
||||||
correction_control_mask_1 = mask_all_zero,
|
|
||||||
correction_control_mask_2 = mask_all_zero,
|
|
||||||
correction_control_mask_3 = mask_all_zero;
|
|
||||||
if (correction_words_per_level == 1) {
|
|
||||||
const auto correction_control_mask_left =
|
|
||||||
correction_controls_left[j] ? mask_all_one : mask_all_zero;
|
|
||||||
const auto correction_control_mask_right =
|
|
||||||
correction_controls_right[j] ? mask_all_one : mask_all_zero;
|
|
||||||
correction_control_mask_0 =
|
|
||||||
IfThenElseMask(path_mask_0, correction_control_mask_right,
|
|
||||||
correction_control_mask_left);
|
|
||||||
correction_control_mask_1 =
|
|
||||||
IfThenElseMask(path_mask_1, correction_control_mask_right,
|
|
||||||
correction_control_mask_left);
|
|
||||||
correction_control_mask_2 =
|
|
||||||
IfThenElseMask(path_mask_2, correction_control_mask_right,
|
|
||||||
correction_control_mask_left);
|
|
||||||
correction_control_mask_3 =
|
|
||||||
IfThenElseMask(path_mask_3, correction_control_mask_right,
|
|
||||||
correction_control_mask_left);
|
|
||||||
} else { // correction_words_per_level == num_seeds.
|
|
||||||
const bool* correction_controls_left_j =
|
|
||||||
correction_controls_left + j * correction_words_per_level +
|
|
||||||
start_block;
|
|
||||||
const bool* correction_controls_right_j =
|
|
||||||
correction_controls_right + j * correction_words_per_level +
|
|
||||||
start_block;
|
|
||||||
correction_control_mask_0 = IfThenElseMask(
|
|
||||||
path_mask_0, MaskFromBools(d64, correction_controls_right_j),
|
|
||||||
MaskFromBools(d64, correction_controls_left_j));
|
|
||||||
correction_control_mask_1 = IfThenElseMask(
|
|
||||||
path_mask_1,
|
|
||||||
MaskFromBools(d64,
|
|
||||||
correction_controls_right_j + 1 * blocks_per_vec),
|
|
||||||
MaskFromBools(d64,
|
|
||||||
correction_controls_left_j + 1 * blocks_per_vec));
|
|
||||||
correction_control_mask_2 = IfThenElseMask(
|
|
||||||
path_mask_2,
|
|
||||||
MaskFromBools(d64,
|
|
||||||
correction_controls_right_j + 2 * blocks_per_vec),
|
|
||||||
MaskFromBools(d64,
|
|
||||||
correction_controls_left_j + 2 * blocks_per_vec));
|
|
||||||
correction_control_mask_3 = IfThenElseMask(
|
|
||||||
path_mask_3,
|
|
||||||
MaskFromBools(d64,
|
|
||||||
correction_controls_right_j + 3 * blocks_per_vec),
|
|
||||||
MaskFromBools(d64,
|
|
||||||
correction_controls_left_j + 3 * blocks_per_vec));
|
|
||||||
}
|
|
||||||
|
|
||||||
control_mask_0 =
|
|
||||||
hn::Xor(next_control_mask_0,
|
|
||||||
(hn::And(control_mask_0, correction_control_mask_0)));
|
|
||||||
control_mask_1 =
|
|
||||||
hn::Xor(next_control_mask_1,
|
|
||||||
(hn::And(control_mask_1, correction_control_mask_1)));
|
|
||||||
control_mask_2 =
|
|
||||||
hn::Xor(next_control_mask_2,
|
|
||||||
(hn::And(control_mask_2, correction_control_mask_2)));
|
|
||||||
control_mask_3 =
|
|
||||||
hn::Xor(next_control_mask_3,
|
|
||||||
(hn::And(control_mask_3, correction_control_mask_3)));
|
|
||||||
}
|
|
||||||
// Write the evaluated outputs to memory.
|
|
||||||
hn::Store(vec_0, d8, seeds_out_ptr + i);
|
|
||||||
hn::Store(vec_1, d8, seeds_out_ptr + i + 1 * bytes_per_vec);
|
|
||||||
hn::Store(vec_2, d8, seeds_out_ptr + i + 2 * bytes_per_vec);
|
|
||||||
hn::Store(vec_3, d8, seeds_out_ptr + i + 3 * bytes_per_vec);
|
|
||||||
BoolsFromMask(d64, control_mask_0, control_bits_out + start_block);
|
|
||||||
BoolsFromMask(d64, control_mask_1,
|
|
||||||
control_bits_out + start_block + 1 * blocks_per_vec);
|
|
||||||
BoolsFromMask(d64, control_mask_2,
|
|
||||||
control_bits_out + start_block + 2 * blocks_per_vec);
|
|
||||||
BoolsFromMask(d64, control_mask_3,
|
|
||||||
control_bits_out + start_block + 3 * blocks_per_vec);
|
|
||||||
}
|
|
||||||
ABSL_DCHECK_GT(i + 4 * bytes_per_vec, num_bytes);
|
|
||||||
|
|
||||||
// Single full vectors.
|
|
||||||
for (; i + bytes_per_vec <= num_bytes; i += bytes_per_vec) {
|
|
||||||
const int64_t start_block = i / sizeof(absl::uint128);
|
|
||||||
auto vec = hn::Load(d8, seeds_in_ptr + i);
|
|
||||||
const auto path = hn::Load(d8, paths_ptr + i);
|
|
||||||
auto control_mask = MaskFromBools(d64, control_bits_in + start_block);
|
|
||||||
for (int j = 0; j < num_levels; ++j) {
|
|
||||||
const int bit_index = num_levels - j - 1 + paths_rightshift;
|
|
||||||
const auto path_mask = IsBitSet(d8, path, bit_index);
|
|
||||||
HashOneWithKeyMask(
|
|
||||||
d8, vec, path_mask,
|
|
||||||
reinterpret_cast<const uint8_t*>(expanded_key_0.rd_key),
|
|
||||||
reinterpret_cast<const uint8_t*>(expanded_key_1.rd_key), vec);
|
|
||||||
|
|
||||||
// Apply correction.
|
|
||||||
hn::Vec<decltype(d64)> correction_seed;
|
|
||||||
if (correction_words_per_level == 1) {
|
|
||||||
correction_seed = hn::LoadDup128(
|
|
||||||
d64, reinterpret_cast<const uint64_t*>(correction_seeds + j));
|
|
||||||
} else {
|
|
||||||
const uint64_t* correction_seeds_ptr =
|
|
||||||
reinterpret_cast<const uint64_t*>(correction_seeds +
|
|
||||||
j * correction_words_per_level +
|
|
||||||
start_block);
|
|
||||||
if (ABSL_PREDICT_TRUE(
|
|
||||||
correction_words_per_level % blocks_per_vec == 0 || j == 0)) {
|
|
||||||
correction_seed = hn::Load(d64, correction_seeds_ptr);
|
|
||||||
} else {
|
|
||||||
correction_seed = hn::LoadU(d64, correction_seeds_ptr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
vec = hn::Xor(vec, hn::BitCast(d8, hn::IfThenElseZero(control_mask,
|
|
||||||
correction_seed)));
|
|
||||||
|
|
||||||
// Extract control bit for next level.
|
|
||||||
const auto next_control_mask = IsBitSet(d8, vec, 0);
|
|
||||||
vec = hn::And(vec, clear_lowest_bit);
|
|
||||||
|
|
||||||
// Perform control bit correction.
|
|
||||||
auto correction_control_mask = mask_all_zero;
|
|
||||||
if (correction_words_per_level == 1) {
|
|
||||||
const auto correction_control_mask_left =
|
|
||||||
correction_controls_left[j] ? mask_all_one : mask_all_zero;
|
|
||||||
const auto correction_control_mask_right =
|
|
||||||
correction_controls_right[j] ? mask_all_one : mask_all_zero;
|
|
||||||
correction_control_mask =
|
|
||||||
IfThenElseMask(path_mask, correction_control_mask_right,
|
|
||||||
correction_control_mask_left);
|
|
||||||
} else {
|
|
||||||
const bool* correction_controls_left_j =
|
|
||||||
correction_controls_left + j * correction_words_per_level +
|
|
||||||
start_block;
|
|
||||||
const bool* correction_controls_right_j =
|
|
||||||
correction_controls_right + j * correction_words_per_level +
|
|
||||||
start_block;
|
|
||||||
correction_control_mask = IfThenElseMask(
|
|
||||||
path_mask, MaskFromBools(d64, correction_controls_right_j),
|
|
||||||
MaskFromBools(d64, correction_controls_left_j));
|
|
||||||
}
|
|
||||||
control_mask = hn::Xor(next_control_mask,
|
|
||||||
(hn::And(control_mask, correction_control_mask)));
|
|
||||||
}
|
|
||||||
hn::Store(vec, d8, seeds_out_ptr + i);
|
|
||||||
BoolsFromMask(d64, control_mask, control_bits_out + start_block);
|
|
||||||
}
|
|
||||||
ABSL_DCHECK_GT(i + bytes_per_vec, num_bytes);
|
|
||||||
|
|
||||||
// Elements less than a full vector.
|
|
||||||
int remaining_blocks = num_seeds - i / sizeof(absl::uint128);
|
|
||||||
if (remaining_blocks > 0) {
|
|
||||||
const int64_t start_block = i / sizeof(absl::uint128);
|
|
||||||
const int remaining_bytes = num_bytes - i;
|
|
||||||
// Copy to a buffer first, to ensure we have at least bytes_per_vec bytes
|
|
||||||
// to read. Calling MaskedLoad directly instead might lead to out-of-bounds
|
|
||||||
// accesses.
|
|
||||||
auto buffer = hwy::AllocateAligned<absl::uint128>(2 * blocks_per_vec);
|
|
||||||
if (buffer == nullptr) {
|
|
||||||
return absl::ResourceExhaustedError("Memory allocation error");
|
|
||||||
}
|
|
||||||
auto buffer_ptr = reinterpret_cast<uint8_t*>(buffer.get());
|
|
||||||
std::copy_n(seeds_in + start_block, remaining_blocks, buffer.get());
|
|
||||||
std::copy_n(paths + start_block, remaining_blocks,
|
|
||||||
buffer.get() + blocks_per_vec);
|
|
||||||
const auto load_mask = hn::FirstN(d8, remaining_bytes);
|
|
||||||
auto vec = hn::MaskedLoad(load_mask, d8, buffer_ptr);
|
|
||||||
const auto path = hn::MaskedLoad(load_mask, d8, buffer_ptr + bytes_per_vec);
|
|
||||||
auto control_mask =
|
|
||||||
MaskFromBools(d64, control_bits_in + start_block, remaining_blocks);
|
|
||||||
for (int j = 0; j < num_levels; ++j) {
|
|
||||||
const int bit_index = num_levels - j - 1 + paths_rightshift;
|
|
||||||
const auto path_mask = IsBitSet(d8, path, bit_index);
|
|
||||||
HashOneWithKeyMask(
|
|
||||||
d8, vec, path_mask,
|
|
||||||
reinterpret_cast<const uint8_t*>(expanded_key_0.rd_key),
|
|
||||||
reinterpret_cast<const uint8_t*>(expanded_key_1.rd_key), vec);
|
|
||||||
|
|
||||||
// Perform seed correction.
|
|
||||||
hn::Vec<decltype(d64)> correction_seed;
|
|
||||||
if (correction_words_per_level == 1) {
|
|
||||||
correction_seed = hn::LoadDup128(
|
|
||||||
d64, reinterpret_cast<const uint64_t*>(correction_seeds + j));
|
|
||||||
} else {
|
|
||||||
std::copy_n(
|
|
||||||
correction_seeds + j * correction_words_per_level + start_block,
|
|
||||||
remaining_blocks, buffer.get());
|
|
||||||
correction_seed =
|
|
||||||
hn::BitCast(d64, hn::MaskedLoad(load_mask, d8, buffer_ptr));
|
|
||||||
}
|
|
||||||
vec = hn::Xor(vec, hn::BitCast(d8, hn::IfThenElseZero(control_mask,
|
|
||||||
correction_seed)));
|
|
||||||
const auto next_control_mask = IsBitSet(d8, vec, 0);
|
|
||||||
vec = hn::And(vec, clear_lowest_bit);
|
|
||||||
|
|
||||||
// Perform control bit correction.
|
|
||||||
auto correction_control_mask = mask_all_zero;
|
|
||||||
if (correction_words_per_level == 1) {
|
|
||||||
const auto correction_control_mask_left =
|
|
||||||
correction_controls_left[j] ? mask_all_one : mask_all_zero;
|
|
||||||
const auto correction_control_mask_right =
|
|
||||||
correction_controls_right[j] ? mask_all_one : mask_all_zero;
|
|
||||||
correction_control_mask =
|
|
||||||
IfThenElseMask(path_mask, correction_control_mask_right,
|
|
||||||
correction_control_mask_left);
|
|
||||||
} else {
|
|
||||||
const bool* correction_controls_left_j =
|
|
||||||
correction_controls_left + j * correction_words_per_level +
|
|
||||||
start_block;
|
|
||||||
const bool* correction_controls_right_j =
|
|
||||||
correction_controls_right + j * correction_words_per_level +
|
|
||||||
start_block;
|
|
||||||
correction_control_mask = IfThenElseMask(
|
|
||||||
path_mask,
|
|
||||||
MaskFromBools(d64, correction_controls_right_j, remaining_blocks),
|
|
||||||
MaskFromBools(d64, correction_controls_left_j, remaining_blocks));
|
|
||||||
}
|
|
||||||
control_mask = hn::Xor(next_control_mask,
|
|
||||||
(hn::And(control_mask, correction_control_mask)));
|
|
||||||
}
|
|
||||||
// Store back into buffer, then copy to seeds_out.
|
|
||||||
hn::Store(vec, d8, buffer_ptr);
|
|
||||||
std::copy_n(buffer.get(), remaining_blocks, seeds_out + start_block);
|
|
||||||
BoolsFromMask(d64, control_mask, control_bits_out + start_block,
|
|
||||||
remaining_blocks);
|
|
||||||
}
|
|
||||||
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // HWY_TARGET == HWY_SCALAR
|
|
||||||
|
|
||||||
} // namespace HWY_NAMESPACE
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
HWY_AFTER_NAMESPACE();
|
|
||||||
|
|
||||||
#if HWY_ONCE || HWY_IDE
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
|
|
||||||
absl::Status EvaluateSeedsNoHwy(
|
|
||||||
int64_t num_seeds, int num_levels, int num_correction_words,
|
|
||||||
const absl::uint128* seeds_in, const bool* control_bits_in,
|
|
||||||
const absl::uint128* paths, int paths_rightshift,
|
|
||||||
const absl::uint128* correction_seeds, const bool* correction_controls_left,
|
|
||||||
const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left,
|
|
||||||
const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out,
|
|
||||||
bool* control_bits_out) {
|
|
||||||
using BitVector =
|
|
||||||
absl::InlinedVector<bool,
|
|
||||||
std::max<size_t>(1, sizeof(bool*) / sizeof(bool))>;
|
|
||||||
constexpr int64_t max_batch_size = Aes128FixedKeyHash::kBatchSize;
|
|
||||||
|
|
||||||
// Allocate buffers.
|
|
||||||
std::vector<absl::uint128> buffer_left, buffer_right;
|
|
||||||
buffer_left.resize(max_batch_size);
|
|
||||||
buffer_right.resize(max_batch_size);
|
|
||||||
BitVector path_bits(max_batch_size), control_bits(max_batch_size);
|
|
||||||
|
|
||||||
// Perform DPF evaluation in blocks.
|
|
||||||
for (int64_t start_block = 0; start_block < num_seeds;
|
|
||||||
start_block += max_batch_size) {
|
|
||||||
int64_t current_batch_size =
|
|
||||||
std::min<int64_t>(num_seeds - start_block, max_batch_size);
|
|
||||||
|
|
||||||
for (int level = 0; level < num_levels; ++level) {
|
|
||||||
// Evaluate PRG. We evaluate both left and right expansions, but only use
|
|
||||||
// one of them (depending on path_bits). This seems to be faster than
|
|
||||||
// first sorting the seeds by path_bits and then expanding.
|
|
||||||
absl::Span<const absl::uint128> seeds =
|
|
||||||
absl::MakeConstSpan((level == 0 ? seeds_in : seeds_out) + start_block,
|
|
||||||
current_batch_size);
|
|
||||||
DPF_RETURN_IF_ERROR(prg_left.Evaluate(
|
|
||||||
seeds, absl::MakeSpan(buffer_left).subspan(0, current_batch_size)));
|
|
||||||
DPF_RETURN_IF_ERROR(prg_right.Evaluate(
|
|
||||||
seeds, absl::MakeSpan(buffer_right).subspan(0, current_batch_size)));
|
|
||||||
|
|
||||||
// Merge back into result.
|
|
||||||
const int bit_index = num_levels - level - 1 + paths_rightshift;
|
|
||||||
for (int i = 0; i < current_batch_size; ++i) {
|
|
||||||
path_bits[i] = 0;
|
|
||||||
if (bit_index < 128) {
|
|
||||||
path_bits[i] =
|
|
||||||
((paths[start_block + i]) & (absl::uint128{1} << bit_index)) != 0;
|
|
||||||
}
|
|
||||||
if (path_bits[i] == 0) {
|
|
||||||
seeds_out[start_block + i] = buffer_left[i];
|
|
||||||
} else {
|
|
||||||
seeds_out[start_block + i] = buffer_right[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute correction. Making a copy here a copy here improves pipelining
|
|
||||||
// by not updating result.control_bits in place. Do benchmarks before
|
|
||||||
// removing this.
|
|
||||||
std::copy_n(
|
|
||||||
&(level == 0 ? control_bits_in : control_bits_out)[start_block],
|
|
||||||
current_batch_size, &control_bits[0]);
|
|
||||||
int correction_index = level;
|
|
||||||
for (int i = 0; i < current_batch_size; ++i) {
|
|
||||||
if (num_correction_words > num_levels) {
|
|
||||||
// We have num_levels * num_seeds correction words.
|
|
||||||
correction_index = level * num_seeds + start_block + i;
|
|
||||||
}
|
|
||||||
if (control_bits[i]) {
|
|
||||||
seeds_out[start_block + i] ^= correction_seeds[correction_index];
|
|
||||||
}
|
|
||||||
bool current_control_bit =
|
|
||||||
ExtractAndClearLowestBit(seeds_out[start_block + i]);
|
|
||||||
if (control_bits[i]) {
|
|
||||||
if (path_bits[i] == 0) {
|
|
||||||
current_control_bit ^= correction_controls_left[correction_index];
|
|
||||||
} else {
|
|
||||||
current_control_bit ^= correction_controls_right[correction_index];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
control_bits_out[start_block + i] = current_control_bit;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
HWY_EXPORT(EvaluateSeedsHwy);
|
|
||||||
|
|
||||||
absl::Status EvaluateSeeds(
|
|
||||||
int64_t num_seeds, int num_levels, int num_correction_words,
|
|
||||||
const absl::uint128* seeds_in, const bool* control_bits_in,
|
|
||||||
const absl::uint128* paths, int paths_rightshift,
|
|
||||||
const absl::uint128* correction_seeds, const bool* correction_controls_left,
|
|
||||||
const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left,
|
|
||||||
const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out,
|
|
||||||
bool* control_bits_out) {
|
|
||||||
// Check that we either have one or `num_seeds` correction words per level.
|
|
||||||
if (num_correction_words != num_levels &&
|
|
||||||
num_correction_words != num_levels * num_seeds) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"`num_correction_words` must be equal to `num_levels` or `num_levels * "
|
|
||||||
"num_seeds`");
|
|
||||||
}
|
|
||||||
return HWY_DYNAMIC_DISPATCH(EvaluateSeedsHwy)(
|
|
||||||
num_seeds, num_levels, num_correction_words, seeds_in, control_bits_in,
|
|
||||||
paths, paths_rightshift, correction_seeds, correction_controls_left,
|
|
||||||
correction_controls_right, prg_left, prg_right, seeds_out,
|
|
||||||
control_bits_out);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
#endif
|
|
@ -1,92 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2021 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_EXPAND_SEEDS_HWY_H_
|
|
||||||
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_EXPAND_SEEDS_HWY_H_
|
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "dpf/aes_128_fixed_key_hash.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
|
|
||||||
using distributed_point_functions::Aes128FixedKeyHash;
|
|
||||||
|
|
||||||
// Extracts the lowest bit of `x` and sets it to 0 in `x`.
|
|
||||||
inline bool ExtractAndClearLowestBit(absl::uint128& x) {
|
|
||||||
bool bit = ((x & absl::uint128{1}) != 0);
|
|
||||||
x &= ~absl::uint128{1};
|
|
||||||
return bit;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Performs DPF evaluation of the seeds given in `seeds_in` using `prg_left` or
|
|
||||||
// `prg_right, and the given `control_bits_in`, and correction words given by
|
|
||||||
// `correction_seeds`, `correction_controls_left`, and
|
|
||||||
// `correction_controls_right`. At each level `l < num_level`, the evaluation
|
|
||||||
// for the i-th seed continues along the left or right path depending on the
|
|
||||||
// l-th most significant bit among the lowest `num_levels` bits of `paths[i]`,
|
|
||||||
// after right-shifting each `paths[i]` by `paths_rightshift`.
|
|
||||||
//
|
|
||||||
// This function takes raw pointers instead of absl::Span for performance
|
|
||||||
// reasons. No bounds checks are performed, so it is the caller's responsibility
|
|
||||||
// to ensure that
|
|
||||||
// - `seeds_in`, `control_bits_in`, `seeds_out`, and `control_bits_out` have at
|
|
||||||
// least `num_seeds` elements, and
|
|
||||||
// - `correction_seeds`, `correction_controls_left`, and
|
|
||||||
// `correction_controls_right` have at least `num_levels` elements.
|
|
||||||
//
|
|
||||||
// If the inputs are aligned (e.g. using HWY_ALIGN, or hwy::AllocateAligned),
|
|
||||||
// and if SIMD operations are supported, then the evaluation will be done using
|
|
||||||
// SIMD operations. Otherwise, falls back to `EvaluateSeedsNoHwy`, which is at
|
|
||||||
// least 2x slower.
|
|
||||||
//
|
|
||||||
// `num_correction_words` can either be equal to `num_levels`, or equal to
|
|
||||||
// `num_seeds * num_levels`. In the first case, the same correction word is used
|
|
||||||
// for every seed at a given level. In the second case, correction word at index
|
|
||||||
// `i * num_seeds + j` is used to correct seed `i` at level `j`.
|
|
||||||
// If `num_correction_words == num_seeds * num_levels`, then `num_seeds` should
|
|
||||||
// be smaller than or divisible by the size of a SIMD vector for optimal
|
|
||||||
// performance.
|
|
||||||
//
|
|
||||||
// Returns OK on success, INVALID_ARGUMENT in case num_correction_words is not
|
|
||||||
// equal to `num_levels` or `num_seeds * num_levels`, and INTERNAL in case of
|
|
||||||
// OpenSSL errors.
|
|
||||||
absl::Status EvaluateSeeds(
|
|
||||||
int64_t num_seeds, int num_levels, int num_correction_words,
|
|
||||||
const absl::uint128* seeds_in, const bool* control_bits_in,
|
|
||||||
const absl::uint128* paths, int paths_rightshift,
|
|
||||||
const absl::uint128* correction_seeds, const bool* correction_controls_left,
|
|
||||||
const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left,
|
|
||||||
const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out,
|
|
||||||
bool* control_bits_out);
|
|
||||||
|
|
||||||
// As `EvaluateSeeds`, but does not require any SIMD support.
|
|
||||||
absl::Status EvaluateSeedsNoHwy(
|
|
||||||
int64_t num_seeds, int num_levels, int num_correction_words,
|
|
||||||
const absl::uint128* seeds_in, const bool* control_bits_in,
|
|
||||||
const absl::uint128* paths, int paths_rightshift,
|
|
||||||
const absl::uint128* correction_seeds, const bool* correction_controls_left,
|
|
||||||
const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left,
|
|
||||||
const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out,
|
|
||||||
bool* control_bits_out);
|
|
||||||
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
|
|
||||||
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_EXPAND_SEEDS_HWY_H_
|
|
@ -1,257 +0,0 @@
|
|||||||
// Copyright 2021 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "dpf/internal/evaluate_prg_hwy.h"
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "dpf/aes_128_fixed_key_hash.h"
|
|
||||||
#include "dpf/internal/status_matchers.h"
|
|
||||||
#include "gmock/gmock.h"
|
|
||||||
#include "gtest/gtest.h"
|
|
||||||
#include "hwy/aligned_allocator.h"
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#define HWY_IS_TEST 1;
|
|
||||||
#undef HWY_TARGET_INCLUDE
|
|
||||||
#define HWY_TARGET_INCLUDE "dpf/internal/evaluate_prg_hwy_test.cc" // NOLINT
|
|
||||||
#include "hwy/foreach_target.h"
|
|
||||||
// clang-format on
|
|
||||||
#include "hwy/highway.h"
|
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
namespace HWY_NAMESPACE {
|
|
||||||
|
|
||||||
using ::testing::HasSubstr;
|
|
||||||
|
|
||||||
constexpr absl::uint128 kKey0 =
|
|
||||||
absl::MakeUint128(0x0000000000000000, 0x0000000000000000);
|
|
||||||
constexpr absl::uint128 kKey1 =
|
|
||||||
absl::MakeUint128(0x1111111111111111, 0x1111111111111111);
|
|
||||||
|
|
||||||
void TestOutputMatchesNoHwyVersion(int num_seeds, int num_levels,
|
|
||||||
int num_correction_words,
|
|
||||||
int paths_rightshift) {
|
|
||||||
// Generate seeds.
|
|
||||||
hwy::AlignedFreeUniquePtr<absl::uint128[]> seeds_in, paths;
|
|
||||||
hwy::AlignedFreeUniquePtr<bool[]> control_bits_in;
|
|
||||||
if (num_seeds > 0) {
|
|
||||||
seeds_in = hwy::AllocateAligned<absl::uint128>(num_seeds);
|
|
||||||
ASSERT_NE(seeds_in, nullptr);
|
|
||||||
paths = hwy::AllocateAligned<absl::uint128>(num_seeds);
|
|
||||||
ASSERT_NE(paths, nullptr);
|
|
||||||
control_bits_in = hwy::AllocateAligned<bool>(num_seeds);
|
|
||||||
ASSERT_NE(control_bits_in, nullptr);
|
|
||||||
}
|
|
||||||
for (int i = 0; i < num_seeds; ++i) {
|
|
||||||
// All of these are arbitrary.
|
|
||||||
seeds_in[i] = absl::MakeUint128(i, i + 1);
|
|
||||||
paths[i] = absl::MakeUint128(23 * i + 42, 42 * i + 23);
|
|
||||||
control_bits_in[i] = (i % 7 == 0);
|
|
||||||
}
|
|
||||||
hwy::AlignedFreeUniquePtr<absl::uint128[]> seeds_out;
|
|
||||||
hwy::AlignedFreeUniquePtr<bool[]> control_bits_out;
|
|
||||||
if (num_seeds > 0) {
|
|
||||||
seeds_out = hwy::AllocateAligned<absl::uint128>(num_seeds);
|
|
||||||
ASSERT_NE(seeds_out, nullptr);
|
|
||||||
control_bits_out = hwy::AllocateAligned<bool>(num_seeds);
|
|
||||||
ASSERT_NE(control_bits_out, nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate correction words.
|
|
||||||
hwy::AlignedFreeUniquePtr<absl::uint128[]> correction_seeds;
|
|
||||||
hwy::AlignedFreeUniquePtr<bool[]> correction_controls_left,
|
|
||||||
correction_controls_right;
|
|
||||||
if (num_correction_words > 0) {
|
|
||||||
correction_seeds =
|
|
||||||
hwy::AllocateAligned<absl::uint128>(num_correction_words);
|
|
||||||
ASSERT_NE(correction_seeds, nullptr);
|
|
||||||
correction_controls_left = hwy::AllocateAligned<bool>(num_correction_words);
|
|
||||||
ASSERT_NE(correction_controls_left, nullptr);
|
|
||||||
correction_controls_right =
|
|
||||||
hwy::AllocateAligned<bool>(num_correction_words);
|
|
||||||
ASSERT_NE(correction_controls_right, nullptr);
|
|
||||||
}
|
|
||||||
for (int i = 0; i < num_correction_words; ++i) {
|
|
||||||
correction_seeds[i] = absl::MakeUint128(i + 1, i);
|
|
||||||
correction_controls_left[i] = (i % 23 == 0);
|
|
||||||
correction_controls_right[i] = (i % 42 != 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up PRGs.
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(
|
|
||||||
auto prg_left,
|
|
||||||
distributed_point_functions::Aes128FixedKeyHash::Create(kKey0));
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(
|
|
||||||
auto prg_right,
|
|
||||||
distributed_point_functions::Aes128FixedKeyHash::Create(kKey1));
|
|
||||||
|
|
||||||
// Evaluate with Highway enabled.
|
|
||||||
DPF_ASSERT_OK(
|
|
||||||
EvaluateSeeds(num_seeds, num_levels, num_correction_words, seeds_in.get(),
|
|
||||||
control_bits_in.get(), paths.get(), paths_rightshift,
|
|
||||||
correction_seeds.get(), correction_controls_left.get(),
|
|
||||||
correction_controls_right.get(), prg_left, prg_right,
|
|
||||||
seeds_out.get(), control_bits_out.get()));
|
|
||||||
|
|
||||||
// Evaluate without highway.
|
|
||||||
hwy::AlignedFreeUniquePtr<absl::uint128[]> seeds_out_wanted;
|
|
||||||
hwy::AlignedFreeUniquePtr<bool[]> control_bits_out_wanted;
|
|
||||||
if (num_seeds > 0) {
|
|
||||||
seeds_out_wanted = hwy::AllocateAligned<absl::uint128>(num_seeds);
|
|
||||||
ASSERT_NE(seeds_out_wanted, nullptr);
|
|
||||||
control_bits_out_wanted = hwy::AllocateAligned<bool>(num_seeds);
|
|
||||||
ASSERT_NE(control_bits_out_wanted, nullptr);
|
|
||||||
}
|
|
||||||
DPF_ASSERT_OK(EvaluateSeedsNoHwy(
|
|
||||||
num_seeds, num_levels, num_correction_words, seeds_in.get(),
|
|
||||||
control_bits_in.get(), paths.get(), paths_rightshift,
|
|
||||||
correction_seeds.get(), correction_controls_left.get(),
|
|
||||||
correction_controls_right.get(), prg_left, prg_right,
|
|
||||||
seeds_out_wanted.get(), control_bits_out_wanted.get()));
|
|
||||||
|
|
||||||
// Check that both evaluations are equal, if there was anything to evaluate.
|
|
||||||
if (num_levels > 0) {
|
|
||||||
for (int i = 0; i < num_seeds; ++i) {
|
|
||||||
EXPECT_EQ(seeds_out[i], seeds_out_wanted[i]);
|
|
||||||
EXPECT_EQ(control_bits_out[i], control_bits_out_wanted[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Evaluate without paths_rightshift
|
|
||||||
if (paths_rightshift != 0) {
|
|
||||||
hwy::AlignedFreeUniquePtr<absl::uint128[]> paths_in2;
|
|
||||||
hwy::AlignedFreeUniquePtr<absl::uint128[]> seeds_out_wanted2;
|
|
||||||
hwy::AlignedFreeUniquePtr<bool[]> control_bits_out_wanted2;
|
|
||||||
if (num_seeds > 0) {
|
|
||||||
paths_in2 = hwy::AllocateAligned<absl::uint128>(num_seeds);
|
|
||||||
ASSERT_NE(paths_in2, nullptr);
|
|
||||||
seeds_out_wanted2 = hwy::AllocateAligned<absl::uint128>(num_seeds);
|
|
||||||
ASSERT_NE(seeds_out_wanted2, nullptr);
|
|
||||||
control_bits_out_wanted2 = hwy::AllocateAligned<bool>(num_seeds);
|
|
||||||
ASSERT_NE(control_bits_out_wanted2, nullptr);
|
|
||||||
}
|
|
||||||
for (int i = 0; i < num_seeds; ++i) {
|
|
||||||
paths_in2[i] = 0;
|
|
||||||
if (paths_rightshift < 128) {
|
|
||||||
paths_in2[i] = paths[i] >> paths_rightshift;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
DPF_ASSERT_OK(EvaluateSeedsNoHwy(
|
|
||||||
num_seeds, num_levels, num_correction_words, seeds_in.get(),
|
|
||||||
control_bits_in.get(), paths_in2.get(), 0, correction_seeds.get(),
|
|
||||||
correction_controls_left.get(), correction_controls_right.get(),
|
|
||||||
prg_left, prg_right, seeds_out_wanted2.get(),
|
|
||||||
control_bits_out_wanted2.get()));
|
|
||||||
// Check that both evaluations are equal, if there was anything to evaluate.
|
|
||||||
if (num_levels > 0) {
|
|
||||||
for (int i = 0; i < num_seeds; ++i) {
|
|
||||||
EXPECT_EQ(seeds_out[i], seeds_out_wanted2[i]);
|
|
||||||
EXPECT_EQ(control_bits_out[i], control_bits_out_wanted2[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestAll() {
|
|
||||||
for (int num_seeds : {0, 1, 2, 101, 128, 1000}) {
|
|
||||||
for (int num_levels : {0, 1, 2, 32, 63, 64, 127, 128}) {
|
|
||||||
for (int num_correction_words : {num_levels, num_levels * num_seeds}) {
|
|
||||||
TestOutputMatchesNoHwyVersion(num_seeds, num_levels,
|
|
||||||
num_correction_words, 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestPathsRightshift() {
|
|
||||||
constexpr int num_levels = 128;
|
|
||||||
for (int num_seeds : {0, 1, 101}) {
|
|
||||||
for (int paths_rightshift = 0; paths_rightshift <= 128;
|
|
||||||
++paths_rightshift) {
|
|
||||||
TestOutputMatchesNoHwyVersion(num_seeds, num_levels, num_levels,
|
|
||||||
paths_rightshift);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void FailsIfNumCorrectionWordsIsWrong() {
|
|
||||||
constexpr int num_seeds = 1000;
|
|
||||||
constexpr int num_levels = 10;
|
|
||||||
constexpr int num_correction_words = 12;
|
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<absl::uint128[]> seeds_in, paths;
|
|
||||||
hwy::AlignedFreeUniquePtr<bool[]> control_bits_in;
|
|
||||||
seeds_in = hwy::AllocateAligned<absl::uint128>(num_seeds);
|
|
||||||
ASSERT_NE(seeds_in, nullptr);
|
|
||||||
paths = hwy::AllocateAligned<absl::uint128>(num_seeds);
|
|
||||||
ASSERT_NE(paths, nullptr);
|
|
||||||
control_bits_in = hwy::AllocateAligned<bool>(num_seeds);
|
|
||||||
ASSERT_NE(control_bits_in, nullptr);
|
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<absl::uint128[]> correction_seeds;
|
|
||||||
hwy::AlignedFreeUniquePtr<bool[]> correction_controls_left,
|
|
||||||
correction_controls_right;
|
|
||||||
correction_seeds = hwy::AllocateAligned<absl::uint128>(num_correction_words);
|
|
||||||
ASSERT_NE(correction_seeds, nullptr);
|
|
||||||
correction_controls_left = hwy::AllocateAligned<bool>(num_correction_words);
|
|
||||||
ASSERT_NE(correction_controls_left, nullptr);
|
|
||||||
correction_controls_right = hwy::AllocateAligned<bool>(num_correction_words);
|
|
||||||
ASSERT_NE(correction_controls_right, nullptr);
|
|
||||||
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(
|
|
||||||
auto prg_left,
|
|
||||||
distributed_point_functions::Aes128FixedKeyHash::Create(kKey0));
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(
|
|
||||||
auto prg_right,
|
|
||||||
distributed_point_functions::Aes128FixedKeyHash::Create(kKey1));
|
|
||||||
|
|
||||||
EXPECT_THAT(
|
|
||||||
EvaluateSeeds(num_seeds, num_levels, num_correction_words, seeds_in.get(),
|
|
||||||
control_bits_in.get(), paths.get(), 0,
|
|
||||||
correction_seeds.get(), correction_controls_left.get(),
|
|
||||||
correction_controls_right.get(), prg_left, prg_right,
|
|
||||||
seeds_in.get(), control_bits_in.get()),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
HasSubstr("num_correction_words")));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace HWY_NAMESPACE
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
HWY_AFTER_NAMESPACE();
|
|
||||||
|
|
||||||
#if HWY_ONCE
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
HWY_BEFORE_TEST(EvaluatePrgHwyTest);
|
|
||||||
HWY_EXPORT_AND_TEST_P(EvaluatePrgHwyTest, TestAll);
|
|
||||||
HWY_EXPORT_AND_TEST_P(EvaluatePrgHwyTest, TestPathsRightshift);
|
|
||||||
HWY_EXPORT_AND_TEST_P(EvaluatePrgHwyTest, FailsIfNumCorrectionWordsIsWrong);
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
::testing::InitGoogleTest(&argc, argv);
|
|
||||||
return RUN_ALL_TESTS();
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
@ -1,48 +0,0 @@
|
|||||||
// Copyright 2022 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "dpf/internal/get_hwy_mode.h"
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
#undef HWY_TARGET_INCLUDE
|
|
||||||
#define HWY_TARGET_INCLUDE "dpf/internal/get_hwy_mode.cc"
|
|
||||||
#include "absl/strings/string_view.h"
|
|
||||||
#include "hwy/foreach_target.h"
|
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
#include "hwy/highway.h"
|
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
namespace HWY_NAMESPACE {
|
|
||||||
|
|
||||||
const absl::string_view GetHwyModeAsString() {
|
|
||||||
return hwy::TargetName(HWY_TARGET);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace HWY_NAMESPACE
|
|
||||||
|
|
||||||
#if HWY_ONCE || HWY_IDE
|
|
||||||
|
|
||||||
HWY_EXPORT(GetHwyModeAsString);
|
|
||||||
const absl::string_view GetHwyModeAsString() {
|
|
||||||
return HWY_DYNAMIC_DISPATCH(GetHwyModeAsString)();
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
HWY_AFTER_NAMESPACE();
|
|
@ -1,32 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2022 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_GET_HWY_MODE_H_
|
|
||||||
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_GET_HWY_MODE_H_
|
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
|
|
||||||
// Utility function for printing the mode selected by Highway. Used for
|
|
||||||
// debugging.
|
|
||||||
const absl::string_view GetHwyModeAsString();
|
|
||||||
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
|
|
||||||
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_GET_HWY_MODE_H_
|
|
@ -1,92 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2023 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_ANY_SPAN_H_
|
|
||||||
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_ANY_SPAN_H_
|
|
||||||
|
|
||||||
// A class that can serve the purpose of both absl::Span<T> and absl::Span<T*>
|
|
||||||
// at the same time. Introduces the run-time overhead of a std::variant check.
|
|
||||||
//
|
|
||||||
// Note that this class DOES NOT provide common container features, such as
|
|
||||||
// iterators. It is not intended to be used by users of this library. Any
|
|
||||||
// function that takes a MaybeDerefSpan<T> should be called with either an
|
|
||||||
// absl::Span<T> or an absl::Span<T*>.
|
|
||||||
|
|
||||||
#include <type_traits>
|
|
||||||
|
|
||||||
#include "absl/meta/type_traits.h"
|
|
||||||
#include "absl/types/span.h"
|
|
||||||
#include "absl/types/variant.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class MaybeDerefSpan {
|
|
||||||
private:
|
|
||||||
template <typename U>
|
|
||||||
using EnableIfValueIsConst =
|
|
||||||
typename absl::enable_if_t<std::is_const<T>::value, U>;
|
|
||||||
|
|
||||||
template <typename U>
|
|
||||||
using EnableIfValueIsConvertibleToSpan = typename absl::enable_if_t<
|
|
||||||
absl::disjunction<std::is_convertible<U, absl::Span<T>>,
|
|
||||||
std::is_convertible<U, absl::Span<T* const>>>::value,
|
|
||||||
U>;
|
|
||||||
|
|
||||||
public:
|
|
||||||
// Implicit constructors from the underlying absl::Span.
|
|
||||||
MaybeDerefSpan(absl::Span<T> span)
|
|
||||||
: span_(span) {} // NOLINT(runtime/explicit)
|
|
||||||
MaybeDerefSpan(absl::Span<T* const> span)
|
|
||||||
: span_(span) {} // NOLINT(runtime/explicit)
|
|
||||||
|
|
||||||
// Implicit constructor of a const MaybeDerefSpan from a non-const one.
|
|
||||||
template <typename T2 = T, typename = EnableIfValueIsConst<T2>>
|
|
||||||
MaybeDerefSpan(
|
|
||||||
const MaybeDerefSpan<typename std::remove_const<T>::type>& other)
|
|
||||||
: span_(absl::ConvertVariantTo<decltype(span_)>(other.span_)) {
|
|
||||||
} // NOLINT(runtime/explicit)
|
|
||||||
|
|
||||||
// Implicit constructor of a const MaybeDerefSpan from anything that is
|
|
||||||
// convertible to one of the underlying spans.
|
|
||||||
template <typename V, typename = EnableIfValueIsConst<V>,
|
|
||||||
typename = EnableIfValueIsConvertibleToSpan<V>>
|
|
||||||
MaybeDerefSpan(const V& span)
|
|
||||||
: span_(absl::MakeConstSpan(span)) {} // NOLINT(runtime/explicit)
|
|
||||||
|
|
||||||
inline constexpr T& operator[](size_t index) const {
|
|
||||||
if (absl::holds_alternative<absl::Span<T* const>>(span_)) {
|
|
||||||
return *absl::get<absl::Span<T* const>>(span_)[index];
|
|
||||||
}
|
|
||||||
return absl::get<absl::Span<T>>(span_)[index];
|
|
||||||
}
|
|
||||||
|
|
||||||
inline constexpr size_t size() const {
|
|
||||||
return absl::visit([](auto v) { return v.size(); }, span_);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
template <typename U>
|
|
||||||
friend class MaybeDerefSpan;
|
|
||||||
|
|
||||||
absl::variant<absl::Span<T>, absl::Span<T* const>> span_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
|
|
||||||
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_ANY_SPAN_H_
|
|
@ -1,195 +0,0 @@
|
|||||||
// Copyright 2023 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "dpf/internal/maybe_deref_span.h"
|
|
||||||
|
|
||||||
#include <type_traits>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "gmock/gmock.h"
|
|
||||||
#include "gtest/gtest.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
using T = int;
|
|
||||||
|
|
||||||
TEST(MaybeDerefSpanTest, TestExplicitMutableDirectSpan) {
|
|
||||||
std::vector<T> x = {1, 2};
|
|
||||||
absl::Span<T> span(x);
|
|
||||||
MaybeDerefSpan<T> span2(span);
|
|
||||||
|
|
||||||
EXPECT_EQ(span2.size(), x.size());
|
|
||||||
EXPECT_EQ(span2[0], x[0]);
|
|
||||||
EXPECT_EQ(span2[1], x[1]);
|
|
||||||
EXPECT_EQ(&span2[0], &x[0]);
|
|
||||||
EXPECT_EQ(&span2[1], &x[1]);
|
|
||||||
|
|
||||||
span2[0] = 3;
|
|
||||||
EXPECT_EQ(span2[0], x[0]);
|
|
||||||
EXPECT_EQ(x[0], 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(MaybeDerefSpanTest, TestExplicitMutableSpan) {
|
|
||||||
const std::vector<T> x = {1, 2};
|
|
||||||
absl::Span<const T> span(x);
|
|
||||||
MaybeDerefSpan<const T> span2(span);
|
|
||||||
|
|
||||||
EXPECT_EQ(span2.size(), x.size());
|
|
||||||
EXPECT_EQ(span2[0], x[0]);
|
|
||||||
EXPECT_EQ(span2[1], x[1]);
|
|
||||||
EXPECT_EQ(&span2[0], &x[0]);
|
|
||||||
EXPECT_EQ(&span2[1], &x[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(MaybeDerefSpanTest, TestExplicitMutablePointerSpan) {
|
|
||||||
std::vector<T> x = {1, 2};
|
|
||||||
std::vector<T*> x2 = {&x[0], &x[1]};
|
|
||||||
absl::Span<T*> span(x2);
|
|
||||||
MaybeDerefSpan<T> span2(span);
|
|
||||||
|
|
||||||
EXPECT_EQ(span2.size(), x.size());
|
|
||||||
EXPECT_EQ(span2[0], x[0]);
|
|
||||||
EXPECT_EQ(span2[1], x[1]);
|
|
||||||
EXPECT_EQ(&span2[0], &x[0]);
|
|
||||||
EXPECT_EQ(&span2[1], &x[1]);
|
|
||||||
|
|
||||||
span2[0] = 3;
|
|
||||||
EXPECT_EQ(span2[0], x[0]);
|
|
||||||
EXPECT_EQ(x[0], 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(MaybeDerefSpanTest, TestExplicitMutablePointerConstSpan) {
|
|
||||||
std::vector<T> x = {1, 2};
|
|
||||||
const std::vector<T*> x2 = {&x[0], &x[1]};
|
|
||||||
absl::Span<T* const> span(x2);
|
|
||||||
MaybeDerefSpan<T> span2(span);
|
|
||||||
|
|
||||||
EXPECT_EQ(span2.size(), x.size());
|
|
||||||
EXPECT_EQ(span2[0], x[0]);
|
|
||||||
EXPECT_EQ(span2[1], x[1]);
|
|
||||||
EXPECT_EQ(&span2[0], &x[0]);
|
|
||||||
EXPECT_EQ(&span2[1], &x[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(MaybeDerefSpanTest, TestExplicitConstPointerConstSpan) {
|
|
||||||
const std::vector<T> x = {1, 2};
|
|
||||||
const std::vector<const T*> x2 = {&x[0], &x[1]};
|
|
||||||
absl::Span<const T* const> span(x2);
|
|
||||||
MaybeDerefSpan<const T> span2(span);
|
|
||||||
|
|
||||||
EXPECT_EQ(span2.size(), x.size());
|
|
||||||
EXPECT_EQ(span2[0], x[0]);
|
|
||||||
EXPECT_EQ(span2[1], x[1]);
|
|
||||||
EXPECT_EQ(&span2[0], &x[0]);
|
|
||||||
EXPECT_EQ(&span2[1], &x[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(MaybeDerefSpanTest, TestMutableSpanToConstSpan) {
|
|
||||||
std::vector<T> x = {1, 2};
|
|
||||||
absl::Span<T> span(x);
|
|
||||||
MaybeDerefSpan<T> span2(span);
|
|
||||||
MaybeDerefSpan<const T> span3(span2);
|
|
||||||
|
|
||||||
EXPECT_EQ(span3.size(), x.size());
|
|
||||||
EXPECT_EQ(span3[0], x[0]);
|
|
||||||
EXPECT_EQ(span3[1], x[1]);
|
|
||||||
EXPECT_EQ(&span3[0], &x[0]);
|
|
||||||
EXPECT_EQ(&span3[1], &x[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(MaybeDerefSpanTest, TestImplicitConstSpan) {
|
|
||||||
const std::vector<T> x = {1, 2};
|
|
||||||
MaybeDerefSpan<const T> span2(x);
|
|
||||||
|
|
||||||
EXPECT_EQ(span2.size(), x.size());
|
|
||||||
EXPECT_EQ(span2[0], x[0]);
|
|
||||||
EXPECT_EQ(span2[1], x[1]);
|
|
||||||
EXPECT_EQ(&span2[0], &x[0]);
|
|
||||||
EXPECT_EQ(&span2[1], &x[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(MaybeDerefSpanTest, TestImplicitPointerConstSpan) {
|
|
||||||
const std::vector<T> x = {1, 2};
|
|
||||||
const std::vector<const T*> x2 = {&x[0], &x[1]};
|
|
||||||
MaybeDerefSpan<const T> span2(x2);
|
|
||||||
|
|
||||||
EXPECT_EQ(span2.size(), x.size());
|
|
||||||
EXPECT_EQ(span2[0], x[0]);
|
|
||||||
EXPECT_EQ(span2[1], x[1]);
|
|
||||||
EXPECT_EQ(&span2[0], &x[0]);
|
|
||||||
EXPECT_EQ(&span2[1], &x[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestEq(MaybeDerefSpan<const T> span, const std::vector<T>& vector) {
|
|
||||||
EXPECT_EQ(span.size(), vector.size());
|
|
||||||
EXPECT_EQ(span[0], vector[0]);
|
|
||||||
EXPECT_EQ(span[1], vector[1]);
|
|
||||||
EXPECT_EQ(&span[0], &vector[0]);
|
|
||||||
EXPECT_EQ(&span[1], &vector[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(MaybeDerefSpanTest, TestFunctionCallMutableVector) {
|
|
||||||
std::vector<T> x = {1, 2};
|
|
||||||
|
|
||||||
TestEq(x, x);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(MaybeDerefSpanTest, TestFunctionCallMutablePointerVector) {
|
|
||||||
std::vector<T> x = {1, 2};
|
|
||||||
std::vector<T*> x2 = {&x[0], &x[1]};
|
|
||||||
|
|
||||||
TestEq(x2, x);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(MaybeDerefSpanTest, TestFunctionCallConstVector) {
|
|
||||||
const std::vector<T> x = {1, 2};
|
|
||||||
|
|
||||||
TestEq(x, x);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(MaybeDerefSpanTest, TestFunctionCallMutablePointerConstVector) {
|
|
||||||
std::vector<T> x = {1, 2};
|
|
||||||
const std::vector<T*> x2 = {&x[0], &x[1]};
|
|
||||||
|
|
||||||
TestEq(x2, x);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(MaybeDerefSpanTest, TestFunctionCallConstPointerConstVector) {
|
|
||||||
const std::vector<T> x = {1, 2};
|
|
||||||
const std::vector<const T*> x2 = {&x[0], &x[1]};
|
|
||||||
|
|
||||||
TestEq(x2, x);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Taken from https://en.cppreference.com/w/cpp/types/is_convertible.
|
|
||||||
template <class From, class To>
|
|
||||||
auto test_implicitly_convertible(int)
|
|
||||||
-> decltype(void(std::declval<void (&)(To)>()(std::declval<From>())),
|
|
||||||
std::true_type{});
|
|
||||||
template <class, class>
|
|
||||||
auto test_implicitly_convertible(...) -> std::false_type;
|
|
||||||
|
|
||||||
// Test that vectors are convertible only to const spans.
|
|
||||||
static_assert(
|
|
||||||
decltype(test_implicitly_convertible<std::vector<T>, MaybeDerefSpan<T>>(
|
|
||||||
0))::value == false);
|
|
||||||
static_assert(decltype(test_implicitly_convertible<
|
|
||||||
std::vector<T>, MaybeDerefSpan<const T>>(0))::value ==
|
|
||||||
true);
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
@ -1,336 +0,0 @@
|
|||||||
// Copyright 2021 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "dpf/internal/proto_validator.h"
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cmath>
|
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
|
||||||
#include "absl/log/absl_check.h"
|
|
||||||
#include "absl/memory/memory.h"
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "absl/strings/str_format.h"
|
|
||||||
#include "absl/types/span.h"
|
|
||||||
#include "dpf/distributed_point_function.pb.h"
|
|
||||||
#include "dpf/internal/value_type_helpers.h"
|
|
||||||
#include "dpf/status_macros.h"
|
|
||||||
#include "google/protobuf/repeated_field.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
inline double GetDefaultSecurityParameter(const DpfParameters& parameters) {
|
|
||||||
return ProtoValidator::kDefaultSecurityParameter +
|
|
||||||
parameters.log_domain_size();
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool AlmostEqual(double a, double b) {
|
|
||||||
return std::abs(a - b) <= ProtoValidator::kSecurityParameterEpsilon;
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<bool> ParametersAreEqual(const DpfParameters& lhs,
|
|
||||||
const DpfParameters& rhs) {
|
|
||||||
if (lhs.log_domain_size() != rhs.log_domain_size()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!(
|
|
||||||
// There are three ways that security parameters can be equivalent.
|
|
||||||
// Both equal.
|
|
||||||
AlmostEqual(lhs.security_parameter(), rhs.security_parameter()) ||
|
|
||||||
// lhs is zero and rhs has the default value.
|
|
||||||
(lhs.security_parameter() == 0 &&
|
|
||||||
AlmostEqual(rhs.security_parameter(),
|
|
||||||
GetDefaultSecurityParameter(rhs))) ||
|
|
||||||
// rhs is zero and lhs has the default value.
|
|
||||||
(rhs.security_parameter() == 0 &&
|
|
||||||
AlmostEqual(lhs.security_parameter(),
|
|
||||||
GetDefaultSecurityParameter(lhs))))) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return ValueTypesAreEqual(lhs.value_type(), rhs.value_type());
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status ValidateIntegerType(const ValueType::Integer& type) {
|
|
||||||
int bitsize = type.bitsize();
|
|
||||||
if (bitsize < 1) {
|
|
||||||
return absl::InvalidArgumentError("`bitsize` must be positive");
|
|
||||||
}
|
|
||||||
if (bitsize > 128) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"`bitsize` must be less than or equal to 128");
|
|
||||||
}
|
|
||||||
if ((bitsize & (bitsize - 1)) != 0) {
|
|
||||||
return absl::InvalidArgumentError("`bitsize` must be a power of 2");
|
|
||||||
}
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status ValidateIntegerValue(const Value::Integer& value,
|
|
||||||
const ValueType::Integer& type) {
|
|
||||||
if (type.bitsize() < 128) {
|
|
||||||
DPF_ASSIGN_OR_RETURN(absl::uint128 value_128, ValueIntegerToUint128(value));
|
|
||||||
if (value_128 >= absl::uint128{1} << type.bitsize()) {
|
|
||||||
return absl::InvalidArgumentError(absl::StrFormat(
|
|
||||||
"Value (= %d) too large for ValueType with bitsize = %d", value_128,
|
|
||||||
type.bitsize()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
ProtoValidator::ProtoValidator(std::vector<DpfParameters> parameters,
|
|
||||||
int tree_levels_needed,
|
|
||||||
absl::flat_hash_map<int, int> tree_to_hierarchy,
|
|
||||||
std::vector<int> hierarchy_to_tree)
|
|
||||||
: parameters_(std::move(parameters)),
|
|
||||||
tree_levels_needed_(tree_levels_needed),
|
|
||||||
tree_to_hierarchy_(std::move(tree_to_hierarchy)),
|
|
||||||
hierarchy_to_tree_(std::move(hierarchy_to_tree)) {}
|
|
||||||
|
|
||||||
absl::StatusOr<std::unique_ptr<ProtoValidator>> ProtoValidator::Create(
|
|
||||||
absl::Span<const DpfParameters> parameters_in) {
|
|
||||||
DPF_RETURN_IF_ERROR(ValidateParameters(parameters_in));
|
|
||||||
|
|
||||||
// Set default values of security_parameter for all parameters.
|
|
||||||
std::vector<DpfParameters> parameters(parameters_in.begin(),
|
|
||||||
parameters_in.end());
|
|
||||||
for (int i = 0; i < static_cast<int>(parameters.size()); ++i) {
|
|
||||||
if (parameters[i].security_parameter() == 0) {
|
|
||||||
parameters[i].set_security_parameter(
|
|
||||||
GetDefaultSecurityParameter(parameters[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Map hierarchy levels to levels in the evaluation tree for value correction,
|
|
||||||
// and vice versa.
|
|
||||||
absl::flat_hash_map<int, int> tree_to_hierarchy;
|
|
||||||
std::vector<int> hierarchy_to_tree(parameters.size());
|
|
||||||
// Also keep track of the height needed for the evaluation tree so far.
|
|
||||||
int tree_levels_needed = 0;
|
|
||||||
for (int i = 0; i < static_cast<int>(parameters.size()); ++i) {
|
|
||||||
int log_bits_needed;
|
|
||||||
DPF_ASSIGN_OR_RETURN(int bits_needed,
|
|
||||||
BitsNeeded(parameters[i].value_type(),
|
|
||||||
parameters[i].security_parameter()));
|
|
||||||
log_bits_needed = static_cast<int>(std::ceil(std::log2(bits_needed)));
|
|
||||||
|
|
||||||
// The tree level depends on the domain size and the element size. A single
|
|
||||||
// AES block can fit 128 = 2^7 bits, so usually tree_level ==
|
|
||||||
// log_domain_size iff log_element_size >= 7. For smaller element sizes, we
|
|
||||||
// can reduce the tree_level (and thus the height of the tree) by the
|
|
||||||
// difference between log_element_size and 7. However, since the minimum
|
|
||||||
// tree level is 0, we have to ensure that no two hierarchy levels map to
|
|
||||||
// the same tree_level, hence the std::max.
|
|
||||||
int tree_level =
|
|
||||||
std::max(tree_levels_needed, parameters[i].log_domain_size() - 7 +
|
|
||||||
std::min(log_bits_needed, 7));
|
|
||||||
tree_to_hierarchy[tree_level] = i;
|
|
||||||
hierarchy_to_tree[i] = tree_level;
|
|
||||||
tree_levels_needed = std::max(tree_levels_needed, tree_level + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
return absl::WrapUnique(new ProtoValidator(
|
|
||||||
std::move(parameters), tree_levels_needed, std::move(tree_to_hierarchy),
|
|
||||||
std::move(hierarchy_to_tree)));
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status ProtoValidator::ValidateParameters(
|
|
||||||
absl::Span<const DpfParameters> parameters) {
|
|
||||||
// Check that parameters are valid.
|
|
||||||
if (parameters.empty()) {
|
|
||||||
return absl::InvalidArgumentError("`parameters` must not be empty");
|
|
||||||
}
|
|
||||||
// Sentinel value for checking that domain sizes are increasing.
|
|
||||||
int previous_log_domain_size = 0;
|
|
||||||
for (int i = 0; i < static_cast<int>(parameters.size()); ++i) {
|
|
||||||
// Check log_domain_size.
|
|
||||||
int log_domain_size = parameters[i].log_domain_size();
|
|
||||||
if (log_domain_size < 0) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"`log_domain_size` must be non-negative");
|
|
||||||
}
|
|
||||||
if (log_domain_size > 128) {
|
|
||||||
return absl::InvalidArgumentError("`log_domain_size` must be <= 128");
|
|
||||||
}
|
|
||||||
if (i > 0 && log_domain_size <= previous_log_domain_size) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"`log_domain_size` fields must be in ascending order in "
|
|
||||||
"`parameters`");
|
|
||||||
}
|
|
||||||
previous_log_domain_size = log_domain_size;
|
|
||||||
|
|
||||||
if (parameters[i].has_value_type()) {
|
|
||||||
DPF_RETURN_IF_ERROR(ValidateValueType(parameters[i].value_type()));
|
|
||||||
} else {
|
|
||||||
return absl::InvalidArgumentError("`value_type` is required");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (std::isnan(parameters[i].security_parameter())) {
|
|
||||||
return absl::InvalidArgumentError("`security_parameter` must not be NaN");
|
|
||||||
}
|
|
||||||
if (parameters[i].security_parameter() < 0 ||
|
|
||||||
parameters[i].security_parameter() > 128) {
|
|
||||||
// Since we use AES-128 for the PRG, a security parameter of > 128 is not
|
|
||||||
// possible.
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"`security_parameter` must be in [0, 128]");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status ProtoValidator::ValidateDpfKey(const DpfKey& key) const {
|
|
||||||
// Check that `key` has the seed and last_level_output_correction set.
|
|
||||||
if (!key.has_seed()) {
|
|
||||||
return absl::InvalidArgumentError("key.seed must be present");
|
|
||||||
}
|
|
||||||
if (key.last_level_value_correction().empty()) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"key.last_level_value_correction must be present");
|
|
||||||
}
|
|
||||||
// Check that `key` is valid for the DPF defined by `parameters_`.
|
|
||||||
if (key.correction_words_size() != tree_levels_needed_ - 1) {
|
|
||||||
return absl::InvalidArgumentError(absl::StrCat(
|
|
||||||
"Malformed DpfKey: expected ", tree_levels_needed_ - 1,
|
|
||||||
" correction words, but got ", key.correction_words_size()));
|
|
||||||
}
|
|
||||||
for (int i = 0; i < static_cast<int>(hierarchy_to_tree_.size()); ++i) {
|
|
||||||
if (hierarchy_to_tree_[i] == tree_levels_needed_ - 1) {
|
|
||||||
// The output correction of the last tree level is always stored in
|
|
||||||
// last_level_output_correction.
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
ABSL_DCHECK(hierarchy_to_tree_[i] < key.correction_words_size());
|
|
||||||
if (key.correction_words(hierarchy_to_tree_[i])
|
|
||||||
.value_correction()
|
|
||||||
.empty()) {
|
|
||||||
return absl::InvalidArgumentError(absl::StrCat(
|
|
||||||
"Malformed DpfKey: expected correction_words[", hierarchy_to_tree_[i],
|
|
||||||
"] to contain the value correction of hierarchy level ", i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status ProtoValidator::ValidateEvaluationContext(
|
|
||||||
const EvaluationContext& ctx) const {
|
|
||||||
if (ctx.parameters_size() != static_cast<int>(parameters_.size())) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"Number of parameters in `ctx` doesn't match");
|
|
||||||
}
|
|
||||||
for (int i = 0; i < ctx.parameters_size(); ++i) {
|
|
||||||
DPF_ASSIGN_OR_RETURN(bool parameters_are_equal,
|
|
||||||
ParametersAreEqual(parameters_[i], ctx.parameters(i)));
|
|
||||||
if (!parameters_are_equal) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
absl::StrCat("Parameter ", i, " in `ctx` doesn't match"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!ctx.has_key()) {
|
|
||||||
return absl::InvalidArgumentError("ctx.key must be present");
|
|
||||||
}
|
|
||||||
DPF_RETURN_IF_ERROR(ValidateDpfKey(ctx.key()));
|
|
||||||
if (ctx.previous_hierarchy_level() >= ctx.parameters_size() - 1) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"This context has already been fully evaluated");
|
|
||||||
}
|
|
||||||
if (!ctx.partial_evaluations().empty() &&
|
|
||||||
ctx.partial_evaluations_level() > ctx.previous_hierarchy_level()) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"ctx.partial_evaluations_level must be less than or equal to "
|
|
||||||
"ctx.previous_hierarchy_level");
|
|
||||||
}
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status ProtoValidator::ValidateValueType(const ValueType& value_type) {
|
|
||||||
if (value_type.type_case() == ValueType::kInteger) {
|
|
||||||
return ValidateIntegerType(value_type.integer());
|
|
||||||
} else if (value_type.type_case() == ValueType::kTuple) {
|
|
||||||
for (const ValueType& el : value_type.tuple().elements()) {
|
|
||||||
DPF_RETURN_IF_ERROR(ValidateValueType(el));
|
|
||||||
}
|
|
||||||
return absl::OkStatus();
|
|
||||||
} else if (value_type.type_case() == ValueType::kIntModN) {
|
|
||||||
const ValueType::Integer& base_integer =
|
|
||||||
value_type.int_mod_n().base_integer();
|
|
||||||
DPF_RETURN_IF_ERROR(ValidateIntegerType(base_integer));
|
|
||||||
return ValidateIntegerValue(value_type.int_mod_n().modulus(), base_integer);
|
|
||||||
} else if (value_type.type_case() == ValueType::kXorWrapper) {
|
|
||||||
return ValidateIntegerType(value_type.xor_wrapper());
|
|
||||||
}
|
|
||||||
return absl::InvalidArgumentError(absl::StrCat(
|
|
||||||
"ValidateValueType: Unsupported ValueType:\n", value_type.DebugString()));
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status ProtoValidator::ValidateValue(const Value& value,
|
|
||||||
const ValueType& type) {
|
|
||||||
if (type.type_case() == ValueType::kInteger) {
|
|
||||||
// Integers.
|
|
||||||
if (value.value_case() != Value::kInteger) {
|
|
||||||
return absl::InvalidArgumentError("Expected integer value");
|
|
||||||
}
|
|
||||||
return ValidateIntegerValue(value.integer(), type.integer());
|
|
||||||
} else if (type.type_case() == ValueType::kTuple) {
|
|
||||||
// Tuples.
|
|
||||||
if (value.value_case() != Value::kTuple) {
|
|
||||||
return absl::InvalidArgumentError("Expected tuple value");
|
|
||||||
}
|
|
||||||
if (value.tuple().elements_size() != type.tuple().elements_size()) {
|
|
||||||
return absl::InvalidArgumentError(absl::StrCat(
|
|
||||||
"Expected tuple value of size ", type.tuple().elements_size(),
|
|
||||||
" but got size ", value.tuple().elements_size()));
|
|
||||||
}
|
|
||||||
for (int i = 0; i < type.tuple().elements_size(); ++i) {
|
|
||||||
DPF_RETURN_IF_ERROR(
|
|
||||||
ValidateValue(value.tuple().elements(i), type.tuple().elements(i)));
|
|
||||||
}
|
|
||||||
return absl::OkStatus();
|
|
||||||
} else if (type.type_case() == ValueType::kIntModN) {
|
|
||||||
DPF_RETURN_IF_ERROR(ValidateIntegerValue(value.int_mod_n(),
|
|
||||||
type.int_mod_n().base_integer()));
|
|
||||||
DPF_ASSIGN_OR_RETURN(absl::uint128 value_128,
|
|
||||||
ValueIntegerToUint128(value.int_mod_n()));
|
|
||||||
DPF_ASSIGN_OR_RETURN(absl::uint128 modulus_128,
|
|
||||||
ValueIntegerToUint128(type.int_mod_n().modulus()));
|
|
||||||
if (value_128 >= modulus_128) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
absl::StrFormat("Value (= %d) is too large for modulus (= %d)",
|
|
||||||
value_128, modulus_128));
|
|
||||||
}
|
|
||||||
return absl::OkStatus();
|
|
||||||
} else if (type.type_case() == ValueType::kXorWrapper) {
|
|
||||||
if (value.value_case() != Value::kXorWrapper) {
|
|
||||||
return absl::InvalidArgumentError("Expected XorWrapper value");
|
|
||||||
}
|
|
||||||
return ValidateIntegerValue(value.xor_wrapper(), type.xor_wrapper());
|
|
||||||
}
|
|
||||||
return absl::InvalidArgumentError(absl::StrCat(
|
|
||||||
"ValidateValue: Unsupported ValueType:\n", type.DebugString()));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
@ -1,120 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2021 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_PROTO_VALIDATOR_H_
|
|
||||||
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_PROTO_VALIDATOR_H_
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/types/span.h"
|
|
||||||
#include "dpf/distributed_point_function.pb.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
// ProtoValidator is used to validate protos for DPF parameters, keys, and
|
|
||||||
// evaluation contexts. Also holds information computed from the DPF parameters,
|
|
||||||
// such as the mappings between hierarchy and tree levels.
|
|
||||||
class ProtoValidator {
|
|
||||||
public:
|
|
||||||
// The negative logarithm of the total variation distance from uniform that a
|
|
||||||
// *full* evaluation of a hierarchy level is allowed to have. Used as the
|
|
||||||
// default value for DpfParameters that don't have an explicit per-element
|
|
||||||
// security parameter set.
|
|
||||||
static constexpr double kDefaultSecurityParameter = 40;
|
|
||||||
|
|
||||||
// Security parameters that differ by less than this are considered equal.
|
|
||||||
static constexpr double kSecurityParameterEpsilon = 0.0001;
|
|
||||||
|
|
||||||
// Checks the validity of `parameters` and returns a ProtoValidator, which
|
|
||||||
// will be used to validate DPF keys and evaluation contexts afterwards.
|
|
||||||
//
|
|
||||||
// Returns INVALID_ARGUMENT if `parameters` are invalid.
|
|
||||||
static absl::StatusOr<std::unique_ptr<ProtoValidator>> Create(
|
|
||||||
absl::Span<const DpfParameters> parameters);
|
|
||||||
|
|
||||||
// Checks the validity of `parameters`.
|
|
||||||
// Returns OK on success, and INVALID_ARGUMENT otherwise.
|
|
||||||
static absl::Status ValidateParameters(
|
|
||||||
absl::Span<const DpfParameters> parameters);
|
|
||||||
|
|
||||||
// Checks that `key` is valid for the `parameters` passed at construction.
|
|
||||||
// Returns OK on success, and INVALID_ARGUMENT otherwise.
|
|
||||||
absl::Status ValidateDpfKey(const DpfKey& key) const;
|
|
||||||
|
|
||||||
// Checks that `ctx` is valid for the `parameters` passed at construction.
|
|
||||||
// Returns OK on success, and INVALID_ARGUMENT otherwise.
|
|
||||||
absl::Status ValidateEvaluationContext(const EvaluationContext& ctx) const;
|
|
||||||
|
|
||||||
// Checks that the given ValueType is valid.
|
|
||||||
// Returns OK on success and INVALID_ARGUMENT otherwise.
|
|
||||||
static absl::Status ValidateValueType(const ValueType& value_type);
|
|
||||||
|
|
||||||
// Checks that `value` is valid for `type`.
|
|
||||||
// Returns OK on success and INVALID_ARGUMENT otherwise.
|
|
||||||
static absl::Status ValidateValue(const Value& value, const ValueType& type);
|
|
||||||
|
|
||||||
// Checks that `value` is valid for `parameters[i]` passed at construction.
|
|
||||||
// Returns OK on success and INVALID_ARGUMENT otherwise.
|
|
||||||
inline absl::Status ValidateValue(const Value& value, int i) const {
|
|
||||||
return ValidateValue(value, parameters_[i].value_type());
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProtoValidator is not copyable.
|
|
||||||
ProtoValidator(const ProtoValidator&) = delete;
|
|
||||||
ProtoValidator& operator=(const ProtoValidator&) = delete;
|
|
||||||
|
|
||||||
// ProtoValidator is movable.
|
|
||||||
ProtoValidator(ProtoValidator&&) = default;
|
|
||||||
ProtoValidator& operator=(ProtoValidator&&) = default;
|
|
||||||
|
|
||||||
// Getters.
|
|
||||||
absl::Span<const DpfParameters> parameters() const { return parameters_; }
|
|
||||||
int tree_levels_needed() const { return tree_levels_needed_; }
|
|
||||||
const absl::flat_hash_map<int, int>& tree_to_hierarchy() const {
|
|
||||||
return tree_to_hierarchy_;
|
|
||||||
}
|
|
||||||
const std::vector<int>& hierarchy_to_tree() const {
|
|
||||||
return hierarchy_to_tree_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
ProtoValidator(std::vector<DpfParameters> parameters, int tree_levels_needed,
|
|
||||||
absl::flat_hash_map<int, int> tree_to_hierarchy,
|
|
||||||
std::vector<int> hierarchy_to_tree);
|
|
||||||
|
|
||||||
// The DpfParameters passed at construction.
|
|
||||||
std::vector<DpfParameters> parameters_;
|
|
||||||
|
|
||||||
// Number of levels in the evaluation tree. This is always less than or equal
|
|
||||||
// to the largest log_domain_size in parameters_.
|
|
||||||
int tree_levels_needed_;
|
|
||||||
|
|
||||||
// Maps levels of the FSS evaluation tree to hierarchy levels (i.e., elements
|
|
||||||
// of parameters_).
|
|
||||||
absl::flat_hash_map<int, int> tree_to_hierarchy_;
|
|
||||||
|
|
||||||
// The inverse of tree_to_hierarchy_.
|
|
||||||
std::vector<int> hierarchy_to_tree_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
|
|
||||||
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_PROTO_VALIDATOR_H_
|
|
@ -1,417 +0,0 @@
|
|||||||
// Copyright 2021 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "dpf/internal/proto_validator.h"
|
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "absl/strings/str_format.h"
|
|
||||||
#include "dpf/distributed_point_function.pb.h"
|
|
||||||
#include "dpf/internal/proto_validator_test_textproto_embed.h"
|
|
||||||
#include "dpf/internal/status_matchers.h"
|
|
||||||
#include "dpf/tuple.h"
|
|
||||||
#include "gmock/gmock.h"
|
|
||||||
#include "google/protobuf/repeated_field.h"
|
|
||||||
#include "google/protobuf/text_format.h"
|
|
||||||
#include "gtest/gtest.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
using ::testing::Ne;
|
|
||||||
using ::testing::StartsWith;
|
|
||||||
|
|
||||||
class ProtoValidatorTest : public testing::Test {
|
|
||||||
protected:
|
|
||||||
void SetUp() override {
|
|
||||||
const auto* const toc = proto_validator_test_textproto_embed_create();
|
|
||||||
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(
|
|
||||||
std::string(toc->data, toc->size), &ctx_));
|
|
||||||
parameters_ = std::vector<DpfParameters>(ctx_.parameters().begin(),
|
|
||||||
ctx_.parameters().end());
|
|
||||||
dpf_key_ = ctx_.key();
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(proto_validator_,
|
|
||||||
ProtoValidator::Create(parameters_));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<DpfParameters> parameters_;
|
|
||||||
DpfKey dpf_key_;
|
|
||||||
EvaluationContext ctx_;
|
|
||||||
std::unique_ptr<dpf_internal::ProtoValidator> proto_validator_;
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, CreateFailsWithoutParameters) {
|
|
||||||
EXPECT_THAT(ProtoValidator::Create({}),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"`parameters` must not be empty"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, CreateFailsWhenParametersNotSorted) {
|
|
||||||
parameters_.resize(2);
|
|
||||||
parameters_[0].set_log_domain_size(10);
|
|
||||||
parameters_[1].set_log_domain_size(8);
|
|
||||||
|
|
||||||
EXPECT_THAT(ProtoValidator::Create(parameters_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"`log_domain_size` fields must be in ascending order in "
|
|
||||||
"`parameters`"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, CreateFailsWhenDomainSizeNegative) {
|
|
||||||
parameters_.resize(1);
|
|
||||||
parameters_[0].set_log_domain_size(-1);
|
|
||||||
|
|
||||||
EXPECT_THAT(ProtoValidator::Create(parameters_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"`log_domain_size` must be non-negative"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, CreateFailsWhenDomainSizeTooLarge) {
|
|
||||||
parameters_.resize(1);
|
|
||||||
parameters_[0].set_log_domain_size(129);
|
|
||||||
|
|
||||||
EXPECT_THAT(ProtoValidator::Create(parameters_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"`log_domain_size` must be <= 128"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, CreateFailsWhenElementBitsizeNegative) {
|
|
||||||
parameters_.resize(1);
|
|
||||||
parameters_[0].mutable_value_type()->mutable_integer()->set_bitsize(-1);
|
|
||||||
|
|
||||||
EXPECT_THAT(ProtoValidator::Create(parameters_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"`bitsize` must be positive"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, CreateFailsWhenElementBitsizeZero) {
|
|
||||||
parameters_.resize(1);
|
|
||||||
parameters_[0].mutable_value_type()->mutable_integer()->set_bitsize(0);
|
|
||||||
|
|
||||||
EXPECT_THAT(ProtoValidator::Create(parameters_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"`bitsize` must be positive"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, CreateFailsWhenElementBitsizeTooLarge) {
|
|
||||||
parameters_.resize(1);
|
|
||||||
parameters_[0].mutable_value_type()->mutable_integer()->set_bitsize(256);
|
|
||||||
|
|
||||||
EXPECT_THAT(ProtoValidator::Create(parameters_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"`bitsize` must be less than or equal to 128"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, CreateFailsWhenElementBitsizeNotAPowerOfTwo) {
|
|
||||||
parameters_.resize(1);
|
|
||||||
parameters_[0].mutable_value_type()->mutable_integer()->set_bitsize(23);
|
|
||||||
|
|
||||||
EXPECT_THAT(ProtoValidator::Create(parameters_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"`bitsize` must be a power of 2"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, CreateFailsIfSecurityParameterIsNaN) {
|
|
||||||
parameters_.resize(1);
|
|
||||||
parameters_[0].set_security_parameter(std::nan(""));
|
|
||||||
|
|
||||||
EXPECT_THAT(ProtoValidator::Create(parameters_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"`security_parameter` must not be NaN"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, CreateFailsIfSecurityParameterIsNegative) {
|
|
||||||
parameters_.resize(1);
|
|
||||||
parameters_[0].set_security_parameter(-0.01);
|
|
||||||
|
|
||||||
EXPECT_THAT(ProtoValidator::Create(parameters_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"`security_parameter` must be in [0, 128]"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, CreateFailsIfSecurityParameterIsTooLarge) {
|
|
||||||
parameters_.resize(1);
|
|
||||||
parameters_[0].set_security_parameter(128.01);
|
|
||||||
|
|
||||||
EXPECT_THAT(ProtoValidator::Create(parameters_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"`security_parameter` must be in [0, 128]"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, CreateWorksWhenElementBitsizesDecrease) {
|
|
||||||
parameters_.resize(2);
|
|
||||||
parameters_[0].mutable_value_type()->mutable_integer()->set_bitsize(64);
|
|
||||||
parameters_[1].mutable_value_type()->mutable_integer()->set_bitsize(32);
|
|
||||||
|
|
||||||
EXPECT_THAT(ProtoValidator::Create(parameters_), IsOkAndHolds(Ne(nullptr)));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, CreateWorksWhenHierarchiesAreFarApart) {
|
|
||||||
parameters_.resize(2);
|
|
||||||
parameters_[0].set_log_domain_size(10);
|
|
||||||
parameters_[1].set_log_domain_size(128);
|
|
||||||
|
|
||||||
EXPECT_THAT(ProtoValidator::Create(parameters_), IsOkAndHolds(Ne(nullptr)));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest,
|
|
||||||
ValidateDpfKeyFailsIfNumberOfCorrectionWordsDoesntMatch) {
|
|
||||||
dpf_key_.add_correction_words();
|
|
||||||
|
|
||||||
EXPECT_THAT(proto_validator_->ValidateDpfKey(dpf_key_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
absl::StrCat("Malformed DpfKey: expected ",
|
|
||||||
dpf_key_.correction_words_size() - 1,
|
|
||||||
" correction words, but got ",
|
|
||||||
dpf_key_.correction_words_size())));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, ValidateDpfKeyFailsIfSeedIsMissing) {
|
|
||||||
dpf_key_.clear_seed();
|
|
||||||
|
|
||||||
EXPECT_THAT(
|
|
||||||
proto_validator_->ValidateDpfKey(dpf_key_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument, "key.seed must be present"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest,
|
|
||||||
ValidateDpfKeyFailsIfLastLevelOutputCorrectionIsMissing) {
|
|
||||||
dpf_key_.clear_last_level_value_correction();
|
|
||||||
|
|
||||||
EXPECT_THAT(proto_validator_->ValidateDpfKey(dpf_key_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"key.last_level_value_correction must be present"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, ValidateDpfKeyFailsIfOutputCorrectionIsMissing) {
|
|
||||||
for (CorrectionWord& cw : *(dpf_key_.mutable_correction_words())) {
|
|
||||||
cw.clear_value_correction();
|
|
||||||
}
|
|
||||||
|
|
||||||
EXPECT_THAT(
|
|
||||||
proto_validator_->ValidateDpfKey(dpf_key_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
StartsWith("Malformed DpfKey: expected correction_words")));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, ValidateEvaluationContextFailsIfKeyIsMissing) {
|
|
||||||
ctx_.clear_key();
|
|
||||||
|
|
||||||
EXPECT_THAT(
|
|
||||||
proto_validator_->ValidateEvaluationContext(ctx_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument, "ctx.key must be present"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest,
|
|
||||||
ValidateEvaluationContextFailsIfParameterSizeDoesntMatch) {
|
|
||||||
ctx_.mutable_parameters()->erase(ctx_.parameters().end() - 1);
|
|
||||||
|
|
||||||
EXPECT_THAT(proto_validator_->ValidateEvaluationContext(ctx_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"Number of parameters in `ctx` doesn't match"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest,
|
|
||||||
ValidateEvaluationContextFailsIfLogDomainSizeDoesntMatch) {
|
|
||||||
ctx_.mutable_parameters(0)->set_log_domain_size(
|
|
||||||
ctx_.parameters(0).log_domain_size() + 1);
|
|
||||||
|
|
||||||
EXPECT_THAT(proto_validator_->ValidateEvaluationContext(ctx_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"Parameter 0 in `ctx` doesn't match"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest,
|
|
||||||
ValidateEvaluationContextSucceedsIfSecurityParameterIsDefault) {
|
|
||||||
parameters_[0].set_security_parameter(0);
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(proto_validator_,
|
|
||||||
ProtoValidator::Create(parameters_));
|
|
||||||
|
|
||||||
ctx_.mutable_parameters(0)->set_security_parameter(0);
|
|
||||||
|
|
||||||
EXPECT_THAT(proto_validator_->ValidateEvaluationContext(ctx_), IsOk());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest,
|
|
||||||
ValidateEvaluationContextFailsIfSecurityParameterDoesntMatch) {
|
|
||||||
ctx_.mutable_parameters(0)->set_security_parameter(
|
|
||||||
ctx_.parameters(0).security_parameter() + 1);
|
|
||||||
|
|
||||||
EXPECT_THAT(proto_validator_->ValidateEvaluationContext(ctx_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"Parameter 0 in `ctx` doesn't match"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest,
|
|
||||||
ValidateEvaluationContextFailsIfContextFullyEvaluated) {
|
|
||||||
ctx_.set_previous_hierarchy_level(parameters_.size() - 1);
|
|
||||||
|
|
||||||
EXPECT_THAT(proto_validator_->ValidateEvaluationContext(ctx_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"This context has already been fully evaluated"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest,
|
|
||||||
ValidateEvaluationContextFailsIfPartialEvaluationsLevelTooLarge) {
|
|
||||||
ctx_.set_previous_hierarchy_level(0);
|
|
||||||
ctx_.set_partial_evaluations_level(1);
|
|
||||||
ctx_.add_partial_evaluations();
|
|
||||||
|
|
||||||
EXPECT_THAT(
|
|
||||||
proto_validator_->ValidateEvaluationContext(ctx_),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"ctx.partial_evaluations_level must be less than or equal to "
|
|
||||||
"ctx.previous_hierarchy_level"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, ValidateValueFailsIfTypeNotInteger) {
|
|
||||||
ValueType type;
|
|
||||||
type.mutable_integer()->set_bitsize(32);
|
|
||||||
Value value;
|
|
||||||
value.mutable_tuple()->add_elements()->mutable_integer()->set_value_uint64(
|
|
||||||
23);
|
|
||||||
|
|
||||||
EXPECT_THAT(
|
|
||||||
proto_validator_->ValidateValue(value, type),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument, "Expected integer value"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, ValidateValueFailsIfIntegerTooLarge) {
|
|
||||||
ValueType type;
|
|
||||||
Value value;
|
|
||||||
|
|
||||||
int element_bitsize = 32;
|
|
||||||
type.mutable_integer()->set_bitsize(element_bitsize);
|
|
||||||
auto value_64 = uint64_t{1} << element_bitsize;
|
|
||||||
value.mutable_integer()->set_value_uint64(value_64);
|
|
||||||
|
|
||||||
EXPECT_THAT(
|
|
||||||
proto_validator_->ValidateValue(value, type),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
absl::StrFormat(
|
|
||||||
"Value (= %d) too large for ValueType with bitsize = %d",
|
|
||||||
value_64, element_bitsize)));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, ValidateValueFailsIfTypeNotTuple) {
|
|
||||||
ValueType type;
|
|
||||||
type.mutable_tuple()->add_elements()->mutable_integer()->set_bitsize(32);
|
|
||||||
Value value;
|
|
||||||
value.mutable_integer()->set_value_uint64(23);
|
|
||||||
|
|
||||||
EXPECT_THAT(
|
|
||||||
proto_validator_->ValidateValue(value, type),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument, "Expected tuple value"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, ValidateValueFailsIfTupleSizeDoesntMatch) {
|
|
||||||
ValueType type;
|
|
||||||
type.mutable_tuple()->add_elements()->mutable_integer()->set_bitsize(32);
|
|
||||||
Value value;
|
|
||||||
|
|
||||||
value.mutable_tuple()->add_elements()->mutable_integer()->set_value_uint64(
|
|
||||||
23);
|
|
||||||
value.mutable_tuple()->add_elements()->mutable_integer()->set_value_uint64(
|
|
||||||
42);
|
|
||||||
|
|
||||||
EXPECT_THAT(proto_validator_->ValidateValue(value, type),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"Expected tuple value of size 1 but got size 2"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, ValidateValueFailsIfValueLargerThanModulus) {
|
|
||||||
constexpr uint64_t kModulus = 3;
|
|
||||||
ValueType type;
|
|
||||||
type.mutable_int_mod_n()->mutable_base_integer()->set_bitsize(64);
|
|
||||||
type.mutable_int_mod_n()->mutable_modulus()->set_value_uint64(kModulus);
|
|
||||||
Value value;
|
|
||||||
|
|
||||||
value.mutable_int_mod_n()->set_value_uint64(kModulus);
|
|
||||||
|
|
||||||
EXPECT_THAT(proto_validator_->ValidateValue(value, type),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"Value (= 3) is too large for modulus (= 3)"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, ValidateValueFailsIfTypeNotXorWrapper) {
|
|
||||||
ValueType type;
|
|
||||||
type.mutable_xor_wrapper()->set_bitsize(32);
|
|
||||||
Value value;
|
|
||||||
value.mutable_integer()->set_value_uint64(23);
|
|
||||||
|
|
||||||
EXPECT_THAT(proto_validator_->ValidateValue(value, type),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"Expected XorWrapper value"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(ProtoValidatorTest, ValidateValueFailsIfValueIsUnknown) {
|
|
||||||
ValueType type;
|
|
||||||
Value value;
|
|
||||||
|
|
||||||
EXPECT_THAT(
|
|
||||||
proto_validator_->ValidateValue(value, type),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
testing::StartsWith("ValidateValue: Unsupported ValueType:")));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ProtoValidator, ValidateValueTypeFailsIfBitsizeNotPositive) {
|
|
||||||
ValueType type;
|
|
||||||
|
|
||||||
type.mutable_integer()->set_bitsize(0);
|
|
||||||
|
|
||||||
EXPECT_THAT(ProtoValidator::ValidateValueType(type),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"`bitsize` must be positive"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ProtoValidator, ValidateValueTypeFailsIfBitsizeTooLarge) {
|
|
||||||
ValueType type;
|
|
||||||
|
|
||||||
type.mutable_integer()->set_bitsize(256);
|
|
||||||
|
|
||||||
EXPECT_THAT(ProtoValidator::ValidateValueType(type),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"`bitsize` must be less than or equal to 128"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ProtoValidator, ValidateValueTypeFailsIfBitsizeNotPowerOfTwo) {
|
|
||||||
ValueType type;
|
|
||||||
|
|
||||||
type.mutable_integer()->set_bitsize(17);
|
|
||||||
|
|
||||||
EXPECT_THAT(ProtoValidator::ValidateValueType(type),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"`bitsize` must be a power of 2"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ProtoValidator, ValidateValueTypeFailsIfNoTypeChosen) {
|
|
||||||
ValueType type;
|
|
||||||
|
|
||||||
EXPECT_THAT(ProtoValidator::ValidateValueType(type),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
StartsWith("ValidateValueType: Unsupported ValueType")));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
108
third_party/distributed_point_functions/code/dpf/internal/proto_validator_test.textproto
vendored
108
third_party/distributed_point_functions/code/dpf/internal/proto_validator_test.textproto
vendored
@ -1,108 +0,0 @@
|
|||||||
# Copyright 2021 Google LLC
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# proto-file dpf/distributed_point_function.proto
|
|
||||||
# proto-message: EvaluationContext
|
|
||||||
|
|
||||||
parameters {
|
|
||||||
log_domain_size: 4
|
|
||||||
value_type {
|
|
||||||
integer {
|
|
||||||
bitsize: 32
|
|
||||||
}
|
|
||||||
}
|
|
||||||
security_parameter: 44
|
|
||||||
}
|
|
||||||
parameters {
|
|
||||||
log_domain_size: 6
|
|
||||||
value_type {
|
|
||||||
integer {
|
|
||||||
bitsize: 32
|
|
||||||
}
|
|
||||||
}
|
|
||||||
security_parameter: 46
|
|
||||||
}
|
|
||||||
parameters {
|
|
||||||
log_domain_size: 8
|
|
||||||
value_type {
|
|
||||||
integer {
|
|
||||||
bitsize: 32
|
|
||||||
}
|
|
||||||
}
|
|
||||||
security_parameter: 48
|
|
||||||
}
|
|
||||||
key {
|
|
||||||
seed {
|
|
||||||
high: 11559904407150645412
|
|
||||||
low: 10793182457266619527
|
|
||||||
}
|
|
||||||
correction_words {
|
|
||||||
seed {
|
|
||||||
high: 17231204231811741091
|
|
||||||
low: 13184625655696690000
|
|
||||||
}
|
|
||||||
control_left: true
|
|
||||||
}
|
|
||||||
correction_words {
|
|
||||||
seed {
|
|
||||||
high: 3072212389250066354
|
|
||||||
low: 1361245143349174348
|
|
||||||
}
|
|
||||||
}
|
|
||||||
correction_words {
|
|
||||||
seed {
|
|
||||||
high: 2882988684359810666
|
|
||||||
low: 16992210518729579018
|
|
||||||
}
|
|
||||||
control_right: true
|
|
||||||
value_correction: {
|
|
||||||
integer: {
|
|
||||||
value_uint64: 536412310
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
correction_words {
|
|
||||||
seed {
|
|
||||||
high: 4993590839844520517
|
|
||||||
low: 13033365507284852634
|
|
||||||
}
|
|
||||||
control_right: true
|
|
||||||
}
|
|
||||||
correction_words {
|
|
||||||
seed {
|
|
||||||
high: 10673753674550143002
|
|
||||||
low: 3019916643383017704
|
|
||||||
}
|
|
||||||
control_left: true
|
|
||||||
control_right: true
|
|
||||||
value_correction: {
|
|
||||||
integer: {
|
|
||||||
value_uint64: 841224518
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
correction_words {
|
|
||||||
seed {
|
|
||||||
high: 2423099213299230757
|
|
||||||
low: 12788496417753523946
|
|
||||||
}
|
|
||||||
control_right: true
|
|
||||||
}
|
|
||||||
last_level_value_correction: {
|
|
||||||
integer: {
|
|
||||||
value_uint64: 8471844854
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
previous_hierarchy_level: -1
|
|
@ -1,63 +0,0 @@
|
|||||||
// Copyright 2021 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "dpf/internal/status_matchers.h"
|
|
||||||
|
|
||||||
#include <ostream>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/strings/string_view.h"
|
|
||||||
#include "gmock/gmock.h"
|
|
||||||
#include "gtest/gtest.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
|
|
||||||
void StatusIsMatcherCommonImpl::DescribeTo(std::ostream* os) const {
|
|
||||||
*os << "has a status code that ";
|
|
||||||
code_matcher_.DescribeTo(os);
|
|
||||||
*os << ", and has an error message that ";
|
|
||||||
message_matcher_.DescribeTo(os);
|
|
||||||
}
|
|
||||||
|
|
||||||
void StatusIsMatcherCommonImpl::DescribeNegationTo(std::ostream* os) const {
|
|
||||||
*os << "has a status code that ";
|
|
||||||
code_matcher_.DescribeNegationTo(os);
|
|
||||||
*os << ", or has an error message that ";
|
|
||||||
message_matcher_.DescribeNegationTo(os);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool StatusIsMatcherCommonImpl::MatchAndExplain(
|
|
||||||
const ::absl::Status& status,
|
|
||||||
::testing::MatchResultListener* result_listener) const {
|
|
||||||
::testing::StringMatchResultListener inner_listener;
|
|
||||||
if (!code_matcher_.MatchAndExplain(status.code(), &inner_listener)) {
|
|
||||||
*result_listener << (inner_listener.str().empty()
|
|
||||||
? "whose status code is wrong"
|
|
||||||
: "which has a status code " +
|
|
||||||
inner_listener.str());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!message_matcher_.Matches(std::string(status.message()))) {
|
|
||||||
*result_listener << "whose error message is wrong: " << status.message();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
@ -1,390 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2021 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Testing utilities for working with absl::Status and absl::StatusOr.
|
|
||||||
//
|
|
||||||
// Defines the following utilities:
|
|
||||||
//
|
|
||||||
// =================
|
|
||||||
// DPF_EXPECT_OK(s)
|
|
||||||
//
|
|
||||||
// DPF_ASSERT_OK(s)
|
|
||||||
// =================
|
|
||||||
// Convenience macros for `EXPECT_THAT(s, IsOk())`, where `s` is either
|
|
||||||
// a `Status` or a `StatusOr<T>`.
|
|
||||||
//
|
|
||||||
// There are no EXPECT_NOT_OK/ASSERT_NOT_OK macros since they would not
|
|
||||||
// provide much value (when they fail, they would just print the OK status
|
|
||||||
// which conveys no more information than `EXPECT_FALSE(s.ok())`. You can
|
|
||||||
// of course use `EXPECT_THAT(s, Not(IsOk()))` if you prefer _THAT style.
|
|
||||||
//
|
|
||||||
// If you want to check for particular errors, better alternatives are:
|
|
||||||
// EXPECT_THAT(s, StatusIs(expected_error));
|
|
||||||
// EXPECT_THAT(s, StatusIs(_, HasSubstr("expected error")));
|
|
||||||
//
|
|
||||||
// ===============
|
|
||||||
// IsOkAndHolds(m)
|
|
||||||
// ===============
|
|
||||||
//
|
|
||||||
// This gMock matcher matches a StatusOr<T> value whose status is OK
|
|
||||||
// and whose inner value matches matcher m. Example:
|
|
||||||
//
|
|
||||||
// using ::testing::MatchesRegex;
|
|
||||||
// using distributed_point_functions::IsOkAndHolds;
|
|
||||||
// ...
|
|
||||||
// absl::StatusOr<string> maybe_name = ...;
|
|
||||||
// EXPECT_THAT(maybe_name, IsOkAndHolds(MatchesRegex("John .*")));
|
|
||||||
//
|
|
||||||
// ===============================
|
|
||||||
// StatusIs(status_code_matcher,
|
|
||||||
// error_message_matcher)
|
|
||||||
// ===============================
|
|
||||||
//
|
|
||||||
// This gMock matcher matches a Status or StatusOr<T> value if all of the
|
|
||||||
// following are true:
|
|
||||||
//
|
|
||||||
// - the status' error_code() matches status_code_matcher, and
|
|
||||||
// - the status' error_message() matches error_message_matcher.
|
|
||||||
//
|
|
||||||
// Example:
|
|
||||||
//
|
|
||||||
// enum FooErrorCode {
|
|
||||||
// ...
|
|
||||||
// kServerError
|
|
||||||
// };
|
|
||||||
//
|
|
||||||
// using ::testing::HasSubstr;
|
|
||||||
// using ::testing::MatchesRegex;
|
|
||||||
// using ::testing::Ne;
|
|
||||||
// using ::testing::_;
|
|
||||||
// using distributed_point_functions::StatusIs;
|
|
||||||
// absl::StatusOr<string> GetName(int id);
|
|
||||||
// ...
|
|
||||||
//
|
|
||||||
// // The status code must be kServerError; the error message can be
|
|
||||||
// // anything.
|
|
||||||
// EXPECT_THAT(GetName(42),
|
|
||||||
// StatusIs(kServerError, _));
|
|
||||||
// // The status code can be anything; the error message must match the
|
|
||||||
// // regex.
|
|
||||||
// EXPECT_THAT(GetName(43),
|
|
||||||
// StatusIs(_, MatchesRegex("server.*time-out")));
|
|
||||||
//
|
|
||||||
// // The status code should not be kServerError; the error message can be
|
|
||||||
// // anything with "client" in it.
|
|
||||||
// EXPECT_CALL(mock_env, HandleStatus(
|
|
||||||
// StatusIs(Ne(kServerError), HasSubstr("client"))));
|
|
||||||
//
|
|
||||||
// ===============================
|
|
||||||
// StatusIs(status_code_matcher)
|
|
||||||
// ===============================
|
|
||||||
//
|
|
||||||
// This is a shorthand for
|
|
||||||
// StatusIs(status_code_matcher,
|
|
||||||
// testing::_)
|
|
||||||
// In other words, it's like the two-argument StatusIs(), except that it
|
|
||||||
// ignores error message.
|
|
||||||
//
|
|
||||||
// ===============
|
|
||||||
// IsOk()
|
|
||||||
// ===============
|
|
||||||
//
|
|
||||||
// Matches an absl::Status or absl::StatusOr<T> value whose status value is
|
|
||||||
// absl::StatusCode::kOk. Equivalent to 'StatusIs(absl::StatusCode::kOk)'.
|
|
||||||
// Example:
|
|
||||||
// using distributed_point_functions::IsOk;
|
|
||||||
// ...
|
|
||||||
// absl::StatusOr<string> maybe_name = ...;
|
|
||||||
// EXPECT_THAT(maybe_name, IsOk());
|
|
||||||
// Status s = ...;
|
|
||||||
// EXPECT_THAT(s, IsOk());
|
|
||||||
//
|
|
||||||
|
|
||||||
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MATCHERS_H_
|
|
||||||
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MATCHERS_H_
|
|
||||||
|
|
||||||
#include <ostream>
|
|
||||||
#include <string>
|
|
||||||
#include <type_traits>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "dpf/status_macros.h"
|
|
||||||
#include "gmock/gmock.h"
|
|
||||||
#include "gtest/gtest.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
|
|
||||||
inline const absl::Status& GetStatus(const absl::Status& status) {
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline const absl::Status& GetStatus(const absl::StatusOr<T>& status) {
|
|
||||||
return status.status();
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////
|
|
||||||
// Implementation of IsOkAndHolds().
|
|
||||||
|
|
||||||
// Monomorphic implementation of matcher IsOkAndHolds(m). StatusOrType is a
|
|
||||||
// reference to StatusOr<T>.
|
|
||||||
template <typename StatusOrType>
|
|
||||||
class IsOkAndHoldsMatcherImpl
|
|
||||||
: public ::testing::MatcherInterface<StatusOrType> {
|
|
||||||
public:
|
|
||||||
typedef
|
|
||||||
typename std::remove_reference<StatusOrType>::type::value_type value_type;
|
|
||||||
|
|
||||||
template <typename InnerMatcher>
|
|
||||||
explicit IsOkAndHoldsMatcherImpl(InnerMatcher&& inner_matcher)
|
|
||||||
: inner_matcher_(::testing::SafeMatcherCast<const value_type&>(
|
|
||||||
std::forward<InnerMatcher>(inner_matcher))) {}
|
|
||||||
|
|
||||||
void DescribeTo(std::ostream* os) const override {
|
|
||||||
*os << "is OK and has a value that ";
|
|
||||||
inner_matcher_.DescribeTo(os);
|
|
||||||
}
|
|
||||||
|
|
||||||
void DescribeNegationTo(std::ostream* os) const override {
|
|
||||||
*os << "isn't OK or has a value that ";
|
|
||||||
inner_matcher_.DescribeNegationTo(os);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool MatchAndExplain(
|
|
||||||
StatusOrType actual_value,
|
|
||||||
::testing::MatchResultListener* result_listener) const override {
|
|
||||||
if (!actual_value.ok()) {
|
|
||||||
*result_listener << "which has status " << actual_value.status();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
::testing::StringMatchResultListener inner_listener;
|
|
||||||
const bool matches =
|
|
||||||
inner_matcher_.MatchAndExplain(*actual_value, &inner_listener);
|
|
||||||
const std::string inner_explanation = inner_listener.str();
|
|
||||||
if (!inner_explanation.empty()) {
|
|
||||||
*result_listener << "which contains value "
|
|
||||||
<< ::testing::PrintToString(*actual_value) << ", "
|
|
||||||
<< inner_explanation;
|
|
||||||
}
|
|
||||||
return matches;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
const ::testing::Matcher<const value_type&> inner_matcher_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Implements IsOkAndHolds(m) as a polymorphic matcher.
|
|
||||||
template <typename InnerMatcher>
|
|
||||||
class IsOkAndHoldsMatcher {
|
|
||||||
public:
|
|
||||||
explicit IsOkAndHoldsMatcher(InnerMatcher inner_matcher)
|
|
||||||
: inner_matcher_(std::move(inner_matcher)) {}
|
|
||||||
|
|
||||||
// Converts this polymorphic matcher to a monomorphic matcher of the
|
|
||||||
// given type. StatusOrType can be either StatusOr<T> or a
|
|
||||||
// reference to StatusOr<T>.
|
|
||||||
template <typename StatusOrType>
|
|
||||||
operator ::testing::Matcher<StatusOrType>() const { // NOLINT
|
|
||||||
return ::testing::Matcher<StatusOrType>(
|
|
||||||
new IsOkAndHoldsMatcherImpl<const StatusOrType&>(inner_matcher_));
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
const InnerMatcher inner_matcher_;
|
|
||||||
};
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////
|
|
||||||
// Implementation of StatusIs().
|
|
||||||
|
|
||||||
// StatusIs() is a polymorphic matcher. This class is the common
|
|
||||||
// implementation of it shared by all types T where StatusIs() can be
|
|
||||||
// used as a Matcher<T>.
|
|
||||||
class StatusIsMatcherCommonImpl {
|
|
||||||
public:
|
|
||||||
StatusIsMatcherCommonImpl(
|
|
||||||
::testing::Matcher<absl::StatusCode> code_matcher,
|
|
||||||
::testing::Matcher<const std::string&> message_matcher)
|
|
||||||
: code_matcher_(std::move(code_matcher)),
|
|
||||||
message_matcher_(std::move(message_matcher)) {}
|
|
||||||
|
|
||||||
void DescribeTo(std::ostream* os) const;
|
|
||||||
|
|
||||||
void DescribeNegationTo(std::ostream* os) const;
|
|
||||||
|
|
||||||
bool MatchAndExplain(const absl::Status& status,
|
|
||||||
::testing::MatchResultListener* result_listener) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const ::testing::Matcher<absl::StatusCode> code_matcher_;
|
|
||||||
const ::testing::Matcher<const std::string&> message_matcher_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Monomorphic implementation of matcher StatusIs() for a given type
|
|
||||||
// T. T can be Status, StatusOr<>, or a reference to either of them.
|
|
||||||
template <typename T>
|
|
||||||
class MonoStatusIsMatcherImpl : public ::testing::MatcherInterface<T> {
|
|
||||||
public:
|
|
||||||
explicit MonoStatusIsMatcherImpl(StatusIsMatcherCommonImpl common_impl)
|
|
||||||
: common_impl_(std::move(common_impl)) {}
|
|
||||||
|
|
||||||
void DescribeTo(std::ostream* os) const override {
|
|
||||||
common_impl_.DescribeTo(os);
|
|
||||||
}
|
|
||||||
|
|
||||||
void DescribeNegationTo(std::ostream* os) const override {
|
|
||||||
common_impl_.DescribeNegationTo(os);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool MatchAndExplain(
|
|
||||||
T actual_value,
|
|
||||||
::testing::MatchResultListener* result_listener) const override {
|
|
||||||
return common_impl_.MatchAndExplain(GetStatus(actual_value),
|
|
||||||
result_listener);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
StatusIsMatcherCommonImpl common_impl_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Implements StatusIs() as a polymorphic matcher.
|
|
||||||
class StatusIsMatcher {
|
|
||||||
public:
|
|
||||||
template <typename StatusCodeMatcher, typename StatusMessageMatcher>
|
|
||||||
StatusIsMatcher(StatusCodeMatcher&& code_matcher,
|
|
||||||
StatusMessageMatcher&& message_matcher)
|
|
||||||
: common_impl_(::testing::MatcherCast<absl::StatusCode>(
|
|
||||||
std::forward<StatusCodeMatcher>(code_matcher)),
|
|
||||||
::testing::MatcherCast<const std::string&>(
|
|
||||||
std::forward<StatusMessageMatcher>(message_matcher))) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// Converts this polymorphic matcher to a monomorphic matcher of the
|
|
||||||
// given type. T can be StatusOr<>, Status, or a reference to
|
|
||||||
// either of them.
|
|
||||||
template <typename T>
|
|
||||||
operator ::testing::Matcher<T>() const { // NOLINT
|
|
||||||
return ::testing::Matcher<T>(
|
|
||||||
new MonoStatusIsMatcherImpl<const T&>(common_impl_));
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
const StatusIsMatcherCommonImpl common_impl_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Monomorphic implementation of matcher IsOk() for a given type T.
|
|
||||||
// T can be Status, StatusOr<>, or a reference to either of them.
|
|
||||||
template <typename T>
|
|
||||||
class MonoIsOkMatcherImpl : public ::testing::MatcherInterface<T> {
|
|
||||||
public:
|
|
||||||
void DescribeTo(std::ostream* os) const override { *os << "is OK"; }
|
|
||||||
void DescribeNegationTo(std::ostream* os) const override {
|
|
||||||
*os << "is not OK";
|
|
||||||
}
|
|
||||||
bool MatchAndExplain(
|
|
||||||
T actual_value,
|
|
||||||
::testing::MatchResultListener* result_listener) const override {
|
|
||||||
if (!actual_value.ok()) {
|
|
||||||
*result_listener << "whose status is "
|
|
||||||
<< GetStatus(actual_value).message();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Implements IsOk() as a polymorphic matcher.
|
|
||||||
class IsOkMatcher {
|
|
||||||
public:
|
|
||||||
template <typename T>
|
|
||||||
operator ::testing::Matcher<T>() const { // NOLINT
|
|
||||||
return ::testing::Matcher<T>(new MonoIsOkMatcherImpl<const T&>());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Macros for testing the results of functions that return absl::Status or
|
|
||||||
// absl::StatusOr<T> (for any type T).
|
|
||||||
#define DPF_EXPECT_OK(expression) \
|
|
||||||
EXPECT_THAT(expression, distributed_point_functions::dpf_internal::IsOk())
|
|
||||||
#define DPF_ASSERT_OK(expression) \
|
|
||||||
ASSERT_THAT(expression, distributed_point_functions::dpf_internal::IsOk())
|
|
||||||
|
|
||||||
// Executes an expression that returns an absl::StatusOr, and assigns the
|
|
||||||
// contained variable to lhs if the error code is OK.
|
|
||||||
// If the Status is non-OK, generates a test failure and returns from the
|
|
||||||
// current function, which must have a void return type.
|
|
||||||
//
|
|
||||||
// Example: Declaring and initializing a new value
|
|
||||||
// DPF_ASSERT_OK_AND_ASSIGN(const ValueType& value, MaybeGetValue(arg));
|
|
||||||
//
|
|
||||||
// Example: Assigning to an existing value
|
|
||||||
// ValueType value;
|
|
||||||
// DPF_ASSERT_OK_AND_ASSIGN(value, MaybeGetValue(arg));
|
|
||||||
//
|
|
||||||
// The value assignment example would expand into something like:
|
|
||||||
// auto status_or_value = MaybeGetValue(arg);
|
|
||||||
// DPF_ASSERT_OK(status_or_value.status());
|
|
||||||
// value = std::move(status_or_value).ValueOrDie();
|
|
||||||
//
|
|
||||||
// WARNING: Like ASSIGN_OR_RETURN, DPF_ASSERT_OK_AND_ASSIGN expands into
|
|
||||||
// multiple statements; it cannot be used in a single statement (e.g. as the
|
|
||||||
// body of an if statement without {})!
|
|
||||||
#define DPF_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN_IMPL_( \
|
|
||||||
DPF_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr)
|
|
||||||
|
|
||||||
#define DPF_ASSERT_OK_AND_ASSIGN_IMPL_(statusor, lhs, rexpr) \
|
|
||||||
auto statusor = (rexpr); \
|
|
||||||
DPF_ASSERT_OK(statusor); \
|
|
||||||
lhs = std::move(statusor).value();
|
|
||||||
|
|
||||||
// Returns a gMock matcher that matches a StatusOr<> whose status is
|
|
||||||
// OK and whose value matches the inner matcher.
|
|
||||||
template <typename InnerMatcher>
|
|
||||||
dpf_internal::IsOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type>
|
|
||||||
IsOkAndHolds(InnerMatcher&& inner_matcher) {
|
|
||||||
return dpf_internal::IsOkAndHoldsMatcher<
|
|
||||||
typename std::decay<InnerMatcher>::type>(
|
|
||||||
std::forward<InnerMatcher>(inner_matcher));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a gMock matcher that matches a Status or StatusOr<> whose status code
|
|
||||||
// matches code_matcher, and whose error message matches message_matcher.
|
|
||||||
template <typename StatusCodeMatcher, typename StatusMessageMatcher>
|
|
||||||
dpf_internal::StatusIsMatcher StatusIs(StatusCodeMatcher&& code_matcher,
|
|
||||||
StatusMessageMatcher&& message_matcher) {
|
|
||||||
return dpf_internal::StatusIsMatcher(
|
|
||||||
std::forward<StatusCodeMatcher>(code_matcher),
|
|
||||||
std::forward<StatusMessageMatcher>(message_matcher));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a gMock matcher that matches a Status or StatusOr<> whose status code
|
|
||||||
// matches code_matcher.
|
|
||||||
template <typename StatusCodeMatcher>
|
|
||||||
dpf_internal::StatusIsMatcher StatusIs(StatusCodeMatcher&& code_matcher) {
|
|
||||||
return StatusIs(std::forward<StatusCodeMatcher>(code_matcher), ::testing::_);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a gMock matcher that matches a Status or StatusOr<> which is OK.
|
|
||||||
inline dpf_internal::IsOkMatcher IsOk() { return dpf_internal::IsOkMatcher(); }
|
|
||||||
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
|
|
||||||
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MATCHERS_H_
|
|
@ -1,169 +0,0 @@
|
|||||||
// Copyright 2021 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "dpf/internal/value_type_helpers.h"
|
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "dpf/distributed_point_function.pb.h"
|
|
||||||
#include "dpf/int_mod_n.h"
|
|
||||||
#include "dpf/status_macros.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
|
|
||||||
absl::StatusOr<bool> ValueTypesAreEqual(const ValueType& lhs,
|
|
||||||
const ValueType& rhs) {
|
|
||||||
if (lhs.type_case() == ValueType::TypeCase::TYPE_NOT_SET ||
|
|
||||||
rhs.type_case() == ValueType::TypeCase::TYPE_NOT_SET) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"Both arguments must be valid ValueTypes");
|
|
||||||
} else if (lhs.type_case() == ValueType::kInteger &&
|
|
||||||
rhs.type_case() == ValueType::kInteger) {
|
|
||||||
return lhs.integer().bitsize() == rhs.integer().bitsize();
|
|
||||||
} else if (lhs.type_case() == ValueType::kTuple &&
|
|
||||||
rhs.type_case() == ValueType::kTuple &&
|
|
||||||
lhs.tuple().elements_size() == rhs.tuple().elements_size()) {
|
|
||||||
bool result = true;
|
|
||||||
for (int i = 0; i < static_cast<int>(lhs.tuple().elements_size()); ++i) {
|
|
||||||
DPF_ASSIGN_OR_RETURN(
|
|
||||||
bool element_result,
|
|
||||||
ValueTypesAreEqual(lhs.tuple().elements(i), rhs.tuple().elements(i)));
|
|
||||||
result &= element_result;
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
} else if (lhs.type_case() == ValueType::kIntModN &&
|
|
||||||
rhs.type_case() == ValueType::kIntModN) {
|
|
||||||
const Value::Integer &lhs_modulus = lhs.int_mod_n().modulus(),
|
|
||||||
&rhs_modulus = rhs.int_mod_n().modulus();
|
|
||||||
DPF_ASSIGN_OR_RETURN(absl::uint128 lhs_modulus_128,
|
|
||||||
ValueIntegerToUint128(lhs_modulus));
|
|
||||||
DPF_ASSIGN_OR_RETURN(absl::uint128 rhs_modulus_128,
|
|
||||||
ValueIntegerToUint128(rhs_modulus));
|
|
||||||
return lhs.int_mod_n().base_integer().bitsize() ==
|
|
||||||
rhs.int_mod_n().base_integer().bitsize() &&
|
|
||||||
lhs_modulus_128 == rhs_modulus_128;
|
|
||||||
} else if (lhs.type_case() == ValueType::kXorWrapper &&
|
|
||||||
rhs.type_case() == ValueType::kXorWrapper) {
|
|
||||||
return lhs.xor_wrapper().bitsize() == rhs.xor_wrapper().bitsize();
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<int> BitsNeeded(const ValueType& value_type,
|
|
||||||
double security_parameter) {
|
|
||||||
if (value_type.type_case() == ValueType::kInteger) {
|
|
||||||
return value_type.integer().bitsize();
|
|
||||||
} else if (value_type.type_case() == ValueType::kTuple) {
|
|
||||||
// We handle elements of type IntModN separately, since we can sample them
|
|
||||||
// together.
|
|
||||||
int num_ints_mod_n = 0;
|
|
||||||
int num_other = 0;
|
|
||||||
const ValueType* int_mod_n = nullptr;
|
|
||||||
int bitsize_ints_mod_n = 0;
|
|
||||||
int bitsize_other = 0;
|
|
||||||
for (const ValueType& el : value_type.tuple().elements()) {
|
|
||||||
if (el.type_case() == ValueType::kIntModN) {
|
|
||||||
// Element is integer mod N -> check if it is the same as the others in
|
|
||||||
// this tuple and increase counter.
|
|
||||||
if (!int_mod_n) {
|
|
||||||
int_mod_n = ⪙
|
|
||||||
} else {
|
|
||||||
absl::StatusOr<bool> types_are_equal =
|
|
||||||
ValueTypesAreEqual(el, *int_mod_n);
|
|
||||||
if (!types_are_equal.ok()) {
|
|
||||||
return types_are_equal.status();
|
|
||||||
}
|
|
||||||
if (!*types_are_equal) {
|
|
||||||
return absl::UnimplementedError(
|
|
||||||
"All elements of type IntModN in a tuple must be the same");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
++num_ints_mod_n;
|
|
||||||
} else {
|
|
||||||
++num_other;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (num_other > 0) {
|
|
||||||
for (int i = 0; i < num_other; ++i) {
|
|
||||||
double per_element_security_parameter =
|
|
||||||
security_parameter + std::log2(static_cast<double>(num_other));
|
|
||||||
DPF_ASSIGN_OR_RETURN(int el_bitsize,
|
|
||||||
BitsNeeded(value_type.tuple().elements(i),
|
|
||||||
per_element_security_parameter));
|
|
||||||
bitsize_other += el_bitsize;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (num_ints_mod_n > 0) {
|
|
||||||
DPF_ASSIGN_OR_RETURN(
|
|
||||||
absl::uint128 modulus,
|
|
||||||
ValueIntegerToUint128(int_mod_n->int_mod_n().modulus()));
|
|
||||||
DPF_ASSIGN_OR_RETURN(
|
|
||||||
int64_t bytes_needed_ints_mod_n,
|
|
||||||
dpf_internal::IntModNBase::GetNumBytesRequired(
|
|
||||||
num_ints_mod_n, int_mod_n->int_mod_n().base_integer().bitsize(),
|
|
||||||
modulus, security_parameter));
|
|
||||||
bitsize_ints_mod_n = bytes_needed_ints_mod_n * 8;
|
|
||||||
}
|
|
||||||
return bitsize_ints_mod_n + bitsize_other;
|
|
||||||
} else if (value_type.type_case() == ValueType::kIntModN) {
|
|
||||||
DPF_ASSIGN_OR_RETURN(
|
|
||||||
absl::uint128 modulus,
|
|
||||||
ValueIntegerToUint128(value_type.int_mod_n().modulus()));
|
|
||||||
DPF_ASSIGN_OR_RETURN(int64_t bytes_needed_ints_mod_n,
|
|
||||||
dpf_internal::IntModNBase::GetNumBytesRequired(
|
|
||||||
1, value_type.int_mod_n().base_integer().bitsize(),
|
|
||||||
modulus, security_parameter));
|
|
||||||
return 8 * bytes_needed_ints_mod_n;
|
|
||||||
} else if (value_type.type_case() == ValueType::kXorWrapper) {
|
|
||||||
return value_type.xor_wrapper().bitsize();
|
|
||||||
}
|
|
||||||
return absl::InvalidArgumentError(absl::StrCat(
|
|
||||||
"BitsNeeded: Unsupported ValueType:\n", value_type.DebugString()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Integer Helpers
|
|
||||||
|
|
||||||
Value::Integer Uint128ToValueInteger(absl::uint128 input) {
|
|
||||||
Value::Integer result;
|
|
||||||
if (absl::Uint128High64(input) == 0) {
|
|
||||||
result.set_value_uint64(absl::Uint128Low64(input));
|
|
||||||
} else {
|
|
||||||
Block& block = *(result.mutable_value_uint128());
|
|
||||||
block.set_high(absl::Uint128High64(input));
|
|
||||||
block.set_low(absl::Uint128Low64(input));
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<absl::uint128> ValueIntegerToUint128(const Value::Integer& in) {
|
|
||||||
if (in.value_case() == Value::Integer::kValueUint128) {
|
|
||||||
return absl::MakeUint128(in.value_uint128().high(),
|
|
||||||
in.value_uint128().low());
|
|
||||||
} else if (in.value_case() == Value::Integer::kValueUint64) {
|
|
||||||
return in.value_uint64();
|
|
||||||
}
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"Unknown value case for the given integer Value");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
@ -1,673 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2021 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_VALUE_TYPE_HELPERS_H_
|
|
||||||
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_VALUE_TYPE_HELPERS_H_
|
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <array>
|
|
||||||
#include <limits>
|
|
||||||
#include <string>
|
|
||||||
#include <tuple>
|
|
||||||
#include <type_traits>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/base/config.h"
|
|
||||||
#include "absl/log/absl_check.h"
|
|
||||||
#include "absl/meta/type_traits.h"
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "absl/strings/str_format.h"
|
|
||||||
#include "absl/strings/string_view.h"
|
|
||||||
#include "absl/utility/utility.h"
|
|
||||||
#include "dpf/distributed_point_function.pb.h"
|
|
||||||
#include "dpf/int_mod_n.h"
|
|
||||||
#include "dpf/tuple.h"
|
|
||||||
#include "dpf/xor_wrapper.h"
|
|
||||||
#include "google/protobuf/repeated_field.h"
|
|
||||||
|
|
||||||
// Contains a collection of helper functions for different DPF value types. This
|
|
||||||
// includes functions for converting between Value protos and the corresponding
|
|
||||||
// C++ objects, as well as functions for sampling values from uniformly random
|
|
||||||
// byte strings.
|
|
||||||
//
|
|
||||||
// This file contains the templated declarations, instantiations for all
|
|
||||||
// supported types, as well as type-independent function declarations.
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
|
|
||||||
// A helper struct containing declarations for all templated functions we need.
|
|
||||||
// This is needed since C++ doesn't support partial function template
|
|
||||||
// specialization, and should be specialized for all supported types.
|
|
||||||
template <typename T, typename = void>
|
|
||||||
struct ValueTypeHelper {
|
|
||||||
// General type traits and conversion functions. Should be implemented by all
|
|
||||||
// types.
|
|
||||||
|
|
||||||
// Type trait for all supported types. Used to provide meaningful error
|
|
||||||
// messages in std::enable_if template guards.
|
|
||||||
static constexpr bool IsSupportedType() { return false; }
|
|
||||||
|
|
||||||
// Checks if the template parameter can be converted directly from a string of
|
|
||||||
// bytes.
|
|
||||||
static constexpr bool CanBeConvertedDirectly();
|
|
||||||
|
|
||||||
// Converts a given Value to the template parameter T.
|
|
||||||
static absl::StatusOr<T> FromValue(const Value& value);
|
|
||||||
|
|
||||||
// ToValue Converts the argument to a Value proto.
|
|
||||||
static Value ToValue(const T& input);
|
|
||||||
|
|
||||||
// ToValueType<T> Returns a `ValueType` message describing T.
|
|
||||||
static ValueType ToValueType();
|
|
||||||
|
|
||||||
// Functions for converting from a byte string to T. There are two approaches:
|
|
||||||
// Either converting directly (i.e., each byte is copied 1-to-1 into the
|
|
||||||
// result), or by sampling (when a direct conversion is not possible). Types
|
|
||||||
// for which CanBeConvertedDirectly() can be true should implement the former,
|
|
||||||
// and all types should implement the latter (to support types composed of
|
|
||||||
// directly-convertible and not-directly-convertible types).
|
|
||||||
|
|
||||||
// Functions for direct conversions from bytes. Should be implemented when
|
|
||||||
// CanBeConvertedDirectly() can be true.
|
|
||||||
|
|
||||||
// Returns the total number of bits in a T.
|
|
||||||
static constexpr int TotalBitSize();
|
|
||||||
|
|
||||||
static T DirectlyFromBytes(absl::string_view bytes);
|
|
||||||
|
|
||||||
// Functions for sampling from a string of bytes. Should be implemented by all
|
|
||||||
// types.
|
|
||||||
|
|
||||||
// Converts `block` to type T. Then, if `update == true`, fills up `block`
|
|
||||||
// from `remaining_bytes` and advances `remaining_bytes` by the amount of
|
|
||||||
// bytes read.
|
|
||||||
static T SampleAndUpdateBytes(bool update, absl::uint128& block,
|
|
||||||
absl::string_view& remaining_bytes);
|
|
||||||
};
|
|
||||||
|
|
||||||
/******************************************************************************/
|
|
||||||
// Type traits //
|
|
||||||
/******************************************************************************/
|
|
||||||
|
|
||||||
// Type trait for all supported types. Used to provide meaningful error messages
|
|
||||||
// in std::enable_if template guards.
|
|
||||||
template <typename T>
|
|
||||||
struct is_supported_type {
|
|
||||||
static constexpr bool value =
|
|
||||||
dpf_internal::ValueTypeHelper<T>::IsSupportedType();
|
|
||||||
};
|
|
||||||
template <typename T>
|
|
||||||
constexpr bool is_supported_type_v = is_supported_type<T>::value;
|
|
||||||
|
|
||||||
// Checks if the template parameter can be converted directly from a string of
|
|
||||||
// bytes.
|
|
||||||
template <typename T>
|
|
||||||
struct can_be_converted_directly {
|
|
||||||
static constexpr bool value =
|
|
||||||
dpf_internal::ValueTypeHelper<T>::CanBeConvertedDirectly();
|
|
||||||
};
|
|
||||||
template <typename T>
|
|
||||||
constexpr bool can_be_converted_directly_v =
|
|
||||||
can_be_converted_directly<T>::value;
|
|
||||||
|
|
||||||
// Returns the total number of bits in a T.
|
|
||||||
template <typename T,
|
|
||||||
typename = absl::enable_if_t<can_be_converted_directly_v<T>>>
|
|
||||||
static constexpr int TotalBitSize() {
|
|
||||||
return ValueTypeHelper<T>::TotalBitSize();
|
|
||||||
}
|
|
||||||
|
|
||||||
/******************************************************************************/
|
|
||||||
// Integer Helpers //
|
|
||||||
/******************************************************************************/
|
|
||||||
|
|
||||||
// Type trait for all integer types we support, i.e., 8 to 128 bit types.
|
|
||||||
template <typename T>
|
|
||||||
using is_unsigned_integer =
|
|
||||||
absl::disjunction<std::is_same<T, uint8_t>, std::is_same<T, uint16_t>,
|
|
||||||
std::is_same<T, uint32_t>, std::is_same<T, uint64_t>,
|
|
||||||
#ifdef ABSL_HAVE_INTRINSIC_INT128
|
|
||||||
std::is_same<T, unsigned __int128>,
|
|
||||||
#endif
|
|
||||||
std::is_same<T, absl::uint128>>;
|
|
||||||
template <typename T>
|
|
||||||
constexpr bool is_unsigned_integer_v = is_unsigned_integer<T>::value;
|
|
||||||
|
|
||||||
// Converts the given Value::Integer to an absl::uint128. Used as a helper
|
|
||||||
// function in `ConvertValueTo` and `ValueTypesAreEqual`.
|
|
||||||
//
|
|
||||||
// Returns INVALID_ARGUMENT if `in` is not a simple integer or IntModN.
|
|
||||||
absl::StatusOr<absl::uint128> ValueIntegerToUint128(const Value::Integer& in);
|
|
||||||
|
|
||||||
// Converts an absl::uint128 to a Value::Integer. Used as a helper function in
|
|
||||||
// ToValue.
|
|
||||||
Value::Integer Uint128ToValueInteger(absl::uint128 input);
|
|
||||||
|
|
||||||
// Checks if the given value is in range of T, and if so, returns it converted
|
|
||||||
// to T.
|
|
||||||
//
|
|
||||||
// Otherwise returns INVALID_ARGUMENT.
|
|
||||||
template <typename T, typename = absl::enable_if_t<is_unsigned_integer_v<T>>>
|
|
||||||
absl::StatusOr<T> Uint128To(absl::uint128 in) {
|
|
||||||
// Check whether value is in range if it's smaller than 128 bits.
|
|
||||||
if (!std::is_same<T, absl::uint128>::value &&
|
|
||||||
absl::Uint128Low64(in) >
|
|
||||||
static_cast<uint64_t>(std::numeric_limits<T>::max())) {
|
|
||||||
return absl::InvalidArgumentError(absl::StrCat(
|
|
||||||
"Value (= ", absl::Uint128Low64(in),
|
|
||||||
") too large for the given type T (size ", sizeof(T), ")"));
|
|
||||||
}
|
|
||||||
return static_cast<T>(in);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implementation of ValueTypeHelper for integers.
|
|
||||||
template <typename T>
|
|
||||||
struct ValueTypeHelper<T, absl::enable_if_t<is_unsigned_integer_v<T>>> {
|
|
||||||
static constexpr bool IsSupportedType() { return true; }
|
|
||||||
|
|
||||||
static constexpr bool CanBeConvertedDirectly() { return true; }
|
|
||||||
|
|
||||||
static absl::StatusOr<T> FromValue(const Value& value) {
|
|
||||||
if (value.value_case() != Value::kInteger) {
|
|
||||||
return absl::InvalidArgumentError("The given Value is not an integer");
|
|
||||||
}
|
|
||||||
// We first parse the value into an absl::uint128, then check its range if
|
|
||||||
// it is supposed to be smaller than 128 bits.
|
|
||||||
absl::StatusOr<absl::uint128> value_128 =
|
|
||||||
ValueIntegerToUint128(value.integer());
|
|
||||||
if (!value_128.ok()) {
|
|
||||||
return value_128.status();
|
|
||||||
}
|
|
||||||
return Uint128To<T>(*value_128);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value ToValue(T input) {
|
|
||||||
Value result;
|
|
||||||
*(result.mutable_integer()) = Uint128ToValueInteger(input);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static ValueType ToValueType() {
|
|
||||||
ValueType result;
|
|
||||||
result.mutable_integer()->set_bitsize(8 * sizeof(T));
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr int TotalBitSize() { return sizeof(T) * 8; }
|
|
||||||
|
|
||||||
static T DirectlyFromBytes(absl::string_view bytes) {
|
|
||||||
ABSL_CHECK(bytes.size() == sizeof(T));
|
|
||||||
T out{0};
|
|
||||||
#ifdef ABSL_IS_LITTLE_ENDIAN
|
|
||||||
std::copy_n(bytes.begin(), sizeof(T), reinterpret_cast<char*>(&out));
|
|
||||||
#else
|
|
||||||
for (int i = sizeof(T) - 1; i >= 0; --i) {
|
|
||||||
out |= absl::bit_cast<uint8_t>(bytes[i]);
|
|
||||||
out <<= 8;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
static T SampleAndUpdateBytes(bool update, absl::uint128& block,
|
|
||||||
absl::string_view& remaining_bytes) {
|
|
||||||
T result = static_cast<T>(block);
|
|
||||||
|
|
||||||
if (update) {
|
|
||||||
// Set sizeof(T) least significant bytes to 0.
|
|
||||||
if (sizeof(T) < sizeof(block)) {
|
|
||||||
constexpr absl::uint128 mask =
|
|
||||||
~absl::uint128{std::numeric_limits<T>::max()};
|
|
||||||
block &= mask;
|
|
||||||
} else {
|
|
||||||
block = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fill up with `bytes` and advance `bytes` by sizeof(T).
|
|
||||||
ABSL_DCHECK(remaining_bytes.size() >= sizeof(T));
|
|
||||||
block |= DirectlyFromBytes(remaining_bytes.substr(0, sizeof(T)));
|
|
||||||
remaining_bytes = remaining_bytes.substr(sizeof(T));
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/******************************************************************************/
|
|
||||||
// IntModN Helpers //
|
|
||||||
/******************************************************************************/
|
|
||||||
|
|
||||||
template <typename BaseInteger, typename ModulusType, ModulusType kModulus>
|
|
||||||
struct ValueTypeHelper<
|
|
||||||
dpf_internal::IntModNImpl<BaseInteger, ModulusType, kModulus>, void> {
|
|
||||||
using IntModNType =
|
|
||||||
dpf_internal::IntModNImpl<BaseInteger, ModulusType, kModulus>;
|
|
||||||
|
|
||||||
static constexpr bool IsSupportedType() {
|
|
||||||
return is_unsigned_integer_v<BaseInteger> &&
|
|
||||||
is_unsigned_integer_v<ModulusType>;
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr bool CanBeConvertedDirectly() { return false; }
|
|
||||||
|
|
||||||
static absl::StatusOr<IntModNType> FromValue(const Value& value) {
|
|
||||||
if (value.value_case() != Value::kIntModN) {
|
|
||||||
return absl::InvalidArgumentError("The given Value is not an IntModN");
|
|
||||||
}
|
|
||||||
absl::StatusOr<absl::uint128> value_128 =
|
|
||||||
ValueIntegerToUint128(value.int_mod_n());
|
|
||||||
if (!value_128.ok()) {
|
|
||||||
return value_128.status();
|
|
||||||
}
|
|
||||||
if (*value_128 >= absl::uint128{kModulus}) {
|
|
||||||
return absl::InvalidArgumentError(absl::StrFormat(
|
|
||||||
"The given value (= %d) is larger than kModulus (= %d)", *value_128,
|
|
||||||
absl::uint128{kModulus}));
|
|
||||||
}
|
|
||||||
return IntModNType(static_cast<BaseInteger>(*value_128));
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value ToValue(IntModNType input) {
|
|
||||||
Value result;
|
|
||||||
*(result.mutable_int_mod_n()) = Uint128ToValueInteger(input.value());
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static ValueType ToValueType() {
|
|
||||||
ValueType result;
|
|
||||||
*(result.mutable_int_mod_n()->mutable_base_integer()) =
|
|
||||||
ValueTypeHelper<BaseInteger>::ToValueType().integer();
|
|
||||||
*(result.mutable_int_mod_n()->mutable_modulus()) =
|
|
||||||
ValueTypeHelper<ModulusType>::ToValue(kModulus).integer();
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static IntModNType SampleAndUpdateBytes(bool update, absl::uint128& block,
|
|
||||||
absl::string_view& remaining_bytes) {
|
|
||||||
// Optimization for native uint128. This is equivalent to what's done in
|
|
||||||
// int128.cc, but since division is not defined in the header, the compiler
|
|
||||||
// cannot optimize the division and modulus into a single operation.
|
|
||||||
#ifdef ABSL_HAVE_INTRINSIC_INT128
|
|
||||||
absl::uint128 quotient = static_cast<unsigned __int128>(block) / kModulus,
|
|
||||||
remainder = static_cast<unsigned __int128>(block) % kModulus;
|
|
||||||
#else
|
|
||||||
absl::uint128 quotient = block / kModulus, remainder = block % kModulus;
|
|
||||||
#endif
|
|
||||||
IntModNType result(static_cast<BaseInteger>(remainder));
|
|
||||||
|
|
||||||
if (update) {
|
|
||||||
if (sizeof(BaseInteger) < sizeof(block)) {
|
|
||||||
block = quotient << (sizeof(BaseInteger) * 8);
|
|
||||||
} else {
|
|
||||||
block = 0;
|
|
||||||
}
|
|
||||||
block |= ValueTypeHelper<BaseInteger>::DirectlyFromBytes(
|
|
||||||
remaining_bytes.substr(0, sizeof(BaseInteger)));
|
|
||||||
remaining_bytes = remaining_bytes.substr(sizeof(BaseInteger));
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/******************************************************************************/
|
|
||||||
// Tuple Helpers //
|
|
||||||
/******************************************************************************/
|
|
||||||
|
|
||||||
// Helper struct for computing the bit size of a tuple type at compile time
|
|
||||||
// without C++17 fold expressions.
|
|
||||||
template <typename FirstElementType, typename... ElementType>
|
|
||||||
struct TupleBitSizeHelper {
|
|
||||||
static constexpr int TotalBitSize() {
|
|
||||||
return TupleBitSizeHelper<FirstElementType>::TotalBitSize() +
|
|
||||||
TupleBitSizeHelper<ElementType...>::TotalBitSize();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
template <typename ElementType>
|
|
||||||
struct TupleBitSizeHelper<ElementType> {
|
|
||||||
static constexpr int TotalBitSize() {
|
|
||||||
return ValueTypeHelper<ElementType>::TotalBitSize();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename... ElementType>
|
|
||||||
struct ValueTypeHelper<Tuple<ElementType...>, void> {
|
|
||||||
using TupleType = Tuple<ElementType...>;
|
|
||||||
|
|
||||||
static constexpr bool IsSupportedType() {
|
|
||||||
return absl::conjunction<is_supported_type<ElementType>...>::value;
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr bool CanBeConvertedDirectly() {
|
|
||||||
return absl::conjunction<can_be_converted_directly<ElementType>...>::value;
|
|
||||||
}
|
|
||||||
|
|
||||||
static absl::StatusOr<TupleType> FromValue(const Value& value) {
|
|
||||||
if (value.value_case() != Value::kTuple) {
|
|
||||||
return absl::InvalidArgumentError("The given Value is not a tuple");
|
|
||||||
}
|
|
||||||
constexpr auto tuple_size =
|
|
||||||
static_cast<int>(std::tuple_size<typename TupleType::Base>());
|
|
||||||
if (value.tuple().elements_size() != tuple_size) {
|
|
||||||
return absl::InvalidArgumentError(
|
|
||||||
"The tuple in the given Value has the wrong number of elements");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a Tuple by unpacking value.tuple().elements(). If we encounter an
|
|
||||||
// error, return it at the end.
|
|
||||||
absl::Status status = absl::OkStatus();
|
|
||||||
int element_index = 0;
|
|
||||||
// The braced initializer list ensures elements are created in the correct
|
|
||||||
// order (unlike std::make_tuple).
|
|
||||||
TupleType result = {[&value, &status, &element_index] {
|
|
||||||
if (status.ok()) {
|
|
||||||
absl::StatusOr<ElementType> element =
|
|
||||||
ValueTypeHelper<ElementType>::FromValue(
|
|
||||||
value.tuple().elements(element_index));
|
|
||||||
element_index++;
|
|
||||||
if (element.ok()) {
|
|
||||||
return *element;
|
|
||||||
} else {
|
|
||||||
status = element.status();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ElementType{};
|
|
||||||
}()...};
|
|
||||||
if (status.ok()) {
|
|
||||||
return result;
|
|
||||||
} else {
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value ToValue(const TupleType& input) {
|
|
||||||
Value result;
|
|
||||||
absl::apply(
|
|
||||||
[&result](const ElementType&... elements) {
|
|
||||||
// Create an unused std::tuple to iterate over `elements` in its
|
|
||||||
// constructor. This can be replaced by a fold expression in C++17.
|
|
||||||
std::tuple<ElementType...>{
|
|
||||||
(*(result.mutable_tuple()->add_elements()) =
|
|
||||||
ValueTypeHelper<ElementType>::ToValue(elements),
|
|
||||||
ElementType{})...};
|
|
||||||
},
|
|
||||||
input.value());
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static ValueType ToValueType() {
|
|
||||||
ValueType result;
|
|
||||||
ValueType::Tuple* tuple = result.mutable_tuple();
|
|
||||||
// Create an unused std::tuple to iterate over `elements` in its
|
|
||||||
// constructor. This can be replaced by a fold expression in C++17.
|
|
||||||
std::tuple<ElementType...>{
|
|
||||||
(*(tuple->add_elements()) = ValueTypeHelper<ElementType>::ToValueType(),
|
|
||||||
ElementType{})...};
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr int TotalBitSize() {
|
|
||||||
// This helper can be replaced by a fold expression in C++17.
|
|
||||||
return TupleBitSizeHelper<ElementType...>::TotalBitSize();
|
|
||||||
}
|
|
||||||
|
|
||||||
static TupleType DirectlyFromBytes(absl::string_view bytes) {
|
|
||||||
ABSL_CHECK(8 * bytes.size() >= TotalBitSize());
|
|
||||||
int offset = 0;
|
|
||||||
absl::Status status = absl::OkStatus();
|
|
||||||
// Braced-init-list ensures the elements are constructed in-order.
|
|
||||||
return TupleType{[&bytes, &offset, &status] {
|
|
||||||
constexpr int element_size_bytes =
|
|
||||||
(ValueTypeHelper<ElementType>::TotalBitSize() + 7) / 8;
|
|
||||||
ElementType element = ValueTypeHelper<ElementType>::DirectlyFromBytes(
|
|
||||||
bytes.substr(offset, element_size_bytes));
|
|
||||||
offset += element_size_bytes;
|
|
||||||
return element;
|
|
||||||
}()...};
|
|
||||||
}
|
|
||||||
|
|
||||||
static TupleType SampleAndUpdateBytes(bool update, absl::uint128& block,
|
|
||||||
absl::string_view& remaining_bytes) {
|
|
||||||
int element_counter = 0;
|
|
||||||
// Braced-init-list ensures the elements are constructed in-order.
|
|
||||||
return TupleType{[update, &element_counter, &block,
|
|
||||||
&remaining_bytes]() -> ElementType {
|
|
||||||
// If `update` is true, update after all elements. Otherwise, don't update
|
|
||||||
// after the last one.
|
|
||||||
constexpr int num_elements = std::tuple_size<typename TupleType::Base>();
|
|
||||||
bool update2 = update || (++element_counter < num_elements);
|
|
||||||
return ValueTypeHelper<ElementType>::SampleAndUpdateBytes(
|
|
||||||
update2, block, remaining_bytes);
|
|
||||||
}()...};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/******************************************************************************/
|
|
||||||
// XorWrapper Helpers //
|
|
||||||
/******************************************************************************/
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct ValueTypeHelper<XorWrapper<T>, void> {
|
|
||||||
static constexpr bool IsSupportedType() {
|
|
||||||
return ValueTypeHelper<T>::IsSupportedType();
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr bool CanBeConvertedDirectly() {
|
|
||||||
return ValueTypeHelper<T>::CanBeConvertedDirectly();
|
|
||||||
}
|
|
||||||
|
|
||||||
static absl::StatusOr<XorWrapper<T>> FromValue(const Value& value) {
|
|
||||||
absl::StatusOr<absl::uint128> wrapped128 =
|
|
||||||
ValueIntegerToUint128(value.xor_wrapper());
|
|
||||||
if (!wrapped128.ok()) {
|
|
||||||
return wrapped128.status();
|
|
||||||
}
|
|
||||||
absl::StatusOr<T> wrapped = Uint128To<T>(*wrapped128);
|
|
||||||
if (!wrapped.ok()) {
|
|
||||||
return wrapped.status();
|
|
||||||
}
|
|
||||||
return XorWrapper<T>(*wrapped);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value ToValue(const XorWrapper<T>& input) {
|
|
||||||
Value result;
|
|
||||||
*(result.mutable_xor_wrapper()) = Uint128ToValueInteger(input.value());
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static ValueType ToValueType() {
|
|
||||||
ValueType result;
|
|
||||||
*(result.mutable_xor_wrapper()) =
|
|
||||||
ValueTypeHelper<T>::ToValueType().integer();
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr int TotalBitSize() {
|
|
||||||
return ValueTypeHelper<T>::TotalBitSize();
|
|
||||||
}
|
|
||||||
|
|
||||||
static XorWrapper<T> DirectlyFromBytes(absl::string_view bytes) {
|
|
||||||
return XorWrapper<T>(ValueTypeHelper<T>::DirectlyFromBytes(bytes));
|
|
||||||
}
|
|
||||||
|
|
||||||
static XorWrapper<T> SampleAndUpdateBytes(
|
|
||||||
bool update, absl::uint128& block, absl::string_view& remaining_bytes) {
|
|
||||||
return XorWrapper<T>(ValueTypeHelper<T>::SampleAndUpdateBytes(
|
|
||||||
update, block, remaining_bytes));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/******************************************************************************/
|
|
||||||
// Free standing helpers. These should always come last. When adding //
|
|
||||||
// additional types, add them above. //
|
|
||||||
/******************************************************************************/
|
|
||||||
|
|
||||||
// Computes the number of values of type T that fit into an absl::uint128.
|
|
||||||
// Returns a value >= 1 if batching is supported, and 1 otherwise.
|
|
||||||
template <typename T,
|
|
||||||
absl::enable_if_t<can_be_converted_directly_v<T>, int> = 0>
|
|
||||||
constexpr int ElementsPerBlock() {
|
|
||||||
if (TotalBitSize<T>() <= 128) {
|
|
||||||
return static_cast<int>(8 * sizeof(absl::uint128)) / TotalBitSize<T>();
|
|
||||||
}
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
template <typename T,
|
|
||||||
absl::enable_if_t<!can_be_converted_directly_v<T>, int> = 0>
|
|
||||||
constexpr int ElementsPerBlock() {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creates a value of type T from the given `bytes`. If possible, converts bytes
|
|
||||||
// directly using DirectlyFromBytes. Otherwise, uses SampleAndUpdateBytes.
|
|
||||||
//
|
|
||||||
// Crashes if `bytes.size()` is too small for the output type.
|
|
||||||
template <typename T,
|
|
||||||
absl::enable_if_t<can_be_converted_directly_v<T>, int> = 0>
|
|
||||||
T FromBytes(absl::string_view bytes) {
|
|
||||||
return ValueTypeHelper<T>::DirectlyFromBytes(bytes);
|
|
||||||
}
|
|
||||||
template <typename T,
|
|
||||||
absl::enable_if_t<!can_be_converted_directly_v<T>, int> = 0>
|
|
||||||
T FromBytes(absl::string_view bytes) {
|
|
||||||
absl::uint128 block =
|
|
||||||
FromBytes<absl::uint128>(bytes.substr(0, sizeof(absl::uint128)));
|
|
||||||
bytes = bytes.substr(sizeof(absl::uint128));
|
|
||||||
return ValueTypeHelper<T>::SampleAndUpdateBytes(false, block, bytes);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Converts a `repeated Value` proto field to a std::array with element type T.
|
|
||||||
//
|
|
||||||
// Returns INVALID_ARGUMENT in case the input has the wrong size, or if the
|
|
||||||
// conversion fails.
|
|
||||||
template <typename T>
|
|
||||||
absl::StatusOr<std::array<T, ElementsPerBlock<T>()>> ValuesToArray(
|
|
||||||
const ::google::protobuf::RepeatedPtrField<Value>& values) {
|
|
||||||
if (values.size() != ElementsPerBlock<T>()) {
|
|
||||||
return absl::InvalidArgumentError(absl::StrCat(
|
|
||||||
"values.size() (= ", values.size(),
|
|
||||||
") does not match ElementsPerBlock<T>() (= ", ElementsPerBlock<T>(),
|
|
||||||
")"));
|
|
||||||
}
|
|
||||||
std::array<T, ElementsPerBlock<T>()> result;
|
|
||||||
for (int i = 0; i < ElementsPerBlock<T>(); ++i) {
|
|
||||||
absl::StatusOr<T> element = ValueTypeHelper<T>::FromValue(values[i]);
|
|
||||||
if (element.ok()) {
|
|
||||||
result[i] = std::move(*element);
|
|
||||||
} else {
|
|
||||||
return element.status();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Converts a given string to an array of exactly ElementsPerBlock<T>() elements
|
|
||||||
// of type T.
|
|
||||||
//
|
|
||||||
// Crashes if `bytes.size()` is too small for the output type.
|
|
||||||
template <typename T,
|
|
||||||
absl::enable_if_t<can_be_converted_directly_v<T>, int> = 0>
|
|
||||||
std::array<T, ElementsPerBlock<T>()> ConvertBytesToArrayOf(
|
|
||||||
absl::string_view bytes) {
|
|
||||||
std::array<T, ElementsPerBlock<T>()> out;
|
|
||||||
const int element_size_bytes = (TotalBitSize<T>() + 7) / 8;
|
|
||||||
ABSL_CHECK(bytes.size() >= ElementsPerBlock<T>() * element_size_bytes);
|
|
||||||
for (int i = 0; i < ElementsPerBlock<T>(); ++i) {
|
|
||||||
out[i] =
|
|
||||||
FromBytes<T>(bytes.substr(i * element_size_bytes, element_size_bytes));
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
template <typename T,
|
|
||||||
absl::enable_if_t<!can_be_converted_directly_v<T>, int> = 0>
|
|
||||||
std::array<T, ElementsPerBlock<T>()> ConvertBytesToArrayOf(
|
|
||||||
absl::string_view bytes) {
|
|
||||||
static_assert(ElementsPerBlock<T>() == 1,
|
|
||||||
"T does not support batching, but ElementsPerBlock<T> != 1");
|
|
||||||
return {FromBytes<T>(bytes)};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes the value correction word given two seeds `seed_a`, `seed_b` for
|
|
||||||
// parties a and b, such that the element at `block_index` is equal to `beta`.
|
|
||||||
// If `invert` is true, the result is multiplied element-wise by -1. Templated
|
|
||||||
// to use the correct integer type without needing modular reduction.
|
|
||||||
//
|
|
||||||
// Returns multiple values in case of packing, and a single value otherwise.
|
|
||||||
template <typename T>
|
|
||||||
absl::StatusOr<std::vector<Value>> ComputeValueCorrectionFor(
|
|
||||||
absl::string_view seed_a, absl::string_view seed_b, int block_index,
|
|
||||||
const Value& beta, bool invert) {
|
|
||||||
absl::StatusOr<T> beta_T = ValueTypeHelper<T>::FromValue(beta);
|
|
||||||
if (!beta_T.ok()) {
|
|
||||||
return beta_T.status();
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr int elements_per_block = ElementsPerBlock<T>();
|
|
||||||
|
|
||||||
// Compute values from seeds. Both arrays will have multiple elements if T
|
|
||||||
// supports batching, and a single one otherwise.
|
|
||||||
std::array<T, elements_per_block> ints_a = ConvertBytesToArrayOf<T>(seed_a),
|
|
||||||
ints_b = ConvertBytesToArrayOf<T>(seed_b);
|
|
||||||
|
|
||||||
// Add beta to the right position.
|
|
||||||
ints_b[block_index] += *beta_T;
|
|
||||||
|
|
||||||
// Add up shares, invert if needed.
|
|
||||||
for (int i = 0; i < elements_per_block; i++) {
|
|
||||||
ints_b[i] = ints_b[i] - ints_a[i];
|
|
||||||
if (invert) {
|
|
||||||
ints_b[i] = -ints_b[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to a vector of Value protos and return.
|
|
||||||
std::vector<Value> result;
|
|
||||||
result.reserve(ints_b.size());
|
|
||||||
for (const T& element : ints_b) {
|
|
||||||
result.push_back(ValueTypeHelper<T>::ToValue(element));
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes the number of pseudorandom bits needed to get a uniform element of
|
|
||||||
// the given `ValueType`. For types whose elements can be bijectively mapped to
|
|
||||||
// strings (e.g., unsigned integers and tuples of integers), this is equivalent
|
|
||||||
// to the bit size of the value type. For all other types, returns the number of
|
|
||||||
// bits needed so that converting a uniform string with the given number of bits
|
|
||||||
// to an element of `value_type` results in a distribution with total variation
|
|
||||||
// distance < 2^(-`security_parameter`) from uniform.
|
|
||||||
//
|
|
||||||
// Returns INVALID_ARGUMENT in case value_type does not represent a known type,
|
|
||||||
// or if sampling with the required security parameter is not possible.
|
|
||||||
absl::StatusOr<int> BitsNeeded(const ValueType& value_type,
|
|
||||||
double security_parameter);
|
|
||||||
|
|
||||||
// Returns `true` if `lhs` and `rhs` describe the same types, and `false`
|
|
||||||
// otherwise.
|
|
||||||
//
|
|
||||||
// Returns INVALID_ARGUMENT if an error occurs while parsing either argument.
|
|
||||||
absl::StatusOr<bool> ValueTypesAreEqual(const ValueType& lhs,
|
|
||||||
const ValueType& rhs);
|
|
||||||
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
|
|
||||||
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_VALUE_TYPE_HELPERS_H_
|
|
@ -1,359 +0,0 @@
|
|||||||
// Copyright 2021 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "dpf/internal/value_type_helpers.h"
|
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#include <array>
|
|
||||||
#include <string>
|
|
||||||
#include <tuple>
|
|
||||||
|
|
||||||
#include "absl/base/config.h"
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "absl/status/status.h"
|
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "dpf/distributed_point_function.pb.h"
|
|
||||||
#include "dpf/int_mod_n.h"
|
|
||||||
#include "dpf/internal/status_matchers.h"
|
|
||||||
#include "dpf/tuple.h"
|
|
||||||
#include "gmock/gmock.h"
|
|
||||||
#include "gtest/gtest.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
namespace dpf_internal {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
constexpr int kDefaultSecurityParameter = 40;
|
|
||||||
|
|
||||||
TEST(ValueTypeHelperTest, ValueTypesAreEqualFailsOnInvalidValueTypes) {
|
|
||||||
ValueType type1, type2;
|
|
||||||
|
|
||||||
EXPECT_THAT(ValueTypesAreEqual(type1, type2),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"Both arguments must be valid ValueTypes"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ValueTypeHelperTest, BitsNeededFailsOnInvalidValueType) {
|
|
||||||
EXPECT_THAT(
|
|
||||||
BitsNeeded(ValueType{}, kDefaultSecurityParameter),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
testing::StartsWith("BitsNeeded: Unsupported ValueType")));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class ValueTypeIntegerTest : public testing::Test {};
|
|
||||||
using IntegerTypes =
|
|
||||||
::testing::Types<uint8_t, uint16_t, uint32_t, uint64_t, absl::uint128>;
|
|
||||||
TYPED_TEST_SUITE(ValueTypeIntegerTest, IntegerTypes);
|
|
||||||
|
|
||||||
TYPED_TEST(ValueTypeIntegerTest, ToValueTypeIntegers) {
|
|
||||||
ValueType value_type = ValueTypeHelper<TypeParam>::ToValueType();
|
|
||||||
|
|
||||||
EXPECT_TRUE(value_type.has_integer());
|
|
||||||
EXPECT_EQ(value_type.integer().bitsize(), sizeof(TypeParam) * 8);
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(ValueTypeIntegerTest, TestValueTypesAreEqual) {
|
|
||||||
ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(),
|
|
||||||
value_type_2;
|
|
||||||
value_type_2.mutable_integer()->set_bitsize(sizeof(TypeParam) * 8);
|
|
||||||
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(bool equal,
|
|
||||||
ValueTypesAreEqual(value_type_1, value_type_2));
|
|
||||||
EXPECT_TRUE(equal);
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(equal,
|
|
||||||
ValueTypesAreEqual(value_type_2, value_type_1));
|
|
||||||
EXPECT_TRUE(equal);
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(ValueTypeIntegerTest, TestValueTypesAreNotEqual) {
|
|
||||||
ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(),
|
|
||||||
value_type_2;
|
|
||||||
value_type_2.mutable_integer()->set_bitsize(sizeof(TypeParam) * 8 * 2);
|
|
||||||
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(bool equal,
|
|
||||||
ValueTypesAreEqual(value_type_1, value_type_2));
|
|
||||||
EXPECT_FALSE(equal);
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(equal,
|
|
||||||
ValueTypesAreEqual(value_type_2, value_type_1));
|
|
||||||
EXPECT_FALSE(equal);
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(ValueTypeIntegerTest, ValueConversionFailsIfNotInteger) {
|
|
||||||
Value value;
|
|
||||||
value.mutable_tuple();
|
|
||||||
|
|
||||||
EXPECT_THAT(ValueTypeHelper<TypeParam>::FromValue(value),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"The given Value is not an integer"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(ValueTypeIntegerTest, ValueConversionFailsIfInvalidIntegerCase) {
|
|
||||||
Value value;
|
|
||||||
value.mutable_integer();
|
|
||||||
|
|
||||||
EXPECT_THAT(ValueTypeHelper<TypeParam>::FromValue(value),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"Unknown value case for the given integer Value"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(ValueTypeIntegerTest, ValueConversionFailsIfValueOutOfRange) {
|
|
||||||
Value value;
|
|
||||||
auto value_64 = uint64_t{1} << 32;
|
|
||||||
value.mutable_integer()->set_value_uint64(value_64);
|
|
||||||
|
|
||||||
if (sizeof(TypeParam) >= sizeof(uint64_t)) {
|
|
||||||
DPF_EXPECT_OK(ValueTypeHelper<TypeParam>::FromValue(value));
|
|
||||||
} else {
|
|
||||||
EXPECT_THAT(ValueTypeHelper<TypeParam>::FromValue(value),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
absl::StrCat("Value (= ", value_64,
|
|
||||||
") too large for the given type T (size ",
|
|
||||||
sizeof(TypeParam), ")")));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class ValueTypeTupleTest : public testing::Test {};
|
|
||||||
|
|
||||||
template <typename T, int... bits>
|
|
||||||
struct TupleTestParam {
|
|
||||||
using Tuple = T;
|
|
||||||
static constexpr int ExpectedNumElements() { return sizeof...(bits); };
|
|
||||||
static constexpr std::array<int, ExpectedNumElements()> ExpectedBitSizes() {
|
|
||||||
return {bits...};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// We only test tuples consisting of integers here.
|
|
||||||
using TupleTypes = ::testing::Types<
|
|
||||||
TupleTestParam<Tuple<uint64_t>, 64>,
|
|
||||||
TupleTestParam<Tuple<uint64_t, uint64_t>, 64, 64>,
|
|
||||||
TupleTestParam<Tuple<uint32_t, absl::uint128, uint8_t>, 32, 128, 8>,
|
|
||||||
TupleTestParam<Tuple<uint8_t, uint8_t, uint8_t, uint8_t>, 8, 8, 8, 8>>;
|
|
||||||
TYPED_TEST_SUITE(ValueTypeTupleTest, TupleTypes);
|
|
||||||
|
|
||||||
TYPED_TEST(ValueTypeTupleTest, ToValueTypeTuples) {
|
|
||||||
ValueType value_type =
|
|
||||||
ValueTypeHelper<typename TypeParam::Tuple>::ToValueType();
|
|
||||||
|
|
||||||
constexpr int expected_num_elements = TypeParam::ExpectedNumElements();
|
|
||||||
EXPECT_TRUE(value_type.has_tuple());
|
|
||||||
ASSERT_EQ(std::tuple_size<typename TypeParam::Tuple::Base>(),
|
|
||||||
expected_num_elements); // Sanity check for test parameters.
|
|
||||||
EXPECT_EQ(value_type.tuple().elements_size(), expected_num_elements);
|
|
||||||
|
|
||||||
std::array<int, expected_num_elements> expected_bit_sizes =
|
|
||||||
TypeParam::ExpectedBitSizes();
|
|
||||||
for (int i = 0; i < expected_num_elements; ++i) {
|
|
||||||
EXPECT_TRUE(value_type.tuple().elements(i).has_integer());
|
|
||||||
EXPECT_EQ(value_type.tuple().elements(i).integer().bitsize(),
|
|
||||||
expected_bit_sizes[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(ValueTypeTupleTest, BitsNeededEqualsCompileTimeTypeSize) {
|
|
||||||
ValueType value_type =
|
|
||||||
ValueTypeHelper<typename TypeParam::Tuple>::ToValueType();
|
|
||||||
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(int bitsize,
|
|
||||||
BitsNeeded(value_type, kDefaultSecurityParameter));
|
|
||||||
|
|
||||||
EXPECT_EQ(bitsize, TotalBitSize<typename TypeParam::Tuple>());
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(ValueTypeTupleTest, ValueConversionFailsIfValueIsNotATuple) {
|
|
||||||
Value value;
|
|
||||||
value.mutable_integer();
|
|
||||||
|
|
||||||
EXPECT_THAT(ValueTypeHelper<Tuple<uint32_t>>::FromValue(value),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"The given Value is not a tuple"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ValueTypeTupleTest, ValueConversionFailsIfValueSizeDoesntMatchTupleSize) {
|
|
||||||
Value value;
|
|
||||||
value.mutable_tuple()->add_elements()->mutable_integer()->set_value_uint64(
|
|
||||||
1234);
|
|
||||||
|
|
||||||
using TupleType = Tuple<uint32_t, uint32_t>;
|
|
||||||
EXPECT_THAT(
|
|
||||||
ValueTypeHelper<TupleType>::FromValue(value),
|
|
||||||
StatusIs(
|
|
||||||
absl::StatusCode::kInvalidArgument,
|
|
||||||
"The tuple in the given Value has the wrong number of elements"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ValueTypeTupleTest, TestValueTypesAreEqual) {
|
|
||||||
using T1 = Tuple<uint32_t, absl::uint128, uint8_t>;
|
|
||||||
using T2 = Tuple<uint32_t, absl::uint128, uint8_t>;
|
|
||||||
|
|
||||||
ValueType value_type_1 = ValueTypeHelper<T1>::ToValueType();
|
|
||||||
ValueType value_type_2 = ValueTypeHelper<T2>::ToValueType();
|
|
||||||
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(bool equal,
|
|
||||||
ValueTypesAreEqual(value_type_1, value_type_2));
|
|
||||||
EXPECT_TRUE(equal);
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(equal,
|
|
||||||
ValueTypesAreEqual(value_type_2, value_type_1));
|
|
||||||
EXPECT_TRUE(equal);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ValueTypeTupleTest, TestValueTypesAreNotEqual) {
|
|
||||||
using T1 = Tuple<uint32_t, absl::uint128, uint8_t>;
|
|
||||||
using T2 = Tuple<uint32_t, absl::uint128, uint16_t>;
|
|
||||||
|
|
||||||
ValueType value_type_1 = ValueTypeHelper<T1>::ToValueType();
|
|
||||||
ValueType value_type_2 = ValueTypeHelper<T2>::ToValueType();
|
|
||||||
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(bool equal,
|
|
||||||
ValueTypesAreEqual(value_type_1, value_type_2));
|
|
||||||
EXPECT_FALSE(equal);
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(equal,
|
|
||||||
ValueTypesAreEqual(value_type_2, value_type_1));
|
|
||||||
EXPECT_FALSE(equal);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ValueTypeTupleTest, TestFromBytesWithConcreteExample) {
|
|
||||||
std::string bytes = "A 128 bit string";
|
|
||||||
|
|
||||||
auto tuple = FromBytes<Tuple<uint64_t, uint64_t>>(bytes);
|
|
||||||
EXPECT_EQ(std::get<0>(tuple.value()), FromBytes<uint64_t>("A 128 bi"));
|
|
||||||
EXPECT_EQ(std::get<1>(tuple.value()), FromBytes<uint64_t>("t string"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ValueTypeTupleTest, TestFromBytesWithConcreteExampleForIntModN) {
|
|
||||||
constexpr uint32_t kModulus = 4294967291u;
|
|
||||||
using MyIntModN = IntModN<uint32_t, kModulus>;
|
|
||||||
std::string bytes = "A 128+32 bit string.";
|
|
||||||
|
|
||||||
absl::uint128 block = FromBytes<absl::uint128>("A 128+32 bit str");
|
|
||||||
MyIntModN expected_0(static_cast<uint32_t>(block % kModulus));
|
|
||||||
block /= kModulus;
|
|
||||||
block <<= (8 * sizeof(uint32_t));
|
|
||||||
block |= FromBytes<uint32_t>("ing.");
|
|
||||||
MyIntModN expected_1(static_cast<uint32_t>(block % kModulus));
|
|
||||||
|
|
||||||
auto tuple = FromBytes<Tuple<MyIntModN, MyIntModN>>(bytes).value();
|
|
||||||
EXPECT_EQ(std::get<0>(tuple), expected_0);
|
|
||||||
EXPECT_EQ(std::get<1>(tuple), expected_1);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class ValueTypeIntModNTest : public testing::Test {};
|
|
||||||
using IntModNTypes = ::testing::Types<
|
|
||||||
IntModN<uint32_t, 4>, IntModN<uint32_t, 4294967291u>,
|
|
||||||
IntModN<uint64_t, 4294967291ull>, IntModN<uint64_t, 1000000000000ull>
|
|
||||||
#ifdef ABSL_HAVE_INTRINSIC_INT128
|
|
||||||
,
|
|
||||||
IntModN<absl::uint128, (unsigned __int128)(absl::MakeUint128(
|
|
||||||
65535u, 18446744073709551551ull))> // 2**80-65
|
|
||||||
#endif
|
|
||||||
>;
|
|
||||||
TYPED_TEST_SUITE(ValueTypeIntModNTest, IntModNTypes);
|
|
||||||
|
|
||||||
TYPED_TEST(ValueTypeIntModNTest, ToValueType) {
|
|
||||||
ValueType value_type = ValueTypeHelper<TypeParam>::ToValueType();
|
|
||||||
|
|
||||||
EXPECT_TRUE(value_type.type_case() == ValueType::kIntModN);
|
|
||||||
EXPECT_EQ(value_type.int_mod_n().base_integer().bitsize(),
|
|
||||||
sizeof(typename TypeParam::Base) * 8);
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(
|
|
||||||
absl::uint128 modulus,
|
|
||||||
ValueIntegerToUint128(value_type.int_mod_n().modulus()));
|
|
||||||
EXPECT_EQ(modulus, absl::uint128{TypeParam::modulus()});
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(ValueTypeIntModNTest, TestValueTypesAreEqual) {
|
|
||||||
ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(),
|
|
||||||
value_type_2;
|
|
||||||
|
|
||||||
value_type_2.mutable_int_mod_n()->mutable_base_integer()->set_bitsize(
|
|
||||||
sizeof(TypeParam) * 8);
|
|
||||||
*(value_type_2.mutable_int_mod_n()->mutable_modulus()) =
|
|
||||||
Uint128ToValueInteger(TypeParam::modulus());
|
|
||||||
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(bool equal,
|
|
||||||
ValueTypesAreEqual(value_type_1, value_type_2));
|
|
||||||
EXPECT_TRUE(equal);
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(equal,
|
|
||||||
ValueTypesAreEqual(value_type_2, value_type_1));
|
|
||||||
EXPECT_TRUE(equal);
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(ValueTypeIntModNTest, TestValueTypesAreDifferentBase) {
|
|
||||||
ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(),
|
|
||||||
value_type_2 = value_type_1;
|
|
||||||
|
|
||||||
value_type_2.mutable_int_mod_n()->mutable_base_integer()->set_bitsize(
|
|
||||||
sizeof(TypeParam) * 8 * 2);
|
|
||||||
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(bool equal,
|
|
||||||
ValueTypesAreEqual(value_type_1, value_type_2));
|
|
||||||
EXPECT_FALSE(equal);
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(equal,
|
|
||||||
ValueTypesAreEqual(value_type_2, value_type_1));
|
|
||||||
EXPECT_FALSE(equal);
|
|
||||||
};
|
|
||||||
|
|
||||||
TYPED_TEST(ValueTypeIntModNTest, TestValueTypesAreDifferentModulus) {
|
|
||||||
ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(),
|
|
||||||
value_type_2 = value_type_1;
|
|
||||||
|
|
||||||
*(value_type_2.mutable_int_mod_n()->mutable_modulus()) =
|
|
||||||
Uint128ToValueInteger(TypeParam::modulus() - 1);
|
|
||||||
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(bool equal,
|
|
||||||
ValueTypesAreEqual(value_type_1, value_type_2));
|
|
||||||
EXPECT_FALSE(equal);
|
|
||||||
DPF_ASSERT_OK_AND_ASSIGN(equal,
|
|
||||||
ValueTypesAreEqual(value_type_2, value_type_1));
|
|
||||||
EXPECT_FALSE(equal);
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(ValueTypeIntModNTest, ValueTypesAreEqualFailsWhenModulusInvalid) {
|
|
||||||
ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(),
|
|
||||||
value_type_2 = value_type_1;
|
|
||||||
|
|
||||||
value_type_2.mutable_int_mod_n()->clear_modulus();
|
|
||||||
|
|
||||||
EXPECT_THAT(ValueTypesAreEqual(value_type_1, value_type_2),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"Unknown value case for the given integer Value"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(ValueTypeIntModNTest, ValueConversionFailsIfNotInteger) {
|
|
||||||
Value value;
|
|
||||||
value.mutable_tuple();
|
|
||||||
|
|
||||||
EXPECT_THAT(ValueTypeHelper<TypeParam>::FromValue(value),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
"The given Value is not an IntModN"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(ValueTypeIntModNTest, ValueConversionFailsIfTooLargeForModulus) {
|
|
||||||
Value value;
|
|
||||||
*(value.mutable_int_mod_n()) = Uint128ToValueInteger(TypeParam::modulus());
|
|
||||||
|
|
||||||
EXPECT_THAT(ValueTypeHelper<TypeParam>::FromValue(value),
|
|
||||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
|
||||||
testing::HasSubstr("is larger than kModulus")));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
} // namespace dpf_internal
|
|
||||||
} // namespace distributed_point_functions
|
|
@ -1,51 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2021 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MACROS_H_
|
|
||||||
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MACROS_H_
|
|
||||||
|
|
||||||
// Helper macro that checks if the right hand side (rexpression) evaluates to a
|
|
||||||
// StatusOr with Status OK, and if so assigns the value to the value on the left
|
|
||||||
// hand side (lhs), otherwise returns the error status. Example:
|
|
||||||
// DPF_ASSIGN_OR_RETURN(lhs, rexpression);
|
|
||||||
#define DPF_ASSIGN_OR_RETURN(lhs, rexpr) \
|
|
||||||
DPF_ASSIGN_OR_RETURN_IMPL_( \
|
|
||||||
DPF_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr)
|
|
||||||
|
|
||||||
// Internal helper.
|
|
||||||
#define DPF_ASSIGN_OR_RETURN_IMPL_(statusor, lhs, rexpr) \
|
|
||||||
auto statusor = (rexpr); \
|
|
||||||
if (ABSL_PREDICT_FALSE(!statusor.ok())) { \
|
|
||||||
return std::move(statusor).status(); \
|
|
||||||
} \
|
|
||||||
lhs = std::move(statusor).value()
|
|
||||||
|
|
||||||
// Internal helper for concatenating macro values.
|
|
||||||
#define DPF_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y
|
|
||||||
#define DPF_STATUS_MACROS_IMPL_CONCAT_(x, y) \
|
|
||||||
DPF_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y)
|
|
||||||
|
|
||||||
#define DPF_RETURN_IF_ERROR(expr) \
|
|
||||||
DPF_RETURN_IF_ERROR_IMPL_(DPF_STATUS_MACROS_IMPL_CONCAT_(_status, __LINE__), \
|
|
||||||
expr)
|
|
||||||
|
|
||||||
#define DPF_RETURN_IF_ERROR_IMPL_(status, expr) \
|
|
||||||
auto status = (expr); \
|
|
||||||
if (ABSL_PREDICT_FALSE(!status.ok())) { \
|
|
||||||
return status; \
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MACROS_H_
|
|
@ -1,122 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2021 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_TUPLE_H_
|
|
||||||
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_TUPLE_H_
|
|
||||||
|
|
||||||
#include <stddef.h>
|
|
||||||
|
|
||||||
#include <tuple>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
|
|
||||||
// A Tuple class with added element-wise addition, subtraction, and negation
|
|
||||||
// operators.
|
|
||||||
template <typename... T>
|
|
||||||
class Tuple {
|
|
||||||
public:
|
|
||||||
using Base = std::tuple<T...>;
|
|
||||||
|
|
||||||
Tuple() {}
|
|
||||||
Tuple(T... elements) : value_(elements...) {}
|
|
||||||
explicit Tuple(Base t) : value_(std::move(t)) {}
|
|
||||||
|
|
||||||
// Copy constructor.
|
|
||||||
Tuple(const Tuple& t) = default;
|
|
||||||
Tuple& operator=(const Tuple& t) = default;
|
|
||||||
|
|
||||||
// Getters for the base tuple type.
|
|
||||||
Base& value() { return value_; }
|
|
||||||
const Base& value() const { return value_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
Base value_;
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace dpf_internal {
|
|
||||||
|
|
||||||
// Implementation of addition and negation. See
|
|
||||||
// https://stackoverflow.com/a/50815143.
|
|
||||||
// We declare the templates here, but define them at the end of this header
|
|
||||||
// because the definitions need to make use of operator+ and operator-.
|
|
||||||
template <typename... T, std::size_t... I>
|
|
||||||
constexpr Tuple<T...> add(const Tuple<T...>& lhs, const Tuple<T...>& rhs,
|
|
||||||
std::index_sequence<I...>);
|
|
||||||
|
|
||||||
template <typename... T, std::size_t... I>
|
|
||||||
constexpr Tuple<T...> negate(const Tuple<T...>& t, std::index_sequence<I...>);
|
|
||||||
|
|
||||||
} // namespace dpf_internal
|
|
||||||
|
|
||||||
template <typename... T>
|
|
||||||
constexpr Tuple<T...> operator+(const Tuple<T...>& lhs,
|
|
||||||
const Tuple<T...>& rhs) {
|
|
||||||
return dpf_internal::add(lhs, rhs, std::make_index_sequence<sizeof...(T)>{});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename... T>
|
|
||||||
constexpr Tuple<T...>& operator+=(Tuple<T...>& lhs, const Tuple<T...>& rhs) {
|
|
||||||
lhs = lhs + rhs;
|
|
||||||
return lhs;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename... T>
|
|
||||||
constexpr Tuple<T...> operator-(const Tuple<T...>& t) {
|
|
||||||
return dpf_internal::negate(t, std::make_index_sequence<sizeof...(T)>{});
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename... T>
|
|
||||||
constexpr Tuple<T...> operator-(const Tuple<T...>& lhs,
|
|
||||||
const Tuple<T...>& rhs) {
|
|
||||||
return lhs + (-rhs);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename... T>
|
|
||||||
constexpr Tuple<T...>& operator-=(Tuple<T...>& lhs, const Tuple<T...>& rhs) {
|
|
||||||
lhs = lhs - rhs;
|
|
||||||
return lhs;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Equality and inequality operators.
|
|
||||||
template <typename... T>
|
|
||||||
constexpr bool operator==(const Tuple<T...>& lhs, const Tuple<T...>& rhs) {
|
|
||||||
return lhs.value() == rhs.value();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename... T>
|
|
||||||
constexpr bool operator!=(const Tuple<T...>& lhs, const Tuple<T...>& rhs) {
|
|
||||||
return lhs.value() != rhs.value();
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace dpf_internal {
|
|
||||||
template <typename... T, std::size_t... I>
|
|
||||||
constexpr Tuple<T...> add(const Tuple<T...>& lhs, const Tuple<T...>& rhs,
|
|
||||||
std::index_sequence<I...>) {
|
|
||||||
return Tuple<T...>{std::get<I>(lhs.value()) + std::get<I>(rhs.value())...};
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename... T, std::size_t... I>
|
|
||||||
constexpr Tuple<T...> negate(const Tuple<T...>& t, std::index_sequence<I...>) {
|
|
||||||
return Tuple<T...>{
|
|
||||||
// Explicitly cast to T to avoid -Wnarrowing warnings for small integers.
|
|
||||||
T(-std::get<I>(t.value()))...};
|
|
||||||
}
|
|
||||||
} // namespace dpf_internal
|
|
||||||
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
|
|
||||||
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_TUPLE_H_
|
|
@ -1,97 +0,0 @@
|
|||||||
// Copyright 2021 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "dpf/tuple.h"
|
|
||||||
|
|
||||||
#include <tuple>
|
|
||||||
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "gtest/gtest.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
using T = Tuple<int, double, absl::uint128>;
|
|
||||||
|
|
||||||
TEST(TupleTest, TestAddition) {
|
|
||||||
T a(std::make_tuple(1, 2, 3));
|
|
||||||
T b(std::make_tuple(4, 5, 6));
|
|
||||||
|
|
||||||
T c = a + b;
|
|
||||||
|
|
||||||
EXPECT_EQ(std::get<0>(c.value()),
|
|
||||||
std::get<0>(a.value()) + std::get<0>(b.value()));
|
|
||||||
EXPECT_EQ(std::get<1>(c.value()),
|
|
||||||
std::get<1>(a.value()) + std::get<1>(b.value()));
|
|
||||||
EXPECT_EQ(std::get<2>(c.value()),
|
|
||||||
std::get<2>(a.value()) + std::get<2>(b.value()));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(TupleTest, TestAdditionInplace) {
|
|
||||||
T a(std::make_tuple(1, 2, 3));
|
|
||||||
T b(std::make_tuple(4, 5, 6));
|
|
||||||
|
|
||||||
T a2 = a;
|
|
||||||
a += b;
|
|
||||||
|
|
||||||
EXPECT_EQ(std::get<0>(a.value()),
|
|
||||||
std::get<0>(a2.value()) + std::get<0>(b.value()));
|
|
||||||
EXPECT_EQ(std::get<1>(a.value()),
|
|
||||||
std::get<1>(a2.value()) + std::get<1>(b.value()));
|
|
||||||
EXPECT_EQ(std::get<2>(a.value()),
|
|
||||||
std::get<2>(a2.value()) + std::get<2>(b.value()));
|
|
||||||
}
|
|
||||||
TEST(TupleTest, TestSubtraction) {
|
|
||||||
T a(std::make_tuple(1, 2, 3));
|
|
||||||
T b(std::make_tuple(4, 5, 6));
|
|
||||||
|
|
||||||
T c = a - b;
|
|
||||||
|
|
||||||
EXPECT_EQ(std::get<0>(c.value()),
|
|
||||||
std::get<0>(a.value()) - std::get<0>(b.value()));
|
|
||||||
EXPECT_EQ(std::get<1>(c.value()),
|
|
||||||
std::get<1>(a.value()) - std::get<1>(b.value()));
|
|
||||||
EXPECT_EQ(std::get<2>(c.value()),
|
|
||||||
std::get<2>(a.value()) - std::get<2>(b.value()));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(TupleTest, TestSubtractionInplace) {
|
|
||||||
T a(std::make_tuple(1, 2, 3));
|
|
||||||
T b(std::make_tuple(4, 5, 6));
|
|
||||||
|
|
||||||
T a2 = a;
|
|
||||||
a -= b;
|
|
||||||
|
|
||||||
EXPECT_EQ(std::get<0>(a.value()),
|
|
||||||
std::get<0>(a2.value()) - std::get<0>(b.value()));
|
|
||||||
EXPECT_EQ(std::get<1>(a.value()),
|
|
||||||
std::get<1>(a2.value()) - std::get<1>(b.value()));
|
|
||||||
EXPECT_EQ(std::get<2>(a.value()),
|
|
||||||
std::get<2>(a2.value()) - std::get<2>(b.value()));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(TupleTest, TestNegation) {
|
|
||||||
T a(std::make_tuple(1, 2, 3));
|
|
||||||
|
|
||||||
T a2 = -a;
|
|
||||||
|
|
||||||
EXPECT_EQ(std::get<0>(a2.value()), -std::get<0>(a.value()));
|
|
||||||
EXPECT_EQ(std::get<1>(a2.value()), -std::get<1>(a.value()));
|
|
||||||
EXPECT_EQ(std::get<2>(a2.value()), -std::get<2>(a.value()));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
} // namespace distributed_point_functions
|
|
@ -1,87 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2021 Google LLC
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_XOR_WRAPPER_H_
|
|
||||||
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_XOR_WRAPPER_H_
|
|
||||||
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
|
|
||||||
// Wraps the given type, replacing additions and subtractions by XOR.
|
|
||||||
template <typename T>
|
|
||||||
class XorWrapper {
|
|
||||||
public:
|
|
||||||
using WrappedType = T;
|
|
||||||
|
|
||||||
constexpr XorWrapper() : wrapped_{} {}
|
|
||||||
explicit constexpr XorWrapper(T wrapped) : wrapped_(std::move(wrapped)) {}
|
|
||||||
|
|
||||||
// XorWrapper is copyable and movable.
|
|
||||||
constexpr XorWrapper(const XorWrapper&) = default;
|
|
||||||
constexpr XorWrapper& operator=(const XorWrapper&) = default;
|
|
||||||
constexpr XorWrapper(XorWrapper&&) = default;
|
|
||||||
constexpr XorWrapper& operator=(XorWrapper&&) = default;
|
|
||||||
|
|
||||||
// Assignment operators.
|
|
||||||
constexpr XorWrapper& operator+=(const XorWrapper& rhs) {
|
|
||||||
wrapped_ ^= rhs.value();
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
constexpr XorWrapper& operator-=(const XorWrapper& rhs) {
|
|
||||||
wrapped_ ^= rhs.value();
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a reference to the wrapped object.
|
|
||||||
constexpr T& value() { return wrapped_; }
|
|
||||||
constexpr const T& value() const { return wrapped_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
T wrapped_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
constexpr XorWrapper<T> operator+(XorWrapper<T> a, const XorWrapper<T>& b) {
|
|
||||||
a += b;
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
constexpr XorWrapper<T> operator-(XorWrapper<T> a, const XorWrapper<T>& b) {
|
|
||||||
a -= b;
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Negation does nothing in XOR sharing, since -a = 0-a.
|
|
||||||
template <typename T>
|
|
||||||
constexpr XorWrapper<T> operator-(const XorWrapper<T>& a) {
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
constexpr bool operator==(const XorWrapper<T>& a, const XorWrapper<T>& b) {
|
|
||||||
return a.value() == b.value();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
constexpr bool operator!=(const XorWrapper<T>& a, const XorWrapper<T>& b) {
|
|
||||||
return !(a == b);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
|
|
||||||
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_XOR_WRAPPER_H_
|
|
@ -1,72 +0,0 @@
|
|||||||
// Copyright 2021 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "dpf/xor_wrapper.h"
|
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
|
|
||||||
#include "absl/numeric/int128.h"
|
|
||||||
#include "gtest/gtest.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
class XorWrapperTest : public testing::Test {};
|
|
||||||
using XorWrapperTypes =
|
|
||||||
testing::Types<uint8_t, uint16_t, uint32_t, uint64_t, absl::uint128>;
|
|
||||||
|
|
||||||
TYPED_TEST_SUITE(XorWrapperTest, XorWrapperTypes);
|
|
||||||
|
|
||||||
TYPED_TEST(XorWrapperTest, TestConstructor) {
|
|
||||||
TypeParam value{42};
|
|
||||||
|
|
||||||
XorWrapper<TypeParam> wrapper(value);
|
|
||||||
|
|
||||||
EXPECT_EQ(wrapper.value(), value);
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(XorWrapperTest, TestAddition) {
|
|
||||||
TypeParam a{42}, b{23};
|
|
||||||
XorWrapper<TypeParam> wrapped_a(a), wrapped_b(b);
|
|
||||||
|
|
||||||
EXPECT_EQ((wrapped_a + wrapped_b).value(), a ^ b);
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(XorWrapperTest, TestSubtraction) {
|
|
||||||
TypeParam a{42}, b{23};
|
|
||||||
XorWrapper<TypeParam> wrapped_a(a), wrapped_b(b);
|
|
||||||
|
|
||||||
EXPECT_EQ((wrapped_a - wrapped_b).value(), a ^ b);
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(XorWrapperTest, TestNegation) {
|
|
||||||
TypeParam value{42};
|
|
||||||
XorWrapper<TypeParam> wrapper(value);
|
|
||||||
|
|
||||||
EXPECT_EQ((-wrapper).value(), value);
|
|
||||||
}
|
|
||||||
|
|
||||||
TYPED_TEST(XorWrapperTest, TestEquality) {
|
|
||||||
TypeParam a{42}, b{23};
|
|
||||||
XorWrapper<TypeParam> wrapped_a(a), wrapped_b(b);
|
|
||||||
|
|
||||||
EXPECT_EQ(wrapped_a, XorWrapper<TypeParam>(a));
|
|
||||||
EXPECT_NE(wrapped_a, XorWrapper<TypeParam>(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
} // namespace distributed_point_functions
|
|
@ -1,9 +0,0 @@
|
|||||||
# Copyright 2024 The Chromium Authors
|
|
||||||
# Use of this source code is governed by a BSD-style license that can be
|
|
||||||
# found in the LICENSE file.
|
|
||||||
|
|
||||||
declare_args() {
|
|
||||||
use_distributed_point_functions = is_debug
|
|
||||||
dpf_abseil_cpp_dir = "//third_party/abseil-cpp"
|
|
||||||
dpf_highway_cpp_dir = "//third_party/highway"
|
|
||||||
}
|
|
@ -1,293 +0,0 @@
|
|||||||
// Copyright 2021 The Chromium Authors
|
|
||||||
// Use of this source code is governed by a BSD-style license that can be
|
|
||||||
// found in the LICENSE file.
|
|
||||||
|
|
||||||
#include "third_party/distributed_point_functions/code/dpf/distributed_point_function.h"
|
|
||||||
|
|
||||||
#include <stddef.h>
|
|
||||||
#include <stdint.h>
|
|
||||||
#include <stdlib.h>
|
|
||||||
|
|
||||||
#include <fuzzer/FuzzedDataProvider.h>
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#define DPF_FUZZER_ASSERT(x) \
|
|
||||||
if (!(x)) { \
|
|
||||||
printf("DPF assertion failed: function %s, file %s, line %d.\n", \
|
|
||||||
__PRETTY_FUNCTION__, __FILE__, __LINE__); \
|
|
||||||
abort(); \
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
const size_t UINT128_SIZE = 2 * sizeof(uint64_t);
|
|
||||||
|
|
||||||
// Constructs a `uint128` numeric value from two 64-bit unsigned integers
|
|
||||||
// consumed from the data provider.
|
|
||||||
absl::uint128 ConsumeUint128(FuzzedDataProvider& data_provider) {
|
|
||||||
uint64_t high = data_provider.ConsumeIntegral<uint64_t>();
|
|
||||||
uint64_t low = data_provider.ConsumeIntegral<uint64_t>();
|
|
||||||
return absl::MakeUint128(high, low);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the prefix of `index` for the domain of `hierarchy_level`.
|
|
||||||
// Adapted from `DpfEvaluationTest::GetPrefixForLevel()`.
|
|
||||||
absl::uint128 GetPrefixForLevel(
|
|
||||||
int hierarchy_level,
|
|
||||||
absl::uint128 index,
|
|
||||||
const std::vector<distributed_point_functions::DpfParameters>& parameters) {
|
|
||||||
absl::uint128 result = 0;
|
|
||||||
int shift_amount = parameters.back().log_domain_size() -
|
|
||||||
parameters[hierarchy_level].log_domain_size();
|
|
||||||
if (shift_amount < 128)
|
|
||||||
result = index >> shift_amount;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Evaluates both contexts `ctx0` and `ctx1` at `hierarchy level`, using the
|
|
||||||
// appropriate prefixes of `evaluation_points`. Checks that the expansion of
|
|
||||||
// both keys from correct DPF shares, i.e., they add up to
|
|
||||||
// `beta[ctx.hierarchy_level()]` under prefixes of `alpha`, and to 0 otherwise.
|
|
||||||
// Adapted from `DpfEvaluationTest::EvaluateAndCheckLevel()`.
|
|
||||||
template <typename T>
|
|
||||||
void EvaluateAndCheckLevel(
|
|
||||||
int hierarchy_level,
|
|
||||||
absl::Span<const absl::uint128> evaluation_points,
|
|
||||||
absl::uint128 alpha,
|
|
||||||
const std::vector<absl::uint128>& beta,
|
|
||||||
distributed_point_functions::EvaluationContext& ctx0,
|
|
||||||
distributed_point_functions::EvaluationContext& ctx1,
|
|
||||||
const std::vector<distributed_point_functions::DpfParameters>& parameters,
|
|
||||||
const distributed_point_functions::DistributedPointFunction& dpf) {
|
|
||||||
int previous_hierarchy_level = ctx0.previous_hierarchy_level();
|
|
||||||
int current_log_domain_size = parameters[hierarchy_level].log_domain_size();
|
|
||||||
int previous_log_domain_size = 0;
|
|
||||||
int num_expansions = 1;
|
|
||||||
bool is_first_evaluation = previous_hierarchy_level < 0;
|
|
||||||
// Generate prefixes if we're not on the first level.
|
|
||||||
std::vector<absl::uint128> prefixes;
|
|
||||||
if (!is_first_evaluation) {
|
|
||||||
num_expansions = static_cast<int>(evaluation_points.size());
|
|
||||||
prefixes.resize(evaluation_points.size());
|
|
||||||
previous_log_domain_size =
|
|
||||||
parameters[previous_hierarchy_level].log_domain_size();
|
|
||||||
for (int i = 0; i < static_cast<int>(evaluation_points.size()); ++i)
|
|
||||||
prefixes[i] = GetPrefixForLevel(previous_hierarchy_level,
|
|
||||||
evaluation_points[i], parameters);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Evaluating a key with N correction words leads to an O(2^N) malloc, which
|
|
||||||
// will unsurprisingly cause a fuzzer crash. See <https://crbug.com/1494260>.
|
|
||||||
constexpr size_t kMaxCorrectionWords = 30;
|
|
||||||
if (ctx0.key().correction_words().size() > kMaxCorrectionWords) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
absl::StatusOr<std::vector<T>> result_0 =
|
|
||||||
dpf.EvaluateUntil<T>(hierarchy_level, prefixes, ctx0);
|
|
||||||
DPF_FUZZER_ASSERT(result_0.ok());
|
|
||||||
if (ctx1.key().correction_words().size() > kMaxCorrectionWords) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
absl::StatusOr<std::vector<T>> result_1 =
|
|
||||||
dpf.EvaluateUntil<T>(hierarchy_level, prefixes, ctx1);
|
|
||||||
DPF_FUZZER_ASSERT(result_1.ok());
|
|
||||||
|
|
||||||
DPF_FUZZER_ASSERT(result_0->size() == result_1->size());
|
|
||||||
int64_t outputs_per_prefix =
|
|
||||||
int64_t{1} << (current_log_domain_size - previous_log_domain_size);
|
|
||||||
int64_t expected_output_size = num_expansions * outputs_per_prefix;
|
|
||||||
DPF_FUZZER_ASSERT(static_cast<int64_t>(result_0->size()) ==
|
|
||||||
expected_output_size);
|
|
||||||
|
|
||||||
// Iterator over the outputs and check that they sum up to 0 or to
|
|
||||||
// `beta[current_hierarchy_level]`;
|
|
||||||
absl::uint128 previous_alpha_prefix = 0;
|
|
||||||
if (!is_first_evaluation)
|
|
||||||
previous_alpha_prefix =
|
|
||||||
GetPrefixForLevel(previous_hierarchy_level, alpha, parameters);
|
|
||||||
|
|
||||||
absl::uint128 current_alpha_prefix =
|
|
||||||
GetPrefixForLevel(hierarchy_level, alpha, parameters);
|
|
||||||
for (int64_t i = 0; i < expected_output_size; ++i) {
|
|
||||||
int prefix_index = i / outputs_per_prefix;
|
|
||||||
int prefix_expansion_index = i % outputs_per_prefix;
|
|
||||||
// The output is on the path to `alpha`, if we're at the first level or
|
|
||||||
// under a prefix of `alpha`, and the current block in the expansion of the
|
|
||||||
// prefix is also on the path to `alpha`.
|
|
||||||
if ((is_first_evaluation ||
|
|
||||||
prefixes[prefix_index] == previous_alpha_prefix) &&
|
|
||||||
prefix_expansion_index == (current_alpha_prefix % outputs_per_prefix)) {
|
|
||||||
// We need to static_cast here since otherwise operator+ returns an
|
|
||||||
// unsigned int without doing a modular reduction, which causes the test
|
|
||||||
// to fail on types with sizeof(T) < sizeof(unsigned).
|
|
||||||
DPF_FUZZER_ASSERT(
|
|
||||||
absl::uint128{static_cast<T>((*result_0)[i] + (*result_1)[i])} ==
|
|
||||||
beta[hierarchy_level]);
|
|
||||||
} else {
|
|
||||||
DPF_FUZZER_ASSERT(static_cast<T>((*result_0)[i] + (*result_1)[i]) == 0U);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
|
|
||||||
// Use magic separator to split the input into two parts. The first part will
|
|
||||||
// generate alpha, and an array of parameters and betas. The second part will
|
|
||||||
// generate level step and an array of evaluation points.
|
|
||||||
const uint8_t separator[] = {0xDE, 0xAD, 0xBE, 0xEF};
|
|
||||||
|
|
||||||
const uint8_t* pos =
|
|
||||||
std::search(data, data + size, separator, separator + sizeof(separator));
|
|
||||||
|
|
||||||
const uint8_t* data1 = data;
|
|
||||||
size_t size1 = pos - data;
|
|
||||||
|
|
||||||
const uint8_t* data2 =
|
|
||||||
(pos == data + size) ? nullptr : pos + sizeof(separator);
|
|
||||||
size_t size2 = data2 ? (data + size) - (pos + sizeof(separator)) : 0;
|
|
||||||
|
|
||||||
FuzzedDataProvider data_provider1(data1, size1);
|
|
||||||
|
|
||||||
if (data_provider1.remaining_bytes() < UINT128_SIZE)
|
|
||||||
return 0;
|
|
||||||
|
|
||||||
absl::uint128 alpha = ConsumeUint128(data_provider1);
|
|
||||||
|
|
||||||
std::vector<int32_t> log_domain_sizes;
|
|
||||||
std::vector<int32_t> element_bitsizes;
|
|
||||||
std::vector<distributed_point_functions::DpfParameters> parameters;
|
|
||||||
std::vector<absl::uint128> beta;
|
|
||||||
|
|
||||||
// log_domain_size(int32_t), element_bitsize(int32_t),
|
|
||||||
// beta(uint128)
|
|
||||||
while (data_provider1.remaining_bytes() >=
|
|
||||||
(2 * sizeof(int32_t) + UINT128_SIZE)) {
|
|
||||||
int32_t log_domain_size = data_provider1.ConsumeIntegral<int32_t>();
|
|
||||||
int32_t element_bitsize = data_provider1.ConsumeIntegral<int32_t>();
|
|
||||||
log_domain_sizes.push_back(log_domain_size);
|
|
||||||
element_bitsizes.push_back(element_bitsize);
|
|
||||||
|
|
||||||
distributed_point_functions::DpfParameters parameter;
|
|
||||||
parameter.set_log_domain_size(log_domain_size);
|
|
||||||
parameter.mutable_value_type()->mutable_integer()->set_bitsize(
|
|
||||||
element_bitsize);
|
|
||||||
parameters.push_back(parameter);
|
|
||||||
|
|
||||||
beta.push_back(ConsumeUint128(data_provider1));
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<
|
|
||||||
std::unique_ptr<distributed_point_functions::DistributedPointFunction>>
|
|
||||||
status_or_dpf = distributed_point_functions::DistributedPointFunction::
|
|
||||||
CreateIncremental(parameters);
|
|
||||||
|
|
||||||
size_t num_levels = parameters.size();
|
|
||||||
|
|
||||||
if (!status_or_dpf.ok()) {
|
|
||||||
// `log_domain_size` is expected to be in ascending order and
|
|
||||||
// `element_bitsize` is expected to be non-decreasing. As it is hard for the
|
|
||||||
// fuzzer to land upon a valid input, we sort the parameters and try again
|
|
||||||
// if the construction fails.
|
|
||||||
std::sort(log_domain_sizes.begin(), log_domain_sizes.end());
|
|
||||||
std::sort(element_bitsizes.begin(), element_bitsizes.end());
|
|
||||||
for (size_t i = 0; i < num_levels; ++i) {
|
|
||||||
parameters[i].set_log_domain_size(log_domain_sizes[i]);
|
|
||||||
parameters[i].mutable_value_type()->mutable_integer()->set_bitsize(
|
|
||||||
element_bitsizes[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
status_or_dpf = distributed_point_functions::DistributedPointFunction::
|
|
||||||
CreateIncremental(parameters);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!status_or_dpf.ok())
|
|
||||||
return 0;
|
|
||||||
|
|
||||||
std::unique_ptr<distributed_point_functions::DistributedPointFunction> dpf =
|
|
||||||
std::move(status_or_dpf).value();
|
|
||||||
|
|
||||||
absl::StatusOr<std::pair<distributed_point_functions::DpfKey,
|
|
||||||
distributed_point_functions::DpfKey>>
|
|
||||||
status_or_keys = dpf->GenerateKeysIncremental(alpha, beta);
|
|
||||||
if (!status_or_keys.ok())
|
|
||||||
return 0;
|
|
||||||
|
|
||||||
std::pair<distributed_point_functions::DpfKey,
|
|
||||||
distributed_point_functions::DpfKey>
|
|
||||||
keys = std::move(status_or_keys).value();
|
|
||||||
|
|
||||||
// Adapted from `DpfEvaluationTest.TestCorrectness()`.
|
|
||||||
absl::StatusOr<distributed_point_functions::EvaluationContext>
|
|
||||||
status_or_ctx0 = dpf->CreateEvaluationContext(keys.first);
|
|
||||||
DPF_FUZZER_ASSERT(status_or_ctx0.ok());
|
|
||||||
|
|
||||||
absl::StatusOr<distributed_point_functions::EvaluationContext>
|
|
||||||
status_or_ctx1 = dpf->CreateEvaluationContext(keys.second);
|
|
||||||
DPF_FUZZER_ASSERT(status_or_ctx1.ok());
|
|
||||||
|
|
||||||
distributed_point_functions::EvaluationContext ctx0 =
|
|
||||||
std::move(status_or_ctx0).value();
|
|
||||||
distributed_point_functions::EvaluationContext ctx1 =
|
|
||||||
std::move(status_or_ctx1).value();
|
|
||||||
|
|
||||||
// Generate evaluation points.
|
|
||||||
FuzzedDataProvider data_provider2(data2, size2);
|
|
||||||
if (data_provider2.remaining_bytes() < sizeof(int))
|
|
||||||
return 0;
|
|
||||||
|
|
||||||
int level_step = data_provider2.ConsumeIntegralInRange<int>(1, 10);
|
|
||||||
|
|
||||||
std::vector<absl::uint128> evaluation_points;
|
|
||||||
while (data_provider2.remaining_bytes() >= UINT128_SIZE) {
|
|
||||||
evaluation_points.push_back(ConsumeUint128(data_provider2));
|
|
||||||
if (parameters.back().log_domain_size() < 128)
|
|
||||||
evaluation_points.back() %=
|
|
||||||
(absl::uint128{1} << parameters.back().log_domain_size());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Always evaluate on alpha.
|
|
||||||
evaluation_points.push_back(alpha);
|
|
||||||
|
|
||||||
int32_t previous_log_domain_size = 0;
|
|
||||||
for (int i = level_step - 1; i < static_cast<int>(num_levels);
|
|
||||||
i += level_step) {
|
|
||||||
// If any gap in the log_domain_sizes used in successive evaluations is
|
|
||||||
// larger than 62, validation will fail in `EvaluateAndCheckLevel`.
|
|
||||||
int32_t current_log_domain_size = parameters[i].log_domain_size();
|
|
||||||
if (current_log_domain_size - previous_log_domain_size > 62)
|
|
||||||
return 0;
|
|
||||||
previous_log_domain_size = current_log_domain_size;
|
|
||||||
|
|
||||||
switch (parameters[i].value_type().integer().bitsize()) {
|
|
||||||
case 8:
|
|
||||||
EvaluateAndCheckLevel<uint8_t>(i, evaluation_points, alpha, beta, ctx0,
|
|
||||||
ctx1, parameters, *dpf);
|
|
||||||
break;
|
|
||||||
case 16:
|
|
||||||
EvaluateAndCheckLevel<uint16_t>(i, evaluation_points, alpha, beta, ctx0,
|
|
||||||
ctx1, parameters, *dpf);
|
|
||||||
break;
|
|
||||||
case 32:
|
|
||||||
EvaluateAndCheckLevel<uint32_t>(i, evaluation_points, alpha, beta, ctx0,
|
|
||||||
ctx1, parameters, *dpf);
|
|
||||||
break;
|
|
||||||
case 64:
|
|
||||||
EvaluateAndCheckLevel<uint64_t>(i, evaluation_points, alpha, beta, ctx0,
|
|
||||||
ctx1, parameters, *dpf);
|
|
||||||
break;
|
|
||||||
case 128:
|
|
||||||
EvaluateAndCheckLevel<absl::uint128>(i, evaluation_points, alpha, beta,
|
|
||||||
ctx0, ctx1, parameters, *dpf);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
// DPF construction should've failed if the parameters were invalid.
|
|
||||||
DPF_FUZZER_ASSERT(false);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
@ -1,48 +0,0 @@
|
|||||||
# Copyright 2024 The Chromium Authors
|
|
||||||
# Use of this source code is governed by a BSD-style license that can be
|
|
||||||
# found in the LICENSE file.
|
|
||||||
|
|
||||||
import("//build/buildflag_header.gni")
|
|
||||||
import("//testing/test.gni")
|
|
||||||
import("//third_party/distributed_point_functions/features.gni")
|
|
||||||
|
|
||||||
source_set("shim") {
|
|
||||||
public_deps = [ ":buildflags" ]
|
|
||||||
|
|
||||||
if (use_distributed_point_functions) {
|
|
||||||
sources = [
|
|
||||||
"distributed_point_function_shim.cc",
|
|
||||||
"distributed_point_function_shim.h",
|
|
||||||
]
|
|
||||||
deps = [
|
|
||||||
"$dpf_abseil_cpp_dir:absl",
|
|
||||||
"//base",
|
|
||||||
"//third_party/distributed_point_functions:internal",
|
|
||||||
]
|
|
||||||
public_deps += [ "//third_party/distributed_point_functions:proto" ]
|
|
||||||
configs += [ "//third_party/distributed_point_functions:includes" ]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# External targets may depend on :buildflags directly without pulling in
|
|
||||||
# :distributed_point_functions. For instance, tests may set different
|
|
||||||
# expectations when the dpf library is omitted from the build.
|
|
||||||
buildflag_header("buildflags") {
|
|
||||||
header = "buildflags.h"
|
|
||||||
flags = [ "USE_DISTRIBUTED_POINT_FUNCTIONS=$use_distributed_point_functions" ]
|
|
||||||
}
|
|
||||||
|
|
||||||
test("distributed_point_functions_shim_unittests") {
|
|
||||||
deps = [
|
|
||||||
"//testing/gtest",
|
|
||||||
"//testing/gtest:gtest_main",
|
|
||||||
]
|
|
||||||
if (use_distributed_point_functions) {
|
|
||||||
sources = [ "distributed_point_function_shim_unittest.cc" ]
|
|
||||||
deps += [
|
|
||||||
":shim",
|
|
||||||
"$dpf_abseil_cpp_dir:absl",
|
|
||||||
"//third_party/protobuf:protobuf_lite",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,3 +0,0 @@
|
|||||||
include_rules = [
|
|
||||||
"+base",
|
|
||||||
]
|
|
@ -1,42 +0,0 @@
|
|||||||
// Copyright 2023 The Chromium Authors
|
|
||||||
// Use of this source code is governed by a BSD-style license that can be
|
|
||||||
// found in the LICENSE file.
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <optional>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "base/check_op.h"
|
|
||||||
#include "base/logging.h"
|
|
||||||
#include "third_party/abseil-cpp/absl/numeric/int128.h"
|
|
||||||
#include "third_party/abseil-cpp/absl/status/status.h"
|
|
||||||
#include "third_party/abseil-cpp/absl/status/statusor.h"
|
|
||||||
#include "third_party/distributed_point_functions/code/dpf/distributed_point_function.h"
|
|
||||||
#include "third_party/distributed_point_functions/dpf/distributed_point_function.pb.h"
|
|
||||||
#include "third_party/distributed_point_functions/shim/distributed_point_function_shim.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
std::optional<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
|
|
||||||
std::vector<DpfParameters> parameters,
|
|
||||||
absl::uint128 alpha,
|
|
||||||
std::vector<absl::uint128> beta) {
|
|
||||||
// absl::StatusOr is not allowed in the codebase, but this minimal usage is
|
|
||||||
// necessary to interact with //third_party/distributed_point_functions/.
|
|
||||||
absl::StatusOr<std::unique_ptr<DistributedPointFunction>> dpf_result =
|
|
||||||
DistributedPointFunction::CreateIncremental(std::move(parameters));
|
|
||||||
if (!dpf_result.ok()) {
|
|
||||||
LOG(ERROR) << "CreateIncremental() failed: " << dpf_result.status();
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
CHECK_NE(*dpf_result, nullptr);
|
|
||||||
|
|
||||||
absl::StatusOr<std::pair<DpfKey, DpfKey>> keys_result =
|
|
||||||
(*dpf_result)->GenerateKeysIncremental(alpha, std::move(beta));
|
|
||||||
if (!keys_result.ok()) {
|
|
||||||
LOG(ERROR) << "GenerateKeysIncremental() failed: " << keys_result.status();
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
return std::move(*keys_result);
|
|
||||||
}
|
|
||||||
} // namespace distributed_point_functions
|
|
@ -1,32 +0,0 @@
|
|||||||
// Copyright 2023 The Chromium Authors
|
|
||||||
// Use of this source code is governed by a BSD-style license that can be
|
|
||||||
// found in the LICENSE file.
|
|
||||||
|
|
||||||
#ifndef CONTENT_BROWSER_AGGREGATION_SERVICE_DISTRIBUTED_POINT_FUNCTION_SHIM_H_
|
|
||||||
#define CONTENT_BROWSER_AGGREGATION_SERVICE_DISTRIBUTED_POINT_FUNCTION_SHIM_H_
|
|
||||||
|
|
||||||
#include "third_party/distributed_point_functions/shim/buildflags.h"
|
|
||||||
|
|
||||||
static_assert(BUILDFLAG(USE_DISTRIBUTED_POINT_FUNCTIONS),
|
|
||||||
"This header must not be included when "
|
|
||||||
"distributed_point_functions is omitted from the build");
|
|
||||||
|
|
||||||
#include <optional>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "third_party/abseil-cpp/absl/numeric/int128.h"
|
|
||||||
#include "third_party/distributed_point_functions/dpf/distributed_point_function.pb.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
|
|
||||||
// Generates a pair of keys for a DPF that evaluates to `beta` when given
|
|
||||||
// `alpha`. On failure, returns std::nullopt.
|
|
||||||
std::optional<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
|
|
||||||
std::vector<DpfParameters> parameters,
|
|
||||||
absl::uint128 alpha,
|
|
||||||
std::vector<absl::uint128> beta);
|
|
||||||
|
|
||||||
} // namespace distributed_point_functions
|
|
||||||
|
|
||||||
#endif // CONTENT_BROWSER_AGGREGATION_SERVICE_DISTRIBUTED_POINT_FUNCTION_SHIM_H_
|
|
52
third_party/distributed_point_functions/shim/distributed_point_function_shim_unittest.cc
vendored
52
third_party/distributed_point_functions/shim/distributed_point_function_shim_unittest.cc
vendored
@ -1,52 +0,0 @@
|
|||||||
// Copyright 2023 The Chromium Authors
|
|
||||||
// Use of this source code is governed by a BSD-style license that can be
|
|
||||||
// found in the LICENSE file.
|
|
||||||
|
|
||||||
#include <stddef.h>
|
|
||||||
|
|
||||||
#include <optional>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "testing/gtest/include/gtest/gtest.h"
|
|
||||||
#include "third_party/abseil-cpp/absl/numeric/int128.h"
|
|
||||||
#include "third_party/distributed_point_functions/dpf/distributed_point_function.pb.h"
|
|
||||||
#include "third_party/distributed_point_functions/shim/distributed_point_function_shim.h"
|
|
||||||
|
|
||||||
namespace distributed_point_functions {
|
|
||||||
|
|
||||||
// The shim's GenerateKeysIncremental() can return a value besides std::nullopt.
|
|
||||||
TEST(DistributedPointFunctionShimTest, GenerateKeysIncrementalConstructsKeys) {
|
|
||||||
constexpr size_t kBitLength = 32;
|
|
||||||
std::vector<DpfParameters> parameters(kBitLength);
|
|
||||||
for (size_t i = 0; i < parameters.size(); ++i) {
|
|
||||||
parameters[i].set_log_domain_size(i + 1);
|
|
||||||
parameters[i].mutable_value_type()->mutable_integer()->set_bitsize(
|
|
||||||
parameters.size());
|
|
||||||
}
|
|
||||||
std::optional<std::pair<DpfKey, DpfKey>> maybe_dpf_keys =
|
|
||||||
GenerateKeysIncremental(
|
|
||||||
std::move(parameters),
|
|
||||||
/*alpha=*/absl::uint128{1},
|
|
||||||
/*beta=*/std::vector<absl::uint128>(kBitLength, absl::uint128{1}));
|
|
||||||
EXPECT_TRUE(maybe_dpf_keys.has_value());
|
|
||||||
}
|
|
||||||
|
|
||||||
// When DistributedPointFunction::CreateIncremental() fails, the shim's
|
|
||||||
// GenerateKeysIncremental() should return std::nullopt.
|
|
||||||
TEST(DistributedPointFunctionShimTest, GenerateKeysIncrementalEmptyParameters) {
|
|
||||||
EXPECT_FALSE(GenerateKeysIncremental(/*parameters=*/{},
|
|
||||||
/*alpha=*/absl::uint128{}, /*beta=*/{}));
|
|
||||||
}
|
|
||||||
|
|
||||||
// When the length of beta does not match the number of parameters, the internal
|
|
||||||
// call to DistributedPointFunction::GenerateKeysIncremental() will fail, and
|
|
||||||
// the shim's GenerateKeysIncremental() should return std::nullopt.
|
|
||||||
TEST(DistributedPointFunctionShimTest, GenerateKeysIncrementalBetaWrongSize) {
|
|
||||||
std::vector<DpfParameters> parameters(3);
|
|
||||||
EXPECT_FALSE(
|
|
||||||
GenerateKeysIncremental(/*parameters=*/std::vector<DpfParameters>(3),
|
|
||||||
/*alpha=*/absl::uint128{}, /*beta=*/{1, 2, 3}));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace distributed_point_functions
|
|
3
third_party/highway/OWNERS
vendored
3
third_party/highway/OWNERS
vendored
@ -1 +1,2 @@
|
|||||||
file://third_party/distributed_point_functions/OWNERS
|
bikineev@chromium.org
|
||||||
|
file://third_party/blink/renderer/core/html/parser/OWNERS
|
||||||
|
Reference in New Issue
Block a user