WebNN: add buffer usages for DML backend
Exposes MLBufferUsageFlags to MLBufferDescriptor and adds new usages to maximize device memory bandwidth. After this change, createBuffer() assumes "no usage" by default. To readBuffer() or writeBuffer(), the corresponding usage flag must be specified by the web developer. Combining usages is allowed but could be inefficient. Usages are always validated even if a backend doesn't use it. https://github.com/webmachinelearning/webnn/issues/542 Bug: 343638938 Change-Id: I4d78e3f8bacd7cbabce3038c234c062c7c07b095 Cq-Include-Trybots: luci.chromium.try:win11-blink-rel Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5787041 Commit-Queue: Bryan Bernhart <bryan.bernhart@intel.com> Reviewed-by: Alex Gough <ajgo@chromium.org> Reviewed-by: ningxin hu <ningxin.hu@intel.com> Reviewed-by: Austin Sullivan <asully@chromium.org> Cr-Commit-Position: refs/heads/main@{#1344910}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
ec42d9a27b
commit
9ee56b1031
services/webnn
third_party/blink
renderer
bindings
modules
web_tests
external
wpt
webnn
conformance_tests
validation_tests
@@ -378,12 +378,27 @@ void ContextImplDml::CreateBufferImpl(
|
||||
// CPU will directly read/write to this heap if the GPU isn't using it.
|
||||
ComPtr<ID3D12Resource> buffer;
|
||||
if (adapter_->IsUMA()) {
|
||||
// TODO(crbug.com/40278771): consider introducing buffer usages for INPUT or
|
||||
// OUTPUT since using upload-equivelent custom heaps everywhere could be
|
||||
// inefficient.
|
||||
hr = CreateCustomUploadBuffer(
|
||||
adapter_->d3d12_device(), aligned_buffer_byte_size,
|
||||
L"WebNN_Custom_Upload_Buffer_External", buffer);
|
||||
// Create a buffer configured with memory properties based on
|
||||
// usage.
|
||||
if (buffer_info->usage.Has(MLBufferUsageFlags::kWriteTo)) {
|
||||
// Upload buffer is used when the buffer mostly CPU writes but
|
||||
// could also CPU read. A upload buffer provides less bandwidth for CPU
|
||||
// reads in favor of GPU writes being optimal.
|
||||
hr = CreateCustomUploadBuffer(
|
||||
adapter_->d3d12_device(), aligned_buffer_byte_size,
|
||||
L"WebNN_Custom_Upload_Buffer_External", buffer);
|
||||
} else if (buffer_info->usage.Has(MLBufferUsageFlags::kReadFrom)) {
|
||||
// Readback buffer is used when the buffer only requires CPU reads.
|
||||
hr = CreateCustomReadbackBuffer(
|
||||
adapter_->d3d12_device(), aligned_buffer_byte_size,
|
||||
L"WebNN_Custom_Readback_Buffer_External", buffer);
|
||||
} else {
|
||||
// Default buffer is used when the buffer has no need for CPU access
|
||||
// in favor of any GPU access being optimal.
|
||||
hr = CreateDefaultBuffer(adapter_->d3d12_device(),
|
||||
aligned_buffer_byte_size,
|
||||
L"WebNN_Default_Buffer_External", buffer);
|
||||
}
|
||||
} else {
|
||||
// Create a default buffer that can be accessed only by GPU.
|
||||
// The CPU must use a staging buffer to read/write to this buffer.
|
||||
|
@@ -13,10 +13,14 @@ enum class MLBufferUsageFlags {
|
||||
// This buffer may be imported/rented to WebGPU.
|
||||
kWebGpuInterop,
|
||||
|
||||
// TODO(crbug.com/343638938): Add more usage flags.
|
||||
// This buffer can be used with readBuffer().
|
||||
kReadFrom,
|
||||
|
||||
// This buffer can be used with writeBuffer().
|
||||
kWriteTo,
|
||||
|
||||
kMinValue = kWebGpuInterop,
|
||||
kMaxValue = kWebGpuInterop,
|
||||
kMaxValue = kWriteTo,
|
||||
};
|
||||
|
||||
using MLBufferUsage = base::EnumSet<MLBufferUsageFlags,
|
||||
|
@@ -17,6 +17,14 @@ struct StructTraits<webnn::mojom::BufferUsageDataView, webnn::MLBufferUsage> {
|
||||
return usage.Has(webnn::MLBufferUsageFlags::kWebGpuInterop);
|
||||
}
|
||||
|
||||
static bool write_to(const webnn::MLBufferUsage& usage) {
|
||||
return usage.Has(webnn::MLBufferUsageFlags::kWriteTo);
|
||||
}
|
||||
|
||||
static bool read_from(const webnn::MLBufferUsage& usage) {
|
||||
return usage.Has(webnn::MLBufferUsageFlags::kReadFrom);
|
||||
}
|
||||
|
||||
static bool Read(webnn::mojom::BufferUsageDataView data,
|
||||
webnn::MLBufferUsage* out) {
|
||||
out->Clear();
|
||||
@@ -24,6 +32,15 @@ struct StructTraits<webnn::mojom::BufferUsageDataView, webnn::MLBufferUsage> {
|
||||
if (data.web_gpu_interop()) {
|
||||
out->Put(webnn::MLBufferUsageFlags::kWebGpuInterop);
|
||||
}
|
||||
|
||||
if (data.read_from()) {
|
||||
out->Put(webnn::MLBufferUsageFlags::kReadFrom);
|
||||
}
|
||||
|
||||
if (data.write_to()) {
|
||||
out->Put(webnn::MLBufferUsageFlags::kWriteTo);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
@@ -12,6 +12,10 @@ import "services/webnn/public/mojom/webnn_error.mojom";
|
||||
// the concept of an enum set (see https://crbug.com/40130879#comment11).
|
||||
struct BufferUsage {
|
||||
bool web_gpu_interop;
|
||||
// This buffer can be used with readBuffer().
|
||||
bool read_from;
|
||||
// This buffer can be used with writeBuffer().
|
||||
bool write_to;
|
||||
};
|
||||
|
||||
// Description of the WebNNBuffer to create.
|
||||
|
@@ -16,8 +16,8 @@ WebNNBufferImpl::WebNNBufferImpl(
|
||||
WebNNContextImpl* context,
|
||||
mojom::BufferInfoPtr buffer_info)
|
||||
: context_(context),
|
||||
// TODO(crbug.com/343638938): Use buffer_info->usage.
|
||||
descriptor_(std::move(buffer_info->descriptor)),
|
||||
usage_(std::move(buffer_info->usage)),
|
||||
receiver_(this, std::move(receiver)) {
|
||||
// Safe to use base::Unretained because `this` owns `receiver_`.
|
||||
receiver_.set_disconnect_handler(
|
||||
@@ -27,11 +27,21 @@ WebNNBufferImpl::WebNNBufferImpl(
|
||||
WebNNBufferImpl::~WebNNBufferImpl() = default;
|
||||
|
||||
void WebNNBufferImpl::ReadBuffer(ReadBufferCallback callback) {
|
||||
if (!usage().Has(MLBufferUsageFlags::kReadFrom)) {
|
||||
receiver_.ReportBadMessage(kBadMessageInvalidBuffer);
|
||||
return;
|
||||
}
|
||||
|
||||
// Call ReadBufferImpl() implemented by a backend.
|
||||
ReadBufferImpl(std::move(callback));
|
||||
}
|
||||
|
||||
void WebNNBufferImpl::WriteBuffer(mojo_base::BigBuffer src_buffer) {
|
||||
if (!usage().Has(MLBufferUsageFlags::kWriteTo)) {
|
||||
receiver_.ReportBadMessage(kBadMessageInvalidBuffer);
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO(https://crbug.com/40278771): Generate error using MLContext.
|
||||
if (PackedByteLength() < src_buffer.size()) {
|
||||
receiver_.ReportBadMessage(kBadMessageInvalidBuffer);
|
||||
|
@@ -34,6 +34,7 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNBufferImpl
|
||||
|
||||
OperandDataType data_type() const { return descriptor_.data_type(); }
|
||||
const std::vector<uint32_t>& shape() const { return descriptor_.shape(); }
|
||||
MLBufferUsage usage() const { return usage_; }
|
||||
|
||||
size_t PackedByteLength() const { return descriptor_.PackedByteLength(); }
|
||||
size_t NumberOfElements() const { return descriptor_.NumberOfElements(); }
|
||||
@@ -70,6 +71,7 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNBufferImpl
|
||||
void OnDisconnect();
|
||||
|
||||
const OperandDescriptor descriptor_;
|
||||
const MLBufferUsage usage_;
|
||||
|
||||
mojo::AssociatedReceiver<mojom::WebNNBuffer> receiver_;
|
||||
|
||||
|
@@ -310,7 +310,8 @@ TEST_F(WebNNBufferImplBackendTest, WriteBufferImplTest) {
|
||||
mojom::BufferInfo::New(
|
||||
*OperandDescriptor::Create(OperandDataType::kUint8,
|
||||
std::array<uint32_t, 2>{2, 2}),
|
||||
MLBufferUsage()));
|
||||
MLBufferUsage{MLBufferUsageFlags::kWriteTo,
|
||||
MLBufferUsageFlags::kReadFrom}));
|
||||
if (buffer_result.has_value()) {
|
||||
webnn_buffer_remote = std::move(buffer_result.value().webnn_buffer_remote);
|
||||
}
|
||||
@@ -353,7 +354,7 @@ TEST_F(WebNNBufferImplBackendTest, WriteBufferImplTooLargeTest) {
|
||||
mojom::BufferInfo::New(
|
||||
*OperandDescriptor::Create(OperandDataType::kUint8,
|
||||
std::array<uint32_t, 2>{2, 2}),
|
||||
MLBufferUsage()));
|
||||
MLBufferUsage{MLBufferUsageFlags::kWriteTo}));
|
||||
if (buffer_result.has_value()) {
|
||||
webnn_buffer_remote = std::move(buffer_result.value().webnn_buffer_remote);
|
||||
}
|
||||
|
@@ -3069,6 +3069,8 @@ generated_namespace_sources_in_modules = [
|
||||
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_gpu_shader_stage.h",
|
||||
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_gpu_texture_usage.cc",
|
||||
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_gpu_texture_usage.h",
|
||||
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_buffer_usage.cc",
|
||||
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_buffer_usage.h",
|
||||
]
|
||||
|
||||
if (enable_compute_pressure) {
|
||||
|
@@ -503,6 +503,7 @@ static_idl_files_in_modules = [
|
||||
"//third_party/blink/renderer/modules/ml/navigator_ml.idl",
|
||||
"//third_party/blink/renderer/modules/ml/webnn/ml_buffer.idl",
|
||||
"//third_party/blink/renderer/modules/ml/webnn/ml_buffer_descriptor.idl",
|
||||
"//third_party/blink/renderer/modules/ml/webnn/ml_buffer_usage.idl",
|
||||
"//third_party/blink/renderer/modules/ml/webnn/ml_graph.idl",
|
||||
"//third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.idl",
|
||||
"//third_party/blink/renderer/modules/ml/webnn/ml_operand.idl",
|
||||
|
@@ -644,9 +644,14 @@ ScriptPromise<MLBuffer> MLContext::createBuffer(
|
||||
return ScriptPromise<MLBuffer>();
|
||||
});
|
||||
|
||||
// TODO(crbug.com/343638938): Pass real buffer usages.
|
||||
auto buffer_info = webnn::mojom::blink::BufferInfo::New(
|
||||
validated_descriptor, webnn::MLBufferUsage());
|
||||
// WebNN bitfield values have the same value as enums.
|
||||
webnn::MLBufferUsage usage;
|
||||
if (descriptor->hasUsage()) {
|
||||
usage = webnn::MLBufferUsage::FromEnumBitmask(descriptor->usage());
|
||||
}
|
||||
|
||||
auto buffer_info =
|
||||
webnn::mojom::blink::BufferInfo::New(validated_descriptor, usage);
|
||||
|
||||
auto* resolver = MakeGarbageCollected<ScriptPromiseResolver<MLBuffer>>(
|
||||
script_state, exception_state.GetContext());
|
||||
@@ -657,7 +662,7 @@ ScriptPromise<MLBuffer> MLContext::createBuffer(
|
||||
std::move(buffer_info),
|
||||
WTF::BindOnce(&MLContext::DidCreateWebNNBuffer, WrapPersistent(this),
|
||||
std::move(scoped_trace), WrapPersistent(resolver),
|
||||
std::move(validated_descriptor)));
|
||||
std::move(validated_descriptor), usage));
|
||||
|
||||
return resolver->Promise();
|
||||
}
|
||||
@@ -726,6 +731,12 @@ ScriptPromise<DOMArrayBuffer> MLContext::readBuffer(
|
||||
return EmptyPromise();
|
||||
}
|
||||
|
||||
if (!src_buffer->Usage().Has(webnn::MLBufferUsageFlags::kReadFrom)) {
|
||||
exception_state.ThrowTypeError(
|
||||
"The source buffer doesn't have read access.");
|
||||
return EmptyPromise();
|
||||
}
|
||||
|
||||
return src_buffer->ReadBufferImpl(script_state, exception_state);
|
||||
}
|
||||
|
||||
@@ -788,6 +799,12 @@ void MLContext::WriteWebNNBuffer(ScriptState* script_state,
|
||||
return;
|
||||
}
|
||||
|
||||
if (!dst_buffer->Usage().Has(webnn::MLBufferUsageFlags::kWriteTo)) {
|
||||
exception_state.ThrowTypeError(
|
||||
"The destination buffer doesn't have write access.");
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t src_data_byte_length = src_data.size();
|
||||
if (src_element_offset > src_data_byte_length / src_data_type_size_bytes) {
|
||||
exception_state.ThrowTypeError(
|
||||
@@ -876,6 +893,7 @@ void MLContext::DidCreateWebNNBuffer(
|
||||
ScopedMLTrace scoped_trace,
|
||||
ScriptPromiseResolver<blink::MLBuffer>* resolver,
|
||||
webnn::OperandDescriptor validated_descriptor,
|
||||
webnn::MLBufferUsage usage,
|
||||
webnn::mojom::blink::CreateBufferResultPtr result) {
|
||||
pending_resolvers_.erase(resolver);
|
||||
|
||||
@@ -894,7 +912,7 @@ void MLContext::DidCreateWebNNBuffer(
|
||||
|
||||
auto* buffer = MakeGarbageCollected<MLBuffer>(
|
||||
resolver->GetExecutionContext(), this, std::move(validated_descriptor),
|
||||
std::move(result->get_success()), base::PassKey<MLContext>());
|
||||
usage, std::move(result->get_success()), base::PassKey<MLContext>());
|
||||
buffers_.insert(buffer);
|
||||
|
||||
resolver->Resolve(buffer);
|
||||
|
@@ -164,6 +164,7 @@ class MODULES_EXPORT MLContext : public ScriptWrappable {
|
||||
void DidCreateWebNNBuffer(ScopedMLTrace scoped_trace,
|
||||
ScriptPromiseResolver<blink::MLBuffer>* resolver,
|
||||
webnn::OperandDescriptor validated_descriptor,
|
||||
webnn::MLBufferUsage usage,
|
||||
webnn::mojom::blink::CreateBufferResultPtr result);
|
||||
|
||||
V8MLDeviceType device_type_;
|
||||
|
@@ -21,10 +21,12 @@ MLBuffer::MLBuffer(
|
||||
ExecutionContext* execution_context,
|
||||
MLContext* context,
|
||||
webnn::OperandDescriptor descriptor,
|
||||
webnn::MLBufferUsage usage,
|
||||
webnn::mojom::blink::CreateBufferSuccessPtr create_buffer_success,
|
||||
base::PassKey<MLContext> /*pass_key*/)
|
||||
: ml_context_(context),
|
||||
descriptor_(std::move(descriptor)),
|
||||
usage_(usage),
|
||||
webnn_handle_(std::move(create_buffer_success->buffer_handle)),
|
||||
remote_buffer_(execution_context) {
|
||||
remote_buffer_.Bind(
|
||||
@@ -52,6 +54,10 @@ Vector<uint32_t> MLBuffer::shape() const {
|
||||
return Vector<uint32_t>(descriptor_.shape());
|
||||
}
|
||||
|
||||
uint32_t MLBuffer::usage() const {
|
||||
return static_cast<uint32_t>(usage_.ToEnumBitmask());
|
||||
}
|
||||
|
||||
void MLBuffer::destroy() {
|
||||
// Calling OnConnectionError() will disconnect and destroy the buffer in
|
||||
// the service. The remote buffer must remain unbound after calling
|
||||
@@ -71,6 +77,10 @@ const std::vector<uint32_t>& MLBuffer::Shape() const {
|
||||
return descriptor_.shape();
|
||||
}
|
||||
|
||||
const webnn::MLBufferUsage& MLBuffer::Usage() const {
|
||||
return usage_;
|
||||
}
|
||||
|
||||
uint64_t MLBuffer::PackedByteLength() const {
|
||||
return descriptor_.PackedByteLength();
|
||||
}
|
||||
|
@@ -11,6 +11,7 @@
|
||||
#include "services/webnn/public/mojom/webnn_buffer.mojom-blink.h"
|
||||
#include "services/webnn/public/mojom/webnn_context_provider.mojom-blink.h"
|
||||
#include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h"
|
||||
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_buffer_usage.h"
|
||||
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_data_type.h"
|
||||
#include "third_party/blink/renderer/modules/ml/ml_trace.h"
|
||||
#include "third_party/blink/renderer/modules/modules_export.h"
|
||||
@@ -40,6 +41,7 @@ class MODULES_EXPORT MLBuffer : public ScriptWrappable {
|
||||
MLBuffer(ExecutionContext* execution_context,
|
||||
MLContext* context,
|
||||
webnn::OperandDescriptor descriptor,
|
||||
webnn::MLBufferUsage usage,
|
||||
webnn::mojom::blink::CreateBufferSuccessPtr create_buffer_success,
|
||||
base::PassKey<MLContext> pass_key);
|
||||
MLBuffer(const MLBuffer&) = delete;
|
||||
@@ -52,6 +54,8 @@ class MODULES_EXPORT MLBuffer : public ScriptWrappable {
|
||||
// ml_buffer.idl
|
||||
V8MLOperandDataType dataType() const;
|
||||
Vector<uint32_t> shape() const;
|
||||
uint32_t usage() const;
|
||||
|
||||
void destroy();
|
||||
|
||||
// Convenience methods for accessing native types, which avoid a copy
|
||||
@@ -59,6 +63,7 @@ class MODULES_EXPORT MLBuffer : public ScriptWrappable {
|
||||
const webnn::OperandDescriptor& Descriptor() const;
|
||||
webnn::OperandDataType DataType() const;
|
||||
const std::vector<uint32_t>& Shape() const;
|
||||
const webnn::MLBufferUsage& Usage() const;
|
||||
|
||||
uint64_t PackedByteLength() const;
|
||||
|
||||
@@ -105,6 +110,9 @@ class MODULES_EXPORT MLBuffer : public ScriptWrappable {
|
||||
// Represents a valid MLBufferDescriptor.
|
||||
const webnn::OperandDescriptor descriptor_;
|
||||
|
||||
// Represents a valid MLBufferUsage.
|
||||
const webnn::MLBufferUsage usage_;
|
||||
|
||||
// Identifies this `WebNNBuffer` mojo instance in the service process.
|
||||
const blink::WebNNBufferToken webnn_handle_;
|
||||
|
||||
|
@@ -10,6 +10,7 @@
|
||||
] interface MLBuffer {
|
||||
readonly attribute MLOperandDataType dataType;
|
||||
readonly attribute FrozenArray<unsigned long> shape;
|
||||
readonly attribute MLBufferUsageFlags usage;
|
||||
|
||||
void destroy();
|
||||
};
|
@@ -7,5 +7,5 @@
|
||||
typedef [EnforceRange] unsigned long long MLSize64;
|
||||
|
||||
dictionary MLBufferDescriptor : MLOperandDescriptor {
|
||||
// TODO(crbug.com/343638938): Add buffer usage flags.
|
||||
MLBufferUsageFlags usage;
|
||||
};
|
11
third_party/blink/renderer/modules/ml/webnn/ml_buffer_usage.h
vendored
Normal file
11
third_party/blink/renderer/modules/ml/webnn/ml_buffer_usage.h
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright 2024 The Chromium Authors
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
#ifndef THIRD_PARTY_BLINK_RENDERER_MODULES_ML_WEBNN_ML_BUFFER_USAGE_H_
|
||||
#define THIRD_PARTY_BLINK_RENDERER_MODULES_ML_WEBNN_ML_BUFFER_USAGE_H_
|
||||
|
||||
// This header is intentionally left blank
|
||||
// Needs to exist because it is included by v8_ml_buffer_usage.cc
|
||||
|
||||
#endif // THIRD_PARTY_BLINK_RENDERER_MODULES_ML_WEBNN_ML_BUFFER_USAGE_H_
|
18
third_party/blink/renderer/modules/ml/webnn/ml_buffer_usage.idl
vendored
Normal file
18
third_party/blink/renderer/modules/ml/webnn/ml_buffer_usage.idl
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright 2024 The Chromium Authors
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
// https://www.w3.org/TR/webnn/#api-mlbuffer
|
||||
|
||||
typedef unsigned long MLFlagsConstant;
|
||||
|
||||
typedef [EnforceRange] unsigned long MLBufferUsageFlags;
|
||||
[
|
||||
RuntimeEnabled=MachineLearningNeuralNetwork,
|
||||
Exposed=(Window, DedicatedWorker),
|
||||
SecureContext
|
||||
] namespace MLBufferUsage {
|
||||
const MLFlagsConstant WEBGPU_INTEROP = 1;
|
||||
const MLFlagsConstant READ_FROM = 2;
|
||||
const MLFlagsConstant WRITE_TO = 4;
|
||||
};
|
@@ -39,6 +39,7 @@
|
||||
#include "third_party/blink/renderer/bindings/core/v8/v8_binding_for_testing.h"
|
||||
#include "third_party/blink/renderer/bindings/core/v8/v8_dom_exception.h"
|
||||
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_buffer_descriptor.h"
|
||||
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_buffer_usage.h"
|
||||
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_clamp_options.h"
|
||||
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_compute_result.h"
|
||||
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_context_options.h"
|
||||
@@ -776,6 +777,8 @@ MLBuffer* CreateMLBufferForOperand(V8TestingScope& scope,
|
||||
auto* desc = MLBufferDescriptor::Create();
|
||||
desc->setDataType(operand->dataType());
|
||||
desc->setDimensions(operand->shape());
|
||||
desc->setUsage(V8MLBufferUsage::Constant::kWriteTo |
|
||||
V8MLBufferUsage::Constant::kReadFrom);
|
||||
|
||||
ScriptPromiseTester tester(
|
||||
scope.GetScriptState(),
|
||||
@@ -1324,6 +1327,8 @@ TEST_F(MLGraphTest, WriteWebNNBufferTest) {
|
||||
auto* desc = MLBufferDescriptor::Create();
|
||||
desc->setDataType(V8MLOperandDataType::Enum::kUint8);
|
||||
desc->setDimensions(kBufferShape);
|
||||
desc->setUsage(V8MLBufferUsage::Constant::kWriteTo |
|
||||
V8MLBufferUsage::Constant::kReadFrom);
|
||||
|
||||
ScriptPromiseTester buffer_tester(
|
||||
script_state,
|
||||
@@ -1417,6 +1422,7 @@ TEST_F(MLGraphTest, WriteWebNNBufferThenDestroyTest) {
|
||||
auto* desc = MLBufferDescriptor::Create();
|
||||
desc->setDataType(V8MLOperandDataType::Enum::kUint8);
|
||||
desc->setDimensions({2, 2});
|
||||
desc->setUsage(V8MLBufferUsage::Constant::kWriteTo);
|
||||
|
||||
ScriptPromiseTester buffer_tester(
|
||||
script_state,
|
||||
@@ -1459,6 +1465,7 @@ TEST_F(MLGraphTest, ReadWebNNBufferThenDestroyTest) {
|
||||
auto* desc = MLBufferDescriptor::Create();
|
||||
desc->setDataType(V8MLOperandDataType::Enum::kFloat32);
|
||||
desc->setDimensions({2, 2});
|
||||
desc->setUsage(V8MLBufferUsage::Constant::kReadFrom);
|
||||
|
||||
ScriptPromiseTester create_buffer_tester(
|
||||
script_state,
|
||||
|
@@ -33,7 +33,11 @@ const sizeOfDescriptor = (descriptor) => {
|
||||
};
|
||||
|
||||
const getDescriptorFromBuffer = (buffer) => {
|
||||
return {dataType: buffer.dataType, dimensions: buffer.shape};
|
||||
return {
|
||||
dataType: buffer.dataType,
|
||||
dimensions: buffer.shape,
|
||||
usage: buffer.usage
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -160,7 +164,11 @@ const testWriteWebNNBuffer = (testName) => {
|
||||
});
|
||||
|
||||
promise_test(async () => {
|
||||
const bufferDescriptor = {dataType: 'int32', dimensions: [1]};
|
||||
const bufferDescriptor = {
|
||||
dataType: 'int32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.WRITE_TO,
|
||||
};
|
||||
let mlBuffer = await mlContext.createBuffer(bufferDescriptor);
|
||||
|
||||
const bufferByteLength = sizeOfDescriptor(bufferDescriptor);
|
||||
@@ -205,7 +213,11 @@ const testWriteWebNNBuffer = (testName) => {
|
||||
}, `${testName} / error`);
|
||||
|
||||
promise_test(async () => {
|
||||
const bufferDescriptor = {dataType: 'int32', dimensions: [2, 2]};
|
||||
const bufferDescriptor = {
|
||||
dataType: 'int32',
|
||||
dimensions: [2, 2],
|
||||
usage: MLBufferUsage.WRITE_TO,
|
||||
};
|
||||
let mlBuffer = await mlContext.createBuffer(bufferDescriptor);
|
||||
|
||||
// Writing data to a destroyed MLBuffer should throw.
|
||||
@@ -218,7 +230,11 @@ const testWriteWebNNBuffer = (testName) => {
|
||||
}, `${testName} / destroy`);
|
||||
|
||||
promise_test(async () => {
|
||||
const bufferDescriptor = {dataType: 'int32', dimensions: [2, 3]};
|
||||
const bufferDescriptor = {
|
||||
dataType: 'int32',
|
||||
dimensions: [2, 3],
|
||||
usage: MLBufferUsage.WRITE_TO,
|
||||
};
|
||||
let mlBuffer = await mlContext.createBuffer(bufferDescriptor);
|
||||
|
||||
let anotherMLContext = await navigator.ml.createContext(contextOptions);
|
||||
@@ -233,8 +249,11 @@ const testWriteWebNNBuffer = (testName) => {
|
||||
}, `${testName} / context_mismatch`);
|
||||
|
||||
promise_test(async () => {
|
||||
let mlBuffer =
|
||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [1]});
|
||||
let mlBuffer = await mlContext.createBuffer({
|
||||
dataType: 'int32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||
});
|
||||
|
||||
// Initialize the buffer.
|
||||
const inputData = Uint8Array.from([0xAA, 0xAA, 0xAA, 0xAA]);
|
||||
@@ -253,7 +272,11 @@ const testWriteWebNNBuffer = (testName) => {
|
||||
}, `${testName} / zero_write`);
|
||||
|
||||
promise_test(async () => {
|
||||
const bufferDescriptor = {dataType: 'int32', dimensions: [2, 2]};
|
||||
const bufferDescriptor = {
|
||||
dataType: 'int32',
|
||||
dimensions: [2, 2],
|
||||
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||
};
|
||||
let mlBuffer = await mlContext.createBuffer(bufferDescriptor);
|
||||
|
||||
const bufferByteLength = sizeOfDescriptor(bufferDescriptor);
|
||||
@@ -300,8 +323,11 @@ const testReadWebNNBuffer = (testName) => {
|
||||
});
|
||||
|
||||
promise_test(async t => {
|
||||
let mlBuffer =
|
||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [2, 2]});
|
||||
let mlBuffer = await mlContext.createBuffer({
|
||||
dataType: 'int32',
|
||||
dimensions: [2, 2],
|
||||
usage: MLBufferUsage.READ_FROM,
|
||||
});
|
||||
|
||||
// Reading a destroyed MLBuffer should reject.
|
||||
mlBuffer.destroy();
|
||||
@@ -311,8 +337,11 @@ const testReadWebNNBuffer = (testName) => {
|
||||
}, `${testName} / read_after_destroy`);
|
||||
|
||||
promise_test(async t => {
|
||||
let mlBuffer =
|
||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [2, 3]});
|
||||
let mlBuffer = await mlContext.createBuffer({
|
||||
dataType: 'int32',
|
||||
dimensions: [2, 3],
|
||||
usage: MLBufferUsage.READ_FROM,
|
||||
});
|
||||
|
||||
let promise = mlContext.readBuffer(mlBuffer);
|
||||
let anotherPromise = mlContext.readBuffer(mlBuffer);
|
||||
@@ -324,16 +353,22 @@ const testReadWebNNBuffer = (testName) => {
|
||||
}, `${testName} / read_before_destroy`);
|
||||
|
||||
promise_test(async () => {
|
||||
let mlBuffer =
|
||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [1024]});
|
||||
let mlBuffer = await mlContext.createBuffer({
|
||||
dataType: 'int32',
|
||||
dimensions: [1024],
|
||||
usage: MLBufferUsage.READ_FROM,
|
||||
});
|
||||
|
||||
await assert_buffer_data_equals(
|
||||
mlContext, mlBuffer, new Uint32Array(1024));
|
||||
}, `${testName} / uninitialized`);
|
||||
|
||||
promise_test(async () => {
|
||||
let mlBuffer =
|
||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [1]});
|
||||
let mlBuffer = await mlContext.createBuffer({
|
||||
dataType: 'int32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.READ_FROM | MLBufferUsage.WRITE_TO,
|
||||
});
|
||||
|
||||
// Initialize the buffer.
|
||||
mlContext.writeBuffer(mlBuffer, Uint8Array.from([0xAA, 0xAA, 0xAA, 0xAA]));
|
||||
@@ -345,8 +380,11 @@ const testReadWebNNBuffer = (testName) => {
|
||||
}, `${testName} / full_size`);
|
||||
|
||||
promise_test(async () => {
|
||||
let mlBuffer =
|
||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [1]});
|
||||
let mlBuffer = await mlContext.createBuffer({
|
||||
dataType: 'int32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||
});
|
||||
|
||||
// Initialize the buffer.
|
||||
mlContext.writeBuffer(mlBuffer, Uint8Array.from([0xAA, 0xAA, 0xAA, 0xAA]));
|
||||
@@ -360,8 +398,11 @@ const testReadWebNNBuffer = (testName) => {
|
||||
}, `${testName} / src_offset_only`);
|
||||
|
||||
promise_test(async () => {
|
||||
let mlBuffer =
|
||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [1]});
|
||||
let mlBuffer = await mlContext.createBuffer({
|
||||
dataType: 'int32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||
});
|
||||
|
||||
// Initialize the buffer.
|
||||
mlContext.writeBuffer(mlBuffer, Uint8Array.from([0xAA, 0xAA, 0xAA, 0xAA]));
|
||||
@@ -375,8 +416,11 @@ const testReadWebNNBuffer = (testName) => {
|
||||
}, `${testName} / src_offset_and_size`);
|
||||
|
||||
promise_test(async () => {
|
||||
let mlBuffer =
|
||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [1]});
|
||||
let mlBuffer = await mlContext.createBuffer({
|
||||
dataType: 'int32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||
});
|
||||
|
||||
// Initialize the buffer.
|
||||
mlContext.writeBuffer(mlBuffer, Uint8Array.from([0xAA, 0xAA, 0xAA, 0xAA]));
|
||||
@@ -390,8 +434,11 @@ const testReadWebNNBuffer = (testName) => {
|
||||
}, `${testName} / larger_src_data`);
|
||||
|
||||
promise_test(async () => {
|
||||
let mlBuffer =
|
||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [1]});
|
||||
let mlBuffer = await mlContext.createBuffer({
|
||||
dataType: 'int32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||
});
|
||||
|
||||
const inputData = [0xAA, 0xAA, 0xAA, 0xAA];
|
||||
|
||||
@@ -404,7 +451,11 @@ const testReadWebNNBuffer = (testName) => {
|
||||
}, `${testName} / no_src_offset`);
|
||||
|
||||
promise_test(async t => {
|
||||
const bufferDescriptor = {dataType: 'int32', dimensions: [2, 3]};
|
||||
const bufferDescriptor = {
|
||||
dataType: 'int32',
|
||||
dimensions: [2, 3],
|
||||
usage: MLBufferUsage.READ_FROM,
|
||||
};
|
||||
let mlBuffer = await mlContext.createBuffer(bufferDescriptor);
|
||||
|
||||
let anotherMLContext = await navigator.ml.createContext(contextOptions);
|
||||
@@ -436,7 +487,11 @@ const testDispatchWebNNBuffer = (testName) => {
|
||||
}
|
||||
// Construct a simple graph: A = B + C, with two outputs.
|
||||
const builder = new MLGraphBuilder(mlContext);
|
||||
const bufferDescriptor = {dataType: 'float32', dimensions: shape};
|
||||
const bufferDescriptor = {
|
||||
dataType: 'float32',
|
||||
dimensions: shape,
|
||||
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||
};
|
||||
const lhsOperand = builder.input('lhs', bufferDescriptor);
|
||||
const rhsOperand = builder.input('rhs', bufferDescriptor);
|
||||
const output1Operand = builder.add(lhsOperand, rhsOperand);
|
||||
|
21
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/byob_readbuffer.https.any.js
vendored
21
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/byob_readbuffer.https.any.js
vendored
@@ -29,8 +29,11 @@ promise_setup(async () => {
|
||||
}
|
||||
|
||||
try {
|
||||
mlBuffer =
|
||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [2, 4]});
|
||||
mlBuffer = await mlContext.createBuffer({
|
||||
dataType: 'int32',
|
||||
dimensions: [2, 4],
|
||||
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||
});
|
||||
} catch (e) {
|
||||
throw new AssertionError(
|
||||
`Unable to create buffer for ${variant} variant. ${e}`);
|
||||
@@ -135,8 +138,11 @@ promise_test(async () => {
|
||||
}, `readBuffer() with a larger TypedArray`);
|
||||
|
||||
promise_test(async (t) => {
|
||||
const buffer =
|
||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [2, 2]});
|
||||
const buffer = await mlContext.createBuffer({
|
||||
dataType: 'int32',
|
||||
dimensions: [2, 2],
|
||||
usage: MLBufferUsage.READ_FROM,
|
||||
});
|
||||
const arrayBufferView = new Int32Array(2 * 2);
|
||||
const arrayBuffer = arrayBufferView.buffer;
|
||||
|
||||
@@ -150,8 +156,11 @@ promise_test(async (t) => {
|
||||
}, `readBuffer() rejects on a destroyed MLBuffer`);
|
||||
|
||||
promise_test(async (t) => {
|
||||
const buffer =
|
||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [2, 2]});
|
||||
const buffer = await mlContext.createBuffer({
|
||||
dataType: 'int32',
|
||||
dimensions: [2, 2],
|
||||
usage: MLBufferUsage.READ_FROM,
|
||||
});
|
||||
const arrayBufferView = new Int32Array(2 * 2);
|
||||
const arrayBuffer = arrayBufferView.buffer;
|
||||
|
||||
|
54
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/parallel-dispatch.https.any.js
vendored
54
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/parallel-dispatch.https.any.js
vendored
@@ -30,7 +30,11 @@ function buildMulGraph(context, operandDescriptor, multiplier) {
|
||||
}
|
||||
|
||||
promise_test(async () => {
|
||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
||||
const operandDescriptor = {
|
||||
dataType: 'float32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||
};
|
||||
|
||||
const [mlGraph, inputBuffer1, inputBuffer2, outputBuffer] =
|
||||
await Promise.all([
|
||||
@@ -66,7 +70,11 @@ promise_test(async () => {
|
||||
}, 'dispatch queues behind readBuffer');
|
||||
|
||||
promise_test(async () => {
|
||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
||||
const operandDescriptor = {
|
||||
dataType: 'float32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||
};
|
||||
const mlGraph = await buildMulGraph(mlContext, operandDescriptor, 3);
|
||||
|
||||
// write/dispatch/read, write/dispatch/read, ...
|
||||
@@ -90,7 +98,11 @@ promise_test(async () => {
|
||||
}, 'same graph: write/dispatch/read, write/dispatch/read, ...');
|
||||
|
||||
promise_test(async () => {
|
||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
||||
const operandDescriptor = {
|
||||
dataType: 'float32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||
};
|
||||
const mlGraph = await buildMulGraph(mlContext, operandDescriptor, 10);
|
||||
|
||||
// write/write...
|
||||
@@ -125,7 +137,11 @@ promise_test(async () => {
|
||||
}, 'same graph: write/write..., dispatch/read, dispatch/read, ...');
|
||||
|
||||
promise_test(async () => {
|
||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
||||
const operandDescriptor = {
|
||||
dataType: 'float32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||
};
|
||||
const mlGraph = await buildMulGraph(mlContext, operandDescriptor, 9);
|
||||
|
||||
// write/write...
|
||||
@@ -159,7 +175,11 @@ promise_test(async () => {
|
||||
}, 'same graph: write/write..., dispatch/dispatch..., read/read...');
|
||||
|
||||
promise_test(async () => {
|
||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
||||
const operandDescriptor = {
|
||||
dataType: 'float32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||
};
|
||||
const mlGraph = await buildMulGraph(mlContext, operandDescriptor, 2);
|
||||
|
||||
const buffers = await Promise.all([
|
||||
@@ -188,7 +208,11 @@ promise_test(async () => {
|
||||
}, 'same graph serial inputs: dispatch/dispatch..., read/read...');
|
||||
|
||||
promise_test(async () => {
|
||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
||||
const operandDescriptor = {
|
||||
dataType: 'float32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||
};
|
||||
|
||||
// write/write...
|
||||
const testInputs = [1, 2, 3, 4];
|
||||
@@ -223,7 +247,11 @@ promise_test(async () => {
|
||||
}, 'different graphs: write/write..., dispatch/read, dispatch/read, ...');
|
||||
|
||||
promise_test(async () => {
|
||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
||||
const operandDescriptor = {
|
||||
dataType: 'float32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||
};
|
||||
|
||||
// write/write...
|
||||
const testInputs = [1, 2, 3, 4];
|
||||
@@ -257,7 +285,11 @@ promise_test(async () => {
|
||||
}, 'different graphs: write/write..., dispatch/dispatch..., read/read...');
|
||||
|
||||
promise_test(async () => {
|
||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
||||
const operandDescriptor = {
|
||||
dataType: 'float32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||
};
|
||||
|
||||
const graphs = await Promise.all([3, 2].map(async (multiplier) => {
|
||||
return buildMulGraph(mlContext, operandDescriptor, multiplier);
|
||||
@@ -289,7 +321,11 @@ promise_test(async () => {
|
||||
}, 'different graphs serial inputs: dispatch/dispatch..., read/read...');
|
||||
|
||||
promise_test(async () => {
|
||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
||||
const operandDescriptor = {
|
||||
dataType: 'float32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||
};
|
||||
|
||||
const graphs = await Promise.all([2, 3].map(async (multiplier) => {
|
||||
return buildMulGraph(mlContext, operandDescriptor, multiplier);
|
||||
|
21
third_party/blink/web_tests/external/wpt/webnn/validation_tests/destroyContext.https.any.js
vendored
21
third_party/blink/web_tests/external/wpt/webnn/validation_tests/destroyContext.https.any.js
vendored
@@ -132,16 +132,22 @@ promise_test(async t => {
|
||||
|
||||
promise_test(async t => {
|
||||
const context = await navigator.ml.createContext(contextOptions);
|
||||
const buffer =
|
||||
await context.createBuffer({dataType: 'float32', dimensions: [1]});
|
||||
const buffer = await context.createBuffer({
|
||||
dataType: 'float32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.READ_FROM,
|
||||
});
|
||||
context.destroy();
|
||||
promise_rejects_dom(t, 'InvalidStateError', context.readBuffer(buffer));
|
||||
}, 'Destroyed context can not read buffer.');
|
||||
|
||||
promise_test(async t => {
|
||||
const context = await navigator.ml.createContext(contextOptions);
|
||||
const buffer =
|
||||
await context.createBuffer({dataType: 'float32', dimensions: [1]});
|
||||
const buffer = await context.createBuffer({
|
||||
dataType: 'float32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.READ_FROM,
|
||||
});
|
||||
let promise = context.readBuffer(buffer);
|
||||
context.destroy();
|
||||
promise_rejects_dom(t, 'InvalidStateError', promise);
|
||||
@@ -152,8 +158,11 @@ promise_test(async t => {
|
||||
// Destroying another context doesn't impact the first context.
|
||||
const another_context = await navigator.ml.createContext(contextOptions);
|
||||
another_context.destroy();
|
||||
const buffer =
|
||||
await context.createBuffer({dataType: 'float32', dimensions: [1]});
|
||||
const buffer = await context.createBuffer({
|
||||
dataType: 'float32',
|
||||
dimensions: [1],
|
||||
usage: MLBufferUsage.WRITE_TO,
|
||||
});
|
||||
let arrayBuffer = new ArrayBuffer(4);
|
||||
context.destroy();
|
||||
assert_throws_dom('InvalidStateError', () => {
|
||||
|
Reference in New Issue
Block a user