webnn: fuse QDQ for element-wise max and min on tflite
Input and outputs must all have same scale and zero_point [1]. [1] https://source.chromium.org/chromium/chromium/src/+/main:third_party/tflite/src/tensorflow/lite/kernels/internal/optimized/optimized_ops.h;l=7101;drc=467a8e68f685f9cfa47ee3fbfca20c22f7f6e893 Bug: 401281047 Change-Id: Iff6d2b320c2e3d89ee49bb804012908c7c6c706f Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6538847 Reviewed-by: ningxin hu <ningxin.hu@intel.com> Reviewed-by: Phillis Tang <phillis@chromium.org> Commit-Queue: Junwei Fu <junwei.fu@intel.com> Cr-Commit-Position: refs/heads/main@{#1459861}
This commit is contained in:
services/webnn/tflite
third_party/blink/web_tests/external/wpt/webnn/conformance_tests
@ -1589,10 +1589,8 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput(
|
||||
// TODO(crbug.com/413083273): Consider the restriction in GPU delegate.
|
||||
// For XNNPack delegate, there are some restriction for inputs and output
|
||||
// scale value.
|
||||
auto checked_lhs_scale_value =
|
||||
base::MakeCheckedNum<float>(lhs_scale_values[0]);
|
||||
auto checked_rhs_scale_value =
|
||||
base::MakeCheckedNum<float>(rhs_scale_values[0]);
|
||||
const float lhs_scale_value = lhs_scale_values[0];
|
||||
const float rhs_scale_value = rhs_scale_values[0];
|
||||
const float output_scale_value = output_scale_values[0];
|
||||
if (binary.kind == mojom::ElementWiseBinary::Kind::kAdd ||
|
||||
binary.kind == mojom::ElementWiseBinary::Kind::kSub) {
|
||||
@ -1602,14 +1600,15 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput(
|
||||
const float scale_max = 256.0f;
|
||||
|
||||
auto checked_lhs_output_scale =
|
||||
checked_lhs_scale_value / output_scale_value;
|
||||
base::MakeCheckedNum<float>(lhs_scale_value);
|
||||
checked_lhs_output_scale /= output_scale_value;
|
||||
if (!checked_lhs_output_scale.IsValid() ||
|
||||
checked_lhs_output_scale.ValueOrDie() < scale_min ||
|
||||
checked_lhs_output_scale.ValueOrDie() >= scale_max) {
|
||||
return std::nullopt;
|
||||
}
|
||||
auto checked_rhs_output_scale =
|
||||
checked_rhs_scale_value / output_scale_value;
|
||||
base::MakeCheckedNum<float>(rhs_scale_value) / output_scale_value;
|
||||
if (!checked_rhs_output_scale.IsValid() ||
|
||||
checked_rhs_output_scale.ValueOrDie() < scale_min ||
|
||||
checked_rhs_output_scale.ValueOrDie() >= scale_max) {
|
||||
@ -1621,13 +1620,31 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput(
|
||||
const float scale_min = 1.0f / 65536.0f;
|
||||
const float scale_max = 256.0f;
|
||||
auto checked_product_output_scale =
|
||||
(checked_lhs_scale_value * checked_rhs_scale_value) /
|
||||
(lhs_scale_value * base::MakeCheckedNum<float>(rhs_scale_value)) /
|
||||
output_scale_value;
|
||||
if (!checked_product_output_scale.IsValid() ||
|
||||
checked_product_output_scale.ValueOrDie() < scale_min ||
|
||||
checked_product_output_scale.ValueOrDie() >= scale_max) {
|
||||
return std::nullopt;
|
||||
}
|
||||
} else if (binary.kind == mojom::ElementWiseBinary::Kind::kMax ||
|
||||
binary.kind == mojom::ElementWiseBinary::Kind::kMin) {
|
||||
// Inputs and output must have the same scale and zero_point.
|
||||
// https://source.chromium.org/chromium/chromium/src/+/main:third_party/tflite/src/tensorflow/lite/kernels/internal/optimized/optimized_ops.h;l=7101;drc=467a8e68f685f9cfa47ee3fbfca20c22f7f6e893
|
||||
if (lhs_scale_value != rhs_scale_value ||
|
||||
lhs_scale_value != output_scale_value) {
|
||||
return std::nullopt;
|
||||
}
|
||||
base::FixedArray<int64_t> lhs_zero_point_values =
|
||||
GetConstantInt64Value(lhs_dequantize.zero_point_operand_id);
|
||||
base::FixedArray<int64_t> rhs_zero_point_values =
|
||||
GetConstantInt64Value(rhs_dequantize.zero_point_operand_id);
|
||||
base::FixedArray<int64_t> output_zero_point_values =
|
||||
GetConstantInt64Value(output_quantize.zero_point_operand_id);
|
||||
if (!std::ranges::equal(lhs_zero_point_values, rhs_zero_point_values) ||
|
||||
!std::ranges::equal(lhs_zero_point_values, output_zero_point_values)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
} else {
|
||||
NOTREACHED() << "Unsupported quantize operators";
|
||||
}
|
||||
@ -3339,11 +3356,13 @@ auto GraphBuilderTflite::SerializeElementWiseBinary(
|
||||
CHECK(context_properties_.data_type_limits.max_input.SupportsAll(
|
||||
{lhs_operand_descriptor, rhs_operand_descriptor}));
|
||||
code = ::tflite::BuiltinOperator_MAXIMUM;
|
||||
quantized_output = CanFuseQuantizeAndGetOutput(op);
|
||||
break;
|
||||
case mojom::ElementWiseBinary::Kind::kMin:
|
||||
CHECK(context_properties_.data_type_limits.min_input.SupportsAll(
|
||||
{lhs_operand_descriptor, rhs_operand_descriptor}));
|
||||
code = ::tflite::BuiltinOperator_MINIMUM;
|
||||
quantized_output = CanFuseQuantizeAndGetOutput(op);
|
||||
break;
|
||||
case mojom::ElementWiseBinary::Kind::kPow:
|
||||
CHECK(context_properties_.data_type_limits.pow_input.SupportsAll(
|
||||
|
Reference in New Issue
Block a user