0

webnn: Use TFLite DEQUANTIZE op when casting float16 to float32

A WebNN graph passing float16 weights to a float32 operator needs to
use a cast() op, however TFLite and its delegates expect that cast to
be done with the DEQUANTIZE op. This change adds a special case for
this scenario. Note, TFLite doesn't support using QUANTIZE for float32
to float16 conversions.

Change-Id: I7aac5e2d594f30433362cfaf0476d3342cec00c9
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6557825
Commit-Queue: Reilly Grant <reillyg@chromium.org>
Auto-Submit: Reilly Grant <reillyg@chromium.org>
Reviewed-by: Phillis Tang <phillis@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1461671}
This commit is contained in:
Reilly Grant
2025-05-16 15:52:08 -07:00
committed by Chromium LUCI CQ
parent 6d6e5eb50f
commit 1500c9f33b

@ -2445,15 +2445,24 @@ auto GraphBuilderTflite::SerializeCastOperation(
::tflite::TensorType input_tensor_type,
TensorIndex output_tensor_index,
::tflite::TensorType output_tensor_type) -> OperatorOffset {
const auto cast_options = ::tflite::CreateCastOptions(
builder_, input_tensor_type, output_tensor_type);
const OperatorCodeIndex operator_code_index =
GetOperatorCodeIndex(::tflite::BuiltinOperator_CAST);
const std::array<TensorIndex, 1> op_inputs = {input_tensor_index};
const std::array<TensorIndex, 1> op_outputs = {output_tensor_index};
if (input_tensor_type == ::tflite::TensorType_FLOAT16 &&
output_tensor_type == ::tflite::TensorType_FLOAT32) {
// TFLite expects the DEQUANTIZE operator to be used to pass float16
// weights to float32 operators, but WebNN represents this with the cast
// operator.
return ::tflite::CreateOperator(
builder_, GetOperatorCodeIndex(::tflite::BuiltinOperator_DEQUANTIZE),
builder_.CreateVector<TensorIndex>(op_inputs),
builder_.CreateVector<TensorIndex>(op_outputs));
}
const auto cast_options = ::tflite::CreateCastOptions(
builder_, input_tensor_type, output_tensor_type);
return ::tflite::CreateOperator(
builder_, operator_code_index,
builder_, GetOperatorCodeIndex(::tflite::BuiltinOperator_CAST),
builder_.CreateVector<TensorIndex>(op_inputs),
builder_.CreateVector<TensorIndex>(op_outputs),
::tflite::BuiltinOptions_CastOptions, cast_options.Union());