webnn: Use TFLite DEQUANTIZE op when casting float16 to float32
A WebNN graph passing float16 weights to a float32 operator needs to use a cast() op, however TFLite and its delegates expect that cast to be done with the DEQUANTIZE op. This change adds a special case for this scenario. Note, TFLite doesn't support using QUANTIZE for float32 to float16 conversions. Change-Id: I7aac5e2d594f30433362cfaf0476d3342cec00c9 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6557825 Commit-Queue: Reilly Grant <reillyg@chromium.org> Auto-Submit: Reilly Grant <reillyg@chromium.org> Reviewed-by: Phillis Tang <phillis@chromium.org> Cr-Commit-Position: refs/heads/main@{#1461671}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
6d6e5eb50f
commit
1500c9f33b
@ -2445,15 +2445,24 @@ auto GraphBuilderTflite::SerializeCastOperation(
|
||||
::tflite::TensorType input_tensor_type,
|
||||
TensorIndex output_tensor_index,
|
||||
::tflite::TensorType output_tensor_type) -> OperatorOffset {
|
||||
const auto cast_options = ::tflite::CreateCastOptions(
|
||||
builder_, input_tensor_type, output_tensor_type);
|
||||
|
||||
const OperatorCodeIndex operator_code_index =
|
||||
GetOperatorCodeIndex(::tflite::BuiltinOperator_CAST);
|
||||
const std::array<TensorIndex, 1> op_inputs = {input_tensor_index};
|
||||
const std::array<TensorIndex, 1> op_outputs = {output_tensor_index};
|
||||
|
||||
if (input_tensor_type == ::tflite::TensorType_FLOAT16 &&
|
||||
output_tensor_type == ::tflite::TensorType_FLOAT32) {
|
||||
// TFLite expects the DEQUANTIZE operator to be used to pass float16
|
||||
// weights to float32 operators, but WebNN represents this with the cast
|
||||
// operator.
|
||||
return ::tflite::CreateOperator(
|
||||
builder_, GetOperatorCodeIndex(::tflite::BuiltinOperator_DEQUANTIZE),
|
||||
builder_.CreateVector<TensorIndex>(op_inputs),
|
||||
builder_.CreateVector<TensorIndex>(op_outputs));
|
||||
}
|
||||
|
||||
const auto cast_options = ::tflite::CreateCastOptions(
|
||||
builder_, input_tensor_type, output_tensor_type);
|
||||
return ::tflite::CreateOperator(
|
||||
builder_, operator_code_index,
|
||||
builder_, GetOperatorCodeIndex(::tflite::BuiltinOperator_CAST),
|
||||
builder_.CreateVector<TensorIndex>(op_inputs),
|
||||
builder_.CreateVector<TensorIndex>(op_outputs),
|
||||
::tflite::BuiltinOptions_CastOptions, cast_options.Union());
|
||||
|
Reference in New Issue
Block a user