0

Reland: Split out OnDeviceModelAssetManager.

Move all of the on-device asset observation logic out of
ModelExecutionManager.  No functional changes.

Reland of crrev.com/c/6115297 + crrev.com/c/6159444 + an iOS fix, since this
was not the culprit for crbug.com/388547765.

Bug: 388538282, 373725257
Change-Id: I4db8241d37ef45be8a6b235479cb394d7a0b9a66
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6289432
Auto-Submit: Steven Holte <holte@chromium.org>
Commit-Queue: Steven Holte <holte@chromium.org>
Reviewed-by: Zekun Jiang <zekunjiang@google.com>
Cr-Commit-Position: refs/heads/main@{#1424147}
This commit is contained in:
Steven Holte
2025-02-24 13:54:05 -08:00
committed by Chromium LUCI CQ
parent be443cc694
commit 643c0e2fc4
11 changed files with 472 additions and 375 deletions

@ -4,6 +4,7 @@
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"
#include <memory>
#include <optional>
#include "base/files/file_path.h"
@ -48,6 +49,7 @@
#include "components/optimization_guide/core/model_execution/model_execution_features.h"
#include "components/optimization_guide/core/model_execution/model_execution_features_controller.h"
#include "components/optimization_guide/core/model_execution/model_execution_manager.h"
#include "components/optimization_guide/core/model_execution/on_device_asset_manager.h"
#include "components/optimization_guide/core/model_execution/on_device_model_component.h"
#include "components/optimization_guide/core/model_execution/on_device_model_service_controller.h"
#include "components/optimization_guide/core/model_execution/performance_class.h"
@ -94,7 +96,6 @@ using ::optimization_guide::OnDeviceModelComponentStateManager;
using ::optimization_guide::OnDeviceModelPerformanceClass;
using ::optimization_guide::OnDeviceModelServiceController;
// Returns the profile to use for when setting up the keyed service when the
// profile is Off-The-Record. For guest profiles, returns a loaded profile if
// one exists, otherwise just the original profile of the OTR profile. Note:
@ -434,15 +435,16 @@ void OptimizationGuideKeyedService::InitializeModelExecution(Profile* profile) {
optimization_guide::features::kOptimizationGuideOnDeviceModel)) {
service_controller = GetOnDeviceModelServiceController(
on_device_component_manager_->GetWeakPtr());
on_device_asset_manager_ =
std::make_unique<optimization_guide::OnDeviceAssetManager>(
g_browser_process->local_state(), service_controller->GetWeakPtr(),
on_device_component_manager_->GetWeakPtr(), this);
}
model_execution_manager_ =
std::make_unique<optimization_guide::ModelExecutionManager>(
url_loader_factory, g_browser_process->local_state(),
IdentityManagerFactory::GetForProfile(profile),
std::move(service_controller), this,
on_device_component_manager_->GetWeakPtr(),
optimization_guide_logger_.get(),
url_loader_factory, IdentityManagerFactory::GetForProfile(profile),
std::move(service_controller), optimization_guide_logger_.get(),
model_quality_logs_uploader_service_
? model_quality_logs_uploader_service_->GetWeakPtr()
: nullptr);

@ -48,6 +48,7 @@ class ModelExecutionManager;
class ModelInfo;
class ModelQualityLogsUploaderService;
class ModelValidatorKeyedService;
class OnDeviceAssetManager;
class OnDeviceModelAvailabilityObserver;
class OnDeviceModelComponentStateManager;
class OptimizationGuideStore;
@ -359,6 +360,9 @@ class OptimizationGuideKeyedService
// prediction models.
std::unique_ptr<optimization_guide::PredictionManager> prediction_manager_;
std::unique_ptr<optimization_guide::OnDeviceAssetManager>
on_device_asset_manager_;
// Manages the model execution. Not created for off the record profiles.
std::unique_ptr<optimization_guide::ModelExecutionManager>
model_execution_manager_;

@ -250,6 +250,8 @@ static_library("core") {
"model_execution/model_execution_util.h",
"model_execution/multimodal_message.cc",
"model_execution/multimodal_message.h",
"model_execution/on_device_asset_manager.cc",
"model_execution/on_device_asset_manager.h",
"model_execution/on_device_context.cc",
"model_execution/on_device_context.h",
"model_execution/on_device_execution.cc",
@ -532,6 +534,7 @@ source_set("unit_tests") {
"model_execution/model_execution_fetcher_unittest.cc",
"model_execution/model_execution_manager_unittest.cc",
"model_execution/multimodal_message_unittest.cc",
"model_execution/on_device_asset_manager_unittest.cc",
"model_execution/on_device_model_adaptation_loader_unittest.cc",
"model_execution/on_device_model_component_unittest.cc",
"model_execution/on_device_model_execution_proto_descriptors_unittest.cc",

@ -122,31 +122,6 @@ void NoOpExecuteRemoteFn(
nullptr);
}
std::map<ModelBasedCapabilityKey, OnDeviceModelAdaptationLoader>
GetRequiredModelAdaptationLoaders(
OptimizationGuideModelProvider* model_provider,
base::WeakPtr<OnDeviceModelComponentStateManager>
on_device_component_state_manager,
PrefService* local_state,
base::WeakPtr<OnDeviceModelServiceController>
on_device_model_service_controller) {
std::map<ModelBasedCapabilityKey, OnDeviceModelAdaptationLoader> loaders;
for (const auto feature : kAllModelBasedCapabilityKeys) {
if (!features::internal::GetOptimizationTargetForCapability(feature)) {
continue;
}
loaders.emplace(
std::piecewise_construct, std::forward_as_tuple(feature),
std::forward_as_tuple(
feature, model_provider, on_device_component_state_manager,
local_state,
base::BindRepeating(
&OnDeviceModelServiceController::MaybeUpdateModelAdaptation,
on_device_model_service_controller, feature)));
}
return loaders;
}
} // namespace
using ModelExecutionError =
@ -154,18 +129,13 @@ using ModelExecutionError =
ModelExecutionManager::ModelExecutionManager(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
PrefService* local_state,
signin::IdentityManager* identity_manager,
scoped_refptr<OnDeviceModelServiceController>
on_device_model_service_controller,
OptimizationGuideModelProvider* model_provider,
base::WeakPtr<OnDeviceModelComponentStateManager>
on_device_component_state_manager,
OptimizationGuideLogger* optimization_guide_logger,
base::WeakPtr<ModelQualityLogsUploaderService>
model_quality_uploader_service)
: model_quality_uploader_service_(model_quality_uploader_service),
on_device_component_state_manager_(on_device_component_state_manager),
optimization_guide_logger_(optimization_guide_logger),
model_execution_service_url_(net::AppendOrReplaceQueryParameter(
GetModelExecutionServiceURL(),
@ -173,47 +143,11 @@ ModelExecutionManager::ModelExecutionManager(
features::GetOptimizationGuideServiceAPIKey())),
url_loader_factory_(url_loader_factory),
identity_manager_(identity_manager),
model_adaptation_loaders_(GetRequiredModelAdaptationLoaders(
model_provider,
on_device_component_state_manager_,
local_state,
on_device_model_service_controller
? on_device_model_service_controller->GetWeakPtr()
: nullptr)),
model_provider_(model_provider),
on_device_model_service_controller_(
std::move(on_device_model_service_controller)) {
if (!model_provider_ && !on_device_model_service_controller_) {
return;
}
if (!features::ShouldUseTextSafetyClassifierModel()) {
return;
}
if (GetGenAILocalFoundationalModelEnterprisePolicySettings(local_state) !=
model_execution::prefs::
GenAILocalFoundationalModelEnterprisePolicySettings::kAllowed) {
return;
}
if (on_device_component_state_manager_) {
on_device_component_state_manager_->AddObserver(this);
if (on_device_component_state_manager_->IsInstallerRegistered()) {
RegisterTextSafetyAndLanguageModels();
}
}
}
ModelExecutionManager::~ModelExecutionManager() {
if (on_device_component_state_manager_) {
on_device_component_state_manager_->RemoveObserver(this);
}
if (did_register_for_supplementary_on_device_models_) {
model_provider_->RemoveObserverForOptimizationTargetModel(
proto::OptimizationTarget::OPTIMIZATION_TARGET_TEXT_SAFETY, this);
model_provider_->RemoveObserverForOptimizationTargetModel(
proto::OptimizationTarget::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
this);
}
}
void ModelExecutionManager::Shutdown() {
@ -323,11 +257,6 @@ ModelExecutionManager::StartSession(
std::move(execute_fn), config_params);
}
// Whether the supplementary on-device models are registered.
bool ModelExecutionManager::IsSupplementaryModelRegistered() {
return did_register_for_supplementary_on_device_models_;
}
void ModelExecutionManager::OnModelExecuteResponse(
ModelBasedCapabilityKey feature,
std::unique_ptr<proto::LogAiDataRequest> log_ai_data_request,
@ -459,47 +388,6 @@ void ModelExecutionManager::OnModelExecuteResponse(
std::move(log_entry));
}
void ModelExecutionManager::RegisterTextSafetyAndLanguageModels() {
if (!did_register_for_supplementary_on_device_models_) {
did_register_for_supplementary_on_device_models_ = true;
model_provider_->AddObserverForOptimizationTargetModel(
proto::OptimizationTarget::OPTIMIZATION_TARGET_TEXT_SAFETY,
/*model_metadata=*/std::nullopt, this);
model_provider_->AddObserverForOptimizationTargetModel(
proto::OptimizationTarget::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
/*model_metadata=*/std::nullopt, this);
}
}
void ModelExecutionManager::OnModelUpdated(
proto::OptimizationTarget optimization_target,
base::optional_ref<const ModelInfo> model_info) {
switch (optimization_target) {
case proto::OPTIMIZATION_TARGET_TEXT_SAFETY:
if (on_device_model_service_controller_) {
on_device_model_service_controller_->MaybeUpdateSafetyModel(model_info);
}
break;
case proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION:
if (on_device_model_service_controller_) {
on_device_model_service_controller_->SetLanguageDetectionModel(
model_info);
}
break;
default:
break;
}
}
void ModelExecutionManager::StateChanged(
const OnDeviceModelComponentState* state) {
if (state) {
RegisterTextSafetyAndLanguageModels();
}
}
optimization_guide::OnDeviceModelEligibilityReason
ModelExecutionManager::GetOnDeviceModelEligibility(
optimization_guide::ModelBasedCapabilityKey feature) {

@ -20,8 +20,6 @@
#include "components/optimization_guide/core/optimization_target_model_observer.h"
#include "components/optimization_guide/proto/model_execution.pb.h"
#include "components/optimization_guide/proto/model_quality_service.pb.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "services/on_device_model/public/mojom/on_device_model.mojom.h"
#include "url/gurl.h"
class OptimizationGuideLogger;
@ -37,28 +35,19 @@ class IdentityManager;
namespace optimization_guide {
class ModelExecutionFetcher;
class OnDeviceModelAdaptationLoader;
class OnDeviceModelServiceController;
class OptimizationGuideModelProvider;
class ModelExecutionManager
: public OptimizationTargetModelObserver,
public OnDeviceModelComponentStateManager::Observer {
class ModelExecutionManager final {
public:
ModelExecutionManager(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
PrefService* local_state,
signin::IdentityManager* identity_manager,
scoped_refptr<OnDeviceModelServiceController>
on_device_model_service_controller,
OptimizationGuideModelProvider* model_provider,
base::WeakPtr<OnDeviceModelComponentStateManager>
on_device_component_state_manager,
OptimizationGuideLogger* optimization_guide_logger,
base::WeakPtr<ModelQualityLogsUploaderService>
model_quality_uploader_service);
~ModelExecutionManager() override;
~ModelExecutionManager();
ModelExecutionManager(const ModelExecutionManager&) = delete;
ModelExecutionManager& operator=(const ModelExecutionManager&) = delete;
@ -93,16 +82,6 @@ class ModelExecutionManager
ModelBasedCapabilityKey feature,
const std::optional<SessionConfigParams>& config_params);
// Whether the supplementary on-device models are registered.
bool IsSupplementaryModelRegistered();
// OptimizationTargetModelObserver:
void OnModelUpdated(proto::OptimizationTarget target,
base::optional_ref<const ModelInfo> model_info) override;
// OnDeviceModelComponentStateManager::Observer:
void StateChanged(const OnDeviceModelComponentState* state) override;
void Shutdown();
private:
@ -114,10 +93,6 @@ class ModelExecutionManager
base::expected<const proto::ExecuteResponse,
OptimizationGuideModelExecutionError> execute_response);
// Registers text safety and language detection models. Does nothing if
// already registered.
void RegisterTextSafetyAndLanguageModels();
// Returns the `OnDeviceModelAdaptationMetadata` for `feature`.
std::optional<optimization_guide::OnDeviceModelAdaptationMetadata>
GetOnDeviceModelAdaptationMetadata(
@ -129,9 +104,6 @@ class ModelExecutionManager
base::WeakPtr<ModelQualityLogsUploaderService>
model_quality_uploader_service_;
base::WeakPtr<OnDeviceModelComponentStateManager>
on_device_component_state_manager_;
// Owned by OptimizationGuideKeyedService and outlives `this`.
raw_ptr<OptimizationGuideLogger> optimization_guide_logger_;
@ -149,21 +121,10 @@ class ModelExecutionManager
// incognito profiles.
const raw_ptr<signin::IdentityManager> identity_manager_;
// Map from feature to its model adaptation loader. Present only for features
// that require model adaptation.
const std::map<ModelBasedCapabilityKey, OnDeviceModelAdaptationLoader>
model_adaptation_loaders_;
// The model provider to observe for updates to auxiliary models.
raw_ptr<OptimizationGuideModelProvider> model_provider_;
// Controller for the on-device service.
scoped_refptr<OnDeviceModelServiceController>
on_device_model_service_controller_;
// Whether the user registered for supplementary on-device models.
bool did_register_for_supplementary_on_device_models_ = false;
SEQUENCE_CHECKER(sequence_checker_);
// Used to get `weak_ptr_` to self.

@ -7,18 +7,16 @@
#include <memory>
#include "base/files/file_path.h"
#include "base/functional/callback_helpers.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "base/test/test.pb.h"
#include "base/test/test_future.h"
#include "components/optimization_guide/core/model_execution/model_execution_features.h"
#include "components/optimization_guide/core/model_execution/model_execution_prefs.h"
#include "components/optimization_guide/core/model_execution/on_device_model_access_controller.h"
#include "components/optimization_guide/core/model_execution/on_device_model_service_controller.h"
#include "components/optimization_guide/core/model_execution/test/request_builder.h"
#include "components/optimization_guide/core/model_execution/test/response_holder.h"
#include "components/optimization_guide/core/model_execution/test/test_on_device_model_component_state_manager.h"
#include "components/optimization_guide/core/optimization_guide_constants.h"
#include "components/optimization_guide/core/optimization_guide_logger.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
@ -55,68 +53,9 @@ proto::ExecuteResponse BuildComposeResponse(const std::string& output) {
return execute_response;
}
class FakeServiceController : public OnDeviceModelServiceController {
public:
FakeServiceController()
: OnDeviceModelServiceController(nullptr, nullptr, base::DoNothing()) {}
void MaybeUpdateSafetyModel(
base::optional_ref<const ModelInfo> model_info) override {
received_safety_info_ = true;
}
bool received_safety_info() const { return received_safety_info_; }
std::optional<base::FilePath> language_detection_model_path() {
return OnDeviceModelServiceController::language_detection_model_path();
}
private:
~FakeServiceController() override = default;
bool received_safety_info_ = false;
};
class FakeModelProvider : public TestOptimizationGuideModelProvider {
public:
void AddObserverForOptimizationTargetModel(
proto::OptimizationTarget optimization_target,
const std::optional<optimization_guide::proto::Any>& model_metadata,
OptimizationTargetModelObserver* observer) override {
switch (optimization_target) {
case proto::OPTIMIZATION_TARGET_TEXT_SAFETY:
registered_for_text_safety_ = true;
break;
case proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION:
registered_for_language_detection_ = true;
break;
default:
NOTREACHED();
}
}
void Reset() {
registered_for_text_safety_ = false;
registered_for_language_detection_ = false;
}
bool was_registered() const {
return registered_for_text_safety_ && registered_for_language_detection_;
}
private:
bool registered_for_text_safety_ = false;
bool registered_for_language_detection_ = false;
};
class ModelExecutionManagerTest : public testing::Test {
public:
ModelExecutionManagerTest() {
scoped_feature_list_.InitWithFeatures({},
{features::kTextSafetyClassifier});
}
ModelExecutionManagerTest() = default;
~ModelExecutionManagerTest() override = default;
// Sets up most of the fields except `model_execution_manager_` and
@ -125,29 +64,11 @@ class ModelExecutionManagerTest : public testing::Test {
url_loader_factory_ =
base::MakeRefCounted<network::WeakWrapperSharedURLLoaderFactory>(
&test_url_loader_factory_);
local_state_ = std::make_unique<TestingPrefServiceSimple>();
model_execution::prefs::RegisterLocalStatePrefs(local_state_->registry());
service_controller_ = base::MakeRefCounted<FakeServiceController>();
}
void CreateModelExecutionManager() {
service_controller_ = base::MakeRefCounted<OnDeviceModelServiceController>(
nullptr, nullptr, base::DoNothing());
model_execution_manager_ = std::make_unique<ModelExecutionManager>(
url_loader_factory_, local_state_.get(),
identity_test_env_.identity_manager(), service_controller_,
&model_provider_,
component_manager_ ? component_manager_->get()->GetWeakPtr() : nullptr,
&optimization_guide_logger_, nullptr);
}
void CreateComponentManager(bool should_observe) {
component_manager_ =
std::make_unique<TestOnDeviceModelComponentStateManager>(
local_state_.get());
component_manager_->get()->OnStartup();
task_environment_.FastForwardBy(base::Seconds(1));
if (should_observe) {
component_manager_->get()->AddObserver(model_execution_manager_.get());
}
url_loader_factory_, identity_test_env_.identity_manager(),
service_controller_, &optimization_guide_logger_, nullptr);
}
bool SimulateResponse(const std::string& content,
@ -179,9 +100,7 @@ class ModelExecutionManagerTest : public testing::Test {
return model_execution_manager_.get();
}
FakeModelProvider* model_provider() { return &model_provider_; }
FakeServiceController* service_controller() {
OnDeviceModelServiceController* service_controller() {
return service_controller_.get();
}
@ -195,37 +114,25 @@ class ModelExecutionManagerTest : public testing::Test {
EXPECT_THAT(body_bytes, HasSubstr(message));
}
void SetModelComponentReady() {
component_manager_->SetReady(base::FilePath());
}
network::TestURLLoaderFactory* test_url_loader_factory() {
return &test_url_loader_factory_;
}
PrefService* local_state() { return local_state_.get(); }
void Reset() { model_execution_manager_ = nullptr; }
private:
base::test::TaskEnvironment task_environment_{
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
base::test::ScopedFeatureList scoped_feature_list_;
std::unique_ptr<TestingPrefServiceSimple> local_state_;
signin::IdentityTestEnvironment identity_test_env_;
variations::ScopedVariationsIdsProvider scoped_variations_ids_provider_{
variations::VariationsIdsProvider::Mode::kUseSignedInState};
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
network::TestURLLoaderFactory test_url_loader_factory_;
scoped_refptr<FakeServiceController> service_controller_;
std::unique_ptr<TestOnDeviceModelComponentStateManager> component_manager_;
FakeModelProvider model_provider_;
scoped_refptr<OnDeviceModelServiceController> service_controller_;
OptimizationGuideLogger optimization_guide_logger_;
std::unique_ptr<ModelExecutionManager> model_execution_manager_;
};
TEST_F(ModelExecutionManagerTest, ExecuteModelEmptyAccessToken) {
CreateModelExecutionManager();
base::HistogramTester histogram_tester;
ResponseHolder response_holder;
model_execution_manager()->ExecuteModel(
@ -244,7 +151,6 @@ TEST_F(ModelExecutionManagerTest, ExecuteModelEmptyAccessToken) {
}
TEST_F(ModelExecutionManagerTest, ExecuteModelWithUserSignIn) {
CreateModelExecutionManager();
base::HistogramTester histogram_tester;
ResponseHolder response_holder;
SetAutomaticIssueOfAccessTokens();
@ -266,7 +172,6 @@ TEST_F(ModelExecutionManagerTest, ExecuteModelWithUserSignIn) {
}
TEST_F(ModelExecutionManagerTest, ExecuteModelWithServerError) {
CreateModelExecutionManager();
base::HistogramTester histogram_tester;
ResponseHolder response_holder;
@ -298,7 +203,6 @@ TEST_F(ModelExecutionManagerTest, ExecuteModelWithServerError) {
TEST_F(ModelExecutionManagerTest,
ExecuteModelWithServerErrorAllowedForLogging) {
CreateModelExecutionManager();
base::HistogramTester histogram_tester;
ResponseHolder response_holder;
@ -340,7 +244,6 @@ TEST_F(ModelExecutionManagerTest,
}
TEST_F(ModelExecutionManagerTest, ExecuteModelExecutionModeSetOnDeviceOnly) {
CreateModelExecutionManager();
base::HistogramTester histogram_tester;
SetAutomaticIssueOfAccessTokens();
@ -359,7 +262,6 @@ TEST_F(ModelExecutionManagerTest, ExecuteModelExecutionModeSetOnDeviceOnly) {
}
TEST_F(ModelExecutionManagerTest, ExecuteModelExecutionModeSetToServerOnly) {
CreateModelExecutionManager();
base::HistogramTester histogram_tester;
ResponseHolder response_holder;
@ -392,7 +294,6 @@ TEST_F(ModelExecutionManagerTest, ExecuteModelExecutionModeSetToServerOnly) {
TEST_F(ModelExecutionManagerTest,
ExecuteModelExecutionModeExplicitlySetToDefault) {
CreateModelExecutionManager();
base::HistogramTester histogram_tester;
ResponseHolder response_holder;
@ -424,7 +325,6 @@ TEST_F(ModelExecutionManagerTest,
}
TEST_F(ModelExecutionManagerTest, ExecuteModelWithPassthroughSession) {
CreateModelExecutionManager();
base::HistogramTester histogram_tester;
ResponseHolder response_holder;
@ -450,7 +350,6 @@ TEST_F(ModelExecutionManagerTest, ExecuteModelWithPassthroughSession) {
}
TEST_F(ModelExecutionManagerTest, LogsContextToExecutionTimeHistogram) {
CreateModelExecutionManager();
base::HistogramTester histogram_tester;
SetAutomaticIssueOfAccessTokens();
auto session = model_execution_manager()->StartSession(
@ -491,7 +390,6 @@ TEST_F(ModelExecutionManagerTest, LogsContextToExecutionTimeHistogram) {
TEST_F(ModelExecutionManagerTest,
ExecuteModelWithPassthroughSessionAddContext) {
CreateModelExecutionManager();
ResponseHolder response_holder;
SetAutomaticIssueOfAccessTokens();
auto session = model_execution_manager()->StartSession(
@ -508,7 +406,6 @@ TEST_F(ModelExecutionManagerTest,
TEST_F(ModelExecutionManagerTest,
ExecuteModelWithPassthroughSessionMultipleAddContext) {
CreateModelExecutionManager();
ResponseHolder response_holder;
SetAutomaticIssueOfAccessTokens();
auto session = model_execution_manager()->StartSession(
@ -525,7 +422,6 @@ TEST_F(ModelExecutionManagerTest,
TEST_F(ModelExecutionManagerTest,
ExecuteModelWithPassthroughSessionExecuteOverridesAddContext) {
CreateModelExecutionManager();
ResponseHolder response_holder;
SetAutomaticIssueOfAccessTokens();
auto session = model_execution_manager()->StartSession(
@ -541,7 +437,6 @@ TEST_F(ModelExecutionManagerTest,
}
TEST_F(ModelExecutionManagerTest, TestMultipleParallelRequests) {
CreateModelExecutionManager();
base::HistogramTester histogram_tester;
ResponseHolder response_holder1, response_holder2;
@ -582,103 +477,6 @@ TEST_F(ModelExecutionManagerTest, TestMultipleParallelRequests) {
"OptimizationGuide.ModelExecution.Result.Compose", false, 1);
}
TEST_F(ModelExecutionManagerTest, DoesNotRegisterTextSafetyIfNotEnabled) {
CreateModelExecutionManager();
EXPECT_FALSE(model_provider()->was_registered());
}
class ModelExecutionManagerSafetyEnabledTest
: public ModelExecutionManagerTest {
public:
ModelExecutionManagerSafetyEnabledTest() {
scoped_feature_list_.InitWithFeatures({features::kTextSafetyClassifier},
{});
}
private:
base::test::ScopedFeatureList scoped_feature_list_;
};
#if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_MAC) || BUILDFLAG(IS_LINUX)
TEST_F(ModelExecutionManagerSafetyEnabledTest,
RegistersTextSafetyModelWithOverrideModel) {
// Effectively, when an override is set, the model component will be ready
// before ModelExecutionManager can be added as an observer. Here we simulate
// that by simply setting up the component without adding
// ModelExecutionManager as an observer.
CreateComponentManager(/*should_observe=*/false);
SetModelComponentReady();
CreateModelExecutionManager();
EXPECT_TRUE(model_provider()->was_registered());
}
TEST_F(ModelExecutionManagerSafetyEnabledTest,
RegistersTextSafetyModelIfEnabled) {
CreateModelExecutionManager();
EXPECT_FALSE(model_provider()->was_registered());
// Text safety model should only be registered after the base model is ready.
local_state()->SetInteger(
model_execution::prefs::localstate::kOnDevicePerformanceClass,
base::to_underlying(OnDeviceModelPerformanceClass::kHigh));
CreateComponentManager(/*should_observe=*/true);
SetModelComponentReady();
EXPECT_TRUE(model_provider()->was_registered());
}
#endif
TEST_F(ModelExecutionManagerSafetyEnabledTest,
DoesNotNotifyServiceControllerWrongTarget) {
CreateModelExecutionManager();
std::unique_ptr<ModelInfo> model_info =
TestModelInfoBuilder().SetVersion(123).Build();
model_execution_manager()->OnModelUpdated(
proto::OPTIMIZATION_TARGET_PAGE_ENTITIES, *model_info);
EXPECT_FALSE(service_controller()->received_safety_info());
}
TEST_F(ModelExecutionManagerSafetyEnabledTest, NotifiesServiceController) {
CreateModelExecutionManager();
std::unique_ptr<ModelInfo> model_info =
TestModelInfoBuilder().SetVersion(123).Build();
model_execution_manager()->OnModelUpdated(
proto::OPTIMIZATION_TARGET_TEXT_SAFETY, *model_info);
EXPECT_TRUE(service_controller()->received_safety_info());
}
TEST_F(ModelExecutionManagerSafetyEnabledTest, UpdateLanguageDetection) {
CreateModelExecutionManager();
const base::FilePath kTestPath{FILE_PATH_LITERAL("foo")};
std::unique_ptr<ModelInfo> model_info = TestModelInfoBuilder()
.SetVersion(123)
.SetModelFilePath(kTestPath)
.Build();
model_execution_manager()->OnModelUpdated(
proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION, *model_info);
EXPECT_EQ(kTestPath, service_controller()->language_detection_model_path());
}
TEST_F(ModelExecutionManagerSafetyEnabledTest,
NotRegisteredWhenDisabledByEnterprisePolicy) {
CreateModelExecutionManager();
model_provider()->Reset();
local_state()->SetInteger(
model_execution::prefs::localstate::
kGenAILocalFoundationalModelEnterprisePolicySettings,
static_cast<int>(model_execution::prefs::
GenAILocalFoundationalModelEnterprisePolicySettings::
kDisallowed));
CreateModelExecutionManager();
EXPECT_FALSE(model_provider()->was_registered());
// Reset manager to make sure removing observer doesn't crash.
Reset();
}
} // namespace
} // namespace optimization_guide

@ -0,0 +1,132 @@
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/optimization_guide/core/model_execution/on_device_asset_manager.h"
#include "components/optimization_guide/core/model_execution/model_execution_features.h"
#include "components/optimization_guide/core/model_execution/model_execution_util.h"
#include "components/optimization_guide/core/model_execution/on_device_model_adaptation_loader.h"
#include "components/optimization_guide/core/model_execution/on_device_model_service_controller.h"
#include "components/optimization_guide/core/optimization_guide_model_provider.h"
namespace optimization_guide {
namespace {
std::map<ModelBasedCapabilityKey, OnDeviceModelAdaptationLoader>
GetRequiredModelAdaptationLoaders(
OptimizationGuideModelProvider* model_provider,
base::WeakPtr<OnDeviceModelComponentStateManager>
on_device_component_state_manager,
PrefService* local_state,
base::WeakPtr<OnDeviceModelServiceController>
on_device_model_service_controller) {
std::map<ModelBasedCapabilityKey, OnDeviceModelAdaptationLoader> loaders;
for (const auto feature : kAllModelBasedCapabilityKeys) {
if (!features::internal::GetOptimizationTargetForCapability(feature)) {
continue;
}
loaders.emplace(
std::piecewise_construct, std::forward_as_tuple(feature),
std::forward_as_tuple(
feature, model_provider, on_device_component_state_manager,
local_state,
base::BindRepeating(
&OnDeviceModelServiceController::MaybeUpdateModelAdaptation,
on_device_model_service_controller, feature)));
}
return loaders;
}
} // namespace
OnDeviceAssetManager::OnDeviceAssetManager(
PrefService* local_state,
base::WeakPtr<OnDeviceModelServiceController> service_controller,
base::WeakPtr<OnDeviceModelComponentStateManager> component_state_manager,
raw_ptr<OptimizationGuideModelProvider> model_provider)
: on_device_model_service_controller_(service_controller),
on_device_component_state_manager_(component_state_manager),
model_provider_(model_provider),
model_adaptation_loaders_(
GetRequiredModelAdaptationLoaders(model_provider,
on_device_component_state_manager_,
local_state,
service_controller)) {
if (!features::ShouldUseTextSafetyClassifierModel()) {
return;
}
if (GetGenAILocalFoundationalModelEnterprisePolicySettings(local_state) !=
model_execution::prefs::
GenAILocalFoundationalModelEnterprisePolicySettings::kAllowed) {
return;
}
if (on_device_component_state_manager_) {
on_device_component_state_manager_->AddObserver(this);
if (on_device_component_state_manager_->IsInstallerRegistered()) {
RegisterTextSafetyAndLanguageModels();
}
}
}
OnDeviceAssetManager::~OnDeviceAssetManager() {
if (on_device_component_state_manager_) {
on_device_component_state_manager_->RemoveObserver(this);
}
if (did_register_for_supplementary_on_device_models_) {
model_provider_->RemoveObserverForOptimizationTargetModel(
proto::OptimizationTarget::OPTIMIZATION_TARGET_TEXT_SAFETY, this);
model_provider_->RemoveObserverForOptimizationTargetModel(
proto::OptimizationTarget::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
this);
}
}
// Whether the supplementary on-device models are registered.
bool OnDeviceAssetManager::IsSupplementaryModelRegistered() {
return did_register_for_supplementary_on_device_models_;
}
void OnDeviceAssetManager::RegisterTextSafetyAndLanguageModels() {
if (!did_register_for_supplementary_on_device_models_) {
did_register_for_supplementary_on_device_models_ = true;
model_provider_->AddObserverForOptimizationTargetModel(
proto::OptimizationTarget::OPTIMIZATION_TARGET_TEXT_SAFETY,
/*model_metadata=*/std::nullopt, this);
model_provider_->AddObserverForOptimizationTargetModel(
proto::OptimizationTarget::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
/*model_metadata=*/std::nullopt, this);
}
}
void OnDeviceAssetManager::OnModelUpdated(
proto::OptimizationTarget optimization_target,
base::optional_ref<const ModelInfo> model_info) {
switch (optimization_target) {
case proto::OPTIMIZATION_TARGET_TEXT_SAFETY:
if (on_device_model_service_controller_) {
on_device_model_service_controller_->MaybeUpdateSafetyModel(model_info);
}
break;
case proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION:
if (on_device_model_service_controller_) {
on_device_model_service_controller_->SetLanguageDetectionModel(
model_info);
}
break;
default:
break;
}
}
void OnDeviceAssetManager::StateChanged(
const OnDeviceModelComponentState* state) {
if (state) {
RegisterTextSafetyAndLanguageModels();
}
}
} // namespace optimization_guide

@ -0,0 +1,76 @@
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_ON_DEVICE_ASSET_MANAGER_H_
#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_ON_DEVICE_ASSET_MANAGER_H_
#include <map>
#include <memory>
#include "base/files/file_path.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "components/optimization_guide/core/model_execution/feature_keys.h"
#include "components/optimization_guide/core/model_execution/on_device_model_component.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_target_model_observer.h"
#include "components/optimization_guide/proto/model_execution.pb.h"
#include "components/optimization_guide/proto/model_quality_service.pb.h"
namespace optimization_guide {
class OnDeviceModelAdaptationLoader;
class OnDeviceModelServiceController;
class OptimizationGuideModelProvider;
// Registers for on-device asset downloads and notifies about updates.
class OnDeviceAssetManager final
: public OptimizationTargetModelObserver,
public OnDeviceModelComponentStateManager::Observer {
public:
OnDeviceAssetManager(
PrefService* local_state,
base::WeakPtr<OnDeviceModelServiceController> service_controller,
base::WeakPtr<OnDeviceModelComponentStateManager> component_state_manager,
raw_ptr<OptimizationGuideModelProvider> model_provider);
~OnDeviceAssetManager() final;
// OptimizationTargetModelObserver:
void OnModelUpdated(proto::OptimizationTarget target,
base::optional_ref<const ModelInfo> model_info) override;
private:
// Registers text safety and language detection models. Does nothing if
// already registered.
void RegisterTextSafetyAndLanguageModels();
// Whether the supplementary on-device models are registered.
bool IsSupplementaryModelRegistered();
// OnDeviceModelComponentStateManager::Observer:
void StateChanged(const OnDeviceModelComponentState* state) override;
// Controller for the on-device service.
base::WeakPtr<OnDeviceModelServiceController>
on_device_model_service_controller_;
base::WeakPtr<OnDeviceModelComponentStateManager>
on_device_component_state_manager_;
// The model provider to observe for updates to auxiliary models.
raw_ptr<OptimizationGuideModelProvider> model_provider_;
// Map from feature to its model adaptation loader. Present only for features
// that require model adaptation.
const std::map<ModelBasedCapabilityKey, OnDeviceModelAdaptationLoader>
model_adaptation_loaders_;
// Whether the user registered for supplementary on-device models.
bool did_register_for_supplementary_on_device_models_ = false;
};
} // namespace optimization_guide
#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_ON_DEVICE_ASSET_MANAGER_H_

@ -0,0 +1,225 @@
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/optimization_guide/core/model_execution/on_device_asset_manager.h"
#include <memory>
#include "base/files/file_path.h"
#include "base/functional/callback_helpers.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "base/test/test.pb.h"
#include "base/test/test_future.h"
#include "components/optimization_guide/core/model_execution/model_execution_features.h"
#include "components/optimization_guide/core/model_execution/model_execution_prefs.h"
#include "components/optimization_guide/core/model_execution/on_device_model_access_controller.h"
#include "components/optimization_guide/core/model_execution/on_device_model_adaptation_loader.h"
#include "components/optimization_guide/core/model_execution/on_device_model_service_controller.h"
#include "components/optimization_guide/core/model_execution/test/test_on_device_model_component_state_manager.h"
#include "components/optimization_guide/core/optimization_guide_constants.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/core/test_model_info_builder.h"
#include "components/optimization_guide/core/test_optimization_guide_model_provider.h"
#include "components/prefs/pref_service.h"
#include "components/sync_preferences/testing_pref_service_syncable.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace optimization_guide {
namespace {
class FakeServiceController : public OnDeviceModelServiceController {
public:
FakeServiceController()
: OnDeviceModelServiceController(nullptr, nullptr, base::DoNothing()) {}
void MaybeUpdateSafetyModel(
base::optional_ref<const ModelInfo> model_info) override {
received_safety_info_ = true;
}
bool received_safety_info() const { return received_safety_info_; }
std::optional<base::FilePath> language_detection_model_path() {
return OnDeviceModelServiceController::language_detection_model_path();
}
private:
~FakeServiceController() override = default;
bool received_safety_info_ = false;
};
class FakeModelProvider : public TestOptimizationGuideModelProvider {
public:
void AddObserverForOptimizationTargetModel(
proto::OptimizationTarget optimization_target,
const std::optional<optimization_guide::proto::Any>& model_metadata,
OptimizationTargetModelObserver* observer) override {
switch (optimization_target) {
case proto::OPTIMIZATION_TARGET_TEXT_SAFETY:
registered_for_text_safety_ = true;
break;
case proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION:
registered_for_language_detection_ = true;
break;
default:
NOTREACHED();
}
}
void Reset() {
registered_for_text_safety_ = false;
registered_for_language_detection_ = false;
}
bool was_registered() const {
return registered_for_text_safety_ && registered_for_language_detection_;
}
private:
bool registered_for_text_safety_ = false;
bool registered_for_language_detection_ = false;
};
class OnDeviceAssetManagerTest : public testing::Test {
public:
OnDeviceAssetManagerTest() {
scoped_feature_list_.InitWithFeatures({features::kTextSafetyClassifier},
{});
model_execution::prefs::RegisterLocalStatePrefs(local_state_.registry());
local_state_.SetInteger(
model_execution::prefs::localstate::kOnDevicePerformanceClass,
base::to_underlying(OnDeviceModelPerformanceClass::kHigh));
service_controller_ = base::MakeRefCounted<FakeServiceController>();
}
void CreateComponentManager() {
component_manager_.get()->OnStartup();
task_environment_.FastForwardBy(base::Seconds(1));
}
void SetModelComponentReady() {
component_manager_.SetReady(base::FilePath());
}
void CreateAssetManager() {
asset_manager_ = std::make_unique<OnDeviceAssetManager>(
&local_state_, service_controller_->GetWeakPtr(),
component_manager_.get()->GetWeakPtr(), &model_provider_, );
}
OnDeviceAssetManager* asset_manager() { return asset_manager_.get(); }
PrefService* local_state() { return &local_state_; }
FakeModelProvider* model_provider() { return &model_provider_; }
FakeServiceController* service_controller() {
return service_controller_.get();
}
void Reset() { asset_manager_ = nullptr; }
private:
base::test::TaskEnvironment task_environment_{
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
base::test::ScopedFeatureList scoped_feature_list_;
TestingPrefServiceSimple local_state_;
scoped_refptr<FakeServiceController> service_controller_;
TestOnDeviceModelComponentStateManager component_manager_{&local_state_};
FakeModelProvider model_provider_;
std::unique_ptr<OnDeviceAssetManager> asset_manager_;
};
#if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_MAC) || BUILDFLAG(IS_LINUX)
TEST_F(OnDeviceAssetManagerTest, RegistersTextSafetyModelWithOverrideModel) {
// Effectively, when an override is set, the model component will be ready
// before ModelExecutionManager can be added as an observer.
CreateComponentManager();
SetModelComponentReady();
CreateAssetManager();
EXPECT_TRUE(model_provider()->was_registered());
}
TEST_F(OnDeviceAssetManagerTest, RegistersTextSafetyModelIfEnabled) {
CreateAssetManager();
// Text safety model should not be registered until the base model is ready.
EXPECT_FALSE(model_provider()->was_registered());
CreateComponentManager();
SetModelComponentReady();
EXPECT_TRUE(model_provider()->was_registered());
}
TEST_F(OnDeviceAssetManagerTest, DoesNotRegisterTextSafetyIfNotEnabled) {
base::test::ScopedFeatureList scoped_feature_list;
scoped_feature_list.InitWithFeatures({}, {features::kTextSafetyClassifier});
CreateAssetManager();
CreateComponentManager();
SetModelComponentReady();
EXPECT_FALSE(model_provider()->was_registered());
}
#endif
TEST_F(OnDeviceAssetManagerTest, DoesNotNotifyServiceControllerWrongTarget) {
CreateAssetManager();
std::unique_ptr<ModelInfo> model_info =
TestModelInfoBuilder().SetVersion(123).Build();
asset_manager()->OnModelUpdated(proto::OPTIMIZATION_TARGET_PAGE_ENTITIES,
*model_info);
EXPECT_FALSE(service_controller()->received_safety_info());
}
TEST_F(OnDeviceAssetManagerTest, NotifiesServiceController) {
CreateAssetManager();
std::unique_ptr<ModelInfo> model_info =
TestModelInfoBuilder().SetVersion(123).Build();
asset_manager()->OnModelUpdated(proto::OPTIMIZATION_TARGET_TEXT_SAFETY,
*model_info);
EXPECT_TRUE(service_controller()->received_safety_info());
}
TEST_F(OnDeviceAssetManagerTest, UpdateLanguageDetection) {
CreateAssetManager();
const base::FilePath kTestPath{FILE_PATH_LITERAL("foo")};
std::unique_ptr<ModelInfo> model_info = TestModelInfoBuilder()
.SetVersion(123)
.SetModelFilePath(kTestPath)
.Build();
asset_manager()->OnModelUpdated(proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
*model_info);
EXPECT_EQ(kTestPath, service_controller()->language_detection_model_path());
}
TEST_F(OnDeviceAssetManagerTest, NotRegisteredWhenDisabledByEnterprisePolicy) {
CreateAssetManager();
model_provider()->Reset();
local_state()->SetInteger(
model_execution::prefs::localstate::
kGenAILocalFoundationalModelEnterprisePolicySettings,
static_cast<int>(model_execution::prefs::
GenAILocalFoundationalModelEnterprisePolicySettings::
kDisallowed));
CreateAssetManager();
EXPECT_FALSE(model_provider()->was_registered());
// Reset manager to make sure removing observer doesn't crash.
Reset();
}
} // namespace
} // namespace optimization_guide

@ -238,6 +238,10 @@ class OptimizationGuideService
scoped_refptr<optimization_guide::OnDeviceModelComponentStateManager>
on_device_model_state_manager_;
// Downloads other model assets for on-device execution.
std::unique_ptr<optimization_guide::OnDeviceAssetManager>
on_device_asset_manager_;
#endif
// Manages the model execution. Not created for off the record profiles.

@ -40,6 +40,7 @@
#import "services/network/public/cpp/shared_url_loader_factory.h"
#if BUILDFLAG(BUILD_WITH_INTERNAL_OPTIMIZATION_GUIDE)
#import "components/optimization_guide/core/model_execution/on_device_asset_manager.h"
#import "components/optimization_guide/core/model_execution/on_device_model_component.h"
#import "ios/chrome/browser/optimization_guide/model/on_device_model_service_controller_ios.h"
#endif // BUILDFLAG(BUILD_WITH_INTERNAL_OPTIMIZATION_GUIDE)
@ -173,9 +174,9 @@ OptimizationGuideService::OptimizationGuideService(
}
if (!off_the_record_) {
#if BUILDFLAG(BUILD_WITH_INTERNAL_OPTIMIZATION_GUIDE)
PrefService* local_state = GetApplicationContext()->GetLocalState();
#if BUILDFLAG(BUILD_WITH_INTERNAL_OPTIMIZATION_GUIDE)
// Create and startup the on-device model's state manager.
on_device_model_state_manager_ =
optimization_guide::OnDeviceModelComponentStateManager::CreateOrGet(
@ -193,17 +194,20 @@ OptimizationGuideService::OptimizationGuideService(
on_device_model_service_controller =
GetApplicationContext()->GetOnDeviceModelServiceController(
on_device_model_state_manager_->GetWeakPtr());
on_device_asset_manager_ =
std::make_unique<optimization_guide::OnDeviceAssetManager>(
local_state, on_device_model_service_controller->GetWeakPtr(),
on_device_model_state_manager_->GetWeakPtr(), this);
model_execution_manager_ =
std::make_unique<optimization_guide::ModelExecutionManager>(
url_loader_factory, local_state, identity_manager,
std::move(on_device_model_service_controller), this,
on_device_model_state_manager_->GetWeakPtr(),
url_loader_factory, identity_manager,
std::move(on_device_model_service_controller),
optimization_guide_logger_.get(), nullptr);
#else // BUILDFLAG(BUILD_WITH_INTERNAL_OPTIMIZATION_GUIDE)
model_execution_manager_ =
std::make_unique<optimization_guide::ModelExecutionManager>(
url_loader_factory, local_state, identity_manager, nullptr, this,
nullptr, optimization_guide_logger_.get(), nullptr);
url_loader_factory, identity_manager, nullptr,
optimization_guide_logger_.get(), nullptr);
#endif // BUILDFLAG(BUILD_WITH_INTERNAL_OPTIMIZATION_GUIDE)
}