webnn: fuse QDQ for element-wise max and min on tflite
Input and outputs must all have same scale and zero_point [1]. [1] https://source.chromium.org/chromium/chromium/src/+/main:third_party/tflite/src/tensorflow/lite/kernels/internal/optimized/optimized_ops.h;l=7101;drc=467a8e68f685f9cfa47ee3fbfca20c22f7f6e893 Bug: 401281047 Change-Id: Iff6d2b320c2e3d89ee49bb804012908c7c6c706f Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6538847 Reviewed-by: ningxin hu <ningxin.hu@intel.com> Reviewed-by: Phillis Tang <phillis@chromium.org> Commit-Queue: Junwei Fu <junwei.fu@intel.com> Cr-Commit-Position: refs/heads/main@{#1459861}
This commit is contained in:
services/webnn/tflite
third_party/blink/web_tests/external/wpt/webnn/conformance_tests
@ -1589,10 +1589,8 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput(
|
||||
// TODO(crbug.com/413083273): Consider the restriction in GPU delegate.
|
||||
// For XNNPack delegate, there are some restriction for inputs and output
|
||||
// scale value.
|
||||
auto checked_lhs_scale_value =
|
||||
base::MakeCheckedNum<float>(lhs_scale_values[0]);
|
||||
auto checked_rhs_scale_value =
|
||||
base::MakeCheckedNum<float>(rhs_scale_values[0]);
|
||||
const float lhs_scale_value = lhs_scale_values[0];
|
||||
const float rhs_scale_value = rhs_scale_values[0];
|
||||
const float output_scale_value = output_scale_values[0];
|
||||
if (binary.kind == mojom::ElementWiseBinary::Kind::kAdd ||
|
||||
binary.kind == mojom::ElementWiseBinary::Kind::kSub) {
|
||||
@ -1602,14 +1600,15 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput(
|
||||
const float scale_max = 256.0f;
|
||||
|
||||
auto checked_lhs_output_scale =
|
||||
checked_lhs_scale_value / output_scale_value;
|
||||
base::MakeCheckedNum<float>(lhs_scale_value);
|
||||
checked_lhs_output_scale /= output_scale_value;
|
||||
if (!checked_lhs_output_scale.IsValid() ||
|
||||
checked_lhs_output_scale.ValueOrDie() < scale_min ||
|
||||
checked_lhs_output_scale.ValueOrDie() >= scale_max) {
|
||||
return std::nullopt;
|
||||
}
|
||||
auto checked_rhs_output_scale =
|
||||
checked_rhs_scale_value / output_scale_value;
|
||||
base::MakeCheckedNum<float>(rhs_scale_value) / output_scale_value;
|
||||
if (!checked_rhs_output_scale.IsValid() ||
|
||||
checked_rhs_output_scale.ValueOrDie() < scale_min ||
|
||||
checked_rhs_output_scale.ValueOrDie() >= scale_max) {
|
||||
@ -1621,13 +1620,31 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput(
|
||||
const float scale_min = 1.0f / 65536.0f;
|
||||
const float scale_max = 256.0f;
|
||||
auto checked_product_output_scale =
|
||||
(checked_lhs_scale_value * checked_rhs_scale_value) /
|
||||
(lhs_scale_value * base::MakeCheckedNum<float>(rhs_scale_value)) /
|
||||
output_scale_value;
|
||||
if (!checked_product_output_scale.IsValid() ||
|
||||
checked_product_output_scale.ValueOrDie() < scale_min ||
|
||||
checked_product_output_scale.ValueOrDie() >= scale_max) {
|
||||
return std::nullopt;
|
||||
}
|
||||
} else if (binary.kind == mojom::ElementWiseBinary::Kind::kMax ||
|
||||
binary.kind == mojom::ElementWiseBinary::Kind::kMin) {
|
||||
// Inputs and output must have the same scale and zero_point.
|
||||
// https://source.chromium.org/chromium/chromium/src/+/main:third_party/tflite/src/tensorflow/lite/kernels/internal/optimized/optimized_ops.h;l=7101;drc=467a8e68f685f9cfa47ee3fbfca20c22f7f6e893
|
||||
if (lhs_scale_value != rhs_scale_value ||
|
||||
lhs_scale_value != output_scale_value) {
|
||||
return std::nullopt;
|
||||
}
|
||||
base::FixedArray<int64_t> lhs_zero_point_values =
|
||||
GetConstantInt64Value(lhs_dequantize.zero_point_operand_id);
|
||||
base::FixedArray<int64_t> rhs_zero_point_values =
|
||||
GetConstantInt64Value(rhs_dequantize.zero_point_operand_id);
|
||||
base::FixedArray<int64_t> output_zero_point_values =
|
||||
GetConstantInt64Value(output_quantize.zero_point_operand_id);
|
||||
if (!std::ranges::equal(lhs_zero_point_values, rhs_zero_point_values) ||
|
||||
!std::ranges::equal(lhs_zero_point_values, output_zero_point_values)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
} else {
|
||||
NOTREACHED() << "Unsupported quantize operators";
|
||||
}
|
||||
@ -3339,11 +3356,13 @@ auto GraphBuilderTflite::SerializeElementWiseBinary(
|
||||
CHECK(context_properties_.data_type_limits.max_input.SupportsAll(
|
||||
{lhs_operand_descriptor, rhs_operand_descriptor}));
|
||||
code = ::tflite::BuiltinOperator_MAXIMUM;
|
||||
quantized_output = CanFuseQuantizeAndGetOutput(op);
|
||||
break;
|
||||
case mojom::ElementWiseBinary::Kind::kMin:
|
||||
CHECK(context_properties_.data_type_limits.min_input.SupportsAll(
|
||||
{lhs_operand_descriptor, rhs_operand_descriptor}));
|
||||
code = ::tflite::BuiltinOperator_MINIMUM;
|
||||
quantized_output = CanFuseQuantizeAndGetOutput(op);
|
||||
break;
|
||||
case mojom::ElementWiseBinary::Kind::kPow:
|
||||
CHECK(context_properties_.data_type_limits.pow_input.SupportsAll(
|
||||
|
216
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/qdq_subgraph.https.any.js
vendored
216
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/qdq_subgraph.https.any.js
vendored
@ -474,6 +474,222 @@ const subgraphTests = [
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
'name': 'quantized element-wise binary max',
|
||||
'graph': {
|
||||
'inputs': {
|
||||
'inputA': {
|
||||
'data': [
|
||||
-2.549168109893799, -4.794857501983643,
|
||||
8.413617134094238, 6.108623504638672
|
||||
],
|
||||
'descriptor': {shape: [2, 2], dataType: 'float32'},
|
||||
'constant': false
|
||||
},
|
||||
'inputAScale': {
|
||||
'data': [0.343092918395996],
|
||||
'descriptor': {shape: [1], dataType: 'float32'},
|
||||
'constant': true
|
||||
},
|
||||
'inputAZeroPoint': {
|
||||
'data': [-128],
|
||||
'descriptor': {shape: [1], dataType: 'int8'},
|
||||
'constant': true
|
||||
},
|
||||
'inputB': {
|
||||
'data': [
|
||||
12, 24, 35, 11,
|
||||
],
|
||||
'descriptor': {shape: [2, 2], dataType: 'int8'},
|
||||
'constant': true
|
||||
},
|
||||
'inputBScale': {
|
||||
'data': [0.343092918395996],
|
||||
'descriptor': {shape: [1], dataType: 'float32'},
|
||||
'constant': true
|
||||
},
|
||||
'inputBZeroPoint': {
|
||||
'data': [-128],
|
||||
'descriptor': {shape: [1], dataType: 'int8'},
|
||||
'constant': true
|
||||
},
|
||||
'outputScale': {
|
||||
'data': [0.343092918395996],
|
||||
'descriptor': {shape: [1], dataType: 'float32'},
|
||||
'constant': true
|
||||
},
|
||||
'outputZeroPoint': {
|
||||
'data': [-128],
|
||||
'descriptor': {shape: [1], dataType: 'int8'},
|
||||
'constant': true
|
||||
},
|
||||
},
|
||||
'operators': [
|
||||
{
|
||||
'name': 'quantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'inputA'},
|
||||
{'scale': 'inputAScale', 'zeroPoint': 'inputAZeroPoint'}
|
||||
],
|
||||
'outputs': 'quantizedInputA'
|
||||
},
|
||||
{
|
||||
'name': 'dequantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'quantizedInputA'},
|
||||
{'scale': 'inputAScale', 'zeroPoint': 'inputAZeroPoint'}
|
||||
],
|
||||
'outputs': 'dequantizedInputA'
|
||||
},
|
||||
{
|
||||
'name': 'dequantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'inputB'},
|
||||
{'scale': 'inputBScale', 'zeroPoint': 'inputBZeroPoint'}
|
||||
],
|
||||
'outputs': 'dequantizedInputB'
|
||||
},
|
||||
{
|
||||
'name': 'max',
|
||||
'arguments': [{'inputA': 'dequantizedInputA'}, {'inputB': 'dequantizedInputB'}],
|
||||
'outputs': 'maxOutput'
|
||||
},
|
||||
{
|
||||
'name': 'quantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'maxOutput'},
|
||||
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||
],
|
||||
'outputs': 'quantizedMaxOutput'
|
||||
},
|
||||
{
|
||||
'name': 'dequantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'quantizedMaxOutput'},
|
||||
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||
],
|
||||
'outputs': 'output'
|
||||
}
|
||||
],
|
||||
'expectedOutputs': {
|
||||
'output': {
|
||||
'data': [
|
||||
48.03300857543945, 52.150123596191406,
|
||||
55.92414474487305, 47.68991470336914,
|
||||
],
|
||||
'descriptor': {shape: [2, 2], dataType: 'float32'}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
'name': 'quantized element-wise binary min',
|
||||
'graph': {
|
||||
'inputs': {
|
||||
'inputA': {
|
||||
'data': [
|
||||
3.549168109893799, 4.794857501983643,
|
||||
8.413617134094238, 6.108623504638672
|
||||
],
|
||||
'descriptor': {shape: [2, 2], dataType: 'float32'},
|
||||
'constant': false
|
||||
},
|
||||
'inputAScale': {
|
||||
'data': [0.343092918395996],
|
||||
'descriptor': {shape: [1], dataType: 'float32'},
|
||||
'constant': true
|
||||
},
|
||||
'inputAZeroPoint': {
|
||||
'data': [-128],
|
||||
'descriptor': {shape: [1], dataType: 'int8'},
|
||||
'constant': true
|
||||
},
|
||||
'inputB': {
|
||||
'data': [
|
||||
12, 24, 35, 11,
|
||||
],
|
||||
'descriptor': {shape: [2, 2], dataType: 'int8'},
|
||||
'constant': true
|
||||
},
|
||||
'inputBScale': {
|
||||
'data': [0.343092918395996],
|
||||
'descriptor': {shape: [1], dataType: 'float32'},
|
||||
'constant': true
|
||||
},
|
||||
'inputBZeroPoint': {
|
||||
'data': [-128],
|
||||
'descriptor': {shape: [1], dataType: 'int8'},
|
||||
'constant': true
|
||||
},
|
||||
'outputScale': {
|
||||
'data': [0.343092918395996],
|
||||
'descriptor': {shape: [1], dataType: 'float32'},
|
||||
'constant': true
|
||||
},
|
||||
'outputZeroPoint': {
|
||||
'data': [-128],
|
||||
'descriptor': {shape: [1], dataType: 'int8'},
|
||||
'constant': true
|
||||
},
|
||||
},
|
||||
'operators': [
|
||||
{
|
||||
'name': 'quantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'inputA'},
|
||||
{'scale': 'inputAScale', 'zeroPoint': 'inputAZeroPoint'}
|
||||
],
|
||||
'outputs': 'quantizedInputA'
|
||||
},
|
||||
{
|
||||
'name': 'dequantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'quantizedInputA'},
|
||||
{'scale': 'inputAScale', 'zeroPoint': 'inputAZeroPoint'}
|
||||
],
|
||||
'outputs': 'dequantizedInputA'
|
||||
},
|
||||
{
|
||||
'name': 'dequantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'inputB'},
|
||||
{'scale': 'inputBScale', 'zeroPoint': 'inputBZeroPoint'}
|
||||
],
|
||||
'outputs': 'dequantizedInputB'
|
||||
},
|
||||
{
|
||||
'name': 'min',
|
||||
'arguments': [{'inputA': 'dequantizedInputA'}, {'inputB': 'dequantizedInputB'}],
|
||||
'outputs': 'minOutput'
|
||||
},
|
||||
{
|
||||
'name': 'quantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'minOutput'},
|
||||
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||
],
|
||||
'outputs': 'quantizedMinOutput'
|
||||
},
|
||||
{
|
||||
'name': 'dequantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'quantizedMinOutput'},
|
||||
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||
],
|
||||
'outputs': 'output'
|
||||
}
|
||||
],
|
||||
'expectedOutputs': {
|
||||
'output': {
|
||||
'data': [
|
||||
3.430929183959961, 4.803300857543945,
|
||||
8.577322959899902, 6.17567253112793,
|
||||
],
|
||||
'descriptor': {shape: [2, 2], dataType: 'float32'}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
'name': 'quantized transpose',
|
||||
'graph': {
|
||||
|
Reference in New Issue
Block a user