Add support for performance hints when loading on-device model
This allows passing a performance hint to the backend when loading the on-device model. This can be used to tune various params in the backend or use a smaller model if requested. This is controlled from the optimization guide side by a new feature param "compatible_low_tier_on_device_performance_classes", which allows us to choose a set of performance classes that will use the kFastestInference performance hint. As part of this change, all the feature params related to model performance were moved under a new base::Feature to allow them to be adjusted independent of the other on-device related params. Bug: 379723772 Change-Id: I28db8de699a8af026ee2ee1cfa08136342860e13 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6089104 Reviewed-by: Matthew Denton <mpdenton@chromium.org> Commit-Queue: Clark DuVall <cduvall@chromium.org> Reviewed-by: Sophie Chang <sophiechang@chromium.org> Cr-Commit-Position: refs/heads/main@{#1395676}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
82818acdfe
commit
3fdbf21ee2
chrome/browser
components/optimization_guide/core
model_execution
on_device_model_component.ccon_device_model_component.hon_device_model_component_unittest.ccon_device_model_service_controller.ccon_device_model_service_controller_unittest.cc
optimization_guide_features.ccoptimization_guide_features.hservices/on_device_model
fake
ml
chrome_ml_api.hchrome_ml_types.hchrome_ml_types_traits.ccchrome_ml_types_traits.hgpu_blocklist.ccon_device_model_executor.ccperformance_class.cc
on_device_model_service_unittest.ccpublic
tools/metrics/histograms
@ -6645,7 +6645,7 @@ const FeatureEntry kFeatureEntries[] = {
|
||||
flag_descriptions::kOptimizationGuideOnDeviceModelName,
|
||||
flag_descriptions::kOptimizationGuideOnDeviceModelDescription, kOsDesktop,
|
||||
FEATURE_WITH_PARAMS_VALUE_TYPE(
|
||||
optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
||||
optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||
kOptimizationGuideOnDeviceModelVariations,
|
||||
"OptimizationGuideOnDeviceModel")},
|
||||
|
||||
|
@ -690,7 +690,8 @@ class OnDeviceModelExecutionEnabledBrowserTest
|
||||
scoped_feature_list_.InitWithFeaturesAndParameters(
|
||||
{{features::kOptimizationGuideModelExecution, {}},
|
||||
{features::kModelQualityLogging, {}},
|
||||
{features::kOptimizationGuideOnDeviceModel,
|
||||
{features::kOptimizationGuideOnDeviceModel, {}},
|
||||
{features::kOnDeviceModelPerformanceParams,
|
||||
{{"compatible_on_device_performance_classes", "*"}}}},
|
||||
{});
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ build_webui("build") {
|
||||
mojo_files = [
|
||||
"$root_gen_dir/chrome/browser/ui/webui/on_device_internals/on_device_internals_page.mojom-webui.ts",
|
||||
"$root_gen_dir/services/on_device_model/public/mojom/on_device_model.mojom-webui.ts",
|
||||
"$root_gen_dir/services/on_device_model/public/mojom/on_device_model_service.mojom-webui.ts",
|
||||
]
|
||||
|
||||
static_files = [ "on_device_internals.html" ]
|
||||
|
@ -1,4 +1,4 @@
|
||||
<style include="cr-shared-style cr-hidden-style">
|
||||
<style include="cr-shared-style cr-hidden-style md-select">
|
||||
:host {
|
||||
display: block;
|
||||
margin: auto;
|
||||
@ -110,24 +110,40 @@
|
||||
cr-expand-button:hover {
|
||||
background-color: var(--cr-hover-background-color);
|
||||
}
|
||||
|
||||
.model-options {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
}
|
||||
|
||||
cr-checkbox {
|
||||
margin-left: 20px;
|
||||
--cr-checkbox-label-padding-start: 10px;
|
||||
}
|
||||
</style>
|
||||
<div class="performance-class">
|
||||
Device performance class: <strong>[[performanceClassText_]]</strong>
|
||||
</div>
|
||||
<cr-input id="modelInput" label="Model directory" placeholder="/tmp/model"
|
||||
disabled="[[isLoading_(loadModelStart_)]]"
|
||||
on-change="onModelSelected_" error-message="[[error_]]"
|
||||
invalid="[[error_.length]]" autofocus>
|
||||
error-message="[[error_]]" invalid="[[error_.length]]" autofocus>
|
||||
<cr-button slot="suffix" disabled="[[isLoading_(loadModelStart_)]]"
|
||||
on-click="onLoadClick_">
|
||||
Load
|
||||
</cr-button>
|
||||
</cr-input>
|
||||
<cr-checkbox slot="suffix" checked="{{enableImageInput_}}">
|
||||
Enable images
|
||||
</cr-checkbox>
|
||||
<div class="model-options">
|
||||
<select id="performanceHintSelect" class="md-select"
|
||||
value="[[performanceHint_]]" on-change="onPerformanceHintChange_">
|
||||
<option value="kHighestQuality">Highest Quality</option>
|
||||
<option value="kFastestInference">Fastest Inference</option>
|
||||
</select>
|
||||
<cr-checkbox slot="suffix" checked="{{enableImageInput_}}">
|
||||
Enable images
|
||||
</cr-checkbox>
|
||||
</div>
|
||||
<div class="model-text">
|
||||
[[getModelText_(modelPath_, loadModelDuration_)]]
|
||||
[[getModelText_(modelPath_, loadModelDuration_, loadedPerformanceHint_)]]
|
||||
<div class="throbber" hidden$="[[!isLoading_(loadModelStart_)]]"></div>
|
||||
</div>
|
||||
|
||||
|
@ -10,6 +10,7 @@ import '//resources/cr_elements/cr_hidden_style.css.js';
|
||||
import '//resources/cr_elements/cr_input/cr_input.js';
|
||||
import '//resources/cr_elements/cr_shared_vars.css.js';
|
||||
import '//resources/cr_elements/cr_textarea/cr_textarea.js';
|
||||
import '//resources/cr_elements/md_select.css.js';
|
||||
|
||||
import type {CrInputElement} from '//resources/cr_elements/cr_input/cr_input.js';
|
||||
import {PolymerElement} from '//resources/polymer/v3_0/polymer/polymer_bundled.min.js';
|
||||
@ -17,6 +18,7 @@ import {PolymerElement} from '//resources/polymer/v3_0/polymer/polymer_bundled.m
|
||||
import {BrowserProxy} from './browser_proxy.js';
|
||||
import type {InputPiece, ResponseChunk, ResponseSummary} from './on_device_model.mojom-webui.js';
|
||||
import {LoadModelResult, OnDeviceModelRemote, PerformanceClass, SessionRemote, StreamingResponderCallbackRouter, Token} from './on_device_model.mojom-webui.js';
|
||||
import {ModelPerformanceHint} from './on_device_model_service.mojom-webui.js';
|
||||
import {getTemplate} from './tools.html.js';
|
||||
|
||||
interface Response {
|
||||
@ -34,6 +36,7 @@ interface OnDeviceInternalsToolsElement {
|
||||
textInput: CrInputElement,
|
||||
imageInput: HTMLInputElement,
|
||||
topKInput: CrInputElement,
|
||||
performanceHintSelect: HTMLSelectElement,
|
||||
};
|
||||
}
|
||||
|
||||
@ -131,6 +134,8 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
||||
type: Object,
|
||||
value: null,
|
||||
},
|
||||
performanceHint_: String,
|
||||
loadedPerformanceHint_: Number,
|
||||
};
|
||||
}
|
||||
|
||||
@ -159,6 +164,8 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
||||
private topK_: number = 1;
|
||||
private imageFile_: File|null = null;
|
||||
private enableImageInput_: boolean = false;
|
||||
private performanceHint_: string = 'kHighestQuality';
|
||||
private loadedPerformanceHint_: ModelPerformanceHint|null;
|
||||
|
||||
private proxy_: BrowserProxy = BrowserProxy.getInstance();
|
||||
private responseRouter_: StreamingResponderCallbackRouter =
|
||||
@ -199,6 +206,10 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
||||
this.$.imageInput.value = '';
|
||||
}
|
||||
|
||||
private onPerformanceHintChange_() {
|
||||
this.performanceHint_ = this.$.performanceHintSelect.value;
|
||||
}
|
||||
|
||||
private onServiceCrashed_() {
|
||||
if (this.currentResponse_) {
|
||||
this.currentResponse_.error = true;
|
||||
@ -233,6 +244,8 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
||||
this.baseModel_ = null;
|
||||
this.model_ = null;
|
||||
this.loadModelStart_ = new Date().getTime();
|
||||
const performanceHint = ModelPerformanceHint[(
|
||||
this.performanceHint_ as keyof typeof ModelPerformanceHint)];
|
||||
const modelPath = this.$.modelInput.value;
|
||||
// <if expr="is_win">
|
||||
// Windows file paths are std::wstring, so use Array<Number>.
|
||||
@ -244,7 +257,8 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
||||
const baseModel = new OnDeviceModelRemote();
|
||||
let newModel = new OnDeviceModelRemote();
|
||||
let {result} = await this.proxy_.handler.loadModel(
|
||||
{path: processedPath}, baseModel.$.bindNewPipeAndPassReceiver());
|
||||
{path: processedPath}, performanceHint,
|
||||
baseModel.$.bindNewPipeAndPassReceiver());
|
||||
if (result === LoadModelResult.kSuccess && this.enableImageInput_) {
|
||||
result = (await baseModel.loadAdaptation(
|
||||
{
|
||||
@ -272,6 +286,7 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
||||
});
|
||||
this.startNewSession_();
|
||||
this.modelPath_ = modelPath;
|
||||
this.loadedPerformanceHint_ = performanceHint;
|
||||
}
|
||||
}
|
||||
|
||||
@ -419,9 +434,13 @@ class OnDeviceInternalsToolsElement extends PolymerElement {
|
||||
return '';
|
||||
}
|
||||
let text = 'Model loaded from ' + this.modelPath_ + ' in ' +
|
||||
this.loadModelDuration_ + 'ms';
|
||||
this.loadModelDuration_ + 'ms ';
|
||||
if (this.imagesEnabled_()) {
|
||||
text += ' [images enabled]';
|
||||
text += '[images enabled]';
|
||||
}
|
||||
if (this.loadedPerformanceHint_ ===
|
||||
ModelPerformanceHint.kFastestInference) {
|
||||
text += '[fastest inference]';
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import "mojo/public/mojom/base/big_buffer.mojom";
|
||||
import "mojo/public/mojom/base/file_path.mojom";
|
||||
import "mojo/public/mojom/base/time.mojom";
|
||||
import "services/on_device_model/public/mojom/on_device_model.mojom";
|
||||
import "services/on_device_model/public/mojom/on_device_model_service.mojom";
|
||||
import "skia/public/mojom/bitmap.mojom";
|
||||
|
||||
// Struct containing data to be displayed on on-device-internals page.
|
||||
@ -43,6 +44,7 @@ interface OnDeviceInternalsPageHandler {
|
||||
// Binds a new OnDeviceModel interface if possible using model assets loaded
|
||||
// from within `model_path`.
|
||||
LoadModel(mojo_base.mojom.FilePath model_path,
|
||||
on_device_model.mojom.ModelPerformanceHint performance_hint,
|
||||
pending_receiver<on_device_model.mojom.OnDeviceModel> model) =>
|
||||
(on_device_model.mojom.LoadModelResult result);
|
||||
|
||||
|
@ -61,6 +61,7 @@ OnDeviceInternalsPageHandler::~OnDeviceInternalsPageHandler() {
|
||||
|
||||
void OnDeviceInternalsPageHandler::LoadModel(
|
||||
const base::FilePath& model_path,
|
||||
ml::ModelPerformanceHint performance_hint,
|
||||
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
|
||||
LoadModelCallback callback) {
|
||||
#if BUILDFLAG(USE_CHROMEOS_MODEL_SERVICE)
|
||||
@ -81,7 +82,7 @@ void OnDeviceInternalsPageHandler::LoadModel(
|
||||
base::BindOnce(&LoadModelAssets, model_path),
|
||||
base::BindOnce(&OnDeviceInternalsPageHandler::OnModelAssetsLoaded,
|
||||
weak_ptr_factory_.GetWeakPtr(), std::move(model),
|
||||
std::move(callback)));
|
||||
std::move(callback), performance_hint));
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -109,10 +110,12 @@ OnDeviceInternalsPageHandler::GetService() {
|
||||
void OnDeviceInternalsPageHandler::OnModelAssetsLoaded(
|
||||
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
|
||||
LoadModelCallback callback,
|
||||
ml::ModelPerformanceHint performance_hint,
|
||||
on_device_model::ModelAssets assets) {
|
||||
auto params = on_device_model::mojom::LoadModelParams::New();
|
||||
params->assets = std::move(assets);
|
||||
params->max_tokens = 4096;
|
||||
params->performance_hint = performance_hint;
|
||||
GetService().LoadModel(std::move(params), std::move(model),
|
||||
std::move(callback));
|
||||
}
|
||||
|
@ -43,12 +43,14 @@ class OnDeviceInternalsPageHandler : public mojom::OnDeviceInternalsPageHandler,
|
||||
void OnModelAssetsLoaded(
|
||||
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
|
||||
LoadModelCallback callback,
|
||||
ml::ModelPerformanceHint performance_hint,
|
||||
on_device_model::ModelAssets assets);
|
||||
#endif
|
||||
|
||||
// mojom::OnDeviceInternalsPageHandler:
|
||||
void LoadModel(
|
||||
const base::FilePath& model_path,
|
||||
ml::ModelPerformanceHint performance_hint,
|
||||
mojo::PendingReceiver<on_device_model::mojom::OnDeviceModel> model,
|
||||
LoadModelCallback callback) override;
|
||||
void GetEstimatedPerformanceClass(
|
||||
|
@ -138,6 +138,13 @@ OnDeviceModelComponentStateManager::GetRegistrationCriteria() {
|
||||
return registration_criteria_.get();
|
||||
}
|
||||
|
||||
bool OnDeviceModelComponentStateManager::IsLowTierDevice() const {
|
||||
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
|
||||
return IsPerformanceClassCompatible(
|
||||
features::kLowTierPerformanceClassListForOnDeviceModel.Get(),
|
||||
PerformanceClassFromPref(*local_state_));
|
||||
}
|
||||
|
||||
void OnDeviceModelComponentStateManager::OnDeviceEligibleFeatureUsed(
|
||||
ModelBasedCapabilityKey feature) {
|
||||
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
|
||||
|
@ -179,6 +179,9 @@ class OnDeviceModelComponentStateManager
|
||||
// registration has been computed yet.
|
||||
const RegistrationCriteria* GetRegistrationCriteria();
|
||||
|
||||
// Returns true if this is determined to be a low tier device.
|
||||
bool IsLowTierDevice() const;
|
||||
|
||||
base::WeakPtr<OnDeviceModelComponentStateManager> GetWeakPtr() {
|
||||
return weak_ptr_factory_.GetWeakPtr();
|
||||
}
|
||||
|
@ -53,7 +53,8 @@ class OnDeviceModelComponentTest : public testing::Test {
|
||||
|
||||
feature_list_.InitWithFeaturesAndParameters(
|
||||
{{features::kOptimizationGuideModelExecution, {}},
|
||||
{features::kOptimizationGuideOnDeviceModel,
|
||||
{features::kOptimizationGuideOnDeviceModel, {}},
|
||||
{features::kOnDeviceModelPerformanceParams,
|
||||
{{"compatible_on_device_performance_classes", "3,4,5,6"}}}},
|
||||
/*disabled_features=*/{});
|
||||
}
|
||||
|
@ -279,6 +279,10 @@ void OnDeviceModelServiceController::OnModelAssetsLoaded(
|
||||
// TODO(crbug.com/302402959): Choose max_tokens based on device.
|
||||
params->max_tokens = features::GetOnDeviceModelMaxTokens();
|
||||
params->adaptation_ranks = features::GetOnDeviceModelAllowedAdaptationRanks();
|
||||
if (on_device_component_state_manager_ &&
|
||||
on_device_component_state_manager_->IsLowTierDevice()) {
|
||||
params->performance_hint = ml::ModelPerformanceHint::kFastestInference;
|
||||
}
|
||||
service_client_.Get()->LoadModel(
|
||||
std::move(params), std::move(model),
|
||||
base::DoNothingAs<void(on_device_model::mojom::LoadModelResult)>());
|
||||
|
16
components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
16
components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
@ -213,6 +213,9 @@ class OnDeviceModelServiceControllerTest : public testing::Test {
|
||||
{"on_device_model_disable_crash_count", "3"},
|
||||
{"on_device_model_crash_backoff_base_time", "1m"},
|
||||
{"on_device_model_max_crash_backoff_time", "1h"}}},
|
||||
{features::kOnDeviceModelPerformanceParams,
|
||||
{{"compatible_on_device_performance_classes", "*"},
|
||||
{"compatible_low_tier_on_device_performance_classes", "3"}}},
|
||||
{features::kTextSafetyClassifier, {}},
|
||||
{features::kOnDeviceModelValidation,
|
||||
{{"on_device_model_validation_delay", "0"}}}},
|
||||
@ -3851,4 +3854,17 @@ TEST_F(OnDeviceModelServiceControllerTest, LoggingModeAlwaysDisable) {
|
||||
EXPECT_EQ(0u, test_uploader.uploaded_logs().size());
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceControllerTest, SendsPerformanceHint) {
|
||||
// Low performance class should use fastest inference.
|
||||
pref_service_.SetInteger(
|
||||
model_execution::prefs::localstate::kOnDevicePerformanceClass,
|
||||
base::to_underlying(OnDeviceModelPerformanceClass::kLow));
|
||||
Initialize(standard_assets_);
|
||||
auto session = CreateSession();
|
||||
session->ExecuteModel(PageUrlRequest("foo"),
|
||||
response_.GetStreamingCallback());
|
||||
ASSERT_TRUE(response_.GetFinalStatus());
|
||||
EXPECT_EQ(*response_.value(), "Fastest inference\nInput: execute:foo\n");
|
||||
}
|
||||
|
||||
} // namespace optimization_guide
|
||||
|
@ -202,10 +202,19 @@ BASE_FEATURE(kAiSettingsPageEnterpriseDisabledUi,
|
||||
"AiSettingsPageEnterpriseDisabledUi",
|
||||
base::FEATURE_DISABLED_BY_DEFAULT);
|
||||
|
||||
BASE_FEATURE(kOnDeviceModelPerformanceParams,
|
||||
"OnDeviceModelPerformanceParams",
|
||||
base::FEATURE_ENABLED_BY_DEFAULT);
|
||||
|
||||
const base::FeatureParam<std::string> kPerformanceClassListForOnDeviceModel{
|
||||
&kOptimizationGuideOnDeviceModel,
|
||||
&kOnDeviceModelPerformanceParams,
|
||||
"compatible_on_device_performance_classes", "5,6"};
|
||||
|
||||
const base::FeatureParam<std::string>
|
||||
kLowTierPerformanceClassListForOnDeviceModel{
|
||||
&kOnDeviceModelPerformanceParams,
|
||||
"compatible_low_tier_on_device_performance_classes", ""};
|
||||
|
||||
BASE_FEATURE(kOptimizationGuideIconView,
|
||||
"OptimizationGuideIconView",
|
||||
base::FEATURE_DISABLED_BY_DEFAULT);
|
||||
|
@ -90,6 +90,11 @@ BASE_DECLARE_FEATURE(kPrivacyGuideAiSettings);
|
||||
COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
|
||||
extern const base::FeatureParam<bool> kShowAiSettingsForTesting;
|
||||
|
||||
// Allows setting feature params for model download configuration, such as
|
||||
// minimum performance class for download.
|
||||
COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
|
||||
BASE_DECLARE_FEATURE(kOnDeviceModelPerformanceParams);
|
||||
|
||||
// Comma-separated list of performance classes (e.g. "3,4,5") that should
|
||||
// download the base model. Use "*" if there is no performance class
|
||||
// requirement.
|
||||
@ -97,6 +102,13 @@ COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
|
||||
extern const base::FeatureParam<std::string>
|
||||
kPerformanceClassListForOnDeviceModel;
|
||||
|
||||
// Comma-separated list of performance classes that should use a smaller model
|
||||
// if available. This should be a subset of
|
||||
// kPerformanceClassListForOnDeviceModel.
|
||||
COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
|
||||
extern const base::FeatureParam<std::string>
|
||||
kLowTierPerformanceClassListForOnDeviceModel;
|
||||
|
||||
COMPONENT_EXPORT(OPTIMIZATION_GUIDE_FEATURES)
|
||||
BASE_DECLARE_FEATURE(kOptimizationGuideIconView);
|
||||
|
||||
|
@ -65,7 +65,8 @@ bool QueryGPUAdapter(void (*adapter_callback_fn)(WGPUAdapter adapter,
|
||||
}
|
||||
|
||||
struct FakeModelInstance {
|
||||
ModelBackendType backend_type_;
|
||||
ml::ModelBackendType backend_type_;
|
||||
ml::ModelPerformanceHint performance_hint;
|
||||
std::string model_data_;
|
||||
};
|
||||
|
||||
@ -97,8 +98,10 @@ std::string ReadFile(PlatformFile api_file) {
|
||||
ChromeMLModel SessionCreateModel(const ChromeMLModelDescriptor* descriptor,
|
||||
uintptr_t context,
|
||||
ChromeMLScheduleFn schedule) {
|
||||
return reinterpret_cast<ChromeMLModel>(
|
||||
new FakeModelInstance{.backend_type_ = descriptor->backend_type});
|
||||
return reinterpret_cast<ChromeMLModel>(new FakeModelInstance{
|
||||
.backend_type_ = descriptor->backend_type,
|
||||
.performance_hint = descriptor->performance_hint,
|
||||
});
|
||||
}
|
||||
|
||||
void DestroyModel(ChromeMLModel model) {
|
||||
@ -121,11 +124,11 @@ ChromeMLSession CreateSession(ChromeMLModel model,
|
||||
if (descriptor) {
|
||||
instance->enable_image_input = descriptor->enable_image_input;
|
||||
if (descriptor->model_data) {
|
||||
if (model_instance->backend_type_ == ModelBackendType::kGpuBackend) {
|
||||
if (model_instance->backend_type_ == ml::ModelBackendType::kGpuBackend) {
|
||||
instance->adaptation_data_ =
|
||||
ReadFile(descriptor->model_data->weights_file);
|
||||
} else if (model_instance->backend_type_ ==
|
||||
ModelBackendType::kApuBackend) {
|
||||
ml::ModelBackendType::kApuBackend) {
|
||||
base::ReadFileToString(
|
||||
base::FilePath::FromUTF8Unsafe(descriptor->model_data->model_path),
|
||||
&instance->adaptation_data_);
|
||||
@ -206,6 +209,10 @@ bool SessionExecuteModel(ChromeMLSession session,
|
||||
output_fn(&output);
|
||||
};
|
||||
|
||||
if (reinterpret_cast<FakeModelInstance*>(model)->performance_hint ==
|
||||
ml::ModelPerformanceHint::kFastestInference) {
|
||||
OutputChunk("Fastest inference\n");
|
||||
}
|
||||
if (!instance->adaptation_data_.empty()) {
|
||||
OutputChunk("Adaptation: " + instance->adaptation_data_ + "\n");
|
||||
}
|
||||
|
@ -15,6 +15,9 @@
|
||||
|
||||
// This header defines the public interface to the ChromeML shared library.
|
||||
|
||||
// TODO: crbug.com/379723772 - Remove this when internal code migrates.
|
||||
using ::ml::ModelBackendType;
|
||||
|
||||
extern "C" {
|
||||
|
||||
// A function used to handle fatal errors.
|
||||
@ -43,15 +46,6 @@ using ChromeMLTSModel = uintptr_t;
|
||||
// Opaque handle to a video-frame-specific ML inference engine.
|
||||
using ChromeMLInferenceEngine = uintptr_t;
|
||||
|
||||
// Type of the backend to run the model.
|
||||
enum ModelBackendType {
|
||||
// The default WebGPU backend.
|
||||
kGpuBackend = 0,
|
||||
// The APU accelerator backend. Only available on devices with APU, and need
|
||||
// special APU model files.
|
||||
kApuBackend = 1,
|
||||
};
|
||||
|
||||
// A contiguous byte span.
|
||||
struct ChromeMLByteSpan {
|
||||
uint8_t* data;
|
||||
@ -79,7 +73,7 @@ struct ChromeMLModelData {
|
||||
// Describes a model to use with ChromeML.
|
||||
struct ChromeMLModelDescriptor {
|
||||
// The backend to run this model.
|
||||
ModelBackendType backend_type;
|
||||
ml::ModelBackendType backend_type;
|
||||
|
||||
// The model data to use.
|
||||
const ChromeMLModelData* model_data;
|
||||
@ -105,6 +99,8 @@ struct ChromeMLModelDescriptor {
|
||||
bool enable_host_mapped_pointer;
|
||||
bool use_low_power;
|
||||
bool allow_fp16;
|
||||
|
||||
ml::ModelPerformanceHint performance_hint;
|
||||
};
|
||||
|
||||
// Describes an adaptation for a model.
|
||||
|
@ -28,6 +28,21 @@ enum class Token {
|
||||
// current library version.
|
||||
using InputPiece = std::variant<Token, std::string, SkBitmap, bool>;
|
||||
|
||||
// Options for specifying the performance characteristics of the model to load.
|
||||
enum class ModelPerformanceHint {
|
||||
kHighestQuality,
|
||||
kFastestInference,
|
||||
};
|
||||
|
||||
// Type of the backend to run the model.
|
||||
enum ModelBackendType {
|
||||
// The default WebGPU backend.
|
||||
kGpuBackend = 0,
|
||||
// The APU accelerator backend. Only available on devices with APU, and need
|
||||
// special APU model files.
|
||||
kApuBackend = 1,
|
||||
};
|
||||
|
||||
} // namespace ml
|
||||
|
||||
#endif // SERVICES_ON_DEVICE_MODEL_ML_CHROME_ML_TYPES_H_
|
||||
|
@ -98,4 +98,62 @@ bool UnionTraits<on_device_model::mojom::InputPieceDataView, ml::InputPiece>::
|
||||
return false;
|
||||
}
|
||||
|
||||
// static
|
||||
on_device_model::mojom::ModelBackendType
|
||||
EnumTraits<on_device_model::mojom::ModelBackendType,
|
||||
ml::ModelBackendType>::ToMojom(ml::ModelBackendType input) {
|
||||
switch (input) {
|
||||
case ml::ModelBackendType::kGpuBackend:
|
||||
return on_device_model::mojom::ModelBackendType::kGpu;
|
||||
case ml::ModelBackendType::kApuBackend:
|
||||
return on_device_model::mojom::ModelBackendType::kApu;
|
||||
}
|
||||
NOTREACHED();
|
||||
}
|
||||
|
||||
// static
|
||||
bool EnumTraits<on_device_model::mojom::ModelBackendType,
|
||||
ml::ModelBackendType>::
|
||||
FromMojom(on_device_model::mojom::ModelBackendType input,
|
||||
ml::ModelBackendType* output) {
|
||||
switch (input) {
|
||||
case on_device_model::mojom::ModelBackendType::kGpu:
|
||||
*output = ml::ModelBackendType::kGpuBackend;
|
||||
return true;
|
||||
case on_device_model::mojom::ModelBackendType::kApu:
|
||||
*output = ml::ModelBackendType::kApuBackend;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// static
|
||||
on_device_model::mojom::ModelPerformanceHint
|
||||
EnumTraits<on_device_model::mojom::ModelPerformanceHint,
|
||||
ml::ModelPerformanceHint>::ToMojom(ml::ModelPerformanceHint input) {
|
||||
switch (input) {
|
||||
case ml::ModelPerformanceHint::kHighestQuality:
|
||||
return on_device_model::mojom::ModelPerformanceHint::kHighestQuality;
|
||||
case ml::ModelPerformanceHint::kFastestInference:
|
||||
return on_device_model::mojom::ModelPerformanceHint::kFastestInference;
|
||||
}
|
||||
NOTREACHED();
|
||||
}
|
||||
|
||||
// static
|
||||
bool EnumTraits<on_device_model::mojom::ModelPerformanceHint,
|
||||
ml::ModelPerformanceHint>::
|
||||
FromMojom(on_device_model::mojom::ModelPerformanceHint input,
|
||||
ml::ModelPerformanceHint* output) {
|
||||
switch (input) {
|
||||
case on_device_model::mojom::ModelPerformanceHint::kHighestQuality:
|
||||
*output = ml::ModelPerformanceHint::kHighestQuality;
|
||||
return true;
|
||||
case on_device_model::mojom::ModelPerformanceHint::kFastestInference:
|
||||
*output = ml::ModelPerformanceHint::kFastestInference;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace mojo
|
||||
|
@ -43,6 +43,24 @@ struct UnionTraits<on_device_model::mojom::InputPieceDataView, ml::InputPiece> {
|
||||
ml::InputPiece* out);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct EnumTraits<on_device_model::mojom::ModelBackendType,
|
||||
ml::ModelBackendType> {
|
||||
static on_device_model::mojom::ModelBackendType ToMojom(
|
||||
ml::ModelBackendType input);
|
||||
static bool FromMojom(on_device_model::mojom::ModelBackendType input,
|
||||
ml::ModelBackendType* output);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct EnumTraits<on_device_model::mojom::ModelPerformanceHint,
|
||||
ml::ModelPerformanceHint> {
|
||||
static on_device_model::mojom::ModelPerformanceHint ToMojom(
|
||||
ml::ModelPerformanceHint input);
|
||||
static bool FromMojom(on_device_model::mojom::ModelPerformanceHint input,
|
||||
ml::ModelPerformanceHint* output);
|
||||
};
|
||||
|
||||
} // namespace mojo
|
||||
|
||||
#endif // SERVICES_ON_DEVICE_MODEL_ML_CHROME_ML_TYPES_TRAITS_H_
|
||||
|
@ -16,7 +16,7 @@ namespace ml {
|
||||
namespace {
|
||||
|
||||
const base::FeatureParam<std::string> kGpuBlockList{
|
||||
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
||||
&optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||
"on_device_model_gpu_block_list",
|
||||
// These devices are nearly always crashing or have very low performance.
|
||||
"8086:412|8086:a16|8086:41e|8086:416|8086:402|8086:166|8086:1616|8086:22b1|"
|
||||
|
@ -105,18 +105,6 @@ uint32_t GetTopK(std::optional<uint32_t> top_k) {
|
||||
std::max(1u, top_k.value_or(1)));
|
||||
}
|
||||
|
||||
std::optional<ModelBackendType> ModelBackendTypeFromMojom(
|
||||
on_device_model::mojom::ModelBackendType backend) {
|
||||
switch (backend) {
|
||||
case on_device_model::mojom::ModelBackendType::kGpu:
|
||||
return ModelBackendType::kGpuBackend;
|
||||
case on_device_model::mojom::ModelBackendType::kApu:
|
||||
return ModelBackendType::kApuBackend;
|
||||
default:
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Handles sending and canceling responses.
|
||||
@ -466,24 +454,17 @@ LoadModelResult OnDeviceModelExecutor::Init(
|
||||
|
||||
max_tokens_ = std::max(params->max_tokens, kReserveTokensForSafety);
|
||||
|
||||
std::optional<ModelBackendType> backend_type =
|
||||
ModelBackendTypeFromMojom(params->backend_type);
|
||||
if (!backend_type.has_value()) {
|
||||
LOG(ERROR) << "Failed to parse model backend type";
|
||||
return LoadModelResult::kFailedToLoadLibrary;
|
||||
}
|
||||
|
||||
ChromeMLModelData data;
|
||||
std::string weights_path_str = assets.weights_path.AsUTF8Unsafe();
|
||||
std::string sp_model_path_str = assets.sp_model_path.AsUTF8Unsafe();
|
||||
if (*backend_type == ModelBackendType::kGpuBackend) {
|
||||
if (params->backend_type == ml::ModelBackendType::kGpuBackend) {
|
||||
data.weights_file = assets.weights.TakePlatformFile();
|
||||
} else {
|
||||
data.model_path = weights_path_str.data();
|
||||
data.sentencepiece_model_path = sp_model_path_str.data();
|
||||
}
|
||||
ChromeMLModelDescriptor descriptor = {
|
||||
.backend_type = *backend_type,
|
||||
.backend_type = params->backend_type,
|
||||
.model_data = &data,
|
||||
.max_tokens = max_tokens_,
|
||||
.temperature = 0.0f,
|
||||
@ -494,6 +475,7 @@ LoadModelResult OnDeviceModelExecutor::Init(
|
||||
.enable_host_mapped_pointer = kEnableHostMappedPointer.Get(),
|
||||
.use_low_power = kUseLowPower.Get(),
|
||||
.allow_fp16 = kAllowFp16.Get(),
|
||||
.performance_hint = params->performance_hint,
|
||||
};
|
||||
model_ = chrome_ml_->api().SessionCreateModel(
|
||||
&descriptor, reinterpret_cast<uintptr_t>(this),
|
||||
|
@ -18,30 +18,30 @@ constexpr uint64_t kBytesPerMb = 1024 * 1024;
|
||||
|
||||
// The threshold for GPU RAM below which the device is considered VeryLow.
|
||||
const base::FeatureParam<int> kLowRAMThreshold{
|
||||
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
||||
&optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||
"on_device_low_ram_threshold_mb", 3000};
|
||||
// RAM threshold necessary to be considered High or better.
|
||||
const base::FeatureParam<int> kHighRAMThreshold{
|
||||
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
||||
&optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||
"on_device_high_ram_threshold_mb", 5500};
|
||||
|
||||
// Output threshold to be considered Low or better.
|
||||
const base::FeatureParam<int> kLowOutputThreshold{
|
||||
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
||||
&optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||
"on_device_low_output_threshold", 5};
|
||||
|
||||
// Input speed min thresholds or each device class.
|
||||
const base::FeatureParam<int> kLowThreshold{
|
||||
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
||||
&optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||
"on_device_low_threshold", 50};
|
||||
const base::FeatureParam<int> kMediumThreshold{
|
||||
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
||||
&optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||
"on_device_medium_threshold", 75};
|
||||
const base::FeatureParam<int> kHighThreshold{
|
||||
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
||||
&optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||
"on_device_high_threshold", 150};
|
||||
const base::FeatureParam<int> kVeryHighThreshold{
|
||||
&optimization_guide::features::kOptimizationGuideOnDeviceModel,
|
||||
&optimization_guide::features::kOnDeviceModelPerformanceParams,
|
||||
"on_device_very_high_threshold", 500};
|
||||
|
||||
// These values are persisted to logs. Entries should not be renumbered and
|
||||
|
@ -81,11 +81,14 @@ class OnDeviceModelServiceTest : public testing::Test {
|
||||
mojo::Remote<mojom::OnDeviceModelService>& service() { return service_; }
|
||||
|
||||
mojo::Remote<mojom::OnDeviceModel> LoadModel(
|
||||
mojom::ModelBackendType backend_type = mojom::ModelBackendType::kGpu) {
|
||||
ml::ModelBackendType backend_type = ml::ModelBackendType::kGpuBackend,
|
||||
ml::ModelPerformanceHint performance_hint =
|
||||
ml::ModelPerformanceHint::kHighestQuality) {
|
||||
base::RunLoop run_loop;
|
||||
mojo::Remote<mojom::OnDeviceModel> remote;
|
||||
auto params = mojom::LoadModelParams::New();
|
||||
params->backend_type = backend_type;
|
||||
params->performance_hint = performance_hint;
|
||||
params->max_tokens = 8000;
|
||||
service()->LoadModel(
|
||||
std::move(params), remote.BindNewPipeAndPassReceiver(),
|
||||
@ -452,7 +455,7 @@ TEST_F(OnDeviceModelServiceTest, DestroysAdaptationSession) {
|
||||
TEST_F(OnDeviceModelServiceTest, LoadsAdaptationWithPath) {
|
||||
FakeFile weights1("Adapt1");
|
||||
FakeFile weights2("Adapt2");
|
||||
auto model = LoadModel(mojom::ModelBackendType::kApu);
|
||||
auto model = LoadModel(ml::ModelBackendType::kApuBackend);
|
||||
auto adaptation1 = LoadAdaptation(*model, weights1.Path());
|
||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Input: foo\n"));
|
||||
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
|
||||
@ -633,5 +636,12 @@ TEST_F(OnDeviceModelServiceTest, ClassifyTextSafety) {
|
||||
EXPECT_THAT(resp2->class_scores, ElementsAre(0.2, 0.2));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, PerformanceHint) {
|
||||
auto model = LoadModel(ml::ModelBackendType::kGpuBackend,
|
||||
ml::ModelPerformanceHint::kFastestInference);
|
||||
EXPECT_THAT(GetResponses(*model, "foo"),
|
||||
ElementsAre("Fastest inference\n", "Input: foo\n"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace on_device_model
|
||||
|
@ -130,6 +130,12 @@ void FakeOnDeviceSession::ExecuteImpl(
|
||||
mojom::InputOptionsPtr input,
|
||||
mojo::PendingRemote<mojom::StreamingResponder> response) {
|
||||
mojo::Remote<mojom::StreamingResponder> remote(std::move(response));
|
||||
if (model_->performance_hint() ==
|
||||
ml::ModelPerformanceHint::kFastestInference) {
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
chunk->text = "Fastest inference\n";
|
||||
remote->OnResponse(std::move(chunk));
|
||||
}
|
||||
if (model_->data().base_weight != "0") {
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
chunk->text = "Base model: " + model_->data().base_weight + "\n";
|
||||
@ -184,8 +190,11 @@ void FakeOnDeviceSession::AddContextInternal(
|
||||
}
|
||||
|
||||
FakeOnDeviceModel::FakeOnDeviceModel(FakeOnDeviceServiceSettings* settings,
|
||||
FakeOnDeviceModel::Data&& data)
|
||||
: settings_(settings), data_(std::move(data)) {}
|
||||
FakeOnDeviceModel::Data&& data,
|
||||
ml::ModelPerformanceHint performance_hint)
|
||||
: settings_(settings),
|
||||
data_(std::move(data)),
|
||||
performance_hint_(performance_hint) {}
|
||||
|
||||
FakeOnDeviceModel::~FakeOnDeviceModel() = default;
|
||||
|
||||
@ -221,8 +230,8 @@ void FakeOnDeviceModel::LoadAdaptation(
|
||||
LoadAdaptationCallback callback) {
|
||||
Data data = data_;
|
||||
data.adaptation_model_weight = ReadFile(params->assets.weights);
|
||||
auto test_model =
|
||||
std::make_unique<FakeOnDeviceModel>(settings_, std::move(data));
|
||||
auto test_model = std::make_unique<FakeOnDeviceModel>(
|
||||
settings_, std::move(data), ml::ModelPerformanceHint::kHighestQuality);
|
||||
model_adaptation_receivers_.Add(std::move(test_model), std::move(model));
|
||||
std::move(callback).Run(mojom::LoadModelResult::kSuccess);
|
||||
}
|
||||
@ -292,8 +301,8 @@ void FakeOnDeviceModelService::LoadModel(
|
||||
}
|
||||
FakeOnDeviceModel::Data data;
|
||||
data.base_weight = ReadFile(params->assets.weights);
|
||||
auto test_model =
|
||||
std::make_unique<FakeOnDeviceModel>(settings_, std::move(data));
|
||||
auto test_model = std::make_unique<FakeOnDeviceModel>(
|
||||
settings_, std::move(data), params->performance_hint);
|
||||
model_receivers_.Add(std::move(test_model), std::move(model));
|
||||
std::move(callback).Run(mojom::LoadModelResult::kSuccess);
|
||||
}
|
||||
|
@ -112,7 +112,8 @@ class FakeOnDeviceModel : public mojom::OnDeviceModel {
|
||||
std::string adaptation_model_weight = "";
|
||||
};
|
||||
explicit FakeOnDeviceModel(FakeOnDeviceServiceSettings* settings,
|
||||
Data&& data);
|
||||
Data&& data,
|
||||
ml::ModelPerformanceHint performance_hint);
|
||||
~FakeOnDeviceModel() override;
|
||||
|
||||
// mojom::OnDeviceModel:
|
||||
@ -134,9 +135,14 @@ class FakeOnDeviceModel : public mojom::OnDeviceModel {
|
||||
|
||||
const Data& data() const { return data_; }
|
||||
|
||||
ml::ModelPerformanceHint performance_hint() const {
|
||||
return performance_hint_;
|
||||
}
|
||||
|
||||
private:
|
||||
raw_ptr<FakeOnDeviceServiceSettings> settings_;
|
||||
Data data_;
|
||||
ml::ModelPerformanceHint performance_hint_;
|
||||
|
||||
mojo::UniqueReceiverSet<mojom::Session> receivers_;
|
||||
mojo::UniqueReceiverSet<mojom::OnDeviceModel> model_adaptation_receivers_;
|
||||
|
@ -37,6 +37,14 @@ mojom("mojom") {
|
||||
mojom = "on_device_model.mojom.InputPiece"
|
||||
cpp = "::ml::InputPiece"
|
||||
},
|
||||
{
|
||||
mojom = "on_device_model.mojom.ModelBackendType"
|
||||
cpp = "::ml::ModelBackendType"
|
||||
},
|
||||
{
|
||||
mojom = "on_device_model.mojom.ModelPerformanceHint"
|
||||
cpp = "::ml::ModelPerformanceHint"
|
||||
},
|
||||
]
|
||||
traits_headers = [
|
||||
"//services/on_device_model/public/cpp/adaptation_assets_mojom_traits.h",
|
||||
|
@ -36,6 +36,12 @@ enum ModelBackendType {
|
||||
kApu,
|
||||
};
|
||||
|
||||
// Options for specifying the performance characteristics of the model to load.
|
||||
enum ModelPerformanceHint {
|
||||
kHighestQuality,
|
||||
kFastestInference,
|
||||
};
|
||||
|
||||
// Params to describe the model to load.
|
||||
struct LoadModelParams {
|
||||
// Backend type of the model.
|
||||
@ -50,6 +56,10 @@ struct LoadModelParams {
|
||||
|
||||
// List of adaptation ranks the model should support.
|
||||
array<uint32> adaptation_ranks;
|
||||
|
||||
// Chooses the performance characteristics of the model loaded. If only a
|
||||
// single model is available, this field will do nothing.
|
||||
ModelPerformanceHint performance_hint = kHighestQuality;
|
||||
};
|
||||
|
||||
struct TextSafetyModelAssets {
|
||||
|
@ -15298,6 +15298,7 @@ from previous Chrome versions.
|
||||
<int value="-2036127998" label="LocalWebApprovals:disabled"/>
|
||||
<int value="-2035845836" label="ArcExtendServiceAnrTimeout:enabled"/>
|
||||
<int value="-2035126988" label="enabled-new-style-notification"/>
|
||||
<int value="-2034315473" label="OnDeviceModelPerformanceParams:disabled"/>
|
||||
<int value="-2034064186" label="EnableKeyboardBacklightToggle:disabled"/>
|
||||
<int value="-2033950090" label="AutofillNoLocalSaveOnUploadSuccess:disabled"/>
|
||||
<int value="-2033908928" label="NightLight:enabled"/>
|
||||
@ -22130,6 +22131,7 @@ from previous Chrome versions.
|
||||
<int value="705946076"
|
||||
label="ContextMenuPerformanceInfoAndRemoteHintFetching:disabled"/>
|
||||
<int value="706280254" label="StoragePressureEvent:enabled"/>
|
||||
<int value="707243920" label="OnDeviceModelPerformanceParams:enabled"/>
|
||||
<int value="707463326" label="DynamicSafeAreaInsets:enabled"/>
|
||||
<int value="708015891"
|
||||
label="AutofillUpdateChromeSettingsLinkToGPayWeb:enabled"/>
|
||||
|
Reference in New Issue
Block a user