Add GetCapabilities() to on device model service
This API allows querying what capabilities a model can support. This will be used to determine if the model can be used for certain features. This also updates chrome://on-device-internals to use this API to automatically detect capabilities when starting the session. Bug: 399200301 Change-Id: I5288bdd4b879f5d3a773d165f212940e7da3ece0 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6353258 Reviewed-by: Austin Sullivan <asully@chromium.org> Commit-Queue: Clark DuVall <cduvall@chromium.org> Reviewed-by: Tom Sepez <tsepez@chromium.org> Cr-Commit-Position: refs/heads/main@{#1432311}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
22eaa3855b
commit
a05c688c9d
chrome/browser
resources
on_device_internals
ui
services/on_device_model
@@ -145,12 +145,6 @@
|
|||||||
<option value="kHighestQuality">Highest Quality</option>
|
<option value="kHighestQuality">Highest Quality</option>
|
||||||
<option value="kFastestInference">Fastest Inference</option>
|
<option value="kFastestInference">Fastest Inference</option>
|
||||||
</select>
|
</select>
|
||||||
<cr-checkbox slot="suffix" checked="{{enableImageInput_}}">
|
|
||||||
Enable images
|
|
||||||
</cr-checkbox>
|
|
||||||
<cr-checkbox slot="suffix" checked="{{enableAudioInput_}}">
|
|
||||||
Enable audio
|
|
||||||
</cr-checkbox>
|
|
||||||
</div>
|
</div>
|
||||||
<div class="model-text">
|
<div class="model-text">
|
||||||
[[getModelText_(modelPath_, loadModelDuration_, loadedPerformanceHint_)]]
|
[[getModelText_(modelPath_, loadModelDuration_, loadedPerformanceHint_)]]
|
||||||
@@ -188,7 +182,7 @@
|
|||||||
value="{{text_}}">
|
value="{{text_}}">
|
||||||
</cr-textarea>
|
</cr-textarea>
|
||||||
<div class="multimodal-buttons" >
|
<div class="multimodal-buttons" >
|
||||||
<div class="image-buttons" hidden$="[[!imagesEnabled_(model_, baseModel_)]]">
|
<div class="image-buttons" hidden$="[[!imagesEnabled_(capabilities_)]]">
|
||||||
<div class="image-error">[[imageError_]]</div>
|
<div class="image-error">[[imageError_]]</div>
|
||||||
<div hidden$="[[imageFile_]]">
|
<div hidden$="[[imageFile_]]">
|
||||||
<cr-button class="floating-button"
|
<cr-button class="floating-button"
|
||||||
@@ -205,7 +199,7 @@
|
|||||||
[[imageFile_.name]]
|
[[imageFile_.name]]
|
||||||
</cr-button>
|
</cr-button>
|
||||||
</div>
|
</div>
|
||||||
<div class="audio-buttons" hidden$="[[!audioEnabled_(model_, baseModel_)]]">
|
<div class="audio-buttons" hidden$="[[!audioEnabled_(capabilities_)]]">
|
||||||
<div class="audio-error">[[audioError_]]</div>
|
<div class="audio-error">[[audioError_]]</div>
|
||||||
<div hidden$="[[audioFile_]]">
|
<div hidden$="[[audioFile_]]">
|
||||||
<cr-button class="floating-button"
|
<cr-button class="floating-button"
|
||||||
|
@@ -16,7 +16,7 @@ import type {CrInputElement} from '//resources/cr_elements/cr_input/cr_input.js'
|
|||||||
import {PolymerElement} from '//resources/polymer/v3_0/polymer/polymer_bundled.min.js';
|
import {PolymerElement} from '//resources/polymer/v3_0/polymer/polymer_bundled.min.js';
|
||||||
|
|
||||||
import {BrowserProxy} from './browser_proxy.js';
|
import {BrowserProxy} from './browser_proxy.js';
|
||||||
import type {InputPiece, ResponseChunk, ResponseSummary,AudioData} from './on_device_model.mojom-webui.js';
|
import type {AudioData, Capabilities, InputPiece, ResponseChunk, ResponseSummary} from './on_device_model.mojom-webui.js';
|
||||||
import {LoadModelResult, OnDeviceModelRemote, PerformanceClass, SessionRemote, StreamingResponderCallbackRouter, Token} from './on_device_model.mojom-webui.js';
|
import {LoadModelResult, OnDeviceModelRemote, PerformanceClass, SessionRemote, StreamingResponderCallbackRouter, Token} from './on_device_model.mojom-webui.js';
|
||||||
import {ModelPerformanceHint} from './on_device_model_service.mojom-webui.js';
|
import {ModelPerformanceHint} from './on_device_model_service.mojom-webui.js';
|
||||||
import {getTemplate} from './tools.html.js';
|
import {getTemplate} from './tools.html.js';
|
||||||
@@ -113,10 +113,6 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
|||||||
type: Array,
|
type: Array,
|
||||||
value: () => [],
|
value: () => [],
|
||||||
},
|
},
|
||||||
baseModel_: {
|
|
||||||
type: Object,
|
|
||||||
value: null,
|
|
||||||
},
|
|
||||||
model_: {
|
model_: {
|
||||||
type: Object,
|
type: Object,
|
||||||
value: null,
|
value: null,
|
||||||
@@ -134,10 +130,6 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
|||||||
value: 0,
|
value: 0,
|
||||||
},
|
},
|
||||||
contextText_: String,
|
contextText_: String,
|
||||||
enableImageInput_: {
|
|
||||||
type: Boolean,
|
|
||||||
value: false,
|
|
||||||
},
|
|
||||||
topK_: {
|
topK_: {
|
||||||
type: Number,
|
type: Number,
|
||||||
value: 1,
|
value: 1,
|
||||||
@@ -170,6 +162,7 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private capabilities_: Capabilities = {imageInput: false, audioInput: false};
|
||||||
private contextExpanded_: boolean;
|
private contextExpanded_: boolean;
|
||||||
private contextLength_: number;
|
private contextLength_: number;
|
||||||
private contextText_: string;
|
private contextText_: string;
|
||||||
@@ -179,7 +172,6 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
|||||||
private loadModelDuration_: number;
|
private loadModelDuration_: number;
|
||||||
private loadModelStart_: number;
|
private loadModelStart_: number;
|
||||||
private modelPath_: string;
|
private modelPath_: string;
|
||||||
private baseModel_: OnDeviceModelRemote|null;
|
|
||||||
private model_: OnDeviceModelRemote|null;
|
private model_: OnDeviceModelRemote|null;
|
||||||
private performanceClassText_: string;
|
private performanceClassText_: string;
|
||||||
private responses_: Response[];
|
private responses_: Response[];
|
||||||
@@ -187,8 +179,6 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
|||||||
private text_: string;
|
private text_: string;
|
||||||
private topK_: number;
|
private topK_: number;
|
||||||
private imageFile_: File|null;
|
private imageFile_: File|null;
|
||||||
private enableAudioInput_: boolean;
|
|
||||||
private enableImageInput_: boolean;
|
|
||||||
private audioFile_: File|null;
|
private audioFile_: File|null;
|
||||||
private audioError_: string;
|
private audioError_: string;
|
||||||
private performanceHint_: string;
|
private performanceHint_: string;
|
||||||
@@ -257,7 +247,6 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
|||||||
}
|
}
|
||||||
this.error_ = 'Service crashed, please reload the model.';
|
this.error_ = 'Service crashed, please reload the model.';
|
||||||
this.model_ = null;
|
this.model_ = null;
|
||||||
this.baseModel_ = null;
|
|
||||||
this.modelPath_ = '';
|
this.modelPath_ = '';
|
||||||
this.loadModelStart_ = 0;
|
this.loadModelStart_ = 0;
|
||||||
this.$.modelInput.focus();
|
this.$.modelInput.focus();
|
||||||
@@ -284,16 +273,16 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
|||||||
|
|
||||||
private async onModelSelected_() {
|
private async onModelSelected_() {
|
||||||
this.error_ = '';
|
this.error_ = '';
|
||||||
if (this.baseModel_) {
|
if (this.model_) {
|
||||||
this.baseModel_.$.close();
|
this.model_.$.close();
|
||||||
}
|
}
|
||||||
if (this.model_) {
|
if (this.model_) {
|
||||||
this.model_.$.close();
|
this.model_.$.close();
|
||||||
}
|
}
|
||||||
this.imageFile_ = null;
|
this.imageFile_ = null;
|
||||||
this.audioFile_ = null;
|
this.audioFile_ = null;
|
||||||
this.baseModel_ = null;
|
|
||||||
this.model_ = null;
|
this.model_ = null;
|
||||||
|
this.capabilities_ = {imageInput: false, audioInput: false};
|
||||||
this.loadModelStart_ = new Date().getTime();
|
this.loadModelStart_ = new Date().getTime();
|
||||||
const performanceHint = ModelPerformanceHint[(
|
const performanceHint = ModelPerformanceHint[(
|
||||||
this.performanceHint_ as keyof typeof ModelPerformanceHint)];
|
this.performanceHint_ as keyof typeof ModelPerformanceHint)];
|
||||||
@@ -305,35 +294,16 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
|||||||
// <if expr="not is_win">
|
// <if expr="not is_win">
|
||||||
const processedPath = modelPath;
|
const processedPath = modelPath;
|
||||||
// </if>
|
// </if>
|
||||||
const baseModel = new OnDeviceModelRemote();
|
const newModel = new OnDeviceModelRemote();
|
||||||
let newModel = new OnDeviceModelRemote();
|
const {result, capabilities} = await this.proxy_.handler.loadModel(
|
||||||
let {result} = await this.proxy_.handler.loadModel(
|
|
||||||
{path: processedPath}, performanceHint,
|
{path: processedPath}, performanceHint,
|
||||||
baseModel.$.bindNewPipeAndPassReceiver());
|
newModel.$.bindNewPipeAndPassReceiver());
|
||||||
if (result === LoadModelResult.kSuccess &&
|
|
||||||
(this.enableImageInput_ || this.enableAudioInput_)) {
|
|
||||||
result = (await baseModel.loadAdaptation(
|
|
||||||
{
|
|
||||||
enableImageInput: this.enableImageInput_,
|
|
||||||
enableAudioInput: this.enableAudioInput_,
|
|
||||||
maxTokens: 0,
|
|
||||||
assets: {
|
|
||||||
weights: null,
|
|
||||||
weightsPath: null,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
newModel.$.bindNewPipeAndPassReceiver()))
|
|
||||||
.result;
|
|
||||||
} else {
|
|
||||||
// No adaptation needed, just use the base model.
|
|
||||||
newModel = baseModel;
|
|
||||||
}
|
|
||||||
if (result !== LoadModelResult.kSuccess) {
|
if (result !== LoadModelResult.kSuccess) {
|
||||||
this.error_ =
|
this.error_ =
|
||||||
'Unable to load model. Specify a correct and absolute path.';
|
'Unable to load model. Specify a correct and absolute path.';
|
||||||
} else {
|
} else {
|
||||||
this.baseModel_ = baseModel;
|
|
||||||
this.model_ = newModel;
|
this.model_ = newModel;
|
||||||
|
this.capabilities_ = capabilities;
|
||||||
this.model_.onConnectionError.addListener(() => {
|
this.model_.onConnectionError.addListener(() => {
|
||||||
this.onServiceCrashed_();
|
this.onServiceCrashed_();
|
||||||
});
|
});
|
||||||
@@ -364,8 +334,13 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
|||||||
}
|
}
|
||||||
this.contextLength_ = 0;
|
this.contextLength_ = 0;
|
||||||
this.session_ = new SessionRemote();
|
this.session_ = new SessionRemote();
|
||||||
this.model_.startSession(
|
this.model_.startSession(this.session_.$.bindNewPipeAndPassReceiver(), {
|
||||||
this.session_.$.bindNewPipeAndPassReceiver(), null);
|
maxTokens: 0,
|
||||||
|
capabilities: {
|
||||||
|
imageInput: this.imagesEnabled_(),
|
||||||
|
audioInput: this.audioEnabled_(),
|
||||||
|
},
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private onCancelClick_() {
|
private onCancelClick_() {
|
||||||
@@ -508,11 +483,11 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private imagesEnabled_(): boolean {
|
private imagesEnabled_(): boolean {
|
||||||
return this.model_ !== this.baseModel_ && this.enableImageInput_;
|
return this.capabilities_.imageInput;
|
||||||
}
|
}
|
||||||
|
|
||||||
private audioEnabled_(): boolean {
|
private audioEnabled_(): boolean {
|
||||||
return this.model_ !== this.baseModel_ && this.enableAudioInput_;
|
return this.capabilities_.audioInput;
|
||||||
}
|
}
|
||||||
|
|
||||||
private getModelText_(): string {
|
private getModelText_(): string {
|
||||||
|
@@ -51,7 +51,8 @@ interface OnDeviceInternalsPageHandler {
|
|||||||
LoadModel(mojo_base.mojom.FilePath model_path,
|
LoadModel(mojo_base.mojom.FilePath model_path,
|
||||||
on_device_model.mojom.ModelPerformanceHint performance_hint,
|
on_device_model.mojom.ModelPerformanceHint performance_hint,
|
||||||
pending_receiver<on_device_model.mojom.OnDeviceModel> model) =>
|
pending_receiver<on_device_model.mojom.OnDeviceModel> model) =>
|
||||||
(on_device_model.mojom.LoadModelResult result);
|
(on_device_model.mojom.LoadModelResult result,
|
||||||
|
on_device_model.mojom.Capabilities capabilities);
|
||||||
|
|
||||||
// Returns the performance class based on benchmarks run on the device.
|
// Returns the performance class based on benchmarks run on the device.
|
||||||
GetEstimatedPerformanceClass() =>
|
GetEstimatedPerformanceClass() =>
|
||||||
|
@@ -81,8 +81,14 @@ void OnDeviceInternalsPageHandler::LoadModel(
|
|||||||
on_device_model::mojom::LoadModelResult::kFailedToLoadLibrary);
|
on_device_model::mojom::LoadModelResult::kFailedToLoadLibrary);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
GetService().LoadPlatformModel(uuid, std::move(model), mojo::NullRemote(),
|
GetService().LoadPlatformModel(
|
||||||
std::move(callback));
|
uuid, std::move(model), mojo::NullRemote(),
|
||||||
|
base::BindOnce(
|
||||||
|
[](LoadModelCallback callback,
|
||||||
|
on_device_model::mojom::LoadModelResult result) {
|
||||||
|
std::move(callback).Run(result, on_device_model::Capabilities());
|
||||||
|
},
|
||||||
|
std::move(callback)));
|
||||||
#else
|
#else
|
||||||
// Warm the service while assets load in the background.
|
// Warm the service while assets load in the background.
|
||||||
std::ignore = GetService();
|
std::ignore = GetService();
|
||||||
@@ -122,11 +128,26 @@ void OnDeviceInternalsPageHandler::OnModelAssetsLoaded(
|
|||||||
ml::ModelPerformanceHint performance_hint,
|
ml::ModelPerformanceHint performance_hint,
|
||||||
on_device_model::ModelAssets assets) {
|
on_device_model::ModelAssets assets) {
|
||||||
auto params = on_device_model::mojom::LoadModelParams::New();
|
auto params = on_device_model::mojom::LoadModelParams::New();
|
||||||
params->assets = std::move(assets);
|
params->assets = assets;
|
||||||
params->max_tokens = 4096;
|
params->max_tokens = 4096;
|
||||||
params->performance_hint = performance_hint;
|
params->performance_hint = performance_hint;
|
||||||
GetService().LoadModel(std::move(params), std::move(model),
|
GetService().LoadModel(
|
||||||
std::move(callback));
|
std::move(params), std::move(model),
|
||||||
|
base::BindOnce(&OnDeviceInternalsPageHandler::OnModelLoaded,
|
||||||
|
weak_ptr_factory_.GetWeakPtr(), std::move(callback),
|
||||||
|
std::move(assets)));
|
||||||
|
}
|
||||||
|
|
||||||
|
void OnDeviceInternalsPageHandler::OnModelLoaded(
|
||||||
|
LoadModelCallback callback,
|
||||||
|
on_device_model::ModelAssets assets,
|
||||||
|
on_device_model::mojom::LoadModelResult result) {
|
||||||
|
if (result != on_device_model::mojom::LoadModelResult::kSuccess) {
|
||||||
|
std::move(callback).Run(result, on_device_model::Capabilities());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
GetService().GetCapabilities(std::move(assets),
|
||||||
|
base::BindOnce(std::move(callback), result));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@@ -45,6 +45,9 @@ class OnDeviceInternalsPageHandler : public mojom::OnDeviceInternalsPageHandler,
|
|||||||
LoadModelCallback callback,
|
LoadModelCallback callback,
|
||||||
ml::ModelPerformanceHint performance_hint,
|
ml::ModelPerformanceHint performance_hint,
|
||||||
on_device_model::ModelAssets assets);
|
on_device_model::ModelAssets assets);
|
||||||
|
void OnModelLoaded(LoadModelCallback callback,
|
||||||
|
on_device_model::ModelAssets assets,
|
||||||
|
on_device_model::mojom::LoadModelResult result);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// mojom::OnDeviceInternalsPageHandler:
|
// mojom::OnDeviceInternalsPageHandler:
|
||||||
|
@@ -38,6 +38,16 @@ std::string PieceToString(const ml::InputPiece& piece) {
|
|||||||
NOTREACHED();
|
NOTREACHED();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string ReadFile(PlatformFile api_file) {
|
||||||
|
base::File file(static_cast<base::PlatformFile>(api_file));
|
||||||
|
std::vector<uint8_t> contents;
|
||||||
|
contents.resize(file.GetLength());
|
||||||
|
if (!file.ReadAndCheck(0, contents)) {
|
||||||
|
return std::string();
|
||||||
|
}
|
||||||
|
return std::string(contents.begin(), contents.end());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void InitDawnProcs(const DawnProcTable& procs) {}
|
void InitDawnProcs(const DawnProcTable& procs) {}
|
||||||
@@ -58,6 +68,13 @@ bool QueryGPUAdapter(void (*adapter_callback_fn)(WGPUAdapter adapter,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool GetCapabilities(PlatformFile file, ChromeMLCapabilities& capabilities) {
|
||||||
|
std::string contents = ReadFile(file);
|
||||||
|
capabilities.image_input = contents.find("image") != std::string::npos;
|
||||||
|
capabilities.audio_input = contents.find("audio") != std::string::npos;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
struct FakeModelInstance {
|
struct FakeModelInstance {
|
||||||
ml::ModelBackendType backend_type_;
|
ml::ModelBackendType backend_type_;
|
||||||
ml::ModelPerformanceHint performance_hint;
|
ml::ModelPerformanceHint performance_hint;
|
||||||
@@ -81,16 +98,6 @@ struct FakeCancelInstance {
|
|||||||
bool cancelled = false;
|
bool cancelled = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string ReadFile(PlatformFile api_file) {
|
|
||||||
base::File file(static_cast<base::PlatformFile>(api_file));
|
|
||||||
std::vector<uint8_t> contents;
|
|
||||||
contents.resize(file.GetLength());
|
|
||||||
if (!file.ReadAndCheck(0, contents)) {
|
|
||||||
return std::string();
|
|
||||||
}
|
|
||||||
return std::string(contents.begin(), contents.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
ChromeMLModel SessionCreateModel(const ChromeMLModelDescriptor* descriptor,
|
ChromeMLModel SessionCreateModel(const ChromeMLModelDescriptor* descriptor,
|
||||||
uintptr_t context,
|
uintptr_t context,
|
||||||
ChromeMLScheduleFn schedule) {
|
ChromeMLScheduleFn schedule) {
|
||||||
@@ -301,6 +308,7 @@ const ChromeMLAPI g_api = {
|
|||||||
.DestroyModel = &DestroyModel,
|
.DestroyModel = &DestroyModel,
|
||||||
.GetEstimatedPerformance = &GetEstimatedPerformance,
|
.GetEstimatedPerformance = &GetEstimatedPerformance,
|
||||||
.QueryGPUAdapter = &QueryGPUAdapter,
|
.QueryGPUAdapter = &QueryGPUAdapter,
|
||||||
|
.GetCapabilities = &GetCapabilities,
|
||||||
.SetFatalErrorNonGpuFn = &SetFatalErrorNonGpuFn,
|
.SetFatalErrorNonGpuFn = &SetFatalErrorNonGpuFn,
|
||||||
|
|
||||||
.SessionCreateModel = &SessionCreateModel,
|
.SessionCreateModel = &SessionCreateModel,
|
||||||
|
@@ -229,6 +229,12 @@ struct GpuConfig {
|
|||||||
WGPUBackendType backend_type;
|
WGPUBackendType backend_type;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// A set of capabilities that a model can have.
|
||||||
|
struct ChromeMLCapabilities {
|
||||||
|
bool image_input = false;
|
||||||
|
bool audio_input = false;
|
||||||
|
};
|
||||||
|
|
||||||
struct ChromeMLMetricsFns {
|
struct ChromeMLMetricsFns {
|
||||||
// Logs an exact sample for the named metric.
|
// Logs an exact sample for the named metric.
|
||||||
void (*RecordExactLinearHistogram)(const char* name,
|
void (*RecordExactLinearHistogram)(const char* name,
|
||||||
@@ -329,6 +335,10 @@ struct ChromeMLAPI {
|
|||||||
void* userdata),
|
void* userdata),
|
||||||
void* userdata);
|
void* userdata);
|
||||||
|
|
||||||
|
// Gets the model capabilities for the model pointed to by `model_data`.
|
||||||
|
bool (*GetCapabilities)(PlatformFile file,
|
||||||
|
ChromeMLCapabilities& capabilities);
|
||||||
|
|
||||||
// Same as SetFatalErrorFn(), but for fatal errors that occur outside of the
|
// Same as SetFatalErrorFn(), but for fatal errors that occur outside of the
|
||||||
// gpu.
|
// gpu.
|
||||||
void (*SetFatalErrorNonGpuFn)(ChromeMLFatalErrorFn error_fn);
|
void (*SetFatalErrorNonGpuFn)(ChromeMLFatalErrorFn error_fn);
|
||||||
|
@@ -391,6 +391,36 @@ OnDeviceModelExecutor::CreateWithResult(
|
|||||||
return base::unexpected(load_model_result);
|
return base::unexpected(load_model_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// static
|
||||||
|
DISABLE_CFI_DLSYM
|
||||||
|
on_device_model::Capabilities OnDeviceModelExecutor::GetCapabilities(
|
||||||
|
const ChromeML& chrome_ml,
|
||||||
|
on_device_model::ModelAssets assets) {
|
||||||
|
on_device_model::Capabilities result;
|
||||||
|
if (!chrome_ml.api().GetCapabilities) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
PlatformFile platform_file;
|
||||||
|
std::string weights_path_str = assets.weights_path.AsUTF8Unsafe();
|
||||||
|
if (assets.weights.IsValid()) {
|
||||||
|
platform_file = assets.weights.TakePlatformFile();
|
||||||
|
} else {
|
||||||
|
base::File file(assets.weights_path, base::File::FLAG_OPEN);
|
||||||
|
platform_file = file.TakePlatformFile();
|
||||||
|
}
|
||||||
|
ChromeMLCapabilities capabilities;
|
||||||
|
chrome_ml.api().GetCapabilities(platform_file, capabilities);
|
||||||
|
|
||||||
|
if (capabilities.image_input) {
|
||||||
|
result.Put(on_device_model::CapabilityFlags::kImageInput);
|
||||||
|
}
|
||||||
|
if (capabilities.audio_input) {
|
||||||
|
result.Put(on_device_model::CapabilityFlags::kAudioInput);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<SessionImpl> OnDeviceModelExecutor::CreateSession(
|
std::unique_ptr<SessionImpl> OnDeviceModelExecutor::CreateSession(
|
||||||
const ScopedAdaptation* adaptation,
|
const ScopedAdaptation* adaptation,
|
||||||
on_device_model::mojom::SessionParamsPtr params) {
|
on_device_model::mojom::SessionParamsPtr params) {
|
||||||
|
@@ -96,6 +96,10 @@ class COMPONENT_EXPORT(ON_DEVICE_MODEL_ML) OnDeviceModelExecutor final {
|
|||||||
on_device_model::mojom::LoadModelParamsPtr params,
|
on_device_model::mojom::LoadModelParamsPtr params,
|
||||||
base::OnceClosure on_complete);
|
base::OnceClosure on_complete);
|
||||||
|
|
||||||
|
static on_device_model::Capabilities GetCapabilities(
|
||||||
|
const ChromeML& chrome_ml,
|
||||||
|
on_device_model::ModelAssets assets);
|
||||||
|
|
||||||
std::unique_ptr<SessionImpl> CreateSession(
|
std::unique_ptr<SessionImpl> CreateSession(
|
||||||
const ScopedAdaptation* adaptation,
|
const ScopedAdaptation* adaptation,
|
||||||
on_device_model::mojom::SessionParamsPtr params);
|
on_device_model::mojom::SessionParamsPtr params);
|
||||||
|
@@ -385,6 +385,12 @@ void OnDeviceModelService::LoadModel(
|
|||||||
std::move(callback).Run(mojom::LoadModelResult::kSuccess);
|
std::move(callback).Run(mojom::LoadModelResult::kSuccess);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OnDeviceModelService::GetCapabilities(ModelAssets assets,
|
||||||
|
GetCapabilitiesCallback callback) {
|
||||||
|
std::move(callback).Run(ml::OnDeviceModelExecutor::GetCapabilities(
|
||||||
|
*chrome_ml_, std::move(assets)));
|
||||||
|
}
|
||||||
|
|
||||||
void OnDeviceModelService::GetEstimatedPerformanceClass(
|
void OnDeviceModelService::GetEstimatedPerformanceClass(
|
||||||
GetEstimatedPerformanceClassCallback callback) {
|
GetEstimatedPerformanceClassCallback callback) {
|
||||||
base::ElapsedTimer timer;
|
base::ElapsedTimer timer;
|
||||||
|
@@ -71,6 +71,8 @@ class COMPONENT_EXPORT(ON_DEVICE_MODEL) OnDeviceModelService
|
|||||||
void LoadModel(mojom::LoadModelParamsPtr params,
|
void LoadModel(mojom::LoadModelParamsPtr params,
|
||||||
mojo::PendingReceiver<mojom::OnDeviceModel> model,
|
mojo::PendingReceiver<mojom::OnDeviceModel> model,
|
||||||
LoadModelCallback callback) override;
|
LoadModelCallback callback) override;
|
||||||
|
void GetCapabilities(ModelAssets assets,
|
||||||
|
GetCapabilitiesCallback callback) override;
|
||||||
void LoadTextSafetyModel(
|
void LoadTextSafetyModel(
|
||||||
mojom::TextSafetyModelParamsPtr params,
|
mojom::TextSafetyModelParamsPtr params,
|
||||||
mojo::PendingReceiver<mojom::TextSafetyModel> model) override;
|
mojo::PendingReceiver<mojom::TextSafetyModel> model) override;
|
||||||
|
@@ -605,5 +605,22 @@ TEST_F(OnDeviceModelServiceTest, PerformanceHint) {
|
|||||||
ElementsAre("Fastest inference\n", "Context: foo\n"));
|
ElementsAre("Fastest inference\n", "Context: foo\n"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OnDeviceModelServiceTest, Capabilities) {
|
||||||
|
auto expect_capabilities = [&](const std::string& data,
|
||||||
|
const Capabilities& expected) {
|
||||||
|
FakeFile file(data);
|
||||||
|
ModelAssets assets;
|
||||||
|
assets.weights = file.Open();
|
||||||
|
base::test::TestFuture<const Capabilities&> future;
|
||||||
|
service()->GetCapabilities(std::move(assets), future.GetCallback());
|
||||||
|
EXPECT_EQ(expected, future.Take());
|
||||||
|
};
|
||||||
|
expect_capabilities("none", {});
|
||||||
|
expect_capabilities("image", {CapabilityFlags::kImageInput});
|
||||||
|
expect_capabilities("audio", {CapabilityFlags::kAudioInput});
|
||||||
|
expect_capabilities("image audio", {CapabilityFlags::kImageInput,
|
||||||
|
CapabilityFlags::kAudioInput});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace on_device_model
|
} // namespace on_device_model
|
||||||
|
@@ -54,6 +54,19 @@ ModelAssetPaths::ModelAssetPaths(const ModelAssetPaths&) = default;
|
|||||||
ModelAssetPaths::~ModelAssetPaths() = default;
|
ModelAssetPaths::~ModelAssetPaths() = default;
|
||||||
|
|
||||||
ModelAssets::ModelAssets() = default;
|
ModelAssets::ModelAssets() = default;
|
||||||
|
|
||||||
|
ModelAssets::ModelAssets(const ModelAssets& other)
|
||||||
|
: weights(other.weights.Duplicate()),
|
||||||
|
weights_path(other.weights_path),
|
||||||
|
sp_model_path(other.sp_model_path) {}
|
||||||
|
|
||||||
|
ModelAssets& ModelAssets::operator=(const ModelAssets& other) {
|
||||||
|
weights = other.weights.Duplicate();
|
||||||
|
weights_path = other.weights_path;
|
||||||
|
sp_model_path = other.sp_model_path;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
ModelAssets::ModelAssets(ModelAssets&&) = default;
|
ModelAssets::ModelAssets(ModelAssets&&) = default;
|
||||||
ModelAssets& ModelAssets::operator=(ModelAssets&&) = default;
|
ModelAssets& ModelAssets::operator=(ModelAssets&&) = default;
|
||||||
ModelAssets::~ModelAssets() = default;
|
ModelAssets::~ModelAssets() = default;
|
||||||
|
@@ -25,6 +25,8 @@ struct COMPONENT_EXPORT(ON_DEVICE_MODEL_CPP) ModelAssetPaths {
|
|||||||
// execution.
|
// execution.
|
||||||
struct COMPONENT_EXPORT(ON_DEVICE_MODEL_CPP) ModelAssets {
|
struct COMPONENT_EXPORT(ON_DEVICE_MODEL_CPP) ModelAssets {
|
||||||
ModelAssets();
|
ModelAssets();
|
||||||
|
ModelAssets(const ModelAssets&);
|
||||||
|
ModelAssets& operator=(const ModelAssets&);
|
||||||
ModelAssets(ModelAssets&&);
|
ModelAssets(ModelAssets&&);
|
||||||
ModelAssets& operator=(ModelAssets&&);
|
ModelAssets& operator=(ModelAssets&&);
|
||||||
~ModelAssets();
|
~ModelAssets();
|
||||||
|
@@ -316,6 +316,20 @@ void FakeOnDeviceModelService::LoadModel(
|
|||||||
std::move(callback).Run(mojom::LoadModelResult::kSuccess);
|
std::move(callback).Run(mojom::LoadModelResult::kSuccess);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void FakeOnDeviceModelService::GetCapabilities(
|
||||||
|
ModelAssets assets,
|
||||||
|
GetCapabilitiesCallback callback) {
|
||||||
|
std::string contents = ReadFile(assets.weights);
|
||||||
|
Capabilities capabilities;
|
||||||
|
if (contents.find("image") != std::string::npos) {
|
||||||
|
capabilities.Put(CapabilityFlags::kImageInput);
|
||||||
|
}
|
||||||
|
if (contents.find("audio") != std::string::npos) {
|
||||||
|
capabilities.Put(CapabilityFlags::kAudioInput);
|
||||||
|
}
|
||||||
|
std::move(callback).Run(capabilities);
|
||||||
|
}
|
||||||
|
|
||||||
void FakeOnDeviceModelService::LoadTextSafetyModel(
|
void FakeOnDeviceModelService::LoadTextSafetyModel(
|
||||||
mojom::TextSafetyModelParamsPtr params,
|
mojom::TextSafetyModelParamsPtr params,
|
||||||
mojo::PendingReceiver<mojom::TextSafetyModel> model) {
|
mojo::PendingReceiver<mojom::TextSafetyModel> model) {
|
||||||
|
@@ -192,6 +192,8 @@ class FakeOnDeviceModelService : public mojom::OnDeviceModelService {
|
|||||||
void LoadModel(mojom::LoadModelParamsPtr params,
|
void LoadModel(mojom::LoadModelParamsPtr params,
|
||||||
mojo::PendingReceiver<mojom::OnDeviceModel> model,
|
mojo::PendingReceiver<mojom::OnDeviceModel> model,
|
||||||
LoadModelCallback callback) override;
|
LoadModelCallback callback) override;
|
||||||
|
void GetCapabilities(ModelAssets assets,
|
||||||
|
GetCapabilitiesCallback callback) override;
|
||||||
void LoadTextSafetyModel(
|
void LoadTextSafetyModel(
|
||||||
mojom::TextSafetyModelParamsPtr params,
|
mojom::TextSafetyModelParamsPtr params,
|
||||||
mojo::PendingReceiver<mojom::TextSafetyModel> model) override;
|
mojo::PendingReceiver<mojom::TextSafetyModel> model) override;
|
||||||
|
@@ -126,6 +126,10 @@ interface OnDeviceModelService {
|
|||||||
LoadModel(LoadModelParams params, pending_receiver<OnDeviceModel> model)
|
LoadModel(LoadModelParams params, pending_receiver<OnDeviceModel> model)
|
||||||
=> (LoadModelResult result);
|
=> (LoadModelResult result);
|
||||||
|
|
||||||
|
// Returns the capabilities for a model specified by `params`. Note that this
|
||||||
|
// will not load the model, so is much cheaper than calling LoadModel().
|
||||||
|
GetCapabilities(ModelAssets assets) => (Capabilities capabilities);
|
||||||
|
|
||||||
// Initializes a new TextSafetyModel with the provided params.
|
// Initializes a new TextSafetyModel with the provided params.
|
||||||
// The model is disconnected on any errors with it.
|
// The model is disconnected on any errors with it.
|
||||||
LoadTextSafetyModel(
|
LoadTextSafetyModel(
|
||||||
|
Reference in New Issue
Block a user