Reland: Split out OnDeviceModelAssetManager.
Move all of the on-device asset observation logic out of ModelExecutionManager. No functional changes. Reland of crrev.com/c/6115297 + crrev.com/c/6159444 + an iOS fix, since this was not the culprit for crbug.com/388547765. Bug: 388538282, 373725257 Change-Id: I4db8241d37ef45be8a6b235479cb394d7a0b9a66 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6289432 Auto-Submit: Steven Holte <holte@chromium.org> Commit-Queue: Steven Holte <holte@chromium.org> Reviewed-by: Zekun Jiang <zekunjiang@google.com> Cr-Commit-Position: refs/heads/main@{#1424147}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
be443cc694
commit
643c0e2fc4
chrome/browser/optimization_guide
components/optimization_guide/core
BUILD.gn
model_execution
ios/chrome/browser/optimization_guide/model
@ -4,6 +4,7 @@
|
||||
|
||||
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
#include "base/files/file_path.h"
|
||||
@ -48,6 +49,7 @@
|
||||
#include "components/optimization_guide/core/model_execution/model_execution_features.h"
|
||||
#include "components/optimization_guide/core/model_execution/model_execution_features_controller.h"
|
||||
#include "components/optimization_guide/core/model_execution/model_execution_manager.h"
|
||||
#include "components/optimization_guide/core/model_execution/on_device_asset_manager.h"
|
||||
#include "components/optimization_guide/core/model_execution/on_device_model_component.h"
|
||||
#include "components/optimization_guide/core/model_execution/on_device_model_service_controller.h"
|
||||
#include "components/optimization_guide/core/model_execution/performance_class.h"
|
||||
@ -94,7 +96,6 @@ using ::optimization_guide::OnDeviceModelComponentStateManager;
|
||||
using ::optimization_guide::OnDeviceModelPerformanceClass;
|
||||
using ::optimization_guide::OnDeviceModelServiceController;
|
||||
|
||||
|
||||
// Returns the profile to use for when setting up the keyed service when the
|
||||
// profile is Off-The-Record. For guest profiles, returns a loaded profile if
|
||||
// one exists, otherwise just the original profile of the OTR profile. Note:
|
||||
@ -434,15 +435,16 @@ void OptimizationGuideKeyedService::InitializeModelExecution(Profile* profile) {
|
||||
optimization_guide::features::kOptimizationGuideOnDeviceModel)) {
|
||||
service_controller = GetOnDeviceModelServiceController(
|
||||
on_device_component_manager_->GetWeakPtr());
|
||||
on_device_asset_manager_ =
|
||||
std::make_unique<optimization_guide::OnDeviceAssetManager>(
|
||||
g_browser_process->local_state(), service_controller->GetWeakPtr(),
|
||||
on_device_component_manager_->GetWeakPtr(), this);
|
||||
}
|
||||
|
||||
model_execution_manager_ =
|
||||
std::make_unique<optimization_guide::ModelExecutionManager>(
|
||||
url_loader_factory, g_browser_process->local_state(),
|
||||
IdentityManagerFactory::GetForProfile(profile),
|
||||
std::move(service_controller), this,
|
||||
on_device_component_manager_->GetWeakPtr(),
|
||||
optimization_guide_logger_.get(),
|
||||
url_loader_factory, IdentityManagerFactory::GetForProfile(profile),
|
||||
std::move(service_controller), optimization_guide_logger_.get(),
|
||||
model_quality_logs_uploader_service_
|
||||
? model_quality_logs_uploader_service_->GetWeakPtr()
|
||||
: nullptr);
|
||||
|
@ -48,6 +48,7 @@ class ModelExecutionManager;
|
||||
class ModelInfo;
|
||||
class ModelQualityLogsUploaderService;
|
||||
class ModelValidatorKeyedService;
|
||||
class OnDeviceAssetManager;
|
||||
class OnDeviceModelAvailabilityObserver;
|
||||
class OnDeviceModelComponentStateManager;
|
||||
class OptimizationGuideStore;
|
||||
@ -359,6 +360,9 @@ class OptimizationGuideKeyedService
|
||||
// prediction models.
|
||||
std::unique_ptr<optimization_guide::PredictionManager> prediction_manager_;
|
||||
|
||||
std::unique_ptr<optimization_guide::OnDeviceAssetManager>
|
||||
on_device_asset_manager_;
|
||||
|
||||
// Manages the model execution. Not created for off the record profiles.
|
||||
std::unique_ptr<optimization_guide::ModelExecutionManager>
|
||||
model_execution_manager_;
|
||||
|
@ -250,6 +250,8 @@ static_library("core") {
|
||||
"model_execution/model_execution_util.h",
|
||||
"model_execution/multimodal_message.cc",
|
||||
"model_execution/multimodal_message.h",
|
||||
"model_execution/on_device_asset_manager.cc",
|
||||
"model_execution/on_device_asset_manager.h",
|
||||
"model_execution/on_device_context.cc",
|
||||
"model_execution/on_device_context.h",
|
||||
"model_execution/on_device_execution.cc",
|
||||
@ -532,6 +534,7 @@ source_set("unit_tests") {
|
||||
"model_execution/model_execution_fetcher_unittest.cc",
|
||||
"model_execution/model_execution_manager_unittest.cc",
|
||||
"model_execution/multimodal_message_unittest.cc",
|
||||
"model_execution/on_device_asset_manager_unittest.cc",
|
||||
"model_execution/on_device_model_adaptation_loader_unittest.cc",
|
||||
"model_execution/on_device_model_component_unittest.cc",
|
||||
"model_execution/on_device_model_execution_proto_descriptors_unittest.cc",
|
||||
|
@ -122,31 +122,6 @@ void NoOpExecuteRemoteFn(
|
||||
nullptr);
|
||||
}
|
||||
|
||||
std::map<ModelBasedCapabilityKey, OnDeviceModelAdaptationLoader>
|
||||
GetRequiredModelAdaptationLoaders(
|
||||
OptimizationGuideModelProvider* model_provider,
|
||||
base::WeakPtr<OnDeviceModelComponentStateManager>
|
||||
on_device_component_state_manager,
|
||||
PrefService* local_state,
|
||||
base::WeakPtr<OnDeviceModelServiceController>
|
||||
on_device_model_service_controller) {
|
||||
std::map<ModelBasedCapabilityKey, OnDeviceModelAdaptationLoader> loaders;
|
||||
for (const auto feature : kAllModelBasedCapabilityKeys) {
|
||||
if (!features::internal::GetOptimizationTargetForCapability(feature)) {
|
||||
continue;
|
||||
}
|
||||
loaders.emplace(
|
||||
std::piecewise_construct, std::forward_as_tuple(feature),
|
||||
std::forward_as_tuple(
|
||||
feature, model_provider, on_device_component_state_manager,
|
||||
local_state,
|
||||
base::BindRepeating(
|
||||
&OnDeviceModelServiceController::MaybeUpdateModelAdaptation,
|
||||
on_device_model_service_controller, feature)));
|
||||
}
|
||||
return loaders;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
using ModelExecutionError =
|
||||
@ -154,18 +129,13 @@ using ModelExecutionError =
|
||||
|
||||
ModelExecutionManager::ModelExecutionManager(
|
||||
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
|
||||
PrefService* local_state,
|
||||
signin::IdentityManager* identity_manager,
|
||||
scoped_refptr<OnDeviceModelServiceController>
|
||||
on_device_model_service_controller,
|
||||
OptimizationGuideModelProvider* model_provider,
|
||||
base::WeakPtr<OnDeviceModelComponentStateManager>
|
||||
on_device_component_state_manager,
|
||||
OptimizationGuideLogger* optimization_guide_logger,
|
||||
base::WeakPtr<ModelQualityLogsUploaderService>
|
||||
model_quality_uploader_service)
|
||||
: model_quality_uploader_service_(model_quality_uploader_service),
|
||||
on_device_component_state_manager_(on_device_component_state_manager),
|
||||
optimization_guide_logger_(optimization_guide_logger),
|
||||
model_execution_service_url_(net::AppendOrReplaceQueryParameter(
|
||||
GetModelExecutionServiceURL(),
|
||||
@ -173,47 +143,11 @@ ModelExecutionManager::ModelExecutionManager(
|
||||
features::GetOptimizationGuideServiceAPIKey())),
|
||||
url_loader_factory_(url_loader_factory),
|
||||
identity_manager_(identity_manager),
|
||||
model_adaptation_loaders_(GetRequiredModelAdaptationLoaders(
|
||||
model_provider,
|
||||
on_device_component_state_manager_,
|
||||
local_state,
|
||||
on_device_model_service_controller
|
||||
? on_device_model_service_controller->GetWeakPtr()
|
||||
: nullptr)),
|
||||
model_provider_(model_provider),
|
||||
on_device_model_service_controller_(
|
||||
std::move(on_device_model_service_controller)) {
|
||||
if (!model_provider_ && !on_device_model_service_controller_) {
|
||||
return;
|
||||
}
|
||||
if (!features::ShouldUseTextSafetyClassifierModel()) {
|
||||
return;
|
||||
}
|
||||
if (GetGenAILocalFoundationalModelEnterprisePolicySettings(local_state) !=
|
||||
model_execution::prefs::
|
||||
GenAILocalFoundationalModelEnterprisePolicySettings::kAllowed) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (on_device_component_state_manager_) {
|
||||
on_device_component_state_manager_->AddObserver(this);
|
||||
if (on_device_component_state_manager_->IsInstallerRegistered()) {
|
||||
RegisterTextSafetyAndLanguageModels();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ModelExecutionManager::~ModelExecutionManager() {
|
||||
if (on_device_component_state_manager_) {
|
||||
on_device_component_state_manager_->RemoveObserver(this);
|
||||
}
|
||||
if (did_register_for_supplementary_on_device_models_) {
|
||||
model_provider_->RemoveObserverForOptimizationTargetModel(
|
||||
proto::OptimizationTarget::OPTIMIZATION_TARGET_TEXT_SAFETY, this);
|
||||
model_provider_->RemoveObserverForOptimizationTargetModel(
|
||||
proto::OptimizationTarget::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
|
||||
this);
|
||||
}
|
||||
}
|
||||
|
||||
void ModelExecutionManager::Shutdown() {
|
||||
@ -323,11 +257,6 @@ ModelExecutionManager::StartSession(
|
||||
std::move(execute_fn), config_params);
|
||||
}
|
||||
|
||||
// Whether the supplementary on-device models are registered.
|
||||
bool ModelExecutionManager::IsSupplementaryModelRegistered() {
|
||||
return did_register_for_supplementary_on_device_models_;
|
||||
}
|
||||
|
||||
void ModelExecutionManager::OnModelExecuteResponse(
|
||||
ModelBasedCapabilityKey feature,
|
||||
std::unique_ptr<proto::LogAiDataRequest> log_ai_data_request,
|
||||
@ -459,47 +388,6 @@ void ModelExecutionManager::OnModelExecuteResponse(
|
||||
std::move(log_entry));
|
||||
}
|
||||
|
||||
void ModelExecutionManager::RegisterTextSafetyAndLanguageModels() {
|
||||
if (!did_register_for_supplementary_on_device_models_) {
|
||||
did_register_for_supplementary_on_device_models_ = true;
|
||||
model_provider_->AddObserverForOptimizationTargetModel(
|
||||
proto::OptimizationTarget::OPTIMIZATION_TARGET_TEXT_SAFETY,
|
||||
/*model_metadata=*/std::nullopt, this);
|
||||
model_provider_->AddObserverForOptimizationTargetModel(
|
||||
proto::OptimizationTarget::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
|
||||
/*model_metadata=*/std::nullopt, this);
|
||||
}
|
||||
}
|
||||
|
||||
void ModelExecutionManager::OnModelUpdated(
|
||||
proto::OptimizationTarget optimization_target,
|
||||
base::optional_ref<const ModelInfo> model_info) {
|
||||
switch (optimization_target) {
|
||||
case proto::OPTIMIZATION_TARGET_TEXT_SAFETY:
|
||||
if (on_device_model_service_controller_) {
|
||||
on_device_model_service_controller_->MaybeUpdateSafetyModel(model_info);
|
||||
}
|
||||
break;
|
||||
|
||||
case proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION:
|
||||
if (on_device_model_service_controller_) {
|
||||
on_device_model_service_controller_->SetLanguageDetectionModel(
|
||||
model_info);
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void ModelExecutionManager::StateChanged(
|
||||
const OnDeviceModelComponentState* state) {
|
||||
if (state) {
|
||||
RegisterTextSafetyAndLanguageModels();
|
||||
}
|
||||
}
|
||||
|
||||
optimization_guide::OnDeviceModelEligibilityReason
|
||||
ModelExecutionManager::GetOnDeviceModelEligibility(
|
||||
optimization_guide::ModelBasedCapabilityKey feature) {
|
||||
|
@ -20,8 +20,6 @@
|
||||
#include "components/optimization_guide/core/optimization_target_model_observer.h"
|
||||
#include "components/optimization_guide/proto/model_execution.pb.h"
|
||||
#include "components/optimization_guide/proto/model_quality_service.pb.h"
|
||||
#include "mojo/public/cpp/bindings/pending_remote.h"
|
||||
#include "services/on_device_model/public/mojom/on_device_model.mojom.h"
|
||||
#include "url/gurl.h"
|
||||
|
||||
class OptimizationGuideLogger;
|
||||
@ -37,28 +35,19 @@ class IdentityManager;
|
||||
namespace optimization_guide {
|
||||
|
||||
class ModelExecutionFetcher;
|
||||
class OnDeviceModelAdaptationLoader;
|
||||
class OnDeviceModelServiceController;
|
||||
class OptimizationGuideModelProvider;
|
||||
|
||||
class ModelExecutionManager
|
||||
: public OptimizationTargetModelObserver,
|
||||
public OnDeviceModelComponentStateManager::Observer {
|
||||
class ModelExecutionManager final {
|
||||
public:
|
||||
ModelExecutionManager(
|
||||
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
|
||||
PrefService* local_state,
|
||||
signin::IdentityManager* identity_manager,
|
||||
scoped_refptr<OnDeviceModelServiceController>
|
||||
on_device_model_service_controller,
|
||||
OptimizationGuideModelProvider* model_provider,
|
||||
base::WeakPtr<OnDeviceModelComponentStateManager>
|
||||
on_device_component_state_manager,
|
||||
OptimizationGuideLogger* optimization_guide_logger,
|
||||
base::WeakPtr<ModelQualityLogsUploaderService>
|
||||
model_quality_uploader_service);
|
||||
|
||||
~ModelExecutionManager() override;
|
||||
~ModelExecutionManager();
|
||||
|
||||
ModelExecutionManager(const ModelExecutionManager&) = delete;
|
||||
ModelExecutionManager& operator=(const ModelExecutionManager&) = delete;
|
||||
@ -93,16 +82,6 @@ class ModelExecutionManager
|
||||
ModelBasedCapabilityKey feature,
|
||||
const std::optional<SessionConfigParams>& config_params);
|
||||
|
||||
// Whether the supplementary on-device models are registered.
|
||||
bool IsSupplementaryModelRegistered();
|
||||
|
||||
// OptimizationTargetModelObserver:
|
||||
void OnModelUpdated(proto::OptimizationTarget target,
|
||||
base::optional_ref<const ModelInfo> model_info) override;
|
||||
|
||||
// OnDeviceModelComponentStateManager::Observer:
|
||||
void StateChanged(const OnDeviceModelComponentState* state) override;
|
||||
|
||||
void Shutdown();
|
||||
|
||||
private:
|
||||
@ -114,10 +93,6 @@ class ModelExecutionManager
|
||||
base::expected<const proto::ExecuteResponse,
|
||||
OptimizationGuideModelExecutionError> execute_response);
|
||||
|
||||
// Registers text safety and language detection models. Does nothing if
|
||||
// already registered.
|
||||
void RegisterTextSafetyAndLanguageModels();
|
||||
|
||||
// Returns the `OnDeviceModelAdaptationMetadata` for `feature`.
|
||||
std::optional<optimization_guide::OnDeviceModelAdaptationMetadata>
|
||||
GetOnDeviceModelAdaptationMetadata(
|
||||
@ -129,9 +104,6 @@ class ModelExecutionManager
|
||||
base::WeakPtr<ModelQualityLogsUploaderService>
|
||||
model_quality_uploader_service_;
|
||||
|
||||
base::WeakPtr<OnDeviceModelComponentStateManager>
|
||||
on_device_component_state_manager_;
|
||||
|
||||
// Owned by OptimizationGuideKeyedService and outlives `this`.
|
||||
raw_ptr<OptimizationGuideLogger> optimization_guide_logger_;
|
||||
|
||||
@ -149,21 +121,10 @@ class ModelExecutionManager
|
||||
// incognito profiles.
|
||||
const raw_ptr<signin::IdentityManager> identity_manager_;
|
||||
|
||||
// Map from feature to its model adaptation loader. Present only for features
|
||||
// that require model adaptation.
|
||||
const std::map<ModelBasedCapabilityKey, OnDeviceModelAdaptationLoader>
|
||||
model_adaptation_loaders_;
|
||||
|
||||
// The model provider to observe for updates to auxiliary models.
|
||||
raw_ptr<OptimizationGuideModelProvider> model_provider_;
|
||||
|
||||
// Controller for the on-device service.
|
||||
scoped_refptr<OnDeviceModelServiceController>
|
||||
on_device_model_service_controller_;
|
||||
|
||||
// Whether the user registered for supplementary on-device models.
|
||||
bool did_register_for_supplementary_on_device_models_ = false;
|
||||
|
||||
SEQUENCE_CHECKER(sequence_checker_);
|
||||
|
||||
// Used to get `weak_ptr_` to self.
|
||||
|
@ -7,18 +7,16 @@
|
||||
#include <memory>
|
||||
|
||||
#include "base/files/file_path.h"
|
||||
#include "base/functional/callback_helpers.h"
|
||||
#include "base/test/metrics/histogram_tester.h"
|
||||
#include "base/test/scoped_feature_list.h"
|
||||
#include "base/test/task_environment.h"
|
||||
#include "base/test/test.pb.h"
|
||||
#include "base/test/test_future.h"
|
||||
#include "components/optimization_guide/core/model_execution/model_execution_features.h"
|
||||
#include "components/optimization_guide/core/model_execution/model_execution_prefs.h"
|
||||
#include "components/optimization_guide/core/model_execution/on_device_model_access_controller.h"
|
||||
#include "components/optimization_guide/core/model_execution/on_device_model_service_controller.h"
|
||||
#include "components/optimization_guide/core/model_execution/test/request_builder.h"
|
||||
#include "components/optimization_guide/core/model_execution/test/response_holder.h"
|
||||
#include "components/optimization_guide/core/model_execution/test/test_on_device_model_component_state_manager.h"
|
||||
#include "components/optimization_guide/core/optimization_guide_constants.h"
|
||||
#include "components/optimization_guide/core/optimization_guide_logger.h"
|
||||
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
|
||||
@ -55,68 +53,9 @@ proto::ExecuteResponse BuildComposeResponse(const std::string& output) {
|
||||
return execute_response;
|
||||
}
|
||||
|
||||
class FakeServiceController : public OnDeviceModelServiceController {
|
||||
public:
|
||||
FakeServiceController()
|
||||
: OnDeviceModelServiceController(nullptr, nullptr, base::DoNothing()) {}
|
||||
|
||||
void MaybeUpdateSafetyModel(
|
||||
base::optional_ref<const ModelInfo> model_info) override {
|
||||
received_safety_info_ = true;
|
||||
}
|
||||
|
||||
bool received_safety_info() const { return received_safety_info_; }
|
||||
|
||||
std::optional<base::FilePath> language_detection_model_path() {
|
||||
return OnDeviceModelServiceController::language_detection_model_path();
|
||||
}
|
||||
|
||||
private:
|
||||
~FakeServiceController() override = default;
|
||||
|
||||
bool received_safety_info_ = false;
|
||||
};
|
||||
|
||||
class FakeModelProvider : public TestOptimizationGuideModelProvider {
|
||||
public:
|
||||
void AddObserverForOptimizationTargetModel(
|
||||
proto::OptimizationTarget optimization_target,
|
||||
const std::optional<optimization_guide::proto::Any>& model_metadata,
|
||||
OptimizationTargetModelObserver* observer) override {
|
||||
switch (optimization_target) {
|
||||
case proto::OPTIMIZATION_TARGET_TEXT_SAFETY:
|
||||
registered_for_text_safety_ = true;
|
||||
break;
|
||||
|
||||
case proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION:
|
||||
registered_for_language_detection_ = true;
|
||||
break;
|
||||
|
||||
default:
|
||||
NOTREACHED();
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
registered_for_text_safety_ = false;
|
||||
registered_for_language_detection_ = false;
|
||||
}
|
||||
|
||||
bool was_registered() const {
|
||||
return registered_for_text_safety_ && registered_for_language_detection_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool registered_for_text_safety_ = false;
|
||||
bool registered_for_language_detection_ = false;
|
||||
};
|
||||
|
||||
class ModelExecutionManagerTest : public testing::Test {
|
||||
public:
|
||||
ModelExecutionManagerTest() {
|
||||
scoped_feature_list_.InitWithFeatures({},
|
||||
{features::kTextSafetyClassifier});
|
||||
}
|
||||
ModelExecutionManagerTest() = default;
|
||||
~ModelExecutionManagerTest() override = default;
|
||||
|
||||
// Sets up most of the fields except `model_execution_manager_` and
|
||||
@ -125,29 +64,11 @@ class ModelExecutionManagerTest : public testing::Test {
|
||||
url_loader_factory_ =
|
||||
base::MakeRefCounted<network::WeakWrapperSharedURLLoaderFactory>(
|
||||
&test_url_loader_factory_);
|
||||
local_state_ = std::make_unique<TestingPrefServiceSimple>();
|
||||
model_execution::prefs::RegisterLocalStatePrefs(local_state_->registry());
|
||||
service_controller_ = base::MakeRefCounted<FakeServiceController>();
|
||||
}
|
||||
|
||||
void CreateModelExecutionManager() {
|
||||
service_controller_ = base::MakeRefCounted<OnDeviceModelServiceController>(
|
||||
nullptr, nullptr, base::DoNothing());
|
||||
model_execution_manager_ = std::make_unique<ModelExecutionManager>(
|
||||
url_loader_factory_, local_state_.get(),
|
||||
identity_test_env_.identity_manager(), service_controller_,
|
||||
&model_provider_,
|
||||
component_manager_ ? component_manager_->get()->GetWeakPtr() : nullptr,
|
||||
&optimization_guide_logger_, nullptr);
|
||||
}
|
||||
|
||||
void CreateComponentManager(bool should_observe) {
|
||||
component_manager_ =
|
||||
std::make_unique<TestOnDeviceModelComponentStateManager>(
|
||||
local_state_.get());
|
||||
component_manager_->get()->OnStartup();
|
||||
task_environment_.FastForwardBy(base::Seconds(1));
|
||||
if (should_observe) {
|
||||
component_manager_->get()->AddObserver(model_execution_manager_.get());
|
||||
}
|
||||
url_loader_factory_, identity_test_env_.identity_manager(),
|
||||
service_controller_, &optimization_guide_logger_, nullptr);
|
||||
}
|
||||
|
||||
bool SimulateResponse(const std::string& content,
|
||||
@ -179,9 +100,7 @@ class ModelExecutionManagerTest : public testing::Test {
|
||||
return model_execution_manager_.get();
|
||||
}
|
||||
|
||||
FakeModelProvider* model_provider() { return &model_provider_; }
|
||||
|
||||
FakeServiceController* service_controller() {
|
||||
OnDeviceModelServiceController* service_controller() {
|
||||
return service_controller_.get();
|
||||
}
|
||||
|
||||
@ -195,37 +114,25 @@ class ModelExecutionManagerTest : public testing::Test {
|
||||
EXPECT_THAT(body_bytes, HasSubstr(message));
|
||||
}
|
||||
|
||||
void SetModelComponentReady() {
|
||||
component_manager_->SetReady(base::FilePath());
|
||||
}
|
||||
|
||||
network::TestURLLoaderFactory* test_url_loader_factory() {
|
||||
return &test_url_loader_factory_;
|
||||
}
|
||||
|
||||
PrefService* local_state() { return local_state_.get(); }
|
||||
|
||||
void Reset() { model_execution_manager_ = nullptr; }
|
||||
|
||||
private:
|
||||
base::test::TaskEnvironment task_environment_{
|
||||
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
|
||||
base::test::ScopedFeatureList scoped_feature_list_;
|
||||
std::unique_ptr<TestingPrefServiceSimple> local_state_;
|
||||
signin::IdentityTestEnvironment identity_test_env_;
|
||||
variations::ScopedVariationsIdsProvider scoped_variations_ids_provider_{
|
||||
variations::VariationsIdsProvider::Mode::kUseSignedInState};
|
||||
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
|
||||
network::TestURLLoaderFactory test_url_loader_factory_;
|
||||
scoped_refptr<FakeServiceController> service_controller_;
|
||||
std::unique_ptr<TestOnDeviceModelComponentStateManager> component_manager_;
|
||||
FakeModelProvider model_provider_;
|
||||
scoped_refptr<OnDeviceModelServiceController> service_controller_;
|
||||
OptimizationGuideLogger optimization_guide_logger_;
|
||||
std::unique_ptr<ModelExecutionManager> model_execution_manager_;
|
||||
};
|
||||
|
||||
TEST_F(ModelExecutionManagerTest, ExecuteModelEmptyAccessToken) {
|
||||
CreateModelExecutionManager();
|
||||
base::HistogramTester histogram_tester;
|
||||
ResponseHolder response_holder;
|
||||
model_execution_manager()->ExecuteModel(
|
||||
@ -244,7 +151,6 @@ TEST_F(ModelExecutionManagerTest, ExecuteModelEmptyAccessToken) {
|
||||
}
|
||||
|
||||
TEST_F(ModelExecutionManagerTest, ExecuteModelWithUserSignIn) {
|
||||
CreateModelExecutionManager();
|
||||
base::HistogramTester histogram_tester;
|
||||
ResponseHolder response_holder;
|
||||
SetAutomaticIssueOfAccessTokens();
|
||||
@ -266,7 +172,6 @@ TEST_F(ModelExecutionManagerTest, ExecuteModelWithUserSignIn) {
|
||||
}
|
||||
|
||||
TEST_F(ModelExecutionManagerTest, ExecuteModelWithServerError) {
|
||||
CreateModelExecutionManager();
|
||||
base::HistogramTester histogram_tester;
|
||||
|
||||
ResponseHolder response_holder;
|
||||
@ -298,7 +203,6 @@ TEST_F(ModelExecutionManagerTest, ExecuteModelWithServerError) {
|
||||
|
||||
TEST_F(ModelExecutionManagerTest,
|
||||
ExecuteModelWithServerErrorAllowedForLogging) {
|
||||
CreateModelExecutionManager();
|
||||
base::HistogramTester histogram_tester;
|
||||
|
||||
ResponseHolder response_holder;
|
||||
@ -340,7 +244,6 @@ TEST_F(ModelExecutionManagerTest,
|
||||
}
|
||||
|
||||
TEST_F(ModelExecutionManagerTest, ExecuteModelExecutionModeSetOnDeviceOnly) {
|
||||
CreateModelExecutionManager();
|
||||
base::HistogramTester histogram_tester;
|
||||
|
||||
SetAutomaticIssueOfAccessTokens();
|
||||
@ -359,7 +262,6 @@ TEST_F(ModelExecutionManagerTest, ExecuteModelExecutionModeSetOnDeviceOnly) {
|
||||
}
|
||||
|
||||
TEST_F(ModelExecutionManagerTest, ExecuteModelExecutionModeSetToServerOnly) {
|
||||
CreateModelExecutionManager();
|
||||
base::HistogramTester histogram_tester;
|
||||
|
||||
ResponseHolder response_holder;
|
||||
@ -392,7 +294,6 @@ TEST_F(ModelExecutionManagerTest, ExecuteModelExecutionModeSetToServerOnly) {
|
||||
|
||||
TEST_F(ModelExecutionManagerTest,
|
||||
ExecuteModelExecutionModeExplicitlySetToDefault) {
|
||||
CreateModelExecutionManager();
|
||||
base::HistogramTester histogram_tester;
|
||||
|
||||
ResponseHolder response_holder;
|
||||
@ -424,7 +325,6 @@ TEST_F(ModelExecutionManagerTest,
|
||||
}
|
||||
|
||||
TEST_F(ModelExecutionManagerTest, ExecuteModelWithPassthroughSession) {
|
||||
CreateModelExecutionManager();
|
||||
base::HistogramTester histogram_tester;
|
||||
|
||||
ResponseHolder response_holder;
|
||||
@ -450,7 +350,6 @@ TEST_F(ModelExecutionManagerTest, ExecuteModelWithPassthroughSession) {
|
||||
}
|
||||
|
||||
TEST_F(ModelExecutionManagerTest, LogsContextToExecutionTimeHistogram) {
|
||||
CreateModelExecutionManager();
|
||||
base::HistogramTester histogram_tester;
|
||||
SetAutomaticIssueOfAccessTokens();
|
||||
auto session = model_execution_manager()->StartSession(
|
||||
@ -491,7 +390,6 @@ TEST_F(ModelExecutionManagerTest, LogsContextToExecutionTimeHistogram) {
|
||||
|
||||
TEST_F(ModelExecutionManagerTest,
|
||||
ExecuteModelWithPassthroughSessionAddContext) {
|
||||
CreateModelExecutionManager();
|
||||
ResponseHolder response_holder;
|
||||
SetAutomaticIssueOfAccessTokens();
|
||||
auto session = model_execution_manager()->StartSession(
|
||||
@ -508,7 +406,6 @@ TEST_F(ModelExecutionManagerTest,
|
||||
|
||||
TEST_F(ModelExecutionManagerTest,
|
||||
ExecuteModelWithPassthroughSessionMultipleAddContext) {
|
||||
CreateModelExecutionManager();
|
||||
ResponseHolder response_holder;
|
||||
SetAutomaticIssueOfAccessTokens();
|
||||
auto session = model_execution_manager()->StartSession(
|
||||
@ -525,7 +422,6 @@ TEST_F(ModelExecutionManagerTest,
|
||||
|
||||
TEST_F(ModelExecutionManagerTest,
|
||||
ExecuteModelWithPassthroughSessionExecuteOverridesAddContext) {
|
||||
CreateModelExecutionManager();
|
||||
ResponseHolder response_holder;
|
||||
SetAutomaticIssueOfAccessTokens();
|
||||
auto session = model_execution_manager()->StartSession(
|
||||
@ -541,7 +437,6 @@ TEST_F(ModelExecutionManagerTest,
|
||||
}
|
||||
|
||||
TEST_F(ModelExecutionManagerTest, TestMultipleParallelRequests) {
|
||||
CreateModelExecutionManager();
|
||||
base::HistogramTester histogram_tester;
|
||||
ResponseHolder response_holder1, response_holder2;
|
||||
|
||||
@ -582,103 +477,6 @@ TEST_F(ModelExecutionManagerTest, TestMultipleParallelRequests) {
|
||||
"OptimizationGuide.ModelExecution.Result.Compose", false, 1);
|
||||
}
|
||||
|
||||
TEST_F(ModelExecutionManagerTest, DoesNotRegisterTextSafetyIfNotEnabled) {
|
||||
CreateModelExecutionManager();
|
||||
EXPECT_FALSE(model_provider()->was_registered());
|
||||
}
|
||||
|
||||
class ModelExecutionManagerSafetyEnabledTest
|
||||
: public ModelExecutionManagerTest {
|
||||
public:
|
||||
ModelExecutionManagerSafetyEnabledTest() {
|
||||
scoped_feature_list_.InitWithFeatures({features::kTextSafetyClassifier},
|
||||
{});
|
||||
}
|
||||
|
||||
private:
|
||||
base::test::ScopedFeatureList scoped_feature_list_;
|
||||
};
|
||||
|
||||
#if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_MAC) || BUILDFLAG(IS_LINUX)
|
||||
TEST_F(ModelExecutionManagerSafetyEnabledTest,
|
||||
RegistersTextSafetyModelWithOverrideModel) {
|
||||
// Effectively, when an override is set, the model component will be ready
|
||||
// before ModelExecutionManager can be added as an observer. Here we simulate
|
||||
// that by simply setting up the component without adding
|
||||
// ModelExecutionManager as an observer.
|
||||
CreateComponentManager(/*should_observe=*/false);
|
||||
SetModelComponentReady();
|
||||
CreateModelExecutionManager();
|
||||
|
||||
EXPECT_TRUE(model_provider()->was_registered());
|
||||
}
|
||||
|
||||
TEST_F(ModelExecutionManagerSafetyEnabledTest,
|
||||
RegistersTextSafetyModelIfEnabled) {
|
||||
CreateModelExecutionManager();
|
||||
EXPECT_FALSE(model_provider()->was_registered());
|
||||
|
||||
// Text safety model should only be registered after the base model is ready.
|
||||
local_state()->SetInteger(
|
||||
model_execution::prefs::localstate::kOnDevicePerformanceClass,
|
||||
base::to_underlying(OnDeviceModelPerformanceClass::kHigh));
|
||||
CreateComponentManager(/*should_observe=*/true);
|
||||
SetModelComponentReady();
|
||||
|
||||
EXPECT_TRUE(model_provider()->was_registered());
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(ModelExecutionManagerSafetyEnabledTest,
|
||||
DoesNotNotifyServiceControllerWrongTarget) {
|
||||
CreateModelExecutionManager();
|
||||
std::unique_ptr<ModelInfo> model_info =
|
||||
TestModelInfoBuilder().SetVersion(123).Build();
|
||||
model_execution_manager()->OnModelUpdated(
|
||||
proto::OPTIMIZATION_TARGET_PAGE_ENTITIES, *model_info);
|
||||
|
||||
EXPECT_FALSE(service_controller()->received_safety_info());
|
||||
}
|
||||
|
||||
TEST_F(ModelExecutionManagerSafetyEnabledTest, NotifiesServiceController) {
|
||||
CreateModelExecutionManager();
|
||||
std::unique_ptr<ModelInfo> model_info =
|
||||
TestModelInfoBuilder().SetVersion(123).Build();
|
||||
model_execution_manager()->OnModelUpdated(
|
||||
proto::OPTIMIZATION_TARGET_TEXT_SAFETY, *model_info);
|
||||
|
||||
EXPECT_TRUE(service_controller()->received_safety_info());
|
||||
}
|
||||
|
||||
TEST_F(ModelExecutionManagerSafetyEnabledTest, UpdateLanguageDetection) {
|
||||
CreateModelExecutionManager();
|
||||
const base::FilePath kTestPath{FILE_PATH_LITERAL("foo")};
|
||||
std::unique_ptr<ModelInfo> model_info = TestModelInfoBuilder()
|
||||
.SetVersion(123)
|
||||
.SetModelFilePath(kTestPath)
|
||||
.Build();
|
||||
model_execution_manager()->OnModelUpdated(
|
||||
proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION, *model_info);
|
||||
EXPECT_EQ(kTestPath, service_controller()->language_detection_model_path());
|
||||
}
|
||||
|
||||
TEST_F(ModelExecutionManagerSafetyEnabledTest,
|
||||
NotRegisteredWhenDisabledByEnterprisePolicy) {
|
||||
CreateModelExecutionManager();
|
||||
model_provider()->Reset();
|
||||
local_state()->SetInteger(
|
||||
model_execution::prefs::localstate::
|
||||
kGenAILocalFoundationalModelEnterprisePolicySettings,
|
||||
static_cast<int>(model_execution::prefs::
|
||||
GenAILocalFoundationalModelEnterprisePolicySettings::
|
||||
kDisallowed));
|
||||
CreateModelExecutionManager();
|
||||
EXPECT_FALSE(model_provider()->was_registered());
|
||||
|
||||
// Reset manager to make sure removing observer doesn't crash.
|
||||
Reset();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace optimization_guide
|
||||
|
@ -0,0 +1,132 @@
|
||||
// Copyright 2024 The Chromium Authors
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
#include "components/optimization_guide/core/model_execution/on_device_asset_manager.h"
|
||||
|
||||
#include "components/optimization_guide/core/model_execution/model_execution_features.h"
|
||||
#include "components/optimization_guide/core/model_execution/model_execution_util.h"
|
||||
#include "components/optimization_guide/core/model_execution/on_device_model_adaptation_loader.h"
|
||||
#include "components/optimization_guide/core/model_execution/on_device_model_service_controller.h"
|
||||
#include "components/optimization_guide/core/optimization_guide_model_provider.h"
|
||||
|
||||
namespace optimization_guide {
|
||||
|
||||
namespace {
|
||||
|
||||
std::map<ModelBasedCapabilityKey, OnDeviceModelAdaptationLoader>
|
||||
GetRequiredModelAdaptationLoaders(
|
||||
OptimizationGuideModelProvider* model_provider,
|
||||
base::WeakPtr<OnDeviceModelComponentStateManager>
|
||||
on_device_component_state_manager,
|
||||
PrefService* local_state,
|
||||
base::WeakPtr<OnDeviceModelServiceController>
|
||||
on_device_model_service_controller) {
|
||||
std::map<ModelBasedCapabilityKey, OnDeviceModelAdaptationLoader> loaders;
|
||||
for (const auto feature : kAllModelBasedCapabilityKeys) {
|
||||
if (!features::internal::GetOptimizationTargetForCapability(feature)) {
|
||||
continue;
|
||||
}
|
||||
loaders.emplace(
|
||||
std::piecewise_construct, std::forward_as_tuple(feature),
|
||||
std::forward_as_tuple(
|
||||
feature, model_provider, on_device_component_state_manager,
|
||||
local_state,
|
||||
base::BindRepeating(
|
||||
&OnDeviceModelServiceController::MaybeUpdateModelAdaptation,
|
||||
on_device_model_service_controller, feature)));
|
||||
}
|
||||
return loaders;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
OnDeviceAssetManager::OnDeviceAssetManager(
|
||||
PrefService* local_state,
|
||||
base::WeakPtr<OnDeviceModelServiceController> service_controller,
|
||||
base::WeakPtr<OnDeviceModelComponentStateManager> component_state_manager,
|
||||
raw_ptr<OptimizationGuideModelProvider> model_provider)
|
||||
: on_device_model_service_controller_(service_controller),
|
||||
on_device_component_state_manager_(component_state_manager),
|
||||
model_provider_(model_provider),
|
||||
model_adaptation_loaders_(
|
||||
GetRequiredModelAdaptationLoaders(model_provider,
|
||||
on_device_component_state_manager_,
|
||||
local_state,
|
||||
service_controller)) {
|
||||
if (!features::ShouldUseTextSafetyClassifierModel()) {
|
||||
return;
|
||||
}
|
||||
if (GetGenAILocalFoundationalModelEnterprisePolicySettings(local_state) !=
|
||||
model_execution::prefs::
|
||||
GenAILocalFoundationalModelEnterprisePolicySettings::kAllowed) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (on_device_component_state_manager_) {
|
||||
on_device_component_state_manager_->AddObserver(this);
|
||||
if (on_device_component_state_manager_->IsInstallerRegistered()) {
|
||||
RegisterTextSafetyAndLanguageModels();
|
||||
}
|
||||
}
|
||||
}
|
||||
OnDeviceAssetManager::~OnDeviceAssetManager() {
|
||||
if (on_device_component_state_manager_) {
|
||||
on_device_component_state_manager_->RemoveObserver(this);
|
||||
}
|
||||
if (did_register_for_supplementary_on_device_models_) {
|
||||
model_provider_->RemoveObserverForOptimizationTargetModel(
|
||||
proto::OptimizationTarget::OPTIMIZATION_TARGET_TEXT_SAFETY, this);
|
||||
model_provider_->RemoveObserverForOptimizationTargetModel(
|
||||
proto::OptimizationTarget::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
|
||||
this);
|
||||
}
|
||||
}
|
||||
|
||||
// Whether the supplementary on-device models are registered.
|
||||
bool OnDeviceAssetManager::IsSupplementaryModelRegistered() {
|
||||
return did_register_for_supplementary_on_device_models_;
|
||||
}
|
||||
|
||||
void OnDeviceAssetManager::RegisterTextSafetyAndLanguageModels() {
|
||||
if (!did_register_for_supplementary_on_device_models_) {
|
||||
did_register_for_supplementary_on_device_models_ = true;
|
||||
model_provider_->AddObserverForOptimizationTargetModel(
|
||||
proto::OptimizationTarget::OPTIMIZATION_TARGET_TEXT_SAFETY,
|
||||
/*model_metadata=*/std::nullopt, this);
|
||||
model_provider_->AddObserverForOptimizationTargetModel(
|
||||
proto::OptimizationTarget::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
|
||||
/*model_metadata=*/std::nullopt, this);
|
||||
}
|
||||
}
|
||||
|
||||
void OnDeviceAssetManager::OnModelUpdated(
|
||||
proto::OptimizationTarget optimization_target,
|
||||
base::optional_ref<const ModelInfo> model_info) {
|
||||
switch (optimization_target) {
|
||||
case proto::OPTIMIZATION_TARGET_TEXT_SAFETY:
|
||||
if (on_device_model_service_controller_) {
|
||||
on_device_model_service_controller_->MaybeUpdateSafetyModel(model_info);
|
||||
}
|
||||
break;
|
||||
|
||||
case proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION:
|
||||
if (on_device_model_service_controller_) {
|
||||
on_device_model_service_controller_->SetLanguageDetectionModel(
|
||||
model_info);
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void OnDeviceAssetManager::StateChanged(
|
||||
const OnDeviceModelComponentState* state) {
|
||||
if (state) {
|
||||
RegisterTextSafetyAndLanguageModels();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace optimization_guide
|
@ -0,0 +1,76 @@
|
||||
// Copyright 2024 The Chromium Authors
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file.
|
||||
#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_ON_DEVICE_ASSET_MANAGER_H_
|
||||
#define COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_ON_DEVICE_ASSET_MANAGER_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
#include "base/files/file_path.h"
|
||||
#include "base/memory/raw_ptr.h"
|
||||
#include "base/memory/weak_ptr.h"
|
||||
#include "base/sequence_checker.h"
|
||||
#include "components/optimization_guide/core/model_execution/feature_keys.h"
|
||||
#include "components/optimization_guide/core/model_execution/on_device_model_component.h"
|
||||
#include "components/optimization_guide/core/optimization_guide_features.h"
|
||||
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
|
||||
#include "components/optimization_guide/core/optimization_target_model_observer.h"
|
||||
#include "components/optimization_guide/proto/model_execution.pb.h"
|
||||
#include "components/optimization_guide/proto/model_quality_service.pb.h"
|
||||
|
||||
namespace optimization_guide {
|
||||
|
||||
class OnDeviceModelAdaptationLoader;
|
||||
class OnDeviceModelServiceController;
|
||||
class OptimizationGuideModelProvider;
|
||||
|
||||
// Registers for on-device asset downloads and notifies about updates.
|
||||
class OnDeviceAssetManager final
|
||||
: public OptimizationTargetModelObserver,
|
||||
public OnDeviceModelComponentStateManager::Observer {
|
||||
public:
|
||||
OnDeviceAssetManager(
|
||||
PrefService* local_state,
|
||||
base::WeakPtr<OnDeviceModelServiceController> service_controller,
|
||||
base::WeakPtr<OnDeviceModelComponentStateManager> component_state_manager,
|
||||
raw_ptr<OptimizationGuideModelProvider> model_provider);
|
||||
~OnDeviceAssetManager() final;
|
||||
|
||||
// OptimizationTargetModelObserver:
|
||||
void OnModelUpdated(proto::OptimizationTarget target,
|
||||
base::optional_ref<const ModelInfo> model_info) override;
|
||||
|
||||
private:
|
||||
// Registers text safety and language detection models. Does nothing if
|
||||
// already registered.
|
||||
void RegisterTextSafetyAndLanguageModels();
|
||||
|
||||
// Whether the supplementary on-device models are registered.
|
||||
bool IsSupplementaryModelRegistered();
|
||||
|
||||
// OnDeviceModelComponentStateManager::Observer:
|
||||
void StateChanged(const OnDeviceModelComponentState* state) override;
|
||||
|
||||
// Controller for the on-device service.
|
||||
base::WeakPtr<OnDeviceModelServiceController>
|
||||
on_device_model_service_controller_;
|
||||
|
||||
base::WeakPtr<OnDeviceModelComponentStateManager>
|
||||
on_device_component_state_manager_;
|
||||
|
||||
// The model provider to observe for updates to auxiliary models.
|
||||
raw_ptr<OptimizationGuideModelProvider> model_provider_;
|
||||
|
||||
// Map from feature to its model adaptation loader. Present only for features
|
||||
// that require model adaptation.
|
||||
const std::map<ModelBasedCapabilityKey, OnDeviceModelAdaptationLoader>
|
||||
model_adaptation_loaders_;
|
||||
|
||||
// Whether the user registered for supplementary on-device models.
|
||||
bool did_register_for_supplementary_on_device_models_ = false;
|
||||
};
|
||||
|
||||
} // namespace optimization_guide
|
||||
|
||||
#endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_MODEL_EXECUTION_ON_DEVICE_ASSET_MANAGER_H_
|
225
components/optimization_guide/core/model_execution/on_device_asset_manager_unittest.cc
Normal file
225
components/optimization_guide/core/model_execution/on_device_asset_manager_unittest.cc
Normal file
@ -0,0 +1,225 @@
|
||||
// Copyright 2024 The Chromium Authors
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
#include "components/optimization_guide/core/model_execution/on_device_asset_manager.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "base/files/file_path.h"
|
||||
#include "base/functional/callback_helpers.h"
|
||||
#include "base/test/metrics/histogram_tester.h"
|
||||
#include "base/test/scoped_feature_list.h"
|
||||
#include "base/test/task_environment.h"
|
||||
#include "base/test/test.pb.h"
|
||||
#include "base/test/test_future.h"
|
||||
#include "components/optimization_guide/core/model_execution/model_execution_features.h"
|
||||
#include "components/optimization_guide/core/model_execution/model_execution_prefs.h"
|
||||
#include "components/optimization_guide/core/model_execution/on_device_model_access_controller.h"
|
||||
#include "components/optimization_guide/core/model_execution/on_device_model_adaptation_loader.h"
|
||||
#include "components/optimization_guide/core/model_execution/on_device_model_service_controller.h"
|
||||
#include "components/optimization_guide/core/model_execution/test/test_on_device_model_component_state_manager.h"
|
||||
#include "components/optimization_guide/core/optimization_guide_constants.h"
|
||||
#include "components/optimization_guide/core/optimization_guide_util.h"
|
||||
#include "components/optimization_guide/core/test_model_info_builder.h"
|
||||
#include "components/optimization_guide/core/test_optimization_guide_model_provider.h"
|
||||
#include "components/prefs/pref_service.h"
|
||||
#include "components/sync_preferences/testing_pref_service_syncable.h"
|
||||
#include "testing/gmock/include/gmock/gmock.h"
|
||||
#include "testing/gtest/include/gtest/gtest.h"
|
||||
|
||||
namespace optimization_guide {
|
||||
|
||||
namespace {
|
||||
|
||||
class FakeServiceController : public OnDeviceModelServiceController {
|
||||
public:
|
||||
FakeServiceController()
|
||||
: OnDeviceModelServiceController(nullptr, nullptr, base::DoNothing()) {}
|
||||
|
||||
void MaybeUpdateSafetyModel(
|
||||
base::optional_ref<const ModelInfo> model_info) override {
|
||||
received_safety_info_ = true;
|
||||
}
|
||||
|
||||
bool received_safety_info() const { return received_safety_info_; }
|
||||
|
||||
std::optional<base::FilePath> language_detection_model_path() {
|
||||
return OnDeviceModelServiceController::language_detection_model_path();
|
||||
}
|
||||
|
||||
private:
|
||||
~FakeServiceController() override = default;
|
||||
|
||||
bool received_safety_info_ = false;
|
||||
};
|
||||
|
||||
class FakeModelProvider : public TestOptimizationGuideModelProvider {
|
||||
public:
|
||||
void AddObserverForOptimizationTargetModel(
|
||||
proto::OptimizationTarget optimization_target,
|
||||
const std::optional<optimization_guide::proto::Any>& model_metadata,
|
||||
OptimizationTargetModelObserver* observer) override {
|
||||
switch (optimization_target) {
|
||||
case proto::OPTIMIZATION_TARGET_TEXT_SAFETY:
|
||||
registered_for_text_safety_ = true;
|
||||
break;
|
||||
|
||||
case proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION:
|
||||
registered_for_language_detection_ = true;
|
||||
break;
|
||||
|
||||
default:
|
||||
NOTREACHED();
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
registered_for_text_safety_ = false;
|
||||
registered_for_language_detection_ = false;
|
||||
}
|
||||
|
||||
bool was_registered() const {
|
||||
return registered_for_text_safety_ && registered_for_language_detection_;
|
||||
}
|
||||
|
||||
private:
|
||||
bool registered_for_text_safety_ = false;
|
||||
bool registered_for_language_detection_ = false;
|
||||
};
|
||||
|
||||
class OnDeviceAssetManagerTest : public testing::Test {
|
||||
public:
|
||||
OnDeviceAssetManagerTest() {
|
||||
scoped_feature_list_.InitWithFeatures({features::kTextSafetyClassifier},
|
||||
{});
|
||||
model_execution::prefs::RegisterLocalStatePrefs(local_state_.registry());
|
||||
local_state_.SetInteger(
|
||||
model_execution::prefs::localstate::kOnDevicePerformanceClass,
|
||||
base::to_underlying(OnDeviceModelPerformanceClass::kHigh));
|
||||
service_controller_ = base::MakeRefCounted<FakeServiceController>();
|
||||
}
|
||||
|
||||
void CreateComponentManager() {
|
||||
component_manager_.get()->OnStartup();
|
||||
task_environment_.FastForwardBy(base::Seconds(1));
|
||||
}
|
||||
|
||||
void SetModelComponentReady() {
|
||||
component_manager_.SetReady(base::FilePath());
|
||||
}
|
||||
|
||||
void CreateAssetManager() {
|
||||
asset_manager_ = std::make_unique<OnDeviceAssetManager>(
|
||||
&local_state_, service_controller_->GetWeakPtr(),
|
||||
component_manager_.get()->GetWeakPtr(), &model_provider_, );
|
||||
}
|
||||
|
||||
OnDeviceAssetManager* asset_manager() { return asset_manager_.get(); }
|
||||
|
||||
PrefService* local_state() { return &local_state_; }
|
||||
|
||||
FakeModelProvider* model_provider() { return &model_provider_; }
|
||||
|
||||
FakeServiceController* service_controller() {
|
||||
return service_controller_.get();
|
||||
}
|
||||
|
||||
void Reset() { asset_manager_ = nullptr; }
|
||||
|
||||
private:
|
||||
base::test::TaskEnvironment task_environment_{
|
||||
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
|
||||
base::test::ScopedFeatureList scoped_feature_list_;
|
||||
TestingPrefServiceSimple local_state_;
|
||||
scoped_refptr<FakeServiceController> service_controller_;
|
||||
TestOnDeviceModelComponentStateManager component_manager_{&local_state_};
|
||||
FakeModelProvider model_provider_;
|
||||
std::unique_ptr<OnDeviceAssetManager> asset_manager_;
|
||||
};
|
||||
|
||||
#if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_MAC) || BUILDFLAG(IS_LINUX)
|
||||
TEST_F(OnDeviceAssetManagerTest, RegistersTextSafetyModelWithOverrideModel) {
|
||||
// Effectively, when an override is set, the model component will be ready
|
||||
// before ModelExecutionManager can be added as an observer.
|
||||
CreateComponentManager();
|
||||
SetModelComponentReady();
|
||||
|
||||
CreateAssetManager();
|
||||
|
||||
EXPECT_TRUE(model_provider()->was_registered());
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceAssetManagerTest, RegistersTextSafetyModelIfEnabled) {
|
||||
CreateAssetManager();
|
||||
|
||||
// Text safety model should not be registered until the base model is ready.
|
||||
EXPECT_FALSE(model_provider()->was_registered());
|
||||
|
||||
CreateComponentManager();
|
||||
SetModelComponentReady();
|
||||
|
||||
EXPECT_TRUE(model_provider()->was_registered());
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceAssetManagerTest, DoesNotRegisterTextSafetyIfNotEnabled) {
|
||||
base::test::ScopedFeatureList scoped_feature_list;
|
||||
scoped_feature_list.InitWithFeatures({}, {features::kTextSafetyClassifier});
|
||||
CreateAssetManager();
|
||||
CreateComponentManager();
|
||||
SetModelComponentReady();
|
||||
EXPECT_FALSE(model_provider()->was_registered());
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(OnDeviceAssetManagerTest, DoesNotNotifyServiceControllerWrongTarget) {
|
||||
CreateAssetManager();
|
||||
std::unique_ptr<ModelInfo> model_info =
|
||||
TestModelInfoBuilder().SetVersion(123).Build();
|
||||
asset_manager()->OnModelUpdated(proto::OPTIMIZATION_TARGET_PAGE_ENTITIES,
|
||||
*model_info);
|
||||
|
||||
EXPECT_FALSE(service_controller()->received_safety_info());
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceAssetManagerTest, NotifiesServiceController) {
|
||||
CreateAssetManager();
|
||||
std::unique_ptr<ModelInfo> model_info =
|
||||
TestModelInfoBuilder().SetVersion(123).Build();
|
||||
asset_manager()->OnModelUpdated(proto::OPTIMIZATION_TARGET_TEXT_SAFETY,
|
||||
*model_info);
|
||||
|
||||
EXPECT_TRUE(service_controller()->received_safety_info());
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceAssetManagerTest, UpdateLanguageDetection) {
|
||||
CreateAssetManager();
|
||||
const base::FilePath kTestPath{FILE_PATH_LITERAL("foo")};
|
||||
std::unique_ptr<ModelInfo> model_info = TestModelInfoBuilder()
|
||||
.SetVersion(123)
|
||||
.SetModelFilePath(kTestPath)
|
||||
.Build();
|
||||
asset_manager()->OnModelUpdated(proto::OPTIMIZATION_TARGET_LANGUAGE_DETECTION,
|
||||
*model_info);
|
||||
EXPECT_EQ(kTestPath, service_controller()->language_detection_model_path());
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceAssetManagerTest, NotRegisteredWhenDisabledByEnterprisePolicy) {
|
||||
CreateAssetManager();
|
||||
model_provider()->Reset();
|
||||
local_state()->SetInteger(
|
||||
model_execution::prefs::localstate::
|
||||
kGenAILocalFoundationalModelEnterprisePolicySettings,
|
||||
static_cast<int>(model_execution::prefs::
|
||||
GenAILocalFoundationalModelEnterprisePolicySettings::
|
||||
kDisallowed));
|
||||
CreateAssetManager();
|
||||
EXPECT_FALSE(model_provider()->was_registered());
|
||||
|
||||
// Reset manager to make sure removing observer doesn't crash.
|
||||
Reset();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace optimization_guide
|
@ -238,6 +238,10 @@ class OptimizationGuideService
|
||||
scoped_refptr<optimization_guide::OnDeviceModelComponentStateManager>
|
||||
on_device_model_state_manager_;
|
||||
|
||||
// Downloads other model assets for on-device execution.
|
||||
std::unique_ptr<optimization_guide::OnDeviceAssetManager>
|
||||
on_device_asset_manager_;
|
||||
|
||||
#endif
|
||||
|
||||
// Manages the model execution. Not created for off the record profiles.
|
||||
|
@ -40,6 +40,7 @@
|
||||
#import "services/network/public/cpp/shared_url_loader_factory.h"
|
||||
|
||||
#if BUILDFLAG(BUILD_WITH_INTERNAL_OPTIMIZATION_GUIDE)
|
||||
#import "components/optimization_guide/core/model_execution/on_device_asset_manager.h"
|
||||
#import "components/optimization_guide/core/model_execution/on_device_model_component.h"
|
||||
#import "ios/chrome/browser/optimization_guide/model/on_device_model_service_controller_ios.h"
|
||||
#endif // BUILDFLAG(BUILD_WITH_INTERNAL_OPTIMIZATION_GUIDE)
|
||||
@ -173,9 +174,9 @@ OptimizationGuideService::OptimizationGuideService(
|
||||
}
|
||||
|
||||
if (!off_the_record_) {
|
||||
#if BUILDFLAG(BUILD_WITH_INTERNAL_OPTIMIZATION_GUIDE)
|
||||
PrefService* local_state = GetApplicationContext()->GetLocalState();
|
||||
|
||||
#if BUILDFLAG(BUILD_WITH_INTERNAL_OPTIMIZATION_GUIDE)
|
||||
// Create and startup the on-device model's state manager.
|
||||
on_device_model_state_manager_ =
|
||||
optimization_guide::OnDeviceModelComponentStateManager::CreateOrGet(
|
||||
@ -193,17 +194,20 @@ OptimizationGuideService::OptimizationGuideService(
|
||||
on_device_model_service_controller =
|
||||
GetApplicationContext()->GetOnDeviceModelServiceController(
|
||||
on_device_model_state_manager_->GetWeakPtr());
|
||||
on_device_asset_manager_ =
|
||||
std::make_unique<optimization_guide::OnDeviceAssetManager>(
|
||||
local_state, on_device_model_service_controller->GetWeakPtr(),
|
||||
on_device_model_state_manager_->GetWeakPtr(), this);
|
||||
model_execution_manager_ =
|
||||
std::make_unique<optimization_guide::ModelExecutionManager>(
|
||||
url_loader_factory, local_state, identity_manager,
|
||||
std::move(on_device_model_service_controller), this,
|
||||
on_device_model_state_manager_->GetWeakPtr(),
|
||||
url_loader_factory, identity_manager,
|
||||
std::move(on_device_model_service_controller),
|
||||
optimization_guide_logger_.get(), nullptr);
|
||||
#else // BUILDFLAG(BUILD_WITH_INTERNAL_OPTIMIZATION_GUIDE)
|
||||
model_execution_manager_ =
|
||||
std::make_unique<optimization_guide::ModelExecutionManager>(
|
||||
url_loader_factory, local_state, identity_manager, nullptr, this,
|
||||
nullptr, optimization_guide_logger_.get(), nullptr);
|
||||
url_loader_factory, identity_manager, nullptr,
|
||||
optimization_guide_logger_.get(), nullptr);
|
||||
#endif // BUILDFLAG(BUILD_WITH_INTERNAL_OPTIMIZATION_GUIDE)
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user