diff --git a/services/webnn/tflite/graph_builder_tflite.cc b/services/webnn/tflite/graph_builder_tflite.cc index 2504c4d365997..d4ef9f4c0ff46 100644 --- a/services/webnn/tflite/graph_builder_tflite.cc +++ b/services/webnn/tflite/graph_builder_tflite.cc @@ -1641,6 +1641,43 @@ GraphBuilderTflite::CanFuseQuantizeAndGetOutput( return TrySerializeQuantizedOutput(*next_op); } +std::optional<GraphBuilderTflite::TensorInfo> +GraphBuilderTflite::CanFuseQuantizeAndGetOutput(const mojom::Elu& elu) { + if (!IsDequantizeOutput(elu.input_operand_id)) { + return std::nullopt; + } + + // TODO(crbug.com/413083273): Consider the restriction in GPU delegate. + // For XNNPack delegate, the input must be dequantized from int8, the input + // and output scale must be scaler. + // https://source.chromium.org/chromium/chromium/src/+/main:third_party/tflite/src/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc;l=4136;drc=f667feb8a5c6f227b49328ce78a062acc4f81187 + const mojom::DequantizeLinear& input_dequantize = + GetDequantizeOp(elu.input_operand_id); + if (GetOperand(input_dequantize.input_operand_id).descriptor.data_type() != + OperandDataType::kInt8) { + return std::nullopt; + } + + if (GetOperand(input_dequantize.scale_operand_id) + .descriptor.NumberOfElements() != 1) { + return std::nullopt; + } + + std::optional<size_t> next_op = + IsNextOpQuantize(elu.output_operand_id, {OperandDataType::kInt8}); + if (!next_op) { + return std::nullopt; + } + + const mojom::QuantizeLinear& output_quantize = GetQuantizeOp(*next_op); + if (GetOperand(output_quantize.scale_operand_id) + .descriptor.NumberOfElements() != 1) { + return std::nullopt; + } + + return TrySerializeQuantizedOutput(*next_op); +} + std::optional<GraphBuilderTflite::TensorInfo> GraphBuilderTflite::CanFuseQuantizeAndGetOutput( const mojom::Transpose& transpose) { @@ -3456,13 +3493,24 @@ auto GraphBuilderTflite::SerializeElu(const mojom::Elu& elu) "Setting a custom alpha is not supported in tflite schema."); } + std::optional<TensorInfo> quantized_output = CanFuseQuantizeAndGetOutput(elu); + const bool fuse_dequantize = quantized_output.has_value(); ASSIGN_OR_RETURN(const TensorInfo& input_tensor_info, - SerializeInputTensorInfo(elu.input_operand_id)); - ASSIGN_OR_RETURN(const TensorInfo& output_tensor_info, - SerializeOutputTensorInfo(elu.output_operand_id)); + SerializeInputTensorInfo( + elu.input_operand_id, /*quantize_params=*/0, + /*operation_supports_float16=*/false, fuse_dequantize)); + + TensorIndex output_tensor_index; + if (fuse_dequantize) { + output_tensor_index = quantized_output->index; + } else { + ASSIGN_OR_RETURN(const TensorInfo& output_tensor_info, + SerializeOutputTensorInfo(elu.output_operand_id)); + output_tensor_index = output_tensor_info.index; + } + return SerializeUnaryOperation(::tflite::BuiltinOperator_ELU, - input_tensor_info.index, - output_tensor_info.index); + input_tensor_info.index, output_tensor_index); } auto GraphBuilderTflite::SerializeErf(const TensorInfo& input_tensor_info, diff --git a/services/webnn/tflite/graph_builder_tflite.h b/services/webnn/tflite/graph_builder_tflite.h index 45273dd344e31..bf39351dba956 100644 --- a/services/webnn/tflite/graph_builder_tflite.h +++ b/services/webnn/tflite/graph_builder_tflite.h @@ -721,6 +721,7 @@ class GraphBuilderTflite final { const mojom::Concat& concat); std::optional<TensorInfo> CanFuseQuantizeAndGetOutput( const mojom::ElementWiseBinary& binary); + std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(const mojom::Elu& elu); std::optional<TensorInfo> CanFuseQuantizeAndGetOutput( const mojom::Transpose& transpose); std::optional<TensorInfo> CanFuseQuantizeAndGetOutput( diff --git a/third_party/blink/web_tests/external/wpt/webnn/conformance_tests/qdq_subgraph.https.any.js b/third_party/blink/web_tests/external/wpt/webnn/conformance_tests/qdq_subgraph.https.any.js index 3b59c3bb49d64..64f37761bbacf 100644 --- a/third_party/blink/web_tests/external/wpt/webnn/conformance_tests/qdq_subgraph.https.any.js +++ b/third_party/blink/web_tests/external/wpt/webnn/conformance_tests/qdq_subgraph.https.any.js @@ -924,6 +924,89 @@ const subgraphTests = [ } } }, + { + 'name': 'quantized elu', + 'graph': { + 'inputs': { + 'input': { + 'data': [ + 1.6811466217041016, 0.0479511022567749, 0.33355462551116943, + -0.1988269537687301, -0.0041167140007019, -0.0634240251779556, + ], + 'descriptor': {shape: [2, 3], dataType: 'float32'}, + 'constant': false + }, + 'inputScale': { + 'data': [0.003921568859368563], + 'descriptor': {shape: [1], dataType: 'float32'}, + 'constant': true + }, + 'inputZeroPoint': { + 'data': [0], + 'descriptor': {shape: [1], dataType: 'int8'}, + 'constant': true + }, + 'outputScale': { + 'data': [0.003921568859368563], + 'descriptor': {shape: [1], dataType: 'float32'}, + 'constant': true + }, + 'outputZeroPoint': { + 'data': [0], + 'descriptor': {shape: [1], dataType: 'int8'}, + 'constant': true + }, + }, + 'operators': [ + { + 'name': 'quantizeLinear', + 'arguments': [ + {'input': 'input'}, + {'scale': 'inputScale', 'zeroPoint': 'inputZeroPoint'} + ], + 'outputs': 'quantizedInput' + }, + { + 'name': 'dequantizeLinear', + 'arguments': [ + {'input': 'quantizedInput'}, + {'scale': 'inputScale', 'zeroPoint': 'inputZeroPoint'} + ], + 'outputs': 'dequantizedInput' + }, + { + 'name': 'elu', + 'arguments': [{'input': 'dequantizedInput'}], + 'outputs': 'eluOutput' + }, + { + 'name': 'quantizeLinear', + 'arguments': [ + {'input': 'eluOutput'}, + {'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'} + ], + 'outputs': 'quantizedeluOutput' + }, + { + 'name': 'dequantizeLinear', + 'arguments': [ + {'input': 'quantizedeluOutput'}, + {'scale': 'outputScale', 'zeroPoint': 'outputZeroPoint'} + ], + 'outputs': 'output' + } + ], + 'expectedOutputs': { + 'output': { + 'data': [ + 0.49803924560546875, 0.0470588281750679, 0.3333333432674408, + -0.18039216101169586, -0.003921568859368563, -0.062745101749897, + ], + 'descriptor': {shape: [2, 3], dataType: 'float32'} + } + } + } + }, ]; if (navigator.ml) {