0

Add GetCapabilities() to on device model service

This API allows querying what capabilities a model can support. This
will be used to determine if the model can be used for certain features.
This also updates chrome://on-device-internals to use this API to
automatically detect capabilities when starting the session.

Bug: 399200301
Change-Id: I5288bdd4b879f5d3a773d165f212940e7da3ece0
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6353258
Reviewed-by: Austin Sullivan <asully@chromium.org>
Commit-Queue: Clark DuVall <cduvall@chromium.org>
Reviewed-by: Tom Sepez <tsepez@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1432311}
This commit is contained in:
Clark DuVall
2025-03-13 13:00:01 -07:00
committed by Chromium LUCI CQ
parent 22eaa3855b
commit a05c688c9d
17 changed files with 173 additions and 67 deletions

@ -145,12 +145,6 @@
<option value="kHighestQuality">Highest Quality</option>
<option value="kFastestInference">Fastest Inference</option>
</select>
<cr-checkbox slot="suffix" checked="{{enableImageInput_}}">
Enable images
</cr-checkbox>
<cr-checkbox slot="suffix" checked="{{enableAudioInput_}}">
Enable audio
</cr-checkbox>
</div>
<div class="model-text">
[[getModelText_(modelPath_, loadModelDuration_, loadedPerformanceHint_)]]
@ -188,7 +182,7 @@
value="{{text_}}">
</cr-textarea>
<div class="multimodal-buttons" >
<div class="image-buttons" hidden$="[[!imagesEnabled_(model_, baseModel_)]]">
<div class="image-buttons" hidden$="[[!imagesEnabled_(capabilities_)]]">
<div class="image-error">[[imageError_]]</div>
<div hidden$="[[imageFile_]]">
<cr-button class="floating-button"
@ -205,7 +199,7 @@
[[imageFile_.name]]
</cr-button>
</div>
<div class="audio-buttons" hidden$="[[!audioEnabled_(model_, baseModel_)]]">
<div class="audio-buttons" hidden$="[[!audioEnabled_(capabilities_)]]">
<div class="audio-error">[[audioError_]]</div>
<div hidden$="[[audioFile_]]">
<cr-button class="floating-button"

@ -16,7 +16,7 @@ import type {CrInputElement} from '//resources/cr_elements/cr_input/cr_input.js'
import {PolymerElement} from '//resources/polymer/v3_0/polymer/polymer_bundled.min.js';
import {BrowserProxy} from './browser_proxy.js';
import type {InputPiece, ResponseChunk, ResponseSummary,AudioData} from './on_device_model.mojom-webui.js';
import type {AudioData, Capabilities, InputPiece, ResponseChunk, ResponseSummary} from './on_device_model.mojom-webui.js';
import {LoadModelResult, OnDeviceModelRemote, PerformanceClass, SessionRemote, StreamingResponderCallbackRouter, Token} from './on_device_model.mojom-webui.js';
import {ModelPerformanceHint} from './on_device_model_service.mojom-webui.js';
import {getTemplate} from './tools.html.js';
@ -113,10 +113,6 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
type: Array,
value: () => [],
},
baseModel_: {
type: Object,
value: null,
},
model_: {
type: Object,
value: null,
@ -134,10 +130,6 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
value: 0,
},
contextText_: String,
enableImageInput_: {
type: Boolean,
value: false,
},
topK_: {
type: Number,
value: 1,
@ -170,6 +162,7 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
}
private capabilities_: Capabilities = {imageInput: false, audioInput: false};
private contextExpanded_: boolean;
private contextLength_: number;
private contextText_: string;
@ -179,7 +172,6 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
private loadModelDuration_: number;
private loadModelStart_: number;
private modelPath_: string;
private baseModel_: OnDeviceModelRemote|null;
private model_: OnDeviceModelRemote|null;
private performanceClassText_: string;
private responses_: Response[];
@ -187,8 +179,6 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
private text_: string;
private topK_: number;
private imageFile_: File|null;
private enableAudioInput_: boolean;
private enableImageInput_: boolean;
private audioFile_: File|null;
private audioError_: string;
private performanceHint_: string;
@ -257,7 +247,6 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
}
this.error_ = 'Service crashed, please reload the model.';
this.model_ = null;
this.baseModel_ = null;
this.modelPath_ = '';
this.loadModelStart_ = 0;
this.$.modelInput.focus();
@ -284,16 +273,16 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
private async onModelSelected_() {
this.error_ = '';
if (this.baseModel_) {
this.baseModel_.$.close();
if (this.model_) {
this.model_.$.close();
}
if (this.model_) {
this.model_.$.close();
}
this.imageFile_ = null;
this.audioFile_ = null;
this.baseModel_ = null;
this.model_ = null;
this.capabilities_ = {imageInput: false, audioInput: false};
this.loadModelStart_ = new Date().getTime();
const performanceHint = ModelPerformanceHint[(
this.performanceHint_ as keyof typeof ModelPerformanceHint)];
@ -305,35 +294,16 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
// <if expr="not is_win">
const processedPath = modelPath;
// </if>
const baseModel = new OnDeviceModelRemote();
let newModel = new OnDeviceModelRemote();
let {result} = await this.proxy_.handler.loadModel(
const newModel = new OnDeviceModelRemote();
const {result, capabilities} = await this.proxy_.handler.loadModel(
{path: processedPath}, performanceHint,
baseModel.$.bindNewPipeAndPassReceiver());
if (result === LoadModelResult.kSuccess &&
(this.enableImageInput_ || this.enableAudioInput_)) {
result = (await baseModel.loadAdaptation(
{
enableImageInput: this.enableImageInput_,
enableAudioInput: this.enableAudioInput_,
maxTokens: 0,
assets: {
weights: null,
weightsPath: null,
},
},
newModel.$.bindNewPipeAndPassReceiver()))
.result;
} else {
// No adaptation needed, just use the base model.
newModel = baseModel;
}
newModel.$.bindNewPipeAndPassReceiver());
if (result !== LoadModelResult.kSuccess) {
this.error_ =
'Unable to load model. Specify a correct and absolute path.';
} else {
this.baseModel_ = baseModel;
this.model_ = newModel;
this.capabilities_ = capabilities;
this.model_.onConnectionError.addListener(() => {
this.onServiceCrashed_();
});
@ -364,8 +334,13 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
}
this.contextLength_ = 0;
this.session_ = new SessionRemote();
this.model_.startSession(
this.session_.$.bindNewPipeAndPassReceiver(), null);
this.model_.startSession(this.session_.$.bindNewPipeAndPassReceiver(), {
maxTokens: 0,
capabilities: {
imageInput: this.imagesEnabled_(),
audioInput: this.audioEnabled_(),
},
});
}
private onCancelClick_() {
@ -508,11 +483,11 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
}
private imagesEnabled_(): boolean {
return this.model_ !== this.baseModel_ && this.enableImageInput_;
return this.capabilities_.imageInput;
}
private audioEnabled_(): boolean {
return this.model_ !== this.baseModel_ && this.enableAudioInput_;
return this.capabilities_.audioInput;
}
private getModelText_(): string {

@ -51,7 +51,8 @@ interface OnDeviceInternalsPageHandler {
LoadModel(mojo_base.mojom.FilePath model_path,
on_device_model.mojom.ModelPerformanceHint performance_hint,
pending_receiver<on_device_model.mojom.OnDeviceModel> model) =>
(on_device_model.mojom.LoadModelResult result);
(on_device_model.mojom.LoadModelResult result,
on_device_model.mojom.Capabilities capabilities);
// Returns the performance class based on benchmarks run on the device.
GetEstimatedPerformanceClass() =>

@ -81,8 +81,14 @@ void OnDeviceInternalsPageHandler::LoadModel(
on_device_model::mojom::LoadModelResult::kFailedToLoadLibrary);
return;
}
GetService().LoadPlatformModel(uuid, std::move(model), mojo::NullRemote(),
std::move(callback));
GetService().LoadPlatformModel(
uuid, std::move(model), mojo::NullRemote(),
base::BindOnce(
[](LoadModelCallback callback,
on_device_model::mojom::LoadModelResult result) {
std::move(callback).Run(result, on_device_model::Capabilities());
},
std::move(callback)));
#else
// Warm the service while assets load in the background.
std::ignore = GetService();
@ -122,11 +128,26 @@ void OnDeviceInternalsPageHandler::OnModelAssetsLoaded(
ml::ModelPerformanceHint performance_hint,
on_device_model::ModelAssets assets) {
auto params = on_device_model::mojom::LoadModelParams::New();
params->assets = std::move(assets);
params->assets = assets;
params->max_tokens = 4096;
params->performance_hint = performance_hint;
GetService().LoadModel(std::move(params), std::move(model),
std::move(callback));
GetService().LoadModel(
std::move(params), std::move(model),
base::BindOnce(&OnDeviceInternalsPageHandler::OnModelLoaded,
weak_ptr_factory_.GetWeakPtr(), std::move(callback),
std::move(assets)));
}
void OnDeviceInternalsPageHandler::OnModelLoaded(
LoadModelCallback callback,
on_device_model::ModelAssets assets,
on_device_model::mojom::LoadModelResult result) {
if (result != on_device_model::mojom::LoadModelResult::kSuccess) {
std::move(callback).Run(result, on_device_model::Capabilities());
return;
}
GetService().GetCapabilities(std::move(assets),
base::BindOnce(std::move(callback), result));
}
#endif

@ -45,6 +45,9 @@ class OnDeviceInternalsPageHandler : public mojom::OnDeviceInternalsPageHandler,
LoadModelCallback callback,
ml::ModelPerformanceHint performance_hint,
on_device_model::ModelAssets assets);
void OnModelLoaded(LoadModelCallback callback,
on_device_model::ModelAssets assets,
on_device_model::mojom::LoadModelResult result);
#endif
// mojom::OnDeviceInternalsPageHandler:

@ -38,6 +38,16 @@ std::string PieceToString(const ml::InputPiece& piece) {
NOTREACHED();
}
std::string ReadFile(PlatformFile api_file) {
base::File file(static_cast<base::PlatformFile>(api_file));
std::vector<uint8_t> contents;
contents.resize(file.GetLength());
if (!file.ReadAndCheck(0, contents)) {
return std::string();
}
return std::string(contents.begin(), contents.end());
}
} // namespace
void InitDawnProcs(const DawnProcTable& procs) {}
@ -58,6 +68,13 @@ bool QueryGPUAdapter(void (*adapter_callback_fn)(WGPUAdapter adapter,
return false;
}
bool GetCapabilities(PlatformFile file, ChromeMLCapabilities& capabilities) {
std::string contents = ReadFile(file);
capabilities.image_input = contents.find("image") != std::string::npos;
capabilities.audio_input = contents.find("audio") != std::string::npos;
return true;
}
struct FakeModelInstance {
ml::ModelBackendType backend_type_;
ml::ModelPerformanceHint performance_hint;
@ -81,16 +98,6 @@ struct FakeCancelInstance {
bool cancelled = false;
};
std::string ReadFile(PlatformFile api_file) {
base::File file(static_cast<base::PlatformFile>(api_file));
std::vector<uint8_t> contents;
contents.resize(file.GetLength());
if (!file.ReadAndCheck(0, contents)) {
return std::string();
}
return std::string(contents.begin(), contents.end());
}
ChromeMLModel SessionCreateModel(const ChromeMLModelDescriptor* descriptor,
uintptr_t context,
ChromeMLScheduleFn schedule) {
@ -301,6 +308,7 @@ const ChromeMLAPI g_api = {
.DestroyModel = &DestroyModel,
.GetEstimatedPerformance = &GetEstimatedPerformance,
.QueryGPUAdapter = &QueryGPUAdapter,
.GetCapabilities = &GetCapabilities,
.SetFatalErrorNonGpuFn = &SetFatalErrorNonGpuFn,
.SessionCreateModel = &SessionCreateModel,

@ -229,6 +229,12 @@ struct GpuConfig {
WGPUBackendType backend_type;
};
// A set of capabilities that a model can have.
struct ChromeMLCapabilities {
bool image_input = false;
bool audio_input = false;
};
struct ChromeMLMetricsFns {
// Logs an exact sample for the named metric.
void (*RecordExactLinearHistogram)(const char* name,
@ -329,6 +335,10 @@ struct ChromeMLAPI {
void* userdata),
void* userdata);
// Gets the model capabilities for the model pointed to by `model_data`.
bool (*GetCapabilities)(PlatformFile file,
ChromeMLCapabilities& capabilities);
// Same as SetFatalErrorFn(), but for fatal errors that occur outside of the
// gpu.
void (*SetFatalErrorNonGpuFn)(ChromeMLFatalErrorFn error_fn);

@ -391,6 +391,36 @@ OnDeviceModelExecutor::CreateWithResult(
return base::unexpected(load_model_result);
}
// static
DISABLE_CFI_DLSYM
on_device_model::Capabilities OnDeviceModelExecutor::GetCapabilities(
const ChromeML& chrome_ml,
on_device_model::ModelAssets assets) {
on_device_model::Capabilities result;
if (!chrome_ml.api().GetCapabilities) {
return result;
}
PlatformFile platform_file;
std::string weights_path_str = assets.weights_path.AsUTF8Unsafe();
if (assets.weights.IsValid()) {
platform_file = assets.weights.TakePlatformFile();
} else {
base::File file(assets.weights_path, base::File::FLAG_OPEN);
platform_file = file.TakePlatformFile();
}
ChromeMLCapabilities capabilities;
chrome_ml.api().GetCapabilities(platform_file, capabilities);
if (capabilities.image_input) {
result.Put(on_device_model::CapabilityFlags::kImageInput);
}
if (capabilities.audio_input) {
result.Put(on_device_model::CapabilityFlags::kAudioInput);
}
return result;
}
std::unique_ptr<SessionImpl> OnDeviceModelExecutor::CreateSession(
const ScopedAdaptation* adaptation,
on_device_model::mojom::SessionParamsPtr params) {

@ -96,6 +96,10 @@ class COMPONENT_EXPORT(ON_DEVICE_MODEL_ML) OnDeviceModelExecutor final {
on_device_model::mojom::LoadModelParamsPtr params,
base::OnceClosure on_complete);
static on_device_model::Capabilities GetCapabilities(
const ChromeML& chrome_ml,
on_device_model::ModelAssets assets);
std::unique_ptr<SessionImpl> CreateSession(
const ScopedAdaptation* adaptation,
on_device_model::mojom::SessionParamsPtr params);

@ -385,6 +385,12 @@ void OnDeviceModelService::LoadModel(
std::move(callback).Run(mojom::LoadModelResult::kSuccess);
}
void OnDeviceModelService::GetCapabilities(ModelAssets assets,
GetCapabilitiesCallback callback) {
std::move(callback).Run(ml::OnDeviceModelExecutor::GetCapabilities(
*chrome_ml_, std::move(assets)));
}
void OnDeviceModelService::GetEstimatedPerformanceClass(
GetEstimatedPerformanceClassCallback callback) {
base::ElapsedTimer timer;

@ -71,6 +71,8 @@ class COMPONENT_EXPORT(ON_DEVICE_MODEL) OnDeviceModelService
void LoadModel(mojom::LoadModelParamsPtr params,
mojo::PendingReceiver<mojom::OnDeviceModel> model,
LoadModelCallback callback) override;
void GetCapabilities(ModelAssets assets,
GetCapabilitiesCallback callback) override;
void LoadTextSafetyModel(
mojom::TextSafetyModelParamsPtr params,
mojo::PendingReceiver<mojom::TextSafetyModel> model) override;

@ -605,5 +605,22 @@ TEST_F(OnDeviceModelServiceTest, PerformanceHint) {
ElementsAre("Fastest inference\n", "Context: foo\n"));
}
TEST_F(OnDeviceModelServiceTest, Capabilities) {
auto expect_capabilities = [&](const std::string& data,
const Capabilities& expected) {
FakeFile file(data);
ModelAssets assets;
assets.weights = file.Open();
base::test::TestFuture<const Capabilities&> future;
service()->GetCapabilities(std::move(assets), future.GetCallback());
EXPECT_EQ(expected, future.Take());
};
expect_capabilities("none", {});
expect_capabilities("image", {CapabilityFlags::kImageInput});
expect_capabilities("audio", {CapabilityFlags::kAudioInput});
expect_capabilities("image audio", {CapabilityFlags::kImageInput,
CapabilityFlags::kAudioInput});
}
} // namespace
} // namespace on_device_model

@ -54,6 +54,19 @@ ModelAssetPaths::ModelAssetPaths(const ModelAssetPaths&) = default;
ModelAssetPaths::~ModelAssetPaths() = default;
ModelAssets::ModelAssets() = default;
ModelAssets::ModelAssets(const ModelAssets& other)
: weights(other.weights.Duplicate()),
weights_path(other.weights_path),
sp_model_path(other.sp_model_path) {}
ModelAssets& ModelAssets::operator=(const ModelAssets& other) {
weights = other.weights.Duplicate();
weights_path = other.weights_path;
sp_model_path = other.sp_model_path;
return *this;
}
ModelAssets::ModelAssets(ModelAssets&&) = default;
ModelAssets& ModelAssets::operator=(ModelAssets&&) = default;
ModelAssets::~ModelAssets() = default;

@ -25,6 +25,8 @@ struct COMPONENT_EXPORT(ON_DEVICE_MODEL_CPP) ModelAssetPaths {
// execution.
struct COMPONENT_EXPORT(ON_DEVICE_MODEL_CPP) ModelAssets {
ModelAssets();
ModelAssets(const ModelAssets&);
ModelAssets& operator=(const ModelAssets&);
ModelAssets(ModelAssets&&);
ModelAssets& operator=(ModelAssets&&);
~ModelAssets();

@ -316,6 +316,20 @@ void FakeOnDeviceModelService::LoadModel(
std::move(callback).Run(mojom::LoadModelResult::kSuccess);
}
void FakeOnDeviceModelService::GetCapabilities(
ModelAssets assets,
GetCapabilitiesCallback callback) {
std::string contents = ReadFile(assets.weights);
Capabilities capabilities;
if (contents.find("image") != std::string::npos) {
capabilities.Put(CapabilityFlags::kImageInput);
}
if (contents.find("audio") != std::string::npos) {
capabilities.Put(CapabilityFlags::kAudioInput);
}
std::move(callback).Run(capabilities);
}
void FakeOnDeviceModelService::LoadTextSafetyModel(
mojom::TextSafetyModelParamsPtr params,
mojo::PendingReceiver<mojom::TextSafetyModel> model) {

@ -192,6 +192,8 @@ class FakeOnDeviceModelService : public mojom::OnDeviceModelService {
void LoadModel(mojom::LoadModelParamsPtr params,
mojo::PendingReceiver<mojom::OnDeviceModel> model,
LoadModelCallback callback) override;
void GetCapabilities(ModelAssets assets,
GetCapabilitiesCallback callback) override;
void LoadTextSafetyModel(
mojom::TextSafetyModelParamsPtr params,
mojo::PendingReceiver<mojom::TextSafetyModel> model) override;

@ -126,6 +126,10 @@ interface OnDeviceModelService {
LoadModel(LoadModelParams params, pending_receiver<OnDeviceModel> model)
=> (LoadModelResult result);
// Returns the capabilities for a model specified by `params`. Note that this
// will not load the model, so is much cheaper than calling LoadModel().
GetCapabilities(ModelAssets assets) => (Capabilities capabilities);
// Initializes a new TextSafetyModel with the provided params.
// The model is disconnected on any errors with it.
LoadTextSafetyModel(