Add support for performance hints when loading on-device model
This allows passing a performance hint to the backend when loading the on-device model. This can be used to tune various params in the backend or use a smaller model if requested. This is controlled from the optimization guide side by a new feature param "compatible_low_tier_on_device_performance_classes", which allows us to choose a set of performance classes that will use the kFastestInference performance hint. As part of this change, all the feature params related to model performance were moved under a new base::Feature to allow them to be adjusted independent of the other on-device related params. Bug: 379723772 Change-Id: I28db8de699a8af026ee2ee1cfa08136342860e13 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6089104 Reviewed-by: Matthew Denton <mpdenton@chromium.org> Commit-Queue: Clark DuVall <cduvall@chromium.org> Reviewed-by: Sophie Chang <sophiechang@chromium.org> Cr-Commit-Position: refs/heads/main@{#1395676}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
82818acdfe
commit
3fdbf21ee2
chrome/browser
components/optimization_guide/core
model_execution
on_device_model_component.ccon_device_model_component.hon_device_model_component_unittest.ccon_device_model_service_controller.ccon_device_model_service_controller_unittest.cc
optimization_guide_features.ccoptimization_guide_features.hservices/on_device_model
fake
ml
chrome_ml_api.hchrome_ml_types.hchrome_ml_types_traits.ccchrome_ml_types_traits.hgpu_blocklist.ccon_device_model_executor.ccperformance_class.cc
on_device_model_service_unittest.ccpublic
tools/metrics/histograms
@@ -6645,7 +6645,7 @@ const FeatureEntry kFeatureEntries[] = {
|
|||||||
flag_descriptions::kOptimizationGuideOnDeviceModelName,
|
flag_descriptions::kOptimizationGuideOnDeviceModelName,
|
||||||
flag_descriptions::kOptimizationGuideOnDeviceModelDescription, kOsDesktop,
|
flag_descriptions::kOptimizationGuideOnDeviceModelDescription, kOsDesktop,
|
||||||
FEATURE_WITH_PARAMS_VALUE_TYPE(
|
FEATURE_WITH_PARAMS_VALUE_TYPE(
|
||||||
optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||||
kOptimizationGuideOnDeviceModelVariations,
|
kOptimizationGuideOnDeviceModelVariations,
|
||||||
"OptimizationGuideOnDeviceModel")},
|
"OptimizationGuideOnDeviceModel")},
|
||||||
|
|
||||||
|
@@ -690,7 +690,8 @@ class OnDeviceModelExecutionEnabledBrowserTest
|
|||||||
scoped_feature_list_.InitWithFeaturesAndParameters(
|
scoped_feature_list_.InitWithFeaturesAndParameters(
|
||||||
{{features::kOptimizationGuideModelExecution, {}},
|
{{features::kOptimizationGuideModelExecution, {}},
|
||||||
{features::kModelQualityLogging, {}},
|
{features::kModelQualityLogging, {}},
|
||||||
{features::kOptimizationGuideOnDeviceModel,
|
{features::kOptimizationGuideOnDeviceModel, {}},
|
||||||
|
{features::kOnDeviceModelPerformanceParams,
|
||||||
{{"compatible_on_device_performance_classes", "*"}}}},
|
{{"compatible_on_device_performance_classes", "*"}}}},
|
||||||
{});
|
{});
|
||||||
}
|
}
|
||||||
|
@@ -16,6 +16,7 @@ build_webui("build") {
|
|||||||
mojo_files = [
|
mojo_files = [
|
||||||
"$root_gen_dir/chrome/browser/ui/webui/on_device_internals/on_device_internals_page.mojom-webui.ts",
|
"$root_gen_dir/chrome/browser/ui/webui/on_device_internals/on_device_internals_page.mojom-webui.ts",
|
||||||
"$root_gen_dir/services/on_device_model/public/mojom/on_device_model.mojom-webui.ts",
|
"$root_gen_dir/services/on_device_model/public/mojom/on_device_model.mojom-webui.ts",
|
||||||
|
"$root_gen_dir/services/on_device_model/public/mojom/on_device_model_service.mojom-webui.ts",
|
||||||
]
|
]
|
||||||
|
|
||||||
static_files = [ "on_device_internals.html" ]
|
static_files = [ "on_device_internals.html" ]
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
<style include="cr-shared-style cr-hidden-style">
|
<style include="cr-shared-style cr-hidden-style md-select">
|
||||||
:host {
|
:host {
|
||||||
display: block;
|
display: block;
|
||||||
margin: auto;
|
margin: auto;
|
||||||
@@ -110,24 +110,40 @@
|
|||||||
cr-expand-button:hover {
|
cr-expand-button:hover {
|
||||||
background-color: var(--cr-hover-background-color);
|
background-color: var(--cr-hover-background-color);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.model-options {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: row;
|
||||||
|
}
|
||||||
|
|
||||||
|
cr-checkbox {
|
||||||
|
margin-left: 20px;
|
||||||
|
--cr-checkbox-label-padding-start: 10px;
|
||||||
|
}
|
||||||
</style>
|
</style>
|
||||||
<div class="performance-class">
|
<div class="performance-class">
|
||||||
Device performance class: <strong>[[performanceClassText_]]</strong>
|
Device performance class: <strong>[[performanceClassText_]]</strong>
|
||||||
</div>
|
</div>
|
||||||
<cr-input id="modelInput" label="Model directory" placeholder="/tmp/model"
|
<cr-input id="modelInput" label="Model directory" placeholder="/tmp/model"
|
||||||
disabled="[[isLoading_(loadModelStart_)]]"
|
disabled="[[isLoading_(loadModelStart_)]]"
|
||||||
on-change="onModelSelected_" error-message="[[error_]]"
|
error-message="[[error_]]" invalid="[[error_.length]]" autofocus>
|
||||||
invalid="[[error_.length]]" autofocus>
|
|
||||||
<cr-button slot="suffix" disabled="[[isLoading_(loadModelStart_)]]"
|
<cr-button slot="suffix" disabled="[[isLoading_(loadModelStart_)]]"
|
||||||
on-click="onLoadClick_">
|
on-click="onLoadClick_">
|
||||||
Load
|
Load
|
||||||
</cr-button>
|
</cr-button>
|
||||||
</cr-input>
|
</cr-input>
|
||||||
<cr-checkbox slot="suffix" checked="{{enableImageInput_}}">
|
<div class="model-options">
|
||||||
Enable images
|
<select id="performanceHintSelect" class="md-select"
|
||||||
</cr-checkbox>
|
value="[[performanceHint_]]" on-change="onPerformanceHintChange_">
|
||||||
|
<option value="kHighestQuality">Highest Quality</option>
|
||||||
|
<option value="kFastestInference">Fastest Inference</option>
|
||||||
|
</select>
|
||||||
|
<cr-checkbox slot="suffix" checked="{{enableImageInput_}}">
|
||||||
|
Enable images
|
||||||
|
</cr-checkbox>
|
||||||
|
</div>
|
||||||
<div class="model-text">
|
<div class="model-text">
|
||||||
[[getModelText_(modelPath_, loadModelDuration_)]]
|
[[getModelText_(modelPath_, loadModelDuration_, loadedPerformanceHint_)]]
|
||||||
<div class="throbber" hidden$="[[!isLoading_(loadModelStart_)]]"></div>
|
<div class="throbber" hidden$="[[!isLoading_(loadModelStart_)]]"></div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
@@ -10,6 +10,7 @@ import '//resources/cr_elements/cr_hidden_style.css.js';
|
|||||||
import '//resources/cr_elements/cr_input/cr_input.js';
|
import '//resources/cr_elements/cr_input/cr_input.js';
|
||||||
import '//resources/cr_elements/cr_shared_vars.css.js';
|
import '//resources/cr_elements/cr_shared_vars.css.js';
|
||||||
import '//resources/cr_elements/cr_textarea/cr_textarea.js';
|
import '//resources/cr_elements/cr_textarea/cr_textarea.js';
|
||||||
|
import '//resources/cr_elements/md_select.css.js';
|
||||||
|
|
||||||
import type {CrInputElement} from '//resources/cr_elements/cr_input/cr_input.js';
|
import type {CrInputElement} from '//resources/cr_elements/cr_input/cr_input.js';
|
||||||
import {PolymerElement} from '//resources/polymer/v3_0/polymer/polymer_bundled.min.js';
|
import {PolymerElement} from '//resources/polymer/v3_0/polymer/polymer_bundled.min.js';
|
||||||
@@ -17,6 +18,7 @@ import {PolymerElement} from '//resources/polymer/v3_0/polymer/polymer_bundled.m
|
|||||||
import {BrowserProxy} from './browser_proxy.js';
|
import {BrowserProxy} from './browser_proxy.js';
|
||||||
import type {InputPiece, ResponseChunk, ResponseSummary} from './on_device_model.mojom-webui.js';
|
import type {InputPiece, ResponseChunk, ResponseSummary} from './on_device_model.mojom-webui.js';
|
||||||
import {LoadModelResult, OnDeviceModelRemote, PerformanceClass, SessionRemote, StreamingResponderCallbackRouter, Token} from './on_device_model.mojom-webui.js';
|
import {LoadModelResult, OnDeviceModelRemote, PerformanceClass, SessionRemote, StreamingResponderCallbackRouter, Token} from './on_device_model.mojom-webui.js';
|
||||||
|
import {ModelPerformanceHint} from './on_device_model_service.mojom-webui.js';
|
||||||
import {getTemplate} from './tools.html.js';
|
import {getTemplate} from './tools.html.js';
|
||||||
|
|
||||||
interface Response {
|
interface Response {
|
||||||
@@ -34,6 +36,7 @@ interface OnDeviceInternalsToolsElement {
|
|||||||
textInput: CrInputElement,
|
textInput: CrInputElement,
|
||||||
imageInput: HTMLInputElement,
|
imageInput: HTMLInputElement,
|
||||||
topKInput: CrInputElement,
|
topKInput: CrInputElement,
|
||||||
|
performanceHintSelect: HTMLSelectElement,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,6 +134,8 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
|||||||
type: Object,
|
type: Object,
|
||||||
value: null,
|
value: null,
|
||||||
},
|
},
|
||||||
|
performanceHint_: String,
|
||||||
|
loadedPerformanceHint_: Number,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -159,6 +164,8 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
|||||||
private topK_: number = 1;
|
private topK_: number = 1;
|
||||||
private imageFile_: File|null = null;
|
private imageFile_: File|null = null;
|
||||||
private enableImageInput_: boolean = false;
|
private enableImageInput_: boolean = false;
|
||||||
|
private performanceHint_: string = 'kHighestQuality';
|
||||||
|
private loadedPerformanceHint_: ModelPerformanceHint|null;
|
||||||
|
|
||||||
private proxy_: BrowserProxy = BrowserProxy.getInstance();
|
private proxy_: BrowserProxy = BrowserProxy.getInstance();
|
||||||
private responseRouter_: StreamingResponderCallbackRouter =
|
private responseRouter_: StreamingResponderCallbackRouter =
|
||||||
@@ -199,6 +206,10 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
|||||||
this.$.imageInput.value = '';
|
this.$.imageInput.value = '';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private onPerformanceHintChange_() {
|
||||||
|
this.performanceHint_ = this.$.performanceHintSelect.value;
|
||||||
|
}
|
||||||
|
|
||||||
private onServiceCrashed_() {
|
private onServiceCrashed_() {
|
||||||
if (this.currentResponse_) {
|
if (this.currentResponse_) {
|
||||||
this.currentResponse_.error = true;
|
this.currentResponse_.error = true;
|
||||||
@@ -233,6 +244,8 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
|||||||
this.baseModel_ = null;
|
this.baseModel_ = null;
|
||||||
this.model_ = null;
|
this.model_ = null;
|
||||||
this.loadModelStart_ = new Date().getTime();
|
this.loadModelStart_ = new Date().getTime();
|
||||||
|
const performanceHint = ModelPerformanceHint[(
|
||||||
|
this.performanceHint_ as keyof typeof ModelPerformanceHint)];
|
||||||
const modelPath = this.$.modelInput.value;
|
const modelPath = this.$.modelInput.value;
|
||||||
// <if expr="is_win">
|
// <if expr="is_win">
|
||||||
// Windows file paths are std::wstring, so use Array<Number>.
|
// Windows file paths are std::wstring, so use Array<Number>.
|
||||||
@@ -244,7 +257,8 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
|||||||
const baseModel = new OnDeviceModelRemote();
|
const baseModel = new OnDeviceModelRemote();
|
||||||
let newModel = new OnDeviceModelRemote();
|
let newModel = new OnDeviceModelRemote();
|
||||||
let {result} = await this.proxy_.handler.loadModel(
|
let {result} = await this.proxy_.handler.loadModel(
|
||||||
{path: processedPath}, baseModel.$.bindNewPipeAndPassReceiver());
|
{path: processedPath}, performanceHint,
|
||||||
|
baseModel.$.bindNewPipeAndPassReceiver());
|
||||||
if (result === LoadModelResult.kSuccess && this.enableImageInput_) {
|
if (result === LoadModelResult.kSuccess && this.enableImageInput_) {
|
||||||
result = (await baseModel.loadAdaptation(
|
result = (await baseModel.loadAdaptation(
|
||||||
{
|
{
|
||||||
@@ -272,6 +286,7 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
|||||||
});
|
});
|
||||||
this.startNewSession_();
|
this.startNewSession_();
|
||||||
this.modelPath_ = modelPath;
|
this.modelPath_ = modelPath;
|
||||||
|
this.loadedPerformanceHint_ = performanceHint;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -419,9 +434,13 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
|||||||
return '';
|
return '';
|
||||||
}
|
}
|
||||||
let text = 'Model loaded from ' + this.modelPath_ + ' in ' +
|
let text = 'Model loaded from ' + this.modelPath_ + ' in ' +
|
||||||
this.loadModelDuration_ + 'ms';
|
this.loadModelDuration_ + 'ms ';
|
||||||
if (this.imagesEnabled_()) {
|
if (this.imagesEnabled_()) {
|
||||||
text += ' [images enabled]';
|
text += '[images enabled]';
|
||||||
|
}
|
||||||
|
if (this.loadedPerformanceHint_ ===
|
||||||
|
ModelPerformanceHint.kFastestInference) {
|
||||||
|
text += '[fastest inference]';
|
||||||
}
|
}
|
||||||
return text;
|
return text;
|
||||||
}
|
}
|
||||||
|
@@ -8,6 +8,7 @@ import "mojo/public/mojom/base/big_buffer.mojom";
|
|||||||
import "mojo/public/mojom/base/file_path.mojom";
|
import "mojo/public/mojom/base/file_path.mojom";
|
||||||
import "mojo/public/mojom/base/time.mojom";
|
import "mojo/public/mojom/base/time.mojom";
|
||||||
import "services/on_device_model/public/mojom/on_device_model.mojom";
|
import "services/on_device_model/public/mojom/on_device_model.mojom";
|
||||||
|
import "services/on_device_model/public/mojom/on_device_model_service.mojom";
|
||||||
import "skia/public/mojom/bitmap.mojom";
|
import "skia/public/mojom/bitmap.mojom";
|
||||||
|
|
||||||
// Struct containing data to be displayed on on-device-internals page.
|
// Struct containing data to be displayed on on-device-internals page.
|
||||||
@@ -43,6 +44,7 @@ interface OnDeviceInternalsPageHandler {
|
|||||||
// Binds a new OnDeviceModel interface if possible using model assets loaded
|
// Binds a new OnDeviceModel interface if possible using model assets loaded
|
||||||
// from within `model_path`.
|
// from within `model_path`.
|
||||||
LoadModel(mojo_base.mojom.FilePath model_path,
|
LoadModel(mojo_base.mojom.FilePath model_path,
|
||||||
|
on_device_model.mojom.ModelPerformanceHint performance_hint,
|
||||||
pending_receiver<on_device_model.mojom.OnDeviceModel> model) =>
|
pending_receiver<on_device_model.mojom.OnDeviceModel> model) =>
|
||||||
(on_device_model.mojom.LoadModelResult result);
|
(on_device_model.mojom.LoadModelResult result);
|
||||||
|
|
||||||
|
@@ -61,6 +61,7 @@ OnDeviceInternalsPageHandler::~OnDeviceInternalsPageHandler() {
|
|||||||
|
|
||||||
void OnDeviceInternalsPageHandler::LoadModel(
|
void OnDeviceInternalsPageHandler::LoadModel(
|
||||||
const base::FilePath& model_path,
|
const base::FilePath& model_path,
|
||||||
|
ml::ModelPerformanceHint performance_hint,
|
||||||
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
|
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
|
||||||
LoadModelCallback callback) {
|
LoadModelCallback callback) {
|
||||||
#if BUILDFLAG(USE_CHROMEOS_MODEL_SERVICE)
|
#if BUILDFLAG(USE_CHROMEOS_MODEL_SERVICE)
|
||||||
@@ -81,7 +82,7 @@ void OnDeviceInternalsPageHandler::LoadModel(
|
|||||||
base::BindOnce(&LoadModelAssets, model_path),
|
base::BindOnce(&LoadModelAssets, model_path),
|
||||||
base::BindOnce(&OnDeviceInternalsPageHandler::OnModelAssetsLoaded,
|
base::BindOnce(&OnDeviceInternalsPageHandler::OnModelAssetsLoaded,
|
||||||
weak_ptr_factory_.GetWeakPtr(), std::move(model),
|
weak_ptr_factory_.GetWeakPtr(), std::move(model),
|
||||||
std::move(callback)));
|
std::move(callback), performance_hint));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,10 +110,12 @@ OnDeviceInternalsPageHandler::GetService() {
|
|||||||
void OnDeviceInternalsPageHandler::OnModelAssetsLoaded(
|
void OnDeviceInternalsPageHandler::OnModelAssetsLoaded(
|
||||||
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
|
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
|
||||||
LoadModelCallback callback,
|
LoadModelCallback callback,
|
||||||
|
ml::ModelPerformanceHint performance_hint,
|
||||||
on_device_model::ModelAssets assets) {
|
on_device_model::ModelAssets assets) {
|
||||||
auto params = on_device_model::mojom::LoadModelParams::New();
|
auto params = on_device_model::mojom::LoadModelParams::New();
|
||||||
params->assets = std::move(assets);
|
params->assets = std::move(assets);
|
||||||
params->max_tokens = 4096;
|
params->max_tokens = 4096;
|
||||||
|
params->performance_hint = performance_hint;
|
||||||
GetService().LoadModel(std::move(params), std::move(model),
|
GetService().LoadModel(std::move(params), std::move(model),
|
||||||
std::move(callback));
|
std::move(callback));
|
||||||
}
|
}
|
||||||
|
@@ -43,12 +43,14 @@ class OnDeviceInternalsPageHandler : public mojom::OnDeviceInternalsPageHandler,
|
|||||||
void OnModelAssetsLoaded(
|
void OnModelAssetsLoaded(
|
||||||
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
|
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
|
||||||
LoadModelCallback callback,
|
LoadModelCallback callback,
|
||||||
|
ml::ModelPerformanceHint performance_hint,
|
||||||
on_device_model::ModelAssets assets);
|
on_device_model::ModelAssets assets);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// mojom::OnDeviceInternalsPageHandler:
|
// mojom::OnDeviceInternalsPageHandler:
|
||||||
void LoadModel(
|
void LoadModel(
|
||||||
const base::FilePath& model_path,
|
const base::FilePath& model_path,
|
||||||
|
ml::ModelPerformanceHint performance_hint,
|
||||||
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
|
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
|
||||||
LoadModelCallback callback) override;
|
LoadModelCallback callback) override;
|
||||||
void GetEstimatedPerformanceClass(
|
void GetEstimatedPerformanceClass(
|
||||||
|
@@ -138,6 +138,13 @@ OnDeviceModelComponentStateManager::GetRegistrationCriteria() {
|
|||||||
return registration_criteria_.get();
|
return registration_criteria_.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool OnDeviceModelComponentStateManager::IsLowTierDevice() const {
|
||||||
|
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
|
||||||
|
return IsPerformanceClassCompatible(
|
||||||
|
features::kLowTierPerformanceClassListForOnDeviceModel.Get(),
|
||||||
|
PerformanceClassFromPref(*local_state_));
|
||||||
|
}
|
||||||
|
|
||||||
void OnDeviceModelComponentStateManager::OnDeviceEligibleFeatureUsed(
|
void OnDeviceModelComponentStateManager::OnDeviceEligibleFeatureUsed(
|
||||||
ModelBasedCapabilityKey feature) {
|
ModelBasedCapabilityKey feature) {
|
||||||
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
|
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
|
||||||
|
@@ -179,6 +179,9 @@ class OnDeviceModelComponentStateManager
|
|||||||
// registration has been computed yet.
|
// registration has been computed yet.
|
||||||
const RegistrationCriteria* GetRegistrationCriteria();
|
const RegistrationCriteria* GetRegistrationCriteria();
|
||||||
|
|
||||||
|
// Returns true if this is determined to be a low tier device.
|
||||||
|
bool IsLowTierDevice() const;
|
||||||
|
|
||||||
base::WeakPtr<OnDeviceModelComponentStateManager> GetWeakPtr() {
|
base::WeakPtr<OnDeviceModelComponentStateManager> GetWeakPtr() {
|
||||||
return weak_ptr_factory_.GetWeakPtr();
|
return weak_ptr_factory_.GetWeakPtr();
|
||||||
}
|
}
|
||||||
|
@@ -53,7 +53,8 @@ class OnDeviceModelComponentTest : public testing::Test {
|
|||||||
|
|
||||||
feature_list_.InitWithFeaturesAndParameters(
|
feature_list_.InitWithFeaturesAndParameters(
|
||||||
{{features::kOptimizationGuideModelExecution, {}},
|
{{features::kOptimizationGuideModelExecution, {}},
|
||||||
{features::kOptimizationGuideOnDeviceModel,
|
{features::kOptimizationGuideOnDeviceModel, {}},
|
||||||
|
{features::kOnDeviceModelPerformanceParams,
|
||||||
{{"compatible_on_device_performance_classes", "3,4,5,6"}}}},
|
{{"compatible_on_device_performance_classes", "3,4,5,6"}}}},
|
||||||
/*disabled_features=*/{});
|
/*disabled_features=*/{});
|
||||||
}
|
}
|
||||||
|
@@ -279,6 +279,10 @@ void OnDeviceModelServiceController::OnModelAssetsLoaded(
|
|||||||
// TODO(crbug.com/302402959): Choose max_tokens based on device.
|
// TODO(crbug.com/302402959): Choose max_tokens based on device.
|
||||||
params->max_tokens = features::GetOnDeviceModelMaxTokens();
|
params->max_tokens = features::GetOnDeviceModelMaxTokens();
|
||||||
params->adaptation_ranks = features::GetOnDeviceModelAllowedAdaptationRanks();
|
params->adaptation_ranks = features::GetOnDeviceModelAllowedAdaptationRanks();
|
||||||
|
if (on_device_component_state_manager_ &&
|
||||||
|
on_device_component_state_manager_->IsLowTierDevice()) {
|
||||||
|
params->performance_hint = ml::ModelPerformanceHint::kFastestInference;
|
||||||
|
}
|
||||||
service_client_.Get()->LoadModel(
|
service_client_.Get()->LoadModel(
|
||||||
std::move(params), std::move(model),
|
std::move(params), std::move(model),
|
||||||
base::DoNothingAs<void(on_device_model::mojom::LoadModelResult)>());
|
base::DoNothingAs<void(on_device_model::mojom::LoadModelResult)>());
|
||||||
|
16
components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
16
components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
@@ -213,6 +213,9 @@ class OnDeviceModelServiceControllerTest : public testing::Test {
|
|||||||
{"on_device_model_disable_crash_count", "3"},
|
{"on_device_model_disable_crash_count", "3"},
|
||||||
{"on_device_model_crash_backoff_base_time", "1m"},
|
{"on_device_model_crash_backoff_base_time", "1m"},
|
||||||
{"on_device_model_max_crash_backoff_time", "1h"}}},
|
{"on_device_model_max_crash_backoff_time", "1h"}}},
|
||||||
|
{features::kOnDeviceModelPerformanceParams,
|
||||||
|
{{"compatible_on_device_performance_classes", "*"},
|
||||||
|
{"compatible_low_tier_on_device_performance_classes", "3"}}},
|
||||||
{features::kTextSafetyClassifier, {}},
|
{features::kTextSafetyClassifier, {}},
|
||||||
{features::kOnDeviceModelValidation,
|
{features::kOnDeviceModelValidation,
|
||||||
{{"on_device_model_validation_delay", "0"}}}},
|
{{"on_device_model_validation_delay", "0"}}}},
|
||||||
@@ -3851,4 +3854,17 @@ TEST_F(OnDeviceModelServiceControllerTest, LoggingModeAlwaysDisable) {
|
|||||||
EXPECT_EQ(0u, test_uploader.uploaded_logs().size());
|
EXPECT_EQ(0u, test_uploader.uploaded_logs().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OnDeviceModelServiceControllerTest, SendsPerformanceHint) {
|
||||||
|
// Low performance class should use fastest inference.
|
||||||
|
pref_service_.SetInteger(
|
||||||
|
model_execution::prefs::localstate::kOnDevicePerformanceClass,
|
||||||
|
base::to_underlying(OnDeviceModelPerformanceClass::kLow));
|
||||||
|
Initialize(standard_assets_);
|
||||||
|
auto session = CreateSession();
|
||||||
|
session->ExecuteModel(PageUrlRequest("foo"),
|
||||||
|
response_.GetStreamingCallback());
|
||||||
|
ASSERT_TRUE(response_.GetFinalStatus());
|
||||||
|
EXPECT_EQ(*response_.value(), "Fastest inference\nInput: execute:foo\n");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace optimization_guide
|
} // namespace optimization_guide
|
||||||
|
@@ -202,10 +202,19 @@ BASE_FEATURE(kAiSettingsPageEnterpriseDisabledUi,
|
|||||||
"AiSettingsPageEnterpriseDisabledUi",
|
"AiSettingsPageEnterpriseDisabledUi",
|
||||||
base::FEATURE_DISABLED_BY_DEFAULT);
|
base::FEATURE_DISABLED_BY_DEFAULT);
|
||||||
|
|
||||||
|
BASE_FEATURE(kOnDeviceModelPerformanceParams,
|
||||||
|
"OnDeviceModelPerformanceParams",
|
||||||
|
base::FEATURE_ENABLED_BY_DEFAULT);
|
||||||
|
|
||||||
const base::FeatureParam<std::string> kPerformanceClassListForOnDeviceModel{
|
const base::FeatureParam<std::string> kPerformanceClassListForOnDeviceModel{
|
||||||
&kOptimizationGuideOnDeviceModel,
|
&kOnDeviceModelPerformanceParams,
|
||||||
"compatible_on_device_performance_classes", "5,6"};
|
"compatible_on_device_performance_classes", "5,6"};
|
||||||
|
|
||||||
|
const base::FeatureParam<std::string>
|
||||||
|
kLowTierPerformanceClassListForOnDeviceModel{
|
||||||
|
&kOnDeviceModelPerformanceParams,
|
||||||
|
"compatible_low_tier_on_device_performance_classes", ""};
|
||||||
|
|
||||||
BASE_FEATURE(kOptimizationGuideIconView,
|
BASE_FEATURE(kOptimizationGuideIconView,
|
||||||
"OptimizationGuideIconView",
|
"OptimizationGuideIconView",
|
||||||
base::FEATURE_DISABLED_BY_DEFAULT);
|
base::FEATURE_DISABLED_BY_DEFAULT);
|
||||||
|
@@ -90,6 +90,11 @@ BASE_DECLARE_FEATURE(kPrivacyGuideAiSettings);
|
|||||||
COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
|
COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
|
||||||
extern const base::FeatureParam<bool> kShowAiSettingsForTesting;
|
extern const base::FeatureParam<bool> kShowAiSettingsForTesting;
|
||||||
|
|
||||||
|
// Allows setting feature params for model download configuration, such as
|
||||||
|
// minimum performance class for download.
|
||||||
|
COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
|
||||||
|
BASE_DECLARE_FEATURE(kOnDeviceModelPerformanceParams);
|
||||||
|
|
||||||
// Comma-separated list of performance classes (e.g. "3,4,5") that should
|
// Comma-separated list of performance classes (e.g. "3,4,5") that should
|
||||||
// download the base model. Use "*" if there is no performance class
|
// download the base model. Use "*" if there is no performance class
|
||||||
// requirement.
|
// requirement.
|
||||||
@@ -97,6 +102,13 @@ COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
|
|||||||
extern const base::FeatureParam<std::string>
|
extern const base::FeatureParam<std::string>
|
||||||
kPerformanceClassListForOnDeviceModel;
|
kPerformanceClassListForOnDeviceModel;
|
||||||
|
|
||||||
|
// Comma-separated list of performance classes that should use a smaller model
|
||||||
|
// if available. This should be a subset of
|
||||||
|
// kPerformanceClassListForOnDeviceModel.
|
||||||
|
COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
|
||||||
|
extern const base::FeatureParam<std::string>
|
||||||
|
kLowTierPerformanceClassListForOnDeviceModel;
|
||||||
|
|
||||||
COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
|
COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
|
||||||
BASE_DECLARE_FEATURE(kOptimizationGuideIconView);
|
BASE_DECLARE_FEATURE(kOptimizationGuideIconView);
|
||||||
|
|
||||||
|
@@ -65,7 +65,8 @@ bool QueryGPUAdapter(void (*adapter_callback_fn)(WGPUAdapter adapter,
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct FakeModelInstance {
|
struct FakeModelInstance {
|
||||||
ModelBackendType backend_type_;
|
ml::ModelBackendType backend_type_;
|
||||||
|
ml::ModelPerformanceHint performance_hint;
|
||||||
std::string model_data_;
|
std::string model_data_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -97,8 +98,10 @@ std::string ReadFile(PlatformFile api_file) {
|
|||||||
ChromeMLModel SessionCreateModel(const ChromeMLModelDescriptor* descriptor,
|
ChromeMLModel SessionCreateModel(const ChromeMLModelDescriptor* descriptor,
|
||||||
uintptr_t context,
|
uintptr_t context,
|
||||||
ChromeMLScheduleFn schedule) {
|
ChromeMLScheduleFn schedule) {
|
||||||
return reinterpret_cast<ChromeMLModel>(
|
return reinterpret_cast<ChromeMLModel>(new FakeModelInstance{
|
||||||
new FakeModelInstance{.backend_type_ = descriptor->backend_type});
|
.backend_type_ = descriptor->backend_type,
|
||||||
|
.performance_hint = descriptor->performance_hint,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void DestroyModel(ChromeMLModel model) {
|
void DestroyModel(ChromeMLModel model) {
|
||||||
@@ -121,11 +124,11 @@ ChromeMLSession CreateSession(ChromeMLModel model,
|
|||||||
if (descriptor) {
|
if (descriptor) {
|
||||||
instance->enable_image_input = descriptor->enable_image_input;
|
instance->enable_image_input = descriptor->enable_image_input;
|
||||||
if (descriptor->model_data) {
|
if (descriptor->model_data) {
|
||||||
if (model_instance->backend_type_ == ModelBackendType::kGpuBackend) {
|
if (model_instance->backend_type_ == ml::ModelBackendType::kGpuBackend) {
|
||||||
instance->adaptation_data_ =
|
instance->adaptation_data_ =
|
||||||
ReadFile(descriptor->model_data->weights_file);
|
ReadFile(descriptor->model_data->weights_file);
|
||||||
} else if (model_instance->backend_type_ ==
|
} else if (model_instance->backend_type_ ==
|
||||||
ModelBackendType::kApuBackend) {
|
ml::ModelBackendType::kApuBackend) {
|
||||||
base::ReadFileToString(
|
base::ReadFileToString(
|
||||||
base::FilePath::FromUTF8Unsafe(descriptor->model_data->model_path),
|
base::FilePath::FromUTF8Unsafe(descriptor->model_data->model_path),
|
||||||
&instance->adaptation_data_);
|
&instance->adaptation_data_);
|
||||||
@@ -206,6 +209,10 @@ bool SessionExecuteModel(ChromeMLSession session,
|
|||||||
output_fn(&output);
|
output_fn(&output);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (reinterpret_cast<FakeModelInstance*>(model)->performance_hint ==
|
||||||
|
ml::ModelPerformanceHint::kFastestInference) {
|
||||||
|
OutputChunk("Fastest inference\n");
|
||||||
|
}
|
||||||
if (!instance->adaptation_data_.empty()) {
|
if (!instance->adaptation_data_.empty()) {
|
||||||
OutputChunk("Adaptation: " + instance->adaptation_data_ + "\n");
|
OutputChunk("Adaptation: " + instance->adaptation_data_ + "\n");
|
||||||
}
|
}
|
||||||
|
@@ -15,6 +15,9 @@
|
|||||||
|
|
||||||
// This header defines the public interface to the ChromeML shared library.
|
// This header defines the public interface to the ChromeML shared library.
|
||||||
|
|
||||||
|
// TODO: crbug.com/379723772 - Remove this when internal code migrates.
|
||||||
|
using ::ml::ModelBackendType;
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
||||||
// A function used to handle fatal errors.
|
// A function used to handle fatal errors.
|
||||||
@@ -43,15 +46,6 @@ using ChromeMLTSModel = uintptr_t;
|
|||||||
// Opaque handle to a video-frame-specific ML inference engine.
|
// Opaque handle to a video-frame-specific ML inference engine.
|
||||||
using ChromeMLInferenceEngine = uintptr_t;
|
using ChromeMLInferenceEngine = uintptr_t;
|
||||||
|
|
||||||
// Type of the backend to run the model.
|
|
||||||
enum ModelBackendType {
|
|
||||||
// The default WebGPU backend.
|
|
||||||
kGpuBackend = 0,
|
|
||||||
// The APU accelerator backend. Only available on devices with APU, and need
|
|
||||||
// special APU model files.
|
|
||||||
kApuBackend = 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
// A contiguous byte span.
|
// A contiguous byte span.
|
||||||
struct ChromeMLByteSpan {
|
struct ChromeMLByteSpan {
|
||||||
uint8_t* data;
|
uint8_t* data;
|
||||||
@@ -79,7 +73,7 @@ struct ChromeMLModelData {
|
|||||||
// Describes a model to use with ChromeML.
|
// Describes a model to use with ChromeML.
|
||||||
struct ChromeMLModelDescriptor {
|
struct ChromeMLModelDescriptor {
|
||||||
// The backend to run this model.
|
// The backend to run this model.
|
||||||
ModelBackendType backend_type;
|
ml::ModelBackendType backend_type;
|
||||||
|
|
||||||
// The model data to use.
|
// The model data to use.
|
||||||
const ChromeMLModelData* model_data;
|
const ChromeMLModelData* model_data;
|
||||||
@@ -105,6 +99,8 @@ struct ChromeMLModelDescriptor {
|
|||||||
bool enable_host_mapped_pointer;
|
bool enable_host_mapped_pointer;
|
||||||
bool use_low_power;
|
bool use_low_power;
|
||||||
bool allow_fp16;
|
bool allow_fp16;
|
||||||
|
|
||||||
|
ml::ModelPerformanceHint performance_hint;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Describes an adaptation for a model.
|
// Describes an adaptation for a model.
|
||||||
|
@@ -28,6 +28,21 @@ enum class Token {
|
|||||||
// current library version.
|
// current library version.
|
||||||
using InputPiece = std::variant<Token, std::string, SkBitmap, bool>;
|
using InputPiece = std::variant<Token, std::string, SkBitmap, bool>;
|
||||||
|
|
||||||
|
// Options for specifying the performance characteristics of the model to load.
|
||||||
|
enum class ModelPerformanceHint {
|
||||||
|
kHighestQuality,
|
||||||
|
kFastestInference,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Type of the backend to run the model.
|
||||||
|
enum ModelBackendType {
|
||||||
|
// The default WebGPU backend.
|
||||||
|
kGpuBackend = 0,
|
||||||
|
// The APU accelerator backend. Only available on devices with APU, and need
|
||||||
|
// special APU model files.
|
||||||
|
kApuBackend = 1,
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace ml
|
} // namespace ml
|
||||||
|
|
||||||
#endif // SERVICES_ON_DEVICE_MODEL_ML_CHROME_ML_TYPES_H_
|
#endif // SERVICES_ON_DEVICE_MODEL_ML_CHROME_ML_TYPES_H_
|
||||||
|
@@ -98,4 +98,62 @@ bool UnionTraits<on_device_model::mojom::InputPieceDataView, ml::InputPiece>::
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// static
|
||||||
|
on_device_model::mojom::ModelBackendType
|
||||||
|
EnumTraits<on_device_model::mojom::ModelBackendType,
|
||||||
|
ml::ModelBackendType>::ToMojom(ml::ModelBackendType input) {
|
||||||
|
switch (input) {
|
||||||
|
case ml::ModelBackendType::kGpuBackend:
|
||||||
|
return on_device_model::mojom::ModelBackendType::kGpu;
|
||||||
|
case ml::ModelBackendType::kApuBackend:
|
||||||
|
return on_device_model::mojom::ModelBackendType::kApu;
|
||||||
|
}
|
||||||
|
NOTREACHED();
|
||||||
|
}
|
||||||
|
|
||||||
|
// static
|
||||||
|
bool EnumTraits<on_device_model::mojom::ModelBackendType,
|
||||||
|
ml::ModelBackendType>::
|
||||||
|
FromMojom(on_device_model::mojom::ModelBackendType input,
|
||||||
|
ml::ModelBackendType* output) {
|
||||||
|
switch (input) {
|
||||||
|
case on_device_model::mojom::ModelBackendType::kGpu:
|
||||||
|
*output = ml::ModelBackendType::kGpuBackend;
|
||||||
|
return true;
|
||||||
|
case on_device_model::mojom::ModelBackendType::kApu:
|
||||||
|
*output = ml::ModelBackendType::kApuBackend;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// static
|
||||||
|
on_device_model::mojom::ModelPerformanceHint
|
||||||
|
EnumTraits<on_device_model::mojom::ModelPerformanceHint,
|
||||||
|
ml::ModelPerformanceHint>::ToMojom(ml::ModelPerformanceHint input) {
|
||||||
|
switch (input) {
|
||||||
|
case ml::ModelPerformanceHint::kHighestQuality:
|
||||||
|
return on_device_model::mojom::ModelPerformanceHint::kHighestQuality;
|
||||||
|
case ml::ModelPerformanceHint::kFastestInference:
|
||||||
|
return on_device_model::mojom::ModelPerformanceHint::kFastestInference;
|
||||||
|
}
|
||||||
|
NOTREACHED();
|
||||||
|
}
|
||||||
|
|
||||||
|
// static
|
||||||
|
bool EnumTraits<on_device_model::mojom::ModelPerformanceHint,
|
||||||
|
ml::ModelPerformanceHint>::
|
||||||
|
FromMojom(on_device_model::mojom::ModelPerformanceHint input,
|
||||||
|
ml::ModelPerformanceHint* output) {
|
||||||
|
switch (input) {
|
||||||
|
case on_device_model::mojom::ModelPerformanceHint::kHighestQuality:
|
||||||
|
*output = ml::ModelPerformanceHint::kHighestQuality;
|
||||||
|
return true;
|
||||||
|
case on_device_model::mojom::ModelPerformanceHint::kFastestInference:
|
||||||
|
*output = ml::ModelPerformanceHint::kFastestInference;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mojo
|
} // namespace mojo
|
||||||
|
@@ -43,6 +43,24 @@ struct UnionTraits<on_device_model::mojom::InputPieceDataView, ml::InputPiece> {
|
|||||||
ml::InputPiece* out);
|
ml::InputPiece* out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct EnumTraits<on_device_model::mojom::ModelBackendType,
|
||||||
|
ml::ModelBackendType> {
|
||||||
|
static on_device_model::mojom::ModelBackendType ToMojom(
|
||||||
|
ml::ModelBackendType input);
|
||||||
|
static bool FromMojom(on_device_model::mojom::ModelBackendType input,
|
||||||
|
ml::ModelBackendType* output);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct EnumTraits<on_device_model::mojom::ModelPerformanceHint,
|
||||||
|
ml::ModelPerformanceHint> {
|
||||||
|
static on_device_model::mojom::ModelPerformanceHint ToMojom(
|
||||||
|
ml::ModelPerformanceHint input);
|
||||||
|
static bool FromMojom(on_device_model::mojom::ModelPerformanceHint input,
|
||||||
|
ml::ModelPerformanceHint* output);
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace mojo
|
} // namespace mojo
|
||||||
|
|
||||||
#endif // SERVICES_ON_DEVICE_MODEL_ML_CHROME_ML_TYPES_TRAITS_H_
|
#endif // SERVICES_ON_DEVICE_MODEL_ML_CHROME_ML_TYPES_TRAITS_H_
|
||||||
|
@@ -16,7 +16,7 @@ namespace ml {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
const base::FeatureParam<std::string> kGpuBlockList{
|
const base::FeatureParam<std::string> kGpuBlockList{
|
||||||
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
&optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||||
"on_device_model_gpu_block_list",
|
"on_device_model_gpu_block_list",
|
||||||
// These devices are nearly always crashing or have very low performance.
|
// These devices are nearly always crashing or have very low performance.
|
||||||
"8086:412|8086:a16|8086:41e|8086:416|8086:402|8086:166|8086:1616|8086:22b1|"
|
"8086:412|8086:a16|8086:41e|8086:416|8086:402|8086:166|8086:1616|8086:22b1|"
|
||||||
|
@@ -105,18 +105,6 @@ uint32_t GetTopK(std::optional<uint32_t> top_k) {
|
|||||||
std::max(1u, top_k.value_or(1)));
|
std::max(1u, top_k.value_or(1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<ModelBackendType> ModelBackendTypeFromMojom(
|
|
||||||
on_device_model::mojom::ModelBackendType backend) {
|
|
||||||
switch (backend) {
|
|
||||||
case on_device_model::mojom::ModelBackendType::kGpu:
|
|
||||||
return ModelBackendType::kGpuBackend;
|
|
||||||
case on_device_model::mojom::ModelBackendType::kApu:
|
|
||||||
return ModelBackendType::kApuBackend;
|
|
||||||
default:
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Handles sending and canceling responses.
|
// Handles sending and canceling responses.
|
||||||
@@ -466,24 +454,17 @@ LoadModelResult OnDeviceModelExecutor::Init(
|
|||||||
|
|
||||||
max_tokens_ = std::max(params->max_tokens, kReserveTokensForSafety);
|
max_tokens_ = std::max(params->max_tokens, kReserveTokensForSafety);
|
||||||
|
|
||||||
std::optional<ModelBackendType> backend_type =
|
|
||||||
ModelBackendTypeFromMojom(params->backend_type);
|
|
||||||
if (!backend_type.has_value()) {
|
|
||||||
LOG(ERROR) << "Failed to parse model backend type";
|
|
||||||
return LoadModelResult::kFailedToLoadLibrary;
|
|
||||||
}
|
|
||||||
|
|
||||||
ChromeMLModelData data;
|
ChromeMLModelData data;
|
||||||
std::string weights_path_str = assets.weights_path.AsUTF8Unsafe();
|
std::string weights_path_str = assets.weights_path.AsUTF8Unsafe();
|
||||||
std::string sp_model_path_str = assets.sp_model_path.AsUTF8Unsafe();
|
std::string sp_model_path_str = assets.sp_model_path.AsUTF8Unsafe();
|
||||||
if (*backend_type == ModelBackendType::kGpuBackend) {
|
if (params->backend_type == ml::ModelBackendType::kGpuBackend) {
|
||||||
data.weights_file = assets.weights.TakePlatformFile();
|
data.weights_file = assets.weights.TakePlatformFile();
|
||||||
} else {
|
} else {
|
||||||
data.model_path = weights_path_str.data();
|
data.model_path = weights_path_str.data();
|
||||||
data.sentencepiece_model_path = sp_model_path_str.data();
|
data.sentencepiece_model_path = sp_model_path_str.data();
|
||||||
}
|
}
|
||||||
ChromeMLModelDescriptor descriptor = {
|
ChromeMLModelDescriptor descriptor = {
|
||||||
.backend_type = *backend_type,
|
.backend_type = params->backend_type,
|
||||||
.model_data = &data,
|
.model_data = &data,
|
||||||
.max_tokens = max_tokens_,
|
.max_tokens = max_tokens_,
|
||||||
.temperature = 0.0f,
|
.temperature = 0.0f,
|
||||||
@@ -494,6 +475,7 @@ LoadModelResult OnDeviceModelExecutor::Init(
|
|||||||
.enable_host_mapped_pointer = kEnableHostMappedPointer.Get(),
|
.enable_host_mapped_pointer = kEnableHostMappedPointer.Get(),
|
||||||
.use_low_power = kUseLowPower.Get(),
|
.use_low_power = kUseLowPower.Get(),
|
||||||
.allow_fp16 = kAllowFp16.Get(),
|
.allow_fp16 = kAllowFp16.Get(),
|
||||||
|
.performance_hint = params->performance_hint,
|
||||||
};
|
};
|
||||||
model_ = chrome_ml_->api().SessionCreateModel(
|
model_ = chrome_ml_->api().SessionCreateModel(
|
||||||
&descriptor, reinterpret_cast<uintptr_t>(this),
|
&descriptor, reinterpret_cast<uintptr_t>(this),
|
||||||
|
@@ -18,30 +18,30 @@ constexpr uint64_t kBytesPerMb = 1024 * 1024;
|
|||||||
|
|
||||||
// The threshold for GPU RAM below which the device is considered VeryLow.
|
// The threshold for GPU RAM below which the device is considered VeryLow.
|
||||||
const base::FeatureParam<int> kLowRAMThreshold{
|
const base::FeatureParam<int> kLowRAMThreshold{
|
||||||
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
&optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||||
"on_device_low_ram_threshold_mb", 3000};
|
"on_device_low_ram_threshold_mb", 3000};
|
||||||
// RAM threshold necessary to be considered High or better.
|
// RAM threshold necessary to be considered High or better.
|
||||||
const base::FeatureParam<int> kHighRAMThreshold{
|
const base::FeatureParam<int> kHighRAMThreshold{
|
||||||
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
&optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||||
"on_device_high_ram_threshold_mb", 5500};
|
"on_device_high_ram_threshold_mb", 5500};
|
||||||
|
|
||||||
// Output threshold to be considered Low or better.
|
// Output threshold to be considered Low or better.
|
||||||
const base::FeatureParam<int> kLowOutputThreshold{
|
const base::FeatureParam<int> kLowOutputThreshold{
|
||||||
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
&optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||||
"on_device_low_output_threshold", 5};
|
"on_device_low_output_threshold", 5};
|
||||||
|
|
||||||
// Input speed min thresholds or each device class.
|
// Input speed min thresholds or each device class.
|
||||||
const base::FeatureParam<int> kLowThreshold{
|
const base::FeatureParam<int> kLowThreshold{
|
||||||
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
&optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||||
"on_device_low_threshold", 50};
|
"on_device_low_threshold", 50};
|
||||||
const base::FeatureParam<int> kMediumThreshold{
|
const base::FeatureParam<int> kMediumThreshold{
|
||||||
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
&optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||||
"on_device_medium_threshold", 75};
|
"on_device_medium_threshold", 75};
|
||||||
const base::FeatureParam<int> kHighThreshold{
|
const base::FeatureParam<int> kHighThreshold{
|
||||||
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
&optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||||
"on_device_high_threshold", 150};
|
"on_device_high_threshold", 150};
|
||||||
const base::FeatureParam<int> kVeryHighThreshold{
|
const base::FeatureParam<int> kVeryHighThreshold{
|
||||||
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
&optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||||
"on_device_very_high_threshold", 500};
|
"on_device_very_high_threshold", 500};
|
||||||
|
|
||||||
// These values are persisted to logs. Entries should not be renumbered and
|
// These values are persisted to logs. Entries should not be renumbered and
|
||||||
|
@@ -81,11 +81,14 @@ class OnDeviceModelServiceTest : public testing::Test {
|
|||||||
mojo::Remote<mojom::OnDeviceModelService>& service() { return service_; }
|
mojo::Remote<mojom::OnDeviceModelService>& service() { return service_; }
|
||||||
|
|
||||||
mojo::Remote<mojom::OnDeviceModel> LoadModel(
|
mojo::Remote<mojom::OnDeviceModel> LoadModel(
|
||||||
mojom::ModelBackendType backend_type = mojom::ModelBackendType::kGpu) {
|
ml::ModelBackendType backend_type = ml::ModelBackendType::kGpuBackend,
|
||||||
|
ml::ModelPerformanceHint performance_hint =
|
||||||
|
ml::ModelPerformanceHint::kHighestQuality) {
|
||||||
base::RunLoop run_loop;
|
base::RunLoop run_loop;
|
||||||
mojo::Remote<mojom::OnDeviceModel> remote;
|
mojo::Remote<mojom::OnDeviceModel> remote;
|
||||||
auto params = mojom::LoadModelParams::New();
|
auto params = mojom::LoadModelParams::New();
|
||||||
params->backend_type = backend_type;
|
params->backend_type = backend_type;
|
||||||
|
params->performance_hint = performance_hint;
|
||||||
params->max_tokens = 8000;
|
params->max_tokens = 8000;
|
||||||
service()->LoadModel(
|
service()->LoadModel(
|
||||||
std::move(params), remote.BindNewPipeAndPassReceiver(),
|
std::move(params), remote.BindNewPipeAndPassReceiver(),
|
||||||
@@ -452,7 +455,7 @@ TEST_F(OnDeviceModelServiceTest, DestroysAdaptationSession) {
|
|||||||
TEST_F(OnDeviceModelServiceTest, LoadsAdaptationWithPath) {
|
TEST_F(OnDeviceModelServiceTest, LoadsAdaptationWithPath) {
|
||||||
FakeFile weights1("Adapt1");
|
FakeFile weights1("Adapt1");
|
||||||
FakeFile weights2("Adapt2");
|
FakeFile weights2("Adapt2");
|
||||||
auto model = LoadModel(mojom::ModelBackendType::kApu);
|
auto model = LoadModel(ml::ModelBackendType::kApuBackend);
|
||||||
auto adaptation1 = LoadAdaptation(*model, weights1.Path());
|
auto adaptation1 = LoadAdaptation(*model, weights1.Path());
|
||||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Input: foo\n"));
|
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Input: foo\n"));
|
||||||
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
|
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
|
||||||
@@ -633,5 +636,12 @@ TEST_F(OnDeviceModelServiceTest, ClassifyTextSafety) {
|
|||||||
EXPECT_THAT(resp2->class_scores, ElementsAre(0.2, 0.2));
|
EXPECT_THAT(resp2->class_scores, ElementsAre(0.2, 0.2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OnDeviceModelServiceTest, PerformanceHint) {
|
||||||
|
auto model = LoadModel(ml::ModelBackendType::kGpuBackend,
|
||||||
|
ml::ModelPerformanceHint::kFastestInference);
|
||||||
|
EXPECT_THAT(GetResponses(*model, "foo"),
|
||||||
|
ElementsAre("Fastest inference\n", "Input: foo\n"));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace on_device_model
|
} // namespace on_device_model
|
||||||
|
@@ -130,6 +130,12 @@ void FakeOnDeviceSession::ExecuteImpl(
|
|||||||
mojom::InputOptionsPtr input,
|
mojom::InputOptionsPtr input,
|
||||||
mojo::PendingRemote<mojom::StreamingResponder> response) {
|
mojo::PendingRemote<mojom::StreamingResponder> response) {
|
||||||
mojo::Remote<mojom::StreamingResponder> remote(std::move(response));
|
mojo::Remote<mojom::StreamingResponder> remote(std::move(response));
|
||||||
|
if (model_->performance_hint() ==
|
||||||
|
ml::ModelPerformanceHint::kFastestInference) {
|
||||||
|
auto chunk = mojom::ResponseChunk::New();
|
||||||
|
chunk->text = "Fastest inference\n";
|
||||||
|
remote->OnResponse(std::move(chunk));
|
||||||
|
}
|
||||||
if (model_->data().base_weight != "0") {
|
if (model_->data().base_weight != "0") {
|
||||||
auto chunk = mojom::ResponseChunk::New();
|
auto chunk = mojom::ResponseChunk::New();
|
||||||
chunk->text = "Base model: " + model_->data().base_weight + "\n";
|
chunk->text = "Base model: " + model_->data().base_weight + "\n";
|
||||||
@@ -184,8 +190,11 @@ void FakeOnDeviceSession::AddContextInternal(
|
|||||||
}
|
}
|
||||||
|
|
||||||
FakeOnDeviceModel::FakeOnDeviceModel(FakeOnDeviceServiceSettings* settings,
|
FakeOnDeviceModel::FakeOnDeviceModel(FakeOnDeviceServiceSettings* settings,
|
||||||
FakeOnDeviceModel::Data&& data)
|
FakeOnDeviceModel::Data&& data,
|
||||||
: settings_(settings), data_(std::move(data)) {}
|
ml::ModelPerformanceHint performance_hint)
|
||||||
|
: settings_(settings),
|
||||||
|
data_(std::move(data)),
|
||||||
|
performance_hint_(performance_hint) {}
|
||||||
|
|
||||||
FakeOnDeviceModel::~FakeOnDeviceModel() = default;
|
FakeOnDeviceModel::~FakeOnDeviceModel() = default;
|
||||||
|
|
||||||
@@ -221,8 +230,8 @@ void FakeOnDeviceModel::LoadAdaptation(
|
|||||||
LoadAdaptationCallback callback) {
|
LoadAdaptationCallback callback) {
|
||||||
Data data = data_;
|
Data data = data_;
|
||||||
data.adaptation_model_weight = ReadFile(params->assets.weights);
|
data.adaptation_model_weight = ReadFile(params->assets.weights);
|
||||||
auto test_model =
|
auto test_model = std::make_unique<FakeOnDeviceModel>(
|
||||||
std::make_unique<FakeOnDeviceModel>(settings_, std::move(data));
|
settings_, std::move(data), ml::ModelPerformanceHint::kHighestQuality);
|
||||||
model_adaptation_receivers_.Add(std::move(test_model), std::move(model));
|
model_adaptation_receivers_.Add(std::move(test_model), std::move(model));
|
||||||
std::move(callback).Run(mojom::LoadModelResult::kSuccess);
|
std::move(callback).Run(mojom::LoadModelResult::kSuccess);
|
||||||
}
|
}
|
||||||
@@ -292,8 +301,8 @@ void FakeOnDeviceModelService::LoadModel(
|
|||||||
}
|
}
|
||||||
FakeOnDeviceModel::Data data;
|
FakeOnDeviceModel::Data data;
|
||||||
data.base_weight = ReadFile(params->assets.weights);
|
data.base_weight = ReadFile(params->assets.weights);
|
||||||
auto test_model =
|
auto test_model = std::make_unique<FakeOnDeviceModel>(
|
||||||
std::make_unique<FakeOnDeviceModel>(settings_, std::move(data));
|
settings_, std::move(data), params->performance_hint);
|
||||||
model_receivers_.Add(std::move(test_model), std::move(model));
|
model_receivers_.Add(std::move(test_model), std::move(model));
|
||||||
std::move(callback).Run(mojom::LoadModelResult::kSuccess);
|
std::move(callback).Run(mojom::LoadModelResult::kSuccess);
|
||||||
}
|
}
|
||||||
|
@@ -112,7 +112,8 @@ class FakeOnDeviceModel : public mojom::OnDeviceModel {
|
|||||||
std::string adaptation_model_weight = "";
|
std::string adaptation_model_weight = "";
|
||||||
};
|
};
|
||||||
explicit FakeOnDeviceModel(FakeOnDeviceServiceSettings* settings,
|
explicit FakeOnDeviceModel(FakeOnDeviceServiceSettings* settings,
|
||||||
Data&& data);
|
Data&& data,
|
||||||
|
ml::ModelPerformanceHint performance_hint);
|
||||||
~FakeOnDeviceModel() override;
|
~FakeOnDeviceModel() override;
|
||||||
|
|
||||||
// mojom::OnDeviceModel:
|
// mojom::OnDeviceModel:
|
||||||
@@ -134,9 +135,14 @@ class FakeOnDeviceModel : public mojom::OnDeviceModel {
|
|||||||
|
|
||||||
const Data& data() const { return data_; }
|
const Data& data() const { return data_; }
|
||||||
|
|
||||||
|
ml::ModelPerformanceHint performance_hint() const {
|
||||||
|
return performance_hint_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
raw_ptr<FakeOnDeviceServiceSettings> settings_;
|
raw_ptr<FakeOnDeviceServiceSettings> settings_;
|
||||||
Data data_;
|
Data data_;
|
||||||
|
ml::ModelPerformanceHint performance_hint_;
|
||||||
|
|
||||||
mojo::UniqueReceiverSet<mojom::Session> receivers_;
|
mojo::UniqueReceiverSet<mojom::Session> receivers_;
|
||||||
mojo::UniqueReceiverSet<mojom::OnDeviceModel> model_adaptation_receivers_;
|
mojo::UniqueReceiverSet<mojom::OnDeviceModel> model_adaptation_receivers_;
|
||||||
|
@@ -37,6 +37,14 @@ mojom("mojom") {
|
|||||||
mojom = "on_device_model.mojom.InputPiece"
|
mojom = "on_device_model.mojom.InputPiece"
|
||||||
cpp = "::ml::InputPiece"
|
cpp = "::ml::InputPiece"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
mojom = "on_device_model.mojom.ModelBackendType"
|
||||||
|
cpp = "::ml::ModelBackendType"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
mojom = "on_device_model.mojom.ModelPerformanceHint"
|
||||||
|
cpp = "::ml::ModelPerformanceHint"
|
||||||
|
},
|
||||||
]
|
]
|
||||||
traits_headers = [
|
traits_headers = [
|
||||||
"//services/on_device_model/public/cpp/adaptation_assets_mojom_traits.h",
|
"//services/on_device_model/public/cpp/adaptation_assets_mojom_traits.h",
|
||||||
|
@@ -36,6 +36,12 @@ enum ModelBackendType {
|
|||||||
kApu,
|
kApu,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Options for specifying the performance characteristics of the model to load.
|
||||||
|
enum ModelPerformanceHint {
|
||||||
|
kHighestQuality,
|
||||||
|
kFastestInference,
|
||||||
|
};
|
||||||
|
|
||||||
// Params to describe the model to load.
|
// Params to describe the model to load.
|
||||||
struct LoadModelParams {
|
struct LoadModelParams {
|
||||||
// Backend type of the model.
|
// Backend type of the model.
|
||||||
@@ -50,6 +56,10 @@ struct LoadModelParams {
|
|||||||
|
|
||||||
// List of adaptation ranks the model should support.
|
// List of adaptation ranks the model should support.
|
||||||
array<uint32> adaptation_ranks;
|
array<uint32> adaptation_ranks;
|
||||||
|
|
||||||
|
// Chooses the performance characteristics of the model loaded. If only a
|
||||||
|
// single model is available, this field will do nothing.
|
||||||
|
ModelPerformanceHint performance_hint = kHighestQuality;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TextSafetyModelAssets {
|
struct TextSafetyModelAssets {
|
||||||
|
@@ -15298,6 +15298,7 @@ from previous Chrome versions.
|
|||||||
<int value="-2036127998" label="LocalWebApprovals:disabled"/>
|
<int value="-2036127998" label="LocalWebApprovals:disabled"/>
|
||||||
<int value="-2035845836" label="ArcExtendServiceAnrTimeout:enabled"/>
|
<int value="-2035845836" label="ArcExtendServiceAnrTimeout:enabled"/>
|
||||||
<int value="-2035126988" label="enabled-new-style-notification"/>
|
<int value="-2035126988" label="enabled-new-style-notification"/>
|
||||||
|
<int value="-2034315473" label="OnDeviceModelPerformanceParams:disabled"/>
|
||||||
<int value="-2034064186" label="EnableKeyboardBacklightToggle:disabled"/>
|
<int value="-2034064186" label="EnableKeyboardBacklightToggle:disabled"/>
|
||||||
<int value="-2033950090" label="AutofillNoLocalSaveOnUploadSuccess:disabled"/>
|
<int value="-2033950090" label="AutofillNoLocalSaveOnUploadSuccess:disabled"/>
|
||||||
<int value="-2033908928" label="NightLight:enabled"/>
|
<int value="-2033908928" label="NightLight:enabled"/>
|
||||||
@@ -22130,6 +22131,7 @@ from previous Chrome versions.
|
|||||||
<int value="705946076"
|
<int value="705946076"
|
||||||
label="ContextMenuPerformanceInfoAndRemoteHintFetching:disabled"/>
|
label="ContextMenuPerformanceInfoAndRemoteHintFetching:disabled"/>
|
||||||
<int value="706280254" label="StoragePressureEvent:enabled"/>
|
<int value="706280254" label="StoragePressureEvent:enabled"/>
|
||||||
|
<int value="707243920" label="OnDeviceModelPerformanceParams:enabled"/>
|
||||||
<int value="707463326" label="DynamicSafeAreaInsets:enabled"/>
|
<int value="707463326" label="DynamicSafeAreaInsets:enabled"/>
|
||||||
<int value="708015891"
|
<int value="708015891"
|
||||||
label="AutofillUpdateChromeSettingsLinkToGPayWeb:enabled"/>
|
label="AutofillUpdateChromeSettingsLinkToGPayWeb:enabled"/>
|
||||||
|
Reference in New Issue
Block a user