0

webnn:: use mojo pipe to run CreateGraph for fuzzer

Now that the mojo fuzzer is enabled on linux, mac and windows, instead
of calling the individual GraphBuilderImpl to run a partial CreateGraph,
use the mojo pipe to call `CreateGraph` that both build and compile
graph.

Adds a custom seed corpus file to increase coverage.

Bug: 378956983
Change-Id: I7c33538ca1d1c5a66556be96dd6cb9d95481634c
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6036173
Reviewed-by: Reilly Grant <reillyg@chromium.org>
Commit-Queue: Phillis Tang <phillis@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1385384}
This commit is contained in:
Phillis Tang
2024-11-20 01:36:50 +00:00
committed by Chromium LUCI CQ
parent e0345d4a88
commit f1c20e5559
3 changed files with 169 additions and 93 deletions
services/webnn
BUILD.gnwebnn_graph_mojolpm_fuzzer.cc
webnn_graph_mojolpm_fuzzer_seed_corpus

@ -247,9 +247,10 @@ mojolpm_fuzzer_test("webnn_graph_mojolpm_fuzzer") {
proto_source = "webnn_graph_mojolpm_fuzzer.proto"
proto_deps = [ "//services/webnn/public/mojom:mojom_mojolpm" ]
testcase_proto_kind = "services.fuzzing.webnn_graph.proto.Testcase"
seed_corpus_sources =
[ "webnn_graph_mojolpm_fuzzer_seed_corpus/simple.textproto" ]
deps = [
":tflite_graph_builder",
":webnn_service",
"//base",
"//base/test:test_support",
@ -257,8 +258,4 @@ mojolpm_fuzzer_test("webnn_graph_mojolpm_fuzzer") {
"//services/webnn/public/mojom",
"//third_party/libprotobuf-mutator",
]
if (is_posix) {
deps += [ ":coreml_graph_builder" ]
}
}

@ -8,39 +8,39 @@
#include "base/files/scoped_temp_dir.h"
#include "base/memory/raw_ref.h"
#include "base/notreached.h"
#include "base/task/single_thread_task_runner.h"
#include "base/test/allow_check_is_test_for_testing.h"
#include "base/test/bind.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "base/test/test_timeouts.h"
#include "content/test/fuzzer/mojolpm_fuzzer_support.h"
#include "mojo/public/cpp/base/big_buffer.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "mojo/public/cpp/bindings/self_owned_receiver.h"
#include "services/webnn/public/mojom/features.mojom-features.h"
#include "services/webnn/public/mojom/webnn_context.mojom.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom.h"
#include "services/webnn/public/mojom/webnn_graph.mojom-mojolpm.h"
#include "services/webnn/public/mojom/webnn_graph.mojom.h"
#include "services/webnn/tflite/context_impl_tflite.h"
#include "services/webnn/tflite/graph_builder_tflite.h"
#include "services/webnn/webnn_constant_operand.h"
#include "services/webnn/webnn_context_impl.h"
#include "services/webnn/webnn_context_provider_impl.h"
#include "services/webnn/webnn_graph_builder_impl.h"
#include "services/webnn/webnn_graph_impl.h"
#include "services/webnn/webnn_graph_mojolpm_fuzzer.pb.h"
#include "third_party/libprotobuf-mutator/src/src/libfuzzer/libfuzzer_macro.h"
#if BUILDFLAG(IS_MAC)
#include "base/mac/mac_util.h"
#endif // BUILDFLAG(IS_MAC)
#if BUILDFLAG(IS_WIN)
#include "services/webnn/dml/adapter.h"
#include "services/webnn/dml/context_impl_dml.h"
#include "services/webnn/dml/graph_builder_dml.h"
#include "services/webnn/dml/graph_impl_dml.h"
#endif
#if BUILDFLAG(IS_POSIX)
#include "services/webnn/coreml/graph_builder_coreml.h"
#endif
namespace {
struct InitGlobals {
InitGlobals() {
InitGlobals()
: scoped_feature_list_(
webnn::mojom::features::kWebMachineLearningNeuralNetwork) {
mojo::core::Init();
bool success = base::CommandLine::Init(0, nullptr);
CHECK(success);
@ -63,6 +63,7 @@ struct InitGlobals {
}
std::unique_ptr<base::test::TaskEnvironment> task_environment;
base::test::ScopedFeatureList scoped_feature_list_;
#if BUILDFLAG(IS_WIN)
scoped_refptr<webnn::dml::Adapter> adapter;
#endif
@ -70,13 +71,57 @@ struct InitGlobals {
InitGlobals* init_globals = new InitGlobals();
#if BUILDFLAG(IS_WIN)
scoped_refptr<webnn::dml::Adapter> GetAdapter() {
return init_globals->adapter;
}
#endif
void BuildGraph(webnn::mojom::GraphInfoPtr graph_info,
webnn::mojom::CreateContextOptions::Device device =
webnn::mojom::CreateContextOptions::Device::kGpu) {
mojo::Remote<webnn::mojom::WebNNContextProvider> webnn_provider_remote;
mojo::Remote<webnn::mojom::WebNNContext> webnn_context_remote;
mojo::AssociatedRemote<webnn::mojom::WebNNGraphBuilder>
webnn_graph_builder_remote;
webnn::WebNNContextProviderImpl::CreateForTesting(
webnn_provider_remote.BindNewPipeAndPassReceiver());
// Create the ContextImpl through context provider.
base::test::TestFuture<webnn::mojom::CreateContextResultPtr>
create_context_future;
webnn_provider_remote->CreateWebNNContext(
webnn::mojom::CreateContextOptions::New(
device,
webnn::mojom::CreateContextOptions::PowerPreference::kDefault),
create_context_future.GetCallback());
webnn::mojom::CreateContextResultPtr create_context_result =
create_context_future.Take();
CHECK(create_context_result->is_success());
webnn_context_remote.Bind(
std::move(create_context_result->get_success()->context_remote));
EXPECT_TRUE(webnn_context_remote.is_bound());
// Create the GraphBuilder through the context.
webnn_context_remote->CreateGraphBuilder(
webnn_graph_builder_remote.BindNewEndpointAndPassReceiver());
base::test::TestFuture<webnn::mojom::CreateGraphResultPtr>
create_graph_future;
webnn_graph_builder_remote.set_disconnect_handler(
base::BindLambdaForTesting([&] {
create_graph_future.SetValue(webnn::mojom::CreateGraphResult::NewError(
webnn::mojom::Error::New(webnn::mojom::Error::Code::kUnknownError,
"Failed to create graph.")));
}));
webnn_graph_builder_remote->CreateGraph(std::move(graph_info),
create_graph_future.GetCallback());
ASSERT_TRUE(create_graph_future.Wait());
}
class WebnnGraphLPMFuzzer {
public:
explicit WebnnGraphLPMFuzzer(
@ -85,79 +130,11 @@ class WebnnGraphLPMFuzzer {
void NextAction() {
const auto& action = testcase_->actions(action_index_);
const auto& create_graph = action.create_graph();
#if BUILDFLAG(IS_POSIX)
auto graph_info_ptr_coreml = webnn::mojom::GraphInfo::New();
mojolpm::FromProto(create_graph.graph_info(), graph_info_ptr_coreml);
auto coreml_properties =
webnn::WebNNContextImpl::IntersectWithBaseProperties(
webnn::coreml::GraphBuilderCoreml::GetContextProperties());
if (webnn::WebNNGraphBuilderImpl::ValidateGraph(coreml_properties,
*graph_info_ptr_coreml)
.has_value()) {
// Test the Core ML graph builder.
base::ScopedTempDir temp_dir;
CHECK(temp_dir.CreateUniqueTempDir());
auto constant_operands =
webnn::WebNNGraphBuilderImpl::TakeConstants(*graph_info_ptr_coreml);
auto coreml_graph_builder =
webnn::coreml::GraphBuilderCoreml::CreateAndBuild(
*graph_info_ptr_coreml, std::move(coreml_properties),
constant_operands, temp_dir.GetPath());
}
#endif
#if BUILDFLAG(IS_WIN)
CHECK(GetAdapter());
auto dml_properties = webnn::WebNNContextImpl::IntersectWithBaseProperties(
webnn::dml::ContextImplDml::GetProperties(
GetAdapter()->max_supported_feature_level()));
auto graph_info_ptr_dml = webnn::mojom::GraphInfo::New();
mojolpm::FromProto(create_graph.graph_info(), graph_info_ptr_dml);
if (webnn::WebNNGraphBuilderImpl::ValidateGraph(dml_properties,
*graph_info_ptr_dml)
.has_value()) {
// Graph compilation relies on IDMLDevice1::CompileGraph introduced in
// DirectML version 1.2 (DML_FEATURE_LEVEL_2_1).
CHECK(GetAdapter()->IsDMLDeviceCompileGraphSupportedForTesting());
auto constant_operands =
webnn::WebNNGraphBuilderImpl::TakeConstants(*graph_info_ptr_dml);
webnn::dml::GraphBuilderDml graph_builder(GetAdapter()->dml_device());
std::unordered_map<uint64_t, uint32_t> constant_id_to_input_index_map;
webnn::dml::GraphImplDml::GraphBufferBindingInfo
graph_buffer_binding_info;
auto create_operator_result =
webnn::dml::GraphImplDml::CreateAndBuildInternal(
dml_properties, GetAdapter(), graph_info_ptr_dml,
constant_operands, graph_builder, constant_id_to_input_index_map,
graph_buffer_binding_info);
if (create_operator_result.has_value()) {
auto dml_graph_builder = graph_builder.Compile(DML_EXECUTION_FLAG_NONE);
}
}
#endif
auto tflite_properties =
webnn::WebNNContextImpl::IntersectWithBaseProperties(
webnn::tflite::GraphBuilderTflite::GetContextProperties());
auto graph_info_ptr_tflite = webnn::mojom::GraphInfo::New();
mojolpm::FromProto(create_graph.graph_info(), graph_info_ptr_tflite);
if (webnn::WebNNGraphBuilderImpl::ValidateGraph(tflite_properties,
*graph_info_ptr_tflite)
.has_value()) {
// Test the TFLite graph builder.
auto constant_operands =
webnn::WebNNGraphBuilderImpl::TakeConstants(*graph_info_ptr_tflite);
auto flatbuffer = webnn::tflite::GraphBuilderTflite::CreateAndBuild(
std::move(tflite_properties), *graph_info_ptr_tflite,
constant_operands);
}
++action_index_;
const auto& create_graph = action.create_graph();
auto graph_info_ptr = webnn::mojom::GraphInfo::New();
mojolpm::FromProto(create_graph.graph_info(), graph_info_ptr);
BuildGraph(std::move(graph_info_ptr));
}
bool IsFinished() { return action_index_ >= testcase_->actions_size(); }
@ -170,6 +147,19 @@ class WebnnGraphLPMFuzzer {
DEFINE_BINARY_PROTO_FUZZER(
const services::fuzzing::webnn_graph::proto::Testcase& testcase) {
#if BUILDFLAG(IS_MAC)
if (base::mac::MacOSVersion() < 14'00'00) {
GTEST_SKIP() << "Skipping test because WebNN is not supported on Mac OS "
<< base::mac::MacOSVersion();
}
#endif
#if BUILDFLAG(IS_WIN)
CHECK(GetAdapter());
// Graph compilation relies on IDMLDevice1::CompileGraph introduced in
// DirectML version 1.2 (DML_FEATURE_LEVEL_2_1).
CHECK(GetAdapter()->IsDMLDeviceCompileGraphSupportedForTesting());
#endif
WebnnGraphLPMFuzzer webnn_graph_fuzzer_instance(testcase);
while (!webnn_graph_fuzzer_instance.IsFinished()) {
webnn_graph_fuzzer_instance.NextAction();

@ -0,0 +1,89 @@
actions {
create_graph {
graph_info {
new {
id: 1
m_id_to_operand_map: {
values: {
key: {
value: 1
}
value: {
value: {
new: {
id: 1
m_kind: 0
m_name: "input"
m_descriptor: {
new: {
id: 1
m_data_type:0
m_shape: {
values: {
value: 1
}
}
}
}
}
}
}
}
values: {
key: {
value: 2
}
value: {
value: {
new: {
id: 2
m_kind: 2
m_name: "output"
m_descriptor: {
new: {
id: 2
m_data_type:0
m_shape: {
values: {
value: 1
}
}
}
}
}
}
}
}
}
m_input_operands: {
values: {
value: 1
}
}
m_output_operands: {
values: {
value: 2
}
}
m_operations: {
values: {
value: {
new: {
id: 1
m_gelu: {
new: {
id: 1
m_input_operand_id: 1
m_output_operand_id: 2
m_label: "test"
}
}
}
}
}
}
m_constant_id_to_buffer_map: {}
}
}
}
}