0

[WebNN] Fuse QDQ for slice on tflite

This CL supports fusing `dq->slice->q` subgraph on tflite.

Input and output operands have to be dequantized from ints8, the scale
and zero point of input and output have to be scaler.

Bug: 401281047
Change-Id: I2da7afc60ddbfc992ee52059a322b32e6c9e4f92
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6551995
Commit-Queue: Wei4 Wang <wei4.wang@intel.com>
Reviewed-by: Phillis Tang <phillis@chromium.org>
Reviewed-by: ningxin hu <ningxin.hu@intel.com>
Cr-Commit-Position: refs/heads/main@{#1461912}
This commit is contained in:
Wei Wang
2025-05-18 19:07:14 -07:00
committed by Chromium LUCI CQ
parent 5c6f4416a4
commit 3137f2f2e2
3 changed files with 126 additions and 2 deletions
services/webnn/tflite
third_party/blink/web_tests/external/wpt/webnn/conformance_tests

@ -1790,6 +1790,38 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput(const mojom::Reshape& reshape) {
return SerializeQuantizedOutput(*next_op);
}
std::optional<GraphBuilderTflite::TensorInfo>
GraphBuilderTflite::CanFuseQuantizeAndGetOutput(const mojom::Slice& slice) {
if (!IsDequantizeOutput(slice.input_operand_id)) {
return std::nullopt;
}
// TODO(crbug.com/413083273): Consider the restriction in GPU delegate.
// For XNNPack delegate, the scale and zero point of input and output have to
// be scaler.
// https://source.chromium.org/chromium/chromium/src/+/main:third_party/tflite/src/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc;l=5302;drc=02446d66622a0a811448be7bb4ac8939c5b00aa9
const mojom::DequantizeLinear& input_dequantize =
GetDequantizeOp(slice.input_operand_id);
if (!IsInts8AndScalarScale(input_dequantize)) {
return std::nullopt;
}
std::optional<std::pair<OperationId, QuantizateParametersOffset>> next_op =
IsNextOpQuantize(slice.output_operand_id,
{GetOperand(input_dequantize.input_operand_id)
.descriptor.data_type()});
if (!next_op) {
return std::nullopt;
}
const mojom::QuantizeLinear& output_quantize = GetQuantizeOp(next_op->first);
if (!IsInts8AndScalarScale(output_quantize)) {
return std::nullopt;
}
return SerializeQuantizedOutput(*next_op);
}
std::optional<base::FixedArray<GraphBuilderTflite::TensorInfo>>
GraphBuilderTflite::CanFuseQuantizeAndGetOutput(const mojom::Split& split) {
if (!IsDequantizeOutput(split.input_operand_id)) {
@ -6546,10 +6578,18 @@ auto GraphBuilderTflite::SerializeSlice(const mojom::Slice& slice)
slice_strides[i] = checked_stride.ValueOrDie();
}
std::optional<TensorInfo> quantized_output =
CanFuseQuantizeAndGetOutput(slice);
const bool fuse_dequantize = quantized_output.has_value();
ASSIGN_OR_RETURN(const TensorInfo& input_tensor_info,
SerializeInputTensorInfo(slice.input_operand_id));
SerializeInputTensorInfo(
slice.input_operand_id, /*quantize_params=*/0,
/*operation_supports_float16=*/false, fuse_dequantize));
const TensorIndex output_tensor_index =
SerializeOutputTensorInfo(slice.output_operand_id).index;
fuse_dequantize
? quantized_output->index
: SerializeOutputTensorInfo(slice.output_operand_id).index;
auto checked_number = base::MakeCheckedNum<int32_t>(slice.ranges.size());
if (!checked_number.IsValid()) {

@ -726,6 +726,8 @@ class GraphBuilderTflite final {
const mojom::Pool2d& pool2d);
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(
const mojom::Reshape& reshape);
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(
const mojom::Slice& slice);
std::optional<base::FixedArray<TensorInfo>> CanFuseQuantizeAndGetOutput(
const mojom::Split& split);
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(

@ -1594,6 +1594,88 @@ const subgraphTests = [
}
}
},
{
'name': 'quantized slice',
'graph': {
'inputs': {
'input': {
'data': [
1.6811466217041016, 0.0479511022567749, 0.33355462551116943,
-0.1988269537687301, -0.0041167140007019, -0.0634240251779556,
],
'descriptor': {shape: [2, 3], dataType: 'float32'},
'constant': false
},
'inputScale': {
'data': [0.003921568859368563],
'descriptor': {shape: [1], dataType: 'float32'},
'constant': true
},
'inputZeroPoint': {
'data': [16],
'descriptor': {shape: [1], dataType: 'int8'},
'constant': true
},
'outputScale': {
'data': [0.003921568859368563],
'descriptor': {shape: [1], dataType: 'float32'},
'constant': true
},
'outputZeroPoint': {
'data': [16],
'descriptor': {shape: [1], dataType: 'int8'},
'constant': true
},
},
'operators': [
{
'name': 'quantizeLinear',
'arguments': [
{'input': 'input'},
{'scale': 'inputScale', 'zeroPoint': 'inputZeroPoint'}
],
'outputs': 'quantizedInput'
},
{
'name': 'dequantizeLinear',
'arguments': [
{'input': 'quantizedInput'},
{'scale': 'inputScale', 'zeroPoint': 'inputZeroPoint'}
],
'outputs': 'dequantizedInput'
},
{
'name': 'slice',
'arguments': [{'input': 'dequantizedInput'}, {'starts': [0, 1]}, {'sizes': [1, 2]}],
'outputs': 'sliceOutput'
},
{
'name': 'quantizeLinear',
'arguments': [
{'input': 'sliceOutput'},
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
],
'outputs': 'quantizedSliceOutput'
},
{
'name': 'dequantizeLinear',
'arguments': [
{'input': 'quantizedSliceOutput'},
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
],
'outputs': 'output'
}
],
'expectedOutputs': {
'output': {
'data': [
0.0470588281750679, 0.3333333432674408,
],
'descriptor': {shape: [1, 2], dataType: 'float32'}
}
}
}
},
];
if (navigator.ml) {