webnn: fuse QDQ for element-wise max and min on tflite
Input and outputs must all have same scale and zero_point [1]. [1] https://source.chromium.org/chromium/chromium/src/+/main:third_party/tflite/src/tensorflow/lite/kernels/internal/optimized/optimized_ops.h;l=7101;drc=467a8e68f685f9cfa47ee3fbfca20c22f7f6e893 Bug: 401281047 Change-Id: Iff6d2b320c2e3d89ee49bb804012908c7c6c706f Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6538847 Reviewed-by: ningxin hu <ningxin.hu@intel.com> Reviewed-by: Phillis Tang <phillis@chromium.org> Commit-Queue: Junwei Fu <junwei.fu@intel.com> Cr-Commit-Position: refs/heads/main@{#1459861}
This commit is contained in:
services/webnn/tflite
third_party/blink/web_tests/external/wpt/webnn/conformance_tests
@ -1589,10 +1589,8 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput(
|
|||||||
// TODO(crbug.com/413083273): Consider the restriction in GPU delegate.
|
// TODO(crbug.com/413083273): Consider the restriction in GPU delegate.
|
||||||
// For XNNPack delegate, there are some restriction for inputs and output
|
// For XNNPack delegate, there are some restriction for inputs and output
|
||||||
// scale value.
|
// scale value.
|
||||||
auto checked_lhs_scale_value =
|
const float lhs_scale_value = lhs_scale_values[0];
|
||||||
base::MakeCheckedNum<float>(lhs_scale_values[0]);
|
const float rhs_scale_value = rhs_scale_values[0];
|
||||||
auto checked_rhs_scale_value =
|
|
||||||
base::MakeCheckedNum<float>(rhs_scale_values[0]);
|
|
||||||
const float output_scale_value = output_scale_values[0];
|
const float output_scale_value = output_scale_values[0];
|
||||||
if (binary.kind == mojom::ElementWiseBinary::Kind::kAdd ||
|
if (binary.kind == mojom::ElementWiseBinary::Kind::kAdd ||
|
||||||
binary.kind == mojom::ElementWiseBinary::Kind::kSub) {
|
binary.kind == mojom::ElementWiseBinary::Kind::kSub) {
|
||||||
@ -1602,14 +1600,15 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput(
|
|||||||
const float scale_max = 256.0f;
|
const float scale_max = 256.0f;
|
||||||
|
|
||||||
auto checked_lhs_output_scale =
|
auto checked_lhs_output_scale =
|
||||||
checked_lhs_scale_value / output_scale_value;
|
base::MakeCheckedNum<float>(lhs_scale_value);
|
||||||
|
checked_lhs_output_scale /= output_scale_value;
|
||||||
if (!checked_lhs_output_scale.IsValid() ||
|
if (!checked_lhs_output_scale.IsValid() ||
|
||||||
checked_lhs_output_scale.ValueOrDie() < scale_min ||
|
checked_lhs_output_scale.ValueOrDie() < scale_min ||
|
||||||
checked_lhs_output_scale.ValueOrDie() >= scale_max) {
|
checked_lhs_output_scale.ValueOrDie() >= scale_max) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
auto checked_rhs_output_scale =
|
auto checked_rhs_output_scale =
|
||||||
checked_rhs_scale_value / output_scale_value;
|
base::MakeCheckedNum<float>(rhs_scale_value) / output_scale_value;
|
||||||
if (!checked_rhs_output_scale.IsValid() ||
|
if (!checked_rhs_output_scale.IsValid() ||
|
||||||
checked_rhs_output_scale.ValueOrDie() < scale_min ||
|
checked_rhs_output_scale.ValueOrDie() < scale_min ||
|
||||||
checked_rhs_output_scale.ValueOrDie() >= scale_max) {
|
checked_rhs_output_scale.ValueOrDie() >= scale_max) {
|
||||||
@ -1621,13 +1620,31 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput(
|
|||||||
const float scale_min = 1.0f / 65536.0f;
|
const float scale_min = 1.0f / 65536.0f;
|
||||||
const float scale_max = 256.0f;
|
const float scale_max = 256.0f;
|
||||||
auto checked_product_output_scale =
|
auto checked_product_output_scale =
|
||||||
(checked_lhs_scale_value * checked_rhs_scale_value) /
|
(lhs_scale_value * base::MakeCheckedNum<float>(rhs_scale_value)) /
|
||||||
output_scale_value;
|
output_scale_value;
|
||||||
if (!checked_product_output_scale.IsValid() ||
|
if (!checked_product_output_scale.IsValid() ||
|
||||||
checked_product_output_scale.ValueOrDie() < scale_min ||
|
checked_product_output_scale.ValueOrDie() < scale_min ||
|
||||||
checked_product_output_scale.ValueOrDie() >= scale_max) {
|
checked_product_output_scale.ValueOrDie() >= scale_max) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
} else if (binary.kind == mojom::ElementWiseBinary::Kind::kMax ||
|
||||||
|
binary.kind == mojom::ElementWiseBinary::Kind::kMin) {
|
||||||
|
// Inputs and output must have the same scale and zero_point.
|
||||||
|
// https://source.chromium.org/chromium/chromium/src/+/main:third_party/tflite/src/tensorflow/lite/kernels/internal/optimized/optimized_ops.h;l=7101;drc=467a8e68f685f9cfa47ee3fbfca20c22f7f6e893
|
||||||
|
if (lhs_scale_value != rhs_scale_value ||
|
||||||
|
lhs_scale_value != output_scale_value) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
base::FixedArray<int64_t> lhs_zero_point_values =
|
||||||
|
GetConstantInt64Value(lhs_dequantize.zero_point_operand_id);
|
||||||
|
base::FixedArray<int64_t> rhs_zero_point_values =
|
||||||
|
GetConstantInt64Value(rhs_dequantize.zero_point_operand_id);
|
||||||
|
base::FixedArray<int64_t> output_zero_point_values =
|
||||||
|
GetConstantInt64Value(output_quantize.zero_point_operand_id);
|
||||||
|
if (!std::ranges::equal(lhs_zero_point_values, rhs_zero_point_values) ||
|
||||||
|
!std::ranges::equal(lhs_zero_point_values, output_zero_point_values)) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
NOTREACHED() << "Unsupported quantize operators";
|
NOTREACHED() << "Unsupported quantize operators";
|
||||||
}
|
}
|
||||||
@ -3339,11 +3356,13 @@ auto GraphBuilderTflite::SerializeElementWiseBinary(
|
|||||||
CHECK(context_properties_.data_type_limits.max_input.SupportsAll(
|
CHECK(context_properties_.data_type_limits.max_input.SupportsAll(
|
||||||
{lhs_operand_descriptor, rhs_operand_descriptor}));
|
{lhs_operand_descriptor, rhs_operand_descriptor}));
|
||||||
code = ::tflite::BuiltinOperator_MAXIMUM;
|
code = ::tflite::BuiltinOperator_MAXIMUM;
|
||||||
|
quantized_output = CanFuseQuantizeAndGetOutput(op);
|
||||||
break;
|
break;
|
||||||
case mojom::ElementWiseBinary::Kind::kMin:
|
case mojom::ElementWiseBinary::Kind::kMin:
|
||||||
CHECK(context_properties_.data_type_limits.min_input.SupportsAll(
|
CHECK(context_properties_.data_type_limits.min_input.SupportsAll(
|
||||||
{lhs_operand_descriptor, rhs_operand_descriptor}));
|
{lhs_operand_descriptor, rhs_operand_descriptor}));
|
||||||
code = ::tflite::BuiltinOperator_MINIMUM;
|
code = ::tflite::BuiltinOperator_MINIMUM;
|
||||||
|
quantized_output = CanFuseQuantizeAndGetOutput(op);
|
||||||
break;
|
break;
|
||||||
case mojom::ElementWiseBinary::Kind::kPow:
|
case mojom::ElementWiseBinary::Kind::kPow:
|
||||||
CHECK(context_properties_.data_type_limits.pow_input.SupportsAll(
|
CHECK(context_properties_.data_type_limits.pow_input.SupportsAll(
|
||||||
|
216
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/qdq_subgraph.https.any.js
vendored
216
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/qdq_subgraph.https.any.js
vendored
@ -474,6 +474,222 @@ const subgraphTests = [
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
'name': 'quantized element-wise binary max',
|
||||||
|
'graph': {
|
||||||
|
'inputs': {
|
||||||
|
'inputA': {
|
||||||
|
'data': [
|
||||||
|
-2.549168109893799, -4.794857501983643,
|
||||||
|
8.413617134094238, 6.108623504638672
|
||||||
|
],
|
||||||
|
'descriptor': {shape: [2, 2], dataType: 'float32'},
|
||||||
|
'constant': false
|
||||||
|
},
|
||||||
|
'inputAScale': {
|
||||||
|
'data': [0.343092918395996],
|
||||||
|
'descriptor': {shape: [1], dataType: 'float32'},
|
||||||
|
'constant': true
|
||||||
|
},
|
||||||
|
'inputAZeroPoint': {
|
||||||
|
'data': [-128],
|
||||||
|
'descriptor': {shape: [1], dataType: 'int8'},
|
||||||
|
'constant': true
|
||||||
|
},
|
||||||
|
'inputB': {
|
||||||
|
'data': [
|
||||||
|
12, 24, 35, 11,
|
||||||
|
],
|
||||||
|
'descriptor': {shape: [2, 2], dataType: 'int8'},
|
||||||
|
'constant': true
|
||||||
|
},
|
||||||
|
'inputBScale': {
|
||||||
|
'data': [0.343092918395996],
|
||||||
|
'descriptor': {shape: [1], dataType: 'float32'},
|
||||||
|
'constant': true
|
||||||
|
},
|
||||||
|
'inputBZeroPoint': {
|
||||||
|
'data': [-128],
|
||||||
|
'descriptor': {shape: [1], dataType: 'int8'},
|
||||||
|
'constant': true
|
||||||
|
},
|
||||||
|
'outputScale': {
|
||||||
|
'data': [0.343092918395996],
|
||||||
|
'descriptor': {shape: [1], dataType: 'float32'},
|
||||||
|
'constant': true
|
||||||
|
},
|
||||||
|
'outputZeroPoint': {
|
||||||
|
'data': [-128],
|
||||||
|
'descriptor': {shape: [1], dataType: 'int8'},
|
||||||
|
'constant': true
|
||||||
|
},
|
||||||
|
},
|
||||||
|
'operators': [
|
||||||
|
{
|
||||||
|
'name': 'quantizeLinear',
|
||||||
|
'arguments': [
|
||||||
|
{'input': 'inputA'},
|
||||||
|
{'scale': 'inputAScale', 'zeroPoint': 'inputAZeroPoint'}
|
||||||
|
],
|
||||||
|
'outputs': 'quantizedInputA'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'dequantizeLinear',
|
||||||
|
'arguments': [
|
||||||
|
{'input': 'quantizedInputA'},
|
||||||
|
{'scale': 'inputAScale', 'zeroPoint': 'inputAZeroPoint'}
|
||||||
|
],
|
||||||
|
'outputs': 'dequantizedInputA'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'dequantizeLinear',
|
||||||
|
'arguments': [
|
||||||
|
{'input': 'inputB'},
|
||||||
|
{'scale': 'inputBScale', 'zeroPoint': 'inputBZeroPoint'}
|
||||||
|
],
|
||||||
|
'outputs': 'dequantizedInputB'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'max',
|
||||||
|
'arguments': [{'inputA': 'dequantizedInputA'}, {'inputB': 'dequantizedInputB'}],
|
||||||
|
'outputs': 'maxOutput'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'quantizeLinear',
|
||||||
|
'arguments': [
|
||||||
|
{'input': 'maxOutput'},
|
||||||
|
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||||
|
],
|
||||||
|
'outputs': 'quantizedMaxOutput'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'dequantizeLinear',
|
||||||
|
'arguments': [
|
||||||
|
{'input': 'quantizedMaxOutput'},
|
||||||
|
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||||
|
],
|
||||||
|
'outputs': 'output'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
'expectedOutputs': {
|
||||||
|
'output': {
|
||||||
|
'data': [
|
||||||
|
48.03300857543945, 52.150123596191406,
|
||||||
|
55.92414474487305, 47.68991470336914,
|
||||||
|
],
|
||||||
|
'descriptor': {shape: [2, 2], dataType: 'float32'}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'quantized element-wise binary min',
|
||||||
|
'graph': {
|
||||||
|
'inputs': {
|
||||||
|
'inputA': {
|
||||||
|
'data': [
|
||||||
|
3.549168109893799, 4.794857501983643,
|
||||||
|
8.413617134094238, 6.108623504638672
|
||||||
|
],
|
||||||
|
'descriptor': {shape: [2, 2], dataType: 'float32'},
|
||||||
|
'constant': false
|
||||||
|
},
|
||||||
|
'inputAScale': {
|
||||||
|
'data': [0.343092918395996],
|
||||||
|
'descriptor': {shape: [1], dataType: 'float32'},
|
||||||
|
'constant': true
|
||||||
|
},
|
||||||
|
'inputAZeroPoint': {
|
||||||
|
'data': [-128],
|
||||||
|
'descriptor': {shape: [1], dataType: 'int8'},
|
||||||
|
'constant': true
|
||||||
|
},
|
||||||
|
'inputB': {
|
||||||
|
'data': [
|
||||||
|
12, 24, 35, 11,
|
||||||
|
],
|
||||||
|
'descriptor': {shape: [2, 2], dataType: 'int8'},
|
||||||
|
'constant': true
|
||||||
|
},
|
||||||
|
'inputBScale': {
|
||||||
|
'data': [0.343092918395996],
|
||||||
|
'descriptor': {shape: [1], dataType: 'float32'},
|
||||||
|
'constant': true
|
||||||
|
},
|
||||||
|
'inputBZeroPoint': {
|
||||||
|
'data': [-128],
|
||||||
|
'descriptor': {shape: [1], dataType: 'int8'},
|
||||||
|
'constant': true
|
||||||
|
},
|
||||||
|
'outputScale': {
|
||||||
|
'data': [0.343092918395996],
|
||||||
|
'descriptor': {shape: [1], dataType: 'float32'},
|
||||||
|
'constant': true
|
||||||
|
},
|
||||||
|
'outputZeroPoint': {
|
||||||
|
'data': [-128],
|
||||||
|
'descriptor': {shape: [1], dataType: 'int8'},
|
||||||
|
'constant': true
|
||||||
|
},
|
||||||
|
},
|
||||||
|
'operators': [
|
||||||
|
{
|
||||||
|
'name': 'quantizeLinear',
|
||||||
|
'arguments': [
|
||||||
|
{'input': 'inputA'},
|
||||||
|
{'scale': 'inputAScale', 'zeroPoint': 'inputAZeroPoint'}
|
||||||
|
],
|
||||||
|
'outputs': 'quantizedInputA'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'dequantizeLinear',
|
||||||
|
'arguments': [
|
||||||
|
{'input': 'quantizedInputA'},
|
||||||
|
{'scale': 'inputAScale', 'zeroPoint': 'inputAZeroPoint'}
|
||||||
|
],
|
||||||
|
'outputs': 'dequantizedInputA'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'dequantizeLinear',
|
||||||
|
'arguments': [
|
||||||
|
{'input': 'inputB'},
|
||||||
|
{'scale': 'inputBScale', 'zeroPoint': 'inputBZeroPoint'}
|
||||||
|
],
|
||||||
|
'outputs': 'dequantizedInputB'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'min',
|
||||||
|
'arguments': [{'inputA': 'dequantizedInputA'}, {'inputB': 'dequantizedInputB'}],
|
||||||
|
'outputs': 'minOutput'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'quantizeLinear',
|
||||||
|
'arguments': [
|
||||||
|
{'input': 'minOutput'},
|
||||||
|
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||||
|
],
|
||||||
|
'outputs': 'quantizedMinOutput'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'dequantizeLinear',
|
||||||
|
'arguments': [
|
||||||
|
{'input': 'quantizedMinOutput'},
|
||||||
|
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||||
|
],
|
||||||
|
'outputs': 'output'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
'expectedOutputs': {
|
||||||
|
'output': {
|
||||||
|
'data': [
|
||||||
|
3.430929183959961, 4.803300857543945,
|
||||||
|
8.577322959899902, 6.17567253112793,
|
||||||
|
],
|
||||||
|
'descriptor': {shape: [2, 2], dataType: 'float32'}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
'name': 'quantized transpose',
|
'name': 'quantized transpose',
|
||||||
'graph': {
|
'graph': {
|
||||||
|
Reference in New Issue
Block a user