[WebNN] Fuse QDQ for slice on tflite
This CL supports fusing `dq->slice->q` subgraph on tflite. Input and output operands have to be dequantized from ints8, the scale and zero point of input and output have to be scaler. Bug: 401281047 Change-Id: I2da7afc60ddbfc992ee52059a322b32e6c9e4f92 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6551995 Commit-Queue: Wei4 Wang <wei4.wang@intel.com> Reviewed-by: Phillis Tang <phillis@chromium.org> Reviewed-by: ningxin hu <ningxin.hu@intel.com> Cr-Commit-Position: refs/heads/main@{#1461912}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
5c6f4416a4
commit
3137f2f2e2
services/webnn/tflite
third_party/blink/web_tests/external/wpt/webnn/conformance_tests
@ -1790,6 +1790,38 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput(const mojom::Reshape& reshape) {
|
||||
return SerializeQuantizedOutput(*next_op);
|
||||
}
|
||||
|
||||
std::optional<GraphBuilderTflite::TensorInfo>
|
||||
GraphBuilderTflite::CanFuseQuantizeAndGetOutput(const mojom::Slice& slice) {
|
||||
if (!IsDequantizeOutput(slice.input_operand_id)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// TODO(crbug.com/413083273): Consider the restriction in GPU delegate.
|
||||
// For XNNPack delegate, the scale and zero point of input and output have to
|
||||
// be scaler.
|
||||
// https://source.chromium.org/chromium/chromium/src/+/main:third_party/tflite/src/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc;l=5302;drc=02446d66622a0a811448be7bb4ac8939c5b00aa9
|
||||
const mojom::DequantizeLinear& input_dequantize =
|
||||
GetDequantizeOp(slice.input_operand_id);
|
||||
if (!IsInts8AndScalarScale(input_dequantize)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<std::pair<OperationId, QuantizateParametersOffset>> next_op =
|
||||
IsNextOpQuantize(slice.output_operand_id,
|
||||
{GetOperand(input_dequantize.input_operand_id)
|
||||
.descriptor.data_type()});
|
||||
if (!next_op) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
const mojom::QuantizeLinear& output_quantize = GetQuantizeOp(next_op->first);
|
||||
if (!IsInts8AndScalarScale(output_quantize)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return SerializeQuantizedOutput(*next_op);
|
||||
}
|
||||
|
||||
std::optional<base::FixedArray<GraphBuilderTflite::TensorInfo>>
|
||||
GraphBuilderTflite::CanFuseQuantizeAndGetOutput(const mojom::Split& split) {
|
||||
if (!IsDequantizeOutput(split.input_operand_id)) {
|
||||
@ -6546,10 +6578,18 @@ auto GraphBuilderTflite::SerializeSlice(const mojom::Slice& slice)
|
||||
slice_strides[i] = checked_stride.ValueOrDie();
|
||||
}
|
||||
|
||||
std::optional<TensorInfo> quantized_output =
|
||||
CanFuseQuantizeAndGetOutput(slice);
|
||||
const bool fuse_dequantize = quantized_output.has_value();
|
||||
ASSIGN_OR_RETURN(const TensorInfo& input_tensor_info,
|
||||
SerializeInputTensorInfo(slice.input_operand_id));
|
||||
SerializeInputTensorInfo(
|
||||
slice.input_operand_id, /*quantize_params=*/0,
|
||||
/*operation_supports_float16=*/false, fuse_dequantize));
|
||||
|
||||
const TensorIndex output_tensor_index =
|
||||
SerializeOutputTensorInfo(slice.output_operand_id).index;
|
||||
fuse_dequantize
|
||||
? quantized_output->index
|
||||
: SerializeOutputTensorInfo(slice.output_operand_id).index;
|
||||
|
||||
auto checked_number = base::MakeCheckedNum<int32_t>(slice.ranges.size());
|
||||
if (!checked_number.IsValid()) {
|
||||
|
@ -726,6 +726,8 @@ class GraphBuilderTflite final {
|
||||
const mojom::Pool2d& pool2d);
|
||||
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(
|
||||
const mojom::Reshape& reshape);
|
||||
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(
|
||||
const mojom::Slice& slice);
|
||||
std::optional<base::FixedArray<TensorInfo>> CanFuseQuantizeAndGetOutput(
|
||||
const mojom::Split& split);
|
||||
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(
|
||||
|
82
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/qdq_subgraph.https.any.js
vendored
82
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/qdq_subgraph.https.any.js
vendored
@ -1594,6 +1594,88 @@ const subgraphTests = [
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
'name': 'quantized slice',
|
||||
'graph': {
|
||||
'inputs': {
|
||||
'input': {
|
||||
'data': [
|
||||
1.6811466217041016, 0.0479511022567749, 0.33355462551116943,
|
||||
-0.1988269537687301, -0.0041167140007019, -0.0634240251779556,
|
||||
],
|
||||
'descriptor': {shape: [2, 3], dataType: 'float32'},
|
||||
'constant': false
|
||||
},
|
||||
'inputScale': {
|
||||
'data': [0.003921568859368563],
|
||||
'descriptor': {shape: [1], dataType: 'float32'},
|
||||
'constant': true
|
||||
},
|
||||
'inputZeroPoint': {
|
||||
'data': [16],
|
||||
'descriptor': {shape: [1], dataType: 'int8'},
|
||||
'constant': true
|
||||
},
|
||||
'outputScale': {
|
||||
'data': [0.003921568859368563],
|
||||
'descriptor': {shape: [1], dataType: 'float32'},
|
||||
'constant': true
|
||||
},
|
||||
'outputZeroPoint': {
|
||||
'data': [16],
|
||||
'descriptor': {shape: [1], dataType: 'int8'},
|
||||
'constant': true
|
||||
},
|
||||
},
|
||||
'operators': [
|
||||
{
|
||||
'name': 'quantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'input'},
|
||||
{'scale': 'inputScale', 'zeroPoint': 'inputZeroPoint'}
|
||||
],
|
||||
'outputs': 'quantizedInput'
|
||||
},
|
||||
{
|
||||
'name': 'dequantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'quantizedInput'},
|
||||
{'scale': 'inputScale', 'zeroPoint': 'inputZeroPoint'}
|
||||
],
|
||||
'outputs': 'dequantizedInput'
|
||||
},
|
||||
{
|
||||
'name': 'slice',
|
||||
'arguments': [{'input': 'dequantizedInput'}, {'starts': [0, 1]}, {'sizes': [1, 2]}],
|
||||
'outputs': 'sliceOutput'
|
||||
},
|
||||
{
|
||||
'name': 'quantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'sliceOutput'},
|
||||
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||
],
|
||||
'outputs': 'quantizedSliceOutput'
|
||||
},
|
||||
{
|
||||
'name': 'dequantizeLinear',
|
||||
'arguments': [
|
||||
{'input': 'quantizedSliceOutput'},
|
||||
{'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'}
|
||||
],
|
||||
'outputs': 'output'
|
||||
}
|
||||
],
|
||||
'expectedOutputs': {
|
||||
'output': {
|
||||
'data': [
|
||||
0.0470588281750679, 0.3333333432674408,
|
||||
],
|
||||
'descriptor': {shape: [1, 2], dataType: 'float32'}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
];
|
||||
|
||||
if (navigator.ml) {
|
||||
|
Reference in New Issue
Block a user