[WebNN] Fuse QDQ for split on tflite
This CL supports fusing `dq->split->q` subgraph on tflite. Input and output operands have to be dequantized from ints8, the scale and zero point of input and output have to be scaler. For XNNPack delegate, the number of outputs should be in the range of [2, 4], but there is no limitation on the number of outputs for TFLite kernel, so relax the output number restriction in this CL. Bug: 401281047 Change-Id: Iec60c9df31c40cb88502dceab551a3581a0adfde Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6532451 Reviewed-by: Phillis Tang <phillis@chromium.org> Reviewed-by: ningxin hu <ningxin.hu@intel.com> Commit-Queue: Wei4 Wang <wei4.wang@intel.com> Cr-Commit-Position: refs/heads/main@{#1460494}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
656fbe59c0
commit
dadde09297
services/webnn/tflite
third_party/blink/web_tests/external/wpt/webnn/conformance_tests
@ -1790,6 +1790,53 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput(const mojom::Reshape& reshape) {
|
||||
return SerializeQuantizedOutput(*next_op);
|
||||
}
|
||||
|
||||
std::optional<base::FixedArray<GraphBuilderTflite::TensorInfo>>
|
||||
GraphBuilderTflite::CanFuseQuantizeAndGetOutput(const mojom::Split& split) {
|
||||
if (!IsDequantizeOutput(split.input_operand_id)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// TODO(crbug.com/413083273): Consider the restriction in GPU delegate.
|
||||
// For XNNPack delegate, the scale and zero point of input and output have to
|
||||
// be scaler, and the number of outputs should be in the range of [2, 4]. But
|
||||
// there is no limitation on the number of outputs for TFLite kernel, so relax
|
||||
// the output number restriction here.
|
||||
// https://source.chromium.org/chromium/chromium/src/+/main:third_party/tflite/src/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc;l=5558;drc=1379ddb0f0535ff846ce0fbad8ee49af303140c4
|
||||
const mojom::DequantizeLinear& input_dequantize =
|
||||
GetDequantizeOp(split.input_operand_id);
|
||||
if (!IsInts8AndScalarScale(input_dequantize)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
const size_t outputs_size = split.output_operand_ids.size();
|
||||
const OperandDataType quantized_type =
|
||||
GetOperand(input_dequantize.input_operand_id).descriptor.data_type();
|
||||
base::FixedArray<std::pair<OperationId, QuantizateParametersOffset>>
|
||||
quantize_ops(outputs_size);
|
||||
for (size_t i = 0; i < outputs_size; ++i) {
|
||||
std::optional<std::pair<OperationId, QuantizateParametersOffset>> next_op =
|
||||
IsNextOpQuantize(split.output_operand_ids[i], {quantized_type});
|
||||
if (!next_op) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
OperationId quantize_op_id = next_op->first;
|
||||
const mojom::QuantizeLinear& output_quantize =
|
||||
GetQuantizeOp(quantize_op_id);
|
||||
if (!IsInts8AndScalarScale(output_quantize)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
quantize_ops[i] = std::move(*next_op);
|
||||
}
|
||||
|
||||
base::FixedArray<TensorInfo> output_tensor_infos(outputs_size);
|
||||
for (size_t i = 0; i < outputs_size; ++i) {
|
||||
output_tensor_infos[i] = SerializeQuantizedOutput(quantize_ops[i]);
|
||||
}
|
||||
|
||||
return output_tensor_infos;
|
||||
}
|
||||
|
||||
std::optional<GraphBuilderTflite::TensorInfo>
|
||||
GraphBuilderTflite::CanFuseQuantizeAndGetOutput(
|
||||
const mojom::Transpose& transpose) {
|
||||
@ -6657,17 +6704,31 @@ auto GraphBuilderTflite::SerializeSplit(const mojom::Split& split)
|
||||
/*buffer=*/std::array<int32_t, 1>{checked_axis.ValueOrDie()},
|
||||
/*dimensions=*/{});
|
||||
|
||||
std::optional<base::FixedArray<TensorInfo>> quantized_outputs =
|
||||
CanFuseQuantizeAndGetOutput(split);
|
||||
const bool fuse_dequantize = quantized_outputs.has_value();
|
||||
ASSIGN_OR_RETURN(const TensorInfo& input_tensor_info,
|
||||
SerializeInputTensorInfo(
|
||||
split.input_operand_id, /*quantize_params=*/0,
|
||||
/*operation_supports_float16=*/false, fuse_dequantize));
|
||||
|
||||
// Serialize the split sizes tensor that specifies the sizes of each output
|
||||
// tensor along the axis.
|
||||
const size_t outputs_size = split.output_operand_ids.size();
|
||||
base::FixedArray<int32_t> split_sizes(outputs_size);
|
||||
base::FixedArray<int32_t> op_outputs(outputs_size);
|
||||
for (size_t i = 0; i < outputs_size; ++i) {
|
||||
const TensorInfo output_tensor_info =
|
||||
SerializeOutputTensorInfo(split.output_operand_ids[i]);
|
||||
CHECK_LT(split.axis, output_tensor_info.dimensions.size());
|
||||
split_sizes[i] = output_tensor_info.dimensions[split.axis];
|
||||
op_outputs[i] = output_tensor_info.index;
|
||||
if (fuse_dequantize) {
|
||||
CHECK_LT(split.axis, quantized_outputs->at(i).dimensions.size());
|
||||
split_sizes[i] = quantized_outputs->at(i).dimensions[split.axis];
|
||||
op_outputs[i] = quantized_outputs->at(i).index;
|
||||
} else {
|
||||
const TensorInfo output_tensor_info =
|
||||
SerializeOutputTensorInfo(split.output_operand_ids[i]);
|
||||
CHECK_LT(split.axis, output_tensor_info.dimensions.size());
|
||||
split_sizes[i] = output_tensor_info.dimensions[split.axis];
|
||||
op_outputs[i] = output_tensor_info.index;
|
||||
}
|
||||
}
|
||||
const auto checked_split_size =
|
||||
base::MakeCheckedNum<int32_t>(split_sizes.size());
|
||||
@ -6683,8 +6744,6 @@ auto GraphBuilderTflite::SerializeSplit(const mojom::Split& split)
|
||||
const auto split_options = ::tflite::CreateSplitOptions(
|
||||
builder_, /*num_splits=*/checked_split_size.ValueOrDie());
|
||||
|
||||
ASSIGN_OR_RETURN(const TensorInfo& input_tensor_info,
|
||||
SerializeInputTensorInfo(split.input_operand_id));
|
||||
const OperatorCodeIndex operator_code_index =
|
||||
GetOperatorCodeIndex(::tflite::BuiltinOperator_SPLIT_V);
|
||||
// The order of inputs is input, split sizes tensor and then axis tensor as
|
||||
|
@ -726,6 +726,8 @@ class GraphBuilderTflite final {
|
||||
const mojom::Pool2d& pool2d);
|
||||
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(
|
||||
const mojom::Reshape& reshape);
|
||||
std::optional<base::FixedArray<TensorInfo>> CanFuseQuantizeAndGetOutput(
|
||||
const mojom::Split& split);
|
||||
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(
|
||||
const mojom::Transpose& transpose);
|
||||
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(
|
||||
|
126
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/qdq_subgraph.https.any.js
vendored
126
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/qdq_subgraph.https.any.js
vendored
@ -1471,6 +1471,132 @@ const subgraphTests = [
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
'name': 'quantized split',
|
||||
'graph': {
|
||||
'inputs': {
|
||||
'input': {
|
||||
'data': [
|
||||
1.6811466217041016, 0.0479511022567749, 0.33355462551116943,
|
||||
-0.1988269537687301, -0.0041167140007019, -0.0634240251779556,
|
||||
],
|
||||
'descriptor': {shape: [2, 3], dataType: 'float32'},
|
||||
'constant': false
|
||||
},
|
||||
'inputScale': {
|
||||
'data': [0.003921568859368563],
|
||||
'descriptor': {shape: [1], dataType: 'float32'},
|
||||
'constant': true
|
||||
},
|
||||
'inputZeroPoint': {
|
||||
'data': [16],
|
||||
'descriptor': {shape: [1], dataType: 'int8'},
|
||||
'constant': true
|
||||
},
|
||||
'outputScale': {
|
||||
'data': [0.003921568859368563],
|
||||
'descriptor': {shape: [1], dataType: 'float32'},
|
||||
'constant': true
|
||||
},
|
||||
'outputZeroPoint': {
|
||||
'data': [16],
|
||||
'descriptor': {shape: [1], dataType: 'int8'},
|
||||
'constant': true
|
||||
},
|
||||
},
|
||||
'operators': [
|
||||
{
|
||||
'name': 'quantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'input'},
|
||||
{'scale': 'inputScale', 'zeroPoint': 'inputZeroPoint'}
|
||||
],
|
||||
'outputs': 'quantizedInput'
|
||||
},
|
||||
{
|
||||
'name': 'dequantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'quantizedInput'},
|
||||
{'scale': 'inputScale', 'zeroPoint': 'inputZeroPoint'}
|
||||
],
|
||||
'outputs': 'dequantizedInput'
|
||||
},
|
||||
{
|
||||
'name': 'split',
|
||||
'arguments': [{'input': 'dequantizedInput'}, {'splits': 3}, {'options': {'axis': 1}}],
|
||||
'outputs': ['splitOutput 1', 'splitOutput 2', 'splitOutput 3'],
|
||||
},
|
||||
{
|
||||
'name': 'quantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'splitOutput 1'},
|
||||
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||
],
|
||||
'outputs': 'quantizedSplitOutput 1'
|
||||
},
|
||||
{
|
||||
'name': 'dequantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'quantizedSplitOutput 1'},
|
||||
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||
],
|
||||
'outputs': 'output 1'
|
||||
},
|
||||
{
|
||||
'name': 'quantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'splitOutput 2'},
|
||||
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||
],
|
||||
'outputs': 'quantizedSplitOutput 2'
|
||||
},
|
||||
{
|
||||
'name': 'dequantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'quantizedSplitOutput 2'},
|
||||
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||
],
|
||||
'outputs': 'output 2'
|
||||
},
|
||||
{
|
||||
'name': 'quantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'splitOutput 3'},
|
||||
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||
],
|
||||
'outputs': 'quantizedSplitOutput 3'
|
||||
},
|
||||
{
|
||||
'name': 'dequantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'quantizedSplitOutput 3'},
|
||||
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||
],
|
||||
'outputs': 'output 3'
|
||||
}
|
||||
],
|
||||
'expectedOutputs': {
|
||||
'output 1': {
|
||||
'data': [
|
||||
0.43529415130615234, -0.20000001788139343,
|
||||
],
|
||||
'descriptor': {shape: [2, 1], dataType: 'float32'}
|
||||
},
|
||||
'output 2': {
|
||||
'data': [
|
||||
0.0470588281750679, -0.003921568859368563,
|
||||
],
|
||||
'descriptor': {shape: [2, 1], dataType: 'float32'}
|
||||
},
|
||||
'output 3': {
|
||||
'data': [
|
||||
0.3333333432674408, -0.062745101749897,
|
||||
],
|
||||
'descriptor': {shape: [2, 1], dataType: 'float32'}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
];
|
||||
|
||||
if (navigator.ml) {
|
||||
|
Reference in New Issue
Block a user