0

[WebNN] Fuse QDQ for split on tflite

This CL supports fusing `dq->split->q` subgraph on tflite.

Input and output operands have to be dequantized from ints8, the scale
and zero point of input and output have to be scaler. For XNNPack
delegate, the number of outputs should be in the range of [2, 4], but
there is no limitation on the number of outputs for TFLite kernel, so
relax the output number restriction in this CL.

Bug: 401281047
Change-Id: Iec60c9df31c40cb88502dceab551a3581a0adfde
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6532451
Reviewed-by: Phillis Tang <phillis@chromium.org>
Reviewed-by: ningxin hu <ningxin.hu@intel.com>
Commit-Queue: Wei4 Wang <wei4.wang@intel.com>
Cr-Commit-Position: refs/heads/main@{#1460494}
This commit is contained in:
Wei Wang
2025-05-14 20:34:00 -07:00
committed by Chromium LUCI CQ
parent 656fbe59c0
commit dadde09297
3 changed files with 194 additions and 7 deletions
services/webnn/tflite
third_party/blink/web_tests/external/wpt/webnn/conformance_tests

@ -1790,6 +1790,53 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput(const mojom::Reshape& reshape) {
return SerializeQuantizedOutput(*next_op);
}
std::optional<base::FixedArray<GraphBuilderTflite::TensorInfo>>
GraphBuilderTflite::CanFuseQuantizeAndGetOutput(const mojom::Split& split) {
if (!IsDequantizeOutput(split.input_operand_id)) {
return std::nullopt;
}
// TODO(crbug.com/413083273): Consider the restriction in GPU delegate.
// For XNNPack delegate, the scale and zero point of input and output have to
// be scaler, and the number of outputs should be in the range of [2, 4]. But
// there is no limitation on the number of outputs for TFLite kernel, so relax
// the output number restriction here.
// https://source.chromium.org/chromium/chromium/src/+/main:third_party/tflite/src/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc;l=5558;drc=1379ddb0f0535ff846ce0fbad8ee49af303140c4
const mojom::DequantizeLinear& input_dequantize =
GetDequantizeOp(split.input_operand_id);
if (!IsInts8AndScalarScale(input_dequantize)) {
return std::nullopt;
}
const size_t outputs_size = split.output_operand_ids.size();
const OperandDataType quantized_type =
GetOperand(input_dequantize.input_operand_id).descriptor.data_type();
base::FixedArray<std::pair<OperationId, QuantizateParametersOffset>>
quantize_ops(outputs_size);
for (size_t i = 0; i < outputs_size; ++i) {
std::optional<std::pair<OperationId, QuantizateParametersOffset>> next_op =
IsNextOpQuantize(split.output_operand_ids[i], {quantized_type});
if (!next_op) {
return std::nullopt;
}
OperationId quantize_op_id = next_op->first;
const mojom::QuantizeLinear& output_quantize =
GetQuantizeOp(quantize_op_id);
if (!IsInts8AndScalarScale(output_quantize)) {
return std::nullopt;
}
quantize_ops[i] = std::move(*next_op);
}
base::FixedArray<TensorInfo> output_tensor_infos(outputs_size);
for (size_t i = 0; i < outputs_size; ++i) {
output_tensor_infos[i] = SerializeQuantizedOutput(quantize_ops[i]);
}
return output_tensor_infos;
}
std::optional<GraphBuilderTflite::TensorInfo>
GraphBuilderTflite::CanFuseQuantizeAndGetOutput(
const mojom::Transpose& transpose) {
@ -6657,17 +6704,31 @@ auto GraphBuilderTflite::SerializeSplit(const mojom::Split& split)
/*buffer=*/std::array<int32_t, 1>{checked_axis.ValueOrDie()},
/*dimensions=*/{});
std::optional<base::FixedArray<TensorInfo>> quantized_outputs =
CanFuseQuantizeAndGetOutput(split);
const bool fuse_dequantize = quantized_outputs.has_value();
ASSIGN_OR_RETURN(const TensorInfo& input_tensor_info,
SerializeInputTensorInfo(
split.input_operand_id, /*quantize_params=*/0,
/*operation_supports_float16=*/false, fuse_dequantize));
// Serialize the split sizes tensor that specifies the sizes of each output
// tensor along the axis.
const size_t outputs_size = split.output_operand_ids.size();
base::FixedArray<int32_t> split_sizes(outputs_size);
base::FixedArray<int32_t> op_outputs(outputs_size);
for (size_t i = 0; i < outputs_size; ++i) {
const TensorInfo output_tensor_info =
SerializeOutputTensorInfo(split.output_operand_ids[i]);
CHECK_LT(split.axis, output_tensor_info.dimensions.size());
split_sizes[i] = output_tensor_info.dimensions[split.axis];
op_outputs[i] = output_tensor_info.index;
if (fuse_dequantize) {
CHECK_LT(split.axis, quantized_outputs->at(i).dimensions.size());
split_sizes[i] = quantized_outputs->at(i).dimensions[split.axis];
op_outputs[i] = quantized_outputs->at(i).index;
} else {
const TensorInfo output_tensor_info =
SerializeOutputTensorInfo(split.output_operand_ids[i]);
CHECK_LT(split.axis, output_tensor_info.dimensions.size());
split_sizes[i] = output_tensor_info.dimensions[split.axis];
op_outputs[i] = output_tensor_info.index;
}
}
const auto checked_split_size =
base::MakeCheckedNum<int32_t>(split_sizes.size());
@ -6683,8 +6744,6 @@ auto GraphBuilderTflite::SerializeSplit(const mojom::Split& split)
const auto split_options = ::tflite::CreateSplitOptions(
builder_, /*num_splits=*/checked_split_size.ValueOrDie());
ASSIGN_OR_RETURN(const TensorInfo& input_tensor_info,
SerializeInputTensorInfo(split.input_operand_id));
const OperatorCodeIndex operator_code_index =
GetOperatorCodeIndex(::tflite::BuiltinOperator_SPLIT_V);
// The order of inputs is input, split sizes tensor and then axis tensor as

@ -726,6 +726,8 @@ class GraphBuilderTflite final {
const mojom::Pool2d& pool2d);
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(
const mojom::Reshape& reshape);
std::optional<base::FixedArray<TensorInfo>> CanFuseQuantizeAndGetOutput(
const mojom::Split& split);
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(
const mojom::Transpose& transpose);
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(

@ -1471,6 +1471,132 @@ const subgraphTests = [
}
}
},
{
'name': 'quantized split',
'graph': {
'inputs': {
'input': {
'data': [
1.6811466217041016, 0.0479511022567749, 0.33355462551116943,
-0.1988269537687301, -0.0041167140007019, -0.0634240251779556,
],
'descriptor': {shape: [2, 3], dataType: 'float32'},
'constant': false
},
'inputScale': {
'data': [0.003921568859368563],
'descriptor': {shape: [1], dataType: 'float32'},
'constant': true
},
'inputZeroPoint': {
'data': [16],
'descriptor': {shape: [1], dataType: 'int8'},
'constant': true
},
'outputScale': {
'data': [0.003921568859368563],
'descriptor': {shape: [1], dataType: 'float32'},
'constant': true
},
'outputZeroPoint': {
'data': [16],
'descriptor': {shape: [1], dataType: 'int8'},
'constant': true
},
},
'operators': [
{
'name': 'quantizeLinear',
'arguments': [
{'input': 'input'},
{'scale': 'inputScale', 'zeroPoint': 'inputZeroPoint'}
],
'outputs': 'quantizedInput'
},
{
'name': 'dequantizeLinear',
'arguments': [
{'input': 'quantizedInput'},
{'scale': 'inputScale', 'zeroPoint': 'inputZeroPoint'}
],
'outputs': 'dequantizedInput'
},
{
'name': 'split',
'arguments': [{'input': 'dequantizedInput'}, {'splits': 3}, {'options': {'axis': 1}}],
'outputs': ['splitOutput 1', 'splitOutput 2', 'splitOutput 3'],
},
{
'name': 'quantizeLinear',
'arguments': [
{'input': 'splitOutput 1'},
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
],
'outputs': 'quantizedSplitOutput 1'
},
{
'name': 'dequantizeLinear',
'arguments': [
{'input': 'quantizedSplitOutput 1'},
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
],
'outputs': 'output 1'
},
{
'name': 'quantizeLinear',
'arguments': [
{'input': 'splitOutput 2'},
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
],
'outputs': 'quantizedSplitOutput 2'
},
{
'name': 'dequantizeLinear',
'arguments': [
{'input': 'quantizedSplitOutput 2'},
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
],
'outputs': 'output 2'
},
{
'name': 'quantizeLinear',
'arguments': [
{'input': 'splitOutput 3'},
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
],
'outputs': 'quantizedSplitOutput 3'
},
{
'name': 'dequantizeLinear',
'arguments': [
{'input': 'quantizedSplitOutput 3'},
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
],
'outputs': 'output 3'
}
],
'expectedOutputs': {
'output 1': {
'data': [
0.43529415130615234, -0.20000001788139343,
],
'descriptor': {shape: [2, 1], dataType: 'float32'}
},
'output 2': {
'data': [
0.0470588281750679, -0.003921568859368563,
],
'descriptor': {shape: [2, 1], dataType: 'float32'}
},
'output 3': {
'data': [
0.3333333432674408, -0.062745101749897,
],
'descriptor': {shape: [2, 1], dataType: 'float32'}
}
}
}
},
];
if (navigator.ml) {