Simplify strings returned from the fake on-device model service
The "\n" and "Context: " pieces of the output were mostly noise and don't meaningfully add anything to test correctness. This change removes those pieces to make tests written against the fake a bit cleaner. Bug: 415808003 Change-Id: Id34f394ba24850f2c89219d66d27ed80474b58fa Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6542686 Reviewed-by: Steven Holte <holte@chromium.org> Commit-Queue: Clark DuVall <cduvall@chromium.org> Cr-Commit-Position: refs/heads/main@{#1459596}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
63b4e023e3
commit
c5a6b2936d
chrome/browser/ai
components/optimization_guide/core/model_execution
services/on_device_model
@ -279,7 +279,7 @@ bool SessionGenerate(ChromeMLSession session,
|
||||
|
||||
if (instance->model_instance->performance_hint ==
|
||||
ml::ModelPerformanceHint::kFastestInference) {
|
||||
OutputChunk("Fastest inference\n");
|
||||
OutputChunk("Fastest inference");
|
||||
}
|
||||
if (!instance->adaptation_data.empty()) {
|
||||
std::string adaptation_str = "Adaptation: " + instance->adaptation_data;
|
||||
@ -287,19 +287,19 @@ bool SessionGenerate(ChromeMLSession session,
|
||||
adaptation_str +=
|
||||
" (" + base::NumberToString(*instance->adaptation_file_id) + ")";
|
||||
}
|
||||
OutputChunk(adaptation_str + "\n");
|
||||
OutputChunk(adaptation_str);
|
||||
}
|
||||
|
||||
// Only include sampling params if they're not the respective default values.
|
||||
if (instance->top_k != 1 || instance->temperature != 0) {
|
||||
OutputChunk(base::StrCat(
|
||||
{"TopK: ", base::NumberToString(instance->top_k),
|
||||
", Temp: ", base::NumberToString(instance->temperature), "\n"}));
|
||||
", Temp: ", base::NumberToString(instance->temperature)}));
|
||||
}
|
||||
|
||||
if (!instance->context.empty()) {
|
||||
for (const std::string& context : instance->context) {
|
||||
OutputChunk("Context: " + context + "\n");
|
||||
OutputChunk(context);
|
||||
}
|
||||
}
|
||||
if (options->constraint) {
|
||||
|
@ -179,9 +179,9 @@ class OnDeviceModelServiceTest : public testing::Test {
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, Responds) {
|
||||
auto model = LoadModel();
|
||||
EXPECT_THAT(GetResponses(*model, "bar"), ElementsAre("Context: bar\n"));
|
||||
EXPECT_THAT(GetResponses(*model, "bar"), ElementsAre("bar"));
|
||||
// Try another input on the same model.
|
||||
EXPECT_THAT(GetResponses(*model, "cat"), ElementsAre("Context: cat\n"));
|
||||
EXPECT_THAT(GetResponses(*model, "cat"), ElementsAre("cat"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, Append) {
|
||||
@ -196,9 +196,7 @@ TEST_F(OnDeviceModelServiceTest, Append) {
|
||||
session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
|
||||
response.WaitForCompletion();
|
||||
|
||||
EXPECT_THAT(response.responses(),
|
||||
ElementsAre("Context: cheese\n", "Context: more\n",
|
||||
"Context: cheddar\n"));
|
||||
EXPECT_THAT(response.responses(), ElementsAre("cheese", "more", "cheddar"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, PerSessionSamplingParams) {
|
||||
@ -221,8 +219,7 @@ TEST_F(OnDeviceModelServiceTest, PerSessionSamplingParams) {
|
||||
response.WaitForCompletion();
|
||||
|
||||
EXPECT_THAT(response.responses(),
|
||||
ElementsAre("TopK: 2, Temp: 0.5\n", "Context: cheese\n",
|
||||
"Context: more\n", "Context: cheddar\n"));
|
||||
ElementsAre("TopK: 2, Temp: 0.5", "cheese", "more", "cheddar"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, CloneContextAndContinue) {
|
||||
@ -240,15 +237,13 @@ TEST_F(OnDeviceModelServiceTest, CloneContextAndContinue) {
|
||||
TestResponseHolder response;
|
||||
cloned->Generate(mojom::GenerateOptions::New(), response.BindRemote());
|
||||
response.WaitForCompletion();
|
||||
EXPECT_THAT(response.responses(),
|
||||
ElementsAre("Context: cheese\n", "Context: more\n"));
|
||||
EXPECT_THAT(response.responses(), ElementsAre("cheese", "more"));
|
||||
}
|
||||
{
|
||||
TestResponseHolder response;
|
||||
session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
|
||||
response.WaitForCompletion();
|
||||
EXPECT_THAT(response.responses(),
|
||||
ElementsAre("Context: cheese\n", "Context: more\n"));
|
||||
EXPECT_THAT(response.responses(), ElementsAre("cheese", "more"));
|
||||
}
|
||||
|
||||
session->Append(MakeInput("foo"), {});
|
||||
@ -257,17 +252,13 @@ TEST_F(OnDeviceModelServiceTest, CloneContextAndContinue) {
|
||||
TestResponseHolder response;
|
||||
session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
|
||||
response.WaitForCompletion();
|
||||
EXPECT_THAT(
|
||||
response.responses(),
|
||||
ElementsAre("Context: cheese\n", "Context: more\n", "Context: foo\n"));
|
||||
EXPECT_THAT(response.responses(), ElementsAre("cheese", "more", "foo"));
|
||||
}
|
||||
{
|
||||
TestResponseHolder response;
|
||||
cloned->Generate(mojom::GenerateOptions::New(), response.BindRemote());
|
||||
response.WaitForCompletion();
|
||||
EXPECT_THAT(
|
||||
response.responses(),
|
||||
ElementsAre("Context: cheese\n", "Context: more\n", "Context: bar\n"));
|
||||
EXPECT_THAT(response.responses(), ElementsAre("cheese", "more", "bar"));
|
||||
}
|
||||
}
|
||||
|
||||
@ -310,21 +301,11 @@ TEST_F(OnDeviceModelServiceTest, MultipleSessionsAppend) {
|
||||
response4.WaitForCompletion();
|
||||
response5.WaitForCompletion();
|
||||
|
||||
EXPECT_THAT(response1.responses(),
|
||||
ElementsAre("Context: cheese\n", "Context: more\n",
|
||||
"Context: cheddar\n"));
|
||||
EXPECT_THAT(
|
||||
response2.responses(),
|
||||
ElementsAre("Context: apple\n", "Context: banana\n", "Context: candy\n"));
|
||||
EXPECT_THAT(
|
||||
response3.responses(),
|
||||
ElementsAre("Context: apple\n", "Context: banana\n", "Context: chip\n"));
|
||||
EXPECT_THAT(
|
||||
response4.responses(),
|
||||
ElementsAre("Context: cheese\n", "Context: more\n", "Context: choco\n"));
|
||||
EXPECT_THAT(response5.responses(),
|
||||
ElementsAre("Context: apple\n", "Context: banana\n",
|
||||
"Context: orange\n"));
|
||||
EXPECT_THAT(response1.responses(), ElementsAre("cheese", "more", "cheddar"));
|
||||
EXPECT_THAT(response2.responses(), ElementsAre("apple", "banana", "candy"));
|
||||
EXPECT_THAT(response3.responses(), ElementsAre("apple", "banana", "chip"));
|
||||
EXPECT_THAT(response4.responses(), ElementsAre("cheese", "more", "choco"));
|
||||
EXPECT_THAT(response5.responses(), ElementsAre("apple", "banana", "orange"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, CountTokens) {
|
||||
@ -369,8 +350,7 @@ TEST_F(OnDeviceModelServiceTest, AppendWithTokenLimits) {
|
||||
response.WaitForCompletion();
|
||||
|
||||
EXPECT_THAT(response.responses(),
|
||||
ElementsAre("Context: big \n", "Context: big cheese\n",
|
||||
"Context: cheddar\n"));
|
||||
ElementsAre("big ", "big cheese", "cheddar"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, MultipleSessionsWaitPreviousSession) {
|
||||
@ -392,14 +372,14 @@ TEST_F(OnDeviceModelServiceTest, MultipleSessionsWaitPreviousSession) {
|
||||
|
||||
// Response from first session should still work.
|
||||
response1.WaitForCompletion();
|
||||
EXPECT_THAT(response1.responses(), ElementsAre("Context: 1\n"));
|
||||
EXPECT_THAT(response1.responses(), ElementsAre("1"));
|
||||
|
||||
// Second session still works.
|
||||
TestResponseHolder response2;
|
||||
session2->Append(MakeInput("2"), {});
|
||||
session2->Generate(mojom::GenerateOptions::New(), response2.BindRemote());
|
||||
response2.WaitForCompletion();
|
||||
EXPECT_THAT(response2.responses(), ElementsAre("Context: 2\n"));
|
||||
EXPECT_THAT(response2.responses(), ElementsAre("2"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, LoadsAdaptation) {
|
||||
@ -407,18 +387,18 @@ TEST_F(OnDeviceModelServiceTest, LoadsAdaptation) {
|
||||
FakeFile weights2("Adapt2");
|
||||
auto model = LoadModel();
|
||||
auto adaptation1 = LoadAdaptation(*model, weights1.Open());
|
||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n"));
|
||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("foo"));
|
||||
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
|
||||
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
|
||||
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
|
||||
|
||||
auto adaptation2 = LoadAdaptation(*model, weights2.Open());
|
||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n"));
|
||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("foo"));
|
||||
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
|
||||
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
|
||||
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
|
||||
EXPECT_THAT(GetResponses(*adaptation2, "foo"),
|
||||
ElementsAre("Adaptation: Adapt2 (1)\n", "Context: foo\n"));
|
||||
ElementsAre("Adaptation: Adapt2 (1)", "foo"));
|
||||
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
|
||||
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
|
||||
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, LoadsAdaptationWithPath) {
|
||||
@ -426,18 +406,18 @@ TEST_F(OnDeviceModelServiceTest, LoadsAdaptationWithPath) {
|
||||
FakeFile weights2("Adapt2");
|
||||
auto model = LoadModel(ml::ModelBackendType::kApuBackend);
|
||||
auto adaptation1 = LoadAdaptation(*model, weights1.Path());
|
||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n"));
|
||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("foo"));
|
||||
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
|
||||
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
|
||||
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
|
||||
|
||||
auto adaptation2 = LoadAdaptation(*model, weights2.Path());
|
||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n"));
|
||||
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("foo"));
|
||||
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
|
||||
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
|
||||
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
|
||||
EXPECT_THAT(GetResponses(*adaptation2, "foo"),
|
||||
ElementsAre("Adaptation: Adapt2 (1)\n", "Context: foo\n"));
|
||||
ElementsAre("Adaptation: Adapt2 (1)", "foo"));
|
||||
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
|
||||
ElementsAre("Adaptation: Adapt1 (0)\n", "Context: foo\n"));
|
||||
ElementsAre("Adaptation: Adapt1 (0)", "foo"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, LoadingAdaptationDoesNotCancelSession) {
|
||||
@ -532,9 +512,8 @@ TEST_F(OnDeviceModelServiceTest, AppendWithTokens) {
|
||||
}
|
||||
response.WaitForCompletion();
|
||||
|
||||
EXPECT_THAT(response.responses(), ElementsAre("Context: System: hi End.\n",
|
||||
"Context: Model: hello End.\n",
|
||||
"Context: User: bye\n"));
|
||||
EXPECT_THAT(response.responses(),
|
||||
ElementsAre("System: hi End.", "Model: hello End.", "User: bye"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, AppendWithImages) {
|
||||
@ -580,8 +559,8 @@ TEST_F(OnDeviceModelServiceTest, AppendWithImages) {
|
||||
}
|
||||
|
||||
EXPECT_THAT(response.responses(),
|
||||
ElementsAre("Context: cheddar[Bitmap of size 7x21]cheese\n",
|
||||
"Context: bleu[Bitmap of size 63x42]cheese\n"));
|
||||
ElementsAre("cheddar[Bitmap of size 7x21]cheese",
|
||||
"bleu[Bitmap of size 63x42]cheese"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, ClassifyTextSafety) {
|
||||
@ -641,7 +620,7 @@ TEST_F(OnDeviceModelServiceTest, PerformanceHint) {
|
||||
auto model = LoadModel(ml::ModelBackendType::kGpuBackend,
|
||||
ml::ModelPerformanceHint::kFastestInference);
|
||||
EXPECT_THAT(GetResponses(*model, "foo"),
|
||||
ElementsAre("Fastest inference\n", "Context: foo\n"));
|
||||
ElementsAre("Fastest inference", "foo"));
|
||||
}
|
||||
|
||||
TEST_F(OnDeviceModelServiceTest, Capabilities) {
|
||||
|
@ -181,29 +181,28 @@ void FakeOnDeviceSession::GenerateImpl(
|
||||
if (model_->performance_hint() ==
|
||||
ml::ModelPerformanceHint::kFastestInference) {
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
chunk->text = "Fastest inference\n";
|
||||
chunk->text = "Fastest inference";
|
||||
remote->OnResponse(std::move(chunk));
|
||||
}
|
||||
if (model_->data().base_weight != "0") {
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
chunk->text = "Base model: " + model_->data().base_weight + "\n";
|
||||
chunk->text = "Base model: " + model_->data().base_weight;
|
||||
remote->OnResponse(std::move(chunk));
|
||||
}
|
||||
if (!model_->data().adaptation_model_weight.empty()) {
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
chunk->text =
|
||||
"Adaptation model: " + model_->data().adaptation_model_weight + "\n";
|
||||
chunk->text = "Adaptation model: " + model_->data().adaptation_model_weight;
|
||||
remote->OnResponse(std::move(chunk));
|
||||
}
|
||||
if (!model_->data().cache_weight.empty()) {
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
chunk->text = "Cache weight: " + model_->data().cache_weight + "\n";
|
||||
chunk->text = "Cache weight: " + model_->data().cache_weight;
|
||||
remote->OnResponse(std::move(chunk));
|
||||
}
|
||||
|
||||
if (priority_ == on_device_model::mojom::Priority::kBackground) {
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
chunk->text = "Priority: background\n";
|
||||
chunk->text = "Priority: background";
|
||||
remote->OnResponse(std::move(chunk));
|
||||
}
|
||||
|
||||
@ -211,11 +210,11 @@ void FakeOnDeviceSession::GenerateImpl(
|
||||
const auto& constraint = *options->constraint;
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
if (constraint.is_json_schema()) {
|
||||
chunk->text = "Constraint: json " + constraint.get_json_schema() + "\n";
|
||||
chunk->text = "Constraint: json " + constraint.get_json_schema();
|
||||
} else if (constraint.is_regex()) {
|
||||
chunk->text = "Constraint: regex " + constraint.get_regex() + "\n";
|
||||
chunk->text = "Constraint: regex " + constraint.get_regex();
|
||||
} else {
|
||||
chunk->text = "Constraint: unknown\n";
|
||||
chunk->text = "Constraint: unknown";
|
||||
}
|
||||
remote->OnResponse(std::move(chunk));
|
||||
}
|
||||
@ -226,15 +225,14 @@ void FakeOnDeviceSession::GenerateImpl(
|
||||
std::string text = CtxToString(*context, params_->capabilities);
|
||||
output_token_count += text.size();
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
chunk->text = "Context: " + text + "\n";
|
||||
chunk->text = text;
|
||||
remote->OnResponse(std::move(chunk));
|
||||
}
|
||||
if (params_->top_k != ml::kMinTopK ||
|
||||
params_->temperature != ml::kMinTemperature) {
|
||||
auto chunk = mojom::ResponseChunk::New();
|
||||
chunk->text += "TopK: " + base::NumberToString(params_->top_k) +
|
||||
", Temp: " + base::NumberToString(params_->temperature) +
|
||||
"\n";
|
||||
", Temp: " + base::NumberToString(params_->temperature);
|
||||
remote->OnResponse(std::move(chunk));
|
||||
}
|
||||
} else {
|
||||
|
Reference in New Issue
Block a user