Simplify strings returned from the fake on-device model service
The "\n" and "Context: " pieces of the output were mostly noise and don't meaningfully add anything to test correctness. This change removes those pieces to make tests written against the fake a bit cleaner. Bug: 415808003 Change-Id: Id34f394ba24850f2c89219d66d27ed80474b58fa Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6542686 Reviewed-by: Steven Holte <holte@chromium.org> Commit-Queue: Clark DuVall <cduvall@chromium.org> Cr-Commit-Position: refs/heads/main@{#1459596}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
63b4e023e3
commit
c5a6b2936d
chrome/browser/ai
components/optimization_guide/core/model_execution
services/on_device_model
@ -243,13 +243,12 @@ optimization_guide::proto::OnDeviceModelExecutionFeatureConfig CreateConfig() {
|
||||
// fake service would look something like this:
|
||||
// - s1.Prompt("foo")
|
||||
// - Adds "UfooEM" to the session
|
||||
// - Gets output of ["Context: UfooEM\n"] from fake service
|
||||
// - Adds "Context: UfooEM\nE" to the session (fake response + end token)
|
||||
// - Gets output of ["UfooEM"] from fake service
|
||||
// - Adds "UfooEME" to the session (fake response + end token)
|
||||
// - s1.Prompt("bar")
|
||||
// - Adds "UbarEM" to the session
|
||||
// - Gets output of ["Context: UfooEM\n", "Context: Context: UfooEM\nE\n",
|
||||
// "Context: UbarEM\n"].
|
||||
// - Adds "Context: UfooEM\nContext: Context: UfooEM\nE\nContext: UbarEM\n"
|
||||
// - Gets output of ["UfooEM", "UfooEME", "UbarEM"].
|
||||
// - Adds "UfooEMUfooEMEUbarEM"
|
||||
// (concatenated output from fake service) to the session
|
||||
// This behavior verifies the correct inputs and outputs are being returned from
|
||||
// the model, and this helper makes it easier to construct these expectations.
|
||||
@ -260,10 +259,10 @@ std::vector<std::string> FormatResponses(
|
||||
std::string last_output;
|
||||
for (const std::string& response : responses) {
|
||||
if (!last_output.empty()) {
|
||||
formatted.push_back("Context: " + last_output + "E\n");
|
||||
formatted.push_back(last_output + "E");
|
||||
last_output += formatted.back();
|
||||
}
|
||||
formatted.push_back("Context: " + response + "\n");
|
||||
formatted.push_back(response);
|
||||
last_output += formatted.back();
|
||||
}
|
||||
return formatted;
|
||||
@ -383,7 +382,7 @@ TEST_F(AILanguageModelTest, Append) {
|
||||
auto session = CreateSession();
|
||||
Append(*session, MakeInput("foo"));
|
||||
EXPECT_THAT(Prompt(*session, MakeInput("bar")),
|
||||
ElementsAre("Context: UfooE\n", "Context: UbarEM\n"));
|
||||
ElementsAre("UfooE", "UbarEM"));
|
||||
}
|
||||
|
||||
TEST_F(AILanguageModelTest, PromptTokenCounts) {
|
||||
@ -464,9 +463,9 @@ TEST_F(AILanguageModelTest, SamplingParams) {
|
||||
auto fork = Fork(*session);
|
||||
|
||||
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
|
||||
ElementsAre("Context: UfooEM\n", "TopK: 2, Temp: 1\n"));
|
||||
ElementsAre("UfooEM", "TopK: 2, Temp: 1"));
|
||||
EXPECT_THAT(Prompt(*fork, MakeInput("bar")),
|
||||
ElementsAre("Context: UbarEM\n", "TopK: 2, Temp: 1\n"));
|
||||
ElementsAre("UbarEM", "TopK: 2, Temp: 1"));
|
||||
}
|
||||
|
||||
TEST_F(AILanguageModelTest, SamplingParamsTopKOutOfRange) {
|
||||
@ -479,7 +478,7 @@ TEST_F(AILanguageModelTest, SamplingParamsTopKOutOfRange) {
|
||||
auto session = CreateSession(std::move(options));
|
||||
|
||||
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
|
||||
ElementsAre("Context: UfooEM\n", "TopK: 1, Temp: 1.5\n"));
|
||||
ElementsAre("UfooEM", "TopK: 1, Temp: 1.5"));
|
||||
}
|
||||
|
||||
TEST_F(AILanguageModelTest, SamplingParamsTemperatureOutOfRange) {
|
||||
@ -492,7 +491,7 @@ TEST_F(AILanguageModelTest, SamplingParamsTemperatureOutOfRange) {
|
||||
auto session = CreateSession(std::move(options));
|
||||
|
||||
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
|
||||
ElementsAre("Context: UfooEM\n", "TopK: 2, Temp: 0\n"));
|
||||
ElementsAre("UfooEM", "TopK: 2, Temp: 0"));
|
||||
}
|
||||
|
||||
TEST_F(AILanguageModelTest, MaxSamplingParams) {
|
||||
@ -505,7 +504,7 @@ TEST_F(AILanguageModelTest, MaxSamplingParams) {
|
||||
auto session = CreateSession(std::move(options));
|
||||
|
||||
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
|
||||
ElementsAre("Context: UfooEM\n", "TopK: 5, Temp: 1.5\n"));
|
||||
ElementsAre("UfooEM", "TopK: 5, Temp: 1.5"));
|
||||
}
|
||||
|
||||
TEST_F(AILanguageModelTest, InitialPrompts) {
|
||||
@ -515,7 +514,7 @@ TEST_F(AILanguageModelTest, InitialPrompts) {
|
||||
auto session = CreateSession(std::move(options));
|
||||
|
||||
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
|
||||
ElementsAre("Context: ShiEUbyeE\n", "Context: UfooEM\n"));
|
||||
ElementsAre("ShiEUbyeE", "UfooEM"));
|
||||
}
|
||||
|
||||
TEST_F(AILanguageModelTest, InitialPromptsInstanceInfo) {
|
||||
@ -591,8 +590,7 @@ TEST_F(AILanguageModelTest, QuotaOverflowOnPromptInput) {
|
||||
// Response should include input/output of previous prompt with the original
|
||||
// long prompt not present.
|
||||
EXPECT_THAT(responder.responses(),
|
||||
ElementsAre("Context: SinitE\n", "Context: UfooEMhiE\n",
|
||||
"Context: U" + long_prompt + "EM\n"));
|
||||
ElementsAre("SinitE", "UfooEMhiE", "U" + long_prompt + "EM"));
|
||||
}
|
||||
|
||||
TEST_F(AILanguageModelTest, QuotaOverflowOnAppend) {
|
||||
@ -609,10 +607,8 @@ TEST_F(AILanguageModelTest, QuotaOverflowOnAppend) {
|
||||
responder.WaitForQuotaOverflow();
|
||||
EXPECT_TRUE(responder.WaitForCompletion());
|
||||
|
||||
EXPECT_THAT(
|
||||
Prompt(*session, MakeInput("foo")),
|
||||
ElementsAre("Context: SinitE\n", "Context: U" + long_prompt + "E\n",
|
||||
"Context: UfooEM\n"));
|
||||
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
|
||||
ElementsAre("SinitE", "U" + long_prompt + "E", "UfooEM"));
|
||||
}
|
||||
|
||||
TEST_F(AILanguageModelTest, QuotaOverflowOnOutput) {
|
||||
@ -643,8 +639,7 @@ TEST_F(AILanguageModelTest, QuotaOverflowOnOutput) {
|
||||
// - "bar" from the current prompt call
|
||||
fake_broker_.settings().set_execute_result({});
|
||||
EXPECT_THAT(Prompt(*session, MakeInput("bar")),
|
||||
ElementsAre("Context: UfooEM" + long_response + "E\n",
|
||||
"Context: UbarEM\n"));
|
||||
ElementsAre("UfooEM" + long_response + "E", "UbarEM"));
|
||||
}
|
||||
|
||||
TEST_F(AILanguageModelTest, Destroy) {
|
||||
@ -949,7 +944,7 @@ TEST_F(AILanguageModelTest, Constraint) {
|
||||
EXPECT_THAT(
|
||||
Prompt(*session, MakeInput("foo"),
|
||||
on_device_model::mojom::ResponseConstraint::NewRegex("reg")),
|
||||
ElementsAre("Constraint: regex reg\n", "Context: UfooEM\n"));
|
||||
ElementsAre("Constraint: regex reg", "UfooEM"));
|
||||
}
|
||||
|
||||
TEST_F(AILanguageModelTest, ServiceCrash) {
|
||||
@ -1112,11 +1107,11 @@ TEST_F(AILanguageModelTest, Priority) {
|
||||
|
||||
main_rfh()->GetRenderWidgetHost()->GetView()->Hide();
|
||||
EXPECT_THAT(Prompt(*session, MakeInput("bar")),
|
||||
ElementsAre("Priority: background\n", "hi"));
|
||||
ElementsAre("Priority: background", "hi"));
|
||||
|
||||
auto fork = Fork(*session);
|
||||
EXPECT_THAT(Prompt(*fork, MakeInput("bar")),
|
||||
ElementsAre("Priority: background\n", "hi"));
|
||||
ElementsAre("Priority: background", "hi"));
|
||||
|
||||
main_rfh()->GetRenderWidgetHost()->GetView()->Show();
|
||||
EXPECT_THAT(Prompt(*session, MakeInput("baz")), ElementsAre("hi"));
|
||||
|
144
components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
144
components/optimization_guide/core/model_execution/on_device_model_service_controller_unittest.cc
@ -330,7 +330,7 @@ TEST_F(OnDeviceModelServiceControllerTest, BaseModelExecutionSuccess) {
|
||||
session->ExecuteModel(PageUrlRequest("foo"),
|
||||
response_.GetStreamingCallback());
|
||||
ASSERT_TRUE(response_.GetFinalStatus());
|
||||
const std::string expected_response = "Context: execute:foo max:1024\n";
|
||||
const std::string expected_response = "execute:foo max:1024";
|
||||
EXPECT_EQ(*response_.value(), expected_response);
|
||||
EXPECT_TRUE(*response_.provided_by_on_device());
|
||||
EXPECT_THAT(response_.partials(), ElementsAre(expected_response));
|
||||
@ -453,8 +453,7 @@ TEST_F(OnDeviceModelServiceControllerTest, CacheWeightExecutionSuccess) {
|
||||
session->ExecuteModel(PageUrlRequest("foo"),
|
||||
response_.GetStreamingCallback());
|
||||
ASSERT_TRUE(response_.GetFinalStatus());
|
||||
EXPECT_EQ(*response_.value(),
|
||||
"Cache weight: 1015\nContext: execute:foo max:1024\n");
|
||||
EXPECT_EQ(*response_.value(), "Cache weight: 1015execute:foo max:1024");
|
||||
|
||||
// If we destroy all sessions and wait long enough, everything should idle out
|
||||
// and the service should get terminated.
|
||||
@ -481,8 +480,7 @@ TEST_F(OnDeviceModelServiceControllerTest, AdaptationModelExecutionSuccess) {
|
||||
session->ExecuteModel(PageUrlRequest("foo"),
|
||||
response_.GetStreamingCallback());
|
||||
ASSERT_TRUE(response_.GetFinalStatus());
|
||||
EXPECT_EQ(*response_.value(),
|
||||
"Adaptation model: 1015\nContext: execute:foo max:1024\n");
|
||||
EXPECT_EQ(*response_.value(), "Adaptation model: 1015execute:foo max:1024");
|
||||
|
||||
// If we destroy all sessions and wait long enough, everything should idle out
|
||||
// and the service should get terminated.
|
||||
@ -535,11 +533,11 @@ TEST_F(OnDeviceModelServiceControllerTest,
|
||||
|
||||
ASSERT_TRUE(compose_response.GetFinalStatus());
|
||||
EXPECT_EQ(*compose_response.value(),
|
||||
"Adaptation model: 1015\nContext: execute:foo max:1024\n");
|
||||
"Adaptation model: 1015execute:foo max:1024");
|
||||
EXPECT_TRUE(*compose_response.provided_by_on_device());
|
||||
ASSERT_TRUE(test_response.GetFinalStatus());
|
||||
EXPECT_EQ(*test_response.value(),
|
||||
"Adaptation model: 2024\nContext: execute:bar max:1024\n");
|
||||
"Adaptation model: 2024execute:bar max:1024");
|
||||
EXPECT_TRUE(*test_response.provided_by_on_device());
|
||||
|
||||
session_compose.reset();
|
||||
@ -594,10 +592,10 @@ TEST_F(OnDeviceModelServiceControllerTest, ModelAdaptationAndBaseModelSuccess) {
|
||||
|
||||
ASSERT_TRUE(compose_response.GetFinalStatus());
|
||||
EXPECT_EQ(*compose_response.value(),
|
||||
"Adaptation model: 1015\nContext: execute:foo max:1024\n");
|
||||
"Adaptation model: 1015execute:foo max:1024");
|
||||
EXPECT_TRUE(*compose_response.provided_by_on_device());
|
||||
ASSERT_TRUE(test_response.GetFinalStatus());
|
||||
EXPECT_EQ(*test_response.value(), "Context: execute:bar max:1024\n");
|
||||
EXPECT_EQ(*test_response.value(), "execute:bar max:1024");
|
||||
EXPECT_TRUE(*test_response.provided_by_on_device());
|
||||
|
||||
session_compose.reset();
|
||||
@ -719,8 +717,7 @@ TEST_F(OnDeviceModelServiceControllerTest, SessionBeforeAndAfterModelUpdate) {
|
||||
session2->ExecuteModel(PageUrlRequest("foo"),
|
||||
response2.GetStreamingCallback());
|
||||
ASSERT_TRUE(response2.GetFinalStatus());
|
||||
EXPECT_EQ(*response2.value(),
|
||||
"Base model: 2\nContext: execute:foo max:1024\n");
|
||||
EXPECT_EQ(*response2.value(), "Base model: 2execute:foo max:1024");
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceControllerTest, SessionFailsForInvalidFeature) {
|
||||
@ -1636,7 +1633,7 @@ TEST_F(OnDeviceModelServiceControllerTest, CancelsExecuteOnExecute) {
|
||||
EXPECT_EQ(
|
||||
*resp1.error(),
|
||||
OptimizationGuideModelExecutionError::ModelExecutionError::kCancelled);
|
||||
EXPECT_EQ(*resp2.value(), "Context: execute:bar max:1024\n");
|
||||
EXPECT_EQ(*resp2.value(), "execute:bar max:1024");
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceControllerTest, WontStartSessionAfterGpuBlocked) {
|
||||
@ -1822,8 +1819,8 @@ TEST_F(OnDeviceModelServiceControllerTest, AddContextDisconnectExecute) {
|
||||
"OptimizationGuide.ModelExecution.OnDeviceExecuteModelResult.Compose",
|
||||
ExecuteModelResult::kUsedOnDevice, 1);
|
||||
std::string expected_response =
|
||||
("Context: ctx:foo max:8192\n"
|
||||
"Context: execute:foobaz max:1024\n");
|
||||
("ctx:foo max:8192"
|
||||
"execute:foobaz max:1024");
|
||||
EXPECT_EQ(*response_.value(), expected_response);
|
||||
}
|
||||
|
||||
@ -1861,16 +1858,16 @@ TEST_F(OnDeviceModelServiceControllerTest, AddContextMultipleSessions) {
|
||||
session2->ExecuteModel(PageUrlRequest("2"), response_.GetStreamingCallback());
|
||||
ASSERT_TRUE(response_.GetFinalStatus());
|
||||
std::string expected_response1 =
|
||||
("Context: ctx:bar max:8192\n"
|
||||
"Context: execute:bar2 max:1024\n");
|
||||
("ctx:bar max:8192"
|
||||
"execute:bar2 max:1024");
|
||||
EXPECT_EQ(*response_.value(), expected_response1);
|
||||
|
||||
ResponseHolder response2;
|
||||
session1->ExecuteModel(PageUrlRequest("1"), response2.GetStreamingCallback());
|
||||
ASSERT_TRUE(response2.GetFinalStatus());
|
||||
std::string expected_response2 =
|
||||
("Context: ctx:foo max:8192\n"
|
||||
"Context: execute:foo1 max:1024\n");
|
||||
("ctx:foo max:8192"
|
||||
"execute:foo1 max:1024");
|
||||
EXPECT_EQ(*response2.value(), expected_response2);
|
||||
}
|
||||
|
||||
@ -2067,7 +2064,7 @@ TEST_F(OnDeviceModelServiceControllerTest, RedactedField) {
|
||||
session1->ExecuteModel(UserInputRequest("foo"),
|
||||
response_.GetStreamingCallback());
|
||||
task_environment_.RunUntilIdle();
|
||||
const std::string expected_response1 = "Context: execute:foo max:1024\n";
|
||||
const std::string expected_response1 = "execute:foo max:1024";
|
||||
EXPECT_EQ(*response_.value(), expected_response1);
|
||||
EXPECT_THAT(response_.partials(), IsEmpty());
|
||||
|
||||
@ -2078,19 +2075,19 @@ TEST_F(OnDeviceModelServiceControllerTest, RedactedField) {
|
||||
session2->ExecuteModel(UserInputRequest("abarx"),
|
||||
response2.GetStreamingCallback());
|
||||
task_environment_.RunUntilIdle();
|
||||
const std::string expected_response2 = "Context: execute:abarx max:1024\n";
|
||||
const std::string expected_response2 = "execute:abarx max:1024";
|
||||
EXPECT_EQ(*response2.value(), expected_response2);
|
||||
EXPECT_THAT(response2.partials(), IsEmpty());
|
||||
|
||||
// Output contains redacted text (and input doesn't), so redact.
|
||||
fake_settings_.set_execute_result({"Context: abarx max:1024\n"});
|
||||
fake_settings_.set_execute_result({"abarx max:1024"});
|
||||
auto session3 = CreateSession();
|
||||
ASSERT_TRUE(session3);
|
||||
ResponseHolder response3;
|
||||
session3->ExecuteModel(UserInputRequest("foo"),
|
||||
response3.GetStreamingCallback());
|
||||
task_environment_.RunUntilIdle();
|
||||
const std::string expected_response3 = "Context: a[###]x max:1024\n";
|
||||
const std::string expected_response3 = "a[###]x max:1024";
|
||||
EXPECT_EQ(*response3.value(), expected_response3);
|
||||
EXPECT_THAT(response3.partials(), IsEmpty());
|
||||
}
|
||||
@ -2151,7 +2148,7 @@ TEST_F(OnDeviceModelServiceControllerTest, UsePreviousResponseForRewrite) {
|
||||
});
|
||||
|
||||
// Force 'bar' to be returned from model.
|
||||
fake_settings_.set_execute_result({"Context: bar max:1024\n"});
|
||||
fake_settings_.set_execute_result({"bar max:1024"});
|
||||
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session);
|
||||
@ -2160,7 +2157,7 @@ TEST_F(OnDeviceModelServiceControllerTest, UsePreviousResponseForRewrite) {
|
||||
response_.GetStreamingCallback());
|
||||
task_environment_.RunUntilIdle();
|
||||
// `bar` shouldn't be rewritten as it's in the input.
|
||||
const std::string expected_response = "Context: bar max:1024\n";
|
||||
const std::string expected_response = "bar max:1024";
|
||||
EXPECT_EQ(*response_.value(), expected_response);
|
||||
EXPECT_THAT(response_.partials(), IsEmpty());
|
||||
}
|
||||
@ -2177,13 +2174,13 @@ TEST_F(OnDeviceModelServiceControllerTest, ReplacementText) {
|
||||
});
|
||||
|
||||
// Output contains redacted text (and input doesn't), so redact.
|
||||
fake_settings_.set_execute_result({"Context: abarx max:1024\n"});
|
||||
fake_settings_.set_execute_result({"abarx max:1024"});
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session);
|
||||
session->ExecuteModel(UserInputRequest("foo"),
|
||||
response_.GetStreamingCallback());
|
||||
task_environment_.RunUntilIdle();
|
||||
const std::string expected_response = "Context: a[redacted]x max:1024\n";
|
||||
const std::string expected_response = "a[redacted]x max:1024";
|
||||
EXPECT_EQ(*response_.value(), expected_response);
|
||||
EXPECT_THAT(response_.partials(), IsEmpty());
|
||||
}
|
||||
@ -2435,8 +2432,8 @@ TEST_F(OnDeviceModelServiceControllerTest, UsesSessionTopKAndTemperature) {
|
||||
task_environment_.RunUntilIdle();
|
||||
EXPECT_TRUE(response_.value());
|
||||
const std::vector<std::string> partial_responses = {
|
||||
"Context: execute:foo max:1024\n",
|
||||
"TopK: 3, Temp: 2\n",
|
||||
"execute:foo max:1024",
|
||||
"TopK: 3, Temp: 2",
|
||||
};
|
||||
EXPECT_EQ(*response_.value(), ConcatResponses(partial_responses));
|
||||
EXPECT_THAT(response_.partials(), ElementsAreArray(partial_responses));
|
||||
@ -3238,8 +3235,7 @@ TEST_F(OnDeviceModelServiceControllerTest, SendsPerformanceHint) {
|
||||
session->ExecuteModel(PageUrlRequest("foo"),
|
||||
response_.GetStreamingCallback());
|
||||
ASSERT_TRUE(response_.GetFinalStatus());
|
||||
EXPECT_EQ(*response_.value(),
|
||||
"Fastest inference\nContext: execute:foo max:1024\n");
|
||||
EXPECT_EQ(*response_.value(), "Fastest inferenceexecute:foo max:1024");
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceControllerTest, ImageExecutionSuccess) {
|
||||
@ -3296,8 +3292,7 @@ TEST_F(OnDeviceModelServiceControllerTest, ImageExecutionSuccess) {
|
||||
session->ExecuteModel(proto::ExampleForTestingRequest(),
|
||||
response.GetStreamingCallback());
|
||||
ASSERT_TRUE(response.GetFinalStatus());
|
||||
EXPECT_EQ(*response.value(),
|
||||
"Context: <image> max:22\nContext: <image> max:1024\n");
|
||||
EXPECT_EQ(*response.value(), "<image> max:22<image> max:1024");
|
||||
}
|
||||
|
||||
// Session without capabilities should not allow images.
|
||||
@ -3310,8 +3305,8 @@ TEST_F(OnDeviceModelServiceControllerTest, ImageExecutionSuccess) {
|
||||
response.GetStreamingCallback());
|
||||
ASSERT_TRUE(response.GetFinalStatus());
|
||||
EXPECT_EQ(*response.value(),
|
||||
"Context: <unsupported> max:22\nContext: <unsupported> "
|
||||
"max:1024\n");
|
||||
"<unsupported> max:22<unsupported> "
|
||||
"max:1024");
|
||||
}
|
||||
}
|
||||
|
||||
@ -3403,8 +3398,7 @@ TEST_F(OnDeviceModelServiceControllerTest, KeepInputOnExtension) {
|
||||
altered_clone->ExecuteModel(proto::ExampleForTestingRequest(),
|
||||
altered_response.GetStreamingCallback());
|
||||
ASSERT_TRUE(altered_response.GetFinalStatus());
|
||||
EXPECT_EQ(*altered_response.value(),
|
||||
"Context: v1<image><audio>v2v3v4 max:22\n");
|
||||
EXPECT_EQ(*altered_response.value(), "v1<image><audio>v2v3v4 max:22");
|
||||
|
||||
// The clone that only extended should have sent input in separate chunks.
|
||||
ResponseHolder extended_response;
|
||||
@ -3412,19 +3406,19 @@ TEST_F(OnDeviceModelServiceControllerTest, KeepInputOnExtension) {
|
||||
extended_response.GetStreamingCallback());
|
||||
ASSERT_TRUE(extended_response.GetFinalStatus());
|
||||
EXPECT_EQ(*extended_response.value(),
|
||||
"Context: v1<image><audio> max:22\n"
|
||||
"Context: v2 max:22\n"
|
||||
"Context: v3 max:4\n"
|
||||
"Context: v4 max:4\n");
|
||||
"v1<image><audio> max:22"
|
||||
"v2 max:22"
|
||||
"v3 max:4"
|
||||
"v4 max:4");
|
||||
|
||||
// The original should have input in separate chunks.
|
||||
session->ExecuteModel(proto::ExampleForTestingRequest(),
|
||||
response_.GetStreamingCallback());
|
||||
ASSERT_TRUE(response_.GetFinalStatus());
|
||||
EXPECT_EQ(*response_.value(),
|
||||
"Context: v1<image><audio> max:22\n"
|
||||
"Context: v2 max:22\n"
|
||||
"Context: v3 max:4\n");
|
||||
"v1<image><audio> max:22"
|
||||
"v2 max:22"
|
||||
"v3 max:4");
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceControllerTest, OmitEmptyInputs) {
|
||||
@ -3495,8 +3489,8 @@ TEST_F(OnDeviceModelServiceControllerTest, CloneUsesSessionTopKAndTemperature) {
|
||||
task_environment_.RunUntilIdle();
|
||||
EXPECT_TRUE(response_.value());
|
||||
const std::vector<std::string> partial_responses = {
|
||||
"Context: execute:foo max:1024\n",
|
||||
"TopK: 3, Temp: 2\n",
|
||||
"execute:foo max:1024",
|
||||
"TopK: 3, Temp: 2",
|
||||
};
|
||||
EXPECT_EQ(*response_.value(), ConcatResponses(partial_responses));
|
||||
EXPECT_THAT(response_.partials(), ElementsAreArray(partial_responses));
|
||||
@ -3586,8 +3580,8 @@ TEST_F(OnDeviceModelServiceControllerTest, AddContextAndClone) {
|
||||
clone->ExecuteModel(PageUrlRequest("bar"), response.GetStreamingCallback());
|
||||
ASSERT_TRUE(response.GetFinalStatus());
|
||||
std::string expected_response =
|
||||
("Context: ctx:foo max:8192\n"
|
||||
"Context: execute:foobar max:1024\n");
|
||||
("ctx:foo max:8192"
|
||||
"execute:foobar max:1024");
|
||||
EXPECT_EQ(*response.value(), expected_response);
|
||||
}
|
||||
|
||||
@ -3598,8 +3592,8 @@ TEST_F(OnDeviceModelServiceControllerTest, AddContextAndClone) {
|
||||
response.GetStreamingCallback());
|
||||
ASSERT_TRUE(response.GetFinalStatus());
|
||||
std::string expected_response =
|
||||
("Context: ctx:foo max:8192\n"
|
||||
"Context: execute:fooblah max:1024\n");
|
||||
("ctx:foo max:8192"
|
||||
"execute:fooblah max:1024");
|
||||
EXPECT_EQ(*response.value(), expected_response);
|
||||
}
|
||||
}
|
||||
@ -3618,7 +3612,7 @@ TEST_F(OnDeviceModelServiceControllerTest, CloneBeforeAddContext) {
|
||||
ResponseHolder response;
|
||||
clone->ExecuteModel(PageUrlRequest("bar"), response.GetStreamingCallback());
|
||||
ASSERT_TRUE(response.GetFinalStatus());
|
||||
EXPECT_EQ(*response.value(), "Context: execute:bar max:1024\n");
|
||||
EXPECT_EQ(*response.value(), "execute:bar max:1024");
|
||||
}
|
||||
|
||||
// Original session should execute with context
|
||||
@ -3628,8 +3622,8 @@ TEST_F(OnDeviceModelServiceControllerTest, CloneBeforeAddContext) {
|
||||
response.GetStreamingCallback());
|
||||
ASSERT_TRUE(response.GetFinalStatus());
|
||||
std::string expected_response =
|
||||
("Context: ctx:foo max:8192\n"
|
||||
"Context: execute:fooblah max:1024\n");
|
||||
("ctx:foo max:8192"
|
||||
"execute:fooblah max:1024");
|
||||
EXPECT_EQ(*response.value(), expected_response);
|
||||
}
|
||||
}
|
||||
@ -3647,7 +3641,7 @@ TEST_F(OnDeviceModelServiceControllerTest, CancelAddContextAndClone) {
|
||||
ResponseHolder response;
|
||||
clone->ExecuteModel(PageUrlRequest("bar"), response.GetStreamingCallback());
|
||||
ASSERT_TRUE(response.GetFinalStatus());
|
||||
EXPECT_EQ(*response.value(), "Context: execute:foobar max:1024\n");
|
||||
EXPECT_EQ(*response.value(), "execute:foobar max:1024");
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceControllerTest, CloneAddContextDisconnectExecute) {
|
||||
@ -3666,8 +3660,8 @@ TEST_F(OnDeviceModelServiceControllerTest, CloneAddContextDisconnectExecute) {
|
||||
clone->ExecuteModel(PageUrlRequest("bar"), response.GetStreamingCallback());
|
||||
ASSERT_TRUE(response.GetFinalStatus());
|
||||
std::string expected_response =
|
||||
("Context: ctx:foo max:8192\n"
|
||||
"Context: execute:foobar max:1024\n");
|
||||
("ctx:foo max:8192"
|
||||
"execute:foobar max:1024");
|
||||
EXPECT_EQ(*response.value(), expected_response);
|
||||
}
|
||||
|
||||
@ -3692,7 +3686,7 @@ TEST_F(OnDeviceModelServiceControllerTest, Broker) {
|
||||
ResponseHolder response;
|
||||
session->ExecuteModel(PageUrlRequest("bar"), response.GetStreamingCallback());
|
||||
ASSERT_TRUE(response.GetFinalStatus());
|
||||
EXPECT_EQ(*response.value(), "Context: execute:bar max:1024\n");
|
||||
EXPECT_EQ(*response.value(), "execute:bar max:1024");
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceControllerTest,
|
||||
@ -3724,16 +3718,16 @@ TEST_F(OnDeviceModelServiceControllerTest, Priority) {
|
||||
auto session = CreateSession();
|
||||
EXPECT_TRUE(session);
|
||||
|
||||
EXPECT_EQ(GetResponse(*session, "foo"), "Context: execute:foo max:1024\n");
|
||||
EXPECT_EQ(GetResponse(*session, "foo"), "execute:foo max:1024");
|
||||
|
||||
session->SetPriority(on_device_model::mojom::Priority::kBackground);
|
||||
EXPECT_EQ(GetResponse(*session, "foo"),
|
||||
"Priority: background\nContext: execute:foo max:1024\n");
|
||||
"Priority: backgroundexecute:foo max:1024");
|
||||
EXPECT_EQ(GetResponse(*session, "foo"),
|
||||
"Priority: background\nContext: execute:foo max:1024\n");
|
||||
"Priority: backgroundexecute:foo max:1024");
|
||||
|
||||
session->SetPriority(on_device_model::mojom::Priority::kForeground);
|
||||
EXPECT_EQ(GetResponse(*session, "foo"), "Context: execute:foo max:1024\n");
|
||||
EXPECT_EQ(GetResponse(*session, "foo"), "execute:foo max:1024");
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceControllerTest, PriorityClone) {
|
||||
@ -3742,17 +3736,17 @@ TEST_F(OnDeviceModelServiceControllerTest, PriorityClone) {
|
||||
auto session = CreateSession();
|
||||
EXPECT_TRUE(session);
|
||||
|
||||
EXPECT_EQ(GetResponse(*session, "foo"), "Context: execute:foo max:1024\n");
|
||||
EXPECT_EQ(GetResponse(*session, "foo"), "execute:foo max:1024");
|
||||
|
||||
session->SetPriority(on_device_model::mojom::Priority::kBackground);
|
||||
EXPECT_EQ(GetResponse(*session, "foo"),
|
||||
"Priority: background\nContext: execute:foo max:1024\n");
|
||||
"Priority: backgroundexecute:foo max:1024");
|
||||
|
||||
auto clone = session->Clone();
|
||||
EXPECT_EQ(GetResponse(*clone, "foo"),
|
||||
"Priority: background\nContext: execute:foo max:1024\n");
|
||||
"Priority: backgroundexecute:foo max:1024");
|
||||
EXPECT_EQ(GetResponse(*clone, "foo"),
|
||||
"Priority: background\nContext: execute:foo max:1024\n");
|
||||
"Priority: backgroundexecute:foo max:1024");
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceControllerTest, SetInputCallback) {
|
||||
@ -3772,8 +3766,8 @@ TEST_F(OnDeviceModelServiceControllerTest, SetInputCallback) {
|
||||
response_.GetStreamingCallback());
|
||||
ASSERT_TRUE(response_.GetFinalStatus());
|
||||
EXPECT_EQ(response_.value(),
|
||||
"Context: ctx:foo max:8192\nContext: execute:foobar "
|
||||
"max:1024\n");
|
||||
"ctx:foo max:8192execute:foobar "
|
||||
"max:1024");
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceControllerTest, SetInputCallbackCancelled) {
|
||||
@ -3802,8 +3796,8 @@ TEST_F(OnDeviceModelServiceControllerTest, SetInputCallbackCancelled) {
|
||||
response_.GetStreamingCallback());
|
||||
ASSERT_TRUE(response_.GetFinalStatus());
|
||||
EXPECT_EQ(response_.value(),
|
||||
"Context: ctx:foo max:8192\nContext: execute:foobar "
|
||||
"max:1024\n");
|
||||
"ctx:foo max:8192execute:foobar "
|
||||
"max:1024");
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceControllerTest, SetInputCallbackError) {
|
||||
@ -3831,7 +3825,7 @@ TEST_F(OnDeviceModelServiceControllerTest, TokenCounts) {
|
||||
session->ExecuteModel(PageUrlRequest("foo"),
|
||||
response_.GetStreamingCallback());
|
||||
ASSERT_TRUE(response_.GetFinalStatus());
|
||||
EXPECT_EQ(response_.value(), "Context: execute:foo max:1024\n");
|
||||
EXPECT_EQ(response_.value(), "execute:foo max:1024");
|
||||
EXPECT_EQ(response_.input_token_count(), strlen("execute:foo"));
|
||||
EXPECT_EQ(response_.output_token_count(), strlen("execute:foo max:1024"));
|
||||
}
|
||||
@ -3848,8 +3842,8 @@ TEST_F(OnDeviceModelServiceControllerTest, ResponseConstraintOnExecute) {
|
||||
response_.GetStreamingCallback());
|
||||
ASSERT_TRUE(response_.GetFinalStatus());
|
||||
EXPECT_EQ(response_.value(),
|
||||
"Constraint: regex [A-Z]*\n"
|
||||
"Context: execute:input max:1024\n");
|
||||
"Constraint: regex [A-Z]*"
|
||||
"execute:input max:1024");
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceControllerTest, ResponseConstraintConfigJson) {
|
||||
@ -3880,8 +3874,8 @@ TEST_F(OnDeviceModelServiceControllerTest, ResponseConstraintConfigJson) {
|
||||
response_.GetStreamingCallback());
|
||||
ASSERT_TRUE(response_.GetFinalStatus());
|
||||
EXPECT_EQ(response_.value(),
|
||||
"Constraint: json { type: \"object\"}\n"
|
||||
"Context: execute:input max:1024\n");
|
||||
"Constraint: json { type: \"object\"}"
|
||||
"execute:input max:1024");
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceControllerTest, ResponseConstraintConfigRegex) {
|
||||
@ -3912,8 +3906,8 @@ TEST_F(OnDeviceModelServiceControllerTest, ResponseConstraintConfigRegex) {
|
||||
response_.GetStreamingCallback());
|
||||
ASSERT_TRUE(response_.GetFinalStatus());
|
||||
EXPECT_EQ(response_.value(),
|
||||
"Constraint: regex [A-Z]*\n"
|
||||
"Context: execute:input max:1024\n");
|
||||
"Constraint: regex [A-Z]*"
|
||||
"execute:input max:1024");
|
||||
}
|
||||
|
||||
} // namespace optimization_guide
|
||||
|
@ -279,7 +279,7 @@ bool SessionGenerate(ChromeMLSession session,
|
||||
|
||||
if (instance->model_instance->performance_hint ==
|
||||
ml::ModelPerformanceHint::kFastestInference) {
|
||||
OutputChunk("Fastest inference\n");
|
||||
OutputChunk("Fastest inference");
|
||||
}
|
||||
if (!instance->adaptation_data.empty()) {
|
||||
std::string adaptation_str = "Adaptation: " + instance->adaptation_data;
|
||||
@ -287,19 +287,19 @@ bool SessionGenerate(ChromeMLSession session,
|
||||
adaptation_str +=
|
||||
" (" + base::NumberToString(*instance->adaptation_file_id) + ")";
|
||||
}
|
||||
OutputChunk(adaptation_str + "\n");
|
||||
OutputChunk(adaptation_str);
|
||||
}
|
||||
|
||||
// Only include sampling params if they're not the respective default values.
|
||||
if (instance->top_k != 1 || instance->temperature != 0) {
|
||||
OutputChunk(base::StrCat(
|
||||
{"TopK: ", base::NumberToString(instance->top_k),
|
||||
", Temp: ", base::NumberToString(instance->temperature), "\n"}));
|
||||
", Temp: ", base::NumberToString(instance->temperature)}));
|
||||
}
|
||||
|
||||
if (!instance->context.empty()) {
|
||||
for (const std::string& context : instance->context) {
|
||||
OutputChunk("Context: " + context + "\n");
|
||||
OutputChunk(context);
|
||||
}
|
||||
}
|
||||
if (options->constraint) {
|
||||
|
@ -179,9 +179,9 @@ class OnDeviceModelServiceTest : public testing::Test {
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, Responds) {
|
||||
auto model = LoadModel();
|
||||
EXPECT_THAT(GetResponses(*model, "bar"), ElementsAre("Context: bar\n"));
|
||||
EXPECT_THAT(GetResponses(*model, "bar"), ElementsAre("bar"));
|
||||
// Try another input on the same model.
|
||||
EXPECT_THAT(GetResponses(*model, "cat"), ElementsAre("Context: cat\n"));
|
||||
EXPECT_THAT(GetResponses(*model, "cat"), ElementsAre("cat"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, Append) {
|
||||
@ -196,9 +196,7 @@ TEST_F(OnDeviceModelServiceTest, Append) {
|
||||
session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
|
||||
response.WaitForCompletion();
|
||||
|
||||
EXPECT_THAT(response.responses(),
|
||||
ElementsAre("Context: cheese\n", "Context: more\n",
|
||||
"Context: cheddar\n"));
|
||||
EXPECT_THAT(response.responses(), ElementsAre("cheese", "more", "cheddar"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, PerSessionSamplingParams) {
|
||||
@ -221,8 +219,7 @@ TEST_F(OnDeviceModelServiceTest, PerSessionSamplingParams) {
|
||||
response.WaitForCompletion();
|
||||
|
||||
EXPECT_THAT(response.responses(),
|
||||
ElementsAre("TopK: 2, Temp: 0.5\n", "Context: cheese\n",
|
||||
"Context: more\n", "Context: cheddar\n"));
|
||||
ElementsAre("TopK: 2, Temp: 0.5", "cheese", "more", "cheddar"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, CloneContextAndContinue) {
|
||||
@ -240,15 +237,13 @@ TEST_F(OnDeviceModelServiceTest, CloneContextAndContinue) {
|
||||
TestResponseHolder response;
|
||||
cloned->Generate(mojom::GenerateOptions::New(), response.BindRemote());
|
||||
response.WaitForCompletion();
|
||||
EXPECT_THAT(response.responses(),
|
||||
ElementsAre("Context: cheese\n", "Context: more\n"));
|
||||
EXPECT_THAT(response.responses(), ElementsAre("cheese", "more"));
|
||||
}
|
||||
{
|
||||
TestResponseHolder response;
|
||||
session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
|
||||
response.WaitForCompletion();
|
||||
EXPECT_THAT(response.responses(),
|
||||
ElementsAre("Context: cheese\n", "Context: more\n"));
|
||||
EXPECT_THAT(response.responses(), ElementsAre("cheese", "more"));
|
||||
}
|
||||
|
||||
session->Append(MakeInput("foo"), {});
|
||||
@ -257,17 +252,13 @@ TEST_F(OnDeviceModelServiceTest, CloneContextAndContinue) {
|
||||
TestResponseHolder response;
|
||||
session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
|
||||
response.WaitForCompletion();
|
||||
EXPECT_THAT(
|
||||
response.responses(),
|
||||
ElementsAre("Context: cheese\n", "Context: more\n", "Context: foo\n"));
|
||||
EXPECT_THAT(response.responses(), ElementsAre("cheese", "more", "foo"));
|
||||
}
|
||||
{
|
||||
TestResponseHolder response;
|
||||
cloned->Generate(mojom::GenerateOptions::New(), response.BindRemote());
|
||||
response.WaitForCompletion();
|
||||
EXPECT_THAT(
|
||||
response.responses(),
|
||||
ElementsAre("Context: cheese\n", "Context: more\n", "Context: bar\n"));
|
||||
EXPECT_THAT(response.responses(), ElementsAre("cheese", "more", "bar"));
|
||||
}
|
||||
}
|
||||
|
||||
@ -310,21 +301,11 @@ TEST_F(OnDeviceModelServiceTest, MultipleSessionsAppend) {
|
||||
response4.WaitForCompletion();
|
||||
response5.WaitForCompletion();
|
||||
|
||||
EXPECT_THAT(response1.responses(),
|
||||
ElementsAre("Context: cheese\n", "Context: more\n",
|
||||
"Context: cheddar\n"));
|
||||
EXPECT_THAT(
|
||||
response2.responses(),
|
||||
ElementsAre("Context: apple\n", "Context: banana\n", "Context: candy\n"));
|
||||
EXPECT_THAT(
|
||||
response3.responses(),
|
||||
ElementsAre("Context: apple\n", "Context: banana\n", "Context: chip\n"));
|
||||
EXPECT_THAT(
|
||||
response4.responses(),
|
||||
ElementsAre("Context: cheese\n", "Context: more\n", "Context: choco\n"));
|
||||
EXPECT_THAT(response5.responses(),
|
||||
ElementsAre("Context: apple\n", "Context: banana\n",
|
||||
"Context: orange\n"));
|
||||
EXPECT_THAT(response1.responses(), ElementsAre("cheese", "more", "cheddar"));
|
||||
EXPECT_THAT(response2.responses(), ElementsAre("apple", "banana", "candy"));
|
||||
EXPECT_THAT(response3.responses(), ElementsAre("apple", "banana", "chip"));
|
||||
EXPECT_THAT(response4.responses(), ElementsAre("cheese", "more", "choco"));
|
||||
EXPECT_THAT(response5.responses(), ElementsAre("apple", "banana", "orange"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, CountTokens) {
|
||||
@ -369,8 +350,7 @@ TEST_F(OnDeviceModelServiceTest, AppendWithTokenLimits) {
|
||||
response.WaitForCompletion();
|
||||
|
||||
EXPECT_THAT(response.responses(),
|
||||
ElementsAre("Context: big \n", "Context: big cheese\n",
|
||||
"Context: cheddar\n"));
|
||||
ElementsAre("big ", "big cheese", "cheddar"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, MultipleSessionsWaitPreviousSession) {
|
||||
@ -392,14 +372,14 @@ TEST_F(OnDeviceModelServiceTest, MultipleSessionsWaitPreviousSession) {
|
||||
|
||||
// Response from first session should still work.
|
||||
response1.WaitForCompletion();
|
||||
EXPECT_THAT(response1.responses(), ElementsAre("Context: 1\n"));
|
||||
EXPECT_THAT(response1.responses(), ElementsAre("1"));
|
||||
|
||||
// Second session still works.
|
||||
TestResponseHolder response2;
|
||||
session2->Append(MakeInput("2"), {});
|
||||
session2->Generate(mojom::GenerateOptions::New(), response2.BindRemote());
|
||||
response2.WaitForCompletion();
|
||||
EXPECT_THAT(response2.responses(), ElementsAre("Context: 2\n"));
|
||||
EXPECT_THAT(response2.responses(), ElementsAre("2"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, LoadsAdaptation) {
|
||||
@ -407,18 +387,18 @@ TEST_F(OnDeviceModelServiceTest, LoadsAdaptation) {
|
||||
FakeFile weights2("Adapt2");
|
||||
auto model = LoadModel();
|
||||
auto adaptation1 = LoadAdaptation(*model, weights1.Open());
|
||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n"));
|
||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("foo"));
|
||||
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
|
||||
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
|
||||
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
|
||||
|
||||
auto adaptation2 = LoadAdaptation(*model, weights2.Open());
|
||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n"));
|
||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("foo"));
|
||||
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
|
||||
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
|
||||
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
|
||||
EXPECT_THAT(GetResponses(*adaptation2, "foo"),
|
||||
ElementsAre("Adaptation: Adapt2 (1)\n", "Context: foo\n"));
|
||||
ElementsAre("Adaptation: Adapt2 (1)", "foo"));
|
||||
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
|
||||
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
|
||||
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, LoadsAdaptationWithPath) {
|
||||
@ -426,18 +406,18 @@ TEST_F(OnDeviceModelServiceTest, LoadsAdaptationWithPath) {
|
||||
FakeFile weights2("Adapt2");
|
||||
auto model = LoadModel(ml::ModelBackendType::kApuBackend);
|
||||
auto adaptation1 = LoadAdaptation(*model, weights1.Path());
|
||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n"));
|
||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("foo"));
|
||||
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
|
||||
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
|
||||
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
|
||||
|
||||
auto adaptation2 = LoadAdaptation(*model, weights2.Path());
|
||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n"));
|
||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("foo"));
|
||||
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
|
||||
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
|
||||
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
|
||||
EXPECT_THAT(GetResponses(*adaptation2, "foo"),
|
||||
ElementsAre("Adaptation: Adapt2 (1)\n", "Context: foo\n"));
|
||||
ElementsAre("Adaptation: Adapt2 (1)", "foo"));
|
||||
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
|
||||
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
|
||||
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, LoadingAdaptationDoesNotCancelSession) {
|
||||
@ -532,9 +512,8 @@ TEST_F(OnDeviceModelServiceTest, AppendWithTokens) {
|
||||
}
|
||||
response.WaitForCompletion();
|
||||
|
||||
EXPECT_THAT(response.responses(), ElementsAre("Context: System: hi End.\n",
|
||||
"Context: Model: hello End.\n",
|
||||
"Context: User: bye\n"));
|
||||
EXPECT_THAT(response.responses(),
|
||||
ElementsAre("System: hi End.", "Model: hello End.", "User: bye"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, AppendWithImages) {
|
||||
@ -580,8 +559,8 @@ TEST_F(OnDeviceModelServiceTest, AppendWithImages) {
|
||||
}
|
||||
|
||||
EXPECT_THAT(response.responses(),
|
||||
ElementsAre("Context: cheddar[Bitmap of size 7x21]cheese\n",
|
||||
"Context: bleu[Bitmap of size 63x42]cheese\n"));
|
||||
ElementsAre("cheddar[Bitmap of size 7x21]cheese",
|
||||
"bleu[Bitmap of size 63x42]cheese"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, ClassifyTextSafety) {
|
||||
@ -641,7 +620,7 @@ TEST_F(OnDeviceModelServiceTest, PerformanceHint) {
|
||||
auto model = LoadModel(ml::ModelBackendType::kGpuBackend,
|
||||
ml::ModelPerformanceHint::kFastestInference);
|
||||
EXPECT_THAT(GetResponses(*model, "foo"),
|
||||
ElementsAre("Fastest inference\n", "Context: foo\n"));
|
||||
ElementsAre("Fastest inference", "foo"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, Capabilities) {
|
||||
|
@ -181,29 +181,28 @@ void FakeOnDeviceSession::GenerateImpl(
|
||||
if (model_->performance_hint() ==
|
||||
ml::ModelPerformanceHint::kFastestInference) {
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
chunk->text = "Fastest inference\n";
|
||||
chunk->text = "Fastest inference";
|
||||
remote->OnResponse(std::move(chunk));
|
||||
}
|
||||
if (model_->data().base_weight != "0") {
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
chunk->text = "Base model: " + model_->data().base_weight + "\n";
|
||||
chunk->text = "Base model: " + model_->data().base_weight;
|
||||
remote->OnResponse(std::move(chunk));
|
||||
}
|
||||
if (!model_->data().adaptation_model_weight.empty()) {
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
chunk->text =
|
||||
"Adaptation model: " + model_->data().adaptation_model_weight + "\n";
|
||||
chunk->text = "Adaptation model: " + model_->data().adaptation_model_weight;
|
||||
remote->OnResponse(std::move(chunk));
|
||||
}
|
||||
if (!model_->data().cache_weight.empty()) {
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
chunk->text = "Cache weight: " + model_->data().cache_weight + "\n";
|
||||
chunk->text = "Cache weight: " + model_->data().cache_weight;
|
||||
remote->OnResponse(std::move(chunk));
|
||||
}
|
||||
|
||||
if (priority_ == on_device_model::mojom::Priority::kBackground) {
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
chunk->text = "Priority: background\n";
|
||||
chunk->text = "Priority: background";
|
||||
remote->OnResponse(std::move(chunk));
|
||||
}
|
||||
|
||||
@ -211,11 +210,11 @@ void FakeOnDeviceSession::GenerateImpl(
|
||||
const auto& constraint = *options->constraint;
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
if (constraint.is_json_schema()) {
|
||||
chunk->text = "Constraint: json " + constraint.get_json_schema() + "\n";
|
||||
chunk->text = "Constraint: json " + constraint.get_json_schema();
|
||||
} else if (constraint.is_regex()) {
|
||||
chunk->text = "Constraint: regex " + constraint.get_regex() + "\n";
|
||||
chunk->text = "Constraint: regex " + constraint.get_regex();
|
||||
} else {
|
||||
chunk->text = "Constraint: unknown\n";
|
||||
chunk->text = "Constraint: unknown";
|
||||
}
|
||||
remote->OnResponse(std::move(chunk));
|
||||
}
|
||||
@ -226,15 +225,14 @@ void FakeOnDeviceSession::GenerateImpl(
|
||||
std::string text = CtxToString(*context, params_->capabilities);
|
||||
output_token_count += text.size();
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
chunk->text = "Context: " + text + "\n";
|
||||
chunk->text = text;
|
||||
remote->OnResponse(std::move(chunk));
|
||||
}
|
||||
if (params_->top_k != ml::kMinTopK ||
|
||||
params_->temperature != ml::kMinTemperature) {
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
chunk->text += "TopK: " + base::NumberToString(params_->top_k) +
|
||||
", Temp: " + base::NumberToString(params_->temperature) +
|
||||
"\n";
|
||||
", Temp: " + base::NumberToString(params_->temperature);
|
||||
remote->OnResponse(std::move(chunk));
|
||||
}
|
||||
} else {
|
||||
|
Reference in New Issue
Block a user