[WebNN] Fuse QDQ for elu on tflite
This CL supports fusing `dq->elu->q` subgraph on tflite. Input and output operand have to be dequantized from int8, and the input and output scale must be scaler that is validated in TFLite XNNPACK's function `CheckTensorFloat32OrQInt8Type()`. Bug: 401281047 Change-Id: I3b58f5921cb1075088b8616c8a8e9dd5d4a2273f Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6514515 Reviewed-by: ningxin hu <ningxin.hu@intel.com> Commit-Queue: Wei4 Wang <wei4.wang@intel.com> Reviewed-by: Phillis Tang <phillis@chromium.org> Cr-Commit-Position: refs/heads/main@{#1456728}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
94a696a1fa
commit
051ecbd34c
services/webnn/tflite
third_party/blink/web_tests/external/wpt/webnn/conformance_tests
@ -1641,6 +1641,43 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput(
|
||||
return TrySerializeQuantizedOutput(*next_op);
|
||||
}
|
||||
|
||||
std::optional<GraphBuilderTflite::TensorInfo>
|
||||
GraphBuilderTflite::CanFuseQuantizeAndGetOutput(const mojom::Elu& elu) {
|
||||
if (!IsDequantizeOutput(elu.input_operand_id)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// TODO(crbug.com/413083273): Consider the restriction in GPU delegate.
|
||||
// For XNNPack delegate, the input must be dequantized from int8, the input
|
||||
// and output scale must be scaler.
|
||||
// https://source.chromium.org/chromium/chromium/src/+/main:third_party/tflite/src/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc;l=4136;drc=f667feb8a5c6f227b49328ce78a062acc4f81187
|
||||
const mojom::DequantizeLinear& input_dequantize =
|
||||
GetDequantizeOp(elu.input_operand_id);
|
||||
if (GetOperand(input_dequantize.input_operand_id).descriptor.data_type() !=
|
||||
OperandDataType::kInt8) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (GetOperand(input_dequantize.scale_operand_id)
|
||||
.descriptor.NumberOfElements() != 1) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<size_t> next_op =
|
||||
IsNextOpQuantize(elu.output_operand_id, {OperandDataType::kInt8});
|
||||
if (!next_op) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
const mojom::QuantizeLinear& output_quantize = GetQuantizeOp(*next_op);
|
||||
if (GetOperand(output_quantize.scale_operand_id)
|
||||
.descriptor.NumberOfElements() != 1) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return TrySerializeQuantizedOutput(*next_op);
|
||||
}
|
||||
|
||||
std::optional<GraphBuilderTflite::TensorInfo>
|
||||
GraphBuilderTflite::CanFuseQuantizeAndGetOutput(
|
||||
const mojom::Transpose& transpose) {
|
||||
@ -3456,13 +3493,24 @@ auto GraphBuilderTflite::SerializeElu(const mojom::Elu& elu)
|
||||
"Setting a custom alpha is not supported in tflite schema.");
|
||||
}
|
||||
|
||||
std::optional<TensorInfo> quantized_output = CanFuseQuantizeAndGetOutput(elu);
|
||||
const bool fuse_dequantize = quantized_output.has_value();
|
||||
ASSIGN_OR_RETURN(const TensorInfo& input_tensor_info,
|
||||
SerializeInputTensorInfo(elu.input_operand_id));
|
||||
ASSIGN_OR_RETURN(const TensorInfo& output_tensor_info,
|
||||
SerializeOutputTensorInfo(elu.output_operand_id));
|
||||
SerializeInputTensorInfo(
|
||||
elu.input_operand_id, /*quantize_params=*/0,
|
||||
/*operation_supports_float16=*/false, fuse_dequantize));
|
||||
|
||||
TensorIndex output_tensor_index;
|
||||
if (fuse_dequantize) {
|
||||
output_tensor_index = quantized_output->index;
|
||||
} else {
|
||||
ASSIGN_OR_RETURN(const TensorInfo& output_tensor_info,
|
||||
SerializeOutputTensorInfo(elu.output_operand_id));
|
||||
output_tensor_index = output_tensor_info.index;
|
||||
}
|
||||
|
||||
return SerializeUnaryOperation(::tflite::BuiltinOperator_ELU,
|
||||
input_tensor_info.index,
|
||||
output_tensor_info.index);
|
||||
input_tensor_info.index, output_tensor_index);
|
||||
}
|
||||
|
||||
auto GraphBuilderTflite::SerializeErf(const TensorInfo& input_tensor_info,
|
||||
|
@ -721,6 +721,7 @@ class GraphBuilderTflite final {
|
||||
const mojom::Concat& concat);
|
||||
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(
|
||||
const mojom::ElementWiseBinary& binary);
|
||||
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(const mojom::Elu& elu);
|
||||
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(
|
||||
const mojom::Transpose& transpose);
|
||||
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(
|
||||
|
83
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/qdq_subgraph.https.any.js
vendored
83
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/qdq_subgraph.https.any.js
vendored
@ -924,6 +924,89 @@ const subgraphTests = [
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
'name': 'quantized elu',
|
||||
'graph': {
|
||||
'inputs': {
|
||||
'input': {
|
||||
'data': [
|
||||
1.6811466217041016, 0.0479511022567749, 0.33355462551116943,
|
||||
-0.1988269537687301, -0.0041167140007019, -0.0634240251779556,
|
||||
],
|
||||
'descriptor': {shape: [2, 3], dataType: 'float32'},
|
||||
'constant': false
|
||||
},
|
||||
'inputScale': {
|
||||
'data': [0.003921568859368563],
|
||||
'descriptor': {shape: [1], dataType: 'float32'},
|
||||
'constant': true
|
||||
},
|
||||
'inputZeroPoint': {
|
||||
'data': [0],
|
||||
'descriptor': {shape: [1], dataType: 'int8'},
|
||||
'constant': true
|
||||
},
|
||||
'outputScale': {
|
||||
'data': [0.003921568859368563],
|
||||
'descriptor': {shape: [1], dataType: 'float32'},
|
||||
'constant': true
|
||||
},
|
||||
'outputZeroPoint': {
|
||||
'data': [0],
|
||||
'descriptor': {shape: [1], dataType: 'int8'},
|
||||
'constant': true
|
||||
},
|
||||
},
|
||||
'operators': [
|
||||
{
|
||||
'name': 'quantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'input'},
|
||||
{'scale': 'inputScale', 'zeroPoint': 'inputZeroPoint'}
|
||||
],
|
||||
'outputs': 'quantizedInput'
|
||||
},
|
||||
{
|
||||
'name': 'dequantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'quantizedInput'},
|
||||
{'scale': 'inputScale', 'zeroPoint': 'inputZeroPoint'}
|
||||
],
|
||||
'outputs': 'dequantizedInput'
|
||||
},
|
||||
{
|
||||
'name': 'elu',
|
||||
'arguments': [{'input': 'dequantizedInput'}],
|
||||
'outputs': 'eluOutput'
|
||||
},
|
||||
{
|
||||
'name': 'quantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'eluOutput'},
|
||||
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||
],
|
||||
'outputs': 'quantizedeluOutput'
|
||||
},
|
||||
{
|
||||
'name': 'dequantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'quantizedeluOutput'},
|
||||
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||
],
|
||||
'outputs': 'output'
|
||||
}
|
||||
],
|
||||
'expectedOutputs': {
|
||||
'output': {
|
||||
'data': [
|
||||
0.49803924560546875, 0.0470588281750679, 0.3333333432674408,
|
||||
-0.18039216101169586, -0.003921568859368563, -0.062745101749897,
|
||||
],
|
||||
'descriptor': {shape: [2, 3], dataType: 'float32'}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
];
|
||||
|
||||
if (navigator.ml) {
|
||||
|
Reference in New Issue
Block a user