[WebNN] Make SerializeOperand() and SerializeOutputTensorInfo() return TensorInfo
Operand dimensions have already been validated to be within int32_t, so `SerializeOperand()` should always succeed and it can return TensorInfo instead of base::expected<TensorInfo, std::string>, same with `SerializeOutputTensorInfo()`. For `SerializeInputTensorInfo()`, it can't be changed now, because it calls `SerializeDequantizeLinear()` which may return `base::unexpected`. This CL also moves the calling of `SerializeQuantizeParams()` to `IsNextOpQuantize()` and make it return both the `OperationId` and the `QuantizateParametersOffset`, and let `SerializeQuantizedOutput()` always succeed. Bug: 417373852 Change-Id: I9236c9d22652be8cdfca7502692baa30fff15011 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6539392 Commit-Queue: Wei4 Wang <wei4.wang@intel.com> Reviewed-by: Phillis Tang <phillis@chromium.org> Reviewed-by: ningxin hu <ningxin.hu@intel.com> Cr-Commit-Position: refs/heads/main@{#1459853}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
105eb7e31b
commit
0c34688f73
services/webnn/tflite
File diff suppressed because it is too large
Load Diff
@@ -143,10 +143,10 @@ class GraphBuilderTflite final {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Serialize tensor for input, constant and output operand and return the
|
// Serialize tensor for input, constant and output operand and return the
|
||||||
// tensor information if it's successful. The `override_tensor_type` is used
|
// tensor information. The `override_tensor_type` is used to override the
|
||||||
// to override the tensor type, such as when dequantising a float16 operator
|
// tensor type, such as when dequantising a float16 operator to float32 before
|
||||||
// to float32 before serializing an operator which does not support float32.
|
// serializing an operator which does not support float32.
|
||||||
base::expected<TensorInfo, std::string> SerializeOperand(
|
TensorInfo SerializeOperand(
|
||||||
OperandId operand_id,
|
OperandId operand_id,
|
||||||
QuantizateParametersOffset quantize_params,
|
QuantizateParametersOffset quantize_params,
|
||||||
std::optional<::tflite::TensorType> override_tensor_type = std::nullopt);
|
std::optional<::tflite::TensorType> override_tensor_type = std::nullopt);
|
||||||
@@ -169,7 +169,7 @@ class GraphBuilderTflite final {
|
|||||||
// as input, for example the input data type has been overridden to float32 of
|
// as input, for example the input data type has been overridden to float32 of
|
||||||
// intermediate operands (Reshape), so the output tensor type should be
|
// intermediate operands (Reshape), so the output tensor type should be
|
||||||
// float32 with the argument.
|
// float32 with the argument.
|
||||||
base::expected<TensorInfo, std::string> SerializeOutputTensorInfo(
|
TensorInfo SerializeOutputTensorInfo(
|
||||||
OperandId operand_id,
|
OperandId operand_id,
|
||||||
QuantizateParametersOffset quantize_params = 0,
|
QuantizateParametersOffset quantize_params = 0,
|
||||||
bool operation_supports_float16 = false,
|
bool operation_supports_float16 = false,
|
||||||
@@ -735,12 +735,12 @@ class GraphBuilderTflite final {
|
|||||||
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(
|
std::optional<TensorInfo> CanFuseQuantizeAndGetOutput(
|
||||||
const mojom::LeakyRelu& leaky_relu);
|
const mojom::LeakyRelu& leaky_relu);
|
||||||
// Helper for activation operations to check if specific fusion criteria
|
// Helper for activation operations to check if specific fusion criteria
|
||||||
// required by TFLite are met and return next quantizeLinear operation id
|
// required by TFLite are met and return next quantizeLinear operation
|
||||||
// if so.
|
// information if so.
|
||||||
// This is shared by `tanh`, `sigmoid` and `leakyRelu`.
|
// This is shared by `tanh`, `sigmoid` and `leakyRelu`.
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
std::optional<OperationId> CanFuseQuantizeForActivationOperation(
|
std::optional<std::pair<OperationId, QuantizateParametersOffset>>
|
||||||
const OpType& op);
|
CanFuseQuantizeForActivationOperation(const OpType& op);
|
||||||
bool IsDequantizeOutput(OperandId operand_id);
|
bool IsDequantizeOutput(OperandId operand_id);
|
||||||
// Get the dequantize op by its output operand id.
|
// Get the dequantize op by its output operand id.
|
||||||
const mojom::DequantizeLinear& GetDequantizeOp(OperandId operand_id);
|
const mojom::DequantizeLinear& GetDequantizeOp(OperandId operand_id);
|
||||||
@@ -757,15 +757,15 @@ class GraphBuilderTflite final {
|
|||||||
bool TrySerializeQuantizedInput(
|
bool TrySerializeQuantizedInput(
|
||||||
const mojom::DequantizeLinear& dequantize_linear,
|
const mojom::DequantizeLinear& dequantize_linear,
|
||||||
OperationId operation_index);
|
OperationId operation_index);
|
||||||
// Try to serialize `quantize_linear`'s output with quantization params and
|
// Serialize `quantize_linear`'s output with quantization params and
|
||||||
// mark the `quantize_linear` to be skipped.
|
// mark the `quantize_linear` to be skipped.
|
||||||
std::optional<TensorInfo> TrySerializeQuantizedOutput(
|
TensorInfo SerializeQuantizedOutput(
|
||||||
OperationId quantize_op_idx);
|
std::pair<OperationId, QuantizateParametersOffset> quantize_op_info);
|
||||||
// Check if next op is quantize, if so mark it to-be skipped and return the
|
// Check if next op is quantize and its parameters can be serialized, if so
|
||||||
// quantized output.
|
// mark it to-be skipped and return the quantized output.
|
||||||
std::optional<OperationId> IsNextOpQuantize(
|
std::optional<std::pair<OperationId, QuantizateParametersOffset>>
|
||||||
OperandId output_operand_id,
|
IsNextOpQuantize(OperandId output_operand_id,
|
||||||
SupportedDataTypes supported_quantized_types);
|
SupportedDataTypes supported_quantized_types);
|
||||||
// Check if the input of DequantizeLinear is (u)int8, the output of
|
// Check if the input of DequantizeLinear is (u)int8, the output of
|
||||||
// QuantizeLinear has been validated (u)int8 in `IsNextOpQuantize`, and its
|
// QuantizeLinear has been validated (u)int8 in `IsNextOpQuantize`, and its
|
||||||
// scale and zero point are scalar values.
|
// scale and zero point are scalar values.
|
||||||
|
Reference in New Issue
Block a user