0

[TDC] Add sampling to time triggered data collection

Change-Id: Ia43152d129a7a11887191ecf7bb6a2be61ee5c7c
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4757485
Reviewed-by: Siddhartha S <ssid@chromium.org>
Commit-Queue: Hailey Wang <haileywang@google.com>
Cr-Commit-Position: refs/heads/main@{#1181741}
This commit is contained in:
Hailey Wang
2023-08-09 22:49:23 +00:00
committed by Chromium LUCI CQ
parent 9f51e807c3
commit 46865afffb
7 changed files with 71 additions and 19 deletions

@ -8,9 +8,11 @@
#include "base/containers/contains.h"
#include "base/functional/callback_helpers.h"
#include "base/logging.h"
#include "base/metrics/field_trial_params.h"
#include "base/metrics/metrics_hashes.h"
#include "base/metrics/user_metrics.h"
#include "base/notreached.h"
#include "base/rand_util.h"
#include "base/task/single_thread_task_runner.h"
#include "base/time/clock.h"
#include "base/time/time.h"
@ -25,7 +27,6 @@
#include "components/segmentation_platform/internal/platform_options.h"
#include "components/segmentation_platform/internal/segmentation_ukm_helper.h"
#include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
#include "components/segmentation_platform/internal/stats.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/local_state_helper.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
@ -84,6 +85,10 @@ bool IsPeriodic(const proto::SegmentInfo& info) {
return type == proto::TrainingOutputs::TriggerConfig::PERIODIC;
}
constexpr base::FeatureParam<int> TimeDelaySamplingRate{
&features::kSegmentationPlatformTimeDelaySampling,
/*name=*/"SamplingRate", /*default_value=*/20};
} // namespace
struct TrainingDataCollectorImpl::TrainingTimings {
@ -112,7 +117,8 @@ TrainingDataCollectorImpl::TrainingDataCollectorImpl(
cached_result_provider_(cached_result_provider),
training_cache_(std::make_unique<TrainingDataCache>(
storage_service->segment_info_database())),
default_model_manager_(storage_service->default_model_manager()) {}
default_model_manager_(storage_service->default_model_manager()),
time_trigger_sampling_rate_(TimeDelaySamplingRate.Get()) {}
TrainingDataCollectorImpl::~TrainingDataCollectorImpl() {
histogram_signal_handler_->RemoveObserver(this);
@ -277,6 +283,11 @@ void TrainingDataCollectorImpl::OnUserAction(const std::string& user_action,
}
}
void TrainingDataCollectorImpl::SetSamplingRateForTesting(
uint64_t sampling_rate) {
time_trigger_sampling_rate_ = sampling_rate;
}
void TrainingDataCollectorImpl::OnUmaUpdatedReportForSegmentInfo(
const absl::optional<ImmediateCollectionParam>& param,
absl::optional<proto::SegmentInfo> segment) {
@ -604,26 +615,28 @@ void TrainingDataCollectorImpl::OnGetTrainingTensorsAtDecisionTime(
// TODO(haileywang): This is slightly inaccurate since the the delay timer is
// only started after the input training tensors are cached.
if (training_request.observation_delayed_task) {
if (training_request.observation_delayed_task.value().is_zero()) {
RecordTrainingDataCollectionEvent(
segment_info.segment_id(),
stats::TrainingDataCollectionEvent::kImmediateObservationPosted);
} else {
RecordTrainingDataCollectionEvent(
segment_info.segment_id(),
stats::TrainingDataCollectionEvent::kDelayedTaskPosted);
}
VLOG(1) << "Observation timeout set for "
<< proto::SegmentId_Name(segment_info.segment_id()) << " "
<< *training_request.observation_delayed_task;
base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
FROM_HERE,
base::BindOnce(&TrainingDataCollectorImpl::OnObservationTrigger,
weak_ptr_factory_.GetWeakPtr(), absl::nullopt,
request_id, segment_info, base::DoNothing()),
*training_request.observation_delayed_task);
if (training_request.observation_delayed_task.value().is_zero()) {
PostObservationTask(
request_id, segment_info, *training_request.observation_delayed_task,
stats::TrainingDataCollectionEvent::kImmediateObservationPosted);
} else {
// Sample time triggered data for ondemand models.
if (IsPeriodic(segment_info) ||
base::RandGenerator(time_trigger_sampling_rate_) == 0) {
PostObservationTask(
request_id, segment_info,
*training_request.observation_delayed_task,
stats::TrainingDataCollectionEvent::kDelayedTaskPosted);
} else {
RecordTrainingDataCollectionEvent(
segment_info.segment_id(),
stats::TrainingDataCollectionEvent::kDelayTriggerSampled);
}
}
} else {
VLOG(1) << "Observation without timeout "
<< proto::SegmentId_Name(segment_info.segment_id());
@ -633,6 +646,20 @@ void TrainingDataCollectorImpl::OnGetTrainingTensorsAtDecisionTime(
}
}
void TrainingDataCollectorImpl::PostObservationTask(
TrainingRequestId request_id,
const proto::SegmentInfo& segment_info,
const base::TimeDelta& delay,
stats::TrainingDataCollectionEvent event) {
base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
FROM_HERE,
base::BindOnce(&TrainingDataCollectorImpl::OnObservationTrigger,
weak_ptr_factory_.GetWeakPtr(), absl::nullopt, request_id,
segment_info, base::DoNothing()),
delay);
RecordTrainingDataCollectionEvent(segment_info.segment_id(), event);
}
void TrainingDataCollectorImpl::OnObservationTrigger(
const absl::optional<ImmediateCollectionParam>& param,
TrainingRequestId request_id,

@ -5,6 +5,7 @@
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATA_COLLECTION_TRAINING_DATA_COLLECTOR_IMPL_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATA_COLLECTION_TRAINING_DATA_COLLECTOR_IMPL_H_
#include <cstdint>
#include <set>
#include <vector>
@ -13,6 +14,7 @@
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/metrics/histogram_base.h"
#include "base/time/time.h"
#include "components/segmentation_platform/internal/data_collection/training_data_cache.h"
#include "components/segmentation_platform/internal/data_collection/training_data_collector.h"
#include "components/segmentation_platform/internal/database/cached_result_provider.h"
@ -22,6 +24,8 @@
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "components/segmentation_platform/internal/signals/histogram_signal_handler.h"
#include "components/segmentation_platform/internal/signals/user_action_signal_handler.h"
#include "components/segmentation_platform/internal/stats.h"
#include "components/segmentation_platform/public/features.h"
#include "components/segmentation_platform/public/model_provider.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "components/segmentation_platform/public/trigger.h"
@ -68,6 +72,8 @@ class TrainingDataCollectorImpl : public TrainingDataCollector,
void OnUserAction(const std::string& user_action,
base::TimeTicks action_time) override;
void SetSamplingRateForTesting(uint64_t sampling_rate);
private:
struct TrainingTimings;
@ -127,6 +133,11 @@ class TrainingDataCollectorImpl : public TrainingDataCollector,
const ModelProvider::Request& input_tensors,
const ModelProvider::Response& output_tensors);
void PostObservationTask(TrainingRequestId request_id,
const proto::SegmentInfo& segment_info,
const base::TimeDelta& delay,
stats::TrainingDataCollectionEvent event);
// Returns whether training data can be reported through UKM. If
// |include_output| is false, only input data will be checked to see if they
// meet the collection requirement.
@ -190,6 +201,8 @@ class TrainingDataCollectorImpl : public TrainingDataCollector,
// list.
base::flat_set<SegmentId> all_segments_for_training_;
uint64_t time_trigger_sampling_rate_{0};
base::WeakPtrFactory<TrainingDataCollectorImpl> weak_ptr_factory_{this};
};

@ -191,6 +191,8 @@ class TrainingDataCollectorImplTest : public ::testing::Test {
&histogram_signal_handler_, &user_action_signal_handler_,
storage_service_.get(), &prefs_, &clock_,
cached_result_provider_.get());
collector_->SetSamplingRateForTesting(1);
}
protected:

@ -284,7 +284,8 @@ enum class TrainingDataCollectionEvent {
kObservationDisallowed = 18,
kTrainingDataMissing = 19,
kOnDecisionTimeTypeMistmatch = 20,
kMaxValue = kOnDecisionTimeTypeMistmatch,
kDelayTriggerSampled = 21,
kMaxValue = kDelayTriggerSampled,
};
// Records analytics for training data collection.

@ -115,4 +115,8 @@ BASE_FEATURE(kSegmentationPlatformIosModuleRanker,
#else
base::FEATURE_DISABLED_BY_DEFAULT);
#endif
BASE_FEATURE(kSegmentationPlatformTimeDelaySampling,
"SegmentationPlatformTimeDelaySampling",
base::FEATURE_DISABLED_BY_DEFAULT);
} // namespace segmentation_platform::features

@ -86,6 +86,9 @@ BASE_DECLARE_FEATURE(kSegmentationPlatformTabResumptionRanker);
// Feature flag for enabling ios module ranker.
BASE_DECLARE_FEATURE(kSegmentationPlatformIosModuleRanker);
// Feature flag for controlling sampling of training data collection.
BASE_DECLARE_FEATURE(kSegmentationPlatformTimeDelaySampling);
} // namespace segmentation_platform::features
#endif // COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_FEATURES_H_

@ -96105,6 +96105,8 @@ https://www.dmtf.org/sites/default/files/standards/documents/DSP0134_2.7.1.pdf
<int value="17" label="Disallowed for recording"/>
<int value="18" label="Observation disallowed for recording"/>
<int value="19" label="Training data missing"/>
<int value="20" label="Decision type mismatch failure"/>
<int value="21" label="Delay trigger sampled"/>
</enum>
<enum name="SegmentationPlatformValidationResult">