0

Add support for performance hints when loading on-device model

This allows passing a performance hint to the backend when loading the
on-device model. This can be used to tune various params in the backend
or use a smaller model if requested.

This is controlled from the optimization guide side by a new feature
param "compatible_low_tier_on_device_performance_classes", which
allows us to choose a set of performance classes that will use the
kFastestInference performance hint.

As part of this change, all the feature params related to model
performance were moved under a new base::Feature to allow them to be
adjusted independent of the other on-device related params.

Bug: 379723772
Change-Id: I28db8de699a8af026ee2ee1cfa08136342860e13
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6089104
Reviewed-by: Matthew Denton <mpdenton@chromium.org>
Commit-Queue: Clark DuVall <cduvall@chromium.org>
Reviewed-by: Sophie Chang <sophiechang@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1395676}
This commit is contained in:
Clark DuVall
2024-12-12 16:09:27 -08:00
committed by Chromium LUCI CQ
parent 82818acdfe
commit 3fdbf21ee2
29 changed files with 285 additions and 68 deletions

@ -6645,7 +6645,7 @@ const FeatureEntry kFeatureEntries[] = {
flag_descriptions::kOptimizationGuideOnDeviceModelName,
flag_descriptions::kOptimizationGuideOnDeviceModelDescription, kOsDesktop,
FEATURE_WITH_PARAMS_VALUE_TYPE(
optimization_guide::features::kOptimizationGuideOnDeviceModel,
optimization_guide::features::kOnDeviceModelPerformanceParams,
kOptimizationGuideOnDeviceModelVariations,
"OptimizationGuideOnDeviceModel")},

@ -690,7 +690,8 @@ class OnDeviceModelExecutionEnabledBrowserTest
scoped_feature_list_.InitWithFeaturesAndParameters(
{{features::kOptimizationGuideModelExecution, {}},
{features::kModelQualityLogging, {}},
{features::kOptimizationGuideOnDeviceModel,
{features::kOptimizationGuideOnDeviceModel, {}},
{features::kOnDeviceModelPerformanceParams,
{{"compatible_on_device_performance_classes", "*"}}}},
{});
}

@ -16,6 +16,7 @@ build_webui("build") {
mojo_files = [
"$root_gen_dir/chrome/browser/ui/webui/on_device_internals/on_device_internals_page.mojom-webui.ts",
"$root_gen_dir/services/on_device_model/public/mojom/on_device_model.mojom-webui.ts",
"$root_gen_dir/services/on_device_model/public/mojom/on_device_model_service.mojom-webui.ts",
]
static_files = [ "on_device_internals.html" ]

@ -1,4 +1,4 @@
<style include="cr-shared-style cr-hidden-style">
<style include="cr-shared-style cr-hidden-style md-select">
:host {
display: block;
margin: auto;
@ -110,24 +110,40 @@
cr-expand-button:hover {
background-color: var(--cr-hover-background-color);
}
.model-options {
display: flex;
flex-direction: row;
}
cr-checkbox {
margin-left: 20px;
--cr-checkbox-label-padding-start: 10px;
}
</style>
<div class="performance-class">
Device performance class: <strong>[[performanceClassText_]]</strong>
</div>
<cr-input id="modelInput" label="Model directory" placeholder="/tmp/model"
disabled="[[isLoading_(loadModelStart_)]]"
on-change="onModelSelected_" error-message="[[error_]]"
invalid="[[error_.length]]" autofocus>
error-message="[[error_]]" invalid="[[error_.length]]" autofocus>
<cr-button slot="suffix" disabled="[[isLoading_(loadModelStart_)]]"
on-click="onLoadClick_">
Load
</cr-button>
</cr-input>
<cr-checkbox slot="suffix" checked="{{enableImageInput_}}">
Enable images
</cr-checkbox>
<div class="model-options">
<select id="performanceHintSelect" class="md-select"
value="[[performanceHint_]]" on-change="onPerformanceHintChange_">
<option value="kHighestQuality">Highest Quality</option>
<option value="kFastestInference">Fastest Inference</option>
</select>
<cr-checkbox slot="suffix" checked="{{enableImageInput_}}">
Enable images
</cr-checkbox>
</div>
<div class="model-text">
[[getModelText_(modelPath_, loadModelDuration_)]]
[[getModelText_(modelPath_, loadModelDuration_, loadedPerformanceHint_)]]
<div class="throbber" hidden$="[[!isLoading_(loadModelStart_)]]"></div>
</div>

@ -10,6 +10,7 @@ import '//resources/cr_elements/cr_hidden_style.css.js';
import '//resources/cr_elements/cr_input/cr_input.js';
import '//resources/cr_elements/cr_shared_vars.css.js';
import '//resources/cr_elements/cr_textarea/cr_textarea.js';
import '//resources/cr_elements/md_select.css.js';
import type {CrInputElement} from '//resources/cr_elements/cr_input/cr_input.js';
import {PolymerElement} from '//resources/polymer/v3_0/polymer/polymer_bundled.min.js';
@ -17,6 +18,7 @@ import {PolymerElement} from '//resources/polymer/v3_0/polymer/polymer_bundled.m
import {BrowserProxy} from './browser_proxy.js';
import type {InputPiece, ResponseChunk, ResponseSummary} from './on_device_model.mojom-webui.js';
import {LoadModelResult, OnDeviceModelRemote, PerformanceClass, SessionRemote, StreamingResponderCallbackRouter, Token} from './on_device_model.mojom-webui.js';
import {ModelPerformanceHint} from './on_device_model_service.mojom-webui.js';
import {getTemplate} from './tools.html.js';
interface Response {
@ -34,6 +36,7 @@ interface OnDeviceInternalsToolsElement {
textInput: CrInputElement,
imageInput: HTMLInputElement,
topKInput: CrInputElement,
performanceHintSelect: HTMLSelectElement,
};
}
@ -131,6 +134,8 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
type: Object,
value: null,
},
performanceHint_: String,
loadedPerformanceHint_: Number,
};
}
@ -159,6 +164,8 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
private topK_: number = 1;
private imageFile_: File|null = null;
private enableImageInput_: boolean = false;
private performanceHint_: string = 'kHighestQuality';
private loadedPerformanceHint_: ModelPerformanceHint|null;
private proxy_: BrowserProxy = BrowserProxy.getInstance();
private responseRouter_: StreamingResponderCallbackRouter =
@ -199,6 +206,10 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
this.$.imageInput.value = '';
}
private onPerformanceHintChange_() {
this.performanceHint_ = this.$.performanceHintSelect.value;
}
private onServiceCrashed_() {
if (this.currentResponse_) {
this.currentResponse_.error = true;
@ -233,6 +244,8 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
this.baseModel_ = null;
this.model_ = null;
this.loadModelStart_ = new Date().getTime();
const performanceHint = ModelPerformanceHint[(
this.performanceHint_ as keyof typeof ModelPerformanceHint)];
const modelPath = this.$.modelInput.value;
// <if expr="is_win">
// Windows file paths are std::wstring, so use Array<Number>.
@ -244,7 +257,8 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
const baseModel = new OnDeviceModelRemote();
let newModel = new OnDeviceModelRemote();
let {result} = await this.proxy_.handler.loadModel(
{path: processedPath}, baseModel.$.bindNewPipeAndPassReceiver());
{path: processedPath}, performanceHint,
baseModel.$.bindNewPipeAndPassReceiver());
if (result === LoadModelResult.kSuccess && this.enableImageInput_) {
result = (await baseModel.loadAdaptation(
{
@ -272,6 +286,7 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
});
this.startNewSession_();
this.modelPath_ = modelPath;
this.loadedPerformanceHint_ = performanceHint;
}
}
@ -419,9 +434,13 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
return '';
}
let text = 'Model loaded from ' + this.modelPath_ + ' in ' +
this.loadModelDuration_ + 'ms';
this.loadModelDuration_ + 'ms ';
if (this.imagesEnabled_()) {
text += ' [images enabled]';
text += '[images enabled]';
}
if (this.loadedPerformanceHint_ ===
ModelPerformanceHint.kFastestInference) {
text += '[fastest inference]';
}
return text;
}

@ -8,6 +8,7 @@ import "mojo/public/mojom/base/big_buffer.mojom";
import "mojo/public/mojom/base/file_path.mojom";
import "mojo/public/mojom/base/time.mojom";
import "services/on_device_model/public/mojom/on_device_model.mojom";
import "services/on_device_model/public/mojom/on_device_model_service.mojom";
import "skia/public/mojom/bitmap.mojom";
// Struct containing data to be displayed on on-device-internals page.
@ -43,6 +44,7 @@ interface OnDeviceInternalsPageHandler {
// Binds a new OnDeviceModel interface if possible using model assets loaded
// from within `model_path`.
LoadModel(mojo_base.mojom.FilePath model_path,
on_device_model.mojom.ModelPerformanceHint performance_hint,
pending_receiver<on_device_model.mojom.OnDeviceModel> model) =>
(on_device_model.mojom.LoadModelResult result);

@ -61,6 +61,7 @@ OnDeviceInternalsPageHandler::~OnDeviceInternalsPageHandler() {
void OnDeviceInternalsPageHandler::LoadModel(
const base::FilePath& model_path,
ml::ModelPerformanceHint performance_hint,
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
LoadModelCallback callback) {
#if BUILDFLAG(USE_CHROMEOS_MODEL_SERVICE)
@ -81,7 +82,7 @@ void OnDeviceInternalsPageHandler::LoadModel(
base::BindOnce(&LoadModelAssets, model_path),
base::BindOnce(&OnDeviceInternalsPageHandler::OnModelAssetsLoaded,
weak_ptr_factory_.GetWeakPtr(), std::move(model),
std::move(callback)));
std::move(callback), performance_hint));
#endif
}
@ -109,10 +110,12 @@ OnDeviceInternalsPageHandler::GetService() {
void OnDeviceInternalsPageHandler::OnModelAssetsLoaded(
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
LoadModelCallback callback,
ml::ModelPerformanceHint performance_hint,
on_device_model::ModelAssets assets) {
auto params = on_device_model::mojom::LoadModelParams::New();
params->assets = std::move(assets);
params->max_tokens = 4096;
params->performance_hint = performance_hint;
GetService().LoadModel(std::move(params), std::move(model),
std::move(callback));
}

@ -43,12 +43,14 @@ class OnDeviceInternalsPageHandler : public mojom::OnDeviceInternalsPageHandler,
void OnModelAssetsLoaded(
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
LoadModelCallback callback,
ml::ModelPerformanceHint performance_hint,
on_device_model::ModelAssets assets);
#endif
// mojom::OnDeviceInternalsPageHandler:
void LoadModel(
const base::FilePath& model_path,
ml::ModelPerformanceHint performance_hint,
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
LoadModelCallback callback) override;
void GetEstimatedPerformanceClass(

@ -138,6 +138,13 @@ OnDeviceModelComponentStateManager::GetRegistrationCriteria() {
return registration_criteria_.get();
}
bool OnDeviceModelComponentStateManager::IsLowTierDevice() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return IsPerformanceClassCompatible(
features::kLowTierPerformanceClassListForOnDeviceModel.Get(),
PerformanceClassFromPref(*local_state_));
}
void OnDeviceModelComponentStateManager::OnDeviceEligibleFeatureUsed(
ModelBasedCapabilityKey feature) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

@ -179,6 +179,9 @@ class OnDeviceModelComponentStateManager
// registration has been computed yet.
const RegistrationCriteria* GetRegistrationCriteria();
// Returns true if this is determined to be a low tier device.
bool IsLowTierDevice() const;
base::WeakPtr<OnDeviceModelComponentStateManager> GetWeakPtr() {
return weak_ptr_factory_.GetWeakPtr();
}

@ -53,7 +53,8 @@ class OnDeviceModelComponentTest : public testing::Test {
feature_list_.InitWithFeaturesAndParameters(
{{features::kOptimizationGuideModelExecution, {}},
{features::kOptimizationGuideOnDeviceModel,
{features::kOptimizationGuideOnDeviceModel, {}},
{features::kOnDeviceModelPerformanceParams,
{{"compatible_on_device_performance_classes", "3,4,5,6"}}}},
/*disabled_features=*/{});
}

@ -279,6 +279,10 @@ void OnDeviceModelServiceController::OnModelAssetsLoaded(
// TODO(crbug.com/302402959): Choose max_tokens based on device.
params->max_tokens = features::GetOnDeviceModelMaxTokens();
params->adaptation_ranks = features::GetOnDeviceModelAllowedAdaptationRanks();
if (on_device_component_state_manager_ &&
on_device_component_state_manager_->IsLowTierDevice()) {
params->performance_hint = ml::ModelPerformanceHint::kFastestInference;
}
service_client_.Get()->LoadModel(
std::move(params), std::move(model),
base::DoNothingAs<void(on_device_model::mojom::LoadModelResult)>());

@ -213,6 +213,9 @@ class OnDeviceModelServiceControllerTest : public testing::Test {
{"on_device_model_disable_crash_count", "3"},
{"on_device_model_crash_backoff_base_time", "1m"},
{"on_device_model_max_crash_backoff_time", "1h"}}},
{features::kOnDeviceModelPerformanceParams,
{{"compatible_on_device_performance_classes", "*"},
{"compatible_low_tier_on_device_performance_classes", "3"}}},
{features::kTextSafetyClassifier, {}},
{features::kOnDeviceModelValidation,
{{"on_device_model_validation_delay", "0"}}}},
@ -3851,4 +3854,17 @@ TEST_F(OnDeviceModelServiceControllerTest, LoggingModeAlwaysDisable) {
EXPECT_EQ(0u, test_uploader.uploaded_logs().size());
}
TEST_F(OnDeviceModelServiceControllerTest, SendsPerformanceHint) {
// Low performance class should use fastest inference.
pref_service_.SetInteger(
model_execution::prefs::localstate::kOnDevicePerformanceClass,
base::to_underlying(OnDeviceModelPerformanceClass::kLow));
Initialize(standard_assets_);
auto session = CreateSession();
session->ExecuteModel(PageUrlRequest("foo"),
response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus());
EXPECT_EQ(*response_.value(), "Fastest inference\nInput: execute:foo\n");
}
} // namespace optimization_guide

@ -202,10 +202,19 @@ BASE_FEATURE(kAiSettingsPageEnterpriseDisabledUi,
"AiSettingsPageEnterpriseDisabledUi",
base::FEATURE_DISABLED_BY_DEFAULT);
BASE_FEATURE(kOnDeviceModelPerformanceParams,
"OnDeviceModelPerformanceParams",
base::FEATURE_ENABLED_BY_DEFAULT);
const base::FeatureParam<std::string> kPerformanceClassListForOnDeviceModel{
&kOptimizationGuideOnDeviceModel,
&kOnDeviceModelPerformanceParams,
"compatible_on_device_performance_classes", "5,6"};
const base::FeatureParam<std::string>
kLowTierPerformanceClassListForOnDeviceModel{
&kOnDeviceModelPerformanceParams,
"compatible_low_tier_on_device_performance_classes", ""};
BASE_FEATURE(kOptimizationGuideIconView,
"OptimizationGuideIconView",
base::FEATURE_DISABLED_BY_DEFAULT);

@ -90,6 +90,11 @@ BASE_DECLARE_FEATURE(kPrivacyGuideAiSettings);
COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
extern const base::FeatureParam<bool> kShowAiSettingsForTesting;
// Allows setting feature params for model download configuration, such as
// minimum performance class for download.
COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
BASE_DECLARE_FEATURE(kOnDeviceModelPerformanceParams);
// Comma-separated list of performance classes (e.g. "3,4,5") that should
// download the base model. Use "*" if there is no performance class
// requirement.
@ -97,6 +102,13 @@ COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
extern const base::FeatureParam<std::string>
kPerformanceClassListForOnDeviceModel;
// Comma-separated list of performance classes that should use a smaller model
// if available. This should be a subset of
// kPerformanceClassListForOnDeviceModel.
COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
extern const base::FeatureParam<std::string>
kLowTierPerformanceClassListForOnDeviceModel;
COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
BASE_DECLARE_FEATURE(kOptimizationGuideIconView);

@ -65,7 +65,8 @@ bool QueryGPUAdapter(void (*adapter_callback_fn)(WGPUAdapter adapter,
}
struct FakeModelInstance {
ModelBackendType backend_type_;
ml::ModelBackendType backend_type_;
ml::ModelPerformanceHint performance_hint;
std::string model_data_;
};
@ -97,8 +98,10 @@ std::string ReadFile(PlatformFile api_file) {
ChromeMLModel SessionCreateModel(const ChromeMLModelDescriptor* descriptor,
uintptr_t context,
ChromeMLScheduleFn schedule) {
return reinterpret_cast<ChromeMLModel>(
new FakeModelInstance{.backend_type_ = descriptor->backend_type});
return reinterpret_cast<ChromeMLModel>(new FakeModelInstance{
.backend_type_ = descriptor->backend_type,
.performance_hint = descriptor->performance_hint,
});
}
void DestroyModel(ChromeMLModel model) {
@ -121,11 +124,11 @@ ChromeMLSession CreateSession(ChromeMLModel model,
if (descriptor) {
instance->enable_image_input = descriptor->enable_image_input;
if (descriptor->model_data) {
if (model_instance->backend_type_ == ModelBackendType::kGpuBackend) {
if (model_instance->backend_type_ == ml::ModelBackendType::kGpuBackend) {
instance->adaptation_data_ =
ReadFile(descriptor->model_data->weights_file);
} else if (model_instance->backend_type_ ==
ModelBackendType::kApuBackend) {
ml::ModelBackendType::kApuBackend) {
base::ReadFileToString(
base::FilePath::FromUTF8Unsafe(descriptor->model_data->model_path),
&instance->adaptation_data_);
@ -206,6 +209,10 @@ bool SessionExecuteModel(ChromeMLSession session,
output_fn(&output);
};
if (reinterpret_cast<FakeModelInstance*>(model)->performance_hint ==
ml::ModelPerformanceHint::kFastestInference) {
OutputChunk("Fastest inference\n");
}
if (!instance->adaptation_data_.empty()) {
OutputChunk("Adaptation: " + instance->adaptation_data_ + "\n");
}

@ -15,6 +15,9 @@
// This header defines the public interface to the ChromeML shared library.
// TODO: crbug.com/379723772 - Remove this when internal code migrates.
using ::ml::ModelBackendType;
extern "C" {
// A function used to handle fatal errors.
@ -43,15 +46,6 @@ using ChromeMLTSModel = uintptr_t;
// Opaque handle to a video-frame-specific ML inference engine.
using ChromeMLInferenceEngine = uintptr_t;
// Type of the backend to run the model.
enum ModelBackendType {
// The default WebGPU backend.
kGpuBackend = 0,
// The APU accelerator backend. Only available on devices with APU, and need
// special APU model files.
kApuBackend = 1,
};
// A contiguous byte span.
struct ChromeMLByteSpan {
uint8_t* data;
@ -79,7 +73,7 @@ struct ChromeMLModelData {
// Describes a model to use with ChromeML.
struct ChromeMLModelDescriptor {
// The backend to run this model.
ModelBackendType backend_type;
ml::ModelBackendType backend_type;
// The model data to use.
const ChromeMLModelData* model_data;
@ -105,6 +99,8 @@ struct ChromeMLModelDescriptor {
bool enable_host_mapped_pointer;
bool use_low_power;
bool allow_fp16;
ml::ModelPerformanceHint performance_hint;
};
// Describes an adaptation for a model.

@ -28,6 +28,21 @@ enum class Token {
// current library version.
using InputPiece = std::variant<Token, std::string, SkBitmap, bool>;
// Options for specifying the performance characteristics of the model to load.
enum class ModelPerformanceHint {
kHighestQuality,
kFastestInference,
};
// Type of the backend to run the model.
enum ModelBackendType {
// The default WebGPU backend.
kGpuBackend = 0,
// The APU accelerator backend. Only available on devices with APU, and need
// special APU model files.
kApuBackend = 1,
};
} // namespace ml
#endif // SERVICES_ON_DEVICE_MODEL_ML_CHROME_ML_TYPES_H_

@ -98,4 +98,62 @@ bool UnionTraits<on_device_model::mojom::InputPieceDataView, ml::InputPiece>::
return false;
}
// static
on_device_model::mojom::ModelBackendType
EnumTraits<on_device_model::mojom::ModelBackendType,
ml::ModelBackendType>::ToMojom(ml::ModelBackendType input) {
switch (input) {
case ml::ModelBackendType::kGpuBackend:
return on_device_model::mojom::ModelBackendType::kGpu;
case ml::ModelBackendType::kApuBackend:
return on_device_model::mojom::ModelBackendType::kApu;
}
NOTREACHED();
}
// static
bool EnumTraits<on_device_model::mojom::ModelBackendType,
ml::ModelBackendType>::
FromMojom(on_device_model::mojom::ModelBackendType input,
ml::ModelBackendType* output) {
switch (input) {
case on_device_model::mojom::ModelBackendType::kGpu:
*output = ml::ModelBackendType::kGpuBackend;
return true;
case on_device_model::mojom::ModelBackendType::kApu:
*output = ml::ModelBackendType::kApuBackend;
return true;
}
return false;
}
// static
on_device_model::mojom::ModelPerformanceHint
EnumTraits<on_device_model::mojom::ModelPerformanceHint,
ml::ModelPerformanceHint>::ToMojom(ml::ModelPerformanceHint input) {
switch (input) {
case ml::ModelPerformanceHint::kHighestQuality:
return on_device_model::mojom::ModelPerformanceHint::kHighestQuality;
case ml::ModelPerformanceHint::kFastestInference:
return on_device_model::mojom::ModelPerformanceHint::kFastestInference;
}
NOTREACHED();
}
// static
bool EnumTraits<on_device_model::mojom::ModelPerformanceHint,
ml::ModelPerformanceHint>::
FromMojom(on_device_model::mojom::ModelPerformanceHint input,
ml::ModelPerformanceHint* output) {
switch (input) {
case on_device_model::mojom::ModelPerformanceHint::kHighestQuality:
*output = ml::ModelPerformanceHint::kHighestQuality;
return true;
case on_device_model::mojom::ModelPerformanceHint::kFastestInference:
*output = ml::ModelPerformanceHint::kFastestInference;
return true;
}
return false;
}
} // namespace mojo

@ -43,6 +43,24 @@ struct UnionTraits<on_device_model::mojom::InputPieceDataView, ml::InputPiece> {
ml::InputPiece* out);
};
template <>
struct EnumTraits<on_device_model::mojom::ModelBackendType,
ml::ModelBackendType> {
static on_device_model::mojom::ModelBackendType ToMojom(
ml::ModelBackendType input);
static bool FromMojom(on_device_model::mojom::ModelBackendType input,
ml::ModelBackendType* output);
};
template <>
struct EnumTraits<on_device_model::mojom::ModelPerformanceHint,
ml::ModelPerformanceHint> {
static on_device_model::mojom::ModelPerformanceHint ToMojom(
ml::ModelPerformanceHint input);
static bool FromMojom(on_device_model::mojom::ModelPerformanceHint input,
ml::ModelPerformanceHint* output);
};
} // namespace mojo
#endif // SERVICES_ON_DEVICE_MODEL_ML_CHROME_ML_TYPES_TRAITS_H_

@ -16,7 +16,7 @@ namespace ml {
namespace {
const base::FeatureParam<std::string> kGpuBlockList{
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
&optimization_guide::features::kOnDeviceModelPerformanceParams,
"on_device_model_gpu_block_list",
// These devices are nearly always crashing or have very low performance.
"8086:412|8086:a16|8086:41e|8086:416|8086:402|8086:166|8086:1616|8086:22b1|"

@ -105,18 +105,6 @@ uint32_t GetTopK(std::optional<uint32_t> top_k) {
std::max(1u, top_k.value_or(1)));
}
std::optional<ModelBackendType> ModelBackendTypeFromMojom(
on_device_model::mojom::ModelBackendType backend) {
switch (backend) {
case on_device_model::mojom::ModelBackendType::kGpu:
return ModelBackendType::kGpuBackend;
case on_device_model::mojom::ModelBackendType::kApu:
return ModelBackendType::kApuBackend;
default:
return std::nullopt;
}
}
} // namespace
// Handles sending and canceling responses.
@ -466,24 +454,17 @@ LoadModelResult OnDeviceModelExecutor::Init(
max_tokens_ = std::max(params->max_tokens, kReserveTokensForSafety);
std::optional<ModelBackendType> backend_type =
ModelBackendTypeFromMojom(params->backend_type);
if (!backend_type.has_value()) {
LOG(ERROR) << "Failed to parse model backend type";
return LoadModelResult::kFailedToLoadLibrary;
}
ChromeMLModelData data;
std::string weights_path_str = assets.weights_path.AsUTF8Unsafe();
std::string sp_model_path_str = assets.sp_model_path.AsUTF8Unsafe();
if (*backend_type == ModelBackendType::kGpuBackend) {
if (params->backend_type == ml::ModelBackendType::kGpuBackend) {
data.weights_file = assets.weights.TakePlatformFile();
} else {
data.model_path = weights_path_str.data();
data.sentencepiece_model_path = sp_model_path_str.data();
}
ChromeMLModelDescriptor descriptor = {
.backend_type = *backend_type,
.backend_type = params->backend_type,
.model_data = &data,
.max_tokens = max_tokens_,
.temperature = 0.0f,
@ -494,6 +475,7 @@ LoadModelResult OnDeviceModelExecutor::Init(
.enable_host_mapped_pointer = kEnableHostMappedPointer.Get(),
.use_low_power = kUseLowPower.Get(),
.allow_fp16 = kAllowFp16.Get(),
.performance_hint = params->performance_hint,
};
model_ = chrome_ml_->api().SessionCreateModel(
&descriptor, reinterpret_cast<uintptr_t>(this),

@ -18,30 +18,30 @@ constexpr uint64_t kBytesPerMb = 1024 * 1024;
// The threshold for GPU RAM below which the device is considered VeryLow.
const base::FeatureParam<int> kLowRAMThreshold{
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
&optimization_guide::features::kOnDeviceModelPerformanceParams,
"on_device_low_ram_threshold_mb", 3000};
// RAM threshold necessary to be considered High or better.
const base::FeatureParam<int> kHighRAMThreshold{
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
&optimization_guide::features::kOnDeviceModelPerformanceParams,
"on_device_high_ram_threshold_mb", 5500};
// Output threshold to be considered Low or better.
const base::FeatureParam<int> kLowOutputThreshold{
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
&optimization_guide::features::kOnDeviceModelPerformanceParams,
"on_device_low_output_threshold", 5};
// Input speed min thresholds or each device class.
const base::FeatureParam<int> kLowThreshold{
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
&optimization_guide::features::kOnDeviceModelPerformanceParams,
"on_device_low_threshold", 50};
const base::FeatureParam<int> kMediumThreshold{
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
&optimization_guide::features::kOnDeviceModelPerformanceParams,
"on_device_medium_threshold", 75};
const base::FeatureParam<int> kHighThreshold{
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
&optimization_guide::features::kOnDeviceModelPerformanceParams,
"on_device_high_threshold", 150};
const base::FeatureParam<int> kVeryHighThreshold{
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
&optimization_guide::features::kOnDeviceModelPerformanceParams,
"on_device_very_high_threshold", 500};
// These values are persisted to logs. Entries should not be renumbered and

@ -81,11 +81,14 @@ class OnDeviceModelServiceTest : public testing::Test {
mojo::Remote<mojom::OnDeviceModelService>& service() { return service_; }
mojo::Remote<mojom::OnDeviceModel> LoadModel(
mojom::ModelBackendType backend_type = mojom::ModelBackendType::kGpu) {
ml::ModelBackendType backend_type = ml::ModelBackendType::kGpuBackend,
ml::ModelPerformanceHint performance_hint =
ml::ModelPerformanceHint::kHighestQuality) {
base::RunLoop run_loop;
mojo::Remote<mojom::OnDeviceModel> remote;
auto params = mojom::LoadModelParams::New();
params->backend_type = backend_type;
params->performance_hint = performance_hint;
params->max_tokens = 8000;
service()->LoadModel(
std::move(params), remote.BindNewPipeAndPassReceiver(),
@ -452,7 +455,7 @@ TEST_F(OnDeviceModelServiceTest, DestroysAdaptationSession) {
TEST_F(OnDeviceModelServiceTest, LoadsAdaptationWithPath) {
FakeFile weights1("Adapt1");
FakeFile weights2("Adapt2");
auto model = LoadModel(mojom::ModelBackendType::kApu);
auto model = LoadModel(ml::ModelBackendType::kApuBackend);
auto adaptation1 = LoadAdaptation(*model, weights1.Path());
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Input: foo\n"));
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
@ -633,5 +636,12 @@ TEST_F(OnDeviceModelServiceTest, ClassifyTextSafety) {
EXPECT_THAT(resp2->class_scores, ElementsAre(0.2, 0.2));
}
TEST_F(OnDeviceModelServiceTest, PerformanceHint) {
auto model = LoadModel(ml::ModelBackendType::kGpuBackend,
ml::ModelPerformanceHint::kFastestInference);
EXPECT_THAT(GetResponses(*model, "foo"),
ElementsAre("Fastest inference\n", "Input: foo\n"));
}
} // namespace
} // namespace on_device_model

@ -130,6 +130,12 @@ void FakeOnDeviceSession::ExecuteImpl(
mojom::InputOptionsPtr input,
mojo::PendingRemote<mojom::StreamingResponder> response) {
mojo::Remote<mojom::StreamingResponder> remote(std::move(response));
if (model_->performance_hint() ==
ml::ModelPerformanceHint::kFastestInference) {
auto chunk = mojom::ResponseChunk::New();
chunk->text = "Fastest inference\n";
remote->OnResponse(std::move(chunk));
}
if (model_->data().base_weight != "0") {
auto chunk = mojom::ResponseChunk::New();
chunk->text = "Base model: " + model_->data().base_weight + "\n";
@ -184,8 +190,11 @@ void FakeOnDeviceSession::AddContextInternal(
}
FakeOnDeviceModel::FakeOnDeviceModel(FakeOnDeviceServiceSettings* settings,
FakeOnDeviceModel::Data&& data)
: settings_(settings), data_(std::move(data)) {}
FakeOnDeviceModel::Data&& data,
ml::ModelPerformanceHint performance_hint)
: settings_(settings),
data_(std::move(data)),
performance_hint_(performance_hint) {}
FakeOnDeviceModel::~FakeOnDeviceModel() = default;
@ -221,8 +230,8 @@ void FakeOnDeviceModel::LoadAdaptation(
LoadAdaptationCallback callback) {
Data data = data_;
data.adaptation_model_weight = ReadFile(params->assets.weights);
auto test_model =
std::make_unique<FakeOnDeviceModel>(settings_, std::move(data));
auto test_model = std::make_unique<FakeOnDeviceModel>(
settings_, std::move(data), ml::ModelPerformanceHint::kHighestQuality);
model_adaptation_receivers_.Add(std::move(test_model), std::move(model));
std::move(callback).Run(mojom::LoadModelResult::kSuccess);
}
@ -292,8 +301,8 @@ void FakeOnDeviceModelService::LoadModel(
}
FakeOnDeviceModel::Data data;
data.base_weight = ReadFile(params->assets.weights);
auto test_model =
std::make_unique<FakeOnDeviceModel>(settings_, std::move(data));
auto test_model = std::make_unique<FakeOnDeviceModel>(
settings_, std::move(data), params->performance_hint);
model_receivers_.Add(std::move(test_model), std::move(model));
std::move(callback).Run(mojom::LoadModelResult::kSuccess);
}

@ -112,7 +112,8 @@ class FakeOnDeviceModel : public mojom::OnDeviceModel {
std::string adaptation_model_weight = "";
};
explicit FakeOnDeviceModel(FakeOnDeviceServiceSettings* settings,
Data&& data);
Data&& data,
ml::ModelPerformanceHint performance_hint);
~FakeOnDeviceModel() override;
// mojom::OnDeviceModel:
@ -134,9 +135,14 @@ class FakeOnDeviceModel : public mojom::OnDeviceModel {
const Data& data() const { return data_; }
ml::ModelPerformanceHint performance_hint() const {
return performance_hint_;
}
private:
raw_ptr<FakeOnDeviceServiceSettings> settings_;
Data data_;
ml::ModelPerformanceHint performance_hint_;
mojo::UniqueReceiverSet<mojom::Session> receivers_;
mojo::UniqueReceiverSet<mojom::OnDeviceModel> model_adaptation_receivers_;

@ -37,6 +37,14 @@ mojom("mojom") {
mojom = "on_device_model.mojom.InputPiece"
cpp = "::ml::InputPiece"
},
{
mojom = "on_device_model.mojom.ModelBackendType"
cpp = "::ml::ModelBackendType"
},
{
mojom = "on_device_model.mojom.ModelPerformanceHint"
cpp = "::ml::ModelPerformanceHint"
},
]
traits_headers = [
"//services/on_device_model/public/cpp/adaptation_assets_mojom_traits.h",

@ -36,6 +36,12 @@ enum ModelBackendType {
kApu,
};
// Options for specifying the performance characteristics of the model to load.
enum ModelPerformanceHint {
kHighestQuality,
kFastestInference,
};
// Params to describe the model to load.
struct LoadModelParams {
// Backend type of the model.
@ -50,6 +56,10 @@ struct LoadModelParams {
// List of adaptation ranks the model should support.
array<uint32> adaptation_ranks;
// Chooses the performance characteristics of the model loaded. If only a
// single model is available, this field will do nothing.
ModelPerformanceHint performance_hint = kHighestQuality;
};
struct TextSafetyModelAssets {

@ -15298,6 +15298,7 @@ from previous Chrome versions.
<int value="-2036127998" label="LocalWebApprovals:disabled"/>
<int value="-2035845836" label="ArcExtendServiceAnrTimeout:enabled"/>
<int value="-2035126988" label="enabled-new-style-notification"/>
<int value="-2034315473" label="OnDeviceModelPerformanceParams:disabled"/>
<int value="-2034064186" label="EnableKeyboardBacklightToggle:disabled"/>
<int value="-2033950090" label="AutofillNoLocalSaveOnUploadSuccess:disabled"/>
<int value="-2033908928" label="NightLight:enabled"/>
@ -22130,6 +22131,7 @@ from previous Chrome versions.
<int value="705946076"
label="ContextMenuPerformanceInfoAndRemoteHintFetching:disabled"/>
<int value="706280254" label="StoragePressureEvent:enabled"/>
<int value="707243920" label="OnDeviceModelPerformanceParams:enabled"/>
<int value="707463326" label="DynamicSafeAreaInsets:enabled"/>
<int value="708015891"
label="AutofillUpdateChromeSettingsLinkToGPayWeb:enabled"/>