Add GetCapabilities() to on device model service
This API allows querying what capabilities a model can support. This will be used to determine if the model can be used for certain features. This also updates chrome://on-device-internals to use this API to automatically detect capabilities when starting the session. Bug: 399200301 Change-Id: I5288bdd4b879f5d3a773d165f212940e7da3ece0 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6353258 Reviewed-by: Austin Sullivan <asully@chromium.org> Commit-Queue: Clark DuVall <cduvall@chromium.org> Reviewed-by: Tom Sepez <tsepez@chromium.org> Cr-Commit-Position: refs/heads/main@{#1432311}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
22eaa3855b
commit
a05c688c9d
chrome/browser
resources
on_device_internals
ui
services/on_device_model
@ -145,12 +145,6 @@
|
||||
<option value="kHighestQuality">Highest Quality</option>
|
||||
<option value="kFastestInference">Fastest Inference</option>
|
||||
</select>
|
||||
<cr-checkbox slot="suffix" checked="{{enableImageInput_}}">
|
||||
Enable images
|
||||
</cr-checkbox>
|
||||
<cr-checkbox slot="suffix" checked="{{enableAudioInput_}}">
|
||||
Enable audio
|
||||
</cr-checkbox>
|
||||
</div>
|
||||
<div class="model-text">
|
||||
[[getModelText_(modelPath_, loadModelDuration_, loadedPerformanceHint_)]]
|
||||
@ -188,7 +182,7 @@
|
||||
value="{{text_}}">
|
||||
</cr-textarea>
|
||||
<div class="multimodal-buttons" >
|
||||
<div class="image-buttons" hidden$="[[!imagesEnabled_(model_, baseModel_)]]">
|
||||
<div class="image-buttons" hidden$="[[!imagesEnabled_(capabilities_)]]">
|
||||
<div class="image-error">[[imageError_]]</div>
|
||||
<div hidden$="[[imageFile_]]">
|
||||
<cr-button class="floating-button"
|
||||
@ -205,7 +199,7 @@
|
||||
[[imageFile_.name]]
|
||||
</cr-button>
|
||||
</div>
|
||||
<div class="audio-buttons" hidden$="[[!audioEnabled_(model_, baseModel_)]]">
|
||||
<div class="audio-buttons" hidden$="[[!audioEnabled_(capabilities_)]]">
|
||||
<div class="audio-error">[[audioError_]]</div>
|
||||
<div hidden$="[[audioFile_]]">
|
||||
<cr-button class="floating-button"
|
||||
|
@ -16,7 +16,7 @@ import type {CrInputElement} from '//resources/cr_elements/cr_input/cr_input.js'
|
||||
import {PolymerElement} from '//resources/polymer/v3_0/polymer/polymer_bundled.min.js';
|
||||
|
||||
import {BrowserProxy} from './browser_proxy.js';
|
||||
import type {InputPiece, ResponseChunk, ResponseSummary,AudioData} from './on_device_model.mojom-webui.js';
|
||||
import type {AudioData, Capabilities, InputPiece, ResponseChunk, ResponseSummary} from './on_device_model.mojom-webui.js';
|
||||
import {LoadModelResult, OnDeviceModelRemote, PerformanceClass, SessionRemote, StreamingResponderCallbackRouter, Token} from './on_device_model.mojom-webui.js';
|
||||
import {ModelPerformanceHint} from './on_device_model_service.mojom-webui.js';
|
||||
import {getTemplate} from './tools.html.js';
|
||||
@ -113,10 +113,6 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
||||
type: Array,
|
||||
value: () => [],
|
||||
},
|
||||
baseModel_: {
|
||||
type: Object,
|
||||
value: null,
|
||||
},
|
||||
model_: {
|
||||
type: Object,
|
||||
value: null,
|
||||
@ -134,10 +130,6 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
||||
value: 0,
|
||||
},
|
||||
contextText_: String,
|
||||
enableImageInput_: {
|
||||
type: Boolean,
|
||||
value: false,
|
||||
},
|
||||
topK_: {
|
||||
type: Number,
|
||||
value: 1,
|
||||
@ -170,6 +162,7 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
||||
}
|
||||
|
||||
|
||||
private capabilities_: Capabilities = {imageInput: false, audioInput: false};
|
||||
private contextExpanded_: boolean;
|
||||
private contextLength_: number;
|
||||
private contextText_: string;
|
||||
@ -179,7 +172,6 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
||||
private loadModelDuration_: number;
|
||||
private loadModelStart_: number;
|
||||
private modelPath_: string;
|
||||
private baseModel_: OnDeviceModelRemote|null;
|
||||
private model_: OnDeviceModelRemote|null;
|
||||
private performanceClassText_: string;
|
||||
private responses_: Response[];
|
||||
@ -187,8 +179,6 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
||||
private text_: string;
|
||||
private topK_: number;
|
||||
private imageFile_: File|null;
|
||||
private enableAudioInput_: boolean;
|
||||
private enableImageInput_: boolean;
|
||||
private audioFile_: File|null;
|
||||
private audioError_: string;
|
||||
private performanceHint_: string;
|
||||
@ -257,7 +247,6 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
||||
}
|
||||
this.error_ = 'Service crashed, please reload the model.';
|
||||
this.model_ = null;
|
||||
this.baseModel_ = null;
|
||||
this.modelPath_ = '';
|
||||
this.loadModelStart_ = 0;
|
||||
this.$.modelInput.focus();
|
||||
@ -284,16 +273,16 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
||||
|
||||
private async onModelSelected_() {
|
||||
this.error_ = '';
|
||||
if (this.baseModel_) {
|
||||
this.baseModel_.$.close();
|
||||
if (this.model_) {
|
||||
this.model_.$.close();
|
||||
}
|
||||
if (this.model_) {
|
||||
this.model_.$.close();
|
||||
}
|
||||
this.imageFile_ = null;
|
||||
this.audioFile_ = null;
|
||||
this.baseModel_ = null;
|
||||
this.model_ = null;
|
||||
this.capabilities_ = {imageInput: false, audioInput: false};
|
||||
this.loadModelStart_ = new Date().getTime();
|
||||
const performanceHint = ModelPerformanceHint[(
|
||||
this.performanceHint_ as keyof typeof ModelPerformanceHint)];
|
||||
@ -305,35 +294,16 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
||||
// <if expr="not is_win">
|
||||
const processedPath = modelPath;
|
||||
// </if>
|
||||
const baseModel = new OnDeviceModelRemote();
|
||||
let newModel = new OnDeviceModelRemote();
|
||||
let {result} = await this.proxy_.handler.loadModel(
|
||||
const newModel = new OnDeviceModelRemote();
|
||||
const {result, capabilities} = await this.proxy_.handler.loadModel(
|
||||
{path: processedPath}, performanceHint,
|
||||
baseModel.$.bindNewPipeAndPassReceiver());
|
||||
if (result === LoadModelResult.kSuccess &&
|
||||
(this.enableImageInput_ || this.enableAudioInput_)) {
|
||||
result = (await baseModel.loadAdaptation(
|
||||
{
|
||||
enableImageInput: this.enableImageInput_,
|
||||
enableAudioInput: this.enableAudioInput_,
|
||||
maxTokens: 0,
|
||||
assets: {
|
||||
weights: null,
|
||||
weightsPath: null,
|
||||
},
|
||||
},
|
||||
newModel.$.bindNewPipeAndPassReceiver()))
|
||||
.result;
|
||||
} else {
|
||||
// No adaptation needed, just use the base model.
|
||||
newModel = baseModel;
|
||||
}
|
||||
newModel.$.bindNewPipeAndPassReceiver());
|
||||
if (result !== LoadModelResult.kSuccess) {
|
||||
this.error_ =
|
||||
'Unable to load model. Specify a correct and absolute path.';
|
||||
} else {
|
||||
this.baseModel_ = baseModel;
|
||||
this.model_ = newModel;
|
||||
this.capabilities_ = capabilities;
|
||||
this.model_.onConnectionError.addListener(() => {
|
||||
this.onServiceCrashed_();
|
||||
});
|
||||
@ -364,8 +334,13 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
||||
}
|
||||
this.contextLength_ = 0;
|
||||
this.session_ = new SessionRemote();
|
||||
this.model_.startSession(
|
||||
this.session_.$.bindNewPipeAndPassReceiver(), null);
|
||||
this.model_.startSession(this.session_.$.bindNewPipeAndPassReceiver(), {
|
||||
maxTokens: 0,
|
||||
capabilities: {
|
||||
imageInput: this.imagesEnabled_(),
|
||||
audioInput: this.audioEnabled_(),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
private onCancelClick_() {
|
||||
@ -508,11 +483,11 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
||||
}
|
||||
|
||||
private imagesEnabled_(): boolean {
|
||||
return this.model_ !== this.baseModel_ && this.enableImageInput_;
|
||||
return this.capabilities_.imageInput;
|
||||
}
|
||||
|
||||
private audioEnabled_(): boolean {
|
||||
return this.model_ !== this.baseModel_ && this.enableAudioInput_;
|
||||
return this.capabilities_.audioInput;
|
||||
}
|
||||
|
||||
private getModelText_(): string {
|
||||
|
@ -51,7 +51,8 @@ interface OnDeviceInternalsPageHandler {
|
||||
LoadModel(mojo_base.mojom.FilePath model_path,
|
||||
on_device_model.mojom.ModelPerformanceHint performance_hint,
|
||||
pending_receiver<on_device_model.mojom.OnDeviceModel> model) =>
|
||||
(on_device_model.mojom.LoadModelResult result);
|
||||
(on_device_model.mojom.LoadModelResult result,
|
||||
on_device_model.mojom.Capabilities capabilities);
|
||||
|
||||
// Returns the performance class based on benchmarks run on the device.
|
||||
GetEstimatedPerformanceClass() =>
|
||||
|
@ -81,8 +81,14 @@ void OnDeviceInternalsPageHandler::LoadModel(
|
||||
on_device_model::mojom::LoadModelResult::kFailedToLoadLibrary);
|
||||
return;
|
||||
}
|
||||
GetService().LoadPlatformModel(uuid, std::move(model), mojo::NullRemote(),
|
||||
std::move(callback));
|
||||
GetService().LoadPlatformModel(
|
||||
uuid, std::move(model), mojo::NullRemote(),
|
||||
base::BindOnce(
|
||||
[](LoadModelCallback callback,
|
||||
on_device_model::mojom::LoadModelResult result) {
|
||||
std::move(callback).Run(result, on_device_model::Capabilities());
|
||||
},
|
||||
std::move(callback)));
|
||||
#else
|
||||
// Warm the service while assets load in the background.
|
||||
std::ignore = GetService();
|
||||
@ -122,11 +128,26 @@ void OnDeviceInternalsPageHandler::OnModelAssetsLoaded(
|
||||
ml::ModelPerformanceHint performance_hint,
|
||||
on_device_model::ModelAssets assets) {
|
||||
auto params = on_device_model::mojom::LoadModelParams::New();
|
||||
params->assets = std::move(assets);
|
||||
params->assets = assets;
|
||||
params->max_tokens = 4096;
|
||||
params->performance_hint = performance_hint;
|
||||
GetService().LoadModel(std::move(params), std::move(model),
|
||||
std::move(callback));
|
||||
GetService().LoadModel(
|
||||
std::move(params), std::move(model),
|
||||
base::BindOnce(&OnDeviceInternalsPageHandler::OnModelLoaded,
|
||||
weak_ptr_factory_.GetWeakPtr(), std::move(callback),
|
||||
std::move(assets)));
|
||||
}
|
||||
|
||||
void OnDeviceInternalsPageHandler::OnModelLoaded(
|
||||
LoadModelCallback callback,
|
||||
on_device_model::ModelAssets assets,
|
||||
on_device_model::mojom::LoadModelResult result) {
|
||||
if (result != on_device_model::mojom::LoadModelResult::kSuccess) {
|
||||
std::move(callback).Run(result, on_device_model::Capabilities());
|
||||
return;
|
||||
}
|
||||
GetService().GetCapabilities(std::move(assets),
|
||||
base::BindOnce(std::move(callback), result));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -45,6 +45,9 @@ class OnDeviceInternalsPageHandler : public mojom::OnDeviceInternalsPageHandler,
|
||||
LoadModelCallback callback,
|
||||
ml::ModelPerformanceHint performance_hint,
|
||||
on_device_model::ModelAssets assets);
|
||||
void OnModelLoaded(LoadModelCallback callback,
|
||||
on_device_model::ModelAssets assets,
|
||||
on_device_model::mojom::LoadModelResult result);
|
||||
#endif
|
||||
|
||||
// mojom::OnDeviceInternalsPageHandler:
|
||||
|
@ -38,6 +38,16 @@ std::string PieceToString(const ml::InputPiece& piece) {
|
||||
NOTREACHED();
|
||||
}
|
||||
|
||||
std::string ReadFile(PlatformFile api_file) {
|
||||
base::File file(static_cast<base::PlatformFile>(api_file));
|
||||
std::vector<uint8_t> contents;
|
||||
contents.resize(file.GetLength());
|
||||
if (!file.ReadAndCheck(0, contents)) {
|
||||
return std::string();
|
||||
}
|
||||
return std::string(contents.begin(), contents.end());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void InitDawnProcs(const DawnProcTable& procs) {}
|
||||
@ -58,6 +68,13 @@ bool QueryGPUAdapter(void (*adapter_callback_fn)(WGPUAdapter adapter,
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GetCapabilities(PlatformFile file, ChromeMLCapabilities& capabilities) {
|
||||
std::string contents = ReadFile(file);
|
||||
capabilities.image_input = contents.find("image") != std::string::npos;
|
||||
capabilities.audio_input = contents.find("audio") != std::string::npos;
|
||||
return true;
|
||||
}
|
||||
|
||||
struct FakeModelInstance {
|
||||
ml::ModelBackendType backend_type_;
|
||||
ml::ModelPerformanceHint performance_hint;
|
||||
@ -81,16 +98,6 @@ struct FakeCancelInstance {
|
||||
bool cancelled = false;
|
||||
};
|
||||
|
||||
std::string ReadFile(PlatformFile api_file) {
|
||||
base::File file(static_cast<base::PlatformFile>(api_file));
|
||||
std::vector<uint8_t> contents;
|
||||
contents.resize(file.GetLength());
|
||||
if (!file.ReadAndCheck(0, contents)) {
|
||||
return std::string();
|
||||
}
|
||||
return std::string(contents.begin(), contents.end());
|
||||
}
|
||||
|
||||
ChromeMLModel SessionCreateModel(const ChromeMLModelDescriptor* descriptor,
|
||||
uintptr_t context,
|
||||
ChromeMLScheduleFn schedule) {
|
||||
@ -301,6 +308,7 @@ const ChromeMLAPI g_api = {
|
||||
.DestroyModel = &DestroyModel,
|
||||
.GetEstimatedPerformance = &GetEstimatedPerformance,
|
||||
.QueryGPUAdapter = &QueryGPUAdapter,
|
||||
.GetCapabilities = &GetCapabilities,
|
||||
.SetFatalErrorNonGpuFn = &SetFatalErrorNonGpuFn,
|
||||
|
||||
.SessionCreateModel = &SessionCreateModel,
|
||||
|
@ -229,6 +229,12 @@ struct GpuConfig {
|
||||
WGPUBackendType backend_type;
|
||||
};
|
||||
|
||||
// A set of capabilities that a model can have.
|
||||
struct ChromeMLCapabilities {
|
||||
bool image_input = false;
|
||||
bool audio_input = false;
|
||||
};
|
||||
|
||||
struct ChromeMLMetricsFns {
|
||||
// Logs an exact sample for the named metric.
|
||||
void (*RecordExactLinearHistogram)(const char* name,
|
||||
@ -329,6 +335,10 @@ struct ChromeMLAPI {
|
||||
void* userdata),
|
||||
void* userdata);
|
||||
|
||||
// Gets the model capabilities for the model pointed to by `model_data`.
|
||||
bool (*GetCapabilities)(PlatformFile file,
|
||||
ChromeMLCapabilities& capabilities);
|
||||
|
||||
// Same as SetFatalErrorFn(), but for fatal errors that occur outside of the
|
||||
// gpu.
|
||||
void (*SetFatalErrorNonGpuFn)(ChromeMLFatalErrorFn error_fn);
|
||||
|
@ -391,6 +391,36 @@ OnDeviceModelExecutor::CreateWithResult(
|
||||
return base::unexpected(load_model_result);
|
||||
}
|
||||
|
||||
// static
|
||||
DISABLE_CFI_DLSYM
|
||||
on_device_model::Capabilities OnDeviceModelExecutor::GetCapabilities(
|
||||
const ChromeML& chrome_ml,
|
||||
on_device_model::ModelAssets assets) {
|
||||
on_device_model::Capabilities result;
|
||||
if (!chrome_ml.api().GetCapabilities) {
|
||||
return result;
|
||||
}
|
||||
|
||||
PlatformFile platform_file;
|
||||
std::string weights_path_str = assets.weights_path.AsUTF8Unsafe();
|
||||
if (assets.weights.IsValid()) {
|
||||
platform_file = assets.weights.TakePlatformFile();
|
||||
} else {
|
||||
base::File file(assets.weights_path, base::File::FLAG_OPEN);
|
||||
platform_file = file.TakePlatformFile();
|
||||
}
|
||||
ChromeMLCapabilities capabilities;
|
||||
chrome_ml.api().GetCapabilities(platform_file, capabilities);
|
||||
|
||||
if (capabilities.image_input) {
|
||||
result.Put(on_device_model::CapabilityFlags::kImageInput);
|
||||
}
|
||||
if (capabilities.audio_input) {
|
||||
result.Put(on_device_model::CapabilityFlags::kAudioInput);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unique_ptr<SessionImpl> OnDeviceModelExecutor::CreateSession(
|
||||
const ScopedAdaptation* adaptation,
|
||||
on_device_model::mojom::SessionParamsPtr params) {
|
||||
|
@ -96,6 +96,10 @@ class COMPONENT_EXPORT(ON_DEVICE_MODEL_ML) OnDeviceModelExecutor final {
|
||||
on_device_model::mojom::LoadModelParamsPtr params,
|
||||
base::OnceClosure on_complete);
|
||||
|
||||
static on_device_model::Capabilities GetCapabilities(
|
||||
const ChromeML& chrome_ml,
|
||||
on_device_model::ModelAssets assets);
|
||||
|
||||
std::unique_ptr<SessionImpl> CreateSession(
|
||||
const ScopedAdaptation* adaptation,
|
||||
on_device_model::mojom::SessionParamsPtr params);
|
||||
|
@ -385,6 +385,12 @@ void OnDeviceModelService::LoadModel(
|
||||
std::move(callback).Run(mojom::LoadModelResult::kSuccess);
|
||||
}
|
||||
|
||||
void OnDeviceModelService::GetCapabilities(ModelAssets assets,
|
||||
GetCapabilitiesCallback callback) {
|
||||
std::move(callback).Run(ml::OnDeviceModelExecutor::GetCapabilities(
|
||||
*chrome_ml_, std::move(assets)));
|
||||
}
|
||||
|
||||
void OnDeviceModelService::GetEstimatedPerformanceClass(
|
||||
GetEstimatedPerformanceClassCallback callback) {
|
||||
base::ElapsedTimer timer;
|
||||
|
@ -71,6 +71,8 @@ class COMPONENT_EXPORT(ON_DEVICE_MODEL) OnDeviceModelService
|
||||
void LoadModel(mojom::LoadModelParamsPtr params,
|
||||
mojo::PendingReceiver<mojom::OnDeviceModel> model,
|
||||
LoadModelCallback callback) override;
|
||||
void GetCapabilities(ModelAssets assets,
|
||||
GetCapabilitiesCallback callback) override;
|
||||
void LoadTextSafetyModel(
|
||||
mojom::TextSafetyModelParamsPtr params,
|
||||
mojo::PendingReceiver<mojom::TextSafetyModel> model) override;
|
||||
|
@ -605,5 +605,22 @@ TEST_F(OnDeviceModelServiceTest, PerformanceHint) {
|
||||
ElementsAre("Fastest inference\n", "Context: foo\n"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, Capabilities) {
|
||||
auto expect_capabilities = [&](const std::string& data,
|
||||
const Capabilities& expected) {
|
||||
FakeFile file(data);
|
||||
ModelAssets assets;
|
||||
assets.weights = file.Open();
|
||||
base::test::TestFuture<const Capabilities&> future;
|
||||
service()->GetCapabilities(std::move(assets), future.GetCallback());
|
||||
EXPECT_EQ(expected, future.Take());
|
||||
};
|
||||
expect_capabilities("none", {});
|
||||
expect_capabilities("image", {CapabilityFlags::kImageInput});
|
||||
expect_capabilities("audio", {CapabilityFlags::kAudioInput});
|
||||
expect_capabilities("image audio", {CapabilityFlags::kImageInput,
|
||||
CapabilityFlags::kAudioInput});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace on_device_model
|
||||
|
@ -54,6 +54,19 @@ ModelAssetPaths::ModelAssetPaths(const ModelAssetPaths&) = default;
|
||||
ModelAssetPaths::~ModelAssetPaths() = default;
|
||||
|
||||
ModelAssets::ModelAssets() = default;
|
||||
|
||||
ModelAssets::ModelAssets(const ModelAssets& other)
|
||||
: weights(other.weights.Duplicate()),
|
||||
weights_path(other.weights_path),
|
||||
sp_model_path(other.sp_model_path) {}
|
||||
|
||||
ModelAssets& ModelAssets::operator=(const ModelAssets& other) {
|
||||
weights = other.weights.Duplicate();
|
||||
weights_path = other.weights_path;
|
||||
sp_model_path = other.sp_model_path;
|
||||
return *this;
|
||||
}
|
||||
|
||||
ModelAssets::ModelAssets(ModelAssets&&) = default;
|
||||
ModelAssets& ModelAssets::operator=(ModelAssets&&) = default;
|
||||
ModelAssets::~ModelAssets() = default;
|
||||
|
@ -25,6 +25,8 @@ struct COMPONENT_EXPORT(ON_DEVICE_MODEL_CPP) ModelAssetPaths {
|
||||
// execution.
|
||||
struct COMPONENT_EXPORT(ON_DEVICE_MODEL_CPP) ModelAssets {
|
||||
ModelAssets();
|
||||
ModelAssets(const ModelAssets&);
|
||||
ModelAssets& operator=(const ModelAssets&);
|
||||
ModelAssets(ModelAssets&&);
|
||||
ModelAssets& operator=(ModelAssets&&);
|
||||
~ModelAssets();
|
||||
|
@ -316,6 +316,20 @@ void FakeOnDeviceModelService::LoadModel(
|
||||
std::move(callback).Run(mojom::LoadModelResult::kSuccess);
|
||||
}
|
||||
|
||||
void FakeOnDeviceModelService::GetCapabilities(
|
||||
ModelAssets assets,
|
||||
GetCapabilitiesCallback callback) {
|
||||
std::string contents = ReadFile(assets.weights);
|
||||
Capabilities capabilities;
|
||||
if (contents.find("image") != std::string::npos) {
|
||||
capabilities.Put(CapabilityFlags::kImageInput);
|
||||
}
|
||||
if (contents.find("audio") != std::string::npos) {
|
||||
capabilities.Put(CapabilityFlags::kAudioInput);
|
||||
}
|
||||
std::move(callback).Run(capabilities);
|
||||
}
|
||||
|
||||
void FakeOnDeviceModelService::LoadTextSafetyModel(
|
||||
mojom::TextSafetyModelParamsPtr params,
|
||||
mojo::PendingReceiver<mojom::TextSafetyModel> model) {
|
||||
|
@ -192,6 +192,8 @@ class FakeOnDeviceModelService : public mojom::OnDeviceModelService {
|
||||
void LoadModel(mojom::LoadModelParamsPtr params,
|
||||
mojo::PendingReceiver<mojom::OnDeviceModel> model,
|
||||
LoadModelCallback callback) override;
|
||||
void GetCapabilities(ModelAssets assets,
|
||||
GetCapabilitiesCallback callback) override;
|
||||
void LoadTextSafetyModel(
|
||||
mojom::TextSafetyModelParamsPtr params,
|
||||
mojo::PendingReceiver<mojom::TextSafetyModel> model) override;
|
||||
|
@ -126,6 +126,10 @@ interface OnDeviceModelService {
|
||||
LoadModel(LoadModelParams params, pending_receiver<OnDeviceModel> model)
|
||||
=> (LoadModelResult result);
|
||||
|
||||
// Returns the capabilities for a model specified by `params`. Note that this
|
||||
// will not load the model, so is much cheaper than calling LoadModel().
|
||||
GetCapabilities(ModelAssets assets) => (Capabilities capabilities);
|
||||
|
||||
// Initializes a new TextSafetyModel with the provided params.
|
||||
// The model is disconnected on any errors with it.
|
||||
LoadTextSafetyModel(
|
||||
|
Reference in New Issue
Block a user