diff --git a/BUILD.gn b/BUILD.gn index c0a9639d27232..e32e31565b7cc 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -114,7 +114,6 @@ group("gn_all") { "//third_party/angle/src/tests:angle_end2end_tests", "//third_party/angle/src/tests:angle_unittests", "//third_party/angle/src/tests:angle_white_box_tests", - "//third_party/distributed_point_functions/shim:distributed_point_functions_shim_unittests", "//third_party/flatbuffers:flatbuffers_unittests", "//third_party/highway:highway_tests", "//third_party/liburlpattern:liburlpattern_unittests", diff --git a/content/browser/BUILD.gn b/content/browser/BUILD.gn index 86e30304014f4..f2a64f86c4cc9 100644 --- a/content/browser/BUILD.gn +++ b/content/browser/BUILD.gn @@ -289,7 +289,6 @@ source_set("browser") { "//third_party/blink/public/strings", "//third_party/boringssl", "//third_party/brotli:dec", - "//third_party/distributed_point_functions", "//third_party/icu", "//third_party/inspector_protocol:crdtp", "//third_party/libyuv", diff --git a/content/test/BUILD.gn b/content/test/BUILD.gn index 2c8e8e9b42603..cc799cd5ad948 100644 --- a/content/test/BUILD.gn +++ b/content/test/BUILD.gn @@ -3252,7 +3252,6 @@ test("content_unittests") { "//third_party/blink/public:test_support", "//third_party/blink/public/common:font_enumeration_table_proto", "//third_party/blink/public/common:headers", - "//third_party/distributed_point_functions/shim:buildflags", "//third_party/icu", "//third_party/inspector_protocol:crdtp", "//third_party/inspector_protocol:crdtp_test", diff --git a/infra/inclusive_language_presubmit_exempt_dirs.txt b/infra/inclusive_language_presubmit_exempt_dirs.txt index fdaecbd641400..5f73e2af70d7d 100644 --- a/infra/inclusive_language_presubmit_exempt_dirs.txt +++ b/infra/inclusive_language_presubmit_exempt_dirs.txt @@ -481,7 +481,6 @@ third_party/crashpad/crashpad/third_party/linux 1 1 third_party/crashpad/crashpad/third_party/ninja 1 1 third_party/crashpad/crashpad/util/misc 1 1 third_party/dav1d 2 2 -third_party/distributed_point_functions/code 2 1 third_party/expat 2 2 third_party/fdlibm 1 1 third_party/fusejs/dist 3 1 diff --git a/third_party/distributed_point_functions/BUILD.gn b/third_party/distributed_point_functions/BUILD.gn deleted file mode 100644 index 7d7cdb930bb63..0000000000000 --- a/third_party/distributed_point_functions/BUILD.gn +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2021 The Chromium Authors -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. - -import("//testing/libfuzzer/fuzzer_test.gni") -import("//third_party/distributed_point_functions/features.gni") -import("//third_party/protobuf/proto_library.gni") - -# This is Chromium's interface with the third-party distributed_point_functions -# library. Targets outside of //third_party/distributed_point_functions should -# depend on this target rather than using the source directly. This extra layer -# prevents macros from leaking into Chromium code via header includes. -source_set("distributed_point_functions") { - public_deps = [ "//third_party/distributed_point_functions/shim" ] -} - -proto_library("proto") { - sources = [ "code/dpf/distributed_point_function.proto" ] - proto_out_dir = "third_party/distributed_point_functions/dpf" - cc_generator_options = "lite" -} - -fuzzer_test("dpf_fuzzer") { - sources = [ "fuzz/dpf_fuzzer.cc" ] - deps = [ ":internal" ] - - # Do not apply Chromium code rules to this third-party code. - suppressed_configs = [ "//build/config/compiler:chromium_code" ] - additional_configs = [ "//build/config/compiler:no_chromium_code" ] - - additional_configs += [ ":includes" ] -} - -# Targets below this line are only visible within this file and shim/. -visibility = [ - ":*", - "//third_party/distributed_point_functions/shim:*", -] - -config("includes") { - include_dirs = [ - "code", - "$target_gen_dir", - ] -} - -source_set("internal") { - sources = [ - "code/dpf/aes_128_fixed_key_hash.cc", - "code/dpf/aes_128_fixed_key_hash.h", - "code/dpf/distributed_point_function.cc", - "code/dpf/distributed_point_function.h", - "code/dpf/int_mod_n.cc", - "code/dpf/int_mod_n.h", - "code/dpf/internal/evaluate_prg_hwy.cc", - "code/dpf/internal/evaluate_prg_hwy.h", - "code/dpf/internal/get_hwy_mode.cc", - "code/dpf/internal/get_hwy_mode.h", - "code/dpf/internal/proto_validator.cc", - "code/dpf/internal/proto_validator.h", - "code/dpf/internal/value_type_helpers.cc", - "code/dpf/internal/value_type_helpers.h", - "code/dpf/status_macros.h", - "code/dpf/tuple.h", - "code/dpf/xor_wrapper.h", - ] - - public_deps = [ - ":proto", - "$dpf_abseil_cpp_dir:absl", - "$dpf_highway_cpp_dir:libhwy", - "//third_party/boringssl", - "//third_party/protobuf:protobuf_lite", - ] - - # Do not apply Chromium code rules to this third-party code. - configs -= [ "//build/config/compiler:chromium_code" ] - configs += [ "//build/config/compiler:no_chromium_code" ] - - configs += [ ":includes" ] -} diff --git a/third_party/distributed_point_functions/DEPS b/third_party/distributed_point_functions/DEPS deleted file mode 100644 index 85d3dcf51b8f1..0000000000000 --- a/third_party/distributed_point_functions/DEPS +++ /dev/null @@ -1,11 +0,0 @@ -include_rules = [ - "+absl", - "+benchmark", - "+dpf", - "+gmock", - "+google/protobuf", - "+gtest", - "+testing", - "+hwy", - "+openssl", -] diff --git a/third_party/distributed_point_functions/DIR_METADATA b/third_party/distributed_point_functions/DIR_METADATA deleted file mode 100644 index 71570707db8e4..0000000000000 --- a/third_party/distributed_point_functions/DIR_METADATA +++ /dev/null @@ -1,6 +0,0 @@ -monorail: { - component: "Internals>AttributionReporting" -} -buganizer_public: { - component_id: 1456103 -} diff --git a/third_party/distributed_point_functions/LICENSE b/third_party/distributed_point_functions/LICENSE deleted file mode 100644 index d645695673349..0000000000000 --- a/third_party/distributed_point_functions/LICENSE +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/third_party/distributed_point_functions/OWNERS b/third_party/distributed_point_functions/OWNERS deleted file mode 100644 index 53b8eade7e1a0..0000000000000 --- a/third_party/distributed_point_functions/OWNERS +++ /dev/null @@ -1,3 +0,0 @@ -alexmt@chromium.org -csharrison@chromium.org -linnan@chromium.org diff --git a/third_party/distributed_point_functions/README.chromium b/third_party/distributed_point_functions/README.chromium deleted file mode 100644 index 66ae1e9bfaa3a..0000000000000 --- a/third_party/distributed_point_functions/README.chromium +++ /dev/null @@ -1,25 +0,0 @@ -Name: The Incremental Distributed Point Functions library -Short Name: distributed_point_functions -URL: https://github.com/google/distributed_point_functions -Version: N/A -Revision: 2db593b64a99f178f682ef0db222d417c23e5bb5 -Date: 2023-11-16 -License: Apache-2.0 -License File: LICENSE -Security Critical: Yes -Shipped: yes -CPEPrefix: unknown - -Description: -This library contains an implementation of incremental distributed point -functions, based on the paper by Boneh et al. - -Local Modifications: -The directory code/ is a copy of the source code, modified in two ways. First, -all top-level directories other than dpf/ have been removed as they are unused. -Second, a .clang-format file has been added to disable automatic code -formatting. Parts of code/dpf/distributed_point_function_test.cc are also -adapted for fuzzing in fuzz/dpf_fuzzer.cc. -Third, a missing absl/strings/str_cat.h include backported from revision -c662ca975068bfa884cc4a96f3a1db40a7611e5e to fix build error when compiled with -latest version of abseil. diff --git a/third_party/distributed_point_functions/code/.bazelrc b/third_party/distributed_point_functions/code/.bazelrc deleted file mode 100644 index 53485cb9743ae..0000000000000 --- a/third_party/distributed_point_functions/code/.bazelrc +++ /dev/null @@ -1 +0,0 @@ -build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 diff --git a/third_party/distributed_point_functions/code/.clang-format b/third_party/distributed_point_functions/code/.clang-format deleted file mode 100644 index e3845288a2aec..0000000000000 --- a/third_party/distributed_point_functions/code/.clang-format +++ /dev/null @@ -1 +0,0 @@ -DisableFormat: true diff --git a/third_party/distributed_point_functions/code/.gitattributes b/third_party/distributed_point_functions/code/.gitattributes deleted file mode 100644 index 440063710088d..0000000000000 --- a/third_party/distributed_point_functions/code/.gitattributes +++ /dev/null @@ -1 +0,0 @@ -experiments/data/* filter=lfs diff=lfs merge=lfs -text diff --git a/third_party/distributed_point_functions/code/.gitignore b/third_party/distributed_point_functions/code/.gitignore deleted file mode 100644 index b803df041501c..0000000000000 --- a/third_party/distributed_point_functions/code/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -# Bazel generated symlinks -bazel-* diff --git a/third_party/distributed_point_functions/code/BUILD b/third_party/distributed_point_functions/code/BUILD deleted file mode 100644 index 19400ec5b3859..0000000000000 --- a/third_party/distributed_point_functions/code/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") - -package( - default_visibility = [":allowlist"], -) - -licenses(["notice"]) - -exports_files(["LICENSE"]) diff --git a/third_party/distributed_point_functions/code/CODE_OF_CONDUCT.md b/third_party/distributed_point_functions/code/CODE_OF_CONDUCT.md deleted file mode 100644 index dc079b4d66eb2..0000000000000 --- a/third_party/distributed_point_functions/code/CODE_OF_CONDUCT.md +++ /dev/null @@ -1,93 +0,0 @@ -# Code of Conduct - -## Our Pledge - -In the interest of fostering an open and welcoming environment, we as -contributors and maintainers pledge to making participation in our project and -our community a harassment-free experience for everyone, regardless of age, body -size, disability, ethnicity, gender identity and expression, level of -experience, education, socio-economic status, nationality, personal appearance, -race, religion, or sexual identity and orientation. - -## Our Standards - -Examples of behavior that contributes to creating a positive environment -include: - -* Using welcoming and inclusive language -* Being respectful of differing viewpoints and experiences -* Gracefully accepting constructive criticism -* Focusing on what is best for the community -* Showing empathy towards other community members - -Examples of unacceptable behavior by participants include: - -* The use of sexualized language or imagery and unwelcome sexual attention or - advances -* Trolling, insulting/derogatory comments, and personal or political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or electronic - address, without explicit permission -* Other conduct which could reasonably be considered inappropriate in a - professional setting - -## Our Responsibilities - -Project maintainers are responsible for clarifying the standards of acceptable -behavior and are expected to take appropriate and fair corrective action in -response to any instances of unacceptable behavior. - -Project maintainers have the right and responsibility to remove, edit, or reject -comments, commits, code, wiki edits, issues, and other contributions that are -not aligned to this Code of Conduct, or to ban temporarily or permanently any -contributor for other behaviors that they deem inappropriate, threatening, -offensive, or harmful. - -## Scope - -This Code of Conduct applies both within project spaces and in public spaces -when an individual is representing the project or its community. Examples of -representing a project or community include using an official project e-mail -address, posting via an official social media account, or acting as an appointed -representative at an online or offline event. Representation of a project may be -further defined and clarified by project maintainers. - -This Code of Conduct also applies outside the project spaces when the Project -Steward has a reasonable belief that an individual's behavior may have a -negative impact on the project or its community. - -## Conflict Resolution - -We do not believe that all conflict is bad; healthy debate and disagreement -often yield positive results. However, it is never okay to be disrespectful or -to engage in behavior that violates the project’s code of conduct. - -If you see someone violating the code of conduct, you are encouraged to address -the behavior directly with those involved. Many issues can be resolved quickly -and easily, and this gives people more control over the outcome of their -dispute. If you are unable to resolve the matter for any reason, or if the -behavior is threatening or harassing, report it. We are dedicated to providing -an environment where participants feel welcome and safe. - -Reports should be directed to *[PROJECT STEWARD NAME(s) AND EMAIL(s)]*, the -Project Steward(s) for *[PROJECT NAME]*. It is the Project Steward’s duty to -receive and address reported violations of the code of conduct. They will then -work with a committee consisting of representatives from the Open Source -Programs Office and the Google Open Source Strategy team. If for any reason you -are uncomfortable reaching out to the Project Steward, please email -opensource@google.com. - -We will investigate every complaint, but you may not receive a direct response. -We will use our discretion in determining when and how to follow up on reported -incidents, which may range from not taking action to permanent expulsion from -the project and project-sponsored spaces. We will notify the accused of the -report and provide them an opportunity to discuss it before any action is taken. -The identity of the reporter will be omitted from the details of the report -supplied to the accused. In potentially harmful situations, such as ongoing -harassment or threats to anyone's safety, we may take action without notice. - -## Attribution - -This Code of Conduct is adapted from the Contributor Covenant, version 1.4, -available at -https://www.contributor-covenant.org/version/1/4/code-of-conduct.html diff --git a/third_party/distributed_point_functions/code/CONTRIBUTING.md b/third_party/distributed_point_functions/code/CONTRIBUTING.md deleted file mode 100644 index 22b241cb732cc..0000000000000 --- a/third_party/distributed_point_functions/code/CONTRIBUTING.md +++ /dev/null @@ -1,29 +0,0 @@ -# How to Contribute - -We'd love to accept your patches and contributions to this project. There are -just a few small guidelines you need to follow. - -## Contributor License Agreement - -Contributions to this project must be accompanied by a Contributor License -Agreement (CLA). You (or your employer) retain the copyright to your -contribution; this simply gives us permission to use and redistribute your -contributions as part of the project. Head over to -<https://cla.developers.google.com/> to see your current agreements on file or -to sign a new one. - -You generally only need to submit a CLA once, so if you've already submitted one -(even if it was for a different project), you probably don't need to do it -again. - -## Code reviews - -All submissions, including submissions by project members, require review. We -use GitHub pull requests for this purpose. Consult -[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more -information on using pull requests. - -## Community Guidelines - -This project follows -[Google's Open Source Community Guidelines](https://opensource.google/conduct/). diff --git a/third_party/distributed_point_functions/code/LICENSE b/third_party/distributed_point_functions/code/LICENSE deleted file mode 100644 index d645695673349..0000000000000 --- a/third_party/distributed_point_functions/code/LICENSE +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/third_party/distributed_point_functions/code/README.md b/third_party/distributed_point_functions/code/README.md deleted file mode 100644 index 0d67344e4a25a..0000000000000 --- a/third_party/distributed_point_functions/code/README.md +++ /dev/null @@ -1,48 +0,0 @@ -# An Implementation of Incremental Distributed Point Functions in C++ [](https://buildkite.com/bazel/google-distributed-point-functions) - -This library contains an implementation of incremental distributed point -functions, based on the following paper: -> Boneh, D., Boyle, E., Corrigan-Gibbs, H., Gilboa, N., & Ishai, Y. (2020). -Lightweight Techniques for Private Heavy Hitters. arXiv preprint -> arXiv:2012.14884. https://arxiv.org/abs/2012.14884 - -## About Incremental Distributed Point Functions - -A distributed point function (DPF) is parameterized by an index `alpha` and a -value `beta`. It consists of two algorithms: key generation and evaluation. -The key generation procedure produces two keys `k_a` and `k_b`, given `alpha` -and `beta`. Evaluating each key on any point `x` in the DPF domain results in an -additive secret share of `beta`, if `x == alpha`, and a share of 0 otherwise. - -Incremental DPFs additionally can be evaluated on prefixes of the index domain. -More precisely, an incremental DPF is parameterized by a hierarchy of index -domains, each a power of two larger than the previous. Key generation now takes -a vector `beta`, one value `beta[i]` for each hierarchy level. -When evaluated on a `b`-bit prefix of `alpha`, where b is the log domain size of -the `i`-th hierarchy level, the incremental DPF returns a secret share of -`beta[i]`, otherwise a share of 0. - -For more details, see the above paper, as well as the -[`DistributedPointFunction` class documentation](dpf/distributed_point_function.h). - - -## Building/Running Tests - -This repository requires Bazel. You can install Bazel by -following the instructions for your platform on the -[Bazel website](https://docs.bazel.build/versions/master/install.html). - -Once you have installed Bazel you can clone this repository and run all tests -that are included by navigating into the root folder and running: - -```bash -bazel test //... -``` - -## Security -To report a security issue, please read [SECURITY.md](SECURITY.md). - -## Disclaimer - -This is not an officially supported Google product. The code is provided as-is, -with no guarantees of correctness or security. diff --git a/third_party/distributed_point_functions/code/SECURITY.md b/third_party/distributed_point_functions/code/SECURITY.md deleted file mode 100644 index 7465c8ba6587e..0000000000000 --- a/third_party/distributed_point_functions/code/SECURITY.md +++ /dev/null @@ -1,5 +0,0 @@ -# Security -To report a security issue, please use http://g.co/vulnz. We use -http://g.co/vulnz for our intake, and do coordination and disclosure here on -GitHub (including using GitHub Security Advisory). The Google Security Team will -respond within 5 working days of your report on g.co/vulnz. \ No newline at end of file diff --git a/third_party/distributed_point_functions/code/WORKSPACE.bazel b/third_party/distributed_point_functions/code/WORKSPACE.bazel deleted file mode 100644 index e9a16014e4019..0000000000000 --- a/third_party/distributed_point_functions/code/WORKSPACE.bazel +++ /dev/null @@ -1,170 +0,0 @@ -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") - -# rules_proto defines abstract rules for building Protocol Buffers. -# https://github.com/bazelbuild/rules_proto -http_archive( - name = "rules_proto", - sha256 = "0daa4fc5b2b820705fcbf239557515f9ab809be45a1e7c6dfaa1d465d5c615d4", - strip_prefix = "rules_proto-3f1ab99b718e3e7dd86ebdc49c580aa6a126b1cd", - urls = [ - "https://github.com/bazelbuild/rules_proto/archive/3f1ab99b718e3e7dd86ebdc49c580aa6a126b1cd.zip", - ], -) - -load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains") - -rules_proto_dependencies() - -rules_proto_toolchains() - -# rules_cc defines rules for generating C++ code from Protocol Buffers. -# https://github.com/bazelbuild/rules_cc -http_archive( - name = "rules_cc", - sha256 = "e17cca44563e0918a36a8ea2a50acb99ea9ad726bbd3cad8ba95a643a40121ab", - strip_prefix = "rules_cc-d7c11265cb157c9b962d87d9ab67b8c24e3a875f", - urls = [ - "https://github.com/bazelbuild/rules_cc/archive/d7c11265cb157c9b962d87d9ab67b8c24e3a875f.zip", - ], -) - -load("@rules_cc//cc:repositories.bzl", "rules_cc_dependencies") - -rules_cc_dependencies() - -# io_bazel_rules_go defines rules for generating C++ code from Protocol Buffers. -# https://github.com/bazelbuild/rules_go -http_archive( - name = "io_bazel_rules_go", - sha256 = "7c35e8515012279ef7bcbc39c4ef4b54a86756d853848cb621b7da49f156c82f", - strip_prefix = "rules_go-b397ab7ace3c4131f48b5f4d4d7e7e9e6809e0d2", - urls = [ - "https://github.com/bazelbuild/rules_go/archive/b397ab7ace3c4131f48b5f4d4d7e7e9e6809e0d2.zip", - ], -) - -load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies") - -go_rules_dependencies() - -go_register_toolchains(version = "1.19.3") - -# Install gtest. -# https://github.com/google/googletest -http_archive( - name = "com_github_google_googletest", - sha256 = "3e91944af2d909a79f18ee9760765624810146ccfae8f1a8f990037a1677d44b", - strip_prefix = "googletest-ac7a126f39d5bcd909b78c9e69900c76659b1bbb", - urls = [ - "https://github.com/google/googletest/archive/ac7a126f39d5bcd909b78c9e69900c76659b1bbb.zip", - ], -) - -# abseil-cpp -# https://github.com/abseil/abseil-cpp -http_archive( - name = "com_google_absl", - sha256 = "431c0c47217c36106f90e2ca4fcdf45af618ea21adde880804661b1ecb240056", - strip_prefix = "abseil-cpp-1fb3830b1cf685999bb2bbd0294be0a53c9440a6", - urls = [ - "https://github.com/abseil/abseil-cpp/archive/1fb3830b1cf685999bb2bbd0294be0a53c9440a6.zip", - ], -) - -# BoringSSL -# https://github.com/google/boringssl -http_archive( - name = "boringssl", - sha256 = "88e4330f4f65ebfdf24847e4807c25f3eacfd5bf1a93f6629d3941196ff9b0b3", - strip_prefix = "boringssl-6347808f2a480a3792148bf7732232229db9b909", - urls = [ - "https://github.com/google/boringssl/archive/6347808f2a480a3792148bf7732232229db9b909.zip", - ], -) - -# Benchmarks -# https://github.com/google/benchmark -http_archive( - name = "com_github_google_benchmark", - sha256 = "5f98b44165f3250f1d749b728018318d654f763ea0f4d7ea156e10e6e0cc678a", - strip_prefix = "benchmark-5e78bedfb07c615edb2b646d1e354980268c1728", - urls = [ - "https://github.com/google/benchmark/archive/5e78bedfb07c615edb2b646d1e354980268c1728.zip", - ], -) - -# gflags needed for glog. -# https://github.com/gflags/gflags -http_archive( - name = "com_github_gflags_gflags", - sha256 = "017e0a91531bfc45be9eaf07e4d8fed33c488b90b58509dbd2e33a33b2648ae6", - strip_prefix = "gflags-a738fdf9338412f83ab3f26f31ac11ed3f3ec4bd", - urls = [ - "https://github.com/gflags/gflags/archive/a738fdf9338412f83ab3f26f31ac11ed3f3ec4bd.zip", - ], -) - -# glog for logging -# https://github.com/google/glog -http_archive( - name = "com_github_google_glog", - sha256 = "0f91ee6cc1edc3b1c53a286382e69a37e5d172ce208b7e5b305be8770d8c21b1", - strip_prefix = "glog-f545ff5e7d7f3df95f6e86c8cb987d9d9d4bd481", - urls = [ - "https://github.com/google/glog/archive/f545ff5e7d7f3df95f6e86c8cb987d9d9d4bd481.zip", - ], -) - -# IREE for cc_embed_data. -# https://github.com/google/iree -http_archive( - name = "com_github_google_iree", - sha256 = "aa369b29a5c45ae9d7aa8bf49ea1308221d1711277222f0755df6e0a575f6879", - strip_prefix = "iree-7e6012468cbaafaaf30302748a2943771b40e2c3", - urls = [ - "https://github.com/google/iree/archive/7e6012468cbaafaaf30302748a2943771b40e2c3.zip", - ], -) - -# riegeli for file IO -# https://github.com/google/riegeli -http_archive( - name = "com_github_google_riegeli", - sha256 = "3de21a222271a1e2c5d728e7f46b63ab4520da829c09ef9727a322e693c9ac18", - strip_prefix = "riegeli-43b7ef9f995469609b6ab07f6becc82186314bfb", - urls = [ - "https://github.com/google/riegeli/archive/43b7ef9f995469609b6ab07f6becc82186314bfb.zip", - ], -) - -# rules_license needed for Highway -# https://github.com/bazelbuild/rules_license -http_archive( - name = "rules_license", - sha256 = "6157e1e68378532d0241ecd15d3c45f6e5cfd98fc10846045509fb2a7cc9e381", - urls = [ - "https://github.com/bazelbuild/rules_license/releases/download/0.0.4/rules_license-0.0.4.tar.gz", - ], -) - -# Highway for SIMD operations. -# https://github.com/google/highway -http_archive( - name = "com_github_google_highway", - sha256 = "cdba0eb21796598dd50fa0a4aa3651fa466c0d37c39d149ee383f725434e4314", - strip_prefix = "highway-45c98184ab7f81cf592c07633070b75fced14a52", - urls = [ - "https://github.com/google/highway/archive/45c98184ab7f81cf592c07633070b75fced14a52.zip", - ], -) - -# cppitertools for logging -# https://github.com/ryanhaining/cppitertools -http_archive( - name = "com_github_ryanhaining_cppitertools", - sha256 = "1608ddbe3c12b0c6e653b992ff63b5dceab9af5347ad93be8714d05e5dc17afb", - strip_prefix = "cppitertools-add5acc932dea2c78acd80747bab71ec0b5bce27", - urls = [ - "https://github.com/ryanhaining/cppitertools/archive/add5acc932dea2c78acd80747bab71ec0b5bce27.zip", - ], -) diff --git a/third_party/distributed_point_functions/code/dpf/BUILD b/third_party/distributed_point_functions/code/dpf/BUILD deleted file mode 100644 index 8ad71b85e436d..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/BUILD +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@rules_cc//cc:defs.bzl", "cc_library") -load("@rules_proto//proto:defs.bzl", "proto_library") - -package( - default_visibility = ["//visibility:public"], -) - -licenses(["notice"]) - -cc_library( - name = "int_mod_n", - srcs = ["int_mod_n.cc"], - hdrs = ["int_mod_n.h"], - deps = [ - "@com_google_absl//absl/base:config", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/numeric:int128", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "int_mod_n_test", - srcs = ["int_mod_n_test.cc"], - deps = [ - ":int_mod_n", - "//dpf/internal:status_matchers", - "@com_github_google_googletest//:gtest_main", - "@com_google_absl//absl/base:config", - "@com_google_absl//absl/numeric:int128", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "int_mod_n_benchmark", - srcs = ["int_mod_n_benchmark.cc"], - deps = [ - ":int_mod_n", - "@boringssl//:crypto", - "@com_github_google_benchmark//:benchmark", - "@com_github_google_googletest//:gtest_main", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "distributed_point_function", - srcs = ["distributed_point_function.cc"], - hdrs = ["distributed_point_function.h"], - deps = [ - ":aes_128_fixed_key_hash", - ":distributed_point_function_cc_proto", - ":status_macros", - "//dpf/internal:evaluate_prg_hwy", - "//dpf/internal:get_hwy_mode", - "//dpf/internal:maybe_deref_span", - "//dpf/internal:proto_validator", - "//dpf/internal:value_type_helpers", - "@boringssl//:crypto", - "@com_github_google_highway//:hwy", - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/meta:type_traits", - "@com_google_absl//absl/numeric:int128", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@com_google_protobuf//:protobuf", - "@com_google_protobuf//:protobuf_lite", - ], -) - -cc_test( - name = "distributed_point_function_test", - size = "medium", - srcs = ["distributed_point_function_test.cc"], - deps = [ - ":distributed_point_function", - ":distributed_point_function_cc_proto", - ":xor_wrapper", - "//dpf/internal:proto_validator", - "//dpf/internal:status_matchers", - "@com_github_google_googletest//:gtest_main", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:config", - "@com_google_absl//absl/numeric:int128", - "@com_google_absl//absl/random", - "@com_google_absl//absl/random:distributions", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@com_google_absl//absl/utility", - ], -) - -proto_library( - name = "distributed_point_function_proto", - srcs = ["distributed_point_function.proto"], -) - -cc_proto_library( - name = "distributed_point_function_cc_proto", - deps = [":distributed_point_function_proto"], -) - -go_proto_library( - name = "distributed_point_function_go_proto", - importpath = "github.com/google/distributed_point_functions/dpf/distributed_point_function_go_proto", - protos = [":distributed_point_function_proto"], -) - -cc_test( - name = "distributed_point_function_benchmark", - srcs = [ - "distributed_point_function_benchmark.cc", - ], - tags = ["benchmark"], - deps = [ - ":distributed_point_function", - "@com_github_google_benchmark//:benchmark", - "@com_github_google_googletest//:gtest_main", - "@com_github_google_highway//:hwy", - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/numeric:int128", - "@com_google_absl//absl/random", - "@com_google_absl//absl/random:distributions", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@com_google_protobuf//:protobuf", - ], -) - -cc_library( - name = "status_macros", - hdrs = ["status_macros.h"], -) - -cc_library( - name = "aes_128_fixed_key_hash", - srcs = ["aes_128_fixed_key_hash.cc"], - hdrs = ["aes_128_fixed_key_hash.h"], - deps = [ - "@boringssl//:crypto", - "@com_google_absl//absl/numeric:int128", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "aes_128_fixed_key_hash_test", - srcs = ["aes_128_fixed_key_hash_test.cc"], - deps = [ - ":aes_128_fixed_key_hash", - "//dpf/internal:status_matchers", - "@com_github_google_googletest//:gtest_main", - "@com_google_absl//absl/numeric:int128", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "tuple", - hdrs = ["tuple.h"], -) - -cc_test( - name = "tuple_test", - srcs = [ - "tuple_test.cc", - ], - deps = [ - ":tuple", - "@com_github_google_googletest//:gtest_main", - "@com_google_absl//absl/numeric:int128", - ], -) - -cc_library( - name = "xor_wrapper", - hdrs = ["xor_wrapper.h"], -) - -cc_test( - name = "xor_wrapper_test", - srcs = [ - "xor_wrapper_test.cc", - ], - deps = [ - ":xor_wrapper", - "@com_github_google_googletest//:gtest_main", - "@com_google_absl//absl/numeric:int128", - ], -) diff --git a/third_party/distributed_point_functions/code/dpf/aes_128_fixed_key_hash.cc b/third_party/distributed_point_functions/code/dpf/aes_128_fixed_key_hash.cc deleted file mode 100644 index f7df231c7a838..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/aes_128_fixed_key_hash.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dpf/aes_128_fixed_key_hash.h" - -#include <stdint.h> - -#include <algorithm> -#include <array> -#include <string> -#include <utility> - -#include "absl/numeric/int128.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "openssl/err.h" - -namespace distributed_point_functions { - -Aes128FixedKeyHash::Aes128FixedKeyHash( - bssl::UniquePtr<EVP_CIPHER_CTX> cipher_ctx, absl::uint128 key) - : cipher_ctx_(std::move(cipher_ctx)), key_(key) {} - -absl::StatusOr<Aes128FixedKeyHash> Aes128FixedKeyHash::Create( - absl::uint128 key) { - bssl::UniquePtr<EVP_CIPHER_CTX> cipher_ctx(EVP_CIPHER_CTX_new()); - if (!cipher_ctx) { - return absl::InternalError("Failed to allocate AES context"); - } - // Set up the OpenSSL encryption context. We want to evaluate the PRG in - // parallel on many seeds (see class comment in pseudorandom_generator.h), so - // we're using ECB mode here to achieve that. This batched evaluation is not - // to be confused with encryption of an array, for which ECB would be - // insecure. - int openssl_status = - EVP_EncryptInit_ex(cipher_ctx.get(), EVP_aes_128_ecb(), nullptr, - reinterpret_cast<const uint8_t*>(&key), nullptr); - if (openssl_status != 1) { - return absl::InternalError("Failed to set up AES context"); - } - return Aes128FixedKeyHash(std::move(cipher_ctx), key); -} - -absl::Status Aes128FixedKeyHash::Evaluate(absl::Span<const absl::uint128> in, - absl::Span<absl::uint128> out) const { - if (in.size() != out.size()) { - return absl::InvalidArgumentError("Input and output sizes don't match"); - } - if (in.empty()) { - // Nothing to do. - return absl::OkStatus(); - } - - // Compute orthomorphism sigma for each element in `in`, `kBatchSize` elements - // at a time. - auto in_size = static_cast<int64_t>(in.size()); - std::array<absl::uint128, kBatchSize> sigma_in; - for (int64_t start_block = 0; start_block < in_size; - start_block += kBatchSize) { - int64_t batch_size = std::min<int64_t>(in_size - start_block, kBatchSize); - for (int i = 0; i < batch_size; ++i) { - sigma_in[i] = - absl::MakeUint128(absl::Uint128High64(in[start_block + i]) ^ - absl::Uint128Low64(in[start_block + i]), - absl::Uint128High64(in[start_block + i])); - } - - // We use EVP_Cipher here instead of EVP_EncryptUpdate, since it doesn't - // mutate the context in ECB mode, and so this call is thread-safe. - int openssl_status = EVP_Cipher( - cipher_ctx_.get(), reinterpret_cast<uint8_t*>(out.data() + start_block), - reinterpret_cast<const uint8_t*>(sigma_in.data()), - static_cast<int>(batch_size * sizeof(absl::uint128))); - if (openssl_status != 1) { - char buf[256]; - ERR_error_string_n(ERR_get_error(), buf, sizeof(buf)); - return absl::InternalError( - absl::StrCat("AES encryption failed: ", std::string(buf))); - } - for (int64_t i = 0; i < batch_size; ++i) { - out[start_block + i] ^= sigma_in[i]; - } - } - return absl::OkStatus(); -} - -} // namespace distributed_point_functions diff --git a/third_party/distributed_point_functions/code/dpf/aes_128_fixed_key_hash.h b/third_party/distributed_point_functions/code/dpf/aes_128_fixed_key_hash.h deleted file mode 100644 index 2af2061f05c11..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/aes_128_fixed_key_hash.h +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_H_ -#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_H_ - -#include "absl/numeric/int128.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "openssl/cipher.h" - -namespace distributed_point_functions { - -// Aes128FixedKeyHash is a circular correlation-robust hash function based on -// AES. For key `key`, input `in` and output `out`, the hash function is defined -// as -// -// out[i] = AES.Encrypt(key, sigma(in[i])) ^ sigma(in[i]), -// -// where sigma(x) = (x.high64 ^ x.low64, x.high64). This is the -// circular correlation-robust MMO construction from -// https://eprint.iacr.org/2019/074.pdf (pp. 18-19). Note that unlike -// cryptographic hash functions such as SHA-256, this hash function is *not* -// compressing and is not designed to provide any security guarantees beyond -// circular correlation-robustness. Use with appropriate caution. -class Aes128FixedKeyHash { - public: - // Creates a new Aes128FixedKeyHash with the given `key`. - // - // Returns INTERNAL in case of allocation failures or OpenSSL errors. - static absl::StatusOr<Aes128FixedKeyHash> Create(absl::uint128 key); - - // Computes hash values of each block in `in`, writing the output to `out`. - // It is safe to call this method if `in` and `out` overlap. - // - // Returns INVALID_ARGUMENT if sizes of `in` and `out` don't match or their - // sizes in bytes exceed an `int`, or INTERNAL in case of OpenSSL errors. - absl::Status Evaluate(absl::Span<const absl::uint128> in, - absl::Span<absl::uint128> out) const; - - // Aes128FixedKeyHash is not copyable. - Aes128FixedKeyHash(const Aes128FixedKeyHash&) = delete; - Aes128FixedKeyHash& operator=(const Aes128FixedKeyHash&) = delete; - - // Aes128FixedKeyHash is movable (it just wraps a bssl::UniquePtr). - Aes128FixedKeyHash(Aes128FixedKeyHash&&) = default; - Aes128FixedKeyHash& operator=(Aes128FixedKeyHash&&) = default; - - // Returns the key used to construct this hash function. - // DO NOT SEND THIS TO ANY OTHER PARTY! - const absl::uint128& key() const { return key_; } - - // The maximum number of AES blocks encrypted at once. Chosen to pipeline AES - // as much as possible, while still allowing both source and destination to - // comfortably fit in the L1 CPU cache. - static constexpr int kBatchSize = 64; - - private: - // Called by `Create`. - Aes128FixedKeyHash(bssl::UniquePtr<EVP_CIPHER_CTX> cipher_ctx, - absl::uint128 key); - - // The OpenSSL encryption context used by `Evaluate`. - bssl::UniquePtr<EVP_CIPHER_CTX> cipher_ctx_; - - // The key used to construct this hash function. - absl::uint128 key_; -}; - -} // namespace distributed_point_functions - -#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_H_ diff --git a/third_party/distributed_point_functions/code/dpf/aes_128_fixed_key_hash_test.cc b/third_party/distributed_point_functions/code/dpf/aes_128_fixed_key_hash_test.cc deleted file mode 100644 index f1ce37c773afc..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/aes_128_fixed_key_hash_test.cc +++ /dev/null @@ -1,178 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dpf/aes_128_fixed_key_hash.h" - -#include <thread> // NOLINT(build/c++11) -#include <vector> - -#include "absl/numeric/int128.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "dpf/internal/status_matchers.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace distributed_point_functions { -namespace { - -using dpf_internal::StatusIs; - -// Test blocks for keys, inputs, and outputs. -constexpr absl::uint128 kKey0 = - absl::MakeUint128(0x0000000000000000, 0x0000000000000000); -constexpr absl::uint128 kKey1 = - absl::MakeUint128(0x1111111111111111, 0x1111111111111111); -constexpr absl::uint128 kSeed0 = - absl::MakeUint128(0x0123012301230123, 0x0123012301230123); -constexpr absl::uint128 kSeed1 = - absl::MakeUint128(0x4567456745674567, 0x4567456745674567); -constexpr absl::uint128 kSeed2 = - absl::MakeUint128(0x89ab89ab89ab89ab, 0x89ab89ab89ab89ab); -constexpr absl::uint128 kSeed3 = - absl::MakeUint128(0xcdefcdefcdefcdef, 0xcdefcdefcdefcdef); - -TEST(Aes128FixedKeyHashTest, CreateSucceeds) { - DPF_EXPECT_OK(Aes128FixedKeyHash::Create(kKey0)); -} - -TEST(Aes128FixedKeyHashTest, SameKeysAndSeedsGenerateSameOutput) { - std::vector<absl::uint128> in; - - DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_0, - Aes128FixedKeyHash::Create(kKey0)); - DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_1, - Aes128FixedKeyHash::Create(kKey0)); - in = {kSeed0}; - // Initialize output arrays with different values, to make sure they are the - // same afterwards. - std::vector<absl::uint128> out_0(in.size(), kSeed2), out_1(in.size(), kSeed3); - - DPF_EXPECT_OK(prg_0.Evaluate(in, absl::MakeSpan(out_0))); - DPF_EXPECT_OK(prg_1.Evaluate(in, absl::MakeSpan(out_1))); - EXPECT_THAT(out_0, testing::ElementsAreArray(out_1)); -} - -TEST(Aes128FixedKeyHashTest, DifferentKeysGenerateDifferentOutput) { - std::vector<absl::uint128> in{kSeed0}; - - DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_0, - Aes128FixedKeyHash::Create(kKey0)); - DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_1, - Aes128FixedKeyHash::Create(kKey1)); - // Initialize output arrays with the same values, to make sure they are - // different afterwards. - std::vector<absl::uint128> out_0(in.size(), kSeed2), out_1(in.size(), kSeed2); - - DPF_EXPECT_OK(prg_0.Evaluate(in, absl::MakeSpan(out_0))); - DPF_EXPECT_OK(prg_1.Evaluate(in, absl::MakeSpan(out_1))); - EXPECT_THAT(out_0, testing::Not(testing::ElementsAreArray(out_1))); -} - -TEST(Aes128FixedKeyHashTest, DifferentSeedsGenerateDifferentOutput) { - DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg, - Aes128FixedKeyHash::Create(kKey0)); - std::vector<absl::uint128> in_0, in_1; - - in_0 = {kSeed0}; - in_1 = {kSeed1}; - // Initialize output arrays with the same values, to make sure they are - // different afterwards. - std::vector<absl::uint128> out_0(in_0.size(), kSeed2), - out_1(in_1.size(), kSeed2); - - DPF_EXPECT_OK(prg.Evaluate(in_0, absl::MakeSpan(out_0))); - DPF_EXPECT_OK(prg.Evaluate(in_1, absl::MakeSpan(out_1))); - EXPECT_THAT(out_0, testing::Not(testing::ElementsAreArray(out_1))); -} - -TEST(Aes128FixedKeyHashTest, BatchedEvaluationEqualsBlockWiseEvaluation) { - DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg, - Aes128FixedKeyHash::Create(kKey0)); - std::vector<absl::uint128> in_0, in_1, in_2; - - in_0 = {kSeed0}; - in_1 = {kSeed1}; - in_2 = {kSeed0, kSeed1}; - std::vector<absl::uint128> out_0(in_0.size()), out_1(in_1.size()), - out_2(in_2.size()); - - DPF_EXPECT_OK(prg.Evaluate(in_0, absl::MakeSpan(out_0))); - DPF_EXPECT_OK(prg.Evaluate(in_1, absl::MakeSpan(out_1))); - DPF_EXPECT_OK(prg.Evaluate(in_2, absl::MakeSpan(out_2))); - EXPECT_THAT(out_2, testing::ElementsAre(out_0[0], out_1[0])); -} - -TEST(Aes128FixedKeyHashTest, TestSpecificOutputValues) { - std::vector<absl::uint128> in, out_0, out_1; - - DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_0, - Aes128FixedKeyHash::Create(kKey0)); - DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_1, - Aes128FixedKeyHash::Create(kKey1)); - in = {kSeed0, kSeed1}; - out_0.resize(in.size()); - out_1.resize(in.size()); - - DPF_EXPECT_OK(prg_0.Evaluate(in, absl::MakeSpan(out_0))); - DPF_EXPECT_OK(prg_1.Evaluate(in, absl::MakeSpan(out_1))); - EXPECT_THAT(out_0, - testing::ElementsAre( - absl::MakeUint128(0x73c2dc14812be4ef, 0xeac64d09c8adf8ed), - absl::MakeUint128(0xb8f33653a53a8436, 0xaedf39b62de91d95))); - EXPECT_THAT(out_1, - testing::ElementsAre( - absl::MakeUint128(0x934704aff58fa233, 0xd3c20d1b9cc18d8f), - absl::MakeUint128(0x530098817046d284, 0x43e61d3273a04f7c))); -} - -TEST(Aes128FixedKeyHashTest, EvaluateFailsWhenSizesDontMatch) { - std::vector<absl::uint128> in{kSeed0}; - DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg, - Aes128FixedKeyHash::Create(kKey0)); - - std::vector<absl::uint128> out(in.size() + 1); - - EXPECT_THAT(prg.Evaluate(in, absl::MakeSpan(out)), - StatusIs(absl::StatusCode::kInvalidArgument, - "Input and output sizes don't match")); -} - -TEST(Aes128FixedKeyHashTest, TestThreadSafety) { - std::vector<absl::uint128> in{kSeed0}; - DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg, - Aes128FixedKeyHash::Create(kKey0)); - constexpr int kNumThreads = 1024; - - auto do_evaluation = [&prg, &in]() { - absl::uint128 out; - DPF_ASSERT_OK(prg.Evaluate(in, absl::MakeSpan(&out, 1))); - }; - - std::vector<std::thread> threads; - threads.reserve(kNumThreads); - for (int i = 0; i < kNumThreads; ++i) { - threads.emplace_back(do_evaluation); - } - - for (auto& thread : threads) { - thread.join(); - } -} - -} // namespace -} // namespace distributed_point_functions diff --git a/third_party/distributed_point_functions/code/dpf/distributed_point_function.cc b/third_party/distributed_point_functions/code/dpf/distributed_point_function.cc deleted file mode 100644 index a3c42d0497dc5..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/distributed_point_function.cc +++ /dev/null @@ -1,732 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dpf/distributed_point_function.h" - -#include <algorithm> -#include <array> -#include <cstddef> -#include <limits> -#include <memory> -#include <numeric> -#include <string> -#include <utility> -#include <vector> - -#include "absl/container/btree_map.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/inlined_vector.h" -#include "absl/log/absl_check.h" -#include "absl/log/absl_log.h" -#include "absl/memory/memory.h" -#include "absl/numeric/int128.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "dpf/internal/evaluate_prg_hwy.h" -#include "dpf/internal/get_hwy_mode.h" -#include "dpf/internal/proto_validator.h" -#include "dpf/internal/value_type_helpers.h" -#include "dpf/status_macros.h" -#include "google/protobuf/io/zero_copy_stream_impl_lite.h" -#include "hwy/aligned_allocator.h" -#include "openssl/rand.h" - -namespace distributed_point_functions { - -namespace { - -// PRG keys used to expand seeds using AES. The first two are used to compute -// correction words of seeds, while the last is used to compute correction -// words of the incremental DPF values. Values were computed by taking the -// first half of the SHA256 sum of the constant name, e.g., `echo -// "DistributedPointFunction::kPrgKeyLeft" | sha256sum` -constexpr absl::uint128 kPrgKeyLeft = - absl::MakeUint128(0x5be037ccf6a03de5ULL, 0x935f08d0a5b6a2fdULL); -constexpr absl::uint128 kPrgKeyRight = - absl::MakeUint128(0xef94b6aedebb026cULL, 0xe2ea1fe0f66f4d0bULL); -constexpr absl::uint128 kPrgKeyValue = - absl::MakeUint128(0x05a5d1588c5423e3ULL, 0x46a31101b21d1c98ULL); - -} // namespace - -DistributedPointFunction::DistributedPointFunction( - std::unique_ptr<dpf_internal::ProtoValidator> proto_validator, - std::vector<int> blocks_needed, Aes128FixedKeyHash prg_left, - Aes128FixedKeyHash prg_right, Aes128FixedKeyHash prg_value, - absl::flat_hash_map<std::string, ValueCorrectionFunction> - value_correction_functions) - : proto_validator_(std::move(proto_validator)), - parameters_(proto_validator_->parameters()), - tree_levels_needed_(proto_validator_->tree_levels_needed()), - tree_to_hierarchy_(proto_validator_->tree_to_hierarchy()), - hierarchy_to_tree_(proto_validator_->hierarchy_to_tree()), - blocks_needed_(std::move(blocks_needed)), - prg_left_(std::move(prg_left)), - prg_right_(std::move(prg_right)), - prg_value_(std::move(prg_value)), - value_correction_functions_(value_correction_functions) {} - -absl::StatusOr<std::vector<Value>> -DistributedPointFunction::ComputeValueCorrection( - int hierarchy_level, absl::Span<const absl::uint128> seeds, - absl::uint128 alpha, const Value& beta, bool invert) const { - // Compute value output component of the PRG on current seeds. To that end, we - // Compute x_0+0, ..., x_0+k-1, and x_1+0, ..., x_1+k-1, where x_i is the seed - // for helper i, and k is the number of blocks needed at the current hierarchy - // level. We use a single contiguous vector for both helpers, which allows us - // to use a single call to prg_value_.Evaluate. - int blocks_needed = blocks_needed_[hierarchy_level]; - std::vector<absl::uint128> expanded_seeds(2 * blocks_needed); - absl::Span<absl::uint128> expanded_seed_a(&expanded_seeds[0], blocks_needed); - absl::Span<absl::uint128> expanded_seed_b(&expanded_seeds[blocks_needed], - blocks_needed); - ABSL_DCHECK(seeds.size() == 2); - std::iota(expanded_seed_a.begin(), expanded_seed_a.end(), seeds[0]); - std::iota(expanded_seed_b.begin(), expanded_seed_b.end(), seeds[1]); - - // Evaluate PRG in place (this is safe as `Evaluate` creates a copy of the - // input). - DPF_RETURN_IF_ERROR( - prg_value_.Evaluate(expanded_seeds, absl::MakeSpan(expanded_seeds))); - - // Compute index in block for alpha at the current hierarchy level. - int index_in_block = DomainToBlockIndex(alpha, hierarchy_level); - - // Choose implementation depending on element_bitsize. - DPF_ASSIGN_OR_RETURN( - ValueCorrectionFunction func, - GetValueCorrectionFunction(parameters_[hierarchy_level])); - return func( - absl::string_view(reinterpret_cast<const char*>(expanded_seed_a.data()), - blocks_needed * sizeof(absl::uint128)), - absl::string_view(reinterpret_cast<const char*>(expanded_seed_b.data()), - blocks_needed * sizeof(absl::uint128)), - index_in_block, beta, invert); -} - -// Expands the PRG seeds at the next `tree_level`, updates `seeds` and -// `control_bits`, and writes the next correction word to `keys`. -absl::Status DistributedPointFunction::GenerateNext( - int tree_level, absl::uint128 alpha, absl::Span<const Value> beta, - absl::Span<absl::uint128> seeds, absl::Span<bool> control_bits, - absl::Span<DpfKey> keys) const { - // As in `GenerateKeysIncremental`, we annotate code with the corresponding - // lines from https://arxiv.org/pdf/2012.14884.pdf#figure.caption.12. - // - // Lines 13 & 14: Compute value correction word if there is a value on the - // current level. This is done here already, since we use the "PRG evaluation - // optimization" described in Appendix C.2 of the paper. Since we are using - // fixed-key AES as PRG, which can have arbitrary stretch, this optimization - // works even for large output groups. - CorrectionWord* correction_word = keys[0].add_correction_words(); - if (tree_to_hierarchy_.contains(tree_level - 1)) { - int hierarchy_level = tree_to_hierarchy_.at(tree_level - 1); - absl::uint128 alpha_prefix = 0; - int shift_amount = parameters_.back().log_domain_size() - - parameters_[hierarchy_level].log_domain_size(); - if (shift_amount < 128) { - alpha_prefix = alpha >> shift_amount; - } - DPF_ASSIGN_OR_RETURN( - std::vector<Value> value_correction, - ComputeValueCorrection(hierarchy_level, seeds, alpha_prefix, - beta[hierarchy_level], control_bits[1])); - for (const Value& value : value_correction) { - *(correction_word->add_value_correction()) = value; - } - } - - // Line 5: Expand seeds from previous level. - std::array<std::array<absl::uint128, 2>, 2> expanded_seeds; - DPF_RETURN_IF_ERROR( - prg_left_.Evaluate(seeds, absl::MakeSpan(expanded_seeds[0]))); - DPF_RETURN_IF_ERROR( - prg_right_.Evaluate(seeds, absl::MakeSpan(expanded_seeds[1]))); - std::array<std::array<bool, 2>, 2> expanded_control_bits; - expanded_control_bits[0][0] = - dpf_internal::ExtractAndClearLowestBit(expanded_seeds[0][0]); - expanded_control_bits[0][1] = - dpf_internal::ExtractAndClearLowestBit(expanded_seeds[0][1]); - expanded_control_bits[1][0] = - dpf_internal::ExtractAndClearLowestBit(expanded_seeds[1][0]); - expanded_control_bits[1][1] = - dpf_internal::ExtractAndClearLowestBit(expanded_seeds[1][1]); - - // Lines 6-8: Assign keep/lose branch depending on current bit of `alpha`. - bool current_bit = 0; - if (parameters_.back().log_domain_size() - tree_level < 128) { - current_bit = - (alpha & (absl::uint128{1} - << (parameters_.back().log_domain_size() - tree_level))) != 0; - } - bool keep = current_bit, lose = !current_bit; - - // Line 9: Compute seed correction word. - absl::uint128 seed_correction = - expanded_seeds[lose][0] ^ expanded_seeds[lose][1]; - - // Line 10: Compute control bit correction words. - std::array<bool, 2> control_bit_correction; - control_bit_correction[0] = expanded_control_bits[0][0] ^ - expanded_control_bits[0][1] ^ current_bit ^ 1; - control_bit_correction[1] = - expanded_control_bits[1][0] ^ expanded_control_bits[1][1] ^ current_bit; - - // We swap lines 11 and 12, since we first need to use the previous level's - // control bits before updating them. - - // Line 12: Update seeds. Note that there is a typo in the paper: The - // multiplication / AND needs to be done with the control bit of iteration - // l-1, not l. Note that unlike the original algorithm, we are using the - // corrected seed directly for the next iteration. This is secure as we're - // using AES with a different key (kPrgKeyValue) to compute the value - // correction word below. - seeds[0] = expanded_seeds[keep][0]; - seeds[1] = expanded_seeds[keep][1]; - if (control_bits[0]) { - seeds[0] ^= seed_correction; - } - if (control_bits[1]) { - seeds[1] ^= seed_correction; - } - - // Line 11: Update control bits. Again, same typo as in Line 12. - control_bits[0] = expanded_control_bits[keep][0] ^ - (control_bits[0] && control_bit_correction[keep]); - control_bits[1] = expanded_control_bits[keep][1] ^ - (control_bits[1] && control_bit_correction[keep]); - - // Line 15: Assemble correction word and add it to keys[0]. - correction_word->mutable_seed()->set_high( - absl::Uint128High64(seed_correction)); - correction_word->mutable_seed()->set_low(absl::Uint128Low64(seed_correction)); - correction_word->set_control_left(control_bit_correction[0]); - correction_word->set_control_right(control_bit_correction[1]); - - // Copy correction word to second key. - *(keys[1].add_correction_words()) = *correction_word; - - return absl::OkStatus(); -} - -absl::uint128 DistributedPointFunction::DomainToTreeIndex( - absl::uint128 domain_index, int hierarchy_level) const { - int block_index_bits = parameters_[hierarchy_level].log_domain_size() - - hierarchy_to_tree_[hierarchy_level]; - ABSL_DCHECK_LT(block_index_bits, 128); - return domain_index >> block_index_bits; -} - -int DistributedPointFunction::DomainToBlockIndex(absl::uint128 domain_index, - int hierarchy_level) const { - int block_index_bits = parameters_[hierarchy_level].log_domain_size() - - hierarchy_to_tree_[hierarchy_level]; - ABSL_DCHECK_LT(block_index_bits, 128); - return static_cast<int>(domain_index & - ((absl::uint128{1} << block_index_bits) - 1)); -} - -absl::Status DistributedPointFunction::EvaluateSeeds( - absl::Span<const absl::uint128> seeds, absl::Span<const bool> control_bits, - absl::Span<const absl::uint128> paths, - absl::Span<const CorrectionWord* const> correction_words, - absl::Span<absl::uint128> seeds_out, - absl::Span<bool> control_bits_out) const { - if (seeds.size() != control_bits.size() || seeds.size() != paths.size() || - seeds.size() != seeds_out.size() || - seeds.size() != control_bits_out.size()) { - return absl::InvalidArgumentError( - "`seeds`, `control_bits`, `paths`, `seeds_out`, and `control_bits_out` " - "must all have equal sizes"); - } - auto num_seeds = static_cast<int64_t>(seeds.size()); - auto num_levels = static_cast<int>(correction_words.size()); - if (num_seeds == 0 || num_levels == 0) { - return absl::OkStatus(); // Nothing to do. - } - - // Parse correction words for each level. - auto correction_seeds = hwy::AllocateAligned<absl::uint128>(num_levels); - if (correction_seeds == nullptr) { - return absl::ResourceExhaustedError("Memory allocation error"); - } - BitVector correction_controls_left(num_levels), - correction_controls_right(num_levels); - for (int level = 0; level < num_levels; ++level) { - const CorrectionWord& correction = *(correction_words[level]); - correction_seeds[level] = - absl::MakeUint128(correction.seed().high(), correction.seed().low()); - correction_controls_left[level] = correction.control_left(); - correction_controls_right[level] = correction.control_right(); - } - - ABSL_DCHECK(seeds.size() == num_seeds); - ABSL_DCHECK(control_bits.size() == num_seeds); - ABSL_DCHECK(correction_controls_left.size() == num_levels); - ABSL_DCHECK(correction_controls_right.size() == num_levels); - ABSL_DCHECK(seeds_out.size() == num_seeds); - ABSL_DCHECK(control_bits_out.size() == num_seeds); - DPF_RETURN_IF_ERROR(dpf_internal::EvaluateSeeds( - num_seeds, num_levels, num_levels, seeds.data(), control_bits.data(), - paths.data(), 0, correction_seeds.get(), correction_controls_left.data(), - correction_controls_right.data(), prg_left_, prg_right_, seeds_out.data(), - control_bits_out.data())); - return absl::OkStatus(); -} - -absl::StatusOr<DistributedPointFunction::DpfExpansion> -DistributedPointFunction::ExpandSeeds( - const DpfExpansion& partial_evaluations, - absl::Span<const CorrectionWord* const> correction_words) const { - int num_expansions = static_cast<int>(correction_words.size()); - - // Check that the output size fits in a size_t. This should already be checked - // by the caller, so using ABSL_DCHECK here is enough. - ABSL_DCHECK_LT(num_expansions, 63); - auto current_level_size = - static_cast<int64_t>(partial_evaluations.control_bits.size()); - absl::uint128 output_size_128 = absl::uint128{current_level_size} - << num_expansions; - ABSL_DCHECK_LE(output_size_128, std::numeric_limits<size_t>::max() / 2); - size_t output_size = static_cast<size_t>(output_size_128); - - // Allocate buffers with the correct size to avoid reallocations. - int64_t max_batch_size = Aes128FixedKeyHash::kBatchSize; - std::vector<absl::uint128> prg_buffer_left(max_batch_size), - prg_buffer_right(max_batch_size); - - // Copy seeds and control bits. We will swap these after every expansion. - DpfExpansion expansion; - expansion.seeds = hwy::AllocateAligned<absl::uint128>(output_size); - if (expansion.seeds == nullptr) { - return absl::ResourceExhaustedError("Memory allocation error"); - } - std::copy_n(partial_evaluations.seeds.get(), current_level_size, - expansion.seeds.get()); - expansion.control_bits = partial_evaluations.control_bits; - expansion.control_bits.reserve(output_size); - DpfExpansion next_level_expansion; - next_level_expansion.seeds = hwy::AllocateAligned<absl::uint128>(output_size); - if (next_level_expansion.seeds == nullptr) { - return absl::ResourceExhaustedError("Memory allocation error"); - } - next_level_expansion.control_bits.reserve(output_size); - - // We use an iterative expansion here to pipeline AES as much as possible. - for (int i = 0; i < num_expansions; ++i) { - next_level_expansion.control_bits.resize(0); - absl::uint128 correction_seed = absl::MakeUint128( - correction_words[i]->seed().high(), correction_words[i]->seed().low()); - bool correction_control_left = correction_words[i]->control_left(); - bool correction_control_right = correction_words[i]->control_right(); - // Expand PRG. - for (int64_t start_block = 0; start_block < current_level_size; - start_block += max_batch_size) { - int64_t batch_size = - std::min<int64_t>(current_level_size - start_block, max_batch_size); - DPF_RETURN_IF_ERROR(prg_left_.Evaluate( - absl::MakeConstSpan(expansion.seeds.get() + start_block, batch_size), - absl::MakeSpan(prg_buffer_left).subspan(0, batch_size))); - DPF_RETURN_IF_ERROR(prg_right_.Evaluate( - absl::MakeConstSpan(expansion.seeds.get() + start_block, batch_size), - absl::MakeSpan(prg_buffer_right).subspan(0, batch_size))); - - // Merge results into next level of seeds and perform correction. - for (int64_t j = 0; j < batch_size; ++j) { - const int64_t index_expanded = 2 * (start_block + j); - if (expansion.control_bits[start_block + j]) { - prg_buffer_left[j] ^= correction_seed; - prg_buffer_right[j] ^= correction_seed; - } - next_level_expansion.seeds[index_expanded] = prg_buffer_left[j]; - next_level_expansion.seeds[index_expanded + 1] = prg_buffer_right[j]; - next_level_expansion.control_bits.push_back( - dpf_internal::ExtractAndClearLowestBit( - next_level_expansion.seeds[index_expanded])); - next_level_expansion.control_bits.push_back( - dpf_internal::ExtractAndClearLowestBit( - next_level_expansion.seeds[index_expanded + 1])); - if (expansion.control_bits[start_block + j]) { - next_level_expansion.control_bits[index_expanded] ^= - correction_control_left; - next_level_expansion.control_bits[index_expanded + 1] ^= - correction_control_right; - } - } - } - std::swap(expansion, next_level_expansion); - current_level_size *= 2; - } - return expansion; -} - -absl::StatusOr<DistributedPointFunction::DpfExpansion> -DistributedPointFunction::ComputePartialEvaluations( - absl::Span<const absl::uint128> prefixes, int hierarchy_level, - bool update_ctx, EvaluationContext& ctx) const { - int64_t num_prefixes = static_cast<int64_t>(prefixes.size()); - - DpfExpansion partial_evaluations; - int start_level = hierarchy_to_tree_[ctx.partial_evaluations_level()]; - int stop_level = hierarchy_to_tree_[hierarchy_level]; - if (ctx.partial_evaluations_size() > 0 && start_level <= stop_level) { - // We have partial evaluations from a tree level before the current one. - // Parse `ctx.partial_evaluations` into a btree_map for quick lookups up by - // prefix. We use a btree_map because `ctx.partial_evaluations()` will - // usually be sorted. - absl::btree_map<absl::uint128, std::pair<absl::uint128, bool>> - previous_partial_evaluations; - for (const PartialEvaluation& element : ctx.partial_evaluations()) { - absl::uint128 prefix = - absl::MakeUint128(element.prefix().high(), element.prefix().low()); - // Try inserting `(seed, control_bit)` at `prefix` into - // partial_evaluations. Return an error if `prefix` is already present - // with a different seed or control bit. - auto value = std::make_pair( - absl::MakeUint128(element.seed().high(), element.seed().low()), - element.control_bit()); - auto it = previous_partial_evaluations.try_emplace( - previous_partial_evaluations.end(), prefix, value); - if (it->second != value) { - return absl::InvalidArgumentError( - "Duplicate prefix in `ctx.partial_evaluations()` with mismatching " - "seed or control bit"); - } - } - // Now select all partial evaluations from the map that correspond to - // `prefixes`. - partial_evaluations.seeds = - hwy::AllocateAligned<absl::uint128>(num_prefixes); - if (partial_evaluations.seeds == nullptr) { - return absl::ResourceExhaustedError("Memory allocation error"); - } - partial_evaluations.control_bits.reserve(num_prefixes); - for (int64_t i = 0; i < num_prefixes; ++i) { - absl::uint128 previous_prefix = 0; - if (stop_level - start_level < 128) { - previous_prefix = prefixes[i] >> (stop_level - start_level); - } - auto it = previous_partial_evaluations.find(previous_prefix); - if (it == previous_partial_evaluations.end()) { - return absl::InvalidArgumentError(absl::StrCat( - "Prefix not present in ctx.partial_evaluations at hierarchy level ", - hierarchy_level)); - } - const std::pair<absl::uint128, bool>& partial_evaluation = it->second; - partial_evaluations.seeds[partial_evaluations.control_bits.size()] = - partial_evaluation.first; - partial_evaluations.control_bits.push_back(partial_evaluation.second); - } - } else { - // No partial evaluations in `ctx` -> Start from the beginning. - partial_evaluations.seeds = - hwy::AllocateAligned<absl::uint128>(num_prefixes); - if (partial_evaluations.seeds == nullptr) { - return absl::ResourceExhaustedError("Memory allocation error"); - } - auto seeds = absl::MakeSpan(partial_evaluations.seeds.get(), num_prefixes); - std::fill( - seeds.begin(), seeds.end(), - absl::MakeUint128(ctx.key().seed().high(), ctx.key().seed().low())); - partial_evaluations.control_bits.resize( - num_prefixes, static_cast<bool>(ctx.key().party())); - start_level = 0; - } - - // Evaluate the DPF up to current_tree_level. - auto seeds = absl::MakeSpan(partial_evaluations.seeds.get(), - partial_evaluations.control_bits.size()); - DPF_RETURN_IF_ERROR( - EvaluateSeeds(seeds, partial_evaluations.control_bits, prefixes, - absl::MakeConstSpan(ctx.key().correction_words()) - .subspan(start_level, stop_level - start_level), - seeds, absl::MakeSpan(partial_evaluations.control_bits))); - - // Update `partial_evaluations` in `ctx` if there are more evaluations to - // come. - ctx.clear_partial_evaluations(); - ctx.mutable_partial_evaluations()->Reserve(num_prefixes); - if (update_ctx) { - for (int64_t i = 0; i < num_prefixes; ++i) { - PartialEvaluation* current_element = ctx.add_partial_evaluations(); - current_element->mutable_prefix()->set_high( - absl::Uint128High64(prefixes[i])); - current_element->mutable_prefix()->set_low( - absl::Uint128Low64(prefixes[i])); - current_element->mutable_seed()->set_high( - absl::Uint128High64(partial_evaluations.seeds[i])); - current_element->mutable_seed()->set_low( - absl::Uint128Low64(partial_evaluations.seeds[i])); - current_element->set_control_bit(partial_evaluations.control_bits[i]); - } - } - ctx.set_partial_evaluations_level(hierarchy_level); - return partial_evaluations; -} - -absl::StatusOr<DistributedPointFunction::DpfExpansion> -DistributedPointFunction::ExpandAndUpdateContext( - int hierarchy_level, absl::Span<const absl::uint128> prefixes, - EvaluationContext& ctx) const { - // Expand seeds by expanding either the DPF key seed, or - // `ctx.partial_evaluations` for the given `prefixes`. - DpfExpansion selected_partial_evaluations; - int start_level = 0; - if (prefixes.empty()) { - // First expansion -> Expand seed of the DPF key. - selected_partial_evaluations.seeds = hwy::AllocateAligned<absl::uint128>(1); - if (selected_partial_evaluations.seeds == nullptr) { - return absl::ResourceExhaustedError("Memory allocation error"); - } - selected_partial_evaluations.seeds[0] = - absl::MakeUint128(ctx.key().seed().high(), ctx.key().seed().low()); - selected_partial_evaluations.control_bits = { - static_cast<bool>(ctx.key().party())}; - } else { - // Second or later expansion -> Extract all seeds for `prefixes` from - // `ctx.partial_evaluations`. Update `ctx` if this is not the last - // evaluation. - bool update_ctx = - (hierarchy_level < static_cast<int>(parameters_.size()) - 1); - ABSL_DCHECK(ctx.previous_hierarchy_level() >= 0); - DPF_ASSIGN_OR_RETURN( - selected_partial_evaluations, - ComputePartialEvaluations(prefixes, ctx.previous_hierarchy_level(), - update_ctx, ctx)); - start_level = hierarchy_to_tree_[ctx.previous_hierarchy_level()]; - } - - // Expand up to the next hierarchy level. - int stop_level = hierarchy_to_tree_[hierarchy_level]; - DPF_ASSIGN_OR_RETURN( - DpfExpansion expansion, - ExpandSeeds(selected_partial_evaluations, - absl::MakeConstSpan(ctx.key().correction_words()) - .subspan(start_level, stop_level - start_level))); - - // Update hierarchy level in ctx. - ctx.set_previous_hierarchy_level(hierarchy_level); - return expansion; -} - -absl::StatusOr<hwy::AlignedFreeUniquePtr<absl::uint128[]>> -DistributedPointFunction::HashExpandedSeeds( - int hierarchy_level, absl::Span<const absl::uint128> expansion) const { - const auto expansion_size = static_cast<int64_t>(expansion.size()); - const int blocks_needed = blocks_needed_[hierarchy_level]; - auto hashed_expansion = - hwy::AllocateAligned<absl::uint128>(expansion_size * blocks_needed); - if (hashed_expansion == nullptr) { - return absl::ResourceExhaustedError("Memory allocation error"); - } - for (int64_t i = 0; i < expansion_size; ++i) { - for (int j = 0; j < blocks_needed; ++j) { - hashed_expansion[i * blocks_needed + j] = expansion[i] + j; - } - } - - // Evaluate PRG in place (this is safe as `Evaluate` creates a copy of the - // input). - absl::Span<absl::uint128> hashed_expansion_span( - hashed_expansion.get(), expansion_size * blocks_needed); - DPF_RETURN_IF_ERROR( - prg_value_.Evaluate(hashed_expansion_span, hashed_expansion_span)); - - return hashed_expansion; -} - -absl::StatusOr<std::string> -DistributedPointFunction::SerializeValueTypeDeterministically( - const ValueType& value_type) { - // We need to do serialization to a string by hand, in order to use - // deterministic serialization. - std::string serialized_value_type; - { // Start new block so that stream destructors are run before returning. - ::google::protobuf::io::StringOutputStream string_stream( - &serialized_value_type); - ::google::protobuf::io::CodedOutputStream coded_stream(&string_stream); - coded_stream.SetSerializationDeterministic(true); - if (!value_type.SerializeToCodedStream(&coded_stream)) { - return absl::InternalError("Serializing value_type to string failed"); - } - } - return serialized_value_type; -} - -absl::StatusOr<DistributedPointFunction::ValueCorrectionFunction> -DistributedPointFunction::GetValueCorrectionFunction( - const DpfParameters& parameters) const { - std::string serialized_value_type; - DPF_ASSIGN_OR_RETURN( - serialized_value_type, - SerializeValueTypeDeterministically(parameters.value_type())); - auto it = value_correction_functions_.find(serialized_value_type); - if (it == value_correction_functions_.end()) { - return absl::FailedPreconditionError(absl::StrCat( - "No value correction function known for the following parameters:\n", - parameters.DebugString(), - "Did you call RegisterValueType<T>() with your value type?")); - } - return it->second; -} - -absl::StatusOr<std::unique_ptr<DistributedPointFunction>> -DistributedPointFunction::Create(const DpfParameters& parameters) { - return CreateIncremental(absl::MakeConstSpan(¶meters, 1)); -} - -absl::StatusOr<std::unique_ptr<DistributedPointFunction>> -DistributedPointFunction::CreateIncremental( - absl::Span<const DpfParameters> parameters) { - // Log Highway mode for debugging. - ABSL_LOG_FIRST_N(INFO, 1) - << "Highway is in " << dpf_internal::GetHwyModeAsString() << " mode"; - - // Validate `parameters` and store validator for later. - DPF_ASSIGN_OR_RETURN( - std::unique_ptr<dpf_internal::ProtoValidator> proto_validator, - dpf_internal::ProtoValidator::Create(parameters)); - - // Compute the number of value correction blocks needed for each hierarchy - // level. - std::vector<int> blocks_needed(parameters.size()); - for (int i = 0; i < static_cast<int>(parameters.size()); ++i) { - DPF_ASSIGN_OR_RETURN( - int bits_needed, - dpf_internal::BitsNeeded(parameters[i].value_type(), - parameters[i].security_parameter())); - blocks_needed[i] = (bits_needed + 127) / 128; - } - - // Set up hash functions for PRG. - DPF_ASSIGN_OR_RETURN(Aes128FixedKeyHash prg_left, - Aes128FixedKeyHash::Create(kPrgKeyLeft)); - DPF_ASSIGN_OR_RETURN(Aes128FixedKeyHash prg_right, - Aes128FixedKeyHash::Create(kPrgKeyRight)); - DPF_ASSIGN_OR_RETURN(Aes128FixedKeyHash prg_value, - Aes128FixedKeyHash::Create(kPrgKeyValue)); - - // For backwards compatibility, register all single unsigned integers as value - // types. - absl::flat_hash_map<std::string, ValueCorrectionFunction> - value_correction_functions; - DPF_RETURN_IF_ERROR( - RegisterValueTypeImpl<uint8_t>(value_correction_functions)); - DPF_RETURN_IF_ERROR( - RegisterValueTypeImpl<uint16_t>(value_correction_functions)); - DPF_RETURN_IF_ERROR( - RegisterValueTypeImpl<uint32_t>(value_correction_functions)); - DPF_RETURN_IF_ERROR( - RegisterValueTypeImpl<uint64_t>(value_correction_functions)); - DPF_RETURN_IF_ERROR( - RegisterValueTypeImpl<absl::uint128>(value_correction_functions)); - - // Copy parameters and return new DPF. - return absl::WrapUnique(new DistributedPointFunction( - std::move(proto_validator), std::move(blocks_needed), std::move(prg_left), - std::move(prg_right), std::move(prg_value), - std::move(value_correction_functions))); -} - -absl::StatusOr<std::pair<DpfKey, DpfKey>> -DistributedPointFunction::GenerateKeysIncremental( - absl::uint128 alpha, absl::Span<const Value> beta) { - // Check validity of beta. - if (beta.size() != parameters_.size()) { - return absl::InvalidArgumentError( - "`beta` has to have the same size as `parameters` passed at " - "construction"); - } - for (int i = 0; i < static_cast<int>(parameters_.size()); ++i) { - absl::Status status = proto_validator_->ValidateValue(beta[i], i); - if (!status.ok()) { - return status; - } - } - - // Check validity of alpha. - int last_level_log_domain_size = parameters_.back().log_domain_size(); - if (last_level_log_domain_size < 128 && - alpha >= (absl::uint128{1} << last_level_log_domain_size)) { - return absl::InvalidArgumentError( - "`alpha` must be smaller than the output domain size"); - } - - std::array<DpfKey, 2> keys; - keys[0].set_party(0); - keys[1].set_party(1); - - // We will annotate the following code with the corresponding lines from the - // pseudocode in the Incremental DPF paper - // (https://arxiv.org/pdf/2012.14884.pdf, Figure 11). - // - // There are two possible dimensions for each variable at each level: Parties - // (0 or 1) and branches (left or right). For two-dimensional arrays, we use - // the outer dimension for the branch, and the inner dimension for the party. - // - // Line 2: Sample random seeds for each party. - std::array<absl::uint128, 2> seeds; - RAND_bytes(reinterpret_cast<uint8_t*>(&seeds[0]), sizeof(absl::uint128)); - RAND_bytes(reinterpret_cast<uint8_t*>(&seeds[1]), sizeof(absl::uint128)); - keys[0].mutable_seed()->set_high(absl::Uint128High64(seeds[0])); - keys[0].mutable_seed()->set_low(absl::Uint128Low64(seeds[0])); - keys[1].mutable_seed()->set_high(absl::Uint128High64(seeds[1])); - keys[1].mutable_seed()->set_low(absl::Uint128Low64(seeds[1])); - - // Line 3: Initialize control bits. - std::array<bool, 2> control_bits{0, 1}; - - // Line 4: Compute correction words for each level after the first one. - keys[0].mutable_correction_words()->Reserve(tree_levels_needed_ - 1); - keys[1].mutable_correction_words()->Reserve(tree_levels_needed_ - 1); - for (int i = 1; i < tree_levels_needed_; i++) { - DPF_RETURN_IF_ERROR(GenerateNext(i, alpha, beta, absl::MakeSpan(seeds), - absl::MakeSpan(control_bits), - absl::MakeSpan(keys))); - } - - // Compute output correction word for last layer. - DPF_ASSIGN_OR_RETURN( - std::vector<Value> last_level_value_correction, - ComputeValueCorrection(parameters_.size() - 1, seeds, alpha, beta.back(), - control_bits[1])); - for (const Value& value : last_level_value_correction) { - *(keys[0].add_last_level_value_correction()) = value; - *(keys[1].add_last_level_value_correction()) = value; - } - - return std::make_pair(std::move(keys[0]), std::move(keys[1])); -} - -absl::StatusOr<EvaluationContext> -DistributedPointFunction::CreateEvaluationContext(DpfKey key) const { - // Check that `key` is valid. - DPF_RETURN_IF_ERROR(proto_validator_->ValidateDpfKey(key)); - - // Create new EvaluationContext with `parameters_` and `key`. - EvaluationContext result; - for (int i = 0; i < static_cast<int>(parameters_.size()); ++i) { - *(result.add_parameters()) = parameters_[i]; - } - *(result.mutable_key()) = std::move(key); - // previous_hierarchy_level = -1 means that this context has not been - // evaluated at all. - result.set_previous_hierarchy_level(-1); - return result; -} - -} // namespace distributed_point_functions diff --git a/third_party/distributed_point_functions/code/dpf/distributed_point_function.h b/third_party/distributed_point_functions/code/dpf/distributed_point_function.h deleted file mode 100644 index 6cd1c56a30836..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/distributed_point_function.h +++ /dev/null @@ -1,1211 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_DISTRIBUTED_POINT_FUNCTION_H_ -#define DISTRIBUTED_POINT_FUNCTIONS_DPF_DISTRIBUTED_POINT_FUNCTION_H_ - -#include <algorithm> -#include <array> -#include <cstddef> -#include <cstdint> -#include <limits> -#include <memory> -#include <string> -#include <tuple> -#include <type_traits> -#include <utility> -#include <vector> - -#include "absl/container/btree_map.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/inlined_vector.h" -#include "absl/log/absl_check.h" -#include "absl/meta/type_traits.h" -#include "absl/numeric/int128.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "dpf/aes_128_fixed_key_hash.h" -#include "dpf/distributed_point_function.pb.h" -#include "dpf/internal/evaluate_prg_hwy.h" -#include "dpf/internal/maybe_deref_span.h" -#include "dpf/internal/proto_validator.h" -#include "dpf/internal/value_type_helpers.h" -#include "google/protobuf/repeated_ptr_field.h" -#include "hwy/aligned_allocator.h" - -namespace distributed_point_functions { - -// Type trait for all supported types. Used to provide meaningful error messages -// in std::enable_if template guards. -template <typename T> -using is_supported_type = dpf_internal::is_supported_type<T>; -template <typename T> -constexpr bool is_supported_type_v = is_supported_type<T>::value; - -// Converts a given Value to the template parameter T. -// -// Returns INVALID_ARGUMENT if the conversion fails. -template <typename T, typename = absl::enable_if_t<is_supported_type_v<T>>> -absl::StatusOr<T> FromValue(const Value& value) { - return dpf_internal::ValueTypeHelper<T>::FromValue(value); -} - -// ToValue Converts the argument to a Value. -template <typename T, typename = absl::enable_if_t<is_supported_type_v<T>>> -Value ToValue(const T& input) { - return dpf_internal::ValueTypeHelper<T>::ToValue(input); -} - -// ToValueType<T> Returns a `ValueType` message describing T. -template <typename T, typename = absl::enable_if_t<is_supported_type_v<T>>> -ValueType ToValueType() { - return dpf_internal::ValueTypeHelper<T>::ToValueType(); -} - -// Implements key generation and evaluation of distributed point functions. -// A distributed point function (DPF) is parameterized by an index `alpha` and a -// value `beta`. The key generation procedure produces two keys `k_a`, `k_b`. -// Evaluating each key on any point `x` in the DPF domain results in an additive -// secret share of `beta`, if `x == alpha`, and a share of 0 otherwise. This -// class also supports *incremental* DPFs that can additionally be evaluated on -// prefixes of points, resulting in different values `beta_i`for each prefix of -// `alpha`. -class DistributedPointFunction { - public: - // Creates a new instance of a distributed point function that can be - // evaluated only at the output layer. - // - // Returns INVALID_ARGUMENT if the parameters are invalid. - static absl::StatusOr<std::unique_ptr<DistributedPointFunction>> Create( - const DpfParameters& parameters); - - // Creates a new instance of an *incremental* DPF that can be evaluated at - // multiple layers. Each parameter set in `parameters` should specify the - // domain size and element size at one of the layers to be evaluated, in - // increasing domain size order. Element sizes must be non-decreasing. - // - // Returns INVALID_ARGUMENT if the parameters are invalid. - static absl::StatusOr<std::unique_ptr<DistributedPointFunction>> - CreateIncremental(absl::Span<const DpfParameters> parameters); - - // DistributedPointFunction is neither copyable nor movable. - DistributedPointFunction(const DistributedPointFunction&) = delete; - DistributedPointFunction& operator=(const DistributedPointFunction&) = delete; - - // Converts the argument to a `Value` proto. Also registers the corresponding - // value type with the DPF by calling `RegisterValueType<T>()`. - template <typename T> - absl::StatusOr<Value> ToValue(const T& in) { - absl::Status status = RegisterValueType<T>(); - if (!status.ok()) { - return status; - } - return distributed_point_functions::ToValue(in); - } - - // Registers the template parameter type with this DPF. Note that it is rarely - // necessary to call this function by hand: It is called by `Create` and - // `CreateIncremental` for all unsigned integer types, including - // absl::uint128, and on every call to ToValue<T>. Only call this function - // when passing `Value`s created by other means than ToValue<T>. - // - // Returns OK on success and otherwise an INTERNAL status describing the - // failure. - template <typename T> - absl::Status RegisterValueType() { - return RegisterValueTypeImpl<T>(value_correction_functions_); - } - - // Generates a pair of keys for a DPF that evaluates to `beta` when evaluated - // `alpha`. The type of `beta` must match the ValueType passed in `parameters` - // at construction. - // - // This function provides three overloads: One with `absl::uint128` for - // `beta`, which implies the output type is a simple integer; One with a - // `Value` proto for `beta`, which can be used for all supported value types; - // And a templated version that computes the Value by calling ToValue<T> on - // the argument. - // - // Example Usages (assuming a std::unique_ptr<DistributedPointFunction> dpf): - // - // // Simple integer: - // dpf->GenerateKeys(23, 42); - // - // // Explicit `Value` proto: - // Value value; - // value[1]->mutable_tuple->add_elements() - // ->mutable_integer->set_value_uint64(12); - // value[1]->mutable_tuple->add_elements() - // ->mutable_integer->set_value_uint64(34); - // // Must be called once before calling GenerateKeys for any type that is - // // not a simple integer. The type should match the one in the - // // DpfParameters passed at construction. - // dpf->RegisterValueType<Tuple<uint32_t, uint64_t>>(); - // dpf->GenerateKeys(23, value); - // - // // Templated version (no call to RegisterValueType needed): - // dpf->GenerateKeys(23, Tuple<uint32_t, uint64_t>{12, 34}); - // - // Returns INVALID_ARGUMENT if used on an incremental DPF with more - // than one set of parameters, if `alpha` is outside of the domain specified - // at construction, or if `beta` does not match the value type passed at - // construction. - // Returns FAILED_PRECONDITION if `RegisterValueType<T>` has not been called - // for the type in the `DpfParameters` passed at construction. - - // Overload for simple integers. - absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeys(absl::uint128 alpha, - absl::uint128 beta) { - return GenerateKeysIncremental(alpha, absl::MakeConstSpan(&beta, 1)); - } - - // Overload for explicit Value proto. - absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeys(absl::uint128 alpha, - Value beta) { - return GenerateKeysIncremental(alpha, absl::MakeConstSpan(&beta, 1)); - } - - // Template for automatic conversion to Value proto. Disabled if the argument - // is convertible to `absl::uint128` or `Value` to make overloading - // unambiguous. - template <typename T, typename = absl::enable_if_t< - !std::is_convertible<T, absl::uint128>::value && - !std::is_convertible<T, Value>::value && - is_supported_type_v<T>>> - absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeys(absl::uint128 alpha, - const T& beta) { - absl::StatusOr<Value> value = ToValue<T>(beta); - if (!value.ok()) { - return value.status(); - } - return GenerateKeysIncremental(alpha, absl::MakeConstSpan(&(*value), 1)); - } - - // Generates a pair of keys for an incremental DPF. For each parameter i - // passed at construction, the DPF evaluates to `beta[i]` at the lowest - // `parameters_[i].log_domain_size()` bits of `alpha`. - // - // Similar to `GenerateKeys`, supports three overloads: One for simple - // integers, passed as an `absl::Span<const absl::uint128>`; One for a span of - // `Value` protos; And a variadic function template that automatically - // converts the passed arguments to a vector of `Value`s. - // - // Example Usages (assuming a std::unique_ptr<DistributedPointFunction> dpf): - // - // // Simple integers: - // std::vector<absl::uint128> beta{123, 456}; - // dpf->GenerateKeysIncremental(23, beta); - // - // // Explicit Value protos: - // std::vector<Value> beta(2); - // value[0]->mutable_integer()->set_value_uint128(42); - // value[1]->mutable_tuple->add_elements() - // ->mutable_integer->set_value_uint64(12); - // value[1]->mutable_tuple->add_elements() - // ->mutable_integer->set_value_uint64(34); - // // Must be called once before calling GenerateKeys for any type that is - // // not a simple integer. The type should match the one in the - // // DpfParameters passed at construction. - // dpf->RegisterValueType<Tuple<uint32_t, uint64_t>>(); - // dpf->GenerateKeysIncremental(23, beta); - // - // // Templated version (equivalent to the one above): - // dpf->GenerateKeysIncremental(23, 42, Tuple<uint32_t, uint64_t>{12, 34})); - // - // Returns INVALID_ARGUMENT if `beta.size() != parameters_.size()`, if `alpha` - // is outside of the domain specified at construction, or if `beta` does not - // match the element type passed at construction. - // Returns FAILED_PRECONDITION if `RegisterValueType<T>` has not been called - // for all types in the `DpfParameters` passed at construction. - - // Legacy interface for absl::uint128, which doesn't require explicitly - // converting to absl::Span<const absl::uint128>. - absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental( - absl::uint128 alpha, const std::vector<absl::uint128>& beta) { - return GenerateKeysIncremental(alpha, absl::MakeConstSpan(beta)); - } - - // Templated version when all value types are equal. - template <typename T> - absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental( - absl::uint128 alpha, absl::Span<const T> beta) { - std::vector<Value> values(beta.size()); - for (int i = 0; i < static_cast<int>(beta.size()); ++i) { - absl::StatusOr<Value> value = ToValue(beta[i]); - if (!value.ok()) { - return value.status(); - } - values[i] = std::move(*value); - } - return GenerateKeysIncremental(alpha, values); - } - - // Overload for Value protos. - absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental( - absl::uint128 alpha, absl::Span<const Value> beta); - - // Variadic template version. Disabled if the first argument is convertible to - // a span of `absl::uint128`s or `Value`s to make overloading unambiguous. - template < - typename T0, typename... Tn, - typename = absl::enable_if_t< - !std::is_convertible<T0, absl::Span<const Value>>::value && - !std::is_convertible<T0, absl::Span<const absl::uint128>>::value && - absl::conjunction<is_supported_type<T0>, - is_supported_type<Tn>...>::value>> - absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental( - absl::uint128 alpha, T0&& beta_0, Tn&&... beta_n); - - // Returns an `EvaluationContext` for incrementally evaluating the given - // DpfKey. - // - // Returns INVALID_ARGUMENT if `key` doesn't match the parameters given at - // construction. - absl::StatusOr<EvaluationContext> CreateEvaluationContext(DpfKey key) const; - - // Evaluates the given `hierarchy_level` of the DPF under all `prefixes` - // passed to this function. If `prefixes` is empty, evaluation starts from the - // seed of `ctx.key`. Otherwise, each element of `prefixes` must fit in the - // domain size of `ctx.previous_hierarchy_level`. Further, `prefixes` may only - // contain extensions of the prefixes passed in the previous call. For - // example, in the following sequence of calls, for each element p2 of - // `prefixes2`, there must be an element p1 of `prefixes1` such that p1 is a - // prefix of p2: - // - // DPF_ASSIGN_OR_RETURN(std::unique_ptr<EvaluationContext> ctx, - // dpf->CreateEvaluationContext(key)); - // using T0 = ...; - // DPF_ASSIGN_OR_RETURN(std::vector<T0> evaluations0, - // dpf->EvaluateUntil(0, {}, *ctx)); - // - // std::vector<absl::uint128> prefixes1 = ...; - // using T1 = ...; - // DPF_ASSIGN_OR_RETURN(std::vector<T1> evaluations1, - // dpf->EvaluateUntil(1, prefixes1, *ctx)); - // ... - // std::vector<absl::uint128> prefixes2 = ...; - // using T2 = ...; - // DPF_ASSIGN_OR_RETURN(std::vector<T2> evaluations2, - // dpf->EvaluateUntil(3, prefixes2, *ctx)); - // - // The prefixes are read from the lowest-order bits of the corresponding - // absl::uint128. The number of bits used for each prefix depends on the - // output domain size of the previously evaluated hierarchy level. For - // example, if `ctx` was last evaluated on a hierarchy level with output - // domain size 2**20, then the 20 lowest-order bits of each element in - // `prefixes` are used. - // - // Returns `INVALID_ARGUMENT` if - // - any element of `prefixes` is larger than the next hierarchy level's - // log_domain_size, - // - `prefixes` contains elements that are not extensions of previous - // prefixes, or - // - the bit-size of T doesn't match the next hierarchy level's - // element_bitsize. - template <typename T> - absl::StatusOr<std::vector<T>> EvaluateUntil( - int hierarchy_level, absl::Span<const absl::uint128> prefixes, - EvaluationContext& ctx) const; - - template <typename T> - absl::StatusOr<std::vector<T>> EvaluateNext( - absl::Span<const absl::uint128> prefixes, EvaluationContext& ctx) const { - if (prefixes.empty()) { - return EvaluateUntil<T>(0, prefixes, ctx); - } else { - return EvaluateUntil<T>(ctx.previous_hierarchy_level() + 1, prefixes, - ctx); - } - } - - // Evaluates a single key at one or multiple points, up to the given - // `hierarchy_level`. Each element of `evaluation_points` must be within the - // domain of this DPF at `hierarchy_level`. - // - // Example: - // - // DpfKey key = ...; - // std::vector<absl::uint128> evaluation_points = {1, 23, 42}; - // // Evaluate `key` on {1, 23, 42}. - // DPF_ASSIGN_OR_RETURN(std::vector<T> result, - // dpf->EvaluateAt(key, 0, evaluation_points); - // - // Returns INVALID_ARGUMENT if `key` is malformed, or if `hierarchy_level` or - // any element of `evaluation_points` is out of range. - template <typename T> - absl::StatusOr<std::vector<T>> EvaluateAt( - const DpfKey& key, int hierarchy_level, - absl::Span<const absl::uint128> evaluation_points) const { - return EvaluateAtImpl<T>(key, hierarchy_level, evaluation_points, nullptr); - } - - // Evaluates a single key at one or multiple points, up to the given - // `hierarchy_level`. Each element of `evaluation_points` must be within the - // domain of this DPF at `hierarchy_level`. - // - // If `ctx.partial_evaluations_size() != 0`, uses the given partial - // evaluations as starting point of the DPF evaluation. Otherwise, the result - // is equivalent to calling `EvaluateAt(ctx.key(), hierarchy_level, - // evaluation_points)`. - // - // When successful, `ctx` is updated to include partial evaluations at - // `hierarchy_level`. The contents of `ctx` are undefined in case of an error. - // - // Returns INVALID_ARGUMENT if `ctx` is malformed, if `hierarchy_level` or - // any element of `evaluation_points` is out of range, or - // `ctx.partial_evaluations()` does not contain the prefixes of all - // `evaluation_points` at `ctx.partial_evaluations_level()`. - template <typename T> - absl::StatusOr<std::vector<T>> EvaluateAt( - int hierarchy_level, absl::Span<const absl::uint128> evaluation_points, - EvaluationContext& ctx) const { - return EvaluateAtImpl<T>(ctx.key(), hierarchy_level, evaluation_points, - &ctx); - } - - // Evaluates a span of DPF keys. The i-th key is evaluated at - // evaluation_points[i]. After each hierarchy level, calls `op` on the output - // at that hierarchy level. `op` must be callable with the following - // signature: - // - // op(int hierarchy_level, absl::Span<T> values) - // - // It should return a value that is implicitly convertible to `bool`. - // - // This method is intended for use cases similar to - // - // absl::StatusOr<std::vector<T>> EvaluateAt( - // int hierarchy_level, absl::Span<const absl::uint128> evaluation_points, - // EvaluationContext& ctx) - // - // but without the overhead of EvaluationContext. Instead, all operations on - // intermediate values, and obtaining the final result, should be done via - // `op`. - // - // Return absl::OkStatus() after successfully evaluating `op` on the last - // hierarchy level, or as soon as `op` returns `false`. Returns - // INVALID_ARGUMENT in case any `key` is malformed, or if any of the - // `evaluation_points` are out of range. - template <typename T, typename Fn> - absl::Status EvaluateAndApply( - dpf_internal::MaybeDerefSpan<const DpfKey>, - absl::Span<const absl::uint128> evaluation_points, Fn op, - int evaluation_points_rightshift = 0) const; - - // Returns the DpfParameters of this DPF. - inline absl::Span<const DpfParameters> parameters() const { - return parameters_; - } - - private: - // BitVector is a vector of bools. Allows for faster access times than - // std::vector<bool>, as well as inlining if the size is small. - using BitVector = - absl::InlinedVector<bool, - std::max<size_t>(1, sizeof(bool*) / sizeof(bool))>; - - // Seeds and control bits resulting from a DPF expansion. This type is - // returned by `ExpandSeeds` and `ExpandAndUpdateContext`. - struct DpfExpansion { - // Ensures that seeds are aligned correctly for SIMD operations. - hwy::AlignedFreeUniquePtr<absl::uint128[]> seeds; - BitVector control_bits; - }; - - // A function for computing value corrections. Used as return type in - // `GetValueCorrectionFunction`. - using ValueCorrectionFunction = absl::StatusOr<std::vector<Value>> (*)( - absl::string_view, absl::string_view, int block_index, const Value&, - bool); - - // Private constructor, called by `CreateIncremental`. - DistributedPointFunction( - std::unique_ptr<dpf_internal::ProtoValidator> proto_validator, - std::vector<int> blocks_needed, Aes128FixedKeyHash prg_left, - Aes128FixedKeyHash prg_right, Aes128FixedKeyHash prg_value, - absl::flat_hash_map<std::string, ValueCorrectionFunction> - value_correction_functions); - - // Computes the value correction for the given `hierarchy_level`, `seeds`, - // index `alpha` and value `beta`. If `invert` is true, the individual values - // in the returned block are multiplied element-wise by -1. Expands `seeds` - // using `prg_ctx_value_`, then calls the function returned by - // `GetValueCorrectionFunction(parameters_[hierarchy_level])` to obtain the - // value correction words. - // - // Returns multiple values in the case of packing, and a single Value - // otherwise. - // - // Returns INTERNAL in case the PRG expansion fails, and UNIMPLEMENTED if - // `element_bitsize` is not supported. - absl::StatusOr<std::vector<Value>> ComputeValueCorrection( - int hierarchy_level, absl::Span<const absl::uint128> seeds, - absl::uint128 alpha, const Value& beta, bool invert) const; - - // Expands the PRG seeds at the next `tree_level` for an incremental DPF with - // index `alpha` and values `beta`, updates `seeds` and `control_bits`, and - // writes the next correction word to `keys`. Called from - // `GenerateKeysIncremental`. - absl::Status GenerateNext(int tree_level, absl::uint128 alpha, - absl::Span<const Value> beta, - absl::Span<absl::uint128> seeds, - absl::Span<bool> control_bits, - absl::Span<DpfKey> keys) const; - - // Computes the tree index (representing a path in the FSS tree) from the - // given `domain_index` and `hierarchy_level`. Does NOT check whether the - // given domain index fits in the domain at `hierarchy_level`. - absl::uint128 DomainToTreeIndex(absl::uint128 domain_index, - int hierarchy_level) const; - - // Computes the block index (pointing to an element in a batched 128-bit - // block) from the given `domain_index` and `hierarchy_level`. Does NOT check - // whether the given domain index fits in the domain at `hierarchy_level`. - int DomainToBlockIndex(absl::uint128 domain_index, int hierarchy_level) const; - - // Performs DPF evaluation of the given `seeds` using prg_ctx_left_ or - // prg_ctx_right_, and the given `control_bits` and `correction_words`. At - // each level `l < correction_words.size()`, the evaluation for the i-th seed - // in `partial_evaluations` continues along the left or right path depending - // on the l-th most significant bit among the lowest `correction_words.size()` - // bits of `paths[i]`. - // - // The output is written to `seeds_out` and `control_bits_out`. These may - // overlap with `seeds` and `control_bits`. We use output spans instead of a - // return value to allow the caller to pre-allocate aligned output arrays, - // which is necessary for the vectorized implementation. The output is - // undefined if `correction_words.size() == 0`. - // - // Returns INVALID_ARGUMENT if the input sizes don't match. - // Returns INTERNAL in case of OpenSSL errors. - absl::Status EvaluateSeeds( - absl::Span<const absl::uint128> seeds, - absl::Span<const bool> control_bits, - absl::Span<const absl::uint128> paths, - absl::Span<const CorrectionWord* const> correction_words, - absl::Span<absl::uint128> seeds_out, - absl::Span<bool> control_bits_out) const; - - // Performs DPF expansion of the given `partial_evaluations` using - // prg_ctx_left_ and prg_ctx_right_, and the given `correction_words`. In - // more detail, each of the partial evaluations is subjected to a full - // subtree expansion of `correction_words.size()` levels, and the - // concatenated result is provided in the response. The result contains - // `(partial_evaluations.size() * (2^correction_words.size())` evaluations - // in a single `DpfExpansion`. - // - // Returns INTERNAL in case of OpenSSL errors. - absl::StatusOr<DpfExpansion> ExpandSeeds( - const DpfExpansion& partial_evaluations, - absl::Span<const CorrectionWord* const> correction_words) const; - - // Computes partial evaluations of the paths to `prefixes` up to - // `hierarchy_level`, to be used as the starting point of the expansion of - // `ctx`. If `update_ctx - // == true`, saves the partial evaluations of `ctx.previous_hierarchy_level` - // to `ctx` and sets `ctx.partial_evaluations_level` to - // `ctx.previous_hierarchy_level`. Called by `ExpandAndUpdateContext`. - // - // Returns INVALID_ARGUMENT if any element of `prefixes` is not found in - // `ctx.partial_evaluations()`, or `ctx.partial_evaluations()` contains - // duplicate prefixes with inconsistent seeds or control bits. - absl::StatusOr<DpfExpansion> ComputePartialEvaluations( - absl::Span<const absl::uint128> prefixes, int hierarchy_level, - bool update_ctx, EvaluationContext& ctx) const; - - // Extracts the seeds for the given `prefixes` from `ctx` and expands them as - // far as needed for the next hierarchy level. Returns the result as a - // `DpfExpansion`. Called by `EvaluateUntil`, where the expanded seeds are - // corrected to obtain output values. - // After expansion, `ctx.hierarchy_level()` is increased. If this isn't the - // last expansion, the expanded seeds are also saved in `ctx` for the next - // expansion. - // - // Returns INVALID_ARGUMENT if any element of `prefixes` is not found in - // `ctx.partial_evaluations()`, or `ctx.partial_evaluations()` contains - // duplicate prefixes with inconsistent seeds or control bits. Returns - // INTERNAL in case of OpenSSL errors. - absl::StatusOr<DpfExpansion> ExpandAndUpdateContext( - int hierarchy_level, absl::Span<const absl::uint128> prefixes, - EvaluationContext& ctx) const; - - // Compute output PRG value of expanded seeds using prg_ctx_value_. - // Returns blocks_needed_[hierarchy_level] * expansion.seeds.size() blocks, - // where every blocks_needed_[hierarchy_level] correspond to the hash of an - // input seed. - // - // Returns INTERNAL in case of OpenSSL errors. - absl::StatusOr<hwy::AlignedFreeUniquePtr<absl::uint128[]>> HashExpandedSeeds( - int hierarchy_level, absl::Span<const absl::uint128> expansion) const; - - // Deterministically serializes the given value_type. - // - // Returns OK on success and INTERNAL in case serialization fails. - static absl::StatusOr<std::string> SerializeValueTypeDeterministically( - const ValueType& value_type); - - // Returns the value correction function for the given parameters. - // For all value types except unsigned integers, these functions have to be - // first registered using RegisterValueType<T>. - // - // Returns UNIMPLEMENTED if no matching function was registered. - absl::StatusOr<ValueCorrectionFunction> GetValueCorrectionFunction( - const DpfParameters& parameters) const; - - // Static implementation of RegisterValueType<T>, so we can call it from - // `Create`. - template <typename T> - static absl::Status RegisterValueTypeImpl( - absl::flat_hash_map<std::string, ValueCorrectionFunction>& - value_correction_functions); - - // For the given `key` and `hierarchy_level`, returns the value correction - // words as an array of integers, where the size of the array matches the - // number of batched elements per block. - template <typename T> - absl::StatusOr<std::array<T, dpf_internal::ElementsPerBlock<T>()>> - GetValueCorrectionAsArray(const DpfKey& key, int hierarchy_level) const; - - // Joint implementation of the two variants of `EvaluateAt<T>`. If `ctx != - // NULL`, `key` must point to `ctx->key()`, and `*ctx` will be updated with - // the partial evaluations at this `hierarchy_level`. - // - template <typename T> - absl::StatusOr<std::vector<T>> EvaluateAtImpl( - const DpfKey& key, int hierarchy_level, - absl::Span<const absl::uint128> evaluation_points, - EvaluationContext* ctx) const; - - // Used to validate DpfParameters, DpfKey and EvaluationContext protos. - const std::unique_ptr<dpf_internal::ProtoValidator> proto_validator_; - - // DP parameters passed to the factory function. Contains the domain size and - // element size for hierarchy level of the incremental DPF. Owned by - // proto_validator_. - const absl::Span<const DpfParameters> parameters_; - - // Number of levels in the evaluation tree. This is always less than or equal - // to the largest log_domain_size in parameters_. - const int tree_levels_needed_; - - // Maps levels of the FSS evaluation tree to hierarchy levels (i.e., elements - // of parameters_). - const absl::flat_hash_map<int, int>& tree_to_hierarchy_; - - // The inverse of tree_to_hierarchy_. - const std::vector<int>& hierarchy_to_tree_; - - // Cached numbers of AES blocks needed for value correction at each hierarchy - // level. - const std::vector<int> blocks_needed_; - - // Pseudorandom generator used for seed expansion (left and right), and value - // correction. The PRG G(x) for hierarchy level i is defined as the - // concatenation of - // - // H_left(x), H_right(x), H_value(x + 0), ..., H_value(x + k-1) - // - // where k is equal to blocks_needed_[i], and H_*(x) is the evaluation of - // prg_*_ on input x. - const Aes128FixedKeyHash prg_left_; - const Aes128FixedKeyHash prg_right_; - const Aes128FixedKeyHash prg_value_; - - // Maps serialized `ValueType` messages to the correct value correction - // functions. Map values are instantiations of - // `dpf_internal::ComputeValueCorrectionFor`. Relies on protobuf's - // deterministic serialization feature. This has the caveat that messages with - // unknown fields are not supported. However, as long as `ValueType` consists - // of a single `oneof` field, this is fine, since we either know the value - // type and have deterministic serialization because the `ValueType` can only - // contain one field, or we don't know the type and wouldn't be able to - // correct values for it anyway. - absl::flat_hash_map<std::string, ValueCorrectionFunction> - value_correction_functions_; -}; - -//========================// -// Implementation Details // -//========================// - -template <typename T> -absl::Status DistributedPointFunction::RegisterValueTypeImpl( - absl::flat_hash_map<std::string, ValueCorrectionFunction>& - value_correction_functions) { - ValueType value_type = ToValueType<T>(); - absl::StatusOr<std::string> serialized_value_type = - SerializeValueTypeDeterministically(value_type); - if (!serialized_value_type.ok()) { - return serialized_value_type.status(); - } - value_correction_functions[*serialized_value_type] = - dpf_internal::ComputeValueCorrectionFor<T>; - return absl::OkStatus(); -} - -template <typename T0, typename... Tn, typename /*= absl::enable_if_t<...>*/> -absl::StatusOr<std::pair<DpfKey, DpfKey>> -DistributedPointFunction::GenerateKeysIncremental(absl::uint128 alpha, - T0&& beta_0, Tn&&... beta_n) { - // Convert the first element of beta. We need to treat it separately to be - // able to check its type in the enable_if above. - absl::StatusOr<Value> value = ToValue(beta_0); - if (!value.ok()) { - return value.status(); - } - std::vector<Value> values = {std::move(*value)}; - values.reserve(1 + sizeof...(beta_n)); - // Convert all values in the parameter pack, stopping at the first error. - absl::Status status = absl::OkStatus(); - // We create an unused std::tuple<Tn...> here, because its braced-initializer - // list constructor allows us to operate on beta_n in a well-defined order. In - // C++17, this could be replaced by a fold expression instead. - std::tuple<Tn...>{[this, &status, &values, &value](auto&& beta_i) -> Tn { - if (status.ok()) { - value = this->ToValue(beta_i); - if (value.ok()) { - values.push_back(std::move(*value)); - } else { - status = value.status(); - } - } - return Tn{}; - }(beta_n)...}; - // Return if there was an error during conversion, otherwise generate keys. - if (!status.ok()) { - return status; - } - return GenerateKeysIncremental(alpha, values); -} - -template <typename T> -absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateUntil( - int hierarchy_level, absl::Span<const absl::uint128> prefixes, - EvaluationContext& ctx) const { - absl::Status status = proto_validator_->ValidateEvaluationContext(ctx); - if (!status.ok()) { - return status; - } - if (hierarchy_level < 0 || - hierarchy_level >= static_cast<int>(parameters_.size())) { - return absl::InvalidArgumentError( - "`hierarchy_level` must be non-negative and less than " - "parameters_.size()"); - } - absl::StatusOr<bool> types_are_equal = dpf_internal::ValueTypesAreEqual( - ToValueType<T>(), parameters_[hierarchy_level].value_type()); - if (!types_are_equal.ok()) { - return types_are_equal.status(); - } else if (!*types_are_equal) { - return absl::InvalidArgumentError( - "Value type T doesn't match parameters at `hierarchy_level`"); - } - if (hierarchy_level <= ctx.previous_hierarchy_level()) { - return absl::InvalidArgumentError( - "`hierarchy_level` must be greater than " - "`ctx.previous_hierarchy_level`"); - } - if ((ctx.previous_hierarchy_level() < 0) != (prefixes.empty())) { - return absl::InvalidArgumentError( - "`prefixes` must be empty if and only if this is the first call with " - "`ctx`."); - } - - int previous_log_domain_size = 0; - int previous_hierarchy_level = ctx.previous_hierarchy_level(); - if (!prefixes.empty()) { - ABSL_DCHECK_GE(ctx.previous_hierarchy_level(), 0); - previous_log_domain_size = - parameters_[previous_hierarchy_level].log_domain_size(); - for (absl::uint128 prefix : prefixes) { - if (previous_log_domain_size < 128 && - prefix >= (absl::uint128{1} << previous_log_domain_size)) { - return absl::InvalidArgumentError( - absl::StrFormat("Index %d out of range for hierarchy level %d", - prefix, previous_hierarchy_level)); - } - } - } - int64_t prefixes_size = static_cast<int64_t>(prefixes.size()); - - // Check that the output size is not too large. We first check that the - // domain size blowup fits in an int64_t, and then check that the total size - // of all elements doesn't over flow a size_t. - int log_domain_size = parameters_[hierarchy_level].log_domain_size(); - if (log_domain_size - previous_log_domain_size >= 63) { - return absl::InvalidArgumentError( - "Domain size gap too large. Please evaluate fewer hierarchy " - "levels at once, or insert intermediate hierarchy levels."); - } - int64_t outputs_per_prefix = int64_t{1} - << (log_domain_size - previous_log_domain_size); - if (absl::uint128{prefixes_size} * outputs_per_prefix > - std::numeric_limits<size_t>::max() / 2) { - return absl::InvalidArgumentError( - "Output size would be too large. Please evaluate fewer hierarchy " - "levels at once, insert intermediate hierarchy levels, or evaluate on " - "fewer prefixes at once."); - } - - // The `prefixes` passed in by the caller refer to the domain of the previous - // hierarchy level. However, because we batch multiple elements of type T in a - // single uint128 block, multiple prefixes can actually refer to the same - // block in the FSS evaluation tree. On a high level, our approach is as - // follows: - // - // 1. Split up each element of `prefixes` into a tree index, pointing to a - // block in the FSS tree, and a block index, pointing to an element of type - // T in that block. - // - // 2. Compute a list of unique `tree_indices`, and for each original prefix, - // remember the position of the corresponding tree index in `tree_indices`. - // - // 3. After expanding the unique `tree_indices`, use the positions saved in - // Step (2) together with the corresponding block index to retrieve the - // expanded values for each prefix, and return them in the same order as - // `prefixes`. - // - // `tree_indices` holds the unique tree indices from `prefixes`, to be passed - // to `ExpandAndUpdateContext`. - std::vector<absl::uint128> tree_indices; - tree_indices.reserve(prefixes_size); - // `tree_indices_inverse` is the inverse of `tree_indices`, used for - // deduplicating and constructing `prefix_map`. Use a btree_map because we - // expect `prefixes` (and thus `tree_indices`) to be sorted. - absl::btree_map<absl::uint128, int64_t> tree_indices_inverse; - // `prefix_map` maps each i < prefixes.size() to an element of `tree_indices` - // and a block index. Used to select which elements to return after the - // expansion, to ensure the result is ordered the same way as `prefixes`. - std::vector<std::pair<int64_t, int>> prefix_map; - prefix_map.reserve(prefixes_size); - for (int64_t i = 0; i < prefixes_size; ++i) { - absl::uint128 tree_index = - DomainToTreeIndex(prefixes[i], previous_hierarchy_level); - int block_index = DomainToBlockIndex(prefixes[i], previous_hierarchy_level); - - // Check if `tree_index` already exists in `tree_indices`. - size_t previous_size = tree_indices_inverse.size(); - auto it = tree_indices_inverse.try_emplace(tree_indices_inverse.end(), - tree_index, tree_indices.size()); - if (tree_indices_inverse.size() > previous_size) { - tree_indices.push_back(tree_index); - } - prefix_map.push_back(std::make_pair(it->second, block_index)); - } - - // Perform expansion of unique `tree_indices`. - absl::StatusOr<DpfExpansion> expansion = - ExpandAndUpdateContext(hierarchy_level, tree_indices, ctx); - if (!expansion.ok()) { - return expansion.status(); - } - const auto expansion_size = - static_cast<int64_t>(expansion->control_bits.size()); - auto seeds = absl::MakeConstSpan(expansion->seeds.get(), expansion_size); - - // Hash the expanded seeds. - absl::StatusOr<hwy::AlignedFreeUniquePtr<absl::uint128[]>> hashed_expansion = - HashExpandedSeeds(hierarchy_level, seeds); - if (!hashed_expansion.ok()) { - return hashed_expansion.status(); - } - - // Get output correction word from `ctx`. - constexpr int elements_per_block = dpf_internal::ElementsPerBlock<T>(); - const ::google::protobuf::RepeatedPtrField<Value>* value_correction = nullptr; - if (hierarchy_level < static_cast<int>(parameters_.size()) - 1) { - value_correction = - &(ctx.key() - .correction_words(hierarchy_to_tree_[hierarchy_level]) - .value_correction()); - } else { - // Last level value correction is stored in an extra proto field, since we - // have one less correction word than tree levels. - value_correction = &(ctx.key().last_level_value_correction()); - } - - // Split output correction into elements of type T. - absl::StatusOr<std::array<T, elements_per_block>> correction_ints = - dpf_internal::ValuesToArray<T>(*value_correction); - if (!correction_ints.ok()) { - return correction_ints.status(); - } - - // Compute value corrections for each block in `expanded_seeds`. We have to - // account for the fact that blocks might not be full (i.e., have less than - // elements_per_block elements). - const int corrected_elements_per_block = - 1 << (parameters_[hierarchy_level].log_domain_size() - - hierarchy_to_tree_[hierarchy_level]); - const int blocks_needed = blocks_needed_[hierarchy_level]; - ABSL_DCHECK(corrected_elements_per_block <= elements_per_block); - std::vector<T> corrected_expansion(expansion_size * - corrected_elements_per_block); - for (int64_t i = 0; i < expansion_size; ++i) { - std::array<T, elements_per_block> current_elements = - dpf_internal::ConvertBytesToArrayOf<T>(absl::string_view( - reinterpret_cast<const char*>(hashed_expansion->get() + - i * blocks_needed), - blocks_needed * sizeof(absl::uint128))); - for (int j = 0; j < corrected_elements_per_block; ++j) { - if (expansion->control_bits[i]) { - current_elements[j] += (*correction_ints)[j]; - } - if (ctx.key().party() == 1) { - current_elements[j] = -current_elements[j]; - } - corrected_expansion[i * corrected_elements_per_block + j] = - current_elements[j]; - } - } - - if (prefixes.empty()) { - // If prefixes is empty (i.e., this is the first evaluation of `ctx`), just - // return the expansion. - ABSL_DCHECK(static_cast<int>(corrected_expansion.size()) == - outputs_per_prefix); - return corrected_expansion; - } else { - // Otherwise, only return elements under `prefixes`. - int blocks_per_tree_prefix = - expansion->control_bits.size() / tree_indices.size(); - std::vector<T> result(prefixes_size * outputs_per_prefix); - for (int64_t i = 0; i < prefixes_size; ++i) { - int64_t prefix_expansion_start = - prefix_map[i].first * blocks_per_tree_prefix * - corrected_elements_per_block + - prefix_map[i].second * outputs_per_prefix; - std::copy_n(&corrected_expansion[prefix_expansion_start], - outputs_per_prefix, &result[i * outputs_per_prefix]); - } - return result; - } -} - -template <typename T> -absl::StatusOr<std::array<T, dpf_internal::ElementsPerBlock<T>()>> -DistributedPointFunction::GetValueCorrectionAsArray(const DpfKey& key, - int hierarchy_level) const { - // Get output correction word from `key`. - const ::google::protobuf::RepeatedPtrField<Value>* value_correction = nullptr; - if (hierarchy_level < static_cast<int>(parameters_.size()) - 1) { - value_correction = - &(key.correction_words(hierarchy_to_tree_[hierarchy_level]) - .value_correction()); - } else { - // Last level value correction is stored in an extra proto field, since we - // have one less correction word than tree levels. - value_correction = &(key.last_level_value_correction()); - } - - // Split output correction into elements of type T, and return it. - return dpf_internal::ValuesToArray<T>(*value_correction); -} - -template <typename T> -absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateAtImpl( - const DpfKey& key, int hierarchy_level, - absl::Span<const absl::uint128> evaluation_points, - EvaluationContext* ctx) const { - if (ctx != nullptr) { - if (&key != &ctx->key()) { - return absl::InvalidArgumentError( - "`key` and `ctx->key()` must refer to the same object"); - } - } - if (hierarchy_level < 0) { - return absl::InvalidArgumentError("`hierarchy_level` must be non-negative"); - } - if (hierarchy_level >= static_cast<int>(parameters_.size())) { - return absl::InvalidArgumentError( - "`hierarchy_level` must be less than the number of parameters passed " - "at construction"); - } - const auto num_evaluation_points = - static_cast<int64_t>(evaluation_points.size()); - const int log_domain_size = parameters_[hierarchy_level].log_domain_size(); - absl::uint128 max_evaluation_point = absl::Uint128Max(); - if (log_domain_size < 128) { - max_evaluation_point = (absl::uint128{1} << log_domain_size) - 1; - } - // Check if `evaluation_points` are inside the domain. This has minimal (~ 1%) - // performance impact. - for (int64_t i = 0; i < num_evaluation_points; ++i) { - if (evaluation_points[i] > max_evaluation_point) { - return absl::InvalidArgumentError( - absl::StrCat("`evaluation_points[", i, - "]` larger than the domain size at hierarchy level ", - hierarchy_level)); - } - } - absl::Status status = proto_validator_->ValidateDpfKey(key); - if (!status.ok()) { - return status; - } - if (num_evaluation_points == 0) { - return std::vector<T>{}; // Nothing to do. - } - - // Split up evaluation_points into tree indices and block indices, if we're - // operating on a packed type. Otherwise set `tree_indices` to - // `evaluation_points`. - hwy::AlignedFreeUniquePtr<absl::uint128[]> maybe_recomputed_tree_indices; - constexpr int elements_per_block = dpf_internal::ElementsPerBlock<T>(); - absl::Span<const absl::uint128> tree_indices; - if (elements_per_block > 1) { - maybe_recomputed_tree_indices = - hwy::AllocateAligned<absl::uint128>(num_evaluation_points); - if (maybe_recomputed_tree_indices == nullptr) { - return absl::ResourceExhaustedError("Memory allocation error"); - } - for (int64_t i = 0; i < num_evaluation_points; ++i) { - maybe_recomputed_tree_indices[i] = - DomainToTreeIndex(evaluation_points[i], hierarchy_level); - } - tree_indices = absl::MakeConstSpan(maybe_recomputed_tree_indices.get(), - num_evaluation_points); - // Copy evaluation_points to new array if not aligned. - } else { - // This avoids copying the evaluation points when elements_per_block == 1. - tree_indices = evaluation_points; - } - - // Set up partial evaluations for the selected tree_indices. If we have a - // context `ctx`, Compute them from `ctx.partial_evaluations`, otherwise start - // from the beginning. - absl::StatusOr<DpfExpansion> selected_partial_evaluations = DpfExpansion(); - int start_level = 0; - if (!ctx) { - // No context or context was never evaluated -> start from the beginning. - absl::uint128 seed = absl::MakeUint128(key.seed().high(), key.seed().low()); - bool party = key.party(); - selected_partial_evaluations->seeds = - hwy::AllocateAligned<absl::uint128>(num_evaluation_points); - if (selected_partial_evaluations->seeds == nullptr) { - return absl::ResourceExhaustedError("Memory allocation error"); - } - auto seeds = absl::MakeSpan(selected_partial_evaluations->seeds.get(), - num_evaluation_points); - std::fill(seeds.begin(), seeds.end(), seed); - selected_partial_evaluations->control_bits.resize(num_evaluation_points, - party); - } else { - // We have a context -> Use it to compute partial evaluations. Always update - // `ctx`, since unlike for full expansion the amount of proto data written - // will always be `tree_indices.size()` and should therefore be negligible. - selected_partial_evaluations = - ComputePartialEvaluations(tree_indices, hierarchy_level, - /*update_ctx=*/true, *ctx); - if (!selected_partial_evaluations.ok()) { - return selected_partial_evaluations.status(); - } - start_level = hierarchy_to_tree_[hierarchy_level]; - } - - // Evaluate DPFs. - const int stop_level = hierarchy_to_tree_[hierarchy_level]; - absl::Span<absl::uint128> seeds( - selected_partial_evaluations->seeds.get(), - selected_partial_evaluations->control_bits.size()); - auto correction_words = absl::MakeConstSpan(key.correction_words()) - .subspan(start_level, stop_level - start_level); - status = - EvaluateSeeds(seeds, selected_partial_evaluations->control_bits, - tree_indices, correction_words, seeds, - absl::MakeSpan(selected_partial_evaluations->control_bits)); - if (!status.ok()) { - return status; - } - ABSL_DCHECK(static_cast<int64_t>(seeds.size()) == num_evaluation_points); - - // Hash `seeds`. - absl::StatusOr<hwy::AlignedFreeUniquePtr<absl::uint128[]>> hashed_expansion = - HashExpandedSeeds(hierarchy_level, seeds); - if (!hashed_expansion.ok()) { - return hashed_expansion.status(); - } - - // Get value correction words. - absl::StatusOr<std::array<T, elements_per_block>> correction_ints = - GetValueCorrectionAsArray<T>(key, hierarchy_level); - if (!correction_ints.ok()) { - return correction_ints.status(); - } - - // Perform value correction. - std::vector<T> result(num_evaluation_points); - const int blocks_needed = blocks_needed_[hierarchy_level]; - for (int64_t i = 0; i < num_evaluation_points; ++i) { - std::array<T, elements_per_block> current_elements = - dpf_internal::ConvertBytesToArrayOf<T>(absl::string_view( - reinterpret_cast<const char*>(hashed_expansion->get() + - i * blocks_needed), - blocks_needed * sizeof(absl::uint128))); - int block_index = 0; - if (elements_per_block > 1) { - block_index = DomainToBlockIndex(evaluation_points[i], hierarchy_level); - } - result[i] = current_elements[block_index]; - if (selected_partial_evaluations->control_bits[i]) { - result[i] += (*correction_ints)[block_index]; - } - if (key.party() == 1) { - result[i] = -result[i]; - } - } - - if (ctx) { - ctx->set_previous_hierarchy_level(hierarchy_level); - } - - return result; -} - -template <typename T, typename Fn> -absl::Status DistributedPointFunction::EvaluateAndApply( - dpf_internal::MaybeDerefSpan<const DpfKey> keys, - absl::Span<const absl::uint128> evaluation_points, Fn op, - int evaluation_points_rightshift) const { - if (evaluation_points.size() != keys.size()) { - return absl::InvalidArgumentError( - "`keys.size()` != `evaluation_points.size()`"); - } - for (size_t i = 0; i < keys.size(); ++i) { - absl::Status status = proto_validator_->ValidateDpfKey(keys[i]); - if (!status.ok()) return status; - } - - const int64_t num_keys = keys.size(); - const int num_hierarchy_levels = parameters_.size(); - DpfExpansion eval; - eval.control_bits.resize(num_keys); - eval.seeds = hwy::AllocateAligned<absl::uint128>(num_keys); - if (eval.seeds == nullptr) { - return absl::ResourceExhaustedError("Memory allocation error"); - } - absl::Span<absl::uint128> seeds(eval.seeds.get(), num_keys); - absl::Span<bool> control_bits(eval.control_bits); - hwy::AlignedFreeUniquePtr<absl::uint128[]> correction_seeds; - BitVector correction_control_bits_left, correction_control_bits_right; - std::vector<T> values(num_keys); - - // Initialize seeds and control bits. - for (int64_t i = 0; i < num_keys; ++i) { - seeds[i] = absl::MakeUint128(keys[i].seed().high(), keys[i].seed().low()); - control_bits[i] = keys[i].party(); - } - - int start_level = 0; - int stop_level = hierarchy_to_tree_[0]; - for (int hierarchy_level = 0; hierarchy_level < num_hierarchy_levels; - ++hierarchy_level) { - if (hierarchy_level > 0) { - start_level = stop_level; - stop_level = hierarchy_to_tree_[hierarchy_level]; - } - - // Compute index shifts for the current level. - const int domain_index_rightshift = - evaluation_points_rightshift + parameters_.back().log_domain_size() - - parameters_[hierarchy_level].log_domain_size(); - const int tree_index_rightshift = evaluation_points_rightshift + - parameters_.back().log_domain_size() - - hierarchy_to_tree_[hierarchy_level]; - - int num_tree_levels = stop_level - start_level; - if (num_tree_levels > 0) { - correction_seeds = - hwy::AllocateAligned<absl::uint128>(num_tree_levels * num_keys); - if (correction_seeds == nullptr) { - return absl::ResourceExhaustedError("Memory allocation error"); - } - correction_control_bits_left.resize(num_tree_levels * num_keys); - correction_control_bits_right.resize(num_tree_levels * num_keys); - for (int i = 0; i < num_tree_levels; ++i) { - for (int64_t j = 0; j < num_keys; ++j) { - const int64_t index = i * num_keys + j; - const CorrectionWord& cw = keys[j].correction_words(start_level + i); - correction_seeds[index] = - absl::MakeUint128(cw.seed().high(), cw.seed().low()); - correction_control_bits_left[index] = cw.control_left(); - correction_control_bits_right[index] = cw.control_right(); - } - } - - // Evaluate the current hierarchy level for all keys. - absl::Status status = dpf_internal::EvaluateSeeds( - seeds.size(), num_tree_levels, num_tree_levels * num_keys, - seeds.data(), control_bits.data(), evaluation_points.data(), - tree_index_rightshift, correction_seeds.get(), - correction_control_bits_left.data(), - correction_control_bits_right.data(), prg_left_, prg_right_, - seeds.data(), control_bits.data()); - if (!status.ok()) { - return status; - } - } - - // Hash `seeds`. - absl::StatusOr<hwy::AlignedFreeUniquePtr<absl::uint128[]>> - hashed_expansion = HashExpandedSeeds(hierarchy_level, seeds); - if (!hashed_expansion.ok()) { - return hashed_expansion.status(); - } - - // Compute value correction for the current level. - constexpr int elements_per_block = dpf_internal::ElementsPerBlock<T>(); - const int blocks_needed = blocks_needed_[hierarchy_level]; - for (int64_t i = 0; i < num_keys; ++i) { - std::array<T, elements_per_block> current_elements = - dpf_internal::ConvertBytesToArrayOf<T>(absl::string_view( - reinterpret_cast<const char*>(hashed_expansion->get() + - i * blocks_needed), - blocks_needed * sizeof(absl::uint128))); - absl::StatusOr<std::array<T, elements_per_block>> correction_ints = - GetValueCorrectionAsArray<T>(keys[i], hierarchy_level); - if (!correction_ints.ok()) { - return correction_ints.status(); - } - int block_index = 0; - if (elements_per_block > 1 && domain_index_rightshift < 128) { - block_index = DomainToBlockIndex( - evaluation_points[i] >> domain_index_rightshift, hierarchy_level); - } - values[i] = current_elements[block_index]; - if (control_bits[i]) { - values[i] += (*correction_ints)[block_index]; - } - if (keys[i].party() == 1) { - values[i] = -values[i]; - } - } - - // Call the callback with the values at the current level, and return if the - // result is `false`. - if (!op(values)) { - break; - } - } - return absl::OkStatus(); -} - -} // namespace distributed_point_functions - -#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_DISTRIBUTED_POINT_FUNCTION_H_ diff --git a/third_party/distributed_point_functions/code/dpf/distributed_point_function.proto b/third_party/distributed_point_functions/code/dpf/distributed_point_function.proto deleted file mode 100644 index 058d759f90495..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/distributed_point_function.proto +++ /dev/null @@ -1,171 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package distributed_point_functions; - -// For faster allocations of sub-messages. -option cc_enable_arenas = true; - -// Describes the type of a single DPF output value. Any additional types added -// here should also be supported in internal/value_type_helpers.h. -// LINT.IfChange -message ValueType { - // Describes an integer modulo 2^l. Maps to the C++ types `uint8_t`, - // `uint16_t`, `uint32_t`, `uint64_t`, and `absl::uint128`. - message Integer { - // Number of bits per integer. Must be a power of 2 and at most 128. - int32 bitsize = 1; - } - - // Describes a tuple of value types. - message Tuple { - repeated ValueType elements = 1; - } - - // Describes an integer ring modulo `modulus`. - message IntModN { - // The underlying integer type used to represent elements in the ring. - Integer base_integer = 1; - // The modulus. - Value.Integer modulus = 2; - } - - oneof type { - // A single integer modulo 2^l. - Integer integer = 1; - // A tuple of values. - Tuple tuple = 2; - // A integer with custom modulus. - IntModN int_mod_n = 3; - // An XOR-wrapped integer. Corresponds to the XorWrapper C++ class. - Integer xor_wrapper = 4; - } - // Do not add fields outside of the `oneof` above, to ensure that messages - // with known ValueTypes are serialized deterministically. See the - // documentation of `value_correction_functions_` in - // distributed_point_function.h for details. -} - -// Used to correct output values to the desired DPF magnitude. Holds the values -// corresponding to the types defined in `ValueType`. -message Value { - message Integer { - oneof value { - // Any value up to 64 bits. - uint64 value_uint64 = 1; - // 128-bit values. - Block value_uint128 = 2; - } - } - - message Tuple { - repeated Value elements = 1; - } - - oneof value { - Integer integer = 1; - Tuple tuple = 2; - Integer int_mod_n = - 3; // The value of an IntModN is represented by its base_integer type. - Integer xor_wrapper = 4; - } -} -// LINT.ThenChange( -// internal/value_type_helpers.h, -// internal/value_type_helpers.cc -// ) - -// Parameters of a single hierarchy level of a distributed point function (DPF). -message DpfParameters { - reserved 2; - // Base-2 logarithm of the number of elements. - int32 log_domain_size = 1; - // Describes the type of output values at this hierarchy level. - ValueType value_type = 3; - // The negative logarithm of the total variation distance from uniform that an - // evaluation at a *single point* at this hierarchy level is allowed to have. - // The correct value for this parameter depends on the maximum number of - // points at which this hierarchy level is evaluated. It should be at least 40 - // + log2(number_of_evaluation_points). Defaults to - // ProtoValidator::kDefaultSecurityParameter + log_domain_size. - double security_parameter = 4; -} - -// A single 128-bit AES block. -message Block { - uint64 high = 1; - uint64 low = 2; -} - -// A correction word used to evaluate a single layer in the DPF evaluation tree. -message CorrectionWord { - // Block used to correct the new seeds after PRG evaluation. - Block seed = 1; - // Correction bits for the left and right control bits. - bool control_left = 2; - bool control_right = 3; - // Reserved for deprecated value correction field. - reserved 4; - // Used to correct the output value at the previous tree layer. Only included - // if the previous tree layer is an output layer. Repeated to capture the case - // where multiple correction values are needed due to packing. - repeated Value value_correction = 5; -} - -// A key of a distributed point function (DPF). -message DpfKey { - // Initial seed at the first level. - Block seed = 1; - // Correction words for each level after expansion. - repeated CorrectionWord correction_words = 2; - // Party this DpfKey belongs to (0 or 1). - int32 party = 3; - // Deprecated last level value correction. - reserved 4; - // Output correction for the last level of the evaluation tree. - repeated Value last_level_value_correction = 5; -} - -// Maps a single prefix of a DPF index to a PRG seed. Used to store partial -// evaluation state between hierarchy levels in `EvaluationContext` -message PartialEvaluation { - // Prefix in the FSS evaluation tree. Does not necessarily coincide with the - // corresponding prefix of the output domain at this hierarchy level. - Block prefix = 1; - // Seed for the next evaluation. - Block seed = 2; - // Control bit for the correction in the next evaluation. - bool control_bit = 3; -} - -// An EvaluationContext holds the state of a partially evaluated incremental -// DPF. -message EvaluationContext { - // The parameters of the DPF being evaluated. One set of parameters for each - // hierarchy level of the incremental DPF. - repeated DpfParameters parameters = 1; - // The DPF key being evaluated. - DpfKey key = 2; - // The hierarchy level that this EvaluationContext was last evaluated on. - int32 previous_hierarchy_level = 3; - // Maps prefixes from an earlier hierarchy level to PRG seeds, which are used - // to continue the evaluation under each prefix. Uses a repeated message field - // since Protobuf doesn't allow messages (such as `Block`) as map keys. - repeated PartialEvaluation partial_evaluations = 4; - // The hierarchy level `partial_evaluations` corresponds to. Ignored when - // `partial_evaluations` is empty. - int32 partial_evaluations_level = 5; -} diff --git a/third_party/distributed_point_functions/code/dpf/distributed_point_function_benchmark.cc b/third_party/distributed_point_functions/code/dpf/distributed_point_function_benchmark.cc deleted file mode 100644 index 0deabc3ac8a66..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/distributed_point_function_benchmark.cc +++ /dev/null @@ -1,418 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <algorithm> -#include <cmath> -#include <memory> -#include <numeric> -#include <string> -#include <utility> -#include <vector> - -#include "absl/container/btree_set.h" -#include "absl/log/absl_check.h" -#include "absl/numeric/int128.h" -#include "absl/random/random.h" -#include "absl/random/uniform_int_distribution.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "benchmark/benchmark.h" -#include "dpf/distributed_point_function.h" -#include "google/protobuf/arena.h" -#include "hwy/aligned_allocator.h" - -namespace distributed_point_functions { -namespace { - -// Benchmarks a regular DPF evaluation. Expects the first range argument to -// specify the output log domain size. -template <typename T> -void BM_EvaluateRegularDpf(benchmark::State& state) { - DpfParameters parameters; - parameters.set_log_domain_size(state.range(0)); - *(parameters.mutable_value_type()) = ToValueType<T>(); - std::unique_ptr<DistributedPointFunction> dpf = - DistributedPointFunction::Create(parameters).value(); - absl::uint128 alpha = 0; - T beta{}; - ABSL_CHECK(dpf->RegisterValueType<T>().ok()); - std::pair<DpfKey, DpfKey> keys = dpf->GenerateKeys(alpha, beta).value(); - EvaluationContext ctx_0 = dpf->CreateEvaluationContext(keys.first).value(); - for (auto s : state) { - google::protobuf::Arena arena; - EvaluationContext* ctx = - google::protobuf::Arena::CreateMessage<EvaluationContext>(&arena); - *ctx = ctx_0; - std::vector<T> result = dpf->EvaluateNext<T>({}, *ctx).value(); - benchmark::DoNotOptimize(result); - } -} -BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, uint8_t)->DenseRange(12, 24, 2); -BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, uint16_t)->DenseRange(12, 24, 2); -BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, uint32_t)->DenseRange(12, 24, 2); -BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, uint64_t)->DenseRange(12, 24, 2); -BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, absl::uint128)->DenseRange(12, 24, 2); -BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, Tuple<uint32_t, uint32_t>) - ->DenseRange(12, 24, 2); -BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, Tuple<uint32_t, uint64_t>) - ->DenseRange(12, 24, 2); -BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, Tuple<uint64_t, uint64_t>) - ->DenseRange(12, 24, 2); -BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, - Tuple<uint32_t, uint32_t, uint32_t, uint32_t>) - ->DenseRange(12, 24, 2); -BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, - Tuple<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t>) - ->DenseRange(12, 24, 2); -BENCHMARK_TEMPLATE( - BM_EvaluateRegularDpf, - Tuple<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t>) - ->DenseRange(12, 24, 2); - -using MyIntModN = IntModN<uint32_t, 4294967291u>; // 2**32 - 5. -BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, - Tuple<MyIntModN, MyIntModN, MyIntModN, MyIntModN, MyIntModN>) - ->DenseRange(12, 24, 2); -using MyIntModN64 = IntModN<uint64_t, 18446744073709551557ull>; // 2**64 - 59. -BENCHMARK_TEMPLATE( - BM_EvaluateRegularDpf, - Tuple<MyIntModN64, MyIntModN64, MyIntModN64, MyIntModN64, MyIntModN64>) - ->DenseRange(12, 22, 2); -BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, XorWrapper<absl::uint128>) - ->DenseRange(1, 24, 1); - -// Benchmarks full evaluation of all hierarchy levels. Expects the first range -// argument to specify the number of iterations. The output domain size is fixed -// to 2**20. -template <typename T> -void BM_EvaluateHierarchicalFull(benchmark::State& state) { - // Set up DPF with the given parameters. - const int kMaxLogDomainSize = 20; - int num_hierarchy_levels = state.range(0); - std::vector<DpfParameters> parameters(num_hierarchy_levels); - for (int i = 0; i < num_hierarchy_levels; ++i) { - parameters[i].set_log_domain_size(static_cast<int>( - static_cast<double>(i + 1) / num_hierarchy_levels * kMaxLogDomainSize)); - parameters[i].mutable_value_type()->mutable_integer()->set_bitsize( - sizeof(T) * 8); - } - std::unique_ptr<DistributedPointFunction> dpf = - DistributedPointFunction::CreateIncremental(parameters).value(); - - // Generate keys. - absl::uint128 alpha = 12345; - std::vector<absl::uint128> beta(num_hierarchy_levels); - for (int i = 0; i < num_hierarchy_levels; ++i) { - beta[i] = i; - } - std::pair<DpfKey, DpfKey> keys = - dpf->GenerateKeysIncremental(alpha, beta).value(); - - // Set up evaluation context and evaluation prefixes for each level. - EvaluationContext ctx_0 = dpf->CreateEvaluationContext(keys.first).value(); - std::vector<std::vector<absl::uint128>> prefixes(num_hierarchy_levels); - for (int i = 1; i < num_hierarchy_levels; ++i) { - prefixes[i].resize(1 << parameters[i - 1].log_domain_size()); - std::iota(prefixes[i].begin(), prefixes[i].end(), absl::uint128{0}); - } - - // Run hierarchical evaluation. - for (auto s : state) { - google::protobuf::Arena arena; - EvaluationContext* ctx = - google::protobuf::Arena::CreateMessage<EvaluationContext>(&arena); - *ctx = ctx_0; - for (int i = 0; i < num_hierarchy_levels; ++i) { - std::vector<T> result = dpf->EvaluateNext<T>(prefixes[i], *ctx).value(); - benchmark::DoNotOptimize(result); - } - benchmark::DoNotOptimize(*ctx); - } -} -BENCHMARK_TEMPLATE(BM_EvaluateHierarchicalFull, uint8_t)->DenseRange(1, 16, 2); -BENCHMARK_TEMPLATE(BM_EvaluateHierarchicalFull, uint16_t)->DenseRange(1, 16, 2); -BENCHMARK_TEMPLATE(BM_EvaluateHierarchicalFull, uint32_t)->DenseRange(1, 16, 2); -BENCHMARK_TEMPLATE(BM_EvaluateHierarchicalFull, uint64_t)->DenseRange(1, 16, 2); -BENCHMARK_TEMPLATE(BM_EvaluateHierarchicalFull, absl::uint128) - ->DenseRange(1, 16, 2); - -// Generates random prefixes for the given set of `parameters`. Generates -// `num_nonzeros[i]` prefixes at hierarchy level `i`. -std::vector<std::vector<absl::uint128>> GenerateRandomPrefixes( - absl::Span<const DpfParameters> parameters, - absl::Span<const int> num_nonzeros) { - auto num_hierarchy_levels = static_cast<int>(parameters.size()); - std::vector<std::vector<absl::uint128>> prefixes(parameters.size()); - - absl::BitGen rng; - absl::uniform_int_distribution<uint32_t> dist_index, dist_value; - for (int i = 0; i < num_hierarchy_levels; ++i) { - if (i > 0) { // prefixes must be empty for the first level. - prefixes[i] = std::vector<absl::uint128>(num_nonzeros[i - 1]); - absl::uint128 prefix = 0; - // Difference between the previous domain size and the one before that. - // This is the amount of bits we have to shift prefixes from the previous - // level to append the current level. - int previous_domain_size_difference = parameters[i - 1].log_domain_size(); - if (i > 1) { - previous_domain_size_difference -= parameters[i - 2].log_domain_size(); - } - dist_value = absl::uniform_int_distribution<uint32_t>( - 0, (1 << previous_domain_size_difference) - 1); - if (i > 1) { - dist_index = absl::uniform_int_distribution<uint32_t>( - 0, prefixes[i - 1].size() - 1); - } - for (int j = 0; i > 0 && j < num_nonzeros[i - 1]; ++j) { - if (i > 1) { - // Choose a random prefix from the previous level to extend. - prefix = prefixes[i - 1][dist_index(rng)] - << previous_domain_size_difference; - } - prefixes[i][j] = prefix | dist_value(rng); - } - } - std::sort(prefixes[i].begin(), prefixes[i].end()); - } - return prefixes; -} - -// Benchmark the example used here: -// https://github.com/abetterinternet/prio-documents/issues/18#issuecomment-801248636 -void BM_IsrgExampleHierarchy(benchmark::State& state) { - const int kNumHierarchyLevels = 2; - std::vector<DpfParameters> parameters(kNumHierarchyLevels); - std::vector<int> num_nonzeros(kNumHierarchyLevels - 1); - - parameters[0].set_log_domain_size(12); - parameters[0].mutable_value_type()->mutable_integer()->set_bitsize(32); - num_nonzeros[0] = 32; - - parameters[1].set_log_domain_size(25); - parameters[1].mutable_value_type()->mutable_integer()->set_bitsize(32); - - std::unique_ptr<DistributedPointFunction> dpf = - DistributedPointFunction::CreateIncremental(parameters).value(); - - // Create DPF keys. - absl::uint128 alpha = 1234567; - std::vector<absl::uint128> beta(kNumHierarchyLevels, 1); - std::pair<DpfKey, DpfKey> keys = - dpf->GenerateKeysIncremental(alpha, beta).value(); - - // Generate prefixes for evaluation with the appropriate number of nonzeros. - std::vector<std::vector<absl::uint128>> prefixes = - GenerateRandomPrefixes(parameters, num_nonzeros); - - // Run hierarchical evaluation. - EvaluationContext ctx_0 = dpf->CreateEvaluationContext(keys.first).value(); - for (auto s : state) { - google::protobuf::Arena arena; - EvaluationContext* ctx = - google::protobuf::Arena::CreateMessage<EvaluationContext>(&arena); - *ctx = ctx_0; - for (int i = 0; i < kNumHierarchyLevels; ++i) { - std::vector<uint32_t> result = - dpf->EvaluateNext<uint32_t>(prefixes[i], *ctx).value(); - benchmark::DoNotOptimize(result); - } - benchmark::DoNotOptimize(*ctx); - } -} -BENCHMARK(BM_IsrgExampleHierarchy); - -// Benchmarks the time needed to generate keys. The log domain size is read from -// the first range argument. If `direct_evaluation` is true, a single hierarchy -// level will be used. Otherwise, the number of hierarchy levels is eqaual to -// the log domain size (i.e., one level per bit in the domain). -template <bool direct_evaluation> -void BM_KeyGeneration(benchmark::State& state) { - int last_level_log_domain_size = state.range(0); - std::vector<DpfParameters> parameters(1); - if (direct_evaluation) { - parameters[0].set_log_domain_size(last_level_log_domain_size); - parameters[0].mutable_value_type()->mutable_integer()->set_bitsize(32); - } else { - parameters.resize(last_level_log_domain_size); - for (int i = 0; i < last_level_log_domain_size; ++i) { - parameters[i].set_log_domain_size(i + 1); - parameters[i].mutable_value_type()->mutable_integer()->set_bitsize(32); - } - } - std::unique_ptr<DistributedPointFunction> dpf = - *(DistributedPointFunction::CreateIncremental(parameters)); - - std::vector<absl::uint128> beta(parameters.size(), 23); - absl::BitGen rng; - absl::uniform_int_distribution<uint64_t> dist; - absl::uint128 alpha_mask = - (absl::uint128{1} << parameters.back().log_domain_size()) - 1; - std::pair<DpfKey, DpfKey> result; - for (auto s : state) { - // Sample alpha randomly, so we don't rely on any structure here. - absl::uint128 alpha = absl::MakeUint128(dist(rng), dist(rng)) & alpha_mask; - result = dpf->GenerateKeysIncremental(alpha, beta).value(); - benchmark::DoNotOptimize(result); - } - state.SetLabel(absl::StrCat("key_size: ", result.first.ByteSizeLong())); -} -BENCHMARK_TEMPLATE(BM_KeyGeneration, true)->RangeMultiplier(2)->Range(1, 128); -BENCHMARK_TEMPLATE(BM_KeyGeneration, false)->RangeMultiplier(2)->Range(1, 128); - -// Generates `num_nonzeros` uniform indices, and computes their prefixes for -// each hierarchy level in `parameters`. -absl::StatusOr<std::vector<std::vector<absl::uint128>>> GenerateUniformPrefixes( - absl::Span<const DpfParameters> parameters, int num_nonzeros) { - int num_parameters = static_cast<int>(parameters.size()); - std::vector<std::vector<absl::uint128>> result(num_parameters); - if (num_parameters <= 1) { - return result; - } - if (std::log2(num_nonzeros) > - parameters[num_parameters - 2].log_domain_size()) { - return absl::InvalidArgumentError("num_nonzeros out of range"); - } - - absl::BitGen rng; - absl::uniform_int_distribution<uint64_t> dist; - - // Generate prefixes for last level. - absl::btree_set<absl::uint128> last_level_prefixes; - while (static_cast<int>(last_level_prefixes.size()) < num_nonzeros) { - absl::uint128 mask = (absl::uint128{1} << parameters[parameters.size() - 2] - .log_domain_size()) - - 1; - last_level_prefixes.insert(absl::MakeUint128(dist(rng), dist(rng)) & mask); - } - result.back() = std::vector<absl::uint128>(last_level_prefixes.begin(), - last_level_prefixes.end()); - - // Iterate backwards through previous levels, computing prefixes by - // appropriately shifting the ones from higher levels. - for (int i = static_cast<int>(result.size()) - 1; i > 1; --i) { - absl::btree_set<absl::uint128> current_level_prefixes; - for (const auto& x : result[i]) { - absl::uint128 prefix = x >> (parameters[i - 1].log_domain_size() - - parameters[i - 2].log_domain_size()); - current_level_prefixes.insert(prefix); - } - result[i - 1] = std::vector<absl::uint128>(current_level_prefixes.begin(), - current_level_prefixes.end()); - } - return result; -} - -// Benchmark a bit-wise hierarchy as in https://github.com/henrycg/heavyhitters. -// Uses a variable domain size with 10000 uniform non-zeros at the last -// hierarchy level, and evaluate at every bit. -void BM_HeavyHitters(benchmark::State& state) { - int num_parameters = state.range(0); - const int kNumNonzeros = 10000; - std::vector<DpfParameters> parameters(num_parameters); - for (int i = 0; i < num_parameters; ++i) { - parameters[i].set_log_domain_size(i + 1); - parameters[i].mutable_value_type()->mutable_integer()->set_bitsize(64); - } - std::unique_ptr<DistributedPointFunction> dpf = - *(DistributedPointFunction::CreateIncremental(parameters)); - - std::vector<absl::uint128> beta(num_parameters, 23); - absl::uint128 alpha = 42; - DpfKey key = dpf->GenerateKeysIncremental(alpha, beta).value().first; - std::vector<std::vector<absl::uint128>> prefixes = - GenerateUniformPrefixes(parameters, kNumNonzeros).value(); - - // Run hierarchical evaluation. - EvaluationContext ctx_0 = dpf->CreateEvaluationContext(key).value(); - for (auto s : state) { - google::protobuf::Arena arena; - EvaluationContext* ctx = - google::protobuf::Arena::CreateMessage<EvaluationContext>(&arena); - *ctx = ctx_0; - for (int i = 0; i < num_parameters; ++i) { - std::vector<uint64_t> result = - dpf->EvaluateNext<uint64_t>(prefixes[i], *ctx).value(); - benchmark::DoNotOptimize(result); - } - benchmark::DoNotOptimize(*ctx); - } -} -BENCHMARK(BM_HeavyHitters)->RangeMultiplier(2)->Range(16, 128); - -// Benchmark batch evaluation of multiple DPF keys at a single point each. -// The first argument specifies the number of keys, the second the domain size, -// and the last the number of evaluation points per key. -template <typename T> -void BM_BatchEvaluation(benchmark::State& state) { - const int num_keys = state.range(0); - const int evaluation_points_per_key = state.range(1); - constexpr int kLogDomainSize = 63 - 7; - - absl::uint128 domain_mask = absl::Uint128Max(); - if (kLogDomainSize < 128) { - domain_mask = (absl::uint128{1} << kLogDomainSize) - 1; - } - - DpfParameters parameters; - parameters.set_log_domain_size(kLogDomainSize); - *(parameters.mutable_value_type()) = ToValueType<T>(); - - std::unique_ptr<DistributedPointFunction> dpf = - DistributedPointFunction::Create(parameters).value(); - - absl::BitGen rng; - google::protobuf::Arena arena; - std::vector<const DpfKey*> key_pointers(num_keys * evaluation_points_per_key); - auto evaluation_points = - hwy::AllocateAligned<absl::uint128>(num_keys * evaluation_points_per_key); - ABSL_CHECK(evaluation_points != nullptr); - for (int i = 0; i < num_keys; ++i) { - absl::uint128 alpha = absl::MakeUint128(absl::Uniform<uint64_t>(rng), - absl::Uniform<uint64_t>(rng)) & - domain_mask; - T beta{}; - DpfKey* key = google::protobuf::Arena::CreateMessage<DpfKey>(&arena); - *key = dpf->GenerateKeys(alpha, beta).value().first; - - for (int j = 0; j < evaluation_points_per_key; ++j) { - key_pointers[i * evaluation_points_per_key + j] = key; - evaluation_points[i * evaluation_points_per_key + j] = - absl::MakeUint128(absl::Uniform<uint64_t>(rng), - absl::Uniform<uint64_t>(rng)) & - domain_mask; - } - } - - for (auto s : state) { - for (int i = 0; i < num_keys; ++i) { - std::vector<T> result = - dpf->EvaluateAt<T>( - *(key_pointers[i]), 0, - absl::MakeConstSpan( - evaluation_points.get() + i * evaluation_points_per_key, - evaluation_points_per_key)) - .value(); - benchmark::DoNotOptimize(result); - } - } -} -BENCHMARK_TEMPLATE(BM_BatchEvaluation, XorWrapper<absl::uint128>) - ->ArgPair(1, 400000) - ->ArgPair(10, 40000) - ->ArgPair(100, 4000); - -} // namespace -} // namespace distributed_point_functions diff --git a/third_party/distributed_point_functions/code/dpf/distributed_point_function_test.cc b/third_party/distributed_point_functions/code/dpf/distributed_point_function_test.cc deleted file mode 100644 index c651818ba093f..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/distributed_point_function_test.cc +++ /dev/null @@ -1,1189 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dpf/distributed_point_function.h" - -#include <memory> -#include <ostream> -#include <string> -#include <tuple> -#include <utility> -#include <vector> - -#include "absl/algorithm/container.h" -#include "absl/base/config.h" -#include "absl/numeric/int128.h" -#include "absl/random/random.h" -#include "absl/random/uniform_int_distribution.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/types/span.h" -#include "absl/utility/utility.h" -#include "dpf/distributed_point_function.pb.h" -#include "dpf/internal/proto_validator.h" -#include "dpf/internal/status_matchers.h" -#include "dpf/xor_wrapper.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace distributed_point_functions { -namespace { - -using dpf_internal::IsOk; -using dpf_internal::IsOkAndHolds; -using dpf_internal::StatusIs; -using ::testing::HasSubstr; -using ::testing::Ne; -using ::testing::StartsWith; - -TEST(DistributedPointFunction, TestCreate) { - for (int log_domain_size = 0; log_domain_size <= 128; ++log_domain_size) { - for (int element_bitsize = 1; element_bitsize <= 128; - element_bitsize *= 2) { - DpfParameters parameters; - - parameters.set_log_domain_size(log_domain_size); - parameters.mutable_value_type()->mutable_integer()->set_bitsize( - element_bitsize); - - EXPECT_THAT(DistributedPointFunction::Create(parameters), - IsOkAndHolds(Ne(nullptr))) - << "log_domain_size=" << log_domain_size - << " element_bitsize=" << element_bitsize; - } - } -} - -TEST(DistributedPointFunction, TestCreateIncrementalLargeDomain) { - std::vector<DpfParameters> parameters(2); - parameters[0].mutable_value_type()->mutable_integer()->set_bitsize(128); - parameters[1].mutable_value_type()->mutable_integer()->set_bitsize(128); - - // Test that creating an incremental DPF with a large total domain size works. - parameters[0].set_log_domain_size(10); - parameters[1].set_log_domain_size(100); - - EXPECT_THAT(DistributedPointFunction::CreateIncremental(parameters), - IsOkAndHolds(Ne(nullptr))); -} - -TEST(DistributedPointFunction, CreateFailsForTupleTypesWithDifferentIntModN) { - DpfParameters parameters; - parameters.set_log_domain_size(10); - *(parameters.mutable_value_type()) = - ToValueType<Tuple<IntModN<uint32_t, 3>, IntModN<uint64_t, 4>>>(); - - EXPECT_THAT( - DistributedPointFunction::Create(parameters), - StatusIs(absl::StatusCode::kUnimplemented, - "All elements of type IntModN in a tuple must be the same")); -} - -TEST(DistributedPointFunction, CreateFailsForMissingValueType) { - DpfParameters parameters; - parameters.set_log_domain_size(10); - - EXPECT_THAT( - DistributedPointFunction::Create(parameters), - StatusIs(absl::StatusCode::kInvalidArgument, "`value_type` is required")); -} - -TEST(DistributedPointFunction, CreateFailsForInvalidValueType) { - DpfParameters parameters; - parameters.set_log_domain_size(10); - *(parameters.mutable_value_type()) = ValueType{}; - - EXPECT_THAT(DistributedPointFunction::Create(parameters), - StatusIs(absl::StatusCode::kInvalidArgument, - StartsWith("ValidateValueType: Unsupported ValueType"))); -} - -TEST(DistributedPointFunction, TestGenerateKeysIncrementalVariadicTemplate) { - std::vector<DpfParameters> parameters(2); - - parameters[0].set_log_domain_size(10); - parameters[1].set_log_domain_size(20); - *(parameters[0].mutable_value_type()) = ToValueType<uint16_t>(); - *(parameters[1].mutable_value_type()) = - ToValueType<Tuple<uint32_t, uint64_t>>(); - DPF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<DistributedPointFunction> dpf, - DistributedPointFunction::CreateIncremental(parameters)); - - absl::StatusOr<std::pair<DpfKey, DpfKey>> keys = dpf->GenerateKeysIncremental( - 23, uint16_t{42}, Tuple<uint32_t, uint64_t>{123, 456}); - EXPECT_THAT(keys, IsOk()); -} - -TEST(DistributedPointFunction, TestGenerateKeysIncrementalTemplate) { - std::vector<DpfParameters> parameters(2); - using T = XorWrapper<absl::uint128>; - - parameters[0].set_log_domain_size(10); - parameters[1].set_log_domain_size(20); - *(parameters[0].mutable_value_type()) = ToValueType<T>(); - *(parameters[1].mutable_value_type()) = ToValueType<T>(); - DPF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<DistributedPointFunction> dpf, - DistributedPointFunction::CreateIncremental(parameters)); - - absl::StatusOr<std::pair<DpfKey, DpfKey>> keys = - dpf->GenerateKeysIncremental(23, T{42}, T{123}); - EXPECT_THAT(keys, IsOk()); -} - -TEST(DistributedPointFunction, KeyGenerationFailsIfValueTypeNotRegistered) { - DpfParameters parameters; - parameters.set_log_domain_size(10); - parameters.mutable_value_type() - ->mutable_tuple() - ->add_elements() - ->mutable_integer() - ->set_bitsize(32); - DPF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DistributedPointFunction> dpf, - DistributedPointFunction::Create(parameters)); - - // Tuple<uint32_t> should not be registered by default. - absl::uint128 alpha = 23; - Value beta; - beta.mutable_tuple()->add_elements()->mutable_integer()->set_value_uint64(42); - - EXPECT_THAT(dpf->GenerateKeys(alpha, beta), - StatusIs(absl::StatusCode::kFailedPrecondition, - StartsWith("No value correction function known"))); -} - -TEST(DistributedPointFunction, EvaluationFailsIfDomainSizeGapTooLarge) { - std::vector<DpfParameters> parameters(2); - parameters[0].mutable_value_type()->mutable_integer()->set_bitsize(128); - parameters[1].mutable_value_type()->mutable_integer()->set_bitsize(128); - parameters[0].set_log_domain_size(10); - parameters[1].set_log_domain_size(100); - DPF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<DistributedPointFunction> dpf, - DistributedPointFunction::CreateIncremental(parameters)); - - std::pair<DpfKey, DpfKey> keys; - DPF_ASSERT_OK_AND_ASSIGN(keys, dpf->GenerateKeysIncremental(123, 456u, 789u)); - DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx, - dpf->CreateEvaluationContext(keys.first)); - - EXPECT_THAT( - dpf->EvaluateUntil<absl::uint128>(1, {}, ctx), - StatusIs(absl::StatusCode::kInvalidArgument, StartsWith("Domain size"))); -} - -TEST(DistributedPointFunction, EvaluationFailsIfOutputSizeTooLarge) { - std::vector<DpfParameters> parameters(2); - parameters[0].mutable_value_type()->mutable_integer()->set_bitsize(128); - parameters[1].mutable_value_type()->mutable_integer()->set_bitsize(128); - parameters[0].set_log_domain_size(10); - parameters[1].set_log_domain_size(72); - DPF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<DistributedPointFunction> dpf, - DistributedPointFunction::CreateIncremental(parameters)); - - std::pair<DpfKey, DpfKey> keys; - DPF_ASSERT_OK_AND_ASSIGN(keys, dpf->GenerateKeysIncremental(123, 456u, 789u)); - DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx, - dpf->CreateEvaluationContext(keys.first)); - - // Evaluate on 2**2 prefixes, bringing the output size to 2**(72-10+2) = - // 2**64, which overflows an int64_t. Assumes a size_t is at most 64 bits. - std::vector<absl::uint128> prefixes = {123, 456, 789, 1011}; - DPF_ASSERT_OK(dpf->EvaluateUntil<absl::uint128>(0, {}, ctx)); - EXPECT_THAT( - dpf->EvaluateUntil<absl::uint128>(1, prefixes, ctx), - StatusIs(absl::StatusCode::kInvalidArgument, StartsWith("Output size"))); -} - -TEST(DistributedPointFunction, TestSinglePointPartialEvaluation) { - // Two hierarchy levels: The first will be evaluated with only a single - // prefix, the second will be fully evaluated. - std::vector<DpfParameters> parameters(2); - parameters[0].set_log_domain_size(108); - parameters[0].mutable_value_type()->mutable_integer()->set_bitsize(32); - parameters[1].set_log_domain_size(128); - parameters[1].mutable_value_type()->mutable_integer()->set_bitsize(32); - - DPF_ASSERT_OK_AND_ASSIGN( - auto dpf, DistributedPointFunction::CreateIncremental(parameters)); - absl::uint128 prefix = 0xdeadbeef; - absl::uint128 suffix = 23; - absl::uint128 alpha = (prefix << 20) + suffix; - uint32_t beta = 42; - DpfKey key_a, key_b; - DPF_ASSERT_OK_AND_ASSIGN(std::tie(key_a, key_b), - dpf->GenerateKeysIncremental(alpha, {beta, beta})); - - // First evaluate directly up to `prefix` - DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx_a, - dpf->CreateEvaluationContext(key_a)); - DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx_b, - dpf->CreateEvaluationContext(key_b)); - std::vector<uint32_t> result_a, result_b; - DPF_ASSERT_OK_AND_ASSIGN(result_a, - dpf->EvaluateAt<uint32_t>(0, {prefix}, ctx_a)); - DPF_ASSERT_OK_AND_ASSIGN(result_b, - dpf->EvaluateAt<uint32_t>(0, {prefix}, ctx_b)); - EXPECT_EQ(result_a[0] + result_b[0], beta); - - // Now fully evaluate the second level. - DPF_ASSERT_OK_AND_ASSIGN(result_a, - dpf->EvaluateUntil<uint32_t>(1, {prefix}, ctx_a)); - DPF_ASSERT_OK_AND_ASSIGN(result_b, - dpf->EvaluateUntil<uint32_t>(1, {prefix}, ctx_b)); - EXPECT_EQ(result_a.size(), result_b.size()); - EXPECT_EQ(result_a.size(), 1 << 20); - for (int i = 0; i < static_cast<int>(result_a.size()); ++i) { - if (i == suffix) { - EXPECT_EQ(result_a[i] + result_b[i], beta); - } else { - EXPECT_EQ(result_a[i] + result_b[i], 0); - } - } -} - -class RegularDpfKeyGenerationTest - : public testing::TestWithParam<std::tuple<int, int>> { - public: - void SetUp() { - std::tie(log_domain_size_, element_bitsize_) = GetParam(); - DpfParameters parameters; - parameters.set_log_domain_size(log_domain_size_); - parameters.mutable_value_type()->mutable_integer()->set_bitsize( - element_bitsize_); - DPF_ASSERT_OK_AND_ASSIGN(dpf_, - DistributedPointFunction::Create(parameters)); - DPF_ASSERT_OK_AND_ASSIGN( - proto_validator_, dpf_internal::ProtoValidator::Create({parameters})); - } - - protected: - int log_domain_size_; - int element_bitsize_; - std::unique_ptr<DistributedPointFunction> dpf_; - std::unique_ptr<dpf_internal::ProtoValidator> proto_validator_; -}; - -TEST_P(RegularDpfKeyGenerationTest, KeyHasCorrectFormat) { - DpfKey key_a, key_b; - DPF_ASSERT_OK_AND_ASSIGN(std::tie(key_a, key_b), dpf_->GenerateKeys(0, 0)); - - // Check that party is set correctly. - EXPECT_EQ(key_a.party(), 0); - EXPECT_EQ(key_b.party(), 1); - // Check that keys are accepted by proto_validator_. - DPF_EXPECT_OK(proto_validator_->ValidateDpfKey(key_a)); - DPF_EXPECT_OK(proto_validator_->ValidateDpfKey(key_b)); -} - -TEST_P(RegularDpfKeyGenerationTest, FailsIfBetaHasTheWrongSize) { - EXPECT_THAT( - dpf_->GenerateKeysIncremental(0, {1, 2}), - StatusIs(absl::StatusCode::kInvalidArgument, - "`beta` has to have the same size as `parameters` passed at " - "construction")); -} - -TEST_P(RegularDpfKeyGenerationTest, FailsIfAlphaIsTooLarge) { - if (log_domain_size_ >= 128) { - // Alpha is an absl::uint128, so never too large in this case. - return; - } - - EXPECT_THAT(dpf_->GenerateKeys((absl::uint128{1} << log_domain_size_), 1), - StatusIs(absl::StatusCode::kInvalidArgument, - "`alpha` must be smaller than the output domain size")); -} - -TEST_P(RegularDpfKeyGenerationTest, FailsIfBetaIsTooLarge) { - if (element_bitsize_ >= 128) { - // Beta is an absl::uint128, so never too large in this case. - return; - } - - const auto beta = absl::uint128{1} << element_bitsize_; - - // Not testing error message, as it's an implementation detail of - // ProtoValidator. - EXPECT_THAT(dpf_->GenerateKeys(0, beta), - StatusIs(absl::StatusCode::kInvalidArgument)); -} - -INSTANTIATE_TEST_SUITE_P(VaryDomainAndElementSizes, RegularDpfKeyGenerationTest, - testing::Combine(testing::Values(0, 1, 2, 3, 4, 5, 6, - 7, 8, 9, 10, 32, 62), - testing::Values(8, 16, 32, 64, 128))); - -struct DpfTestParameters { - int log_domain_size; - int element_bitsize; - - friend std::ostream& operator<<(std::ostream& os, - const DpfTestParameters& parameters) { - return os << "{ log_domain_size: " << parameters.log_domain_size - << ", element_bitsize: " << parameters.element_bitsize << " }"; - } -}; - -class IncrementalDpfTest : public testing::TestWithParam< - std::tuple</*parameters*/ - std::vector<DpfTestParameters>, - /*alpha*/ absl::uint128, - /*beta*/ std::vector<absl::uint128>, - /*level_step*/ int, - /*single_point*/ bool>> { - protected: - void SetUp() { - const std::vector<DpfTestParameters>& parameters = std::get<0>(GetParam()); - parameters_.resize(parameters.size()); - for (int i = 0; i < static_cast<int>(parameters.size()); ++i) { - parameters_[i].set_log_domain_size(parameters[i].log_domain_size); - parameters_[i].mutable_value_type()->mutable_integer()->set_bitsize( - parameters[i].element_bitsize); - } - DPF_ASSERT_OK_AND_ASSIGN( - dpf_, DistributedPointFunction::CreateIncremental(parameters_)); - alpha_ = std::get<1>(GetParam()); - beta_ = std::get<2>(GetParam()); - DPF_ASSERT_OK_AND_ASSIGN(keys_, - dpf_->GenerateKeysIncremental(alpha_, beta_)); - level_step_ = std::get<3>( - GetParam()); // Number of hierarchy level to evaluate at once. - single_point_ = std::get<4>(GetParam()); - DPF_ASSERT_OK_AND_ASSIGN(proto_validator_, - dpf_internal::ProtoValidator::Create(parameters_)); - } - - // Returns the prefix of `index` for the domain of `hierarchy_level`. - absl::uint128 GetPrefixForLevel(int hierarchy_level, absl::uint128 index) { - absl::uint128 result = 0; - int shift_amount = parameters_.back().log_domain_size() - - parameters_[hierarchy_level].log_domain_size(); - if (shift_amount < 128) { - result = index >> shift_amount; - } - return result; - } - - // Evaluates both contexts `ctx0` and `ctx1` at `hierarchy level`, using the - // appropriate prefixes of `evaluation_points`. Checks that the expansion of - // both keys form correct DPF shares, i.e., they add up to - // `beta_[ctx.hierarchy_level()]` under prefixes of `alpha_`, and to 0 - // otherwise. If `singl_point == true`, only evaluates at the prefixes of the - // given `evaluation_points`. Otherwise, fully expands the given - // `hierarchy_level`. - template <typename T> - void EvaluateAndCheckLevel(int hierarchy_level, - absl::Span<const absl::uint128> evaluation_points, - EvaluationContext& ctx0, EvaluationContext& ctx1) { - int previous_hierarchy_level = ctx0.previous_hierarchy_level(); - int current_log_domain_size = - parameters_[hierarchy_level].log_domain_size(); - int previous_log_domain_size = 0; - int num_expansions = 1; - bool is_first_evaluation = previous_hierarchy_level < 0; - std::vector<absl::uint128> prefixes; - if (single_point_) { - // Single point evaluation: Generate prefixes for the current hierarchy - // level. - prefixes.resize(evaluation_points.size()); - for (int i = 0; i < static_cast<int>(evaluation_points.size()); ++i) { - prefixes[i] = GetPrefixForLevel(hierarchy_level, evaluation_points[i]); - } - } else if (!is_first_evaluation) { - // Full expansion: Generate prefixes for the previous hierarchy level if - // we're not on the first level. - num_expansions = static_cast<int>(evaluation_points.size()); - prefixes.resize(evaluation_points.size()); - previous_log_domain_size = - parameters_[previous_hierarchy_level].log_domain_size(); - for (int i = 0; i < static_cast<int>(evaluation_points.size()); ++i) { - prefixes[i] = - GetPrefixForLevel(previous_hierarchy_level, evaluation_points[i]); - } - } - - absl::StatusOr<std::vector<T>> result_0, result_1; - if (single_point_) { - result_0 = dpf_->EvaluateAt<T>(hierarchy_level, prefixes, ctx0); - result_1 = dpf_->EvaluateAt<T>(hierarchy_level, prefixes, ctx1); - } else { - result_0 = dpf_->EvaluateUntil<T>(hierarchy_level, prefixes, ctx0); - result_1 = dpf_->EvaluateUntil<T>(hierarchy_level, prefixes, ctx1); - } - - // Check results are ok. - DPF_EXPECT_OK(result_0) - << "hierarchy_level=" << hierarchy_level << "\nparameters={\n" - << parameters_[hierarchy_level].DebugString() << "}\n"; - DPF_EXPECT_OK(result_1); - if (result_0.ok() && result_1.ok()) { - // Check output sizes match. - ASSERT_EQ(result_0->size(), result_1->size()); - - if (single_point_) { - absl::uint128 current_alpha_prefix = - GetPrefixForLevel(hierarchy_level, alpha_); - for (int i = 0; i < result_0->size(); ++i) { - if (prefixes[i] == current_alpha_prefix) { - EXPECT_EQ(static_cast<T>((*result_0)[i] + (*result_1)[i]), - beta_[hierarchy_level]) - << "i=" << i << ", hierarchy_level=" << hierarchy_level - << "\nparameters={\n" - << parameters_[hierarchy_level].DebugString() << "}\n" - << "current_alpha_prefix=" << current_alpha_prefix << "\n" - << "prefixes[" << i << "]=" << prefixes[i] << "\n" - << "evaluation_points[" << i << "]=" << evaluation_points[i] - << "\n"; - } else { - EXPECT_EQ(static_cast<T>((*result_0)[i] + (*result_1)[i]), 0) - << "i=" << i << ", hierarchy_level=" << hierarchy_level - << "\nparameters={\n" - << parameters_[hierarchy_level].DebugString() << "}\n" - << "current_alpha_prefix=" << current_alpha_prefix << "\n" - << "prefixes[" << i << "]=" << prefixes[i] << "\n" - << "evaluation_points[" << i << "]=" << evaluation_points[i] - << "\n"; - } - } - } else { - int64_t outputs_per_prefix = - int64_t{1} << (current_log_domain_size - previous_log_domain_size); - int64_t expected_output_size = num_expansions * outputs_per_prefix; - ASSERT_EQ(result_0->size(), expected_output_size); - - // Iterate over the outputs and check that they sum up to 0 or to - // beta_[current_hierarchy_level]. - absl::uint128 previous_alpha_prefix = 0; - if (!is_first_evaluation) { - previous_alpha_prefix = - GetPrefixForLevel(previous_hierarchy_level, alpha_); - } - absl::uint128 current_alpha_prefix = - GetPrefixForLevel(hierarchy_level, alpha_); - for (int64_t i = 0; i < expected_output_size; ++i) { - int prefix_index = i / outputs_per_prefix; - int prefix_expansion_index = i % outputs_per_prefix; - // The output is on the path to alpha, if we're at the first level or - // under a prefix of alpha, and the current block in the expansion of - // the prefix is also on the path to alpha. - if ((is_first_evaluation || - prefixes[prefix_index] == previous_alpha_prefix) && - prefix_expansion_index == - (current_alpha_prefix % outputs_per_prefix)) { - // We need to static_cast here since otherwise operator+ returns an - // unsigned int without doing a modular reduction, which causes the - // test to fail on types with sizeof(T) < sizeof(unsigned). - EXPECT_EQ(static_cast<T>((*result_0)[i] + (*result_1)[i]), - beta_[hierarchy_level]) - << "i=" << i << ", hierarchy_level=" << hierarchy_level - << "\nparameters={\n" - << parameters_[hierarchy_level].DebugString() << "}\n" - << "previous_alpha_prefix=" << previous_alpha_prefix << "\n" - << "current_alpha_prefix=" << current_alpha_prefix << "\n"; - } else { - EXPECT_EQ(static_cast<T>((*result_0)[i] + (*result_1)[i]), 0) - << "i=" << i << ", hierarchy_level=" << hierarchy_level - << "\nparameters={\n" - << parameters_[hierarchy_level].DebugString() << "}\n"; - } - } - } - } - } - - std::vector<DpfParameters> parameters_; - std::unique_ptr<DistributedPointFunction> dpf_; - absl::uint128 alpha_; - std::vector<absl::uint128> beta_; - std::pair<DpfKey, DpfKey> keys_; - int level_step_; - bool single_point_; - std::unique_ptr<dpf_internal::ProtoValidator> proto_validator_; -}; - -TEST_P(IncrementalDpfTest, CreateEvaluationContextCreatesValidContext) { - DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx, - dpf_->CreateEvaluationContext(keys_.first)); - DPF_EXPECT_OK(proto_validator_->ValidateEvaluationContext(ctx)); -} - -TEST_P(IncrementalDpfTest, FailsIfPrefixNotPresentInCtx) { - if (parameters_.size() < 3 || parameters_[0].log_domain_size() < 1 || - parameters_[0].value_type().integer().bitsize() != 128 || - parameters_[1].value_type().integer().bitsize() != 128 || - parameters_[2].value_type().integer().bitsize() != 128) { - return; - } - DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx, - dpf_->CreateEvaluationContext(keys_.first)); - - // Evaluate on prefixes 0 and 1, then delete partial evaluations for prefix 0. - DPF_ASSERT_OK(dpf_->EvaluateNext<absl::uint128>({}, ctx)); - DPF_ASSERT_OK(dpf_->EvaluateNext<absl::uint128>({0, 1}, ctx)); - ctx.mutable_partial_evaluations()->erase(ctx.partial_evaluations().begin()); - - // The missing prefix corresponds to hierarchy level 1, even though we - // evaluate at level 2. - EXPECT_THAT(dpf_->EvaluateNext<absl::uint128>({0}, ctx), - StatusIs(absl::StatusCode::kInvalidArgument, - "Prefix not present in ctx.partial_evaluations at " - "hierarchy level 1")); -} - -TEST_P(IncrementalDpfTest, FailsIfDuplicatePrefixInCtx) { - if (parameters_.size() < 3 || parameters_[0].log_domain_size() < 1 || - parameters_[0].value_type().integer().bitsize() != 128 || - parameters_[1].value_type().integer().bitsize() != 128 || - parameters_[2].value_type().integer().bitsize() != 128) { - return; - } - DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx, - dpf_->CreateEvaluationContext(keys_.first)); - - // Evaluate on prefixes 0 and 1, then add duplicate prefix for 0 with - // different seed. - DPF_ASSERT_OK(dpf_->EvaluateNext<absl::uint128>({}, ctx)); - DPF_ASSERT_OK(dpf_->EvaluateNext<absl::uint128>({0, 1}, ctx)); - *(ctx.add_partial_evaluations()) = ctx.partial_evaluations(0); - Block changed_seed = ctx.partial_evaluations(0).seed(); - changed_seed.set_low(changed_seed.low() + 1); - *((ctx.mutable_partial_evaluations()->end() - 1)->mutable_seed()) = - changed_seed; - - // The missing prefix corresponds to hierarchy level 1, even though we - // evaluate at level 2. - EXPECT_THAT(dpf_->EvaluateNext<absl::uint128>({0}, ctx), - StatusIs(absl::StatusCode::kInvalidArgument, - "Duplicate prefix in `ctx.partial_evaluations()` with " - "mismatching seed or control bit")); -} - -TEST_P(IncrementalDpfTest, EvaluationFailsOnEmptyContext) { - if (parameters_[0].value_type().integer().bitsize() != 128) { - return; - } - EvaluationContext ctx; - - // We don't check the error message, since it depends on the ProtoValidator - // implementation which is tested in the corresponding unit test. - EXPECT_THAT(dpf_->EvaluateNext<absl::uint128>({}, ctx), - StatusIs(absl::StatusCode::kInvalidArgument)); -} - -TEST_P(IncrementalDpfTest, EvaluationFailsIfHierarchyLevelNegative) { - if (parameters_[0].value_type().integer().bitsize() != 128) { - return; - } - DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx, - dpf_->CreateEvaluationContext(keys_.first)); - - EXPECT_THAT(dpf_->EvaluateUntil<absl::uint128>(-1, {}, ctx), - StatusIs(absl::StatusCode::kInvalidArgument, - "`hierarchy_level` must be non-negative and less than " - "parameters_.size()")); -} - -TEST_P(IncrementalDpfTest, EvaluationFailsIfHierarchyLevelTooLarge) { - if (parameters_[0].value_type().integer().bitsize() != 128) { - return; - } - DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx, - dpf_->CreateEvaluationContext(keys_.first)); - - EXPECT_THAT(dpf_->EvaluateUntil<absl::uint128>(parameters_.size(), {}, ctx), - StatusIs(absl::StatusCode::kInvalidArgument, - "`hierarchy_level` must be non-negative and less than " - "parameters_.size()")); -} - -TEST_P(IncrementalDpfTest, EvaluationFailsIfValueTypeDoesntMatch) { - using SomeStrangeType = Tuple<uint8_t, uint32_t, uint8_t, uint16_t, uint8_t>; - DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx, - dpf_->CreateEvaluationContext(keys_.first)); - - EXPECT_THAT( - dpf_->EvaluateUntil<SomeStrangeType>(0, {}, ctx), - StatusIs(absl::StatusCode::kInvalidArgument, - "Value type T doesn't match parameters at `hierarchy_level`")); -} - -TEST_P(IncrementalDpfTest, EvaluationFailsIfLevelAlreadyEvaluated) { - if (parameters_.size() < 2 || - parameters_[0].value_type().integer().bitsize() != 128) { - return; - } - DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx, - dpf_->CreateEvaluationContext(keys_.first)); - - DPF_ASSERT_OK(dpf_->EvaluateUntil<absl::uint128>(0, {}, ctx)); - - EXPECT_THAT(dpf_->EvaluateUntil<absl::uint128>(0, {}, ctx), - StatusIs(absl::StatusCode::kInvalidArgument, - "`hierarchy_level` must be greater than " - "`ctx.previous_hierarchy_level`")); -} - -TEST_P(IncrementalDpfTest, EvaluationFailsIfPrefixesNotEmptyOnFirstCall) { - if (parameters_[0].value_type().integer().bitsize() != 128) { - return; - } - DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx, - dpf_->CreateEvaluationContext(keys_.first)); - - EXPECT_THAT( - dpf_->EvaluateUntil<absl::uint128>(0, {0}, ctx), - StatusIs( - absl::StatusCode::kInvalidArgument, - "`prefixes` must be empty if and only if this is the first call with " - "`ctx`.")); -} - -TEST_P(IncrementalDpfTest, EvaluationFailsIfPrefixOutOfRange) { - if (parameters_.size() < 2 || - parameters_[0].value_type().integer().bitsize() != 128 || - parameters_[1].value_type().integer().bitsize() != 128) { - return; - } - DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx, - dpf_->CreateEvaluationContext(keys_.first)); - - DPF_ASSERT_OK(dpf_->EvaluateUntil<absl::uint128>(0, {}, ctx)); - auto invalid_prefix = absl::uint128{1} << parameters_[0].log_domain_size(); - - EXPECT_THAT(dpf_->EvaluateUntil<absl::uint128>(1, {invalid_prefix}, ctx), - StatusIs(absl::StatusCode::kInvalidArgument, - StrFormat("Index %d out of range for hierarchy level 0", - invalid_prefix))); -} - -TEST_P(IncrementalDpfTest, TestCorrectness) { - // Generate a random set of evaluation points. The library should be able to - // handle duplicates, so fixing the size to 1000 works even for smaller - // domains. - absl::BitGen rng; - absl::uniform_int_distribution<uint64_t> dist; - const int kNumEvaluationPoints = 1000; - std::vector<absl::uint128> evaluation_points(kNumEvaluationPoints); - for (int i = 0; i < kNumEvaluationPoints - 1; ++i) { - evaluation_points[i] = absl::MakeUint128(dist(rng), dist(rng)); - if (parameters_.back().log_domain_size() < 128) { - evaluation_points[i] %= absl::uint128{1} - << parameters_.back().log_domain_size(); - } - } - evaluation_points.back() = alpha_; // Always evaluate on alpha_. - - int num_levels = static_cast<int>(parameters_.size()); - DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx0, - dpf_->CreateEvaluationContext(keys_.first)); - DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx1, - dpf_->CreateEvaluationContext(keys_.second)); - - for (int i = level_step_ - 1; i < num_levels; i += level_step_) { - switch (parameters_[i].value_type().integer().bitsize()) { - case 8: - EvaluateAndCheckLevel<uint8_t>(i, evaluation_points, ctx0, ctx1); - break; - case 16: - EvaluateAndCheckLevel<uint16_t>(i, evaluation_points, ctx0, ctx1); - break; - case 32: - EvaluateAndCheckLevel<uint32_t>(i, evaluation_points, ctx0, ctx1); - break; - case 64: - EvaluateAndCheckLevel<uint64_t>(i, evaluation_points, ctx0, ctx1); - break; - case 128: - EvaluateAndCheckLevel<absl::uint128>(i, evaluation_points, ctx0, ctx1); - break; - default: - ASSERT_TRUE(0) << "Unsupported element_bitsize"; - } - } -} - -INSTANTIATE_TEST_SUITE_P( - OneHierarchyLevelVaryElementSizes, IncrementalDpfTest, - testing::Combine( - // DPF parameters. - testing::Values( - // Vary element sizes, small domain size. - std::vector<DpfTestParameters>{ - {.log_domain_size = 4, .element_bitsize = 8}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 4, .element_bitsize = 16}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 4, .element_bitsize = 32}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 4, .element_bitsize = 64}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 4, .element_bitsize = 128}}, - // Vary element sizes, medium domain size. - std::vector<DpfTestParameters>{ - {.log_domain_size = 10, .element_bitsize = 8}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 10, .element_bitsize = 16}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 10, .element_bitsize = 32}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 10, .element_bitsize = 64}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 10, .element_bitsize = 128}}), - testing::Values(0, 1, 15), // alpha - testing::Values(std::vector<absl::uint128>(1, 1), - std::vector<absl::uint128>(1, 100), - std::vector<absl::uint128>(1, 255)), // beta - testing::Values(1), // level_step - testing::Values(false, true) // single_point - )); - -INSTANTIATE_TEST_SUITE_P( - OneHierarchyLevelVaryDomainSizes, IncrementalDpfTest, - testing::Combine( - // DPF parameters. - testing::Values( - // Vary domain sizes, small element size. - std::vector<DpfTestParameters>{ - {.log_domain_size = 0, .element_bitsize = 8}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 1, .element_bitsize = 8}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 2, .element_bitsize = 8}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 3, .element_bitsize = 8}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 4, .element_bitsize = 8}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 5, .element_bitsize = 8}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 6, .element_bitsize = 8}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 7, .element_bitsize = 8}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 8, .element_bitsize = 8}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 9, .element_bitsize = 8}}, - // Vary domain sizes, medium element size. - std::vector<DpfTestParameters>{ - {.log_domain_size = 0, .element_bitsize = 64}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 1, .element_bitsize = 64}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 2, .element_bitsize = 64}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 3, .element_bitsize = 64}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 4, .element_bitsize = 64}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 5, .element_bitsize = 64}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 6, .element_bitsize = 64}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 7, .element_bitsize = 64}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 8, .element_bitsize = 64}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 9, .element_bitsize = 64}}, - // Vary domain sizes, large element size. - std::vector<DpfTestParameters>{ - {.log_domain_size = 0, .element_bitsize = 128}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 1, .element_bitsize = 128}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 2, .element_bitsize = 128}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 3, .element_bitsize = 128}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 4, .element_bitsize = 128}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 5, .element_bitsize = 128}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 6, .element_bitsize = 128}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 7, .element_bitsize = 128}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 8, .element_bitsize = 128}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 9, .element_bitsize = 128}}), - testing::Values(0), // alpha - testing::Values(std::vector<absl::uint128>(1, 1), - std::vector<absl::uint128>(1, 100), - std::vector<absl::uint128>(1, 255)), // beta - testing::Values(1), // level_step - testing::Values(false, true) // single_point - )); - -INSTANTIATE_TEST_SUITE_P( - TwoHierarchyLevels, IncrementalDpfTest, - testing::Combine( - // DPF parameters. - testing::Values( - // Equal element sizes. - std::vector<DpfTestParameters>{ - {.log_domain_size = 5, .element_bitsize = 8}, - {.log_domain_size = 10, .element_bitsize = 8}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 5, .element_bitsize = 16}, - {.log_domain_size = 10, .element_bitsize = 16}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 5, .element_bitsize = 32}, - {.log_domain_size = 10, .element_bitsize = 32}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 5, .element_bitsize = 64}, - {.log_domain_size = 10, .element_bitsize = 64}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 5, .element_bitsize = 128}, - {.log_domain_size = 10, .element_bitsize = 128}}, - // First correction is in seed word. - std::vector<DpfTestParameters>{ - {.log_domain_size = 0, .element_bitsize = 8}, - {.log_domain_size = 10, .element_bitsize = 128}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 0, .element_bitsize = 16}, - {.log_domain_size = 10, .element_bitsize = 128}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 0, .element_bitsize = 32}, - {.log_domain_size = 10, .element_bitsize = 128}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 0, .element_bitsize = 64}, - {.log_domain_size = 10, .element_bitsize = 128}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 0, .element_bitsize = 128}, - {.log_domain_size = 10, .element_bitsize = 128}}), - testing::Values(0, 1, 2, 100, 1023), // alpha - testing::Values(std::vector<absl::uint128>({1, 2}), - std::vector<absl::uint128>({80, 90}), - std::vector<absl::uint128>({255, 255})), // beta - testing::Values(1, 2), // level_step - testing::Values(false, true) // single_point - )); - -INSTANTIATE_TEST_SUITE_P( - ThreeHierarchyLevels, IncrementalDpfTest, - testing::Combine( - // DPF parameters. - testing::Values<std::vector<DpfTestParameters>>( - // Equal element sizes. - std::vector<DpfTestParameters>{ - {.log_domain_size = 5, .element_bitsize = 8}, - {.log_domain_size = 10, .element_bitsize = 8}, - {.log_domain_size = 15, .element_bitsize = 8}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 5, .element_bitsize = 16}, - {.log_domain_size = 10, .element_bitsize = 16}, - {.log_domain_size = 15, .element_bitsize = 16}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 5, .element_bitsize = 32}, - {.log_domain_size = 10, .element_bitsize = 32}, - {.log_domain_size = 15, .element_bitsize = 32}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 5, .element_bitsize = 64}, - {.log_domain_size = 10, .element_bitsize = 64}, - {.log_domain_size = 15, .element_bitsize = 64}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 5, .element_bitsize = 128}, - {.log_domain_size = 10, .element_bitsize = 128}, - {.log_domain_size = 15, .element_bitsize = 128}}, - // Varying element sizes - std::vector<DpfTestParameters>{ - {.log_domain_size = 5, .element_bitsize = 8}, - {.log_domain_size = 10, .element_bitsize = 16}, - {.log_domain_size = 15, .element_bitsize = 32}}, - // Small level distances. - std::vector<DpfTestParameters>{ - {.log_domain_size = 4, .element_bitsize = 8}, - {.log_domain_size = 5, .element_bitsize = 8}, - {.log_domain_size = 6, .element_bitsize = 8}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 3, .element_bitsize = 16}, - {.log_domain_size = 4, .element_bitsize = 16}, - {.log_domain_size = 5, .element_bitsize = 16}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 2, .element_bitsize = 32}, - {.log_domain_size = 3, .element_bitsize = 32}, - {.log_domain_size = 4, .element_bitsize = 32}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 1, .element_bitsize = 64}, - {.log_domain_size = 2, .element_bitsize = 64}, - {.log_domain_size = 3, .element_bitsize = 64}}, - std::vector<DpfTestParameters>{ - {.log_domain_size = 0, .element_bitsize = 128}, - {.log_domain_size = 1, .element_bitsize = 128}, - {.log_domain_size = 2, .element_bitsize = 128}}), - testing::Values(0, 1), // alpha - testing::Values(std::vector<absl::uint128>({1, 2, 3})), // beta - testing::Values(1, 2), // level_step - testing::Values(false, true) // single_point - )); - -INSTANTIATE_TEST_SUITE_P( - MaximumOutputDomainSize, IncrementalDpfTest, - testing::Combine( - // DPF parameters. We want to be able to evaluate at every bit, so this - // lambda returns a vector with 129 parameters with log domain sizes - // 0...128. - testing::Values([]() -> std::vector<DpfTestParameters> { - std::vector<DpfTestParameters> parameters(129); - for (int i = 0; i < static_cast<int>(parameters.size()); ++i) { - parameters[i].log_domain_size = i; - parameters[i].element_bitsize = 64; - } - return parameters; - }()), - testing::Values(absl::MakeUint128(23, 42)), // alpha - testing::Values(std::vector<absl::uint128>(129, 1234567)), // beta - testing::Values(1, 2, 3, 5, 7), // level_step - testing::Values(false, true) // single_point - )); - -template <typename T> -class DpfEvaluationTest : public ::testing::Test { - protected: - void SetUp() { SetUp(10, 23); } - void SetUp(int log_domain_size, absl::uint128 alpha) { - return SetUp(absl::MakeConstSpan(&log_domain_size, 1), alpha); - } - void SetUp(absl::Span<const int> log_domain_size, absl::uint128 alpha) { - log_domain_size_.resize(log_domain_size.size()); - absl::c_copy(log_domain_size, log_domain_size_.begin()); - alpha_ = alpha; - beta_.resize(log_domain_size.size()); - for (T& beta : beta_) { - SetTo42(beta); - } - parameters_.resize(log_domain_size.size()); - for (int i = 0; i < parameters_.size(); ++i) { - parameters_[i].set_log_domain_size(log_domain_size_[i]); - parameters_[i].set_security_parameter(48); - *(parameters_[i].mutable_value_type()) = ToValueType<T>(); - } - DPF_ASSERT_OK_AND_ASSIGN( - dpf_, DistributedPointFunction::CreateIncremental(parameters_)); - DPF_ASSERT_OK(this->dpf_->template RegisterValueType<T>()); - DPF_ASSERT_OK_AND_ASSIGN( - keys_, this->dpf_->GenerateKeysIncremental( - this->alpha_, absl::MakeConstSpan(this->beta_))); - } - - // Helper function that recursively sets all elements of a tuple to 42. - template <typename T0> - static void SetTo42(T0& x) { - x = T0(42); - } - template <typename T0, typename... Tn> - static void SetTo42(T0& x0, Tn&... xn) { - SetTo42(x0); - SetTo42(xn...); - } - template <typename... Tn> - static void SetTo42(Tuple<Tn...>& x) { - absl::apply([](auto&... in) { SetTo42(in...); }, x.value()); - } - - std::vector<int> log_domain_size_; - absl::uint128 alpha_; - std::vector<T> beta_; - std::vector<DpfParameters> parameters_; - std::unique_ptr<DistributedPointFunction> dpf_; - std::pair<DpfKey, DpfKey> keys_; -}; - -using MyIntModN = IntModN<uint32_t, 4294967291u>; // 2**32 - 5. -using MyIntModN64 = IntModN<uint64_t, 18446744073709551557ull>; // 2**64 - 59. -#ifdef ABSL_HAVE_INTRINSIC_INT128 -using MyIntModN128 = - IntModN<absl::uint128, (unsigned __int128)(absl::MakeUint128( - 65535u, 18446744073709551551ull))>; // 2**80-65 -#endif -using DpfEvaluationTypes = ::testing::Types< - // Integers - uint8_t, uint32_t, uint64_t, absl::uint128, - // Tuple - Tuple<uint8_t>, Tuple<uint32_t>, Tuple<absl::uint128>, - Tuple<uint32_t, uint32_t>, Tuple<uint32_t, uint64_t>, - Tuple<uint64_t, uint64_t>, Tuple<uint8_t, uint16_t, uint32_t, uint64_t>, - Tuple<uint32_t, uint32_t, uint32_t, uint32_t>, - Tuple<uint32_t, Tuple<uint32_t, uint32_t>, uint32_t>, - Tuple<uint32_t, absl::uint128>, - // IntModN - MyIntModN, Tuple<MyIntModN>, Tuple<uint32_t, MyIntModN>, - Tuple<absl::uint128, MyIntModN>, Tuple<MyIntModN, Tuple<MyIntModN>>, - Tuple<MyIntModN, MyIntModN, MyIntModN, MyIntModN, MyIntModN>, - Tuple<MyIntModN64, MyIntModN64> -#ifdef ABSL_HAVE_INTRINSIC_INT128 - , - Tuple<MyIntModN128, MyIntModN128>, -#endif - // XorWrapper - XorWrapper<uint8_t>, XorWrapper<absl::uint128>, - Tuple<XorWrapper<uint32_t>, absl::uint128>>; -TYPED_TEST_SUITE(DpfEvaluationTest, DpfEvaluationTypes); - -TYPED_TEST(DpfEvaluationTest, TestRegularDpf) { - int log_domain_size = 10; - absl::uint128 alpha = 23; - this->SetUp(log_domain_size, alpha); - DPF_ASSERT_OK_AND_ASSIGN( - EvaluationContext ctx_1, - this->dpf_->CreateEvaluationContext(this->keys_.first)); - DPF_ASSERT_OK_AND_ASSIGN( - EvaluationContext ctx_2, - this->dpf_->CreateEvaluationContext(this->keys_.second)); - DPF_ASSERT_OK_AND_ASSIGN( - std::vector<TypeParam> output_1, - this->dpf_->template EvaluateNext<TypeParam>({}, ctx_1)); - DPF_ASSERT_OK_AND_ASSIGN( - std::vector<TypeParam> output_2, - this->dpf_->template EvaluateNext<TypeParam>({}, ctx_2)); - - EXPECT_EQ(output_1.size(), 1 << log_domain_size); - EXPECT_EQ(output_2.size(), 1 << log_domain_size); - for (int i = 0; i < (1 << log_domain_size); ++i) { - TypeParam sum = output_1[i] + output_2[i]; - if (i == this->alpha_) { - EXPECT_EQ(sum, this->beta_[0]); - } else { - EXPECT_EQ(sum, TypeParam{}); - } - } -} - -TYPED_TEST(DpfEvaluationTest, TestBatchSinglePointEvaluation) { - // Set Up with a large output domain, to make sure this works. - for (int log_domain_size : {0, 1, 2, 32, 128}) { - absl::uint128 max_evaluation_point = absl::Uint128Max(); - if (log_domain_size < 128) { - max_evaluation_point = (absl::uint128{1} << log_domain_size) - 1; - } - const absl::uint128 alpha = 23 & max_evaluation_point; - this->SetUp(log_domain_size, alpha); - for (int num_evaluation_points : {0, 1, 2, 100, 1000}) { - std::vector<absl::uint128> evaluation_points(num_evaluation_points); - for (int i = 0; i < num_evaluation_points; ++i) { - evaluation_points[i] = i & max_evaluation_point; - } - DPF_ASSERT_OK_AND_ASSIGN(std::vector<TypeParam> output_1, - this->dpf_->template EvaluateAt<TypeParam>( - this->keys_.first, 0, evaluation_points)); - DPF_ASSERT_OK_AND_ASSIGN(std::vector<TypeParam> output_2, - this->dpf_->template EvaluateAt<TypeParam>( - this->keys_.second, 0, evaluation_points)); - ASSERT_EQ(output_1.size(), output_2.size()); - ASSERT_EQ(output_1.size(), num_evaluation_points); - - for (int i = 0; i < num_evaluation_points; ++i) { - TypeParam sum = output_1[i] + output_2[i]; - if (evaluation_points[i] == this->alpha_) { - EXPECT_EQ(sum, this->beta_[0]) - << "i=" << i << ", log_domain_size=" << log_domain_size; - } else { - EXPECT_EQ(sum, TypeParam{}) - << "i=" << i << ", log_domain_size=" << log_domain_size; - } - } - } - } -} - -TYPED_TEST(DpfEvaluationTest, TestEvaluateAndApplySimpleAddition) { - std::vector<std::vector<int>> parameters = { - {0, 1, 2}, {8, 16, 32, 64}, {0, 128}, {128}, {/* filled below */}}; - for (int i = 0; i <= 128; ++i) { - parameters.back().push_back(i); - } - for (const auto& log_domain_sizes : parameters) { - absl::uint128 max_domain_element = absl::Uint128Max(); - if (log_domain_sizes.back() < 128) { - max_domain_element = (absl::uint128{1} << log_domain_sizes.back()) - 1; - } - absl::uint128 alpha = max_domain_element; - this->SetUp(log_domain_sizes, alpha); - - std::vector<absl::uint128> evaluation_points = {23, 42, 123, 0, - absl::Uint128Max()}; - for (auto& point : evaluation_points) { - point &= max_domain_element; - } - std::vector<const DpfKey*> keys = { - &(this->keys_.first), &(this->keys_.second), &(this->keys_.first), - &(this->keys_.second), &(this->keys_.first)}; - int num_levels = log_domain_sizes.size(); - int num_keys = keys.size(); - - std::vector<TypeParam> sum(num_keys, TypeParam{}); - int count = 0; - auto fn = [&sum, &count](absl::Span<const TypeParam> values) { - for (int i = 0; i < values.size(); ++i) { - sum[i] += values[i]; - } - ++count; - return true; - }; - - // Run evaluation level-by-level to compute the expected sum. - std::vector<TypeParam> expected(num_keys, TypeParam{}); - for (int hierarchy_level = 0; hierarchy_level < num_levels; - ++hierarchy_level) { - const int shift_amount = - (log_domain_sizes.back() - log_domain_sizes[hierarchy_level]); - for (int i = 0; i < num_keys; ++i) { - absl::uint128 prefix = 0; - if (shift_amount < 128) { - prefix = evaluation_points[i] >> shift_amount; - } - DPF_ASSERT_OK_AND_ASSIGN( - auto result, - this->dpf_->template EvaluateAt<TypeParam>( - *keys[i], hierarchy_level, absl::MakeConstSpan(&prefix, 1))); - expected[i] += result[0]; - } - } - - EXPECT_THAT(this->dpf_->template EvaluateAndApply<TypeParam>( - keys, evaluation_points, fn), - IsOk()); - EXPECT_EQ(sum, expected) - << "log_domain_sizes=" << absl::StrJoin(log_domain_sizes, " "); - EXPECT_EQ(count, num_levels); - } -} - -TYPED_TEST(DpfEvaluationTest, - EvaluateAndApplyFailsWithTooManyEvaluationPoints) { - std::vector<absl::uint128> evaluation_points = {0, 1}; - - EXPECT_THAT( - this->dpf_->template EvaluateAndApply<TypeParam>( - absl::MakeConstSpan(&(this->keys_.first), 1), evaluation_points, - [](absl::Span<const TypeParam>) { return true; }), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("evaluation_points"))); -} - -TYPED_TEST(DpfEvaluationTest, EvaluateAndApplyFailsWithInvalidKey) { - DpfKey key; - - EXPECT_THAT(this->dpf_->template EvaluateAndApply<TypeParam>( - absl::MakeConstSpan(&key, 1), {0}, - [](absl::Span<const TypeParam>) { return true; }), - StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("key"))); -} - -} // namespace -} // namespace distributed_point_functions diff --git a/third_party/distributed_point_functions/code/dpf/int_mod_n.cc b/third_party/distributed_point_functions/code/dpf/int_mod_n.cc deleted file mode 100644 index 3b5d9926907ae..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/int_mod_n.cc +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dpf/int_mod_n.h" - -#include <cmath> -#include <string> - -#include "absl/numeric/int128.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" - -namespace distributed_point_functions { - -namespace dpf_internal { - -double IntModNBase::GetSecurityLevel(int num_samples, absl::uint128 modulus) { - return 128 + 3 - - (std::log2(static_cast<double>(modulus)) + - std::log2(static_cast<double>(num_samples)) + - std::log2(static_cast<double>(num_samples + 1))); -} - -absl::Status IntModNBase::CheckParameters(int num_samples, - int base_integer_bitsize, - absl::uint128 modulus, - double security_parameter) { - if (num_samples <= 0) { - return absl::InvalidArgumentError("num_samples must be positive"); - } - if (base_integer_bitsize <= 0) { - return absl::InvalidArgumentError("base_integer_bitsize must be positive"); - } - if (base_integer_bitsize > 128) { - return absl::InvalidArgumentError( - "base_integer_bitsize must be at most 128"); - } - if (base_integer_bitsize < 128 && - (absl::uint128{1} << base_integer_bitsize) < modulus) { - return absl::InvalidArgumentError(absl::StrFormat( - "kModulus %d out of range for base_integer_bitsize = %d", modulus, - base_integer_bitsize)); - } - - // Compute the level of security that we will get, and fail if it is - // insufficient. - const double sigma = GetSecurityLevel(num_samples, modulus); - if (security_parameter > sigma) { - return absl::InvalidArgumentError(absl::StrFormat( - "For num_samples = %d and kModulus = %d this approach can only " - "provide " - "%f bits of statistical security. You can try calling this function " - "several times with smaller values of num_samples.", - num_samples, modulus, sigma)); - } - return absl::OkStatus(); -} - -absl::StatusOr<int> IntModNBase::GetNumBytesRequired( - int num_samples, int base_integer_bitsize, absl::uint128 modulus, - double security_parameter) { - absl::Status status = CheckParameters(num_samples, base_integer_bitsize, - modulus, security_parameter); - if (!status.ok()) { - return status; - } - - const int base_integer_bytes = ((base_integer_bitsize + 7) / 8); - // We start the sampling by requiring a 128-bit (16 bytes) block, see - // function `SampleFromBytes`. - return 16 + base_integer_bytes * (num_samples - 1); -} - -} // namespace dpf_internal - -} // namespace distributed_point_functions diff --git a/third_party/distributed_point_functions/code/dpf/int_mod_n.h b/third_party/distributed_point_functions/code/dpf/int_mod_n.h deleted file mode 100644 index 0f61ed05d1dca..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/int_mod_n.h +++ /dev/null @@ -1,282 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_INT_MOD_N_H_ -#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_INT_MOD_N_H_ - -#include <algorithm> -#include <string> -#include <type_traits> - -#include "absl/base/config.h" -#include "absl/container/inlined_vector.h" -#include "absl/log/absl_check.h" -#include "absl/numeric/int128.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" - -namespace distributed_point_functions { - -namespace dpf_internal { - -// Base class holding common functions of IntModN that are independent of the -// template parameter. -class IntModNBase { - public: - // Computes the security level achievable when sampling `num_samples` elements - // with the given `kModulus`. - // - static double GetSecurityLevel(int num_samples, absl::uint128 modulus); - - // Checks if the given parameters are consistent and valid for an IntModN. - // - // Returns OK for valid parameters, and INVALID_ARGUMENT otherwise. - static absl::Status CheckParameters(int num_samples, int base_integer_bitsize, - absl::uint128 modulus, - double security_parameter); - - // Computes the number of bytes required to sample `num_samples` integers - // modulo `kModulus` with an underlying integer type of - // `base_integer_bitsize`. - // - // Returns INVALID_ARGUMENT if the achievable security level with the given - // parameters is less than `security_parameter`, or if the parameters are - // invalid. - static absl::StatusOr<int> GetNumBytesRequired(int num_samples, - int base_integer_bitsize, - absl::uint128 modulus, - double security_parameter); - - // Creates a value of type T from the given `bytes`, using little-endian - // encoding. Called by SampleFromBytes. Crashes if bytes.size() != sizeof(T). - // - // This is a reimplementation of dpf_internal::ConvertBytesTo for integers, - // to avoid depending on value_type_helpers here. - template <typename T> - static T ConvertBytesTo(absl::string_view bytes) { - ABSL_CHECK(bytes.size() == sizeof(T)); - T out{0}; -#ifdef ABSL_IS_LITTLE_ENDIAN - std::copy_n(bytes.begin(), sizeof(T), reinterpret_cast<char*>(&out)); -#else - for (int i = sizeof(T) - 1; i >= 0; --i) { - out |= absl::bit_cast<uint8_t>(bytes[i]); - out <<= 8; - } -#endif - return out; - } -}; - -template <typename BaseInteger, typename ModulusType, ModulusType kModulus> -class IntModNImpl : public IntModNBase { - static_assert(sizeof(BaseInteger) <= sizeof(absl::uint128), - "BaseInteger may be at most 128 bits large"); - static_assert( - std::is_same<BaseInteger, absl::uint128>::value || -#ifdef ABSL_HAVE_INTRINSIC_INT128 - // std::is_unsigned_v<unsigned __int128> is not true everywhere: - // https://quuxplusone.github.io/blog/2019/02/28/is-int128-integral/#signedness - std::is_same<BaseInteger, unsigned __int128>::value || -#endif - std::is_unsigned<BaseInteger>::value, - "BaseInteger must be unsigned"); - static_assert(kModulus <= ModulusType(BaseInteger(-1)), - "kModulus must fit in BaseInteger"); - - public: - using Base = BaseInteger; - - constexpr IntModNImpl() : value_(0) {} - explicit constexpr IntModNImpl(BaseInteger value) - : value_(value % kModulus) {} - - // Copyable. - constexpr IntModNImpl(const IntModNImpl& a) = default; - - constexpr IntModNImpl& operator=(const IntModNImpl& a) = default; - - // Assignment operators. - constexpr IntModNImpl& operator=(const BaseInteger& a) { - value_ = a % kModulus; - return *this; - } - - constexpr IntModNImpl& operator+=(const IntModNImpl& a) { - AddBaseInteger(a.value_); - return *this; - } - - constexpr IntModNImpl& operator-=(const IntModNImpl& a) { - SubtractBaseInteger(a.value_); - return *this; - } - - // Returns the underlying representation as a BaseInteger. - constexpr BaseInteger value() const { return value_; } - - // Returns the modulus of this IntModNImpl type. - static constexpr BaseInteger modulus() { return kModulus; } - - // Returns the number of (pseudo)random bytes required to extract - // `num_samples` samples r1, ..., rn - // so that the stream r1, ..., rn is close to a truly (pseudo) random - // sequence up to total variation distance < 2^(-`security_parameter`) - static absl::StatusOr<int> GetNumBytesRequired(int num_samples, - double security_parameter) { - return IntModNBase::GetNumBytesRequired( - num_samples, 8 * sizeof(BaseInteger), kModulus, security_parameter); - } - - // Extracts `samples.size()` samples r1, ..., rn so that the stream r1, ..., - // rn is close to a truly (pseudo) random sequence up to total variation - // distance < 2^(-`security_parameter`). Returns r1, ..., rn in `samples`. - // - // The optional template argument allows users to specify the number of - // samples at compile time, which can save heap allocations. - // - // Caution: For performance reasons, this function does not check whether - // `bytes` is long enough for the required number of samples and security - // parameter. Use `GetNumBytesRequired` or `SampleFromBytes` if such checks - // are needed. - // - template <int kCompiledNumSamples = 1> - static void UnsafeSampleFromBytes(absl::string_view bytes, - double security_parameter, - absl::Span<IntModNImpl> samples) { - static_assert(kCompiledNumSamples >= 1, - "kCompiledNumSamples must be positive"); - absl::uint128 r = ConvertBytesTo<absl::uint128>(bytes.substr(0, 16)); - absl::InlinedVector<BaseInteger, std::max(1, kCompiledNumSamples - 1)> - randomness(samples.size() - 1); - for (int i = 0; i < static_cast<int>(randomness.size()); ++i) { - randomness[i] = ConvertBytesTo<BaseInteger>( - bytes.substr(16 + i * sizeof(BaseInteger), sizeof(BaseInteger))); - } - for (int i = 0; i < static_cast<int>(samples.size()); ++i) { - samples[i] = IntModNImpl(static_cast<BaseInteger>(r % kModulus)); - if (i < static_cast<int>(randomness.size())) { - r /= kModulus; - if (sizeof(BaseInteger) < sizeof(absl::uint128)) { - r <<= (sizeof(BaseInteger) * 8); - } - r |= randomness[i]; - } - } - } - - // Checks that length(`bytes`) is enough to extract - // `samples.size()` samples r1, ..., rn - // so that the stream r1, ..., rn is close to a truly (pseudo) random - // sequence up to total variation distance < 2^(-`security_parameter`) and - // fails if that is not the case. - // Otherwise returns r1, ..., rn in `samples`. - static absl::Status SampleFromBytes(absl::string_view bytes, - double security_parameter, - absl::Span<IntModNImpl> samples) { - if (samples.empty()) { - return absl::InvalidArgumentError( - "The number of samples required must be > 0"); - } - absl::StatusOr<int> num_bytes_lower_bound = - GetNumBytesRequired(samples.size(), security_parameter); - if (!num_bytes_lower_bound.ok()) { - return num_bytes_lower_bound.status(); - } - if (*num_bytes_lower_bound > bytes.size()) { - return absl::InvalidArgumentError( - absl::StrCat("The number of bytes provided (", bytes.size(), - ") is insufficient for the required " - "statistical security and number of samples.")); - } - UnsafeSampleFromBytes(bytes, security_parameter, samples); - return absl::OkStatus(); - } - - private: - constexpr void SubtractBaseInteger(const BaseInteger& a) { - if (value_ >= a) { - value_ -= a; - } else { - value_ = kModulus - a + value_; - } - } - - constexpr void AddBaseInteger(const BaseInteger& a) { - SubtractBaseInteger(kModulus - a); - } - - BaseInteger value_; -}; - -template <typename BaseInteger, typename ModulusType, ModulusType kModulus> -constexpr IntModNImpl<BaseInteger, ModulusType, kModulus> operator+( - IntModNImpl<BaseInteger, ModulusType, kModulus> a, - const IntModNImpl<BaseInteger, ModulusType, kModulus>& b) { - a += b; - return a; -} - -template <typename BaseInteger, typename ModulusType, ModulusType kModulus> -constexpr IntModNImpl<BaseInteger, ModulusType, kModulus> operator-( - IntModNImpl<BaseInteger, ModulusType, kModulus> a, - const IntModNImpl<BaseInteger, ModulusType, kModulus>& b) { - a -= b; - return a; -} - -template <typename BaseInteger, typename ModulusType, ModulusType kModulus> -constexpr IntModNImpl<BaseInteger, ModulusType, kModulus> operator-( - IntModNImpl<BaseInteger, ModulusType, kModulus> a) { - IntModNImpl<BaseInteger, ModulusType, kModulus> result(BaseInteger{0}); - result -= a; - return result; -} - -template <typename BaseInteger, typename ModulusType, ModulusType kModulus> -constexpr bool operator==( - const IntModNImpl<BaseInteger, ModulusType, kModulus>& a, - const IntModNImpl<BaseInteger, ModulusType, kModulus>& b) { - return a.value() == b.value(); -} - -template <typename BaseInteger, typename ModulusType, ModulusType kModulus> -constexpr bool operator!=( - const IntModNImpl<BaseInteger, ModulusType, kModulus>& a, - const IntModNImpl<BaseInteger, ModulusType, kModulus>& b) { - return !(a == b); -} - -} // namespace dpf_internal - -// Since `absl::uint128` is not an alias to `unsigned __int128`, but a struct, -// we cannot use it as a template parameter type. So if we have an intrinsic -// int128, we always use that as the modulus type. Otherwise, the modulus type -// is the same as BaseInteger. -#ifdef ABSL_HAVE_INTRINSIC_INT128 -template <typename BaseInteger, unsigned __int128 kModulus> -using IntModN = - dpf_internal::IntModNImpl<BaseInteger, unsigned __int128, kModulus>; -#else -template <typename BaseInteger, BaseInteger kModulus> -using IntModN = dpf_internal::IntModNImpl<BaseInteger, BaseInteger, kModulus>; -#endif - -} // namespace distributed_point_functions -#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_INT_MOD_N_H_ diff --git a/third_party/distributed_point_functions/code/dpf/int_mod_n_benchmark.cc b/third_party/distributed_point_functions/code/dpf/int_mod_n_benchmark.cc deleted file mode 100644 index b095faa391abd..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/int_mod_n_benchmark.cc +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <stdint.h> - -#include <cmath> -#include <vector> - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "benchmark/benchmark.h" -#include "dpf/int_mod_n.h" -#include "openssl/rand.h" - -namespace distributed_point_functions { -namespace { - -using MyInt = IntModN<uint32_t, 4294967291u>; // 2**32 - 5. -constexpr int kNumSamples = 5; - -void BM_Sample(benchmark::State& state) { - int num_iterations = state.range(0); - double security_parameter = 40 + std::log2(num_iterations); - std::vector<uint8_t> bytes( - MyInt::GetNumBytesRequired(kNumSamples, security_parameter).value()); - RAND_bytes(bytes.data(), bytes.size()); - std::vector<MyInt> output(num_iterations * kNumSamples); - for (auto s : state) { - for (int i = 0; i < num_iterations; ++i) { - MyInt::UnsafeSampleFromBytes<kNumSamples>( - absl::string_view(reinterpret_cast<const char*>(bytes.data()), - bytes.size()), - security_parameter, - absl::MakeSpan(&output[i * kNumSamples], kNumSamples)); - } - benchmark::DoNotOptimize(output); - } -} -BENCHMARK(BM_Sample)->Range(1, 1 << 20); - -} // namespace -} // namespace distributed_point_functions diff --git a/third_party/distributed_point_functions/code/dpf/int_mod_n_test.cc b/third_party/distributed_point_functions/code/dpf/int_mod_n_test.cc deleted file mode 100644 index dd03d50364695..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/int_mod_n_test.cc +++ /dev/null @@ -1,258 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dpf/int_mod_n.h" - -#include <cstdint> -#include <string> -#include <vector> - -#include "absl/base/config.h" -#include "absl/numeric/int128.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" -#include "absl/types/span.h" -#include "dpf/internal/status_matchers.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace distributed_point_functions { -namespace { - -constexpr double kFeasibleSecurityParameter = 40; -constexpr double kUnfeasibleSecurityParameter = 95; -constexpr int kNumSamples = 5; - -template <typename T> -class IntModNTest : public testing::Test {}; -using IntModNTypes = ::testing::Types< - IntModN<uint32_t, 4294967291u>, // 2**32-5 - IntModN<uint64_t, 18446744073709551557ull> // 2**64-59 -#ifdef ABSL_HAVE_INTRINSIC_INT128 - , - IntModN<absl::uint128, (unsigned __int128)(absl::MakeUint128( - 65535u, 18446744073709551551ull))> // 2**80-65 -#endif - >; -TYPED_TEST_SUITE(IntModNTest, IntModNTypes); - -TYPED_TEST(IntModNTest, DefaultValueIsZero) { - TypeParam a; - EXPECT_EQ(a.value(), 0); -} - -TYPED_TEST(IntModNTest, SetValueWorks) { - TypeParam a; - EXPECT_EQ(a.value(), 0); - a = 23; - EXPECT_EQ(a.value(), 23); -} - -TYPED_TEST(IntModNTest, AdditionWithoutWrapAroundWorks) { - TypeParam a; - TypeParam b; - a += b; - EXPECT_EQ(a.value(), 0); - b = 23; - a += b; - EXPECT_EQ(a.value(), 23); - b = 4294967200; - a += b; - EXPECT_EQ(a.value(), 4294967223); -} - -TYPED_TEST(IntModNTest, AdditionWithWrapAroundWorks) { - TypeParam a; - TypeParam b; - a += b; - EXPECT_EQ(a.value(), 0); - b = 23; - a += b; - EXPECT_EQ(a.value(), 23); - b = TypeParam::modulus() - 10; - a += b; - EXPECT_EQ(a.value(), 13); -} - -TYPED_TEST(IntModNTest, NegationWorks) { - TypeParam a(10); - TypeParam b = -a; - EXPECT_EQ(a + b, TypeParam(0)); -} - -TYPED_TEST(IntModNTest, GetNumBytesRequiredFailsIfUnfeasible) { - absl::StatusOr<int> result = - TypeParam::GetNumBytesRequired(kNumSamples, kUnfeasibleSecurityParameter); - EXPECT_THAT(result, dpf_internal::StatusIs( - absl::StatusCode::kInvalidArgument, - testing::StartsWith(absl::StrFormat( - "For num_samples = 5 and kModulus = %d", - absl::uint128(TypeParam::modulus()))))); -} - -TYPED_TEST(IntModNTest, GetNumBytesRequiredSucceedsIfFeasible) { - absl::StatusOr<int> result = - TypeParam::GetNumBytesRequired(5, kFeasibleSecurityParameter); - EXPECT_EQ(result.ok(), true); -} - -TYPED_TEST(IntModNTest, SampleFailsIfUnfeasible) { - absl::StatusOr<int> r_getnum = - TypeParam::GetNumBytesRequired(5, kFeasibleSecurityParameter); - EXPECT_EQ(r_getnum.ok(), true); - - std::string bytes = std::string(16, '#'); - EXPECT_GT(r_getnum.value(), bytes.size()); - std::vector<TypeParam> samples(5); - absl::Status r_sample = TypeParam::SampleFromBytes( - bytes, kFeasibleSecurityParameter, absl::MakeSpan(samples)); - EXPECT_EQ(r_sample.ok(), false); - EXPECT_THAT( - r_sample, - dpf_internal::StatusIs( - absl::StatusCode::kInvalidArgument, - "The number of bytes provided (16) is insufficient for the required " - "statistical security and number of samples.")); -} - -TYPED_TEST(IntModNTest, SampleSucceedsIfFeasible) { - absl::StatusOr<int> r_getnum = - TypeParam::GetNumBytesRequired(5, kFeasibleSecurityParameter); - EXPECT_EQ(r_getnum.ok(), true); - - std::string bytes = std::string(r_getnum.value(), '#'); - std::vector<TypeParam> samples(5); - absl::Status r_sample = TypeParam::SampleFromBytes( - bytes, kFeasibleSecurityParameter, absl::MakeSpan(samples)); - EXPECT_EQ(r_sample.ok(), true); -} - -TYPED_TEST(IntModNTest, FirstEntryOfSamplesIsAsExpected) { - absl::StatusOr<int> r_getnum = - TypeParam::GetNumBytesRequired(5, kFeasibleSecurityParameter); - EXPECT_EQ(r_getnum.ok(), true); - - std::string bytes = std::string(r_getnum.value(), '#'); - std::vector<TypeParam> samples(5); - absl::Status r_sample = TypeParam::SampleFromBytes( - bytes, kFeasibleSecurityParameter, absl::MakeSpan(samples)); - EXPECT_EQ(r_sample.ok(), true); - EXPECT_EQ( - samples[0].value(), - TypeParam::template ConvertBytesTo<absl::uint128>(bytes.substr(0, 16)) % - TypeParam::modulus()); -} - -using BaseInteger = uint32_t; -constexpr BaseInteger kModulus32 = 4294967291u; // 2**32 - 5 -using MyIntModN = IntModN<BaseInteger, kModulus32>; - -TEST(IntModNTest, SampleFromBytesWorksInConcreteExample) { - absl::StatusOr<int> r_getnum = - MyIntModN::GetNumBytesRequired(5, kFeasibleSecurityParameter); - EXPECT_EQ(r_getnum.ok(), true); - EXPECT_EQ(*r_getnum, 32); - std::string bytes = "this is a length 32 test string."; - EXPECT_EQ(bytes.size(), 32); - - std::vector<MyIntModN> samples(5); - absl::Status r_sample = MyIntModN::SampleFromBytes( - bytes, kFeasibleSecurityParameter, absl::MakeSpan(samples)); - EXPECT_EQ(r_sample.ok(), true); - absl::uint128 r = - MyIntModN::ConvertBytesTo<absl::uint128>("this is a length"); - EXPECT_EQ(samples[0].value(), r % MyIntModN::modulus()); - r /= MyIntModN::modulus(); - r <<= (sizeof(MyIntModN::Base) * 8); - r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>(" 32 "); - EXPECT_EQ(samples[1].value(), r % MyIntModN::modulus()); - r /= MyIntModN::modulus(); - r <<= (sizeof(MyIntModN::Base) * 8); - r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>("test"); - EXPECT_EQ(samples[2].value(), r % MyIntModN::modulus()); - r /= MyIntModN::modulus(); - r <<= (sizeof(MyIntModN::Base) * 8); - r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>(" str"); - EXPECT_EQ(samples[3].value(), r % MyIntModN::modulus()); - r /= MyIntModN::modulus(); - r <<= (sizeof(MyIntModN::Base) * 8); - r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>("ing."); - EXPECT_EQ(samples[4].value(), r % MyIntModN::modulus()); -} - -TEST(IntModNTest, SampleFromBytesFailsAsExpectedInConcreteExample) { - absl::StatusOr<int> r_getnum = - MyIntModN::GetNumBytesRequired(5, kFeasibleSecurityParameter); - EXPECT_EQ(r_getnum.ok(), true); - EXPECT_EQ(*r_getnum, 32); - std::string bytes = "this is a length 32 test string."; - EXPECT_EQ(bytes.size(), 32); - - std::vector<MyIntModN> samples(5); - absl::Status r_sample = MyIntModN::SampleFromBytes( - bytes, kFeasibleSecurityParameter, absl::MakeSpan(samples)); - EXPECT_EQ(r_sample.ok(), true); - absl::uint128 r = - MyIntModN::ConvertBytesTo<absl::uint128>("this is a length"); - EXPECT_EQ(samples[0].value(), r % MyIntModN::modulus()); - r /= MyIntModN::modulus(); - r <<= (sizeof(MyIntModN::Base) * 8); - r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>(" 32 "); - EXPECT_EQ(samples[1].value(), r % MyIntModN::modulus()); - r /= MyIntModN::modulus(); - r <<= (sizeof(MyIntModN::Base) * 8); - r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>("test"); - EXPECT_EQ(samples[2].value(), r % MyIntModN::modulus()); - r /= MyIntModN::modulus(); - r <<= (sizeof(MyIntModN::Base) * 8); - r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>(" str"); - EXPECT_EQ(samples[3].value(), r % MyIntModN::modulus()); - r /= MyIntModN::modulus(); - r <<= (sizeof(MyIntModN::Base) * 8); - r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>("ing#"); // # instead of . - EXPECT_NE(samples[4].value(), r % MyIntModN::modulus()); -} - -// Test if IntModN operators are in fact constexpr. This will fail to compile -// otherwise. -constexpr MyIntModN TestAddition() { return MyIntModN(2) + MyIntModN(5); } -static_assert(TestAddition().value() == 7, - "constexpr addition of IntModNs incorrect"); - -constexpr MyIntModN TestSubtraction() { return MyIntModN(5) - MyIntModN(2); } -static_assert(TestSubtraction().value() == 3, - "constexpr subtraction of IntModNs incorrect"); - -constexpr MyIntModN TestAssignment() { - MyIntModN x(0); - x = 5; - return x; -} -static_assert(TestAssignment().value() == 5, - "constexpr assignment to IntModN incorrect"); - -#ifdef ABSL_HAVE_INTRINSIC_INT128 -constexpr unsigned __int128 kModulus128 = - (unsigned __int128)(-1); // 2**128 - 159 -using MyIntModN128 = IntModN<unsigned __int128, kModulus128>; -constexpr MyIntModN128 TestAddition128() { - return MyIntModN128(2) + MyIntModN128(5); -} -static_assert(TestAddition128().value() == 7, - "constexpr addition of IntModNs incorrect"); -#endif - -} // namespace -} // namespace distributed_point_functions diff --git a/third_party/distributed_point_functions/code/dpf/internal/BUILD b/third_party/distributed_point_functions/code/dpf/internal/BUILD deleted file mode 100644 index cfbdc6acc111c..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/BUILD +++ /dev/null @@ -1,238 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@com_github_google_iree//build_tools/embed_data:build_defs.bzl", "cc_embed_data") - -package( - default_visibility = ["//:__subpackages__"], -) - -licenses(["notice"]) - -cc_library( - name = "value_type_helpers", - srcs = ["value_type_helpers.cc"], - hdrs = ["value_type_helpers.h"], - deps = [ - "//dpf:distributed_point_function_cc_proto", - "//dpf:int_mod_n", - "//dpf:status_macros", - "//dpf:tuple", - "//dpf:xor_wrapper", - "@com_google_absl//absl/base:config", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/meta:type_traits", - "@com_google_absl//absl/numeric:int128", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/utility", - "@com_google_protobuf//:protobuf_lite", - ], -) - -cc_test( - name = "value_type_helpers_test", - srcs = ["value_type_helpers_test.cc"], - deps = [ - ":status_matchers", - ":value_type_helpers", - "//dpf:distributed_point_function_cc_proto", - "//dpf:int_mod_n", - "//dpf:tuple", - "@com_github_google_googletest//:gtest_main", - "@com_google_absl//absl/base:config", - "@com_google_absl//absl/numeric:int128", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "status_matchers", - testonly = 1, - srcs = [ - "status_matchers.cc", - ], - hdrs = ["status_matchers.h"], - deps = [ - "//dpf:status_macros", - "@com_github_google_googletest//:gtest", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "proto_validator", - srcs = [ - "proto_validator.cc", - ], - hdrs = [ - "proto_validator.h", - ], - deps = [ - ":value_type_helpers", - "//dpf:distributed_point_function_cc_proto", - "//dpf:status_macros", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/numeric:int128", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@com_google_protobuf//:protobuf_lite", - ], -) - -cc_embed_data( - name = "proto_validator_test_textproto_embed", - srcs = [ - "proto_validator_test.textproto", - ], - cc_file_output = "proto_validator_test_textproto_embed.cc", - cpp_namespace = "distributed_point_functions::dpf_internal", - h_file_output = "proto_validator_test_textproto_embed.h", -) - -cc_test( - name = "proto_validator_test", - srcs = [ - "proto_validator_test.cc", - ], - data = [ - "proto_validator_test.textproto", - ], - deps = [ - ":proto_validator", - ":proto_validator_test_textproto_embed", - ":status_matchers", - "//dpf:distributed_point_function_cc_proto", - "//dpf:tuple", - "@com_github_google_googletest//:gtest_main", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_protobuf//:protobuf", - ], -) - -cc_library( - name = "evaluate_prg_hwy", - srcs = ["evaluate_prg_hwy.cc"], - hdrs = ["evaluate_prg_hwy.h"], - deps = [ - ":aes_128_fixed_key_hash_hwy", - "//dpf:aes_128_fixed_key_hash", - "//dpf:status_macros", - "@boringssl//:crypto", - "@com_github_google_highway//:hwy", - "@com_google_absl//absl/base:config", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/numeric:int128", - "@com_google_absl//absl/status", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "evaluate_prg_hwy_test", - srcs = [ - "evaluate_prg_hwy_test.cc", - ], - deps = [ - ":evaluate_prg_hwy", - ":status_matchers", - "//dpf:aes_128_fixed_key_hash", - "@com_github_google_googletest//:gtest_main", - "@com_github_google_highway//:hwy", - "@com_github_google_highway//:hwy_test_util", - "@com_google_absl//absl/numeric:int128", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - ], -) - -cc_library( - name = "get_hwy_mode", - srcs = ["get_hwy_mode.cc"], - hdrs = ["get_hwy_mode.h"], - deps = [ - "@com_github_google_highway//:hwy", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "aes_128_fixed_key_hash_hwy", - hdrs = [ - "aes_128_fixed_key_hash_hwy.h", - ], - deps = [ - "@com_github_google_highway//:hwy", - "@com_google_absl//absl/numeric:int128", - ], -) - -cc_library( - name = "maybe_deref_span", - hdrs = ["maybe_deref_span.h"], - deps = [ - "@com_google_absl//absl/meta:type_traits", - "@com_google_absl//absl/types:span", - "@com_google_absl//absl/types:variant", - ], -) - -cc_test( - name = "aes_128_fixed_key_hash_hwy_test", - srcs = [ - "aes_128_fixed_key_hash_hwy_test.cc", - ], - deps = [ - ":aes_128_fixed_key_hash_hwy", - ":get_hwy_mode", - ":status_matchers", - "//dpf:aes_128_fixed_key_hash", - "@boringssl//:crypto", - "@com_github_google_googletest//:gtest_main", - "@com_github_google_highway//:hwy", - "@com_github_google_highway//:hwy_test_util", - "@com_google_absl//absl/flags:parse", - "@com_google_absl//absl/log:absl_log", - "@com_google_absl//absl/numeric:int128", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "maybe_deref_span_test", - srcs = ["maybe_deref_span_test.cc"], - deps = [ - ":maybe_deref_span", - "@com_github_google_googletest//:gtest_main", - ], -) diff --git a/third_party/distributed_point_functions/code/dpf/internal/aes_128_fixed_key_hash_hwy.h b/third_party/distributed_point_functions/code/dpf/internal/aes_128_fixed_key_hash_hwy.h deleted file mode 100644 index 3d1f8ec6b47f0..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/aes_128_fixed_key_hash_hwy.h +++ /dev/null @@ -1,237 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Highway-specific include guard, ensuring the header can get included once per -// target architecture. -#if defined( \ - DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_HWY_H_) == \ - defined(HWY_TARGET_TOGGLE) -#ifdef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_HWY_H_ -#undef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_HWY_H_ -#else -#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_HWY_H_ -#endif - -#include <limits> - -#include "absl/numeric/int128.h" -#include "hwy/highway.h" - -HWY_BEFORE_NAMESPACE(); -namespace distributed_point_functions { -namespace dpf_internal { -namespace HWY_NAMESPACE { - -// There is no AES support on HWY_SCALAR, but we still want to be able to -// include this header when compiling for HWY_SCALAR. The caller has to make -// sure to only call the functions defined here when not on HWY_SCALAR. -#if HWY_TARGET != HWY_SCALAR - -namespace hn = hwy::HWY_NAMESPACE; - -constexpr int kAesBlockSize = 16; - -// Helper to convert a Highway tag to a tag for vectors of the same bit size, -// but with 64-bit lanes. -template <typename D> -constexpr auto To64(D d) { - return hn::Repartition<uint64_t, D>(); -} - -// The following macros define parts of the fixed-key AES hash function -// implementation. We use macros here since Highway doesn't allow creating -// arrays of vectors/SIMD registers. That way, we can access each register by a -// unique variable name. All inputs and outputs are assumed to be of type -// hn::ScalableTag<uint8_t>. - -// Loads the AES round key for the given round and key_index. -#define DPF_AES_LOAD_ROUND_KEY(key_index, round) \ - const auto round_##round##_key_##key_index = \ - hn::LoadDup128(d, round_keys_##key_index + kAesBlockSize * round); - -// Selects key_0 or key_1 for the given block_index and round, depending on the -// bits in `mask`. Keys are first converted to 64-bit vectors to apply the more -// efficient 64 bit masks. -#define DPF_AES_SELECT_KEY(block_index, round) \ - const auto selected_round_##round##_key_##block_index = hn::BitCast( \ - d, hn::IfThenElse(mask_##block_index, \ - hn::BitCast(To64(d), round_##round##_key_1), \ - hn::BitCast(To64(d), round_##round##_key_0))); - -// Load mask for computing {0, x.high64}, for computing sigma(x) below. -HWY_ALIGN constexpr absl::uint128 kSigmaMask = - absl::MakeUint128(std::numeric_limits<uint64_t>::max(), 0); -#define DPF_AES_LOAD_SIGMA_MASK() \ - const auto sigma_mask = \ - hn::LoadDup128(To64(d), reinterpret_cast<const uint64_t*>(&kSigmaMask)); - -// Compute sigma(x) = {x.high64, x.high64^x.low64} (in little-endian notation). -#define DPF_AES_COMPUTE_SIGMA(block_index) \ - const auto in_##block_index##_64 = hn::BitCast(To64(d), in_##block_index); \ - const auto sigma_##block_index = \ - hn::BitCast(d, hn::Xor(hn::Shuffle01(in_##block_index##_64), \ - hn::And(sigma_mask, in_##block_index##_64))); - -// Performs the first round of AES for the given block_index, using sigma as the -// input. -#define DPF_AES_FIRST_ROUND(block_index) \ - out_##block_index = \ - hn::Xor(sigma_##block_index, selected_round_0_key_##block_index) - -// Performs a middle round of AES for the given block_index. -#define DPF_AES_MIDDLE_ROUND(block_index, round) \ - out_##block_index = hn::AESRound( \ - out_##block_index, selected_round_##round##_key_##block_index); - -// Performs the last round of AES for the given block_index. -#define DPF_AES_LAST_ROUND(block_index) \ - out_##block_index = hn::AESLastRound(out_##block_index, \ - selected_round_10_key_##block_index); - -// Finalize the hash by XORing with sigma. -#define DPF_AES_FINALIZE_HASH(block_index) \ - out_##block_index = hn::Xor(out_##block_index, sigma_##block_index); - -// Helper macro for hashing a single vector. -#define DPF_AES_MIDDLE_ROUND_1(round) \ - DPF_AES_LOAD_ROUND_KEY(0, round); \ - DPF_AES_LOAD_ROUND_KEY(1, round); \ - DPF_AES_SELECT_KEY(0, round); \ - DPF_AES_MIDDLE_ROUND(0, round); - -// Hashes a vector `in_0`, writing the output to `out_0`. Each block is hashed -// using either `round_keys_0` or `round_keys_1`, which both must point to a -// byte array containing two expanded AES keys. Which key is used for each block -// depends on `mask_0`: If the mask 0, then `round_keys_0` is used, otherwise -// `round_keys_1`. Note that the masks are masks on 64 bit integers, so there -// are two mask bits per AES block. The caller is responsible for making sure -// that the masks for the two halves of any given block have the same value. -template <typename V, typename D, typename M> -void HashOneWithKeyMask(D d, V in_0, M mask_0, - const uint8_t* HWY_RESTRICT round_keys_0, - const uint8_t* HWY_RESTRICT round_keys_1, V& out_0) { - // Compute sigma(in_0) - DPF_AES_LOAD_SIGMA_MASK(); - DPF_AES_COMPUTE_SIGMA(0); - - // First AES round. - DPF_AES_LOAD_ROUND_KEY(0, 0); - DPF_AES_LOAD_ROUND_KEY(1, 0); - DPF_AES_SELECT_KEY(0, 0); - DPF_AES_FIRST_ROUND(0); - - // Middle AES rounds. - DPF_AES_MIDDLE_ROUND_1(1); - DPF_AES_MIDDLE_ROUND_1(2); - DPF_AES_MIDDLE_ROUND_1(3); - DPF_AES_MIDDLE_ROUND_1(4); - DPF_AES_MIDDLE_ROUND_1(5); - DPF_AES_MIDDLE_ROUND_1(6); - DPF_AES_MIDDLE_ROUND_1(7); - DPF_AES_MIDDLE_ROUND_1(8); - DPF_AES_MIDDLE_ROUND_1(9); - - // Last AES round. - DPF_AES_LOAD_ROUND_KEY(0, 10); - DPF_AES_LOAD_ROUND_KEY(1, 10); - DPF_AES_SELECT_KEY(0, 10) - DPF_AES_LAST_ROUND(0); - - // Finalize hash. - DPF_AES_FINALIZE_HASH(0); -} - -// Helper macros for hashing four vectors in parallel. -#define DPF_AES_SELECT_KEY_4(round) \ - DPF_AES_SELECT_KEY(0, round); \ - DPF_AES_SELECT_KEY(1, round); \ - DPF_AES_SELECT_KEY(2, round); \ - DPF_AES_SELECT_KEY(3, round); -#define DPF_AES_MIDDLE_ROUND_4(round) \ - DPF_AES_LOAD_ROUND_KEY(0, round); \ - DPF_AES_LOAD_ROUND_KEY(1, round); \ - DPF_AES_SELECT_KEY_4(round); \ - DPF_AES_MIDDLE_ROUND(0, round); \ - DPF_AES_MIDDLE_ROUND(1, round); \ - DPF_AES_MIDDLE_ROUND(2, round); \ - DPF_AES_MIDDLE_ROUND(3, round); - -// Hashes four vectors `in_0, ..., in_3`, writing the results to `out_0, ..., -// out_3`. This improves pipelining of AES instructions, and improves -// performance by about 10%. Each block is hashed using either `round_keys_0` or -// `round_keys_1`, which both must point to a byte array containing two expanded -// AES keys. Which key is used for each block depends on `mask_0, ... mask_3`: -// If the mask 0, then `round_keys_0` is used, otherwise `round_keys_1`. Note -// that the masks are masks on 64 bit integers, so there are two mask bits per -// AES block. The caller is responsible for making sure that the masks for the -// two halves of any given block have the same value. -template <typename V, typename D, typename M> -void HashFourWithKeyMask(D d, V in_0, V in_1, V in_2, V in_3, M mask_0, - M mask_1, M mask_2, M mask_3, - const uint8_t* HWY_RESTRICT round_keys_0, - const uint8_t* HWY_RESTRICT round_keys_1, V& out_0, - V& out_1, V& out_2, V& out_3) { - // Compute sigma(in_0), ..., sigma(in_3) - DPF_AES_LOAD_SIGMA_MASK(); - DPF_AES_COMPUTE_SIGMA(0); - DPF_AES_COMPUTE_SIGMA(1); - DPF_AES_COMPUTE_SIGMA(2); - DPF_AES_COMPUTE_SIGMA(3); - - // First AES round. - DPF_AES_LOAD_ROUND_KEY(0, 0); - DPF_AES_LOAD_ROUND_KEY(1, 0); - DPF_AES_SELECT_KEY_4(0) - DPF_AES_FIRST_ROUND(0); - DPF_AES_FIRST_ROUND(1); - DPF_AES_FIRST_ROUND(2); - DPF_AES_FIRST_ROUND(3); - - // Middle AES rounds. - DPF_AES_MIDDLE_ROUND_4(1); - DPF_AES_MIDDLE_ROUND_4(2); - DPF_AES_MIDDLE_ROUND_4(3); - DPF_AES_MIDDLE_ROUND_4(4); - DPF_AES_MIDDLE_ROUND_4(5); - DPF_AES_MIDDLE_ROUND_4(6); - DPF_AES_MIDDLE_ROUND_4(7); - DPF_AES_MIDDLE_ROUND_4(8); - DPF_AES_MIDDLE_ROUND_4(9); - - // Last AES round. - DPF_AES_LOAD_ROUND_KEY(0, 10); - DPF_AES_LOAD_ROUND_KEY(1, 10); - DPF_AES_SELECT_KEY_4(10) - DPF_AES_LAST_ROUND(0); - DPF_AES_LAST_ROUND(1); - DPF_AES_LAST_ROUND(2); - DPF_AES_LAST_ROUND(3); - - // Finalize hash. - DPF_AES_FINALIZE_HASH(0); - DPF_AES_FINALIZE_HASH(1); - DPF_AES_FINALIZE_HASH(2); - DPF_AES_FINALIZE_HASH(3); -} - -#endif // HWY_TARGET != HWY_SCALAR - -} // namespace HWY_NAMESPACE -} // namespace dpf_internal -} // namespace distributed_point_functions -HWY_AFTER_NAMESPACE(); - -#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_HWY_H_ diff --git a/third_party/distributed_point_functions/code/dpf/internal/aes_128_fixed_key_hash_hwy_test.cc b/third_party/distributed_point_functions/code/dpf/internal/aes_128_fixed_key_hash_hwy_test.cc deleted file mode 100644 index d581e1155dfc4..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/aes_128_fixed_key_hash_hwy_test.cc +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <limits> -#include <memory> -#include <vector> - -#include "absl/flags/parse.h" -#include "absl/log/absl_log.h" -#include "absl/numeric/int128.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "dpf/aes_128_fixed_key_hash.h" -#include "dpf/internal/get_hwy_mode.h" -#include "dpf/internal/status_matchers.h" -#include "gtest/gtest.h" -#include "hwy/aligned_allocator.h" -#include "hwy/detect_targets.h" -#include "openssl/aes.h" - -// clang-format off -#define HWY_IS_TEST 1 -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "dpf/internal/aes_128_fixed_key_hash_hwy_test.cc" // NOLINT -#include "hwy/foreach_target.h" -// clang-format on - -#include "dpf/internal/aes_128_fixed_key_hash_hwy.h" -#include "hwy/highway.h" -#include "hwy/tests/test_util-inl.h" - -HWY_BEFORE_NAMESPACE(); -namespace distributed_point_functions { -namespace dpf_internal { -namespace HWY_NAMESPACE { - -#if HWY_TARGET == HWY_SCALAR - -void TestAllAes() { - return; // HWY_SCALAR doesn't support AES instructions, so nothing to test. -} - -#else - -namespace hn = hwy::HWY_NAMESPACE; - -constexpr absl::uint128 kKey0 = - absl::MakeUint128(0x0000000000000000, 0x0000000000000000); -constexpr absl::uint128 kKey1 = - absl::MakeUint128(0x1111111111111111, 0x1111111111111111); -constexpr int kNumBlocks = 128; // Must be divisible by (4 * hn::Lanes(d)). -constexpr int kNumBytes = kNumBlocks * sizeof(absl::uint128); - -class TestOutputMatchesOpenSSL { - public: - template <typename T, typename D> - HWY_NOINLINE void operator()(T /*unused*/, D d) { - Reset(); - EvaluateOne(d); - CheckResult(); - Reset(); - EvaluateFour(d); - CheckResult(); - } - - private: - void Reset() { - inputs_ = hwy::AllocateAligned<absl::uint128>(kNumBlocks); - ASSERT_NE(inputs_, nullptr); - masks_ = hwy::AllocateAligned<uint64_t>(2 * kNumBlocks); - ASSERT_NE(masks_, nullptr); - for (int i = 0; i < kNumBlocks; ++i) { - inputs_[i] = absl::MakeUint128(i, i + 1); - masks_[2 * i] = masks_[2 * i + 1] = - (i % 3 == 0) ? std::numeric_limits<uint64_t>::max() : 0; - } - outputs_ = hwy::AllocateAligned<absl::uint128>(kNumBlocks); - ASSERT_NE(outputs_, nullptr); - ASSERT_EQ(0, AES_set_encrypt_key(reinterpret_cast<const uint8_t*>(&kKey0), - 128, &expanded_key_0_)); - ASSERT_EQ(0, AES_set_encrypt_key(reinterpret_cast<const uint8_t*>(&kKey1), - 128, &expanded_key_1_)); - input_ptr_ = reinterpret_cast<const uint8_t*>(inputs_.get()); - output_ptr_ = reinterpret_cast<uint8_t*>(outputs_.get()); - } - - void CheckResult() { - // Check the result by comparing with OpenSSL-based AES hash. - DPF_ASSERT_OK_AND_ASSIGN( - distributed_point_functions::Aes128FixedKeyHash hash_0, - distributed_point_functions::Aes128FixedKeyHash::Create(kKey0)); - DPF_ASSERT_OK_AND_ASSIGN( - distributed_point_functions::Aes128FixedKeyHash hash_1, - distributed_point_functions::Aes128FixedKeyHash::Create(kKey1)); - std::vector<absl::uint128> wanted_0(kNumBlocks), wanted_1(kNumBlocks); - DPF_ASSERT_OK( - hash_0.Evaluate(absl::MakeConstSpan(inputs_.get(), kNumBlocks), - absl::MakeSpan(wanted_0))); - DPF_ASSERT_OK( - hash_1.Evaluate(absl::MakeConstSpan(inputs_.get(), kNumBlocks), - absl::MakeSpan(wanted_1))); - for (int i = 0; i < kNumBlocks; ++i) { - if (i % 3 == 0) { - EXPECT_EQ(wanted_1[i], outputs_.get()[i]) << "i=" << i; - } else { - EXPECT_EQ(wanted_0[i], outputs_.get()[i]) << "i=" << i; - } - } - } - - template <typename D> - void EvaluateOne(D d) { - hn::Repartition<uint64_t, D> d64; - for (int i = 0; i + hn::Lanes(d) <= kNumBytes; i += hn::Lanes(d)) { - const auto in = hn::Load(d, input_ptr_ + i); - const auto mask = - hn::MaskFromVec(hn::Load(d64, masks_.get() + i / sizeof(uint64_t))); - auto out = hn::Undefined(d); - HashOneWithKeyMask( - d, in, mask, reinterpret_cast<const uint8_t*>(expanded_key_0_.rd_key), - reinterpret_cast<const uint8_t*>(expanded_key_1_.rd_key), out); - hn::Store(out, d, output_ptr_ + i); - } - } - - template <typename D> - void EvaluateFour(D d) { - hn::Repartition<uint64_t, D> d64; - // Evaluate four vectors at once. Assumes kNumBytes is divisible by (4 * - // hn::Lanes(d)). - for (int i = 0; i < kNumBytes; i += 4 * hn::Lanes(d)) { - const auto in_0 = hn::Load(d, input_ptr_ + i); - const auto in_1 = hn::Load(d, input_ptr_ + i + 1 * hn::Lanes(d)); - const auto in_2 = hn::Load(d, input_ptr_ + i + 2 * hn::Lanes(d)); - const auto in_3 = hn::Load(d, input_ptr_ + i + 3 * hn::Lanes(d)); - const auto mask_0 = - hn::MaskFromVec(hn::Load(d64, masks_.get() + i / sizeof(uint64_t))); - const auto mask_1 = hn::MaskFromVec(hn::Load( - d64, masks_.get() + (i + 1 * hn::Lanes(d)) / sizeof(uint64_t))); - const auto mask_2 = hn::MaskFromVec(hn::Load( - d64, masks_.get() + (i + 2 * hn::Lanes(d)) / sizeof(uint64_t))); - const auto mask_3 = hn::MaskFromVec(hn::Load( - d64, masks_.get() + (i + 3 * hn::Lanes(d)) / sizeof(uint64_t))); - auto out_0 = hn::Undefined(d); - auto out_1 = hn::Undefined(d); - auto out_2 = hn::Undefined(d); - auto out_3 = hn::Undefined(d); - HashFourWithKeyMask( - d, in_0, in_1, in_2, in_3, mask_0, mask_1, mask_2, mask_3, - reinterpret_cast<const uint8_t*>(expanded_key_0_.rd_key), - reinterpret_cast<const uint8_t*>(expanded_key_1_.rd_key), out_0, - out_1, out_2, out_3); - hn::Store(out_0, d, output_ptr_ + i); - hn::Store(out_1, d, output_ptr_ + i + 1 * hn::Lanes(d)); - hn::Store(out_2, d, output_ptr_ + i + 2 * hn::Lanes(d)); - hn::Store(out_3, d, output_ptr_ + i + 3 * hn::Lanes(d)); - } - - // Check the result by comparing with OpenSSL-based AES hash. - DPF_ASSERT_OK_AND_ASSIGN( - distributed_point_functions::Aes128FixedKeyHash hash_0, - distributed_point_functions::Aes128FixedKeyHash::Create(kKey0)); - DPF_ASSERT_OK_AND_ASSIGN( - distributed_point_functions::Aes128FixedKeyHash hash_1, - distributed_point_functions::Aes128FixedKeyHash::Create(kKey1)); - std::vector<absl::uint128> wanted_0(kNumBlocks), wanted_1(kNumBlocks); - DPF_ASSERT_OK( - hash_0.Evaluate(absl::MakeConstSpan(inputs_.get(), kNumBlocks), - absl::MakeSpan(wanted_0))); - DPF_ASSERT_OK( - hash_1.Evaluate(absl::MakeConstSpan(inputs_.get(), kNumBlocks), - absl::MakeSpan(wanted_1))); - for (int i = 0; i < kNumBlocks; ++i) { - if (i % 3 == 0) { - EXPECT_EQ(wanted_1[i], outputs_.get()[i]) << "i=" << i; - } else { - EXPECT_EQ(wanted_0[i], outputs_.get()[i]) << "i=" << i; - } - } - } - - hwy::AlignedFreeUniquePtr<absl::uint128[]> inputs_, outputs_; - hwy::AlignedFreeUniquePtr<uint64_t[]> masks_; - const uint8_t* input_ptr_; - uint8_t* output_ptr_; - HWY_ALIGN AES_KEY expanded_key_0_, expanded_key_1_; -}; - -void TestAllAes() { - hn::ForGE128Vectors<TestOutputMatchesOpenSSL>()(uint8_t{0}); -} - -#endif // HWY_TARGET == HWY_SCALAR - -} // namespace HWY_NAMESPACE -} // namespace dpf_internal -} // namespace distributed_point_functions -HWY_AFTER_NAMESPACE(); - -#if HWY_ONCE - -namespace distributed_point_functions { -namespace dpf_internal { -HWY_BEFORE_TEST(Aes128FixedKeyHashHwyTest); -HWY_EXPORT_AND_TEST_P(Aes128FixedKeyHashHwyTest, TestAllAes); - -TEST(Aes128FixedKeyHashHwy, LogHwyMode) { - ABSL_LOG(INFO) << "Highway is in " << GetHwyModeAsString() << " mode"; -} - -} // namespace dpf_internal -} // namespace distributed_point_functions - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - - return RUN_ALL_TESTS(); -} - -#endif diff --git a/third_party/distributed_point_functions/code/dpf/internal/evaluate_prg_hwy.cc b/third_party/distributed_point_functions/code/dpf/internal/evaluate_prg_hwy.cc deleted file mode 100644 index f1e9ced63c8c6..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/evaluate_prg_hwy.cc +++ /dev/null @@ -1,662 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dpf/internal/evaluate_prg_hwy.h" - -#include <algorithm> -#include <cstdint> -#include <memory> -#include <vector> - -#include "absl/base/config.h" -#include "absl/base/optimization.h" -#include "absl/container/inlined_vector.h" -#include "absl/log/absl_check.h" -#include "absl/numeric/int128.h" -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "dpf/aes_128_fixed_key_hash.h" -#include "dpf/status_macros.h" -#include "hwy/aligned_allocator.h" -#include "openssl/aes.h" - -// clang-format off -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "dpf/internal/evaluate_prg_hwy.cc" -#include "hwy/foreach_target.h" -// clang-format on - -#include "dpf/internal/aes_128_fixed_key_hash_hwy.h" -#include "hwy/highway.h" - -HWY_BEFORE_NAMESPACE(); -namespace distributed_point_functions { -namespace dpf_internal { -namespace HWY_NAMESPACE { - -namespace hn = hwy::HWY_NAMESPACE; - -#if HWY_TARGET == HWY_SCALAR - -absl::Status EvaluateSeedsHwy( - int64_t num_seeds, int num_levels, const absl::uint128* seeds_in, - const bool* control_bits_in, const absl::uint128* paths, - const absl::uint128* correction_seeds, const bool* correction_controls_left, - const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left, - const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out, - bool* control_bits_out) { - return EvaluateSeedsNoHwy(num_seeds, num_levels, seeds_in, control_bits_in, - paths, correction_seeds, correction_controls_left, - correction_controls_right, prg_left, prg_right, - seeds_out, control_bits_out); -} - -#else - -// Converts a bool array to a block-level mask suitable for vectors described by -// `d`. The mask value for each integer in the i-th block is set to input[i]. -// If `max_blocks > 0`, returns after reading `max_blocks` bools from `input`. -template <typename D> -auto MaskFromBools(D d, const bool* input, int max_blocks = 0) { - using T = hn::TFromD<D>; - constexpr size_t ints_per_block = sizeof(absl::uint128) / sizeof(T); - constexpr int buffer_size = std::max(HWY_MAX_BYTES / 8, 64); - uint8_t mask_bits[buffer_size] = {0}; - for (int i = 0; i < hn::Lanes(d); ++i) { - int block_idx = i / ints_per_block; - if (max_blocks > 0 && block_idx >= max_blocks) { - break; - } - if (input[block_idx]) { - mask_bits[i / 8] |= uint8_t{1} << (i % 8); - } - } - return hn::LoadMaskBits(d, mask_bits); -} - -// Converts a mask for types `d` to a bool array. Assumes that the mask value -// for all integers in the i-th block is equal, and writes that value to -// output[i]. If `max_blocks > 0`, returns after writing `max_blocks` bools to -// `output`. -template <typename D, typename M> -void BoolsFromMask(D d, M mask, bool* output, int max_blocks = 0) { - using T = hn::TFromD<D>; - constexpr size_t ints_per_block = sizeof(absl::uint128) / sizeof(T); - int num_outputs = hn::Lanes(d) / ints_per_block; - if (max_blocks > 0) { - num_outputs = max_blocks; - } - constexpr int buffer_size = std::max(HWY_MAX_BYTES / 8, 64); - uint8_t mask_bits[buffer_size] = {0}; - hn::StoreMaskBits(d, mask, mask_bits); - for (int i = 0; i < num_outputs; ++i) { - int mask_idx = i * ints_per_block; - output[i] = (mask_bits[mask_idx / 8] & (uint8_t{1} << (mask_idx % 8))) != 0; - } -} - -template <typename M> -M IfThenElseMask(M condition, M true_value, M false_value) { - return hn::Or(hn::And(condition, true_value), - hn::And(hn::Not(condition), false_value)); -} - -// Returns a mask that is `true` on all blocks where `input[i] & (1 << index)` -// is nonzero. The mask is a 64-bit-level mask, suitable for AES hashing. -template <typename V, typename D> -auto IsBitSet(D d, const V input, int index) { - // First create a 128-bit block with the `index`-th bit set. - HWY_ALIGN absl::uint128 mask = 0; - if (index < 128) { - mask = absl::uint128{1} << index; - } - - // Now load it into a vector of 64-bit integers. Note that every second - // element of that vector will be 0. - const hn::Repartition<uint64_t, D> d64; - static_assert(ABSL_IS_LITTLE_ENDIAN); - const auto mask_64 = - hn::LoadDup128(d64, reinterpret_cast<const uint64_t*>(&mask)); - - // Compute input AND mask_64 on 64-bit integers. - auto input_64 = hn::BitCast(d64, input); - input_64 = hn::And(input_64, mask_64); - - // Take the OR of every two adjacent 64-bit integers. This ensures that each - // half of an 128-bit block is nonzero iff at least one half was nonzero. - input_64 = hn::Or(input_64, hn::Shuffle01(input_64)); - - // Compute a 64-bit mask that checks which integers are nonzero. - return hn::Ne(input_64, hn::Zero(d64)); -} - -// Dummy struct to get HWY_ALIGN as a number, for testing if an array of -// absl::uint128 is aligned. -struct HWY_ALIGN Aligned128 { - absl::uint128 _; -}; - -absl::Status EvaluateSeedsHwy( - int64_t num_seeds, int num_levels, int num_correction_words, - const absl::uint128* seeds_in, const bool* control_bits_in, - const absl::uint128* paths, int paths_rightshift, - const absl::uint128* correction_seeds, const bool* correction_controls_left, - const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left, - const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out, - bool* control_bits_out) { - // Exit early if inputs are empty. - if (num_seeds == 0 || num_levels == 0) { - return absl::OkStatus(); - } - - // Check if inputs and outputs are aligned. - constexpr size_t kHwyAlignment = alignof(Aligned128); - const bool is_aligned = - (reinterpret_cast<uintptr_t>(seeds_in) % kHwyAlignment == 0) && - (reinterpret_cast<uintptr_t>(paths) % kHwyAlignment == 0) && - (reinterpret_cast<uintptr_t>(correction_seeds) % kHwyAlignment == 0) && - (reinterpret_cast<uintptr_t>(seeds_out) % kHwyAlignment == 0); - // Vector type used throughout this function: Largest byte vector available. - const hn::ScalableTag<uint8_t> d8; - // Only run the highway version if - // - the inputs are aligned, - // - the number of bytes in a vector is at least 16, and - // - the number of bytes in a vector is a multiple of 16. - if (ABSL_PREDICT_FALSE(!is_aligned || hn::Lanes(d8) < 16 || - hn::Lanes(d8) % 16 != 0)) { - return EvaluateSeedsNoHwy( - num_seeds, num_levels, num_correction_words, seeds_in, control_bits_in, - paths, paths_rightshift, correction_seeds, correction_controls_left, - correction_controls_right, prg_left, prg_right, seeds_out, - control_bits_out); - } - - // Do AES key schedule. - HWY_ALIGN AES_KEY expanded_key_0; - HWY_ALIGN AES_KEY expanded_key_1; - int openssl_status = AES_set_encrypt_key( - reinterpret_cast<const uint8_t*>(&prg_left.key()), 128, &expanded_key_0); - if (openssl_status != 0) { - return absl::InternalError("Failed to set up AES key"); - } - openssl_status = AES_set_encrypt_key( - reinterpret_cast<const uint8_t*>(&prg_right.key()), 128, &expanded_key_1); - if (openssl_status != 0) { - return absl::InternalError("Failed to set up AES key"); - } - - // Helper variables. - const hn::Repartition<uint64_t, decltype(d8)> d64; - HWY_ALIGN absl::uint128 clear_lowest_bit_128 = ~absl::uint128{1}; - const auto clear_lowest_bit = hn::LoadDup128( - d8, reinterpret_cast<const uint8_t*>(&clear_lowest_bit_128)); - const auto mask_all_zero = hn::FirstN(d64, 0); - const auto mask_all_one = hn::Not(mask_all_zero); - const int64_t num_bytes = num_seeds * sizeof(absl::uint128); - const int bytes_per_vec = hn::Lanes(d8); - const int blocks_per_vec = bytes_per_vec / sizeof(absl::uint128); - const int64_t correction_words_per_level = num_correction_words / num_levels; - - // Pointer aliases for reading and writing data. - const uint8_t* seeds_in_ptr = reinterpret_cast<const uint8_t*>(seeds_in); - const uint8_t* paths_ptr = reinterpret_cast<const uint8_t*>(paths); - uint8_t* seeds_out_ptr = reinterpret_cast<uint8_t*>(seeds_out); - // Four vectors at a time. - int64_t i = 0; - for (; i + 4 * bytes_per_vec <= num_bytes; i += 4 * bytes_per_vec) { - const int64_t start_block = i / sizeof(absl::uint128); - // Load initial seeds and paths into vectors. - auto vec_0 = hn::Load(d8, seeds_in_ptr + i); - auto vec_1 = hn::Load(d8, seeds_in_ptr + i + 1 * bytes_per_vec); - auto vec_2 = hn::Load(d8, seeds_in_ptr + i + 2 * bytes_per_vec); - auto vec_3 = hn::Load(d8, seeds_in_ptr + i + 3 * bytes_per_vec); - const auto path_0 = hn::Load(d8, paths_ptr + i); - const auto path_1 = hn::Load(d8, paths_ptr + i + 1 * bytes_per_vec); - const auto path_2 = hn::Load(d8, paths_ptr + i + 2 * bytes_per_vec); - const auto path_3 = hn::Load(d8, paths_ptr + i + 3 * bytes_per_vec); - auto control_mask_0 = MaskFromBools(d64, control_bits_in + start_block); - auto control_mask_1 = - MaskFromBools(d64, control_bits_in + start_block + 1 * blocks_per_vec); - auto control_mask_2 = - MaskFromBools(d64, control_bits_in + start_block + 2 * blocks_per_vec); - auto control_mask_3 = - MaskFromBools(d64, control_bits_in + start_block + 3 * blocks_per_vec); - for (int j = 0; j < num_levels; ++j) { - // Convert path bits to masks and evaluate PRG. - const int bit_index = num_levels - j - 1 + paths_rightshift; - const auto path_mask_0 = IsBitSet(d8, path_0, bit_index); - const auto path_mask_1 = IsBitSet(d8, path_1, bit_index); - const auto path_mask_2 = IsBitSet(d8, path_2, bit_index); - const auto path_mask_3 = IsBitSet(d8, path_3, bit_index); - HashFourWithKeyMask( - d8, vec_0, vec_1, vec_2, vec_3, path_mask_0, path_mask_1, path_mask_2, - path_mask_3, reinterpret_cast<const uint8_t*>(expanded_key_0.rd_key), - reinterpret_cast<const uint8_t*>(expanded_key_1.rd_key), vec_0, vec_1, - vec_2, vec_3); - - // Apply correction. - if (correction_words_per_level == 1) { - const auto correction_seed = hn::LoadDup128( - d64, reinterpret_cast<const uint64_t*>(correction_seeds + j)); - vec_0 = hn::Xor(vec_0, - hn::BitCast(d8, hn::IfThenElseZero(control_mask_0, - correction_seed))); - vec_1 = hn::Xor(vec_1, - hn::BitCast(d8, hn::IfThenElseZero(control_mask_1, - correction_seed))); - vec_2 = hn::Xor(vec_2, - hn::BitCast(d8, hn::IfThenElseZero(control_mask_2, - correction_seed))); - vec_3 = hn::Xor(vec_3, - hn::BitCast(d8, hn::IfThenElseZero(control_mask_3, - correction_seed))); - } else { // correction_words_per_level == num_seeds. - const uint8_t* correction_seeds_ptr = reinterpret_cast<const uint8_t*>( - correction_seeds + j * correction_words_per_level); - hn::Vec<decltype(d64)> correction_seed_0, correction_seed_1, - correction_seed_2, correction_seed_3; - if (ABSL_PREDICT_TRUE( - correction_words_per_level % blocks_per_vec == 0 || j == 0)) { - correction_seed_0 = - hn::BitCast(d64, hn::Load(d8, correction_seeds_ptr + i)); - correction_seed_1 = hn::BitCast( - d64, hn::Load(d8, correction_seeds_ptr + i + 1 * bytes_per_vec)); - correction_seed_2 = hn::BitCast( - d64, hn::Load(d8, correction_seeds_ptr + i + 2 * bytes_per_vec)); - correction_seed_3 = hn::BitCast( - d64, hn::Load(d8, correction_seeds_ptr + i + 3 * bytes_per_vec)); - } else { - correction_seed_0 = - hn::BitCast(d64, hn::LoadU(d8, correction_seeds_ptr + i)); - correction_seed_1 = hn::BitCast( - d64, hn::LoadU(d8, correction_seeds_ptr + i + 1 * bytes_per_vec)); - correction_seed_2 = hn::BitCast( - d64, hn::LoadU(d8, correction_seeds_ptr + i + 2 * bytes_per_vec)); - correction_seed_3 = hn::BitCast( - d64, hn::LoadU(d8, correction_seeds_ptr + i + 3 * bytes_per_vec)); - } - vec_0 = hn::Xor(vec_0, - hn::BitCast(d8, hn::IfThenElseZero(control_mask_0, - correction_seed_0))); - vec_1 = hn::Xor(vec_1, - hn::BitCast(d8, hn::IfThenElseZero(control_mask_1, - correction_seed_1))); - vec_2 = hn::Xor(vec_2, - hn::BitCast(d8, hn::IfThenElseZero(control_mask_2, - correction_seed_2))); - vec_3 = hn::Xor(vec_3, - hn::BitCast(d8, hn::IfThenElseZero(control_mask_3, - correction_seed_3))); - } - - // Extract control bit for next level. - const auto next_control_mask_0 = IsBitSet(d8, vec_0, 0); - const auto next_control_mask_1 = IsBitSet(d8, vec_1, 0); - const auto next_control_mask_2 = IsBitSet(d8, vec_2, 0); - const auto next_control_mask_3 = IsBitSet(d8, vec_3, 0); - vec_0 = hn::And(vec_0, clear_lowest_bit); - vec_1 = hn::And(vec_1, clear_lowest_bit); - vec_2 = hn::And(vec_2, clear_lowest_bit); - vec_3 = hn::And(vec_3, clear_lowest_bit); - - // Perform control bit correction. - auto correction_control_mask_0 = mask_all_zero, - correction_control_mask_1 = mask_all_zero, - correction_control_mask_2 = mask_all_zero, - correction_control_mask_3 = mask_all_zero; - if (correction_words_per_level == 1) { - const auto correction_control_mask_left = - correction_controls_left[j] ? mask_all_one : mask_all_zero; - const auto correction_control_mask_right = - correction_controls_right[j] ? mask_all_one : mask_all_zero; - correction_control_mask_0 = - IfThenElseMask(path_mask_0, correction_control_mask_right, - correction_control_mask_left); - correction_control_mask_1 = - IfThenElseMask(path_mask_1, correction_control_mask_right, - correction_control_mask_left); - correction_control_mask_2 = - IfThenElseMask(path_mask_2, correction_control_mask_right, - correction_control_mask_left); - correction_control_mask_3 = - IfThenElseMask(path_mask_3, correction_control_mask_right, - correction_control_mask_left); - } else { // correction_words_per_level == num_seeds. - const bool* correction_controls_left_j = - correction_controls_left + j * correction_words_per_level + - start_block; - const bool* correction_controls_right_j = - correction_controls_right + j * correction_words_per_level + - start_block; - correction_control_mask_0 = IfThenElseMask( - path_mask_0, MaskFromBools(d64, correction_controls_right_j), - MaskFromBools(d64, correction_controls_left_j)); - correction_control_mask_1 = IfThenElseMask( - path_mask_1, - MaskFromBools(d64, - correction_controls_right_j + 1 * blocks_per_vec), - MaskFromBools(d64, - correction_controls_left_j + 1 * blocks_per_vec)); - correction_control_mask_2 = IfThenElseMask( - path_mask_2, - MaskFromBools(d64, - correction_controls_right_j + 2 * blocks_per_vec), - MaskFromBools(d64, - correction_controls_left_j + 2 * blocks_per_vec)); - correction_control_mask_3 = IfThenElseMask( - path_mask_3, - MaskFromBools(d64, - correction_controls_right_j + 3 * blocks_per_vec), - MaskFromBools(d64, - correction_controls_left_j + 3 * blocks_per_vec)); - } - - control_mask_0 = - hn::Xor(next_control_mask_0, - (hn::And(control_mask_0, correction_control_mask_0))); - control_mask_1 = - hn::Xor(next_control_mask_1, - (hn::And(control_mask_1, correction_control_mask_1))); - control_mask_2 = - hn::Xor(next_control_mask_2, - (hn::And(control_mask_2, correction_control_mask_2))); - control_mask_3 = - hn::Xor(next_control_mask_3, - (hn::And(control_mask_3, correction_control_mask_3))); - } - // Write the evaluated outputs to memory. - hn::Store(vec_0, d8, seeds_out_ptr + i); - hn::Store(vec_1, d8, seeds_out_ptr + i + 1 * bytes_per_vec); - hn::Store(vec_2, d8, seeds_out_ptr + i + 2 * bytes_per_vec); - hn::Store(vec_3, d8, seeds_out_ptr + i + 3 * bytes_per_vec); - BoolsFromMask(d64, control_mask_0, control_bits_out + start_block); - BoolsFromMask(d64, control_mask_1, - control_bits_out + start_block + 1 * blocks_per_vec); - BoolsFromMask(d64, control_mask_2, - control_bits_out + start_block + 2 * blocks_per_vec); - BoolsFromMask(d64, control_mask_3, - control_bits_out + start_block + 3 * blocks_per_vec); - } - ABSL_DCHECK_GT(i + 4 * bytes_per_vec, num_bytes); - - // Single full vectors. - for (; i + bytes_per_vec <= num_bytes; i += bytes_per_vec) { - const int64_t start_block = i / sizeof(absl::uint128); - auto vec = hn::Load(d8, seeds_in_ptr + i); - const auto path = hn::Load(d8, paths_ptr + i); - auto control_mask = MaskFromBools(d64, control_bits_in + start_block); - for (int j = 0; j < num_levels; ++j) { - const int bit_index = num_levels - j - 1 + paths_rightshift; - const auto path_mask = IsBitSet(d8, path, bit_index); - HashOneWithKeyMask( - d8, vec, path_mask, - reinterpret_cast<const uint8_t*>(expanded_key_0.rd_key), - reinterpret_cast<const uint8_t*>(expanded_key_1.rd_key), vec); - - // Apply correction. - hn::Vec<decltype(d64)> correction_seed; - if (correction_words_per_level == 1) { - correction_seed = hn::LoadDup128( - d64, reinterpret_cast<const uint64_t*>(correction_seeds + j)); - } else { - const uint64_t* correction_seeds_ptr = - reinterpret_cast<const uint64_t*>(correction_seeds + - j * correction_words_per_level + - start_block); - if (ABSL_PREDICT_TRUE( - correction_words_per_level % blocks_per_vec == 0 || j == 0)) { - correction_seed = hn::Load(d64, correction_seeds_ptr); - } else { - correction_seed = hn::LoadU(d64, correction_seeds_ptr); - } - } - vec = hn::Xor(vec, hn::BitCast(d8, hn::IfThenElseZero(control_mask, - correction_seed))); - - // Extract control bit for next level. - const auto next_control_mask = IsBitSet(d8, vec, 0); - vec = hn::And(vec, clear_lowest_bit); - - // Perform control bit correction. - auto correction_control_mask = mask_all_zero; - if (correction_words_per_level == 1) { - const auto correction_control_mask_left = - correction_controls_left[j] ? mask_all_one : mask_all_zero; - const auto correction_control_mask_right = - correction_controls_right[j] ? mask_all_one : mask_all_zero; - correction_control_mask = - IfThenElseMask(path_mask, correction_control_mask_right, - correction_control_mask_left); - } else { - const bool* correction_controls_left_j = - correction_controls_left + j * correction_words_per_level + - start_block; - const bool* correction_controls_right_j = - correction_controls_right + j * correction_words_per_level + - start_block; - correction_control_mask = IfThenElseMask( - path_mask, MaskFromBools(d64, correction_controls_right_j), - MaskFromBools(d64, correction_controls_left_j)); - } - control_mask = hn::Xor(next_control_mask, - (hn::And(control_mask, correction_control_mask))); - } - hn::Store(vec, d8, seeds_out_ptr + i); - BoolsFromMask(d64, control_mask, control_bits_out + start_block); - } - ABSL_DCHECK_GT(i + bytes_per_vec, num_bytes); - - // Elements less than a full vector. - int remaining_blocks = num_seeds - i / sizeof(absl::uint128); - if (remaining_blocks > 0) { - const int64_t start_block = i / sizeof(absl::uint128); - const int remaining_bytes = num_bytes - i; - // Copy to a buffer first, to ensure we have at least bytes_per_vec bytes - // to read. Calling MaskedLoad directly instead might lead to out-of-bounds - // accesses. - auto buffer = hwy::AllocateAligned<absl::uint128>(2 * blocks_per_vec); - if (buffer == nullptr) { - return absl::ResourceExhaustedError("Memory allocation error"); - } - auto buffer_ptr = reinterpret_cast<uint8_t*>(buffer.get()); - std::copy_n(seeds_in + start_block, remaining_blocks, buffer.get()); - std::copy_n(paths + start_block, remaining_blocks, - buffer.get() + blocks_per_vec); - const auto load_mask = hn::FirstN(d8, remaining_bytes); - auto vec = hn::MaskedLoad(load_mask, d8, buffer_ptr); - const auto path = hn::MaskedLoad(load_mask, d8, buffer_ptr + bytes_per_vec); - auto control_mask = - MaskFromBools(d64, control_bits_in + start_block, remaining_blocks); - for (int j = 0; j < num_levels; ++j) { - const int bit_index = num_levels - j - 1 + paths_rightshift; - const auto path_mask = IsBitSet(d8, path, bit_index); - HashOneWithKeyMask( - d8, vec, path_mask, - reinterpret_cast<const uint8_t*>(expanded_key_0.rd_key), - reinterpret_cast<const uint8_t*>(expanded_key_1.rd_key), vec); - - // Perform seed correction. - hn::Vec<decltype(d64)> correction_seed; - if (correction_words_per_level == 1) { - correction_seed = hn::LoadDup128( - d64, reinterpret_cast<const uint64_t*>(correction_seeds + j)); - } else { - std::copy_n( - correction_seeds + j * correction_words_per_level + start_block, - remaining_blocks, buffer.get()); - correction_seed = - hn::BitCast(d64, hn::MaskedLoad(load_mask, d8, buffer_ptr)); - } - vec = hn::Xor(vec, hn::BitCast(d8, hn::IfThenElseZero(control_mask, - correction_seed))); - const auto next_control_mask = IsBitSet(d8, vec, 0); - vec = hn::And(vec, clear_lowest_bit); - - // Perform control bit correction. - auto correction_control_mask = mask_all_zero; - if (correction_words_per_level == 1) { - const auto correction_control_mask_left = - correction_controls_left[j] ? mask_all_one : mask_all_zero; - const auto correction_control_mask_right = - correction_controls_right[j] ? mask_all_one : mask_all_zero; - correction_control_mask = - IfThenElseMask(path_mask, correction_control_mask_right, - correction_control_mask_left); - } else { - const bool* correction_controls_left_j = - correction_controls_left + j * correction_words_per_level + - start_block; - const bool* correction_controls_right_j = - correction_controls_right + j * correction_words_per_level + - start_block; - correction_control_mask = IfThenElseMask( - path_mask, - MaskFromBools(d64, correction_controls_right_j, remaining_blocks), - MaskFromBools(d64, correction_controls_left_j, remaining_blocks)); - } - control_mask = hn::Xor(next_control_mask, - (hn::And(control_mask, correction_control_mask))); - } - // Store back into buffer, then copy to seeds_out. - hn::Store(vec, d8, buffer_ptr); - std::copy_n(buffer.get(), remaining_blocks, seeds_out + start_block); - BoolsFromMask(d64, control_mask, control_bits_out + start_block, - remaining_blocks); - } - - return absl::OkStatus(); -} - -#endif // HWY_TARGET == HWY_SCALAR - -} // namespace HWY_NAMESPACE -} // namespace dpf_internal -} // namespace distributed_point_functions -HWY_AFTER_NAMESPACE(); - -#if HWY_ONCE || HWY_IDE -namespace distributed_point_functions { -namespace dpf_internal { - -absl::Status EvaluateSeedsNoHwy( - int64_t num_seeds, int num_levels, int num_correction_words, - const absl::uint128* seeds_in, const bool* control_bits_in, - const absl::uint128* paths, int paths_rightshift, - const absl::uint128* correction_seeds, const bool* correction_controls_left, - const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left, - const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out, - bool* control_bits_out) { - using BitVector = - absl::InlinedVector<bool, - std::max<size_t>(1, sizeof(bool*) / sizeof(bool))>; - constexpr int64_t max_batch_size = Aes128FixedKeyHash::kBatchSize; - - // Allocate buffers. - std::vector<absl::uint128> buffer_left, buffer_right; - buffer_left.resize(max_batch_size); - buffer_right.resize(max_batch_size); - BitVector path_bits(max_batch_size), control_bits(max_batch_size); - - // Perform DPF evaluation in blocks. - for (int64_t start_block = 0; start_block < num_seeds; - start_block += max_batch_size) { - int64_t current_batch_size = - std::min<int64_t>(num_seeds - start_block, max_batch_size); - - for (int level = 0; level < num_levels; ++level) { - // Evaluate PRG. We evaluate both left and right expansions, but only use - // one of them (depending on path_bits). This seems to be faster than - // first sorting the seeds by path_bits and then expanding. - absl::Span<const absl::uint128> seeds = - absl::MakeConstSpan((level == 0 ? seeds_in : seeds_out) + start_block, - current_batch_size); - DPF_RETURN_IF_ERROR(prg_left.Evaluate( - seeds, absl::MakeSpan(buffer_left).subspan(0, current_batch_size))); - DPF_RETURN_IF_ERROR(prg_right.Evaluate( - seeds, absl::MakeSpan(buffer_right).subspan(0, current_batch_size))); - - // Merge back into result. - const int bit_index = num_levels - level - 1 + paths_rightshift; - for (int i = 0; i < current_batch_size; ++i) { - path_bits[i] = 0; - if (bit_index < 128) { - path_bits[i] = - ((paths[start_block + i]) & (absl::uint128{1} << bit_index)) != 0; - } - if (path_bits[i] == 0) { - seeds_out[start_block + i] = buffer_left[i]; - } else { - seeds_out[start_block + i] = buffer_right[i]; - } - } - - // Compute correction. Making a copy here a copy here improves pipelining - // by not updating result.control_bits in place. Do benchmarks before - // removing this. - std::copy_n( - &(level == 0 ? control_bits_in : control_bits_out)[start_block], - current_batch_size, &control_bits[0]); - int correction_index = level; - for (int i = 0; i < current_batch_size; ++i) { - if (num_correction_words > num_levels) { - // We have num_levels * num_seeds correction words. - correction_index = level * num_seeds + start_block + i; - } - if (control_bits[i]) { - seeds_out[start_block + i] ^= correction_seeds[correction_index]; - } - bool current_control_bit = - ExtractAndClearLowestBit(seeds_out[start_block + i]); - if (control_bits[i]) { - if (path_bits[i] == 0) { - current_control_bit ^= correction_controls_left[correction_index]; - } else { - current_control_bit ^= correction_controls_right[correction_index]; - } - } - control_bits_out[start_block + i] = current_control_bit; - } - } - } - - return absl::OkStatus(); -} - -HWY_EXPORT(EvaluateSeedsHwy); - -absl::Status EvaluateSeeds( - int64_t num_seeds, int num_levels, int num_correction_words, - const absl::uint128* seeds_in, const bool* control_bits_in, - const absl::uint128* paths, int paths_rightshift, - const absl::uint128* correction_seeds, const bool* correction_controls_left, - const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left, - const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out, - bool* control_bits_out) { - // Check that we either have one or `num_seeds` correction words per level. - if (num_correction_words != num_levels && - num_correction_words != num_levels * num_seeds) { - return absl::InvalidArgumentError( - "`num_correction_words` must be equal to `num_levels` or `num_levels * " - "num_seeds`"); - } - return HWY_DYNAMIC_DISPATCH(EvaluateSeedsHwy)( - num_seeds, num_levels, num_correction_words, seeds_in, control_bits_in, - paths, paths_rightshift, correction_seeds, correction_controls_left, - correction_controls_right, prg_left, prg_right, seeds_out, - control_bits_out); -} - -} // namespace dpf_internal -} // namespace distributed_point_functions -#endif diff --git a/third_party/distributed_point_functions/code/dpf/internal/evaluate_prg_hwy.h b/third_party/distributed_point_functions/code/dpf/internal/evaluate_prg_hwy.h deleted file mode 100644 index 866c7ebfab936..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/evaluate_prg_hwy.h +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_EXPAND_SEEDS_HWY_H_ -#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_EXPAND_SEEDS_HWY_H_ - -#include <stdint.h> - -#include "absl/numeric/int128.h" -#include "absl/status/status.h" -#include "dpf/aes_128_fixed_key_hash.h" - -namespace distributed_point_functions { -namespace dpf_internal { - -using distributed_point_functions::Aes128FixedKeyHash; - -// Extracts the lowest bit of `x` and sets it to 0 in `x`. -inline bool ExtractAndClearLowestBit(absl::uint128& x) { - bool bit = ((x & absl::uint128{1}) != 0); - x &= ~absl::uint128{1}; - return bit; -} - -// Performs DPF evaluation of the seeds given in `seeds_in` using `prg_left` or -// `prg_right, and the given `control_bits_in`, and correction words given by -// `correction_seeds`, `correction_controls_left`, and -// `correction_controls_right`. At each level `l < num_level`, the evaluation -// for the i-th seed continues along the left or right path depending on the -// l-th most significant bit among the lowest `num_levels` bits of `paths[i]`, -// after right-shifting each `paths[i]` by `paths_rightshift`. -// -// This function takes raw pointers instead of absl::Span for performance -// reasons. No bounds checks are performed, so it is the caller's responsibility -// to ensure that -// - `seeds_in`, `control_bits_in`, `seeds_out`, and `control_bits_out` have at -// least `num_seeds` elements, and -// - `correction_seeds`, `correction_controls_left`, and -// `correction_controls_right` have at least `num_levels` elements. -// -// If the inputs are aligned (e.g. using HWY_ALIGN, or hwy::AllocateAligned), -// and if SIMD operations are supported, then the evaluation will be done using -// SIMD operations. Otherwise, falls back to `EvaluateSeedsNoHwy`, which is at -// least 2x slower. -// -// `num_correction_words` can either be equal to `num_levels`, or equal to -// `num_seeds * num_levels`. In the first case, the same correction word is used -// for every seed at a given level. In the second case, correction word at index -// `i * num_seeds + j` is used to correct seed `i` at level `j`. -// If `num_correction_words == num_seeds * num_levels`, then `num_seeds` should -// be smaller than or divisible by the size of a SIMD vector for optimal -// performance. -// -// Returns OK on success, INVALID_ARGUMENT in case num_correction_words is not -// equal to `num_levels` or `num_seeds * num_levels`, and INTERNAL in case of -// OpenSSL errors. -absl::Status EvaluateSeeds( - int64_t num_seeds, int num_levels, int num_correction_words, - const absl::uint128* seeds_in, const bool* control_bits_in, - const absl::uint128* paths, int paths_rightshift, - const absl::uint128* correction_seeds, const bool* correction_controls_left, - const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left, - const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out, - bool* control_bits_out); - -// As `EvaluateSeeds`, but does not require any SIMD support. -absl::Status EvaluateSeedsNoHwy( - int64_t num_seeds, int num_levels, int num_correction_words, - const absl::uint128* seeds_in, const bool* control_bits_in, - const absl::uint128* paths, int paths_rightshift, - const absl::uint128* correction_seeds, const bool* correction_controls_left, - const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left, - const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out, - bool* control_bits_out); - -} // namespace dpf_internal -} // namespace distributed_point_functions - -#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_EXPAND_SEEDS_HWY_H_ diff --git a/third_party/distributed_point_functions/code/dpf/internal/evaluate_prg_hwy_test.cc b/third_party/distributed_point_functions/code/dpf/internal/evaluate_prg_hwy_test.cc deleted file mode 100644 index 8b05486250c25..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/evaluate_prg_hwy_test.cc +++ /dev/null @@ -1,257 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dpf/internal/evaluate_prg_hwy.h" - -#include <memory> - -#include "absl/numeric/int128.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "dpf/aes_128_fixed_key_hash.h" -#include "dpf/internal/status_matchers.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "hwy/aligned_allocator.h" - -// clang-format off -#define HWY_IS_TEST 1; -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "dpf/internal/evaluate_prg_hwy_test.cc" // NOLINT -#include "hwy/foreach_target.h" -// clang-format on -#include "hwy/highway.h" -#include "hwy/tests/hwy_gtest.h" - -HWY_BEFORE_NAMESPACE(); -namespace distributed_point_functions { -namespace dpf_internal { -namespace HWY_NAMESPACE { - -using ::testing::HasSubstr; - -constexpr absl::uint128 kKey0 = - absl::MakeUint128(0x0000000000000000, 0x0000000000000000); -constexpr absl::uint128 kKey1 = - absl::MakeUint128(0x1111111111111111, 0x1111111111111111); - -void TestOutputMatchesNoHwyVersion(int num_seeds, int num_levels, - int num_correction_words, - int paths_rightshift) { - // Generate seeds. - hwy::AlignedFreeUniquePtr<absl::uint128[]> seeds_in, paths; - hwy::AlignedFreeUniquePtr<bool[]> control_bits_in; - if (num_seeds > 0) { - seeds_in = hwy::AllocateAligned<absl::uint128>(num_seeds); - ASSERT_NE(seeds_in, nullptr); - paths = hwy::AllocateAligned<absl::uint128>(num_seeds); - ASSERT_NE(paths, nullptr); - control_bits_in = hwy::AllocateAligned<bool>(num_seeds); - ASSERT_NE(control_bits_in, nullptr); - } - for (int i = 0; i < num_seeds; ++i) { - // All of these are arbitrary. - seeds_in[i] = absl::MakeUint128(i, i + 1); - paths[i] = absl::MakeUint128(23 * i + 42, 42 * i + 23); - control_bits_in[i] = (i % 7 == 0); - } - hwy::AlignedFreeUniquePtr<absl::uint128[]> seeds_out; - hwy::AlignedFreeUniquePtr<bool[]> control_bits_out; - if (num_seeds > 0) { - seeds_out = hwy::AllocateAligned<absl::uint128>(num_seeds); - ASSERT_NE(seeds_out, nullptr); - control_bits_out = hwy::AllocateAligned<bool>(num_seeds); - ASSERT_NE(control_bits_out, nullptr); - } - - // Generate correction words. - hwy::AlignedFreeUniquePtr<absl::uint128[]> correction_seeds; - hwy::AlignedFreeUniquePtr<bool[]> correction_controls_left, - correction_controls_right; - if (num_correction_words > 0) { - correction_seeds = - hwy::AllocateAligned<absl::uint128>(num_correction_words); - ASSERT_NE(correction_seeds, nullptr); - correction_controls_left = hwy::AllocateAligned<bool>(num_correction_words); - ASSERT_NE(correction_controls_left, nullptr); - correction_controls_right = - hwy::AllocateAligned<bool>(num_correction_words); - ASSERT_NE(correction_controls_right, nullptr); - } - for (int i = 0; i < num_correction_words; ++i) { - correction_seeds[i] = absl::MakeUint128(i + 1, i); - correction_controls_left[i] = (i % 23 == 0); - correction_controls_right[i] = (i % 42 != 0); - } - - // Set up PRGs. - DPF_ASSERT_OK_AND_ASSIGN( - auto prg_left, - distributed_point_functions::Aes128FixedKeyHash::Create(kKey0)); - DPF_ASSERT_OK_AND_ASSIGN( - auto prg_right, - distributed_point_functions::Aes128FixedKeyHash::Create(kKey1)); - - // Evaluate with Highway enabled. - DPF_ASSERT_OK( - EvaluateSeeds(num_seeds, num_levels, num_correction_words, seeds_in.get(), - control_bits_in.get(), paths.get(), paths_rightshift, - correction_seeds.get(), correction_controls_left.get(), - correction_controls_right.get(), prg_left, prg_right, - seeds_out.get(), control_bits_out.get())); - - // Evaluate without highway. - hwy::AlignedFreeUniquePtr<absl::uint128[]> seeds_out_wanted; - hwy::AlignedFreeUniquePtr<bool[]> control_bits_out_wanted; - if (num_seeds > 0) { - seeds_out_wanted = hwy::AllocateAligned<absl::uint128>(num_seeds); - ASSERT_NE(seeds_out_wanted, nullptr); - control_bits_out_wanted = hwy::AllocateAligned<bool>(num_seeds); - ASSERT_NE(control_bits_out_wanted, nullptr); - } - DPF_ASSERT_OK(EvaluateSeedsNoHwy( - num_seeds, num_levels, num_correction_words, seeds_in.get(), - control_bits_in.get(), paths.get(), paths_rightshift, - correction_seeds.get(), correction_controls_left.get(), - correction_controls_right.get(), prg_left, prg_right, - seeds_out_wanted.get(), control_bits_out_wanted.get())); - - // Check that both evaluations are equal, if there was anything to evaluate. - if (num_levels > 0) { - for (int i = 0; i < num_seeds; ++i) { - EXPECT_EQ(seeds_out[i], seeds_out_wanted[i]); - EXPECT_EQ(control_bits_out[i], control_bits_out_wanted[i]); - } - } - - // Evaluate without paths_rightshift - if (paths_rightshift != 0) { - hwy::AlignedFreeUniquePtr<absl::uint128[]> paths_in2; - hwy::AlignedFreeUniquePtr<absl::uint128[]> seeds_out_wanted2; - hwy::AlignedFreeUniquePtr<bool[]> control_bits_out_wanted2; - if (num_seeds > 0) { - paths_in2 = hwy::AllocateAligned<absl::uint128>(num_seeds); - ASSERT_NE(paths_in2, nullptr); - seeds_out_wanted2 = hwy::AllocateAligned<absl::uint128>(num_seeds); - ASSERT_NE(seeds_out_wanted2, nullptr); - control_bits_out_wanted2 = hwy::AllocateAligned<bool>(num_seeds); - ASSERT_NE(control_bits_out_wanted2, nullptr); - } - for (int i = 0; i < num_seeds; ++i) { - paths_in2[i] = 0; - if (paths_rightshift < 128) { - paths_in2[i] = paths[i] >> paths_rightshift; - } - } - DPF_ASSERT_OK(EvaluateSeedsNoHwy( - num_seeds, num_levels, num_correction_words, seeds_in.get(), - control_bits_in.get(), paths_in2.get(), 0, correction_seeds.get(), - correction_controls_left.get(), correction_controls_right.get(), - prg_left, prg_right, seeds_out_wanted2.get(), - control_bits_out_wanted2.get())); - // Check that both evaluations are equal, if there was anything to evaluate. - if (num_levels > 0) { - for (int i = 0; i < num_seeds; ++i) { - EXPECT_EQ(seeds_out[i], seeds_out_wanted2[i]); - EXPECT_EQ(control_bits_out[i], control_bits_out_wanted2[i]); - } - } - } -} - -void TestAll() { - for (int num_seeds : {0, 1, 2, 101, 128, 1000}) { - for (int num_levels : {0, 1, 2, 32, 63, 64, 127, 128}) { - for (int num_correction_words : {num_levels, num_levels * num_seeds}) { - TestOutputMatchesNoHwyVersion(num_seeds, num_levels, - num_correction_words, 0); - } - } - } -} - -void TestPathsRightshift() { - constexpr int num_levels = 128; - for (int num_seeds : {0, 1, 101}) { - for (int paths_rightshift = 0; paths_rightshift <= 128; - ++paths_rightshift) { - TestOutputMatchesNoHwyVersion(num_seeds, num_levels, num_levels, - paths_rightshift); - } - } -} - -void FailsIfNumCorrectionWordsIsWrong() { - constexpr int num_seeds = 1000; - constexpr int num_levels = 10; - constexpr int num_correction_words = 12; - - hwy::AlignedFreeUniquePtr<absl::uint128[]> seeds_in, paths; - hwy::AlignedFreeUniquePtr<bool[]> control_bits_in; - seeds_in = hwy::AllocateAligned<absl::uint128>(num_seeds); - ASSERT_NE(seeds_in, nullptr); - paths = hwy::AllocateAligned<absl::uint128>(num_seeds); - ASSERT_NE(paths, nullptr); - control_bits_in = hwy::AllocateAligned<bool>(num_seeds); - ASSERT_NE(control_bits_in, nullptr); - - hwy::AlignedFreeUniquePtr<absl::uint128[]> correction_seeds; - hwy::AlignedFreeUniquePtr<bool[]> correction_controls_left, - correction_controls_right; - correction_seeds = hwy::AllocateAligned<absl::uint128>(num_correction_words); - ASSERT_NE(correction_seeds, nullptr); - correction_controls_left = hwy::AllocateAligned<bool>(num_correction_words); - ASSERT_NE(correction_controls_left, nullptr); - correction_controls_right = hwy::AllocateAligned<bool>(num_correction_words); - ASSERT_NE(correction_controls_right, nullptr); - - DPF_ASSERT_OK_AND_ASSIGN( - auto prg_left, - distributed_point_functions::Aes128FixedKeyHash::Create(kKey0)); - DPF_ASSERT_OK_AND_ASSIGN( - auto prg_right, - distributed_point_functions::Aes128FixedKeyHash::Create(kKey1)); - - EXPECT_THAT( - EvaluateSeeds(num_seeds, num_levels, num_correction_words, seeds_in.get(), - control_bits_in.get(), paths.get(), 0, - correction_seeds.get(), correction_controls_left.get(), - correction_controls_right.get(), prg_left, prg_right, - seeds_in.get(), control_bits_in.get()), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("num_correction_words"))); -} - -} // namespace HWY_NAMESPACE -} // namespace dpf_internal -} // namespace distributed_point_functions -HWY_AFTER_NAMESPACE(); - -#if HWY_ONCE - -namespace distributed_point_functions { -namespace dpf_internal { -HWY_BEFORE_TEST(EvaluatePrgHwyTest); -HWY_EXPORT_AND_TEST_P(EvaluatePrgHwyTest, TestAll); -HWY_EXPORT_AND_TEST_P(EvaluatePrgHwyTest, TestPathsRightshift); -HWY_EXPORT_AND_TEST_P(EvaluatePrgHwyTest, FailsIfNumCorrectionWordsIsWrong); -} // namespace dpf_internal -} // namespace distributed_point_functions - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} - -#endif diff --git a/third_party/distributed_point_functions/code/dpf/internal/get_hwy_mode.cc b/third_party/distributed_point_functions/code/dpf/internal/get_hwy_mode.cc deleted file mode 100644 index b3d07729c7e00..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/get_hwy_mode.cc +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dpf/internal/get_hwy_mode.h" - -// clang-format off -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "dpf/internal/get_hwy_mode.cc" -#include "absl/strings/string_view.h" -#include "hwy/foreach_target.h" -// clang-format on - -#include "hwy/highway.h" - -HWY_BEFORE_NAMESPACE(); -namespace distributed_point_functions { -namespace dpf_internal { -namespace HWY_NAMESPACE { - -const absl::string_view GetHwyModeAsString() { - return hwy::TargetName(HWY_TARGET); -} - -} // namespace HWY_NAMESPACE - -#if HWY_ONCE || HWY_IDE - -HWY_EXPORT(GetHwyModeAsString); -const absl::string_view GetHwyModeAsString() { - return HWY_DYNAMIC_DISPATCH(GetHwyModeAsString)(); -} - -#endif - -} // namespace dpf_internal -} // namespace distributed_point_functions -HWY_AFTER_NAMESPACE(); diff --git a/third_party/distributed_point_functions/code/dpf/internal/get_hwy_mode.h b/third_party/distributed_point_functions/code/dpf/internal/get_hwy_mode.h deleted file mode 100644 index a123850924910..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/get_hwy_mode.h +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright 2022 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_GET_HWY_MODE_H_ -#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_GET_HWY_MODE_H_ - -#include "absl/strings/string_view.h" - -namespace distributed_point_functions { -namespace dpf_internal { - -// Utility function for printing the mode selected by Highway. Used for -// debugging. -const absl::string_view GetHwyModeAsString(); - -} // namespace dpf_internal -} // namespace distributed_point_functions - -#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_GET_HWY_MODE_H_ diff --git a/third_party/distributed_point_functions/code/dpf/internal/maybe_deref_span.h b/third_party/distributed_point_functions/code/dpf/internal/maybe_deref_span.h deleted file mode 100644 index 27b5d74d9c30f..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/maybe_deref_span.h +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright 2023 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_ANY_SPAN_H_ -#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_ANY_SPAN_H_ - -// A class that can serve the purpose of both absl::Span<T> and absl::Span<T*> -// at the same time. Introduces the run-time overhead of a std::variant check. -// -// Note that this class DOES NOT provide common container features, such as -// iterators. It is not intended to be used by users of this library. Any -// function that takes a MaybeDerefSpan<T> should be called with either an -// absl::Span<T> or an absl::Span<T*>. - -#include <type_traits> - -#include "absl/meta/type_traits.h" -#include "absl/types/span.h" -#include "absl/types/variant.h" - -namespace distributed_point_functions { -namespace dpf_internal { - -template <typename T> -class MaybeDerefSpan { - private: - template <typename U> - using EnableIfValueIsConst = - typename absl::enable_if_t<std::is_const<T>::value, U>; - - template <typename U> - using EnableIfValueIsConvertibleToSpan = typename absl::enable_if_t< - absl::disjunction<std::is_convertible<U, absl::Span<T>>, - std::is_convertible<U, absl::Span<T* const>>>::value, - U>; - - public: - // Implicit constructors from the underlying absl::Span. - MaybeDerefSpan(absl::Span<T> span) - : span_(span) {} // NOLINT(runtime/explicit) - MaybeDerefSpan(absl::Span<T* const> span) - : span_(span) {} // NOLINT(runtime/explicit) - - // Implicit constructor of a const MaybeDerefSpan from a non-const one. - template <typename T2 = T, typename = EnableIfValueIsConst<T2>> - MaybeDerefSpan( - const MaybeDerefSpan<typename std::remove_const<T>::type>& other) - : span_(absl::ConvertVariantTo<decltype(span_)>(other.span_)) { - } // NOLINT(runtime/explicit) - - // Implicit constructor of a const MaybeDerefSpan from anything that is - // convertible to one of the underlying spans. - template <typename V, typename = EnableIfValueIsConst<V>, - typename = EnableIfValueIsConvertibleToSpan<V>> - MaybeDerefSpan(const V& span) - : span_(absl::MakeConstSpan(span)) {} // NOLINT(runtime/explicit) - - inline constexpr T& operator[](size_t index) const { - if (absl::holds_alternative<absl::Span<T* const>>(span_)) { - return *absl::get<absl::Span<T* const>>(span_)[index]; - } - return absl::get<absl::Span<T>>(span_)[index]; - } - - inline constexpr size_t size() const { - return absl::visit([](auto v) { return v.size(); }, span_); - } - - private: - template <typename U> - friend class MaybeDerefSpan; - - absl::variant<absl::Span<T>, absl::Span<T* const>> span_; -}; - -} // namespace dpf_internal -} // namespace distributed_point_functions - -#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_ANY_SPAN_H_ diff --git a/third_party/distributed_point_functions/code/dpf/internal/maybe_deref_span_test.cc b/third_party/distributed_point_functions/code/dpf/internal/maybe_deref_span_test.cc deleted file mode 100644 index cbcbfcc9f7554..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/maybe_deref_span_test.cc +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dpf/internal/maybe_deref_span.h" - -#include <type_traits> -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace distributed_point_functions { -namespace dpf_internal { -namespace { - -using T = int; - -TEST(MaybeDerefSpanTest, TestExplicitMutableDirectSpan) { - std::vector<T> x = {1, 2}; - absl::Span<T> span(x); - MaybeDerefSpan<T> span2(span); - - EXPECT_EQ(span2.size(), x.size()); - EXPECT_EQ(span2[0], x[0]); - EXPECT_EQ(span2[1], x[1]); - EXPECT_EQ(&span2[0], &x[0]); - EXPECT_EQ(&span2[1], &x[1]); - - span2[0] = 3; - EXPECT_EQ(span2[0], x[0]); - EXPECT_EQ(x[0], 3); -} - -TEST(MaybeDerefSpanTest, TestExplicitMutableSpan) { - const std::vector<T> x = {1, 2}; - absl::Span<const T> span(x); - MaybeDerefSpan<const T> span2(span); - - EXPECT_EQ(span2.size(), x.size()); - EXPECT_EQ(span2[0], x[0]); - EXPECT_EQ(span2[1], x[1]); - EXPECT_EQ(&span2[0], &x[0]); - EXPECT_EQ(&span2[1], &x[1]); -} - -TEST(MaybeDerefSpanTest, TestExplicitMutablePointerSpan) { - std::vector<T> x = {1, 2}; - std::vector<T*> x2 = {&x[0], &x[1]}; - absl::Span<T*> span(x2); - MaybeDerefSpan<T> span2(span); - - EXPECT_EQ(span2.size(), x.size()); - EXPECT_EQ(span2[0], x[0]); - EXPECT_EQ(span2[1], x[1]); - EXPECT_EQ(&span2[0], &x[0]); - EXPECT_EQ(&span2[1], &x[1]); - - span2[0] = 3; - EXPECT_EQ(span2[0], x[0]); - EXPECT_EQ(x[0], 3); -} - -TEST(MaybeDerefSpanTest, TestExplicitMutablePointerConstSpan) { - std::vector<T> x = {1, 2}; - const std::vector<T*> x2 = {&x[0], &x[1]}; - absl::Span<T* const> span(x2); - MaybeDerefSpan<T> span2(span); - - EXPECT_EQ(span2.size(), x.size()); - EXPECT_EQ(span2[0], x[0]); - EXPECT_EQ(span2[1], x[1]); - EXPECT_EQ(&span2[0], &x[0]); - EXPECT_EQ(&span2[1], &x[1]); -} - -TEST(MaybeDerefSpanTest, TestExplicitConstPointerConstSpan) { - const std::vector<T> x = {1, 2}; - const std::vector<const T*> x2 = {&x[0], &x[1]}; - absl::Span<const T* const> span(x2); - MaybeDerefSpan<const T> span2(span); - - EXPECT_EQ(span2.size(), x.size()); - EXPECT_EQ(span2[0], x[0]); - EXPECT_EQ(span2[1], x[1]); - EXPECT_EQ(&span2[0], &x[0]); - EXPECT_EQ(&span2[1], &x[1]); -} - -TEST(MaybeDerefSpanTest, TestMutableSpanToConstSpan) { - std::vector<T> x = {1, 2}; - absl::Span<T> span(x); - MaybeDerefSpan<T> span2(span); - MaybeDerefSpan<const T> span3(span2); - - EXPECT_EQ(span3.size(), x.size()); - EXPECT_EQ(span3[0], x[0]); - EXPECT_EQ(span3[1], x[1]); - EXPECT_EQ(&span3[0], &x[0]); - EXPECT_EQ(&span3[1], &x[1]); -} - -TEST(MaybeDerefSpanTest, TestImplicitConstSpan) { - const std::vector<T> x = {1, 2}; - MaybeDerefSpan<const T> span2(x); - - EXPECT_EQ(span2.size(), x.size()); - EXPECT_EQ(span2[0], x[0]); - EXPECT_EQ(span2[1], x[1]); - EXPECT_EQ(&span2[0], &x[0]); - EXPECT_EQ(&span2[1], &x[1]); -} - -TEST(MaybeDerefSpanTest, TestImplicitPointerConstSpan) { - const std::vector<T> x = {1, 2}; - const std::vector<const T*> x2 = {&x[0], &x[1]}; - MaybeDerefSpan<const T> span2(x2); - - EXPECT_EQ(span2.size(), x.size()); - EXPECT_EQ(span2[0], x[0]); - EXPECT_EQ(span2[1], x[1]); - EXPECT_EQ(&span2[0], &x[0]); - EXPECT_EQ(&span2[1], &x[1]); -} - -void TestEq(MaybeDerefSpan<const T> span, const std::vector<T>& vector) { - EXPECT_EQ(span.size(), vector.size()); - EXPECT_EQ(span[0], vector[0]); - EXPECT_EQ(span[1], vector[1]); - EXPECT_EQ(&span[0], &vector[0]); - EXPECT_EQ(&span[1], &vector[1]); -} - -TEST(MaybeDerefSpanTest, TestFunctionCallMutableVector) { - std::vector<T> x = {1, 2}; - - TestEq(x, x); -} - -TEST(MaybeDerefSpanTest, TestFunctionCallMutablePointerVector) { - std::vector<T> x = {1, 2}; - std::vector<T*> x2 = {&x[0], &x[1]}; - - TestEq(x2, x); -} - -TEST(MaybeDerefSpanTest, TestFunctionCallConstVector) { - const std::vector<T> x = {1, 2}; - - TestEq(x, x); -} - -TEST(MaybeDerefSpanTest, TestFunctionCallMutablePointerConstVector) { - std::vector<T> x = {1, 2}; - const std::vector<T*> x2 = {&x[0], &x[1]}; - - TestEq(x2, x); -} - -TEST(MaybeDerefSpanTest, TestFunctionCallConstPointerConstVector) { - const std::vector<T> x = {1, 2}; - const std::vector<const T*> x2 = {&x[0], &x[1]}; - - TestEq(x2, x); -} - -// Taken from https://en.cppreference.com/w/cpp/types/is_convertible. -template <class From, class To> -auto test_implicitly_convertible(int) - -> decltype(void(std::declval<void (&)(To)>()(std::declval<From>())), - std::true_type{}); -template <class, class> -auto test_implicitly_convertible(...) -> std::false_type; - -// Test that vectors are convertible only to const spans. -static_assert( - decltype(test_implicitly_convertible<std::vector<T>, MaybeDerefSpan<T>>( - 0))::value == false); -static_assert(decltype(test_implicitly_convertible< - std::vector<T>, MaybeDerefSpan<const T>>(0))::value == - true); - -} // namespace -} // namespace dpf_internal -} // namespace distributed_point_functions diff --git a/third_party/distributed_point_functions/code/dpf/internal/proto_validator.cc b/third_party/distributed_point_functions/code/dpf/internal/proto_validator.cc deleted file mode 100644 index 678b3de9770b0..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/proto_validator.cc +++ /dev/null @@ -1,336 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dpf/internal/proto_validator.h" - -#include <algorithm> -#include <cmath> -#include <memory> -#include <string> -#include <utility> -#include <vector> - -#include "absl/container/flat_hash_map.h" -#include "absl/log/absl_check.h" -#include "absl/memory/memory.h" -#include "absl/numeric/int128.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/types/span.h" -#include "dpf/distributed_point_function.pb.h" -#include "dpf/internal/value_type_helpers.h" -#include "dpf/status_macros.h" -#include "google/protobuf/repeated_field.h" - -namespace distributed_point_functions { -namespace dpf_internal { - -namespace { - -inline double GetDefaultSecurityParameter(const DpfParameters& parameters) { - return ProtoValidator::kDefaultSecurityParameter + - parameters.log_domain_size(); -} - -inline bool AlmostEqual(double a, double b) { - return std::abs(a - b) <= ProtoValidator::kSecurityParameterEpsilon; -} - -absl::StatusOr<bool> ParametersAreEqual(const DpfParameters& lhs, - const DpfParameters& rhs) { - if (lhs.log_domain_size() != rhs.log_domain_size()) { - return false; - } - if (!( - // There are three ways that security parameters can be equivalent. - // Both equal. - AlmostEqual(lhs.security_parameter(), rhs.security_parameter()) || - // lhs is zero and rhs has the default value. - (lhs.security_parameter() == 0 && - AlmostEqual(rhs.security_parameter(), - GetDefaultSecurityParameter(rhs))) || - // rhs is zero and lhs has the default value. - (rhs.security_parameter() == 0 && - AlmostEqual(lhs.security_parameter(), - GetDefaultSecurityParameter(lhs))))) { - return false; - } - return ValueTypesAreEqual(lhs.value_type(), rhs.value_type()); -} - -absl::Status ValidateIntegerType(const ValueType::Integer& type) { - int bitsize = type.bitsize(); - if (bitsize < 1) { - return absl::InvalidArgumentError("`bitsize` must be positive"); - } - if (bitsize > 128) { - return absl::InvalidArgumentError( - "`bitsize` must be less than or equal to 128"); - } - if ((bitsize & (bitsize - 1)) != 0) { - return absl::InvalidArgumentError("`bitsize` must be a power of 2"); - } - return absl::OkStatus(); -} - -absl::Status ValidateIntegerValue(const Value::Integer& value, - const ValueType::Integer& type) { - if (type.bitsize() < 128) { - DPF_ASSIGN_OR_RETURN(absl::uint128 value_128, ValueIntegerToUint128(value)); - if (value_128 >= absl::uint128{1} << type.bitsize()) { - return absl::InvalidArgumentError(absl::StrFormat( - "Value (= %d) too large for ValueType with bitsize = %d", value_128, - type.bitsize())); - } - } - return absl::OkStatus(); -} - -} // namespace - -ProtoValidator::ProtoValidator(std::vector<DpfParameters> parameters, - int tree_levels_needed, - absl::flat_hash_map<int, int> tree_to_hierarchy, - std::vector<int> hierarchy_to_tree) - : parameters_(std::move(parameters)), - tree_levels_needed_(tree_levels_needed), - tree_to_hierarchy_(std::move(tree_to_hierarchy)), - hierarchy_to_tree_(std::move(hierarchy_to_tree)) {} - -absl::StatusOr<std::unique_ptr<ProtoValidator>> ProtoValidator::Create( - absl::Span<const DpfParameters> parameters_in) { - DPF_RETURN_IF_ERROR(ValidateParameters(parameters_in)); - - // Set default values of security_parameter for all parameters. - std::vector<DpfParameters> parameters(parameters_in.begin(), - parameters_in.end()); - for (int i = 0; i < static_cast<int>(parameters.size()); ++i) { - if (parameters[i].security_parameter() == 0) { - parameters[i].set_security_parameter( - GetDefaultSecurityParameter(parameters[i])); - } - } - - // Map hierarchy levels to levels in the evaluation tree for value correction, - // and vice versa. - absl::flat_hash_map<int, int> tree_to_hierarchy; - std::vector<int> hierarchy_to_tree(parameters.size()); - // Also keep track of the height needed for the evaluation tree so far. - int tree_levels_needed = 0; - for (int i = 0; i < static_cast<int>(parameters.size()); ++i) { - int log_bits_needed; - DPF_ASSIGN_OR_RETURN(int bits_needed, - BitsNeeded(parameters[i].value_type(), - parameters[i].security_parameter())); - log_bits_needed = static_cast<int>(std::ceil(std::log2(bits_needed))); - - // The tree level depends on the domain size and the element size. A single - // AES block can fit 128 = 2^7 bits, so usually tree_level == - // log_domain_size iff log_element_size >= 7. For smaller element sizes, we - // can reduce the tree_level (and thus the height of the tree) by the - // difference between log_element_size and 7. However, since the minimum - // tree level is 0, we have to ensure that no two hierarchy levels map to - // the same tree_level, hence the std::max. - int tree_level = - std::max(tree_levels_needed, parameters[i].log_domain_size() - 7 + - std::min(log_bits_needed, 7)); - tree_to_hierarchy[tree_level] = i; - hierarchy_to_tree[i] = tree_level; - tree_levels_needed = std::max(tree_levels_needed, tree_level + 1); - } - - return absl::WrapUnique(new ProtoValidator( - std::move(parameters), tree_levels_needed, std::move(tree_to_hierarchy), - std::move(hierarchy_to_tree))); -} - -absl::Status ProtoValidator::ValidateParameters( - absl::Span<const DpfParameters> parameters) { - // Check that parameters are valid. - if (parameters.empty()) { - return absl::InvalidArgumentError("`parameters` must not be empty"); - } - // Sentinel value for checking that domain sizes are increasing. - int previous_log_domain_size = 0; - for (int i = 0; i < static_cast<int>(parameters.size()); ++i) { - // Check log_domain_size. - int log_domain_size = parameters[i].log_domain_size(); - if (log_domain_size < 0) { - return absl::InvalidArgumentError( - "`log_domain_size` must be non-negative"); - } - if (log_domain_size > 128) { - return absl::InvalidArgumentError("`log_domain_size` must be <= 128"); - } - if (i > 0 && log_domain_size <= previous_log_domain_size) { - return absl::InvalidArgumentError( - "`log_domain_size` fields must be in ascending order in " - "`parameters`"); - } - previous_log_domain_size = log_domain_size; - - if (parameters[i].has_value_type()) { - DPF_RETURN_IF_ERROR(ValidateValueType(parameters[i].value_type())); - } else { - return absl::InvalidArgumentError("`value_type` is required"); - } - - if (std::isnan(parameters[i].security_parameter())) { - return absl::InvalidArgumentError("`security_parameter` must not be NaN"); - } - if (parameters[i].security_parameter() < 0 || - parameters[i].security_parameter() > 128) { - // Since we use AES-128 for the PRG, a security parameter of > 128 is not - // possible. - return absl::InvalidArgumentError( - "`security_parameter` must be in [0, 128]"); - } - } - return absl::OkStatus(); -} - -absl::Status ProtoValidator::ValidateDpfKey(const DpfKey& key) const { - // Check that `key` has the seed and last_level_output_correction set. - if (!key.has_seed()) { - return absl::InvalidArgumentError("key.seed must be present"); - } - if (key.last_level_value_correction().empty()) { - return absl::InvalidArgumentError( - "key.last_level_value_correction must be present"); - } - // Check that `key` is valid for the DPF defined by `parameters_`. - if (key.correction_words_size() != tree_levels_needed_ - 1) { - return absl::InvalidArgumentError(absl::StrCat( - "Malformed DpfKey: expected ", tree_levels_needed_ - 1, - " correction words, but got ", key.correction_words_size())); - } - for (int i = 0; i < static_cast<int>(hierarchy_to_tree_.size()); ++i) { - if (hierarchy_to_tree_[i] == tree_levels_needed_ - 1) { - // The output correction of the last tree level is always stored in - // last_level_output_correction. - continue; - } - ABSL_DCHECK(hierarchy_to_tree_[i] < key.correction_words_size()); - if (key.correction_words(hierarchy_to_tree_[i]) - .value_correction() - .empty()) { - return absl::InvalidArgumentError(absl::StrCat( - "Malformed DpfKey: expected correction_words[", hierarchy_to_tree_[i], - "] to contain the value correction of hierarchy level ", i)); - } - } - return absl::OkStatus(); -} - -absl::Status ProtoValidator::ValidateEvaluationContext( - const EvaluationContext& ctx) const { - if (ctx.parameters_size() != static_cast<int>(parameters_.size())) { - return absl::InvalidArgumentError( - "Number of parameters in `ctx` doesn't match"); - } - for (int i = 0; i < ctx.parameters_size(); ++i) { - DPF_ASSIGN_OR_RETURN(bool parameters_are_equal, - ParametersAreEqual(parameters_[i], ctx.parameters(i))); - if (!parameters_are_equal) { - return absl::InvalidArgumentError( - absl::StrCat("Parameter ", i, " in `ctx` doesn't match")); - } - } - if (!ctx.has_key()) { - return absl::InvalidArgumentError("ctx.key must be present"); - } - DPF_RETURN_IF_ERROR(ValidateDpfKey(ctx.key())); - if (ctx.previous_hierarchy_level() >= ctx.parameters_size() - 1) { - return absl::InvalidArgumentError( - "This context has already been fully evaluated"); - } - if (!ctx.partial_evaluations().empty() && - ctx.partial_evaluations_level() > ctx.previous_hierarchy_level()) { - return absl::InvalidArgumentError( - "ctx.partial_evaluations_level must be less than or equal to " - "ctx.previous_hierarchy_level"); - } - return absl::OkStatus(); -} - -absl::Status ProtoValidator::ValidateValueType(const ValueType& value_type) { - if (value_type.type_case() == ValueType::kInteger) { - return ValidateIntegerType(value_type.integer()); - } else if (value_type.type_case() == ValueType::kTuple) { - for (const ValueType& el : value_type.tuple().elements()) { - DPF_RETURN_IF_ERROR(ValidateValueType(el)); - } - return absl::OkStatus(); - } else if (value_type.type_case() == ValueType::kIntModN) { - const ValueType::Integer& base_integer = - value_type.int_mod_n().base_integer(); - DPF_RETURN_IF_ERROR(ValidateIntegerType(base_integer)); - return ValidateIntegerValue(value_type.int_mod_n().modulus(), base_integer); - } else if (value_type.type_case() == ValueType::kXorWrapper) { - return ValidateIntegerType(value_type.xor_wrapper()); - } - return absl::InvalidArgumentError(absl::StrCat( - "ValidateValueType: Unsupported ValueType:\n", value_type.DebugString())); -} - -absl::Status ProtoValidator::ValidateValue(const Value& value, - const ValueType& type) { - if (type.type_case() == ValueType::kInteger) { - // Integers. - if (value.value_case() != Value::kInteger) { - return absl::InvalidArgumentError("Expected integer value"); - } - return ValidateIntegerValue(value.integer(), type.integer()); - } else if (type.type_case() == ValueType::kTuple) { - // Tuples. - if (value.value_case() != Value::kTuple) { - return absl::InvalidArgumentError("Expected tuple value"); - } - if (value.tuple().elements_size() != type.tuple().elements_size()) { - return absl::InvalidArgumentError(absl::StrCat( - "Expected tuple value of size ", type.tuple().elements_size(), - " but got size ", value.tuple().elements_size())); - } - for (int i = 0; i < type.tuple().elements_size(); ++i) { - DPF_RETURN_IF_ERROR( - ValidateValue(value.tuple().elements(i), type.tuple().elements(i))); - } - return absl::OkStatus(); - } else if (type.type_case() == ValueType::kIntModN) { - DPF_RETURN_IF_ERROR(ValidateIntegerValue(value.int_mod_n(), - type.int_mod_n().base_integer())); - DPF_ASSIGN_OR_RETURN(absl::uint128 value_128, - ValueIntegerToUint128(value.int_mod_n())); - DPF_ASSIGN_OR_RETURN(absl::uint128 modulus_128, - ValueIntegerToUint128(type.int_mod_n().modulus())); - if (value_128 >= modulus_128) { - return absl::InvalidArgumentError( - absl::StrFormat("Value (= %d) is too large for modulus (= %d)", - value_128, modulus_128)); - } - return absl::OkStatus(); - } else if (type.type_case() == ValueType::kXorWrapper) { - if (value.value_case() != Value::kXorWrapper) { - return absl::InvalidArgumentError("Expected XorWrapper value"); - } - return ValidateIntegerValue(value.xor_wrapper(), type.xor_wrapper()); - } - return absl::InvalidArgumentError(absl::StrCat( - "ValidateValue: Unsupported ValueType:\n", type.DebugString())); -} - -} // namespace dpf_internal -} // namespace distributed_point_functions diff --git a/third_party/distributed_point_functions/code/dpf/internal/proto_validator.h b/third_party/distributed_point_functions/code/dpf/internal/proto_validator.h deleted file mode 100644 index e6e63dd8ee9c2..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/proto_validator.h +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_PROTO_VALIDATOR_H_ -#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_PROTO_VALIDATOR_H_ - -#include <memory> -#include <vector> - -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "dpf/distributed_point_function.pb.h" - -namespace distributed_point_functions { -namespace dpf_internal { -// ProtoValidator is used to validate protos for DPF parameters, keys, and -// evaluation contexts. Also holds information computed from the DPF parameters, -// such as the mappings between hierarchy and tree levels. -class ProtoValidator { - public: - // The negative logarithm of the total variation distance from uniform that a - // *full* evaluation of a hierarchy level is allowed to have. Used as the - // default value for DpfParameters that don't have an explicit per-element - // security parameter set. - static constexpr double kDefaultSecurityParameter = 40; - - // Security parameters that differ by less than this are considered equal. - static constexpr double kSecurityParameterEpsilon = 0.0001; - - // Checks the validity of `parameters` and returns a ProtoValidator, which - // will be used to validate DPF keys and evaluation contexts afterwards. - // - // Returns INVALID_ARGUMENT if `parameters` are invalid. - static absl::StatusOr<std::unique_ptr<ProtoValidator>> Create( - absl::Span<const DpfParameters> parameters); - - // Checks the validity of `parameters`. - // Returns OK on success, and INVALID_ARGUMENT otherwise. - static absl::Status ValidateParameters( - absl::Span<const DpfParameters> parameters); - - // Checks that `key` is valid for the `parameters` passed at construction. - // Returns OK on success, and INVALID_ARGUMENT otherwise. - absl::Status ValidateDpfKey(const DpfKey& key) const; - - // Checks that `ctx` is valid for the `parameters` passed at construction. - // Returns OK on success, and INVALID_ARGUMENT otherwise. - absl::Status ValidateEvaluationContext(const EvaluationContext& ctx) const; - - // Checks that the given ValueType is valid. - // Returns OK on success and INVALID_ARGUMENT otherwise. - static absl::Status ValidateValueType(const ValueType& value_type); - - // Checks that `value` is valid for `type`. - // Returns OK on success and INVALID_ARGUMENT otherwise. - static absl::Status ValidateValue(const Value& value, const ValueType& type); - - // Checks that `value` is valid for `parameters[i]` passed at construction. - // Returns OK on success and INVALID_ARGUMENT otherwise. - inline absl::Status ValidateValue(const Value& value, int i) const { - return ValidateValue(value, parameters_[i].value_type()); - } - - // ProtoValidator is not copyable. - ProtoValidator(const ProtoValidator&) = delete; - ProtoValidator& operator=(const ProtoValidator&) = delete; - - // ProtoValidator is movable. - ProtoValidator(ProtoValidator&&) = default; - ProtoValidator& operator=(ProtoValidator&&) = default; - - // Getters. - absl::Span<const DpfParameters> parameters() const { return parameters_; } - int tree_levels_needed() const { return tree_levels_needed_; } - const absl::flat_hash_map<int, int>& tree_to_hierarchy() const { - return tree_to_hierarchy_; - } - const std::vector<int>& hierarchy_to_tree() const { - return hierarchy_to_tree_; - } - - private: - ProtoValidator(std::vector<DpfParameters> parameters, int tree_levels_needed, - absl::flat_hash_map<int, int> tree_to_hierarchy, - std::vector<int> hierarchy_to_tree); - - // The DpfParameters passed at construction. - std::vector<DpfParameters> parameters_; - - // Number of levels in the evaluation tree. This is always less than or equal - // to the largest log_domain_size in parameters_. - int tree_levels_needed_; - - // Maps levels of the FSS evaluation tree to hierarchy levels (i.e., elements - // of parameters_). - absl::flat_hash_map<int, int> tree_to_hierarchy_; - - // The inverse of tree_to_hierarchy_. - std::vector<int> hierarchy_to_tree_; -}; - -} // namespace dpf_internal -} // namespace distributed_point_functions - -#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_PROTO_VALIDATOR_H_ diff --git a/third_party/distributed_point_functions/code/dpf/internal/proto_validator_test.cc b/third_party/distributed_point_functions/code/dpf/internal/proto_validator_test.cc deleted file mode 100644 index 62e9d7090b11c..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/proto_validator_test.cc +++ /dev/null @@ -1,417 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dpf/internal/proto_validator.h" - -#include <stdint.h> - -#include <cmath> -#include <memory> -#include <string> -#include <vector> - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "dpf/distributed_point_function.pb.h" -#include "dpf/internal/proto_validator_test_textproto_embed.h" -#include "dpf/internal/status_matchers.h" -#include "dpf/tuple.h" -#include "gmock/gmock.h" -#include "google/protobuf/repeated_field.h" -#include "google/protobuf/text_format.h" -#include "gtest/gtest.h" - -namespace distributed_point_functions { -namespace dpf_internal { -namespace { - -using ::testing::Ne; -using ::testing::StartsWith; - -class ProtoValidatorTest : public testing::Test { - protected: - void SetUp() override { - const auto* const toc = proto_validator_test_textproto_embed_create(); - ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( - std::string(toc->data, toc->size), &ctx_)); - parameters_ = std::vector<DpfParameters>(ctx_.parameters().begin(), - ctx_.parameters().end()); - dpf_key_ = ctx_.key(); - DPF_ASSERT_OK_AND_ASSIGN(proto_validator_, - ProtoValidator::Create(parameters_)); - } - - std::vector<DpfParameters> parameters_; - DpfKey dpf_key_; - EvaluationContext ctx_; - std::unique_ptr<dpf_internal::ProtoValidator> proto_validator_; -}; - -TEST_F(ProtoValidatorTest, CreateFailsWithoutParameters) { - EXPECT_THAT(ProtoValidator::Create({}), - StatusIs(absl::StatusCode::kInvalidArgument, - "`parameters` must not be empty")); -} - -TEST_F(ProtoValidatorTest, CreateFailsWhenParametersNotSorted) { - parameters_.resize(2); - parameters_[0].set_log_domain_size(10); - parameters_[1].set_log_domain_size(8); - - EXPECT_THAT(ProtoValidator::Create(parameters_), - StatusIs(absl::StatusCode::kInvalidArgument, - "`log_domain_size` fields must be in ascending order in " - "`parameters`")); -} - -TEST_F(ProtoValidatorTest, CreateFailsWhenDomainSizeNegative) { - parameters_.resize(1); - parameters_[0].set_log_domain_size(-1); - - EXPECT_THAT(ProtoValidator::Create(parameters_), - StatusIs(absl::StatusCode::kInvalidArgument, - "`log_domain_size` must be non-negative")); -} - -TEST_F(ProtoValidatorTest, CreateFailsWhenDomainSizeTooLarge) { - parameters_.resize(1); - parameters_[0].set_log_domain_size(129); - - EXPECT_THAT(ProtoValidator::Create(parameters_), - StatusIs(absl::StatusCode::kInvalidArgument, - "`log_domain_size` must be <= 128")); -} - -TEST_F(ProtoValidatorTest, CreateFailsWhenElementBitsizeNegative) { - parameters_.resize(1); - parameters_[0].mutable_value_type()->mutable_integer()->set_bitsize(-1); - - EXPECT_THAT(ProtoValidator::Create(parameters_), - StatusIs(absl::StatusCode::kInvalidArgument, - "`bitsize` must be positive")); -} - -TEST_F(ProtoValidatorTest, CreateFailsWhenElementBitsizeZero) { - parameters_.resize(1); - parameters_[0].mutable_value_type()->mutable_integer()->set_bitsize(0); - - EXPECT_THAT(ProtoValidator::Create(parameters_), - StatusIs(absl::StatusCode::kInvalidArgument, - "`bitsize` must be positive")); -} - -TEST_F(ProtoValidatorTest, CreateFailsWhenElementBitsizeTooLarge) { - parameters_.resize(1); - parameters_[0].mutable_value_type()->mutable_integer()->set_bitsize(256); - - EXPECT_THAT(ProtoValidator::Create(parameters_), - StatusIs(absl::StatusCode::kInvalidArgument, - "`bitsize` must be less than or equal to 128")); -} - -TEST_F(ProtoValidatorTest, CreateFailsWhenElementBitsizeNotAPowerOfTwo) { - parameters_.resize(1); - parameters_[0].mutable_value_type()->mutable_integer()->set_bitsize(23); - - EXPECT_THAT(ProtoValidator::Create(parameters_), - StatusIs(absl::StatusCode::kInvalidArgument, - "`bitsize` must be a power of 2")); -} - -TEST_F(ProtoValidatorTest, CreateFailsIfSecurityParameterIsNaN) { - parameters_.resize(1); - parameters_[0].set_security_parameter(std::nan("")); - - EXPECT_THAT(ProtoValidator::Create(parameters_), - StatusIs(absl::StatusCode::kInvalidArgument, - "`security_parameter` must not be NaN")); -} - -TEST_F(ProtoValidatorTest, CreateFailsIfSecurityParameterIsNegative) { - parameters_.resize(1); - parameters_[0].set_security_parameter(-0.01); - - EXPECT_THAT(ProtoValidator::Create(parameters_), - StatusIs(absl::StatusCode::kInvalidArgument, - "`security_parameter` must be in [0, 128]")); -} - -TEST_F(ProtoValidatorTest, CreateFailsIfSecurityParameterIsTooLarge) { - parameters_.resize(1); - parameters_[0].set_security_parameter(128.01); - - EXPECT_THAT(ProtoValidator::Create(parameters_), - StatusIs(absl::StatusCode::kInvalidArgument, - "`security_parameter` must be in [0, 128]")); -} - -TEST_F(ProtoValidatorTest, CreateWorksWhenElementBitsizesDecrease) { - parameters_.resize(2); - parameters_[0].mutable_value_type()->mutable_integer()->set_bitsize(64); - parameters_[1].mutable_value_type()->mutable_integer()->set_bitsize(32); - - EXPECT_THAT(ProtoValidator::Create(parameters_), IsOkAndHolds(Ne(nullptr))); -} - -TEST_F(ProtoValidatorTest, CreateWorksWhenHierarchiesAreFarApart) { - parameters_.resize(2); - parameters_[0].set_log_domain_size(10); - parameters_[1].set_log_domain_size(128); - - EXPECT_THAT(ProtoValidator::Create(parameters_), IsOkAndHolds(Ne(nullptr))); -} - -TEST_F(ProtoValidatorTest, - ValidateDpfKeyFailsIfNumberOfCorrectionWordsDoesntMatch) { - dpf_key_.add_correction_words(); - - EXPECT_THAT(proto_validator_->ValidateDpfKey(dpf_key_), - StatusIs(absl::StatusCode::kInvalidArgument, - absl::StrCat("Malformed DpfKey: expected ", - dpf_key_.correction_words_size() - 1, - " correction words, but got ", - dpf_key_.correction_words_size()))); -} - -TEST_F(ProtoValidatorTest, ValidateDpfKeyFailsIfSeedIsMissing) { - dpf_key_.clear_seed(); - - EXPECT_THAT( - proto_validator_->ValidateDpfKey(dpf_key_), - StatusIs(absl::StatusCode::kInvalidArgument, "key.seed must be present")); -} - -TEST_F(ProtoValidatorTest, - ValidateDpfKeyFailsIfLastLevelOutputCorrectionIsMissing) { - dpf_key_.clear_last_level_value_correction(); - - EXPECT_THAT(proto_validator_->ValidateDpfKey(dpf_key_), - StatusIs(absl::StatusCode::kInvalidArgument, - "key.last_level_value_correction must be present")); -} - -TEST_F(ProtoValidatorTest, ValidateDpfKeyFailsIfOutputCorrectionIsMissing) { - for (CorrectionWord& cw : *(dpf_key_.mutable_correction_words())) { - cw.clear_value_correction(); - } - - EXPECT_THAT( - proto_validator_->ValidateDpfKey(dpf_key_), - StatusIs(absl::StatusCode::kInvalidArgument, - StartsWith("Malformed DpfKey: expected correction_words"))); -} - -TEST_F(ProtoValidatorTest, ValidateEvaluationContextFailsIfKeyIsMissing) { - ctx_.clear_key(); - - EXPECT_THAT( - proto_validator_->ValidateEvaluationContext(ctx_), - StatusIs(absl::StatusCode::kInvalidArgument, "ctx.key must be present")); -} - -TEST_F(ProtoValidatorTest, - ValidateEvaluationContextFailsIfParameterSizeDoesntMatch) { - ctx_.mutable_parameters()->erase(ctx_.parameters().end() - 1); - - EXPECT_THAT(proto_validator_->ValidateEvaluationContext(ctx_), - StatusIs(absl::StatusCode::kInvalidArgument, - "Number of parameters in `ctx` doesn't match")); -} - -TEST_F(ProtoValidatorTest, - ValidateEvaluationContextFailsIfLogDomainSizeDoesntMatch) { - ctx_.mutable_parameters(0)->set_log_domain_size( - ctx_.parameters(0).log_domain_size() + 1); - - EXPECT_THAT(proto_validator_->ValidateEvaluationContext(ctx_), - StatusIs(absl::StatusCode::kInvalidArgument, - "Parameter 0 in `ctx` doesn't match")); -} - -TEST_F(ProtoValidatorTest, - ValidateEvaluationContextSucceedsIfSecurityParameterIsDefault) { - parameters_[0].set_security_parameter(0); - DPF_ASSERT_OK_AND_ASSIGN(proto_validator_, - ProtoValidator::Create(parameters_)); - - ctx_.mutable_parameters(0)->set_security_parameter(0); - - EXPECT_THAT(proto_validator_->ValidateEvaluationContext(ctx_), IsOk()); -} - -TEST_F(ProtoValidatorTest, - ValidateEvaluationContextFailsIfSecurityParameterDoesntMatch) { - ctx_.mutable_parameters(0)->set_security_parameter( - ctx_.parameters(0).security_parameter() + 1); - - EXPECT_THAT(proto_validator_->ValidateEvaluationContext(ctx_), - StatusIs(absl::StatusCode::kInvalidArgument, - "Parameter 0 in `ctx` doesn't match")); -} - -TEST_F(ProtoValidatorTest, - ValidateEvaluationContextFailsIfContextFullyEvaluated) { - ctx_.set_previous_hierarchy_level(parameters_.size() - 1); - - EXPECT_THAT(proto_validator_->ValidateEvaluationContext(ctx_), - StatusIs(absl::StatusCode::kInvalidArgument, - "This context has already been fully evaluated")); -} - -TEST_F(ProtoValidatorTest, - ValidateEvaluationContextFailsIfPartialEvaluationsLevelTooLarge) { - ctx_.set_previous_hierarchy_level(0); - ctx_.set_partial_evaluations_level(1); - ctx_.add_partial_evaluations(); - - EXPECT_THAT( - proto_validator_->ValidateEvaluationContext(ctx_), - StatusIs(absl::StatusCode::kInvalidArgument, - "ctx.partial_evaluations_level must be less than or equal to " - "ctx.previous_hierarchy_level")); -} - -TEST_F(ProtoValidatorTest, ValidateValueFailsIfTypeNotInteger) { - ValueType type; - type.mutable_integer()->set_bitsize(32); - Value value; - value.mutable_tuple()->add_elements()->mutable_integer()->set_value_uint64( - 23); - - EXPECT_THAT( - proto_validator_->ValidateValue(value, type), - StatusIs(absl::StatusCode::kInvalidArgument, "Expected integer value")); -} - -TEST_F(ProtoValidatorTest, ValidateValueFailsIfIntegerTooLarge) { - ValueType type; - Value value; - - int element_bitsize = 32; - type.mutable_integer()->set_bitsize(element_bitsize); - auto value_64 = uint64_t{1} << element_bitsize; - value.mutable_integer()->set_value_uint64(value_64); - - EXPECT_THAT( - proto_validator_->ValidateValue(value, type), - StatusIs(absl::StatusCode::kInvalidArgument, - absl::StrFormat( - "Value (= %d) too large for ValueType with bitsize = %d", - value_64, element_bitsize))); -} - -TEST_F(ProtoValidatorTest, ValidateValueFailsIfTypeNotTuple) { - ValueType type; - type.mutable_tuple()->add_elements()->mutable_integer()->set_bitsize(32); - Value value; - value.mutable_integer()->set_value_uint64(23); - - EXPECT_THAT( - proto_validator_->ValidateValue(value, type), - StatusIs(absl::StatusCode::kInvalidArgument, "Expected tuple value")); -} - -TEST_F(ProtoValidatorTest, ValidateValueFailsIfTupleSizeDoesntMatch) { - ValueType type; - type.mutable_tuple()->add_elements()->mutable_integer()->set_bitsize(32); - Value value; - - value.mutable_tuple()->add_elements()->mutable_integer()->set_value_uint64( - 23); - value.mutable_tuple()->add_elements()->mutable_integer()->set_value_uint64( - 42); - - EXPECT_THAT(proto_validator_->ValidateValue(value, type), - StatusIs(absl::StatusCode::kInvalidArgument, - "Expected tuple value of size 1 but got size 2")); -} - -TEST_F(ProtoValidatorTest, ValidateValueFailsIfValueLargerThanModulus) { - constexpr uint64_t kModulus = 3; - ValueType type; - type.mutable_int_mod_n()->mutable_base_integer()->set_bitsize(64); - type.mutable_int_mod_n()->mutable_modulus()->set_value_uint64(kModulus); - Value value; - - value.mutable_int_mod_n()->set_value_uint64(kModulus); - - EXPECT_THAT(proto_validator_->ValidateValue(value, type), - StatusIs(absl::StatusCode::kInvalidArgument, - "Value (= 3) is too large for modulus (= 3)")); -} - -TEST_F(ProtoValidatorTest, ValidateValueFailsIfTypeNotXorWrapper) { - ValueType type; - type.mutable_xor_wrapper()->set_bitsize(32); - Value value; - value.mutable_integer()->set_value_uint64(23); - - EXPECT_THAT(proto_validator_->ValidateValue(value, type), - StatusIs(absl::StatusCode::kInvalidArgument, - "Expected XorWrapper value")); -} - -TEST_F(ProtoValidatorTest, ValidateValueFailsIfValueIsUnknown) { - ValueType type; - Value value; - - EXPECT_THAT( - proto_validator_->ValidateValue(value, type), - StatusIs(absl::StatusCode::kInvalidArgument, - testing::StartsWith("ValidateValue: Unsupported ValueType:"))); -} - -TEST(ProtoValidator, ValidateValueTypeFailsIfBitsizeNotPositive) { - ValueType type; - - type.mutable_integer()->set_bitsize(0); - - EXPECT_THAT(ProtoValidator::ValidateValueType(type), - StatusIs(absl::StatusCode::kInvalidArgument, - "`bitsize` must be positive")); -} - -TEST(ProtoValidator, ValidateValueTypeFailsIfBitsizeTooLarge) { - ValueType type; - - type.mutable_integer()->set_bitsize(256); - - EXPECT_THAT(ProtoValidator::ValidateValueType(type), - StatusIs(absl::StatusCode::kInvalidArgument, - "`bitsize` must be less than or equal to 128")); -} - -TEST(ProtoValidator, ValidateValueTypeFailsIfBitsizeNotPowerOfTwo) { - ValueType type; - - type.mutable_integer()->set_bitsize(17); - - EXPECT_THAT(ProtoValidator::ValidateValueType(type), - StatusIs(absl::StatusCode::kInvalidArgument, - "`bitsize` must be a power of 2")); -} - -TEST(ProtoValidator, ValidateValueTypeFailsIfNoTypeChosen) { - ValueType type; - - EXPECT_THAT(ProtoValidator::ValidateValueType(type), - StatusIs(absl::StatusCode::kInvalidArgument, - StartsWith("ValidateValueType: Unsupported ValueType"))); -} - -} // namespace -} // namespace dpf_internal -} // namespace distributed_point_functions diff --git a/third_party/distributed_point_functions/code/dpf/internal/proto_validator_test.textproto b/third_party/distributed_point_functions/code/dpf/internal/proto_validator_test.textproto deleted file mode 100644 index 5e411d7eb7fbb..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/proto_validator_test.textproto +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# proto-file dpf/distributed_point_function.proto -# proto-message: EvaluationContext - -parameters { - log_domain_size: 4 - value_type { - integer { - bitsize: 32 - } - } - security_parameter: 44 -} -parameters { - log_domain_size: 6 - value_type { - integer { - bitsize: 32 - } - } - security_parameter: 46 -} -parameters { - log_domain_size: 8 - value_type { - integer { - bitsize: 32 - } - } - security_parameter: 48 -} -key { - seed { - high: 11559904407150645412 - low: 10793182457266619527 - } - correction_words { - seed { - high: 17231204231811741091 - low: 13184625655696690000 - } - control_left: true - } - correction_words { - seed { - high: 3072212389250066354 - low: 1361245143349174348 - } - } - correction_words { - seed { - high: 2882988684359810666 - low: 16992210518729579018 - } - control_right: true - value_correction: { - integer: { - value_uint64: 536412310 - } - } - } - correction_words { - seed { - high: 4993590839844520517 - low: 13033365507284852634 - } - control_right: true - } - correction_words { - seed { - high: 10673753674550143002 - low: 3019916643383017704 - } - control_left: true - control_right: true - value_correction: { - integer: { - value_uint64: 841224518 - } - } - } - correction_words { - seed { - high: 2423099213299230757 - low: 12788496417753523946 - } - control_right: true - } - last_level_value_correction: { - integer: { - value_uint64: 8471844854 - } - } -} -previous_hierarchy_level: -1 diff --git a/third_party/distributed_point_functions/code/dpf/internal/status_matchers.cc b/third_party/distributed_point_functions/code/dpf/internal/status_matchers.cc deleted file mode 100644 index 62317bc89817a..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/status_matchers.cc +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dpf/internal/status_matchers.h" - -#include <ostream> -#include <string> - -#include "absl/status/status.h" -#include "absl/strings/string_view.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace distributed_point_functions { -namespace dpf_internal { - -void StatusIsMatcherCommonImpl::DescribeTo(std::ostream* os) const { - *os << "has a status code that "; - code_matcher_.DescribeTo(os); - *os << ", and has an error message that "; - message_matcher_.DescribeTo(os); -} - -void StatusIsMatcherCommonImpl::DescribeNegationTo(std::ostream* os) const { - *os << "has a status code that "; - code_matcher_.DescribeNegationTo(os); - *os << ", or has an error message that "; - message_matcher_.DescribeNegationTo(os); -} - -bool StatusIsMatcherCommonImpl::MatchAndExplain( - const ::absl::Status& status, - ::testing::MatchResultListener* result_listener) const { - ::testing::StringMatchResultListener inner_listener; - if (!code_matcher_.MatchAndExplain(status.code(), &inner_listener)) { - *result_listener << (inner_listener.str().empty() - ? "whose status code is wrong" - : "which has a status code " + - inner_listener.str()); - return false; - } - - if (!message_matcher_.Matches(std::string(status.message()))) { - *result_listener << "whose error message is wrong: " << status.message(); - return false; - } - - return true; -} - -} // namespace dpf_internal -} // namespace distributed_point_functions diff --git a/third_party/distributed_point_functions/code/dpf/internal/status_matchers.h b/third_party/distributed_point_functions/code/dpf/internal/status_matchers.h deleted file mode 100644 index fa15a0243489b..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/status_matchers.h +++ /dev/null @@ -1,390 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Testing utilities for working with absl::Status and absl::StatusOr. -// -// Defines the following utilities: -// -// ================= -// DPF_EXPECT_OK(s) -// -// DPF_ASSERT_OK(s) -// ================= -// Convenience macros for `EXPECT_THAT(s, IsOk())`, where `s` is either -// a `Status` or a `StatusOr<T>`. -// -// There are no EXPECT_NOT_OK/ASSERT_NOT_OK macros since they would not -// provide much value (when they fail, they would just print the OK status -// which conveys no more information than `EXPECT_FALSE(s.ok())`. You can -// of course use `EXPECT_THAT(s, Not(IsOk()))` if you prefer _THAT style. -// -// If you want to check for particular errors, better alternatives are: -// EXPECT_THAT(s, StatusIs(expected_error)); -// EXPECT_THAT(s, StatusIs(_, HasSubstr("expected error"))); -// -// =============== -// IsOkAndHolds(m) -// =============== -// -// This gMock matcher matches a StatusOr<T> value whose status is OK -// and whose inner value matches matcher m. Example: -// -// using ::testing::MatchesRegex; -// using distributed_point_functions::IsOkAndHolds; -// ... -// absl::StatusOr<string> maybe_name = ...; -// EXPECT_THAT(maybe_name, IsOkAndHolds(MatchesRegex("John .*"))); -// -// =============================== -// StatusIs(status_code_matcher, -// error_message_matcher) -// =============================== -// -// This gMock matcher matches a Status or StatusOr<T> value if all of the -// following are true: -// -// - the status' error_code() matches status_code_matcher, and -// - the status' error_message() matches error_message_matcher. -// -// Example: -// -// enum FooErrorCode { -// ... -// kServerError -// }; -// -// using ::testing::HasSubstr; -// using ::testing::MatchesRegex; -// using ::testing::Ne; -// using ::testing::_; -// using distributed_point_functions::StatusIs; -// absl::StatusOr<string> GetName(int id); -// ... -// -// // The status code must be kServerError; the error message can be -// // anything. -// EXPECT_THAT(GetName(42), -// StatusIs(kServerError, _)); -// // The status code can be anything; the error message must match the -// // regex. -// EXPECT_THAT(GetName(43), -// StatusIs(_, MatchesRegex("server.*time-out"))); -// -// // The status code should not be kServerError; the error message can be -// // anything with "client" in it. -// EXPECT_CALL(mock_env, HandleStatus( -// StatusIs(Ne(kServerError), HasSubstr("client")))); -// -// =============================== -// StatusIs(status_code_matcher) -// =============================== -// -// This is a shorthand for -// StatusIs(status_code_matcher, -// testing::_) -// In other words, it's like the two-argument StatusIs(), except that it -// ignores error message. -// -// =============== -// IsOk() -// =============== -// -// Matches an absl::Status or absl::StatusOr<T> value whose status value is -// absl::StatusCode::kOk. Equivalent to 'StatusIs(absl::StatusCode::kOk)'. -// Example: -// using distributed_point_functions::IsOk; -// ... -// absl::StatusOr<string> maybe_name = ...; -// EXPECT_THAT(maybe_name, IsOk()); -// Status s = ...; -// EXPECT_THAT(s, IsOk()); -// - -#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MATCHERS_H_ -#define DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MATCHERS_H_ - -#include <ostream> -#include <string> -#include <type_traits> -#include <utility> - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "dpf/status_macros.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace distributed_point_functions { -namespace dpf_internal { - -inline const absl::Status& GetStatus(const absl::Status& status) { - return status; -} - -template <typename T> -inline const absl::Status& GetStatus(const absl::StatusOr<T>& status) { - return status.status(); -} - -//////////////////////////////////////////////////////////// -// Implementation of IsOkAndHolds(). - -// Monomorphic implementation of matcher IsOkAndHolds(m). StatusOrType is a -// reference to StatusOr<T>. -template <typename StatusOrType> -class IsOkAndHoldsMatcherImpl - : public ::testing::MatcherInterface<StatusOrType> { - public: - typedef - typename std::remove_reference<StatusOrType>::type::value_type value_type; - - template <typename InnerMatcher> - explicit IsOkAndHoldsMatcherImpl(InnerMatcher&& inner_matcher) - : inner_matcher_(::testing::SafeMatcherCast<const value_type&>( - std::forward<InnerMatcher>(inner_matcher))) {} - - void DescribeTo(std::ostream* os) const override { - *os << "is OK and has a value that "; - inner_matcher_.DescribeTo(os); - } - - void DescribeNegationTo(std::ostream* os) const override { - *os << "isn't OK or has a value that "; - inner_matcher_.DescribeNegationTo(os); - } - - bool MatchAndExplain( - StatusOrType actual_value, - ::testing::MatchResultListener* result_listener) const override { - if (!actual_value.ok()) { - *result_listener << "which has status " << actual_value.status(); - return false; - } - - ::testing::StringMatchResultListener inner_listener; - const bool matches = - inner_matcher_.MatchAndExplain(*actual_value, &inner_listener); - const std::string inner_explanation = inner_listener.str(); - if (!inner_explanation.empty()) { - *result_listener << "which contains value " - << ::testing::PrintToString(*actual_value) << ", " - << inner_explanation; - } - return matches; - } - - private: - const ::testing::Matcher<const value_type&> inner_matcher_; -}; - -// Implements IsOkAndHolds(m) as a polymorphic matcher. -template <typename InnerMatcher> -class IsOkAndHoldsMatcher { - public: - explicit IsOkAndHoldsMatcher(InnerMatcher inner_matcher) - : inner_matcher_(std::move(inner_matcher)) {} - - // Converts this polymorphic matcher to a monomorphic matcher of the - // given type. StatusOrType can be either StatusOr<T> or a - // reference to StatusOr<T>. - template <typename StatusOrType> - operator ::testing::Matcher<StatusOrType>() const { // NOLINT - return ::testing::Matcher<StatusOrType>( - new IsOkAndHoldsMatcherImpl<const StatusOrType&>(inner_matcher_)); - } - - private: - const InnerMatcher inner_matcher_; -}; - -//////////////////////////////////////////////////////////// -// Implementation of StatusIs(). - -// StatusIs() is a polymorphic matcher. This class is the common -// implementation of it shared by all types T where StatusIs() can be -// used as a Matcher<T>. -class StatusIsMatcherCommonImpl { - public: - StatusIsMatcherCommonImpl( - ::testing::Matcher<absl::StatusCode> code_matcher, - ::testing::Matcher<const std::string&> message_matcher) - : code_matcher_(std::move(code_matcher)), - message_matcher_(std::move(message_matcher)) {} - - void DescribeTo(std::ostream* os) const; - - void DescribeNegationTo(std::ostream* os) const; - - bool MatchAndExplain(const absl::Status& status, - ::testing::MatchResultListener* result_listener) const; - - private: - const ::testing::Matcher<absl::StatusCode> code_matcher_; - const ::testing::Matcher<const std::string&> message_matcher_; -}; - -// Monomorphic implementation of matcher StatusIs() for a given type -// T. T can be Status, StatusOr<>, or a reference to either of them. -template <typename T> -class MonoStatusIsMatcherImpl : public ::testing::MatcherInterface<T> { - public: - explicit MonoStatusIsMatcherImpl(StatusIsMatcherCommonImpl common_impl) - : common_impl_(std::move(common_impl)) {} - - void DescribeTo(std::ostream* os) const override { - common_impl_.DescribeTo(os); - } - - void DescribeNegationTo(std::ostream* os) const override { - common_impl_.DescribeNegationTo(os); - } - - bool MatchAndExplain( - T actual_value, - ::testing::MatchResultListener* result_listener) const override { - return common_impl_.MatchAndExplain(GetStatus(actual_value), - result_listener); - } - - private: - StatusIsMatcherCommonImpl common_impl_; -}; - -// Implements StatusIs() as a polymorphic matcher. -class StatusIsMatcher { - public: - template <typename StatusCodeMatcher, typename StatusMessageMatcher> - StatusIsMatcher(StatusCodeMatcher&& code_matcher, - StatusMessageMatcher&& message_matcher) - : common_impl_(::testing::MatcherCast<absl::StatusCode>( - std::forward<StatusCodeMatcher>(code_matcher)), - ::testing::MatcherCast<const std::string&>( - std::forward<StatusMessageMatcher>(message_matcher))) { - } - - // Converts this polymorphic matcher to a monomorphic matcher of the - // given type. T can be StatusOr<>, Status, or a reference to - // either of them. - template <typename T> - operator ::testing::Matcher<T>() const { // NOLINT - return ::testing::Matcher<T>( - new MonoStatusIsMatcherImpl<const T&>(common_impl_)); - } - - private: - const StatusIsMatcherCommonImpl common_impl_; -}; - -// Monomorphic implementation of matcher IsOk() for a given type T. -// T can be Status, StatusOr<>, or a reference to either of them. -template <typename T> -class MonoIsOkMatcherImpl : public ::testing::MatcherInterface<T> { - public: - void DescribeTo(std::ostream* os) const override { *os << "is OK"; } - void DescribeNegationTo(std::ostream* os) const override { - *os << "is not OK"; - } - bool MatchAndExplain( - T actual_value, - ::testing::MatchResultListener* result_listener) const override { - if (!actual_value.ok()) { - *result_listener << "whose status is " - << GetStatus(actual_value).message(); - return false; - } - return true; - } -}; - -// Implements IsOk() as a polymorphic matcher. -class IsOkMatcher { - public: - template <typename T> - operator ::testing::Matcher<T>() const { // NOLINT - return ::testing::Matcher<T>(new MonoIsOkMatcherImpl<const T&>()); - } -}; - -// Macros for testing the results of functions that return absl::Status or -// absl::StatusOr<T> (for any type T). -#define DPF_EXPECT_OK(expression) \ - EXPECT_THAT(expression, distributed_point_functions::dpf_internal::IsOk()) -#define DPF_ASSERT_OK(expression) \ - ASSERT_THAT(expression, distributed_point_functions::dpf_internal::IsOk()) - -// Executes an expression that returns an absl::StatusOr, and assigns the -// contained variable to lhs if the error code is OK. -// If the Status is non-OK, generates a test failure and returns from the -// current function, which must have a void return type. -// -// Example: Declaring and initializing a new value -// DPF_ASSERT_OK_AND_ASSIGN(const ValueType& value, MaybeGetValue(arg)); -// -// Example: Assigning to an existing value -// ValueType value; -// DPF_ASSERT_OK_AND_ASSIGN(value, MaybeGetValue(arg)); -// -// The value assignment example would expand into something like: -// auto status_or_value = MaybeGetValue(arg); -// DPF_ASSERT_OK(status_or_value.status()); -// value = std::move(status_or_value).ValueOrDie(); -// -// WARNING: Like ASSIGN_OR_RETURN, DPF_ASSERT_OK_AND_ASSIGN expands into -// multiple statements; it cannot be used in a single statement (e.g. as the -// body of an if statement without {})! -#define DPF_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ - DPF_ASSERT_OK_AND_ASSIGN_IMPL_( \ - DPF_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr) - -#define DPF_ASSERT_OK_AND_ASSIGN_IMPL_(statusor, lhs, rexpr) \ - auto statusor = (rexpr); \ - DPF_ASSERT_OK(statusor); \ - lhs = std::move(statusor).value(); - -// Returns a gMock matcher that matches a StatusOr<> whose status is -// OK and whose value matches the inner matcher. -template <typename InnerMatcher> -dpf_internal::IsOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type> -IsOkAndHolds(InnerMatcher&& inner_matcher) { - return dpf_internal::IsOkAndHoldsMatcher< - typename std::decay<InnerMatcher>::type>( - std::forward<InnerMatcher>(inner_matcher)); -} - -// Returns a gMock matcher that matches a Status or StatusOr<> whose status code -// matches code_matcher, and whose error message matches message_matcher. -template <typename StatusCodeMatcher, typename StatusMessageMatcher> -dpf_internal::StatusIsMatcher StatusIs(StatusCodeMatcher&& code_matcher, - StatusMessageMatcher&& message_matcher) { - return dpf_internal::StatusIsMatcher( - std::forward<StatusCodeMatcher>(code_matcher), - std::forward<StatusMessageMatcher>(message_matcher)); -} - -// Returns a gMock matcher that matches a Status or StatusOr<> whose status code -// matches code_matcher. -template <typename StatusCodeMatcher> -dpf_internal::StatusIsMatcher StatusIs(StatusCodeMatcher&& code_matcher) { - return StatusIs(std::forward<StatusCodeMatcher>(code_matcher), ::testing::_); -} - -// Returns a gMock matcher that matches a Status or StatusOr<> which is OK. -inline dpf_internal::IsOkMatcher IsOk() { return dpf_internal::IsOkMatcher(); } - -} // namespace dpf_internal -} // namespace distributed_point_functions - -#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MATCHERS_H_ diff --git a/third_party/distributed_point_functions/code/dpf/internal/value_type_helpers.cc b/third_party/distributed_point_functions/code/dpf/internal/value_type_helpers.cc deleted file mode 100644 index 0704cceedaeaf..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/value_type_helpers.cc +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dpf/internal/value_type_helpers.h" - -#include <stdint.h> - -#include <cmath> -#include <string> - -#include "absl/numeric/int128.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "dpf/distributed_point_function.pb.h" -#include "dpf/int_mod_n.h" -#include "dpf/status_macros.h" - -namespace distributed_point_functions { -namespace dpf_internal { - -absl::StatusOr<bool> ValueTypesAreEqual(const ValueType& lhs, - const ValueType& rhs) { - if (lhs.type_case() == ValueType::TypeCase::TYPE_NOT_SET || - rhs.type_case() == ValueType::TypeCase::TYPE_NOT_SET) { - return absl::InvalidArgumentError( - "Both arguments must be valid ValueTypes"); - } else if (lhs.type_case() == ValueType::kInteger && - rhs.type_case() == ValueType::kInteger) { - return lhs.integer().bitsize() == rhs.integer().bitsize(); - } else if (lhs.type_case() == ValueType::kTuple && - rhs.type_case() == ValueType::kTuple && - lhs.tuple().elements_size() == rhs.tuple().elements_size()) { - bool result = true; - for (int i = 0; i < static_cast<int>(lhs.tuple().elements_size()); ++i) { - DPF_ASSIGN_OR_RETURN( - bool element_result, - ValueTypesAreEqual(lhs.tuple().elements(i), rhs.tuple().elements(i))); - result &= element_result; - } - return result; - } else if (lhs.type_case() == ValueType::kIntModN && - rhs.type_case() == ValueType::kIntModN) { - const Value::Integer &lhs_modulus = lhs.int_mod_n().modulus(), - &rhs_modulus = rhs.int_mod_n().modulus(); - DPF_ASSIGN_OR_RETURN(absl::uint128 lhs_modulus_128, - ValueIntegerToUint128(lhs_modulus)); - DPF_ASSIGN_OR_RETURN(absl::uint128 rhs_modulus_128, - ValueIntegerToUint128(rhs_modulus)); - return lhs.int_mod_n().base_integer().bitsize() == - rhs.int_mod_n().base_integer().bitsize() && - lhs_modulus_128 == rhs_modulus_128; - } else if (lhs.type_case() == ValueType::kXorWrapper && - rhs.type_case() == ValueType::kXorWrapper) { - return lhs.xor_wrapper().bitsize() == rhs.xor_wrapper().bitsize(); - } - return false; -} - -absl::StatusOr<int> BitsNeeded(const ValueType& value_type, - double security_parameter) { - if (value_type.type_case() == ValueType::kInteger) { - return value_type.integer().bitsize(); - } else if (value_type.type_case() == ValueType::kTuple) { - // We handle elements of type IntModN separately, since we can sample them - // together. - int num_ints_mod_n = 0; - int num_other = 0; - const ValueType* int_mod_n = nullptr; - int bitsize_ints_mod_n = 0; - int bitsize_other = 0; - for (const ValueType& el : value_type.tuple().elements()) { - if (el.type_case() == ValueType::kIntModN) { - // Element is integer mod N -> check if it is the same as the others in - // this tuple and increase counter. - if (!int_mod_n) { - int_mod_n = ⪙ - } else { - absl::StatusOr<bool> types_are_equal = - ValueTypesAreEqual(el, *int_mod_n); - if (!types_are_equal.ok()) { - return types_are_equal.status(); - } - if (!*types_are_equal) { - return absl::UnimplementedError( - "All elements of type IntModN in a tuple must be the same"); - } - } - ++num_ints_mod_n; - } else { - ++num_other; - } - } - if (num_other > 0) { - for (int i = 0; i < num_other; ++i) { - double per_element_security_parameter = - security_parameter + std::log2(static_cast<double>(num_other)); - DPF_ASSIGN_OR_RETURN(int el_bitsize, - BitsNeeded(value_type.tuple().elements(i), - per_element_security_parameter)); - bitsize_other += el_bitsize; - } - } - if (num_ints_mod_n > 0) { - DPF_ASSIGN_OR_RETURN( - absl::uint128 modulus, - ValueIntegerToUint128(int_mod_n->int_mod_n().modulus())); - DPF_ASSIGN_OR_RETURN( - int64_t bytes_needed_ints_mod_n, - dpf_internal::IntModNBase::GetNumBytesRequired( - num_ints_mod_n, int_mod_n->int_mod_n().base_integer().bitsize(), - modulus, security_parameter)); - bitsize_ints_mod_n = bytes_needed_ints_mod_n * 8; - } - return bitsize_ints_mod_n + bitsize_other; - } else if (value_type.type_case() == ValueType::kIntModN) { - DPF_ASSIGN_OR_RETURN( - absl::uint128 modulus, - ValueIntegerToUint128(value_type.int_mod_n().modulus())); - DPF_ASSIGN_OR_RETURN(int64_t bytes_needed_ints_mod_n, - dpf_internal::IntModNBase::GetNumBytesRequired( - 1, value_type.int_mod_n().base_integer().bitsize(), - modulus, security_parameter)); - return 8 * bytes_needed_ints_mod_n; - } else if (value_type.type_case() == ValueType::kXorWrapper) { - return value_type.xor_wrapper().bitsize(); - } - return absl::InvalidArgumentError(absl::StrCat( - "BitsNeeded: Unsupported ValueType:\n", value_type.DebugString())); -} - -// Integer Helpers - -Value::Integer Uint128ToValueInteger(absl::uint128 input) { - Value::Integer result; - if (absl::Uint128High64(input) == 0) { - result.set_value_uint64(absl::Uint128Low64(input)); - } else { - Block& block = *(result.mutable_value_uint128()); - block.set_high(absl::Uint128High64(input)); - block.set_low(absl::Uint128Low64(input)); - } - return result; -} - -absl::StatusOr<absl::uint128> ValueIntegerToUint128(const Value::Integer& in) { - if (in.value_case() == Value::Integer::kValueUint128) { - return absl::MakeUint128(in.value_uint128().high(), - in.value_uint128().low()); - } else if (in.value_case() == Value::Integer::kValueUint64) { - return in.value_uint64(); - } - return absl::InvalidArgumentError( - "Unknown value case for the given integer Value"); -} - -} // namespace dpf_internal -} // namespace distributed_point_functions diff --git a/third_party/distributed_point_functions/code/dpf/internal/value_type_helpers.h b/third_party/distributed_point_functions/code/dpf/internal/value_type_helpers.h deleted file mode 100644 index 1e283487d07e3..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/value_type_helpers.h +++ /dev/null @@ -1,673 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_VALUE_TYPE_HELPERS_H_ -#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_VALUE_TYPE_HELPERS_H_ - -#include <stdint.h> - -#include <algorithm> -#include <array> -#include <limits> -#include <string> -#include <tuple> -#include <type_traits> -#include <vector> - -#include "absl/base/config.h" -#include "absl/log/absl_check.h" -#include "absl/meta/type_traits.h" -#include "absl/numeric/int128.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "absl/utility/utility.h" -#include "dpf/distributed_point_function.pb.h" -#include "dpf/int_mod_n.h" -#include "dpf/tuple.h" -#include "dpf/xor_wrapper.h" -#include "google/protobuf/repeated_field.h" - -// Contains a collection of helper functions for different DPF value types. This -// includes functions for converting between Value protos and the corresponding -// C++ objects, as well as functions for sampling values from uniformly random -// byte strings. -// -// This file contains the templated declarations, instantiations for all -// supported types, as well as type-independent function declarations. -namespace distributed_point_functions { -namespace dpf_internal { - -// A helper struct containing declarations for all templated functions we need. -// This is needed since C++ doesn't support partial function template -// specialization, and should be specialized for all supported types. -template <typename T, typename = void> -struct ValueTypeHelper { - // General type traits and conversion functions. Should be implemented by all - // types. - - // Type trait for all supported types. Used to provide meaningful error - // messages in std::enable_if template guards. - static constexpr bool IsSupportedType() { return false; } - - // Checks if the template parameter can be converted directly from a string of - // bytes. - static constexpr bool CanBeConvertedDirectly(); - - // Converts a given Value to the template parameter T. - static absl::StatusOr<T> FromValue(const Value& value); - - // ToValue Converts the argument to a Value proto. - static Value ToValue(const T& input); - - // ToValueType<T> Returns a `ValueType` message describing T. - static ValueType ToValueType(); - - // Functions for converting from a byte string to T. There are two approaches: - // Either converting directly (i.e., each byte is copied 1-to-1 into the - // result), or by sampling (when a direct conversion is not possible). Types - // for which CanBeConvertedDirectly() can be true should implement the former, - // and all types should implement the latter (to support types composed of - // directly-convertible and not-directly-convertible types). - - // Functions for direct conversions from bytes. Should be implemented when - // CanBeConvertedDirectly() can be true. - - // Returns the total number of bits in a T. - static constexpr int TotalBitSize(); - - static T DirectlyFromBytes(absl::string_view bytes); - - // Functions for sampling from a string of bytes. Should be implemented by all - // types. - - // Converts `block` to type T. Then, if `update == true`, fills up `block` - // from `remaining_bytes` and advances `remaining_bytes` by the amount of - // bytes read. - static T SampleAndUpdateBytes(bool update, absl::uint128& block, - absl::string_view& remaining_bytes); -}; - -/******************************************************************************/ -// Type traits // -/******************************************************************************/ - -// Type trait for all supported types. Used to provide meaningful error messages -// in std::enable_if template guards. -template <typename T> -struct is_supported_type { - static constexpr bool value = - dpf_internal::ValueTypeHelper<T>::IsSupportedType(); -}; -template <typename T> -constexpr bool is_supported_type_v = is_supported_type<T>::value; - -// Checks if the template parameter can be converted directly from a string of -// bytes. -template <typename T> -struct can_be_converted_directly { - static constexpr bool value = - dpf_internal::ValueTypeHelper<T>::CanBeConvertedDirectly(); -}; -template <typename T> -constexpr bool can_be_converted_directly_v = - can_be_converted_directly<T>::value; - -// Returns the total number of bits in a T. -template <typename T, - typename = absl::enable_if_t<can_be_converted_directly_v<T>>> -static constexpr int TotalBitSize() { - return ValueTypeHelper<T>::TotalBitSize(); -} - -/******************************************************************************/ -// Integer Helpers // -/******************************************************************************/ - -// Type trait for all integer types we support, i.e., 8 to 128 bit types. -template <typename T> -using is_unsigned_integer = - absl::disjunction<std::is_same<T, uint8_t>, std::is_same<T, uint16_t>, - std::is_same<T, uint32_t>, std::is_same<T, uint64_t>, -#ifdef ABSL_HAVE_INTRINSIC_INT128 - std::is_same<T, unsigned __int128>, -#endif - std::is_same<T, absl::uint128>>; -template <typename T> -constexpr bool is_unsigned_integer_v = is_unsigned_integer<T>::value; - -// Converts the given Value::Integer to an absl::uint128. Used as a helper -// function in `ConvertValueTo` and `ValueTypesAreEqual`. -// -// Returns INVALID_ARGUMENT if `in` is not a simple integer or IntModN. -absl::StatusOr<absl::uint128> ValueIntegerToUint128(const Value::Integer& in); - -// Converts an absl::uint128 to a Value::Integer. Used as a helper function in -// ToValue. -Value::Integer Uint128ToValueInteger(absl::uint128 input); - -// Checks if the given value is in range of T, and if so, returns it converted -// to T. -// -// Otherwise returns INVALID_ARGUMENT. -template <typename T, typename = absl::enable_if_t<is_unsigned_integer_v<T>>> -absl::StatusOr<T> Uint128To(absl::uint128 in) { - // Check whether value is in range if it's smaller than 128 bits. - if (!std::is_same<T, absl::uint128>::value && - absl::Uint128Low64(in) > - static_cast<uint64_t>(std::numeric_limits<T>::max())) { - return absl::InvalidArgumentError(absl::StrCat( - "Value (= ", absl::Uint128Low64(in), - ") too large for the given type T (size ", sizeof(T), ")")); - } - return static_cast<T>(in); -} - -// Implementation of ValueTypeHelper for integers. -template <typename T> -struct ValueTypeHelper<T, absl::enable_if_t<is_unsigned_integer_v<T>>> { - static constexpr bool IsSupportedType() { return true; } - - static constexpr bool CanBeConvertedDirectly() { return true; } - - static absl::StatusOr<T> FromValue(const Value& value) { - if (value.value_case() != Value::kInteger) { - return absl::InvalidArgumentError("The given Value is not an integer"); - } - // We first parse the value into an absl::uint128, then check its range if - // it is supposed to be smaller than 128 bits. - absl::StatusOr<absl::uint128> value_128 = - ValueIntegerToUint128(value.integer()); - if (!value_128.ok()) { - return value_128.status(); - } - return Uint128To<T>(*value_128); - } - - static Value ToValue(T input) { - Value result; - *(result.mutable_integer()) = Uint128ToValueInteger(input); - return result; - } - - static ValueType ToValueType() { - ValueType result; - result.mutable_integer()->set_bitsize(8 * sizeof(T)); - return result; - } - - static constexpr int TotalBitSize() { return sizeof(T) * 8; } - - static T DirectlyFromBytes(absl::string_view bytes) { - ABSL_CHECK(bytes.size() == sizeof(T)); - T out{0}; -#ifdef ABSL_IS_LITTLE_ENDIAN - std::copy_n(bytes.begin(), sizeof(T), reinterpret_cast<char*>(&out)); -#else - for (int i = sizeof(T) - 1; i >= 0; --i) { - out |= absl::bit_cast<uint8_t>(bytes[i]); - out <<= 8; - } -#endif - return out; - } - - static T SampleAndUpdateBytes(bool update, absl::uint128& block, - absl::string_view& remaining_bytes) { - T result = static_cast<T>(block); - - if (update) { - // Set sizeof(T) least significant bytes to 0. - if (sizeof(T) < sizeof(block)) { - constexpr absl::uint128 mask = - ~absl::uint128{std::numeric_limits<T>::max()}; - block &= mask; - } else { - block = 0; - } - - // Fill up with `bytes` and advance `bytes` by sizeof(T). - ABSL_DCHECK(remaining_bytes.size() >= sizeof(T)); - block |= DirectlyFromBytes(remaining_bytes.substr(0, sizeof(T))); - remaining_bytes = remaining_bytes.substr(sizeof(T)); - } - - return result; - } -}; - -/******************************************************************************/ -// IntModN Helpers // -/******************************************************************************/ - -template <typename BaseInteger, typename ModulusType, ModulusType kModulus> -struct ValueTypeHelper< - dpf_internal::IntModNImpl<BaseInteger, ModulusType, kModulus>, void> { - using IntModNType = - dpf_internal::IntModNImpl<BaseInteger, ModulusType, kModulus>; - - static constexpr bool IsSupportedType() { - return is_unsigned_integer_v<BaseInteger> && - is_unsigned_integer_v<ModulusType>; - } - - static constexpr bool CanBeConvertedDirectly() { return false; } - - static absl::StatusOr<IntModNType> FromValue(const Value& value) { - if (value.value_case() != Value::kIntModN) { - return absl::InvalidArgumentError("The given Value is not an IntModN"); - } - absl::StatusOr<absl::uint128> value_128 = - ValueIntegerToUint128(value.int_mod_n()); - if (!value_128.ok()) { - return value_128.status(); - } - if (*value_128 >= absl::uint128{kModulus}) { - return absl::InvalidArgumentError(absl::StrFormat( - "The given value (= %d) is larger than kModulus (= %d)", *value_128, - absl::uint128{kModulus})); - } - return IntModNType(static_cast<BaseInteger>(*value_128)); - } - - static Value ToValue(IntModNType input) { - Value result; - *(result.mutable_int_mod_n()) = Uint128ToValueInteger(input.value()); - return result; - } - - static ValueType ToValueType() { - ValueType result; - *(result.mutable_int_mod_n()->mutable_base_integer()) = - ValueTypeHelper<BaseInteger>::ToValueType().integer(); - *(result.mutable_int_mod_n()->mutable_modulus()) = - ValueTypeHelper<ModulusType>::ToValue(kModulus).integer(); - return result; - } - - static IntModNType SampleAndUpdateBytes(bool update, absl::uint128& block, - absl::string_view& remaining_bytes) { - // Optimization for native uint128. This is equivalent to what's done in - // int128.cc, but since division is not defined in the header, the compiler - // cannot optimize the division and modulus into a single operation. -#ifdef ABSL_HAVE_INTRINSIC_INT128 - absl::uint128 quotient = static_cast<unsigned __int128>(block) / kModulus, - remainder = static_cast<unsigned __int128>(block) % kModulus; -#else - absl::uint128 quotient = block / kModulus, remainder = block % kModulus; -#endif - IntModNType result(static_cast<BaseInteger>(remainder)); - - if (update) { - if (sizeof(BaseInteger) < sizeof(block)) { - block = quotient << (sizeof(BaseInteger) * 8); - } else { - block = 0; - } - block |= ValueTypeHelper<BaseInteger>::DirectlyFromBytes( - remaining_bytes.substr(0, sizeof(BaseInteger))); - remaining_bytes = remaining_bytes.substr(sizeof(BaseInteger)); - } - - return result; - } -}; - -/******************************************************************************/ -// Tuple Helpers // -/******************************************************************************/ - -// Helper struct for computing the bit size of a tuple type at compile time -// without C++17 fold expressions. -template <typename FirstElementType, typename... ElementType> -struct TupleBitSizeHelper { - static constexpr int TotalBitSize() { - return TupleBitSizeHelper<FirstElementType>::TotalBitSize() + - TupleBitSizeHelper<ElementType...>::TotalBitSize(); - } -}; -template <typename ElementType> -struct TupleBitSizeHelper<ElementType> { - static constexpr int TotalBitSize() { - return ValueTypeHelper<ElementType>::TotalBitSize(); - } -}; - -template <typename... ElementType> -struct ValueTypeHelper<Tuple<ElementType...>, void> { - using TupleType = Tuple<ElementType...>; - - static constexpr bool IsSupportedType() { - return absl::conjunction<is_supported_type<ElementType>...>::value; - } - - static constexpr bool CanBeConvertedDirectly() { - return absl::conjunction<can_be_converted_directly<ElementType>...>::value; - } - - static absl::StatusOr<TupleType> FromValue(const Value& value) { - if (value.value_case() != Value::kTuple) { - return absl::InvalidArgumentError("The given Value is not a tuple"); - } - constexpr auto tuple_size = - static_cast<int>(std::tuple_size<typename TupleType::Base>()); - if (value.tuple().elements_size() != tuple_size) { - return absl::InvalidArgumentError( - "The tuple in the given Value has the wrong number of elements"); - } - - // Create a Tuple by unpacking value.tuple().elements(). If we encounter an - // error, return it at the end. - absl::Status status = absl::OkStatus(); - int element_index = 0; - // The braced initializer list ensures elements are created in the correct - // order (unlike std::make_tuple). - TupleType result = {[&value, &status, &element_index] { - if (status.ok()) { - absl::StatusOr<ElementType> element = - ValueTypeHelper<ElementType>::FromValue( - value.tuple().elements(element_index)); - element_index++; - if (element.ok()) { - return *element; - } else { - status = element.status(); - } - } - return ElementType{}; - }()...}; - if (status.ok()) { - return result; - } else { - return status; - } - } - - static Value ToValue(const TupleType& input) { - Value result; - absl::apply( - [&result](const ElementType&... elements) { - // Create an unused std::tuple to iterate over `elements` in its - // constructor. This can be replaced by a fold expression in C++17. - std::tuple<ElementType...>{ - (*(result.mutable_tuple()->add_elements()) = - ValueTypeHelper<ElementType>::ToValue(elements), - ElementType{})...}; - }, - input.value()); - return result; - } - - static ValueType ToValueType() { - ValueType result; - ValueType::Tuple* tuple = result.mutable_tuple(); - // Create an unused std::tuple to iterate over `elements` in its - // constructor. This can be replaced by a fold expression in C++17. - std::tuple<ElementType...>{ - (*(tuple->add_elements()) = ValueTypeHelper<ElementType>::ToValueType(), - ElementType{})...}; - return result; - } - - static constexpr int TotalBitSize() { - // This helper can be replaced by a fold expression in C++17. - return TupleBitSizeHelper<ElementType...>::TotalBitSize(); - } - - static TupleType DirectlyFromBytes(absl::string_view bytes) { - ABSL_CHECK(8 * bytes.size() >= TotalBitSize()); - int offset = 0; - absl::Status status = absl::OkStatus(); - // Braced-init-list ensures the elements are constructed in-order. - return TupleType{[&bytes, &offset, &status] { - constexpr int element_size_bytes = - (ValueTypeHelper<ElementType>::TotalBitSize() + 7) / 8; - ElementType element = ValueTypeHelper<ElementType>::DirectlyFromBytes( - bytes.substr(offset, element_size_bytes)); - offset += element_size_bytes; - return element; - }()...}; - } - - static TupleType SampleAndUpdateBytes(bool update, absl::uint128& block, - absl::string_view& remaining_bytes) { - int element_counter = 0; - // Braced-init-list ensures the elements are constructed in-order. - return TupleType{[update, &element_counter, &block, - &remaining_bytes]() -> ElementType { - // If `update` is true, update after all elements. Otherwise, don't update - // after the last one. - constexpr int num_elements = std::tuple_size<typename TupleType::Base>(); - bool update2 = update || (++element_counter < num_elements); - return ValueTypeHelper<ElementType>::SampleAndUpdateBytes( - update2, block, remaining_bytes); - }()...}; - } -}; - -/******************************************************************************/ -// XorWrapper Helpers // -/******************************************************************************/ - -template <typename T> -struct ValueTypeHelper<XorWrapper<T>, void> { - static constexpr bool IsSupportedType() { - return ValueTypeHelper<T>::IsSupportedType(); - } - - static constexpr bool CanBeConvertedDirectly() { - return ValueTypeHelper<T>::CanBeConvertedDirectly(); - } - - static absl::StatusOr<XorWrapper<T>> FromValue(const Value& value) { - absl::StatusOr<absl::uint128> wrapped128 = - ValueIntegerToUint128(value.xor_wrapper()); - if (!wrapped128.ok()) { - return wrapped128.status(); - } - absl::StatusOr<T> wrapped = Uint128To<T>(*wrapped128); - if (!wrapped.ok()) { - return wrapped.status(); - } - return XorWrapper<T>(*wrapped); - } - - static Value ToValue(const XorWrapper<T>& input) { - Value result; - *(result.mutable_xor_wrapper()) = Uint128ToValueInteger(input.value()); - return result; - } - - static ValueType ToValueType() { - ValueType result; - *(result.mutable_xor_wrapper()) = - ValueTypeHelper<T>::ToValueType().integer(); - return result; - } - - static constexpr int TotalBitSize() { - return ValueTypeHelper<T>::TotalBitSize(); - } - - static XorWrapper<T> DirectlyFromBytes(absl::string_view bytes) { - return XorWrapper<T>(ValueTypeHelper<T>::DirectlyFromBytes(bytes)); - } - - static XorWrapper<T> SampleAndUpdateBytes( - bool update, absl::uint128& block, absl::string_view& remaining_bytes) { - return XorWrapper<T>(ValueTypeHelper<T>::SampleAndUpdateBytes( - update, block, remaining_bytes)); - } -}; - -/******************************************************************************/ -// Free standing helpers. These should always come last. When adding // -// additional types, add them above. // -/******************************************************************************/ - -// Computes the number of values of type T that fit into an absl::uint128. -// Returns a value >= 1 if batching is supported, and 1 otherwise. -template <typename T, - absl::enable_if_t<can_be_converted_directly_v<T>, int> = 0> -constexpr int ElementsPerBlock() { - if (TotalBitSize<T>() <= 128) { - return static_cast<int>(8 * sizeof(absl::uint128)) / TotalBitSize<T>(); - } - return 1; -} -template <typename T, - absl::enable_if_t<!can_be_converted_directly_v<T>, int> = 0> -constexpr int ElementsPerBlock() { - return 1; -} - -// Creates a value of type T from the given `bytes`. If possible, converts bytes -// directly using DirectlyFromBytes. Otherwise, uses SampleAndUpdateBytes. -// -// Crashes if `bytes.size()` is too small for the output type. -template <typename T, - absl::enable_if_t<can_be_converted_directly_v<T>, int> = 0> -T FromBytes(absl::string_view bytes) { - return ValueTypeHelper<T>::DirectlyFromBytes(bytes); -} -template <typename T, - absl::enable_if_t<!can_be_converted_directly_v<T>, int> = 0> -T FromBytes(absl::string_view bytes) { - absl::uint128 block = - FromBytes<absl::uint128>(bytes.substr(0, sizeof(absl::uint128))); - bytes = bytes.substr(sizeof(absl::uint128)); - return ValueTypeHelper<T>::SampleAndUpdateBytes(false, block, bytes); -} - -// Converts a `repeated Value` proto field to a std::array with element type T. -// -// Returns INVALID_ARGUMENT in case the input has the wrong size, or if the -// conversion fails. -template <typename T> -absl::StatusOr<std::array<T, ElementsPerBlock<T>()>> ValuesToArray( - const ::google::protobuf::RepeatedPtrField<Value>& values) { - if (values.size() != ElementsPerBlock<T>()) { - return absl::InvalidArgumentError(absl::StrCat( - "values.size() (= ", values.size(), - ") does not match ElementsPerBlock<T>() (= ", ElementsPerBlock<T>(), - ")")); - } - std::array<T, ElementsPerBlock<T>()> result; - for (int i = 0; i < ElementsPerBlock<T>(); ++i) { - absl::StatusOr<T> element = ValueTypeHelper<T>::FromValue(values[i]); - if (element.ok()) { - result[i] = std::move(*element); - } else { - return element.status(); - } - } - return result; -} - -// Converts a given string to an array of exactly ElementsPerBlock<T>() elements -// of type T. -// -// Crashes if `bytes.size()` is too small for the output type. -template <typename T, - absl::enable_if_t<can_be_converted_directly_v<T>, int> = 0> -std::array<T, ElementsPerBlock<T>()> ConvertBytesToArrayOf( - absl::string_view bytes) { - std::array<T, ElementsPerBlock<T>()> out; - const int element_size_bytes = (TotalBitSize<T>() + 7) / 8; - ABSL_CHECK(bytes.size() >= ElementsPerBlock<T>() * element_size_bytes); - for (int i = 0; i < ElementsPerBlock<T>(); ++i) { - out[i] = - FromBytes<T>(bytes.substr(i * element_size_bytes, element_size_bytes)); - } - return out; -} -template <typename T, - absl::enable_if_t<!can_be_converted_directly_v<T>, int> = 0> -std::array<T, ElementsPerBlock<T>()> ConvertBytesToArrayOf( - absl::string_view bytes) { - static_assert(ElementsPerBlock<T>() == 1, - "T does not support batching, but ElementsPerBlock<T> != 1"); - return {FromBytes<T>(bytes)}; -} - -// Computes the value correction word given two seeds `seed_a`, `seed_b` for -// parties a and b, such that the element at `block_index` is equal to `beta`. -// If `invert` is true, the result is multiplied element-wise by -1. Templated -// to use the correct integer type without needing modular reduction. -// -// Returns multiple values in case of packing, and a single value otherwise. -template <typename T> -absl::StatusOr<std::vector<Value>> ComputeValueCorrectionFor( - absl::string_view seed_a, absl::string_view seed_b, int block_index, - const Value& beta, bool invert) { - absl::StatusOr<T> beta_T = ValueTypeHelper<T>::FromValue(beta); - if (!beta_T.ok()) { - return beta_T.status(); - } - - constexpr int elements_per_block = ElementsPerBlock<T>(); - - // Compute values from seeds. Both arrays will have multiple elements if T - // supports batching, and a single one otherwise. - std::array<T, elements_per_block> ints_a = ConvertBytesToArrayOf<T>(seed_a), - ints_b = ConvertBytesToArrayOf<T>(seed_b); - - // Add beta to the right position. - ints_b[block_index] += *beta_T; - - // Add up shares, invert if needed. - for (int i = 0; i < elements_per_block; i++) { - ints_b[i] = ints_b[i] - ints_a[i]; - if (invert) { - ints_b[i] = -ints_b[i]; - } - } - - // Convert to a vector of Value protos and return. - std::vector<Value> result; - result.reserve(ints_b.size()); - for (const T& element : ints_b) { - result.push_back(ValueTypeHelper<T>::ToValue(element)); - } - return result; -} - -// Computes the number of pseudorandom bits needed to get a uniform element of -// the given `ValueType`. For types whose elements can be bijectively mapped to -// strings (e.g., unsigned integers and tuples of integers), this is equivalent -// to the bit size of the value type. For all other types, returns the number of -// bits needed so that converting a uniform string with the given number of bits -// to an element of `value_type` results in a distribution with total variation -// distance < 2^(-`security_parameter`) from uniform. -// -// Returns INVALID_ARGUMENT in case value_type does not represent a known type, -// or if sampling with the required security parameter is not possible. -absl::StatusOr<int> BitsNeeded(const ValueType& value_type, - double security_parameter); - -// Returns `true` if `lhs` and `rhs` describe the same types, and `false` -// otherwise. -// -// Returns INVALID_ARGUMENT if an error occurs while parsing either argument. -absl::StatusOr<bool> ValueTypesAreEqual(const ValueType& lhs, - const ValueType& rhs); - -} // namespace dpf_internal -} // namespace distributed_point_functions - -#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_VALUE_TYPE_HELPERS_H_ diff --git a/third_party/distributed_point_functions/code/dpf/internal/value_type_helpers_test.cc b/third_party/distributed_point_functions/code/dpf/internal/value_type_helpers_test.cc deleted file mode 100644 index 1abec0af4f45d..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/internal/value_type_helpers_test.cc +++ /dev/null @@ -1,359 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dpf/internal/value_type_helpers.h" - -#include <stdint.h> - -#include <array> -#include <string> -#include <tuple> - -#include "absl/base/config.h" -#include "absl/numeric/int128.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "dpf/distributed_point_function.pb.h" -#include "dpf/int_mod_n.h" -#include "dpf/internal/status_matchers.h" -#include "dpf/tuple.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace distributed_point_functions { -namespace dpf_internal { -namespace { - -constexpr int kDefaultSecurityParameter = 40; - -TEST(ValueTypeHelperTest, ValueTypesAreEqualFailsOnInvalidValueTypes) { - ValueType type1, type2; - - EXPECT_THAT(ValueTypesAreEqual(type1, type2), - StatusIs(absl::StatusCode::kInvalidArgument, - "Both arguments must be valid ValueTypes")); -} - -TEST(ValueTypeHelperTest, BitsNeededFailsOnInvalidValueType) { - EXPECT_THAT( - BitsNeeded(ValueType{}, kDefaultSecurityParameter), - StatusIs(absl::StatusCode::kInvalidArgument, - testing::StartsWith("BitsNeeded: Unsupported ValueType"))); -} - -template <typename T> -class ValueTypeIntegerTest : public testing::Test {}; -using IntegerTypes = - ::testing::Types<uint8_t, uint16_t, uint32_t, uint64_t, absl::uint128>; -TYPED_TEST_SUITE(ValueTypeIntegerTest, IntegerTypes); - -TYPED_TEST(ValueTypeIntegerTest, ToValueTypeIntegers) { - ValueType value_type = ValueTypeHelper<TypeParam>::ToValueType(); - - EXPECT_TRUE(value_type.has_integer()); - EXPECT_EQ(value_type.integer().bitsize(), sizeof(TypeParam) * 8); -} - -TYPED_TEST(ValueTypeIntegerTest, TestValueTypesAreEqual) { - ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(), - value_type_2; - value_type_2.mutable_integer()->set_bitsize(sizeof(TypeParam) * 8); - - DPF_ASSERT_OK_AND_ASSIGN(bool equal, - ValueTypesAreEqual(value_type_1, value_type_2)); - EXPECT_TRUE(equal); - DPF_ASSERT_OK_AND_ASSIGN(equal, - ValueTypesAreEqual(value_type_2, value_type_1)); - EXPECT_TRUE(equal); -} - -TYPED_TEST(ValueTypeIntegerTest, TestValueTypesAreNotEqual) { - ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(), - value_type_2; - value_type_2.mutable_integer()->set_bitsize(sizeof(TypeParam) * 8 * 2); - - DPF_ASSERT_OK_AND_ASSIGN(bool equal, - ValueTypesAreEqual(value_type_1, value_type_2)); - EXPECT_FALSE(equal); - DPF_ASSERT_OK_AND_ASSIGN(equal, - ValueTypesAreEqual(value_type_2, value_type_1)); - EXPECT_FALSE(equal); -} - -TYPED_TEST(ValueTypeIntegerTest, ValueConversionFailsIfNotInteger) { - Value value; - value.mutable_tuple(); - - EXPECT_THAT(ValueTypeHelper<TypeParam>::FromValue(value), - StatusIs(absl::StatusCode::kInvalidArgument, - "The given Value is not an integer")); -} - -TYPED_TEST(ValueTypeIntegerTest, ValueConversionFailsIfInvalidIntegerCase) { - Value value; - value.mutable_integer(); - - EXPECT_THAT(ValueTypeHelper<TypeParam>::FromValue(value), - StatusIs(absl::StatusCode::kInvalidArgument, - "Unknown value case for the given integer Value")); -} - -TYPED_TEST(ValueTypeIntegerTest, ValueConversionFailsIfValueOutOfRange) { - Value value; - auto value_64 = uint64_t{1} << 32; - value.mutable_integer()->set_value_uint64(value_64); - - if (sizeof(TypeParam) >= sizeof(uint64_t)) { - DPF_EXPECT_OK(ValueTypeHelper<TypeParam>::FromValue(value)); - } else { - EXPECT_THAT(ValueTypeHelper<TypeParam>::FromValue(value), - StatusIs(absl::StatusCode::kInvalidArgument, - absl::StrCat("Value (= ", value_64, - ") too large for the given type T (size ", - sizeof(TypeParam), ")"))); - } -} - -template <typename T> -class ValueTypeTupleTest : public testing::Test {}; - -template <typename T, int... bits> -struct TupleTestParam { - using Tuple = T; - static constexpr int ExpectedNumElements() { return sizeof...(bits); }; - static constexpr std::array<int, ExpectedNumElements()> ExpectedBitSizes() { - return {bits...}; - } -}; - -// We only test tuples consisting of integers here. -using TupleTypes = ::testing::Types< - TupleTestParam<Tuple<uint64_t>, 64>, - TupleTestParam<Tuple<uint64_t, uint64_t>, 64, 64>, - TupleTestParam<Tuple<uint32_t, absl::uint128, uint8_t>, 32, 128, 8>, - TupleTestParam<Tuple<uint8_t, uint8_t, uint8_t, uint8_t>, 8, 8, 8, 8>>; -TYPED_TEST_SUITE(ValueTypeTupleTest, TupleTypes); - -TYPED_TEST(ValueTypeTupleTest, ToValueTypeTuples) { - ValueType value_type = - ValueTypeHelper<typename TypeParam::Tuple>::ToValueType(); - - constexpr int expected_num_elements = TypeParam::ExpectedNumElements(); - EXPECT_TRUE(value_type.has_tuple()); - ASSERT_EQ(std::tuple_size<typename TypeParam::Tuple::Base>(), - expected_num_elements); // Sanity check for test parameters. - EXPECT_EQ(value_type.tuple().elements_size(), expected_num_elements); - - std::array<int, expected_num_elements> expected_bit_sizes = - TypeParam::ExpectedBitSizes(); - for (int i = 0; i < expected_num_elements; ++i) { - EXPECT_TRUE(value_type.tuple().elements(i).has_integer()); - EXPECT_EQ(value_type.tuple().elements(i).integer().bitsize(), - expected_bit_sizes[i]); - } -} - -TYPED_TEST(ValueTypeTupleTest, BitsNeededEqualsCompileTimeTypeSize) { - ValueType value_type = - ValueTypeHelper<typename TypeParam::Tuple>::ToValueType(); - - DPF_ASSERT_OK_AND_ASSIGN(int bitsize, - BitsNeeded(value_type, kDefaultSecurityParameter)); - - EXPECT_EQ(bitsize, TotalBitSize<typename TypeParam::Tuple>()); -} - -TYPED_TEST(ValueTypeTupleTest, ValueConversionFailsIfValueIsNotATuple) { - Value value; - value.mutable_integer(); - - EXPECT_THAT(ValueTypeHelper<Tuple<uint32_t>>::FromValue(value), - StatusIs(absl::StatusCode::kInvalidArgument, - "The given Value is not a tuple")); -} - -TEST(ValueTypeTupleTest, ValueConversionFailsIfValueSizeDoesntMatchTupleSize) { - Value value; - value.mutable_tuple()->add_elements()->mutable_integer()->set_value_uint64( - 1234); - - using TupleType = Tuple<uint32_t, uint32_t>; - EXPECT_THAT( - ValueTypeHelper<TupleType>::FromValue(value), - StatusIs( - absl::StatusCode::kInvalidArgument, - "The tuple in the given Value has the wrong number of elements")); -} - -TEST(ValueTypeTupleTest, TestValueTypesAreEqual) { - using T1 = Tuple<uint32_t, absl::uint128, uint8_t>; - using T2 = Tuple<uint32_t, absl::uint128, uint8_t>; - - ValueType value_type_1 = ValueTypeHelper<T1>::ToValueType(); - ValueType value_type_2 = ValueTypeHelper<T2>::ToValueType(); - - DPF_ASSERT_OK_AND_ASSIGN(bool equal, - ValueTypesAreEqual(value_type_1, value_type_2)); - EXPECT_TRUE(equal); - DPF_ASSERT_OK_AND_ASSIGN(equal, - ValueTypesAreEqual(value_type_2, value_type_1)); - EXPECT_TRUE(equal); -} - -TEST(ValueTypeTupleTest, TestValueTypesAreNotEqual) { - using T1 = Tuple<uint32_t, absl::uint128, uint8_t>; - using T2 = Tuple<uint32_t, absl::uint128, uint16_t>; - - ValueType value_type_1 = ValueTypeHelper<T1>::ToValueType(); - ValueType value_type_2 = ValueTypeHelper<T2>::ToValueType(); - - DPF_ASSERT_OK_AND_ASSIGN(bool equal, - ValueTypesAreEqual(value_type_1, value_type_2)); - EXPECT_FALSE(equal); - DPF_ASSERT_OK_AND_ASSIGN(equal, - ValueTypesAreEqual(value_type_2, value_type_1)); - EXPECT_FALSE(equal); -} - -TEST(ValueTypeTupleTest, TestFromBytesWithConcreteExample) { - std::string bytes = "A 128 bit string"; - - auto tuple = FromBytes<Tuple<uint64_t, uint64_t>>(bytes); - EXPECT_EQ(std::get<0>(tuple.value()), FromBytes<uint64_t>("A 128 bi")); - EXPECT_EQ(std::get<1>(tuple.value()), FromBytes<uint64_t>("t string")); -} - -TEST(ValueTypeTupleTest, TestFromBytesWithConcreteExampleForIntModN) { - constexpr uint32_t kModulus = 4294967291u; - using MyIntModN = IntModN<uint32_t, kModulus>; - std::string bytes = "A 128+32 bit string."; - - absl::uint128 block = FromBytes<absl::uint128>("A 128+32 bit str"); - MyIntModN expected_0(static_cast<uint32_t>(block % kModulus)); - block /= kModulus; - block <<= (8 * sizeof(uint32_t)); - block |= FromBytes<uint32_t>("ing."); - MyIntModN expected_1(static_cast<uint32_t>(block % kModulus)); - - auto tuple = FromBytes<Tuple<MyIntModN, MyIntModN>>(bytes).value(); - EXPECT_EQ(std::get<0>(tuple), expected_0); - EXPECT_EQ(std::get<1>(tuple), expected_1); -} - -template <typename T> -class ValueTypeIntModNTest : public testing::Test {}; -using IntModNTypes = ::testing::Types< - IntModN<uint32_t, 4>, IntModN<uint32_t, 4294967291u>, - IntModN<uint64_t, 4294967291ull>, IntModN<uint64_t, 1000000000000ull> -#ifdef ABSL_HAVE_INTRINSIC_INT128 - , - IntModN<absl::uint128, (unsigned __int128)(absl::MakeUint128( - 65535u, 18446744073709551551ull))> // 2**80-65 -#endif - >; -TYPED_TEST_SUITE(ValueTypeIntModNTest, IntModNTypes); - -TYPED_TEST(ValueTypeIntModNTest, ToValueType) { - ValueType value_type = ValueTypeHelper<TypeParam>::ToValueType(); - - EXPECT_TRUE(value_type.type_case() == ValueType::kIntModN); - EXPECT_EQ(value_type.int_mod_n().base_integer().bitsize(), - sizeof(typename TypeParam::Base) * 8); - DPF_ASSERT_OK_AND_ASSIGN( - absl::uint128 modulus, - ValueIntegerToUint128(value_type.int_mod_n().modulus())); - EXPECT_EQ(modulus, absl::uint128{TypeParam::modulus()}); -} - -TYPED_TEST(ValueTypeIntModNTest, TestValueTypesAreEqual) { - ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(), - value_type_2; - - value_type_2.mutable_int_mod_n()->mutable_base_integer()->set_bitsize( - sizeof(TypeParam) * 8); - *(value_type_2.mutable_int_mod_n()->mutable_modulus()) = - Uint128ToValueInteger(TypeParam::modulus()); - - DPF_ASSERT_OK_AND_ASSIGN(bool equal, - ValueTypesAreEqual(value_type_1, value_type_2)); - EXPECT_TRUE(equal); - DPF_ASSERT_OK_AND_ASSIGN(equal, - ValueTypesAreEqual(value_type_2, value_type_1)); - EXPECT_TRUE(equal); -} - -TYPED_TEST(ValueTypeIntModNTest, TestValueTypesAreDifferentBase) { - ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(), - value_type_2 = value_type_1; - - value_type_2.mutable_int_mod_n()->mutable_base_integer()->set_bitsize( - sizeof(TypeParam) * 8 * 2); - - DPF_ASSERT_OK_AND_ASSIGN(bool equal, - ValueTypesAreEqual(value_type_1, value_type_2)); - EXPECT_FALSE(equal); - DPF_ASSERT_OK_AND_ASSIGN(equal, - ValueTypesAreEqual(value_type_2, value_type_1)); - EXPECT_FALSE(equal); -}; - -TYPED_TEST(ValueTypeIntModNTest, TestValueTypesAreDifferentModulus) { - ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(), - value_type_2 = value_type_1; - - *(value_type_2.mutable_int_mod_n()->mutable_modulus()) = - Uint128ToValueInteger(TypeParam::modulus() - 1); - - DPF_ASSERT_OK_AND_ASSIGN(bool equal, - ValueTypesAreEqual(value_type_1, value_type_2)); - EXPECT_FALSE(equal); - DPF_ASSERT_OK_AND_ASSIGN(equal, - ValueTypesAreEqual(value_type_2, value_type_1)); - EXPECT_FALSE(equal); -} - -TYPED_TEST(ValueTypeIntModNTest, ValueTypesAreEqualFailsWhenModulusInvalid) { - ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(), - value_type_2 = value_type_1; - - value_type_2.mutable_int_mod_n()->clear_modulus(); - - EXPECT_THAT(ValueTypesAreEqual(value_type_1, value_type_2), - StatusIs(absl::StatusCode::kInvalidArgument, - "Unknown value case for the given integer Value")); -} - -TYPED_TEST(ValueTypeIntModNTest, ValueConversionFailsIfNotInteger) { - Value value; - value.mutable_tuple(); - - EXPECT_THAT(ValueTypeHelper<TypeParam>::FromValue(value), - StatusIs(absl::StatusCode::kInvalidArgument, - "The given Value is not an IntModN")); -} - -TYPED_TEST(ValueTypeIntModNTest, ValueConversionFailsIfTooLargeForModulus) { - Value value; - *(value.mutable_int_mod_n()) = Uint128ToValueInteger(TypeParam::modulus()); - - EXPECT_THAT(ValueTypeHelper<TypeParam>::FromValue(value), - StatusIs(absl::StatusCode::kInvalidArgument, - testing::HasSubstr("is larger than kModulus"))); -} - -} // namespace - -} // namespace dpf_internal -} // namespace distributed_point_functions diff --git a/third_party/distributed_point_functions/code/dpf/status_macros.h b/third_party/distributed_point_functions/code/dpf/status_macros.h deleted file mode 100644 index 9949ec417f505..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/status_macros.h +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MACROS_H_ -#define DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MACROS_H_ - -// Helper macro that checks if the right hand side (rexpression) evaluates to a -// StatusOr with Status OK, and if so assigns the value to the value on the left -// hand side (lhs), otherwise returns the error status. Example: -// DPF_ASSIGN_OR_RETURN(lhs, rexpression); -#define DPF_ASSIGN_OR_RETURN(lhs, rexpr) \ - DPF_ASSIGN_OR_RETURN_IMPL_( \ - DPF_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr) - -// Internal helper. -#define DPF_ASSIGN_OR_RETURN_IMPL_(statusor, lhs, rexpr) \ - auto statusor = (rexpr); \ - if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ - return std::move(statusor).status(); \ - } \ - lhs = std::move(statusor).value() - -// Internal helper for concatenating macro values. -#define DPF_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y -#define DPF_STATUS_MACROS_IMPL_CONCAT_(x, y) \ - DPF_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) - -#define DPF_RETURN_IF_ERROR(expr) \ - DPF_RETURN_IF_ERROR_IMPL_(DPF_STATUS_MACROS_IMPL_CONCAT_(_status, __LINE__), \ - expr) - -#define DPF_RETURN_IF_ERROR_IMPL_(status, expr) \ - auto status = (expr); \ - if (ABSL_PREDICT_FALSE(!status.ok())) { \ - return status; \ - } - -#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MACROS_H_ diff --git a/third_party/distributed_point_functions/code/dpf/tuple.h b/third_party/distributed_point_functions/code/dpf/tuple.h deleted file mode 100644 index 48627b38eba16..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/tuple.h +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_TUPLE_H_ -#define DISTRIBUTED_POINT_FUNCTIONS_DPF_TUPLE_H_ - -#include <stddef.h> - -#include <tuple> -#include <utility> - -namespace distributed_point_functions { - -// A Tuple class with added element-wise addition, subtraction, and negation -// operators. -template <typename... T> -class Tuple { - public: - using Base = std::tuple<T...>; - - Tuple() {} - Tuple(T... elements) : value_(elements...) {} - explicit Tuple(Base t) : value_(std::move(t)) {} - - // Copy constructor. - Tuple(const Tuple& t) = default; - Tuple& operator=(const Tuple& t) = default; - - // Getters for the base tuple type. - Base& value() { return value_; } - const Base& value() const { return value_; } - - private: - Base value_; -}; - -namespace dpf_internal { - -// Implementation of addition and negation. See -// https://stackoverflow.com/a/50815143. -// We declare the templates here, but define them at the end of this header -// because the definitions need to make use of operator+ and operator-. -template <typename... T, std::size_t... I> -constexpr Tuple<T...> add(const Tuple<T...>& lhs, const Tuple<T...>& rhs, - std::index_sequence<I...>); - -template <typename... T, std::size_t... I> -constexpr Tuple<T...> negate(const Tuple<T...>& t, std::index_sequence<I...>); - -} // namespace dpf_internal - -template <typename... T> -constexpr Tuple<T...> operator+(const Tuple<T...>& lhs, - const Tuple<T...>& rhs) { - return dpf_internal::add(lhs, rhs, std::make_index_sequence<sizeof...(T)>{}); -} - -template <typename... T> -constexpr Tuple<T...>& operator+=(Tuple<T...>& lhs, const Tuple<T...>& rhs) { - lhs = lhs + rhs; - return lhs; -} - -template <typename... T> -constexpr Tuple<T...> operator-(const Tuple<T...>& t) { - return dpf_internal::negate(t, std::make_index_sequence<sizeof...(T)>{}); -} - -template <typename... T> -constexpr Tuple<T...> operator-(const Tuple<T...>& lhs, - const Tuple<T...>& rhs) { - return lhs + (-rhs); -} - -template <typename... T> -constexpr Tuple<T...>& operator-=(Tuple<T...>& lhs, const Tuple<T...>& rhs) { - lhs = lhs - rhs; - return lhs; -} - -// Equality and inequality operators. -template <typename... T> -constexpr bool operator==(const Tuple<T...>& lhs, const Tuple<T...>& rhs) { - return lhs.value() == rhs.value(); -} - -template <typename... T> -constexpr bool operator!=(const Tuple<T...>& lhs, const Tuple<T...>& rhs) { - return lhs.value() != rhs.value(); -} - -namespace dpf_internal { -template <typename... T, std::size_t... I> -constexpr Tuple<T...> add(const Tuple<T...>& lhs, const Tuple<T...>& rhs, - std::index_sequence<I...>) { - return Tuple<T...>{std::get<I>(lhs.value()) + std::get<I>(rhs.value())...}; -} - -template <typename... T, std::size_t... I> -constexpr Tuple<T...> negate(const Tuple<T...>& t, std::index_sequence<I...>) { - return Tuple<T...>{ - // Explicitly cast to T to avoid -Wnarrowing warnings for small integers. - T(-std::get<I>(t.value()))...}; -} -} // namespace dpf_internal - -} // namespace distributed_point_functions - -#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_TUPLE_H_ diff --git a/third_party/distributed_point_functions/code/dpf/tuple_test.cc b/third_party/distributed_point_functions/code/dpf/tuple_test.cc deleted file mode 100644 index 993a9984e070d..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/tuple_test.cc +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dpf/tuple.h" - -#include <tuple> - -#include "absl/numeric/int128.h" -#include "gtest/gtest.h" - -namespace distributed_point_functions { - -namespace { - -using T = Tuple<int, double, absl::uint128>; - -TEST(TupleTest, TestAddition) { - T a(std::make_tuple(1, 2, 3)); - T b(std::make_tuple(4, 5, 6)); - - T c = a + b; - - EXPECT_EQ(std::get<0>(c.value()), - std::get<0>(a.value()) + std::get<0>(b.value())); - EXPECT_EQ(std::get<1>(c.value()), - std::get<1>(a.value()) + std::get<1>(b.value())); - EXPECT_EQ(std::get<2>(c.value()), - std::get<2>(a.value()) + std::get<2>(b.value())); -} - -TEST(TupleTest, TestAdditionInplace) { - T a(std::make_tuple(1, 2, 3)); - T b(std::make_tuple(4, 5, 6)); - - T a2 = a; - a += b; - - EXPECT_EQ(std::get<0>(a.value()), - std::get<0>(a2.value()) + std::get<0>(b.value())); - EXPECT_EQ(std::get<1>(a.value()), - std::get<1>(a2.value()) + std::get<1>(b.value())); - EXPECT_EQ(std::get<2>(a.value()), - std::get<2>(a2.value()) + std::get<2>(b.value())); -} -TEST(TupleTest, TestSubtraction) { - T a(std::make_tuple(1, 2, 3)); - T b(std::make_tuple(4, 5, 6)); - - T c = a - b; - - EXPECT_EQ(std::get<0>(c.value()), - std::get<0>(a.value()) - std::get<0>(b.value())); - EXPECT_EQ(std::get<1>(c.value()), - std::get<1>(a.value()) - std::get<1>(b.value())); - EXPECT_EQ(std::get<2>(c.value()), - std::get<2>(a.value()) - std::get<2>(b.value())); -} - -TEST(TupleTest, TestSubtractionInplace) { - T a(std::make_tuple(1, 2, 3)); - T b(std::make_tuple(4, 5, 6)); - - T a2 = a; - a -= b; - - EXPECT_EQ(std::get<0>(a.value()), - std::get<0>(a2.value()) - std::get<0>(b.value())); - EXPECT_EQ(std::get<1>(a.value()), - std::get<1>(a2.value()) - std::get<1>(b.value())); - EXPECT_EQ(std::get<2>(a.value()), - std::get<2>(a2.value()) - std::get<2>(b.value())); -} - -TEST(TupleTest, TestNegation) { - T a(std::make_tuple(1, 2, 3)); - - T a2 = -a; - - EXPECT_EQ(std::get<0>(a2.value()), -std::get<0>(a.value())); - EXPECT_EQ(std::get<1>(a2.value()), -std::get<1>(a.value())); - EXPECT_EQ(std::get<2>(a2.value()), -std::get<2>(a.value())); -} - -} // namespace - -} // namespace distributed_point_functions diff --git a/third_party/distributed_point_functions/code/dpf/xor_wrapper.h b/third_party/distributed_point_functions/code/dpf/xor_wrapper.h deleted file mode 100644 index 31166be8e00c9..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/xor_wrapper.h +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_XOR_WRAPPER_H_ -#define DISTRIBUTED_POINT_FUNCTIONS_DPF_XOR_WRAPPER_H_ - -#include <utility> - -namespace distributed_point_functions { - -// Wraps the given type, replacing additions and subtractions by XOR. -template <typename T> -class XorWrapper { - public: - using WrappedType = T; - - constexpr XorWrapper() : wrapped_{} {} - explicit constexpr XorWrapper(T wrapped) : wrapped_(std::move(wrapped)) {} - - // XorWrapper is copyable and movable. - constexpr XorWrapper(const XorWrapper&) = default; - constexpr XorWrapper& operator=(const XorWrapper&) = default; - constexpr XorWrapper(XorWrapper&&) = default; - constexpr XorWrapper& operator=(XorWrapper&&) = default; - - // Assignment operators. - constexpr XorWrapper& operator+=(const XorWrapper& rhs) { - wrapped_ ^= rhs.value(); - return *this; - } - constexpr XorWrapper& operator-=(const XorWrapper& rhs) { - wrapped_ ^= rhs.value(); - return *this; - } - - // Returns a reference to the wrapped object. - constexpr T& value() { return wrapped_; } - constexpr const T& value() const { return wrapped_; } - - private: - T wrapped_; -}; - -template <typename T> -constexpr XorWrapper<T> operator+(XorWrapper<T> a, const XorWrapper<T>& b) { - a += b; - return a; -} - -template <typename T> -constexpr XorWrapper<T> operator-(XorWrapper<T> a, const XorWrapper<T>& b) { - a -= b; - return a; -} - -// Negation does nothing in XOR sharing, since -a = 0-a. -template <typename T> -constexpr XorWrapper<T> operator-(const XorWrapper<T>& a) { - return a; -} - -template <typename T> -constexpr bool operator==(const XorWrapper<T>& a, const XorWrapper<T>& b) { - return a.value() == b.value(); -} - -template <typename T> -constexpr bool operator!=(const XorWrapper<T>& a, const XorWrapper<T>& b) { - return !(a == b); -} - -} // namespace distributed_point_functions - -#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_XOR_WRAPPER_H_ diff --git a/third_party/distributed_point_functions/code/dpf/xor_wrapper_test.cc b/third_party/distributed_point_functions/code/dpf/xor_wrapper_test.cc deleted file mode 100644 index b5fda9e7d9e51..0000000000000 --- a/third_party/distributed_point_functions/code/dpf/xor_wrapper_test.cc +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dpf/xor_wrapper.h" - -#include <stdint.h> - -#include "absl/numeric/int128.h" -#include "gtest/gtest.h" - -namespace distributed_point_functions { - -namespace { - -template <typename T> -class XorWrapperTest : public testing::Test {}; -using XorWrapperTypes = - testing::Types<uint8_t, uint16_t, uint32_t, uint64_t, absl::uint128>; - -TYPED_TEST_SUITE(XorWrapperTest, XorWrapperTypes); - -TYPED_TEST(XorWrapperTest, TestConstructor) { - TypeParam value{42}; - - XorWrapper<TypeParam> wrapper(value); - - EXPECT_EQ(wrapper.value(), value); -} - -TYPED_TEST(XorWrapperTest, TestAddition) { - TypeParam a{42}, b{23}; - XorWrapper<TypeParam> wrapped_a(a), wrapped_b(b); - - EXPECT_EQ((wrapped_a + wrapped_b).value(), a ^ b); -} - -TYPED_TEST(XorWrapperTest, TestSubtraction) { - TypeParam a{42}, b{23}; - XorWrapper<TypeParam> wrapped_a(a), wrapped_b(b); - - EXPECT_EQ((wrapped_a - wrapped_b).value(), a ^ b); -} - -TYPED_TEST(XorWrapperTest, TestNegation) { - TypeParam value{42}; - XorWrapper<TypeParam> wrapper(value); - - EXPECT_EQ((-wrapper).value(), value); -} - -TYPED_TEST(XorWrapperTest, TestEquality) { - TypeParam a{42}, b{23}; - XorWrapper<TypeParam> wrapped_a(a), wrapped_b(b); - - EXPECT_EQ(wrapped_a, XorWrapper<TypeParam>(a)); - EXPECT_NE(wrapped_a, XorWrapper<TypeParam>(b)); -} - -} // namespace - -} // namespace distributed_point_functions diff --git a/third_party/distributed_point_functions/features.gni b/third_party/distributed_point_functions/features.gni deleted file mode 100644 index fd6f285fdeef8..0000000000000 --- a/third_party/distributed_point_functions/features.gni +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2024 The Chromium Authors -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. - -declare_args() { - use_distributed_point_functions = is_debug - dpf_abseil_cpp_dir = "//third_party/abseil-cpp" - dpf_highway_cpp_dir = "//third_party/highway" -} diff --git a/third_party/distributed_point_functions/fuzz/dpf_fuzzer.cc b/third_party/distributed_point_functions/fuzz/dpf_fuzzer.cc deleted file mode 100644 index 858993533c05e..0000000000000 --- a/third_party/distributed_point_functions/fuzz/dpf_fuzzer.cc +++ /dev/null @@ -1,293 +0,0 @@ -// Copyright 2021 The Chromium Authors -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "third_party/distributed_point_functions/code/dpf/distributed_point_function.h" - -#include <stddef.h> -#include <stdint.h> -#include <stdlib.h> - -#include <fuzzer/FuzzedDataProvider.h> - -#include <algorithm> -#include <memory> - -#define DPF_FUZZER_ASSERT(x) \ - if (!(x)) { \ - printf("DPF assertion failed: function %s, file %s, line %d.\n", \ - __PRETTY_FUNCTION__, __FILE__, __LINE__); \ - abort(); \ - } - -namespace { - -const size_t UINT128_SIZE = 2 * sizeof(uint64_t); - -// Constructs a `uint128` numeric value from two 64-bit unsigned integers -// consumed from the data provider. -absl::uint128 ConsumeUint128(FuzzedDataProvider& data_provider) { - uint64_t high = data_provider.ConsumeIntegral<uint64_t>(); - uint64_t low = data_provider.ConsumeIntegral<uint64_t>(); - return absl::MakeUint128(high, low); -} - -// Returns the prefix of `index` for the domain of `hierarchy_level`. -// Adapted from `DpfEvaluationTest::GetPrefixForLevel()`. -absl::uint128 GetPrefixForLevel( - int hierarchy_level, - absl::uint128 index, - const std::vector<distributed_point_functions::DpfParameters>& parameters) { - absl::uint128 result = 0; - int shift_amount = parameters.back().log_domain_size() - - parameters[hierarchy_level].log_domain_size(); - if (shift_amount < 128) - result = index >> shift_amount; - return result; -} - -// Evaluates both contexts `ctx0` and `ctx1` at `hierarchy level`, using the -// appropriate prefixes of `evaluation_points`. Checks that the expansion of -// both keys from correct DPF shares, i.e., they add up to -// `beta[ctx.hierarchy_level()]` under prefixes of `alpha`, and to 0 otherwise. -// Adapted from `DpfEvaluationTest::EvaluateAndCheckLevel()`. -template <typename T> -void EvaluateAndCheckLevel( - int hierarchy_level, - absl::Span<const absl::uint128> evaluation_points, - absl::uint128 alpha, - const std::vector<absl::uint128>& beta, - distributed_point_functions::EvaluationContext& ctx0, - distributed_point_functions::EvaluationContext& ctx1, - const std::vector<distributed_point_functions::DpfParameters>& parameters, - const distributed_point_functions::DistributedPointFunction& dpf) { - int previous_hierarchy_level = ctx0.previous_hierarchy_level(); - int current_log_domain_size = parameters[hierarchy_level].log_domain_size(); - int previous_log_domain_size = 0; - int num_expansions = 1; - bool is_first_evaluation = previous_hierarchy_level < 0; - // Generate prefixes if we're not on the first level. - std::vector<absl::uint128> prefixes; - if (!is_first_evaluation) { - num_expansions = static_cast<int>(evaluation_points.size()); - prefixes.resize(evaluation_points.size()); - previous_log_domain_size = - parameters[previous_hierarchy_level].log_domain_size(); - for (int i = 0; i < static_cast<int>(evaluation_points.size()); ++i) - prefixes[i] = GetPrefixForLevel(previous_hierarchy_level, - evaluation_points[i], parameters); - } - - // Evaluating a key with N correction words leads to an O(2^N) malloc, which - // will unsurprisingly cause a fuzzer crash. See <https://crbug.com/1494260>. - constexpr size_t kMaxCorrectionWords = 30; - if (ctx0.key().correction_words().size() > kMaxCorrectionWords) { - return; - } - absl::StatusOr<std::vector<T>> result_0 = - dpf.EvaluateUntil<T>(hierarchy_level, prefixes, ctx0); - DPF_FUZZER_ASSERT(result_0.ok()); - if (ctx1.key().correction_words().size() > kMaxCorrectionWords) { - return; - } - absl::StatusOr<std::vector<T>> result_1 = - dpf.EvaluateUntil<T>(hierarchy_level, prefixes, ctx1); - DPF_FUZZER_ASSERT(result_1.ok()); - - DPF_FUZZER_ASSERT(result_0->size() == result_1->size()); - int64_t outputs_per_prefix = - int64_t{1} << (current_log_domain_size - previous_log_domain_size); - int64_t expected_output_size = num_expansions * outputs_per_prefix; - DPF_FUZZER_ASSERT(static_cast<int64_t>(result_0->size()) == - expected_output_size); - - // Iterator over the outputs and check that they sum up to 0 or to - // `beta[current_hierarchy_level]`; - absl::uint128 previous_alpha_prefix = 0; - if (!is_first_evaluation) - previous_alpha_prefix = - GetPrefixForLevel(previous_hierarchy_level, alpha, parameters); - - absl::uint128 current_alpha_prefix = - GetPrefixForLevel(hierarchy_level, alpha, parameters); - for (int64_t i = 0; i < expected_output_size; ++i) { - int prefix_index = i / outputs_per_prefix; - int prefix_expansion_index = i % outputs_per_prefix; - // The output is on the path to `alpha`, if we're at the first level or - // under a prefix of `alpha`, and the current block in the expansion of the - // prefix is also on the path to `alpha`. - if ((is_first_evaluation || - prefixes[prefix_index] == previous_alpha_prefix) && - prefix_expansion_index == (current_alpha_prefix % outputs_per_prefix)) { - // We need to static_cast here since otherwise operator+ returns an - // unsigned int without doing a modular reduction, which causes the test - // to fail on types with sizeof(T) < sizeof(unsigned). - DPF_FUZZER_ASSERT( - absl::uint128{static_cast<T>((*result_0)[i] + (*result_1)[i])} == - beta[hierarchy_level]); - } else { - DPF_FUZZER_ASSERT(static_cast<T>((*result_0)[i] + (*result_1)[i]) == 0U); - } - } -} - -} // namespace - -extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { - // Use magic separator to split the input into two parts. The first part will - // generate alpha, and an array of parameters and betas. The second part will - // generate level step and an array of evaluation points. - const uint8_t separator[] = {0xDE, 0xAD, 0xBE, 0xEF}; - - const uint8_t* pos = - std::search(data, data + size, separator, separator + sizeof(separator)); - - const uint8_t* data1 = data; - size_t size1 = pos - data; - - const uint8_t* data2 = - (pos == data + size) ? nullptr : pos + sizeof(separator); - size_t size2 = data2 ? (data + size) - (pos + sizeof(separator)) : 0; - - FuzzedDataProvider data_provider1(data1, size1); - - if (data_provider1.remaining_bytes() < UINT128_SIZE) - return 0; - - absl::uint128 alpha = ConsumeUint128(data_provider1); - - std::vector<int32_t> log_domain_sizes; - std::vector<int32_t> element_bitsizes; - std::vector<distributed_point_functions::DpfParameters> parameters; - std::vector<absl::uint128> beta; - - // log_domain_size(int32_t), element_bitsize(int32_t), - // beta(uint128) - while (data_provider1.remaining_bytes() >= - (2 * sizeof(int32_t) + UINT128_SIZE)) { - int32_t log_domain_size = data_provider1.ConsumeIntegral<int32_t>(); - int32_t element_bitsize = data_provider1.ConsumeIntegral<int32_t>(); - log_domain_sizes.push_back(log_domain_size); - element_bitsizes.push_back(element_bitsize); - - distributed_point_functions::DpfParameters parameter; - parameter.set_log_domain_size(log_domain_size); - parameter.mutable_value_type()->mutable_integer()->set_bitsize( - element_bitsize); - parameters.push_back(parameter); - - beta.push_back(ConsumeUint128(data_provider1)); - } - - absl::StatusOr< - std::unique_ptr<distributed_point_functions::DistributedPointFunction>> - status_or_dpf = distributed_point_functions::DistributedPointFunction:: - CreateIncremental(parameters); - - size_t num_levels = parameters.size(); - - if (!status_or_dpf.ok()) { - // `log_domain_size` is expected to be in ascending order and - // `element_bitsize` is expected to be non-decreasing. As it is hard for the - // fuzzer to land upon a valid input, we sort the parameters and try again - // if the construction fails. - std::sort(log_domain_sizes.begin(), log_domain_sizes.end()); - std::sort(element_bitsizes.begin(), element_bitsizes.end()); - for (size_t i = 0; i < num_levels; ++i) { - parameters[i].set_log_domain_size(log_domain_sizes[i]); - parameters[i].mutable_value_type()->mutable_integer()->set_bitsize( - element_bitsizes[i]); - } - - status_or_dpf = distributed_point_functions::DistributedPointFunction:: - CreateIncremental(parameters); - } - - if (!status_or_dpf.ok()) - return 0; - - std::unique_ptr<distributed_point_functions::DistributedPointFunction> dpf = - std::move(status_or_dpf).value(); - - absl::StatusOr<std::pair<distributed_point_functions::DpfKey, - distributed_point_functions::DpfKey>> - status_or_keys = dpf->GenerateKeysIncremental(alpha, beta); - if (!status_or_keys.ok()) - return 0; - - std::pair<distributed_point_functions::DpfKey, - distributed_point_functions::DpfKey> - keys = std::move(status_or_keys).value(); - - // Adapted from `DpfEvaluationTest.TestCorrectness()`. - absl::StatusOr<distributed_point_functions::EvaluationContext> - status_or_ctx0 = dpf->CreateEvaluationContext(keys.first); - DPF_FUZZER_ASSERT(status_or_ctx0.ok()); - - absl::StatusOr<distributed_point_functions::EvaluationContext> - status_or_ctx1 = dpf->CreateEvaluationContext(keys.second); - DPF_FUZZER_ASSERT(status_or_ctx1.ok()); - - distributed_point_functions::EvaluationContext ctx0 = - std::move(status_or_ctx0).value(); - distributed_point_functions::EvaluationContext ctx1 = - std::move(status_or_ctx1).value(); - - // Generate evaluation points. - FuzzedDataProvider data_provider2(data2, size2); - if (data_provider2.remaining_bytes() < sizeof(int)) - return 0; - - int level_step = data_provider2.ConsumeIntegralInRange<int>(1, 10); - - std::vector<absl::uint128> evaluation_points; - while (data_provider2.remaining_bytes() >= UINT128_SIZE) { - evaluation_points.push_back(ConsumeUint128(data_provider2)); - if (parameters.back().log_domain_size() < 128) - evaluation_points.back() %= - (absl::uint128{1} << parameters.back().log_domain_size()); - } - - // Always evaluate on alpha. - evaluation_points.push_back(alpha); - - int32_t previous_log_domain_size = 0; - for (int i = level_step - 1; i < static_cast<int>(num_levels); - i += level_step) { - // If any gap in the log_domain_sizes used in successive evaluations is - // larger than 62, validation will fail in `EvaluateAndCheckLevel`. - int32_t current_log_domain_size = parameters[i].log_domain_size(); - if (current_log_domain_size - previous_log_domain_size > 62) - return 0; - previous_log_domain_size = current_log_domain_size; - - switch (parameters[i].value_type().integer().bitsize()) { - case 8: - EvaluateAndCheckLevel<uint8_t>(i, evaluation_points, alpha, beta, ctx0, - ctx1, parameters, *dpf); - break; - case 16: - EvaluateAndCheckLevel<uint16_t>(i, evaluation_points, alpha, beta, ctx0, - ctx1, parameters, *dpf); - break; - case 32: - EvaluateAndCheckLevel<uint32_t>(i, evaluation_points, alpha, beta, ctx0, - ctx1, parameters, *dpf); - break; - case 64: - EvaluateAndCheckLevel<uint64_t>(i, evaluation_points, alpha, beta, ctx0, - ctx1, parameters, *dpf); - break; - case 128: - EvaluateAndCheckLevel<absl::uint128>(i, evaluation_points, alpha, beta, - ctx0, ctx1, parameters, *dpf); - break; - default: - // DPF construction should've failed if the parameters were invalid. - DPF_FUZZER_ASSERT(false); - break; - } - } - - return 0; -} diff --git a/third_party/distributed_point_functions/shim/BUILD.gn b/third_party/distributed_point_functions/shim/BUILD.gn deleted file mode 100644 index f2ce4cd72cc51..0000000000000 --- a/third_party/distributed_point_functions/shim/BUILD.gn +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2024 The Chromium Authors -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. - -import("//build/buildflag_header.gni") -import("//testing/test.gni") -import("//third_party/distributed_point_functions/features.gni") - -source_set("shim") { - public_deps = [ ":buildflags" ] - - if (use_distributed_point_functions) { - sources = [ - "distributed_point_function_shim.cc", - "distributed_point_function_shim.h", - ] - deps = [ - "$dpf_abseil_cpp_dir:absl", - "//base", - "//third_party/distributed_point_functions:internal", - ] - public_deps += [ "//third_party/distributed_point_functions:proto" ] - configs += [ "//third_party/distributed_point_functions:includes" ] - } -} - -# External targets may depend on :buildflags directly without pulling in -# :distributed_point_functions. For instance, tests may set different -# expectations when the dpf library is omitted from the build. -buildflag_header("buildflags") { - header = "buildflags.h" - flags = [ "USE_DISTRIBUTED_POINT_FUNCTIONS=$use_distributed_point_functions" ] -} - -test("distributed_point_functions_shim_unittests") { - deps = [ - "//testing/gtest", - "//testing/gtest:gtest_main", - ] - if (use_distributed_point_functions) { - sources = [ "distributed_point_function_shim_unittest.cc" ] - deps += [ - ":shim", - "$dpf_abseil_cpp_dir:absl", - "//third_party/protobuf:protobuf_lite", - ] - } -} diff --git a/third_party/distributed_point_functions/shim/DEPS b/third_party/distributed_point_functions/shim/DEPS deleted file mode 100644 index 5cd0867c848de..0000000000000 --- a/third_party/distributed_point_functions/shim/DEPS +++ /dev/null @@ -1,3 +0,0 @@ -include_rules = [ - "+base", -] diff --git a/third_party/distributed_point_functions/shim/distributed_point_function_shim.cc b/third_party/distributed_point_functions/shim/distributed_point_function_shim.cc deleted file mode 100644 index 4cbf927cddf66..0000000000000 --- a/third_party/distributed_point_functions/shim/distributed_point_function_shim.cc +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2023 The Chromium Authors -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include <memory> -#include <optional> -#include <utility> -#include <vector> - -#include "base/check_op.h" -#include "base/logging.h" -#include "third_party/abseil-cpp/absl/numeric/int128.h" -#include "third_party/abseil-cpp/absl/status/status.h" -#include "third_party/abseil-cpp/absl/status/statusor.h" -#include "third_party/distributed_point_functions/code/dpf/distributed_point_function.h" -#include "third_party/distributed_point_functions/dpf/distributed_point_function.pb.h" -#include "third_party/distributed_point_functions/shim/distributed_point_function_shim.h" - -namespace distributed_point_functions { -std::optional<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental( - std::vector<DpfParameters> parameters, - absl::uint128 alpha, - std::vector<absl::uint128> beta) { - // absl::StatusOr is not allowed in the codebase, but this minimal usage is - // necessary to interact with //third_party/distributed_point_functions/. - absl::StatusOr<std::unique_ptr<DistributedPointFunction>> dpf_result = - DistributedPointFunction::CreateIncremental(std::move(parameters)); - if (!dpf_result.ok()) { - LOG(ERROR) << "CreateIncremental() failed: " << dpf_result.status(); - return std::nullopt; - } - CHECK_NE(*dpf_result, nullptr); - - absl::StatusOr<std::pair<DpfKey, DpfKey>> keys_result = - (*dpf_result)->GenerateKeysIncremental(alpha, std::move(beta)); - if (!keys_result.ok()) { - LOG(ERROR) << "GenerateKeysIncremental() failed: " << keys_result.status(); - return std::nullopt; - } - return std::move(*keys_result); -} -} // namespace distributed_point_functions diff --git a/third_party/distributed_point_functions/shim/distributed_point_function_shim.h b/third_party/distributed_point_functions/shim/distributed_point_function_shim.h deleted file mode 100644 index 9165d9c08beb3..0000000000000 --- a/third_party/distributed_point_functions/shim/distributed_point_function_shim.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2023 The Chromium Authors -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef CONTENT_BROWSER_AGGREGATION_SERVICE_DISTRIBUTED_POINT_FUNCTION_SHIM_H_ -#define CONTENT_BROWSER_AGGREGATION_SERVICE_DISTRIBUTED_POINT_FUNCTION_SHIM_H_ - -#include "third_party/distributed_point_functions/shim/buildflags.h" - -static_assert(BUILDFLAG(USE_DISTRIBUTED_POINT_FUNCTIONS), - "This header must not be included when " - "distributed_point_functions is omitted from the build"); - -#include <optional> -#include <utility> -#include <vector> - -#include "third_party/abseil-cpp/absl/numeric/int128.h" -#include "third_party/distributed_point_functions/dpf/distributed_point_function.pb.h" - -namespace distributed_point_functions { - -// Generates a pair of keys for a DPF that evaluates to `beta` when given -// `alpha`. On failure, returns std::nullopt. -std::optional<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental( - std::vector<DpfParameters> parameters, - absl::uint128 alpha, - std::vector<absl::uint128> beta); - -} // namespace distributed_point_functions - -#endif // CONTENT_BROWSER_AGGREGATION_SERVICE_DISTRIBUTED_POINT_FUNCTION_SHIM_H_ diff --git a/third_party/distributed_point_functions/shim/distributed_point_function_shim_unittest.cc b/third_party/distributed_point_functions/shim/distributed_point_function_shim_unittest.cc deleted file mode 100644 index d39dcb9911b5c..0000000000000 --- a/third_party/distributed_point_functions/shim/distributed_point_function_shim_unittest.cc +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2023 The Chromium Authors -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include <stddef.h> - -#include <optional> -#include <utility> -#include <vector> - -#include "testing/gtest/include/gtest/gtest.h" -#include "third_party/abseil-cpp/absl/numeric/int128.h" -#include "third_party/distributed_point_functions/dpf/distributed_point_function.pb.h" -#include "third_party/distributed_point_functions/shim/distributed_point_function_shim.h" - -namespace distributed_point_functions { - -// The shim's GenerateKeysIncremental() can return a value besides std::nullopt. -TEST(DistributedPointFunctionShimTest, GenerateKeysIncrementalConstructsKeys) { - constexpr size_t kBitLength = 32; - std::vector<DpfParameters> parameters(kBitLength); - for (size_t i = 0; i < parameters.size(); ++i) { - parameters[i].set_log_domain_size(i + 1); - parameters[i].mutable_value_type()->mutable_integer()->set_bitsize( - parameters.size()); - } - std::optional<std::pair<DpfKey, DpfKey>> maybe_dpf_keys = - GenerateKeysIncremental( - std::move(parameters), - /*alpha=*/absl::uint128{1}, - /*beta=*/std::vector<absl::uint128>(kBitLength, absl::uint128{1})); - EXPECT_TRUE(maybe_dpf_keys.has_value()); -} - -// When DistributedPointFunction::CreateIncremental() fails, the shim's -// GenerateKeysIncremental() should return std::nullopt. -TEST(DistributedPointFunctionShimTest, GenerateKeysIncrementalEmptyParameters) { - EXPECT_FALSE(GenerateKeysIncremental(/*parameters=*/{}, - /*alpha=*/absl::uint128{}, /*beta=*/{})); -} - -// When the length of beta does not match the number of parameters, the internal -// call to DistributedPointFunction::GenerateKeysIncremental() will fail, and -// the shim's GenerateKeysIncremental() should return std::nullopt. -TEST(DistributedPointFunctionShimTest, GenerateKeysIncrementalBetaWrongSize) { - std::vector<DpfParameters> parameters(3); - EXPECT_FALSE( - GenerateKeysIncremental(/*parameters=*/std::vector<DpfParameters>(3), - /*alpha=*/absl::uint128{}, /*beta=*/{1, 2, 3})); -} - -} // namespace distributed_point_functions diff --git a/third_party/highway/OWNERS b/third_party/highway/OWNERS index 3216b486f8d99..cf40f20ccc679 100644 --- a/third_party/highway/OWNERS +++ b/third_party/highway/OWNERS @@ -1 +1,2 @@ -file://third_party/distributed_point_functions/OWNERS \ No newline at end of file +bikineev@chromium.org +file://third_party/blink/renderer/core/html/parser/OWNERS