0

Simplify strings returned from the fake on-device model service

The "\n" and "Context: " pieces of the output were mostly noise and
don't meaningfully add anything to test correctness. This change removes
those pieces to make tests written against the fake a bit cleaner.

Bug: 415808003
Change-Id: Id34f394ba24850f2c89219d66d27ed80474b58fa
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6542686
Reviewed-by: Steven Holte <holte@chromium.org>
Commit-Queue: Clark DuVall <cduvall@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1459596}
This commit is contained in:
Clark DuVall
2025-05-13 11:38:28 -07:00
committed by Chromium LUCI CQ
parent 63b4e023e3
commit c5a6b2936d
5 changed files with 136 additions and 170 deletions
chrome/browser/ai
components/optimization_guide/core/model_execution
services/on_device_model

@ -279,7 +279,7 @@ bool SessionGenerate(ChromeMLSession session,
if (instance->model_instance->performance_hint ==
ml::ModelPerformanceHint::kFastestInference) {
OutputChunk("Fastest inference\n");
OutputChunk("Fastest inference");
}
if (!instance->adaptation_data.empty()) {
std::string adaptation_str = "Adaptation: " + instance->adaptation_data;
@ -287,19 +287,19 @@ bool SessionGenerate(ChromeMLSession session,
adaptation_str +=
" (" + base::NumberToString(*instance->adaptation_file_id) + ")";
}
OutputChunk(adaptation_str + "\n");
OutputChunk(adaptation_str);
}
// Only include sampling params if they're not the respective default values.
if (instance->top_k != 1 || instance->temperature != 0) {
OutputChunk(base::StrCat(
{"TopK: ", base::NumberToString(instance->top_k),
", Temp: ", base::NumberToString(instance->temperature), "\n"}));
", Temp: ", base::NumberToString(instance->temperature)}));
}
if (!instance->context.empty()) {
for (const std::string& context : instance->context) {
OutputChunk("Context: " + context + "\n");
OutputChunk(context);
}
}
if (options->constraint) {

@ -179,9 +179,9 @@ class OnDeviceModelServiceTest : public testing::Test {
TEST_F(OnDeviceModelServiceTest, Responds) {
auto model = LoadModel();
EXPECT_THAT(GetResponses(*model, "bar"), ElementsAre("Context: bar\n"));
EXPECT_THAT(GetResponses(*model, "bar"), ElementsAre("bar"));
// Try another input on the same model.
EXPECT_THAT(GetResponses(*model, "cat"), ElementsAre("Context: cat\n"));
EXPECT_THAT(GetResponses(*model, "cat"), ElementsAre("cat"));
}
TEST_F(OnDeviceModelServiceTest, Append) {
@ -196,9 +196,7 @@ TEST_F(OnDeviceModelServiceTest, Append) {
session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion();
EXPECT_THAT(response.responses(),
ElementsAre("Context: cheese\n", "Context: more\n",
"Context: cheddar\n"));
EXPECT_THAT(response.responses(), ElementsAre("cheese", "more", "cheddar"));
}
TEST_F(OnDeviceModelServiceTest, PerSessionSamplingParams) {
@ -221,8 +219,7 @@ TEST_F(OnDeviceModelServiceTest, PerSessionSamplingParams) {
response.WaitForCompletion();
EXPECT_THAT(response.responses(),
ElementsAre("TopK: 2, Temp: 0.5\n", "Context: cheese\n",
"Context: more\n", "Context: cheddar\n"));
ElementsAre("TopK: 2, Temp: 0.5", "cheese", "more", "cheddar"));
}
TEST_F(OnDeviceModelServiceTest, CloneContextAndContinue) {
@ -240,15 +237,13 @@ TEST_F(OnDeviceModelServiceTest, CloneContextAndContinue) {
TestResponseHolder response;
cloned->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion();
EXPECT_THAT(response.responses(),
ElementsAre("Context: cheese\n", "Context: more\n"));
EXPECT_THAT(response.responses(), ElementsAre("cheese", "more"));
}
{
TestResponseHolder response;
session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion();
EXPECT_THAT(response.responses(),
ElementsAre("Context: cheese\n", "Context: more\n"));
EXPECT_THAT(response.responses(), ElementsAre("cheese", "more"));
}
session->Append(MakeInput("foo"), {});
@ -257,17 +252,13 @@ TEST_F(OnDeviceModelServiceTest, CloneContextAndContinue) {
TestResponseHolder response;
session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion();
EXPECT_THAT(
response.responses(),
ElementsAre("Context: cheese\n", "Context: more\n", "Context: foo\n"));
EXPECT_THAT(response.responses(), ElementsAre("cheese", "more", "foo"));
}
{
TestResponseHolder response;
cloned->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion();
EXPECT_THAT(
response.responses(),
ElementsAre("Context: cheese\n", "Context: more\n", "Context: bar\n"));
EXPECT_THAT(response.responses(), ElementsAre("cheese", "more", "bar"));
}
}
@ -310,21 +301,11 @@ TEST_F(OnDeviceModelServiceTest, MultipleSessionsAppend) {
response4.WaitForCompletion();
response5.WaitForCompletion();
EXPECT_THAT(response1.responses(),
ElementsAre("Context: cheese\n", "Context: more\n",
"Context: cheddar\n"));
EXPECT_THAT(
response2.responses(),
ElementsAre("Context: apple\n", "Context: banana\n", "Context: candy\n"));
EXPECT_THAT(
response3.responses(),
ElementsAre("Context: apple\n", "Context: banana\n", "Context: chip\n"));
EXPECT_THAT(
response4.responses(),
ElementsAre("Context: cheese\n", "Context: more\n", "Context: choco\n"));
EXPECT_THAT(response5.responses(),
ElementsAre("Context: apple\n", "Context: banana\n",
"Context: orange\n"));
EXPECT_THAT(response1.responses(), ElementsAre("cheese", "more", "cheddar"));
EXPECT_THAT(response2.responses(), ElementsAre("apple", "banana", "candy"));
EXPECT_THAT(response3.responses(), ElementsAre("apple", "banana", "chip"));
EXPECT_THAT(response4.responses(), ElementsAre("cheese", "more", "choco"));
EXPECT_THAT(response5.responses(), ElementsAre("apple", "banana", "orange"));
}
TEST_F(OnDeviceModelServiceTest, CountTokens) {
@ -369,8 +350,7 @@ TEST_F(OnDeviceModelServiceTest, AppendWithTokenLimits) {
response.WaitForCompletion();
EXPECT_THAT(response.responses(),
ElementsAre("Context: big \n", "Context: big cheese\n",
"Context: cheddar\n"));
ElementsAre("big ", "big cheese", "cheddar"));
}
TEST_F(OnDeviceModelServiceTest, MultipleSessionsWaitPreviousSession) {
@ -392,14 +372,14 @@ TEST_F(OnDeviceModelServiceTest, MultipleSessionsWaitPreviousSession) {
// Response from first session should still work.
response1.WaitForCompletion();
EXPECT_THAT(response1.responses(), ElementsAre("Context: 1\n"));
EXPECT_THAT(response1.responses(), ElementsAre("1"));
// Second session still works.
TestResponseHolder response2;
session2->Append(MakeInput("2"), {});
session2->Generate(mojom::GenerateOptions::New(), response2.BindRemote());
response2.WaitForCompletion();
EXPECT_THAT(response2.responses(), ElementsAre("Context: 2\n"));
EXPECT_THAT(response2.responses(), ElementsAre("2"));
}
TEST_F(OnDeviceModelServiceTest, LoadsAdaptation) {
@ -407,18 +387,18 @@ TEST_F(OnDeviceModelServiceTest, LoadsAdaptation) {
FakeFile weights2("Adapt2");
auto model = LoadModel();
auto adaptation1 = LoadAdaptation(*model, weights1.Open());
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n"));
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("foo"));
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
auto adaptation2 = LoadAdaptation(*model, weights2.Open());
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n"));
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("foo"));
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
EXPECT_THAT(GetResponses(*adaptation2, "foo"),
ElementsAre("Adaptation: Adapt2 (1)\n", "Context: foo\n"));
ElementsAre("Adaptation: Adapt2 (1)", "foo"));
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
}
TEST_F(OnDeviceModelServiceTest, LoadsAdaptationWithPath) {
@ -426,18 +406,18 @@ TEST_F(OnDeviceModelServiceTest, LoadsAdaptationWithPath) {
FakeFile weights2("Adapt2");
auto model = LoadModel(ml::ModelBackendType::kApuBackend);
auto adaptation1 = LoadAdaptation(*model, weights1.Path());
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n"));
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("foo"));
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
auto adaptation2 = LoadAdaptation(*model, weights2.Path());
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n"));
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("foo"));
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
EXPECT_THAT(GetResponses(*adaptation2, "foo"),
ElementsAre("Adaptation: Adapt2 (1)\n", "Context: foo\n"));
ElementsAre("Adaptation: Adapt2 (1)", "foo"));
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
}
TEST_F(OnDeviceModelServiceTest, LoadingAdaptationDoesNotCancelSession) {
@ -532,9 +512,8 @@ TEST_F(OnDeviceModelServiceTest, AppendWithTokens) {
}
response.WaitForCompletion();
EXPECT_THAT(response.responses(), ElementsAre("Context: System: hi End.\n",
"Context: Model: hello End.\n",
"Context: User: bye\n"));
EXPECT_THAT(response.responses(),
ElementsAre("System: hi End.", "Model: hello End.", "User: bye"));
}
TEST_F(OnDeviceModelServiceTest, AppendWithImages) {
@ -580,8 +559,8 @@ TEST_F(OnDeviceModelServiceTest, AppendWithImages) {
}
EXPECT_THAT(response.responses(),
ElementsAre("Context: cheddar[Bitmap of size 7x21]cheese\n",
"Context: bleu[Bitmap of size 63x42]cheese\n"));
ElementsAre("cheddar[Bitmap of size 7x21]cheese",
"bleu[Bitmap of size 63x42]cheese"));
}
TEST_F(OnDeviceModelServiceTest, ClassifyTextSafety) {
@ -641,7 +620,7 @@ TEST_F(OnDeviceModelServiceTest, PerformanceHint) {
auto model = LoadModel(ml::ModelBackendType::kGpuBackend,
ml::ModelPerformanceHint::kFastestInference);
EXPECT_THAT(GetResponses(*model, "foo"),
ElementsAre("Fastest inference\n", "Context: foo\n"));
ElementsAre("Fastest inference", "foo"));
}
TEST_F(OnDeviceModelServiceTest, Capabilities) {

@ -181,29 +181,28 @@ void FakeOnDeviceSession::GenerateImpl(
if (model_->performance_hint() ==
ml::ModelPerformanceHint::kFastestInference) {
auto chunk = mojom::ResponseChunk::New();
chunk->text = "Fastest inference\n";
chunk->text = "Fastest inference";
remote->OnResponse(std::move(chunk));
}
if (model_->data().base_weight != "0") {
auto chunk = mojom::ResponseChunk::New();
chunk->text = "Base model: " + model_->data().base_weight + "\n";
chunk->text = "Base model: " + model_->data().base_weight;
remote->OnResponse(std::move(chunk));
}
if (!model_->data().adaptation_model_weight.empty()) {
auto chunk = mojom::ResponseChunk::New();
chunk->text =
"Adaptation model: " + model_->data().adaptation_model_weight + "\n";
chunk->text = "Adaptation model: " + model_->data().adaptation_model_weight;
remote->OnResponse(std::move(chunk));
}
if (!model_->data().cache_weight.empty()) {
auto chunk = mojom::ResponseChunk::New();
chunk->text = "Cache weight: " + model_->data().cache_weight + "\n";
chunk->text = "Cache weight: " + model_->data().cache_weight;
remote->OnResponse(std::move(chunk));
}
if (priority_ == on_device_model::mojom::Priority::kBackground) {
auto chunk = mojom::ResponseChunk::New();
chunk->text = "Priority: background\n";
chunk->text = "Priority: background";
remote->OnResponse(std::move(chunk));
}
@ -211,11 +210,11 @@ void FakeOnDeviceSession::GenerateImpl(
const auto& constraint = *options->constraint;
auto chunk = mojom::ResponseChunk::New();
if (constraint.is_json_schema()) {
chunk->text = "Constraint: json " + constraint.get_json_schema() + "\n";
chunk->text = "Constraint: json " + constraint.get_json_schema();
} else if (constraint.is_regex()) {
chunk->text = "Constraint: regex " + constraint.get_regex() + "\n";
chunk->text = "Constraint: regex " + constraint.get_regex();
} else {
chunk->text = "Constraint: unknown\n";
chunk->text = "Constraint: unknown";
}
remote->OnResponse(std::move(chunk));
}
@ -226,15 +225,14 @@ void FakeOnDeviceSession::GenerateImpl(
std::string text = CtxToString(*context, params_->capabilities);
output_token_count += text.size();
auto chunk = mojom::ResponseChunk::New();
chunk->text = "Context: " + text + "\n";
chunk->text = text;
remote->OnResponse(std::move(chunk));
}
if (params_->top_k != ml::kMinTopK ||
params_->temperature != ml::kMinTemperature) {
auto chunk = mojom::ResponseChunk::New();
chunk->text += "TopK: " + base::NumberToString(params_->top_k) +
", Temp: " + base::NumberToString(params_->temperature) +
"\n";
", Temp: " + base::NumberToString(params_->temperature);
remote->OnResponse(std::move(chunk));
}
} else {