0

webnn: fuse QDQ for element-wise max and min on tflite

Input and outputs must all have same scale and zero_point [1].

[1] https://source.chromium.org/chromium/chromium/src/+/main:third_party/tflite/src/tensorflow/lite/kernels/internal/optimized/optimized_ops.h;l=7101;drc=467a8e68f685f9cfa47ee3fbfca20c22f7f6e893

Bug: 401281047
Change-Id: Iff6d2b320c2e3d89ee49bb804012908c7c6c706f
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6538847
Reviewed-by: ningxin hu <ningxin.hu@intel.com>
Reviewed-by: Phillis Tang <phillis@chromium.org>
Commit-Queue: Junwei Fu <junwei.fu@intel.com>
Cr-Commit-Position: refs/heads/main@{#1459861}
This commit is contained in:
junwei
2025-05-13 20:06:32 -07:00
committed by Chromium LUCI CQ
parent c9bc0ed04e
commit 0a11afc76a
2 changed files with 242 additions and 7 deletions
services/webnn/tflite
third_party/blink/web_tests/external/wpt/webnn/conformance_tests

@ -1589,10 +1589,8 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput(
// TODO(crbug.com/413083273): Consider the restriction in GPU delegate.
// For XNNPack delegate, there are some restriction for inputs and output
// scale value.
auto checked_lhs_scale_value =
base::MakeCheckedNum<float>(lhs_scale_values[0]);
auto checked_rhs_scale_value =
base::MakeCheckedNum<float>(rhs_scale_values[0]);
const float lhs_scale_value = lhs_scale_values[0];
const float rhs_scale_value = rhs_scale_values[0];
const float output_scale_value = output_scale_values[0];
if (binary.kind == mojom::ElementWiseBinary::Kind::kAdd ||
binary.kind == mojom::ElementWiseBinary::Kind::kSub) {
@ -1602,14 +1600,15 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput(
const float scale_max = 256.0f;
auto checked_lhs_output_scale =
checked_lhs_scale_value / output_scale_value;
base::MakeCheckedNum<float>(lhs_scale_value);
checked_lhs_output_scale /= output_scale_value;
if (!checked_lhs_output_scale.IsValid() ||
checked_lhs_output_scale.ValueOrDie() < scale_min ||
checked_lhs_output_scale.ValueOrDie() >= scale_max) {
return std::nullopt;
}
auto checked_rhs_output_scale =
checked_rhs_scale_value / output_scale_value;
base::MakeCheckedNum<float>(rhs_scale_value) / output_scale_value;
if (!checked_rhs_output_scale.IsValid() ||
checked_rhs_output_scale.ValueOrDie() < scale_min ||
checked_rhs_output_scale.ValueOrDie() >= scale_max) {
@ -1621,13 +1620,31 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput(
const float scale_min = 1.0f / 65536.0f;
const float scale_max = 256.0f;
auto checked_product_output_scale =
(checked_lhs_scale_value * checked_rhs_scale_value) /
(lhs_scale_value * base::MakeCheckedNum<float>(rhs_scale_value)) /
output_scale_value;
if (!checked_product_output_scale.IsValid() ||
checked_product_output_scale.ValueOrDie() < scale_min ||
checked_product_output_scale.ValueOrDie() >= scale_max) {
return std::nullopt;
}
} else if (binary.kind == mojom::ElementWiseBinary::Kind::kMax ||
binary.kind == mojom::ElementWiseBinary::Kind::kMin) {
// Inputs and output must have the same scale and zero_point.
// https://source.chromium.org/chromium/chromium/src/+/main:third_party/tflite/src/tensorflow/lite/kernels/internal/optimized/optimized_ops.h;l=7101;drc=467a8e68f685f9cfa47ee3fbfca20c22f7f6e893
if (lhs_scale_value != rhs_scale_value ||
lhs_scale_value != output_scale_value) {
return std::nullopt;
}
base::FixedArray<int64_t> lhs_zero_point_values =
GetConstantInt64Value(lhs_dequantize.zero_point_operand_id);
base::FixedArray<int64_t> rhs_zero_point_values =
GetConstantInt64Value(rhs_dequantize.zero_point_operand_id);
base::FixedArray<int64_t> output_zero_point_values =
GetConstantInt64Value(output_quantize.zero_point_operand_id);
if (!std::ranges::equal(lhs_zero_point_values, rhs_zero_point_values) ||
!std::ranges::equal(lhs_zero_point_values, output_zero_point_values)) {
return std::nullopt;
}
} else {
NOTREACHED() << "Unsupported quantize operators";
}
@ -3339,11 +3356,13 @@ auto GraphBuilderTflite::SerializeElementWiseBinary(
CHECK(context_properties_.data_type_limits.max_input.SupportsAll(
{lhs_operand_descriptor, rhs_operand_descriptor}));
code = ::tflite::BuiltinOperator_MAXIMUM;
quantized_output = CanFuseQuantizeAndGetOutput(op);
break;
case mojom::ElementWiseBinary::Kind::kMin:
CHECK(context_properties_.data_type_limits.min_input.SupportsAll(
{lhs_operand_descriptor, rhs_operand_descriptor}));
code = ::tflite::BuiltinOperator_MINIMUM;
quantized_output = CanFuseQuantizeAndGetOutput(op);
break;
case mojom::ElementWiseBinary::Kind::kPow:
CHECK(context_properties_.data_type_limits.pow_input.SupportsAll(