0

Revert "WebNN: support constant tensors"

This reverts commit fd00312e80.

Reason for revert: See bug

Bug: 418078503
Original change's description:
> WebNN: support constant tensors
>
> Allows MLTensor to be input to constant() so weights can be
> reused on-device between multiple builds on the same builder
> or different builders. This eliminates the need to keep the original
> JS input data and lowers CPU memory usage.
>
> To keep the CL size in check, only the DML backend was enabled.
>
> More specifically:
> * Adds constant usage to MLTensor.
> * Allows tensors to be initialized from a supplied JS buffer.
> * Supports graph builds using weights from tensors.
>
> Restrictions:
> * Constant tensors cannot be dispatched.
> * Constant tensors must be initialized.
> * Constant tensors must remain static.
>
> https://github.com/webmachinelearning/webnn/issues/760
>
> Bug: 332350952
> Change-Id: Ib18dfe06ead6728172355f2a540e3faeec99917b
> Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6075601
> Reviewed-by: Alex Gough <ajgo@chromium.org>
> Reviewed-by: Reilly Grant <reillyg@chromium.org>
> Commit-Queue: Bryan Bernhart <bryan.bernhart@intel.com>
> Cr-Commit-Position: refs/heads/main@{#1460981}

Bug: 332350952
No-Presubmit: true
No-Tree-Checks: true
No-Try: true
Change-Id: I01c02f3fafef3d68f483f88e30fc5fc7a5d70740
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6552231
Bot-Commit: Rubber Stamper <rubber-stamper@appspot.gserviceaccount.com>
Owners-Override: Stefan Zager <szager@google.com>
Cr-Commit-Position: refs/heads/main@{#1461097}
This commit is contained in:
Stefan Zager
2025-05-15 17:46:06 -07:00
parent da58b857d8
commit 2740ff60b0
48 changed files with 89 additions and 642 deletions

@ -47,7 +47,6 @@ class API_AVAILABLE(macos(14.0)) ContextImplCoreml final
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
CreateGraphImplCallback callback) override;
void CreateTensorImpl(

@ -41,26 +41,17 @@ void ContextImplCoreml::CreateGraphImpl(
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
CreateGraphImplCallback callback) {
GraphImplCoreml::CreateAndBuild(
std::move(receiver), this, std::move(graph_info),
std::move(compute_resource_info), std::move(constant_operands),
std::move(constant_tensor_operands), options().Clone(), properties(),
std::move(callback));
options().Clone(), properties(), std::move(callback));
}
void ContextImplCoreml::CreateTensorImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
mojom::TensorInfoPtr tensor_info,
CreateTensorImplCallback callback) {
// TODO(crbug.com/332350952): implement constant tensors for CoreML.
if (tensor_info->usage.Has(MLTensorUsageFlags::kGraphConstant)) {
std::move(callback).Run(base::unexpected(
mojom::Error::New(mojom::Error::Code::kNotSupportedError,
"Creation of constant tensors is not supported.")));
return;
}
std::move(callback).Run(TensorImplCoreml::Create(std::move(receiver), this,
std::move(tensor_info)));
}

@ -48,7 +48,6 @@ class API_AVAILABLE(macos(14.0)) GraphImplCoreml final : public WebNNGraphImpl {
ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
mojom::CreateContextOptionsPtr context_options,
ContextProperties context_properties,
WebNNContextImpl::CreateGraphImplCallback callback);

@ -374,7 +374,6 @@ void GraphImplCoreml::CreateAndBuild(
ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
mojom::CreateContextOptionsPtr context_options,
ContextProperties context_properties,
WebNNContextImpl::CreateGraphImplCallback callback) {

@ -613,7 +613,6 @@ void ContextImplDml::CreateGraphImpl(
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
WebNNContextImpl::CreateGraphImplCallback callback) {
if (g_backend_for_testing) {
g_backend_for_testing->CreateGraphImpl(std::move(receiver), this,
@ -625,8 +624,7 @@ void ContextImplDml::CreateGraphImpl(
GraphImplDml::CreateAndBuild(
std::move(receiver), adapter_, weak_factory_.GetWeakPtr(),
std::move(graph_info), std::move(compute_resource_info),
std::move(constant_operands), std::move(constant_tensor_operands),
std::move(callback),
std::move(constant_operands), std::move(callback),
gpu_feature_info_->IsWorkaroundEnabled(
gpu::DISABLE_DML_META_COMMANDS_FOR_GPU));
}

@ -90,7 +90,6 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) ContextImplDml final
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
CreateGraphImplCallback callback) override;
void CreateTensorImpl(

@ -123,7 +123,7 @@ CreateTensorSuccess CreateWebNNTensor(
*OperandDescriptor::Create(webnn::GetContextPropertiesForTesting(),
data_type, shape, "tensor"),
MLTensorUsage{MLTensorUsageFlags::kWrite, MLTensorUsageFlags::kRead}),
mojo_base::BigBuffer(0), create_tensor_future.GetCallback());
create_tensor_future.GetCallback());
mojom::CreateTensorResultPtr create_tensor_result =
create_tensor_future.Take();
mojo::AssociatedRemote<mojom::WebNNTensor> webnn_tensor_remote;

@ -5970,7 +5970,6 @@ void GraphImplDml::OnCompilationComplete(
ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
base::expected<ComPtr<IDMLCompiledOperator>, HRESULT> compilation_result) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::OnCompilationComplete");
@ -6114,23 +6113,6 @@ void GraphImplDml::OnCompilationComplete(
std::move(buffer_binding);
}
}
// The tensors used for constants must be bound during operator initialization
// and not during execution.
for (auto& [constant_id, constant_tensor] : constant_tensor_operands) {
TensorImplDml* constant_tensor_impl =
static_cast<TensorImplDml*>(constant_tensor);
// Get the graph input index with the constant id.
const auto graph_input_index_iterator =
constant_id_to_input_index_map.find(constant_id);
CHECK(graph_input_index_iterator != constant_id_to_input_index_map.end());
input_buffer_binding[graph_input_index_iterator->second] =
DML_BUFFER_BINDING{
.Buffer = constant_tensor_impl->buffer(),
.Offset = 0,
.SizeInBytes = constant_tensor_impl->PackedByteLength()};
}
DML_BUFFER_ARRAY_BINDING input_buffer_array_binding{
.BindingCount = base::checked_cast<uint32_t>(input_buffer_binding.size()),
.Bindings = input_buffer_binding.data()};
@ -6330,7 +6312,6 @@ base::expected<void, mojom::ErrorPtr> GraphImplDml::CreateAndBuildInternal(
mojom::GraphInfoPtr& graph_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>&
constant_operands,
const base::flat_map<OperandId, WebNNTensorImpl*>& constant_tensor_operands,
GraphBuilderDml& graph_builder,
absl::flat_hash_map<OperandId, uint32_t>& constant_id_to_input_index_map,
GraphBufferBindingInfo& graph_buffer_binding_info) {
@ -6363,19 +6344,6 @@ base::expected<void, mojom::ErrorPtr> GraphImplDml::CreateAndBuildInternal(
constant_id_to_input_index_map);
}
// Add constant tensors which are considered read-only inputs that must be
// bound during graph initialization.
for (const auto& [constant_id, tensor_impl] : constant_tensor_operands) {
const Node* node = graph_builder.CreateInputNode();
constant_id_to_input_index_map[constant_id] =
node->AsInputNode()->GetGraphInputIndex();
TensorDesc tensor_desc(GetTensorDataType(tensor_impl->data_type()),
DML_TENSOR_FLAG_OWNED_BY_DML, tensor_impl->shape());
const NodeOutput* output =
graph_builder.CreateNodeOutput(node, std::move(tensor_desc));
CHECK(id_to_node_output_map.try_emplace(constant_id, output).second);
}
// Fuse the operations in `mojom::GraphInfo` wherever possible to optimize the
// graph's compute performance.
//
@ -6805,7 +6773,6 @@ void GraphImplDml::CreateAndBuild(
ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
WebNNContextImpl::CreateGraphImplCallback callback,
const bool disable_dml_meta_commands_for_gpu) {
TRACE_EVENT0("gpu", "dml::GraphImplDml::CreateAndBuild");
@ -6815,8 +6782,8 @@ void GraphImplDml::CreateAndBuild(
base::expected<void, mojom::ErrorPtr> create_operator_result =
GraphImplDml::CreateAndBuildInternal(
context->properties(), adapter, graph_info, constant_operands,
constant_tensor_operands, graph_builder,
constant_id_to_input_index_map, graph_buffer_binding_info);
graph_builder, constant_id_to_input_index_map,
graph_buffer_binding_info);
// TODO(crbug.com/349649099): Handle context lost for operator creation
// failures.
@ -6847,8 +6814,7 @@ void GraphImplDml::CreateAndBuild(
std::move(adapter), std::move(context), std::move(callback),
std::move(constant_id_to_input_index_map),
std::move(graph_buffer_binding_info),
std::move(compute_resource_info), std::move(constant_operands),
std::move(constant_tensor_operands)));
std::move(compute_resource_info), std::move(constant_operands)));
}
void GraphImplDml::HandleDispatchFailure(std::string_view error_message,

@ -80,8 +80,6 @@ class GraphImplDml final : public WebNNGraphImpl {
mojom::GraphInfoPtr& graph_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>&
constant_operands,
const base::flat_map<OperandId, WebNNTensorImpl*>&
constant_tensor_operands,
GraphBuilderDml& graph_builder,
absl::flat_hash_map<OperandId, uint32_t>& constant_id_to_input_index_map,
GraphBufferBindingInfo& graph_buffer_binding_info);
@ -100,7 +98,6 @@ class GraphImplDml final : public WebNNGraphImpl {
ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
WebNNContextImpl::CreateGraphImplCallback callback,
bool disable_dml_meta_commands_for_gpu);
@ -231,7 +228,6 @@ class GraphImplDml final : public WebNNGraphImpl {
ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
base::expected<Microsoft::WRL::ComPtr<IDMLCompiledOperator>, HRESULT>
compilation_result);

@ -19,11 +19,8 @@ enum class MLTensorUsageFlags {
// This tensor can be used with writeTensor().
kWrite,
// This tensor can be used with constant().
kGraphConstant,
kMinValue = kWebGpuInterop,
kMaxValue = kGraphConstant,
kMaxValue = kWrite,
};
using MLTensorUsage = base::EnumSet<MLTensorUsageFlags,

@ -25,10 +25,6 @@ struct StructTraits<webnn::mojom::TensorUsageDataView, webnn::MLTensorUsage> {
return usage.Has(webnn::MLTensorUsageFlags::kRead);
}
static bool graph_constant(const webnn::MLTensorUsage& usage) {
return usage.Has(webnn::MLTensorUsageFlags::kGraphConstant);
}
static bool Read(webnn::mojom::TensorUsageDataView data,
webnn::MLTensorUsage* out) {
out->Clear();
@ -45,13 +41,6 @@ struct StructTraits<webnn::mojom::TensorUsageDataView, webnn::MLTensorUsage> {
out->Put(webnn::MLTensorUsageFlags::kWrite);
}
if (data.graph_constant()) {
if (data.read() || data.write()) {
return false;
}
out->Put(webnn::MLTensorUsageFlags::kGraphConstant);
}
return true;
}
};

@ -4,7 +4,6 @@
module webnn.mojom;
import "mojo/public/mojom/base/big_buffer.mojom";
import "services/webnn/public/mojom/features.mojom";
import "services/webnn/public/mojom/webnn_tensor.mojom";
import "services/webnn/public/mojom/webnn_error.mojom";
@ -39,8 +38,5 @@ interface WebNNContext {
// Called by the renderer process to create `WebNNTensor` message pipe for
// creating platform specific tensors, the WebNN tensor will be validated and
// created. This method guarantees memory allocation on the device.
// Optionally, non-empty tensor data containing values to initialize contents.
// Valid for tensor data to be empty when not being used as graph constants.
CreateTensor(TensorInfo tensor_info, mojo_base.mojom.BigBuffer tensor_data)
=> (CreateTensorResult result);
CreateTensor(TensorInfo tensor_info) => (CreateTensorResult result);
};

@ -1427,10 +1427,6 @@ struct GraphInfo {
// which identifies the respective pending constant operand.
map<uint32, blink.mojom.WebNNPendingConstantToken>
constant_operand_ids_to_handles;
// A map of tensors used for graph constants. The key is the id of
// the constant operand, while the value is a handle to the tensor containing
// the weights.
map<uint32, blink.mojom.WebNNTensorToken> id_to_constant_tensor_operand_map;
};
// WebNNGraph runs in the GPU process and is called by the renderer process to

@ -16,9 +16,6 @@ struct TensorUsage {
bool read;
// This tensor can be used with writeTensor().
bool write;
// This tensor is only allowed to be used as a graph constant.
// A graph constant cannot be modified after it is created.
bool graph_constant;
};
// Description of the WebNNTensor to create.

@ -38,25 +38,16 @@ void ContextImplTflite::CreateGraphImpl(
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
CreateGraphImplCallback callback) {
std::move(callback).Run(GraphImplTflite::CreateAndBuild(
std::move(receiver), std::move(graph_info),
std::move(compute_resource_info), std::move(constant_operands),
std::move(constant_tensor_operands), this));
std::move(compute_resource_info), std::move(constant_operands), this));
}
void ContextImplTflite::CreateTensorImpl(
mojo::PendingAssociatedReceiver<mojom::WebNNTensor> receiver,
mojom::TensorInfoPtr tensor_info,
CreateTensorImplCallback callback) {
// TODO(crbug.com/332350952): implement constant tensors for TFLite.
if (tensor_info->usage.Has(MLTensorUsageFlags::kGraphConstant)) {
std::move(callback).Run(base::unexpected(
mojom::Error::New(mojom::Error::Code::kNotSupportedError,
"Creation of constant tensors is not supported.")));
return;
}
std::move(callback).Run(TensorImplTflite::Create(std::move(receiver), this,
std::move(tensor_info)));
}

@ -39,7 +39,6 @@ class ContextImplTflite final : public WebNNContextImpl {
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
CreateGraphImplCallback callback) override;
void CreateTensorImpl(

@ -289,7 +289,6 @@ GraphImplTflite::CreateAndBuild(
ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
ContextImplTflite* context) {
ASSIGN_OR_RETURN(GraphBuilderTflite::Result result,
GraphBuilderTflite::CreateAndBuild(

@ -40,7 +40,6 @@ class GraphImplTflite final : public WebNNGraphImpl {
ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
ContextImplTflite* context);
GraphImplTflite(const GraphImplTflite&) = delete;

@ -86,7 +86,6 @@ void WebNNContextImpl::CreateGraphBuilder(
void WebNNContextImpl::CreateTensor(
mojom::TensorInfoPtr tensor_info,
mojo_base::BigBuffer tensor_data,
mojom::WebNNContext::CreateTensorCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
@ -95,40 +94,17 @@ void WebNNContextImpl::CreateTensor(
return;
}
if (tensor_info->usage.Has(MLTensorUsageFlags::kGraphConstant)) {
const base::expected<OperandDescriptor, std::string> validated_descriptor =
webnn::OperandDescriptor::Create(
properties_, tensor_info->descriptor.data_type(),
tensor_info->descriptor.shape(), "WebNNGraphConstant");
if (!validated_descriptor.has_value()) {
receiver_.ReportBadMessage(kBadMessageInvalidTensor);
return;
}
if (!properties_.data_type_limits.constant.Has(
validated_descriptor->data_type())) {
receiver_.ReportBadMessage(kBadMessageInvalidTensor);
return;
}
if (tensor_data.size() != validated_descriptor->PackedByteLength()) {
receiver_.ReportBadMessage(kBadMessageInvalidTensor);
return;
}
}
mojo::PendingAssociatedRemote<mojom::WebNNTensor> remote;
auto receiver = remote.InitWithNewEndpointAndPassReceiver();
CreateTensorImpl(std::move(receiver), std::move(tensor_info),
base::BindOnce(&WebNNContextImpl::DidCreateWebNNTensorImpl,
AsWeakPtr(), std::move(callback),
std::move(remote), std::move(tensor_data)));
CreateTensorImpl(
std::move(receiver), std::move(tensor_info),
base::BindOnce(&WebNNContextImpl::DidCreateWebNNTensorImpl, AsWeakPtr(),
std::move(callback), std::move(remote)));
}
void WebNNContextImpl::DidCreateWebNNTensorImpl(
mojom::WebNNContext::CreateTensorCallback callback,
mojo::PendingAssociatedRemote<mojom::WebNNTensor> remote,
mojo_base::BigBuffer tensor_data,
base::expected<std::unique_ptr<WebNNTensorImpl>, mojom::ErrorPtr> result) {
if (!result.has_value()) {
std::move(callback).Run(
@ -136,13 +112,6 @@ void WebNNContextImpl::DidCreateWebNNTensorImpl(
return;
}
// Write the specified values into the tensor. If `tensor_data` is empty,
// the tensor should be left initialized to zero. The `tensor_data` size
// should of been already validated in CreateTensor().
if (tensor_data.size() > 0) {
result.value()->WriteTensorImpl(std::move(tensor_data));
}
auto success = mojom::CreateTensorSuccess::New(std::move(remote),
result.value()->handle());
std::move(callback).Run(

@ -18,7 +18,6 @@
#include "base/types/expected.h"
#include "base/types/optional_ref.h"
#include "base/types/pass_key.h"
#include "mojo/public/cpp/base/big_buffer.h"
#include "mojo/public/cpp/bindings/pending_associated_receiver.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/receiver.h"
@ -105,7 +104,6 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNContextImpl
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands,
CreateGraphImplCallback callback) = 0;
// Pass ownership of a newly-created `graph_impl` to this context.
@ -143,7 +141,6 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNContextImpl
mojo::PendingAssociatedReceiver<mojom::WebNNGraphBuilder> receiver)
override;
void CreateTensor(mojom::TensorInfoPtr tensor_info,
mojo_base::BigBuffer tensor_data,
CreateTensorCallback callback) override;
// This method will be called by `CreateTensor()` after the tensor info is
@ -157,7 +154,6 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNContextImpl
void DidCreateWebNNTensorImpl(
CreateTensorCallback callback,
mojo::PendingAssociatedRemote<mojom::WebNNTensor> remote,
mojo_base::BigBuffer tensor_data,
base::expected<std::unique_ptr<WebNNTensorImpl>, mojom::ErrorPtr> result);
SEQUENCE_CHECKER(sequence_checker_);

@ -24,7 +24,6 @@
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_graph_impl.h"
#include "services/webnn/webnn_pending_constant_operand.h"
#include "services/webnn/webnn_tensor_impl.h"
#include "services/webnn/webnn_utils.h"
// Evaluate `condition`, and if it returns false then return false.
@ -2693,11 +2692,9 @@ bool OperationValidationContext::ValidateOperation(
WebNNGraphBuilderImpl::ValidateGraphSuccessResult::ValidateGraphSuccessResult(
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands)
constant_operands)
: compute_resource_info(std::move(compute_resource_info)),
constant_operands(std::move(constant_operands)),
constant_tensor_operands(std::move(constant_tensor_operands)) {}
constant_operands(std::move(constant_operands)) {}
WebNNGraphBuilderImpl::ValidateGraphSuccessResult::ValidateGraphSuccessResult(
ValidateGraphSuccessResult&&) = default;
@ -2783,7 +2780,6 @@ void WebNNGraphBuilderImpl::CreateGraph(mojom::GraphInfoPtr graph_info,
std::move(receiver), std::move(graph_info),
std::move(validate_graph_result->compute_resource_info),
std::move(validate_graph_result->constant_operands),
std::move(validate_graph_result->constant_tensor_operands),
base::BindOnce(&WebNNGraphBuilderImpl::DidCreateGraph,
weak_factory_.GetWeakPtr(), std::move(callback),
std::move(remote)));
@ -2871,9 +2867,6 @@ WebNNGraphBuilderImpl::ValidateGraphImpl(
std::vector<std::pair<OperandId, std::unique_ptr<WebNNConstantOperand>>>
graph_constants;
graph_constants.reserve(graph_info.constant_operand_ids_to_handles.size());
std::vector<std::pair<OperandId, WebNNTensorImpl*>> graph_constant_tensors;
graph_constant_tensors.reserve(
graph_info.id_to_constant_tensor_operand_map.size());
for (size_t id = 0; id < graph_info.operands.size(); ++id) {
const mojom::OperandPtr& operand = graph_info.operands[id];
@ -2931,33 +2924,6 @@ WebNNGraphBuilderImpl::ValidateGraphImpl(
return std::nullopt;
}
// Constants using tensors for weights.
if (auto id_and_handle_it =
graph_info.id_to_constant_tensor_operand_map.find(id);
id_and_handle_it !=
graph_info.id_to_constant_tensor_operand_map.end()) {
// `id` must correspond to a handle known by the context...
base::optional_ref<WebNNTensorImpl> tensor_impl =
context_->GetWebNNTensorImpl(id_and_handle_it->second);
if (!tensor_impl.has_value()) {
return std::nullopt;
}
// ...whose tensor must have the correct usage.
if (!tensor_impl->usage().Has(MLTensorUsageFlags::kGraphConstant)) {
return std::nullopt;
}
// ...whose data must be compatible with what `operand` expects.
if (!tensor_impl->IsValidWithDescriptor(operand->descriptor)) {
return std::nullopt;
}
graph_constant_tensors.emplace_back(id, tensor_impl.as_ptr());
processed_operands.insert(id);
break;
}
// `id` must correspond to a pending constant operand handle...
auto id_and_handle_it =
graph_info.constant_operand_ids_to_handles.find(id);
@ -3034,11 +3000,6 @@ WebNNGraphBuilderImpl::ValidateGraphImpl(
return std::nullopt;
}
if (graph_constant_tensors.size() !=
graph_info.id_to_constant_tensor_operand_map.size()) {
return std::nullopt;
}
// Validate the operations which are sorted in the topological order.
std::optional<OperationValidationContext::ValidationResult> result =
OperationValidationContext::ValidateOperationsAndGetDependencies(
@ -3074,7 +3035,7 @@ WebNNGraphBuilderImpl::ValidateGraphImpl(
std::move(result->operand_to_dependent_operations),
std::move(result->operand_to_producing_operation),
base::PassKey<WebNNGraphBuilderImpl>()),
std::move(graph_constants), std::move(graph_constant_tensors)};
std::move(graph_constants)};
}
void WebNNGraphBuilderImpl::DestroySelf() {

@ -31,7 +31,6 @@ namespace webnn {
class WebNNConstantOperand;
class WebNNContextImpl;
class WebNNTensorImpl;
// Services-side connection to an `MLGraphBuilder`. Responsible for managing
// data associated with the graph builder.
@ -67,8 +66,7 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNGraphBuilderImpl
ValidateGraphSuccessResult(
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands,
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands);
constant_operands);
~ValidateGraphSuccessResult();
ValidateGraphSuccessResult(const ValidateGraphSuccessResult&) = delete;
@ -85,10 +83,6 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNGraphBuilderImpl
// `keep_builder_resources_for_testing` is false.
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
constant_operands;
// Constant tensors associated with this graph, which will be used during
// graph construction.
base::flat_map<OperandId, WebNNTensorImpl*> constant_tensor_operands;
};
// Transfer ownership of this builder's resources to a returned

@ -92,8 +92,6 @@ class FakeWebNNContextImpl final : public WebNNContextImpl {
WebNNGraphImpl::ComputeResourceInfo compute_resource_info,
base::flat_map<OperandId, std::unique_ptr<WebNNConstantOperand>>
/*constant_operands*/,
base::flat_map<OperandId, WebNNTensorImpl*>
/*constant_tensor_operands*/,
CreateGraphImplCallback callback) override {
// Asynchronously resolve `callback` so there's an opportunity for
// subsequent messages to be (illegally) sent from the `WebNNGraphBuilder`

@ -128,14 +128,6 @@ void WebNNGraphImpl::Dispatch(
if (!input_tensor.has_value()) {
return;
}
// Input MLTensor is always dispatchable, which isnt allowed when used as
// a graph constant.
if (input_tensor->usage().Has(MLTensorUsageFlags::kGraphConstant)) {
receiver_.ReportBadMessage(kBadMessageInvalidTensor);
return;
}
name_to_input_tensors.emplace_back(name, input_tensor.as_ptr());
}
base::flat_map<std::string_view, WebNNTensorImpl*> name_to_input_tensor_map(
@ -158,14 +150,6 @@ void WebNNGraphImpl::Dispatch(
if (!output_tensor.has_value()) {
return;
}
// Output MLTensor is always dispatchable, which isnt allowed when used as
// a graph constant.
if (output_tensor->usage().Has(MLTensorUsageFlags::kGraphConstant)) {
receiver_.ReportBadMessage(kBadMessageInvalidTensor);
return;
}
name_to_output_tensors.emplace_back(name, output_tensor.as_ptr());
}

@ -84,7 +84,7 @@ TensorRemoteAndHandle CreateTensor(
mojo::AssociatedRemote<mojom::WebNNTensor> webnn_tensor_remote;
base::test::TestFuture<mojom::CreateTensorResultPtr> create_tensor_future;
context_remote->CreateTensor(std::move(tensor_info), mojo_base::BigBuffer(0),
context_remote->CreateTensor(std::move(tensor_info),
create_tensor_future.GetCallback());
mojom::CreateTensorResultPtr create_tensor_result =
create_tensor_future.Take();

@ -121,7 +121,6 @@ class FakeWebNNContextImpl final : public WebNNContextImpl {
base::flat_map<
OperandId,
std::unique_ptr<WebNNConstantOperand>> /*constant_operands*/,
base::flat_map<OperandId, WebNNTensorImpl*> /*constant_tensor_operands*/,
CreateGraphImplCallback callback) override {
FakeWebNNGraphImpl::CreateAndBuild(std::move(receiver), this, *graph_info,
std::move(compute_resource_info),
@ -176,7 +175,7 @@ CreateTensorSuccess CreateWebNNTensor(
mojom::TensorInfo::New(
OperandDescriptor::UnsafeCreateForTesting(data_type, shape),
MLTensorUsage()),
mojo_base::BigBuffer(0), create_tensor_future.GetCallback());
create_tensor_future.GetCallback());
mojom::CreateTensorResultPtr create_tensor_result =
create_tensor_future.Take();
mojo::AssociatedRemote<mojom::WebNNTensor> webnn_tensor;

@ -152,7 +152,6 @@ void BuildGraph(const mojolpm::webnn::mojom::GraphInfo& graph_info_proto,
base::test::TestFuture<webnn::mojom::CreateTensorResultPtr>
create_tensor_future;
webnn_context_remote->CreateTensor(std::move(tensor_info),
mojo_base::BigBuffer(0),
create_tensor_future.GetCallback());
webnn::mojom::CreateTensorResultPtr create_tensor_result =
create_tensor_future.Take();
@ -187,7 +186,6 @@ void BuildGraph(const mojolpm::webnn::mojom::GraphInfo& graph_info_proto,
base::test::TestFuture<webnn::mojom::CreateTensorResultPtr>
create_tensor_future;
webnn_context_remote->CreateTensor(std::move(tensor_info),
mojo_base::BigBuffer(0),
create_tensor_future.GetCallback());
webnn::mojom::CreateTensorResultPtr create_tensor_result =
create_tensor_future.Take();

@ -26,11 +26,6 @@ WebNNTensorImpl::WebNNTensorImpl(
WebNNTensorImpl::~WebNNTensorImpl() = default;
bool WebNNTensorImpl::IsValidWithDescriptor(
const OperandDescriptor& descriptor) const {
return descriptor_ == descriptor;
}
void WebNNTensorImpl::ReadTensor(ReadTensorCallback callback) {
if (!usage().Has(MLTensorUsageFlags::kRead)) {
receiver_.ReportBadMessage(kBadMessageInvalidTensor);

@ -43,13 +43,6 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNTensorImpl
return weak_factory_.GetWeakPtr();
}
bool IsValidWithDescriptor(const OperandDescriptor& descriptor) const;
// This method will be called by `WriteTensor()` after the write info is
// validated. A backend subclass should implement this method to write data
// to a platform specific buffer.
virtual void WriteTensorImpl(mojo_base::BigBuffer src_buffer) = 0;
protected:
// This method will be called by `ReadTensor()` after the read info is
// validated. A backend subclass should implement this method to read data
@ -57,6 +50,11 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNTensorImpl
virtual void ReadTensorImpl(
mojom::WebNNTensor::ReadTensorCallback callback) = 0;
// This method will be called by `WriteTensor()` after the write info is
// validated. A backend subclass should implement this method to write data
// to a platform specific buffer.
virtual void WriteTensorImpl(mojo_base::BigBuffer src_buffer) = 0;
// WebNNContextImpl owns this object.
const raw_ptr<WebNNContextImpl> context_;

@ -202,7 +202,6 @@ CreateWebNNTensor(mojo::Remote<mojom::WebNNContext>& webnn_context_remote,
mojom::TensorInfoPtr tensor_info) {
base::test::TestFuture<mojom::CreateTensorResultPtr> create_tensor_future;
webnn_context_remote->CreateTensor(std::move(tensor_info),
mojo_base::BigBuffer(0),
create_tensor_future.GetCallback());
mojom::CreateTensorResultPtr create_tensor_result =
create_tensor_future.Take();
@ -313,7 +312,7 @@ TEST_F(WebNNTensorImplBackendTest, MAYBE_CreateTooLargeTensorTest) {
mojom::TensorInfo::New(OperandDescriptor::UnsafeCreateForTesting(
OperandDataType::kUint8, large_shape),
MLTensorUsage{MLTensorUsageFlags::kWrite}),
mojo_base::BigBuffer(0), std::move(create_tensor_callback));
std::move(create_tensor_callback));
webnn_context_remote.FlushForTesting();
EXPECT_EQ(bad_message_helper.GetLastBadMessage(), kBadMessageInvalidTensor);

@ -1025,7 +1025,7 @@ ScriptPromise<MLTensor> MLContext::createTensor(
//
// This assertion protects against the usage flags changing without updating
// this mapping.
static_assert(base::to_underlying(webnn::MLTensorUsageFlags::kMaxValue) == 3);
static_assert(base::to_underlying(webnn::MLTensorUsageFlags::kMaxValue) == 2);
webnn::MLTensorUsage usage;
if (descriptor->exportableToGPU()) {
usage.Put(webnn::MLTensorUsageFlags::kWebGpuInterop);
@ -1037,9 +1037,6 @@ ScriptPromise<MLTensor> MLContext::createTensor(
usage.Put(webnn::MLTensorUsageFlags::kWrite);
}
// MLTensorUsageFlags::kGraphConstant is only assigned for
// createConstantTensor().
auto tensor_info =
webnn::mojom::blink::TensorInfo::New(validated_descriptor, usage);
@ -1049,84 +1046,7 @@ ScriptPromise<MLTensor> MLContext::createTensor(
// Use `WebNNContext` to create `WebNNTensor` message pipe.
context_remote_->CreateTensor(
std::move(tensor_info), mojo_base::BigBuffer(0),
WTF::BindOnce(&MLContext::DidCreateWebNNTensor, WrapPersistent(this),
std::move(scoped_trace), WrapPersistent(resolver),
std::move(validated_descriptor), usage));
return resolver->Promise();
}
ScriptPromise<MLTensor> MLContext::createConstantTensor(
ScriptState* script_state,
const MLOperandDescriptor* descriptor,
AllowSharedBufferSource* src_data,
ExceptionState& exception_state) {
webnn::ScopedTrace scoped_trace("MLContext::createConstantTensor");
if (!script_state->ContextIsValid()) {
exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError,
"Invalid script state");
return EmptyPromise();
}
if (!base::FeatureList::IsEnabled(
webnn::mojom::features::kWebMachineLearningNeuralNetwork)) {
exception_state.ThrowDOMException(DOMExceptionCode::kNotSupportedError,
"Not implemented");
return EmptyPromise();
}
if (!context_remote_.is_bound()) {
exception_state.ThrowDOMException(DOMExceptionCode::kInvalidStateError,
"Context is lost.");
return EmptyPromise();
}
ASSIGN_OR_RETURN(
webnn::OperandDescriptor validated_descriptor,
webnn::OperandDescriptor::Create(
properties_, FromBlinkDataType(descriptor->dataType().AsEnum()),
descriptor->shape(), "constant_tensor"),
[&exception_state](std::string error) {
exception_state.ThrowTypeError(String(error));
return ScriptPromise<MLTensor>();
});
RETURN_IF_ERROR(webnn::ValidateTensor(properties_, validated_descriptor),
[&exception_state](std::string error) {
exception_state.ThrowTypeError(String(error));
return ScriptPromise<MLTensor>();
});
base::span<const uint8_t> bytes = AsByteSpan(*src_data);
if (validated_descriptor.PackedByteLength() != bytes.size()) {
exception_state.ThrowTypeError(
String::Format("The source data byte length (%zu) doesn't match the "
"expected byte length (%zu).",
bytes.size(), validated_descriptor.PackedByteLength()));
return ScriptPromise<MLTensor>();
}
if (!properties_.data_type_limits.constant.Has(
validated_descriptor.data_type())) {
exception_state.ThrowTypeError(String(webnn::NotSupportedConstantTypeError(
validated_descriptor.data_type(),
properties_.data_type_limits.constant)));
return ScriptPromise<MLTensor>();
}
webnn::MLTensorUsage usage =
webnn::MLTensorUsage{webnn::MLTensorUsageFlags::kGraphConstant};
auto tensor_info =
webnn::mojom::blink::TensorInfo::New(validated_descriptor, usage);
auto* resolver = MakeGarbageCollected<ScriptPromiseResolver<MLTensor>>(
script_state, exception_state.GetContext());
pending_resolvers_.insert(resolver);
// Use `WebNNContext` to create `WebNNTensor` message pipe.
context_remote_->CreateTensor(
std::move(tensor_info), bytes,
std::move(tensor_info),
WTF::BindOnce(&MLContext::DidCreateWebNNTensor, WrapPersistent(this),
std::move(scoped_trace), WrapPersistent(resolver),
std::move(validated_descriptor), usage));

@ -77,12 +77,6 @@ class MODULES_EXPORT MLContext : public ScriptWrappable {
const MLTensorDescriptor* descriptor,
ExceptionState& exception_state);
ScriptPromise<MLTensor> createConstantTensor(
ScriptState* script_state,
const MLOperandDescriptor* descriptor,
AllowSharedBufferSource* src_data,
ExceptionState& exception_state);
void writeTensor(ScriptState* script_state,
MLTensor* dst_tensor,
AllowSharedBufferSource* src_data,

@ -299,14 +299,6 @@ typedef record<USVString, MLTensor> MLNamedTensors;
RaisesException
] Promise<MLTensor> createTensor(MLTensorDescriptor descriptor);
[
RuntimeEnabled=MachineLearningNeuralNetwork,
CallWith=ScriptState,
RaisesException
] Promise<MLTensor> createConstantTensor(
MLOperandDescriptor descriptor,
AllowSharedBufferSource sourceData);
[
RuntimeEnabled=MachineLearningNeuralNetwork,
CallWith=ScriptState,

@ -7,7 +7,6 @@
#include "services/webnn/public/mojom/webnn_graph.mojom-blink.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operand.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_tensor.h"
namespace blink {
@ -17,16 +16,9 @@ MLConstantOperand::MLConstantOperand(MLGraphBuilder* builder,
webnn::mojom::blink::Operand::Kind::kConstant,
std::move(descriptor)) {}
MLConstantOperand::MLConstantOperand(MLGraphBuilder* builder, MLTensor* tensor)
: MLOperand(builder,
webnn::mojom::blink::Operand::Kind::kConstant,
tensor->Descriptor()),
tensor_(tensor) {}
MLConstantOperand::~MLConstantOperand() = default;
void MLConstantOperand::Trace(Visitor* visitor) const {
visitor->Trace(tensor_);
MLOperand::Trace(visitor);
}

@ -14,7 +14,6 @@
namespace blink {
class MLGraphBuilder;
class MLTensor;
// Represents an `MLOperand` created from the `MLGraphBuilder.constant()`
// method. See https://www.w3.org/TR/webnn/#api-mlgraphbuilder-constant.
@ -27,9 +26,6 @@ class MODULES_EXPORT MLConstantOperand final : public MLOperand {
MLConstantOperand(MLGraphBuilder* builder,
webnn::OperandDescriptor descriptor);
// Similar to above but uses a tensor for weight data.
MLConstantOperand(MLGraphBuilder* builder, MLTensor* tensor);
MLConstantOperand(const MLConstantOperand&) = delete;
MLConstantOperand& operator=(const MLConstantOperand&) = delete;
@ -39,13 +35,9 @@ class MODULES_EXPORT MLConstantOperand final : public MLOperand {
const WebNNPendingConstantToken& handle() const { return handle_; }
const MLTensor* tensor() const { return tensor_; }
private:
// Identifies this constant operand in the WebNN service.
const WebNNPendingConstantToken handle_;
Member<MLTensor> tensor_;
};
} // namespace blink

@ -186,11 +186,6 @@ void MLGraph::Dispatch(webnn::ScopedTrace scoped_trace,
return;
}
if (input_tensor->Usage().Has(webnn::MLTensorUsageFlags::kGraphConstant)) {
exception_state.ThrowTypeError("Invalid input tensor usage");
return;
}
mojo_inputs.insert(name, input_tensor->handle());
}
@ -202,11 +197,6 @@ void MLGraph::Dispatch(webnn::ScopedTrace scoped_trace,
return;
}
if (output_tensor->Usage().Has(webnn::MLTensorUsageFlags::kGraphConstant)) {
exception_state.ThrowTypeError("Invalid output tensor usage");
return;
}
mojo_outputs.insert(name, output_tensor->handle());
}

@ -69,7 +69,6 @@
#include "third_party/blink/renderer/modules/ml/webnn/ml_graph_utils.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operand.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_operator.h"
#include "third_party/blink/renderer/modules/ml/webnn/ml_tensor.h"
#include "third_party/blink/renderer/platform/bindings/exception_code.h"
#include "third_party/blink/renderer/platform/bindings/exception_state.h"
#include "third_party/blink/renderer/platform/bindings/script_state.h"
@ -1470,16 +1469,8 @@ blink_mojom::GraphInfoPtr BuildWebNNGraphInfo(
webnn::OperandId operand_id = AddOperand(
*graph_info,
mojo::ConvertTo<blink_mojom::OperandPtr>(operand.Get()));
// Build the map of constant operands for this graph with the id.
MLConstantOperand const* constant_operand =
operand->AsConstantOperand();
if (constant_operand->tensor()) {
graph_info->id_to_constant_tensor_operand_map.insert(
operand_id, constant_operand->tensor()->handle());
} else {
graph_info->constant_operand_ids_to_handles.insert(
operand_id, operand->AsConstantOperand()->handle());
}
graph_info->constant_operand_ids_to_handles.insert(
operand_id, operand->AsConstantOperand()->handle());
operand_to_id_map.insert(operand, operand_id);
break;
}
@ -1731,33 +1722,6 @@ MLOperand* MLGraphBuilder::constant(ScriptState* script_state,
return constant;
}
MLOperand* MLGraphBuilder::constant(ScriptState* script_state,
MLTensor* tensor,
ExceptionState& exception_state) {
THROW_AND_RETURN_IF_ERROR(ValidateGraphBuilderState(), nullptr);
if (tensor->context() != ml_context_) {
exception_state.ThrowTypeError(
"The tensor wasn't created with this context.");
return nullptr;
}
if (!tensor->IsValid()) {
exception_state.ThrowDOMException(
DOMExceptionCode::kInvalidStateError,
"Tensor has been destroyed or context is lost.");
return nullptr;
}
if (!tensor->Usage().Has(webnn::MLTensorUsageFlags::kGraphConstant)) {
exception_state.ThrowTypeError(
"Tensor was not created by createConstantTensor.");
return nullptr;
}
return MakeGarbageCollected<MLConstantOperand>(this, tensor);
}
MLOperand* MLGraphBuilder::argMin(MLOperand* input,
const uint32_t axis,
const MLArgMinMaxOptions* options,

@ -54,7 +54,6 @@ class MLReverseOptions;
class MLScatterOptions;
class MLSliceOptions;
class MLSplitOptions;
class MLTensor;
class MLTransposeOptions;
class MLTriangularOptions;
class MLOperand;
@ -102,9 +101,6 @@ class MODULES_EXPORT MLGraphBuilder final : public ScriptWrappable {
const MLOperandDescriptor* desc,
AllowSharedBufferSource* buffer,
ExceptionState& exception_state);
MLOperand* constant(ScriptState* script_state,
MLTensor* tensor,
ExceptionState& exception_state);
// The order of operations declaration is the same as spec.
MLOperand* argMin(MLOperand* input,

@ -241,11 +241,6 @@ dictionary MLTriangularOptions : MLOperatorOptions {
RaisesException
] MLOperand constant(MLOperandDescriptor desc, AllowSharedBufferSource buffer);
[
CallWith=ScriptState,
RaisesException
] MLOperand constant(MLTensor tensor);
[RaisesException] MLOperand argMin(MLOperand input, [EnforceRange] unsigned long axis, optional MLArgMinMaxOptions options = {});
[RaisesException] MLOperand argMax(MLOperand input, [EnforceRange] unsigned long axis, optional MLArgMinMaxOptions options = {});

@ -499,7 +499,6 @@ class FakeWebNNContext : public blink_mojom::WebNNContext {
}
void CreateTensor(blink_mojom::TensorInfoPtr tensor_info,
mojo_base::BigBuffer tensor_data,
CreateTensorCallback callback) override {
mojo::PendingAssociatedRemote<blink_mojom::WebNNTensor> blink_remote;
auto blink_receiver = blink_remote.InitWithNewEndpointAndPassReceiver();

@ -76,10 +76,6 @@ bool MLTensor::writable() const {
return usage_.Has(webnn::MLTensorUsageFlags::kWrite);
}
bool MLTensor::constant() const {
return usage_.Has(webnn::MLTensorUsageFlags::kGraphConstant);
}
void MLTensor::destroy() {
// Calling OnConnectionError() will disconnect and destroy the tensor in
// the service. The remote tensor must remain unbound after calling

@ -60,7 +60,6 @@ class MODULES_EXPORT MLTensor : public ScriptWrappable {
bool exportableToGPU() const;
bool readable() const;
bool writable() const;
bool constant() const;
void destroy();

@ -13,7 +13,6 @@
readonly attribute boolean exportableToGPU;
readonly attribute boolean readable;
readonly attribute boolean writable;
readonly attribute boolean constant;
void destroy();
};

@ -128,127 +128,6 @@ const testCreateTensorFails = (testName, tensorDescriptor) => {
}, `${testName} / ${tensorDescriptor.dataType}`);
};
/**
* WebNN create constant tensor test.
* @param {String} testName - The name of the test operation.
* @param {MLOperandDescriptor} descriptor - The intended operand specs.
*/
const testCreateConstantTensor = (testName, descriptor) => {
let mlContext;
let isConstantTensorSupported = false;
promise_setup(async () => {
try {
mlContext = await navigator.ml.createContext(contextOptions);
} catch (error) {
throw new AssertionError(
`Unable to create context for ${variant} variant. ${error}`);
}
// Check if WebNN has constant tensor support.
try {
await mlContext.createConstantTensor(
{
dataType: 'float32',
shape: [1],
},
new Float32Array([0xAA]));
isConstantTensorSupported = true;
} catch (error) {
if (error.name !== 'NotSupportedError') {
throw error;
}
}
});
promise_test(async t => {
if (!isConstantTensorSupported) {
return;
}
const inputData =
new TypedArrayDict[descriptor.dataType](sizeOfShape(descriptor.shape))
.fill(0xAA);
if (!mlContext.opSupportLimits().constant.dataTypes.includes(
descriptor.dataType)) {
await promise_rejects_js(
t, TypeError, mlContext.createConstantTensor(descriptor, inputData));
return;
}
const mlTensor =
await mlContext.createConstantTensor(descriptor, inputData);
assert_true(mlTensor.constant, 'constant tensors should be constant.');
assert_false(mlTensor.readable, 'constant tensors should not be readable.');
assert_false(mlTensor.writable, 'constant tensors should not be writable.');
}, `${testName} / ${descriptor.dataType}`);
promise_test(async t => {
if (!isConstantTensorSupported) {
return;
}
try {
const inputDataTooBig = new TypedArrayDict[descriptor.dataType](
sizeOfShape(descriptor.shape) + 1);
await promise_rejects_js(
t, TypeError,
mlContext.createConstantTensor(descriptor, inputDataTooBig));
} catch (error) {
if (error instanceof RangeError) {
return; // Skip test when dataType is too big.
} else {
throw error;
}
}
}, `${testName} / ${descriptor.dataType} / source data too big`);
promise_test(async t => {
if (!isConstantTensorSupported) {
return;
}
try {
const inputDataTooSmall = new TypedArrayDict[descriptor.dataType](
sizeOfShape(descriptor.shape) - 1);
await promise_rejects_js(
t, TypeError,
mlContext.createConstantTensor(descriptor, inputDataTooSmall));
} catch (error) {
if (error instanceof RangeError) {
return; // Skip test when dataType is too big.
} else {
throw error;
}
}
}, `${testName} / ${descriptor.dataType} / source data too small`);
};
/**
* Same as above, but expect constant tensor creation to fail.
* @param {String} testName - The name of the test operation.
* @param {MLOperandDescriptor} descriptor - The intended operand specs.
*/
const testCreateConstantTensorFails = (testName, descriptor) => {
let mlContext;
promise_setup(async () => {
try {
mlContext = await navigator.ml.createContext(contextOptions);
} catch (error) {
throw new AssertionError(
`Unable to create context for ${variant} variant. ${error}`);
}
});
promise_test(async t => {
await promise_rejects_js(
t, TypeError,
mlContext.createConstantTensor(
descriptor,
new TypedArrayDict[descriptor.dataType](
sizeOfShape(descriptor.shape))));
}, `${testName} / ${descriptor.dataType}`);
};
promise_test(async t => {
const tensorDescriptor = {
@ -545,7 +424,6 @@ const testDispatchTensor = (testName) => {
const shape = [3, 5];
let inputs = {};
let outputs = {};
let isConstantTensorSupported = false;
promise_setup(async () => {
try {
mlContext = await navigator.ml.createContext(contextOptions);
@ -553,22 +431,6 @@ const testDispatchTensor = (testName) => {
throw new AssertionError(
`Unable to create context for ${variant} variant. ${e}`);
}
// Check if WebNN has constant tensor support.
try {
await mlContext.createConstantTensor(
{
dataType: 'float32',
shape: [1],
},
new Float32Array([0xAA]));
isConstantTensorSupported = true;
} catch (error) {
if (error.name !== 'NotSupportedError') {
throw error;
}
}
// Construct a simple graph: A = B + C, with two outputs.
const builder = new MLGraphBuilder(mlContext);
const tensorDescriptor = {
@ -1227,98 +1089,6 @@ const testDispatchTensor = (testName) => {
mlContext, dispatchOutputs['output1'],
new Float32Array(sizeOfShape(shape)).fill(3));
}, `${testName} / same name diff outputs tensors destroy`);
promise_test(async () => {
if (!isConstantTensorSupported) {
return;
}
let constantTensor = await mlContext.createConstantTensor(
{
dataType: 'float32',
shape: shape,
},
new Float32Array(sizeOfShape(shape)).fill(3.0));
const builder = new MLGraphBuilder(mlContext);
const lhsConstantOperand = builder.constant(constantTensor);
const rhsConstantOperand = builder.constant(constantTensor);
const outputOperand = builder.add(lhsConstantOperand, rhsConstantOperand);
const graphWithOnlyConstants =
await builder.build({'output': outputOperand});
const outputTensor = await mlContext.createTensor(
getDescriptorFromTensor(outputs['output1']));
// Output = LHS + RHS = 3 + 3 = 6
mlContext.dispatch(graphWithOnlyConstants, {}, {'output': outputTensor});
await assert_tensor_data_equals(
mlContext, outputTensor,
new Float32Array(sizeOfShape(shape)).fill(6.0));
}, `${testName} / same constant same graph`);
promise_test(async () => {
if (!isConstantTensorSupported) {
return;
}
const rhsConstantTensor = await mlContext.createConstantTensor(
{
dataType: 'float32',
shape: shape,
},
new Float32Array(sizeOfShape(shape)).fill(3.0));
const lhsInputOperandDesc = {dataType: 'float32', shape};
let graphWithConstants;
{
const builder = new MLGraphBuilder(mlContext);
const lhsOperand = builder.input('lhs', lhsInputOperandDesc);
const rhsConstantOperand = builder.constant(rhsConstantTensor);
const outputOperand = builder.sub(lhsOperand, rhsConstantOperand);
graphWithConstants = await builder.build({'output': outputOperand});
}
const lhsTensor =
await mlContext.createTensor(getDescriptorFromTensor(inputs['lhs']));
mlContext.writeTensor(
lhsTensor, new Float32Array(sizeOfShape(shape)).fill(5.0));
const outputTensor = await mlContext.createTensor(
getDescriptorFromTensor(outputs['output1']));
// Output = LHS - RHS = 5 - 3 = 2
mlContext.dispatch(
graphWithConstants, {
'lhs': lhsTensor,
},
{'output': outputTensor});
// Create another graph reusing the same constants.
{
const builder = new MLGraphBuilder(mlContext);
const lhsOperand = builder.input('lhs', lhsInputOperandDesc);
const rhsConstantOperand = builder.constant(rhsConstantTensor);
const outputOperand = builder.sub(lhsOperand, rhsConstantOperand);
graphWithConstants = await builder.build({'output': outputOperand});
}
mlContext.writeTensor(
lhsTensor, new Float32Array(sizeOfShape(shape)).fill(4.0));
// Output = LHS - RHS = 4 - 3 = 1
mlContext.dispatch(
graphWithConstants, {
'lhs': lhsTensor,
},
{'output': outputTensor});
await assert_tensor_data_equals(
mlContext, outputTensor,
new Float32Array(sizeOfShape(shape)).fill(1.0));
}, `${testName} / same constant multiple graphs`);
};
if (navigator.ml) {
@ -1334,14 +1104,6 @@ if (navigator.ml) {
shape: [kMaxUnsignedLong, kMaxUnsignedLong, kMaxUnsignedLong]
});
testCreateConstantTensor('createConstant', {dataType: 'int32', shape: [4]});
testCreateConstantTensor(
'createConstant', {dataType: 'uint8', shape: [3, 2, 4]});
testCreateConstantTensorFails(
'createConstantFailsEmptyDimension',
{dataType: 'int32', shape: [2, 0, 3]});
testDestroyTensor('destroyTwice');
testReadTensor('read');
testWriteTensor('write');

@ -1,5 +1,19 @@
This is a testharness.js-based test.
Found 1 FAIL, 0 TIMEOUT, 0 NOTRUN.
Found 8 FAIL, 0 TIMEOUT, 0 NOTRUN.
[FAIL] MLContext interface: operation createConstantTensor(MLOperandDescriptor, AllowSharedBufferSource)
assert_own_property: interface prototype object missing non-static operation expected property "createConstantTensor" missing
[FAIL] MLContext interface: context must inherit property "createConstantTensor(MLOperandDescriptor, AllowSharedBufferSource)" with the proper type
assert_inherits: property "createConstantTensor" not found in prototype chain
[FAIL] MLContext interface: calling createConstantTensor(MLOperandDescriptor, AllowSharedBufferSource) on context with too few arguments must throw TypeError
assert_inherits: property "createConstantTensor" not found in prototype chain
[FAIL] MLTensor interface: attribute constant
assert_true: The prototype object must have a property "constant" expected true got false
[FAIL] MLGraphBuilder interface: operation constant(MLOperandDescriptor, AllowSharedBufferSource)
assert_equals: property has wrong .length expected 1 but got 2
[FAIL] MLGraphBuilder interface: operation constant(MLOperandDataType, MLNumber)
assert_equals: property has wrong .length expected 1 but got 2
[FAIL] MLGraphBuilder interface: operation constant(MLTensor)
assert_equals: property has wrong .length expected 1 but got 2
[FAIL] MLGraphBuilder interface: operation softmax(MLOperand, unsigned long, optional MLOperatorOptions)
assert_equals: property has wrong .length expected 2 but got 1
Harness: the test ran to completion.

@ -1,5 +1,19 @@
This is a testharness.js-based test.
Found 1 FAIL, 0 TIMEOUT, 0 NOTRUN.
Found 8 FAIL, 0 TIMEOUT, 0 NOTRUN.
[FAIL] MLContext interface: operation createConstantTensor(MLOperandDescriptor, AllowSharedBufferSource)
assert_own_property: interface prototype object missing non-static operation expected property "createConstantTensor" missing
[FAIL] MLContext interface: context must inherit property "createConstantTensor(MLOperandDescriptor, AllowSharedBufferSource)" with the proper type
assert_inherits: property "createConstantTensor" not found in prototype chain
[FAIL] MLContext interface: calling createConstantTensor(MLOperandDescriptor, AllowSharedBufferSource) on context with too few arguments must throw TypeError
assert_inherits: property "createConstantTensor" not found in prototype chain
[FAIL] MLTensor interface: attribute constant
assert_true: The prototype object must have a property "constant" expected true got false
[FAIL] MLGraphBuilder interface: operation constant(MLOperandDescriptor, AllowSharedBufferSource)
assert_equals: property has wrong .length expected 1 but got 2
[FAIL] MLGraphBuilder interface: operation constant(MLOperandDataType, MLNumber)
assert_equals: property has wrong .length expected 1 but got 2
[FAIL] MLGraphBuilder interface: operation constant(MLTensor)
assert_equals: property has wrong .length expected 1 but got 2
[FAIL] MLGraphBuilder interface: operation softmax(MLOperand, unsigned long, optional MLOperatorOptions)
assert_equals: property has wrong .length expected 2 but got 1
Harness: the test ran to completion.

@ -1,5 +1,19 @@
This is a testharness.js-based test.
Found 1 FAIL, 0 TIMEOUT, 0 NOTRUN.
Found 8 FAIL, 0 TIMEOUT, 0 NOTRUN.
[FAIL] MLContext interface: operation createConstantTensor(MLOperandDescriptor, AllowSharedBufferSource)
assert_own_property: interface prototype object missing non-static operation expected property "createConstantTensor" missing
[FAIL] MLContext interface: context must inherit property "createConstantTensor(MLOperandDescriptor, AllowSharedBufferSource)" with the proper type
assert_inherits: property "createConstantTensor" not found in prototype chain
[FAIL] MLContext interface: calling createConstantTensor(MLOperandDescriptor, AllowSharedBufferSource) on context with too few arguments must throw TypeError
assert_inherits: property "createConstantTensor" not found in prototype chain
[FAIL] MLTensor interface: attribute constant
assert_true: The prototype object must have a property "constant" expected true got false
[FAIL] MLGraphBuilder interface: operation constant(MLOperandDescriptor, AllowSharedBufferSource)
assert_equals: property has wrong .length expected 1 but got 2
[FAIL] MLGraphBuilder interface: operation constant(MLOperandDataType, MLNumber)
assert_equals: property has wrong .length expected 1 but got 2
[FAIL] MLGraphBuilder interface: operation constant(MLTensor)
assert_equals: property has wrong .length expected 1 but got 2
[FAIL] MLGraphBuilder interface: operation softmax(MLOperand, unsigned long, optional MLOperatorOptions)
assert_equals: property has wrong .length expected 2 but got 1
Harness: the test ran to completion.

@ -1,5 +1,19 @@
This is a testharness.js-based test.
Found 1 FAIL, 0 TIMEOUT, 0 NOTRUN.
Found 8 FAIL, 0 TIMEOUT, 0 NOTRUN.
[FAIL] MLContext interface: operation createConstantTensor(MLOperandDescriptor, AllowSharedBufferSource)
assert_own_property: interface prototype object missing non-static operation expected property "createConstantTensor" missing
[FAIL] MLContext interface: context must inherit property "createConstantTensor(MLOperandDescriptor, AllowSharedBufferSource)" with the proper type
assert_inherits: property "createConstantTensor" not found in prototype chain
[FAIL] MLContext interface: calling createConstantTensor(MLOperandDescriptor, AllowSharedBufferSource) on context with too few arguments must throw TypeError
assert_inherits: property "createConstantTensor" not found in prototype chain
[FAIL] MLTensor interface: attribute constant
assert_true: The prototype object must have a property "constant" expected true got false
[FAIL] MLGraphBuilder interface: operation constant(MLOperandDescriptor, AllowSharedBufferSource)
assert_equals: property has wrong .length expected 1 but got 2
[FAIL] MLGraphBuilder interface: operation constant(MLOperandDataType, MLNumber)
assert_equals: property has wrong .length expected 1 but got 2
[FAIL] MLGraphBuilder interface: operation constant(MLTensor)
assert_equals: property has wrong .length expected 1 but got 2
[FAIL] MLGraphBuilder interface: operation softmax(MLOperand, unsigned long, optional MLOperatorOptions)
assert_equals: property has wrong .length expected 2 but got 1
Harness: the test ran to completion.