0

Remove Distributed Point Functions library

This third-party library is no longer integrated with any Chromium code
following https://crrev.com/c/6505925, so we can remove the library.
Note that this also removes the associated fuzzer.

We can't remove the third_party/highway library as well as it is now
used in
//third_party/blink/renderer/core/html/parser/html_document_parser_fastpath.cc.
We update the OWNERS to align with that new usage.

Bug: 40178420
Change-Id: I879d718fbd83d6069adef18801de0d1c13ab3037
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6507840
Reviewed-by: Anton Bikineev <bikineev@chromium.org>
Reviewed-by: Andrew Grieve <agrieve@chromium.org>
Reviewed-by: Nan Lin <linnan@chromium.org>
Commit-Queue: Alex Turner <alexmt@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1456337}
This commit is contained in:
Alex Turner
2025-05-06 08:02:52 -07:00
committed by Chromium LUCI CQ
parent 23f70fe18f
commit a1a7802b3b
66 changed files with 2 additions and 11526 deletions

@ -114,7 +114,6 @@ group("gn_all") {
"//third_party/angle/src/tests:angle_end2end_tests", "//third_party/angle/src/tests:angle_end2end_tests",
"//third_party/angle/src/tests:angle_unittests", "//third_party/angle/src/tests:angle_unittests",
"//third_party/angle/src/tests:angle_white_box_tests", "//third_party/angle/src/tests:angle_white_box_tests",
"//third_party/distributed_point_functions/shim:distributed_point_functions_shim_unittests",
"//third_party/flatbuffers:flatbuffers_unittests", "//third_party/flatbuffers:flatbuffers_unittests",
"//third_party/highway:highway_tests", "//third_party/highway:highway_tests",
"//third_party/liburlpattern:liburlpattern_unittests", "//third_party/liburlpattern:liburlpattern_unittests",

@ -289,7 +289,6 @@ source_set("browser") {
"//third_party/blink/public/strings", "//third_party/blink/public/strings",
"//third_party/boringssl", "//third_party/boringssl",
"//third_party/brotli:dec", "//third_party/brotli:dec",
"//third_party/distributed_point_functions",
"//third_party/icu", "//third_party/icu",
"//third_party/inspector_protocol:crdtp", "//third_party/inspector_protocol:crdtp",
"//third_party/libyuv", "//third_party/libyuv",

@ -3252,7 +3252,6 @@ test("content_unittests") {
"//third_party/blink/public:test_support", "//third_party/blink/public:test_support",
"//third_party/blink/public/common:font_enumeration_table_proto", "//third_party/blink/public/common:font_enumeration_table_proto",
"//third_party/blink/public/common:headers", "//third_party/blink/public/common:headers",
"//third_party/distributed_point_functions/shim:buildflags",
"//third_party/icu", "//third_party/icu",
"//third_party/inspector_protocol:crdtp", "//third_party/inspector_protocol:crdtp",
"//third_party/inspector_protocol:crdtp_test", "//third_party/inspector_protocol:crdtp_test",

@ -481,7 +481,6 @@ third_party/crashpad/crashpad/third_party/linux 1 1
third_party/crashpad/crashpad/third_party/ninja 1 1 third_party/crashpad/crashpad/third_party/ninja 1 1
third_party/crashpad/crashpad/util/misc 1 1 third_party/crashpad/crashpad/util/misc 1 1
third_party/dav1d 2 2 third_party/dav1d 2 2
third_party/distributed_point_functions/code 2 1
third_party/expat 2 2 third_party/expat 2 2
third_party/fdlibm 1 1 third_party/fdlibm 1 1
third_party/fusejs/dist 3 1 third_party/fusejs/dist 3 1

@ -1,81 +0,0 @@
# Copyright 2021 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
import("//testing/libfuzzer/fuzzer_test.gni")
import("//third_party/distributed_point_functions/features.gni")
import("//third_party/protobuf/proto_library.gni")
# This is Chromium's interface with the third-party distributed_point_functions
# library. Targets outside of //third_party/distributed_point_functions should
# depend on this target rather than using the source directly. This extra layer
# prevents macros from leaking into Chromium code via header includes.
source_set("distributed_point_functions") {
public_deps = [ "//third_party/distributed_point_functions/shim" ]
}
proto_library("proto") {
sources = [ "code/dpf/distributed_point_function.proto" ]
proto_out_dir = "third_party/distributed_point_functions/dpf"
cc_generator_options = "lite"
}
fuzzer_test("dpf_fuzzer") {
sources = [ "fuzz/dpf_fuzzer.cc" ]
deps = [ ":internal" ]
# Do not apply Chromium code rules to this third-party code.
suppressed_configs = [ "//build/config/compiler:chromium_code" ]
additional_configs = [ "//build/config/compiler:no_chromium_code" ]
additional_configs += [ ":includes" ]
}
# Targets below this line are only visible within this file and shim/.
visibility = [
":*",
"//third_party/distributed_point_functions/shim:*",
]
config("includes") {
include_dirs = [
"code",
"$target_gen_dir",
]
}
source_set("internal") {
sources = [
"code/dpf/aes_128_fixed_key_hash.cc",
"code/dpf/aes_128_fixed_key_hash.h",
"code/dpf/distributed_point_function.cc",
"code/dpf/distributed_point_function.h",
"code/dpf/int_mod_n.cc",
"code/dpf/int_mod_n.h",
"code/dpf/internal/evaluate_prg_hwy.cc",
"code/dpf/internal/evaluate_prg_hwy.h",
"code/dpf/internal/get_hwy_mode.cc",
"code/dpf/internal/get_hwy_mode.h",
"code/dpf/internal/proto_validator.cc",
"code/dpf/internal/proto_validator.h",
"code/dpf/internal/value_type_helpers.cc",
"code/dpf/internal/value_type_helpers.h",
"code/dpf/status_macros.h",
"code/dpf/tuple.h",
"code/dpf/xor_wrapper.h",
]
public_deps = [
":proto",
"$dpf_abseil_cpp_dir:absl",
"$dpf_highway_cpp_dir:libhwy",
"//third_party/boringssl",
"//third_party/protobuf:protobuf_lite",
]
# Do not apply Chromium code rules to this third-party code.
configs -= [ "//build/config/compiler:chromium_code" ]
configs += [ "//build/config/compiler:no_chromium_code" ]
configs += [ ":includes" ]
}

@ -1,11 +0,0 @@
include_rules = [
"+absl",
"+benchmark",
"+dpf",
"+gmock",
"+google/protobuf",
"+gtest",
"+testing",
"+hwy",
"+openssl",
]

@ -1,6 +0,0 @@
monorail: {
component: "Internals>AttributionReporting"
}
buganizer_public: {
component_id: 1456103
}

@ -1,202 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

@ -1,3 +0,0 @@
alexmt@chromium.org
csharrison@chromium.org
linnan@chromium.org

@ -1,25 +0,0 @@
Name: The Incremental Distributed Point Functions library
Short Name: distributed_point_functions
URL: https://github.com/google/distributed_point_functions
Version: N/A
Revision: 2db593b64a99f178f682ef0db222d417c23e5bb5
Date: 2023-11-16
License: Apache-2.0
License File: LICENSE
Security Critical: Yes
Shipped: yes
CPEPrefix: unknown
Description:
This library contains an implementation of incremental distributed point
functions, based on the paper by Boneh et al.
Local Modifications:
The directory code/ is a copy of the source code, modified in two ways. First,
all top-level directories other than dpf/ have been removed as they are unused.
Second, a .clang-format file has been added to disable automatic code
formatting. Parts of code/dpf/distributed_point_function_test.cc are also
adapted for fuzzing in fuzz/dpf_fuzzer.cc.
Third, a missing absl/strings/str_cat.h include backported from revision
c662ca975068bfa884cc4a96f3a1db40a7611e5e to fix build error when compiled with
latest version of abseil.

@ -1 +0,0 @@
build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17

@ -1 +0,0 @@
DisableFormat: true

@ -1 +0,0 @@
experiments/data/* filter=lfs diff=lfs merge=lfs -text

@ -1,2 +0,0 @@
# Bazel generated symlinks
bazel-*

@ -1,9 +0,0 @@
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
package(
default_visibility = [":allowlist"],
)
licenses(["notice"])
exports_files(["LICENSE"])

@ -1,93 +0,0 @@
# Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, gender identity and expression, level of
experience, education, socio-economic status, nationality, personal appearance,
race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, or to ban temporarily or permanently any
contributor for other behaviors that they deem inappropriate, threatening,
offensive, or harmful.
## Scope
This Code of Conduct applies both within project spaces and in public spaces
when an individual is representing the project or its community. Examples of
representing a project or community include using an official project e-mail
address, posting via an official social media account, or acting as an appointed
representative at an online or offline event. Representation of a project may be
further defined and clarified by project maintainers.
This Code of Conduct also applies outside the project spaces when the Project
Steward has a reasonable belief that an individual's behavior may have a
negative impact on the project or its community.
## Conflict Resolution
We do not believe that all conflict is bad; healthy debate and disagreement
often yield positive results. However, it is never okay to be disrespectful or
to engage in behavior that violates the projects code of conduct.
If you see someone violating the code of conduct, you are encouraged to address
the behavior directly with those involved. Many issues can be resolved quickly
and easily, and this gives people more control over the outcome of their
dispute. If you are unable to resolve the matter for any reason, or if the
behavior is threatening or harassing, report it. We are dedicated to providing
an environment where participants feel welcome and safe.
Reports should be directed to *[PROJECT STEWARD NAME(s) AND EMAIL(s)]*, the
Project Steward(s) for *[PROJECT NAME]*. It is the Project Stewards duty to
receive and address reported violations of the code of conduct. They will then
work with a committee consisting of representatives from the Open Source
Programs Office and the Google Open Source Strategy team. If for any reason you
are uncomfortable reaching out to the Project Steward, please email
opensource@google.com.
We will investigate every complaint, but you may not receive a direct response.
We will use our discretion in determining when and how to follow up on reported
incidents, which may range from not taking action to permanent expulsion from
the project and project-sponsored spaces. We will notify the accused of the
report and provide them an opportunity to discuss it before any action is taken.
The identity of the reporter will be omitted from the details of the report
supplied to the accused. In potentially harmful situations, such as ongoing
harassment or threats to anyone's safety, we may take action without notice.
## Attribution
This Code of Conduct is adapted from the Contributor Covenant, version 1.4,
available at
https://www.contributor-covenant.org/version/1/4/code-of-conduct.html

@ -1,29 +0,0 @@
# How to Contribute
We'd love to accept your patches and contributions to this project. There are
just a few small guidelines you need to follow.
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement (CLA). You (or your employer) retain the copyright to your
contribution; this simply gives us permission to use and redistribute your
contributions as part of the project. Head over to
<https://cla.developers.google.com/> to see your current agreements on file or
to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Code reviews
All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.
## Community Guidelines
This project follows
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).

@ -1,202 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

@ -1,48 +0,0 @@
# An Implementation of Incremental Distributed Point Functions in C++ [![Build status](https://badge.buildkite.com/64bb7c0fcc8c11d630517356b2c3932d7e14850801a5f22c48.svg?branch=master)](https://buildkite.com/bazel/google-distributed-point-functions)
This library contains an implementation of incremental distributed point
functions, based on the following paper:
> Boneh, D., Boyle, E., Corrigan-Gibbs, H., Gilboa, N., & Ishai, Y. (2020).
Lightweight Techniques for Private Heavy Hitters. arXiv preprint
> arXiv:2012.14884. https://arxiv.org/abs/2012.14884
## About Incremental Distributed Point Functions
A distributed point function (DPF) is parameterized by an index `alpha` and a
value `beta`. It consists of two algorithms: key generation and evaluation.
The key generation procedure produces two keys `k_a` and `k_b`, given `alpha`
and `beta`. Evaluating each key on any point `x` in the DPF domain results in an
additive secret share of `beta`, if `x == alpha`, and a share of 0 otherwise.
Incremental DPFs additionally can be evaluated on prefixes of the index domain.
More precisely, an incremental DPF is parameterized by a hierarchy of index
domains, each a power of two larger than the previous. Key generation now takes
a vector `beta`, one value `beta[i]` for each hierarchy level.
When evaluated on a `b`-bit prefix of `alpha`, where b is the log domain size of
the `i`-th hierarchy level, the incremental DPF returns a secret share of
`beta[i]`, otherwise a share of 0.
For more details, see the above paper, as well as the
[`DistributedPointFunction` class documentation](dpf/distributed_point_function.h).
## Building/Running Tests
This repository requires Bazel. You can install Bazel by
following the instructions for your platform on the
[Bazel website](https://docs.bazel.build/versions/master/install.html).
Once you have installed Bazel you can clone this repository and run all tests
that are included by navigating into the root folder and running:
```bash
bazel test //...
```
## Security
To report a security issue, please read [SECURITY.md](SECURITY.md).
## Disclaimer
This is not an officially supported Google product. The code is provided as-is,
with no guarantees of correctness or security.

@ -1,5 +0,0 @@
# Security
To report a security issue, please use http://g.co/vulnz. We use
http://g.co/vulnz for our intake, and do coordination and disclosure here on
GitHub (including using GitHub Security Advisory). The Google Security Team will
respond within 5 working days of your report on g.co/vulnz.

@ -1,170 +0,0 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
# rules_proto defines abstract rules for building Protocol Buffers.
# https://github.com/bazelbuild/rules_proto
http_archive(
name = "rules_proto",
sha256 = "0daa4fc5b2b820705fcbf239557515f9ab809be45a1e7c6dfaa1d465d5c615d4",
strip_prefix = "rules_proto-3f1ab99b718e3e7dd86ebdc49c580aa6a126b1cd",
urls = [
"https://github.com/bazelbuild/rules_proto/archive/3f1ab99b718e3e7dd86ebdc49c580aa6a126b1cd.zip",
],
)
load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains")
rules_proto_dependencies()
rules_proto_toolchains()
# rules_cc defines rules for generating C++ code from Protocol Buffers.
# https://github.com/bazelbuild/rules_cc
http_archive(
name = "rules_cc",
sha256 = "e17cca44563e0918a36a8ea2a50acb99ea9ad726bbd3cad8ba95a643a40121ab",
strip_prefix = "rules_cc-d7c11265cb157c9b962d87d9ab67b8c24e3a875f",
urls = [
"https://github.com/bazelbuild/rules_cc/archive/d7c11265cb157c9b962d87d9ab67b8c24e3a875f.zip",
],
)
load("@rules_cc//cc:repositories.bzl", "rules_cc_dependencies")
rules_cc_dependencies()
# io_bazel_rules_go defines rules for generating C++ code from Protocol Buffers.
# https://github.com/bazelbuild/rules_go
http_archive(
name = "io_bazel_rules_go",
sha256 = "7c35e8515012279ef7bcbc39c4ef4b54a86756d853848cb621b7da49f156c82f",
strip_prefix = "rules_go-b397ab7ace3c4131f48b5f4d4d7e7e9e6809e0d2",
urls = [
"https://github.com/bazelbuild/rules_go/archive/b397ab7ace3c4131f48b5f4d4d7e7e9e6809e0d2.zip",
],
)
load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies")
go_rules_dependencies()
go_register_toolchains(version = "1.19.3")
# Install gtest.
# https://github.com/google/googletest
http_archive(
name = "com_github_google_googletest",
sha256 = "3e91944af2d909a79f18ee9760765624810146ccfae8f1a8f990037a1677d44b",
strip_prefix = "googletest-ac7a126f39d5bcd909b78c9e69900c76659b1bbb",
urls = [
"https://github.com/google/googletest/archive/ac7a126f39d5bcd909b78c9e69900c76659b1bbb.zip",
],
)
# abseil-cpp
# https://github.com/abseil/abseil-cpp
http_archive(
name = "com_google_absl",
sha256 = "431c0c47217c36106f90e2ca4fcdf45af618ea21adde880804661b1ecb240056",
strip_prefix = "abseil-cpp-1fb3830b1cf685999bb2bbd0294be0a53c9440a6",
urls = [
"https://github.com/abseil/abseil-cpp/archive/1fb3830b1cf685999bb2bbd0294be0a53c9440a6.zip",
],
)
# BoringSSL
# https://github.com/google/boringssl
http_archive(
name = "boringssl",
sha256 = "88e4330f4f65ebfdf24847e4807c25f3eacfd5bf1a93f6629d3941196ff9b0b3",
strip_prefix = "boringssl-6347808f2a480a3792148bf7732232229db9b909",
urls = [
"https://github.com/google/boringssl/archive/6347808f2a480a3792148bf7732232229db9b909.zip",
],
)
# Benchmarks
# https://github.com/google/benchmark
http_archive(
name = "com_github_google_benchmark",
sha256 = "5f98b44165f3250f1d749b728018318d654f763ea0f4d7ea156e10e6e0cc678a",
strip_prefix = "benchmark-5e78bedfb07c615edb2b646d1e354980268c1728",
urls = [
"https://github.com/google/benchmark/archive/5e78bedfb07c615edb2b646d1e354980268c1728.zip",
],
)
# gflags needed for glog.
# https://github.com/gflags/gflags
http_archive(
name = "com_github_gflags_gflags",
sha256 = "017e0a91531bfc45be9eaf07e4d8fed33c488b90b58509dbd2e33a33b2648ae6",
strip_prefix = "gflags-a738fdf9338412f83ab3f26f31ac11ed3f3ec4bd",
urls = [
"https://github.com/gflags/gflags/archive/a738fdf9338412f83ab3f26f31ac11ed3f3ec4bd.zip",
],
)
# glog for logging
# https://github.com/google/glog
http_archive(
name = "com_github_google_glog",
sha256 = "0f91ee6cc1edc3b1c53a286382e69a37e5d172ce208b7e5b305be8770d8c21b1",
strip_prefix = "glog-f545ff5e7d7f3df95f6e86c8cb987d9d9d4bd481",
urls = [
"https://github.com/google/glog/archive/f545ff5e7d7f3df95f6e86c8cb987d9d9d4bd481.zip",
],
)
# IREE for cc_embed_data.
# https://github.com/google/iree
http_archive(
name = "com_github_google_iree",
sha256 = "aa369b29a5c45ae9d7aa8bf49ea1308221d1711277222f0755df6e0a575f6879",
strip_prefix = "iree-7e6012468cbaafaaf30302748a2943771b40e2c3",
urls = [
"https://github.com/google/iree/archive/7e6012468cbaafaaf30302748a2943771b40e2c3.zip",
],
)
# riegeli for file IO
# https://github.com/google/riegeli
http_archive(
name = "com_github_google_riegeli",
sha256 = "3de21a222271a1e2c5d728e7f46b63ab4520da829c09ef9727a322e693c9ac18",
strip_prefix = "riegeli-43b7ef9f995469609b6ab07f6becc82186314bfb",
urls = [
"https://github.com/google/riegeli/archive/43b7ef9f995469609b6ab07f6becc82186314bfb.zip",
],
)
# rules_license needed for Highway
# https://github.com/bazelbuild/rules_license
http_archive(
name = "rules_license",
sha256 = "6157e1e68378532d0241ecd15d3c45f6e5cfd98fc10846045509fb2a7cc9e381",
urls = [
"https://github.com/bazelbuild/rules_license/releases/download/0.0.4/rules_license-0.0.4.tar.gz",
],
)
# Highway for SIMD operations.
# https://github.com/google/highway
http_archive(
name = "com_github_google_highway",
sha256 = "cdba0eb21796598dd50fa0a4aa3651fa466c0d37c39d149ee383f725434e4314",
strip_prefix = "highway-45c98184ab7f81cf592c07633070b75fced14a52",
urls = [
"https://github.com/google/highway/archive/45c98184ab7f81cf592c07633070b75fced14a52.zip",
],
)
# cppitertools for logging
# https://github.com/ryanhaining/cppitertools
http_archive(
name = "com_github_ryanhaining_cppitertools",
sha256 = "1608ddbe3c12b0c6e653b992ff63b5dceab9af5347ad93be8714d05e5dc17afb",
strip_prefix = "cppitertools-add5acc932dea2c78acd80747bab71ec0b5bce27",
urls = [
"https://github.com/ryanhaining/cppitertools/archive/add5acc932dea2c78acd80747bab71ec0b5bce27.zip",
],
)

@ -1,235 +0,0 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_proto//proto:defs.bzl", "proto_library")
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
cc_library(
name = "int_mod_n",
srcs = ["int_mod_n.cc"],
hdrs = ["int_mod_n.h"],
deps = [
"@com_google_absl//absl/base:config",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/numeric:int128",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
],
)
cc_test(
name = "int_mod_n_test",
srcs = ["int_mod_n_test.cc"],
deps = [
":int_mod_n",
"//dpf/internal:status_matchers",
"@com_github_google_googletest//:gtest_main",
"@com_google_absl//absl/base:config",
"@com_google_absl//absl/numeric:int128",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
],
)
cc_test(
name = "int_mod_n_benchmark",
srcs = ["int_mod_n_benchmark.cc"],
deps = [
":int_mod_n",
"@boringssl//:crypto",
"@com_github_google_benchmark//:benchmark",
"@com_github_google_googletest//:gtest_main",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "distributed_point_function",
srcs = ["distributed_point_function.cc"],
hdrs = ["distributed_point_function.h"],
deps = [
":aes_128_fixed_key_hash",
":distributed_point_function_cc_proto",
":status_macros",
"//dpf/internal:evaluate_prg_hwy",
"//dpf/internal:get_hwy_mode",
"//dpf/internal:maybe_deref_span",
"//dpf/internal:proto_validator",
"//dpf/internal:value_type_helpers",
"@boringssl//:crypto",
"@com_github_google_highway//:hwy",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/meta:type_traits",
"@com_google_absl//absl/numeric:int128",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@com_google_protobuf//:protobuf",
"@com_google_protobuf//:protobuf_lite",
],
)
cc_test(
name = "distributed_point_function_test",
size = "medium",
srcs = ["distributed_point_function_test.cc"],
deps = [
":distributed_point_function",
":distributed_point_function_cc_proto",
":xor_wrapper",
"//dpf/internal:proto_validator",
"//dpf/internal:status_matchers",
"@com_github_google_googletest//:gtest_main",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:config",
"@com_google_absl//absl/numeric:int128",
"@com_google_absl//absl/random",
"@com_google_absl//absl/random:distributions",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/utility",
],
)
proto_library(
name = "distributed_point_function_proto",
srcs = ["distributed_point_function.proto"],
)
cc_proto_library(
name = "distributed_point_function_cc_proto",
deps = [":distributed_point_function_proto"],
)
go_proto_library(
name = "distributed_point_function_go_proto",
importpath = "github.com/google/distributed_point_functions/dpf/distributed_point_function_go_proto",
protos = [":distributed_point_function_proto"],
)
cc_test(
name = "distributed_point_function_benchmark",
srcs = [
"distributed_point_function_benchmark.cc",
],
tags = ["benchmark"],
deps = [
":distributed_point_function",
"@com_github_google_benchmark//:benchmark",
"@com_github_google_googletest//:gtest_main",
"@com_github_google_highway//:hwy",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/numeric:int128",
"@com_google_absl//absl/random",
"@com_google_absl//absl/random:distributions",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@com_google_protobuf//:protobuf",
],
)
cc_library(
name = "status_macros",
hdrs = ["status_macros.h"],
)
cc_library(
name = "aes_128_fixed_key_hash",
srcs = ["aes_128_fixed_key_hash.cc"],
hdrs = ["aes_128_fixed_key_hash.h"],
deps = [
"@boringssl//:crypto",
"@com_google_absl//absl/numeric:int128",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
cc_test(
name = "aes_128_fixed_key_hash_test",
srcs = ["aes_128_fixed_key_hash_test.cc"],
deps = [
":aes_128_fixed_key_hash",
"//dpf/internal:status_matchers",
"@com_github_google_googletest//:gtest_main",
"@com_google_absl//absl/numeric:int128",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "tuple",
hdrs = ["tuple.h"],
)
cc_test(
name = "tuple_test",
srcs = [
"tuple_test.cc",
],
deps = [
":tuple",
"@com_github_google_googletest//:gtest_main",
"@com_google_absl//absl/numeric:int128",
],
)
cc_library(
name = "xor_wrapper",
hdrs = ["xor_wrapper.h"],
)
cc_test(
name = "xor_wrapper_test",
srcs = [
"xor_wrapper_test.cc",
],
deps = [
":xor_wrapper",
"@com_github_google_googletest//:gtest_main",
"@com_google_absl//absl/numeric:int128",
],
)

@ -1,102 +0,0 @@
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dpf/aes_128_fixed_key_hash.h"
#include <stdint.h>
#include <algorithm>
#include <array>
#include <string>
#include <utility>
#include "absl/numeric/int128.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "openssl/err.h"
namespace distributed_point_functions {
Aes128FixedKeyHash::Aes128FixedKeyHash(
bssl::UniquePtr<EVP_CIPHER_CTX> cipher_ctx, absl::uint128 key)
: cipher_ctx_(std::move(cipher_ctx)), key_(key) {}
absl::StatusOr<Aes128FixedKeyHash> Aes128FixedKeyHash::Create(
absl::uint128 key) {
bssl::UniquePtr<EVP_CIPHER_CTX> cipher_ctx(EVP_CIPHER_CTX_new());
if (!cipher_ctx) {
return absl::InternalError("Failed to allocate AES context");
}
// Set up the OpenSSL encryption context. We want to evaluate the PRG in
// parallel on many seeds (see class comment in pseudorandom_generator.h), so
// we're using ECB mode here to achieve that. This batched evaluation is not
// to be confused with encryption of an array, for which ECB would be
// insecure.
int openssl_status =
EVP_EncryptInit_ex(cipher_ctx.get(), EVP_aes_128_ecb(), nullptr,
reinterpret_cast<const uint8_t*>(&key), nullptr);
if (openssl_status != 1) {
return absl::InternalError("Failed to set up AES context");
}
return Aes128FixedKeyHash(std::move(cipher_ctx), key);
}
absl::Status Aes128FixedKeyHash::Evaluate(absl::Span<const absl::uint128> in,
absl::Span<absl::uint128> out) const {
if (in.size() != out.size()) {
return absl::InvalidArgumentError("Input and output sizes don't match");
}
if (in.empty()) {
// Nothing to do.
return absl::OkStatus();
}
// Compute orthomorphism sigma for each element in `in`, `kBatchSize` elements
// at a time.
auto in_size = static_cast<int64_t>(in.size());
std::array<absl::uint128, kBatchSize> sigma_in;
for (int64_t start_block = 0; start_block < in_size;
start_block += kBatchSize) {
int64_t batch_size = std::min<int64_t>(in_size - start_block, kBatchSize);
for (int i = 0; i < batch_size; ++i) {
sigma_in[i] =
absl::MakeUint128(absl::Uint128High64(in[start_block + i]) ^
absl::Uint128Low64(in[start_block + i]),
absl::Uint128High64(in[start_block + i]));
}
// We use EVP_Cipher here instead of EVP_EncryptUpdate, since it doesn't
// mutate the context in ECB mode, and so this call is thread-safe.
int openssl_status = EVP_Cipher(
cipher_ctx_.get(), reinterpret_cast<uint8_t*>(out.data() + start_block),
reinterpret_cast<const uint8_t*>(sigma_in.data()),
static_cast<int>(batch_size * sizeof(absl::uint128)));
if (openssl_status != 1) {
char buf[256];
ERR_error_string_n(ERR_get_error(), buf, sizeof(buf));
return absl::InternalError(
absl::StrCat("AES encryption failed: ", std::string(buf)));
}
for (int64_t i = 0; i < batch_size; ++i) {
out[start_block + i] ^= sigma_in[i];
}
}
return absl::OkStatus();
}
} // namespace distributed_point_functions

@ -1,86 +0,0 @@
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_H_
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_H_
#include "absl/numeric/int128.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "openssl/cipher.h"
namespace distributed_point_functions {
// Aes128FixedKeyHash is a circular correlation-robust hash function based on
// AES. For key `key`, input `in` and output `out`, the hash function is defined
// as
//
// out[i] = AES.Encrypt(key, sigma(in[i])) ^ sigma(in[i]),
//
// where sigma(x) = (x.high64 ^ x.low64, x.high64). This is the
// circular correlation-robust MMO construction from
// https://eprint.iacr.org/2019/074.pdf (pp. 18-19). Note that unlike
// cryptographic hash functions such as SHA-256, this hash function is *not*
// compressing and is not designed to provide any security guarantees beyond
// circular correlation-robustness. Use with appropriate caution.
class Aes128FixedKeyHash {
public:
// Creates a new Aes128FixedKeyHash with the given `key`.
//
// Returns INTERNAL in case of allocation failures or OpenSSL errors.
static absl::StatusOr<Aes128FixedKeyHash> Create(absl::uint128 key);
// Computes hash values of each block in `in`, writing the output to `out`.
// It is safe to call this method if `in` and `out` overlap.
//
// Returns INVALID_ARGUMENT if sizes of `in` and `out` don't match or their
// sizes in bytes exceed an `int`, or INTERNAL in case of OpenSSL errors.
absl::Status Evaluate(absl::Span<const absl::uint128> in,
absl::Span<absl::uint128> out) const;
// Aes128FixedKeyHash is not copyable.
Aes128FixedKeyHash(const Aes128FixedKeyHash&) = delete;
Aes128FixedKeyHash& operator=(const Aes128FixedKeyHash&) = delete;
// Aes128FixedKeyHash is movable (it just wraps a bssl::UniquePtr).
Aes128FixedKeyHash(Aes128FixedKeyHash&&) = default;
Aes128FixedKeyHash& operator=(Aes128FixedKeyHash&&) = default;
// Returns the key used to construct this hash function.
// DO NOT SEND THIS TO ANY OTHER PARTY!
const absl::uint128& key() const { return key_; }
// The maximum number of AES blocks encrypted at once. Chosen to pipeline AES
// as much as possible, while still allowing both source and destination to
// comfortably fit in the L1 CPU cache.
static constexpr int kBatchSize = 64;
private:
// Called by `Create`.
Aes128FixedKeyHash(bssl::UniquePtr<EVP_CIPHER_CTX> cipher_ctx,
absl::uint128 key);
// The OpenSSL encryption context used by `Evaluate`.
bssl::UniquePtr<EVP_CIPHER_CTX> cipher_ctx_;
// The key used to construct this hash function.
absl::uint128 key_;
};
} // namespace distributed_point_functions
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_H_

@ -1,178 +0,0 @@
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dpf/aes_128_fixed_key_hash.h"
#include <thread> // NOLINT(build/c++11)
#include <vector>
#include "absl/numeric/int128.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "dpf/internal/status_matchers.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace distributed_point_functions {
namespace {
using dpf_internal::StatusIs;
// Test blocks for keys, inputs, and outputs.
constexpr absl::uint128 kKey0 =
absl::MakeUint128(0x0000000000000000, 0x0000000000000000);
constexpr absl::uint128 kKey1 =
absl::MakeUint128(0x1111111111111111, 0x1111111111111111);
constexpr absl::uint128 kSeed0 =
absl::MakeUint128(0x0123012301230123, 0x0123012301230123);
constexpr absl::uint128 kSeed1 =
absl::MakeUint128(0x4567456745674567, 0x4567456745674567);
constexpr absl::uint128 kSeed2 =
absl::MakeUint128(0x89ab89ab89ab89ab, 0x89ab89ab89ab89ab);
constexpr absl::uint128 kSeed3 =
absl::MakeUint128(0xcdefcdefcdefcdef, 0xcdefcdefcdefcdef);
TEST(Aes128FixedKeyHashTest, CreateSucceeds) {
DPF_EXPECT_OK(Aes128FixedKeyHash::Create(kKey0));
}
TEST(Aes128FixedKeyHashTest, SameKeysAndSeedsGenerateSameOutput) {
std::vector<absl::uint128> in;
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_0,
Aes128FixedKeyHash::Create(kKey0));
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_1,
Aes128FixedKeyHash::Create(kKey0));
in = {kSeed0};
// Initialize output arrays with different values, to make sure they are the
// same afterwards.
std::vector<absl::uint128> out_0(in.size(), kSeed2), out_1(in.size(), kSeed3);
DPF_EXPECT_OK(prg_0.Evaluate(in, absl::MakeSpan(out_0)));
DPF_EXPECT_OK(prg_1.Evaluate(in, absl::MakeSpan(out_1)));
EXPECT_THAT(out_0, testing::ElementsAreArray(out_1));
}
TEST(Aes128FixedKeyHashTest, DifferentKeysGenerateDifferentOutput) {
std::vector<absl::uint128> in{kSeed0};
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_0,
Aes128FixedKeyHash::Create(kKey0));
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_1,
Aes128FixedKeyHash::Create(kKey1));
// Initialize output arrays with the same values, to make sure they are
// different afterwards.
std::vector<absl::uint128> out_0(in.size(), kSeed2), out_1(in.size(), kSeed2);
DPF_EXPECT_OK(prg_0.Evaluate(in, absl::MakeSpan(out_0)));
DPF_EXPECT_OK(prg_1.Evaluate(in, absl::MakeSpan(out_1)));
EXPECT_THAT(out_0, testing::Not(testing::ElementsAreArray(out_1)));
}
TEST(Aes128FixedKeyHashTest, DifferentSeedsGenerateDifferentOutput) {
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg,
Aes128FixedKeyHash::Create(kKey0));
std::vector<absl::uint128> in_0, in_1;
in_0 = {kSeed0};
in_1 = {kSeed1};
// Initialize output arrays with the same values, to make sure they are
// different afterwards.
std::vector<absl::uint128> out_0(in_0.size(), kSeed2),
out_1(in_1.size(), kSeed2);
DPF_EXPECT_OK(prg.Evaluate(in_0, absl::MakeSpan(out_0)));
DPF_EXPECT_OK(prg.Evaluate(in_1, absl::MakeSpan(out_1)));
EXPECT_THAT(out_0, testing::Not(testing::ElementsAreArray(out_1)));
}
TEST(Aes128FixedKeyHashTest, BatchedEvaluationEqualsBlockWiseEvaluation) {
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg,
Aes128FixedKeyHash::Create(kKey0));
std::vector<absl::uint128> in_0, in_1, in_2;
in_0 = {kSeed0};
in_1 = {kSeed1};
in_2 = {kSeed0, kSeed1};
std::vector<absl::uint128> out_0(in_0.size()), out_1(in_1.size()),
out_2(in_2.size());
DPF_EXPECT_OK(prg.Evaluate(in_0, absl::MakeSpan(out_0)));
DPF_EXPECT_OK(prg.Evaluate(in_1, absl::MakeSpan(out_1)));
DPF_EXPECT_OK(prg.Evaluate(in_2, absl::MakeSpan(out_2)));
EXPECT_THAT(out_2, testing::ElementsAre(out_0[0], out_1[0]));
}
TEST(Aes128FixedKeyHashTest, TestSpecificOutputValues) {
std::vector<absl::uint128> in, out_0, out_1;
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_0,
Aes128FixedKeyHash::Create(kKey0));
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg_1,
Aes128FixedKeyHash::Create(kKey1));
in = {kSeed0, kSeed1};
out_0.resize(in.size());
out_1.resize(in.size());
DPF_EXPECT_OK(prg_0.Evaluate(in, absl::MakeSpan(out_0)));
DPF_EXPECT_OK(prg_1.Evaluate(in, absl::MakeSpan(out_1)));
EXPECT_THAT(out_0,
testing::ElementsAre(
absl::MakeUint128(0x73c2dc14812be4ef, 0xeac64d09c8adf8ed),
absl::MakeUint128(0xb8f33653a53a8436, 0xaedf39b62de91d95)));
EXPECT_THAT(out_1,
testing::ElementsAre(
absl::MakeUint128(0x934704aff58fa233, 0xd3c20d1b9cc18d8f),
absl::MakeUint128(0x530098817046d284, 0x43e61d3273a04f7c)));
}
TEST(Aes128FixedKeyHashTest, EvaluateFailsWhenSizesDontMatch) {
std::vector<absl::uint128> in{kSeed0};
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg,
Aes128FixedKeyHash::Create(kKey0));
std::vector<absl::uint128> out(in.size() + 1);
EXPECT_THAT(prg.Evaluate(in, absl::MakeSpan(out)),
StatusIs(absl::StatusCode::kInvalidArgument,
"Input and output sizes don't match"));
}
TEST(Aes128FixedKeyHashTest, TestThreadSafety) {
std::vector<absl::uint128> in{kSeed0};
DPF_ASSERT_OK_AND_ASSIGN(Aes128FixedKeyHash prg,
Aes128FixedKeyHash::Create(kKey0));
constexpr int kNumThreads = 1024;
auto do_evaluation = [&prg, &in]() {
absl::uint128 out;
DPF_ASSERT_OK(prg.Evaluate(in, absl::MakeSpan(&out, 1)));
};
std::vector<std::thread> threads;
threads.reserve(kNumThreads);
for (int i = 0; i < kNumThreads; ++i) {
threads.emplace_back(do_evaluation);
}
for (auto& thread : threads) {
thread.join();
}
}
} // namespace
} // namespace distributed_point_functions

@ -1,732 +0,0 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dpf/distributed_point_function.h"
#include <algorithm>
#include <array>
#include <cstddef>
#include <limits>
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/memory/memory.h"
#include "absl/numeric/int128.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "dpf/internal/evaluate_prg_hwy.h"
#include "dpf/internal/get_hwy_mode.h"
#include "dpf/internal/proto_validator.h"
#include "dpf/internal/value_type_helpers.h"
#include "dpf/status_macros.h"
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
#include "hwy/aligned_allocator.h"
#include "openssl/rand.h"
namespace distributed_point_functions {
namespace {
// PRG keys used to expand seeds using AES. The first two are used to compute
// correction words of seeds, while the last is used to compute correction
// words of the incremental DPF values. Values were computed by taking the
// first half of the SHA256 sum of the constant name, e.g., `echo
// "DistributedPointFunction::kPrgKeyLeft" | sha256sum`
constexpr absl::uint128 kPrgKeyLeft =
absl::MakeUint128(0x5be037ccf6a03de5ULL, 0x935f08d0a5b6a2fdULL);
constexpr absl::uint128 kPrgKeyRight =
absl::MakeUint128(0xef94b6aedebb026cULL, 0xe2ea1fe0f66f4d0bULL);
constexpr absl::uint128 kPrgKeyValue =
absl::MakeUint128(0x05a5d1588c5423e3ULL, 0x46a31101b21d1c98ULL);
} // namespace
DistributedPointFunction::DistributedPointFunction(
std::unique_ptr<dpf_internal::ProtoValidator> proto_validator,
std::vector<int> blocks_needed, Aes128FixedKeyHash prg_left,
Aes128FixedKeyHash prg_right, Aes128FixedKeyHash prg_value,
absl::flat_hash_map<std::string, ValueCorrectionFunction>
value_correction_functions)
: proto_validator_(std::move(proto_validator)),
parameters_(proto_validator_->parameters()),
tree_levels_needed_(proto_validator_->tree_levels_needed()),
tree_to_hierarchy_(proto_validator_->tree_to_hierarchy()),
hierarchy_to_tree_(proto_validator_->hierarchy_to_tree()),
blocks_needed_(std::move(blocks_needed)),
prg_left_(std::move(prg_left)),
prg_right_(std::move(prg_right)),
prg_value_(std::move(prg_value)),
value_correction_functions_(value_correction_functions) {}
absl::StatusOr<std::vector<Value>>
DistributedPointFunction::ComputeValueCorrection(
int hierarchy_level, absl::Span<const absl::uint128> seeds,
absl::uint128 alpha, const Value& beta, bool invert) const {
// Compute value output component of the PRG on current seeds. To that end, we
// Compute x_0+0, ..., x_0+k-1, and x_1+0, ..., x_1+k-1, where x_i is the seed
// for helper i, and k is the number of blocks needed at the current hierarchy
// level. We use a single contiguous vector for both helpers, which allows us
// to use a single call to prg_value_.Evaluate.
int blocks_needed = blocks_needed_[hierarchy_level];
std::vector<absl::uint128> expanded_seeds(2 * blocks_needed);
absl::Span<absl::uint128> expanded_seed_a(&expanded_seeds[0], blocks_needed);
absl::Span<absl::uint128> expanded_seed_b(&expanded_seeds[blocks_needed],
blocks_needed);
ABSL_DCHECK(seeds.size() == 2);
std::iota(expanded_seed_a.begin(), expanded_seed_a.end(), seeds[0]);
std::iota(expanded_seed_b.begin(), expanded_seed_b.end(), seeds[1]);
// Evaluate PRG in place (this is safe as `Evaluate` creates a copy of the
// input).
DPF_RETURN_IF_ERROR(
prg_value_.Evaluate(expanded_seeds, absl::MakeSpan(expanded_seeds)));
// Compute index in block for alpha at the current hierarchy level.
int index_in_block = DomainToBlockIndex(alpha, hierarchy_level);
// Choose implementation depending on element_bitsize.
DPF_ASSIGN_OR_RETURN(
ValueCorrectionFunction func,
GetValueCorrectionFunction(parameters_[hierarchy_level]));
return func(
absl::string_view(reinterpret_cast<const char*>(expanded_seed_a.data()),
blocks_needed * sizeof(absl::uint128)),
absl::string_view(reinterpret_cast<const char*>(expanded_seed_b.data()),
blocks_needed * sizeof(absl::uint128)),
index_in_block, beta, invert);
}
// Expands the PRG seeds at the next `tree_level`, updates `seeds` and
// `control_bits`, and writes the next correction word to `keys`.
absl::Status DistributedPointFunction::GenerateNext(
int tree_level, absl::uint128 alpha, absl::Span<const Value> beta,
absl::Span<absl::uint128> seeds, absl::Span<bool> control_bits,
absl::Span<DpfKey> keys) const {
// As in `GenerateKeysIncremental`, we annotate code with the corresponding
// lines from https://arxiv.org/pdf/2012.14884.pdf#figure.caption.12.
//
// Lines 13 & 14: Compute value correction word if there is a value on the
// current level. This is done here already, since we use the "PRG evaluation
// optimization" described in Appendix C.2 of the paper. Since we are using
// fixed-key AES as PRG, which can have arbitrary stretch, this optimization
// works even for large output groups.
CorrectionWord* correction_word = keys[0].add_correction_words();
if (tree_to_hierarchy_.contains(tree_level - 1)) {
int hierarchy_level = tree_to_hierarchy_.at(tree_level - 1);
absl::uint128 alpha_prefix = 0;
int shift_amount = parameters_.back().log_domain_size() -
parameters_[hierarchy_level].log_domain_size();
if (shift_amount < 128) {
alpha_prefix = alpha >> shift_amount;
}
DPF_ASSIGN_OR_RETURN(
std::vector<Value> value_correction,
ComputeValueCorrection(hierarchy_level, seeds, alpha_prefix,
beta[hierarchy_level], control_bits[1]));
for (const Value& value : value_correction) {
*(correction_word->add_value_correction()) = value;
}
}
// Line 5: Expand seeds from previous level.
std::array<std::array<absl::uint128, 2>, 2> expanded_seeds;
DPF_RETURN_IF_ERROR(
prg_left_.Evaluate(seeds, absl::MakeSpan(expanded_seeds[0])));
DPF_RETURN_IF_ERROR(
prg_right_.Evaluate(seeds, absl::MakeSpan(expanded_seeds[1])));
std::array<std::array<bool, 2>, 2> expanded_control_bits;
expanded_control_bits[0][0] =
dpf_internal::ExtractAndClearLowestBit(expanded_seeds[0][0]);
expanded_control_bits[0][1] =
dpf_internal::ExtractAndClearLowestBit(expanded_seeds[0][1]);
expanded_control_bits[1][0] =
dpf_internal::ExtractAndClearLowestBit(expanded_seeds[1][0]);
expanded_control_bits[1][1] =
dpf_internal::ExtractAndClearLowestBit(expanded_seeds[1][1]);
// Lines 6-8: Assign keep/lose branch depending on current bit of `alpha`.
bool current_bit = 0;
if (parameters_.back().log_domain_size() - tree_level < 128) {
current_bit =
(alpha & (absl::uint128{1}
<< (parameters_.back().log_domain_size() - tree_level))) != 0;
}
bool keep = current_bit, lose = !current_bit;
// Line 9: Compute seed correction word.
absl::uint128 seed_correction =
expanded_seeds[lose][0] ^ expanded_seeds[lose][1];
// Line 10: Compute control bit correction words.
std::array<bool, 2> control_bit_correction;
control_bit_correction[0] = expanded_control_bits[0][0] ^
expanded_control_bits[0][1] ^ current_bit ^ 1;
control_bit_correction[1] =
expanded_control_bits[1][0] ^ expanded_control_bits[1][1] ^ current_bit;
// We swap lines 11 and 12, since we first need to use the previous level's
// control bits before updating them.
// Line 12: Update seeds. Note that there is a typo in the paper: The
// multiplication / AND needs to be done with the control bit of iteration
// l-1, not l. Note that unlike the original algorithm, we are using the
// corrected seed directly for the next iteration. This is secure as we're
// using AES with a different key (kPrgKeyValue) to compute the value
// correction word below.
seeds[0] = expanded_seeds[keep][0];
seeds[1] = expanded_seeds[keep][1];
if (control_bits[0]) {
seeds[0] ^= seed_correction;
}
if (control_bits[1]) {
seeds[1] ^= seed_correction;
}
// Line 11: Update control bits. Again, same typo as in Line 12.
control_bits[0] = expanded_control_bits[keep][0] ^
(control_bits[0] && control_bit_correction[keep]);
control_bits[1] = expanded_control_bits[keep][1] ^
(control_bits[1] && control_bit_correction[keep]);
// Line 15: Assemble correction word and add it to keys[0].
correction_word->mutable_seed()->set_high(
absl::Uint128High64(seed_correction));
correction_word->mutable_seed()->set_low(absl::Uint128Low64(seed_correction));
correction_word->set_control_left(control_bit_correction[0]);
correction_word->set_control_right(control_bit_correction[1]);
// Copy correction word to second key.
*(keys[1].add_correction_words()) = *correction_word;
return absl::OkStatus();
}
absl::uint128 DistributedPointFunction::DomainToTreeIndex(
absl::uint128 domain_index, int hierarchy_level) const {
int block_index_bits = parameters_[hierarchy_level].log_domain_size() -
hierarchy_to_tree_[hierarchy_level];
ABSL_DCHECK_LT(block_index_bits, 128);
return domain_index >> block_index_bits;
}
int DistributedPointFunction::DomainToBlockIndex(absl::uint128 domain_index,
int hierarchy_level) const {
int block_index_bits = parameters_[hierarchy_level].log_domain_size() -
hierarchy_to_tree_[hierarchy_level];
ABSL_DCHECK_LT(block_index_bits, 128);
return static_cast<int>(domain_index &
((absl::uint128{1} << block_index_bits) - 1));
}
absl::Status DistributedPointFunction::EvaluateSeeds(
absl::Span<const absl::uint128> seeds, absl::Span<const bool> control_bits,
absl::Span<const absl::uint128> paths,
absl::Span<const CorrectionWord* const> correction_words,
absl::Span<absl::uint128> seeds_out,
absl::Span<bool> control_bits_out) const {
if (seeds.size() != control_bits.size() || seeds.size() != paths.size() ||
seeds.size() != seeds_out.size() ||
seeds.size() != control_bits_out.size()) {
return absl::InvalidArgumentError(
"`seeds`, `control_bits`, `paths`, `seeds_out`, and `control_bits_out` "
"must all have equal sizes");
}
auto num_seeds = static_cast<int64_t>(seeds.size());
auto num_levels = static_cast<int>(correction_words.size());
if (num_seeds == 0 || num_levels == 0) {
return absl::OkStatus(); // Nothing to do.
}
// Parse correction words for each level.
auto correction_seeds = hwy::AllocateAligned<absl::uint128>(num_levels);
if (correction_seeds == nullptr) {
return absl::ResourceExhaustedError("Memory allocation error");
}
BitVector correction_controls_left(num_levels),
correction_controls_right(num_levels);
for (int level = 0; level < num_levels; ++level) {
const CorrectionWord& correction = *(correction_words[level]);
correction_seeds[level] =
absl::MakeUint128(correction.seed().high(), correction.seed().low());
correction_controls_left[level] = correction.control_left();
correction_controls_right[level] = correction.control_right();
}
ABSL_DCHECK(seeds.size() == num_seeds);
ABSL_DCHECK(control_bits.size() == num_seeds);
ABSL_DCHECK(correction_controls_left.size() == num_levels);
ABSL_DCHECK(correction_controls_right.size() == num_levels);
ABSL_DCHECK(seeds_out.size() == num_seeds);
ABSL_DCHECK(control_bits_out.size() == num_seeds);
DPF_RETURN_IF_ERROR(dpf_internal::EvaluateSeeds(
num_seeds, num_levels, num_levels, seeds.data(), control_bits.data(),
paths.data(), 0, correction_seeds.get(), correction_controls_left.data(),
correction_controls_right.data(), prg_left_, prg_right_, seeds_out.data(),
control_bits_out.data()));
return absl::OkStatus();
}
absl::StatusOr<DistributedPointFunction::DpfExpansion>
DistributedPointFunction::ExpandSeeds(
const DpfExpansion& partial_evaluations,
absl::Span<const CorrectionWord* const> correction_words) const {
int num_expansions = static_cast<int>(correction_words.size());
// Check that the output size fits in a size_t. This should already be checked
// by the caller, so using ABSL_DCHECK here is enough.
ABSL_DCHECK_LT(num_expansions, 63);
auto current_level_size =
static_cast<int64_t>(partial_evaluations.control_bits.size());
absl::uint128 output_size_128 = absl::uint128{current_level_size}
<< num_expansions;
ABSL_DCHECK_LE(output_size_128, std::numeric_limits<size_t>::max() / 2);
size_t output_size = static_cast<size_t>(output_size_128);
// Allocate buffers with the correct size to avoid reallocations.
int64_t max_batch_size = Aes128FixedKeyHash::kBatchSize;
std::vector<absl::uint128> prg_buffer_left(max_batch_size),
prg_buffer_right(max_batch_size);
// Copy seeds and control bits. We will swap these after every expansion.
DpfExpansion expansion;
expansion.seeds = hwy::AllocateAligned<absl::uint128>(output_size);
if (expansion.seeds == nullptr) {
return absl::ResourceExhaustedError("Memory allocation error");
}
std::copy_n(partial_evaluations.seeds.get(), current_level_size,
expansion.seeds.get());
expansion.control_bits = partial_evaluations.control_bits;
expansion.control_bits.reserve(output_size);
DpfExpansion next_level_expansion;
next_level_expansion.seeds = hwy::AllocateAligned<absl::uint128>(output_size);
if (next_level_expansion.seeds == nullptr) {
return absl::ResourceExhaustedError("Memory allocation error");
}
next_level_expansion.control_bits.reserve(output_size);
// We use an iterative expansion here to pipeline AES as much as possible.
for (int i = 0; i < num_expansions; ++i) {
next_level_expansion.control_bits.resize(0);
absl::uint128 correction_seed = absl::MakeUint128(
correction_words[i]->seed().high(), correction_words[i]->seed().low());
bool correction_control_left = correction_words[i]->control_left();
bool correction_control_right = correction_words[i]->control_right();
// Expand PRG.
for (int64_t start_block = 0; start_block < current_level_size;
start_block += max_batch_size) {
int64_t batch_size =
std::min<int64_t>(current_level_size - start_block, max_batch_size);
DPF_RETURN_IF_ERROR(prg_left_.Evaluate(
absl::MakeConstSpan(expansion.seeds.get() + start_block, batch_size),
absl::MakeSpan(prg_buffer_left).subspan(0, batch_size)));
DPF_RETURN_IF_ERROR(prg_right_.Evaluate(
absl::MakeConstSpan(expansion.seeds.get() + start_block, batch_size),
absl::MakeSpan(prg_buffer_right).subspan(0, batch_size)));
// Merge results into next level of seeds and perform correction.
for (int64_t j = 0; j < batch_size; ++j) {
const int64_t index_expanded = 2 * (start_block + j);
if (expansion.control_bits[start_block + j]) {
prg_buffer_left[j] ^= correction_seed;
prg_buffer_right[j] ^= correction_seed;
}
next_level_expansion.seeds[index_expanded] = prg_buffer_left[j];
next_level_expansion.seeds[index_expanded + 1] = prg_buffer_right[j];
next_level_expansion.control_bits.push_back(
dpf_internal::ExtractAndClearLowestBit(
next_level_expansion.seeds[index_expanded]));
next_level_expansion.control_bits.push_back(
dpf_internal::ExtractAndClearLowestBit(
next_level_expansion.seeds[index_expanded + 1]));
if (expansion.control_bits[start_block + j]) {
next_level_expansion.control_bits[index_expanded] ^=
correction_control_left;
next_level_expansion.control_bits[index_expanded + 1] ^=
correction_control_right;
}
}
}
std::swap(expansion, next_level_expansion);
current_level_size *= 2;
}
return expansion;
}
absl::StatusOr<DistributedPointFunction::DpfExpansion>
DistributedPointFunction::ComputePartialEvaluations(
absl::Span<const absl::uint128> prefixes, int hierarchy_level,
bool update_ctx, EvaluationContext& ctx) const {
int64_t num_prefixes = static_cast<int64_t>(prefixes.size());
DpfExpansion partial_evaluations;
int start_level = hierarchy_to_tree_[ctx.partial_evaluations_level()];
int stop_level = hierarchy_to_tree_[hierarchy_level];
if (ctx.partial_evaluations_size() > 0 && start_level <= stop_level) {
// We have partial evaluations from a tree level before the current one.
// Parse `ctx.partial_evaluations` into a btree_map for quick lookups up by
// prefix. We use a btree_map because `ctx.partial_evaluations()` will
// usually be sorted.
absl::btree_map<absl::uint128, std::pair<absl::uint128, bool>>
previous_partial_evaluations;
for (const PartialEvaluation& element : ctx.partial_evaluations()) {
absl::uint128 prefix =
absl::MakeUint128(element.prefix().high(), element.prefix().low());
// Try inserting `(seed, control_bit)` at `prefix` into
// partial_evaluations. Return an error if `prefix` is already present
// with a different seed or control bit.
auto value = std::make_pair(
absl::MakeUint128(element.seed().high(), element.seed().low()),
element.control_bit());
auto it = previous_partial_evaluations.try_emplace(
previous_partial_evaluations.end(), prefix, value);
if (it->second != value) {
return absl::InvalidArgumentError(
"Duplicate prefix in `ctx.partial_evaluations()` with mismatching "
"seed or control bit");
}
}
// Now select all partial evaluations from the map that correspond to
// `prefixes`.
partial_evaluations.seeds =
hwy::AllocateAligned<absl::uint128>(num_prefixes);
if (partial_evaluations.seeds == nullptr) {
return absl::ResourceExhaustedError("Memory allocation error");
}
partial_evaluations.control_bits.reserve(num_prefixes);
for (int64_t i = 0; i < num_prefixes; ++i) {
absl::uint128 previous_prefix = 0;
if (stop_level - start_level < 128) {
previous_prefix = prefixes[i] >> (stop_level - start_level);
}
auto it = previous_partial_evaluations.find(previous_prefix);
if (it == previous_partial_evaluations.end()) {
return absl::InvalidArgumentError(absl::StrCat(
"Prefix not present in ctx.partial_evaluations at hierarchy level ",
hierarchy_level));
}
const std::pair<absl::uint128, bool>& partial_evaluation = it->second;
partial_evaluations.seeds[partial_evaluations.control_bits.size()] =
partial_evaluation.first;
partial_evaluations.control_bits.push_back(partial_evaluation.second);
}
} else {
// No partial evaluations in `ctx` -> Start from the beginning.
partial_evaluations.seeds =
hwy::AllocateAligned<absl::uint128>(num_prefixes);
if (partial_evaluations.seeds == nullptr) {
return absl::ResourceExhaustedError("Memory allocation error");
}
auto seeds = absl::MakeSpan(partial_evaluations.seeds.get(), num_prefixes);
std::fill(
seeds.begin(), seeds.end(),
absl::MakeUint128(ctx.key().seed().high(), ctx.key().seed().low()));
partial_evaluations.control_bits.resize(
num_prefixes, static_cast<bool>(ctx.key().party()));
start_level = 0;
}
// Evaluate the DPF up to current_tree_level.
auto seeds = absl::MakeSpan(partial_evaluations.seeds.get(),
partial_evaluations.control_bits.size());
DPF_RETURN_IF_ERROR(
EvaluateSeeds(seeds, partial_evaluations.control_bits, prefixes,
absl::MakeConstSpan(ctx.key().correction_words())
.subspan(start_level, stop_level - start_level),
seeds, absl::MakeSpan(partial_evaluations.control_bits)));
// Update `partial_evaluations` in `ctx` if there are more evaluations to
// come.
ctx.clear_partial_evaluations();
ctx.mutable_partial_evaluations()->Reserve(num_prefixes);
if (update_ctx) {
for (int64_t i = 0; i < num_prefixes; ++i) {
PartialEvaluation* current_element = ctx.add_partial_evaluations();
current_element->mutable_prefix()->set_high(
absl::Uint128High64(prefixes[i]));
current_element->mutable_prefix()->set_low(
absl::Uint128Low64(prefixes[i]));
current_element->mutable_seed()->set_high(
absl::Uint128High64(partial_evaluations.seeds[i]));
current_element->mutable_seed()->set_low(
absl::Uint128Low64(partial_evaluations.seeds[i]));
current_element->set_control_bit(partial_evaluations.control_bits[i]);
}
}
ctx.set_partial_evaluations_level(hierarchy_level);
return partial_evaluations;
}
absl::StatusOr<DistributedPointFunction::DpfExpansion>
DistributedPointFunction::ExpandAndUpdateContext(
int hierarchy_level, absl::Span<const absl::uint128> prefixes,
EvaluationContext& ctx) const {
// Expand seeds by expanding either the DPF key seed, or
// `ctx.partial_evaluations` for the given `prefixes`.
DpfExpansion selected_partial_evaluations;
int start_level = 0;
if (prefixes.empty()) {
// First expansion -> Expand seed of the DPF key.
selected_partial_evaluations.seeds = hwy::AllocateAligned<absl::uint128>(1);
if (selected_partial_evaluations.seeds == nullptr) {
return absl::ResourceExhaustedError("Memory allocation error");
}
selected_partial_evaluations.seeds[0] =
absl::MakeUint128(ctx.key().seed().high(), ctx.key().seed().low());
selected_partial_evaluations.control_bits = {
static_cast<bool>(ctx.key().party())};
} else {
// Second or later expansion -> Extract all seeds for `prefixes` from
// `ctx.partial_evaluations`. Update `ctx` if this is not the last
// evaluation.
bool update_ctx =
(hierarchy_level < static_cast<int>(parameters_.size()) - 1);
ABSL_DCHECK(ctx.previous_hierarchy_level() >= 0);
DPF_ASSIGN_OR_RETURN(
selected_partial_evaluations,
ComputePartialEvaluations(prefixes, ctx.previous_hierarchy_level(),
update_ctx, ctx));
start_level = hierarchy_to_tree_[ctx.previous_hierarchy_level()];
}
// Expand up to the next hierarchy level.
int stop_level = hierarchy_to_tree_[hierarchy_level];
DPF_ASSIGN_OR_RETURN(
DpfExpansion expansion,
ExpandSeeds(selected_partial_evaluations,
absl::MakeConstSpan(ctx.key().correction_words())
.subspan(start_level, stop_level - start_level)));
// Update hierarchy level in ctx.
ctx.set_previous_hierarchy_level(hierarchy_level);
return expansion;
}
absl::StatusOr<hwy::AlignedFreeUniquePtr<absl::uint128[]>>
DistributedPointFunction::HashExpandedSeeds(
int hierarchy_level, absl::Span<const absl::uint128> expansion) const {
const auto expansion_size = static_cast<int64_t>(expansion.size());
const int blocks_needed = blocks_needed_[hierarchy_level];
auto hashed_expansion =
hwy::AllocateAligned<absl::uint128>(expansion_size * blocks_needed);
if (hashed_expansion == nullptr) {
return absl::ResourceExhaustedError("Memory allocation error");
}
for (int64_t i = 0; i < expansion_size; ++i) {
for (int j = 0; j < blocks_needed; ++j) {
hashed_expansion[i * blocks_needed + j] = expansion[i] + j;
}
}
// Evaluate PRG in place (this is safe as `Evaluate` creates a copy of the
// input).
absl::Span<absl::uint128> hashed_expansion_span(
hashed_expansion.get(), expansion_size * blocks_needed);
DPF_RETURN_IF_ERROR(
prg_value_.Evaluate(hashed_expansion_span, hashed_expansion_span));
return hashed_expansion;
}
absl::StatusOr<std::string>
DistributedPointFunction::SerializeValueTypeDeterministically(
const ValueType& value_type) {
// We need to do serialization to a string by hand, in order to use
// deterministic serialization.
std::string serialized_value_type;
{ // Start new block so that stream destructors are run before returning.
::google::protobuf::io::StringOutputStream string_stream(
&serialized_value_type);
::google::protobuf::io::CodedOutputStream coded_stream(&string_stream);
coded_stream.SetSerializationDeterministic(true);
if (!value_type.SerializeToCodedStream(&coded_stream)) {
return absl::InternalError("Serializing value_type to string failed");
}
}
return serialized_value_type;
}
absl::StatusOr<DistributedPointFunction::ValueCorrectionFunction>
DistributedPointFunction::GetValueCorrectionFunction(
const DpfParameters& parameters) const {
std::string serialized_value_type;
DPF_ASSIGN_OR_RETURN(
serialized_value_type,
SerializeValueTypeDeterministically(parameters.value_type()));
auto it = value_correction_functions_.find(serialized_value_type);
if (it == value_correction_functions_.end()) {
return absl::FailedPreconditionError(absl::StrCat(
"No value correction function known for the following parameters:\n",
parameters.DebugString(),
"Did you call RegisterValueType<T>() with your value type?"));
}
return it->second;
}
absl::StatusOr<std::unique_ptr<DistributedPointFunction>>
DistributedPointFunction::Create(const DpfParameters& parameters) {
return CreateIncremental(absl::MakeConstSpan(&parameters, 1));
}
absl::StatusOr<std::unique_ptr<DistributedPointFunction>>
DistributedPointFunction::CreateIncremental(
absl::Span<const DpfParameters> parameters) {
// Log Highway mode for debugging.
ABSL_LOG_FIRST_N(INFO, 1)
<< "Highway is in " << dpf_internal::GetHwyModeAsString() << " mode";
// Validate `parameters` and store validator for later.
DPF_ASSIGN_OR_RETURN(
std::unique_ptr<dpf_internal::ProtoValidator> proto_validator,
dpf_internal::ProtoValidator::Create(parameters));
// Compute the number of value correction blocks needed for each hierarchy
// level.
std::vector<int> blocks_needed(parameters.size());
for (int i = 0; i < static_cast<int>(parameters.size()); ++i) {
DPF_ASSIGN_OR_RETURN(
int bits_needed,
dpf_internal::BitsNeeded(parameters[i].value_type(),
parameters[i].security_parameter()));
blocks_needed[i] = (bits_needed + 127) / 128;
}
// Set up hash functions for PRG.
DPF_ASSIGN_OR_RETURN(Aes128FixedKeyHash prg_left,
Aes128FixedKeyHash::Create(kPrgKeyLeft));
DPF_ASSIGN_OR_RETURN(Aes128FixedKeyHash prg_right,
Aes128FixedKeyHash::Create(kPrgKeyRight));
DPF_ASSIGN_OR_RETURN(Aes128FixedKeyHash prg_value,
Aes128FixedKeyHash::Create(kPrgKeyValue));
// For backwards compatibility, register all single unsigned integers as value
// types.
absl::flat_hash_map<std::string, ValueCorrectionFunction>
value_correction_functions;
DPF_RETURN_IF_ERROR(
RegisterValueTypeImpl<uint8_t>(value_correction_functions));
DPF_RETURN_IF_ERROR(
RegisterValueTypeImpl<uint16_t>(value_correction_functions));
DPF_RETURN_IF_ERROR(
RegisterValueTypeImpl<uint32_t>(value_correction_functions));
DPF_RETURN_IF_ERROR(
RegisterValueTypeImpl<uint64_t>(value_correction_functions));
DPF_RETURN_IF_ERROR(
RegisterValueTypeImpl<absl::uint128>(value_correction_functions));
// Copy parameters and return new DPF.
return absl::WrapUnique(new DistributedPointFunction(
std::move(proto_validator), std::move(blocks_needed), std::move(prg_left),
std::move(prg_right), std::move(prg_value),
std::move(value_correction_functions)));
}
absl::StatusOr<std::pair<DpfKey, DpfKey>>
DistributedPointFunction::GenerateKeysIncremental(
absl::uint128 alpha, absl::Span<const Value> beta) {
// Check validity of beta.
if (beta.size() != parameters_.size()) {
return absl::InvalidArgumentError(
"`beta` has to have the same size as `parameters` passed at "
"construction");
}
for (int i = 0; i < static_cast<int>(parameters_.size()); ++i) {
absl::Status status = proto_validator_->ValidateValue(beta[i], i);
if (!status.ok()) {
return status;
}
}
// Check validity of alpha.
int last_level_log_domain_size = parameters_.back().log_domain_size();
if (last_level_log_domain_size < 128 &&
alpha >= (absl::uint128{1} << last_level_log_domain_size)) {
return absl::InvalidArgumentError(
"`alpha` must be smaller than the output domain size");
}
std::array<DpfKey, 2> keys;
keys[0].set_party(0);
keys[1].set_party(1);
// We will annotate the following code with the corresponding lines from the
// pseudocode in the Incremental DPF paper
// (https://arxiv.org/pdf/2012.14884.pdf, Figure 11).
//
// There are two possible dimensions for each variable at each level: Parties
// (0 or 1) and branches (left or right). For two-dimensional arrays, we use
// the outer dimension for the branch, and the inner dimension for the party.
//
// Line 2: Sample random seeds for each party.
std::array<absl::uint128, 2> seeds;
RAND_bytes(reinterpret_cast<uint8_t*>(&seeds[0]), sizeof(absl::uint128));
RAND_bytes(reinterpret_cast<uint8_t*>(&seeds[1]), sizeof(absl::uint128));
keys[0].mutable_seed()->set_high(absl::Uint128High64(seeds[0]));
keys[0].mutable_seed()->set_low(absl::Uint128Low64(seeds[0]));
keys[1].mutable_seed()->set_high(absl::Uint128High64(seeds[1]));
keys[1].mutable_seed()->set_low(absl::Uint128Low64(seeds[1]));
// Line 3: Initialize control bits.
std::array<bool, 2> control_bits{0, 1};
// Line 4: Compute correction words for each level after the first one.
keys[0].mutable_correction_words()->Reserve(tree_levels_needed_ - 1);
keys[1].mutable_correction_words()->Reserve(tree_levels_needed_ - 1);
for (int i = 1; i < tree_levels_needed_; i++) {
DPF_RETURN_IF_ERROR(GenerateNext(i, alpha, beta, absl::MakeSpan(seeds),
absl::MakeSpan(control_bits),
absl::MakeSpan(keys)));
}
// Compute output correction word for last layer.
DPF_ASSIGN_OR_RETURN(
std::vector<Value> last_level_value_correction,
ComputeValueCorrection(parameters_.size() - 1, seeds, alpha, beta.back(),
control_bits[1]));
for (const Value& value : last_level_value_correction) {
*(keys[0].add_last_level_value_correction()) = value;
*(keys[1].add_last_level_value_correction()) = value;
}
return std::make_pair(std::move(keys[0]), std::move(keys[1]));
}
absl::StatusOr<EvaluationContext>
DistributedPointFunction::CreateEvaluationContext(DpfKey key) const {
// Check that `key` is valid.
DPF_RETURN_IF_ERROR(proto_validator_->ValidateDpfKey(key));
// Create new EvaluationContext with `parameters_` and `key`.
EvaluationContext result;
for (int i = 0; i < static_cast<int>(parameters_.size()); ++i) {
*(result.add_parameters()) = parameters_[i];
}
*(result.mutable_key()) = std::move(key);
// previous_hierarchy_level = -1 means that this context has not been
// evaluated at all.
result.set_previous_hierarchy_level(-1);
return result;
}
} // namespace distributed_point_functions

File diff suppressed because it is too large Load Diff

@ -1,171 +0,0 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package distributed_point_functions;
// For faster allocations of sub-messages.
option cc_enable_arenas = true;
// Describes the type of a single DPF output value. Any additional types added
// here should also be supported in internal/value_type_helpers.h.
// LINT.IfChange
message ValueType {
// Describes an integer modulo 2^l. Maps to the C++ types `uint8_t`,
// `uint16_t`, `uint32_t`, `uint64_t`, and `absl::uint128`.
message Integer {
// Number of bits per integer. Must be a power of 2 and at most 128.
int32 bitsize = 1;
}
// Describes a tuple of value types.
message Tuple {
repeated ValueType elements = 1;
}
// Describes an integer ring modulo `modulus`.
message IntModN {
// The underlying integer type used to represent elements in the ring.
Integer base_integer = 1;
// The modulus.
Value.Integer modulus = 2;
}
oneof type {
// A single integer modulo 2^l.
Integer integer = 1;
// A tuple of values.
Tuple tuple = 2;
// A integer with custom modulus.
IntModN int_mod_n = 3;
// An XOR-wrapped integer. Corresponds to the XorWrapper C++ class.
Integer xor_wrapper = 4;
}
// Do not add fields outside of the `oneof` above, to ensure that messages
// with known ValueTypes are serialized deterministically. See the
// documentation of `value_correction_functions_` in
// distributed_point_function.h for details.
}
// Used to correct output values to the desired DPF magnitude. Holds the values
// corresponding to the types defined in `ValueType`.
message Value {
message Integer {
oneof value {
// Any value up to 64 bits.
uint64 value_uint64 = 1;
// 128-bit values.
Block value_uint128 = 2;
}
}
message Tuple {
repeated Value elements = 1;
}
oneof value {
Integer integer = 1;
Tuple tuple = 2;
Integer int_mod_n =
3; // The value of an IntModN is represented by its base_integer type.
Integer xor_wrapper = 4;
}
}
// LINT.ThenChange(
// internal/value_type_helpers.h,
// internal/value_type_helpers.cc
// )
// Parameters of a single hierarchy level of a distributed point function (DPF).
message DpfParameters {
reserved 2;
// Base-2 logarithm of the number of elements.
int32 log_domain_size = 1;
// Describes the type of output values at this hierarchy level.
ValueType value_type = 3;
// The negative logarithm of the total variation distance from uniform that an
// evaluation at a *single point* at this hierarchy level is allowed to have.
// The correct value for this parameter depends on the maximum number of
// points at which this hierarchy level is evaluated. It should be at least 40
// + log2(number_of_evaluation_points). Defaults to
// ProtoValidator::kDefaultSecurityParameter + log_domain_size.
double security_parameter = 4;
}
// A single 128-bit AES block.
message Block {
uint64 high = 1;
uint64 low = 2;
}
// A correction word used to evaluate a single layer in the DPF evaluation tree.
message CorrectionWord {
// Block used to correct the new seeds after PRG evaluation.
Block seed = 1;
// Correction bits for the left and right control bits.
bool control_left = 2;
bool control_right = 3;
// Reserved for deprecated value correction field.
reserved 4;
// Used to correct the output value at the previous tree layer. Only included
// if the previous tree layer is an output layer. Repeated to capture the case
// where multiple correction values are needed due to packing.
repeated Value value_correction = 5;
}
// A key of a distributed point function (DPF).
message DpfKey {
// Initial seed at the first level.
Block seed = 1;
// Correction words for each level after expansion.
repeated CorrectionWord correction_words = 2;
// Party this DpfKey belongs to (0 or 1).
int32 party = 3;
// Deprecated last level value correction.
reserved 4;
// Output correction for the last level of the evaluation tree.
repeated Value last_level_value_correction = 5;
}
// Maps a single prefix of a DPF index to a PRG seed. Used to store partial
// evaluation state between hierarchy levels in `EvaluationContext`
message PartialEvaluation {
// Prefix in the FSS evaluation tree. Does not necessarily coincide with the
// corresponding prefix of the output domain at this hierarchy level.
Block prefix = 1;
// Seed for the next evaluation.
Block seed = 2;
// Control bit for the correction in the next evaluation.
bool control_bit = 3;
}
// An EvaluationContext holds the state of a partially evaluated incremental
// DPF.
message EvaluationContext {
// The parameters of the DPF being evaluated. One set of parameters for each
// hierarchy level of the incremental DPF.
repeated DpfParameters parameters = 1;
// The DPF key being evaluated.
DpfKey key = 2;
// The hierarchy level that this EvaluationContext was last evaluated on.
int32 previous_hierarchy_level = 3;
// Maps prefixes from an earlier hierarchy level to PRG seeds, which are used
// to continue the evaluation under each prefix. Uses a repeated message field
// since Protobuf doesn't allow messages (such as `Block`) as map keys.
repeated PartialEvaluation partial_evaluations = 4;
// The hierarchy level `partial_evaluations` corresponds to. Ignored when
// `partial_evaluations` is empty.
int32 partial_evaluations_level = 5;
}

@ -1,418 +0,0 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cmath>
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/btree_set.h"
#include "absl/log/absl_check.h"
#include "absl/numeric/int128.h"
#include "absl/random/random.h"
#include "absl/random/uniform_int_distribution.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "benchmark/benchmark.h"
#include "dpf/distributed_point_function.h"
#include "google/protobuf/arena.h"
#include "hwy/aligned_allocator.h"
namespace distributed_point_functions {
namespace {
// Benchmarks a regular DPF evaluation. Expects the first range argument to
// specify the output log domain size.
template <typename T>
void BM_EvaluateRegularDpf(benchmark::State& state) {
DpfParameters parameters;
parameters.set_log_domain_size(state.range(0));
*(parameters.mutable_value_type()) = ToValueType<T>();
std::unique_ptr<DistributedPointFunction> dpf =
DistributedPointFunction::Create(parameters).value();
absl::uint128 alpha = 0;
T beta{};
ABSL_CHECK(dpf->RegisterValueType<T>().ok());
std::pair<DpfKey, DpfKey> keys = dpf->GenerateKeys(alpha, beta).value();
EvaluationContext ctx_0 = dpf->CreateEvaluationContext(keys.first).value();
for (auto s : state) {
google::protobuf::Arena arena;
EvaluationContext* ctx =
google::protobuf::Arena::CreateMessage<EvaluationContext>(&arena);
*ctx = ctx_0;
std::vector<T> result = dpf->EvaluateNext<T>({}, *ctx).value();
benchmark::DoNotOptimize(result);
}
}
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, uint8_t)->DenseRange(12, 24, 2);
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, uint16_t)->DenseRange(12, 24, 2);
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, uint32_t)->DenseRange(12, 24, 2);
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, uint64_t)->DenseRange(12, 24, 2);
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, absl::uint128)->DenseRange(12, 24, 2);
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, Tuple<uint32_t, uint32_t>)
->DenseRange(12, 24, 2);
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, Tuple<uint32_t, uint64_t>)
->DenseRange(12, 24, 2);
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, Tuple<uint64_t, uint64_t>)
->DenseRange(12, 24, 2);
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf,
Tuple<uint32_t, uint32_t, uint32_t, uint32_t>)
->DenseRange(12, 24, 2);
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf,
Tuple<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t>)
->DenseRange(12, 24, 2);
BENCHMARK_TEMPLATE(
BM_EvaluateRegularDpf,
Tuple<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t>)
->DenseRange(12, 24, 2);
using MyIntModN = IntModN<uint32_t, 4294967291u>; // 2**32 - 5.
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf,
Tuple<MyIntModN, MyIntModN, MyIntModN, MyIntModN, MyIntModN>)
->DenseRange(12, 24, 2);
using MyIntModN64 = IntModN<uint64_t, 18446744073709551557ull>; // 2**64 - 59.
BENCHMARK_TEMPLATE(
BM_EvaluateRegularDpf,
Tuple<MyIntModN64, MyIntModN64, MyIntModN64, MyIntModN64, MyIntModN64>)
->DenseRange(12, 22, 2);
BENCHMARK_TEMPLATE(BM_EvaluateRegularDpf, XorWrapper<absl::uint128>)
->DenseRange(1, 24, 1);
// Benchmarks full evaluation of all hierarchy levels. Expects the first range
// argument to specify the number of iterations. The output domain size is fixed
// to 2**20.
template <typename T>
void BM_EvaluateHierarchicalFull(benchmark::State& state) {
// Set up DPF with the given parameters.
const int kMaxLogDomainSize = 20;
int num_hierarchy_levels = state.range(0);
std::vector<DpfParameters> parameters(num_hierarchy_levels);
for (int i = 0; i < num_hierarchy_levels; ++i) {
parameters[i].set_log_domain_size(static_cast<int>(
static_cast<double>(i + 1) / num_hierarchy_levels * kMaxLogDomainSize));
parameters[i].mutable_value_type()->mutable_integer()->set_bitsize(
sizeof(T) * 8);
}
std::unique_ptr<DistributedPointFunction> dpf =
DistributedPointFunction::CreateIncremental(parameters).value();
// Generate keys.
absl::uint128 alpha = 12345;
std::vector<absl::uint128> beta(num_hierarchy_levels);
for (int i = 0; i < num_hierarchy_levels; ++i) {
beta[i] = i;
}
std::pair<DpfKey, DpfKey> keys =
dpf->GenerateKeysIncremental(alpha, beta).value();
// Set up evaluation context and evaluation prefixes for each level.
EvaluationContext ctx_0 = dpf->CreateEvaluationContext(keys.first).value();
std::vector<std::vector<absl::uint128>> prefixes(num_hierarchy_levels);
for (int i = 1; i < num_hierarchy_levels; ++i) {
prefixes[i].resize(1 << parameters[i - 1].log_domain_size());
std::iota(prefixes[i].begin(), prefixes[i].end(), absl::uint128{0});
}
// Run hierarchical evaluation.
for (auto s : state) {
google::protobuf::Arena arena;
EvaluationContext* ctx =
google::protobuf::Arena::CreateMessage<EvaluationContext>(&arena);
*ctx = ctx_0;
for (int i = 0; i < num_hierarchy_levels; ++i) {
std::vector<T> result = dpf->EvaluateNext<T>(prefixes[i], *ctx).value();
benchmark::DoNotOptimize(result);
}
benchmark::DoNotOptimize(*ctx);
}
}
BENCHMARK_TEMPLATE(BM_EvaluateHierarchicalFull, uint8_t)->DenseRange(1, 16, 2);
BENCHMARK_TEMPLATE(BM_EvaluateHierarchicalFull, uint16_t)->DenseRange(1, 16, 2);
BENCHMARK_TEMPLATE(BM_EvaluateHierarchicalFull, uint32_t)->DenseRange(1, 16, 2);
BENCHMARK_TEMPLATE(BM_EvaluateHierarchicalFull, uint64_t)->DenseRange(1, 16, 2);
BENCHMARK_TEMPLATE(BM_EvaluateHierarchicalFull, absl::uint128)
->DenseRange(1, 16, 2);
// Generates random prefixes for the given set of `parameters`. Generates
// `num_nonzeros[i]` prefixes at hierarchy level `i`.
std::vector<std::vector<absl::uint128>> GenerateRandomPrefixes(
absl::Span<const DpfParameters> parameters,
absl::Span<const int> num_nonzeros) {
auto num_hierarchy_levels = static_cast<int>(parameters.size());
std::vector<std::vector<absl::uint128>> prefixes(parameters.size());
absl::BitGen rng;
absl::uniform_int_distribution<uint32_t> dist_index, dist_value;
for (int i = 0; i < num_hierarchy_levels; ++i) {
if (i > 0) { // prefixes must be empty for the first level.
prefixes[i] = std::vector<absl::uint128>(num_nonzeros[i - 1]);
absl::uint128 prefix = 0;
// Difference between the previous domain size and the one before that.
// This is the amount of bits we have to shift prefixes from the previous
// level to append the current level.
int previous_domain_size_difference = parameters[i - 1].log_domain_size();
if (i > 1) {
previous_domain_size_difference -= parameters[i - 2].log_domain_size();
}
dist_value = absl::uniform_int_distribution<uint32_t>(
0, (1 << previous_domain_size_difference) - 1);
if (i > 1) {
dist_index = absl::uniform_int_distribution<uint32_t>(
0, prefixes[i - 1].size() - 1);
}
for (int j = 0; i > 0 && j < num_nonzeros[i - 1]; ++j) {
if (i > 1) {
// Choose a random prefix from the previous level to extend.
prefix = prefixes[i - 1][dist_index(rng)]
<< previous_domain_size_difference;
}
prefixes[i][j] = prefix | dist_value(rng);
}
}
std::sort(prefixes[i].begin(), prefixes[i].end());
}
return prefixes;
}
// Benchmark the example used here:
// https://github.com/abetterinternet/prio-documents/issues/18#issuecomment-801248636
void BM_IsrgExampleHierarchy(benchmark::State& state) {
const int kNumHierarchyLevels = 2;
std::vector<DpfParameters> parameters(kNumHierarchyLevels);
std::vector<int> num_nonzeros(kNumHierarchyLevels - 1);
parameters[0].set_log_domain_size(12);
parameters[0].mutable_value_type()->mutable_integer()->set_bitsize(32);
num_nonzeros[0] = 32;
parameters[1].set_log_domain_size(25);
parameters[1].mutable_value_type()->mutable_integer()->set_bitsize(32);
std::unique_ptr<DistributedPointFunction> dpf =
DistributedPointFunction::CreateIncremental(parameters).value();
// Create DPF keys.
absl::uint128 alpha = 1234567;
std::vector<absl::uint128> beta(kNumHierarchyLevels, 1);
std::pair<DpfKey, DpfKey> keys =
dpf->GenerateKeysIncremental(alpha, beta).value();
// Generate prefixes for evaluation with the appropriate number of nonzeros.
std::vector<std::vector<absl::uint128>> prefixes =
GenerateRandomPrefixes(parameters, num_nonzeros);
// Run hierarchical evaluation.
EvaluationContext ctx_0 = dpf->CreateEvaluationContext(keys.first).value();
for (auto s : state) {
google::protobuf::Arena arena;
EvaluationContext* ctx =
google::protobuf::Arena::CreateMessage<EvaluationContext>(&arena);
*ctx = ctx_0;
for (int i = 0; i < kNumHierarchyLevels; ++i) {
std::vector<uint32_t> result =
dpf->EvaluateNext<uint32_t>(prefixes[i], *ctx).value();
benchmark::DoNotOptimize(result);
}
benchmark::DoNotOptimize(*ctx);
}
}
BENCHMARK(BM_IsrgExampleHierarchy);
// Benchmarks the time needed to generate keys. The log domain size is read from
// the first range argument. If `direct_evaluation` is true, a single hierarchy
// level will be used. Otherwise, the number of hierarchy levels is eqaual to
// the log domain size (i.e., one level per bit in the domain).
template <bool direct_evaluation>
void BM_KeyGeneration(benchmark::State& state) {
int last_level_log_domain_size = state.range(0);
std::vector<DpfParameters> parameters(1);
if (direct_evaluation) {
parameters[0].set_log_domain_size(last_level_log_domain_size);
parameters[0].mutable_value_type()->mutable_integer()->set_bitsize(32);
} else {
parameters.resize(last_level_log_domain_size);
for (int i = 0; i < last_level_log_domain_size; ++i) {
parameters[i].set_log_domain_size(i + 1);
parameters[i].mutable_value_type()->mutable_integer()->set_bitsize(32);
}
}
std::unique_ptr<DistributedPointFunction> dpf =
*(DistributedPointFunction::CreateIncremental(parameters));
std::vector<absl::uint128> beta(parameters.size(), 23);
absl::BitGen rng;
absl::uniform_int_distribution<uint64_t> dist;
absl::uint128 alpha_mask =
(absl::uint128{1} << parameters.back().log_domain_size()) - 1;
std::pair<DpfKey, DpfKey> result;
for (auto s : state) {
// Sample alpha randomly, so we don't rely on any structure here.
absl::uint128 alpha = absl::MakeUint128(dist(rng), dist(rng)) & alpha_mask;
result = dpf->GenerateKeysIncremental(alpha, beta).value();
benchmark::DoNotOptimize(result);
}
state.SetLabel(absl::StrCat("key_size: ", result.first.ByteSizeLong()));
}
BENCHMARK_TEMPLATE(BM_KeyGeneration, true)->RangeMultiplier(2)->Range(1, 128);
BENCHMARK_TEMPLATE(BM_KeyGeneration, false)->RangeMultiplier(2)->Range(1, 128);
// Generates `num_nonzeros` uniform indices, and computes their prefixes for
// each hierarchy level in `parameters`.
absl::StatusOr<std::vector<std::vector<absl::uint128>>> GenerateUniformPrefixes(
absl::Span<const DpfParameters> parameters, int num_nonzeros) {
int num_parameters = static_cast<int>(parameters.size());
std::vector<std::vector<absl::uint128>> result(num_parameters);
if (num_parameters <= 1) {
return result;
}
if (std::log2(num_nonzeros) >
parameters[num_parameters - 2].log_domain_size()) {
return absl::InvalidArgumentError("num_nonzeros out of range");
}
absl::BitGen rng;
absl::uniform_int_distribution<uint64_t> dist;
// Generate prefixes for last level.
absl::btree_set<absl::uint128> last_level_prefixes;
while (static_cast<int>(last_level_prefixes.size()) < num_nonzeros) {
absl::uint128 mask = (absl::uint128{1} << parameters[parameters.size() - 2]
.log_domain_size()) -
1;
last_level_prefixes.insert(absl::MakeUint128(dist(rng), dist(rng)) & mask);
}
result.back() = std::vector<absl::uint128>(last_level_prefixes.begin(),
last_level_prefixes.end());
// Iterate backwards through previous levels, computing prefixes by
// appropriately shifting the ones from higher levels.
for (int i = static_cast<int>(result.size()) - 1; i > 1; --i) {
absl::btree_set<absl::uint128> current_level_prefixes;
for (const auto& x : result[i]) {
absl::uint128 prefix = x >> (parameters[i - 1].log_domain_size() -
parameters[i - 2].log_domain_size());
current_level_prefixes.insert(prefix);
}
result[i - 1] = std::vector<absl::uint128>(current_level_prefixes.begin(),
current_level_prefixes.end());
}
return result;
}
// Benchmark a bit-wise hierarchy as in https://github.com/henrycg/heavyhitters.
// Uses a variable domain size with 10000 uniform non-zeros at the last
// hierarchy level, and evaluate at every bit.
void BM_HeavyHitters(benchmark::State& state) {
int num_parameters = state.range(0);
const int kNumNonzeros = 10000;
std::vector<DpfParameters> parameters(num_parameters);
for (int i = 0; i < num_parameters; ++i) {
parameters[i].set_log_domain_size(i + 1);
parameters[i].mutable_value_type()->mutable_integer()->set_bitsize(64);
}
std::unique_ptr<DistributedPointFunction> dpf =
*(DistributedPointFunction::CreateIncremental(parameters));
std::vector<absl::uint128> beta(num_parameters, 23);
absl::uint128 alpha = 42;
DpfKey key = dpf->GenerateKeysIncremental(alpha, beta).value().first;
std::vector<std::vector<absl::uint128>> prefixes =
GenerateUniformPrefixes(parameters, kNumNonzeros).value();
// Run hierarchical evaluation.
EvaluationContext ctx_0 = dpf->CreateEvaluationContext(key).value();
for (auto s : state) {
google::protobuf::Arena arena;
EvaluationContext* ctx =
google::protobuf::Arena::CreateMessage<EvaluationContext>(&arena);
*ctx = ctx_0;
for (int i = 0; i < num_parameters; ++i) {
std::vector<uint64_t> result =
dpf->EvaluateNext<uint64_t>(prefixes[i], *ctx).value();
benchmark::DoNotOptimize(result);
}
benchmark::DoNotOptimize(*ctx);
}
}
BENCHMARK(BM_HeavyHitters)->RangeMultiplier(2)->Range(16, 128);
// Benchmark batch evaluation of multiple DPF keys at a single point each.
// The first argument specifies the number of keys, the second the domain size,
// and the last the number of evaluation points per key.
template <typename T>
void BM_BatchEvaluation(benchmark::State& state) {
const int num_keys = state.range(0);
const int evaluation_points_per_key = state.range(1);
constexpr int kLogDomainSize = 63 - 7;
absl::uint128 domain_mask = absl::Uint128Max();
if (kLogDomainSize < 128) {
domain_mask = (absl::uint128{1} << kLogDomainSize) - 1;
}
DpfParameters parameters;
parameters.set_log_domain_size(kLogDomainSize);
*(parameters.mutable_value_type()) = ToValueType<T>();
std::unique_ptr<DistributedPointFunction> dpf =
DistributedPointFunction::Create(parameters).value();
absl::BitGen rng;
google::protobuf::Arena arena;
std::vector<const DpfKey*> key_pointers(num_keys * evaluation_points_per_key);
auto evaluation_points =
hwy::AllocateAligned<absl::uint128>(num_keys * evaluation_points_per_key);
ABSL_CHECK(evaluation_points != nullptr);
for (int i = 0; i < num_keys; ++i) {
absl::uint128 alpha = absl::MakeUint128(absl::Uniform<uint64_t>(rng),
absl::Uniform<uint64_t>(rng)) &
domain_mask;
T beta{};
DpfKey* key = google::protobuf::Arena::CreateMessage<DpfKey>(&arena);
*key = dpf->GenerateKeys(alpha, beta).value().first;
for (int j = 0; j < evaluation_points_per_key; ++j) {
key_pointers[i * evaluation_points_per_key + j] = key;
evaluation_points[i * evaluation_points_per_key + j] =
absl::MakeUint128(absl::Uniform<uint64_t>(rng),
absl::Uniform<uint64_t>(rng)) &
domain_mask;
}
}
for (auto s : state) {
for (int i = 0; i < num_keys; ++i) {
std::vector<T> result =
dpf->EvaluateAt<T>(
*(key_pointers[i]), 0,
absl::MakeConstSpan(
evaluation_points.get() + i * evaluation_points_per_key,
evaluation_points_per_key))
.value();
benchmark::DoNotOptimize(result);
}
}
}
BENCHMARK_TEMPLATE(BM_BatchEvaluation, XorWrapper<absl::uint128>)
->ArgPair(1, 400000)
->ArgPair(10, 40000)
->ArgPair(100, 4000);
} // namespace
} // namespace distributed_point_functions

File diff suppressed because it is too large Load Diff

@ -1,88 +0,0 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dpf/int_mod_n.h"
#include <cmath>
#include <string>
#include "absl/numeric/int128.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
namespace distributed_point_functions {
namespace dpf_internal {
double IntModNBase::GetSecurityLevel(int num_samples, absl::uint128 modulus) {
return 128 + 3 -
(std::log2(static_cast<double>(modulus)) +
std::log2(static_cast<double>(num_samples)) +
std::log2(static_cast<double>(num_samples + 1)));
}
absl::Status IntModNBase::CheckParameters(int num_samples,
int base_integer_bitsize,
absl::uint128 modulus,
double security_parameter) {
if (num_samples <= 0) {
return absl::InvalidArgumentError("num_samples must be positive");
}
if (base_integer_bitsize <= 0) {
return absl::InvalidArgumentError("base_integer_bitsize must be positive");
}
if (base_integer_bitsize > 128) {
return absl::InvalidArgumentError(
"base_integer_bitsize must be at most 128");
}
if (base_integer_bitsize < 128 &&
(absl::uint128{1} << base_integer_bitsize) < modulus) {
return absl::InvalidArgumentError(absl::StrFormat(
"kModulus %d out of range for base_integer_bitsize = %d", modulus,
base_integer_bitsize));
}
// Compute the level of security that we will get, and fail if it is
// insufficient.
const double sigma = GetSecurityLevel(num_samples, modulus);
if (security_parameter > sigma) {
return absl::InvalidArgumentError(absl::StrFormat(
"For num_samples = %d and kModulus = %d this approach can only "
"provide "
"%f bits of statistical security. You can try calling this function "
"several times with smaller values of num_samples.",
num_samples, modulus, sigma));
}
return absl::OkStatus();
}
absl::StatusOr<int> IntModNBase::GetNumBytesRequired(
int num_samples, int base_integer_bitsize, absl::uint128 modulus,
double security_parameter) {
absl::Status status = CheckParameters(num_samples, base_integer_bitsize,
modulus, security_parameter);
if (!status.ok()) {
return status;
}
const int base_integer_bytes = ((base_integer_bitsize + 7) / 8);
// We start the sampling by requiring a 128-bit (16 bytes) block, see
// function `SampleFromBytes`.
return 16 + base_integer_bytes * (num_samples - 1);
}
} // namespace dpf_internal
} // namespace distributed_point_functions

@ -1,282 +0,0 @@
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_INT_MOD_N_H_
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_INT_MOD_N_H_
#include <algorithm>
#include <string>
#include <type_traits>
#include "absl/base/config.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/absl_check.h"
#include "absl/numeric/int128.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
namespace distributed_point_functions {
namespace dpf_internal {
// Base class holding common functions of IntModN that are independent of the
// template parameter.
class IntModNBase {
public:
// Computes the security level achievable when sampling `num_samples` elements
// with the given `kModulus`.
//
static double GetSecurityLevel(int num_samples, absl::uint128 modulus);
// Checks if the given parameters are consistent and valid for an IntModN.
//
// Returns OK for valid parameters, and INVALID_ARGUMENT otherwise.
static absl::Status CheckParameters(int num_samples, int base_integer_bitsize,
absl::uint128 modulus,
double security_parameter);
// Computes the number of bytes required to sample `num_samples` integers
// modulo `kModulus` with an underlying integer type of
// `base_integer_bitsize`.
//
// Returns INVALID_ARGUMENT if the achievable security level with the given
// parameters is less than `security_parameter`, or if the parameters are
// invalid.
static absl::StatusOr<int> GetNumBytesRequired(int num_samples,
int base_integer_bitsize,
absl::uint128 modulus,
double security_parameter);
// Creates a value of type T from the given `bytes`, using little-endian
// encoding. Called by SampleFromBytes. Crashes if bytes.size() != sizeof(T).
//
// This is a reimplementation of dpf_internal::ConvertBytesTo for integers,
// to avoid depending on value_type_helpers here.
template <typename T>
static T ConvertBytesTo(absl::string_view bytes) {
ABSL_CHECK(bytes.size() == sizeof(T));
T out{0};
#ifdef ABSL_IS_LITTLE_ENDIAN
std::copy_n(bytes.begin(), sizeof(T), reinterpret_cast<char*>(&out));
#else
for (int i = sizeof(T) - 1; i >= 0; --i) {
out |= absl::bit_cast<uint8_t>(bytes[i]);
out <<= 8;
}
#endif
return out;
}
};
template <typename BaseInteger, typename ModulusType, ModulusType kModulus>
class IntModNImpl : public IntModNBase {
static_assert(sizeof(BaseInteger) <= sizeof(absl::uint128),
"BaseInteger may be at most 128 bits large");
static_assert(
std::is_same<BaseInteger, absl::uint128>::value ||
#ifdef ABSL_HAVE_INTRINSIC_INT128
// std::is_unsigned_v<unsigned __int128> is not true everywhere:
// https://quuxplusone.github.io/blog/2019/02/28/is-int128-integral/#signedness
std::is_same<BaseInteger, unsigned __int128>::value ||
#endif
std::is_unsigned<BaseInteger>::value,
"BaseInteger must be unsigned");
static_assert(kModulus <= ModulusType(BaseInteger(-1)),
"kModulus must fit in BaseInteger");
public:
using Base = BaseInteger;
constexpr IntModNImpl() : value_(0) {}
explicit constexpr IntModNImpl(BaseInteger value)
: value_(value % kModulus) {}
// Copyable.
constexpr IntModNImpl(const IntModNImpl& a) = default;
constexpr IntModNImpl& operator=(const IntModNImpl& a) = default;
// Assignment operators.
constexpr IntModNImpl& operator=(const BaseInteger& a) {
value_ = a % kModulus;
return *this;
}
constexpr IntModNImpl& operator+=(const IntModNImpl& a) {
AddBaseInteger(a.value_);
return *this;
}
constexpr IntModNImpl& operator-=(const IntModNImpl& a) {
SubtractBaseInteger(a.value_);
return *this;
}
// Returns the underlying representation as a BaseInteger.
constexpr BaseInteger value() const { return value_; }
// Returns the modulus of this IntModNImpl type.
static constexpr BaseInteger modulus() { return kModulus; }
// Returns the number of (pseudo)random bytes required to extract
// `num_samples` samples r1, ..., rn
// so that the stream r1, ..., rn is close to a truly (pseudo) random
// sequence up to total variation distance < 2^(-`security_parameter`)
static absl::StatusOr<int> GetNumBytesRequired(int num_samples,
double security_parameter) {
return IntModNBase::GetNumBytesRequired(
num_samples, 8 * sizeof(BaseInteger), kModulus, security_parameter);
}
// Extracts `samples.size()` samples r1, ..., rn so that the stream r1, ...,
// rn is close to a truly (pseudo) random sequence up to total variation
// distance < 2^(-`security_parameter`). Returns r1, ..., rn in `samples`.
//
// The optional template argument allows users to specify the number of
// samples at compile time, which can save heap allocations.
//
// Caution: For performance reasons, this function does not check whether
// `bytes` is long enough for the required number of samples and security
// parameter. Use `GetNumBytesRequired` or `SampleFromBytes` if such checks
// are needed.
//
template <int kCompiledNumSamples = 1>
static void UnsafeSampleFromBytes(absl::string_view bytes,
double security_parameter,
absl::Span<IntModNImpl> samples) {
static_assert(kCompiledNumSamples >= 1,
"kCompiledNumSamples must be positive");
absl::uint128 r = ConvertBytesTo<absl::uint128>(bytes.substr(0, 16));
absl::InlinedVector<BaseInteger, std::max(1, kCompiledNumSamples - 1)>
randomness(samples.size() - 1);
for (int i = 0; i < static_cast<int>(randomness.size()); ++i) {
randomness[i] = ConvertBytesTo<BaseInteger>(
bytes.substr(16 + i * sizeof(BaseInteger), sizeof(BaseInteger)));
}
for (int i = 0; i < static_cast<int>(samples.size()); ++i) {
samples[i] = IntModNImpl(static_cast<BaseInteger>(r % kModulus));
if (i < static_cast<int>(randomness.size())) {
r /= kModulus;
if (sizeof(BaseInteger) < sizeof(absl::uint128)) {
r <<= (sizeof(BaseInteger) * 8);
}
r |= randomness[i];
}
}
}
// Checks that length(`bytes`) is enough to extract
// `samples.size()` samples r1, ..., rn
// so that the stream r1, ..., rn is close to a truly (pseudo) random
// sequence up to total variation distance < 2^(-`security_parameter`) and
// fails if that is not the case.
// Otherwise returns r1, ..., rn in `samples`.
static absl::Status SampleFromBytes(absl::string_view bytes,
double security_parameter,
absl::Span<IntModNImpl> samples) {
if (samples.empty()) {
return absl::InvalidArgumentError(
"The number of samples required must be > 0");
}
absl::StatusOr<int> num_bytes_lower_bound =
GetNumBytesRequired(samples.size(), security_parameter);
if (!num_bytes_lower_bound.ok()) {
return num_bytes_lower_bound.status();
}
if (*num_bytes_lower_bound > bytes.size()) {
return absl::InvalidArgumentError(
absl::StrCat("The number of bytes provided (", bytes.size(),
") is insufficient for the required "
"statistical security and number of samples."));
}
UnsafeSampleFromBytes(bytes, security_parameter, samples);
return absl::OkStatus();
}
private:
constexpr void SubtractBaseInteger(const BaseInteger& a) {
if (value_ >= a) {
value_ -= a;
} else {
value_ = kModulus - a + value_;
}
}
constexpr void AddBaseInteger(const BaseInteger& a) {
SubtractBaseInteger(kModulus - a);
}
BaseInteger value_;
};
template <typename BaseInteger, typename ModulusType, ModulusType kModulus>
constexpr IntModNImpl<BaseInteger, ModulusType, kModulus> operator+(
IntModNImpl<BaseInteger, ModulusType, kModulus> a,
const IntModNImpl<BaseInteger, ModulusType, kModulus>& b) {
a += b;
return a;
}
template <typename BaseInteger, typename ModulusType, ModulusType kModulus>
constexpr IntModNImpl<BaseInteger, ModulusType, kModulus> operator-(
IntModNImpl<BaseInteger, ModulusType, kModulus> a,
const IntModNImpl<BaseInteger, ModulusType, kModulus>& b) {
a -= b;
return a;
}
template <typename BaseInteger, typename ModulusType, ModulusType kModulus>
constexpr IntModNImpl<BaseInteger, ModulusType, kModulus> operator-(
IntModNImpl<BaseInteger, ModulusType, kModulus> a) {
IntModNImpl<BaseInteger, ModulusType, kModulus> result(BaseInteger{0});
result -= a;
return result;
}
template <typename BaseInteger, typename ModulusType, ModulusType kModulus>
constexpr bool operator==(
const IntModNImpl<BaseInteger, ModulusType, kModulus>& a,
const IntModNImpl<BaseInteger, ModulusType, kModulus>& b) {
return a.value() == b.value();
}
template <typename BaseInteger, typename ModulusType, ModulusType kModulus>
constexpr bool operator!=(
const IntModNImpl<BaseInteger, ModulusType, kModulus>& a,
const IntModNImpl<BaseInteger, ModulusType, kModulus>& b) {
return !(a == b);
}
} // namespace dpf_internal
// Since `absl::uint128` is not an alias to `unsigned __int128`, but a struct,
// we cannot use it as a template parameter type. So if we have an intrinsic
// int128, we always use that as the modulus type. Otherwise, the modulus type
// is the same as BaseInteger.
#ifdef ABSL_HAVE_INTRINSIC_INT128
template <typename BaseInteger, unsigned __int128 kModulus>
using IntModN =
dpf_internal::IntModNImpl<BaseInteger, unsigned __int128, kModulus>;
#else
template <typename BaseInteger, BaseInteger kModulus>
using IntModN = dpf_internal::IntModNImpl<BaseInteger, BaseInteger, kModulus>;
#endif
} // namespace distributed_point_functions
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_INT_MOD_N_H_

@ -1,54 +0,0 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdint.h>
#include <cmath>
#include <vector>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "benchmark/benchmark.h"
#include "dpf/int_mod_n.h"
#include "openssl/rand.h"
namespace distributed_point_functions {
namespace {
using MyInt = IntModN<uint32_t, 4294967291u>; // 2**32 - 5.
constexpr int kNumSamples = 5;
void BM_Sample(benchmark::State& state) {
int num_iterations = state.range(0);
double security_parameter = 40 + std::log2(num_iterations);
std::vector<uint8_t> bytes(
MyInt::GetNumBytesRequired(kNumSamples, security_parameter).value());
RAND_bytes(bytes.data(), bytes.size());
std::vector<MyInt> output(num_iterations * kNumSamples);
for (auto s : state) {
for (int i = 0; i < num_iterations; ++i) {
MyInt::UnsafeSampleFromBytes<kNumSamples>(
absl::string_view(reinterpret_cast<const char*>(bytes.data()),
bytes.size()),
security_parameter,
absl::MakeSpan(&output[i * kNumSamples], kNumSamples));
}
benchmark::DoNotOptimize(output);
}
}
BENCHMARK(BM_Sample)->Range(1, 1 << 20);
} // namespace
} // namespace distributed_point_functions

@ -1,258 +0,0 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dpf/int_mod_n.h"
#include <cstdint>
#include <string>
#include <vector>
#include "absl/base/config.h"
#include "absl/numeric/int128.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "dpf/internal/status_matchers.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace distributed_point_functions {
namespace {
constexpr double kFeasibleSecurityParameter = 40;
constexpr double kUnfeasibleSecurityParameter = 95;
constexpr int kNumSamples = 5;
template <typename T>
class IntModNTest : public testing::Test {};
using IntModNTypes = ::testing::Types<
IntModN<uint32_t, 4294967291u>, // 2**32-5
IntModN<uint64_t, 18446744073709551557ull> // 2**64-59
#ifdef ABSL_HAVE_INTRINSIC_INT128
,
IntModN<absl::uint128, (unsigned __int128)(absl::MakeUint128(
65535u, 18446744073709551551ull))> // 2**80-65
#endif
>;
TYPED_TEST_SUITE(IntModNTest, IntModNTypes);
TYPED_TEST(IntModNTest, DefaultValueIsZero) {
TypeParam a;
EXPECT_EQ(a.value(), 0);
}
TYPED_TEST(IntModNTest, SetValueWorks) {
TypeParam a;
EXPECT_EQ(a.value(), 0);
a = 23;
EXPECT_EQ(a.value(), 23);
}
TYPED_TEST(IntModNTest, AdditionWithoutWrapAroundWorks) {
TypeParam a;
TypeParam b;
a += b;
EXPECT_EQ(a.value(), 0);
b = 23;
a += b;
EXPECT_EQ(a.value(), 23);
b = 4294967200;
a += b;
EXPECT_EQ(a.value(), 4294967223);
}
TYPED_TEST(IntModNTest, AdditionWithWrapAroundWorks) {
TypeParam a;
TypeParam b;
a += b;
EXPECT_EQ(a.value(), 0);
b = 23;
a += b;
EXPECT_EQ(a.value(), 23);
b = TypeParam::modulus() - 10;
a += b;
EXPECT_EQ(a.value(), 13);
}
TYPED_TEST(IntModNTest, NegationWorks) {
TypeParam a(10);
TypeParam b = -a;
EXPECT_EQ(a + b, TypeParam(0));
}
TYPED_TEST(IntModNTest, GetNumBytesRequiredFailsIfUnfeasible) {
absl::StatusOr<int> result =
TypeParam::GetNumBytesRequired(kNumSamples, kUnfeasibleSecurityParameter);
EXPECT_THAT(result, dpf_internal::StatusIs(
absl::StatusCode::kInvalidArgument,
testing::StartsWith(absl::StrFormat(
"For num_samples = 5 and kModulus = %d",
absl::uint128(TypeParam::modulus())))));
}
TYPED_TEST(IntModNTest, GetNumBytesRequiredSucceedsIfFeasible) {
absl::StatusOr<int> result =
TypeParam::GetNumBytesRequired(5, kFeasibleSecurityParameter);
EXPECT_EQ(result.ok(), true);
}
TYPED_TEST(IntModNTest, SampleFailsIfUnfeasible) {
absl::StatusOr<int> r_getnum =
TypeParam::GetNumBytesRequired(5, kFeasibleSecurityParameter);
EXPECT_EQ(r_getnum.ok(), true);
std::string bytes = std::string(16, '#');
EXPECT_GT(r_getnum.value(), bytes.size());
std::vector<TypeParam> samples(5);
absl::Status r_sample = TypeParam::SampleFromBytes(
bytes, kFeasibleSecurityParameter, absl::MakeSpan(samples));
EXPECT_EQ(r_sample.ok(), false);
EXPECT_THAT(
r_sample,
dpf_internal::StatusIs(
absl::StatusCode::kInvalidArgument,
"The number of bytes provided (16) is insufficient for the required "
"statistical security and number of samples."));
}
TYPED_TEST(IntModNTest, SampleSucceedsIfFeasible) {
absl::StatusOr<int> r_getnum =
TypeParam::GetNumBytesRequired(5, kFeasibleSecurityParameter);
EXPECT_EQ(r_getnum.ok(), true);
std::string bytes = std::string(r_getnum.value(), '#');
std::vector<TypeParam> samples(5);
absl::Status r_sample = TypeParam::SampleFromBytes(
bytes, kFeasibleSecurityParameter, absl::MakeSpan(samples));
EXPECT_EQ(r_sample.ok(), true);
}
TYPED_TEST(IntModNTest, FirstEntryOfSamplesIsAsExpected) {
absl::StatusOr<int> r_getnum =
TypeParam::GetNumBytesRequired(5, kFeasibleSecurityParameter);
EXPECT_EQ(r_getnum.ok(), true);
std::string bytes = std::string(r_getnum.value(), '#');
std::vector<TypeParam> samples(5);
absl::Status r_sample = TypeParam::SampleFromBytes(
bytes, kFeasibleSecurityParameter, absl::MakeSpan(samples));
EXPECT_EQ(r_sample.ok(), true);
EXPECT_EQ(
samples[0].value(),
TypeParam::template ConvertBytesTo<absl::uint128>(bytes.substr(0, 16)) %
TypeParam::modulus());
}
using BaseInteger = uint32_t;
constexpr BaseInteger kModulus32 = 4294967291u; // 2**32 - 5
using MyIntModN = IntModN<BaseInteger, kModulus32>;
TEST(IntModNTest, SampleFromBytesWorksInConcreteExample) {
absl::StatusOr<int> r_getnum =
MyIntModN::GetNumBytesRequired(5, kFeasibleSecurityParameter);
EXPECT_EQ(r_getnum.ok(), true);
EXPECT_EQ(*r_getnum, 32);
std::string bytes = "this is a length 32 test string.";
EXPECT_EQ(bytes.size(), 32);
std::vector<MyIntModN> samples(5);
absl::Status r_sample = MyIntModN::SampleFromBytes(
bytes, kFeasibleSecurityParameter, absl::MakeSpan(samples));
EXPECT_EQ(r_sample.ok(), true);
absl::uint128 r =
MyIntModN::ConvertBytesTo<absl::uint128>("this is a length");
EXPECT_EQ(samples[0].value(), r % MyIntModN::modulus());
r /= MyIntModN::modulus();
r <<= (sizeof(MyIntModN::Base) * 8);
r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>(" 32 ");
EXPECT_EQ(samples[1].value(), r % MyIntModN::modulus());
r /= MyIntModN::modulus();
r <<= (sizeof(MyIntModN::Base) * 8);
r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>("test");
EXPECT_EQ(samples[2].value(), r % MyIntModN::modulus());
r /= MyIntModN::modulus();
r <<= (sizeof(MyIntModN::Base) * 8);
r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>(" str");
EXPECT_EQ(samples[3].value(), r % MyIntModN::modulus());
r /= MyIntModN::modulus();
r <<= (sizeof(MyIntModN::Base) * 8);
r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>("ing.");
EXPECT_EQ(samples[4].value(), r % MyIntModN::modulus());
}
TEST(IntModNTest, SampleFromBytesFailsAsExpectedInConcreteExample) {
absl::StatusOr<int> r_getnum =
MyIntModN::GetNumBytesRequired(5, kFeasibleSecurityParameter);
EXPECT_EQ(r_getnum.ok(), true);
EXPECT_EQ(*r_getnum, 32);
std::string bytes = "this is a length 32 test string.";
EXPECT_EQ(bytes.size(), 32);
std::vector<MyIntModN> samples(5);
absl::Status r_sample = MyIntModN::SampleFromBytes(
bytes, kFeasibleSecurityParameter, absl::MakeSpan(samples));
EXPECT_EQ(r_sample.ok(), true);
absl::uint128 r =
MyIntModN::ConvertBytesTo<absl::uint128>("this is a length");
EXPECT_EQ(samples[0].value(), r % MyIntModN::modulus());
r /= MyIntModN::modulus();
r <<= (sizeof(MyIntModN::Base) * 8);
r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>(" 32 ");
EXPECT_EQ(samples[1].value(), r % MyIntModN::modulus());
r /= MyIntModN::modulus();
r <<= (sizeof(MyIntModN::Base) * 8);
r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>("test");
EXPECT_EQ(samples[2].value(), r % MyIntModN::modulus());
r /= MyIntModN::modulus();
r <<= (sizeof(MyIntModN::Base) * 8);
r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>(" str");
EXPECT_EQ(samples[3].value(), r % MyIntModN::modulus());
r /= MyIntModN::modulus();
r <<= (sizeof(MyIntModN::Base) * 8);
r |= MyIntModN::ConvertBytesTo<MyIntModN::Base>("ing#"); // # instead of .
EXPECT_NE(samples[4].value(), r % MyIntModN::modulus());
}
// Test if IntModN operators are in fact constexpr. This will fail to compile
// otherwise.
constexpr MyIntModN TestAddition() { return MyIntModN(2) + MyIntModN(5); }
static_assert(TestAddition().value() == 7,
"constexpr addition of IntModNs incorrect");
constexpr MyIntModN TestSubtraction() { return MyIntModN(5) - MyIntModN(2); }
static_assert(TestSubtraction().value() == 3,
"constexpr subtraction of IntModNs incorrect");
constexpr MyIntModN TestAssignment() {
MyIntModN x(0);
x = 5;
return x;
}
static_assert(TestAssignment().value() == 5,
"constexpr assignment to IntModN incorrect");
#ifdef ABSL_HAVE_INTRINSIC_INT128
constexpr unsigned __int128 kModulus128 =
(unsigned __int128)(-1); // 2**128 - 159
using MyIntModN128 = IntModN<unsigned __int128, kModulus128>;
constexpr MyIntModN128 TestAddition128() {
return MyIntModN128(2) + MyIntModN128(5);
}
static_assert(TestAddition128().value() == 7,
"constexpr addition of IntModNs incorrect");
#endif
} // namespace
} // namespace distributed_point_functions

@ -1,238 +0,0 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
load("@com_github_google_iree//build_tools/embed_data:build_defs.bzl", "cc_embed_data")
package(
default_visibility = ["//:__subpackages__"],
)
licenses(["notice"])
cc_library(
name = "value_type_helpers",
srcs = ["value_type_helpers.cc"],
hdrs = ["value_type_helpers.h"],
deps = [
"//dpf:distributed_point_function_cc_proto",
"//dpf:int_mod_n",
"//dpf:status_macros",
"//dpf:tuple",
"//dpf:xor_wrapper",
"@com_google_absl//absl/base:config",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/meta:type_traits",
"@com_google_absl//absl/numeric:int128",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/utility",
"@com_google_protobuf//:protobuf_lite",
],
)
cc_test(
name = "value_type_helpers_test",
srcs = ["value_type_helpers_test.cc"],
deps = [
":status_matchers",
":value_type_helpers",
"//dpf:distributed_point_function_cc_proto",
"//dpf:int_mod_n",
"//dpf:tuple",
"@com_github_google_googletest//:gtest_main",
"@com_google_absl//absl/base:config",
"@com_google_absl//absl/numeric:int128",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "status_matchers",
testonly = 1,
srcs = [
"status_matchers.cc",
],
hdrs = ["status_matchers.h"],
deps = [
"//dpf:status_macros",
"@com_github_google_googletest//:gtest",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "proto_validator",
srcs = [
"proto_validator.cc",
],
hdrs = [
"proto_validator.h",
],
deps = [
":value_type_helpers",
"//dpf:distributed_point_function_cc_proto",
"//dpf:status_macros",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/numeric:int128",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@com_google_protobuf//:protobuf_lite",
],
)
cc_embed_data(
name = "proto_validator_test_textproto_embed",
srcs = [
"proto_validator_test.textproto",
],
cc_file_output = "proto_validator_test_textproto_embed.cc",
cpp_namespace = "distributed_point_functions::dpf_internal",
h_file_output = "proto_validator_test_textproto_embed.h",
)
cc_test(
name = "proto_validator_test",
srcs = [
"proto_validator_test.cc",
],
data = [
"proto_validator_test.textproto",
],
deps = [
":proto_validator",
":proto_validator_test_textproto_embed",
":status_matchers",
"//dpf:distributed_point_function_cc_proto",
"//dpf:tuple",
"@com_github_google_googletest//:gtest_main",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_protobuf//:protobuf",
],
)
cc_library(
name = "evaluate_prg_hwy",
srcs = ["evaluate_prg_hwy.cc"],
hdrs = ["evaluate_prg_hwy.h"],
deps = [
":aes_128_fixed_key_hash_hwy",
"//dpf:aes_128_fixed_key_hash",
"//dpf:status_macros",
"@boringssl//:crypto",
"@com_github_google_highway//:hwy",
"@com_google_absl//absl/base:config",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/numeric:int128",
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:span",
],
)
cc_test(
name = "evaluate_prg_hwy_test",
srcs = [
"evaluate_prg_hwy_test.cc",
],
deps = [
":evaluate_prg_hwy",
":status_matchers",
"//dpf:aes_128_fixed_key_hash",
"@com_github_google_googletest//:gtest_main",
"@com_github_google_highway//:hwy",
"@com_github_google_highway//:hwy_test_util",
"@com_google_absl//absl/numeric:int128",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)
cc_library(
name = "get_hwy_mode",
srcs = ["get_hwy_mode.cc"],
hdrs = ["get_hwy_mode.h"],
deps = [
"@com_github_google_highway//:hwy",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "aes_128_fixed_key_hash_hwy",
hdrs = [
"aes_128_fixed_key_hash_hwy.h",
],
deps = [
"@com_github_google_highway//:hwy",
"@com_google_absl//absl/numeric:int128",
],
)
cc_library(
name = "maybe_deref_span",
hdrs = ["maybe_deref_span.h"],
deps = [
"@com_google_absl//absl/meta:type_traits",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
],
)
cc_test(
name = "aes_128_fixed_key_hash_hwy_test",
srcs = [
"aes_128_fixed_key_hash_hwy_test.cc",
],
deps = [
":aes_128_fixed_key_hash_hwy",
":get_hwy_mode",
":status_matchers",
"//dpf:aes_128_fixed_key_hash",
"@boringssl//:crypto",
"@com_github_google_googletest//:gtest_main",
"@com_github_google_highway//:hwy",
"@com_github_google_highway//:hwy_test_util",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/numeric:int128",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
],
)
cc_test(
name = "maybe_deref_span_test",
srcs = ["maybe_deref_span_test.cc"],
deps = [
":maybe_deref_span",
"@com_github_google_googletest//:gtest_main",
],
)

@ -1,237 +0,0 @@
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Highway-specific include guard, ensuring the header can get included once per
// target architecture.
#if defined( \
DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_HWY_H_) == \
defined(HWY_TARGET_TOGGLE)
#ifdef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_HWY_H_
#undef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_HWY_H_
#else
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_HWY_H_
#endif
#include <limits>
#include "absl/numeric/int128.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace distributed_point_functions {
namespace dpf_internal {
namespace HWY_NAMESPACE {
// There is no AES support on HWY_SCALAR, but we still want to be able to
// include this header when compiling for HWY_SCALAR. The caller has to make
// sure to only call the functions defined here when not on HWY_SCALAR.
#if HWY_TARGET != HWY_SCALAR
namespace hn = hwy::HWY_NAMESPACE;
constexpr int kAesBlockSize = 16;
// Helper to convert a Highway tag to a tag for vectors of the same bit size,
// but with 64-bit lanes.
template <typename D>
constexpr auto To64(D d) {
return hn::Repartition<uint64_t, D>();
}
// The following macros define parts of the fixed-key AES hash function
// implementation. We use macros here since Highway doesn't allow creating
// arrays of vectors/SIMD registers. That way, we can access each register by a
// unique variable name. All inputs and outputs are assumed to be of type
// hn::ScalableTag<uint8_t>.
// Loads the AES round key for the given round and key_index.
#define DPF_AES_LOAD_ROUND_KEY(key_index, round) \
const auto round_##round##_key_##key_index = \
hn::LoadDup128(d, round_keys_##key_index + kAesBlockSize * round);
// Selects key_0 or key_1 for the given block_index and round, depending on the
// bits in `mask`. Keys are first converted to 64-bit vectors to apply the more
// efficient 64 bit masks.
#define DPF_AES_SELECT_KEY(block_index, round) \
const auto selected_round_##round##_key_##block_index = hn::BitCast( \
d, hn::IfThenElse(mask_##block_index, \
hn::BitCast(To64(d), round_##round##_key_1), \
hn::BitCast(To64(d), round_##round##_key_0)));
// Load mask for computing {0, x.high64}, for computing sigma(x) below.
HWY_ALIGN constexpr absl::uint128 kSigmaMask =
absl::MakeUint128(std::numeric_limits<uint64_t>::max(), 0);
#define DPF_AES_LOAD_SIGMA_MASK() \
const auto sigma_mask = \
hn::LoadDup128(To64(d), reinterpret_cast<const uint64_t*>(&kSigmaMask));
// Compute sigma(x) = {x.high64, x.high64^x.low64} (in little-endian notation).
#define DPF_AES_COMPUTE_SIGMA(block_index) \
const auto in_##block_index##_64 = hn::BitCast(To64(d), in_##block_index); \
const auto sigma_##block_index = \
hn::BitCast(d, hn::Xor(hn::Shuffle01(in_##block_index##_64), \
hn::And(sigma_mask, in_##block_index##_64)));
// Performs the first round of AES for the given block_index, using sigma as the
// input.
#define DPF_AES_FIRST_ROUND(block_index) \
out_##block_index = \
hn::Xor(sigma_##block_index, selected_round_0_key_##block_index)
// Performs a middle round of AES for the given block_index.
#define DPF_AES_MIDDLE_ROUND(block_index, round) \
out_##block_index = hn::AESRound( \
out_##block_index, selected_round_##round##_key_##block_index);
// Performs the last round of AES for the given block_index.
#define DPF_AES_LAST_ROUND(block_index) \
out_##block_index = hn::AESLastRound(out_##block_index, \
selected_round_10_key_##block_index);
// Finalize the hash by XORing with sigma.
#define DPF_AES_FINALIZE_HASH(block_index) \
out_##block_index = hn::Xor(out_##block_index, sigma_##block_index);
// Helper macro for hashing a single vector.
#define DPF_AES_MIDDLE_ROUND_1(round) \
DPF_AES_LOAD_ROUND_KEY(0, round); \
DPF_AES_LOAD_ROUND_KEY(1, round); \
DPF_AES_SELECT_KEY(0, round); \
DPF_AES_MIDDLE_ROUND(0, round);
// Hashes a vector `in_0`, writing the output to `out_0`. Each block is hashed
// using either `round_keys_0` or `round_keys_1`, which both must point to a
// byte array containing two expanded AES keys. Which key is used for each block
// depends on `mask_0`: If the mask 0, then `round_keys_0` is used, otherwise
// `round_keys_1`. Note that the masks are masks on 64 bit integers, so there
// are two mask bits per AES block. The caller is responsible for making sure
// that the masks for the two halves of any given block have the same value.
template <typename V, typename D, typename M>
void HashOneWithKeyMask(D d, V in_0, M mask_0,
const uint8_t* HWY_RESTRICT round_keys_0,
const uint8_t* HWY_RESTRICT round_keys_1, V& out_0) {
// Compute sigma(in_0)
DPF_AES_LOAD_SIGMA_MASK();
DPF_AES_COMPUTE_SIGMA(0);
// First AES round.
DPF_AES_LOAD_ROUND_KEY(0, 0);
DPF_AES_LOAD_ROUND_KEY(1, 0);
DPF_AES_SELECT_KEY(0, 0);
DPF_AES_FIRST_ROUND(0);
// Middle AES rounds.
DPF_AES_MIDDLE_ROUND_1(1);
DPF_AES_MIDDLE_ROUND_1(2);
DPF_AES_MIDDLE_ROUND_1(3);
DPF_AES_MIDDLE_ROUND_1(4);
DPF_AES_MIDDLE_ROUND_1(5);
DPF_AES_MIDDLE_ROUND_1(6);
DPF_AES_MIDDLE_ROUND_1(7);
DPF_AES_MIDDLE_ROUND_1(8);
DPF_AES_MIDDLE_ROUND_1(9);
// Last AES round.
DPF_AES_LOAD_ROUND_KEY(0, 10);
DPF_AES_LOAD_ROUND_KEY(1, 10);
DPF_AES_SELECT_KEY(0, 10)
DPF_AES_LAST_ROUND(0);
// Finalize hash.
DPF_AES_FINALIZE_HASH(0);
}
// Helper macros for hashing four vectors in parallel.
#define DPF_AES_SELECT_KEY_4(round) \
DPF_AES_SELECT_KEY(0, round); \
DPF_AES_SELECT_KEY(1, round); \
DPF_AES_SELECT_KEY(2, round); \
DPF_AES_SELECT_KEY(3, round);
#define DPF_AES_MIDDLE_ROUND_4(round) \
DPF_AES_LOAD_ROUND_KEY(0, round); \
DPF_AES_LOAD_ROUND_KEY(1, round); \
DPF_AES_SELECT_KEY_4(round); \
DPF_AES_MIDDLE_ROUND(0, round); \
DPF_AES_MIDDLE_ROUND(1, round); \
DPF_AES_MIDDLE_ROUND(2, round); \
DPF_AES_MIDDLE_ROUND(3, round);
// Hashes four vectors `in_0, ..., in_3`, writing the results to `out_0, ...,
// out_3`. This improves pipelining of AES instructions, and improves
// performance by about 10%. Each block is hashed using either `round_keys_0` or
// `round_keys_1`, which both must point to a byte array containing two expanded
// AES keys. Which key is used for each block depends on `mask_0, ... mask_3`:
// If the mask 0, then `round_keys_0` is used, otherwise `round_keys_1`. Note
// that the masks are masks on 64 bit integers, so there are two mask bits per
// AES block. The caller is responsible for making sure that the masks for the
// two halves of any given block have the same value.
template <typename V, typename D, typename M>
void HashFourWithKeyMask(D d, V in_0, V in_1, V in_2, V in_3, M mask_0,
M mask_1, M mask_2, M mask_3,
const uint8_t* HWY_RESTRICT round_keys_0,
const uint8_t* HWY_RESTRICT round_keys_1, V& out_0,
V& out_1, V& out_2, V& out_3) {
// Compute sigma(in_0), ..., sigma(in_3)
DPF_AES_LOAD_SIGMA_MASK();
DPF_AES_COMPUTE_SIGMA(0);
DPF_AES_COMPUTE_SIGMA(1);
DPF_AES_COMPUTE_SIGMA(2);
DPF_AES_COMPUTE_SIGMA(3);
// First AES round.
DPF_AES_LOAD_ROUND_KEY(0, 0);
DPF_AES_LOAD_ROUND_KEY(1, 0);
DPF_AES_SELECT_KEY_4(0)
DPF_AES_FIRST_ROUND(0);
DPF_AES_FIRST_ROUND(1);
DPF_AES_FIRST_ROUND(2);
DPF_AES_FIRST_ROUND(3);
// Middle AES rounds.
DPF_AES_MIDDLE_ROUND_4(1);
DPF_AES_MIDDLE_ROUND_4(2);
DPF_AES_MIDDLE_ROUND_4(3);
DPF_AES_MIDDLE_ROUND_4(4);
DPF_AES_MIDDLE_ROUND_4(5);
DPF_AES_MIDDLE_ROUND_4(6);
DPF_AES_MIDDLE_ROUND_4(7);
DPF_AES_MIDDLE_ROUND_4(8);
DPF_AES_MIDDLE_ROUND_4(9);
// Last AES round.
DPF_AES_LOAD_ROUND_KEY(0, 10);
DPF_AES_LOAD_ROUND_KEY(1, 10);
DPF_AES_SELECT_KEY_4(10)
DPF_AES_LAST_ROUND(0);
DPF_AES_LAST_ROUND(1);
DPF_AES_LAST_ROUND(2);
DPF_AES_LAST_ROUND(3);
// Finalize hash.
DPF_AES_FINALIZE_HASH(0);
DPF_AES_FINALIZE_HASH(1);
DPF_AES_FINALIZE_HASH(2);
DPF_AES_FINALIZE_HASH(3);
}
#endif // HWY_TARGET != HWY_SCALAR
} // namespace HWY_NAMESPACE
} // namespace dpf_internal
} // namespace distributed_point_functions
HWY_AFTER_NAMESPACE();
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_AES_128_FIXED_KEY_HASH_HWY_H_

@ -1,232 +0,0 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <limits>
#include <memory>
#include <vector>
#include "absl/flags/parse.h"
#include "absl/log/absl_log.h"
#include "absl/numeric/int128.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "dpf/aes_128_fixed_key_hash.h"
#include "dpf/internal/get_hwy_mode.h"
#include "dpf/internal/status_matchers.h"
#include "gtest/gtest.h"
#include "hwy/aligned_allocator.h"
#include "hwy/detect_targets.h"
#include "openssl/aes.h"
// clang-format off
#define HWY_IS_TEST 1
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "dpf/internal/aes_128_fixed_key_hash_hwy_test.cc" // NOLINT
#include "hwy/foreach_target.h"
// clang-format on
#include "dpf/internal/aes_128_fixed_key_hash_hwy.h"
#include "hwy/highway.h"
#include "hwy/tests/test_util-inl.h"
HWY_BEFORE_NAMESPACE();
namespace distributed_point_functions {
namespace dpf_internal {
namespace HWY_NAMESPACE {
#if HWY_TARGET == HWY_SCALAR
void TestAllAes() {
return; // HWY_SCALAR doesn't support AES instructions, so nothing to test.
}
#else
namespace hn = hwy::HWY_NAMESPACE;
constexpr absl::uint128 kKey0 =
absl::MakeUint128(0x0000000000000000, 0x0000000000000000);
constexpr absl::uint128 kKey1 =
absl::MakeUint128(0x1111111111111111, 0x1111111111111111);
constexpr int kNumBlocks = 128; // Must be divisible by (4 * hn::Lanes(d)).
constexpr int kNumBytes = kNumBlocks * sizeof(absl::uint128);
class TestOutputMatchesOpenSSL {
public:
template <typename T, typename D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
Reset();
EvaluateOne(d);
CheckResult();
Reset();
EvaluateFour(d);
CheckResult();
}
private:
void Reset() {
inputs_ = hwy::AllocateAligned<absl::uint128>(kNumBlocks);
ASSERT_NE(inputs_, nullptr);
masks_ = hwy::AllocateAligned<uint64_t>(2 * kNumBlocks);
ASSERT_NE(masks_, nullptr);
for (int i = 0; i < kNumBlocks; ++i) {
inputs_[i] = absl::MakeUint128(i, i + 1);
masks_[2 * i] = masks_[2 * i + 1] =
(i % 3 == 0) ? std::numeric_limits<uint64_t>::max() : 0;
}
outputs_ = hwy::AllocateAligned<absl::uint128>(kNumBlocks);
ASSERT_NE(outputs_, nullptr);
ASSERT_EQ(0, AES_set_encrypt_key(reinterpret_cast<const uint8_t*>(&kKey0),
128, &expanded_key_0_));
ASSERT_EQ(0, AES_set_encrypt_key(reinterpret_cast<const uint8_t*>(&kKey1),
128, &expanded_key_1_));
input_ptr_ = reinterpret_cast<const uint8_t*>(inputs_.get());
output_ptr_ = reinterpret_cast<uint8_t*>(outputs_.get());
}
void CheckResult() {
// Check the result by comparing with OpenSSL-based AES hash.
DPF_ASSERT_OK_AND_ASSIGN(
distributed_point_functions::Aes128FixedKeyHash hash_0,
distributed_point_functions::Aes128FixedKeyHash::Create(kKey0));
DPF_ASSERT_OK_AND_ASSIGN(
distributed_point_functions::Aes128FixedKeyHash hash_1,
distributed_point_functions::Aes128FixedKeyHash::Create(kKey1));
std::vector<absl::uint128> wanted_0(kNumBlocks), wanted_1(kNumBlocks);
DPF_ASSERT_OK(
hash_0.Evaluate(absl::MakeConstSpan(inputs_.get(), kNumBlocks),
absl::MakeSpan(wanted_0)));
DPF_ASSERT_OK(
hash_1.Evaluate(absl::MakeConstSpan(inputs_.get(), kNumBlocks),
absl::MakeSpan(wanted_1)));
for (int i = 0; i < kNumBlocks; ++i) {
if (i % 3 == 0) {
EXPECT_EQ(wanted_1[i], outputs_.get()[i]) << "i=" << i;
} else {
EXPECT_EQ(wanted_0[i], outputs_.get()[i]) << "i=" << i;
}
}
}
template <typename D>
void EvaluateOne(D d) {
hn::Repartition<uint64_t, D> d64;
for (int i = 0; i + hn::Lanes(d) <= kNumBytes; i += hn::Lanes(d)) {
const auto in = hn::Load(d, input_ptr_ + i);
const auto mask =
hn::MaskFromVec(hn::Load(d64, masks_.get() + i / sizeof(uint64_t)));
auto out = hn::Undefined(d);
HashOneWithKeyMask(
d, in, mask, reinterpret_cast<const uint8_t*>(expanded_key_0_.rd_key),
reinterpret_cast<const uint8_t*>(expanded_key_1_.rd_key), out);
hn::Store(out, d, output_ptr_ + i);
}
}
template <typename D>
void EvaluateFour(D d) {
hn::Repartition<uint64_t, D> d64;
// Evaluate four vectors at once. Assumes kNumBytes is divisible by (4 *
// hn::Lanes(d)).
for (int i = 0; i < kNumBytes; i += 4 * hn::Lanes(d)) {
const auto in_0 = hn::Load(d, input_ptr_ + i);
const auto in_1 = hn::Load(d, input_ptr_ + i + 1 * hn::Lanes(d));
const auto in_2 = hn::Load(d, input_ptr_ + i + 2 * hn::Lanes(d));
const auto in_3 = hn::Load(d, input_ptr_ + i + 3 * hn::Lanes(d));
const auto mask_0 =
hn::MaskFromVec(hn::Load(d64, masks_.get() + i / sizeof(uint64_t)));
const auto mask_1 = hn::MaskFromVec(hn::Load(
d64, masks_.get() + (i + 1 * hn::Lanes(d)) / sizeof(uint64_t)));
const auto mask_2 = hn::MaskFromVec(hn::Load(
d64, masks_.get() + (i + 2 * hn::Lanes(d)) / sizeof(uint64_t)));
const auto mask_3 = hn::MaskFromVec(hn::Load(
d64, masks_.get() + (i + 3 * hn::Lanes(d)) / sizeof(uint64_t)));
auto out_0 = hn::Undefined(d);
auto out_1 = hn::Undefined(d);
auto out_2 = hn::Undefined(d);
auto out_3 = hn::Undefined(d);
HashFourWithKeyMask(
d, in_0, in_1, in_2, in_3, mask_0, mask_1, mask_2, mask_3,
reinterpret_cast<const uint8_t*>(expanded_key_0_.rd_key),
reinterpret_cast<const uint8_t*>(expanded_key_1_.rd_key), out_0,
out_1, out_2, out_3);
hn::Store(out_0, d, output_ptr_ + i);
hn::Store(out_1, d, output_ptr_ + i + 1 * hn::Lanes(d));
hn::Store(out_2, d, output_ptr_ + i + 2 * hn::Lanes(d));
hn::Store(out_3, d, output_ptr_ + i + 3 * hn::Lanes(d));
}
// Check the result by comparing with OpenSSL-based AES hash.
DPF_ASSERT_OK_AND_ASSIGN(
distributed_point_functions::Aes128FixedKeyHash hash_0,
distributed_point_functions::Aes128FixedKeyHash::Create(kKey0));
DPF_ASSERT_OK_AND_ASSIGN(
distributed_point_functions::Aes128FixedKeyHash hash_1,
distributed_point_functions::Aes128FixedKeyHash::Create(kKey1));
std::vector<absl::uint128> wanted_0(kNumBlocks), wanted_1(kNumBlocks);
DPF_ASSERT_OK(
hash_0.Evaluate(absl::MakeConstSpan(inputs_.get(), kNumBlocks),
absl::MakeSpan(wanted_0)));
DPF_ASSERT_OK(
hash_1.Evaluate(absl::MakeConstSpan(inputs_.get(), kNumBlocks),
absl::MakeSpan(wanted_1)));
for (int i = 0; i < kNumBlocks; ++i) {
if (i % 3 == 0) {
EXPECT_EQ(wanted_1[i], outputs_.get()[i]) << "i=" << i;
} else {
EXPECT_EQ(wanted_0[i], outputs_.get()[i]) << "i=" << i;
}
}
}
hwy::AlignedFreeUniquePtr<absl::uint128[]> inputs_, outputs_;
hwy::AlignedFreeUniquePtr<uint64_t[]> masks_;
const uint8_t* input_ptr_;
uint8_t* output_ptr_;
HWY_ALIGN AES_KEY expanded_key_0_, expanded_key_1_;
};
void TestAllAes() {
hn::ForGE128Vectors<TestOutputMatchesOpenSSL>()(uint8_t{0});
}
#endif // HWY_TARGET == HWY_SCALAR
} // namespace HWY_NAMESPACE
} // namespace dpf_internal
} // namespace distributed_point_functions
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace distributed_point_functions {
namespace dpf_internal {
HWY_BEFORE_TEST(Aes128FixedKeyHashHwyTest);
HWY_EXPORT_AND_TEST_P(Aes128FixedKeyHashHwyTest, TestAllAes);
TEST(Aes128FixedKeyHashHwy, LogHwyMode) {
ABSL_LOG(INFO) << "Highway is in " << GetHwyModeAsString() << " mode";
}
} // namespace dpf_internal
} // namespace distributed_point_functions
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
#endif

@ -1,662 +0,0 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dpf/internal/evaluate_prg_hwy.h"
#include <algorithm>
#include <cstdint>
#include <memory>
#include <vector>
#include "absl/base/config.h"
#include "absl/base/optimization.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/absl_check.h"
#include "absl/numeric/int128.h"
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "dpf/aes_128_fixed_key_hash.h"
#include "dpf/status_macros.h"
#include "hwy/aligned_allocator.h"
#include "openssl/aes.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "dpf/internal/evaluate_prg_hwy.cc"
#include "hwy/foreach_target.h"
// clang-format on
#include "dpf/internal/aes_128_fixed_key_hash_hwy.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace distributed_point_functions {
namespace dpf_internal {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
#if HWY_TARGET == HWY_SCALAR
absl::Status EvaluateSeedsHwy(
int64_t num_seeds, int num_levels, const absl::uint128* seeds_in,
const bool* control_bits_in, const absl::uint128* paths,
const absl::uint128* correction_seeds, const bool* correction_controls_left,
const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left,
const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out,
bool* control_bits_out) {
return EvaluateSeedsNoHwy(num_seeds, num_levels, seeds_in, control_bits_in,
paths, correction_seeds, correction_controls_left,
correction_controls_right, prg_left, prg_right,
seeds_out, control_bits_out);
}
#else
// Converts a bool array to a block-level mask suitable for vectors described by
// `d`. The mask value for each integer in the i-th block is set to input[i].
// If `max_blocks > 0`, returns after reading `max_blocks` bools from `input`.
template <typename D>
auto MaskFromBools(D d, const bool* input, int max_blocks = 0) {
using T = hn::TFromD<D>;
constexpr size_t ints_per_block = sizeof(absl::uint128) / sizeof(T);
constexpr int buffer_size = std::max(HWY_MAX_BYTES / 8, 64);
uint8_t mask_bits[buffer_size] = {0};
for (int i = 0; i < hn::Lanes(d); ++i) {
int block_idx = i / ints_per_block;
if (max_blocks > 0 && block_idx >= max_blocks) {
break;
}
if (input[block_idx]) {
mask_bits[i / 8] |= uint8_t{1} << (i % 8);
}
}
return hn::LoadMaskBits(d, mask_bits);
}
// Converts a mask for types `d` to a bool array. Assumes that the mask value
// for all integers in the i-th block is equal, and writes that value to
// output[i]. If `max_blocks > 0`, returns after writing `max_blocks` bools to
// `output`.
template <typename D, typename M>
void BoolsFromMask(D d, M mask, bool* output, int max_blocks = 0) {
using T = hn::TFromD<D>;
constexpr size_t ints_per_block = sizeof(absl::uint128) / sizeof(T);
int num_outputs = hn::Lanes(d) / ints_per_block;
if (max_blocks > 0) {
num_outputs = max_blocks;
}
constexpr int buffer_size = std::max(HWY_MAX_BYTES / 8, 64);
uint8_t mask_bits[buffer_size] = {0};
hn::StoreMaskBits(d, mask, mask_bits);
for (int i = 0; i < num_outputs; ++i) {
int mask_idx = i * ints_per_block;
output[i] = (mask_bits[mask_idx / 8] & (uint8_t{1} << (mask_idx % 8))) != 0;
}
}
template <typename M>
M IfThenElseMask(M condition, M true_value, M false_value) {
return hn::Or(hn::And(condition, true_value),
hn::And(hn::Not(condition), false_value));
}
// Returns a mask that is `true` on all blocks where `input[i] & (1 << index)`
// is nonzero. The mask is a 64-bit-level mask, suitable for AES hashing.
template <typename V, typename D>
auto IsBitSet(D d, const V input, int index) {
// First create a 128-bit block with the `index`-th bit set.
HWY_ALIGN absl::uint128 mask = 0;
if (index < 128) {
mask = absl::uint128{1} << index;
}
// Now load it into a vector of 64-bit integers. Note that every second
// element of that vector will be 0.
const hn::Repartition<uint64_t, D> d64;
static_assert(ABSL_IS_LITTLE_ENDIAN);
const auto mask_64 =
hn::LoadDup128(d64, reinterpret_cast<const uint64_t*>(&mask));
// Compute input AND mask_64 on 64-bit integers.
auto input_64 = hn::BitCast(d64, input);
input_64 = hn::And(input_64, mask_64);
// Take the OR of every two adjacent 64-bit integers. This ensures that each
// half of an 128-bit block is nonzero iff at least one half was nonzero.
input_64 = hn::Or(input_64, hn::Shuffle01(input_64));
// Compute a 64-bit mask that checks which integers are nonzero.
return hn::Ne(input_64, hn::Zero(d64));
}
// Dummy struct to get HWY_ALIGN as a number, for testing if an array of
// absl::uint128 is aligned.
struct HWY_ALIGN Aligned128 {
absl::uint128 _;
};
absl::Status EvaluateSeedsHwy(
int64_t num_seeds, int num_levels, int num_correction_words,
const absl::uint128* seeds_in, const bool* control_bits_in,
const absl::uint128* paths, int paths_rightshift,
const absl::uint128* correction_seeds, const bool* correction_controls_left,
const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left,
const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out,
bool* control_bits_out) {
// Exit early if inputs are empty.
if (num_seeds == 0 || num_levels == 0) {
return absl::OkStatus();
}
// Check if inputs and outputs are aligned.
constexpr size_t kHwyAlignment = alignof(Aligned128);
const bool is_aligned =
(reinterpret_cast<uintptr_t>(seeds_in) % kHwyAlignment == 0) &&
(reinterpret_cast<uintptr_t>(paths) % kHwyAlignment == 0) &&
(reinterpret_cast<uintptr_t>(correction_seeds) % kHwyAlignment == 0) &&
(reinterpret_cast<uintptr_t>(seeds_out) % kHwyAlignment == 0);
// Vector type used throughout this function: Largest byte vector available.
const hn::ScalableTag<uint8_t> d8;
// Only run the highway version if
// - the inputs are aligned,
// - the number of bytes in a vector is at least 16, and
// - the number of bytes in a vector is a multiple of 16.
if (ABSL_PREDICT_FALSE(!is_aligned || hn::Lanes(d8) < 16 ||
hn::Lanes(d8) % 16 != 0)) {
return EvaluateSeedsNoHwy(
num_seeds, num_levels, num_correction_words, seeds_in, control_bits_in,
paths, paths_rightshift, correction_seeds, correction_controls_left,
correction_controls_right, prg_left, prg_right, seeds_out,
control_bits_out);
}
// Do AES key schedule.
HWY_ALIGN AES_KEY expanded_key_0;
HWY_ALIGN AES_KEY expanded_key_1;
int openssl_status = AES_set_encrypt_key(
reinterpret_cast<const uint8_t*>(&prg_left.key()), 128, &expanded_key_0);
if (openssl_status != 0) {
return absl::InternalError("Failed to set up AES key");
}
openssl_status = AES_set_encrypt_key(
reinterpret_cast<const uint8_t*>(&prg_right.key()), 128, &expanded_key_1);
if (openssl_status != 0) {
return absl::InternalError("Failed to set up AES key");
}
// Helper variables.
const hn::Repartition<uint64_t, decltype(d8)> d64;
HWY_ALIGN absl::uint128 clear_lowest_bit_128 = ~absl::uint128{1};
const auto clear_lowest_bit = hn::LoadDup128(
d8, reinterpret_cast<const uint8_t*>(&clear_lowest_bit_128));
const auto mask_all_zero = hn::FirstN(d64, 0);
const auto mask_all_one = hn::Not(mask_all_zero);
const int64_t num_bytes = num_seeds * sizeof(absl::uint128);
const int bytes_per_vec = hn::Lanes(d8);
const int blocks_per_vec = bytes_per_vec / sizeof(absl::uint128);
const int64_t correction_words_per_level = num_correction_words / num_levels;
// Pointer aliases for reading and writing data.
const uint8_t* seeds_in_ptr = reinterpret_cast<const uint8_t*>(seeds_in);
const uint8_t* paths_ptr = reinterpret_cast<const uint8_t*>(paths);
uint8_t* seeds_out_ptr = reinterpret_cast<uint8_t*>(seeds_out);
// Four vectors at a time.
int64_t i = 0;
for (; i + 4 * bytes_per_vec <= num_bytes; i += 4 * bytes_per_vec) {
const int64_t start_block = i / sizeof(absl::uint128);
// Load initial seeds and paths into vectors.
auto vec_0 = hn::Load(d8, seeds_in_ptr + i);
auto vec_1 = hn::Load(d8, seeds_in_ptr + i + 1 * bytes_per_vec);
auto vec_2 = hn::Load(d8, seeds_in_ptr + i + 2 * bytes_per_vec);
auto vec_3 = hn::Load(d8, seeds_in_ptr + i + 3 * bytes_per_vec);
const auto path_0 = hn::Load(d8, paths_ptr + i);
const auto path_1 = hn::Load(d8, paths_ptr + i + 1 * bytes_per_vec);
const auto path_2 = hn::Load(d8, paths_ptr + i + 2 * bytes_per_vec);
const auto path_3 = hn::Load(d8, paths_ptr + i + 3 * bytes_per_vec);
auto control_mask_0 = MaskFromBools(d64, control_bits_in + start_block);
auto control_mask_1 =
MaskFromBools(d64, control_bits_in + start_block + 1 * blocks_per_vec);
auto control_mask_2 =
MaskFromBools(d64, control_bits_in + start_block + 2 * blocks_per_vec);
auto control_mask_3 =
MaskFromBools(d64, control_bits_in + start_block + 3 * blocks_per_vec);
for (int j = 0; j < num_levels; ++j) {
// Convert path bits to masks and evaluate PRG.
const int bit_index = num_levels - j - 1 + paths_rightshift;
const auto path_mask_0 = IsBitSet(d8, path_0, bit_index);
const auto path_mask_1 = IsBitSet(d8, path_1, bit_index);
const auto path_mask_2 = IsBitSet(d8, path_2, bit_index);
const auto path_mask_3 = IsBitSet(d8, path_3, bit_index);
HashFourWithKeyMask(
d8, vec_0, vec_1, vec_2, vec_3, path_mask_0, path_mask_1, path_mask_2,
path_mask_3, reinterpret_cast<const uint8_t*>(expanded_key_0.rd_key),
reinterpret_cast<const uint8_t*>(expanded_key_1.rd_key), vec_0, vec_1,
vec_2, vec_3);
// Apply correction.
if (correction_words_per_level == 1) {
const auto correction_seed = hn::LoadDup128(
d64, reinterpret_cast<const uint64_t*>(correction_seeds + j));
vec_0 = hn::Xor(vec_0,
hn::BitCast(d8, hn::IfThenElseZero(control_mask_0,
correction_seed)));
vec_1 = hn::Xor(vec_1,
hn::BitCast(d8, hn::IfThenElseZero(control_mask_1,
correction_seed)));
vec_2 = hn::Xor(vec_2,
hn::BitCast(d8, hn::IfThenElseZero(control_mask_2,
correction_seed)));
vec_3 = hn::Xor(vec_3,
hn::BitCast(d8, hn::IfThenElseZero(control_mask_3,
correction_seed)));
} else { // correction_words_per_level == num_seeds.
const uint8_t* correction_seeds_ptr = reinterpret_cast<const uint8_t*>(
correction_seeds + j * correction_words_per_level);
hn::Vec<decltype(d64)> correction_seed_0, correction_seed_1,
correction_seed_2, correction_seed_3;
if (ABSL_PREDICT_TRUE(
correction_words_per_level % blocks_per_vec == 0 || j == 0)) {
correction_seed_0 =
hn::BitCast(d64, hn::Load(d8, correction_seeds_ptr + i));
correction_seed_1 = hn::BitCast(
d64, hn::Load(d8, correction_seeds_ptr + i + 1 * bytes_per_vec));
correction_seed_2 = hn::BitCast(
d64, hn::Load(d8, correction_seeds_ptr + i + 2 * bytes_per_vec));
correction_seed_3 = hn::BitCast(
d64, hn::Load(d8, correction_seeds_ptr + i + 3 * bytes_per_vec));
} else {
correction_seed_0 =
hn::BitCast(d64, hn::LoadU(d8, correction_seeds_ptr + i));
correction_seed_1 = hn::BitCast(
d64, hn::LoadU(d8, correction_seeds_ptr + i + 1 * bytes_per_vec));
correction_seed_2 = hn::BitCast(
d64, hn::LoadU(d8, correction_seeds_ptr + i + 2 * bytes_per_vec));
correction_seed_3 = hn::BitCast(
d64, hn::LoadU(d8, correction_seeds_ptr + i + 3 * bytes_per_vec));
}
vec_0 = hn::Xor(vec_0,
hn::BitCast(d8, hn::IfThenElseZero(control_mask_0,
correction_seed_0)));
vec_1 = hn::Xor(vec_1,
hn::BitCast(d8, hn::IfThenElseZero(control_mask_1,
correction_seed_1)));
vec_2 = hn::Xor(vec_2,
hn::BitCast(d8, hn::IfThenElseZero(control_mask_2,
correction_seed_2)));
vec_3 = hn::Xor(vec_3,
hn::BitCast(d8, hn::IfThenElseZero(control_mask_3,
correction_seed_3)));
}
// Extract control bit for next level.
const auto next_control_mask_0 = IsBitSet(d8, vec_0, 0);
const auto next_control_mask_1 = IsBitSet(d8, vec_1, 0);
const auto next_control_mask_2 = IsBitSet(d8, vec_2, 0);
const auto next_control_mask_3 = IsBitSet(d8, vec_3, 0);
vec_0 = hn::And(vec_0, clear_lowest_bit);
vec_1 = hn::And(vec_1, clear_lowest_bit);
vec_2 = hn::And(vec_2, clear_lowest_bit);
vec_3 = hn::And(vec_3, clear_lowest_bit);
// Perform control bit correction.
auto correction_control_mask_0 = mask_all_zero,
correction_control_mask_1 = mask_all_zero,
correction_control_mask_2 = mask_all_zero,
correction_control_mask_3 = mask_all_zero;
if (correction_words_per_level == 1) {
const auto correction_control_mask_left =
correction_controls_left[j] ? mask_all_one : mask_all_zero;
const auto correction_control_mask_right =
correction_controls_right[j] ? mask_all_one : mask_all_zero;
correction_control_mask_0 =
IfThenElseMask(path_mask_0, correction_control_mask_right,
correction_control_mask_left);
correction_control_mask_1 =
IfThenElseMask(path_mask_1, correction_control_mask_right,
correction_control_mask_left);
correction_control_mask_2 =
IfThenElseMask(path_mask_2, correction_control_mask_right,
correction_control_mask_left);
correction_control_mask_3 =
IfThenElseMask(path_mask_3, correction_control_mask_right,
correction_control_mask_left);
} else { // correction_words_per_level == num_seeds.
const bool* correction_controls_left_j =
correction_controls_left + j * correction_words_per_level +
start_block;
const bool* correction_controls_right_j =
correction_controls_right + j * correction_words_per_level +
start_block;
correction_control_mask_0 = IfThenElseMask(
path_mask_0, MaskFromBools(d64, correction_controls_right_j),
MaskFromBools(d64, correction_controls_left_j));
correction_control_mask_1 = IfThenElseMask(
path_mask_1,
MaskFromBools(d64,
correction_controls_right_j + 1 * blocks_per_vec),
MaskFromBools(d64,
correction_controls_left_j + 1 * blocks_per_vec));
correction_control_mask_2 = IfThenElseMask(
path_mask_2,
MaskFromBools(d64,
correction_controls_right_j + 2 * blocks_per_vec),
MaskFromBools(d64,
correction_controls_left_j + 2 * blocks_per_vec));
correction_control_mask_3 = IfThenElseMask(
path_mask_3,
MaskFromBools(d64,
correction_controls_right_j + 3 * blocks_per_vec),
MaskFromBools(d64,
correction_controls_left_j + 3 * blocks_per_vec));
}
control_mask_0 =
hn::Xor(next_control_mask_0,
(hn::And(control_mask_0, correction_control_mask_0)));
control_mask_1 =
hn::Xor(next_control_mask_1,
(hn::And(control_mask_1, correction_control_mask_1)));
control_mask_2 =
hn::Xor(next_control_mask_2,
(hn::And(control_mask_2, correction_control_mask_2)));
control_mask_3 =
hn::Xor(next_control_mask_3,
(hn::And(control_mask_3, correction_control_mask_3)));
}
// Write the evaluated outputs to memory.
hn::Store(vec_0, d8, seeds_out_ptr + i);
hn::Store(vec_1, d8, seeds_out_ptr + i + 1 * bytes_per_vec);
hn::Store(vec_2, d8, seeds_out_ptr + i + 2 * bytes_per_vec);
hn::Store(vec_3, d8, seeds_out_ptr + i + 3 * bytes_per_vec);
BoolsFromMask(d64, control_mask_0, control_bits_out + start_block);
BoolsFromMask(d64, control_mask_1,
control_bits_out + start_block + 1 * blocks_per_vec);
BoolsFromMask(d64, control_mask_2,
control_bits_out + start_block + 2 * blocks_per_vec);
BoolsFromMask(d64, control_mask_3,
control_bits_out + start_block + 3 * blocks_per_vec);
}
ABSL_DCHECK_GT(i + 4 * bytes_per_vec, num_bytes);
// Single full vectors.
for (; i + bytes_per_vec <= num_bytes; i += bytes_per_vec) {
const int64_t start_block = i / sizeof(absl::uint128);
auto vec = hn::Load(d8, seeds_in_ptr + i);
const auto path = hn::Load(d8, paths_ptr + i);
auto control_mask = MaskFromBools(d64, control_bits_in + start_block);
for (int j = 0; j < num_levels; ++j) {
const int bit_index = num_levels - j - 1 + paths_rightshift;
const auto path_mask = IsBitSet(d8, path, bit_index);
HashOneWithKeyMask(
d8, vec, path_mask,
reinterpret_cast<const uint8_t*>(expanded_key_0.rd_key),
reinterpret_cast<const uint8_t*>(expanded_key_1.rd_key), vec);
// Apply correction.
hn::Vec<decltype(d64)> correction_seed;
if (correction_words_per_level == 1) {
correction_seed = hn::LoadDup128(
d64, reinterpret_cast<const uint64_t*>(correction_seeds + j));
} else {
const uint64_t* correction_seeds_ptr =
reinterpret_cast<const uint64_t*>(correction_seeds +
j * correction_words_per_level +
start_block);
if (ABSL_PREDICT_TRUE(
correction_words_per_level % blocks_per_vec == 0 || j == 0)) {
correction_seed = hn::Load(d64, correction_seeds_ptr);
} else {
correction_seed = hn::LoadU(d64, correction_seeds_ptr);
}
}
vec = hn::Xor(vec, hn::BitCast(d8, hn::IfThenElseZero(control_mask,
correction_seed)));
// Extract control bit for next level.
const auto next_control_mask = IsBitSet(d8, vec, 0);
vec = hn::And(vec, clear_lowest_bit);
// Perform control bit correction.
auto correction_control_mask = mask_all_zero;
if (correction_words_per_level == 1) {
const auto correction_control_mask_left =
correction_controls_left[j] ? mask_all_one : mask_all_zero;
const auto correction_control_mask_right =
correction_controls_right[j] ? mask_all_one : mask_all_zero;
correction_control_mask =
IfThenElseMask(path_mask, correction_control_mask_right,
correction_control_mask_left);
} else {
const bool* correction_controls_left_j =
correction_controls_left + j * correction_words_per_level +
start_block;
const bool* correction_controls_right_j =
correction_controls_right + j * correction_words_per_level +
start_block;
correction_control_mask = IfThenElseMask(
path_mask, MaskFromBools(d64, correction_controls_right_j),
MaskFromBools(d64, correction_controls_left_j));
}
control_mask = hn::Xor(next_control_mask,
(hn::And(control_mask, correction_control_mask)));
}
hn::Store(vec, d8, seeds_out_ptr + i);
BoolsFromMask(d64, control_mask, control_bits_out + start_block);
}
ABSL_DCHECK_GT(i + bytes_per_vec, num_bytes);
// Elements less than a full vector.
int remaining_blocks = num_seeds - i / sizeof(absl::uint128);
if (remaining_blocks > 0) {
const int64_t start_block = i / sizeof(absl::uint128);
const int remaining_bytes = num_bytes - i;
// Copy to a buffer first, to ensure we have at least bytes_per_vec bytes
// to read. Calling MaskedLoad directly instead might lead to out-of-bounds
// accesses.
auto buffer = hwy::AllocateAligned<absl::uint128>(2 * blocks_per_vec);
if (buffer == nullptr) {
return absl::ResourceExhaustedError("Memory allocation error");
}
auto buffer_ptr = reinterpret_cast<uint8_t*>(buffer.get());
std::copy_n(seeds_in + start_block, remaining_blocks, buffer.get());
std::copy_n(paths + start_block, remaining_blocks,
buffer.get() + blocks_per_vec);
const auto load_mask = hn::FirstN(d8, remaining_bytes);
auto vec = hn::MaskedLoad(load_mask, d8, buffer_ptr);
const auto path = hn::MaskedLoad(load_mask, d8, buffer_ptr + bytes_per_vec);
auto control_mask =
MaskFromBools(d64, control_bits_in + start_block, remaining_blocks);
for (int j = 0; j < num_levels; ++j) {
const int bit_index = num_levels - j - 1 + paths_rightshift;
const auto path_mask = IsBitSet(d8, path, bit_index);
HashOneWithKeyMask(
d8, vec, path_mask,
reinterpret_cast<const uint8_t*>(expanded_key_0.rd_key),
reinterpret_cast<const uint8_t*>(expanded_key_1.rd_key), vec);
// Perform seed correction.
hn::Vec<decltype(d64)> correction_seed;
if (correction_words_per_level == 1) {
correction_seed = hn::LoadDup128(
d64, reinterpret_cast<const uint64_t*>(correction_seeds + j));
} else {
std::copy_n(
correction_seeds + j * correction_words_per_level + start_block,
remaining_blocks, buffer.get());
correction_seed =
hn::BitCast(d64, hn::MaskedLoad(load_mask, d8, buffer_ptr));
}
vec = hn::Xor(vec, hn::BitCast(d8, hn::IfThenElseZero(control_mask,
correction_seed)));
const auto next_control_mask = IsBitSet(d8, vec, 0);
vec = hn::And(vec, clear_lowest_bit);
// Perform control bit correction.
auto correction_control_mask = mask_all_zero;
if (correction_words_per_level == 1) {
const auto correction_control_mask_left =
correction_controls_left[j] ? mask_all_one : mask_all_zero;
const auto correction_control_mask_right =
correction_controls_right[j] ? mask_all_one : mask_all_zero;
correction_control_mask =
IfThenElseMask(path_mask, correction_control_mask_right,
correction_control_mask_left);
} else {
const bool* correction_controls_left_j =
correction_controls_left + j * correction_words_per_level +
start_block;
const bool* correction_controls_right_j =
correction_controls_right + j * correction_words_per_level +
start_block;
correction_control_mask = IfThenElseMask(
path_mask,
MaskFromBools(d64, correction_controls_right_j, remaining_blocks),
MaskFromBools(d64, correction_controls_left_j, remaining_blocks));
}
control_mask = hn::Xor(next_control_mask,
(hn::And(control_mask, correction_control_mask)));
}
// Store back into buffer, then copy to seeds_out.
hn::Store(vec, d8, buffer_ptr);
std::copy_n(buffer.get(), remaining_blocks, seeds_out + start_block);
BoolsFromMask(d64, control_mask, control_bits_out + start_block,
remaining_blocks);
}
return absl::OkStatus();
}
#endif // HWY_TARGET == HWY_SCALAR
} // namespace HWY_NAMESPACE
} // namespace dpf_internal
} // namespace distributed_point_functions
HWY_AFTER_NAMESPACE();
#if HWY_ONCE || HWY_IDE
namespace distributed_point_functions {
namespace dpf_internal {
absl::Status EvaluateSeedsNoHwy(
int64_t num_seeds, int num_levels, int num_correction_words,
const absl::uint128* seeds_in, const bool* control_bits_in,
const absl::uint128* paths, int paths_rightshift,
const absl::uint128* correction_seeds, const bool* correction_controls_left,
const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left,
const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out,
bool* control_bits_out) {
using BitVector =
absl::InlinedVector<bool,
std::max<size_t>(1, sizeof(bool*) / sizeof(bool))>;
constexpr int64_t max_batch_size = Aes128FixedKeyHash::kBatchSize;
// Allocate buffers.
std::vector<absl::uint128> buffer_left, buffer_right;
buffer_left.resize(max_batch_size);
buffer_right.resize(max_batch_size);
BitVector path_bits(max_batch_size), control_bits(max_batch_size);
// Perform DPF evaluation in blocks.
for (int64_t start_block = 0; start_block < num_seeds;
start_block += max_batch_size) {
int64_t current_batch_size =
std::min<int64_t>(num_seeds - start_block, max_batch_size);
for (int level = 0; level < num_levels; ++level) {
// Evaluate PRG. We evaluate both left and right expansions, but only use
// one of them (depending on path_bits). This seems to be faster than
// first sorting the seeds by path_bits and then expanding.
absl::Span<const absl::uint128> seeds =
absl::MakeConstSpan((level == 0 ? seeds_in : seeds_out) + start_block,
current_batch_size);
DPF_RETURN_IF_ERROR(prg_left.Evaluate(
seeds, absl::MakeSpan(buffer_left).subspan(0, current_batch_size)));
DPF_RETURN_IF_ERROR(prg_right.Evaluate(
seeds, absl::MakeSpan(buffer_right).subspan(0, current_batch_size)));
// Merge back into result.
const int bit_index = num_levels - level - 1 + paths_rightshift;
for (int i = 0; i < current_batch_size; ++i) {
path_bits[i] = 0;
if (bit_index < 128) {
path_bits[i] =
((paths[start_block + i]) & (absl::uint128{1} << bit_index)) != 0;
}
if (path_bits[i] == 0) {
seeds_out[start_block + i] = buffer_left[i];
} else {
seeds_out[start_block + i] = buffer_right[i];
}
}
// Compute correction. Making a copy here a copy here improves pipelining
// by not updating result.control_bits in place. Do benchmarks before
// removing this.
std::copy_n(
&(level == 0 ? control_bits_in : control_bits_out)[start_block],
current_batch_size, &control_bits[0]);
int correction_index = level;
for (int i = 0; i < current_batch_size; ++i) {
if (num_correction_words > num_levels) {
// We have num_levels * num_seeds correction words.
correction_index = level * num_seeds + start_block + i;
}
if (control_bits[i]) {
seeds_out[start_block + i] ^= correction_seeds[correction_index];
}
bool current_control_bit =
ExtractAndClearLowestBit(seeds_out[start_block + i]);
if (control_bits[i]) {
if (path_bits[i] == 0) {
current_control_bit ^= correction_controls_left[correction_index];
} else {
current_control_bit ^= correction_controls_right[correction_index];
}
}
control_bits_out[start_block + i] = current_control_bit;
}
}
}
return absl::OkStatus();
}
HWY_EXPORT(EvaluateSeedsHwy);
absl::Status EvaluateSeeds(
int64_t num_seeds, int num_levels, int num_correction_words,
const absl::uint128* seeds_in, const bool* control_bits_in,
const absl::uint128* paths, int paths_rightshift,
const absl::uint128* correction_seeds, const bool* correction_controls_left,
const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left,
const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out,
bool* control_bits_out) {
// Check that we either have one or `num_seeds` correction words per level.
if (num_correction_words != num_levels &&
num_correction_words != num_levels * num_seeds) {
return absl::InvalidArgumentError(
"`num_correction_words` must be equal to `num_levels` or `num_levels * "
"num_seeds`");
}
return HWY_DYNAMIC_DISPATCH(EvaluateSeedsHwy)(
num_seeds, num_levels, num_correction_words, seeds_in, control_bits_in,
paths, paths_rightshift, correction_seeds, correction_controls_left,
correction_controls_right, prg_left, prg_right, seeds_out,
control_bits_out);
}
} // namespace dpf_internal
} // namespace distributed_point_functions
#endif

@ -1,92 +0,0 @@
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_EXPAND_SEEDS_HWY_H_
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_EXPAND_SEEDS_HWY_H_
#include <stdint.h>
#include "absl/numeric/int128.h"
#include "absl/status/status.h"
#include "dpf/aes_128_fixed_key_hash.h"
namespace distributed_point_functions {
namespace dpf_internal {
using distributed_point_functions::Aes128FixedKeyHash;
// Extracts the lowest bit of `x` and sets it to 0 in `x`.
inline bool ExtractAndClearLowestBit(absl::uint128& x) {
bool bit = ((x & absl::uint128{1}) != 0);
x &= ~absl::uint128{1};
return bit;
}
// Performs DPF evaluation of the seeds given in `seeds_in` using `prg_left` or
// `prg_right, and the given `control_bits_in`, and correction words given by
// `correction_seeds`, `correction_controls_left`, and
// `correction_controls_right`. At each level `l < num_level`, the evaluation
// for the i-th seed continues along the left or right path depending on the
// l-th most significant bit among the lowest `num_levels` bits of `paths[i]`,
// after right-shifting each `paths[i]` by `paths_rightshift`.
//
// This function takes raw pointers instead of absl::Span for performance
// reasons. No bounds checks are performed, so it is the caller's responsibility
// to ensure that
// - `seeds_in`, `control_bits_in`, `seeds_out`, and `control_bits_out` have at
// least `num_seeds` elements, and
// - `correction_seeds`, `correction_controls_left`, and
// `correction_controls_right` have at least `num_levels` elements.
//
// If the inputs are aligned (e.g. using HWY_ALIGN, or hwy::AllocateAligned),
// and if SIMD operations are supported, then the evaluation will be done using
// SIMD operations. Otherwise, falls back to `EvaluateSeedsNoHwy`, which is at
// least 2x slower.
//
// `num_correction_words` can either be equal to `num_levels`, or equal to
// `num_seeds * num_levels`. In the first case, the same correction word is used
// for every seed at a given level. In the second case, correction word at index
// `i * num_seeds + j` is used to correct seed `i` at level `j`.
// If `num_correction_words == num_seeds * num_levels`, then `num_seeds` should
// be smaller than or divisible by the size of a SIMD vector for optimal
// performance.
//
// Returns OK on success, INVALID_ARGUMENT in case num_correction_words is not
// equal to `num_levels` or `num_seeds * num_levels`, and INTERNAL in case of
// OpenSSL errors.
absl::Status EvaluateSeeds(
int64_t num_seeds, int num_levels, int num_correction_words,
const absl::uint128* seeds_in, const bool* control_bits_in,
const absl::uint128* paths, int paths_rightshift,
const absl::uint128* correction_seeds, const bool* correction_controls_left,
const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left,
const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out,
bool* control_bits_out);
// As `EvaluateSeeds`, but does not require any SIMD support.
absl::Status EvaluateSeedsNoHwy(
int64_t num_seeds, int num_levels, int num_correction_words,
const absl::uint128* seeds_in, const bool* control_bits_in,
const absl::uint128* paths, int paths_rightshift,
const absl::uint128* correction_seeds, const bool* correction_controls_left,
const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left,
const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out,
bool* control_bits_out);
} // namespace dpf_internal
} // namespace distributed_point_functions
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_EXPAND_SEEDS_HWY_H_

@ -1,257 +0,0 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dpf/internal/evaluate_prg_hwy.h"
#include <memory>
#include "absl/numeric/int128.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "dpf/aes_128_fixed_key_hash.h"
#include "dpf/internal/status_matchers.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "hwy/aligned_allocator.h"
// clang-format off
#define HWY_IS_TEST 1;
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "dpf/internal/evaluate_prg_hwy_test.cc" // NOLINT
#include "hwy/foreach_target.h"
// clang-format on
#include "hwy/highway.h"
#include "hwy/tests/hwy_gtest.h"
HWY_BEFORE_NAMESPACE();
namespace distributed_point_functions {
namespace dpf_internal {
namespace HWY_NAMESPACE {
using ::testing::HasSubstr;
constexpr absl::uint128 kKey0 =
absl::MakeUint128(0x0000000000000000, 0x0000000000000000);
constexpr absl::uint128 kKey1 =
absl::MakeUint128(0x1111111111111111, 0x1111111111111111);
void TestOutputMatchesNoHwyVersion(int num_seeds, int num_levels,
int num_correction_words,
int paths_rightshift) {
// Generate seeds.
hwy::AlignedFreeUniquePtr<absl::uint128[]> seeds_in, paths;
hwy::AlignedFreeUniquePtr<bool[]> control_bits_in;
if (num_seeds > 0) {
seeds_in = hwy::AllocateAligned<absl::uint128>(num_seeds);
ASSERT_NE(seeds_in, nullptr);
paths = hwy::AllocateAligned<absl::uint128>(num_seeds);
ASSERT_NE(paths, nullptr);
control_bits_in = hwy::AllocateAligned<bool>(num_seeds);
ASSERT_NE(control_bits_in, nullptr);
}
for (int i = 0; i < num_seeds; ++i) {
// All of these are arbitrary.
seeds_in[i] = absl::MakeUint128(i, i + 1);
paths[i] = absl::MakeUint128(23 * i + 42, 42 * i + 23);
control_bits_in[i] = (i % 7 == 0);
}
hwy::AlignedFreeUniquePtr<absl::uint128[]> seeds_out;
hwy::AlignedFreeUniquePtr<bool[]> control_bits_out;
if (num_seeds > 0) {
seeds_out = hwy::AllocateAligned<absl::uint128>(num_seeds);
ASSERT_NE(seeds_out, nullptr);
control_bits_out = hwy::AllocateAligned<bool>(num_seeds);
ASSERT_NE(control_bits_out, nullptr);
}
// Generate correction words.
hwy::AlignedFreeUniquePtr<absl::uint128[]> correction_seeds;
hwy::AlignedFreeUniquePtr<bool[]> correction_controls_left,
correction_controls_right;
if (num_correction_words > 0) {
correction_seeds =
hwy::AllocateAligned<absl::uint128>(num_correction_words);
ASSERT_NE(correction_seeds, nullptr);
correction_controls_left = hwy::AllocateAligned<bool>(num_correction_words);
ASSERT_NE(correction_controls_left, nullptr);
correction_controls_right =
hwy::AllocateAligned<bool>(num_correction_words);
ASSERT_NE(correction_controls_right, nullptr);
}
for (int i = 0; i < num_correction_words; ++i) {
correction_seeds[i] = absl::MakeUint128(i + 1, i);
correction_controls_left[i] = (i % 23 == 0);
correction_controls_right[i] = (i % 42 != 0);
}
// Set up PRGs.
DPF_ASSERT_OK_AND_ASSIGN(
auto prg_left,
distributed_point_functions::Aes128FixedKeyHash::Create(kKey0));
DPF_ASSERT_OK_AND_ASSIGN(
auto prg_right,
distributed_point_functions::Aes128FixedKeyHash::Create(kKey1));
// Evaluate with Highway enabled.
DPF_ASSERT_OK(
EvaluateSeeds(num_seeds, num_levels, num_correction_words, seeds_in.get(),
control_bits_in.get(), paths.get(), paths_rightshift,
correction_seeds.get(), correction_controls_left.get(),
correction_controls_right.get(), prg_left, prg_right,
seeds_out.get(), control_bits_out.get()));
// Evaluate without highway.
hwy::AlignedFreeUniquePtr<absl::uint128[]> seeds_out_wanted;
hwy::AlignedFreeUniquePtr<bool[]> control_bits_out_wanted;
if (num_seeds > 0) {
seeds_out_wanted = hwy::AllocateAligned<absl::uint128>(num_seeds);
ASSERT_NE(seeds_out_wanted, nullptr);
control_bits_out_wanted = hwy::AllocateAligned<bool>(num_seeds);
ASSERT_NE(control_bits_out_wanted, nullptr);
}
DPF_ASSERT_OK(EvaluateSeedsNoHwy(
num_seeds, num_levels, num_correction_words, seeds_in.get(),
control_bits_in.get(), paths.get(), paths_rightshift,
correction_seeds.get(), correction_controls_left.get(),
correction_controls_right.get(), prg_left, prg_right,
seeds_out_wanted.get(), control_bits_out_wanted.get()));
// Check that both evaluations are equal, if there was anything to evaluate.
if (num_levels > 0) {
for (int i = 0; i < num_seeds; ++i) {
EXPECT_EQ(seeds_out[i], seeds_out_wanted[i]);
EXPECT_EQ(control_bits_out[i], control_bits_out_wanted[i]);
}
}
// Evaluate without paths_rightshift
if (paths_rightshift != 0) {
hwy::AlignedFreeUniquePtr<absl::uint128[]> paths_in2;
hwy::AlignedFreeUniquePtr<absl::uint128[]> seeds_out_wanted2;
hwy::AlignedFreeUniquePtr<bool[]> control_bits_out_wanted2;
if (num_seeds > 0) {
paths_in2 = hwy::AllocateAligned<absl::uint128>(num_seeds);
ASSERT_NE(paths_in2, nullptr);
seeds_out_wanted2 = hwy::AllocateAligned<absl::uint128>(num_seeds);
ASSERT_NE(seeds_out_wanted2, nullptr);
control_bits_out_wanted2 = hwy::AllocateAligned<bool>(num_seeds);
ASSERT_NE(control_bits_out_wanted2, nullptr);
}
for (int i = 0; i < num_seeds; ++i) {
paths_in2[i] = 0;
if (paths_rightshift < 128) {
paths_in2[i] = paths[i] >> paths_rightshift;
}
}
DPF_ASSERT_OK(EvaluateSeedsNoHwy(
num_seeds, num_levels, num_correction_words, seeds_in.get(),
control_bits_in.get(), paths_in2.get(), 0, correction_seeds.get(),
correction_controls_left.get(), correction_controls_right.get(),
prg_left, prg_right, seeds_out_wanted2.get(),
control_bits_out_wanted2.get()));
// Check that both evaluations are equal, if there was anything to evaluate.
if (num_levels > 0) {
for (int i = 0; i < num_seeds; ++i) {
EXPECT_EQ(seeds_out[i], seeds_out_wanted2[i]);
EXPECT_EQ(control_bits_out[i], control_bits_out_wanted2[i]);
}
}
}
}
void TestAll() {
for (int num_seeds : {0, 1, 2, 101, 128, 1000}) {
for (int num_levels : {0, 1, 2, 32, 63, 64, 127, 128}) {
for (int num_correction_words : {num_levels, num_levels * num_seeds}) {
TestOutputMatchesNoHwyVersion(num_seeds, num_levels,
num_correction_words, 0);
}
}
}
}
void TestPathsRightshift() {
constexpr int num_levels = 128;
for (int num_seeds : {0, 1, 101}) {
for (int paths_rightshift = 0; paths_rightshift <= 128;
++paths_rightshift) {
TestOutputMatchesNoHwyVersion(num_seeds, num_levels, num_levels,
paths_rightshift);
}
}
}
void FailsIfNumCorrectionWordsIsWrong() {
constexpr int num_seeds = 1000;
constexpr int num_levels = 10;
constexpr int num_correction_words = 12;
hwy::AlignedFreeUniquePtr<absl::uint128[]> seeds_in, paths;
hwy::AlignedFreeUniquePtr<bool[]> control_bits_in;
seeds_in = hwy::AllocateAligned<absl::uint128>(num_seeds);
ASSERT_NE(seeds_in, nullptr);
paths = hwy::AllocateAligned<absl::uint128>(num_seeds);
ASSERT_NE(paths, nullptr);
control_bits_in = hwy::AllocateAligned<bool>(num_seeds);
ASSERT_NE(control_bits_in, nullptr);
hwy::AlignedFreeUniquePtr<absl::uint128[]> correction_seeds;
hwy::AlignedFreeUniquePtr<bool[]> correction_controls_left,
correction_controls_right;
correction_seeds = hwy::AllocateAligned<absl::uint128>(num_correction_words);
ASSERT_NE(correction_seeds, nullptr);
correction_controls_left = hwy::AllocateAligned<bool>(num_correction_words);
ASSERT_NE(correction_controls_left, nullptr);
correction_controls_right = hwy::AllocateAligned<bool>(num_correction_words);
ASSERT_NE(correction_controls_right, nullptr);
DPF_ASSERT_OK_AND_ASSIGN(
auto prg_left,
distributed_point_functions::Aes128FixedKeyHash::Create(kKey0));
DPF_ASSERT_OK_AND_ASSIGN(
auto prg_right,
distributed_point_functions::Aes128FixedKeyHash::Create(kKey1));
EXPECT_THAT(
EvaluateSeeds(num_seeds, num_levels, num_correction_words, seeds_in.get(),
control_bits_in.get(), paths.get(), 0,
correction_seeds.get(), correction_controls_left.get(),
correction_controls_right.get(), prg_left, prg_right,
seeds_in.get(), control_bits_in.get()),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("num_correction_words")));
}
} // namespace HWY_NAMESPACE
} // namespace dpf_internal
} // namespace distributed_point_functions
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace distributed_point_functions {
namespace dpf_internal {
HWY_BEFORE_TEST(EvaluatePrgHwyTest);
HWY_EXPORT_AND_TEST_P(EvaluatePrgHwyTest, TestAll);
HWY_EXPORT_AND_TEST_P(EvaluatePrgHwyTest, TestPathsRightshift);
HWY_EXPORT_AND_TEST_P(EvaluatePrgHwyTest, FailsIfNumCorrectionWordsIsWrong);
} // namespace dpf_internal
} // namespace distributed_point_functions
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
#endif

@ -1,48 +0,0 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dpf/internal/get_hwy_mode.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "dpf/internal/get_hwy_mode.cc"
#include "absl/strings/string_view.h"
#include "hwy/foreach_target.h"
// clang-format on
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace distributed_point_functions {
namespace dpf_internal {
namespace HWY_NAMESPACE {
const absl::string_view GetHwyModeAsString() {
return hwy::TargetName(HWY_TARGET);
}
} // namespace HWY_NAMESPACE
#if HWY_ONCE || HWY_IDE
HWY_EXPORT(GetHwyModeAsString);
const absl::string_view GetHwyModeAsString() {
return HWY_DYNAMIC_DISPATCH(GetHwyModeAsString)();
}
#endif
} // namespace dpf_internal
} // namespace distributed_point_functions
HWY_AFTER_NAMESPACE();

@ -1,32 +0,0 @@
/*
* Copyright 2022 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_GET_HWY_MODE_H_
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_GET_HWY_MODE_H_
#include "absl/strings/string_view.h"
namespace distributed_point_functions {
namespace dpf_internal {
// Utility function for printing the mode selected by Highway. Used for
// debugging.
const absl::string_view GetHwyModeAsString();
} // namespace dpf_internal
} // namespace distributed_point_functions
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_GET_HWY_MODE_H_

@ -1,92 +0,0 @@
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_ANY_SPAN_H_
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_ANY_SPAN_H_
// A class that can serve the purpose of both absl::Span<T> and absl::Span<T*>
// at the same time. Introduces the run-time overhead of a std::variant check.
//
// Note that this class DOES NOT provide common container features, such as
// iterators. It is not intended to be used by users of this library. Any
// function that takes a MaybeDerefSpan<T> should be called with either an
// absl::Span<T> or an absl::Span<T*>.
#include <type_traits>
#include "absl/meta/type_traits.h"
#include "absl/types/span.h"
#include "absl/types/variant.h"
namespace distributed_point_functions {
namespace dpf_internal {
template <typename T>
class MaybeDerefSpan {
private:
template <typename U>
using EnableIfValueIsConst =
typename absl::enable_if_t<std::is_const<T>::value, U>;
template <typename U>
using EnableIfValueIsConvertibleToSpan = typename absl::enable_if_t<
absl::disjunction<std::is_convertible<U, absl::Span<T>>,
std::is_convertible<U, absl::Span<T* const>>>::value,
U>;
public:
// Implicit constructors from the underlying absl::Span.
MaybeDerefSpan(absl::Span<T> span)
: span_(span) {} // NOLINT(runtime/explicit)
MaybeDerefSpan(absl::Span<T* const> span)
: span_(span) {} // NOLINT(runtime/explicit)
// Implicit constructor of a const MaybeDerefSpan from a non-const one.
template <typename T2 = T, typename = EnableIfValueIsConst<T2>>
MaybeDerefSpan(
const MaybeDerefSpan<typename std::remove_const<T>::type>& other)
: span_(absl::ConvertVariantTo<decltype(span_)>(other.span_)) {
} // NOLINT(runtime/explicit)
// Implicit constructor of a const MaybeDerefSpan from anything that is
// convertible to one of the underlying spans.
template <typename V, typename = EnableIfValueIsConst<V>,
typename = EnableIfValueIsConvertibleToSpan<V>>
MaybeDerefSpan(const V& span)
: span_(absl::MakeConstSpan(span)) {} // NOLINT(runtime/explicit)
inline constexpr T& operator[](size_t index) const {
if (absl::holds_alternative<absl::Span<T* const>>(span_)) {
return *absl::get<absl::Span<T* const>>(span_)[index];
}
return absl::get<absl::Span<T>>(span_)[index];
}
inline constexpr size_t size() const {
return absl::visit([](auto v) { return v.size(); }, span_);
}
private:
template <typename U>
friend class MaybeDerefSpan;
absl::variant<absl::Span<T>, absl::Span<T* const>> span_;
};
} // namespace dpf_internal
} // namespace distributed_point_functions
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_ANY_SPAN_H_

@ -1,195 +0,0 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dpf/internal/maybe_deref_span.h"
#include <type_traits>
#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace distributed_point_functions {
namespace dpf_internal {
namespace {
using T = int;
TEST(MaybeDerefSpanTest, TestExplicitMutableDirectSpan) {
std::vector<T> x = {1, 2};
absl::Span<T> span(x);
MaybeDerefSpan<T> span2(span);
EXPECT_EQ(span2.size(), x.size());
EXPECT_EQ(span2[0], x[0]);
EXPECT_EQ(span2[1], x[1]);
EXPECT_EQ(&span2[0], &x[0]);
EXPECT_EQ(&span2[1], &x[1]);
span2[0] = 3;
EXPECT_EQ(span2[0], x[0]);
EXPECT_EQ(x[0], 3);
}
TEST(MaybeDerefSpanTest, TestExplicitMutableSpan) {
const std::vector<T> x = {1, 2};
absl::Span<const T> span(x);
MaybeDerefSpan<const T> span2(span);
EXPECT_EQ(span2.size(), x.size());
EXPECT_EQ(span2[0], x[0]);
EXPECT_EQ(span2[1], x[1]);
EXPECT_EQ(&span2[0], &x[0]);
EXPECT_EQ(&span2[1], &x[1]);
}
TEST(MaybeDerefSpanTest, TestExplicitMutablePointerSpan) {
std::vector<T> x = {1, 2};
std::vector<T*> x2 = {&x[0], &x[1]};
absl::Span<T*> span(x2);
MaybeDerefSpan<T> span2(span);
EXPECT_EQ(span2.size(), x.size());
EXPECT_EQ(span2[0], x[0]);
EXPECT_EQ(span2[1], x[1]);
EXPECT_EQ(&span2[0], &x[0]);
EXPECT_EQ(&span2[1], &x[1]);
span2[0] = 3;
EXPECT_EQ(span2[0], x[0]);
EXPECT_EQ(x[0], 3);
}
TEST(MaybeDerefSpanTest, TestExplicitMutablePointerConstSpan) {
std::vector<T> x = {1, 2};
const std::vector<T*> x2 = {&x[0], &x[1]};
absl::Span<T* const> span(x2);
MaybeDerefSpan<T> span2(span);
EXPECT_EQ(span2.size(), x.size());
EXPECT_EQ(span2[0], x[0]);
EXPECT_EQ(span2[1], x[1]);
EXPECT_EQ(&span2[0], &x[0]);
EXPECT_EQ(&span2[1], &x[1]);
}
TEST(MaybeDerefSpanTest, TestExplicitConstPointerConstSpan) {
const std::vector<T> x = {1, 2};
const std::vector<const T*> x2 = {&x[0], &x[1]};
absl::Span<const T* const> span(x2);
MaybeDerefSpan<const T> span2(span);
EXPECT_EQ(span2.size(), x.size());
EXPECT_EQ(span2[0], x[0]);
EXPECT_EQ(span2[1], x[1]);
EXPECT_EQ(&span2[0], &x[0]);
EXPECT_EQ(&span2[1], &x[1]);
}
TEST(MaybeDerefSpanTest, TestMutableSpanToConstSpan) {
std::vector<T> x = {1, 2};
absl::Span<T> span(x);
MaybeDerefSpan<T> span2(span);
MaybeDerefSpan<const T> span3(span2);
EXPECT_EQ(span3.size(), x.size());
EXPECT_EQ(span3[0], x[0]);
EXPECT_EQ(span3[1], x[1]);
EXPECT_EQ(&span3[0], &x[0]);
EXPECT_EQ(&span3[1], &x[1]);
}
TEST(MaybeDerefSpanTest, TestImplicitConstSpan) {
const std::vector<T> x = {1, 2};
MaybeDerefSpan<const T> span2(x);
EXPECT_EQ(span2.size(), x.size());
EXPECT_EQ(span2[0], x[0]);
EXPECT_EQ(span2[1], x[1]);
EXPECT_EQ(&span2[0], &x[0]);
EXPECT_EQ(&span2[1], &x[1]);
}
TEST(MaybeDerefSpanTest, TestImplicitPointerConstSpan) {
const std::vector<T> x = {1, 2};
const std::vector<const T*> x2 = {&x[0], &x[1]};
MaybeDerefSpan<const T> span2(x2);
EXPECT_EQ(span2.size(), x.size());
EXPECT_EQ(span2[0], x[0]);
EXPECT_EQ(span2[1], x[1]);
EXPECT_EQ(&span2[0], &x[0]);
EXPECT_EQ(&span2[1], &x[1]);
}
void TestEq(MaybeDerefSpan<const T> span, const std::vector<T>& vector) {
EXPECT_EQ(span.size(), vector.size());
EXPECT_EQ(span[0], vector[0]);
EXPECT_EQ(span[1], vector[1]);
EXPECT_EQ(&span[0], &vector[0]);
EXPECT_EQ(&span[1], &vector[1]);
}
TEST(MaybeDerefSpanTest, TestFunctionCallMutableVector) {
std::vector<T> x = {1, 2};
TestEq(x, x);
}
TEST(MaybeDerefSpanTest, TestFunctionCallMutablePointerVector) {
std::vector<T> x = {1, 2};
std::vector<T*> x2 = {&x[0], &x[1]};
TestEq(x2, x);
}
TEST(MaybeDerefSpanTest, TestFunctionCallConstVector) {
const std::vector<T> x = {1, 2};
TestEq(x, x);
}
TEST(MaybeDerefSpanTest, TestFunctionCallMutablePointerConstVector) {
std::vector<T> x = {1, 2};
const std::vector<T*> x2 = {&x[0], &x[1]};
TestEq(x2, x);
}
TEST(MaybeDerefSpanTest, TestFunctionCallConstPointerConstVector) {
const std::vector<T> x = {1, 2};
const std::vector<const T*> x2 = {&x[0], &x[1]};
TestEq(x2, x);
}
// Taken from https://en.cppreference.com/w/cpp/types/is_convertible.
template <class From, class To>
auto test_implicitly_convertible(int)
-> decltype(void(std::declval<void (&)(To)>()(std::declval<From>())),
std::true_type{});
template <class, class>
auto test_implicitly_convertible(...) -> std::false_type;
// Test that vectors are convertible only to const spans.
static_assert(
decltype(test_implicitly_convertible<std::vector<T>, MaybeDerefSpan<T>>(
0))::value == false);
static_assert(decltype(test_implicitly_convertible<
std::vector<T>, MaybeDerefSpan<const T>>(0))::value ==
true);
} // namespace
} // namespace dpf_internal
} // namespace distributed_point_functions

@ -1,336 +0,0 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dpf/internal/proto_validator.h"
#include <algorithm>
#include <cmath>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/log/absl_check.h"
#include "absl/memory/memory.h"
#include "absl/numeric/int128.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "dpf/distributed_point_function.pb.h"
#include "dpf/internal/value_type_helpers.h"
#include "dpf/status_macros.h"
#include "google/protobuf/repeated_field.h"
namespace distributed_point_functions {
namespace dpf_internal {
namespace {
inline double GetDefaultSecurityParameter(const DpfParameters& parameters) {
return ProtoValidator::kDefaultSecurityParameter +
parameters.log_domain_size();
}
inline bool AlmostEqual(double a, double b) {
return std::abs(a - b) <= ProtoValidator::kSecurityParameterEpsilon;
}
absl::StatusOr<bool> ParametersAreEqual(const DpfParameters& lhs,
const DpfParameters& rhs) {
if (lhs.log_domain_size() != rhs.log_domain_size()) {
return false;
}
if (!(
// There are three ways that security parameters can be equivalent.
// Both equal.
AlmostEqual(lhs.security_parameter(), rhs.security_parameter()) ||
// lhs is zero and rhs has the default value.
(lhs.security_parameter() == 0 &&
AlmostEqual(rhs.security_parameter(),
GetDefaultSecurityParameter(rhs))) ||
// rhs is zero and lhs has the default value.
(rhs.security_parameter() == 0 &&
AlmostEqual(lhs.security_parameter(),
GetDefaultSecurityParameter(lhs))))) {
return false;
}
return ValueTypesAreEqual(lhs.value_type(), rhs.value_type());
}
absl::Status ValidateIntegerType(const ValueType::Integer& type) {
int bitsize = type.bitsize();
if (bitsize < 1) {
return absl::InvalidArgumentError("`bitsize` must be positive");
}
if (bitsize > 128) {
return absl::InvalidArgumentError(
"`bitsize` must be less than or equal to 128");
}
if ((bitsize & (bitsize - 1)) != 0) {
return absl::InvalidArgumentError("`bitsize` must be a power of 2");
}
return absl::OkStatus();
}
absl::Status ValidateIntegerValue(const Value::Integer& value,
const ValueType::Integer& type) {
if (type.bitsize() < 128) {
DPF_ASSIGN_OR_RETURN(absl::uint128 value_128, ValueIntegerToUint128(value));
if (value_128 >= absl::uint128{1} << type.bitsize()) {
return absl::InvalidArgumentError(absl::StrFormat(
"Value (= %d) too large for ValueType with bitsize = %d", value_128,
type.bitsize()));
}
}
return absl::OkStatus();
}
} // namespace
ProtoValidator::ProtoValidator(std::vector<DpfParameters> parameters,
int tree_levels_needed,
absl::flat_hash_map<int, int> tree_to_hierarchy,
std::vector<int> hierarchy_to_tree)
: parameters_(std::move(parameters)),
tree_levels_needed_(tree_levels_needed),
tree_to_hierarchy_(std::move(tree_to_hierarchy)),
hierarchy_to_tree_(std::move(hierarchy_to_tree)) {}
absl::StatusOr<std::unique_ptr<ProtoValidator>> ProtoValidator::Create(
absl::Span<const DpfParameters> parameters_in) {
DPF_RETURN_IF_ERROR(ValidateParameters(parameters_in));
// Set default values of security_parameter for all parameters.
std::vector<DpfParameters> parameters(parameters_in.begin(),
parameters_in.end());
for (int i = 0; i < static_cast<int>(parameters.size()); ++i) {
if (parameters[i].security_parameter() == 0) {
parameters[i].set_security_parameter(
GetDefaultSecurityParameter(parameters[i]));
}
}
// Map hierarchy levels to levels in the evaluation tree for value correction,
// and vice versa.
absl::flat_hash_map<int, int> tree_to_hierarchy;
std::vector<int> hierarchy_to_tree(parameters.size());
// Also keep track of the height needed for the evaluation tree so far.
int tree_levels_needed = 0;
for (int i = 0; i < static_cast<int>(parameters.size()); ++i) {
int log_bits_needed;
DPF_ASSIGN_OR_RETURN(int bits_needed,
BitsNeeded(parameters[i].value_type(),
parameters[i].security_parameter()));
log_bits_needed = static_cast<int>(std::ceil(std::log2(bits_needed)));
// The tree level depends on the domain size and the element size. A single
// AES block can fit 128 = 2^7 bits, so usually tree_level ==
// log_domain_size iff log_element_size >= 7. For smaller element sizes, we
// can reduce the tree_level (and thus the height of the tree) by the
// difference between log_element_size and 7. However, since the minimum
// tree level is 0, we have to ensure that no two hierarchy levels map to
// the same tree_level, hence the std::max.
int tree_level =
std::max(tree_levels_needed, parameters[i].log_domain_size() - 7 +
std::min(log_bits_needed, 7));
tree_to_hierarchy[tree_level] = i;
hierarchy_to_tree[i] = tree_level;
tree_levels_needed = std::max(tree_levels_needed, tree_level + 1);
}
return absl::WrapUnique(new ProtoValidator(
std::move(parameters), tree_levels_needed, std::move(tree_to_hierarchy),
std::move(hierarchy_to_tree)));
}
absl::Status ProtoValidator::ValidateParameters(
absl::Span<const DpfParameters> parameters) {
// Check that parameters are valid.
if (parameters.empty()) {
return absl::InvalidArgumentError("`parameters` must not be empty");
}
// Sentinel value for checking that domain sizes are increasing.
int previous_log_domain_size = 0;
for (int i = 0; i < static_cast<int>(parameters.size()); ++i) {
// Check log_domain_size.
int log_domain_size = parameters[i].log_domain_size();
if (log_domain_size < 0) {
return absl::InvalidArgumentError(
"`log_domain_size` must be non-negative");
}
if (log_domain_size > 128) {
return absl::InvalidArgumentError("`log_domain_size` must be <= 128");
}
if (i > 0 && log_domain_size <= previous_log_domain_size) {
return absl::InvalidArgumentError(
"`log_domain_size` fields must be in ascending order in "
"`parameters`");
}
previous_log_domain_size = log_domain_size;
if (parameters[i].has_value_type()) {
DPF_RETURN_IF_ERROR(ValidateValueType(parameters[i].value_type()));
} else {
return absl::InvalidArgumentError("`value_type` is required");
}
if (std::isnan(parameters[i].security_parameter())) {
return absl::InvalidArgumentError("`security_parameter` must not be NaN");
}
if (parameters[i].security_parameter() < 0 ||
parameters[i].security_parameter() > 128) {
// Since we use AES-128 for the PRG, a security parameter of > 128 is not
// possible.
return absl::InvalidArgumentError(
"`security_parameter` must be in [0, 128]");
}
}
return absl::OkStatus();
}
absl::Status ProtoValidator::ValidateDpfKey(const DpfKey& key) const {
// Check that `key` has the seed and last_level_output_correction set.
if (!key.has_seed()) {
return absl::InvalidArgumentError("key.seed must be present");
}
if (key.last_level_value_correction().empty()) {
return absl::InvalidArgumentError(
"key.last_level_value_correction must be present");
}
// Check that `key` is valid for the DPF defined by `parameters_`.
if (key.correction_words_size() != tree_levels_needed_ - 1) {
return absl::InvalidArgumentError(absl::StrCat(
"Malformed DpfKey: expected ", tree_levels_needed_ - 1,
" correction words, but got ", key.correction_words_size()));
}
for (int i = 0; i < static_cast<int>(hierarchy_to_tree_.size()); ++i) {
if (hierarchy_to_tree_[i] == tree_levels_needed_ - 1) {
// The output correction of the last tree level is always stored in
// last_level_output_correction.
continue;
}
ABSL_DCHECK(hierarchy_to_tree_[i] < key.correction_words_size());
if (key.correction_words(hierarchy_to_tree_[i])
.value_correction()
.empty()) {
return absl::InvalidArgumentError(absl::StrCat(
"Malformed DpfKey: expected correction_words[", hierarchy_to_tree_[i],
"] to contain the value correction of hierarchy level ", i));
}
}
return absl::OkStatus();
}
absl::Status ProtoValidator::ValidateEvaluationContext(
const EvaluationContext& ctx) const {
if (ctx.parameters_size() != static_cast<int>(parameters_.size())) {
return absl::InvalidArgumentError(
"Number of parameters in `ctx` doesn't match");
}
for (int i = 0; i < ctx.parameters_size(); ++i) {
DPF_ASSIGN_OR_RETURN(bool parameters_are_equal,
ParametersAreEqual(parameters_[i], ctx.parameters(i)));
if (!parameters_are_equal) {
return absl::InvalidArgumentError(
absl::StrCat("Parameter ", i, " in `ctx` doesn't match"));
}
}
if (!ctx.has_key()) {
return absl::InvalidArgumentError("ctx.key must be present");
}
DPF_RETURN_IF_ERROR(ValidateDpfKey(ctx.key()));
if (ctx.previous_hierarchy_level() >= ctx.parameters_size() - 1) {
return absl::InvalidArgumentError(
"This context has already been fully evaluated");
}
if (!ctx.partial_evaluations().empty() &&
ctx.partial_evaluations_level() > ctx.previous_hierarchy_level()) {
return absl::InvalidArgumentError(
"ctx.partial_evaluations_level must be less than or equal to "
"ctx.previous_hierarchy_level");
}
return absl::OkStatus();
}
absl::Status ProtoValidator::ValidateValueType(const ValueType& value_type) {
if (value_type.type_case() == ValueType::kInteger) {
return ValidateIntegerType(value_type.integer());
} else if (value_type.type_case() == ValueType::kTuple) {
for (const ValueType& el : value_type.tuple().elements()) {
DPF_RETURN_IF_ERROR(ValidateValueType(el));
}
return absl::OkStatus();
} else if (value_type.type_case() == ValueType::kIntModN) {
const ValueType::Integer& base_integer =
value_type.int_mod_n().base_integer();
DPF_RETURN_IF_ERROR(ValidateIntegerType(base_integer));
return ValidateIntegerValue(value_type.int_mod_n().modulus(), base_integer);
} else if (value_type.type_case() == ValueType::kXorWrapper) {
return ValidateIntegerType(value_type.xor_wrapper());
}
return absl::InvalidArgumentError(absl::StrCat(
"ValidateValueType: Unsupported ValueType:\n", value_type.DebugString()));
}
absl::Status ProtoValidator::ValidateValue(const Value& value,
const ValueType& type) {
if (type.type_case() == ValueType::kInteger) {
// Integers.
if (value.value_case() != Value::kInteger) {
return absl::InvalidArgumentError("Expected integer value");
}
return ValidateIntegerValue(value.integer(), type.integer());
} else if (type.type_case() == ValueType::kTuple) {
// Tuples.
if (value.value_case() != Value::kTuple) {
return absl::InvalidArgumentError("Expected tuple value");
}
if (value.tuple().elements_size() != type.tuple().elements_size()) {
return absl::InvalidArgumentError(absl::StrCat(
"Expected tuple value of size ", type.tuple().elements_size(),
" but got size ", value.tuple().elements_size()));
}
for (int i = 0; i < type.tuple().elements_size(); ++i) {
DPF_RETURN_IF_ERROR(
ValidateValue(value.tuple().elements(i), type.tuple().elements(i)));
}
return absl::OkStatus();
} else if (type.type_case() == ValueType::kIntModN) {
DPF_RETURN_IF_ERROR(ValidateIntegerValue(value.int_mod_n(),
type.int_mod_n().base_integer()));
DPF_ASSIGN_OR_RETURN(absl::uint128 value_128,
ValueIntegerToUint128(value.int_mod_n()));
DPF_ASSIGN_OR_RETURN(absl::uint128 modulus_128,
ValueIntegerToUint128(type.int_mod_n().modulus()));
if (value_128 >= modulus_128) {
return absl::InvalidArgumentError(
absl::StrFormat("Value (= %d) is too large for modulus (= %d)",
value_128, modulus_128));
}
return absl::OkStatus();
} else if (type.type_case() == ValueType::kXorWrapper) {
if (value.value_case() != Value::kXorWrapper) {
return absl::InvalidArgumentError("Expected XorWrapper value");
}
return ValidateIntegerValue(value.xor_wrapper(), type.xor_wrapper());
}
return absl::InvalidArgumentError(absl::StrCat(
"ValidateValue: Unsupported ValueType:\n", type.DebugString()));
}
} // namespace dpf_internal
} // namespace distributed_point_functions

@ -1,120 +0,0 @@
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_PROTO_VALIDATOR_H_
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_PROTO_VALIDATOR_H_
#include <memory>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "dpf/distributed_point_function.pb.h"
namespace distributed_point_functions {
namespace dpf_internal {
// ProtoValidator is used to validate protos for DPF parameters, keys, and
// evaluation contexts. Also holds information computed from the DPF parameters,
// such as the mappings between hierarchy and tree levels.
class ProtoValidator {
public:
// The negative logarithm of the total variation distance from uniform that a
// *full* evaluation of a hierarchy level is allowed to have. Used as the
// default value for DpfParameters that don't have an explicit per-element
// security parameter set.
static constexpr double kDefaultSecurityParameter = 40;
// Security parameters that differ by less than this are considered equal.
static constexpr double kSecurityParameterEpsilon = 0.0001;
// Checks the validity of `parameters` and returns a ProtoValidator, which
// will be used to validate DPF keys and evaluation contexts afterwards.
//
// Returns INVALID_ARGUMENT if `parameters` are invalid.
static absl::StatusOr<std::unique_ptr<ProtoValidator>> Create(
absl::Span<const DpfParameters> parameters);
// Checks the validity of `parameters`.
// Returns OK on success, and INVALID_ARGUMENT otherwise.
static absl::Status ValidateParameters(
absl::Span<const DpfParameters> parameters);
// Checks that `key` is valid for the `parameters` passed at construction.
// Returns OK on success, and INVALID_ARGUMENT otherwise.
absl::Status ValidateDpfKey(const DpfKey& key) const;
// Checks that `ctx` is valid for the `parameters` passed at construction.
// Returns OK on success, and INVALID_ARGUMENT otherwise.
absl::Status ValidateEvaluationContext(const EvaluationContext& ctx) const;
// Checks that the given ValueType is valid.
// Returns OK on success and INVALID_ARGUMENT otherwise.
static absl::Status ValidateValueType(const ValueType& value_type);
// Checks that `value` is valid for `type`.
// Returns OK on success and INVALID_ARGUMENT otherwise.
static absl::Status ValidateValue(const Value& value, const ValueType& type);
// Checks that `value` is valid for `parameters[i]` passed at construction.
// Returns OK on success and INVALID_ARGUMENT otherwise.
inline absl::Status ValidateValue(const Value& value, int i) const {
return ValidateValue(value, parameters_[i].value_type());
}
// ProtoValidator is not copyable.
ProtoValidator(const ProtoValidator&) = delete;
ProtoValidator& operator=(const ProtoValidator&) = delete;
// ProtoValidator is movable.
ProtoValidator(ProtoValidator&&) = default;
ProtoValidator& operator=(ProtoValidator&&) = default;
// Getters.
absl::Span<const DpfParameters> parameters() const { return parameters_; }
int tree_levels_needed() const { return tree_levels_needed_; }
const absl::flat_hash_map<int, int>& tree_to_hierarchy() const {
return tree_to_hierarchy_;
}
const std::vector<int>& hierarchy_to_tree() const {
return hierarchy_to_tree_;
}
private:
ProtoValidator(std::vector<DpfParameters> parameters, int tree_levels_needed,
absl::flat_hash_map<int, int> tree_to_hierarchy,
std::vector<int> hierarchy_to_tree);
// The DpfParameters passed at construction.
std::vector<DpfParameters> parameters_;
// Number of levels in the evaluation tree. This is always less than or equal
// to the largest log_domain_size in parameters_.
int tree_levels_needed_;
// Maps levels of the FSS evaluation tree to hierarchy levels (i.e., elements
// of parameters_).
absl::flat_hash_map<int, int> tree_to_hierarchy_;
// The inverse of tree_to_hierarchy_.
std::vector<int> hierarchy_to_tree_;
};
} // namespace dpf_internal
} // namespace distributed_point_functions
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_PROTO_VALIDATOR_H_

@ -1,417 +0,0 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dpf/internal/proto_validator.h"
#include <stdint.h>
#include <cmath>
#include <memory>
#include <string>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "dpf/distributed_point_function.pb.h"
#include "dpf/internal/proto_validator_test_textproto_embed.h"
#include "dpf/internal/status_matchers.h"
#include "dpf/tuple.h"
#include "gmock/gmock.h"
#include "google/protobuf/repeated_field.h"
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
namespace distributed_point_functions {
namespace dpf_internal {
namespace {
using ::testing::Ne;
using ::testing::StartsWith;
class ProtoValidatorTest : public testing::Test {
protected:
void SetUp() override {
const auto* const toc = proto_validator_test_textproto_embed_create();
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(
std::string(toc->data, toc->size), &ctx_));
parameters_ = std::vector<DpfParameters>(ctx_.parameters().begin(),
ctx_.parameters().end());
dpf_key_ = ctx_.key();
DPF_ASSERT_OK_AND_ASSIGN(proto_validator_,
ProtoValidator::Create(parameters_));
}
std::vector<DpfParameters> parameters_;
DpfKey dpf_key_;
EvaluationContext ctx_;
std::unique_ptr<dpf_internal::ProtoValidator> proto_validator_;
};
TEST_F(ProtoValidatorTest, CreateFailsWithoutParameters) {
EXPECT_THAT(ProtoValidator::Create({}),
StatusIs(absl::StatusCode::kInvalidArgument,
"`parameters` must not be empty"));
}
TEST_F(ProtoValidatorTest, CreateFailsWhenParametersNotSorted) {
parameters_.resize(2);
parameters_[0].set_log_domain_size(10);
parameters_[1].set_log_domain_size(8);
EXPECT_THAT(ProtoValidator::Create(parameters_),
StatusIs(absl::StatusCode::kInvalidArgument,
"`log_domain_size` fields must be in ascending order in "
"`parameters`"));
}
TEST_F(ProtoValidatorTest, CreateFailsWhenDomainSizeNegative) {
parameters_.resize(1);
parameters_[0].set_log_domain_size(-1);
EXPECT_THAT(ProtoValidator::Create(parameters_),
StatusIs(absl::StatusCode::kInvalidArgument,
"`log_domain_size` must be non-negative"));
}
TEST_F(ProtoValidatorTest, CreateFailsWhenDomainSizeTooLarge) {
parameters_.resize(1);
parameters_[0].set_log_domain_size(129);
EXPECT_THAT(ProtoValidator::Create(parameters_),
StatusIs(absl::StatusCode::kInvalidArgument,
"`log_domain_size` must be <= 128"));
}
TEST_F(ProtoValidatorTest, CreateFailsWhenElementBitsizeNegative) {
parameters_.resize(1);
parameters_[0].mutable_value_type()->mutable_integer()->set_bitsize(-1);
EXPECT_THAT(ProtoValidator::Create(parameters_),
StatusIs(absl::StatusCode::kInvalidArgument,
"`bitsize` must be positive"));
}
TEST_F(ProtoValidatorTest, CreateFailsWhenElementBitsizeZero) {
parameters_.resize(1);
parameters_[0].mutable_value_type()->mutable_integer()->set_bitsize(0);
EXPECT_THAT(ProtoValidator::Create(parameters_),
StatusIs(absl::StatusCode::kInvalidArgument,
"`bitsize` must be positive"));
}
TEST_F(ProtoValidatorTest, CreateFailsWhenElementBitsizeTooLarge) {
parameters_.resize(1);
parameters_[0].mutable_value_type()->mutable_integer()->set_bitsize(256);
EXPECT_THAT(ProtoValidator::Create(parameters_),
StatusIs(absl::StatusCode::kInvalidArgument,
"`bitsize` must be less than or equal to 128"));
}
TEST_F(ProtoValidatorTest, CreateFailsWhenElementBitsizeNotAPowerOfTwo) {
parameters_.resize(1);
parameters_[0].mutable_value_type()->mutable_integer()->set_bitsize(23);
EXPECT_THAT(ProtoValidator::Create(parameters_),
StatusIs(absl::StatusCode::kInvalidArgument,
"`bitsize` must be a power of 2"));
}
TEST_F(ProtoValidatorTest, CreateFailsIfSecurityParameterIsNaN) {
parameters_.resize(1);
parameters_[0].set_security_parameter(std::nan(""));
EXPECT_THAT(ProtoValidator::Create(parameters_),
StatusIs(absl::StatusCode::kInvalidArgument,
"`security_parameter` must not be NaN"));
}
TEST_F(ProtoValidatorTest, CreateFailsIfSecurityParameterIsNegative) {
parameters_.resize(1);
parameters_[0].set_security_parameter(-0.01);
EXPECT_THAT(ProtoValidator::Create(parameters_),
StatusIs(absl::StatusCode::kInvalidArgument,
"`security_parameter` must be in [0, 128]"));
}
TEST_F(ProtoValidatorTest, CreateFailsIfSecurityParameterIsTooLarge) {
parameters_.resize(1);
parameters_[0].set_security_parameter(128.01);
EXPECT_THAT(ProtoValidator::Create(parameters_),
StatusIs(absl::StatusCode::kInvalidArgument,
"`security_parameter` must be in [0, 128]"));
}
TEST_F(ProtoValidatorTest, CreateWorksWhenElementBitsizesDecrease) {
parameters_.resize(2);
parameters_[0].mutable_value_type()->mutable_integer()->set_bitsize(64);
parameters_[1].mutable_value_type()->mutable_integer()->set_bitsize(32);
EXPECT_THAT(ProtoValidator::Create(parameters_), IsOkAndHolds(Ne(nullptr)));
}
TEST_F(ProtoValidatorTest, CreateWorksWhenHierarchiesAreFarApart) {
parameters_.resize(2);
parameters_[0].set_log_domain_size(10);
parameters_[1].set_log_domain_size(128);
EXPECT_THAT(ProtoValidator::Create(parameters_), IsOkAndHolds(Ne(nullptr)));
}
TEST_F(ProtoValidatorTest,
ValidateDpfKeyFailsIfNumberOfCorrectionWordsDoesntMatch) {
dpf_key_.add_correction_words();
EXPECT_THAT(proto_validator_->ValidateDpfKey(dpf_key_),
StatusIs(absl::StatusCode::kInvalidArgument,
absl::StrCat("Malformed DpfKey: expected ",
dpf_key_.correction_words_size() - 1,
" correction words, but got ",
dpf_key_.correction_words_size())));
}
TEST_F(ProtoValidatorTest, ValidateDpfKeyFailsIfSeedIsMissing) {
dpf_key_.clear_seed();
EXPECT_THAT(
proto_validator_->ValidateDpfKey(dpf_key_),
StatusIs(absl::StatusCode::kInvalidArgument, "key.seed must be present"));
}
TEST_F(ProtoValidatorTest,
ValidateDpfKeyFailsIfLastLevelOutputCorrectionIsMissing) {
dpf_key_.clear_last_level_value_correction();
EXPECT_THAT(proto_validator_->ValidateDpfKey(dpf_key_),
StatusIs(absl::StatusCode::kInvalidArgument,
"key.last_level_value_correction must be present"));
}
TEST_F(ProtoValidatorTest, ValidateDpfKeyFailsIfOutputCorrectionIsMissing) {
for (CorrectionWord& cw : *(dpf_key_.mutable_correction_words())) {
cw.clear_value_correction();
}
EXPECT_THAT(
proto_validator_->ValidateDpfKey(dpf_key_),
StatusIs(absl::StatusCode::kInvalidArgument,
StartsWith("Malformed DpfKey: expected correction_words")));
}
TEST_F(ProtoValidatorTest, ValidateEvaluationContextFailsIfKeyIsMissing) {
ctx_.clear_key();
EXPECT_THAT(
proto_validator_->ValidateEvaluationContext(ctx_),
StatusIs(absl::StatusCode::kInvalidArgument, "ctx.key must be present"));
}
TEST_F(ProtoValidatorTest,
ValidateEvaluationContextFailsIfParameterSizeDoesntMatch) {
ctx_.mutable_parameters()->erase(ctx_.parameters().end() - 1);
EXPECT_THAT(proto_validator_->ValidateEvaluationContext(ctx_),
StatusIs(absl::StatusCode::kInvalidArgument,
"Number of parameters in `ctx` doesn't match"));
}
TEST_F(ProtoValidatorTest,
ValidateEvaluationContextFailsIfLogDomainSizeDoesntMatch) {
ctx_.mutable_parameters(0)->set_log_domain_size(
ctx_.parameters(0).log_domain_size() + 1);
EXPECT_THAT(proto_validator_->ValidateEvaluationContext(ctx_),
StatusIs(absl::StatusCode::kInvalidArgument,
"Parameter 0 in `ctx` doesn't match"));
}
TEST_F(ProtoValidatorTest,
ValidateEvaluationContextSucceedsIfSecurityParameterIsDefault) {
parameters_[0].set_security_parameter(0);
DPF_ASSERT_OK_AND_ASSIGN(proto_validator_,
ProtoValidator::Create(parameters_));
ctx_.mutable_parameters(0)->set_security_parameter(0);
EXPECT_THAT(proto_validator_->ValidateEvaluationContext(ctx_), IsOk());
}
TEST_F(ProtoValidatorTest,
ValidateEvaluationContextFailsIfSecurityParameterDoesntMatch) {
ctx_.mutable_parameters(0)->set_security_parameter(
ctx_.parameters(0).security_parameter() + 1);
EXPECT_THAT(proto_validator_->ValidateEvaluationContext(ctx_),
StatusIs(absl::StatusCode::kInvalidArgument,
"Parameter 0 in `ctx` doesn't match"));
}
TEST_F(ProtoValidatorTest,
ValidateEvaluationContextFailsIfContextFullyEvaluated) {
ctx_.set_previous_hierarchy_level(parameters_.size() - 1);
EXPECT_THAT(proto_validator_->ValidateEvaluationContext(ctx_),
StatusIs(absl::StatusCode::kInvalidArgument,
"This context has already been fully evaluated"));
}
TEST_F(ProtoValidatorTest,
ValidateEvaluationContextFailsIfPartialEvaluationsLevelTooLarge) {
ctx_.set_previous_hierarchy_level(0);
ctx_.set_partial_evaluations_level(1);
ctx_.add_partial_evaluations();
EXPECT_THAT(
proto_validator_->ValidateEvaluationContext(ctx_),
StatusIs(absl::StatusCode::kInvalidArgument,
"ctx.partial_evaluations_level must be less than or equal to "
"ctx.previous_hierarchy_level"));
}
TEST_F(ProtoValidatorTest, ValidateValueFailsIfTypeNotInteger) {
ValueType type;
type.mutable_integer()->set_bitsize(32);
Value value;
value.mutable_tuple()->add_elements()->mutable_integer()->set_value_uint64(
23);
EXPECT_THAT(
proto_validator_->ValidateValue(value, type),
StatusIs(absl::StatusCode::kInvalidArgument, "Expected integer value"));
}
TEST_F(ProtoValidatorTest, ValidateValueFailsIfIntegerTooLarge) {
ValueType type;
Value value;
int element_bitsize = 32;
type.mutable_integer()->set_bitsize(element_bitsize);
auto value_64 = uint64_t{1} << element_bitsize;
value.mutable_integer()->set_value_uint64(value_64);
EXPECT_THAT(
proto_validator_->ValidateValue(value, type),
StatusIs(absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"Value (= %d) too large for ValueType with bitsize = %d",
value_64, element_bitsize)));
}
TEST_F(ProtoValidatorTest, ValidateValueFailsIfTypeNotTuple) {
ValueType type;
type.mutable_tuple()->add_elements()->mutable_integer()->set_bitsize(32);
Value value;
value.mutable_integer()->set_value_uint64(23);
EXPECT_THAT(
proto_validator_->ValidateValue(value, type),
StatusIs(absl::StatusCode::kInvalidArgument, "Expected tuple value"));
}
TEST_F(ProtoValidatorTest, ValidateValueFailsIfTupleSizeDoesntMatch) {
ValueType type;
type.mutable_tuple()->add_elements()->mutable_integer()->set_bitsize(32);
Value value;
value.mutable_tuple()->add_elements()->mutable_integer()->set_value_uint64(
23);
value.mutable_tuple()->add_elements()->mutable_integer()->set_value_uint64(
42);
EXPECT_THAT(proto_validator_->ValidateValue(value, type),
StatusIs(absl::StatusCode::kInvalidArgument,
"Expected tuple value of size 1 but got size 2"));
}
TEST_F(ProtoValidatorTest, ValidateValueFailsIfValueLargerThanModulus) {
constexpr uint64_t kModulus = 3;
ValueType type;
type.mutable_int_mod_n()->mutable_base_integer()->set_bitsize(64);
type.mutable_int_mod_n()->mutable_modulus()->set_value_uint64(kModulus);
Value value;
value.mutable_int_mod_n()->set_value_uint64(kModulus);
EXPECT_THAT(proto_validator_->ValidateValue(value, type),
StatusIs(absl::StatusCode::kInvalidArgument,
"Value (= 3) is too large for modulus (= 3)"));
}
TEST_F(ProtoValidatorTest, ValidateValueFailsIfTypeNotXorWrapper) {
ValueType type;
type.mutable_xor_wrapper()->set_bitsize(32);
Value value;
value.mutable_integer()->set_value_uint64(23);
EXPECT_THAT(proto_validator_->ValidateValue(value, type),
StatusIs(absl::StatusCode::kInvalidArgument,
"Expected XorWrapper value"));
}
TEST_F(ProtoValidatorTest, ValidateValueFailsIfValueIsUnknown) {
ValueType type;
Value value;
EXPECT_THAT(
proto_validator_->ValidateValue(value, type),
StatusIs(absl::StatusCode::kInvalidArgument,
testing::StartsWith("ValidateValue: Unsupported ValueType:")));
}
TEST(ProtoValidator, ValidateValueTypeFailsIfBitsizeNotPositive) {
ValueType type;
type.mutable_integer()->set_bitsize(0);
EXPECT_THAT(ProtoValidator::ValidateValueType(type),
StatusIs(absl::StatusCode::kInvalidArgument,
"`bitsize` must be positive"));
}
TEST(ProtoValidator, ValidateValueTypeFailsIfBitsizeTooLarge) {
ValueType type;
type.mutable_integer()->set_bitsize(256);
EXPECT_THAT(ProtoValidator::ValidateValueType(type),
StatusIs(absl::StatusCode::kInvalidArgument,
"`bitsize` must be less than or equal to 128"));
}
TEST(ProtoValidator, ValidateValueTypeFailsIfBitsizeNotPowerOfTwo) {
ValueType type;
type.mutable_integer()->set_bitsize(17);
EXPECT_THAT(ProtoValidator::ValidateValueType(type),
StatusIs(absl::StatusCode::kInvalidArgument,
"`bitsize` must be a power of 2"));
}
TEST(ProtoValidator, ValidateValueTypeFailsIfNoTypeChosen) {
ValueType type;
EXPECT_THAT(ProtoValidator::ValidateValueType(type),
StatusIs(absl::StatusCode::kInvalidArgument,
StartsWith("ValidateValueType: Unsupported ValueType")));
}
} // namespace
} // namespace dpf_internal
} // namespace distributed_point_functions

@ -1,108 +0,0 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# proto-file dpf/distributed_point_function.proto
# proto-message: EvaluationContext
parameters {
log_domain_size: 4
value_type {
integer {
bitsize: 32
}
}
security_parameter: 44
}
parameters {
log_domain_size: 6
value_type {
integer {
bitsize: 32
}
}
security_parameter: 46
}
parameters {
log_domain_size: 8
value_type {
integer {
bitsize: 32
}
}
security_parameter: 48
}
key {
seed {
high: 11559904407150645412
low: 10793182457266619527
}
correction_words {
seed {
high: 17231204231811741091
low: 13184625655696690000
}
control_left: true
}
correction_words {
seed {
high: 3072212389250066354
low: 1361245143349174348
}
}
correction_words {
seed {
high: 2882988684359810666
low: 16992210518729579018
}
control_right: true
value_correction: {
integer: {
value_uint64: 536412310
}
}
}
correction_words {
seed {
high: 4993590839844520517
low: 13033365507284852634
}
control_right: true
}
correction_words {
seed {
high: 10673753674550143002
low: 3019916643383017704
}
control_left: true
control_right: true
value_correction: {
integer: {
value_uint64: 841224518
}
}
}
correction_words {
seed {
high: 2423099213299230757
low: 12788496417753523946
}
control_right: true
}
last_level_value_correction: {
integer: {
value_uint64: 8471844854
}
}
}
previous_hierarchy_level: -1

@ -1,63 +0,0 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dpf/internal/status_matchers.h"
#include <ostream>
#include <string>
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace distributed_point_functions {
namespace dpf_internal {
void StatusIsMatcherCommonImpl::DescribeTo(std::ostream* os) const {
*os << "has a status code that ";
code_matcher_.DescribeTo(os);
*os << ", and has an error message that ";
message_matcher_.DescribeTo(os);
}
void StatusIsMatcherCommonImpl::DescribeNegationTo(std::ostream* os) const {
*os << "has a status code that ";
code_matcher_.DescribeNegationTo(os);
*os << ", or has an error message that ";
message_matcher_.DescribeNegationTo(os);
}
bool StatusIsMatcherCommonImpl::MatchAndExplain(
const ::absl::Status& status,
::testing::MatchResultListener* result_listener) const {
::testing::StringMatchResultListener inner_listener;
if (!code_matcher_.MatchAndExplain(status.code(), &inner_listener)) {
*result_listener << (inner_listener.str().empty()
? "whose status code is wrong"
: "which has a status code " +
inner_listener.str());
return false;
}
if (!message_matcher_.Matches(std::string(status.message()))) {
*result_listener << "whose error message is wrong: " << status.message();
return false;
}
return true;
}
} // namespace dpf_internal
} // namespace distributed_point_functions

@ -1,390 +0,0 @@
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Testing utilities for working with absl::Status and absl::StatusOr.
//
// Defines the following utilities:
//
// =================
// DPF_EXPECT_OK(s)
//
// DPF_ASSERT_OK(s)
// =================
// Convenience macros for `EXPECT_THAT(s, IsOk())`, where `s` is either
// a `Status` or a `StatusOr<T>`.
//
// There are no EXPECT_NOT_OK/ASSERT_NOT_OK macros since they would not
// provide much value (when they fail, they would just print the OK status
// which conveys no more information than `EXPECT_FALSE(s.ok())`. You can
// of course use `EXPECT_THAT(s, Not(IsOk()))` if you prefer _THAT style.
//
// If you want to check for particular errors, better alternatives are:
// EXPECT_THAT(s, StatusIs(expected_error));
// EXPECT_THAT(s, StatusIs(_, HasSubstr("expected error")));
//
// ===============
// IsOkAndHolds(m)
// ===============
//
// This gMock matcher matches a StatusOr<T> value whose status is OK
// and whose inner value matches matcher m. Example:
//
// using ::testing::MatchesRegex;
// using distributed_point_functions::IsOkAndHolds;
// ...
// absl::StatusOr<string> maybe_name = ...;
// EXPECT_THAT(maybe_name, IsOkAndHolds(MatchesRegex("John .*")));
//
// ===============================
// StatusIs(status_code_matcher,
// error_message_matcher)
// ===============================
//
// This gMock matcher matches a Status or StatusOr<T> value if all of the
// following are true:
//
// - the status' error_code() matches status_code_matcher, and
// - the status' error_message() matches error_message_matcher.
//
// Example:
//
// enum FooErrorCode {
// ...
// kServerError
// };
//
// using ::testing::HasSubstr;
// using ::testing::MatchesRegex;
// using ::testing::Ne;
// using ::testing::_;
// using distributed_point_functions::StatusIs;
// absl::StatusOr<string> GetName(int id);
// ...
//
// // The status code must be kServerError; the error message can be
// // anything.
// EXPECT_THAT(GetName(42),
// StatusIs(kServerError, _));
// // The status code can be anything; the error message must match the
// // regex.
// EXPECT_THAT(GetName(43),
// StatusIs(_, MatchesRegex("server.*time-out")));
//
// // The status code should not be kServerError; the error message can be
// // anything with "client" in it.
// EXPECT_CALL(mock_env, HandleStatus(
// StatusIs(Ne(kServerError), HasSubstr("client"))));
//
// ===============================
// StatusIs(status_code_matcher)
// ===============================
//
// This is a shorthand for
// StatusIs(status_code_matcher,
// testing::_)
// In other words, it's like the two-argument StatusIs(), except that it
// ignores error message.
//
// ===============
// IsOk()
// ===============
//
// Matches an absl::Status or absl::StatusOr<T> value whose status value is
// absl::StatusCode::kOk. Equivalent to 'StatusIs(absl::StatusCode::kOk)'.
// Example:
// using distributed_point_functions::IsOk;
// ...
// absl::StatusOr<string> maybe_name = ...;
// EXPECT_THAT(maybe_name, IsOk());
// Status s = ...;
// EXPECT_THAT(s, IsOk());
//
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MATCHERS_H_
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MATCHERS_H_
#include <ostream>
#include <string>
#include <type_traits>
#include <utility>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "dpf/status_macros.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace distributed_point_functions {
namespace dpf_internal {
inline const absl::Status& GetStatus(const absl::Status& status) {
return status;
}
template <typename T>
inline const absl::Status& GetStatus(const absl::StatusOr<T>& status) {
return status.status();
}
////////////////////////////////////////////////////////////
// Implementation of IsOkAndHolds().
// Monomorphic implementation of matcher IsOkAndHolds(m). StatusOrType is a
// reference to StatusOr<T>.
template <typename StatusOrType>
class IsOkAndHoldsMatcherImpl
: public ::testing::MatcherInterface<StatusOrType> {
public:
typedef
typename std::remove_reference<StatusOrType>::type::value_type value_type;
template <typename InnerMatcher>
explicit IsOkAndHoldsMatcherImpl(InnerMatcher&& inner_matcher)
: inner_matcher_(::testing::SafeMatcherCast<const value_type&>(
std::forward<InnerMatcher>(inner_matcher))) {}
void DescribeTo(std::ostream* os) const override {
*os << "is OK and has a value that ";
inner_matcher_.DescribeTo(os);
}
void DescribeNegationTo(std::ostream* os) const override {
*os << "isn't OK or has a value that ";
inner_matcher_.DescribeNegationTo(os);
}
bool MatchAndExplain(
StatusOrType actual_value,
::testing::MatchResultListener* result_listener) const override {
if (!actual_value.ok()) {
*result_listener << "which has status " << actual_value.status();
return false;
}
::testing::StringMatchResultListener inner_listener;
const bool matches =
inner_matcher_.MatchAndExplain(*actual_value, &inner_listener);
const std::string inner_explanation = inner_listener.str();
if (!inner_explanation.empty()) {
*result_listener << "which contains value "
<< ::testing::PrintToString(*actual_value) << ", "
<< inner_explanation;
}
return matches;
}
private:
const ::testing::Matcher<const value_type&> inner_matcher_;
};
// Implements IsOkAndHolds(m) as a polymorphic matcher.
template <typename InnerMatcher>
class IsOkAndHoldsMatcher {
public:
explicit IsOkAndHoldsMatcher(InnerMatcher inner_matcher)
: inner_matcher_(std::move(inner_matcher)) {}
// Converts this polymorphic matcher to a monomorphic matcher of the
// given type. StatusOrType can be either StatusOr<T> or a
// reference to StatusOr<T>.
template <typename StatusOrType>
operator ::testing::Matcher<StatusOrType>() const { // NOLINT
return ::testing::Matcher<StatusOrType>(
new IsOkAndHoldsMatcherImpl<const StatusOrType&>(inner_matcher_));
}
private:
const InnerMatcher inner_matcher_;
};
////////////////////////////////////////////////////////////
// Implementation of StatusIs().
// StatusIs() is a polymorphic matcher. This class is the common
// implementation of it shared by all types T where StatusIs() can be
// used as a Matcher<T>.
class StatusIsMatcherCommonImpl {
public:
StatusIsMatcherCommonImpl(
::testing::Matcher<absl::StatusCode> code_matcher,
::testing::Matcher<const std::string&> message_matcher)
: code_matcher_(std::move(code_matcher)),
message_matcher_(std::move(message_matcher)) {}
void DescribeTo(std::ostream* os) const;
void DescribeNegationTo(std::ostream* os) const;
bool MatchAndExplain(const absl::Status& status,
::testing::MatchResultListener* result_listener) const;
private:
const ::testing::Matcher<absl::StatusCode> code_matcher_;
const ::testing::Matcher<const std::string&> message_matcher_;
};
// Monomorphic implementation of matcher StatusIs() for a given type
// T. T can be Status, StatusOr<>, or a reference to either of them.
template <typename T>
class MonoStatusIsMatcherImpl : public ::testing::MatcherInterface<T> {
public:
explicit MonoStatusIsMatcherImpl(StatusIsMatcherCommonImpl common_impl)
: common_impl_(std::move(common_impl)) {}
void DescribeTo(std::ostream* os) const override {
common_impl_.DescribeTo(os);
}
void DescribeNegationTo(std::ostream* os) const override {
common_impl_.DescribeNegationTo(os);
}
bool MatchAndExplain(
T actual_value,
::testing::MatchResultListener* result_listener) const override {
return common_impl_.MatchAndExplain(GetStatus(actual_value),
result_listener);
}
private:
StatusIsMatcherCommonImpl common_impl_;
};
// Implements StatusIs() as a polymorphic matcher.
class StatusIsMatcher {
public:
template <typename StatusCodeMatcher, typename StatusMessageMatcher>
StatusIsMatcher(StatusCodeMatcher&& code_matcher,
StatusMessageMatcher&& message_matcher)
: common_impl_(::testing::MatcherCast<absl::StatusCode>(
std::forward<StatusCodeMatcher>(code_matcher)),
::testing::MatcherCast<const std::string&>(
std::forward<StatusMessageMatcher>(message_matcher))) {
}
// Converts this polymorphic matcher to a monomorphic matcher of the
// given type. T can be StatusOr<>, Status, or a reference to
// either of them.
template <typename T>
operator ::testing::Matcher<T>() const { // NOLINT
return ::testing::Matcher<T>(
new MonoStatusIsMatcherImpl<const T&>(common_impl_));
}
private:
const StatusIsMatcherCommonImpl common_impl_;
};
// Monomorphic implementation of matcher IsOk() for a given type T.
// T can be Status, StatusOr<>, or a reference to either of them.
template <typename T>
class MonoIsOkMatcherImpl : public ::testing::MatcherInterface<T> {
public:
void DescribeTo(std::ostream* os) const override { *os << "is OK"; }
void DescribeNegationTo(std::ostream* os) const override {
*os << "is not OK";
}
bool MatchAndExplain(
T actual_value,
::testing::MatchResultListener* result_listener) const override {
if (!actual_value.ok()) {
*result_listener << "whose status is "
<< GetStatus(actual_value).message();
return false;
}
return true;
}
};
// Implements IsOk() as a polymorphic matcher.
class IsOkMatcher {
public:
template <typename T>
operator ::testing::Matcher<T>() const { // NOLINT
return ::testing::Matcher<T>(new MonoIsOkMatcherImpl<const T&>());
}
};
// Macros for testing the results of functions that return absl::Status or
// absl::StatusOr<T> (for any type T).
#define DPF_EXPECT_OK(expression) \
EXPECT_THAT(expression, distributed_point_functions::dpf_internal::IsOk())
#define DPF_ASSERT_OK(expression) \
ASSERT_THAT(expression, distributed_point_functions::dpf_internal::IsOk())
// Executes an expression that returns an absl::StatusOr, and assigns the
// contained variable to lhs if the error code is OK.
// If the Status is non-OK, generates a test failure and returns from the
// current function, which must have a void return type.
//
// Example: Declaring and initializing a new value
// DPF_ASSERT_OK_AND_ASSIGN(const ValueType& value, MaybeGetValue(arg));
//
// Example: Assigning to an existing value
// ValueType value;
// DPF_ASSERT_OK_AND_ASSIGN(value, MaybeGetValue(arg));
//
// The value assignment example would expand into something like:
// auto status_or_value = MaybeGetValue(arg);
// DPF_ASSERT_OK(status_or_value.status());
// value = std::move(status_or_value).ValueOrDie();
//
// WARNING: Like ASSIGN_OR_RETURN, DPF_ASSERT_OK_AND_ASSIGN expands into
// multiple statements; it cannot be used in a single statement (e.g. as the
// body of an if statement without {})!
#define DPF_ASSERT_OK_AND_ASSIGN(lhs, rexpr) \
DPF_ASSERT_OK_AND_ASSIGN_IMPL_( \
DPF_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr)
#define DPF_ASSERT_OK_AND_ASSIGN_IMPL_(statusor, lhs, rexpr) \
auto statusor = (rexpr); \
DPF_ASSERT_OK(statusor); \
lhs = std::move(statusor).value();
// Returns a gMock matcher that matches a StatusOr<> whose status is
// OK and whose value matches the inner matcher.
template <typename InnerMatcher>
dpf_internal::IsOkAndHoldsMatcher<typename std::decay<InnerMatcher>::type>
IsOkAndHolds(InnerMatcher&& inner_matcher) {
return dpf_internal::IsOkAndHoldsMatcher<
typename std::decay<InnerMatcher>::type>(
std::forward<InnerMatcher>(inner_matcher));
}
// Returns a gMock matcher that matches a Status or StatusOr<> whose status code
// matches code_matcher, and whose error message matches message_matcher.
template <typename StatusCodeMatcher, typename StatusMessageMatcher>
dpf_internal::StatusIsMatcher StatusIs(StatusCodeMatcher&& code_matcher,
StatusMessageMatcher&& message_matcher) {
return dpf_internal::StatusIsMatcher(
std::forward<StatusCodeMatcher>(code_matcher),
std::forward<StatusMessageMatcher>(message_matcher));
}
// Returns a gMock matcher that matches a Status or StatusOr<> whose status code
// matches code_matcher.
template <typename StatusCodeMatcher>
dpf_internal::StatusIsMatcher StatusIs(StatusCodeMatcher&& code_matcher) {
return StatusIs(std::forward<StatusCodeMatcher>(code_matcher), ::testing::_);
}
// Returns a gMock matcher that matches a Status or StatusOr<> which is OK.
inline dpf_internal::IsOkMatcher IsOk() { return dpf_internal::IsOkMatcher(); }
} // namespace dpf_internal
} // namespace distributed_point_functions
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MATCHERS_H_

@ -1,169 +0,0 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dpf/internal/value_type_helpers.h"
#include <stdint.h>
#include <cmath>
#include <string>
#include "absl/numeric/int128.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "dpf/distributed_point_function.pb.h"
#include "dpf/int_mod_n.h"
#include "dpf/status_macros.h"
namespace distributed_point_functions {
namespace dpf_internal {
absl::StatusOr<bool> ValueTypesAreEqual(const ValueType& lhs,
const ValueType& rhs) {
if (lhs.type_case() == ValueType::TypeCase::TYPE_NOT_SET ||
rhs.type_case() == ValueType::TypeCase::TYPE_NOT_SET) {
return absl::InvalidArgumentError(
"Both arguments must be valid ValueTypes");
} else if (lhs.type_case() == ValueType::kInteger &&
rhs.type_case() == ValueType::kInteger) {
return lhs.integer().bitsize() == rhs.integer().bitsize();
} else if (lhs.type_case() == ValueType::kTuple &&
rhs.type_case() == ValueType::kTuple &&
lhs.tuple().elements_size() == rhs.tuple().elements_size()) {
bool result = true;
for (int i = 0; i < static_cast<int>(lhs.tuple().elements_size()); ++i) {
DPF_ASSIGN_OR_RETURN(
bool element_result,
ValueTypesAreEqual(lhs.tuple().elements(i), rhs.tuple().elements(i)));
result &= element_result;
}
return result;
} else if (lhs.type_case() == ValueType::kIntModN &&
rhs.type_case() == ValueType::kIntModN) {
const Value::Integer &lhs_modulus = lhs.int_mod_n().modulus(),
&rhs_modulus = rhs.int_mod_n().modulus();
DPF_ASSIGN_OR_RETURN(absl::uint128 lhs_modulus_128,
ValueIntegerToUint128(lhs_modulus));
DPF_ASSIGN_OR_RETURN(absl::uint128 rhs_modulus_128,
ValueIntegerToUint128(rhs_modulus));
return lhs.int_mod_n().base_integer().bitsize() ==
rhs.int_mod_n().base_integer().bitsize() &&
lhs_modulus_128 == rhs_modulus_128;
} else if (lhs.type_case() == ValueType::kXorWrapper &&
rhs.type_case() == ValueType::kXorWrapper) {
return lhs.xor_wrapper().bitsize() == rhs.xor_wrapper().bitsize();
}
return false;
}
absl::StatusOr<int> BitsNeeded(const ValueType& value_type,
double security_parameter) {
if (value_type.type_case() == ValueType::kInteger) {
return value_type.integer().bitsize();
} else if (value_type.type_case() == ValueType::kTuple) {
// We handle elements of type IntModN separately, since we can sample them
// together.
int num_ints_mod_n = 0;
int num_other = 0;
const ValueType* int_mod_n = nullptr;
int bitsize_ints_mod_n = 0;
int bitsize_other = 0;
for (const ValueType& el : value_type.tuple().elements()) {
if (el.type_case() == ValueType::kIntModN) {
// Element is integer mod N -> check if it is the same as the others in
// this tuple and increase counter.
if (!int_mod_n) {
int_mod_n = &el;
} else {
absl::StatusOr<bool> types_are_equal =
ValueTypesAreEqual(el, *int_mod_n);
if (!types_are_equal.ok()) {
return types_are_equal.status();
}
if (!*types_are_equal) {
return absl::UnimplementedError(
"All elements of type IntModN in a tuple must be the same");
}
}
++num_ints_mod_n;
} else {
++num_other;
}
}
if (num_other > 0) {
for (int i = 0; i < num_other; ++i) {
double per_element_security_parameter =
security_parameter + std::log2(static_cast<double>(num_other));
DPF_ASSIGN_OR_RETURN(int el_bitsize,
BitsNeeded(value_type.tuple().elements(i),
per_element_security_parameter));
bitsize_other += el_bitsize;
}
}
if (num_ints_mod_n > 0) {
DPF_ASSIGN_OR_RETURN(
absl::uint128 modulus,
ValueIntegerToUint128(int_mod_n->int_mod_n().modulus()));
DPF_ASSIGN_OR_RETURN(
int64_t bytes_needed_ints_mod_n,
dpf_internal::IntModNBase::GetNumBytesRequired(
num_ints_mod_n, int_mod_n->int_mod_n().base_integer().bitsize(),
modulus, security_parameter));
bitsize_ints_mod_n = bytes_needed_ints_mod_n * 8;
}
return bitsize_ints_mod_n + bitsize_other;
} else if (value_type.type_case() == ValueType::kIntModN) {
DPF_ASSIGN_OR_RETURN(
absl::uint128 modulus,
ValueIntegerToUint128(value_type.int_mod_n().modulus()));
DPF_ASSIGN_OR_RETURN(int64_t bytes_needed_ints_mod_n,
dpf_internal::IntModNBase::GetNumBytesRequired(
1, value_type.int_mod_n().base_integer().bitsize(),
modulus, security_parameter));
return 8 * bytes_needed_ints_mod_n;
} else if (value_type.type_case() == ValueType::kXorWrapper) {
return value_type.xor_wrapper().bitsize();
}
return absl::InvalidArgumentError(absl::StrCat(
"BitsNeeded: Unsupported ValueType:\n", value_type.DebugString()));
}
// Integer Helpers
Value::Integer Uint128ToValueInteger(absl::uint128 input) {
Value::Integer result;
if (absl::Uint128High64(input) == 0) {
result.set_value_uint64(absl::Uint128Low64(input));
} else {
Block& block = *(result.mutable_value_uint128());
block.set_high(absl::Uint128High64(input));
block.set_low(absl::Uint128Low64(input));
}
return result;
}
absl::StatusOr<absl::uint128> ValueIntegerToUint128(const Value::Integer& in) {
if (in.value_case() == Value::Integer::kValueUint128) {
return absl::MakeUint128(in.value_uint128().high(),
in.value_uint128().low());
} else if (in.value_case() == Value::Integer::kValueUint64) {
return in.value_uint64();
}
return absl::InvalidArgumentError(
"Unknown value case for the given integer Value");
}
} // namespace dpf_internal
} // namespace distributed_point_functions

@ -1,673 +0,0 @@
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_VALUE_TYPE_HELPERS_H_
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_VALUE_TYPE_HELPERS_H_
#include <stdint.h>
#include <algorithm>
#include <array>
#include <limits>
#include <string>
#include <tuple>
#include <type_traits>
#include <vector>
#include "absl/base/config.h"
#include "absl/log/absl_check.h"
#include "absl/meta/type_traits.h"
#include "absl/numeric/int128.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/utility/utility.h"
#include "dpf/distributed_point_function.pb.h"
#include "dpf/int_mod_n.h"
#include "dpf/tuple.h"
#include "dpf/xor_wrapper.h"
#include "google/protobuf/repeated_field.h"
// Contains a collection of helper functions for different DPF value types. This
// includes functions for converting between Value protos and the corresponding
// C++ objects, as well as functions for sampling values from uniformly random
// byte strings.
//
// This file contains the templated declarations, instantiations for all
// supported types, as well as type-independent function declarations.
namespace distributed_point_functions {
namespace dpf_internal {
// A helper struct containing declarations for all templated functions we need.
// This is needed since C++ doesn't support partial function template
// specialization, and should be specialized for all supported types.
template <typename T, typename = void>
struct ValueTypeHelper {
// General type traits and conversion functions. Should be implemented by all
// types.
// Type trait for all supported types. Used to provide meaningful error
// messages in std::enable_if template guards.
static constexpr bool IsSupportedType() { return false; }
// Checks if the template parameter can be converted directly from a string of
// bytes.
static constexpr bool CanBeConvertedDirectly();
// Converts a given Value to the template parameter T.
static absl::StatusOr<T> FromValue(const Value& value);
// ToValue Converts the argument to a Value proto.
static Value ToValue(const T& input);
// ToValueType<T> Returns a `ValueType` message describing T.
static ValueType ToValueType();
// Functions for converting from a byte string to T. There are two approaches:
// Either converting directly (i.e., each byte is copied 1-to-1 into the
// result), or by sampling (when a direct conversion is not possible). Types
// for which CanBeConvertedDirectly() can be true should implement the former,
// and all types should implement the latter (to support types composed of
// directly-convertible and not-directly-convertible types).
// Functions for direct conversions from bytes. Should be implemented when
// CanBeConvertedDirectly() can be true.
// Returns the total number of bits in a T.
static constexpr int TotalBitSize();
static T DirectlyFromBytes(absl::string_view bytes);
// Functions for sampling from a string of bytes. Should be implemented by all
// types.
// Converts `block` to type T. Then, if `update == true`, fills up `block`
// from `remaining_bytes` and advances `remaining_bytes` by the amount of
// bytes read.
static T SampleAndUpdateBytes(bool update, absl::uint128& block,
absl::string_view& remaining_bytes);
};
/******************************************************************************/
// Type traits //
/******************************************************************************/
// Type trait for all supported types. Used to provide meaningful error messages
// in std::enable_if template guards.
template <typename T>
struct is_supported_type {
static constexpr bool value =
dpf_internal::ValueTypeHelper<T>::IsSupportedType();
};
template <typename T>
constexpr bool is_supported_type_v = is_supported_type<T>::value;
// Checks if the template parameter can be converted directly from a string of
// bytes.
template <typename T>
struct can_be_converted_directly {
static constexpr bool value =
dpf_internal::ValueTypeHelper<T>::CanBeConvertedDirectly();
};
template <typename T>
constexpr bool can_be_converted_directly_v =
can_be_converted_directly<T>::value;
// Returns the total number of bits in a T.
template <typename T,
typename = absl::enable_if_t<can_be_converted_directly_v<T>>>
static constexpr int TotalBitSize() {
return ValueTypeHelper<T>::TotalBitSize();
}
/******************************************************************************/
// Integer Helpers //
/******************************************************************************/
// Type trait for all integer types we support, i.e., 8 to 128 bit types.
template <typename T>
using is_unsigned_integer =
absl::disjunction<std::is_same<T, uint8_t>, std::is_same<T, uint16_t>,
std::is_same<T, uint32_t>, std::is_same<T, uint64_t>,
#ifdef ABSL_HAVE_INTRINSIC_INT128
std::is_same<T, unsigned __int128>,
#endif
std::is_same<T, absl::uint128>>;
template <typename T>
constexpr bool is_unsigned_integer_v = is_unsigned_integer<T>::value;
// Converts the given Value::Integer to an absl::uint128. Used as a helper
// function in `ConvertValueTo` and `ValueTypesAreEqual`.
//
// Returns INVALID_ARGUMENT if `in` is not a simple integer or IntModN.
absl::StatusOr<absl::uint128> ValueIntegerToUint128(const Value::Integer& in);
// Converts an absl::uint128 to a Value::Integer. Used as a helper function in
// ToValue.
Value::Integer Uint128ToValueInteger(absl::uint128 input);
// Checks if the given value is in range of T, and if so, returns it converted
// to T.
//
// Otherwise returns INVALID_ARGUMENT.
template <typename T, typename = absl::enable_if_t<is_unsigned_integer_v<T>>>
absl::StatusOr<T> Uint128To(absl::uint128 in) {
// Check whether value is in range if it's smaller than 128 bits.
if (!std::is_same<T, absl::uint128>::value &&
absl::Uint128Low64(in) >
static_cast<uint64_t>(std::numeric_limits<T>::max())) {
return absl::InvalidArgumentError(absl::StrCat(
"Value (= ", absl::Uint128Low64(in),
") too large for the given type T (size ", sizeof(T), ")"));
}
return static_cast<T>(in);
}
// Implementation of ValueTypeHelper for integers.
template <typename T>
struct ValueTypeHelper<T, absl::enable_if_t<is_unsigned_integer_v<T>>> {
static constexpr bool IsSupportedType() { return true; }
static constexpr bool CanBeConvertedDirectly() { return true; }
static absl::StatusOr<T> FromValue(const Value& value) {
if (value.value_case() != Value::kInteger) {
return absl::InvalidArgumentError("The given Value is not an integer");
}
// We first parse the value into an absl::uint128, then check its range if
// it is supposed to be smaller than 128 bits.
absl::StatusOr<absl::uint128> value_128 =
ValueIntegerToUint128(value.integer());
if (!value_128.ok()) {
return value_128.status();
}
return Uint128To<T>(*value_128);
}
static Value ToValue(T input) {
Value result;
*(result.mutable_integer()) = Uint128ToValueInteger(input);
return result;
}
static ValueType ToValueType() {
ValueType result;
result.mutable_integer()->set_bitsize(8 * sizeof(T));
return result;
}
static constexpr int TotalBitSize() { return sizeof(T) * 8; }
static T DirectlyFromBytes(absl::string_view bytes) {
ABSL_CHECK(bytes.size() == sizeof(T));
T out{0};
#ifdef ABSL_IS_LITTLE_ENDIAN
std::copy_n(bytes.begin(), sizeof(T), reinterpret_cast<char*>(&out));
#else
for (int i = sizeof(T) - 1; i >= 0; --i) {
out |= absl::bit_cast<uint8_t>(bytes[i]);
out <<= 8;
}
#endif
return out;
}
static T SampleAndUpdateBytes(bool update, absl::uint128& block,
absl::string_view& remaining_bytes) {
T result = static_cast<T>(block);
if (update) {
// Set sizeof(T) least significant bytes to 0.
if (sizeof(T) < sizeof(block)) {
constexpr absl::uint128 mask =
~absl::uint128{std::numeric_limits<T>::max()};
block &= mask;
} else {
block = 0;
}
// Fill up with `bytes` and advance `bytes` by sizeof(T).
ABSL_DCHECK(remaining_bytes.size() >= sizeof(T));
block |= DirectlyFromBytes(remaining_bytes.substr(0, sizeof(T)));
remaining_bytes = remaining_bytes.substr(sizeof(T));
}
return result;
}
};
/******************************************************************************/
// IntModN Helpers //
/******************************************************************************/
template <typename BaseInteger, typename ModulusType, ModulusType kModulus>
struct ValueTypeHelper<
dpf_internal::IntModNImpl<BaseInteger, ModulusType, kModulus>, void> {
using IntModNType =
dpf_internal::IntModNImpl<BaseInteger, ModulusType, kModulus>;
static constexpr bool IsSupportedType() {
return is_unsigned_integer_v<BaseInteger> &&
is_unsigned_integer_v<ModulusType>;
}
static constexpr bool CanBeConvertedDirectly() { return false; }
static absl::StatusOr<IntModNType> FromValue(const Value& value) {
if (value.value_case() != Value::kIntModN) {
return absl::InvalidArgumentError("The given Value is not an IntModN");
}
absl::StatusOr<absl::uint128> value_128 =
ValueIntegerToUint128(value.int_mod_n());
if (!value_128.ok()) {
return value_128.status();
}
if (*value_128 >= absl::uint128{kModulus}) {
return absl::InvalidArgumentError(absl::StrFormat(
"The given value (= %d) is larger than kModulus (= %d)", *value_128,
absl::uint128{kModulus}));
}
return IntModNType(static_cast<BaseInteger>(*value_128));
}
static Value ToValue(IntModNType input) {
Value result;
*(result.mutable_int_mod_n()) = Uint128ToValueInteger(input.value());
return result;
}
static ValueType ToValueType() {
ValueType result;
*(result.mutable_int_mod_n()->mutable_base_integer()) =
ValueTypeHelper<BaseInteger>::ToValueType().integer();
*(result.mutable_int_mod_n()->mutable_modulus()) =
ValueTypeHelper<ModulusType>::ToValue(kModulus).integer();
return result;
}
static IntModNType SampleAndUpdateBytes(bool update, absl::uint128& block,
absl::string_view& remaining_bytes) {
// Optimization for native uint128. This is equivalent to what's done in
// int128.cc, but since division is not defined in the header, the compiler
// cannot optimize the division and modulus into a single operation.
#ifdef ABSL_HAVE_INTRINSIC_INT128
absl::uint128 quotient = static_cast<unsigned __int128>(block) / kModulus,
remainder = static_cast<unsigned __int128>(block) % kModulus;
#else
absl::uint128 quotient = block / kModulus, remainder = block % kModulus;
#endif
IntModNType result(static_cast<BaseInteger>(remainder));
if (update) {
if (sizeof(BaseInteger) < sizeof(block)) {
block = quotient << (sizeof(BaseInteger) * 8);
} else {
block = 0;
}
block |= ValueTypeHelper<BaseInteger>::DirectlyFromBytes(
remaining_bytes.substr(0, sizeof(BaseInteger)));
remaining_bytes = remaining_bytes.substr(sizeof(BaseInteger));
}
return result;
}
};
/******************************************************************************/
// Tuple Helpers //
/******************************************************************************/
// Helper struct for computing the bit size of a tuple type at compile time
// without C++17 fold expressions.
template <typename FirstElementType, typename... ElementType>
struct TupleBitSizeHelper {
static constexpr int TotalBitSize() {
return TupleBitSizeHelper<FirstElementType>::TotalBitSize() +
TupleBitSizeHelper<ElementType...>::TotalBitSize();
}
};
template <typename ElementType>
struct TupleBitSizeHelper<ElementType> {
static constexpr int TotalBitSize() {
return ValueTypeHelper<ElementType>::TotalBitSize();
}
};
template <typename... ElementType>
struct ValueTypeHelper<Tuple<ElementType...>, void> {
using TupleType = Tuple<ElementType...>;
static constexpr bool IsSupportedType() {
return absl::conjunction<is_supported_type<ElementType>...>::value;
}
static constexpr bool CanBeConvertedDirectly() {
return absl::conjunction<can_be_converted_directly<ElementType>...>::value;
}
static absl::StatusOr<TupleType> FromValue(const Value& value) {
if (value.value_case() != Value::kTuple) {
return absl::InvalidArgumentError("The given Value is not a tuple");
}
constexpr auto tuple_size =
static_cast<int>(std::tuple_size<typename TupleType::Base>());
if (value.tuple().elements_size() != tuple_size) {
return absl::InvalidArgumentError(
"The tuple in the given Value has the wrong number of elements");
}
// Create a Tuple by unpacking value.tuple().elements(). If we encounter an
// error, return it at the end.
absl::Status status = absl::OkStatus();
int element_index = 0;
// The braced initializer list ensures elements are created in the correct
// order (unlike std::make_tuple).
TupleType result = {[&value, &status, &element_index] {
if (status.ok()) {
absl::StatusOr<ElementType> element =
ValueTypeHelper<ElementType>::FromValue(
value.tuple().elements(element_index));
element_index++;
if (element.ok()) {
return *element;
} else {
status = element.status();
}
}
return ElementType{};
}()...};
if (status.ok()) {
return result;
} else {
return status;
}
}
static Value ToValue(const TupleType& input) {
Value result;
absl::apply(
[&result](const ElementType&... elements) {
// Create an unused std::tuple to iterate over `elements` in its
// constructor. This can be replaced by a fold expression in C++17.
std::tuple<ElementType...>{
(*(result.mutable_tuple()->add_elements()) =
ValueTypeHelper<ElementType>::ToValue(elements),
ElementType{})...};
},
input.value());
return result;
}
static ValueType ToValueType() {
ValueType result;
ValueType::Tuple* tuple = result.mutable_tuple();
// Create an unused std::tuple to iterate over `elements` in its
// constructor. This can be replaced by a fold expression in C++17.
std::tuple<ElementType...>{
(*(tuple->add_elements()) = ValueTypeHelper<ElementType>::ToValueType(),
ElementType{})...};
return result;
}
static constexpr int TotalBitSize() {
// This helper can be replaced by a fold expression in C++17.
return TupleBitSizeHelper<ElementType...>::TotalBitSize();
}
static TupleType DirectlyFromBytes(absl::string_view bytes) {
ABSL_CHECK(8 * bytes.size() >= TotalBitSize());
int offset = 0;
absl::Status status = absl::OkStatus();
// Braced-init-list ensures the elements are constructed in-order.
return TupleType{[&bytes, &offset, &status] {
constexpr int element_size_bytes =
(ValueTypeHelper<ElementType>::TotalBitSize() + 7) / 8;
ElementType element = ValueTypeHelper<ElementType>::DirectlyFromBytes(
bytes.substr(offset, element_size_bytes));
offset += element_size_bytes;
return element;
}()...};
}
static TupleType SampleAndUpdateBytes(bool update, absl::uint128& block,
absl::string_view& remaining_bytes) {
int element_counter = 0;
// Braced-init-list ensures the elements are constructed in-order.
return TupleType{[update, &element_counter, &block,
&remaining_bytes]() -> ElementType {
// If `update` is true, update after all elements. Otherwise, don't update
// after the last one.
constexpr int num_elements = std::tuple_size<typename TupleType::Base>();
bool update2 = update || (++element_counter < num_elements);
return ValueTypeHelper<ElementType>::SampleAndUpdateBytes(
update2, block, remaining_bytes);
}()...};
}
};
/******************************************************************************/
// XorWrapper Helpers //
/******************************************************************************/
template <typename T>
struct ValueTypeHelper<XorWrapper<T>, void> {
static constexpr bool IsSupportedType() {
return ValueTypeHelper<T>::IsSupportedType();
}
static constexpr bool CanBeConvertedDirectly() {
return ValueTypeHelper<T>::CanBeConvertedDirectly();
}
static absl::StatusOr<XorWrapper<T>> FromValue(const Value& value) {
absl::StatusOr<absl::uint128> wrapped128 =
ValueIntegerToUint128(value.xor_wrapper());
if (!wrapped128.ok()) {
return wrapped128.status();
}
absl::StatusOr<T> wrapped = Uint128To<T>(*wrapped128);
if (!wrapped.ok()) {
return wrapped.status();
}
return XorWrapper<T>(*wrapped);
}
static Value ToValue(const XorWrapper<T>& input) {
Value result;
*(result.mutable_xor_wrapper()) = Uint128ToValueInteger(input.value());
return result;
}
static ValueType ToValueType() {
ValueType result;
*(result.mutable_xor_wrapper()) =
ValueTypeHelper<T>::ToValueType().integer();
return result;
}
static constexpr int TotalBitSize() {
return ValueTypeHelper<T>::TotalBitSize();
}
static XorWrapper<T> DirectlyFromBytes(absl::string_view bytes) {
return XorWrapper<T>(ValueTypeHelper<T>::DirectlyFromBytes(bytes));
}
static XorWrapper<T> SampleAndUpdateBytes(
bool update, absl::uint128& block, absl::string_view& remaining_bytes) {
return XorWrapper<T>(ValueTypeHelper<T>::SampleAndUpdateBytes(
update, block, remaining_bytes));
}
};
/******************************************************************************/
// Free standing helpers. These should always come last. When adding //
// additional types, add them above. //
/******************************************************************************/
// Computes the number of values of type T that fit into an absl::uint128.
// Returns a value >= 1 if batching is supported, and 1 otherwise.
template <typename T,
absl::enable_if_t<can_be_converted_directly_v<T>, int> = 0>
constexpr int ElementsPerBlock() {
if (TotalBitSize<T>() <= 128) {
return static_cast<int>(8 * sizeof(absl::uint128)) / TotalBitSize<T>();
}
return 1;
}
template <typename T,
absl::enable_if_t<!can_be_converted_directly_v<T>, int> = 0>
constexpr int ElementsPerBlock() {
return 1;
}
// Creates a value of type T from the given `bytes`. If possible, converts bytes
// directly using DirectlyFromBytes. Otherwise, uses SampleAndUpdateBytes.
//
// Crashes if `bytes.size()` is too small for the output type.
template <typename T,
absl::enable_if_t<can_be_converted_directly_v<T>, int> = 0>
T FromBytes(absl::string_view bytes) {
return ValueTypeHelper<T>::DirectlyFromBytes(bytes);
}
template <typename T,
absl::enable_if_t<!can_be_converted_directly_v<T>, int> = 0>
T FromBytes(absl::string_view bytes) {
absl::uint128 block =
FromBytes<absl::uint128>(bytes.substr(0, sizeof(absl::uint128)));
bytes = bytes.substr(sizeof(absl::uint128));
return ValueTypeHelper<T>::SampleAndUpdateBytes(false, block, bytes);
}
// Converts a `repeated Value` proto field to a std::array with element type T.
//
// Returns INVALID_ARGUMENT in case the input has the wrong size, or if the
// conversion fails.
template <typename T>
absl::StatusOr<std::array<T, ElementsPerBlock<T>()>> ValuesToArray(
const ::google::protobuf::RepeatedPtrField<Value>& values) {
if (values.size() != ElementsPerBlock<T>()) {
return absl::InvalidArgumentError(absl::StrCat(
"values.size() (= ", values.size(),
") does not match ElementsPerBlock<T>() (= ", ElementsPerBlock<T>(),
")"));
}
std::array<T, ElementsPerBlock<T>()> result;
for (int i = 0; i < ElementsPerBlock<T>(); ++i) {
absl::StatusOr<T> element = ValueTypeHelper<T>::FromValue(values[i]);
if (element.ok()) {
result[i] = std::move(*element);
} else {
return element.status();
}
}
return result;
}
// Converts a given string to an array of exactly ElementsPerBlock<T>() elements
// of type T.
//
// Crashes if `bytes.size()` is too small for the output type.
template <typename T,
absl::enable_if_t<can_be_converted_directly_v<T>, int> = 0>
std::array<T, ElementsPerBlock<T>()> ConvertBytesToArrayOf(
absl::string_view bytes) {
std::array<T, ElementsPerBlock<T>()> out;
const int element_size_bytes = (TotalBitSize<T>() + 7) / 8;
ABSL_CHECK(bytes.size() >= ElementsPerBlock<T>() * element_size_bytes);
for (int i = 0; i < ElementsPerBlock<T>(); ++i) {
out[i] =
FromBytes<T>(bytes.substr(i * element_size_bytes, element_size_bytes));
}
return out;
}
template <typename T,
absl::enable_if_t<!can_be_converted_directly_v<T>, int> = 0>
std::array<T, ElementsPerBlock<T>()> ConvertBytesToArrayOf(
absl::string_view bytes) {
static_assert(ElementsPerBlock<T>() == 1,
"T does not support batching, but ElementsPerBlock<T> != 1");
return {FromBytes<T>(bytes)};
}
// Computes the value correction word given two seeds `seed_a`, `seed_b` for
// parties a and b, such that the element at `block_index` is equal to `beta`.
// If `invert` is true, the result is multiplied element-wise by -1. Templated
// to use the correct integer type without needing modular reduction.
//
// Returns multiple values in case of packing, and a single value otherwise.
template <typename T>
absl::StatusOr<std::vector<Value>> ComputeValueCorrectionFor(
absl::string_view seed_a, absl::string_view seed_b, int block_index,
const Value& beta, bool invert) {
absl::StatusOr<T> beta_T = ValueTypeHelper<T>::FromValue(beta);
if (!beta_T.ok()) {
return beta_T.status();
}
constexpr int elements_per_block = ElementsPerBlock<T>();
// Compute values from seeds. Both arrays will have multiple elements if T
// supports batching, and a single one otherwise.
std::array<T, elements_per_block> ints_a = ConvertBytesToArrayOf<T>(seed_a),
ints_b = ConvertBytesToArrayOf<T>(seed_b);
// Add beta to the right position.
ints_b[block_index] += *beta_T;
// Add up shares, invert if needed.
for (int i = 0; i < elements_per_block; i++) {
ints_b[i] = ints_b[i] - ints_a[i];
if (invert) {
ints_b[i] = -ints_b[i];
}
}
// Convert to a vector of Value protos and return.
std::vector<Value> result;
result.reserve(ints_b.size());
for (const T& element : ints_b) {
result.push_back(ValueTypeHelper<T>::ToValue(element));
}
return result;
}
// Computes the number of pseudorandom bits needed to get a uniform element of
// the given `ValueType`. For types whose elements can be bijectively mapped to
// strings (e.g., unsigned integers and tuples of integers), this is equivalent
// to the bit size of the value type. For all other types, returns the number of
// bits needed so that converting a uniform string with the given number of bits
// to an element of `value_type` results in a distribution with total variation
// distance < 2^(-`security_parameter`) from uniform.
//
// Returns INVALID_ARGUMENT in case value_type does not represent a known type,
// or if sampling with the required security parameter is not possible.
absl::StatusOr<int> BitsNeeded(const ValueType& value_type,
double security_parameter);
// Returns `true` if `lhs` and `rhs` describe the same types, and `false`
// otherwise.
//
// Returns INVALID_ARGUMENT if an error occurs while parsing either argument.
absl::StatusOr<bool> ValueTypesAreEqual(const ValueType& lhs,
const ValueType& rhs);
} // namespace dpf_internal
} // namespace distributed_point_functions
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_INTERNAL_VALUE_TYPE_HELPERS_H_

@ -1,359 +0,0 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dpf/internal/value_type_helpers.h"
#include <stdint.h>
#include <array>
#include <string>
#include <tuple>
#include "absl/base/config.h"
#include "absl/numeric/int128.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "dpf/distributed_point_function.pb.h"
#include "dpf/int_mod_n.h"
#include "dpf/internal/status_matchers.h"
#include "dpf/tuple.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace distributed_point_functions {
namespace dpf_internal {
namespace {
constexpr int kDefaultSecurityParameter = 40;
TEST(ValueTypeHelperTest, ValueTypesAreEqualFailsOnInvalidValueTypes) {
ValueType type1, type2;
EXPECT_THAT(ValueTypesAreEqual(type1, type2),
StatusIs(absl::StatusCode::kInvalidArgument,
"Both arguments must be valid ValueTypes"));
}
TEST(ValueTypeHelperTest, BitsNeededFailsOnInvalidValueType) {
EXPECT_THAT(
BitsNeeded(ValueType{}, kDefaultSecurityParameter),
StatusIs(absl::StatusCode::kInvalidArgument,
testing::StartsWith("BitsNeeded: Unsupported ValueType")));
}
template <typename T>
class ValueTypeIntegerTest : public testing::Test {};
using IntegerTypes =
::testing::Types<uint8_t, uint16_t, uint32_t, uint64_t, absl::uint128>;
TYPED_TEST_SUITE(ValueTypeIntegerTest, IntegerTypes);
TYPED_TEST(ValueTypeIntegerTest, ToValueTypeIntegers) {
ValueType value_type = ValueTypeHelper<TypeParam>::ToValueType();
EXPECT_TRUE(value_type.has_integer());
EXPECT_EQ(value_type.integer().bitsize(), sizeof(TypeParam) * 8);
}
TYPED_TEST(ValueTypeIntegerTest, TestValueTypesAreEqual) {
ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(),
value_type_2;
value_type_2.mutable_integer()->set_bitsize(sizeof(TypeParam) * 8);
DPF_ASSERT_OK_AND_ASSIGN(bool equal,
ValueTypesAreEqual(value_type_1, value_type_2));
EXPECT_TRUE(equal);
DPF_ASSERT_OK_AND_ASSIGN(equal,
ValueTypesAreEqual(value_type_2, value_type_1));
EXPECT_TRUE(equal);
}
TYPED_TEST(ValueTypeIntegerTest, TestValueTypesAreNotEqual) {
ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(),
value_type_2;
value_type_2.mutable_integer()->set_bitsize(sizeof(TypeParam) * 8 * 2);
DPF_ASSERT_OK_AND_ASSIGN(bool equal,
ValueTypesAreEqual(value_type_1, value_type_2));
EXPECT_FALSE(equal);
DPF_ASSERT_OK_AND_ASSIGN(equal,
ValueTypesAreEqual(value_type_2, value_type_1));
EXPECT_FALSE(equal);
}
TYPED_TEST(ValueTypeIntegerTest, ValueConversionFailsIfNotInteger) {
Value value;
value.mutable_tuple();
EXPECT_THAT(ValueTypeHelper<TypeParam>::FromValue(value),
StatusIs(absl::StatusCode::kInvalidArgument,
"The given Value is not an integer"));
}
TYPED_TEST(ValueTypeIntegerTest, ValueConversionFailsIfInvalidIntegerCase) {
Value value;
value.mutable_integer();
EXPECT_THAT(ValueTypeHelper<TypeParam>::FromValue(value),
StatusIs(absl::StatusCode::kInvalidArgument,
"Unknown value case for the given integer Value"));
}
TYPED_TEST(ValueTypeIntegerTest, ValueConversionFailsIfValueOutOfRange) {
Value value;
auto value_64 = uint64_t{1} << 32;
value.mutable_integer()->set_value_uint64(value_64);
if (sizeof(TypeParam) >= sizeof(uint64_t)) {
DPF_EXPECT_OK(ValueTypeHelper<TypeParam>::FromValue(value));
} else {
EXPECT_THAT(ValueTypeHelper<TypeParam>::FromValue(value),
StatusIs(absl::StatusCode::kInvalidArgument,
absl::StrCat("Value (= ", value_64,
") too large for the given type T (size ",
sizeof(TypeParam), ")")));
}
}
template <typename T>
class ValueTypeTupleTest : public testing::Test {};
template <typename T, int... bits>
struct TupleTestParam {
using Tuple = T;
static constexpr int ExpectedNumElements() { return sizeof...(bits); };
static constexpr std::array<int, ExpectedNumElements()> ExpectedBitSizes() {
return {bits...};
}
};
// We only test tuples consisting of integers here.
using TupleTypes = ::testing::Types<
TupleTestParam<Tuple<uint64_t>, 64>,
TupleTestParam<Tuple<uint64_t, uint64_t>, 64, 64>,
TupleTestParam<Tuple<uint32_t, absl::uint128, uint8_t>, 32, 128, 8>,
TupleTestParam<Tuple<uint8_t, uint8_t, uint8_t, uint8_t>, 8, 8, 8, 8>>;
TYPED_TEST_SUITE(ValueTypeTupleTest, TupleTypes);
TYPED_TEST(ValueTypeTupleTest, ToValueTypeTuples) {
ValueType value_type =
ValueTypeHelper<typename TypeParam::Tuple>::ToValueType();
constexpr int expected_num_elements = TypeParam::ExpectedNumElements();
EXPECT_TRUE(value_type.has_tuple());
ASSERT_EQ(std::tuple_size<typename TypeParam::Tuple::Base>(),
expected_num_elements); // Sanity check for test parameters.
EXPECT_EQ(value_type.tuple().elements_size(), expected_num_elements);
std::array<int, expected_num_elements> expected_bit_sizes =
TypeParam::ExpectedBitSizes();
for (int i = 0; i < expected_num_elements; ++i) {
EXPECT_TRUE(value_type.tuple().elements(i).has_integer());
EXPECT_EQ(value_type.tuple().elements(i).integer().bitsize(),
expected_bit_sizes[i]);
}
}
TYPED_TEST(ValueTypeTupleTest, BitsNeededEqualsCompileTimeTypeSize) {
ValueType value_type =
ValueTypeHelper<typename TypeParam::Tuple>::ToValueType();
DPF_ASSERT_OK_AND_ASSIGN(int bitsize,
BitsNeeded(value_type, kDefaultSecurityParameter));
EXPECT_EQ(bitsize, TotalBitSize<typename TypeParam::Tuple>());
}
TYPED_TEST(ValueTypeTupleTest, ValueConversionFailsIfValueIsNotATuple) {
Value value;
value.mutable_integer();
EXPECT_THAT(ValueTypeHelper<Tuple<uint32_t>>::FromValue(value),
StatusIs(absl::StatusCode::kInvalidArgument,
"The given Value is not a tuple"));
}
TEST(ValueTypeTupleTest, ValueConversionFailsIfValueSizeDoesntMatchTupleSize) {
Value value;
value.mutable_tuple()->add_elements()->mutable_integer()->set_value_uint64(
1234);
using TupleType = Tuple<uint32_t, uint32_t>;
EXPECT_THAT(
ValueTypeHelper<TupleType>::FromValue(value),
StatusIs(
absl::StatusCode::kInvalidArgument,
"The tuple in the given Value has the wrong number of elements"));
}
TEST(ValueTypeTupleTest, TestValueTypesAreEqual) {
using T1 = Tuple<uint32_t, absl::uint128, uint8_t>;
using T2 = Tuple<uint32_t, absl::uint128, uint8_t>;
ValueType value_type_1 = ValueTypeHelper<T1>::ToValueType();
ValueType value_type_2 = ValueTypeHelper<T2>::ToValueType();
DPF_ASSERT_OK_AND_ASSIGN(bool equal,
ValueTypesAreEqual(value_type_1, value_type_2));
EXPECT_TRUE(equal);
DPF_ASSERT_OK_AND_ASSIGN(equal,
ValueTypesAreEqual(value_type_2, value_type_1));
EXPECT_TRUE(equal);
}
TEST(ValueTypeTupleTest, TestValueTypesAreNotEqual) {
using T1 = Tuple<uint32_t, absl::uint128, uint8_t>;
using T2 = Tuple<uint32_t, absl::uint128, uint16_t>;
ValueType value_type_1 = ValueTypeHelper<T1>::ToValueType();
ValueType value_type_2 = ValueTypeHelper<T2>::ToValueType();
DPF_ASSERT_OK_AND_ASSIGN(bool equal,
ValueTypesAreEqual(value_type_1, value_type_2));
EXPECT_FALSE(equal);
DPF_ASSERT_OK_AND_ASSIGN(equal,
ValueTypesAreEqual(value_type_2, value_type_1));
EXPECT_FALSE(equal);
}
TEST(ValueTypeTupleTest, TestFromBytesWithConcreteExample) {
std::string bytes = "A 128 bit string";
auto tuple = FromBytes<Tuple<uint64_t, uint64_t>>(bytes);
EXPECT_EQ(std::get<0>(tuple.value()), FromBytes<uint64_t>("A 128 bi"));
EXPECT_EQ(std::get<1>(tuple.value()), FromBytes<uint64_t>("t string"));
}
TEST(ValueTypeTupleTest, TestFromBytesWithConcreteExampleForIntModN) {
constexpr uint32_t kModulus = 4294967291u;
using MyIntModN = IntModN<uint32_t, kModulus>;
std::string bytes = "A 128+32 bit string.";
absl::uint128 block = FromBytes<absl::uint128>("A 128+32 bit str");
MyIntModN expected_0(static_cast<uint32_t>(block % kModulus));
block /= kModulus;
block <<= (8 * sizeof(uint32_t));
block |= FromBytes<uint32_t>("ing.");
MyIntModN expected_1(static_cast<uint32_t>(block % kModulus));
auto tuple = FromBytes<Tuple<MyIntModN, MyIntModN>>(bytes).value();
EXPECT_EQ(std::get<0>(tuple), expected_0);
EXPECT_EQ(std::get<1>(tuple), expected_1);
}
template <typename T>
class ValueTypeIntModNTest : public testing::Test {};
using IntModNTypes = ::testing::Types<
IntModN<uint32_t, 4>, IntModN<uint32_t, 4294967291u>,
IntModN<uint64_t, 4294967291ull>, IntModN<uint64_t, 1000000000000ull>
#ifdef ABSL_HAVE_INTRINSIC_INT128
,
IntModN<absl::uint128, (unsigned __int128)(absl::MakeUint128(
65535u, 18446744073709551551ull))> // 2**80-65
#endif
>;
TYPED_TEST_SUITE(ValueTypeIntModNTest, IntModNTypes);
TYPED_TEST(ValueTypeIntModNTest, ToValueType) {
ValueType value_type = ValueTypeHelper<TypeParam>::ToValueType();
EXPECT_TRUE(value_type.type_case() == ValueType::kIntModN);
EXPECT_EQ(value_type.int_mod_n().base_integer().bitsize(),
sizeof(typename TypeParam::Base) * 8);
DPF_ASSERT_OK_AND_ASSIGN(
absl::uint128 modulus,
ValueIntegerToUint128(value_type.int_mod_n().modulus()));
EXPECT_EQ(modulus, absl::uint128{TypeParam::modulus()});
}
TYPED_TEST(ValueTypeIntModNTest, TestValueTypesAreEqual) {
ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(),
value_type_2;
value_type_2.mutable_int_mod_n()->mutable_base_integer()->set_bitsize(
sizeof(TypeParam) * 8);
*(value_type_2.mutable_int_mod_n()->mutable_modulus()) =
Uint128ToValueInteger(TypeParam::modulus());
DPF_ASSERT_OK_AND_ASSIGN(bool equal,
ValueTypesAreEqual(value_type_1, value_type_2));
EXPECT_TRUE(equal);
DPF_ASSERT_OK_AND_ASSIGN(equal,
ValueTypesAreEqual(value_type_2, value_type_1));
EXPECT_TRUE(equal);
}
TYPED_TEST(ValueTypeIntModNTest, TestValueTypesAreDifferentBase) {
ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(),
value_type_2 = value_type_1;
value_type_2.mutable_int_mod_n()->mutable_base_integer()->set_bitsize(
sizeof(TypeParam) * 8 * 2);
DPF_ASSERT_OK_AND_ASSIGN(bool equal,
ValueTypesAreEqual(value_type_1, value_type_2));
EXPECT_FALSE(equal);
DPF_ASSERT_OK_AND_ASSIGN(equal,
ValueTypesAreEqual(value_type_2, value_type_1));
EXPECT_FALSE(equal);
};
TYPED_TEST(ValueTypeIntModNTest, TestValueTypesAreDifferentModulus) {
ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(),
value_type_2 = value_type_1;
*(value_type_2.mutable_int_mod_n()->mutable_modulus()) =
Uint128ToValueInteger(TypeParam::modulus() - 1);
DPF_ASSERT_OK_AND_ASSIGN(bool equal,
ValueTypesAreEqual(value_type_1, value_type_2));
EXPECT_FALSE(equal);
DPF_ASSERT_OK_AND_ASSIGN(equal,
ValueTypesAreEqual(value_type_2, value_type_1));
EXPECT_FALSE(equal);
}
TYPED_TEST(ValueTypeIntModNTest, ValueTypesAreEqualFailsWhenModulusInvalid) {
ValueType value_type_1 = ValueTypeHelper<TypeParam>::ToValueType(),
value_type_2 = value_type_1;
value_type_2.mutable_int_mod_n()->clear_modulus();
EXPECT_THAT(ValueTypesAreEqual(value_type_1, value_type_2),
StatusIs(absl::StatusCode::kInvalidArgument,
"Unknown value case for the given integer Value"));
}
TYPED_TEST(ValueTypeIntModNTest, ValueConversionFailsIfNotInteger) {
Value value;
value.mutable_tuple();
EXPECT_THAT(ValueTypeHelper<TypeParam>::FromValue(value),
StatusIs(absl::StatusCode::kInvalidArgument,
"The given Value is not an IntModN"));
}
TYPED_TEST(ValueTypeIntModNTest, ValueConversionFailsIfTooLargeForModulus) {
Value value;
*(value.mutable_int_mod_n()) = Uint128ToValueInteger(TypeParam::modulus());
EXPECT_THAT(ValueTypeHelper<TypeParam>::FromValue(value),
StatusIs(absl::StatusCode::kInvalidArgument,
testing::HasSubstr("is larger than kModulus")));
}
} // namespace
} // namespace dpf_internal
} // namespace distributed_point_functions

@ -1,51 +0,0 @@
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MACROS_H_
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MACROS_H_
// Helper macro that checks if the right hand side (rexpression) evaluates to a
// StatusOr with Status OK, and if so assigns the value to the value on the left
// hand side (lhs), otherwise returns the error status. Example:
// DPF_ASSIGN_OR_RETURN(lhs, rexpression);
#define DPF_ASSIGN_OR_RETURN(lhs, rexpr) \
DPF_ASSIGN_OR_RETURN_IMPL_( \
DPF_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr)
// Internal helper.
#define DPF_ASSIGN_OR_RETURN_IMPL_(statusor, lhs, rexpr) \
auto statusor = (rexpr); \
if (ABSL_PREDICT_FALSE(!statusor.ok())) { \
return std::move(statusor).status(); \
} \
lhs = std::move(statusor).value()
// Internal helper for concatenating macro values.
#define DPF_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y
#define DPF_STATUS_MACROS_IMPL_CONCAT_(x, y) \
DPF_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y)
#define DPF_RETURN_IF_ERROR(expr) \
DPF_RETURN_IF_ERROR_IMPL_(DPF_STATUS_MACROS_IMPL_CONCAT_(_status, __LINE__), \
expr)
#define DPF_RETURN_IF_ERROR_IMPL_(status, expr) \
auto status = (expr); \
if (ABSL_PREDICT_FALSE(!status.ok())) { \
return status; \
}
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_UTIL_STATUS_MACROS_H_

@ -1,122 +0,0 @@
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_TUPLE_H_
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_TUPLE_H_
#include <stddef.h>
#include <tuple>
#include <utility>
namespace distributed_point_functions {
// A Tuple class with added element-wise addition, subtraction, and negation
// operators.
template <typename... T>
class Tuple {
public:
using Base = std::tuple<T...>;
Tuple() {}
Tuple(T... elements) : value_(elements...) {}
explicit Tuple(Base t) : value_(std::move(t)) {}
// Copy constructor.
Tuple(const Tuple& t) = default;
Tuple& operator=(const Tuple& t) = default;
// Getters for the base tuple type.
Base& value() { return value_; }
const Base& value() const { return value_; }
private:
Base value_;
};
namespace dpf_internal {
// Implementation of addition and negation. See
// https://stackoverflow.com/a/50815143.
// We declare the templates here, but define them at the end of this header
// because the definitions need to make use of operator+ and operator-.
template <typename... T, std::size_t... I>
constexpr Tuple<T...> add(const Tuple<T...>& lhs, const Tuple<T...>& rhs,
std::index_sequence<I...>);
template <typename... T, std::size_t... I>
constexpr Tuple<T...> negate(const Tuple<T...>& t, std::index_sequence<I...>);
} // namespace dpf_internal
template <typename... T>
constexpr Tuple<T...> operator+(const Tuple<T...>& lhs,
const Tuple<T...>& rhs) {
return dpf_internal::add(lhs, rhs, std::make_index_sequence<sizeof...(T)>{});
}
template <typename... T>
constexpr Tuple<T...>& operator+=(Tuple<T...>& lhs, const Tuple<T...>& rhs) {
lhs = lhs + rhs;
return lhs;
}
template <typename... T>
constexpr Tuple<T...> operator-(const Tuple<T...>& t) {
return dpf_internal::negate(t, std::make_index_sequence<sizeof...(T)>{});
}
template <typename... T>
constexpr Tuple<T...> operator-(const Tuple<T...>& lhs,
const Tuple<T...>& rhs) {
return lhs + (-rhs);
}
template <typename... T>
constexpr Tuple<T...>& operator-=(Tuple<T...>& lhs, const Tuple<T...>& rhs) {
lhs = lhs - rhs;
return lhs;
}
// Equality and inequality operators.
template <typename... T>
constexpr bool operator==(const Tuple<T...>& lhs, const Tuple<T...>& rhs) {
return lhs.value() == rhs.value();
}
template <typename... T>
constexpr bool operator!=(const Tuple<T...>& lhs, const Tuple<T...>& rhs) {
return lhs.value() != rhs.value();
}
namespace dpf_internal {
template <typename... T, std::size_t... I>
constexpr Tuple<T...> add(const Tuple<T...>& lhs, const Tuple<T...>& rhs,
std::index_sequence<I...>) {
return Tuple<T...>{std::get<I>(lhs.value()) + std::get<I>(rhs.value())...};
}
template <typename... T, std::size_t... I>
constexpr Tuple<T...> negate(const Tuple<T...>& t, std::index_sequence<I...>) {
return Tuple<T...>{
// Explicitly cast to T to avoid -Wnarrowing warnings for small integers.
T(-std::get<I>(t.value()))...};
}
} // namespace dpf_internal
} // namespace distributed_point_functions
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_TUPLE_H_

@ -1,97 +0,0 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dpf/tuple.h"
#include <tuple>
#include "absl/numeric/int128.h"
#include "gtest/gtest.h"
namespace distributed_point_functions {
namespace {
using T = Tuple<int, double, absl::uint128>;
TEST(TupleTest, TestAddition) {
T a(std::make_tuple(1, 2, 3));
T b(std::make_tuple(4, 5, 6));
T c = a + b;
EXPECT_EQ(std::get<0>(c.value()),
std::get<0>(a.value()) + std::get<0>(b.value()));
EXPECT_EQ(std::get<1>(c.value()),
std::get<1>(a.value()) + std::get<1>(b.value()));
EXPECT_EQ(std::get<2>(c.value()),
std::get<2>(a.value()) + std::get<2>(b.value()));
}
TEST(TupleTest, TestAdditionInplace) {
T a(std::make_tuple(1, 2, 3));
T b(std::make_tuple(4, 5, 6));
T a2 = a;
a += b;
EXPECT_EQ(std::get<0>(a.value()),
std::get<0>(a2.value()) + std::get<0>(b.value()));
EXPECT_EQ(std::get<1>(a.value()),
std::get<1>(a2.value()) + std::get<1>(b.value()));
EXPECT_EQ(std::get<2>(a.value()),
std::get<2>(a2.value()) + std::get<2>(b.value()));
}
TEST(TupleTest, TestSubtraction) {
T a(std::make_tuple(1, 2, 3));
T b(std::make_tuple(4, 5, 6));
T c = a - b;
EXPECT_EQ(std::get<0>(c.value()),
std::get<0>(a.value()) - std::get<0>(b.value()));
EXPECT_EQ(std::get<1>(c.value()),
std::get<1>(a.value()) - std::get<1>(b.value()));
EXPECT_EQ(std::get<2>(c.value()),
std::get<2>(a.value()) - std::get<2>(b.value()));
}
TEST(TupleTest, TestSubtractionInplace) {
T a(std::make_tuple(1, 2, 3));
T b(std::make_tuple(4, 5, 6));
T a2 = a;
a -= b;
EXPECT_EQ(std::get<0>(a.value()),
std::get<0>(a2.value()) - std::get<0>(b.value()));
EXPECT_EQ(std::get<1>(a.value()),
std::get<1>(a2.value()) - std::get<1>(b.value()));
EXPECT_EQ(std::get<2>(a.value()),
std::get<2>(a2.value()) - std::get<2>(b.value()));
}
TEST(TupleTest, TestNegation) {
T a(std::make_tuple(1, 2, 3));
T a2 = -a;
EXPECT_EQ(std::get<0>(a2.value()), -std::get<0>(a.value()));
EXPECT_EQ(std::get<1>(a2.value()), -std::get<1>(a.value()));
EXPECT_EQ(std::get<2>(a2.value()), -std::get<2>(a.value()));
}
} // namespace
} // namespace distributed_point_functions

@ -1,87 +0,0 @@
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_XOR_WRAPPER_H_
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_XOR_WRAPPER_H_
#include <utility>
namespace distributed_point_functions {
// Wraps the given type, replacing additions and subtractions by XOR.
template <typename T>
class XorWrapper {
public:
using WrappedType = T;
constexpr XorWrapper() : wrapped_{} {}
explicit constexpr XorWrapper(T wrapped) : wrapped_(std::move(wrapped)) {}
// XorWrapper is copyable and movable.
constexpr XorWrapper(const XorWrapper&) = default;
constexpr XorWrapper& operator=(const XorWrapper&) = default;
constexpr XorWrapper(XorWrapper&&) = default;
constexpr XorWrapper& operator=(XorWrapper&&) = default;
// Assignment operators.
constexpr XorWrapper& operator+=(const XorWrapper& rhs) {
wrapped_ ^= rhs.value();
return *this;
}
constexpr XorWrapper& operator-=(const XorWrapper& rhs) {
wrapped_ ^= rhs.value();
return *this;
}
// Returns a reference to the wrapped object.
constexpr T& value() { return wrapped_; }
constexpr const T& value() const { return wrapped_; }
private:
T wrapped_;
};
template <typename T>
constexpr XorWrapper<T> operator+(XorWrapper<T> a, const XorWrapper<T>& b) {
a += b;
return a;
}
template <typename T>
constexpr XorWrapper<T> operator-(XorWrapper<T> a, const XorWrapper<T>& b) {
a -= b;
return a;
}
// Negation does nothing in XOR sharing, since -a = 0-a.
template <typename T>
constexpr XorWrapper<T> operator-(const XorWrapper<T>& a) {
return a;
}
template <typename T>
constexpr bool operator==(const XorWrapper<T>& a, const XorWrapper<T>& b) {
return a.value() == b.value();
}
template <typename T>
constexpr bool operator!=(const XorWrapper<T>& a, const XorWrapper<T>& b) {
return !(a == b);
}
} // namespace distributed_point_functions
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_XOR_WRAPPER_H_

@ -1,72 +0,0 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dpf/xor_wrapper.h"
#include <stdint.h>
#include "absl/numeric/int128.h"
#include "gtest/gtest.h"
namespace distributed_point_functions {
namespace {
template <typename T>
class XorWrapperTest : public testing::Test {};
using XorWrapperTypes =
testing::Types<uint8_t, uint16_t, uint32_t, uint64_t, absl::uint128>;
TYPED_TEST_SUITE(XorWrapperTest, XorWrapperTypes);
TYPED_TEST(XorWrapperTest, TestConstructor) {
TypeParam value{42};
XorWrapper<TypeParam> wrapper(value);
EXPECT_EQ(wrapper.value(), value);
}
TYPED_TEST(XorWrapperTest, TestAddition) {
TypeParam a{42}, b{23};
XorWrapper<TypeParam> wrapped_a(a), wrapped_b(b);
EXPECT_EQ((wrapped_a + wrapped_b).value(), a ^ b);
}
TYPED_TEST(XorWrapperTest, TestSubtraction) {
TypeParam a{42}, b{23};
XorWrapper<TypeParam> wrapped_a(a), wrapped_b(b);
EXPECT_EQ((wrapped_a - wrapped_b).value(), a ^ b);
}
TYPED_TEST(XorWrapperTest, TestNegation) {
TypeParam value{42};
XorWrapper<TypeParam> wrapper(value);
EXPECT_EQ((-wrapper).value(), value);
}
TYPED_TEST(XorWrapperTest, TestEquality) {
TypeParam a{42}, b{23};
XorWrapper<TypeParam> wrapped_a(a), wrapped_b(b);
EXPECT_EQ(wrapped_a, XorWrapper<TypeParam>(a));
EXPECT_NE(wrapped_a, XorWrapper<TypeParam>(b));
}
} // namespace
} // namespace distributed_point_functions

@ -1,9 +0,0 @@
# Copyright 2024 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
declare_args() {
use_distributed_point_functions = is_debug
dpf_abseil_cpp_dir = "//third_party/abseil-cpp"
dpf_highway_cpp_dir = "//third_party/highway"
}

@ -1,293 +0,0 @@
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "third_party/distributed_point_functions/code/dpf/distributed_point_function.h"
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
#include <fuzzer/FuzzedDataProvider.h>
#include <algorithm>
#include <memory>
#define DPF_FUZZER_ASSERT(x) \
if (!(x)) { \
printf("DPF assertion failed: function %s, file %s, line %d.\n", \
__PRETTY_FUNCTION__, __FILE__, __LINE__); \
abort(); \
}
namespace {
const size_t UINT128_SIZE = 2 * sizeof(uint64_t);
// Constructs a `uint128` numeric value from two 64-bit unsigned integers
// consumed from the data provider.
absl::uint128 ConsumeUint128(FuzzedDataProvider& data_provider) {
uint64_t high = data_provider.ConsumeIntegral<uint64_t>();
uint64_t low = data_provider.ConsumeIntegral<uint64_t>();
return absl::MakeUint128(high, low);
}
// Returns the prefix of `index` for the domain of `hierarchy_level`.
// Adapted from `DpfEvaluationTest::GetPrefixForLevel()`.
absl::uint128 GetPrefixForLevel(
int hierarchy_level,
absl::uint128 index,
const std::vector<distributed_point_functions::DpfParameters>& parameters) {
absl::uint128 result = 0;
int shift_amount = parameters.back().log_domain_size() -
parameters[hierarchy_level].log_domain_size();
if (shift_amount < 128)
result = index >> shift_amount;
return result;
}
// Evaluates both contexts `ctx0` and `ctx1` at `hierarchy level`, using the
// appropriate prefixes of `evaluation_points`. Checks that the expansion of
// both keys from correct DPF shares, i.e., they add up to
// `beta[ctx.hierarchy_level()]` under prefixes of `alpha`, and to 0 otherwise.
// Adapted from `DpfEvaluationTest::EvaluateAndCheckLevel()`.
template <typename T>
void EvaluateAndCheckLevel(
int hierarchy_level,
absl::Span<const absl::uint128> evaluation_points,
absl::uint128 alpha,
const std::vector<absl::uint128>& beta,
distributed_point_functions::EvaluationContext& ctx0,
distributed_point_functions::EvaluationContext& ctx1,
const std::vector<distributed_point_functions::DpfParameters>& parameters,
const distributed_point_functions::DistributedPointFunction& dpf) {
int previous_hierarchy_level = ctx0.previous_hierarchy_level();
int current_log_domain_size = parameters[hierarchy_level].log_domain_size();
int previous_log_domain_size = 0;
int num_expansions = 1;
bool is_first_evaluation = previous_hierarchy_level < 0;
// Generate prefixes if we're not on the first level.
std::vector<absl::uint128> prefixes;
if (!is_first_evaluation) {
num_expansions = static_cast<int>(evaluation_points.size());
prefixes.resize(evaluation_points.size());
previous_log_domain_size =
parameters[previous_hierarchy_level].log_domain_size();
for (int i = 0; i < static_cast<int>(evaluation_points.size()); ++i)
prefixes[i] = GetPrefixForLevel(previous_hierarchy_level,
evaluation_points[i], parameters);
}
// Evaluating a key with N correction words leads to an O(2^N) malloc, which
// will unsurprisingly cause a fuzzer crash. See <https://crbug.com/1494260>.
constexpr size_t kMaxCorrectionWords = 30;
if (ctx0.key().correction_words().size() > kMaxCorrectionWords) {
return;
}
absl::StatusOr<std::vector<T>> result_0 =
dpf.EvaluateUntil<T>(hierarchy_level, prefixes, ctx0);
DPF_FUZZER_ASSERT(result_0.ok());
if (ctx1.key().correction_words().size() > kMaxCorrectionWords) {
return;
}
absl::StatusOr<std::vector<T>> result_1 =
dpf.EvaluateUntil<T>(hierarchy_level, prefixes, ctx1);
DPF_FUZZER_ASSERT(result_1.ok());
DPF_FUZZER_ASSERT(result_0->size() == result_1->size());
int64_t outputs_per_prefix =
int64_t{1} << (current_log_domain_size - previous_log_domain_size);
int64_t expected_output_size = num_expansions * outputs_per_prefix;
DPF_FUZZER_ASSERT(static_cast<int64_t>(result_0->size()) ==
expected_output_size);
// Iterator over the outputs and check that they sum up to 0 or to
// `beta[current_hierarchy_level]`;
absl::uint128 previous_alpha_prefix = 0;
if (!is_first_evaluation)
previous_alpha_prefix =
GetPrefixForLevel(previous_hierarchy_level, alpha, parameters);
absl::uint128 current_alpha_prefix =
GetPrefixForLevel(hierarchy_level, alpha, parameters);
for (int64_t i = 0; i < expected_output_size; ++i) {
int prefix_index = i / outputs_per_prefix;
int prefix_expansion_index = i % outputs_per_prefix;
// The output is on the path to `alpha`, if we're at the first level or
// under a prefix of `alpha`, and the current block in the expansion of the
// prefix is also on the path to `alpha`.
if ((is_first_evaluation ||
prefixes[prefix_index] == previous_alpha_prefix) &&
prefix_expansion_index == (current_alpha_prefix % outputs_per_prefix)) {
// We need to static_cast here since otherwise operator+ returns an
// unsigned int without doing a modular reduction, which causes the test
// to fail on types with sizeof(T) < sizeof(unsigned).
DPF_FUZZER_ASSERT(
absl::uint128{static_cast<T>((*result_0)[i] + (*result_1)[i])} ==
beta[hierarchy_level]);
} else {
DPF_FUZZER_ASSERT(static_cast<T>((*result_0)[i] + (*result_1)[i]) == 0U);
}
}
}
} // namespace
extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
// Use magic separator to split the input into two parts. The first part will
// generate alpha, and an array of parameters and betas. The second part will
// generate level step and an array of evaluation points.
const uint8_t separator[] = {0xDE, 0xAD, 0xBE, 0xEF};
const uint8_t* pos =
std::search(data, data + size, separator, separator + sizeof(separator));
const uint8_t* data1 = data;
size_t size1 = pos - data;
const uint8_t* data2 =
(pos == data + size) ? nullptr : pos + sizeof(separator);
size_t size2 = data2 ? (data + size) - (pos + sizeof(separator)) : 0;
FuzzedDataProvider data_provider1(data1, size1);
if (data_provider1.remaining_bytes() < UINT128_SIZE)
return 0;
absl::uint128 alpha = ConsumeUint128(data_provider1);
std::vector<int32_t> log_domain_sizes;
std::vector<int32_t> element_bitsizes;
std::vector<distributed_point_functions::DpfParameters> parameters;
std::vector<absl::uint128> beta;
// log_domain_size(int32_t), element_bitsize(int32_t),
// beta(uint128)
while (data_provider1.remaining_bytes() >=
(2 * sizeof(int32_t) + UINT128_SIZE)) {
int32_t log_domain_size = data_provider1.ConsumeIntegral<int32_t>();
int32_t element_bitsize = data_provider1.ConsumeIntegral<int32_t>();
log_domain_sizes.push_back(log_domain_size);
element_bitsizes.push_back(element_bitsize);
distributed_point_functions::DpfParameters parameter;
parameter.set_log_domain_size(log_domain_size);
parameter.mutable_value_type()->mutable_integer()->set_bitsize(
element_bitsize);
parameters.push_back(parameter);
beta.push_back(ConsumeUint128(data_provider1));
}
absl::StatusOr<
std::unique_ptr<distributed_point_functions::DistributedPointFunction>>
status_or_dpf = distributed_point_functions::DistributedPointFunction::
CreateIncremental(parameters);
size_t num_levels = parameters.size();
if (!status_or_dpf.ok()) {
// `log_domain_size` is expected to be in ascending order and
// `element_bitsize` is expected to be non-decreasing. As it is hard for the
// fuzzer to land upon a valid input, we sort the parameters and try again
// if the construction fails.
std::sort(log_domain_sizes.begin(), log_domain_sizes.end());
std::sort(element_bitsizes.begin(), element_bitsizes.end());
for (size_t i = 0; i < num_levels; ++i) {
parameters[i].set_log_domain_size(log_domain_sizes[i]);
parameters[i].mutable_value_type()->mutable_integer()->set_bitsize(
element_bitsizes[i]);
}
status_or_dpf = distributed_point_functions::DistributedPointFunction::
CreateIncremental(parameters);
}
if (!status_or_dpf.ok())
return 0;
std::unique_ptr<distributed_point_functions::DistributedPointFunction> dpf =
std::move(status_or_dpf).value();
absl::StatusOr<std::pair<distributed_point_functions::DpfKey,
distributed_point_functions::DpfKey>>
status_or_keys = dpf->GenerateKeysIncremental(alpha, beta);
if (!status_or_keys.ok())
return 0;
std::pair<distributed_point_functions::DpfKey,
distributed_point_functions::DpfKey>
keys = std::move(status_or_keys).value();
// Adapted from `DpfEvaluationTest.TestCorrectness()`.
absl::StatusOr<distributed_point_functions::EvaluationContext>
status_or_ctx0 = dpf->CreateEvaluationContext(keys.first);
DPF_FUZZER_ASSERT(status_or_ctx0.ok());
absl::StatusOr<distributed_point_functions::EvaluationContext>
status_or_ctx1 = dpf->CreateEvaluationContext(keys.second);
DPF_FUZZER_ASSERT(status_or_ctx1.ok());
distributed_point_functions::EvaluationContext ctx0 =
std::move(status_or_ctx0).value();
distributed_point_functions::EvaluationContext ctx1 =
std::move(status_or_ctx1).value();
// Generate evaluation points.
FuzzedDataProvider data_provider2(data2, size2);
if (data_provider2.remaining_bytes() < sizeof(int))
return 0;
int level_step = data_provider2.ConsumeIntegralInRange<int>(1, 10);
std::vector<absl::uint128> evaluation_points;
while (data_provider2.remaining_bytes() >= UINT128_SIZE) {
evaluation_points.push_back(ConsumeUint128(data_provider2));
if (parameters.back().log_domain_size() < 128)
evaluation_points.back() %=
(absl::uint128{1} << parameters.back().log_domain_size());
}
// Always evaluate on alpha.
evaluation_points.push_back(alpha);
int32_t previous_log_domain_size = 0;
for (int i = level_step - 1; i < static_cast<int>(num_levels);
i += level_step) {
// If any gap in the log_domain_sizes used in successive evaluations is
// larger than 62, validation will fail in `EvaluateAndCheckLevel`.
int32_t current_log_domain_size = parameters[i].log_domain_size();
if (current_log_domain_size - previous_log_domain_size > 62)
return 0;
previous_log_domain_size = current_log_domain_size;
switch (parameters[i].value_type().integer().bitsize()) {
case 8:
EvaluateAndCheckLevel<uint8_t>(i, evaluation_points, alpha, beta, ctx0,
ctx1, parameters, *dpf);
break;
case 16:
EvaluateAndCheckLevel<uint16_t>(i, evaluation_points, alpha, beta, ctx0,
ctx1, parameters, *dpf);
break;
case 32:
EvaluateAndCheckLevel<uint32_t>(i, evaluation_points, alpha, beta, ctx0,
ctx1, parameters, *dpf);
break;
case 64:
EvaluateAndCheckLevel<uint64_t>(i, evaluation_points, alpha, beta, ctx0,
ctx1, parameters, *dpf);
break;
case 128:
EvaluateAndCheckLevel<absl::uint128>(i, evaluation_points, alpha, beta,
ctx0, ctx1, parameters, *dpf);
break;
default:
// DPF construction should've failed if the parameters were invalid.
DPF_FUZZER_ASSERT(false);
break;
}
}
return 0;
}

@ -1,48 +0,0 @@
# Copyright 2024 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
import("//build/buildflag_header.gni")
import("//testing/test.gni")
import("//third_party/distributed_point_functions/features.gni")
source_set("shim") {
public_deps = [ ":buildflags" ]
if (use_distributed_point_functions) {
sources = [
"distributed_point_function_shim.cc",
"distributed_point_function_shim.h",
]
deps = [
"$dpf_abseil_cpp_dir:absl",
"//base",
"//third_party/distributed_point_functions:internal",
]
public_deps += [ "//third_party/distributed_point_functions:proto" ]
configs += [ "//third_party/distributed_point_functions:includes" ]
}
}
# External targets may depend on :buildflags directly without pulling in
# :distributed_point_functions. For instance, tests may set different
# expectations when the dpf library is omitted from the build.
buildflag_header("buildflags") {
header = "buildflags.h"
flags = [ "USE_DISTRIBUTED_POINT_FUNCTIONS=$use_distributed_point_functions" ]
}
test("distributed_point_functions_shim_unittests") {
deps = [
"//testing/gtest",
"//testing/gtest:gtest_main",
]
if (use_distributed_point_functions) {
sources = [ "distributed_point_function_shim_unittest.cc" ]
deps += [
":shim",
"$dpf_abseil_cpp_dir:absl",
"//third_party/protobuf:protobuf_lite",
]
}
}

@ -1,3 +0,0 @@
include_rules = [
"+base",
]

@ -1,42 +0,0 @@
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <memory>
#include <optional>
#include <utility>
#include <vector>
#include "base/check_op.h"
#include "base/logging.h"
#include "third_party/abseil-cpp/absl/numeric/int128.h"
#include "third_party/abseil-cpp/absl/status/status.h"
#include "third_party/abseil-cpp/absl/status/statusor.h"
#include "third_party/distributed_point_functions/code/dpf/distributed_point_function.h"
#include "third_party/distributed_point_functions/dpf/distributed_point_function.pb.h"
#include "third_party/distributed_point_functions/shim/distributed_point_function_shim.h"
namespace distributed_point_functions {
std::optional<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
std::vector<DpfParameters> parameters,
absl::uint128 alpha,
std::vector<absl::uint128> beta) {
// absl::StatusOr is not allowed in the codebase, but this minimal usage is
// necessary to interact with //third_party/distributed_point_functions/.
absl::StatusOr<std::unique_ptr<DistributedPointFunction>> dpf_result =
DistributedPointFunction::CreateIncremental(std::move(parameters));
if (!dpf_result.ok()) {
LOG(ERROR) << "CreateIncremental() failed: " << dpf_result.status();
return std::nullopt;
}
CHECK_NE(*dpf_result, nullptr);
absl::StatusOr<std::pair<DpfKey, DpfKey>> keys_result =
(*dpf_result)->GenerateKeysIncremental(alpha, std::move(beta));
if (!keys_result.ok()) {
LOG(ERROR) << "GenerateKeysIncremental() failed: " << keys_result.status();
return std::nullopt;
}
return std::move(*keys_result);
}
} // namespace distributed_point_functions

@ -1,32 +0,0 @@
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CONTENT_BROWSER_AGGREGATION_SERVICE_DISTRIBUTED_POINT_FUNCTION_SHIM_H_
#define CONTENT_BROWSER_AGGREGATION_SERVICE_DISTRIBUTED_POINT_FUNCTION_SHIM_H_
#include "third_party/distributed_point_functions/shim/buildflags.h"
static_assert(BUILDFLAG(USE_DISTRIBUTED_POINT_FUNCTIONS),
"This header must not be included when "
"distributed_point_functions is omitted from the build");
#include <optional>
#include <utility>
#include <vector>
#include "third_party/abseil-cpp/absl/numeric/int128.h"
#include "third_party/distributed_point_functions/dpf/distributed_point_function.pb.h"
namespace distributed_point_functions {
// Generates a pair of keys for a DPF that evaluates to `beta` when given
// `alpha`. On failure, returns std::nullopt.
std::optional<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
std::vector<DpfParameters> parameters,
absl::uint128 alpha,
std::vector<absl::uint128> beta);
} // namespace distributed_point_functions
#endif // CONTENT_BROWSER_AGGREGATION_SERVICE_DISTRIBUTED_POINT_FUNCTION_SHIM_H_

@ -1,52 +0,0 @@
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <stddef.h>
#include <optional>
#include <utility>
#include <vector>
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/abseil-cpp/absl/numeric/int128.h"
#include "third_party/distributed_point_functions/dpf/distributed_point_function.pb.h"
#include "third_party/distributed_point_functions/shim/distributed_point_function_shim.h"
namespace distributed_point_functions {
// The shim's GenerateKeysIncremental() can return a value besides std::nullopt.
TEST(DistributedPointFunctionShimTest, GenerateKeysIncrementalConstructsKeys) {
constexpr size_t kBitLength = 32;
std::vector<DpfParameters> parameters(kBitLength);
for (size_t i = 0; i < parameters.size(); ++i) {
parameters[i].set_log_domain_size(i + 1);
parameters[i].mutable_value_type()->mutable_integer()->set_bitsize(
parameters.size());
}
std::optional<std::pair<DpfKey, DpfKey>> maybe_dpf_keys =
GenerateKeysIncremental(
std::move(parameters),
/*alpha=*/absl::uint128{1},
/*beta=*/std::vector<absl::uint128>(kBitLength, absl::uint128{1}));
EXPECT_TRUE(maybe_dpf_keys.has_value());
}
// When DistributedPointFunction::CreateIncremental() fails, the shim's
// GenerateKeysIncremental() should return std::nullopt.
TEST(DistributedPointFunctionShimTest, GenerateKeysIncrementalEmptyParameters) {
EXPECT_FALSE(GenerateKeysIncremental(/*parameters=*/{},
/*alpha=*/absl::uint128{}, /*beta=*/{}));
}
// When the length of beta does not match the number of parameters, the internal
// call to DistributedPointFunction::GenerateKeysIncremental() will fail, and
// the shim's GenerateKeysIncremental() should return std::nullopt.
TEST(DistributedPointFunctionShimTest, GenerateKeysIncrementalBetaWrongSize) {
std::vector<DpfParameters> parameters(3);
EXPECT_FALSE(
GenerateKeysIncremental(/*parameters=*/std::vector<DpfParameters>(3),
/*alpha=*/absl::uint128{}, /*beta=*/{1, 2, 3}));
}
} // namespace distributed_point_functions

@ -1 +1,2 @@
file://third_party/distributed_point_functions/OWNERS bikineev@chromium.org
file://third_party/blink/renderer/core/html/parser/OWNERS