0

[TDC] Add sampling to time triggered data collection

Change-Id: Ia43152d129a7a11887191ecf7bb6a2be61ee5c7c
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4757485
Reviewed-by: Siddhartha S <ssid@chromium.org>
Commit-Queue: Hailey Wang <haileywang@google.com>
Cr-Commit-Position: refs/heads/main@{#1181741}
This commit is contained in:
Hailey Wang
2023-08-09 22:49:23 +00:00
committed by Chromium LUCI CQ
parent 9f51e807c3
commit 46865afffb
7 changed files with 71 additions and 19 deletions

@@ -8,9 +8,11 @@
#include "base/containers/contains.h" #include "base/containers/contains.h"
#include "base/functional/callback_helpers.h" #include "base/functional/callback_helpers.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/metrics/field_trial_params.h"
#include "base/metrics/metrics_hashes.h" #include "base/metrics/metrics_hashes.h"
#include "base/metrics/user_metrics.h" #include "base/metrics/user_metrics.h"
#include "base/notreached.h" #include "base/notreached.h"
#include "base/rand_util.h"
#include "base/task/single_thread_task_runner.h" #include "base/task/single_thread_task_runner.h"
#include "base/time/clock.h" #include "base/time/clock.h"
#include "base/time/time.h" #include "base/time/time.h"
@@ -25,7 +27,6 @@
#include "components/segmentation_platform/internal/platform_options.h" #include "components/segmentation_platform/internal/platform_options.h"
#include "components/segmentation_platform/internal/segmentation_ukm_helper.h" #include "components/segmentation_platform/internal/segmentation_ukm_helper.h"
#include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h" #include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
#include "components/segmentation_platform/internal/stats.h"
#include "components/segmentation_platform/public/config.h" #include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/local_state_helper.h" #include "components/segmentation_platform/public/local_state_helper.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h" #include "components/segmentation_platform/public/proto/model_metadata.pb.h"
@@ -84,6 +85,10 @@ bool IsPeriodic(const proto::SegmentInfo& info) {
return type == proto::TrainingOutputs::TriggerConfig::PERIODIC; return type == proto::TrainingOutputs::TriggerConfig::PERIODIC;
} }
constexpr base::FeatureParam<int> TimeDelaySamplingRate{
&features::kSegmentationPlatformTimeDelaySampling,
/*name=*/"SamplingRate", /*default_value=*/20};
} // namespace } // namespace
struct TrainingDataCollectorImpl::TrainingTimings { struct TrainingDataCollectorImpl::TrainingTimings {
@@ -112,7 +117,8 @@ TrainingDataCollectorImpl::TrainingDataCollectorImpl(
cached_result_provider_(cached_result_provider), cached_result_provider_(cached_result_provider),
training_cache_(std::make_unique<TrainingDataCache>( training_cache_(std::make_unique<TrainingDataCache>(
storage_service->segment_info_database())), storage_service->segment_info_database())),
default_model_manager_(storage_service->default_model_manager()) {} default_model_manager_(storage_service->default_model_manager()),
time_trigger_sampling_rate_(TimeDelaySamplingRate.Get()) {}
TrainingDataCollectorImpl::~TrainingDataCollectorImpl() { TrainingDataCollectorImpl::~TrainingDataCollectorImpl() {
histogram_signal_handler_->RemoveObserver(this); histogram_signal_handler_->RemoveObserver(this);
@@ -277,6 +283,11 @@ void TrainingDataCollectorImpl::OnUserAction(const std::string& user_action,
} }
} }
void TrainingDataCollectorImpl::SetSamplingRateForTesting(
uint64_t sampling_rate) {
time_trigger_sampling_rate_ = sampling_rate;
}
void TrainingDataCollectorImpl::OnUmaUpdatedReportForSegmentInfo( void TrainingDataCollectorImpl::OnUmaUpdatedReportForSegmentInfo(
const absl::optional<ImmediateCollectionParam>& param, const absl::optional<ImmediateCollectionParam>& param,
absl::optional<proto::SegmentInfo> segment) { absl::optional<proto::SegmentInfo> segment) {
@@ -604,26 +615,28 @@ void TrainingDataCollectorImpl::OnGetTrainingTensorsAtDecisionTime(
// TODO(haileywang): This is slightly inaccurate since the the delay timer is // TODO(haileywang): This is slightly inaccurate since the the delay timer is
// only started after the input training tensors are cached. // only started after the input training tensors are cached.
if (training_request.observation_delayed_task) { if (training_request.observation_delayed_task) {
if (training_request.observation_delayed_task.value().is_zero()) {
RecordTrainingDataCollectionEvent(
segment_info.segment_id(),
stats::TrainingDataCollectionEvent::kImmediateObservationPosted);
} else {
RecordTrainingDataCollectionEvent(
segment_info.segment_id(),
stats::TrainingDataCollectionEvent::kDelayedTaskPosted);
}
VLOG(1) << "Observation timeout set for " VLOG(1) << "Observation timeout set for "
<< proto::SegmentId_Name(segment_info.segment_id()) << " " << proto::SegmentId_Name(segment_info.segment_id()) << " "
<< *training_request.observation_delayed_task; << *training_request.observation_delayed_task;
base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask( if (training_request.observation_delayed_task.value().is_zero()) {
FROM_HERE, PostObservationTask(
base::BindOnce(&TrainingDataCollectorImpl::OnObservationTrigger, request_id, segment_info, *training_request.observation_delayed_task,
weak_ptr_factory_.GetWeakPtr(), absl::nullopt, stats::TrainingDataCollectionEvent::kImmediateObservationPosted);
request_id, segment_info, base::DoNothing()), } else {
*training_request.observation_delayed_task); // Sample time triggered data for ondemand models.
if (IsPeriodic(segment_info) ||
base::RandGenerator(time_trigger_sampling_rate_) == 0) {
PostObservationTask(
request_id, segment_info,
*training_request.observation_delayed_task,
stats::TrainingDataCollectionEvent::kDelayedTaskPosted);
} else {
RecordTrainingDataCollectionEvent(
segment_info.segment_id(),
stats::TrainingDataCollectionEvent::kDelayTriggerSampled);
}
}
} else { } else {
VLOG(1) << "Observation without timeout " VLOG(1) << "Observation without timeout "
<< proto::SegmentId_Name(segment_info.segment_id()); << proto::SegmentId_Name(segment_info.segment_id());
@@ -633,6 +646,20 @@ void TrainingDataCollectorImpl::OnGetTrainingTensorsAtDecisionTime(
} }
} }
void TrainingDataCollectorImpl::PostObservationTask(
TrainingRequestId request_id,
const proto::SegmentInfo& segment_info,
const base::TimeDelta& delay,
stats::TrainingDataCollectionEvent event) {
base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
FROM_HERE,
base::BindOnce(&TrainingDataCollectorImpl::OnObservationTrigger,
weak_ptr_factory_.GetWeakPtr(), absl::nullopt, request_id,
segment_info, base::DoNothing()),
delay);
RecordTrainingDataCollectionEvent(segment_info.segment_id(), event);
}
void TrainingDataCollectorImpl::OnObservationTrigger( void TrainingDataCollectorImpl::OnObservationTrigger(
const absl::optional<ImmediateCollectionParam>& param, const absl::optional<ImmediateCollectionParam>& param,
TrainingRequestId request_id, TrainingRequestId request_id,

@@ -5,6 +5,7 @@
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATA_COLLECTION_TRAINING_DATA_COLLECTOR_IMPL_H_ #ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATA_COLLECTION_TRAINING_DATA_COLLECTOR_IMPL_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATA_COLLECTION_TRAINING_DATA_COLLECTOR_IMPL_H_ #define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATA_COLLECTION_TRAINING_DATA_COLLECTOR_IMPL_H_
#include <cstdint>
#include <set> #include <set>
#include <vector> #include <vector>
@@ -13,6 +14,7 @@
#include "base/memory/raw_ptr.h" #include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h" #include "base/memory/weak_ptr.h"
#include "base/metrics/histogram_base.h" #include "base/metrics/histogram_base.h"
#include "base/time/time.h"
#include "components/segmentation_platform/internal/data_collection/training_data_cache.h" #include "components/segmentation_platform/internal/data_collection/training_data_cache.h"
#include "components/segmentation_platform/internal/data_collection/training_data_collector.h" #include "components/segmentation_platform/internal/data_collection/training_data_collector.h"
#include "components/segmentation_platform/internal/database/cached_result_provider.h" #include "components/segmentation_platform/internal/database/cached_result_provider.h"
@@ -22,6 +24,8 @@
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h" #include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "components/segmentation_platform/internal/signals/histogram_signal_handler.h" #include "components/segmentation_platform/internal/signals/histogram_signal_handler.h"
#include "components/segmentation_platform/internal/signals/user_action_signal_handler.h" #include "components/segmentation_platform/internal/signals/user_action_signal_handler.h"
#include "components/segmentation_platform/internal/stats.h"
#include "components/segmentation_platform/public/features.h"
#include "components/segmentation_platform/public/model_provider.h" #include "components/segmentation_platform/public/model_provider.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h" #include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "components/segmentation_platform/public/trigger.h" #include "components/segmentation_platform/public/trigger.h"
@@ -68,6 +72,8 @@ class TrainingDataCollectorImpl : public TrainingDataCollector,
void OnUserAction(const std::string& user_action, void OnUserAction(const std::string& user_action,
base::TimeTicks action_time) override; base::TimeTicks action_time) override;
void SetSamplingRateForTesting(uint64_t sampling_rate);
private: private:
struct TrainingTimings; struct TrainingTimings;
@@ -127,6 +133,11 @@ class TrainingDataCollectorImpl : public TrainingDataCollector,
const ModelProvider::Request& input_tensors, const ModelProvider::Request& input_tensors,
const ModelProvider::Response& output_tensors); const ModelProvider::Response& output_tensors);
void PostObservationTask(TrainingRequestId request_id,
const proto::SegmentInfo& segment_info,
const base::TimeDelta& delay,
stats::TrainingDataCollectionEvent event);
// Returns whether training data can be reported through UKM. If // Returns whether training data can be reported through UKM. If
// |include_output| is false, only input data will be checked to see if they // |include_output| is false, only input data will be checked to see if they
// meet the collection requirement. // meet the collection requirement.
@@ -190,6 +201,8 @@ class TrainingDataCollectorImpl : public TrainingDataCollector,
// list. // list.
base::flat_set<SegmentId> all_segments_for_training_; base::flat_set<SegmentId> all_segments_for_training_;
uint64_t time_trigger_sampling_rate_{0};
base::WeakPtrFactory<TrainingDataCollectorImpl> weak_ptr_factory_{this}; base::WeakPtrFactory<TrainingDataCollectorImpl> weak_ptr_factory_{this};
}; };

@@ -191,6 +191,8 @@ class TrainingDataCollectorImplTest : public ::testing::Test {
&histogram_signal_handler_, &user_action_signal_handler_, &histogram_signal_handler_, &user_action_signal_handler_,
storage_service_.get(), &prefs_, &clock_, storage_service_.get(), &prefs_, &clock_,
cached_result_provider_.get()); cached_result_provider_.get());
collector_->SetSamplingRateForTesting(1);
} }
protected: protected:

@@ -284,7 +284,8 @@ enum class TrainingDataCollectionEvent {
kObservationDisallowed = 18, kObservationDisallowed = 18,
kTrainingDataMissing = 19, kTrainingDataMissing = 19,
kOnDecisionTimeTypeMistmatch = 20, kOnDecisionTimeTypeMistmatch = 20,
kMaxValue = kOnDecisionTimeTypeMistmatch, kDelayTriggerSampled = 21,
kMaxValue = kDelayTriggerSampled,
}; };
// Records analytics for training data collection. // Records analytics for training data collection.

@@ -115,4 +115,8 @@ BASE_FEATURE(kSegmentationPlatformIosModuleRanker,
#else #else
base::FEATURE_DISABLED_BY_DEFAULT); base::FEATURE_DISABLED_BY_DEFAULT);
#endif #endif
BASE_FEATURE(kSegmentationPlatformTimeDelaySampling,
"SegmentationPlatformTimeDelaySampling",
base::FEATURE_DISABLED_BY_DEFAULT);
} // namespace segmentation_platform::features } // namespace segmentation_platform::features

@@ -86,6 +86,9 @@ BASE_DECLARE_FEATURE(kSegmentationPlatformTabResumptionRanker);
// Feature flag for enabling ios module ranker. // Feature flag for enabling ios module ranker.
BASE_DECLARE_FEATURE(kSegmentationPlatformIosModuleRanker); BASE_DECLARE_FEATURE(kSegmentationPlatformIosModuleRanker);
// Feature flag for controlling sampling of training data collection.
BASE_DECLARE_FEATURE(kSegmentationPlatformTimeDelaySampling);
} // namespace segmentation_platform::features } // namespace segmentation_platform::features
#endif // COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_FEATURES_H_ #endif // COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_FEATURES_H_

@@ -96105,6 +96105,8 @@ https://www.dmtf.org/sites/default/files/standards/documents/DSP0134_2.7.1.pdf
<int value="17" label="Disallowed for recording"/> <int value="17" label="Disallowed for recording"/>
<int value="18" label="Observation disallowed for recording"/> <int value="18" label="Observation disallowed for recording"/>
<int value="19" label="Training data missing"/> <int value="19" label="Training data missing"/>
<int value="20" label="Decision type mismatch failure"/>
<int value="21" label="Delay trigger sampled"/>
</enum> </enum>
<enum name="SegmentationPlatformValidationResult"> <enum name="SegmentationPlatformValidationResult">