0

Introduce Append() and Generate() to on-device Session interface

These new methods will replace the old AddContext() and Execute()
methods. Now Append() will be used exclusively for appending input to a
session, and Generate() will be used for starting output. This
simplifies things, since previously it was a bit confusing that
Execute() could also take input.

As part of this change, the `ignore_context` option was removed from
Append() since it is not being actively used by any features.

The AddContext() and Execute() methods can be removed and associated
code cleaned up once the internal repos have updated to use the new
methods.

Bug: 395163391
Change-Id: Id2dec21149025532b46d7a922119d6a54d48032a
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6251405
Reviewed-by: Alex Gough <ajgo@chromium.org>
Reviewed-by: Steven Holte <holte@chromium.org>
Commit-Queue: Clark DuVall <cduvall@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1418958}
This commit is contained in:
Clark DuVall
2025-02-11 15:40:57 -08:00
committed by Chromium LUCI CQ
parent 179b9472a8
commit ca35d7c79d
18 changed files with 434 additions and 269 deletions

@ -74,11 +74,11 @@ void OnDeviceContext::CloneSession(
}
void OnDeviceContext::AddContext() {
auto options = on_device_model::mojom::InputOptions::New();
auto options = on_device_model::mojom::AppendOptions::New();
options->input = input_.Clone();
options->max_tokens = opts_.token_limits.max_context_tokens;
options->token_offset = 0;
session_->AddContext(std::move(options), client_.BindNewPipeAndPassRemote());
session_->Append(std::move(options), client_.BindNewPipeAndPassRemote());
}
void OnDeviceContext::OnComplete(uint32_t tokens_processed) {

@ -127,8 +127,7 @@ OnDeviceExecution::OnDeviceExecution(
last_message_(std::move(message)),
histogram_logger_(std::move(logger)),
callback_(std::move(callback)),
cleanup_callback_(std::move(cleanup_callback)),
receiver_(this) {
cleanup_callback_(std::move(cleanup_callback)) {
log_.mutable_model_execution_info()
->mutable_on_device_model_execution_info()
->add_execution_infos();
@ -209,10 +208,13 @@ void OnDeviceExecution::BeginExecution(OnDeviceContext& context,
logged_request->set_execution_string(input->ToString());
LogRequest(opts_.logger.get(), *logged_request);
auto options = on_device_model::mojom::InputOptions::New();
options->input = std::move(input->input);
options->max_tokens = opts_.token_limits.max_execute_tokens;
options->ignore_context = input->should_ignore_input_context;
auto append_options = on_device_model::mojom::AppendOptions::New();
append_options->input = std::move(input->input);
append_options->max_tokens = opts_.token_limits.max_execute_tokens;
session_->Append(std::move(append_options),
context_receiver_.BindNewPipeAndPassRemote());
auto options = on_device_model::mojom::GenerateOptions::New();
options->max_output_tokens = opts_.token_limits.max_output_tokens;
options->top_k = sampling_params.top_k;
options->temperature = sampling_params.temperature;
@ -224,7 +226,7 @@ void OnDeviceExecution::BeginExecution(OnDeviceContext& context,
}
void OnDeviceExecution::OnRequestSafetyResult(
on_device_model::mojom::InputOptionsPtr options,
on_device_model::mojom::GenerateOptionsPtr options,
SafetyChecker::Result safety_result) {
if (safety_result.failed_to_run) {
FallbackToRemote(Result::kFailedConstructingMessage);
@ -250,8 +252,8 @@ void OnDeviceExecution::OnRequestSafetyResult(
}
void OnDeviceExecution::BeginRequestExecution(
on_device_model::mojom::InputOptionsPtr options) {
session_->Execute(std::move(options), receiver_.BindNewPipeAndPassRemote());
on_device_model::mojom::GenerateOptionsPtr options) {
session_->Generate(std::move(options), receiver_.BindNewPipeAndPassRemote());
receiver_.set_disconnect_handler(base::BindOnce(
&OnDeviceExecution::OnResponderDisconnect, base::Unretained(this)));
}
@ -310,11 +312,6 @@ void OnDeviceExecution::OnComplete(
receiver_.reset(); // Suppress expected disconnect
bool has_repeats = MutableLoggedResponse()->has_repeats();
// TODO(holte): Make input_token_count available earlier / in more cases.
if (!has_repeats) {
MutableLoggedRequest()->set_execution_num_tokens_processed(
summary->input_token_count);
}
LogResponseHasRepeats(feature_, has_repeats);
LogResponseCompleteTokens(feature_, num_response_tokens_);
@ -331,6 +328,10 @@ void OnDeviceExecution::OnComplete(
RunRawOutputSafetyCheck(ResponseCompleteness::kComplete);
}
void OnDeviceExecution::OnComplete(uint32_t tokens_processed) {
MutableLoggedRequest()->set_execution_num_tokens_processed(tokens_processed);
}
void OnDeviceExecution::OnResponderDisconnect() {
// OnComplete resets the receiver, so this implies that the response is
// incomplete and there was either a service crash or model eviction.
@ -552,6 +553,7 @@ void OnDeviceExecution::Cleanup(bool healthy) {
weak_ptr_factory_.InvalidateWeakPtrs();
session_.reset();
receiver_.reset();
context_receiver_.reset();
callback_.Reset();
log_.Clear();
current_response_.clear();

@ -47,7 +47,8 @@ void InvokeStreamingCallbackWithRemoteResult(
// The state for an ongoing ExecuteModel() call.
class OnDeviceExecution final
: public on_device_model::mojom::StreamingResponder {
: public on_device_model::mojom::StreamingResponder,
public on_device_model::mojom::ContextClient {
public:
// Possible outcomes of ExecuteModel().
// These values are persisted to logs. Entries should not be renumbered and
@ -144,16 +145,21 @@ class OnDeviceExecution final
// Callback invoked with RequestSafetyCheck result.
// Calls BeginRequestExecution if safety checks pass.
void OnRequestSafetyResult(on_device_model::mojom::InputOptionsPtr options,
void OnRequestSafetyResult(on_device_model::mojom::GenerateOptionsPtr options,
SafetyChecker::Result safety_result);
// Begins request execution (leads to OnResponse/OnComplete, which will
// call RunRawOutputSafetyCheck).
void BeginRequestExecution(on_device_model::mojom::InputOptionsPtr options);
void BeginRequestExecution(
on_device_model::mojom::GenerateOptionsPtr options);
// on_device_model::mojom::StreamingResponder:
void OnResponse(on_device_model::mojom::ResponseChunkPtr chunk) override;
void OnComplete(on_device_model::mojom::ResponseSummaryPtr summary) override;
// on_device_model::mojom::ContextClient:
void OnComplete(uint32_t tokens_processed) override;
void OnResponderDisconnect();
// Evaluates raw output safety (leads to OnRawOutputSafetyResult).
@ -265,7 +271,8 @@ class OnDeviceExecution final
// Should pass true to indicate healthy completion, or false if unhealthy.
base::OnceCallback<void(bool)> cleanup_callback_;
mojo::Receiver<on_device_model::mojom::StreamingResponder> receiver_;
mojo::Receiver<on_device_model::mojom::StreamingResponder> receiver_{this};
mojo::Receiver<on_device_model::mojom::ContextClient> context_receiver_{this};
// Factory for weak pointers related to this session that are invalidated
// with the request state.

@ -431,7 +431,7 @@ TEST_F(OnDeviceModelServiceControllerTest, BaseModelExecutionSuccess) {
session->ExecuteModel(PageUrlRequest("foo"),
response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus());
const std::string expected_response = "Input: execute:foo\n";
const std::string expected_response = "Context: execute:foo off:0 max:1024\n";
EXPECT_EQ(*response_.value(), expected_response);
EXPECT_TRUE(*response_.provided_by_on_device());
EXPECT_THAT(response_.partials(), ElementsAre(expected_response));
@ -503,7 +503,8 @@ TEST_F(OnDeviceModelServiceControllerTest, AdaptationModelExecutionSuccess) {
session->ExecuteModel(PageUrlRequest("foo"),
response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus());
EXPECT_EQ(*response_.value(), "Adaptation model: 1015\nInput: execute:foo\n");
EXPECT_EQ(*response_.value(),
"Adaptation model: 1015\nContext: execute:foo off:0 max:1024\n");
// If we destroy all sessions and wait long enough, everything should idle out
// and the service should get terminated.
@ -559,11 +560,11 @@ TEST_F(OnDeviceModelServiceControllerTest,
ASSERT_TRUE(compose_response.GetFinalStatus());
EXPECT_EQ(*compose_response.value(),
"Adaptation model: 1015\nInput: execute:foo\n");
"Adaptation model: 1015\nContext: execute:foo off:0 max:1024\n");
EXPECT_TRUE(*compose_response.provided_by_on_device());
ASSERT_TRUE(test_response.GetFinalStatus());
EXPECT_EQ(*test_response.value(),
"Adaptation model: 2024\nInput: execute:bar\n");
"Adaptation model: 2024\nContext: execute:bar off:0 max:1024\n");
EXPECT_TRUE(*test_response.provided_by_on_device());
session_compose.reset();
@ -622,10 +623,10 @@ TEST_F(OnDeviceModelServiceControllerTest, ModelAdaptationAndBaseModelSuccess) {
ASSERT_TRUE(compose_response.GetFinalStatus());
EXPECT_EQ(*compose_response.value(),
"Adaptation model: 1015\nInput: execute:foo\n");
"Adaptation model: 1015\nContext: execute:foo off:0 max:1024\n");
EXPECT_TRUE(*compose_response.provided_by_on_device());
ASSERT_TRUE(test_response.GetFinalStatus());
EXPECT_EQ(*test_response.value(), "Input: execute:bar\n");
EXPECT_EQ(*test_response.value(), "Context: execute:bar off:0 max:1024\n");
EXPECT_TRUE(*test_response.provided_by_on_device());
session_compose.reset();
@ -725,7 +726,7 @@ TEST_F(OnDeviceModelServiceControllerTest, MidSessionModelUpdate) {
response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus());
// Note that the session does not execute with the new model.
EXPECT_EQ(*response_.value(), "Input: execute:foo\n");
EXPECT_EQ(*response_.value(), "Context: execute:foo off:0 max:1024\n");
}
TEST_F(OnDeviceModelServiceControllerTest, SessionBeforeAndAfterModelUpdate) {
@ -753,7 +754,8 @@ TEST_F(OnDeviceModelServiceControllerTest, SessionBeforeAndAfterModelUpdate) {
session2->ExecuteModel(PageUrlRequest("foo"),
response2.GetStreamingCallback());
ASSERT_TRUE(response2.GetFinalStatus());
EXPECT_EQ(*response2.value(), "Base model: 2\nInput: execute:foo\n");
EXPECT_EQ(*response2.value(),
"Base model: 2\nContext: execute:foo off:0 max:1024\n");
}
TEST_F(OnDeviceModelServiceControllerTest, SessionFailsForInvalidFeature) {
@ -1642,7 +1644,7 @@ TEST_F(OnDeviceModelServiceControllerTest, CancelsExecuteOnExecute) {
EXPECT_EQ(
*resp1.error(),
OptimizationGuideModelExecutionError::ModelExecutionError::kCancelled);
EXPECT_EQ(*resp2.value(), "Input: execute:bar\n");
EXPECT_EQ(*resp2.value(), "Context: execute:bar off:0 max:1024\n");
}
TEST_F(OnDeviceModelServiceControllerTest, WontStartSessionAfterGpuBlocked) {
@ -1829,7 +1831,7 @@ TEST_F(OnDeviceModelServiceControllerTest, AddContextDisconnectExecute) {
ExecuteModelResult::kUsedOnDevice, 1);
std::string expected_response =
("Context: ctx:foo off:0 max:4096\n"
"Input: execute:foobaz\n");
"Context: execute:foobaz off:0 max:1024\n");
EXPECT_EQ(*response_.value(), expected_response);
}
@ -1870,7 +1872,7 @@ TEST_F(OnDeviceModelServiceControllerTest, AddContextMultipleSessions) {
ASSERT_TRUE(response_.GetFinalStatus());
std::string expected_response1 =
("Context: ctx:bar off:0 max:4096\n"
"Input: execute:bar2\n");
"Context: execute:bar2 off:0 max:1024\n");
EXPECT_EQ(*response_.value(), expected_response1);
ResponseHolder response2;
@ -1878,7 +1880,7 @@ TEST_F(OnDeviceModelServiceControllerTest, AddContextMultipleSessions) {
ASSERT_TRUE(response2.GetFinalStatus());
std::string expected_response2 =
("Context: ctx:foo off:0 max:4096\n"
"Input: execute:foo1\n");
"Context: execute:foo1 off:0 max:1024\n");
EXPECT_EQ(*response2.value(), expected_response2);
}
@ -2081,7 +2083,8 @@ TEST_F(OnDeviceModelServiceControllerTest, RedactedField) {
session1->ExecuteModel(UserInputRequest("foo"),
response_.GetStreamingCallback());
task_environment_.RunUntilIdle();
const std::string expected_response1 = "Input: execute:foo\n";
const std::string expected_response1 =
"Context: execute:foo off:0 max:1024\n";
EXPECT_EQ(*response_.value(), expected_response1);
EXPECT_THAT(response_.partials(), ElementsAre(expected_response1));
@ -2092,19 +2095,20 @@ TEST_F(OnDeviceModelServiceControllerTest, RedactedField) {
session2->ExecuteModel(UserInputRequest("abarx"),
response2.GetStreamingCallback());
task_environment_.RunUntilIdle();
const std::string expected_response2 = "Input: execute:abarx\n";
const std::string expected_response2 =
"Context: execute:abarx off:0 max:1024\n";
EXPECT_EQ(*response2.value(), expected_response2);
EXPECT_THAT(response2.partials(), ElementsAre(expected_response2));
// Output contains redacted text (and input doesn't), so redact.
fake_settings_.set_execute_result({"Input: abarx\n"});
fake_settings_.set_execute_result({"Context: abarx off:0 max:1024\n"});
auto session3 = CreateSession();
ASSERT_TRUE(session3);
ResponseHolder response3;
session3->ExecuteModel(UserInputRequest("foo"),
response3.GetStreamingCallback());
task_environment_.RunUntilIdle();
const std::string expected_response3 = "Input: a[###]x\n";
const std::string expected_response3 = "Context: a[###]x off:0 max:1024\n";
EXPECT_EQ(*response3.value(), expected_response3);
EXPECT_THAT(response3.partials(), ElementsAre(expected_response3));
}
@ -2182,7 +2186,7 @@ TEST_F(OnDeviceModelServiceControllerTest, UsePreviousResponseForRewrite) {
});
// Force 'bar' to be returned from model.
fake_settings_.set_execute_result({"Input: bar\n"});
fake_settings_.set_execute_result({"Context: bar off:0 max:1024\n"});
auto session = CreateSession();
ASSERT_TRUE(session);
@ -2191,7 +2195,7 @@ TEST_F(OnDeviceModelServiceControllerTest, UsePreviousResponseForRewrite) {
response_.GetStreamingCallback());
task_environment_.RunUntilIdle();
// `bar` shouldn't be rewritten as it's in the input.
const std::string expected_response = "Input: bar\n";
const std::string expected_response = "Context: bar off:0 max:1024\n";
EXPECT_EQ(*response_.value(), expected_response);
EXPECT_THAT(response_.partials(), ElementsAre(expected_response));
}
@ -2208,13 +2212,14 @@ TEST_F(OnDeviceModelServiceControllerTest, ReplacementText) {
});
// Output contains redacted text (and input doesn't), so redact.
fake_settings_.set_execute_result({"Input: abarx\n"});
fake_settings_.set_execute_result({"Context: abarx off:0 max:1024\n"});
auto session = CreateSession();
ASSERT_TRUE(session);
session->ExecuteModel(UserInputRequest("foo"),
response_.GetStreamingCallback());
task_environment_.RunUntilIdle();
const std::string expected_response = "Input: a[redacted]x\n";
const std::string expected_response =
"Context: a[redacted]x off:0 max:1024\n";
EXPECT_EQ(*response_.value(), expected_response);
EXPECT_THAT(response_.partials(), ElementsAre(expected_response));
}
@ -2516,10 +2521,12 @@ TEST_F(OnDeviceModelServiceControllerTest, UsesAdapterTopKAndTemperature) {
response_.GetStreamingCallback());
task_environment_.RunUntilIdle();
EXPECT_TRUE(response_.value());
const std::string expected_response =
"Input: execute:foo\nTopK: 4, Temp: 1.5\n";
EXPECT_EQ(*response_.value(), expected_response);
EXPECT_THAT(response_.partials(), ElementsAre(expected_response));
const std::vector<std::string> partial_responses = ConcatResponses({
"Context: execute:foo off:0 max:1024\n",
"TopK: 4, Temp: 1.5\n",
});
EXPECT_EQ(*response_.value(), partial_responses.back());
EXPECT_THAT(response_.partials(), ElementsAreArray(partial_responses));
}
TEST_F(OnDeviceModelServiceControllerTest, UsesSessionTopKAndTemperature) {
@ -2546,10 +2553,12 @@ TEST_F(OnDeviceModelServiceControllerTest, UsesSessionTopKAndTemperature) {
response_.GetStreamingCallback());
task_environment_.RunUntilIdle();
EXPECT_TRUE(response_.value());
const std::string expected_response =
"Input: execute:foo\nTopK: 3, Temp: 2\n";
EXPECT_EQ(*response_.value(), expected_response);
EXPECT_THAT(response_.partials(), ElementsAre(expected_response));
const std::vector<std::string> partial_responses = ConcatResponses({
"Context: execute:foo off:0 max:1024\n",
"TopK: 3, Temp: 2\n",
});
EXPECT_EQ(*response_.value(), partial_responses.back());
EXPECT_THAT(response_.partials(), ElementsAreArray(partial_responses));
}
// Validate that a missing partial output config suppresses partial output.
@ -3406,7 +3415,8 @@ TEST_F(OnDeviceModelServiceControllerTest, SendsPerformanceHint) {
session->ExecuteModel(PageUrlRequest("foo"),
response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus());
EXPECT_EQ(*response_.value(), "Fastest inference\nInput: execute:foo\n");
EXPECT_EQ(*response_.value(),
"Fastest inference\nContext: execute:foo off:0 max:1024\n");
}
SkBitmap CreateBlackSkBitmap(int width, int height) {
@ -3472,7 +3482,7 @@ TEST_F(OnDeviceModelServiceControllerTest, ImageExecutionSuccess) {
response_.GetStreamingCallback());
ASSERT_TRUE(response_.GetFinalStatus());
EXPECT_EQ(*response_.value(),
"Context: <image> off:0 max:22\nInput: <image>\n");
"Context: <image> off:0 max:22\nContext: <image> off:0 max:1024\n");
}
} // namespace optimization_guide

@ -31,14 +31,25 @@ void OnDeviceModelValidator::ValidateNextPrompt() {
}
receiver_.reset();
active_session_.reset();
session_->Clone(active_session_.BindNewPipeAndPassReceiver());
// base::Unretained is safe since `this` owns the session.
active_session_.set_disconnect_handler(base::BindOnce(
&OnDeviceModelValidator::FinishValidation, base::Unretained(this),
OnDeviceModelValidationResult::kInterrupted));
current_response_ = "";
auto options = on_device_model::mojom::InputOptions::New();
options->input = on_device_model::mojom::Input::New();
options->input->pieces.push_back(
auto append_options = on_device_model::mojom::AppendOptions::New();
append_options->input = on_device_model::mojom::Input::New();
append_options->input->pieces.push_back(
validation_config_.validation_prompts(index_).prompt());
active_session_->Append(std::move(append_options), {});
auto generate_options = on_device_model::mojom::GenerateOptions::New();
// Avoid bad responses spamming output and taking too long.
options->max_output_tokens = 64;
session_->Execute(std::move(options), receiver_.BindNewPipeAndPassRemote());
generate_options->max_output_tokens = 64;
active_session_->Generate(std::move(generate_options),
receiver_.BindNewPipeAndPassRemote());
}
void OnDeviceModelValidator::OnResponse(

@ -38,6 +38,7 @@ class OnDeviceModelValidator
int index_ = 0;
proto::OnDeviceModelValidationConfig validation_config_;
mojo::Remote<on_device_model::mojom::Session> session_;
mojo::Remote<on_device_model::mojom::Session> active_session_;
mojo::Receiver<on_device_model::mojom::StreamingResponder> receiver_{this};
FinishCallback finish_callback_;
};

@ -221,12 +221,9 @@ bool SessionExecuteModel(ChromeMLSession session,
OutputChunk("Adaptation: " + instance->adaptation_data_ + "\n");
}
if (!instance->context_.empty()) {
const std::string last = instance->context_.back();
instance->context_.pop_back();
for (const std::string& context : instance->context_) {
OutputChunk("Context: " + context + "\n");
}
OutputChunk("Input: " + last + "\n");
}
OutputChunk("");
return true;

@ -307,23 +307,21 @@ SessionImpl::SessionImpl(const ChromeML& chrome_ml,
SessionImpl::~SessionImpl() = default;
DISABLE_CFI_DLSYM
void SessionImpl::AddContext(
on_device_model::mojom::InputOptionsPtr input,
void SessionImpl::Append(
on_device_model::mojom::AppendOptionsPtr options,
mojo::PendingRemote<on_device_model::mojom::ContextClient> client,
base::OnceClosure on_complete) {
auto context_holder = std::make_unique<ContextHolder>(
std::move(client),
base::BindOnce(&SessionImpl::RemoveContext, base::Unretained(this)),
std::move(on_complete));
if (input->max_tokens == 0 || input->max_tokens > max_tokens_) {
input->max_tokens = max_tokens_;
if (options->max_tokens == 0 || options->max_tokens > max_tokens_) {
options->max_tokens = max_tokens_;
}
input->top_k = GetTopK(input->top_k);
input->temperature = GetTemperature(input->temperature);
ChromeMLContextSavedFn context_saved_fn =
context_holder->CreateContextSavedFn();
*context_holder->GetCancelFn() =
session_->Execute(std::move(input), nullptr, context_saved_fn);
session_->Append(std::move(options), context_saved_fn);
context_holders_.insert(std::move(context_holder));
}
@ -348,6 +346,22 @@ void SessionImpl::Execute(
cloned_raw->Execute(std::move(input), output_fn, context_saved_fn);
}
DISABLE_CFI_DLSYM
void SessionImpl::Generate(
on_device_model::mojom::GenerateOptionsPtr options,
mojo::PendingRemote<on_device_model::mojom::StreamingResponder> response,
base::OnceClosure on_complete) {
auto cloned = session_->Clone();
auto cloned_raw = cloned.get(); // For Generate after std::move
responder_ = std::make_unique<Responder>(
std::move(response), std::move(on_complete), std::move(cloned));
ChromeMLExecutionOutputFn output_fn = responder_->CreateOutputFn();
options->top_k = GetTopK(options->top_k);
options->temperature = GetTemperature(options->temperature);
*responder_->GetCancelFn() =
cloned_raw->Generate(std::move(options), output_fn);
}
DISABLE_CFI_DLSYM
void SessionImpl::SizeInTokens(on_device_model::mojom::InputPtr input,
base::OnceCallback<void(uint32_t)> callback) {

@ -43,14 +43,17 @@ class COMPONENT_EXPORT(ON_DEVICE_MODEL_ML) SessionImpl final {
SessionImpl(const SessionImpl&) = delete;
SessionImpl& operator=(const SessionImpl&) = delete;
void AddContext(
on_device_model::mojom::InputOptionsPtr input,
mojo::PendingRemote<on_device_model::mojom::ContextClient> client,
base::OnceClosure on_complete);
void Append(on_device_model::mojom::AppendOptionsPtr options,
mojo::PendingRemote<on_device_model::mojom::ContextClient> client,
base::OnceClosure on_complete);
void Execute(
on_device_model::mojom::InputOptionsPtr input,
mojo::PendingRemote<on_device_model::mojom::StreamingResponder> response,
base::OnceClosure on_complete);
void Generate(
on_device_model::mojom::GenerateOptionsPtr input,
mojo::PendingRemote<on_device_model::mojom::StreamingResponder> response,
base::OnceClosure on_complete);
void SizeInTokens(on_device_model::mojom::InputPtr input,
base::OnceCallback<void(uint32_t)> callback);
void Score(const std::string& text, base::OnceCallback<void(float)> callback);

@ -86,6 +86,28 @@ ChromeMLCancelFn SessionAccessor::Execute(
return [canceler] { canceler->Cancel(); };
}
ChromeMLCancelFn SessionAccessor::Append(
on_device_model::mojom::AppendOptionsPtr options,
ChromeMLContextSavedFn context_saved_fn) {
auto canceler = base::MakeRefCounted<Canceler>(chrome_ml_.get());
task_runner_->PostTask(
FROM_HERE, base::BindOnce(&SessionAccessor::AppendInternal,
base::Unretained(this), std::move(options),
std::move(context_saved_fn), canceler));
return [canceler] { canceler->Cancel(); };
}
ChromeMLCancelFn SessionAccessor::Generate(
on_device_model::mojom::GenerateOptionsPtr options,
ChromeMLExecutionOutputFn output_fn) {
auto canceler = base::MakeRefCounted<Canceler>(chrome_ml_.get());
task_runner_->PostTask(
FROM_HERE,
base::BindOnce(&SessionAccessor::GenerateInternal, base::Unretained(this),
std::move(options), std::move(output_fn), canceler));
return [canceler] { canceler->Cancel(); };
}
void SessionAccessor::Score(const std::string& text, ChromeMLScoreFn score_fn) {
task_runner_->PostTask(
FROM_HERE,
@ -160,6 +182,43 @@ void SessionAccessor::ExecuteInternal(
canceler->get());
}
DISABLE_CFI_DLSYM
void SessionAccessor::AppendInternal(
on_device_model::mojom::AppendOptionsPtr append_options,
ChromeMLContextSavedFn context_saved_fn,
scoped_refptr<Canceler> canceler) {
DCHECK(task_runner_->RunsTasksInCurrentSequence());
ChromeMLExecuteOptions options{
.max_tokens = append_options->max_tokens,
.token_offset = append_options->token_offset,
};
options.input = append_options->input->pieces.data();
options.input_size = append_options->input->pieces.size();
if (context_saved_fn) {
options.context_saved_fn = &context_saved_fn;
}
chrome_ml_->api().SessionExecuteModel(session_, model_, &options,
canceler->get());
}
DISABLE_CFI_DLSYM
void SessionAccessor::GenerateInternal(
on_device_model::mojom::GenerateOptionsPtr generate_options,
ChromeMLExecutionOutputFn output_fn,
scoped_refptr<Canceler> canceler) {
DCHECK(task_runner_->RunsTasksInCurrentSequence());
ChromeMLExecuteOptions options{
.max_output_tokens = generate_options->max_output_tokens,
.top_k = generate_options->top_k.value_or(1),
.temperature = generate_options->temperature.value_or(0),
};
if (output_fn) {
options.execution_output_fn = &output_fn;
}
chrome_ml_->api().SessionExecuteModel(session_, model_, &options,
canceler->get());
}
DISABLE_CFI_DLSYM
void SessionAccessor::ScoreInternal(const std::string& text,
ChromeMLScoreFn score_fn) {

@ -37,6 +37,10 @@ class COMPONENT_EXPORT(ON_DEVICE_MODEL_ML) SessionAccessor {
ChromeMLCancelFn Execute(on_device_model::mojom::InputOptionsPtr input,
ChromeMLExecutionOutputFn output_fn,
ChromeMLContextSavedFn context_saved_fn);
ChromeMLCancelFn Append(on_device_model::mojom::AppendOptionsPtr options,
ChromeMLContextSavedFn context_saved_fn);
ChromeMLCancelFn Generate(on_device_model::mojom::GenerateOptionsPtr options,
ChromeMLExecutionOutputFn output_fn);
void Score(const std::string& text, ChromeMLScoreFn score_fn);
void SizeInTokens(on_device_model::mojom::InputPtr input,
ChromeMLSizeInTokensFn size_in_tokens_fn);
@ -54,6 +58,13 @@ class COMPONENT_EXPORT(ON_DEVICE_MODEL_ML) SessionAccessor {
ChromeMLExecutionOutputFn output_fn,
ChromeMLContextSavedFn context_saved_fn,
scoped_refptr<Canceler> canceler);
void AppendInternal(on_device_model::mojom::AppendOptionsPtr append_options,
ChromeMLContextSavedFn context_saved_fn,
scoped_refptr<Canceler> canceler);
void GenerateInternal(
on_device_model::mojom::GenerateOptionsPtr generate_options,
ChromeMLExecutionOutputFn output_fn,
scoped_refptr<Canceler> canceler);
void ScoreInternal(const std::string& text, ChromeMLScoreFn score_fn);
void SizeInTokensInternal(on_device_model::mojom::InputPtr input,
ChromeMLSizeInTokensFn size_in_tokens_fn);

@ -48,9 +48,14 @@ class SessionWrapper final : public mojom::Session {
void AddContext(mojom::InputOptionsPtr input,
mojo::PendingRemote<mojom::ContextClient> client) override;
void Append(mojom::AppendOptionsPtr options,
mojo::PendingRemote<mojom::ContextClient> client) override;
void Execute(
mojom::InputOptionsPtr input,
mojo::PendingRemote<mojom::StreamingResponder> response) override;
void Generate(
mojom::GenerateOptionsPtr options,
mojo::PendingRemote<mojom::StreamingResponder> response) override;
void GetSizeInTokens(mojom::InputPtr input,
GetSizeInTokensCallback callback) override;
void Score(const std::string& text, ScoreCallback callback) override;
@ -59,11 +64,11 @@ class SessionWrapper final : public mojom::Session {
mojo::Receiver<mojom::Session>& receiver() { return receiver_; }
private:
void AddContextInternal(mojom::InputOptionsPtr input,
mojo::PendingRemote<mojom::ContextClient> client,
base::OnceClosure on_complete) {
session_->AddContext(std::move(input), std::move(client),
std::move(on_complete));
void AppendInternal(mojom::AppendOptionsPtr options,
mojo::PendingRemote<mojom::ContextClient> client,
base::OnceClosure on_complete) {
session_->Append(std::move(options), std::move(client),
std::move(on_complete));
}
void ExecuteInternal(mojom::InputOptionsPtr input,
@ -73,6 +78,13 @@ class SessionWrapper final : public mojom::Session {
std::move(on_complete));
}
void GenerateInternal(mojom::GenerateOptionsPtr input,
mojo::PendingRemote<mojom::StreamingResponder> response,
base::OnceClosure on_complete) {
session_->Generate(std::move(input), std::move(response),
std::move(on_complete));
}
void GetSizeInTokensInternal(mojom::InputPtr input,
GetSizeInTokensCallback callback,
base::OnceClosure on_complete) {
@ -246,15 +258,24 @@ class ModelWrapper final : public mojom::OnDeviceModel {
void SessionWrapper::AddContext(
mojom::InputOptionsPtr input,
mojo::PendingRemote<mojom::ContextClient> client) {
auto append_options = mojom::AppendOptions::New();
append_options->input = std::move(input->input);
append_options->max_tokens = input->max_tokens;
append_options->token_offset = input->token_offset;
Append(std::move(append_options), std::move(client));
}
void SessionWrapper::Append(mojom::AppendOptionsPtr options,
mojo::PendingRemote<mojom::ContextClient> client) {
if (!model_) {
return;
}
auto add_context_internal = base::BindOnce(
&SessionWrapper::AddContextInternal, weak_ptr_factory_.GetWeakPtr(),
std::move(input), std::move(client));
auto append_internal = base::BindOnce(&SessionWrapper::AppendInternal,
weak_ptr_factory_.GetWeakPtr(),
std::move(options), std::move(client));
model_->AddAndRunPendingTask(std::move(add_context_internal),
model_->AddAndRunPendingTask(std::move(append_internal),
weak_ptr_factory_.GetWeakPtr());
}
@ -273,6 +294,21 @@ void SessionWrapper::Execute(
weak_ptr_factory_.GetWeakPtr());
}
void SessionWrapper::Generate(
mojom::GenerateOptionsPtr options,
mojo::PendingRemote<mojom::StreamingResponder> response) {
if (!model_) {
return;
}
auto generate_internal = base::BindOnce(
&SessionWrapper::GenerateInternal, weak_ptr_factory_.GetWeakPtr(),
std::move(options), std::move(response));
model_->AddAndRunPendingTask(std::move(generate_internal),
weak_ptr_factory_.GetWeakPtr());
}
void SessionWrapper::GetSizeInTokens(mojom::InputPtr input,
GetSizeInTokensCallback callback) {
if (!model_) {

@ -131,12 +131,12 @@ class OnDeviceModelServiceTest : public testing::Test {
return LoadAdaptationWithParams(model, std::move(params));
}
mojom::InputOptionsPtr MakeInput(const std::string& input) {
mojom::AppendOptionsPtr MakeInput(const std::string& input) {
return MakeInput({ml::InputPiece(input)});
}
mojom::InputOptionsPtr MakeInput(std::vector<ml::InputPiece> input) {
auto options = mojom::InputOptions::New();
mojom::AppendOptionsPtr MakeInput(std::vector<ml::InputPiece> input) {
auto options = mojom::AppendOptions::New();
options->input = mojom::Input::New(std::move(input));
return options;
}
@ -146,7 +146,11 @@ class OnDeviceModelServiceTest : public testing::Test {
TestResponseHolder response;
mojo::Remote<mojom::Session> session;
model.StartSession(session.BindNewPipeAndPassReceiver());
session->Execute(MakeInput(input), response.BindRemote());
auto options = mojom::AppendOptions::New();
options->input =
mojom::Input::New(std::vector<ml::InputPiece>{ml::InputPiece(input)});
session->Append(std::move(options), {});
session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion();
return response.responses();
}
@ -165,25 +169,26 @@ class OnDeviceModelServiceTest : public testing::Test {
TEST_F(OnDeviceModelServiceTest, Responds) {
auto model = LoadModel();
EXPECT_THAT(GetResponses(*model, "bar"), ElementsAre("Input: bar\n"));
EXPECT_THAT(GetResponses(*model, "bar"), ElementsAre("Context: bar\n"));
// Try another input on the same model.
EXPECT_THAT(GetResponses(*model, "cat"), ElementsAre("Input: cat\n"));
EXPECT_THAT(GetResponses(*model, "cat"), ElementsAre("Context: cat\n"));
}
TEST_F(OnDeviceModelServiceTest, AddContext) {
TEST_F(OnDeviceModelServiceTest, Append) {
auto model = LoadModel();
TestResponseHolder response;
mojo::Remote<mojom::Session> session;
model->StartSession(session.BindNewPipeAndPassReceiver());
session->AddContext(MakeInput("cheese"), {});
session->AddContext(MakeInput("more"), {});
session->Execute(MakeInput("cheddar"), response.BindRemote());
session->Append(MakeInput("cheese"), {});
session->Append(MakeInput("more"), {});
session->Append(MakeInput("cheddar"), {});
session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion();
EXPECT_THAT(
response.responses(),
ElementsAre("Context: cheese\n", "Context: more\n", "Input: cheddar\n"));
EXPECT_THAT(response.responses(),
ElementsAre("Context: cheese\n", "Context: more\n",
"Context: cheddar\n"));
}
TEST_F(OnDeviceModelServiceTest, CloneContextAndContinue) {
@ -191,134 +196,79 @@ TEST_F(OnDeviceModelServiceTest, CloneContextAndContinue) {
mojo::Remote<mojom::Session> session;
model->StartSession(session.BindNewPipeAndPassReceiver());
session->AddContext(MakeInput("cheese"), {});
session->AddContext(MakeInput("more"), {});
session->Append(MakeInput("cheese"), {});
session->Append(MakeInput("more"), {});
mojo::Remote<mojom::Session> cloned;
session->Clone(cloned.BindNewPipeAndPassReceiver());
{
TestResponseHolder response;
cloned->Execute(MakeInput("cheddar"), response.BindRemote());
cloned->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion();
EXPECT_THAT(response.responses(),
ElementsAre("Context: cheese\n", "Context: more\n",
"Input: cheddar\n"));
ElementsAre("Context: cheese\n", "Context: more\n"));
}
{
TestResponseHolder response;
session->Execute(MakeInput("swiss"), response.BindRemote());
session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion();
EXPECT_THAT(response.responses(),
ElementsAre("Context: cheese\n", "Context: more\n"));
}
session->Append(MakeInput("foo"), {});
cloned->Append(MakeInput("bar"), {});
{
TestResponseHolder response;
session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion();
EXPECT_THAT(
response.responses(),
ElementsAre("Context: cheese\n", "Context: more\n", "Input: swiss\n"));
}
session->AddContext(MakeInput("foo"), {});
cloned->AddContext(MakeInput("bar"), {});
{
TestResponseHolder response;
session->Execute(MakeInput("swiss"), response.BindRemote());
response.WaitForCompletion();
EXPECT_THAT(response.responses(),
ElementsAre("Context: cheese\n", "Context: more\n",
"Context: foo\n", "Input: swiss\n"));
ElementsAre("Context: cheese\n", "Context: more\n", "Context: foo\n"));
}
{
TestResponseHolder response;
cloned->Execute(MakeInput("cheddar"), response.BindRemote());
cloned->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion();
EXPECT_THAT(response.responses(),
ElementsAre("Context: cheese\n", "Context: more\n",
"Context: bar\n", "Input: cheddar\n"));
EXPECT_THAT(
response.responses(),
ElementsAre("Context: cheese\n", "Context: more\n", "Context: bar\n"));
}
}
TEST_F(OnDeviceModelServiceTest, MultipleSessionsAddContext) {
TEST_F(OnDeviceModelServiceTest, MultipleSessionsAppend) {
auto model = LoadModel();
TestResponseHolder response1, response2, response3, response4, response5;
mojo::Remote<mojom::Session> session1, session2;
mojo::Remote<mojom::Session> session1, session2, session3, session4, session5;
model->StartSession(session1.BindNewPipeAndPassReceiver());
model->StartSession(session2.BindNewPipeAndPassReceiver());
session1->AddContext(MakeInput("cheese"), {});
session1->AddContext(MakeInput("more"), {});
session2->AddContext(MakeInput("apple"), {});
session1->Append(MakeInput("cheese"), {});
session1->Append(MakeInput("more"), {});
session2->Append(MakeInput("apple"), {});
session1->Execute(MakeInput("cheddar"), response1.BindRemote());
session1->Clone(session3.BindNewPipeAndPassReceiver());
session1->Append(MakeInput("cheddar"), {});
session1->Generate(mojom::GenerateOptions::New(), response1.BindRemote());
session2->AddContext(MakeInput("banana"), {});
session2->Append(MakeInput("banana"), {});
session2->Execute(MakeInput("candy"), response2.BindRemote());
session2->Execute(MakeInput("chip"), response3.BindRemote());
session1->Execute(MakeInput("choco"), response4.BindRemote());
session2->Execute(MakeInput("orange"), response5.BindRemote());
session2->Clone(session4.BindNewPipeAndPassReceiver());
session2->Append(MakeInput("candy"), {});
session2->Generate(mojom::GenerateOptions::New(), response2.BindRemote());
response1.WaitForCompletion();
response2.WaitForCompletion();
response3.WaitForCompletion();
response4.WaitForCompletion();
response5.WaitForCompletion();
session4->Clone(session5.BindNewPipeAndPassReceiver());
session4->Append(MakeInput("chip"), {});
session4->Generate(mojom::GenerateOptions::New(), response3.BindRemote());
EXPECT_THAT(
response1.responses(),
ElementsAre("Context: cheese\n", "Context: more\n", "Input: cheddar\n"));
EXPECT_THAT(
response2.responses(),
ElementsAre("Context: apple\n", "Context: banana\n", "Input: candy\n"));
EXPECT_THAT(
response3.responses(),
ElementsAre("Context: apple\n", "Context: banana\n", "Input: chip\n"));
EXPECT_THAT(
response4.responses(),
ElementsAre("Context: cheese\n", "Context: more\n", "Input: choco\n"));
EXPECT_THAT(
response5.responses(),
ElementsAre("Context: apple\n", "Context: banana\n", "Input: orange\n"));
}
session3->Append(MakeInput("choco"), {});
session3->Generate(mojom::GenerateOptions::New(), response4.BindRemote());
TEST_F(OnDeviceModelServiceTest, IgnoresContext) {
auto model = LoadModel();
TestResponseHolder response;
mojo::Remote<mojom::Session> session;
model->StartSession(session.BindNewPipeAndPassReceiver());
session->AddContext(MakeInput("cheese"), {});
auto input = MakeInput("cheddar");
input->ignore_context = true;
session->Execute(std::move(input), response.BindRemote());
response.WaitForCompletion();
EXPECT_THAT(response.responses(), ElementsAre("Input: cheddar\n"));
}
TEST_F(OnDeviceModelServiceTest, MultipleSessionsIgnoreContext) {
auto model = LoadModel();
TestResponseHolder response1, response2, response3, response4, response5;
mojo::Remote<mojom::Session> session1, session2;
model->StartSession(session1.BindNewPipeAndPassReceiver());
model->StartSession(session2.BindNewPipeAndPassReceiver());
session1->AddContext(MakeInput("cheese"), {});
session1->Execute(MakeInput("cheddar"), response1.BindRemote());
session1->AddContext(MakeInput("more"), {});
session2->AddContext(MakeInput("apple"), {});
session2->AddContext(MakeInput("banana"), {});
session2->Execute(MakeInput("candy"), response2.BindRemote());
auto chip = MakeInput("chip");
chip->ignore_context = true;
session2->Execute(std::move(chip), response3.BindRemote());
auto choco = MakeInput("choco");
choco->ignore_context = true;
session1->Execute(std::move(choco), response4.BindRemote());
session2->Execute(MakeInput("orange"), response5.BindRemote());
session5->Append(MakeInput("orange"), {});
session5->Generate(mojom::GenerateOptions::New(), response5.BindRemote());
response1.WaitForCompletion();
response2.WaitForCompletion();
@ -327,15 +277,20 @@ TEST_F(OnDeviceModelServiceTest, MultipleSessionsIgnoreContext) {
response5.WaitForCompletion();
EXPECT_THAT(response1.responses(),
ElementsAre("Context: cheese\n", "Input: cheddar\n"));
ElementsAre("Context: cheese\n", "Context: more\n",
"Context: cheddar\n"));
EXPECT_THAT(
response2.responses(),
ElementsAre("Context: apple\n", "Context: banana\n", "Input: candy\n"));
EXPECT_THAT(response3.responses(), ElementsAre("Input: chip\n"));
EXPECT_THAT(response4.responses(), ElementsAre("Input: choco\n"));
ElementsAre("Context: apple\n", "Context: banana\n", "Context: candy\n"));
EXPECT_THAT(
response5.responses(),
ElementsAre("Context: apple\n", "Context: banana\n", "Input: orange\n"));
response3.responses(),
ElementsAre("Context: apple\n", "Context: banana\n", "Context: chip\n"));
EXPECT_THAT(
response4.responses(),
ElementsAre("Context: cheese\n", "Context: more\n", "Context: choco\n"));
EXPECT_THAT(response5.responses(),
ElementsAre("Context: apple\n", "Context: banana\n",
"Context: orange\n"));
}
TEST_F(OnDeviceModelServiceTest, CountTokens) {
@ -344,19 +299,19 @@ TEST_F(OnDeviceModelServiceTest, CountTokens) {
TestResponseHolder response;
mojo::Remote<mojom::Session> session;
model->StartSession(session.BindNewPipeAndPassReceiver());
session->AddContext(MakeInput("cheese"), {});
session->AddContext(MakeInput("more"), {});
session->Append(MakeInput("cheese"), {});
session->Append(MakeInput("more"), {});
std::string input = "cheddar";
session->Execute(MakeInput(input), response.BindRemote());
session->Append(MakeInput(input), {});
session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion();
EXPECT_THAT(response.input_token_count(), input.size());
// 2 context + 1 input.
// 3 context.
EXPECT_THAT(response.output_token_count(), 3);
}
TEST_F(OnDeviceModelServiceTest, AddContextWithTokenLimits) {
TEST_F(OnDeviceModelServiceTest, AppendWithTokenLimits) {
auto model = LoadModel();
TestResponseHolder response;
@ -367,21 +322,22 @@ TEST_F(OnDeviceModelServiceTest, AddContextWithTokenLimits) {
ContextClientWaiter client1;
auto max_input = MakeInput("big cheese");
max_input->max_tokens = 4;
session->AddContext(std::move(max_input), client1.BindRemote());
session->Append(std::move(max_input), client1.BindRemote());
EXPECT_EQ(client1.WaitForCompletion(), 4);
ContextClientWaiter client2;
auto offset_input = MakeInput("big cheese");
offset_input->token_offset = 4;
session->AddContext(std::move(offset_input), client2.BindRemote());
session->Append(std::move(offset_input), client2.BindRemote());
EXPECT_EQ(client2.WaitForCompletion(), 6);
session->Execute(MakeInput("cheddar"), response.BindRemote());
session->Append(MakeInput("cheddar"), {});
session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion();
EXPECT_THAT(
response.responses(),
ElementsAre("Context: big \n", "Context: cheese\n", "Input: cheddar\n"));
EXPECT_THAT(response.responses(),
ElementsAre("Context: big \n", "Context: cheese\n",
"Context: cheddar\n"));
}
TEST_F(OnDeviceModelServiceTest, MultipleSessionsWaitPreviousSession) {
@ -390,7 +346,8 @@ TEST_F(OnDeviceModelServiceTest, MultipleSessionsWaitPreviousSession) {
TestResponseHolder response1;
mojo::Remote<mojom::Session> session1;
model->StartSession(session1.BindNewPipeAndPassReceiver());
session1->Execute(MakeInput("1"), response1.BindRemote());
session1->Append(MakeInput("1"), {});
session1->Generate(mojom::GenerateOptions::New(), response1.BindRemote());
mojo::Remote<mojom::Session> session2;
model->StartSession(session2.BindNewPipeAndPassReceiver());
@ -402,13 +359,14 @@ TEST_F(OnDeviceModelServiceTest, MultipleSessionsWaitPreviousSession) {
// Response from first session should still work.
response1.WaitForCompletion();
EXPECT_THAT(response1.responses(), ElementsAre("Input: 1\n"));
EXPECT_THAT(response1.responses(), ElementsAre("Context: 1\n"));
// Second session still works.
TestResponseHolder response2;
session2->Execute(MakeInput("2"), response2.BindRemote());
session2->Append(MakeInput("2"), {});
session2->Generate(mojom::GenerateOptions::New(), response2.BindRemote());
response2.WaitForCompletion();
EXPECT_THAT(response2.responses(), ElementsAre("Input: 2\n"));
EXPECT_THAT(response2.responses(), ElementsAre("Context: 2\n"));
}
TEST_F(OnDeviceModelServiceTest, LoadsAdaptation) {
@ -416,16 +374,16 @@ TEST_F(OnDeviceModelServiceTest, LoadsAdaptation) {
FakeFile weights2("Adapt2");
auto model = LoadModel();
auto adaptation1 = LoadAdaptation(*model, weights1.Open());
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Input: foo\n"));
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n"));
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
ElementsAre("Adaptation: Adapt1\n", "Input: foo\n"));
ElementsAre("Adaptation: Adapt1\n", "Context: foo\n"));
auto adaptation2 = LoadAdaptation(*model, weights2.Open());
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Input: foo\n"));
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n"));
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
ElementsAre("Adaptation: Adapt1\n", "Input: foo\n"));
ElementsAre("Adaptation: Adapt1\n", "Context: foo\n"));
EXPECT_THAT(GetResponses(*adaptation2, "foo"),
ElementsAre("Adaptation: Adapt2\n", "Input: foo\n"));
ElementsAre("Adaptation: Adapt2\n", "Context: foo\n"));
}
TEST_F(OnDeviceModelServiceTest, DestroysAdaptationSession) {
@ -457,16 +415,16 @@ TEST_F(OnDeviceModelServiceTest, LoadsAdaptationWithPath) {
FakeFile weights2("Adapt2");
auto model = LoadModel(ml::ModelBackendType::kApuBackend);
auto adaptation1 = LoadAdaptation(*model, weights1.Path());
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Input: foo\n"));
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n"));
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
ElementsAre("Adaptation: Adapt1\n", "Input: foo\n"));
ElementsAre("Adaptation: Adapt1\n", "Context: foo\n"));
auto adaptation2 = LoadAdaptation(*model, weights2.Path());
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Input: foo\n"));
EXPECT_THAT(GetResponses(*model, "foo"), ElementsAre("Context: foo\n"));
EXPECT_THAT(GetResponses(*adaptation1, "foo"),
ElementsAre("Adaptation: Adapt1\n", "Input: foo\n"));
ElementsAre("Adaptation: Adapt1\n", "Context: foo\n"));
EXPECT_THAT(GetResponses(*adaptation2, "foo"),
ElementsAre("Adaptation: Adapt2\n", "Input: foo\n"));
ElementsAre("Adaptation: Adapt2\n", "Context: foo\n"));
}
TEST_F(OnDeviceModelServiceTest, LoadingAdaptationDoesNotCancelSession) {
@ -518,7 +476,7 @@ TEST_F(OnDeviceModelServiceTest, Score) {
mojo::Remote<mojom::Session> session;
model->StartSession(session.BindNewPipeAndPassReceiver());
session->AddContext(MakeInput("hi"), {});
session->Append(MakeInput("hi"), {});
{
base::test::TestFuture<float> future;
@ -532,7 +490,7 @@ TEST_F(OnDeviceModelServiceTest, Score) {
}
}
TEST_F(OnDeviceModelServiceTest, AddContextWithTokens) {
TEST_F(OnDeviceModelServiceTest, AppendWithTokens) {
auto model = LoadModel();
TestResponseHolder response;
@ -543,29 +501,30 @@ TEST_F(OnDeviceModelServiceTest, AddContextWithTokens) {
pieces.push_back(ml::Token::kSystem);
pieces.push_back("hi");
pieces.push_back(ml::Token::kEnd);
session->AddContext(MakeInput(std::move(pieces)), {});
session->Append(MakeInput(std::move(pieces)), {});
}
{
std::vector<ml::InputPiece> pieces;
pieces.push_back(ml::Token::kModel);
pieces.push_back("hello");
pieces.push_back(ml::Token::kEnd);
session->AddContext(MakeInput(std::move(pieces)), {});
session->Append(MakeInput(std::move(pieces)), {});
}
{
std::vector<ml::InputPiece> pieces;
pieces.push_back(ml::Token::kUser);
pieces.push_back("bye");
session->Execute(MakeInput(std::move(pieces)), response.BindRemote());
session->Append(MakeInput(std::move(pieces)), {});
session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
}
response.WaitForCompletion();
EXPECT_THAT(response.responses(), ElementsAre("Context: System: hi End.\n",
"Context: Model: hello End.\n",
"Input: User: bye\n"));
"Context: User: bye\n"));
}
TEST_F(OnDeviceModelServiceTest, AddContextWithImages) {
TEST_F(OnDeviceModelServiceTest, AppendWithImages) {
auto model = LoadModel();
auto params = mojom::LoadAdaptationParams::New();
params->enable_image_input = true;
@ -588,7 +547,7 @@ TEST_F(OnDeviceModelServiceTest, AddContextWithImages) {
pieces.push_back("cheese");
session->AddContext(MakeInput(std::move(pieces)), {});
session->Append(MakeInput(std::move(pieces)), {});
}
{
@ -604,13 +563,14 @@ TEST_F(OnDeviceModelServiceTest, AddContextWithImages) {
pieces.push_back("cheese");
session->Execute(MakeInput(std::move(pieces)), response.BindRemote());
session->Append(MakeInput(std::move(pieces)), {});
session->Generate(mojom::GenerateOptions::New(), response.BindRemote());
response.WaitForCompletion();
}
EXPECT_THAT(response.responses(),
ElementsAre("Context: cheddar[Bitmap of size 7x21]cheese\n",
"Input: bleu[Bitmap of size 63x42]cheese\n"));
"Context: bleu[Bitmap of size 63x42]cheese\n"));
}
TEST_F(OnDeviceModelServiceTest, ClassifyTextSafety) {
@ -640,7 +600,7 @@ TEST_F(OnDeviceModelServiceTest, PerformanceHint) {
auto model = LoadModel(ml::ModelBackendType::kGpuBackend,
ml::ModelPerformanceHint::kFastestInference);
EXPECT_THAT(GetResponses(*model, "foo"),
ElementsAre("Fastest inference\n", "Input: foo\n"));
ElementsAre("Fastest inference\n", "Context: foo\n"));
}
} // namespace

@ -41,7 +41,7 @@ std::string OnDeviceInputToString(const mojom::Input& input) {
return oss.str();
}
std::string CtxToString(const mojom::InputOptions& input) {
std::string CtxToString(const mojom::AppendOptions& input) {
std::string suffix;
std::string context = OnDeviceInputToString(*input.input);
if (input.token_offset > 0) {
@ -90,23 +90,35 @@ FakeOnDeviceSession::~FakeOnDeviceSession() = default;
void FakeOnDeviceSession::AddContext(
mojom::InputOptionsPtr input,
mojo::PendingRemote<mojom::ContextClient> client) {
NOTREACHED();
}
void FakeOnDeviceSession::Append(
mojom::AppendOptionsPtr options,
mojo::PendingRemote<mojom::ContextClient> client) {
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(&FakeOnDeviceSession::AddContextInternal,
weak_factory_.GetWeakPtr(), std::move(input),
FROM_HERE, base::BindOnce(&FakeOnDeviceSession::AppendImpl,
weak_factory_.GetWeakPtr(), std::move(options),
std::move(client)));
}
void FakeOnDeviceSession::Execute(
mojom::InputOptionsPtr input,
mojo::PendingRemote<mojom::StreamingResponder> response) {
NOTREACHED();
}
void FakeOnDeviceSession::Generate(
mojom::GenerateOptionsPtr options,
mojo::PendingRemote<mojom::StreamingResponder> response) {
if (settings_->execute_delay.is_zero()) {
ExecuteImpl(std::move(input), std::move(response));
GenerateImpl(std::move(options), std::move(response));
return;
}
base::SequencedTaskRunner::GetCurrentDefault()->PostDelayedTask(
FROM_HERE,
base::BindOnce(&FakeOnDeviceSession::ExecuteImpl,
weak_factory_.GetWeakPtr(), std::move(input),
base::BindOnce(&FakeOnDeviceSession::GenerateImpl,
weak_factory_.GetWeakPtr(), std::move(options),
std::move(response)),
settings_->execute_delay);
}
@ -130,8 +142,8 @@ void FakeOnDeviceSession::Clone(
model_->AddSession(std::move(session), std::move(new_session));
}
void FakeOnDeviceSession::ExecuteImpl(
mojom::InputOptionsPtr input,
void FakeOnDeviceSession::GenerateImpl(
mojom::GenerateOptionsPtr options,
mojo::PendingRemote<mojom::StreamingResponder> response) {
mojo::Remote<mojom::StreamingResponder> remote(std::move(response));
if (model_->performance_hint() ==
@ -151,21 +163,20 @@ void FakeOnDeviceSession::ExecuteImpl(
"Adaptation model: " + model_->data().adaptation_model_weight + "\n";
remote->OnResponse(std::move(chunk));
}
for (const auto& context : context_) {
auto chunk = mojom::ResponseChunk::New();
chunk->text = "Context: " + CtxToString(*context) + "\n";
remote->OnResponse(std::move(chunk));
}
if (settings_->model_execute_result.empty()) {
auto chunk = mojom::ResponseChunk::New();
chunk->text = "Input: " + OnDeviceInputToString(*input->input) + "\n";
if (input->top_k > 1) {
chunk->text += "TopK: " + base::NumberToString(*input->top_k) +
", Temp: " + base::NumberToString(*input->temperature) +
"\n";
for (const auto& context : context_) {
auto chunk = mojom::ResponseChunk::New();
chunk->text = "Context: " + CtxToString(*context) + "\n";
remote->OnResponse(std::move(chunk));
}
if (options->top_k > 1) {
auto chunk = mojom::ResponseChunk::New();
chunk->text += "TopK: " + base::NumberToString(*options->top_k) +
", Temp: " + base::NumberToString(*options->temperature) +
"\n";
remote->OnResponse(std::move(chunk));
}
remote->OnResponse(std::move(chunk));
} else {
for (const auto& text : settings_->model_execute_result) {
auto chunk = mojom::ResponseChunk::New();
@ -177,16 +188,16 @@ void FakeOnDeviceSession::ExecuteImpl(
remote->OnComplete(std::move(summary));
}
void FakeOnDeviceSession::AddContextInternal(
mojom::InputOptionsPtr input,
void FakeOnDeviceSession::AppendImpl(
mojom::AppendOptionsPtr options,
mojo::PendingRemote<mojom::ContextClient> client) {
uint32_t input_tokens =
static_cast<uint32_t>(OnDeviceInputToString(*input->input).size());
static_cast<uint32_t>(OnDeviceInputToString(*options->input).size());
uint32_t max_tokens =
input->max_tokens > 0 ? input->max_tokens : input_tokens;
uint32_t token_offset = input->token_offset;
options->max_tokens > 0 ? options->max_tokens : input_tokens;
uint32_t token_offset = options->token_offset;
uint32_t tokens_processed = std::min(input_tokens - token_offset, max_tokens);
context_.emplace_back(std::move(input));
context_.emplace_back(std::move(options));
if (client) {
mojo::Remote<mojom::ContextClient> remote(std::move(client));
remote->OnComplete(tokens_processed);

@ -78,10 +78,17 @@ class FakeOnDeviceSession final : public mojom::Session {
void AddContext(mojom::InputOptionsPtr input,
mojo::PendingRemote<mojom::ContextClient> client) override;
void Append(mojom::AppendOptionsPtr options,
mojo::PendingRemote<mojom::ContextClient> client) override;
void Execute(
mojom::InputOptionsPtr input,
mojo::PendingRemote<mojom::StreamingResponder> response) override;
void Generate(
mojom::GenerateOptionsPtr input,
mojo::PendingRemote<mojom::StreamingResponder> response) override;
void GetSizeInTokens(mojom::InputPtr input,
GetSizeInTokensCallback callback) override;
@ -91,15 +98,14 @@ class FakeOnDeviceSession final : public mojom::Session {
mojo::PendingReceiver<on_device_model::mojom::Session> session) override;
private:
void ExecuteImpl(mojom::InputOptionsPtr input,
mojo::PendingRemote<mojom::StreamingResponder> response);
void AddContextInternal(mojom::InputOptionsPtr input,
mojo::PendingRemote<mojom::ContextClient> client);
void GenerateImpl(mojom::GenerateOptionsPtr options,
mojo::PendingRemote<mojom::StreamingResponder> response);
void AppendImpl(mojom::AppendOptionsPtr options,
mojo::PendingRemote<mojom::ContextClient> client);
raw_ptr<FakeOnDeviceServiceSettings> settings_;
std::string adaptation_model_weight_;
std::vector<mojom::InputOptionsPtr> context_;
std::vector<mojom::AppendOptionsPtr> context_;
raw_ptr<FakeOnDeviceModel> model_;
base::WeakPtrFactory<FakeOnDeviceSession> weak_factory_{this};

@ -32,7 +32,6 @@ void TestResponseHolder::OnResponse(mojom::ResponseChunkPtr chunk) {
void TestResponseHolder::OnComplete(mojom::ResponseSummaryPtr summary) {
complete_ = true;
input_token_count_ = summary->input_token_count;
output_token_count_ = summary->output_token_count;
run_loop_.Quit();
}

@ -32,7 +32,6 @@ class TestResponseHolder : public mojom::StreamingResponder {
bool complete() const { return complete_; }
bool disconnected() const { return disconnected_; }
bool terminated() const { return disconnected_ || complete_; }
uint32_t input_token_count() const { return input_token_count_; }
uint32_t output_token_count() const { return output_token_count_; }
// Spins a RunLoop until this object observes completion of its response.
@ -48,7 +47,6 @@ class TestResponseHolder : public mojom::StreamingResponder {
std::vector<std::string> responses_;
bool complete_ = false;
bool disconnected_ = false;
uint32_t input_token_count_ = 0;
uint32_t output_token_count_ = 0;
mojo::Receiver<mojom::StreamingResponder> receiver_{this};
};

@ -62,7 +62,8 @@ struct ResponseSummary {
// Optional safety information computed against the full response.
SafetyInfo? safety_info;
// The total number of input tokens from the call to Execute().
// Deprecated: Execute() is now deprecated and input should be passed through
// Append().
uint32 input_token_count;
// The total number of output tokens for this response.
@ -191,18 +192,57 @@ struct InputOptions {
float? temperature;
};
[Stable]
struct AppendOptions {
// The input for the model.
Input input;
// The maximum number of tokens that should be processed. If zero, will
// process all tokens from this input.
uint32 max_tokens = 0;
// After text is tokenized, the offset into that vector to start processing.
// If zero, will start at the first token.
uint32 token_offset = 0;
};
[Stable]
struct GenerateOptions {
// The maximum number of tokens that should be output from a call to
// Execute(). If zero, will output tokens until an end token or the maximum
// sequence length.
uint32 max_output_tokens = 0;
// These params control the output sampling. Higher `top_k` means more tokens
// are considered, higher `temperature` means less likely tokens are more
// probable.
// `top_k` should be a value from 1 to the max top K value the model was
// initialized with.
uint32? top_k;
// `temperature` should be a value greater than 0.0. Values above 1.0 may give
// poor results.
float? temperature;
};
// A session for a model that allows adding context and then executing an input
// with that context.
[Stable]
interface Session {
// Adds context to this session. Any context added here will build off of
// previous calls to |AddContext()|.
// Deprecated: use Append() instead.
AddContext@0(InputOptions input, pending_remote<ContextClient>? client);
// Appends input to this session. Any input added here will build off of
// previous calls to |Append()|. To cancel, close the |client| pipe.
[MinVersion=1]
Append@6(AppendOptions options, pending_remote<ContextClient>? client);
// Executes model on the given input. The input will be added on top of the
// context provided by |AddContext()|. The response will be streamed to
// |response|. To cancel the request, close the |response| pipe.
// Deprecated: use Generate() instead.
Execute@1(InputOptions input, pending_remote<StreamingResponder> response);
// Generates output from the model on top of any input added from Append().
// The response will be streamed to |response|. To cancel the request, close
// the |response| pipe.
[MinVersion=1]
Generate@7(
GenerateOptions options, pending_remote<StreamingResponder> response);
// Gets the size of the given text in tokens. Will return 0 if the text is
// empty or error occurred.