0

[passage-embeddings] Improves test coverage and cleans up APIs

Adds and/or improves test coverage in:
- HistoryEmbeddingsService, for observing the embedder model metadata,
  initializing the storage with the metadata, and unregistering the
  observation on destruction.
- SchedulingEmbedder, for observing the embedder model metadata, setting
  the metadata and initializing the storage, and unregistering the
  observation on destruction.
- PassageEmbedderModelObserver, for observing the og target, notifying
  the PassageEmbeddingsServiceController, and unregistering the
  observation on destruction.
- Embedder, for returning the original passages in case of failure.
- PassageEmbeddingsServiceController, for handling empty passages set.

Deletes MlEmbedder, replacing it with an EmbedderRemoteProxy interface
implemented by PassageEmbeddingsServiceController and consumed by
SchedulingEmbedder.

Deletes SchedulingClientEmbedder, as it no longer is needed now that
the PassageEmbeddingsServiceController returns a ref to owned Embedder.

Introduces an EmbedderMetadataProvider interface implemented by
PassageEmbeddingsServiceControllera and consumed by
HistoryEmbeddingsService and SchedulingEmbedder.

Organizes passage_embeddings:: public constants, enums, and interfaces
in passage_embeddings_types.h.

Organizes passage_embeddings:: testing utilities in one place,
accessible via //components/passage_embeddings:test_support test target.

Hides passage_embeddings::SchedulingEmbedder in an internal target
//components/passage_embeddings:passage_embeddings_internal

Removes obsolete includes, cleans up public/private methods, and other
misc. clean-ups

BYPASS_LARGE_CHANGE_WARNING: Largely adding test coverage, moving files and making non-functional changes. It's difficult to do large refactoring in a piecemeal fashion.

Bug: 397906676
Change-Id: I19281d96fcd36206b86b1766fec6bd512c095dc1
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6285407
Reviewed-by: Orin Jaworski <orinj@chromium.org>
Reviewed-by: Sophie Chang <sophiechang@chromium.org>
Commit-Queue: Moe Ahmadi <mahmadi@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1423439}
This commit is contained in:
Moe Ahmadi
2025-02-21 15:00:20 -08:00
committed by Chromium LUCI CQ
parent 4938c402ab
commit 4a1ce004f2
43 changed files with 873 additions and 787 deletions

@ -36,6 +36,7 @@
#include "components/history_embeddings/mock_intent_classifier.h"
#include "components/network_session_configurator/common/network_switches.h"
#include "components/optimization_guide/proto/features/common_quality_data.pb.h"
#include "components/passage_embeddings/passage_embeddings_test_util.h"
#include "content/public/browser/web_contents.h"
#include "content/public/common/content_switches.h"
#include "content/public/test/browser_test.h"
@ -69,10 +70,13 @@ class AiDataKeyedServiceBrowserTest : public InProcessBrowserTest {
HistoryEmbeddingsServiceFactory::GetInstance()->SetTestingFactory(
browser()->profile(),
base::BindLambdaForTesting([](content::BrowserContext* context) {
base::BindLambdaForTesting([this](content::BrowserContext* context) {
return HistoryEmbeddingsServiceFactory::
BuildServiceInstanceForBrowserContextForTesting(
context, std::make_unique<history_embeddings::MockAnswerer>(),
context,
passage_embeddings_test_env_.embedder_metadata_provider(),
passage_embeddings_test_env_.embedder(),
std::make_unique<history_embeddings::MockAnswerer>(),
std::make_unique<history_embeddings::MockIntentClassifier>());
}));
}
@ -117,6 +121,7 @@ class AiDataKeyedServiceBrowserTest : public InProcessBrowserTest {
private:
autofill::test::AutofillBrowserTestEnvironment autofill_test_environment_;
passage_embeddings::TestEnvironment passage_embeddings_test_env_;
std::unique_ptr<net::EmbeddedTestServer> https_server_ =
std::make_unique<net::EmbeddedTestServer>(
net::EmbeddedTestServer::TYPE_HTTPS);

@ -17,7 +17,6 @@
#include "components/optimization_guide/core/model_quality/model_quality_log_entry.h"
#include "components/optimization_guide/proto/features/history_query.pb.h"
#include "components/optimization_guide/proto/model_quality_service.pb.h"
#include "components/passage_embeddings/passage_embeddings_service_controller.h"
namespace history_embeddings {
@ -27,16 +26,16 @@ ChromeHistoryEmbeddingsService::ChromeHistoryEmbeddingsService(
page_content_annotations::PageContentAnnotationsService*
page_content_annotations_service,
optimization_guide::OptimizationGuideDecider* optimization_guide_decider,
passage_embeddings::PassageEmbeddingsServiceController* service_controller,
std::unique_ptr<passage_embeddings::Embedder> embedder,
passage_embeddings::EmbedderMetadataProvider* embedder_metadata_provider,
passage_embeddings::Embedder* embedder,
std::unique_ptr<Answerer> answerer,
std::unique_ptr<IntentClassifier> intent_classifier)
: HistoryEmbeddingsService(g_browser_process->os_crypt_async(),
history_service,
page_content_annotations_service,
optimization_guide_decider,
service_controller,
std::move(embedder),
embedder_metadata_provider,
embedder,
std::move(answerer),
std::move(intent_classifier)),
profile_(profile) {}

@ -9,8 +9,7 @@
#include "base/no_destructor.h"
#include "chrome/browser/profiles/profile_keyed_service_factory.h"
#include "components/history_embeddings/history_embeddings_service.h"
#include "components/passage_embeddings/embedder.h"
#include "components/passage_embeddings/passage_embeddings_service_controller.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "content/public/browser/browser_context.h"
class Profile;
@ -27,9 +26,8 @@ class ChromeHistoryEmbeddingsService : public HistoryEmbeddingsService {
page_content_annotations::PageContentAnnotationsService*
page_content_annotations_service,
optimization_guide::OptimizationGuideDecider* optimization_guide_decider,
passage_embeddings::PassageEmbeddingsServiceController*
service_controller,
std::unique_ptr<passage_embeddings::Embedder> embedder,
passage_embeddings::EmbedderMetadataProvider* embedder_metadata_provider,
passage_embeddings::Embedder* embedder,
std::unique_ptr<Answerer> answerer,
std::unique_ptr<IntentClassifier> intent_classifier);
explicit ChromeHistoryEmbeddingsService(const HistoryEmbeddingsService&) =

@ -32,6 +32,7 @@
#include "components/page_content_annotations/core/page_content_annotations_features.h"
#include "components/page_content_annotations/core/page_content_annotations_service.h"
#include "components/page_content_annotations/core/test_page_content_annotator.h"
#include "components/passage_embeddings/passage_embeddings_test_util.h"
#include "content/public/browser/weak_document_ptr.h"
#include "content/public/test/browser_test.h"
@ -58,13 +59,15 @@ class HistoryEmbeddingsBrowserTest : public InProcessBrowserTest {
HistoryEmbeddingsServiceFactory::GetInstance()->SetTestingFactory(
browser()->profile(),
base::BindLambdaForTesting([](content::BrowserContext* context) {
base::BindLambdaForTesting([this](content::BrowserContext* context) {
return HistoryEmbeddingsServiceFactory::
BuildServiceInstanceForBrowserContextForTesting(
context, std::make_unique<MockAnswerer>(),
context,
passage_embeddings_test_env_.embedder_metadata_provider(),
passage_embeddings_test_env_.embedder(),
std::make_unique<MockAnswerer>(),
std::make_unique<MockIntentClassifier>());
}));
service()->EmbedderMetadataUpdated({1, 768});
HistoryEmbeddingsTabHelper::CreateForWebContents(GetActiveWebContents());
@ -136,6 +139,7 @@ class HistoryEmbeddingsBrowserTest : public InProcessBrowserTest {
private:
page_content_annotations::TestPageContentAnnotator page_content_annotator_;
passage_embeddings::TestEnvironment passage_embeddings_test_env_;
};
class HistoryEmbeddingsRestrictedSigninBrowserTest

@ -26,8 +26,6 @@
#include "components/history_embeddings/mock_answerer.h"
#include "components/history_embeddings/mock_intent_classifier.h"
#include "components/keyed_service/core/service_access_type.h"
#include "components/passage_embeddings/ml_embedder.h"
#include "components/passage_embeddings/mock_embedder.h"
#if BUILDFLAG(IS_CHROMEOS_ASH)
#include "chrome/browser/ash/profiles/profile_helper.h"
@ -85,6 +83,9 @@ HistoryEmbeddingsServiceFactory::GetInstance() {
std::unique_ptr<KeyedService> HistoryEmbeddingsServiceFactory::
BuildServiceInstanceForBrowserContextForTesting(
content::BrowserContext* context,
passage_embeddings::EmbedderMetadataProvider*
embedder_metadata_provider,
passage_embeddings::Embedder* embedder,
std::unique_ptr<history_embeddings::Answerer> answerer,
std::unique_ptr<history_embeddings::IntentClassifier>
intent_classifier) {
@ -93,17 +94,14 @@ std::unique_ptr<KeyedService> HistoryEmbeddingsServiceFactory::
return nullptr;
}
std::unique_ptr<passage_embeddings::Embedder> embedder =
std::make_unique<passage_embeddings::MockEmbedder>();
return std::make_unique<history_embeddings::ChromeHistoryEmbeddingsService>(
profile,
HistoryServiceFactory::GetForProfile(profile,
ServiceAccessType::EXPLICIT_ACCESS),
PageContentAnnotationsServiceFactory::GetForProfile(profile),
OptimizationGuideKeyedServiceFactory::GetForProfile(profile),
passage_embeddings::ChromePassageEmbeddingsServiceController::Get(),
std::move(embedder), std::move(answerer), std::move(intent_classifier));
embedder_metadata_provider, embedder, std::move(answerer),
std::move(intent_classifier));
}
HistoryEmbeddingsServiceFactory::HistoryEmbeddingsServiceFactory()
@ -135,9 +133,8 @@ HistoryEmbeddingsServiceFactory::BuildServiceInstanceForBrowserContext(
OptimizationGuideKeyedService* optimization_guide_keyed_service =
OptimizationGuideKeyedServiceFactory::GetForProfile(profile);
std::unique_ptr<passage_embeddings::Embedder> embedder =
passage_embeddings::ChromePassageEmbeddingsServiceController::Get()
->MakeEmbedder();
auto* passage_embeddings_service_controller =
passage_embeddings::ChromePassageEmbeddingsServiceController::Get();
std::unique_ptr<history_embeddings::Answerer> answerer;
if (history_embeddings::IsHistoryEmbeddingsAnswersFeatureEnabled()) {
@ -166,7 +163,7 @@ HistoryEmbeddingsServiceFactory::BuildServiceInstanceForBrowserContext(
HistoryServiceFactory::GetForProfile(profile,
ServiceAccessType::EXPLICIT_ACCESS),
PageContentAnnotationsServiceFactory::GetForProfile(profile),
optimization_guide_keyed_service,
passage_embeddings::ChromePassageEmbeddingsServiceController::Get(),
std::move(embedder), std::move(answerer), std::move(intent_classifier));
optimization_guide_keyed_service, passage_embeddings_service_controller,
passage_embeddings_service_controller->GetEmbedder(), std::move(answerer),
std::move(intent_classifier));
}

@ -9,7 +9,6 @@
#include "base/no_destructor.h"
#include "chrome/browser/profiles/profile_keyed_service_factory.h"
#include "components/passage_embeddings/embedder.h"
#include "content/public/browser/browser_context.h"
namespace history_embeddings {
@ -18,6 +17,11 @@ class HistoryEmbeddingsService;
class IntentClassifier;
} // namespace history_embeddings
namespace passage_embeddings {
class Embedder;
class EmbedderMetadataProvider;
} // namespace passage_embeddings
class HistoryEmbeddingsServiceFactory : public ProfileKeyedServiceFactory {
public:
static history_embeddings::HistoryEmbeddingsService* GetForProfile(
@ -28,6 +32,8 @@ class HistoryEmbeddingsServiceFactory : public ProfileKeyedServiceFactory {
static std::unique_ptr<KeyedService>
BuildServiceInstanceForBrowserContextForTesting(
content::BrowserContext* context,
passage_embeddings::EmbedderMetadataProvider* embedder_metadata_provider,
passage_embeddings::Embedder* embedder,
std::unique_ptr<history_embeddings::Answerer> answerer,
std::unique_ptr<history_embeddings::IntentClassifier> intent_classifier);

@ -1 +1 @@
file://components/optimization_guide/OWNERS
file://components/passage_embeddings/OWNERS

@ -25,7 +25,6 @@
#include "content/public/browser/weak_document_ptr.h"
#include "content/public/browser/web_contents.h"
#include "mojo/public/cpp/bindings/callback_helpers.h"
#include "services/passage_embeddings/public/mojom/passage_embeddings.mojom.h"
#include "services/service_manager/public/cpp/interface_provider.h"
#include "url/gurl.h"
@ -56,9 +55,7 @@ void OnGotEmbeddings(base::ElapsedTimer embeddings_computation_timer,
} // namespace
EmbedderTabObserver::EmbedderTabObserver(content::WebContents* web_contents)
: content::WebContentsObserver(web_contents),
embedder_(
ChromePassageEmbeddingsServiceController::Get()->MakeEmbedder()) {}
: content::WebContentsObserver(web_contents) {}
EmbedderTabObserver::~EmbedderTabObserver() = default;
@ -160,10 +157,12 @@ void EmbedderTabObserver::OnGotPassages(
<< total_text_size;
base::ElapsedTimer embeddings_computation_timer;
embedder_->ComputePassagesEmbeddings(
PassagePriority::kPassive, std::move(passages),
base::BindOnce(&OnGotEmbeddings,
std::move(embeddings_computation_timer)));
ChromePassageEmbeddingsServiceController::Get()
->GetEmbedder()
->ComputePassagesEmbeddings(
PassagePriority::kPassive, std::move(passages),
base::BindOnce(&OnGotEmbeddings,
std::move(embeddings_computation_timer)));
}
Profile* EmbedderTabObserver::GetProfile() {

@ -7,7 +7,6 @@
#include "base/memory/weak_ptr.h"
#include "base/timer/elapsed_timer.h"
#include "components/passage_embeddings/embedder.h"
#include "content/public/browser/web_contents_observer.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "third_party/blink/public/mojom/content_extraction/inner_text.mojom.h"
@ -61,8 +60,6 @@ class EmbedderTabObserver : public content::WebContentsObserver {
const raw_ptr<content::WebContents> web_contents_;
std::unique_ptr<Embedder> embedder_;
// Used to cancel scheduled passage extraction.
base::WeakPtrFactory<EmbedderTabObserver> weak_ptr_factory_{this};
};

@ -28,6 +28,7 @@
#include "components/history_embeddings/history_embeddings_features.h"
#include "components/history_embeddings/history_embeddings_service.h"
#include "components/page_content_annotations/core/test_page_content_annotations_service.h"
#include "components/passage_embeddings/passage_embeddings_test_util.h"
#include "components/user_education/test/mock_feature_promo_controller.h"
#include "content/public/browser/browser_context.h"
#include "content/public/test/test_web_ui.h"
@ -63,10 +64,13 @@ class MockPage : public history_embeddings::mojom::Page {
} // namespace
std::unique_ptr<KeyedService> BuildTestHistoryEmbeddingsService(
passage_embeddings::TestEnvironment* passage_embeddings_test_env,
content::BrowserContext* browser_context) {
return HistoryEmbeddingsServiceFactory::
BuildServiceInstanceForBrowserContextForTesting(
browser_context,
passage_embeddings_test_env->embedder_metadata_provider(),
passage_embeddings_test_env->embedder(),
/*answerer=*/nullptr,
/*intent_classifier=*/nullptr);
}
@ -118,7 +122,8 @@ class HistoryEmbeddingsHandlerTest : public BrowserWithTestWindowTest {
HistoryServiceFactory::GetDefaultFactory()},
TestingProfile::TestingFactory{
HistoryEmbeddingsServiceFactory::GetInstance(),
base::BindRepeating(&BuildTestHistoryEmbeddingsService)},
base::BindRepeating(&BuildTestHistoryEmbeddingsService,
&passage_embeddings_test_env_)},
TestingProfile::TestingFactory{
PageContentAnnotationsServiceFactory::GetInstance(),
base::BindRepeating(&BuildTestPageContentAnnotationsService)},
@ -168,6 +173,7 @@ class HistoryEmbeddingsHandlerTest : public BrowserWithTestWindowTest {
base::test::ScopedFeatureList feature_list_;
std::unique_ptr<content::WebContents> web_contents_;
content::TestWebUI web_ui_;
passage_embeddings::TestEnvironment passage_embeddings_test_env_;
std::unique_ptr<HistoryEmbeddingsHandler> handler_;
testing::NiceMock<MockPage> page_;
raw_ptr<MockHatsService> mock_hats_service_;

@ -20,6 +20,7 @@
#include "components/page_content_annotations/core/page_content_annotations_features.h"
#include "components/page_content_annotations/core/page_content_annotations_service.h"
#include "components/page_content_annotations/core/test_page_content_annotator.h"
#include "components/passage_embeddings/passage_embeddings_test_util.h"
#include "content/public/test/browser_test.h"
namespace {
@ -45,12 +46,14 @@ class HistoryEmbeddingsInteractiveTest
void SetUpOnMainThread() override {
HistoryEmbeddingsServiceFactory::GetInstance()->SetTestingFactory(
browser()->profile(),
base::BindLambdaForTesting([](content::BrowserContext* context) {
base::BindLambdaForTesting([this](content::BrowserContext* context) {
return HistoryEmbeddingsServiceFactory::
BuildServiceInstanceForBrowserContextForTesting(
context, /*answerer=*/nullptr, /*intent_classifier=*/nullptr);
context,
passage_embeddings_test_env_.embedder_metadata_provider(),
passage_embeddings_test_env_.embedder(),
/*answerer=*/nullptr, /*intent_classifier=*/nullptr);
}));
service()->EmbedderMetadataUpdated({1, 768});
InteractiveBrowserTest::SetUpOnMainThread();
}
@ -84,6 +87,7 @@ class HistoryEmbeddingsInteractiveTest
private:
base::test::ScopedFeatureList scoped_feature_list_;
page_content_annotations::TestPageContentAnnotator page_content_annotator_;
passage_embeddings::TestEnvironment passage_embeddings_test_env_;
};
// Opening the feedback dialog on CrOS & LaCrOS open a system level dialog,

@ -2337,7 +2337,6 @@ if (!is_android) {
"//components/optimization_guide/core:bloomfilter",
"//components/optimization_guide/core:prediction",
"//components/optimization_guide/core:test_support",
"//components/optimization_guide/core:test_support",
"//components/optimization_guide/proto:optimization_guide_proto",
"//components/origin_trials:browser",
"//components/origin_trials/test",
@ -2351,6 +2350,8 @@ if (!is_android) {
"//components/page_load_metrics/common",
"//components/page_load_metrics/common:test_support",
"//components/page_load_metrics/google/browser",
"//components/passage_embeddings",
"//components/passage_embeddings:test_support",
"//components/password_manager/content/browser",
"//components/password_manager/content/common",
"//components/password_manager/core/browser/features:password_features",
@ -7901,6 +7902,8 @@ test("unit_tests") {
"//chrome/browser/ui/webui/signin:unit_tests",
"//components/autofill_ai/core/browser:browser",
"//components/data_sharing:test_support",
"//components/passage_embeddings",
"//components/passage_embeddings:test_support",
]
if (is_chrome_branded) {
@ -10880,6 +10883,8 @@ if (!is_android && !is_chromeos_device) {
"//components/media_router/browser:test_support",
"//components/metrics:content",
"//components/os_crypt/sync:test_support",
"//components/passage_embeddings",
"//components/passage_embeddings:test_support",
"//components/password_manager/content/browser",
"//components/plus_addresses",
"//components/plus_addresses:test_support",

@ -91,6 +91,7 @@ source_set("unit_tests") {
"//components/os_crypt/async/browser:test_support",
"//components/page_content_annotations/core:test_support",
"//components/passage_embeddings",
"//components/passage_embeddings:test_support",
"//mojo/public/cpp/bindings",
"//services/passage_embeddings/public/mojom",
"//testing/gtest",

@ -33,7 +33,6 @@
#include "components/optimization_guide/core/optimization_guide_decider.h"
#include "components/os_crypt/async/browser/os_crypt_async.h"
#include "components/page_content_annotations/core/page_content_annotations_service.h"
#include "components/passage_embeddings/passage_embeddings_service_controller.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "url/gurl.h"
@ -228,15 +227,15 @@ HistoryEmbeddingsService::HistoryEmbeddingsService(
page_content_annotations::PageContentAnnotationsService*
page_content_annotations_service,
optimization_guide::OptimizationGuideDecider* optimization_guide_decider,
passage_embeddings::PassageEmbeddingsServiceController* service_controller,
std::unique_ptr<passage_embeddings::Embedder> embedder,
passage_embeddings::EmbedderMetadataProvider* embedder_metadata_provider,
passage_embeddings::Embedder* embedder,
std::unique_ptr<Answerer> answerer,
std::unique_ptr<IntentClassifier> intent_classifier)
: os_crypt_async_(os_crypt_async),
history_service_(history_service),
page_content_annotations_service_(page_content_annotations_service),
optimization_guide_decider_(optimization_guide_decider),
embedder_(std::move(embedder)),
embedder_(embedder),
answerer_(std::move(answerer)),
intent_classifier_(std::move(intent_classifier)),
query_id_weak_ptr_factory_(&query_id_),
@ -267,8 +266,8 @@ HistoryEmbeddingsService::HistoryEmbeddingsService(
// Observation needs to be set up after the `storage_` construction since the
// update notification could be invoked immediately.
if (service_controller) {
embedder_metadata_observation_.Observe(service_controller);
if (embedder_metadata_provider) {
embedder_metadata_observation_.Observe(embedder_metadata_provider);
}
}
@ -595,7 +594,7 @@ void HistoryEmbeddingsService::EmbedderMetadataUpdated(
return;
}
embedder_metadata_ = metadata;
subscription_ = os_crypt_async_->GetInstance(
os_crypt_async_subscription_ = os_crypt_async_->GetInstance(
base::BindOnce(&HistoryEmbeddingsService::OnOsCryptAsyncReady,
weak_ptr_factory_.GetWeakPtr()));
}

@ -33,8 +33,6 @@
#include "components/optimization_guide/core/model_quality/model_quality_log_entry.h"
#include "components/optimization_guide/proto/features/common_quality_data.pb.h"
#include "components/os_crypt/async/common/encryptor.h"
#include "components/passage_embeddings/embedder.h"
#include "components/passage_embeddings/passage_embeddings_service_controller.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
namespace optimization_guide {
@ -164,9 +162,8 @@ class HistoryEmbeddingsService
page_content_annotations::PageContentAnnotationsService*
page_content_annotations_service,
optimization_guide::OptimizationGuideDecider* optimization_guide_decider,
passage_embeddings::PassageEmbeddingsServiceController*
service_controller,
std::unique_ptr<passage_embeddings::Embedder> embedder,
passage_embeddings::EmbedderMetadataProvider* embedder_metadata_provider,
passage_embeddings::Embedder* embedder,
std::unique_ptr<Answerer> answerer,
std::unique_ptr<IntentClassifier> intent_classifier);
HistoryEmbeddingsService(const HistoryEmbeddingsService&) = delete;
@ -227,12 +224,6 @@ class HistoryEmbeddingsService
void OnHistoryDeletions(history::HistoryService* history_service,
const history::DeletionInfo& deletion_info) override;
// EmbedderMetadataObserver:
// Called when the embedder metadata is available. Passes the metadata to
// the internal storage.
void EmbedderMetadataUpdated(
passage_embeddings::EmbedderMetadata metadata) override;
// This can be overridden to gate answer generation for some accounts.
virtual bool IsAnswererUseAllowed() const;
@ -318,6 +309,11 @@ class HistoryEmbeddingsService
SqlDatabase sql_database;
};
// passage_embeddings::EmbedderMetadataObserver:
// Passes the metadata to the internal storage.
void EmbedderMetadataUpdated(
passage_embeddings::EmbedderMetadata metadata) override;
void OnOsCryptAsyncReady(os_crypt_async::Encryptor encryptor, bool success);
// This can be overridden to prepare a log entry that will then be filled
@ -428,8 +424,8 @@ class HistoryEmbeddingsService
history::HistoryServiceObserver>
history_service_observation_{this};
// The embedder used to compute embeddings.
std::unique_ptr<passage_embeddings::Embedder> embedder_;
// The embedder used to compute embeddings. Outlives this.
raw_ptr<passage_embeddings::Embedder> embedder_;
// The answerer used to answer queries with context. May be nullptr if
// the kHistoryEmbeddingsAnswers feature is disabled.
@ -438,7 +434,8 @@ class HistoryEmbeddingsService
// The intent classifier used to determine query intent and answerability.
std::unique_ptr<IntentClassifier> intent_classifier_;
// Metadata about the embedder.
// Metadata about the embedder; Set when valid metadata is received from
// `embedder_metadata_provider`.
passage_embeddings::EmbedderMetadata embedder_metadata_{0, 0};
// Storage is bound to a separate sequence.
@ -462,11 +459,12 @@ class HistoryEmbeddingsService
passage_embeddings::Embedder::TaskId query_embedding_task_id_ =
passage_embeddings::Embedder::kInvalidTaskId;
base::CallbackListSubscription subscription_;
// Callback subscription for receiving OsCryptAsync ready event.
base::CallbackListSubscription os_crypt_async_subscription_;
base::ScopedObservation<
passage_embeddings::PassageEmbeddingsServiceController,
passage_embeddings::EmbedderMetadataObserver>
// Scoped observation for when the embedder metadata is available.
base::ScopedObservation<passage_embeddings::EmbedderMetadataProvider,
passage_embeddings::EmbedderMetadataObserver>
embedder_metadata_observation_{this};
base::WeakPtrFactory<std::atomic<size_t>> query_id_weak_ptr_factory_;

@ -38,10 +38,8 @@
#include "components/os_crypt/async/browser/test_utils.h"
#include "components/page_content_annotations/core/test_page_content_annotations_service.h"
#include "components/page_content_annotations/core/test_page_content_annotator.h"
#include "components/passage_embeddings/embedder.h"
#include "components/passage_embeddings/mock_embedder.h"
#include "components/passage_embeddings/passage_embeddings_service_controller.h"
#include "components/passage_embeddings/scheduling_embedder.h"
#include "components/passage_embeddings/passage_embeddings_test_util.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace history_embeddings {
@ -68,17 +66,16 @@ class HistoryEmbeddingsServicePublic : public HistoryEmbeddingsService {
page_content_annotations::PageContentAnnotationsService*
page_content_annotations_service,
optimization_guide::OptimizationGuideDecider* optimization_guide_decider,
passage_embeddings::PassageEmbeddingsServiceController*
service_controller,
std::unique_ptr<passage_embeddings::Embedder> embedder,
passage_embeddings::EmbedderMetadataProvider* embedder_metadata_provider,
passage_embeddings::Embedder* embedder,
std::unique_ptr<Answerer> answerer,
std::unique_ptr<IntentClassifier> intent_classfier)
: HistoryEmbeddingsService(os_crypt_async,
history_service,
page_content_annotations_service,
optimization_guide_decider,
service_controller,
std::move(embedder),
embedder_metadata_provider,
embedder,
std::move(answerer),
std::move(intent_classfier)) {}
@ -90,7 +87,6 @@ class HistoryEmbeddingsServicePublic : public HistoryEmbeddingsService {
using HistoryEmbeddingsService::RebuildAbsentEmbeddings;
using HistoryEmbeddingsService::answerer_;
using HistoryEmbeddingsService::embedder_;
using HistoryEmbeddingsService::embedder_metadata_;
using HistoryEmbeddingsService::intent_classifier_;
using HistoryEmbeddingsService::storage_;
@ -107,10 +103,10 @@ class HistoryEmbeddingsServiceTest : public testing::Test {
SetFeatureParametersForTesting(feature_parameters);
CHECK(history_dir_.CreateUniqueTempDir());
history_service_ =
history::CreateHistoryService(history_dir_.GetPath(), true);
CHECK(history_service_);
os_crypt_ = os_crypt_async::GetTestOSCryptAsyncForTesting(
/*is_sync_for_unittests=*/true);
@ -126,11 +122,11 @@ class HistoryEmbeddingsServiceTest : public testing::Test {
os_crypt_.get(), history_service_.get(),
page_content_annotations_service_.get(),
/*optimization_guide_decider=*/nullptr,
/*service_controller=*/nullptr,
std::make_unique<passage_embeddings::MockEmbedder>(),
passage_embeddings_test_env_.embedder_metadata_provider(),
passage_embeddings_test_env_.embedder(),
std::make_unique<MockAnswerer>(),
std::make_unique<MockIntentClassifier>());
service_->EmbedderMetadataUpdated({1, 768});
ASSERT_TRUE(service_->embedder_metadata_.IsValid());
ASSERT_TRUE(listener()->filter_words_hashes().empty());
listener()->OnSearchStringsUpdate(
@ -195,7 +191,7 @@ class HistoryEmbeddingsServiceTest : public testing::Test {
service_->OnPassagesEmbeddingsComputed(
std::move(url_passages), std::move(passages),
std::move(passages_embeddings),
passage_embeddings::SchedulingEmbedder::kInvalidTaskId, status);
passage_embeddings::Embedder::kInvalidTaskId, status);
}
void SetMetadataScoreThreshold(double threshold) {
@ -231,6 +227,7 @@ class HistoryEmbeddingsServiceTest : public testing::Test {
optimization_guide_decider_;
std::unique_ptr<page_content_annotations::TestPageContentAnnotationsService>
page_content_annotations_service_;
passage_embeddings::TestEnvironment passage_embeddings_test_env_;
page_content_annotations::TestPageContentAnnotator page_content_annotator_;
std::unique_ptr<HistoryEmbeddingsServicePublic> service_;
};

@ -4,22 +4,21 @@
#include "components/history_embeddings/mock_history_embeddings_service.h"
#include "components/passage_embeddings/mock_embedder.h"
namespace history_embeddings {
MockHistoryEmbeddingsService::MockHistoryEmbeddingsService(
os_crypt_async::OSCryptAsync* os_crypt_async,
history::HistoryService* history_service)
: HistoryEmbeddingsService(
os_crypt_async,
history_service,
/*page_content_annotations_service=*/nullptr,
/*optimization_guide_decider=*/nullptr,
/*service_controller=*/nullptr,
std::make_unique<passage_embeddings::MockEmbedder>(),
/*answerer=*/nullptr,
/*intent_classifier=*/nullptr) {}
history::HistoryService* history_service,
passage_embeddings::EmbedderMetadataProvider* embedder_metadata_provider,
passage_embeddings::Embedder* embedder)
: HistoryEmbeddingsService(os_crypt_async,
history_service,
/*page_content_annotations_service=*/nullptr,
/*optimization_guide_decider=*/nullptr,
embedder_metadata_provider,
embedder,
/*answerer=*/nullptr,
/*intent_classifier=*/nullptr) {}
MockHistoryEmbeddingsService::~MockHistoryEmbeddingsService() = default;

@ -11,6 +11,7 @@
#include "base/time/time.h"
#include "components/history/core/browser/history_service.h"
#include "components/history_embeddings/history_embeddings_service.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "testing/gmock/include/gmock/gmock.h"
namespace os_crypt_async {
@ -32,7 +33,9 @@ class MockHistoryEmbeddingsService : public HistoryEmbeddingsService {
(override));
explicit MockHistoryEmbeddingsService(
os_crypt_async::OSCryptAsync* os_crypt_async,
history::HistoryService* history_service);
history::HistoryService* history_service,
passage_embeddings::EmbedderMetadataProvider* embedder_metadata_provider,
passage_embeddings::Embedder* embedder);
~MockHistoryEmbeddingsService() override;
};

@ -17,7 +17,6 @@
#include "components/history_embeddings/proto/history_embeddings.pb.h"
#include "components/history_embeddings/vector_database.h"
#include "components/os_crypt/async/common/encryptor.h"
#include "components/passage_embeddings/embedder.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "sql/database.h"
#include "sql/init_status.h"

@ -18,7 +18,7 @@
#include "components/history_embeddings/proto/history_embeddings.pb.h"
#include "components/os_crypt/async/browser/test_utils.h"
#include "components/os_crypt/async/common/encryptor.h"
#include "components/passage_embeddings/embedder.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace history_embeddings {

@ -13,7 +13,7 @@
#include "components/history/core/browser/history_types.h"
#include "components/history_embeddings/proto/history_embeddings.pb.h"
#include "components/keyed_service/core/keyed_service.h"
#include "components/passage_embeddings/embedder.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
namespace history_embeddings {

@ -17,7 +17,7 @@
#include "base/time/time.h"
#include "base/timer/elapsed_timer.h"
#include "components/history_embeddings/proto/history_embeddings.pb.h"
#include "components/passage_embeddings/embedder.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

@ -902,6 +902,8 @@ source_set("unit_tests") {
"//components/optimization_guide/core:test_support",
"//components/optimization_guide/proto:optimization_guide_proto",
"//components/os_crypt/async/browser",
"//components/passage_embeddings:passage_embeddings",
"//components/passage_embeddings:test_support",
"//components/prefs:test_support",
"//components/query_parser:query_parser",
"//components/safe_browsing/core/common:common",

@ -28,6 +28,7 @@ include_rules = [
"+components/optimization_guide/core",
"+components/optimization_guide/proto",
"+components/os_crypt/async/browser",
"+components/passage_embeddings",
"+components/pref_registry",
"+components/prefs",
"+components/query_parser",

@ -34,6 +34,7 @@
#include "components/optimization_guide/proto/features/history_answer.pb.h"
#include "components/os_crypt/async/browser/os_crypt_async.h"
#include "components/os_crypt/async/browser/test_utils.h"
#include "components/passage_embeddings/passage_embeddings_test_util.h"
#include "components/search_engines/template_url.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/metrics_proto/omnibox_event.pb.h"
@ -123,7 +124,9 @@ class HistoryEmbeddingsProviderTest : public testing::Test,
client_->set_history_embeddings_service(
std::make_unique<testing::NiceMock<
history_embeddings::MockHistoryEmbeddingsService>>(
os_crypt_.get(), client_->GetHistoryService()));
os_crypt_.get(), client_->GetHistoryService(),
passage_embeddings_test_env_.embedder_metadata_provider(),
passage_embeddings_test_env_.embedder()));
history_embeddings_service_ = static_cast<
testing::NiceMock<history_embeddings::MockHistoryEmbeddingsService>*>(
client_->GetHistoryEmbeddingsService());
@ -160,6 +163,7 @@ class HistoryEmbeddingsProviderTest : public testing::Test,
base::ScopedTempDir history_dir_;
std::unique_ptr<os_crypt_async::OSCryptAsync> os_crypt_;
base::test::TaskEnvironment task_environment_;
passage_embeddings::TestEnvironment passage_embeddings_test_env_;
std::unique_ptr<FakeAutocompleteProviderClient> client_;
raw_ptr<testing::NiceMock<history_embeddings::MockHistoryEmbeddingsService>>
history_embeddings_service_;

@ -4,14 +4,34 @@
import("//build/config/ui.gni")
static_library("passage_embeddings_types") {
sources = [
"passage_embeddings_types.cc",
"passage_embeddings_types.h",
]
deps = [
"//base",
"//services/passage_embeddings/public/mojom",
]
}
static_library("passage_embeddings_internal") {
sources = [
"internal/scheduling_embedder.cc",
"internal/scheduling_embedder.h",
]
deps = [
":passage_embeddings_types",
"//base",
"//components/performance_manager/scenario_api",
"//services/passage_embeddings/public/mojom",
]
}
static_library("passage_embeddings") {
sources = [
"embedder.cc",
"embedder.h",
"ml_embedder.cc",
"ml_embedder.h",
"mock_embedder.cc",
"mock_embedder.h",
"passage_embedder_model_observer.cc",
"passage_embedder_model_observer.h",
"passage_embeddings_features.cc",
@ -19,31 +39,48 @@ static_library("passage_embeddings") {
"passage_embeddings_service_controller.cc",
"passage_embeddings_service_controller.h",
"passage_embeddings_types.h",
"scheduling_embedder.cc",
"scheduling_embedder.h",
]
deps = [
":passage_embeddings_internal",
"//base",
"//build:blink_buildflags",
"//components/optimization_guide/core",
"//components/optimization_guide/proto:optimization_guide_proto",
"//components/performance_manager/scenario_api",
"//mojo/public/cpp/bindings",
"//services/passage_embeddings/public/mojom",
]
public_deps = [ ":passage_embeddings_types" ]
}
static_library("test_support") {
testonly = true
sources = [
"passage_embeddings_test_util.cc",
"passage_embeddings_test_util.h",
]
deps = [
":passage_embeddings",
"//base",
"//components/optimization_guide/core",
"//components/optimization_guide/core:test_support",
"//components/optimization_guide/proto:optimization_guide_proto",
]
}
source_set("unit_tests") {
testonly = true
sources = [
"internal/scheduling_embedder_unittest.cc",
"ml_embedder_unittest.cc",
"passage_embedder_model_observer_unittest.cc",
"scheduling_embedder_unittest.cc",
]
deps = [
":passage_embeddings",
":passage_embeddings_internal",
":test_support",
"//base/test:test_support",
"//components/history/core/browser",
"//components/history/core/test",

@ -1,97 +0,0 @@
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_PASSAGE_EMBEDDINGS_EMBEDDER_H_
#define COMPONENTS_PASSAGE_EMBEDDINGS_EMBEDDER_H_
#include <optional>
#include <string>
#include <vector>
#include "base/functional/callback.h"
#include "base/observer_list_types.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
namespace passage_embeddings {
class Embedding {
public:
explicit Embedding(std::vector<float> data);
Embedding(std::vector<float> data, size_t passage_word_count);
Embedding();
~Embedding();
Embedding(const Embedding&);
Embedding& operator=(const Embedding&);
Embedding(Embedding&&);
Embedding& operator=(Embedding&&);
bool operator==(const Embedding&) const;
// The number of elements in the data vector.
size_t Dimensions() const;
// The length of the vector.
float Magnitude() const;
// Scale the vector to unit length.
void Normalize();
// Compares one embedding with another and returns a similarity measure.
float ScoreWith(const Embedding& other_embedding) const;
// Const accessor used for storage.
const std::vector<float>& GetData() const { return data_; }
// Used for search filtering of passages with low word count.
size_t GetPassageWordCount() const { return passage_word_count_; }
void SetPassageWordCount(size_t passage_word_count) {
passage_word_count_ = passage_word_count;
}
private:
std::vector<float> data_;
size_t passage_word_count_ = 0;
};
class EmbedderMetadataObserver : public base::CheckedObserver {
public:
// This is notified when model metadata is updated.
virtual void EmbedderMetadataUpdated(EmbedderMetadata metadata) = 0;
};
// Base class that hides implementation details for how text is embedded.
class Embedder {
public:
using TaskId = uint64_t;
static constexpr TaskId kInvalidTaskId = 0;
virtual ~Embedder() = default;
// Computes embeddings for each entry in `passages`. Will invoke callback on
// done. If successful, it is guaranteed that the number of passages in
// `passages` will match the number of entries in the embeddings vector and in
// the same order. If unsuccessful, the callback will still return the
// original passages but an empty embeddings vector.
using ComputePassagesEmbeddingsCallback =
base::OnceCallback<void(std::vector<std::string> passages,
std::vector<Embedding> embeddings,
TaskId task_id,
ComputeEmbeddingsStatus status)>;
virtual TaskId ComputePassagesEmbeddings(
PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) = 0;
// Cancels computation of embeddings iff none of the passages given to
// `ComputePassagesEmbeddings()` has been submitted for embedding yet.
// If successful, the callback for the canceled task will be invoked with
// `ComputeEmbeddingsStatus::kCanceled` status.
virtual bool TryCancel(TaskId task_id) = 0;
protected:
Embedder() = default;
};
} // namespace passage_embeddings
#endif // COMPONENTS_PASSAGE_EMBEDDINGS_EMBEDDER_H_

@ -2,9 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/passage_embeddings/scheduling_embedder.h"
#include "components/passage_embeddings/internal/scheduling_embedder.h"
#include <atomic>
#include <memory>
#include <numeric>
#include <string>
@ -90,14 +89,19 @@ void SchedulingEmbedder::Job::Finish(ComputeEmbeddingsStatus status) {
////////////////////////////////////////////////////////////////////////////////
SchedulingEmbedder::SchedulingEmbedder(std::unique_ptr<Embedder> embedder,
size_t max_jobs,
size_t max_batch_size,
bool use_performance_scenario)
: embedder_(std::move(embedder)),
SchedulingEmbedder::SchedulingEmbedder(
EmbedderMetadataProvider* embedder_metadata_provider,
GetEmbeddingsCallback get_embeddings_callback,
size_t max_jobs,
size_t max_batch_size,
bool use_performance_scenario)
: get_embeddings_callback_(get_embeddings_callback),
max_jobs_(max_jobs),
max_batch_size_(max_batch_size),
use_performance_scenario_(use_performance_scenario) {
if (embedder_metadata_provider) {
embedder_metadata_observation_.Observe(embedder_metadata_provider);
}
if (use_performance_scenario_) {
performance_scenario_observation_.Observe(
performance_scenarios::PerformanceScenarioObserverList::GetForScope(
@ -201,8 +205,8 @@ void SchedulingEmbedder::SubmitWorkToEmbedder() {
}
work_submitted_ = true;
embedder_->ComputePassagesEmbeddings(
priority, std::move(passages),
get_embeddings_callback_.Run(
std::move(passages), priority,
base::BindOnce(&SchedulingEmbedder::OnEmbeddingsComputed,
weak_ptr_factory_.GetWeakPtr()));
}
@ -270,10 +274,17 @@ void SchedulingEmbedder::OnInputScenarioChanged(ScenarioScope scope,
SubmitWorkToEmbedder();
}
void SchedulingEmbedder::OnEmbeddingsComputed(std::vector<std::string> passages,
std::vector<Embedding> embeddings,
TaskId task_id,
ComputeEmbeddingsStatus status) {
void SchedulingEmbedder::OnEmbeddingsComputed(
std::vector<mojom::PassageEmbeddingsResultPtr> results,
ComputeEmbeddingsStatus status) {
std::vector<std::string> passages;
std::vector<Embedding> embeddings;
for (auto& result : results) {
passages.push_back(result->passage);
embeddings.emplace_back(result->embeddings);
embeddings.back().Normalize();
}
VLOG(3) << embeddings.size() << " embeddings computed for " << passages.size()
<< " passages with status " << static_cast<int>(status);
CHECK_EQ(passages.size(), embeddings.size());
@ -315,24 +326,4 @@ void SchedulingEmbedder::OnEmbeddingsComputed(std::vector<std::string> passages,
SubmitWorkToEmbedder();
}
////////////////////////////////////////////////////////////////////////////////
SchedulingClientEmbedder::SchedulingClientEmbedder(
SchedulingEmbedder* scheduling_embedder)
: scheduling_embedder_(scheduling_embedder) {}
SchedulingClientEmbedder::~SchedulingClientEmbedder() = default;
Embedder::TaskId SchedulingClientEmbedder::ComputePassagesEmbeddings(
passage_embeddings::PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) {
return scheduling_embedder_->ComputePassagesEmbeddings(
priority, std::move(passages), std::move(callback));
}
bool SchedulingClientEmbedder::TryCancel(TaskId task_id) {
return scheduling_embedder_->TryCancel(task_id);
}
} // namespace passage_embeddings

@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_PASSAGE_EMBEDDINGS_SCHEDULING_EMBEDDER_H_
#define COMPONENTS_PASSAGE_EMBEDDINGS_SCHEDULING_EMBEDDER_H_
#ifndef COMPONENTS_PASSAGE_EMBEDDINGS_INTERNAL_SCHEDULING_EMBEDDER_H_
#define COMPONENTS_PASSAGE_EMBEDDINGS_INTERNAL_SCHEDULING_EMBEDDER_H_
#include <memory>
#include <optional>
@ -16,70 +16,52 @@
#include "base/scoped_observation.h"
#include "base/time/time.h"
#include "base/timer/elapsed_timer.h"
#include "build/build_config.h"
#include "components/passage_embeddings/embedder.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "components/performance_manager/scenario_api/performance_scenario_observer.h"
#include "components/performance_manager/scenario_api/performance_scenarios.h"
#include "services/passage_embeddings/public/mojom/passage_embeddings.mojom.h"
namespace passage_embeddings {
// The SchedulingEmbedder wraps a primary embedder and adds scheduling control
// with batching and priorities so that high priority queries can be computed as
// soon as possible. Scheduling is also needed to avoid clogging the pipes for a
// slow remote embedder. Even single pages can take a while, and when the model
// changes, all existing passages need their embeddings recomputed, which can
// take a very long time and should be done at lower priority.
// The SchedulingEmbedder adds scheduling control with batching and priorities
// so that high priority queries can be computed as soon as possible. Scheduling
// is also needed to avoid clogging the pipes for a slow remote embedder. Even
// single pages can take a while, and when the model changes, all existing
// passages need their embeddings recomputed, which can take a very long time
// and should be done at lower priority.
class SchedulingEmbedder
: public Embedder,
public EmbedderMetadataObserver,
public performance_scenarios::PerformanceScenarioObserver {
public:
SchedulingEmbedder(std::unique_ptr<Embedder> embedder,
using GetEmbeddingsResultCallback = base::OnceCallback<void(
std::vector<mojom::PassageEmbeddingsResultPtr> results,
ComputeEmbeddingsStatus status)>;
using GetEmbeddingsCallback =
base::RepeatingCallback<void(std::vector<std::string> passages,
PassagePriority priority,
GetEmbeddingsResultCallback callback)>;
SchedulingEmbedder(EmbedderMetadataProvider* embedder_metadata_provider,
GetEmbeddingsCallback get_embeddings_callback,
size_t max_jobs,
size_t scheduled_max_batch_size,
bool use_performance_scenario);
~SchedulingEmbedder() override;
// Returns latest metadata; may be zero/invalid if embedder is not yet ready.
EmbedderMetadata GetEmbedderMetadata() const { return embedder_metadata_; }
// Embedder:
// Computes embeddings for each entry in `passages`. Will invoke callback on
// done. If successful, it is guaranteed that the number of passages in
// `passages` will match the number of entries in the embeddings vector and in
// the same order. If unsuccessful, the callback will still return the
// original passages but with an empty embeddings vector and an appropriate
// status.
TaskId ComputePassagesEmbeddings(
PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) override;
// Cancels computation of embeddings iff none of the passages given to
// `ComputePassagesEmbeddings()` has been submitted to the embedder yet.
// If successful, the callback for the canceled task will be invoked with
// `ComputeEmbeddingsStatus::kCanceled` status.
bool TryCancel(TaskId task_id) override;
// EmbedderMetadataObserver:
void EmbedderMetadataUpdated(EmbedderMetadata metadata) override;
// PerformanceScenarioObserver:
void OnLoadingScenarioChanged(
performance_scenarios::ScenarioScope scope,
performance_scenarios::LoadingScenario old_scenario,
performance_scenarios::LoadingScenario new_scenario) override;
void OnInputScenarioChanged(
performance_scenarios::ScenarioScope scope,
performance_scenarios::InputScenario old_scenario,
performance_scenarios::InputScenario new_scenario) override;
private:
friend class SchedulingEmbedderPublic;
// A job consists of multiple passages, and each passage must have its
// embedding computed. When all are finished, the job is done and its
// callback will be invoked. Multiple jobs may be batched together when
// when submitting work to the `embedder_`, and jobs can also be broken
// submitting work to the `embedder_remote_proxy`, and jobs can also be broken
// down so that partial progress is made across multiple work submissions.
struct Job {
Job(PassagePriority priority,
@ -109,12 +91,24 @@ class SchedulingEmbedder
base::ElapsedTimer timer;
};
// EmbedderMetadataObserver:
void EmbedderMetadataUpdated(EmbedderMetadata metadata) override;
// PerformanceScenarioObserver:
void OnLoadingScenarioChanged(
performance_scenarios::ScenarioScope scope,
performance_scenarios::LoadingScenario old_scenario,
performance_scenarios::LoadingScenario new_scenario) override;
void OnInputScenarioChanged(
performance_scenarios::ScenarioScope scope,
performance_scenarios::InputScenario old_scenario,
performance_scenarios::InputScenario new_scenario) override;
// Invoked after the embedding for the current job has been computed.
// Continues processing next job if one is pending.
void OnEmbeddingsComputed(std::vector<std::string> passages,
std::vector<Embedding> embedding,
TaskId task_id,
ComputeEmbeddingsStatus status);
void OnEmbeddingsComputed(
std::vector<mojom::PassageEmbeddingsResultPtr> results,
ComputeEmbeddingsStatus status);
// Stable-sort jobs by priority and submit a batch of work to embedder.
// This will only submit new work if the embedder is not already working.
@ -142,11 +136,12 @@ class SchedulingEmbedder
// embedder work submissions may be required to complete a job.
bool work_submitted_ = false;
// The primary embedder that does the actual embedding computations.
// This may be slow, and we await results before sending the next request.
std::unique_ptr<Embedder> embedder_;
// The callback that does the actual embeddings computations.
// May be slow; await results before sending the next request.
GetEmbeddingsCallback get_embeddings_callback_;
// Starts empty; set when valid metadata is received from `embedder_`.
// Metadata about the embedder; Set when valid metadata is received from
// `embedder_metadata_provider`.
EmbedderMetadata embedder_metadata_{0, 0};
// The maximum number of jobs to hold at once. Exceeding the cap
@ -164,31 +159,15 @@ class SchedulingEmbedder
base::ScopedObservation<
performance_scenarios::PerformanceScenarioObserverList,
SchedulingEmbedder>
performance_scenarios::PerformanceScenarioObserver>
performance_scenario_observation_{this};
base::ScopedObservation<EmbedderMetadataProvider, EmbedderMetadataObserver>
embedder_metadata_observation_{this};
base::WeakPtrFactory<SchedulingEmbedder> weak_ptr_factory_{this};
};
// This is a common use embedder type that simply routes its requests to
// a non-owned SchedulingEmbedder.
class SchedulingClientEmbedder : public Embedder {
public:
explicit SchedulingClientEmbedder(SchedulingEmbedder* embedder);
~SchedulingClientEmbedder() override;
// Embedder:
TaskId ComputePassagesEmbeddings(
PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) override;
bool TryCancel(TaskId task_id) override;
private:
raw_ptr<SchedulingEmbedder> scheduling_embedder_;
};
} // namespace passage_embeddings
#endif // COMPONENTS_PASSAGE_EMBEDDINGS_SCHEDULING_EMBEDDER_H_
#endif // COMPONENTS_PASSAGE_EMBEDDINGS_INTERNAL_SCHEDULING_EMBEDDER_H_

@ -2,77 +2,108 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/passage_embeddings/scheduling_embedder.h"
#include "components/passage_embeddings/internal/scheduling_embedder.h"
#include <memory>
#include <tuple>
#include "base/functional/bind.h"
#include "base/logging.h"
#include "base/task/sequenced_task_runner.h"
#include "base/test/bind.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "components/passage_embeddings/embedder.h"
#include "components/passage_embeddings/mock_embedder.h"
#include "components/passage_embeddings/passage_embeddings_service_controller.h"
#include "components/passage_embeddings/passage_embeddings_test_util.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace passage_embeddings {
namespace {
using ComputePassagesEmbeddingsFuture =
base::test::TestFuture<std::vector<std::string>,
std::vector<Embedding>,
SchedulingEmbedder::TaskId,
ComputeEmbeddingsStatus>;
class MockEmbedderWithDelay : public MockEmbedder {
void GetEmbeddings(std::vector<std::string> passages,
PassagePriority priority,
SchedulingEmbedder::GetEmbeddingsResultCallback callback) {
base::SequencedTaskRunner::GetCurrentDefault()->PostDelayedTask(
FROM_HERE,
base::BindOnce(
[](std::vector<std::string> passages,
SchedulingEmbedder::GetEmbeddingsResultCallback callback) {
std::vector<mojom::PassageEmbeddingsResultPtr> results;
for (const std::string& passage : passages) {
results.push_back(mojom::PassageEmbeddingsResult::New());
results.back()->embeddings =
std::vector<float>(kEmbeddingsModelOutputSize, 1.0);
results.back()->passage = passage;
}
std::move(callback).Run(std::move(results),
ComputeEmbeddingsStatus::kSuccess);
},
std::move(passages), std::move(callback)),
base::Seconds(1));
}
} // namespace
class SchedulingEmbedderPublic : public SchedulingEmbedder {
public:
static constexpr base::TimeDelta kTimeout = base::Seconds(1);
SchedulingEmbedderPublic(EmbedderMetadataProvider* embedder_metadata_provider,
GetEmbeddingsCallback get_embeddings_callback,
size_t max_jobs,
size_t scheduled_max_batch_size,
bool use_performance_scenario)
: SchedulingEmbedder(embedder_metadata_provider,
get_embeddings_callback,
max_jobs,
scheduled_max_batch_size,
use_performance_scenario) {}
MockEmbedderWithDelay() = default;
~MockEmbedderWithDelay() override = default;
// Embedder:
TaskId ComputePassagesEmbeddings(
PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) override {
base::SequencedTaskRunner::GetCurrentDefault()->PostDelayedTask(
FROM_HERE,
base::BindOnce(std::move(callback), std::move(passages),
ComputeEmbeddingsForPassages(passages), kInvalidTaskId,
ComputeEmbeddingsStatus::kSuccess),
kTimeout);
return kInvalidTaskId;
}
using SchedulingEmbedder::embedder_metadata_;
using SchedulingEmbedder::GetEmbeddingsCallback;
using SchedulingEmbedder::GetEmbeddingsResultCallback;
};
class SchedulingEmbedderTest : public testing::Test {
protected:
std::unique_ptr<SchedulingEmbedder> MakeEmbedder() {
auto embedder = std::make_unique<SchedulingEmbedder>(
std::make_unique<MockEmbedderWithDelay>(), 4u, 1u, false);
embedder->EmbedderMetadataUpdated(EmbedderMetadata{1, 768});
return embedder;
public:
void SetUp() override {
embedder_metadata_provider_ =
std::make_unique<TestEmbedderMetadataProvider>();
embedder_ = std::make_unique<SchedulingEmbedderPublic>(
/*embedder_metadata_provider=*/embedder_metadata_provider_.get(),
/*get_embeddings_callback=*/base::BindRepeating(&GetEmbeddings),
/*max_jobs=*/4u,
/*max_batch_size=*/1u,
/*use_performance_scenario=*/false);
ASSERT_TRUE(embedder_->embedder_metadata_.IsValid());
}
protected:
base::test::TaskEnvironment task_environment_{
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
base::HistogramTester histogram_tester_;
std::unique_ptr<EmbedderMetadataProvider> embedder_metadata_provider_;
std::unique_ptr<SchedulingEmbedderPublic> embedder_;
};
TEST_F(SchedulingEmbedderTest, UserInitiatedJobTakesPriority) {
auto embedder = MakeEmbedder();
// Submit a passive priority task.
ComputePassagesEmbeddingsFuture future_1;
auto expected_task_id_1 = embedder->ComputePassagesEmbeddings(
auto expected_task_id_1 = embedder_->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"test passage 1", "test passage 2"},
future_1.GetCallback());
// Submit a user-initiated priority task. This will suspend the partially
// completed passive priority task.
ComputePassagesEmbeddingsFuture future_2;
auto expected_task_id_2 = embedder->ComputePassagesEmbeddings(
auto expected_task_id_2 = embedder_->ComputePassagesEmbeddings(
PassagePriority::kUserInitiated, {"query"}, future_2.GetCallback());
// The user-initiated priority task finishes first.
@ -96,19 +127,17 @@ TEST_F(SchedulingEmbedderTest, UserInitiatedJobTakesPriority) {
}
TEST_F(SchedulingEmbedderTest, RecordsHistograms) {
auto embedder = MakeEmbedder();
ComputePassagesEmbeddingsFuture future1;
ComputePassagesEmbeddingsFuture future2;
ComputePassagesEmbeddingsFuture future3;
embedder->ComputePassagesEmbeddings(
embedder_->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"test passage 1"}, future1.GetCallback());
auto task_id = embedder->ComputePassagesEmbeddings(
auto task_id = embedder_->ComputePassagesEmbeddings(
PassagePriority::kUserInitiated, {"test passage 2a", "test passage 2b"},
future2.GetCallback());
embedder->ComputePassagesEmbeddings(
embedder_->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"test passage 3"}, future3.GetCallback());
embedder->TryCancel(task_id);
embedder_->TryCancel(task_id);
EXPECT_TRUE(future1.Wait());
EXPECT_TRUE(future2.Wait());
EXPECT_TRUE(future3.Wait());
@ -156,22 +185,21 @@ TEST_F(SchedulingEmbedderTest, RecordsHistograms) {
}
TEST_F(SchedulingEmbedderTest, LimitsJobCount) {
auto embedder = MakeEmbedder();
ComputePassagesEmbeddingsFuture future1;
ComputePassagesEmbeddingsFuture future2;
ComputePassagesEmbeddingsFuture future3;
ComputePassagesEmbeddingsFuture future4;
ComputePassagesEmbeddingsFuture future5;
embedder->ComputePassagesEmbeddings(
embedder_->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"test passage 1"}, future1.GetCallback());
embedder->ComputePassagesEmbeddings(
embedder_->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"test passage 2"}, future2.GetCallback());
embedder->ComputePassagesEmbeddings(
embedder_->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"test passage 3"}, future3.GetCallback());
embedder->ComputePassagesEmbeddings(
embedder_->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"test passage 4"}, future4.GetCallback());
embedder->ComputePassagesEmbeddings(
embedder_->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"test passage 5"}, future5.GetCallback());
// Final job interrupts the job at back of line when limit (4) is reached.

@ -1,47 +0,0 @@
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/passage_embeddings/ml_embedder.h"
#include "base/task/sequenced_task_runner.h"
#include "components/passage_embeddings/passage_embeddings_service_controller.h"
#include "services/passage_embeddings/public/mojom/passage_embeddings.mojom.h"
namespace passage_embeddings {
MlEmbedder::MlEmbedder(PassageEmbeddingsServiceController* service_controller)
: service_controller_(service_controller) {}
MlEmbedder::~MlEmbedder() = default;
Embedder::TaskId MlEmbedder::ComputePassagesEmbeddings(
PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) {
service_controller_->GetEmbeddings(
std::move(passages), priority,
base::BindOnce(
[](ComputePassagesEmbeddingsCallback callback,
std::vector<mojom::PassageEmbeddingsResultPtr> results,
ComputeEmbeddingsStatus status) {
std::vector<std::string> result_passages;
std::vector<Embedding> result_embeddings;
for (auto& result : results) {
result_passages.push_back(result->passage);
result_embeddings.emplace_back(result->embeddings);
result_embeddings.back().Normalize();
}
std::move(callback).Run(std::move(result_passages),
std::move(result_embeddings),
kInvalidTaskId, status);
},
std::move(callback)));
return kInvalidTaskId;
}
bool MlEmbedder::TryCancel(TaskId task_id) {
return false;
}
} // namespace passage_embeddings

@ -1,37 +0,0 @@
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_PASSAGE_EMBEDDINGS_ML_EMBEDDER_H_
#define COMPONENTS_PASSAGE_EMBEDDINGS_ML_EMBEDDER_H_
#include "base/memory/raw_ptr.h"
#include "components/passage_embeddings/embedder.h"
namespace passage_embeddings {
class PassageEmbeddingsServiceController;
// An embedder that returns embeddings from a machine learning model.
class MlEmbedder : public Embedder {
public:
explicit MlEmbedder(PassageEmbeddingsServiceController* service_controller);
~MlEmbedder() override;
// Embedder:
TaskId ComputePassagesEmbeddings(
PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) override;
bool TryCancel(TaskId task_id) override;
private:
// The controller used to interact with the PassageEmbeddingsService.
// It is a singleton and guaranteed not to be nullptr.
raw_ptr<PassageEmbeddingsServiceController> service_controller_;
};
} // namespace passage_embeddings
#endif // COMPONENTS_PASSAGE_EMBEDDINGS_ML_EMBEDDER_H_

@ -2,19 +2,19 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/passage_embeddings/ml_embedder.h"
#include <memory>
#include "base/memory/raw_ptr.h"
#include "base/path_service.h"
#include "base/scoped_observation.h"
#include "base/test/bind.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "components/optimization_guide/core/optimization_guide_proto_util.h"
#include "components/optimization_guide/core/test_model_info_builder.h"
#include "components/passage_embeddings/embedder.h"
#include "components/passage_embeddings/passage_embeddings_service_controller.h"
#include "components/passage_embeddings/passage_embeddings_test_util.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "services/passage_embeddings/public/mojom/passage_embeddings.mojom.h"
@ -24,46 +24,12 @@ namespace passage_embeddings {
namespace {
constexpr int64_t kEmbeddingsModelVersion = 1l;
constexpr uint32_t kEmbeddingsModelInputWindowSize = 256u;
constexpr size_t kEmbeddingsModelOutputSize = 768ul;
using ComputePassagesEmbeddingsFuture =
base::test::TestFuture<std::vector<std::string>,
std::vector<Embedding>,
Embedder::TaskId,
ComputeEmbeddingsStatus>;
// Returns a model info builder preloaded with valid model info.
optimization_guide::TestModelInfoBuilder GetBuilderWithValidModelInfo() {
// Get file paths to the test model files.
base::FilePath test_data_dir;
base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &test_data_dir);
test_data_dir = test_data_dir.AppendASCII("components")
.AppendASCII("test")
.AppendASCII("data")
.AppendASCII("passage_embeddings");
// The files only exist to appease the mojo run-time check for null arguments,
// and they are not read by the fake embedder.
base::FilePath embeddings_path = test_data_dir.AppendASCII("fake_model_file");
base::FilePath sp_path = test_data_dir.AppendASCII("fake_model_file");
// Create serialized metadata.
optimization_guide::proto::PassageEmbeddingsModelMetadata model_metadata;
model_metadata.set_input_window_size(kEmbeddingsModelInputWindowSize);
model_metadata.set_output_size(kEmbeddingsModelOutputSize);
// Load a model info builder.
optimization_guide::TestModelInfoBuilder builder;
builder.SetModelFilePath(embeddings_path);
builder.SetAdditionalFiles({sp_path});
builder.SetVersion(kEmbeddingsModelVersion);
builder.SetModelMetadata(optimization_guide::AnyWrapProto(model_metadata));
return builder;
}
class FakePassageEmbedder : public mojom::PassageEmbedder {
public:
explicit FakePassageEmbedder(
@ -75,23 +41,18 @@ class FakePassageEmbedder : public mojom::PassageEmbedder {
void GenerateEmbeddings(const std::vector<std::string>& inputs,
mojom::PassagePriority priority,
GenerateEmbeddingsCallback callback) override {
std::vector<std::string> passages = inputs;
std::vector<Embedding> embeddings;
std::vector<mojom::PassageEmbeddingsResultPtr> results;
for (const std::string& input : inputs) {
// Fails the generation on an "error" string to simulate failed model
// execution.
// Fail Embeddings generation for the entire batch when encountering
// "error" string to simulate failed model execution.
if (input == "error") {
results.clear();
break;
return std::move(callback).Run({});
}
results.push_back(mojom::PassageEmbeddingsResult::New());
results.back()->embeddings =
std::vector<float>(kEmbeddingsModelOutputSize, 1.0);
results.back()->passage = input;
}
std::move(callback).Run(std::move(results));
}
@ -134,6 +95,7 @@ class FakePassageEmbeddingsServiceController
service_remote_.BindNewPipeAndPassReceiver());
}
using PassageEmbeddingsServiceController::GetEmbeddingsCallback;
using PassageEmbeddingsServiceController::ResetEmbedderRemote;
void ResetServiceRemote() override {
@ -141,41 +103,58 @@ class FakePassageEmbeddingsServiceController
service_remote_.reset();
}
using PassageEmbeddingsServiceController::GetEmbeddings;
private:
std::unique_ptr<FakePassageEmbeddingsService> service_;
};
class FakeMlEmbedder : public MlEmbedder, public EmbedderMetadataObserver {
class FakeEmbedder : public TestEmbedder, public EmbedderMetadataObserver {
public:
explicit FakeMlEmbedder(
PassageEmbeddingsServiceController* service_controller)
: MlEmbedder(service_controller) {
embedder_metadata_observation_.Observe(service_controller);
explicit FakeEmbedder(
EmbedderMetadataProvider* embedder_metadata_provider,
FakePassageEmbeddingsServiceController::GetEmbeddingsCallback
get_embeddings_callback,
base::test::TestFuture<EmbedderMetadata>* embedder_metadata_future)
: get_embeddings_callback_(get_embeddings_callback),
embedder_metadata_future_(embedder_metadata_future) {
embedder_metadata_observation_.Observe(embedder_metadata_provider);
}
using OnEmbedderReadyCallback = base::OnceCallback<void(EmbedderMetadata)>;
void SetOnEmbedderReadyCallback(OnEmbedderReadyCallback callback) {
callback_ = std::move(callback);
if (callback_ && metadata_.IsValid()) {
std::move(callback_).Run(metadata_);
}
// Embedder:
Embedder::TaskId ComputePassagesEmbeddings(
PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) override {
get_embeddings_callback_.Run(
passages, priority,
base::BindOnce(
[](std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback,
std::vector<mojom::PassageEmbeddingsResultPtr> results,
ComputeEmbeddingsStatus status) {
std::vector<Embedding> embeddings;
if (status == ComputeEmbeddingsStatus::kSuccess) {
embeddings = ComputeEmbeddingsForPassages(passages);
}
std::move(callback).Run(passages, embeddings, kInvalidTaskId,
status);
},
passages, std::move(callback)));
return kInvalidTaskId;
}
protected:
// EmbedderMetadataObserver:
void EmbedderMetadataUpdated(
passage_embeddings::EmbedderMetadata metadata) override {
metadata_ = metadata;
if (callback_) {
std::move(callback_).Run(metadata_);
}
void EmbedderMetadataUpdated(EmbedderMetadata metadata) override {
embedder_metadata_future_->SetValue(metadata);
}
EmbedderMetadata metadata_{0, 0};
OnEmbedderReadyCallback callback_;
base::ScopedObservation<PassageEmbeddingsServiceController,
EmbedderMetadataObserver>
base::ScopedObservation<EmbedderMetadataProvider, EmbedderMetadataObserver>
embedder_metadata_observation_{this};
FakePassageEmbeddingsServiceController::GetEmbeddingsCallback
get_embeddings_callback_;
raw_ptr<base::test::TestFuture<EmbedderMetadata>> embedder_metadata_future_;
};
} // namespace
@ -185,72 +164,63 @@ class MlEmbedderTest : public testing::Test {
void SetUp() override {
service_controller_ =
std::make_unique<FakePassageEmbeddingsServiceController>();
service_controller_->SetEmbedderForTesting(std::make_unique<FakeEmbedder>(
/*embedder_metadata_provider=*/service_controller_.get(),
/*get_embeddings_callback=*/
base::BindRepeating(
&FakePassageEmbeddingsServiceController::GetEmbeddings,
base::Unretained(service_controller_.get())),
/*embedder_metadata_future=*/embedder_metadata_future()));
EXPECT_FALSE(embedder_metadata_future()->IsReady());
}
void TearDown() override {}
protected:
base::test::TestFuture<EmbedderMetadata>* embedder_metadata_future() {
return &embedder_metadata_future_;
}
Embedder* embedder() { return service_controller_->GetEmbedder(); }
base::test::TaskEnvironment task_environment_;
base::HistogramTester histogram_tester_;
base::test::TestFuture<EmbedderMetadata> embedder_metadata_future_;
std::unique_ptr<FakePassageEmbeddingsServiceController> service_controller_;
};
TEST_F(MlEmbedderTest, ReceivesValidModelInfo) {
service_controller_->MaybeUpdateModelInfo(
*GetBuilderWithValidModelInfo().Build());
EXPECT_TRUE(service_controller_->MaybeUpdateModelInfo(
*GetBuilderWithValidModelInfo().Build()));
auto metadata = embedder_metadata_future()->Take();
EXPECT_TRUE(metadata.IsValid());
EXPECT_EQ(metadata.model_version, kEmbeddingsModelVersion);
EXPECT_EQ(metadata.output_size, kEmbeddingsModelOutputSize);
auto ml_embedder =
std::make_unique<FakeMlEmbedder>(service_controller_.get());
bool on_embedder_ready_invoked = false;
ml_embedder->SetOnEmbedderReadyCallback(
base::BindLambdaForTesting([&](EmbedderMetadata metadata) {
EXPECT_EQ(metadata.model_version, kEmbeddingsModelVersion);
EXPECT_EQ(metadata.output_size, kEmbeddingsModelOutputSize);
on_embedder_ready_invoked = true;
}));
EXPECT_TRUE(on_embedder_ready_invoked);
histogram_tester_.ExpectTotalCount(kModelInfoMetricName, 1);
histogram_tester_.ExpectUniqueSample(kModelInfoMetricName,
EmbeddingsModelInfoStatus::kValid, 1);
}
TEST_F(MlEmbedderTest, ReceivesEmptyModelInfo) {
service_controller_->MaybeUpdateModelInfo({});
auto ml_embedder =
std::make_unique<FakeMlEmbedder>(service_controller_.get());
bool on_embedder_ready_invoked = false;
EXPECT_FALSE(service_controller_->MaybeUpdateModelInfo({}));
EXPECT_FALSE(embedder_metadata_future()->IsReady());
ml_embedder->SetOnEmbedderReadyCallback(base::BindLambdaForTesting(
[&](EmbedderMetadata metadata) { on_embedder_ready_invoked = true; }));
EXPECT_FALSE(on_embedder_ready_invoked);
histogram_tester_.ExpectTotalCount(kModelInfoMetricName, 1);
histogram_tester_.ExpectUniqueSample(kModelInfoMetricName,
EmbeddingsModelInfoStatus::kEmpty, 1);
}
TEST_F(MlEmbedderTest, ReceivesModelInfoWithInvalidModelMetadata) {
// Make some invalid metadata.
optimization_guide::proto::Any metadata_any;
metadata_any.set_type_url("not a valid type url");
metadata_any.set_value("not a valid serialized metadata");
optimization_guide::TestModelInfoBuilder builder =
GetBuilderWithValidModelInfo();
builder.SetModelMetadata(metadata_any);
service_controller_->MaybeUpdateModelInfo(*builder.Build());
auto ml_embedder =
std::make_unique<FakeMlEmbedder>(service_controller_.get());
bool on_embedder_ready_invoked = false;
EXPECT_FALSE(service_controller_->MaybeUpdateModelInfo(*builder.Build()));
EXPECT_FALSE(embedder_metadata_future()->IsReady());
ml_embedder->SetOnEmbedderReadyCallback(base::BindLambdaForTesting(
[&](EmbedderMetadata metadata) { on_embedder_ready_invoked = true; }));
EXPECT_FALSE(on_embedder_ready_invoked);
histogram_tester_.ExpectTotalCount(kModelInfoMetricName, 1);
histogram_tester_.ExpectUniqueSample(
kModelInfoMetricName, EmbeddingsModelInfoStatus::kInvalidMetadata, 1);
@ -261,15 +231,9 @@ TEST_F(MlEmbedderTest, ReceivesModelInfoWithoutModelMetadata) {
GetBuilderWithValidModelInfo();
builder.SetModelMetadata(std::nullopt);
service_controller_->MaybeUpdateModelInfo(*builder.Build());
auto ml_embedder =
std::make_unique<FakeMlEmbedder>(service_controller_.get());
bool on_embedder_ready_invoked = false;
EXPECT_FALSE(service_controller_->MaybeUpdateModelInfo(*builder.Build()));
EXPECT_FALSE(embedder_metadata_future()->IsReady());
ml_embedder->SetOnEmbedderReadyCallback(base::BindLambdaForTesting(
[&](EmbedderMetadata metadata) { on_embedder_ready_invoked = true; }));
EXPECT_FALSE(on_embedder_ready_invoked);
histogram_tester_.ExpectTotalCount(kModelInfoMetricName, 1);
histogram_tester_.ExpectUniqueSample(
kModelInfoMetricName, EmbeddingsModelInfoStatus::kNoMetadata, 1);
@ -283,35 +247,36 @@ TEST_F(MlEmbedderTest, ReceivesModelInfoWithoutAdditionalFiles) {
builder.SetAdditionalFiles(
{test_data_dir.AppendASCII("foo"), test_data_dir.AppendASCII("bar")});
service_controller_->MaybeUpdateModelInfo(*builder.Build());
auto ml_embedder =
std::make_unique<FakeMlEmbedder>(service_controller_.get());
bool on_embedder_ready_invoked = false;
EXPECT_FALSE(service_controller_->MaybeUpdateModelInfo(*builder.Build()));
EXPECT_FALSE(embedder_metadata_future()->IsReady());
ml_embedder->SetOnEmbedderReadyCallback(base::BindLambdaForTesting(
[&](EmbedderMetadata metadata) { on_embedder_ready_invoked = true; }));
EXPECT_FALSE(on_embedder_ready_invoked);
histogram_tester_.ExpectTotalCount(kModelInfoMetricName, 1);
histogram_tester_.ExpectUniqueSample(
kModelInfoMetricName, EmbeddingsModelInfoStatus::kInvalidAdditionalFiles,
1);
}
TEST_F(MlEmbedderTest, ReturnsEmbeddings) {
service_controller_->MaybeUpdateModelInfo(
*GetBuilderWithValidModelInfo().Build());
auto ml_embedder =
std::make_unique<FakeMlEmbedder>(service_controller_.get());
histogram_tester_.ExpectTotalCount(kModelInfoMetricName, 1);
histogram_tester_.ExpectUniqueSample(kModelInfoMetricName,
EmbeddingsModelInfoStatus::kValid, 1);
TEST_F(MlEmbedderTest, ReceivesEmptyPassages) {
EXPECT_TRUE(service_controller_->MaybeUpdateModelInfo(
*GetBuilderWithValidModelInfo().Build()));
ComputePassagesEmbeddingsFuture future;
ml_embedder->ComputePassagesEmbeddings(PassagePriority::kPassive,
{"foo", "bar"}, future.GetCallback());
embedder()->ComputePassagesEmbeddings(PassagePriority::kPassive, {},
future.GetCallback());
auto [passages, embeddings, task_id, status] = future.Get();
EXPECT_EQ(status, ComputeEmbeddingsStatus::kSuccess);
EXPECT_TRUE(passages.empty());
EXPECT_TRUE(embeddings.empty());
}
TEST_F(MlEmbedderTest, ReturnsEmbeddings) {
EXPECT_TRUE(service_controller_->MaybeUpdateModelInfo(
*GetBuilderWithValidModelInfo().Build()));
ComputePassagesEmbeddingsFuture future;
embedder()->ComputePassagesEmbeddings(PassagePriority::kPassive,
{"foo", "bar"}, future.GetCallback());
auto [passages, embeddings, task_id, status] = future.Get();
EXPECT_EQ(status, ComputeEmbeddingsStatus::kSuccess);
@ -326,54 +291,46 @@ TEST_F(MlEmbedderTest, ReturnsModelUnavailableErrorIfModelInfoNotValid) {
GetBuilderWithValidModelInfo();
builder.SetModelMetadata(std::nullopt);
service_controller_->MaybeUpdateModelInfo(*builder.Build());
auto ml_embedder =
std::make_unique<FakeMlEmbedder>(service_controller_.get());
EXPECT_FALSE(service_controller_->MaybeUpdateModelInfo(*builder.Build()));
ComputePassagesEmbeddingsFuture future;
ml_embedder->ComputePassagesEmbeddings(PassagePriority::kPassive,
{"foo", "bar"}, future.GetCallback());
embedder()->ComputePassagesEmbeddings(PassagePriority::kPassive,
{"foo", "bar"}, future.GetCallback());
auto [passages, embeddings, task_id, status] = future.Get();
EXPECT_EQ(status, ComputeEmbeddingsStatus::kModelUnavailable);
EXPECT_TRUE(passages.empty());
EXPECT_EQ(passages[0], "foo");
EXPECT_EQ(passages[1], "bar");
EXPECT_TRUE(embeddings.empty());
histogram_tester_.ExpectTotalCount(kModelInfoMetricName, 1);
histogram_tester_.ExpectUniqueSample(
kModelInfoMetricName, EmbeddingsModelInfoStatus::kNoMetadata, 1);
}
TEST_F(MlEmbedderTest, ReturnsExecutionFailure) {
service_controller_->MaybeUpdateModelInfo(
*GetBuilderWithValidModelInfo().Build());
auto ml_embedder =
std::make_unique<FakeMlEmbedder>(service_controller_.get());
EXPECT_TRUE(service_controller_->MaybeUpdateModelInfo(
*GetBuilderWithValidModelInfo().Build()));
ComputePassagesEmbeddingsFuture future;
ml_embedder->ComputePassagesEmbeddings(PassagePriority::kPassive, {"error"},
future.GetCallback());
embedder()->ComputePassagesEmbeddings(PassagePriority::kPassive,
{"error", "baz"}, future.GetCallback());
auto [passages, embeddings, task_id, status] = future.Get();
EXPECT_EQ(status, ComputeEmbeddingsStatus::kExecutionFailure);
EXPECT_TRUE(passages.empty());
EXPECT_EQ(passages[0], "error");
EXPECT_EQ(passages[1], "baz");
EXPECT_TRUE(embeddings.empty());
}
TEST_F(MlEmbedderTest, EmbedderRunningStatus) {
service_controller_->MaybeUpdateModelInfo(
*GetBuilderWithValidModelInfo().Build());
auto ml_embedder =
std::make_unique<FakeMlEmbedder>(service_controller_.get());
EXPECT_TRUE(service_controller_->MaybeUpdateModelInfo(
*GetBuilderWithValidModelInfo().Build()));
{
ComputePassagesEmbeddingsFuture future1;
ml_embedder->ComputePassagesEmbeddings(
embedder()->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"foo", "bar"}, future1.GetCallback());
// Embedder is running.
EXPECT_TRUE(service_controller_->EmbedderRunning());
ComputePassagesEmbeddingsFuture future2;
ml_embedder->ComputePassagesEmbeddings(
embedder()->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"baz", "qux"}, future2.GetCallback());
// Embedder is running.
EXPECT_TRUE(service_controller_->EmbedderRunning());
@ -390,13 +347,13 @@ TEST_F(MlEmbedderTest, EmbedderRunningStatus) {
}
{
ComputePassagesEmbeddingsFuture future1;
ml_embedder->ComputePassagesEmbeddings(
embedder()->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"foo", "bar"}, future1.GetCallback());
// Embedder is running.
EXPECT_TRUE(service_controller_->EmbedderRunning());
ComputePassagesEmbeddingsFuture future2;
ml_embedder->ComputePassagesEmbeddings(
embedder()->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"baz", "qux"}, future2.GetCallback());
// Embedder is running.
EXPECT_TRUE(service_controller_->EmbedderRunning());
@ -414,13 +371,13 @@ TEST_F(MlEmbedderTest, EmbedderRunningStatus) {
{
// Calling `ComputePassagesEmbeddings()` again launches the service.
ComputePassagesEmbeddingsFuture future1;
ml_embedder->ComputePassagesEmbeddings(
embedder()->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"foo", "bar"}, future1.GetCallback());
// Embedder is running.
EXPECT_TRUE(service_controller_->EmbedderRunning());
ComputePassagesEmbeddingsFuture future2;
ml_embedder->ComputePassagesEmbeddings(
embedder()->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"baz", "qux"}, future2.GetCallback());
// Embedder is running.
EXPECT_TRUE(service_controller_->EmbedderRunning());
@ -437,13 +394,13 @@ TEST_F(MlEmbedderTest, EmbedderRunningStatus) {
}
{
ComputePassagesEmbeddingsFuture future1;
ml_embedder->ComputePassagesEmbeddings(
embedder()->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"foo", "bar"}, future1.GetCallback());
// Embedder is running.
EXPECT_TRUE(service_controller_->EmbedderRunning());
ComputePassagesEmbeddingsFuture future2;
ml_embedder->ComputePassagesEmbeddings(
embedder()->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"baz", "qux"}, future2.GetCallback());
// Embedder is still running.
EXPECT_TRUE(service_controller_->EmbedderRunning());

@ -1,57 +0,0 @@
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/passage_embeddings/mock_embedder.h"
#include "base/task/sequenced_task_runner.h"
namespace passage_embeddings {
namespace {
constexpr int64_t kModelVersion = 1;
constexpr size_t kOutputSize = 768ul;
constexpr size_t kMockPassageWordCount = 10;
Embedding ComputeEmbeddingForPassage(const std::string& passage) {
Embedding embedding(std::vector<float>(kOutputSize, 1.0f));
embedding.Normalize();
embedding.SetPassageWordCount(kMockPassageWordCount);
return embedding;
}
} // namespace
MockEmbedder::MockEmbedder() = default;
MockEmbedder::~MockEmbedder() = default;
Embedder::TaskId MockEmbedder::ComputePassagesEmbeddings(
PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) {
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(std::move(callback), std::move(passages),
ComputeEmbeddingsForPassages(passages), kInvalidTaskId,
ComputeEmbeddingsStatus::kSuccess));
return kInvalidTaskId;
}
bool MockEmbedder::TryCancel(TaskId task_id) {
return false;
}
void MockEmbedder::SetOnEmbedderReadyCallback(
OnEmbedderReadyCallback callback) {
// The mock embedder is always ready, so we invoke the callback directly.
std::move(callback).Run({kModelVersion, kOutputSize});
}
std::vector<Embedding> MockEmbedder::ComputeEmbeddingsForPassages(
const std::vector<std::string>& passages) {
return std::vector<Embedding>(passages.size(),
ComputeEmbeddingForPassage(""));
}
} // namespace passage_embeddings

@ -1,38 +0,0 @@
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_PASSAGE_EMBEDDINGS_MOCK_EMBEDDER_H_
#define COMPONENTS_PASSAGE_EMBEDDINGS_MOCK_EMBEDDER_H_
#include <string>
#include <vector>
#include "components/passage_embeddings/embedder.h"
namespace passage_embeddings {
class MockEmbedder : public Embedder {
public:
MockEmbedder();
~MockEmbedder() override;
// Embedder:
TaskId ComputePassagesEmbeddings(
PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) override;
bool TryCancel(TaskId task_id) override;
using OnEmbedderReadyCallback = base::OnceCallback<void(EmbedderMetadata)>;
void SetOnEmbedderReadyCallback(OnEmbedderReadyCallback callback);
protected:
std::vector<Embedding> ComputeEmbeddingsForPassages(
const std::vector<std::string>& passages);
};
} // namespace passage_embeddings
#endif // COMPONENTS_PASSAGE_EMBEDDINGS_MOCK_EMBEDDER_H_

@ -6,15 +6,11 @@
#include <memory>
#include "base/logging.h"
#include "base/test/bind.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/task_environment.h"
#include "base/memory/raw_ptr.h"
#include "base/test/test_future.h"
#include "components/optimization_guide/core/test_optimization_guide_model_provider.h"
#include "components/passage_embeddings/embedder.h"
#include "components/passage_embeddings/mock_embedder.h"
#include "components/passage_embeddings/passage_embeddings_service_controller.h"
#include "components/passage_embeddings/passage_embeddings_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace passage_embeddings {
@ -22,61 +18,108 @@ namespace passage_embeddings {
class FakePassageEmbeddingsServiceController
: public passage_embeddings::PassageEmbeddingsServiceController {
public:
FakePassageEmbeddingsServiceController() = default;
explicit FakePassageEmbeddingsServiceController(
base::test::TestFuture<bool>* model_info_future)
: model_info_received_future_(model_info_future) {}
~FakePassageEmbeddingsServiceController() override = default;
// passage_embeddings::PassageEmbeddingsServiceController:
bool MaybeUpdateModelInfo(
base::optional_ref<const optimization_guide::ModelInfo> model_info)
override {
const bool received_model_info = model_info.has_value();
model_info_received_future_->SetValue(received_model_info);
return received_model_info;
}
void MaybeLaunchService() override {}
void ResetServiceRemote() override {}
protected:
raw_ptr<base::test::TestFuture<bool>> model_info_received_future_;
};
class TestOptimizationGuideModelProvider
: public optimization_guide::TestOptimizationGuideModelProvider {
public:
explicit TestOptimizationGuideModelProvider(
base::test::TestFuture<bool>* target_observed_future)
: target_observed_future_(target_observed_future) {}
// optimization_guide::OptimizationGuideModelProvider:
void AddObserverForOptimizationTargetModel(
optimization_guide::proto::OptimizationTarget optimization_target,
const std::optional<optimization_guide::proto::Any>& model_metadata,
optimization_guide::OptimizationTargetModelObserver* observer) override {
if (optimization_target ==
optimization_guide::proto::OPTIMIZATION_TARGET_PASSAGE_EMBEDDER) {
passage_embedder_target_registered_ = true;
}
if (!model_info_) {
observer->OnModelUpdated(
optimization_guide::proto::OPTIMIZATION_TARGET_PASSAGE_EMBEDDER,
std::nullopt);
} else {
observer->OnModelUpdated(
optimization_guide::proto::OPTIMIZATION_TARGET_PASSAGE_EMBEDDER,
*model_info_);
}
target_observed_future_->SetValue(
optimization_target ==
optimization_guide::proto::OPTIMIZATION_TARGET_PASSAGE_EMBEDDER);
observer_list_.AddObserver(observer);
NotifyObservers();
}
bool passage_embedder_target_registered() const {
return passage_embedder_target_registered_;
void RemoveObserverForOptimizationTargetModel(
optimization_guide::proto::OptimizationTarget optimization_target,
optimization_guide::OptimizationTargetModelObserver* observer) override {
observer_list_.RemoveObserver(observer);
}
// Set the model info to be sent to the observer.
void SetModelInfo(std::unique_ptr<optimization_guide::ModelInfo> model_info) {
model_info_ = std::move(model_info);
NotifyObservers();
}
private:
bool passage_embedder_target_registered_ = false;
void NotifyObservers() {
if (model_info_) {
observer_list_.Notify(
&optimization_guide::OptimizationTargetModelObserver::OnModelUpdated,
optimization_guide::proto::OPTIMIZATION_TARGET_PASSAGE_EMBEDDER,
*model_info_);
} else {
observer_list_.Notify(
&optimization_guide::OptimizationTargetModelObserver::OnModelUpdated,
optimization_guide::proto::OPTIMIZATION_TARGET_PASSAGE_EMBEDDER,
std::nullopt);
}
}
raw_ptr<base::test::TestFuture<bool>> target_observed_future_;
base::ObserverList<optimization_guide::OptimizationTargetModelObserver>
observer_list_;
std::unique_ptr<optimization_guide::ModelInfo> model_info_;
};
class PassageEmbedderModelObserverTest : public testing::Test {};
class PassageEmbedderModelObserverTest : public testing::Test {
protected:
base::test::TestFuture<bool> target_observed_future_;
base::test::TestFuture<bool> model_info_received_future_;
};
TEST_F(PassageEmbedderModelObserverTest, ObservesTargetAndNotifiesObserver) {
auto model_provider = std::make_unique<TestOptimizationGuideModelProvider>(
&target_observed_future_);
EXPECT_FALSE(target_observed_future_.IsReady());
TEST_F(PassageEmbedderModelObserverTest, ObservesTarget) {
auto model_provider = std::make_unique<TestOptimizationGuideModelProvider>();
auto service_controller =
std::make_unique<FakePassageEmbeddingsServiceController>();
std::make_unique<FakePassageEmbeddingsServiceController>(
&model_info_received_future_);
EXPECT_FALSE(model_info_received_future_.IsReady());
EXPECT_FALSE(model_provider->passage_embedder_target_registered());
auto passage_embedder_model_observer =
std::make_unique<PassageEmbedderModelObserver>(
model_provider.get(), service_controller.get(), false);
EXPECT_TRUE(model_provider->passage_embedder_target_registered());
EXPECT_TRUE(target_observed_future_.IsReady());
EXPECT_TRUE(target_observed_future_.Take());
EXPECT_TRUE(model_info_received_future_.IsReady());
EXPECT_FALSE(model_info_received_future_.Take());
model_provider->SetModelInfo(GetBuilderWithValidModelInfo().Build());
EXPECT_TRUE(model_info_received_future_.IsReady());
EXPECT_TRUE(model_info_received_future_.Take());
}
} // namespace passage_embeddings

@ -6,16 +6,17 @@
#include <ranges>
#include "base/functional/bind.h"
#include "base/metrics/histogram_functions.h"
#include "base/not_fatal_until.h"
#include "base/notreached.h"
#include "base/task/thread_pool.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/passage_embeddings/ml_embedder.h"
#include "components/passage_embeddings/internal/scheduling_embedder.h"
#include "components/passage_embeddings/passage_embeddings_features.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "mojo/public/cpp/bindings/callback_helpers.h"
#include "services/passage_embeddings/public/mojom/passage_embeddings.mojom-shared.h"
#include "services/passage_embeddings/public/mojom/passage_embeddings.mojom.h"
namespace passage_embeddings {
@ -72,17 +73,18 @@ class ScopedEmbeddingsModelInfoStatusLogger {
} // namespace
PassageEmbeddingsServiceController::PassageEmbeddingsServiceController()
: scheduling_embedder_(std::make_unique<SchedulingEmbedder>(
std::make_unique<MlEmbedder>(this),
: embedder_(std::make_unique<SchedulingEmbedder>(
/*embedder_metadata_provider=*/this,
/*get_embeddings_callback=*/
base::BindRepeating(
&PassageEmbeddingsServiceController::GetEmbeddings,
base::Unretained(this)),
kSchedulerMaxJobs.Get(),
kSchedulerMaxBatchSize.Get(),
kUsePerformanceScenario.Get())) {
AddObserver(scheduling_embedder_.get());
}
kUsePerformanceScenario.Get())) {}
PassageEmbeddingsServiceController::~PassageEmbeddingsServiceController() {
RemoveObserver(scheduling_embedder_.get());
}
PassageEmbeddingsServiceController::~PassageEmbeddingsServiceController() =
default;
bool PassageEmbeddingsServiceController::MaybeUpdateModelInfo(
base::optional_ref<const optimization_guide::ModelInfo> model_info) {
@ -164,10 +166,13 @@ void PassageEmbeddingsServiceController::OnLoadModelsResult(bool success) {
}
}
std::unique_ptr<Embedder> PassageEmbeddingsServiceController::MakeEmbedder() {
auto client =
std::make_unique<SchedulingClientEmbedder>(scheduling_embedder_.get());
return client;
Embedder* PassageEmbeddingsServiceController::GetEmbedder() {
return embedder_.get();
}
void PassageEmbeddingsServiceController::SetEmbedderForTesting(
std::unique_ptr<Embedder> embedder) {
embedder_ = std::move(embedder);
}
void PassageEmbeddingsServiceController::AddObserver(
@ -186,7 +191,12 @@ void PassageEmbeddingsServiceController::RemoveObserver(
void PassageEmbeddingsServiceController::GetEmbeddings(
std::vector<std::string> passages,
PassagePriority priority,
GetEmbeddingsCallback callback) {
GetEmbeddingsResultCallback callback) {
if (passages.empty()) {
std::move(callback).Run({}, ComputeEmbeddingsStatus::kSuccess);
return;
}
if (!EmbedderReady()) {
VLOG(1) << "Missing model path: embeddings='" << embeddings_model_path_
<< "'; sp='" << sp_model_path_ << "'";
@ -250,7 +260,7 @@ void PassageEmbeddingsServiceController::ResetEmbedderRemote() {
void PassageEmbeddingsServiceController::OnGotEmbeddings(
RequestId request_id,
GetEmbeddingsCallback callback,
GetEmbeddingsResultCallback callback,
std::vector<mojom::PassageEmbeddingsResultPtr> results) {
// Mojo invokes the callbacks in the order in which `GenerateEmbeddings()` was
// called. Therefore, `request_id` should be expected at the front of

@ -4,61 +4,60 @@
#ifndef COMPONENTS_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDINGS_SERVICE_CONTROLLER_H_
#define COMPONENTS_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDINGS_SERVICE_CONTROLLER_H_
#include <memory>
#include <vector>
#include "base/callback_list.h"
#include "base/observer_list.h"
#include "base/types/optional_ref.h"
#include "components/optimization_guide/core/model_info.h"
#include "components/optimization_guide/proto/passage_embeddings_model_metadata.pb.h"
#include "components/passage_embeddings/embedder.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "components/passage_embeddings/scheduling_embedder.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "services/passage_embeddings/public/mojom/passage_embeddings.mojom.h"
namespace passage_embeddings {
class PassageEmbeddingsServiceController {
class PassageEmbeddingsServiceController : public EmbedderMetadataProvider {
public:
PassageEmbeddingsServiceController();
virtual ~PassageEmbeddingsServiceController();
~PassageEmbeddingsServiceController() override;
// Updates the paths and the metadata needed for executing the passage
// embeddings model. The original paths and metadata will be erased regardless
// of the validity of the new model paths. Returns true if the given paths are
// valid.
bool MaybeUpdateModelInfo(
// of the validity of the new model paths.
// Returns true and notifies the observers if the given paths are valid.
// Virtual for testing.
virtual bool MaybeUpdateModelInfo(
base::optional_ref<const optimization_guide::ModelInfo> model_info);
// Returns true if the embedder is currently running.
bool EmbedderRunning();
// Returns an embedder that can be used to generate passage embeddings.
std::unique_ptr<Embedder> MakeEmbedder();
// Returns the embedder used to generate embeddings.
Embedder* GetEmbedder();
// Subscribe for notification when embedder metadata is ready. This may
// result in immediate notification if metadata is ready at time of call.
void AddObserver(EmbedderMetadataObserver* observer);
// Must be called exactly once for each corresponding call to
// `AddEmbedderMetadataObserver` when observation is no longer needed.
void RemoveObserver(EmbedderMetadataObserver* observer);
void SetEmbedderForTesting(std::unique_ptr<Embedder> embedder);
protected:
// Embedders are the way to access the `GetEmbeddings` API. Protecting it from
// general use avoids bare access calls that would interrupt scheduled tasks.
friend class MlEmbedder;
// EmbedderMetadataProvider:
void AddObserver(EmbedderMetadataObserver* observer) override;
void RemoveObserver(EmbedderMetadataObserver* observer) override;
// Starts the service and calls `callback` with the embeddings. It is
// guaranteed that the result will have the same number of elements as
// `passages` when all embeddings executions succeed. Otherwise, will return
// an empty vector.
using GetEmbeddingsCallback = base::OnceCallback<void(
// Computes embeddings for each entry in `passages`. Will invoke `callback`
// when done. If successful, it is guaranteed that `results` will have the
// same number of passages and embeddings and in the same order as
// `passages`. Otherwise `results` will have empty passages and embeddings.
using GetEmbeddingsResultCallback = base::OnceCallback<void(
std::vector<mojom::PassageEmbeddingsResultPtr> results,
ComputeEmbeddingsStatus status)>;
using GetEmbeddingsCallback =
base::RepeatingCallback<void(std::vector<std::string> passages,
PassagePriority priority,
GetEmbeddingsResultCallback callback)>;
void GetEmbeddings(std::vector<std::string> passages,
PassagePriority priority,
GetEmbeddingsCallback callback);
GetEmbeddingsResultCallback callback);
// Returns true if this service controller is ready for embeddings generation.
bool EmbedderReady();
@ -97,7 +96,7 @@ class PassageEmbeddingsServiceController {
// Called when an attempt to generate embeddings finishes.
void OnGotEmbeddings(RequestId request_id,
GetEmbeddingsCallback callback,
GetEmbeddingsResultCallback callback,
std::vector<mojom::PassageEmbeddingsResultPtr> results);
// Version of the embeddings model.
@ -118,10 +117,10 @@ class PassageEmbeddingsServiceController {
// Notifies embedders that model metadata updated.
base::ObserverList<EmbedderMetadataObserver> observer_list_;
// This holds the main scheduler that receives requests from multiple separate
// client embedders, prioritizes all the jobs, and ultimately submits batches
// of work via `GetEmbeddings` when the time is right.
std::unique_ptr<SchedulingEmbedder> scheduling_embedder_;
// This holds the main scheduler that receives requests from multiple clients,
// prioritizes all the jobs, and ultimately submits batches of work via
// `GetEmbeddings` when the time is right.
std::unique_ptr<Embedder> embedder_;
// Used to generate weak pointers to self.
base::WeakPtrFactory<PassageEmbeddingsServiceController> weak_ptr_factory_{

@ -0,0 +1,113 @@
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/passage_embeddings/passage_embeddings_test_util.h"
#include "base/path_service.h"
#include "base/task/sequenced_task_runner.h"
#include "components/optimization_guide/core/optimization_guide_proto_util.h"
#include "components/optimization_guide/proto/passage_embeddings_model_metadata.pb.h"
namespace passage_embeddings {
namespace {
inline constexpr uint32_t kEmbeddingsModelInputWindowSize = 256u;
Embedding ComputeEmbeddingForPassage(size_t embeddings_model_output_size) {
constexpr size_t kMockPassageWordCount = 10;
Embedding embedding(std::vector<float>(embeddings_model_output_size, 1.0f));
embedding.Normalize();
embedding.SetPassageWordCount(kMockPassageWordCount);
return embedding;
}
EmbedderMetadata GetValidEmbedderMetadata() {
return EmbedderMetadata(kEmbeddingsModelVersion, kEmbeddingsModelOutputSize);
}
} // namespace
optimization_guide::TestModelInfoBuilder GetBuilderWithValidModelInfo() {
// Get file paths to the test model files.
base::FilePath test_data_dir;
base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &test_data_dir);
test_data_dir = test_data_dir.AppendASCII("components")
.AppendASCII("test")
.AppendASCII("data")
.AppendASCII("passage_embeddings");
// The files only exist to appease the mojo run-time check for null arguments,
// and they are not read by the fake embedder.
base::FilePath embeddings_path = test_data_dir.AppendASCII("fake_model_file");
base::FilePath sp_path = test_data_dir.AppendASCII("fake_model_file");
// Create serialized metadata.
optimization_guide::proto::PassageEmbeddingsModelMetadata model_metadata;
model_metadata.set_input_window_size(kEmbeddingsModelInputWindowSize);
model_metadata.set_output_size(kEmbeddingsModelOutputSize);
// Load a model info builder.
optimization_guide::TestModelInfoBuilder builder;
builder.SetModelFilePath(embeddings_path);
builder.SetAdditionalFiles({sp_path});
builder.SetVersion(kEmbeddingsModelVersion);
builder.SetModelMetadata(optimization_guide::AnyWrapProto(model_metadata));
return builder;
}
std::vector<Embedding> ComputeEmbeddingsForPassages(
const std::vector<std::string>& passages) {
return std::vector<Embedding>(
passages.size(), ComputeEmbeddingForPassage(kEmbeddingsModelOutputSize));
}
////////////////////////////////////////////////////////////////////////////////
Embedder::TaskId TestEmbedder::ComputePassagesEmbeddings(
PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) {
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(
[](std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) {
std::move(callback).Run(
passages, ComputeEmbeddingsForPassages(passages),
kInvalidTaskId, ComputeEmbeddingsStatus::kSuccess);
},
passages, std::move(callback)));
return kInvalidTaskId;
}
bool TestEmbedder::TryCancel(TaskId task_id) {
return false;
}
////////////////////////////////////////////////////////////////////////////////
TestEmbedderMetadataProvider::TestEmbedderMetadataProvider() = default;
TestEmbedderMetadataProvider::~TestEmbedderMetadataProvider() = default;
void TestEmbedderMetadataProvider::AddObserver(
EmbedderMetadataObserver* observer) {
observer->EmbedderMetadataUpdated(GetValidEmbedderMetadata());
observer_list_.AddObserver(observer);
}
void TestEmbedderMetadataProvider::RemoveObserver(
EmbedderMetadataObserver* observer) {
observer_list_.RemoveObserver(observer);
}
////////////////////////////////////////////////////////////////////////////////
TestEnvironment::TestEnvironment()
: embedder_(std::make_unique<TestEmbedder>()),
embedder_metadata_provider_(
std::make_unique<TestEmbedderMetadataProvider>()) {}
TestEnvironment::~TestEnvironment() = default;
} // namespace passage_embeddings

@ -0,0 +1,82 @@
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDINGS_TEST_UTIL_H_
#define COMPONENTS_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDINGS_TEST_UTIL_H_
#include <memory>
#include <string>
#include <vector>
#include "base/observer_list.h"
#include "base/time/time.h"
#include "components/optimization_guide/core/test_model_info_builder.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
namespace passage_embeddings {
inline constexpr int64_t kEmbeddingsModelVersion = 1l;
inline constexpr size_t kEmbeddingsModelOutputSize = 768ul;
// Returns a model info builder preloaded with valid model info.
optimization_guide::TestModelInfoBuilder GetBuilderWithValidModelInfo();
// Returns valid Embeddings for the given passages.
std::vector<Embedding> ComputeEmbeddingsForPassages(
const std::vector<std::string>& passages);
////////////////////////////////////////////////////////////////////////////////
// An Embedder that generates Embeddings asynchronously.
class TestEmbedder : public Embedder {
public:
TestEmbedder() = default;
~TestEmbedder() override = default;
// Embedder:
TaskId ComputePassagesEmbeddings(
PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) override;
bool TryCancel(TaskId task_id) override;
};
////////////////////////////////////////////////////////////////////////////////
// An EmbedderMetadataProvider that notifies the observer immediately with valid
// embedder metadata.
class TestEmbedderMetadataProvider : public EmbedderMetadataProvider {
public:
TestEmbedderMetadataProvider();
~TestEmbedderMetadataProvider() override;
// EmbedderMetadataProvider:
void AddObserver(EmbedderMetadataObserver* observer) override;
void RemoveObserver(EmbedderMetadataObserver* observer) override;
private:
base::ObserverList<EmbedderMetadataObserver> observer_list_;
};
////////////////////////////////////////////////////////////////////////////////
// The TestEnvironment that encapsulates test helper instances.
class TestEnvironment {
public:
TestEnvironment();
~TestEnvironment();
Embedder* embedder() { return embedder_.get(); }
EmbedderMetadataProvider* embedder_metadata_provider() {
return embedder_metadata_provider_.get();
}
private:
std::unique_ptr<TestEmbedder> embedder_;
std::unique_ptr<TestEmbedderMetadataProvider> embedder_metadata_provider_;
};
} // namespace passage_embeddings
#endif // COMPONENTS_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDINGS_TEST_UTIL_H_

@ -2,10 +2,9 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/passage_embeddings/embedder.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include <algorithm>
#include <queue>
namespace passage_embeddings {

@ -6,27 +6,18 @@
#define COMPONENTS_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDINGS_TYPES_H_
#include <optional>
#include <string>
#include <vector>
#include "base/functional/callback.h"
#include "base/observer_list_types.h"
#include "services/passage_embeddings/public/mojom/passage_embeddings.mojom.h"
namespace passage_embeddings {
inline constexpr char kModelInfoMetricName[] =
"History.Embeddings.Embedder.ModelInfoStatus";
struct EmbedderMetadata {
EmbedderMetadata(int64_t model_version,
size_t output_size,
std::optional<double> search_score_threshold = std::nullopt)
: model_version(model_version),
output_size(output_size),
search_score_threshold(search_score_threshold) {}
bool IsValid() { return model_version != 0 && output_size != 0; }
int64_t model_version;
size_t output_size;
std::optional<double> search_score_threshold;
};
enum class EmbeddingsModelInfoStatus {
kUnknown = 0,
@ -81,6 +72,116 @@ enum class ComputeEmbeddingsStatus {
kMaxValue = kCanceled,
};
struct EmbedderMetadata {
EmbedderMetadata(int64_t model_version,
size_t output_size,
std::optional<double> search_score_threshold = std::nullopt)
: model_version(model_version),
output_size(output_size),
search_score_threshold(search_score_threshold) {}
bool IsValid() { return model_version != 0 && output_size != 0; }
int64_t model_version;
size_t output_size;
std::optional<double> search_score_threshold;
};
// Observer interface for getting notified when the embedder metadata is updated.
class EmbedderMetadataObserver : public base::CheckedObserver {
public:
// Called when the embedder metadata is updated.
virtual void EmbedderMetadataUpdated(EmbedderMetadata metadata) = 0;
};
// Notifies observers when the embedder metadata is updated.
class EmbedderMetadataProvider {
public:
virtual ~EmbedderMetadataProvider() = default;
// Subscribes `observer` for notifications when the embedder metadata is
// updated. Will immediately notify if metadata is ready at the time of call.
virtual void AddObserver(EmbedderMetadataObserver* observer) = 0;
// Unsubscribes `observer` from notifications when the embedder metadata is
// updated.
virtual void RemoveObserver(EmbedderMetadataObserver* observer) = 0;
protected:
EmbedderMetadataProvider() = default;
};
// Encapsulate embeddings and related helpers.
class Embedding {
public:
explicit Embedding(std::vector<float> data);
Embedding(std::vector<float> data, size_t passage_word_count);
Embedding();
~Embedding();
Embedding(const Embedding&);
Embedding& operator=(const Embedding&);
Embedding(Embedding&&);
Embedding& operator=(Embedding&&);
bool operator==(const Embedding&) const;
// The number of elements in the data vector.
size_t Dimensions() const;
// The length of the vector.
float Magnitude() const;
// Scale the vector to unit length.
void Normalize();
// Compares one embedding with another and returns a similarity measure.
float ScoreWith(const Embedding& other_embedding) const;
// Const accessor used for storage.
const std::vector<float>& GetData() const { return data_; }
// Used for search filtering of passages with low word count.
size_t GetPassageWordCount() const { return passage_word_count_; }
void SetPassageWordCount(size_t passage_word_count) {
passage_word_count_ = passage_word_count;
}
private:
std::vector<float> data_;
size_t passage_word_count_ = 0;
};
// Computes embeddings for passages. Allows for cancellation of tasks.
class Embedder {
public:
using TaskId = uint64_t;
static constexpr TaskId kInvalidTaskId = 0;
virtual ~Embedder() = default;
// Computes embeddings for each entry in `passages`. Will invoke `callback`
// when done. If successful, it is guaranteed that the callback will return
// the same number of passages and embeddings and in the same order as
// `passages`. Otherwise the callback will return the original passages but
// with an empty embeddings vector.
using ComputePassagesEmbeddingsCallback =
base::OnceCallback<void(std::vector<std::string> passages,
std::vector<Embedding> embeddings,
TaskId task_id,
ComputeEmbeddingsStatus status)>;
virtual TaskId ComputePassagesEmbeddings(
PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) = 0;
// Cancels computation of embeddings iff none of the passages given to
// `ComputePassagesEmbeddings()` has been submitted for embedding yet.
// If successful, the callback for the canceled task will be invoked with
// `ComputeEmbeddingsStatus::kCanceled` status.
virtual bool TryCancel(TaskId task_id) = 0;
protected:
Embedder() = default;
};
} // namespace passage_embeddings
#endif // COMPONENTS_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDINGS_TYPES_H_

@ -1 +1 @@
file://components/optimization_guide/OWNERS
file://components/passage_embeddings/OWNERS