0

Simplify strings returned from the fake on-device model service

The "\n" and "Context: " pieces of the output were mostly noise and
don't meaningfully add anything to test correctness. This change removes
those pieces to make tests written against the fake a bit cleaner.

Bug: 415808003
Change-Id: Id34f394ba24850f2c89219d66d27ed80474b58fa
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6542686
Reviewed-by: Steven Holte <holte@chromium.org>
Commit-Queue: Clark DuVall <cduvall@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1459596}
This commit is contained in:
Clark DuVall
2025-05-13 11:38:28 -07:00
committed by Chromium LUCI CQ
parent 63b4e023e3
commit c5a6b2936d
5 changed files with 136 additions and 170 deletions
chrome/browser/ai
components/optimization_guide/core/model_execution
services/on_device_model

@@ -243,13 +243,12 @@ optimization_guide::proto::OnDeviceModelExecutionFeatureConfig CreateConfig() {
// fake service would look something like this: // fake service would look something like this:
// - s1.Prompt("foo") // - s1.Prompt("foo")
// - Adds "UfooEM" to the session // - Adds "UfooEM" to the session
// - Gets output of ["Context: UfooEM\n"] from fake service // - Gets output of ["UfooEM"] from fake service
// - Adds "Context: UfooEM\nE" to the session (fake response + end token) // - Adds "UfooEME" to the session (fake response + end token)
// - s1.Prompt("bar") // - s1.Prompt("bar")
// - Adds "UbarEM" to the session // - Adds "UbarEM" to the session
// - Gets output of ["Context: UfooEM\n", "Context: Context: UfooEM\nE\n", // - Gets output of ["UfooEM", "UfooEME", "UbarEM"].
// "Context: UbarEM\n"]. // - Adds "UfooEMUfooEMEUbarEM"
// - Adds "Context: UfooEM\nContext: Context: UfooEM\nE\nContext: UbarEM\n"
// (concatenated output from fake service) to the session // (concatenated output from fake service) to the session
// This behavior verifies the correct inputs and outputs are being returned from // This behavior verifies the correct inputs and outputs are being returned from
// the model, and this helper makes it easier to construct these expectations. // the model, and this helper makes it easier to construct these expectations.
@@ -260,10 +259,10 @@ std::vector<std::string> FormatResponses(
std::string last_output; std::string last_output;
for (const std::string& response : responses) { for (const std::string& response : responses) {
if (!last_output.empty()) { if (!last_output.empty()) {
formatted.push_back("Context: " + last_output + "E\n"); formatted.push_back(last_output + "E");
last_output += formatted.back(); last_output += formatted.back();
} }
formatted.push_back("Context: " + response + "\n"); formatted.push_back(response);
last_output += formatted.back(); last_output += formatted.back();
} }
return formatted; return formatted;
@@ -383,7 +382,7 @@ TEST_F(AILanguageModelTest, Append) {
auto session = CreateSession(); auto session = CreateSession();
Append(*session, MakeInput("foo")); Append(*session, MakeInput("foo"));
EXPECT_THAT(Prompt(*session, MakeInput("bar")), EXPECT_THAT(Prompt(*session, MakeInput("bar")),
ElementsAre("Context: UfooE\n", "Context: UbarEM\n")); ElementsAre("UfooE", "UbarEM"));
} }
TEST_F(AILanguageModelTest, PromptTokenCounts) { TEST_F(AILanguageModelTest, PromptTokenCounts) {
@@ -464,9 +463,9 @@ TEST_F(AILanguageModelTest, SamplingParams) {
auto fork = Fork(*session); auto fork = Fork(*session);
EXPECT_THAT(Prompt(*session, MakeInput("foo")), EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAre("Context: UfooEM\n", "TopK: 2, Temp: 1\n")); ElementsAre("UfooEM", "TopK: 2, Temp: 1"));
EXPECT_THAT(Prompt(*fork, MakeInput("bar")), EXPECT_THAT(Prompt(*fork, MakeInput("bar")),
ElementsAre("Context: UbarEM\n", "TopK: 2, Temp: 1\n")); ElementsAre("UbarEM", "TopK: 2, Temp: 1"));
} }
TEST_F(AILanguageModelTest, SamplingParamsTopKOutOfRange) { TEST_F(AILanguageModelTest, SamplingParamsTopKOutOfRange) {
@@ -479,7 +478,7 @@ TEST_F(AILanguageModelTest, SamplingParamsTopKOutOfRange) {
auto session = CreateSession(std::move(options)); auto session = CreateSession(std::move(options));
EXPECT_THAT(Prompt(*session, MakeInput("foo")), EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAre("Context: UfooEM\n", "TopK: 1, Temp: 1.5\n")); ElementsAre("UfooEM", "TopK: 1, Temp: 1.5"));
} }
TEST_F(AILanguageModelTest, SamplingParamsTemperatureOutOfRange) { TEST_F(AILanguageModelTest, SamplingParamsTemperatureOutOfRange) {
@@ -492,7 +491,7 @@ TEST_F(AILanguageModelTest, SamplingParamsTemperatureOutOfRange) {
auto session = CreateSession(std::move(options)); auto session = CreateSession(std::move(options));
EXPECT_THAT(Prompt(*session, MakeInput("foo")), EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAre("Context: UfooEM\n", "TopK: 2, Temp: 0\n")); ElementsAre("UfooEM", "TopK: 2, Temp: 0"));
} }
TEST_F(AILanguageModelTest, MaxSamplingParams) { TEST_F(AILanguageModelTest, MaxSamplingParams) {
@@ -505,7 +504,7 @@ TEST_F(AILanguageModelTest, MaxSamplingParams) {
auto session = CreateSession(std::move(options)); auto session = CreateSession(std::move(options));
EXPECT_THAT(Prompt(*session, MakeInput("foo")), EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAre("Context: UfooEM\n", "TopK: 5, Temp: 1.5\n")); ElementsAre("UfooEM", "TopK: 5, Temp: 1.5"));
} }
TEST_F(AILanguageModelTest, InitialPrompts) { TEST_F(AILanguageModelTest, InitialPrompts) {
@@ -515,7 +514,7 @@ TEST_F(AILanguageModelTest, InitialPrompts) {
auto session = CreateSession(std::move(options)); auto session = CreateSession(std::move(options));
EXPECT_THAT(Prompt(*session, MakeInput("foo")), EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAre("Context: ShiEUbyeE\n", "Context: UfooEM\n")); ElementsAre("ShiEUbyeE", "UfooEM"));
} }
TEST_F(AILanguageModelTest, InitialPromptsInstanceInfo) { TEST_F(AILanguageModelTest, InitialPromptsInstanceInfo) {
@@ -591,8 +590,7 @@ TEST_F(AILanguageModelTest, QuotaOverflowOnPromptInput) {
// Response should include input/output of previous prompt with the original // Response should include input/output of previous prompt with the original
// long prompt not present. // long prompt not present.
EXPECT_THAT(responder.responses(), EXPECT_THAT(responder.responses(),
ElementsAre("Context: SinitE\n", "Context: UfooEMhiE\n", ElementsAre("SinitE", "UfooEMhiE", "U" + long_prompt + "EM"));
"Context: U" + long_prompt + "EM\n"));
} }
TEST_F(AILanguageModelTest, QuotaOverflowOnAppend) { TEST_F(AILanguageModelTest, QuotaOverflowOnAppend) {
@@ -609,10 +607,8 @@ TEST_F(AILanguageModelTest, QuotaOverflowOnAppend) {
responder.WaitForQuotaOverflow(); responder.WaitForQuotaOverflow();
EXPECT_TRUE(responder.WaitForCompletion()); EXPECT_TRUE(responder.WaitForCompletion());
EXPECT_THAT( EXPECT_THAT(Prompt(*session, MakeInput("foo")),
Prompt(*session, MakeInput("foo")), ElementsAre("SinitE", "U" + long_prompt + "E", "UfooEM"));
ElementsAre("Context: SinitE\n", "Context: U" + long_prompt + "E\n",
"Context: UfooEM\n"));
} }
TEST_F(AILanguageModelTest, QuotaOverflowOnOutput) { TEST_F(AILanguageModelTest, QuotaOverflowOnOutput) {
@@ -643,8 +639,7 @@ TEST_F(AILanguageModelTest, QuotaOverflowOnOutput) {
// - "bar" from the current prompt call // - "bar" from the current prompt call
fake_broker_.settings().set_execute_result({}); fake_broker_.settings().set_execute_result({});
EXPECT_THAT(Prompt(*session, MakeInput("bar")), EXPECT_THAT(Prompt(*session, MakeInput("bar")),
ElementsAre("Context: UfooEM" + long_response + "E\n", ElementsAre("UfooEM" + long_response + "E", "UbarEM"));
"Context: UbarEM\n"));
} }
TEST_F(AILanguageModelTest, Destroy) { TEST_F(AILanguageModelTest, Destroy) {
@@ -949,7 +944,7 @@ TEST_F(AILanguageModelTest, Constraint) {
EXPECT_THAT( EXPECT_THAT(
Prompt(*session, MakeInput("foo"), Prompt(*session, MakeInput("foo"),
on_device_model::mojom::ResponseConstraint::NewRegex("reg")), on_device_model::mojom::ResponseConstraint::NewRegex("reg")),
ElementsAre("Constraint: regex reg\n", "Context: UfooEM\n")); ElementsAre("Constraint: regex reg", "UfooEM"));
} }
TEST_F(AILanguageModelTest, ServiceCrash) { TEST_F(AILanguageModelTest, ServiceCrash) {
@@ -1112,11 +1107,11 @@ TEST_F(AILanguageModelTest, Priority) {
main_rfh()->GetRenderWidgetHost()->GetView()->Hide(); main_rfh()->GetRenderWidgetHost()->GetView()->Hide();
EXPECT_THAT(Prompt(*session, MakeInput("bar")), EXPECT_THAT(Prompt(*session, MakeInput("bar")),
ElementsAre("Priority: background\n", "hi")); ElementsAre("Priority: background", "hi"));
auto fork = Fork(*session); auto fork = Fork(*session);
EXPECT_THAT(Prompt(*fork, MakeInput("bar")), EXPECT_THAT(Prompt(*fork, MakeInput("bar")),
ElementsAre("Priority: background\n", "hi")); ElementsAre("Priority: background", "hi"));
main_rfh()->GetRenderWidgetHost()->GetView()->Show(); main_rfh()->GetRenderWidgetHost()->GetView()->Show();
EXPECT_THAT(Prompt(*session, MakeInput("baz")), ElementsAre("hi")); EXPECT_THAT(Prompt(*session, MakeInput("baz")), ElementsAre("hi"));

@@ -330,7 +330,7 @@ TEST_F(OnDeviceModelServiceControllerTest, BaseModelExecutionSuccess) {
session->ExecuteModel(PageUrlRequest("foo"), session->ExecuteModel(PageUrlRequest("foo"),
response_.GetStreamingCallback()); response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus()); ASSERT_TRUE(response_.GetFinalStatus());
const std::string expected_response = "Context: execute:foo max:1024\n"; const std::string expected_response = "execute:foo max:1024";
EXPECT_EQ(*response_.value(), expected_response); EXPECT_EQ(*response_.value(), expected_response);
EXPECT_TRUE(*response_.provided_by_on_device()); EXPECT_TRUE(*response_.provided_by_on_device());
EXPECT_THAT(response_.partials(), ElementsAre(expected_response)); EXPECT_THAT(response_.partials(), ElementsAre(expected_response));
@@ -453,8 +453,7 @@ TEST_F(OnDeviceModelServiceControllerTest, CacheWeightExecutionSuccess) {
session->ExecuteModel(PageUrlRequest("foo"), session->ExecuteModel(PageUrlRequest("foo"),
response_.GetStreamingCallback()); response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus()); ASSERT_TRUE(response_.GetFinalStatus());
EXPECT_EQ(*response_.value(), EXPECT_EQ(*response_.value(), "Cache weight: 1015execute:foo max:1024");
"Cache weight: 1015\nContext: execute:foo max:1024\n");
// If we destroy all sessions and wait long enough, everything should idle out // If we destroy all sessions and wait long enough, everything should idle out
// and the service should get terminated. // and the service should get terminated.
@@ -481,8 +480,7 @@ TEST_F(OnDeviceModelServiceControllerTest, AdaptationModelExecutionSuccess) {
session->ExecuteModel(PageUrlRequest("foo"), session->ExecuteModel(PageUrlRequest("foo"),
response_.GetStreamingCallback()); response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus()); ASSERT_TRUE(response_.GetFinalStatus());
EXPECT_EQ(*response_.value(), EXPECT_EQ(*response_.value(), "Adaptation model: 1015execute:foo max:1024");
"Adaptation model: 1015\nContext: execute:foo max:1024\n");
// If we destroy all sessions and wait long enough, everything should idle out // If we destroy all sessions and wait long enough, everything should idle out
// and the service should get terminated. // and the service should get terminated.
@@ -535,11 +533,11 @@ TEST_F(OnDeviceModelServiceControllerTest,
ASSERT_TRUE(compose_response.GetFinalStatus()); ASSERT_TRUE(compose_response.GetFinalStatus());
EXPECT_EQ(*compose_response.value(), EXPECT_EQ(*compose_response.value(),
"Adaptation model: 1015\nContext: execute:foo max:1024\n"); "Adaptation model: 1015execute:foo max:1024");
EXPECT_TRUE(*compose_response.provided_by_on_device()); EXPECT_TRUE(*compose_response.provided_by_on_device());
ASSERT_TRUE(test_response.GetFinalStatus()); ASSERT_TRUE(test_response.GetFinalStatus());
EXPECT_EQ(*test_response.value(), EXPECT_EQ(*test_response.value(),
"Adaptation model: 2024\nContext: execute:bar max:1024\n"); "Adaptation model: 2024execute:bar max:1024");
EXPECT_TRUE(*test_response.provided_by_on_device()); EXPECT_TRUE(*test_response.provided_by_on_device());
session_compose.reset(); session_compose.reset();
@@ -594,10 +592,10 @@ TEST_F(OnDeviceModelServiceControllerTest, ModelAdaptationAndBaseModelSuccess) {
ASSERT_TRUE(compose_response.GetFinalStatus()); ASSERT_TRUE(compose_response.GetFinalStatus());
EXPECT_EQ(*compose_response.value(), EXPECT_EQ(*compose_response.value(),
"Adaptation model: 1015\nContext: execute:foo max:1024\n"); "Adaptation model: 1015execute:foo max:1024");
EXPECT_TRUE(*compose_response.provided_by_on_device()); EXPECT_TRUE(*compose_response.provided_by_on_device());
ASSERT_TRUE(test_response.GetFinalStatus()); ASSERT_TRUE(test_response.GetFinalStatus());
EXPECT_EQ(*test_response.value(), "Context: execute:bar max:1024\n"); EXPECT_EQ(*test_response.value(), "execute:bar max:1024");
EXPECT_TRUE(*test_response.provided_by_on_device()); EXPECT_TRUE(*test_response.provided_by_on_device());
session_compose.reset(); session_compose.reset();
@@ -719,8 +717,7 @@ TEST_F(OnDeviceModelServiceControllerTest, SessionBeforeAndAfterModelUpdate) {
session2->ExecuteModel(PageUrlRequest("foo"), session2->ExecuteModel(PageUrlRequest("foo"),
response2.GetStreamingCallback()); response2.GetStreamingCallback());
ASSERT_TRUE(response2.GetFinalStatus()); ASSERT_TRUE(response2.GetFinalStatus());
EXPECT_EQ(*response2.value(), EXPECT_EQ(*response2.value(), "Base model: 2execute:foo max:1024");
"Base model: 2\nContext: execute:foo max:1024\n");
} }
TEST_F(OnDeviceModelServiceControllerTest, SessionFailsForInvalidFeature) { TEST_F(OnDeviceModelServiceControllerTest, SessionFailsForInvalidFeature) {
@@ -1636,7 +1633,7 @@ TEST_F(OnDeviceModelServiceControllerTest, CancelsExecuteOnExecute) {
EXPECT_EQ( EXPECT_EQ(
*resp1.error(), *resp1.error(),
OptimizationGuideModelExecutionError::ModelExecutionError::kCancelled); OptimizationGuideModelExecutionError::ModelExecutionError::kCancelled);
EXPECT_EQ(*resp2.value(), "Context: execute:bar max:1024\n"); EXPECT_EQ(*resp2.value(), "execute:bar max:1024");
} }
TEST_F(OnDeviceModelServiceControllerTest, WontStartSessionAfterGpuBlocked) { TEST_F(OnDeviceModelServiceControllerTest, WontStartSessionAfterGpuBlocked) {
@@ -1822,8 +1819,8 @@ TEST_F(OnDeviceModelServiceControllerTest, AddContextDisconnectExecute) {
"OptimizationGuide.ModelExecution.OnDeviceExecuteModelResult.Compose", "OptimizationGuide.ModelExecution.OnDeviceExecuteModelResult.Compose",
ExecuteModelResult::kUsedOnDevice, 1); ExecuteModelResult::kUsedOnDevice, 1);
std::string expected_response = std::string expected_response =
("Context: ctx:foo max:8192\n" ("ctx:foo max:8192"
"Context: execute:foobaz max:1024\n"); "execute:foobaz max:1024");
EXPECT_EQ(*response_.value(), expected_response); EXPECT_EQ(*response_.value(), expected_response);
} }
@@ -1861,16 +1858,16 @@ TEST_F(OnDeviceModelServiceControllerTest, AddContextMultipleSessions) {
session2->ExecuteModel(PageUrlRequest("2"), response_.GetStreamingCallback()); session2->ExecuteModel(PageUrlRequest("2"), response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus()); ASSERT_TRUE(response_.GetFinalStatus());
std::string expected_response1 = std::string expected_response1 =
("Context: ctx:bar max:8192\n" ("ctx:bar max:8192"
"Context: execute:bar2 max:1024\n"); "execute:bar2 max:1024");
EXPECT_EQ(*response_.value(), expected_response1); EXPECT_EQ(*response_.value(), expected_response1);
ResponseHolder response2; ResponseHolder response2;
session1->ExecuteModel(PageUrlRequest("1"), response2.GetStreamingCallback()); session1->ExecuteModel(PageUrlRequest("1"), response2.GetStreamingCallback());
ASSERT_TRUE(response2.GetFinalStatus()); ASSERT_TRUE(response2.GetFinalStatus());
std::string expected_response2 = std::string expected_response2 =
("Context: ctx:foo max:8192\n" ("ctx:foo max:8192"
"Context: execute:foo1 max:1024\n"); "execute:foo1 max:1024");
EXPECT_EQ(*response2.value(), expected_response2); EXPECT_EQ(*response2.value(), expected_response2);
} }
@@ -2067,7 +2064,7 @@ TEST_F(OnDeviceModelServiceControllerTest, RedactedField) {
session1->ExecuteModel(UserInputRequest("foo"), session1->ExecuteModel(UserInputRequest("foo"),
response_.GetStreamingCallback()); response_.GetStreamingCallback());
task_environment_.RunUntilIdle(); task_environment_.RunUntilIdle();
const std::string expected_response1 = "Context: execute:foo max:1024\n"; const std::string expected_response1 = "execute:foo max:1024";
EXPECT_EQ(*response_.value(), expected_response1); EXPECT_EQ(*response_.value(), expected_response1);
EXPECT_THAT(response_.partials(), IsEmpty()); EXPECT_THAT(response_.partials(), IsEmpty());
@@ -2078,19 +2075,19 @@ TEST_F(OnDeviceModelServiceControllerTest, RedactedField) {
session2->ExecuteModel(UserInputRequest("abarx"), session2->ExecuteModel(UserInputRequest("abarx"),
response2.GetStreamingCallback()); response2.GetStreamingCallback());
task_environment_.RunUntilIdle(); task_environment_.RunUntilIdle();
const std::string expected_response2 = "Context: execute:abarx max:1024\n"; const std::string expected_response2 = "execute:abarx max:1024";
EXPECT_EQ(*response2.value(), expected_response2); EXPECT_EQ(*response2.value(), expected_response2);
EXPECT_THAT(response2.partials(), IsEmpty()); EXPECT_THAT(response2.partials(), IsEmpty());
// Output contains redacted text (and input doesn't), so redact. // Output contains redacted text (and input doesn't), so redact.
fake_settings_.set_execute_result({"Context: abarx max:1024\n"}); fake_settings_.set_execute_result({"abarx max:1024"});
auto session3 = CreateSession(); auto session3 = CreateSession();
ASSERT_TRUE(session3); ASSERT_TRUE(session3);
ResponseHolder response3; ResponseHolder response3;
session3->ExecuteModel(UserInputRequest("foo"), session3->ExecuteModel(UserInputRequest("foo"),
response3.GetStreamingCallback()); response3.GetStreamingCallback());
task_environment_.RunUntilIdle(); task_environment_.RunUntilIdle();
const std::string expected_response3 = "Context: a[###]x max:1024\n"; const std::string expected_response3 = "a[###]x max:1024";
EXPECT_EQ(*response3.value(), expected_response3); EXPECT_EQ(*response3.value(), expected_response3);
EXPECT_THAT(response3.partials(), IsEmpty()); EXPECT_THAT(response3.partials(), IsEmpty());
} }
@@ -2151,7 +2148,7 @@ TEST_F(OnDeviceModelServiceControllerTest, UsePreviousResponseForRewrite) {
}); });
// Force 'bar' to be returned from model. // Force 'bar' to be returned from model.
fake_settings_.set_execute_result({"Context: bar max:1024\n"}); fake_settings_.set_execute_result({"bar max:1024"});
auto session = CreateSession(); auto session = CreateSession();
ASSERT_TRUE(session); ASSERT_TRUE(session);
@@ -2160,7 +2157,7 @@ TEST_F(OnDeviceModelServiceControllerTest, UsePreviousResponseForRewrite) {
response_.GetStreamingCallback()); response_.GetStreamingCallback());
task_environment_.RunUntilIdle(); task_environment_.RunUntilIdle();
// `bar` shouldn't be rewritten as it's in the input. // `bar` shouldn't be rewritten as it's in the input.
const std::string expected_response = "Context: bar max:1024\n"; const std::string expected_response = "bar max:1024";
EXPECT_EQ(*response_.value(), expected_response); EXPECT_EQ(*response_.value(), expected_response);
EXPECT_THAT(response_.partials(), IsEmpty()); EXPECT_THAT(response_.partials(), IsEmpty());
} }
@@ -2177,13 +2174,13 @@ TEST_F(OnDeviceModelServiceControllerTest, ReplacementText) {
}); });
// Output contains redacted text (and input doesn't), so redact. // Output contains redacted text (and input doesn't), so redact.
fake_settings_.set_execute_result({"Context: abarx max:1024\n"}); fake_settings_.set_execute_result({"abarx max:1024"});
auto session = CreateSession(); auto session = CreateSession();
ASSERT_TRUE(session); ASSERT_TRUE(session);
session->ExecuteModel(UserInputRequest("foo"), session->ExecuteModel(UserInputRequest("foo"),
response_.GetStreamingCallback()); response_.GetStreamingCallback());
task_environment_.RunUntilIdle(); task_environment_.RunUntilIdle();
const std::string expected_response = "Context: a[redacted]x max:1024\n"; const std::string expected_response = "a[redacted]x max:1024";
EXPECT_EQ(*response_.value(), expected_response); EXPECT_EQ(*response_.value(), expected_response);
EXPECT_THAT(response_.partials(), IsEmpty()); EXPECT_THAT(response_.partials(), IsEmpty());
} }
@@ -2435,8 +2432,8 @@ TEST_F(OnDeviceModelServiceControllerTest, UsesSessionTopKAndTemperature) {
task_environment_.RunUntilIdle(); task_environment_.RunUntilIdle();
EXPECT_TRUE(response_.value()); EXPECT_TRUE(response_.value());
const std::vector<std::string> partial_responses = { const std::vector<std::string> partial_responses = {
"Context: execute:foo max:1024\n", "execute:foo max:1024",
"TopK: 3, Temp: 2\n", "TopK: 3, Temp: 2",
}; };
EXPECT_EQ(*response_.value(), ConcatResponses(partial_responses)); EXPECT_EQ(*response_.value(), ConcatResponses(partial_responses));
EXPECT_THAT(response_.partials(), ElementsAreArray(partial_responses)); EXPECT_THAT(response_.partials(), ElementsAreArray(partial_responses));
@@ -3238,8 +3235,7 @@ TEST_F(OnDeviceModelServiceControllerTest, SendsPerformanceHint) {
session->ExecuteModel(PageUrlRequest("foo"), session->ExecuteModel(PageUrlRequest("foo"),
response_.GetStreamingCallback()); response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus()); ASSERT_TRUE(response_.GetFinalStatus());
EXPECT_EQ(*response_.value(), EXPECT_EQ(*response_.value(), "Fastest inferenceexecute:foo max:1024");
"Fastest inference\nContext: execute:foo max:1024\n");
} }
TEST_F(OnDeviceModelServiceControllerTest, ImageExecutionSuccess) { TEST_F(OnDeviceModelServiceControllerTest, ImageExecutionSuccess) {
@@ -3296,8 +3292,7 @@ TEST_F(OnDeviceModelServiceControllerTest, ImageExecutionSuccess) {
session->ExecuteModel(proto::ExampleForTestingRequest(), session->ExecuteModel(proto::ExampleForTestingRequest(),
response.GetStreamingCallback()); response.GetStreamingCallback());
ASSERT_TRUE(response.GetFinalStatus()); ASSERT_TRUE(response.GetFinalStatus());
EXPECT_EQ(*response.value(), EXPECT_EQ(*response.value(), "<image> max:22<image> max:1024");
"Context: <image> max:22\nContext: <image> max:1024\n");
} }
// Session without capabilities should not allow images. // Session without capabilities should not allow images.
@@ -3310,8 +3305,8 @@ TEST_F(OnDeviceModelServiceControllerTest, ImageExecutionSuccess) {
response.GetStreamingCallback()); response.GetStreamingCallback());
ASSERT_TRUE(response.GetFinalStatus()); ASSERT_TRUE(response.GetFinalStatus());
EXPECT_EQ(*response.value(), EXPECT_EQ(*response.value(),
"Context: <unsupported> max:22\nContext: <unsupported> " "<unsupported> max:22<unsupported> "
"max:1024\n"); "max:1024");
} }
} }
@@ -3403,8 +3398,7 @@ TEST_F(OnDeviceModelServiceControllerTest, KeepInputOnExtension) {
altered_clone->ExecuteModel(proto::ExampleForTestingRequest(), altered_clone->ExecuteModel(proto::ExampleForTestingRequest(),
altered_response.GetStreamingCallback()); altered_response.GetStreamingCallback());
ASSERT_TRUE(altered_response.GetFinalStatus()); ASSERT_TRUE(altered_response.GetFinalStatus());
EXPECT_EQ(*altered_response.value(), EXPECT_EQ(*altered_response.value(), "v1<image><audio>v2v3v4 max:22");
"Context: v1<image><audio>v2v3v4 max:22\n");
// The clone that only extended should have sent input in separate chunks. // The clone that only extended should have sent input in separate chunks.
ResponseHolder extended_response; ResponseHolder extended_response;
@@ -3412,19 +3406,19 @@ TEST_F(OnDeviceModelServiceControllerTest, KeepInputOnExtension) {
extended_response.GetStreamingCallback()); extended_response.GetStreamingCallback());
ASSERT_TRUE(extended_response.GetFinalStatus()); ASSERT_TRUE(extended_response.GetFinalStatus());
EXPECT_EQ(*extended_response.value(), EXPECT_EQ(*extended_response.value(),
"Context: v1<image><audio> max:22\n" "v1<image><audio> max:22"
"Context: v2 max:22\n" "v2 max:22"
"Context: v3 max:4\n" "v3 max:4"
"Context: v4 max:4\n"); "v4 max:4");
// The original should have input in separate chunks. // The original should have input in separate chunks.
session->ExecuteModel(proto::ExampleForTestingRequest(), session->ExecuteModel(proto::ExampleForTestingRequest(),
response_.GetStreamingCallback()); response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus()); ASSERT_TRUE(response_.GetFinalStatus());
EXPECT_EQ(*response_.value(), EXPECT_EQ(*response_.value(),
"Context: v1<image><audio> max:22\n" "v1<image><audio> max:22"
"Context: v2 max:22\n" "v2 max:22"
"Context: v3 max:4\n"); "v3 max:4");
} }
TEST_F(OnDeviceModelServiceControllerTest, OmitEmptyInputs) { TEST_F(OnDeviceModelServiceControllerTest, OmitEmptyInputs) {
@@ -3495,8 +3489,8 @@ TEST_F(OnDeviceModelServiceControllerTest, CloneUsesSessionTopKAndTemperature) {
task_environment_.RunUntilIdle(); task_environment_.RunUntilIdle();
EXPECT_TRUE(response_.value()); EXPECT_TRUE(response_.value());
const std::vector<std::string> partial_responses = { const std::vector<std::string> partial_responses = {
"Context: execute:foo max:1024\n", "execute:foo max:1024",
"TopK: 3, Temp: 2\n", "TopK: 3, Temp: 2",
}; };
EXPECT_EQ(*response_.value(), ConcatResponses(partial_responses)); EXPECT_EQ(*response_.value(), ConcatResponses(partial_responses));
EXPECT_THAT(response_.partials(), ElementsAreArray(partial_responses)); EXPECT_THAT(response_.partials(), ElementsAreArray(partial_responses));
@@ -3586,8 +3580,8 @@ TEST_F(OnDeviceModelServiceControllerTest, AddContextAndClone) {
clone->ExecuteModel(PageUrlRequest("bar"), response.GetStreamingCallback()); clone->ExecuteModel(PageUrlRequest("bar"), response.GetStreamingCallback());
ASSERT_TRUE(response.GetFinalStatus()); ASSERT_TRUE(response.GetFinalStatus());
std::string expected_response = std::string expected_response =
("Context: ctx:foo max:8192\n" ("ctx:foo max:8192"
"Context: execute:foobar max:1024\n"); "execute:foobar max:1024");
EXPECT_EQ(*response.value(), expected_response); EXPECT_EQ(*response.value(), expected_response);
} }
@@ -3598,8 +3592,8 @@ TEST_F(OnDeviceModelServiceControllerTest, AddContextAndClone) {
response.GetStreamingCallback()); response.GetStreamingCallback());
ASSERT_TRUE(response.GetFinalStatus()); ASSERT_TRUE(response.GetFinalStatus());
std::string expected_response = std::string expected_response =
("Context: ctx:foo max:8192\n" ("ctx:foo max:8192"
"Context: execute:fooblah max:1024\n"); "execute:fooblah max:1024");
EXPECT_EQ(*response.value(), expected_response); EXPECT_EQ(*response.value(), expected_response);
} }
} }
@@ -3618,7 +3612,7 @@ TEST_F(OnDeviceModelServiceControllerTest, CloneBeforeAddContext) {
ResponseHolder response; ResponseHolder response;
clone->ExecuteModel(PageUrlRequest("bar"), response.GetStreamingCallback()); clone->ExecuteModel(PageUrlRequest("bar"), response.GetStreamingCallback());
ASSERT_TRUE(response.GetFinalStatus()); ASSERT_TRUE(response.GetFinalStatus());
EXPECT_EQ(*response.value(), "Context: execute:bar max:1024\n"); EXPECT_EQ(*response.value(), "execute:bar max:1024");
} }
// Original session should execute with context // Original session should execute with context
@@ -3628,8 +3622,8 @@ TEST_F(OnDeviceModelServiceControllerTest, CloneBeforeAddContext) {
response.GetStreamingCallback()); response.GetStreamingCallback());
ASSERT_TRUE(response.GetFinalStatus()); ASSERT_TRUE(response.GetFinalStatus());
std::string expected_response = std::string expected_response =
("Context: ctx:foo max:8192\n" ("ctx:foo max:8192"
"Context: execute:fooblah max:1024\n"); "execute:fooblah max:1024");
EXPECT_EQ(*response.value(), expected_response); EXPECT_EQ(*response.value(), expected_response);
} }
} }
@@ -3647,7 +3641,7 @@ TEST_F(OnDeviceModelServiceControllerTest, CancelAddContextAndClone) {
ResponseHolder response; ResponseHolder response;
clone->ExecuteModel(PageUrlRequest("bar"), response.GetStreamingCallback()); clone->ExecuteModel(PageUrlRequest("bar"), response.GetStreamingCallback());
ASSERT_TRUE(response.GetFinalStatus()); ASSERT_TRUE(response.GetFinalStatus());
EXPECT_EQ(*response.value(), "Context: execute:foobar max:1024\n"); EXPECT_EQ(*response.value(), "execute:foobar max:1024");
} }
TEST_F(OnDeviceModelServiceControllerTest, CloneAddContextDisconnectExecute) { TEST_F(OnDeviceModelServiceControllerTest, CloneAddContextDisconnectExecute) {
@@ -3666,8 +3660,8 @@ TEST_F(OnDeviceModelServiceControllerTest, CloneAddContextDisconnectExecute) {
clone->ExecuteModel(PageUrlRequest("bar"), response.GetStreamingCallback()); clone->ExecuteModel(PageUrlRequest("bar"), response.GetStreamingCallback());
ASSERT_TRUE(response.GetFinalStatus()); ASSERT_TRUE(response.GetFinalStatus());
std::string expected_response = std::string expected_response =
("Context: ctx:foo max:8192\n" ("ctx:foo max:8192"
"Context: execute:foobar max:1024\n"); "execute:foobar max:1024");
EXPECT_EQ(*response.value(), expected_response); EXPECT_EQ(*response.value(), expected_response);
} }
@@ -3692,7 +3686,7 @@ TEST_F(OnDeviceModelServiceControllerTest, Broker) {
ResponseHolder response; ResponseHolder response;
session->ExecuteModel(PageUrlRequest("bar"), response.GetStreamingCallback()); session->ExecuteModel(PageUrlRequest("bar"), response.GetStreamingCallback());
ASSERT_TRUE(response.GetFinalStatus()); ASSERT_TRUE(response.GetFinalStatus());
EXPECT_EQ(*response.value(), "Context: execute:bar max:1024\n"); EXPECT_EQ(*response.value(), "execute:bar max:1024");
} }
TEST_F(OnDeviceModelServiceControllerTest, TEST_F(OnDeviceModelServiceControllerTest,
@@ -3724,16 +3718,16 @@ TEST_F(OnDeviceModelServiceControllerTest, Priority) {
auto session = CreateSession(); auto session = CreateSession();
EXPECT_TRUE(session); EXPECT_TRUE(session);
EXPECT_EQ(GetResponse(*session, "foo"), "Context: execute:foo max:1024\n"); EXPECT_EQ(GetResponse(*session, "foo"), "execute:foo max:1024");
session->SetPriority(on_device_model::mojom::Priority::kBackground); session->SetPriority(on_device_model::mojom::Priority::kBackground);
EXPECT_EQ(GetResponse(*session, "foo"), EXPECT_EQ(GetResponse(*session, "foo"),
"Priority: background\nContext: execute:foo max:1024\n"); "Priority: backgroundexecute:foo max:1024");
EXPECT_EQ(GetResponse(*session, "foo"), EXPECT_EQ(GetResponse(*session, "foo"),
"Priority: background\nContext: execute:foo max:1024\n"); "Priority: backgroundexecute:foo max:1024");
session->SetPriority(on_device_model::mojom::Priority::kForeground); session->SetPriority(on_device_model::mojom::Priority::kForeground);
EXPECT_EQ(GetResponse(*session, "foo"), "Context: execute:foo max:1024\n"); EXPECT_EQ(GetResponse(*session, "foo"), "execute:foo max:1024");
} }
TEST_F(OnDeviceModelServiceControllerTest, PriorityClone) { TEST_F(OnDeviceModelServiceControllerTest, PriorityClone) {
@@ -3742,17 +3736,17 @@ TEST_F(OnDeviceModelServiceControllerTest, PriorityClone) {
auto session = CreateSession(); auto session = CreateSession();
EXPECT_TRUE(session); EXPECT_TRUE(session);
EXPECT_EQ(GetResponse(*session, "foo"), "Context: execute:foo max:1024\n"); EXPECT_EQ(GetResponse(*session, "foo"), "execute:foo max:1024");
session->SetPriority(on_device_model::mojom::Priority::kBackground); session->SetPriority(on_device_model::mojom::Priority::kBackground);
EXPECT_EQ(GetResponse(*session, "foo"), EXPECT_EQ(GetResponse(*session, "foo"),
"Priority: background\nContext: execute:foo max:1024\n"); "Priority: backgroundexecute:foo max:1024");
auto clone = session->Clone(); auto clone = session->Clone();
EXPECT_EQ(GetResponse(*clone, "foo"), EXPECT_EQ(GetResponse(*clone, "foo"),
"Priority: background\nContext: execute:foo max:1024\n"); "Priority: backgroundexecute:foo max:1024");
EXPECT_EQ(GetResponse(*clone, "foo"), EXPECT_EQ(GetResponse(*clone, "foo"),
"Priority: background\nContext: execute:foo max:1024\n"); "Priority: backgroundexecute:foo max:1024");
} }
TEST_F(OnDeviceModelServiceControllerTest, SetInputCallback) { TEST_F(OnDeviceModelServiceControllerTest, SetInputCallback) {
@@ -3772,8 +3766,8 @@ TEST_F(OnDeviceModelServiceControllerTest, SetInputCallback) {
response_.GetStreamingCallback()); response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus()); ASSERT_TRUE(response_.GetFinalStatus());
EXPECT_EQ(response_.value(), EXPECT_EQ(response_.value(),
"Context: ctx:foo max:8192\nContext: execute:foobar " "ctx:foo max:8192execute:foobar "
"max:1024\n"); "max:1024");
} }
TEST_F(OnDeviceModelServiceControllerTest, SetInputCallbackCancelled) { TEST_F(OnDeviceModelServiceControllerTest, SetInputCallbackCancelled) {
@@ -3802,8 +3796,8 @@ TEST_F(OnDeviceModelServiceControllerTest, SetInputCallbackCancelled) {
response_.GetStreamingCallback()); response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus()); ASSERT_TRUE(response_.GetFinalStatus());
EXPECT_EQ(response_.value(), EXPECT_EQ(response_.value(),
"Context: ctx:foo max:8192\nContext: execute:foobar " "ctx:foo max:8192execute:foobar "
"max:1024\n"); "max:1024");
} }
TEST_F(OnDeviceModelServiceControllerTest, SetInputCallbackError) { TEST_F(OnDeviceModelServiceControllerTest, SetInputCallbackError) {
@@ -3831,7 +3825,7 @@ TEST_F(OnDeviceModelServiceControllerTest, TokenCounts) {
session->ExecuteModel(PageUrlRequest("foo"), session->ExecuteModel(PageUrlRequest("foo"),
response_.GetStreamingCallback()); response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus()); ASSERT_TRUE(response_.GetFinalStatus());
EXPECT_EQ(response_.value(), "Context: execute:foo max:1024\n"); EXPECT_EQ(response_.value(), "execute:foo max:1024");
EXPECT_EQ(response_.input_token_count(), strlen("execute:foo")); EXPECT_EQ(response_.input_token_count(), strlen("execute:foo"));
EXPECT_EQ(response_.output_token_count(), strlen("execute:foo max:1024")); EXPECT_EQ(response_.output_token_count(), strlen("execute:foo max:1024"));
} }
@@ -3848,8 +3842,8 @@ TEST_F(OnDeviceModelServiceControllerTest, ResponseConstraintOnExecute) {
response_.GetStreamingCallback()); response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus()); ASSERT_TRUE(response_.GetFinalStatus());
EXPECT_EQ(response_.value(), EXPECT_EQ(response_.value(),
"Constraint: regex [A-Z]*\n" "Constraint: regex [A-Z]*"
"Context: execute:input max:1024\n"); "execute:input max:1024");
} }
TEST_F(OnDeviceModelServiceControllerTest, ResponseConstraintConfigJson) { TEST_F(OnDeviceModelServiceControllerTest, ResponseConstraintConfigJson) {
@@ -3880,8 +3874,8 @@ TEST_F(OnDeviceModelServiceControllerTest, ResponseConstraintConfigJson) {
response_.GetStreamingCallback()); response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus()); ASSERT_TRUE(response_.GetFinalStatus());
EXPECT_EQ(response_.value(), EXPECT_EQ(response_.value(),
"Constraint: json { type: \"object\"}\n" "Constraint: json { type: \"object\"}"
"Context: execute:input max:1024\n"); "execute:input max:1024");
} }
TEST_F(OnDeviceModelServiceControllerTest, ResponseConstraintConfigRegex) { TEST_F(OnDeviceModelServiceControllerTest, ResponseConstraintConfigRegex) {
@@ -3912,8 +3906,8 @@ TEST_F(OnDeviceModelServiceControllerTest, ResponseConstraintConfigRegex) {
response_.GetStreamingCallback()); response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus()); ASSERT_TRUE(response_.GetFinalStatus());
EXPECT_EQ(response_.value(), EXPECT_EQ(response_.value(),
"Constraint: regex [A-Z]*\n" "Constraint: regex [A-Z]*"
"Context: execute:input max:1024\n"); "execute:input max:1024");
} }
} // namespace optimization_guide } // namespace optimization_guide

@@ -279,7 +279,7 @@ bool SessionGenerate(ChromeMLSession session,
if (instance->model_instance->performance_hint == if (instance->model_instance->performance_hint ==
ml::ModelPerformanceHint::kFastestInference) { ml::ModelPerformanceHint::kFastestInference) {
OutputChunk("Fastest inference\n"); OutputChunk("Fastest inference");
} }
if (!instance->adaptation_data.empty()) { if (!instance->adaptation_data.empty()) {
std::string adaptation_str = "Adaptation: " + instance->adaptation_data; std::string adaptation_str = "Adaptation: " + instance->adaptation_data;
@@ -287,19 +287,19 @@ bool SessionGenerate(ChromeMLSession session,
adaptation_str += adaptation_str +=
" (" + base::NumberToString(*instance->adaptation_file_id) + ")"; " (" + base::NumberToString(*instance->adaptation_file_id) + ")";
} }
OutputChunk(adaptation_str + "\n"); OutputChunk(adaptation_str);
} }
// Only include sampling params if they're not the respective default values. // Only include sampling params if they're not the respective default values.
if (instance->top_k != 1 || instance->temperature != 0) { if (instance->top_k != 1 || instance->temperature != 0) {
OutputChunk(base::StrCat( OutputChunk(base::StrCat(
{"TopK: ", base::NumberToString(instance->top_k), {"TopK: ", base::NumberToString(instance->top_k),
", Temp: ", base::NumberToString(instance->temperature), "\n"})); ", Temp: ", base::NumberToString(instance->temperature)}));
} }
if (!instance->context.empty()) { if (!instance->context.empty()) {
for (const std::string& context : instance->context) { for (const std::string& context : instance->context) {
OutputChunk("Context: " + context + "\n"); OutputChunk(context);
} }
} }
if (options->constraint) { if (options->constraint) {

@@ -179,9 +179,9 @@ class OnDeviceModelServiceTest : public testing::Test {
TEST_F(OnDeviceModelServiceTest, Responds) { TEST_F(OnDeviceModelServiceTest, Responds) {
auto model = LoadModel(); auto model = LoadModel();
EXPECT_THAT(GetResponses(*model, "bar"), ElementsAre("Context: bar\n")); EXPECT_THAT(GetResponses(*model, "bar"), ElementsAre("bar"));
// Try another input on the same model. // Try another input on the same model.
EXPECT_THAT(GetResponses(*model, "cat"), ElementsAre("Context: cat\n")); EXPECT_THAT(GetResponses(*model, "cat"), ElementsAre("cat"));
} }
TEST_F(OnDeviceModelServiceTest, Append) { TEST_F(OnDeviceModelServiceTest, Append) {
@@ -196,9 +196,7 @@ TEST_F(OnDeviceModelServiceTest, Append) {
session->Generate(mojom::GenerateOptions::New(), response.BindRemote()); session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion(); response.WaitForCompletion();
EXPECT_THAT(response.responses(), EXPECT_THAT(response.responses(), ElementsAre("cheese", "more", "cheddar"));
ElementsAre("Context: cheese\n", "Context: more\n",
"Context: cheddar\n"));
} }
TEST_F(OnDeviceModelServiceTest, PerSessionSamplingParams) { TEST_F(OnDeviceModelServiceTest, PerSessionSamplingParams) {
@@ -221,8 +219,7 @@ TEST_F(OnDeviceModelServiceTest, PerSessionSamplingParams) {
response.WaitForCompletion(); response.WaitForCompletion();
EXPECT_THAT(response.responses(), EXPECT_THAT(response.responses(),
ElementsAre("TopK: 2, Temp: 0.5\n", "Context: cheese\n", ElementsAre("TopK: 2, Temp: 0.5", "cheese", "more", "cheddar"));
"Context: more\n", "Context: cheddar\n"));
} }
TEST_F(OnDeviceModelServiceTest, CloneContextAndContinue) { TEST_F(OnDeviceModelServiceTest, CloneContextAndContinue) {
@@ -240,15 +237,13 @@ TEST_F(OnDeviceModelServiceTest, CloneContextAndContinue) {
TestResponseHolder response; TestResponseHolder response;
cloned->Generate(mojom::GenerateOptions::New(), response.BindRemote()); cloned->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion(); response.WaitForCompletion();
EXPECT_THAT(response.responses(), EXPECT_THAT(response.responses(), ElementsAre("cheese", "more"));
ElementsAre("Context: cheese\n", "Context: more\n"));
} }
{ {
TestResponseHolder response; TestResponseHolder response;
session->Generate(mojom::GenerateOptions::New(), response.BindRemote()); session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion(); response.WaitForCompletion();
EXPECT_THAT(response.responses(), EXPECT_THAT(response.responses(), ElementsAre("cheese", "more"));
ElementsAre("Context: cheese\n", "Context: more\n"));
} }
session->Append(MakeInput("foo"), {}); session->Append(MakeInput("foo"), {});
@@ -257,17 +252,13 @@ TEST_F(OnDeviceModelServiceTest, CloneContextAndContinue) {
TestResponseHolder response; TestResponseHolder response;
session->Generate(mojom::GenerateOptions::New(), response.BindRemote()); session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion(); response.WaitForCompletion();
EXPECT_THAT( EXPECT_THAT(response.responses(), ElementsAre("cheese", "more", "foo"));
response.responses(),
ElementsAre("Context: cheese\n", "Context: more\n", "Context: foo\n"));
} }
{ {
TestResponseHolder response; TestResponseHolder response;
cloned->Generate(mojom::GenerateOptions::New(), response.BindRemote()); cloned->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion(); response.WaitForCompletion();
EXPECT_THAT( EXPECT_THAT(response.responses(), ElementsAre("cheese", "more", "bar"));
response.responses(),
ElementsAre("Context: cheese\n", "Context: more\n", "Context: bar\n"));
} }
} }
@@ -310,21 +301,11 @@ TEST_F(OnDeviceModelServiceTest, MultipleSessionsAppend) {
response4.WaitForCompletion(); response4.WaitForCompletion();
response5.WaitForCompletion(); response5.WaitForCompletion();
EXPECT_THAT(response1.responses(), EXPECT_THAT(response1.responses(), ElementsAre("cheese", "more", "cheddar"));
ElementsAre("Context: cheese\n", "Context: more\n", EXPECT_THAT(response2.responses(), ElementsAre("apple", "banana", "candy"));
"Context: cheddar\n")); EXPECT_THAT(response3.responses(), ElementsAre("apple", "banana", "chip"));
EXPECT_THAT( EXPECT_THAT(response4.responses(), ElementsAre("cheese", "more", "choco"));
response2.responses(), EXPECT_THAT(response5.responses(), ElementsAre("apple", "banana", "orange"));
ElementsAre("Context: apple\n", "Context: banana\n", "Context: candy\n"));
EXPECT_THAT(
response3.responses(),
ElementsAre("Context: apple\n", "Context: banana\n", "Context: chip\n"));
EXPECT_THAT(
response4.responses(),
ElementsAre("Context: cheese\n", "Context: more\n", "Context: choco\n"));
EXPECT_THAT(response5.responses(),
ElementsAre("Context: apple\n", "Context: banana\n",
"Context: orange\n"));
} }
TEST_F(OnDeviceModelServiceTest, CountTokens) { TEST_F(OnDeviceModelServiceTest, CountTokens) {
@@ -369,8 +350,7 @@ TEST_F(OnDeviceModelServiceTest, AppendWithTokenLimits) {
response.WaitForCompletion(); response.WaitForCompletion();
EXPECT_THAT(response.responses(), EXPECT_THAT(response.responses(),
ElementsAre("Context: big \n", "Context: big cheese\n", ElementsAre("big ", "big cheese", "cheddar"));
"Context: cheddar\n"));
} }
TEST_F(OnDeviceModelServiceTest, MultipleSessionsWaitPreviousSession) { TEST_F(OnDeviceModelServiceTest, MultipleSessionsWaitPreviousSession) {
@@ -392,14 +372,14 @@ TEST_F(OnDeviceModelServiceTest, MultipleSessionsWaitPreviousSession) {
// Response from first session should still work. // Response from first session should still work.
response1.WaitForCompletion(); response1.WaitForCompletion();
EXPECT_THAT(response1.responses(), ElementsAre("Context: 1\n")); EXPECT_THAT(response1.responses(), ElementsAre("1"));
// Second session still works. // Second session still works.
TestResponseHolder response2; TestResponseHolder response2;
session2->Append(MakeInput("2"), {}); session2->Append(MakeInput("2"), {});
session2->Generate(mojom::GenerateOptions::New(), response2.BindRemote()); session2->Generate(mojom::GenerateOptions::New(), response2.BindRemote());
response2.WaitForCompletion(); response2.WaitForCompletion();
EXPECT_THAT(response2.responses(), ElementsAre("Context: 2\n")); EXPECT_THAT(response2.responses(), ElementsAre("2"));
} }
TEST_F(OnDeviceModelServiceTest, LoadsAdaptation) { TEST_F(OnDeviceModelServiceTest, LoadsAdaptation) {
@@ -407,18 +387,18 @@ TEST_F(OnDeviceModelServiceTest, LoadsAdaptation) {
FakeFile weights2("Adapt2"); FakeFile weights2("Adapt2");
auto model = LoadModel(); auto model = LoadModel();
auto adaptation1 = LoadAdaptation(*model, weights1.Open()); auto adaptation1 = LoadAdaptation(*model, weights1.Open());
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n")); EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("foo"));
EXPECT_THAT(GetResponses(*adaptation1, "foo"), EXPECT_THAT(GetResponses(*adaptation1, "foo"),
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n")); ElementsAre("Adaptation: Adapt1 (0)", "foo"));
auto adaptation2 = LoadAdaptation(*model, weights2.Open()); auto adaptation2 = LoadAdaptation(*model, weights2.Open());
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n")); EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("foo"));
EXPECT_THAT(GetResponses(*adaptation1, "foo"), EXPECT_THAT(GetResponses(*adaptation1, "foo"),
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n")); ElementsAre("Adaptation: Adapt1 (0)", "foo"));
EXPECT_THAT(GetResponses(*adaptation2, "foo"), EXPECT_THAT(GetResponses(*adaptation2, "foo"),
ElementsAre("Adaptation: Adapt2 (1)\n", "Context: foo\n")); ElementsAre("Adaptation: Adapt2 (1)", "foo"));
EXPECT_THAT(GetResponses(*adaptation1, "foo"), EXPECT_THAT(GetResponses(*adaptation1, "foo"),
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n")); ElementsAre("Adaptation: Adapt1 (0)", "foo"));
} }
TEST_F(OnDeviceModelServiceTest, LoadsAdaptationWithPath) { TEST_F(OnDeviceModelServiceTest, LoadsAdaptationWithPath) {
@@ -426,18 +406,18 @@ TEST_F(OnDeviceModelServiceTest, LoadsAdaptationWithPath) {
FakeFile weights2("Adapt2"); FakeFile weights2("Adapt2");
auto model = LoadModel(ml::ModelBackendType::kApuBackend); auto model = LoadModel(ml::ModelBackendType::kApuBackend);
auto adaptation1 = LoadAdaptation(*model, weights1.Path()); auto adaptation1 = LoadAdaptation(*model, weights1.Path());
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n")); EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("foo"));
EXPECT_THAT(GetResponses(*adaptation1, "foo"), EXPECT_THAT(GetResponses(*adaptation1, "foo"),
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n")); ElementsAre("Adaptation: Adapt1 (0)", "foo"));
auto adaptation2 = LoadAdaptation(*model, weights2.Path()); auto adaptation2 = LoadAdaptation(*model, weights2.Path());
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n")); EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("foo"));
EXPECT_THAT(GetResponses(*adaptation1, "foo"), EXPECT_THAT(GetResponses(*adaptation1, "foo"),
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n")); ElementsAre("Adaptation: Adapt1 (0)", "foo"));
EXPECT_THAT(GetResponses(*adaptation2, "foo"), EXPECT_THAT(GetResponses(*adaptation2, "foo"),
ElementsAre("Adaptation: Adapt2 (1)\n", "Context: foo\n")); ElementsAre("Adaptation: Adapt2 (1)", "foo"));
EXPECT_THAT(GetResponses(*adaptation1, "foo"), EXPECT_THAT(GetResponses(*adaptation1, "foo"),
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n")); ElementsAre("Adaptation: Adapt1 (0)", "foo"));
} }
TEST_F(OnDeviceModelServiceTest, LoadingAdaptationDoesNotCancelSession) { TEST_F(OnDeviceModelServiceTest, LoadingAdaptationDoesNotCancelSession) {
@@ -532,9 +512,8 @@ TEST_F(OnDeviceModelServiceTest, AppendWithTokens) {
} }
response.WaitForCompletion(); response.WaitForCompletion();
EXPECT_THAT(response.responses(), ElementsAre("Context: System: hi End.\n", EXPECT_THAT(response.responses(),
"Context: Model: hello End.\n", ElementsAre("System: hi End.", "Model: hello End.", "User: bye"));
"Context: User: bye\n"));
} }
TEST_F(OnDeviceModelServiceTest, AppendWithImages) { TEST_F(OnDeviceModelServiceTest, AppendWithImages) {
@@ -580,8 +559,8 @@ TEST_F(OnDeviceModelServiceTest, AppendWithImages) {
} }
EXPECT_THAT(response.responses(), EXPECT_THAT(response.responses(),
ElementsAre("Context: cheddar[Bitmap of size 7x21]cheese\n", ElementsAre("cheddar[Bitmap of size 7x21]cheese",
"Context: bleu[Bitmap of size 63x42]cheese\n")); "bleu[Bitmap of size 63x42]cheese"));
} }
TEST_F(OnDeviceModelServiceTest, ClassifyTextSafety) { TEST_F(OnDeviceModelServiceTest, ClassifyTextSafety) {
@@ -641,7 +620,7 @@ TEST_F(OnDeviceModelServiceTest, PerformanceHint) {
auto model = LoadModel(ml::ModelBackendType::kGpuBackend, auto model = LoadModel(ml::ModelBackendType::kGpuBackend,
ml::ModelPerformanceHint::kFastestInference); ml::ModelPerformanceHint::kFastestInference);
EXPECT_THAT(GetResponses(*model, "foo"), EXPECT_THAT(GetResponses(*model, "foo"),
ElementsAre("Fastest inference\n", "Context: foo\n")); ElementsAre("Fastest inference", "foo"));
} }
TEST_F(OnDeviceModelServiceTest, Capabilities) { TEST_F(OnDeviceModelServiceTest, Capabilities) {

@@ -181,29 +181,28 @@ void FakeOnDeviceSession::GenerateImpl(
if (model_->performance_hint() == if (model_->performance_hint() ==
ml::ModelPerformanceHint::kFastestInference) { ml::ModelPerformanceHint::kFastestInference) {
auto chunk = mojom::ResponseChunk::New(); auto chunk = mojom::ResponseChunk::New();
chunk->text = "Fastest inference\n"; chunk->text = "Fastest inference";
remote->OnResponse(std::move(chunk)); remote->OnResponse(std::move(chunk));
} }
if (model_->data().base_weight != "0") { if (model_->data().base_weight != "0") {
auto chunk = mojom::ResponseChunk::New(); auto chunk = mojom::ResponseChunk::New();
chunk->text = "Base model: " + model_->data().base_weight + "\n"; chunk->text = "Base model: " + model_->data().base_weight;
remote->OnResponse(std::move(chunk)); remote->OnResponse(std::move(chunk));
} }
if (!model_->data().adaptation_model_weight.empty()) { if (!model_->data().adaptation_model_weight.empty()) {
auto chunk = mojom::ResponseChunk::New(); auto chunk = mojom::ResponseChunk::New();
chunk->text = chunk->text = "Adaptation model: " + model_->data().adaptation_model_weight;
"Adaptation model: " + model_->data().adaptation_model_weight + "\n";
remote->OnResponse(std::move(chunk)); remote->OnResponse(std::move(chunk));
} }
if (!model_->data().cache_weight.empty()) { if (!model_->data().cache_weight.empty()) {
auto chunk = mojom::ResponseChunk::New(); auto chunk = mojom::ResponseChunk::New();
chunk->text = "Cache weight: " + model_->data().cache_weight + "\n"; chunk->text = "Cache weight: " + model_->data().cache_weight;
remote->OnResponse(std::move(chunk)); remote->OnResponse(std::move(chunk));
} }
if (priority_ == on_device_model::mojom::Priority::kBackground) { if (priority_ == on_device_model::mojom::Priority::kBackground) {
auto chunk = mojom::ResponseChunk::New(); auto chunk = mojom::ResponseChunk::New();
chunk->text = "Priority: background\n"; chunk->text = "Priority: background";
remote->OnResponse(std::move(chunk)); remote->OnResponse(std::move(chunk));
} }
@@ -211,11 +210,11 @@ void FakeOnDeviceSession::GenerateImpl(
const auto& constraint = *options->constraint; const auto& constraint = *options->constraint;
auto chunk = mojom::ResponseChunk::New(); auto chunk = mojom::ResponseChunk::New();
if (constraint.is_json_schema()) { if (constraint.is_json_schema()) {
chunk->text = "Constraint: json " + constraint.get_json_schema() + "\n"; chunk->text = "Constraint: json " + constraint.get_json_schema();
} else if (constraint.is_regex()) { } else if (constraint.is_regex()) {
chunk->text = "Constraint: regex " + constraint.get_regex() + "\n"; chunk->text = "Constraint: regex " + constraint.get_regex();
} else { } else {
chunk->text = "Constraint: unknown\n"; chunk->text = "Constraint: unknown";
} }
remote->OnResponse(std::move(chunk)); remote->OnResponse(std::move(chunk));
} }
@@ -226,15 +225,14 @@ void FakeOnDeviceSession::GenerateImpl(
std::string text = CtxToString(*context, params_->capabilities); std::string text = CtxToString(*context, params_->capabilities);
output_token_count += text.size(); output_token_count += text.size();
auto chunk = mojom::ResponseChunk::New(); auto chunk = mojom::ResponseChunk::New();
chunk->text = "Context: " + text + "\n"; chunk->text = text;
remote->OnResponse(std::move(chunk)); remote->OnResponse(std::move(chunk));
} }
if (params_->top_k != ml::kMinTopK || if (params_->top_k != ml::kMinTopK ||
params_->temperature != ml::kMinTemperature) { params_->temperature != ml::kMinTemperature) {
auto chunk = mojom::ResponseChunk::New(); auto chunk = mojom::ResponseChunk::New();
chunk->text += "TopK: " + base::NumberToString(params_->top_k) + chunk->text += "TopK: " + base::NumberToString(params_->top_k) +
", Temp: " + base::NumberToString(params_->temperature) + ", Temp: " + base::NumberToString(params_->temperature);
"\n";
remote->OnResponse(std::move(chunk)); remote->OnResponse(std::move(chunk));
} }
} else { } else {