[Trigger API] Add request id to classification result
Return request id as part of the classification result. The request id will be used by the trigger API in following CL to allow client to manually trigger data collection. Bug: 1424531 Change-Id: I26a64c4d223496a87a6d6d3584177e84386bca2e Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4337729 Reviewed-by: Shakti Sahu <shaktisahu@chromium.org> Reviewed-by: Siddhartha S <ssid@chromium.org> Commit-Queue: Hailey Wang <haileywang@google.com> Cr-Commit-Position: refs/heads/main@{#1136308}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
27ef30d73b
commit
d0e8f4c490
components/segmentation_platform
@@ -72,10 +72,12 @@ class TrainingDataCollector {
|
||||
virtual void ReportCollectedContinuousTrainingData() = 0;
|
||||
|
||||
// Called to collect and store training input data. The data will only be
|
||||
// uploaded once |OnObservationTrigger| is triggered.
|
||||
virtual void OnDecisionTime(proto::SegmentId id,
|
||||
scoped_refptr<InputContext> input_context,
|
||||
DecisionType type) = 0;
|
||||
// uploaded once |OnObservationTrigger| is triggered. |TrainingRequestId| can
|
||||
// be used to trigger observation for a specific set of training data.
|
||||
virtual TrainingRequestId OnDecisionTime(
|
||||
proto::SegmentId id,
|
||||
scoped_refptr<InputContext> input_context,
|
||||
DecisionType type) = 0;
|
||||
|
||||
// Called when a relevant uma histogram is recorded or when a time delay
|
||||
// trigger is hit, retrieve input training data from storage, collect output
|
||||
|
@@ -24,6 +24,7 @@
|
||||
#include "components/segmentation_platform/public/config.h"
|
||||
#include "components/segmentation_platform/public/local_state_helper.h"
|
||||
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
|
||||
#include "components/segmentation_platform/public/trigger.h"
|
||||
|
||||
namespace segmentation_platform {
|
||||
namespace {
|
||||
@@ -197,9 +198,10 @@ void TrainingDataCollectorImpl::OnGetSegmentsInfoList(
|
||||
const auto& training_data = segment_info.training_data(i);
|
||||
if (current_time > training_data.observation_trigger_timestamp()) {
|
||||
// Observation is reached for the current training data.
|
||||
OnObservationTrigger(absl::nullopt,
|
||||
(TrainingRequestId)training_data.request_id(),
|
||||
segment_info);
|
||||
OnObservationTrigger(
|
||||
absl::nullopt,
|
||||
TrainingRequestId::FromUnsafeValue(training_data.request_id()),
|
||||
segment_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -416,12 +418,12 @@ void TrainingDataCollectorImpl::ReportCollectedContinuousTrainingData() {
|
||||
}
|
||||
}
|
||||
|
||||
void TrainingDataCollectorImpl::OnDecisionTime(
|
||||
TrainingRequestId TrainingDataCollectorImpl::OnDecisionTime(
|
||||
proto::SegmentId id,
|
||||
scoped_refptr<InputContext> input_context,
|
||||
DecisionType type) {
|
||||
if (all_segments_for_training_.count(id) == 0) {
|
||||
return;
|
||||
return TrainingRequestId();
|
||||
}
|
||||
|
||||
const TrainingRequestId request_id = training_cache_->GenerateNextId();
|
||||
@@ -431,6 +433,8 @@ void TrainingDataCollectorImpl::OnDecisionTime(
|
||||
base::BindOnce(&TrainingDataCollectorImpl::OnGetSegmentInfoAtDecisionTime,
|
||||
weak_ptr_factory_.GetWeakPtr(), id, request_id, type,
|
||||
input_context));
|
||||
|
||||
return request_id;
|
||||
}
|
||||
|
||||
void TrainingDataCollectorImpl::OnGetSegmentInfoAtDecisionTime(
|
||||
@@ -550,6 +554,10 @@ void TrainingDataCollectorImpl::OnObservationTrigger(
|
||||
const absl::optional<ImmediaCollectionParam>& param,
|
||||
TrainingRequestId request_id,
|
||||
const proto::SegmentInfo& segment_info) {
|
||||
if (request_id.is_null()) {
|
||||
return;
|
||||
}
|
||||
|
||||
RecordTrainingDataCollectionEvent(
|
||||
segment_info.segment_id(),
|
||||
stats::TrainingDataCollectionEvent::kObservationTimeReached);
|
||||
|
@@ -21,6 +21,7 @@
|
||||
#include "components/segmentation_platform/internal/signals/user_action_signal_handler.h"
|
||||
#include "components/segmentation_platform/public/model_provider.h"
|
||||
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
|
||||
#include "components/segmentation_platform/public/trigger.h"
|
||||
#include "third_party/abseil-cpp/absl/types/optional.h"
|
||||
|
||||
namespace segmentation_platform {
|
||||
@@ -47,9 +48,9 @@ class TrainingDataCollectorImpl : public TrainingDataCollector,
|
||||
void OnModelMetadataUpdated() override;
|
||||
void OnServiceInitialized() override;
|
||||
void ReportCollectedContinuousTrainingData() override;
|
||||
void OnDecisionTime(proto::SegmentId id,
|
||||
scoped_refptr<InputContext> input_context,
|
||||
DecisionType type) override;
|
||||
TrainingRequestId OnDecisionTime(proto::SegmentId id,
|
||||
scoped_refptr<InputContext> input_context,
|
||||
DecisionType type) override;
|
||||
|
||||
void OnObservationTrigger(const absl::optional<ImmediaCollectionParam>& param,
|
||||
TrainingRequestId request_id,
|
||||
|
@@ -15,6 +15,7 @@
|
||||
#include "components/segmentation_platform/public/input_context.h"
|
||||
#include "components/segmentation_platform/public/prediction_options.h"
|
||||
#include "components/segmentation_platform/public/proto/prediction_result.pb.h"
|
||||
#include "components/segmentation_platform/public/trigger.h"
|
||||
|
||||
namespace segmentation_platform {
|
||||
namespace {
|
||||
@@ -59,6 +60,9 @@ class RequestHandlerImpl : public RequestHandler {
|
||||
ClassificationResultCallback classification_callback,
|
||||
std::unique_ptr<SegmentResultProvider::SegmentResult> result);
|
||||
|
||||
TrainingRequestId CollectTrainingData(
|
||||
scoped_refptr<InputContext> input_context);
|
||||
|
||||
// The config for providing client config params.
|
||||
const raw_ref<const Config> config_;
|
||||
|
||||
@@ -118,6 +122,7 @@ void RequestHandlerImpl::OnGetModelResultForClassification(
|
||||
PostProcessor post_processor;
|
||||
PredictionStatus status = PredictionStatus::kFailed;
|
||||
proto::PredictionResult pred_result;
|
||||
absl::optional<TrainingRequestId> request_id;
|
||||
if (result) {
|
||||
stats::RecordSegmentSelectionFailure(
|
||||
*config_, stats::GetSuccessOrFailureReason(result->state));
|
||||
@@ -125,13 +130,9 @@ void RequestHandlerImpl::OnGetModelResultForClassification(
|
||||
pred_result = result->result;
|
||||
stats::RecordClassificationResultComputed(*config_, pred_result);
|
||||
|
||||
// Collect training data. The execution service and training data collector
|
||||
// might be null in testing.
|
||||
if (execution_service_ && execution_service_->training_data_collector()) {
|
||||
execution_service_->training_data_collector()->OnDecisionTime(
|
||||
config_->segments.begin()->first, input_context,
|
||||
proto::TrainingOutputs::TriggerConfig::ONDEMAND);
|
||||
}
|
||||
// Collect training data. The training data collector might be null in
|
||||
// testing.
|
||||
request_id = CollectTrainingData(input_context);
|
||||
} else {
|
||||
stats::RecordSegmentSelectionFailure(
|
||||
*config_, stats::SegmentationSelectionFailureReason::
|
||||
@@ -139,11 +140,27 @@ void RequestHandlerImpl::OnGetModelResultForClassification(
|
||||
}
|
||||
ClassificationResult classification_result =
|
||||
post_processor.GetPostProcessedClassificationResult(pred_result, status);
|
||||
|
||||
if (request_id && !request_id.value().is_null()) {
|
||||
classification_result.request_id = request_id.value();
|
||||
}
|
||||
|
||||
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
|
||||
FROM_HERE, base::BindOnce(std::move(classification_callback),
|
||||
classification_result));
|
||||
}
|
||||
|
||||
TrainingRequestId RequestHandlerImpl::CollectTrainingData(
|
||||
scoped_refptr<InputContext> input_context) {
|
||||
if (!execution_service_->training_data_collector()) {
|
||||
return TrainingRequestId();
|
||||
}
|
||||
|
||||
return execution_service_->training_data_collector()->OnDecisionTime(
|
||||
config_->segments.begin()->first, input_context,
|
||||
proto::TrainingOutputs::TriggerConfig::ONDEMAND);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// static
|
||||
|
@@ -58,8 +58,8 @@ class RequestHandlerTest : public testing::Test {
|
||||
config_ = test_utils::CreateTestConfig("test_client", kSegmentId);
|
||||
auto provider = std::make_unique<MockResultProvider>();
|
||||
result_provider_ = provider.get();
|
||||
request_handler_ =
|
||||
RequestHandler::Create(*config_, std::move(provider), nullptr);
|
||||
request_handler_ = RequestHandler::Create(*config_, std::move(provider),
|
||||
&execution_service_);
|
||||
}
|
||||
|
||||
void OnGetClassificationResult(base::RepeatingClosure closure,
|
||||
@@ -70,6 +70,7 @@ class RequestHandlerTest : public testing::Test {
|
||||
std::move(closure).Run();
|
||||
}
|
||||
|
||||
ExecutionService execution_service_;
|
||||
base::test::TaskEnvironment task_environment_{
|
||||
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
|
||||
std::unique_ptr<Config> config_;
|
||||
|
@@ -77,9 +77,9 @@ class MockTrainingDataCollector : public TrainingDataCollector {
|
||||
MOCK_METHOD0(OnServiceInitialized, void());
|
||||
MOCK_METHOD0(ReportCollectedContinuousTrainingData, void());
|
||||
MOCK_METHOD3(OnDecisionTime,
|
||||
void(proto::SegmentId id,
|
||||
scoped_refptr<InputContext> input_context,
|
||||
DecisionType type));
|
||||
TrainingRequestId(proto::SegmentId id,
|
||||
scoped_refptr<InputContext> input_context,
|
||||
DecisionType type));
|
||||
MOCK_METHOD3(OnObservationTrigger,
|
||||
void(const absl::optional<ImmediaCollectionParam>& param,
|
||||
TrainingRequestId request_id,
|
||||
|
@@ -9,6 +9,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "base/functional/callback_helpers.h"
|
||||
#include "components/segmentation_platform/public/trigger.h"
|
||||
|
||||
namespace segmentation_platform {
|
||||
|
||||
@@ -41,6 +42,10 @@ struct ClassificationResult {
|
||||
// label from one of the bin depending on where the score from the model
|
||||
// evaluation lies.
|
||||
std::vector<std::string> ordered_labels;
|
||||
|
||||
// The request ID used for identifying a specific training data inputs. Can be
|
||||
// null if training data was not uploaded for that execution.
|
||||
TrainingRequestId request_id;
|
||||
};
|
||||
|
||||
// RegressionResult is returned when Predictor specified by the client in
|
||||
|
Reference in New Issue
Block a user