WebNN: add buffer usages for DML backend
Exposes MLBufferUsageFlags to MLBufferDescriptor and adds new usages to maximize device memory bandwidth. After this change, createBuffer() assumes "no usage" by default. To readBuffer() or writeBuffer(), the corresponding usage flag must be specified by the web developer. Combining usages is allowed but could be inefficient. Usages are always validated even if a backend doesn't use it. https://github.com/webmachinelearning/webnn/issues/542 Bug: 343638938 Change-Id: I4d78e3f8bacd7cbabce3038c234c062c7c07b095 Cq-Include-Trybots: luci.chromium.try:win11-blink-rel Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5787041 Commit-Queue: Bryan Bernhart <bryan.bernhart@intel.com> Reviewed-by: Alex Gough <ajgo@chromium.org> Reviewed-by: ningxin hu <ningxin.hu@intel.com> Reviewed-by: Austin Sullivan <asully@chromium.org> Cr-Commit-Position: refs/heads/main@{#1344910}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
ec42d9a27b
commit
9ee56b1031
services/webnn
third_party/blink
renderer
bindings
modules
web_tests
external
wpt
webnn
conformance_tests
validation_tests
@@ -378,12 +378,27 @@ void ContextImplDml::CreateBufferImpl(
|
|||||||
// CPU will directly read/write to this heap if the GPU isn't using it.
|
// CPU will directly read/write to this heap if the GPU isn't using it.
|
||||||
ComPtr<ID3D12Resource> buffer;
|
ComPtr<ID3D12Resource> buffer;
|
||||||
if (adapter_->IsUMA()) {
|
if (adapter_->IsUMA()) {
|
||||||
// TODO(crbug.com/40278771): consider introducing buffer usages for INPUT or
|
// Create a buffer configured with memory properties based on
|
||||||
// OUTPUT since using upload-equivelent custom heaps everywhere could be
|
// usage.
|
||||||
// inefficient.
|
if (buffer_info->usage.Has(MLBufferUsageFlags::kWriteTo)) {
|
||||||
hr = CreateCustomUploadBuffer(
|
// Upload buffer is used when the buffer mostly CPU writes but
|
||||||
adapter_->d3d12_device(), aligned_buffer_byte_size,
|
// could also CPU read. A upload buffer provides less bandwidth for CPU
|
||||||
L"WebNN_Custom_Upload_Buffer_External", buffer);
|
// reads in favor of GPU writes being optimal.
|
||||||
|
hr = CreateCustomUploadBuffer(
|
||||||
|
adapter_->d3d12_device(), aligned_buffer_byte_size,
|
||||||
|
L"WebNN_Custom_Upload_Buffer_External", buffer);
|
||||||
|
} else if (buffer_info->usage.Has(MLBufferUsageFlags::kReadFrom)) {
|
||||||
|
// Readback buffer is used when the buffer only requires CPU reads.
|
||||||
|
hr = CreateCustomReadbackBuffer(
|
||||||
|
adapter_->d3d12_device(), aligned_buffer_byte_size,
|
||||||
|
L"WebNN_Custom_Readback_Buffer_External", buffer);
|
||||||
|
} else {
|
||||||
|
// Default buffer is used when the buffer has no need for CPU access
|
||||||
|
// in favor of any GPU access being optimal.
|
||||||
|
hr = CreateDefaultBuffer(adapter_->d3d12_device(),
|
||||||
|
aligned_buffer_byte_size,
|
||||||
|
L"WebNN_Default_Buffer_External", buffer);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// Create a default buffer that can be accessed only by GPU.
|
// Create a default buffer that can be accessed only by GPU.
|
||||||
// The CPU must use a staging buffer to read/write to this buffer.
|
// The CPU must use a staging buffer to read/write to this buffer.
|
||||||
|
@@ -13,10 +13,14 @@ enum class MLBufferUsageFlags {
|
|||||||
// This buffer may be imported/rented to WebGPU.
|
// This buffer may be imported/rented to WebGPU.
|
||||||
kWebGpuInterop,
|
kWebGpuInterop,
|
||||||
|
|
||||||
// TODO(crbug.com/343638938): Add more usage flags.
|
// This buffer can be used with readBuffer().
|
||||||
|
kReadFrom,
|
||||||
|
|
||||||
|
// This buffer can be used with writeBuffer().
|
||||||
|
kWriteTo,
|
||||||
|
|
||||||
kMinValue = kWebGpuInterop,
|
kMinValue = kWebGpuInterop,
|
||||||
kMaxValue = kWebGpuInterop,
|
kMaxValue = kWriteTo,
|
||||||
};
|
};
|
||||||
|
|
||||||
using MLBufferUsage = base::EnumSet<MLBufferUsageFlags,
|
using MLBufferUsage = base::EnumSet<MLBufferUsageFlags,
|
||||||
|
@@ -17,6 +17,14 @@ struct StructTraits<webnn::mojom::BufferUsageDataView, webnn::MLBufferUsage> {
|
|||||||
return usage.Has(webnn::MLBufferUsageFlags::kWebGpuInterop);
|
return usage.Has(webnn::MLBufferUsageFlags::kWebGpuInterop);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool write_to(const webnn::MLBufferUsage& usage) {
|
||||||
|
return usage.Has(webnn::MLBufferUsageFlags::kWriteTo);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool read_from(const webnn::MLBufferUsage& usage) {
|
||||||
|
return usage.Has(webnn::MLBufferUsageFlags::kReadFrom);
|
||||||
|
}
|
||||||
|
|
||||||
static bool Read(webnn::mojom::BufferUsageDataView data,
|
static bool Read(webnn::mojom::BufferUsageDataView data,
|
||||||
webnn::MLBufferUsage* out) {
|
webnn::MLBufferUsage* out) {
|
||||||
out->Clear();
|
out->Clear();
|
||||||
@@ -24,6 +32,15 @@ struct StructTraits<webnn::mojom::BufferUsageDataView, webnn::MLBufferUsage> {
|
|||||||
if (data.web_gpu_interop()) {
|
if (data.web_gpu_interop()) {
|
||||||
out->Put(webnn::MLBufferUsageFlags::kWebGpuInterop);
|
out->Put(webnn::MLBufferUsageFlags::kWebGpuInterop);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (data.read_from()) {
|
||||||
|
out->Put(webnn::MLBufferUsageFlags::kReadFrom);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data.write_to()) {
|
||||||
|
out->Put(webnn::MLBufferUsageFlags::kWriteTo);
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@@ -12,6 +12,10 @@ import "services/webnn/public/mojom/webnn_error.mojom";
|
|||||||
// the concept of an enum set (see https://crbug.com/40130879#comment11).
|
// the concept of an enum set (see https://crbug.com/40130879#comment11).
|
||||||
struct BufferUsage {
|
struct BufferUsage {
|
||||||
bool web_gpu_interop;
|
bool web_gpu_interop;
|
||||||
|
// This buffer can be used with readBuffer().
|
||||||
|
bool read_from;
|
||||||
|
// This buffer can be used with writeBuffer().
|
||||||
|
bool write_to;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Description of the WebNNBuffer to create.
|
// Description of the WebNNBuffer to create.
|
||||||
|
@@ -16,8 +16,8 @@ WebNNBufferImpl::WebNNBufferImpl(
|
|||||||
WebNNContextImpl* context,
|
WebNNContextImpl* context,
|
||||||
mojom::BufferInfoPtr buffer_info)
|
mojom::BufferInfoPtr buffer_info)
|
||||||
: context_(context),
|
: context_(context),
|
||||||
// TODO(crbug.com/343638938): Use buffer_info->usage.
|
|
||||||
descriptor_(std::move(buffer_info->descriptor)),
|
descriptor_(std::move(buffer_info->descriptor)),
|
||||||
|
usage_(std::move(buffer_info->usage)),
|
||||||
receiver_(this, std::move(receiver)) {
|
receiver_(this, std::move(receiver)) {
|
||||||
// Safe to use base::Unretained because `this` owns `receiver_`.
|
// Safe to use base::Unretained because `this` owns `receiver_`.
|
||||||
receiver_.set_disconnect_handler(
|
receiver_.set_disconnect_handler(
|
||||||
@@ -27,11 +27,21 @@ WebNNBufferImpl::WebNNBufferImpl(
|
|||||||
WebNNBufferImpl::~WebNNBufferImpl() = default;
|
WebNNBufferImpl::~WebNNBufferImpl() = default;
|
||||||
|
|
||||||
void WebNNBufferImpl::ReadBuffer(ReadBufferCallback callback) {
|
void WebNNBufferImpl::ReadBuffer(ReadBufferCallback callback) {
|
||||||
|
if (!usage().Has(MLBufferUsageFlags::kReadFrom)) {
|
||||||
|
receiver_.ReportBadMessage(kBadMessageInvalidBuffer);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Call ReadBufferImpl() implemented by a backend.
|
// Call ReadBufferImpl() implemented by a backend.
|
||||||
ReadBufferImpl(std::move(callback));
|
ReadBufferImpl(std::move(callback));
|
||||||
}
|
}
|
||||||
|
|
||||||
void WebNNBufferImpl::WriteBuffer(mojo_base::BigBuffer src_buffer) {
|
void WebNNBufferImpl::WriteBuffer(mojo_base::BigBuffer src_buffer) {
|
||||||
|
if (!usage().Has(MLBufferUsageFlags::kWriteTo)) {
|
||||||
|
receiver_.ReportBadMessage(kBadMessageInvalidBuffer);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(https://crbug.com/40278771): Generate error using MLContext.
|
// TODO(https://crbug.com/40278771): Generate error using MLContext.
|
||||||
if (PackedByteLength() < src_buffer.size()) {
|
if (PackedByteLength() < src_buffer.size()) {
|
||||||
receiver_.ReportBadMessage(kBadMessageInvalidBuffer);
|
receiver_.ReportBadMessage(kBadMessageInvalidBuffer);
|
||||||
|
@@ -34,6 +34,7 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNBufferImpl
|
|||||||
|
|
||||||
OperandDataType data_type() const { return descriptor_.data_type(); }
|
OperandDataType data_type() const { return descriptor_.data_type(); }
|
||||||
const std::vector<uint32_t>& shape() const { return descriptor_.shape(); }
|
const std::vector<uint32_t>& shape() const { return descriptor_.shape(); }
|
||||||
|
MLBufferUsage usage() const { return usage_; }
|
||||||
|
|
||||||
size_t PackedByteLength() const { return descriptor_.PackedByteLength(); }
|
size_t PackedByteLength() const { return descriptor_.PackedByteLength(); }
|
||||||
size_t NumberOfElements() const { return descriptor_.NumberOfElements(); }
|
size_t NumberOfElements() const { return descriptor_.NumberOfElements(); }
|
||||||
@@ -70,6 +71,7 @@ class COMPONENT_EXPORT(WEBNN_SERVICE) WebNNBufferImpl
|
|||||||
void OnDisconnect();
|
void OnDisconnect();
|
||||||
|
|
||||||
const OperandDescriptor descriptor_;
|
const OperandDescriptor descriptor_;
|
||||||
|
const MLBufferUsage usage_;
|
||||||
|
|
||||||
mojo::AssociatedReceiver<mojom::WebNNBuffer> receiver_;
|
mojo::AssociatedReceiver<mojom::WebNNBuffer> receiver_;
|
||||||
|
|
||||||
|
@@ -310,7 +310,8 @@ TEST_F(WebNNBufferImplBackendTest, WriteBufferImplTest) {
|
|||||||
mojom::BufferInfo::New(
|
mojom::BufferInfo::New(
|
||||||
*OperandDescriptor::Create(OperandDataType::kUint8,
|
*OperandDescriptor::Create(OperandDataType::kUint8,
|
||||||
std::array<uint32_t, 2>{2, 2}),
|
std::array<uint32_t, 2>{2, 2}),
|
||||||
MLBufferUsage()));
|
MLBufferUsage{MLBufferUsageFlags::kWriteTo,
|
||||||
|
MLBufferUsageFlags::kReadFrom}));
|
||||||
if (buffer_result.has_value()) {
|
if (buffer_result.has_value()) {
|
||||||
webnn_buffer_remote = std::move(buffer_result.value().webnn_buffer_remote);
|
webnn_buffer_remote = std::move(buffer_result.value().webnn_buffer_remote);
|
||||||
}
|
}
|
||||||
@@ -353,7 +354,7 @@ TEST_F(WebNNBufferImplBackendTest, WriteBufferImplTooLargeTest) {
|
|||||||
mojom::BufferInfo::New(
|
mojom::BufferInfo::New(
|
||||||
*OperandDescriptor::Create(OperandDataType::kUint8,
|
*OperandDescriptor::Create(OperandDataType::kUint8,
|
||||||
std::array<uint32_t, 2>{2, 2}),
|
std::array<uint32_t, 2>{2, 2}),
|
||||||
MLBufferUsage()));
|
MLBufferUsage{MLBufferUsageFlags::kWriteTo}));
|
||||||
if (buffer_result.has_value()) {
|
if (buffer_result.has_value()) {
|
||||||
webnn_buffer_remote = std::move(buffer_result.value().webnn_buffer_remote);
|
webnn_buffer_remote = std::move(buffer_result.value().webnn_buffer_remote);
|
||||||
}
|
}
|
||||||
|
@@ -3069,6 +3069,8 @@ generated_namespace_sources_in_modules = [
|
|||||||
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_gpu_shader_stage.h",
|
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_gpu_shader_stage.h",
|
||||||
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_gpu_texture_usage.cc",
|
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_gpu_texture_usage.cc",
|
||||||
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_gpu_texture_usage.h",
|
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_gpu_texture_usage.h",
|
||||||
|
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_buffer_usage.cc",
|
||||||
|
"$root_gen_dir/third_party/blink/renderer/bindings/modules/v8/v8_ml_buffer_usage.h",
|
||||||
]
|
]
|
||||||
|
|
||||||
if (enable_compute_pressure) {
|
if (enable_compute_pressure) {
|
||||||
|
@@ -503,6 +503,7 @@ static_idl_files_in_modules = [
|
|||||||
"//third_party/blink/renderer/modules/ml/navigator_ml.idl",
|
"//third_party/blink/renderer/modules/ml/navigator_ml.idl",
|
||||||
"//third_party/blink/renderer/modules/ml/webnn/ml_buffer.idl",
|
"//third_party/blink/renderer/modules/ml/webnn/ml_buffer.idl",
|
||||||
"//third_party/blink/renderer/modules/ml/webnn/ml_buffer_descriptor.idl",
|
"//third_party/blink/renderer/modules/ml/webnn/ml_buffer_descriptor.idl",
|
||||||
|
"//third_party/blink/renderer/modules/ml/webnn/ml_buffer_usage.idl",
|
||||||
"//third_party/blink/renderer/modules/ml/webnn/ml_graph.idl",
|
"//third_party/blink/renderer/modules/ml/webnn/ml_graph.idl",
|
||||||
"//third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.idl",
|
"//third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.idl",
|
||||||
"//third_party/blink/renderer/modules/ml/webnn/ml_operand.idl",
|
"//third_party/blink/renderer/modules/ml/webnn/ml_operand.idl",
|
||||||
|
@@ -644,9 +644,14 @@ ScriptPromise<MLBuffer> MLContext::createBuffer(
|
|||||||
return ScriptPromise<MLBuffer>();
|
return ScriptPromise<MLBuffer>();
|
||||||
});
|
});
|
||||||
|
|
||||||
// TODO(crbug.com/343638938): Pass real buffer usages.
|
// WebNN bitfield values have the same value as enums.
|
||||||
auto buffer_info = webnn::mojom::blink::BufferInfo::New(
|
webnn::MLBufferUsage usage;
|
||||||
validated_descriptor, webnn::MLBufferUsage());
|
if (descriptor->hasUsage()) {
|
||||||
|
usage = webnn::MLBufferUsage::FromEnumBitmask(descriptor->usage());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto buffer_info =
|
||||||
|
webnn::mojom::blink::BufferInfo::New(validated_descriptor, usage);
|
||||||
|
|
||||||
auto* resolver = MakeGarbageCollected<ScriptPromiseResolver<MLBuffer>>(
|
auto* resolver = MakeGarbageCollected<ScriptPromiseResolver<MLBuffer>>(
|
||||||
script_state, exception_state.GetContext());
|
script_state, exception_state.GetContext());
|
||||||
@@ -657,7 +662,7 @@ ScriptPromise<MLBuffer> MLContext::createBuffer(
|
|||||||
std::move(buffer_info),
|
std::move(buffer_info),
|
||||||
WTF::BindOnce(&MLContext::DidCreateWebNNBuffer, WrapPersistent(this),
|
WTF::BindOnce(&MLContext::DidCreateWebNNBuffer, WrapPersistent(this),
|
||||||
std::move(scoped_trace), WrapPersistent(resolver),
|
std::move(scoped_trace), WrapPersistent(resolver),
|
||||||
std::move(validated_descriptor)));
|
std::move(validated_descriptor), usage));
|
||||||
|
|
||||||
return resolver->Promise();
|
return resolver->Promise();
|
||||||
}
|
}
|
||||||
@@ -726,6 +731,12 @@ ScriptPromise<DOMArrayBuffer> MLContext::readBuffer(
|
|||||||
return EmptyPromise();
|
return EmptyPromise();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!src_buffer->Usage().Has(webnn::MLBufferUsageFlags::kReadFrom)) {
|
||||||
|
exception_state.ThrowTypeError(
|
||||||
|
"The source buffer doesn't have read access.");
|
||||||
|
return EmptyPromise();
|
||||||
|
}
|
||||||
|
|
||||||
return src_buffer->ReadBufferImpl(script_state, exception_state);
|
return src_buffer->ReadBufferImpl(script_state, exception_state);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -788,6 +799,12 @@ void MLContext::WriteWebNNBuffer(ScriptState* script_state,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!dst_buffer->Usage().Has(webnn::MLBufferUsageFlags::kWriteTo)) {
|
||||||
|
exception_state.ThrowTypeError(
|
||||||
|
"The destination buffer doesn't have write access.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const size_t src_data_byte_length = src_data.size();
|
const size_t src_data_byte_length = src_data.size();
|
||||||
if (src_element_offset > src_data_byte_length / src_data_type_size_bytes) {
|
if (src_element_offset > src_data_byte_length / src_data_type_size_bytes) {
|
||||||
exception_state.ThrowTypeError(
|
exception_state.ThrowTypeError(
|
||||||
@@ -876,6 +893,7 @@ void MLContext::DidCreateWebNNBuffer(
|
|||||||
ScopedMLTrace scoped_trace,
|
ScopedMLTrace scoped_trace,
|
||||||
ScriptPromiseResolver<blink::MLBuffer>* resolver,
|
ScriptPromiseResolver<blink::MLBuffer>* resolver,
|
||||||
webnn::OperandDescriptor validated_descriptor,
|
webnn::OperandDescriptor validated_descriptor,
|
||||||
|
webnn::MLBufferUsage usage,
|
||||||
webnn::mojom::blink::CreateBufferResultPtr result) {
|
webnn::mojom::blink::CreateBufferResultPtr result) {
|
||||||
pending_resolvers_.erase(resolver);
|
pending_resolvers_.erase(resolver);
|
||||||
|
|
||||||
@@ -894,7 +912,7 @@ void MLContext::DidCreateWebNNBuffer(
|
|||||||
|
|
||||||
auto* buffer = MakeGarbageCollected<MLBuffer>(
|
auto* buffer = MakeGarbageCollected<MLBuffer>(
|
||||||
resolver->GetExecutionContext(), this, std::move(validated_descriptor),
|
resolver->GetExecutionContext(), this, std::move(validated_descriptor),
|
||||||
std::move(result->get_success()), base::PassKey<MLContext>());
|
usage, std::move(result->get_success()), base::PassKey<MLContext>());
|
||||||
buffers_.insert(buffer);
|
buffers_.insert(buffer);
|
||||||
|
|
||||||
resolver->Resolve(buffer);
|
resolver->Resolve(buffer);
|
||||||
|
@@ -164,6 +164,7 @@ class MODULES_EXPORT MLContext : public ScriptWrappable {
|
|||||||
void DidCreateWebNNBuffer(ScopedMLTrace scoped_trace,
|
void DidCreateWebNNBuffer(ScopedMLTrace scoped_trace,
|
||||||
ScriptPromiseResolver<blink::MLBuffer>* resolver,
|
ScriptPromiseResolver<blink::MLBuffer>* resolver,
|
||||||
webnn::OperandDescriptor validated_descriptor,
|
webnn::OperandDescriptor validated_descriptor,
|
||||||
|
webnn::MLBufferUsage usage,
|
||||||
webnn::mojom::blink::CreateBufferResultPtr result);
|
webnn::mojom::blink::CreateBufferResultPtr result);
|
||||||
|
|
||||||
V8MLDeviceType device_type_;
|
V8MLDeviceType device_type_;
|
||||||
|
@@ -21,10 +21,12 @@ MLBuffer::MLBuffer(
|
|||||||
ExecutionContext* execution_context,
|
ExecutionContext* execution_context,
|
||||||
MLContext* context,
|
MLContext* context,
|
||||||
webnn::OperandDescriptor descriptor,
|
webnn::OperandDescriptor descriptor,
|
||||||
|
webnn::MLBufferUsage usage,
|
||||||
webnn::mojom::blink::CreateBufferSuccessPtr create_buffer_success,
|
webnn::mojom::blink::CreateBufferSuccessPtr create_buffer_success,
|
||||||
base::PassKey<MLContext> /*pass_key*/)
|
base::PassKey<MLContext> /*pass_key*/)
|
||||||
: ml_context_(context),
|
: ml_context_(context),
|
||||||
descriptor_(std::move(descriptor)),
|
descriptor_(std::move(descriptor)),
|
||||||
|
usage_(usage),
|
||||||
webnn_handle_(std::move(create_buffer_success->buffer_handle)),
|
webnn_handle_(std::move(create_buffer_success->buffer_handle)),
|
||||||
remote_buffer_(execution_context) {
|
remote_buffer_(execution_context) {
|
||||||
remote_buffer_.Bind(
|
remote_buffer_.Bind(
|
||||||
@@ -52,6 +54,10 @@ Vector<uint32_t> MLBuffer::shape() const {
|
|||||||
return Vector<uint32_t>(descriptor_.shape());
|
return Vector<uint32_t>(descriptor_.shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t MLBuffer::usage() const {
|
||||||
|
return static_cast<uint32_t>(usage_.ToEnumBitmask());
|
||||||
|
}
|
||||||
|
|
||||||
void MLBuffer::destroy() {
|
void MLBuffer::destroy() {
|
||||||
// Calling OnConnectionError() will disconnect and destroy the buffer in
|
// Calling OnConnectionError() will disconnect and destroy the buffer in
|
||||||
// the service. The remote buffer must remain unbound after calling
|
// the service. The remote buffer must remain unbound after calling
|
||||||
@@ -71,6 +77,10 @@ const std::vector<uint32_t>& MLBuffer::Shape() const {
|
|||||||
return descriptor_.shape();
|
return descriptor_.shape();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const webnn::MLBufferUsage& MLBuffer::Usage() const {
|
||||||
|
return usage_;
|
||||||
|
}
|
||||||
|
|
||||||
uint64_t MLBuffer::PackedByteLength() const {
|
uint64_t MLBuffer::PackedByteLength() const {
|
||||||
return descriptor_.PackedByteLength();
|
return descriptor_.PackedByteLength();
|
||||||
}
|
}
|
||||||
|
@@ -11,6 +11,7 @@
|
|||||||
#include "services/webnn/public/mojom/webnn_buffer.mojom-blink.h"
|
#include "services/webnn/public/mojom/webnn_buffer.mojom-blink.h"
|
||||||
#include "services/webnn/public/mojom/webnn_context_provider.mojom-blink.h"
|
#include "services/webnn/public/mojom/webnn_context_provider.mojom-blink.h"
|
||||||
#include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h"
|
#include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h"
|
||||||
|
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_buffer_usage.h"
|
||||||
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_data_type.h"
|
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_data_type.h"
|
||||||
#include "third_party/blink/renderer/modules/ml/ml_trace.h"
|
#include "third_party/blink/renderer/modules/ml/ml_trace.h"
|
||||||
#include "third_party/blink/renderer/modules/modules_export.h"
|
#include "third_party/blink/renderer/modules/modules_export.h"
|
||||||
@@ -40,6 +41,7 @@ class MODULES_EXPORT MLBuffer : public ScriptWrappable {
|
|||||||
MLBuffer(ExecutionContext* execution_context,
|
MLBuffer(ExecutionContext* execution_context,
|
||||||
MLContext* context,
|
MLContext* context,
|
||||||
webnn::OperandDescriptor descriptor,
|
webnn::OperandDescriptor descriptor,
|
||||||
|
webnn::MLBufferUsage usage,
|
||||||
webnn::mojom::blink::CreateBufferSuccessPtr create_buffer_success,
|
webnn::mojom::blink::CreateBufferSuccessPtr create_buffer_success,
|
||||||
base::PassKey<MLContext> pass_key);
|
base::PassKey<MLContext> pass_key);
|
||||||
MLBuffer(const MLBuffer&) = delete;
|
MLBuffer(const MLBuffer&) = delete;
|
||||||
@@ -52,6 +54,8 @@ class MODULES_EXPORT MLBuffer : public ScriptWrappable {
|
|||||||
// ml_buffer.idl
|
// ml_buffer.idl
|
||||||
V8MLOperandDataType dataType() const;
|
V8MLOperandDataType dataType() const;
|
||||||
Vector<uint32_t> shape() const;
|
Vector<uint32_t> shape() const;
|
||||||
|
uint32_t usage() const;
|
||||||
|
|
||||||
void destroy();
|
void destroy();
|
||||||
|
|
||||||
// Convenience methods for accessing native types, which avoid a copy
|
// Convenience methods for accessing native types, which avoid a copy
|
||||||
@@ -59,6 +63,7 @@ class MODULES_EXPORT MLBuffer : public ScriptWrappable {
|
|||||||
const webnn::OperandDescriptor& Descriptor() const;
|
const webnn::OperandDescriptor& Descriptor() const;
|
||||||
webnn::OperandDataType DataType() const;
|
webnn::OperandDataType DataType() const;
|
||||||
const std::vector<uint32_t>& Shape() const;
|
const std::vector<uint32_t>& Shape() const;
|
||||||
|
const webnn::MLBufferUsage& Usage() const;
|
||||||
|
|
||||||
uint64_t PackedByteLength() const;
|
uint64_t PackedByteLength() const;
|
||||||
|
|
||||||
@@ -105,6 +110,9 @@ class MODULES_EXPORT MLBuffer : public ScriptWrappable {
|
|||||||
// Represents a valid MLBufferDescriptor.
|
// Represents a valid MLBufferDescriptor.
|
||||||
const webnn::OperandDescriptor descriptor_;
|
const webnn::OperandDescriptor descriptor_;
|
||||||
|
|
||||||
|
// Represents a valid MLBufferUsage.
|
||||||
|
const webnn::MLBufferUsage usage_;
|
||||||
|
|
||||||
// Identifies this `WebNNBuffer` mojo instance in the service process.
|
// Identifies this `WebNNBuffer` mojo instance in the service process.
|
||||||
const blink::WebNNBufferToken webnn_handle_;
|
const blink::WebNNBufferToken webnn_handle_;
|
||||||
|
|
||||||
|
@@ -10,6 +10,7 @@
|
|||||||
] interface MLBuffer {
|
] interface MLBuffer {
|
||||||
readonly attribute MLOperandDataType dataType;
|
readonly attribute MLOperandDataType dataType;
|
||||||
readonly attribute FrozenArray<unsigned long> shape;
|
readonly attribute FrozenArray<unsigned long> shape;
|
||||||
|
readonly attribute MLBufferUsageFlags usage;
|
||||||
|
|
||||||
void destroy();
|
void destroy();
|
||||||
};
|
};
|
@@ -7,5 +7,5 @@
|
|||||||
typedef [EnforceRange] unsigned long long MLSize64;
|
typedef [EnforceRange] unsigned long long MLSize64;
|
||||||
|
|
||||||
dictionary MLBufferDescriptor : MLOperandDescriptor {
|
dictionary MLBufferDescriptor : MLOperandDescriptor {
|
||||||
// TODO(crbug.com/343638938): Add buffer usage flags.
|
MLBufferUsageFlags usage;
|
||||||
};
|
};
|
11
third_party/blink/renderer/modules/ml/webnn/ml_buffer_usage.h
vendored
Normal file
11
third_party/blink/renderer/modules/ml/webnn/ml_buffer_usage.h
vendored
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// Copyright 2024 The Chromium Authors
|
||||||
|
// Use of this source code is governed by a BSD-style license that can be
|
||||||
|
// found in the LICENSE file.
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_BLINK_RENDERER_MODULES_ML_WEBNN_ML_BUFFER_USAGE_H_
|
||||||
|
#define THIRD_PARTY_BLINK_RENDERER_MODULES_ML_WEBNN_ML_BUFFER_USAGE_H_
|
||||||
|
|
||||||
|
// This header is intentionally left blank
|
||||||
|
// Needs to exist because it is included by v8_ml_buffer_usage.cc
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_BLINK_RENDERER_MODULES_ML_WEBNN_ML_BUFFER_USAGE_H_
|
18
third_party/blink/renderer/modules/ml/webnn/ml_buffer_usage.idl
vendored
Normal file
18
third_party/blink/renderer/modules/ml/webnn/ml_buffer_usage.idl
vendored
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
// Copyright 2024 The Chromium Authors
|
||||||
|
// Use of this source code is governed by a BSD-style license that can be
|
||||||
|
// found in the LICENSE file.
|
||||||
|
|
||||||
|
// https://www.w3.org/TR/webnn/#api-mlbuffer
|
||||||
|
|
||||||
|
typedef unsigned long MLFlagsConstant;
|
||||||
|
|
||||||
|
typedef [EnforceRange] unsigned long MLBufferUsageFlags;
|
||||||
|
[
|
||||||
|
RuntimeEnabled=MachineLearningNeuralNetwork,
|
||||||
|
Exposed=(Window, DedicatedWorker),
|
||||||
|
SecureContext
|
||||||
|
] namespace MLBufferUsage {
|
||||||
|
const MLFlagsConstant WEBGPU_INTEROP = 1;
|
||||||
|
const MLFlagsConstant READ_FROM = 2;
|
||||||
|
const MLFlagsConstant WRITE_TO = 4;
|
||||||
|
};
|
@@ -39,6 +39,7 @@
|
|||||||
#include "third_party/blink/renderer/bindings/core/v8/v8_binding_for_testing.h"
|
#include "third_party/blink/renderer/bindings/core/v8/v8_binding_for_testing.h"
|
||||||
#include "third_party/blink/renderer/bindings/core/v8/v8_dom_exception.h"
|
#include "third_party/blink/renderer/bindings/core/v8/v8_dom_exception.h"
|
||||||
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_buffer_descriptor.h"
|
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_buffer_descriptor.h"
|
||||||
|
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_buffer_usage.h"
|
||||||
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_clamp_options.h"
|
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_clamp_options.h"
|
||||||
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_compute_result.h"
|
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_compute_result.h"
|
||||||
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_context_options.h"
|
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_context_options.h"
|
||||||
@@ -776,6 +777,8 @@ MLBuffer* CreateMLBufferForOperand(V8TestingScope& scope,
|
|||||||
auto* desc = MLBufferDescriptor::Create();
|
auto* desc = MLBufferDescriptor::Create();
|
||||||
desc->setDataType(operand->dataType());
|
desc->setDataType(operand->dataType());
|
||||||
desc->setDimensions(operand->shape());
|
desc->setDimensions(operand->shape());
|
||||||
|
desc->setUsage(V8MLBufferUsage::Constant::kWriteTo |
|
||||||
|
V8MLBufferUsage::Constant::kReadFrom);
|
||||||
|
|
||||||
ScriptPromiseTester tester(
|
ScriptPromiseTester tester(
|
||||||
scope.GetScriptState(),
|
scope.GetScriptState(),
|
||||||
@@ -1324,6 +1327,8 @@ TEST_F(MLGraphTest, WriteWebNNBufferTest) {
|
|||||||
auto* desc = MLBufferDescriptor::Create();
|
auto* desc = MLBufferDescriptor::Create();
|
||||||
desc->setDataType(V8MLOperandDataType::Enum::kUint8);
|
desc->setDataType(V8MLOperandDataType::Enum::kUint8);
|
||||||
desc->setDimensions(kBufferShape);
|
desc->setDimensions(kBufferShape);
|
||||||
|
desc->setUsage(V8MLBufferUsage::Constant::kWriteTo |
|
||||||
|
V8MLBufferUsage::Constant::kReadFrom);
|
||||||
|
|
||||||
ScriptPromiseTester buffer_tester(
|
ScriptPromiseTester buffer_tester(
|
||||||
script_state,
|
script_state,
|
||||||
@@ -1417,6 +1422,7 @@ TEST_F(MLGraphTest, WriteWebNNBufferThenDestroyTest) {
|
|||||||
auto* desc = MLBufferDescriptor::Create();
|
auto* desc = MLBufferDescriptor::Create();
|
||||||
desc->setDataType(V8MLOperandDataType::Enum::kUint8);
|
desc->setDataType(V8MLOperandDataType::Enum::kUint8);
|
||||||
desc->setDimensions({2, 2});
|
desc->setDimensions({2, 2});
|
||||||
|
desc->setUsage(V8MLBufferUsage::Constant::kWriteTo);
|
||||||
|
|
||||||
ScriptPromiseTester buffer_tester(
|
ScriptPromiseTester buffer_tester(
|
||||||
script_state,
|
script_state,
|
||||||
@@ -1459,6 +1465,7 @@ TEST_F(MLGraphTest, ReadWebNNBufferThenDestroyTest) {
|
|||||||
auto* desc = MLBufferDescriptor::Create();
|
auto* desc = MLBufferDescriptor::Create();
|
||||||
desc->setDataType(V8MLOperandDataType::Enum::kFloat32);
|
desc->setDataType(V8MLOperandDataType::Enum::kFloat32);
|
||||||
desc->setDimensions({2, 2});
|
desc->setDimensions({2, 2});
|
||||||
|
desc->setUsage(V8MLBufferUsage::Constant::kReadFrom);
|
||||||
|
|
||||||
ScriptPromiseTester create_buffer_tester(
|
ScriptPromiseTester create_buffer_tester(
|
||||||
script_state,
|
script_state,
|
||||||
|
@@ -33,7 +33,11 @@ const sizeOfDescriptor = (descriptor) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const getDescriptorFromBuffer = (buffer) => {
|
const getDescriptorFromBuffer = (buffer) => {
|
||||||
return {dataType: buffer.dataType, dimensions: buffer.shape};
|
return {
|
||||||
|
dataType: buffer.dataType,
|
||||||
|
dimensions: buffer.shape,
|
||||||
|
usage: buffer.usage
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@@ -160,7 +164,11 @@ const testWriteWebNNBuffer = (testName) => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
const bufferDescriptor = {dataType: 'int32', dimensions: [1]};
|
const bufferDescriptor = {
|
||||||
|
dataType: 'int32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.WRITE_TO,
|
||||||
|
};
|
||||||
let mlBuffer = await mlContext.createBuffer(bufferDescriptor);
|
let mlBuffer = await mlContext.createBuffer(bufferDescriptor);
|
||||||
|
|
||||||
const bufferByteLength = sizeOfDescriptor(bufferDescriptor);
|
const bufferByteLength = sizeOfDescriptor(bufferDescriptor);
|
||||||
@@ -205,7 +213,11 @@ const testWriteWebNNBuffer = (testName) => {
|
|||||||
}, `${testName} / error`);
|
}, `${testName} / error`);
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
const bufferDescriptor = {dataType: 'int32', dimensions: [2, 2]};
|
const bufferDescriptor = {
|
||||||
|
dataType: 'int32',
|
||||||
|
dimensions: [2, 2],
|
||||||
|
usage: MLBufferUsage.WRITE_TO,
|
||||||
|
};
|
||||||
let mlBuffer = await mlContext.createBuffer(bufferDescriptor);
|
let mlBuffer = await mlContext.createBuffer(bufferDescriptor);
|
||||||
|
|
||||||
// Writing data to a destroyed MLBuffer should throw.
|
// Writing data to a destroyed MLBuffer should throw.
|
||||||
@@ -218,7 +230,11 @@ const testWriteWebNNBuffer = (testName) => {
|
|||||||
}, `${testName} / destroy`);
|
}, `${testName} / destroy`);
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
const bufferDescriptor = {dataType: 'int32', dimensions: [2, 3]};
|
const bufferDescriptor = {
|
||||||
|
dataType: 'int32',
|
||||||
|
dimensions: [2, 3],
|
||||||
|
usage: MLBufferUsage.WRITE_TO,
|
||||||
|
};
|
||||||
let mlBuffer = await mlContext.createBuffer(bufferDescriptor);
|
let mlBuffer = await mlContext.createBuffer(bufferDescriptor);
|
||||||
|
|
||||||
let anotherMLContext = await navigator.ml.createContext(contextOptions);
|
let anotherMLContext = await navigator.ml.createContext(contextOptions);
|
||||||
@@ -233,8 +249,11 @@ const testWriteWebNNBuffer = (testName) => {
|
|||||||
}, `${testName} / context_mismatch`);
|
}, `${testName} / context_mismatch`);
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
let mlBuffer =
|
let mlBuffer = await mlContext.createBuffer({
|
||||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [1]});
|
dataType: 'int32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||||
|
});
|
||||||
|
|
||||||
// Initialize the buffer.
|
// Initialize the buffer.
|
||||||
const inputData = Uint8Array.from([0xAA, 0xAA, 0xAA, 0xAA]);
|
const inputData = Uint8Array.from([0xAA, 0xAA, 0xAA, 0xAA]);
|
||||||
@@ -253,7 +272,11 @@ const testWriteWebNNBuffer = (testName) => {
|
|||||||
}, `${testName} / zero_write`);
|
}, `${testName} / zero_write`);
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
const bufferDescriptor = {dataType: 'int32', dimensions: [2, 2]};
|
const bufferDescriptor = {
|
||||||
|
dataType: 'int32',
|
||||||
|
dimensions: [2, 2],
|
||||||
|
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||||
|
};
|
||||||
let mlBuffer = await mlContext.createBuffer(bufferDescriptor);
|
let mlBuffer = await mlContext.createBuffer(bufferDescriptor);
|
||||||
|
|
||||||
const bufferByteLength = sizeOfDescriptor(bufferDescriptor);
|
const bufferByteLength = sizeOfDescriptor(bufferDescriptor);
|
||||||
@@ -300,8 +323,11 @@ const testReadWebNNBuffer = (testName) => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
promise_test(async t => {
|
promise_test(async t => {
|
||||||
let mlBuffer =
|
let mlBuffer = await mlContext.createBuffer({
|
||||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [2, 2]});
|
dataType: 'int32',
|
||||||
|
dimensions: [2, 2],
|
||||||
|
usage: MLBufferUsage.READ_FROM,
|
||||||
|
});
|
||||||
|
|
||||||
// Reading a destroyed MLBuffer should reject.
|
// Reading a destroyed MLBuffer should reject.
|
||||||
mlBuffer.destroy();
|
mlBuffer.destroy();
|
||||||
@@ -311,8 +337,11 @@ const testReadWebNNBuffer = (testName) => {
|
|||||||
}, `${testName} / read_after_destroy`);
|
}, `${testName} / read_after_destroy`);
|
||||||
|
|
||||||
promise_test(async t => {
|
promise_test(async t => {
|
||||||
let mlBuffer =
|
let mlBuffer = await mlContext.createBuffer({
|
||||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [2, 3]});
|
dataType: 'int32',
|
||||||
|
dimensions: [2, 3],
|
||||||
|
usage: MLBufferUsage.READ_FROM,
|
||||||
|
});
|
||||||
|
|
||||||
let promise = mlContext.readBuffer(mlBuffer);
|
let promise = mlContext.readBuffer(mlBuffer);
|
||||||
let anotherPromise = mlContext.readBuffer(mlBuffer);
|
let anotherPromise = mlContext.readBuffer(mlBuffer);
|
||||||
@@ -324,16 +353,22 @@ const testReadWebNNBuffer = (testName) => {
|
|||||||
}, `${testName} / read_before_destroy`);
|
}, `${testName} / read_before_destroy`);
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
let mlBuffer =
|
let mlBuffer = await mlContext.createBuffer({
|
||||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [1024]});
|
dataType: 'int32',
|
||||||
|
dimensions: [1024],
|
||||||
|
usage: MLBufferUsage.READ_FROM,
|
||||||
|
});
|
||||||
|
|
||||||
await assert_buffer_data_equals(
|
await assert_buffer_data_equals(
|
||||||
mlContext, mlBuffer, new Uint32Array(1024));
|
mlContext, mlBuffer, new Uint32Array(1024));
|
||||||
}, `${testName} / uninitialized`);
|
}, `${testName} / uninitialized`);
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
let mlBuffer =
|
let mlBuffer = await mlContext.createBuffer({
|
||||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [1]});
|
dataType: 'int32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.READ_FROM | MLBufferUsage.WRITE_TO,
|
||||||
|
});
|
||||||
|
|
||||||
// Initialize the buffer.
|
// Initialize the buffer.
|
||||||
mlContext.writeBuffer(mlBuffer, Uint8Array.from([0xAA, 0xAA, 0xAA, 0xAA]));
|
mlContext.writeBuffer(mlBuffer, Uint8Array.from([0xAA, 0xAA, 0xAA, 0xAA]));
|
||||||
@@ -345,8 +380,11 @@ const testReadWebNNBuffer = (testName) => {
|
|||||||
}, `${testName} / full_size`);
|
}, `${testName} / full_size`);
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
let mlBuffer =
|
let mlBuffer = await mlContext.createBuffer({
|
||||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [1]});
|
dataType: 'int32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||||
|
});
|
||||||
|
|
||||||
// Initialize the buffer.
|
// Initialize the buffer.
|
||||||
mlContext.writeBuffer(mlBuffer, Uint8Array.from([0xAA, 0xAA, 0xAA, 0xAA]));
|
mlContext.writeBuffer(mlBuffer, Uint8Array.from([0xAA, 0xAA, 0xAA, 0xAA]));
|
||||||
@@ -360,8 +398,11 @@ const testReadWebNNBuffer = (testName) => {
|
|||||||
}, `${testName} / src_offset_only`);
|
}, `${testName} / src_offset_only`);
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
let mlBuffer =
|
let mlBuffer = await mlContext.createBuffer({
|
||||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [1]});
|
dataType: 'int32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||||
|
});
|
||||||
|
|
||||||
// Initialize the buffer.
|
// Initialize the buffer.
|
||||||
mlContext.writeBuffer(mlBuffer, Uint8Array.from([0xAA, 0xAA, 0xAA, 0xAA]));
|
mlContext.writeBuffer(mlBuffer, Uint8Array.from([0xAA, 0xAA, 0xAA, 0xAA]));
|
||||||
@@ -375,8 +416,11 @@ const testReadWebNNBuffer = (testName) => {
|
|||||||
}, `${testName} / src_offset_and_size`);
|
}, `${testName} / src_offset_and_size`);
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
let mlBuffer =
|
let mlBuffer = await mlContext.createBuffer({
|
||||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [1]});
|
dataType: 'int32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||||
|
});
|
||||||
|
|
||||||
// Initialize the buffer.
|
// Initialize the buffer.
|
||||||
mlContext.writeBuffer(mlBuffer, Uint8Array.from([0xAA, 0xAA, 0xAA, 0xAA]));
|
mlContext.writeBuffer(mlBuffer, Uint8Array.from([0xAA, 0xAA, 0xAA, 0xAA]));
|
||||||
@@ -390,8 +434,11 @@ const testReadWebNNBuffer = (testName) => {
|
|||||||
}, `${testName} / larger_src_data`);
|
}, `${testName} / larger_src_data`);
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
let mlBuffer =
|
let mlBuffer = await mlContext.createBuffer({
|
||||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [1]});
|
dataType: 'int32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||||
|
});
|
||||||
|
|
||||||
const inputData = [0xAA, 0xAA, 0xAA, 0xAA];
|
const inputData = [0xAA, 0xAA, 0xAA, 0xAA];
|
||||||
|
|
||||||
@@ -404,7 +451,11 @@ const testReadWebNNBuffer = (testName) => {
|
|||||||
}, `${testName} / no_src_offset`);
|
}, `${testName} / no_src_offset`);
|
||||||
|
|
||||||
promise_test(async t => {
|
promise_test(async t => {
|
||||||
const bufferDescriptor = {dataType: 'int32', dimensions: [2, 3]};
|
const bufferDescriptor = {
|
||||||
|
dataType: 'int32',
|
||||||
|
dimensions: [2, 3],
|
||||||
|
usage: MLBufferUsage.READ_FROM,
|
||||||
|
};
|
||||||
let mlBuffer = await mlContext.createBuffer(bufferDescriptor);
|
let mlBuffer = await mlContext.createBuffer(bufferDescriptor);
|
||||||
|
|
||||||
let anotherMLContext = await navigator.ml.createContext(contextOptions);
|
let anotherMLContext = await navigator.ml.createContext(contextOptions);
|
||||||
@@ -436,7 +487,11 @@ const testDispatchWebNNBuffer = (testName) => {
|
|||||||
}
|
}
|
||||||
// Construct a simple graph: A = B + C, with two outputs.
|
// Construct a simple graph: A = B + C, with two outputs.
|
||||||
const builder = new MLGraphBuilder(mlContext);
|
const builder = new MLGraphBuilder(mlContext);
|
||||||
const bufferDescriptor = {dataType: 'float32', dimensions: shape};
|
const bufferDescriptor = {
|
||||||
|
dataType: 'float32',
|
||||||
|
dimensions: shape,
|
||||||
|
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||||
|
};
|
||||||
const lhsOperand = builder.input('lhs', bufferDescriptor);
|
const lhsOperand = builder.input('lhs', bufferDescriptor);
|
||||||
const rhsOperand = builder.input('rhs', bufferDescriptor);
|
const rhsOperand = builder.input('rhs', bufferDescriptor);
|
||||||
const output1Operand = builder.add(lhsOperand, rhsOperand);
|
const output1Operand = builder.add(lhsOperand, rhsOperand);
|
||||||
|
21
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/byob_readbuffer.https.any.js
vendored
21
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/byob_readbuffer.https.any.js
vendored
@@ -29,8 +29,11 @@ promise_setup(async () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
mlBuffer =
|
mlBuffer = await mlContext.createBuffer({
|
||||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [2, 4]});
|
dataType: 'int32',
|
||||||
|
dimensions: [2, 4],
|
||||||
|
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||||
|
});
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
throw new AssertionError(
|
throw new AssertionError(
|
||||||
`Unable to create buffer for ${variant} variant. ${e}`);
|
`Unable to create buffer for ${variant} variant. ${e}`);
|
||||||
@@ -135,8 +138,11 @@ promise_test(async () => {
|
|||||||
}, `readBuffer() with a larger TypedArray`);
|
}, `readBuffer() with a larger TypedArray`);
|
||||||
|
|
||||||
promise_test(async (t) => {
|
promise_test(async (t) => {
|
||||||
const buffer =
|
const buffer = await mlContext.createBuffer({
|
||||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [2, 2]});
|
dataType: 'int32',
|
||||||
|
dimensions: [2, 2],
|
||||||
|
usage: MLBufferUsage.READ_FROM,
|
||||||
|
});
|
||||||
const arrayBufferView = new Int32Array(2 * 2);
|
const arrayBufferView = new Int32Array(2 * 2);
|
||||||
const arrayBuffer = arrayBufferView.buffer;
|
const arrayBuffer = arrayBufferView.buffer;
|
||||||
|
|
||||||
@@ -150,8 +156,11 @@ promise_test(async (t) => {
|
|||||||
}, `readBuffer() rejects on a destroyed MLBuffer`);
|
}, `readBuffer() rejects on a destroyed MLBuffer`);
|
||||||
|
|
||||||
promise_test(async (t) => {
|
promise_test(async (t) => {
|
||||||
const buffer =
|
const buffer = await mlContext.createBuffer({
|
||||||
await mlContext.createBuffer({dataType: 'int32', dimensions: [2, 2]});
|
dataType: 'int32',
|
||||||
|
dimensions: [2, 2],
|
||||||
|
usage: MLBufferUsage.READ_FROM,
|
||||||
|
});
|
||||||
const arrayBufferView = new Int32Array(2 * 2);
|
const arrayBufferView = new Int32Array(2 * 2);
|
||||||
const arrayBuffer = arrayBufferView.buffer;
|
const arrayBuffer = arrayBufferView.buffer;
|
||||||
|
|
||||||
|
54
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/parallel-dispatch.https.any.js
vendored
54
third_party/blink/web_tests/external/wpt/webnn/conformance_tests/parallel-dispatch.https.any.js
vendored
@@ -30,7 +30,11 @@ function buildMulGraph(context, operandDescriptor, multiplier) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
const operandDescriptor = {
|
||||||
|
dataType: 'float32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||||
|
};
|
||||||
|
|
||||||
const [mlGraph, inputBuffer1, inputBuffer2, outputBuffer] =
|
const [mlGraph, inputBuffer1, inputBuffer2, outputBuffer] =
|
||||||
await Promise.all([
|
await Promise.all([
|
||||||
@@ -66,7 +70,11 @@ promise_test(async () => {
|
|||||||
}, 'dispatch queues behind readBuffer');
|
}, 'dispatch queues behind readBuffer');
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
const operandDescriptor = {
|
||||||
|
dataType: 'float32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||||
|
};
|
||||||
const mlGraph = await buildMulGraph(mlContext, operandDescriptor, 3);
|
const mlGraph = await buildMulGraph(mlContext, operandDescriptor, 3);
|
||||||
|
|
||||||
// write/dispatch/read, write/dispatch/read, ...
|
// write/dispatch/read, write/dispatch/read, ...
|
||||||
@@ -90,7 +98,11 @@ promise_test(async () => {
|
|||||||
}, 'same graph: write/dispatch/read, write/dispatch/read, ...');
|
}, 'same graph: write/dispatch/read, write/dispatch/read, ...');
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
const operandDescriptor = {
|
||||||
|
dataType: 'float32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||||
|
};
|
||||||
const mlGraph = await buildMulGraph(mlContext, operandDescriptor, 10);
|
const mlGraph = await buildMulGraph(mlContext, operandDescriptor, 10);
|
||||||
|
|
||||||
// write/write...
|
// write/write...
|
||||||
@@ -125,7 +137,11 @@ promise_test(async () => {
|
|||||||
}, 'same graph: write/write..., dispatch/read, dispatch/read, ...');
|
}, 'same graph: write/write..., dispatch/read, dispatch/read, ...');
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
const operandDescriptor = {
|
||||||
|
dataType: 'float32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||||
|
};
|
||||||
const mlGraph = await buildMulGraph(mlContext, operandDescriptor, 9);
|
const mlGraph = await buildMulGraph(mlContext, operandDescriptor, 9);
|
||||||
|
|
||||||
// write/write...
|
// write/write...
|
||||||
@@ -159,7 +175,11 @@ promise_test(async () => {
|
|||||||
}, 'same graph: write/write..., dispatch/dispatch..., read/read...');
|
}, 'same graph: write/write..., dispatch/dispatch..., read/read...');
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
const operandDescriptor = {
|
||||||
|
dataType: 'float32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||||
|
};
|
||||||
const mlGraph = await buildMulGraph(mlContext, operandDescriptor, 2);
|
const mlGraph = await buildMulGraph(mlContext, operandDescriptor, 2);
|
||||||
|
|
||||||
const buffers = await Promise.all([
|
const buffers = await Promise.all([
|
||||||
@@ -188,7 +208,11 @@ promise_test(async () => {
|
|||||||
}, 'same graph serial inputs: dispatch/dispatch..., read/read...');
|
}, 'same graph serial inputs: dispatch/dispatch..., read/read...');
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
const operandDescriptor = {
|
||||||
|
dataType: 'float32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||||
|
};
|
||||||
|
|
||||||
// write/write...
|
// write/write...
|
||||||
const testInputs = [1, 2, 3, 4];
|
const testInputs = [1, 2, 3, 4];
|
||||||
@@ -223,7 +247,11 @@ promise_test(async () => {
|
|||||||
}, 'different graphs: write/write..., dispatch/read, dispatch/read, ...');
|
}, 'different graphs: write/write..., dispatch/read, dispatch/read, ...');
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
const operandDescriptor = {
|
||||||
|
dataType: 'float32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||||
|
};
|
||||||
|
|
||||||
// write/write...
|
// write/write...
|
||||||
const testInputs = [1, 2, 3, 4];
|
const testInputs = [1, 2, 3, 4];
|
||||||
@@ -257,7 +285,11 @@ promise_test(async () => {
|
|||||||
}, 'different graphs: write/write..., dispatch/dispatch..., read/read...');
|
}, 'different graphs: write/write..., dispatch/dispatch..., read/read...');
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
const operandDescriptor = {
|
||||||
|
dataType: 'float32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||||
|
};
|
||||||
|
|
||||||
const graphs = await Promise.all([3, 2].map(async (multiplier) => {
|
const graphs = await Promise.all([3, 2].map(async (multiplier) => {
|
||||||
return buildMulGraph(mlContext, operandDescriptor, multiplier);
|
return buildMulGraph(mlContext, operandDescriptor, multiplier);
|
||||||
@@ -289,7 +321,11 @@ promise_test(async () => {
|
|||||||
}, 'different graphs serial inputs: dispatch/dispatch..., read/read...');
|
}, 'different graphs serial inputs: dispatch/dispatch..., read/read...');
|
||||||
|
|
||||||
promise_test(async () => {
|
promise_test(async () => {
|
||||||
const operandDescriptor = {dataType: 'float32', dimensions: [1]};
|
const operandDescriptor = {
|
||||||
|
dataType: 'float32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.WRITE_TO | MLBufferUsage.READ_FROM,
|
||||||
|
};
|
||||||
|
|
||||||
const graphs = await Promise.all([2, 3].map(async (multiplier) => {
|
const graphs = await Promise.all([2, 3].map(async (multiplier) => {
|
||||||
return buildMulGraph(mlContext, operandDescriptor, multiplier);
|
return buildMulGraph(mlContext, operandDescriptor, multiplier);
|
||||||
|
21
third_party/blink/web_tests/external/wpt/webnn/validation_tests/destroyContext.https.any.js
vendored
21
third_party/blink/web_tests/external/wpt/webnn/validation_tests/destroyContext.https.any.js
vendored
@@ -132,16 +132,22 @@ promise_test(async t => {
|
|||||||
|
|
||||||
promise_test(async t => {
|
promise_test(async t => {
|
||||||
const context = await navigator.ml.createContext(contextOptions);
|
const context = await navigator.ml.createContext(contextOptions);
|
||||||
const buffer =
|
const buffer = await context.createBuffer({
|
||||||
await context.createBuffer({dataType: 'float32', dimensions: [1]});
|
dataType: 'float32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.READ_FROM,
|
||||||
|
});
|
||||||
context.destroy();
|
context.destroy();
|
||||||
promise_rejects_dom(t, 'InvalidStateError', context.readBuffer(buffer));
|
promise_rejects_dom(t, 'InvalidStateError', context.readBuffer(buffer));
|
||||||
}, 'Destroyed context can not read buffer.');
|
}, 'Destroyed context can not read buffer.');
|
||||||
|
|
||||||
promise_test(async t => {
|
promise_test(async t => {
|
||||||
const context = await navigator.ml.createContext(contextOptions);
|
const context = await navigator.ml.createContext(contextOptions);
|
||||||
const buffer =
|
const buffer = await context.createBuffer({
|
||||||
await context.createBuffer({dataType: 'float32', dimensions: [1]});
|
dataType: 'float32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.READ_FROM,
|
||||||
|
});
|
||||||
let promise = context.readBuffer(buffer);
|
let promise = context.readBuffer(buffer);
|
||||||
context.destroy();
|
context.destroy();
|
||||||
promise_rejects_dom(t, 'InvalidStateError', promise);
|
promise_rejects_dom(t, 'InvalidStateError', promise);
|
||||||
@@ -152,8 +158,11 @@ promise_test(async t => {
|
|||||||
// Destroying another context doesn't impact the first context.
|
// Destroying another context doesn't impact the first context.
|
||||||
const another_context = await navigator.ml.createContext(contextOptions);
|
const another_context = await navigator.ml.createContext(contextOptions);
|
||||||
another_context.destroy();
|
another_context.destroy();
|
||||||
const buffer =
|
const buffer = await context.createBuffer({
|
||||||
await context.createBuffer({dataType: 'float32', dimensions: [1]});
|
dataType: 'float32',
|
||||||
|
dimensions: [1],
|
||||||
|
usage: MLBufferUsage.WRITE_TO,
|
||||||
|
});
|
||||||
let arrayBuffer = new ArrayBuffer(4);
|
let arrayBuffer = new ArrayBuffer(4);
|
||||||
context.destroy();
|
context.destroy();
|
||||||
assert_throws_dom('InvalidStateError', () => {
|
assert_throws_dom('InvalidStateError', () => {
|
||||||
|
Reference in New Issue
Block a user