[TDC] Add sampling to time triggered data collection
Change-Id: Ia43152d129a7a11887191ecf7bb6a2be61ee5c7c Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4757485 Reviewed-by: Siddhartha S <ssid@chromium.org> Commit-Queue: Hailey Wang <haileywang@google.com> Cr-Commit-Position: refs/heads/main@{#1181741}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
9f51e807c3
commit
46865afffb
components/segmentation_platform
internal
data_collection
training_data_collector_impl.cctraining_data_collector_impl.htraining_data_collector_impl_unittest.cc
stats.hpublic
tools/metrics/histograms
@ -8,9 +8,11 @@
|
||||
#include "base/containers/contains.h"
|
||||
#include "base/functional/callback_helpers.h"
|
||||
#include "base/logging.h"
|
||||
#include "base/metrics/field_trial_params.h"
|
||||
#include "base/metrics/metrics_hashes.h"
|
||||
#include "base/metrics/user_metrics.h"
|
||||
#include "base/notreached.h"
|
||||
#include "base/rand_util.h"
|
||||
#include "base/task/single_thread_task_runner.h"
|
||||
#include "base/time/clock.h"
|
||||
#include "base/time/time.h"
|
||||
@ -25,7 +27,6 @@
|
||||
#include "components/segmentation_platform/internal/platform_options.h"
|
||||
#include "components/segmentation_platform/internal/segmentation_ukm_helper.h"
|
||||
#include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
|
||||
#include "components/segmentation_platform/internal/stats.h"
|
||||
#include "components/segmentation_platform/public/config.h"
|
||||
#include "components/segmentation_platform/public/local_state_helper.h"
|
||||
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
|
||||
@ -84,6 +85,10 @@ bool IsPeriodic(const proto::SegmentInfo& info) {
|
||||
return type == proto::TrainingOutputs::TriggerConfig::PERIODIC;
|
||||
}
|
||||
|
||||
constexpr base::FeatureParam<int> TimeDelaySamplingRate{
|
||||
&features::kSegmentationPlatformTimeDelaySampling,
|
||||
/*name=*/"SamplingRate", /*default_value=*/20};
|
||||
|
||||
} // namespace
|
||||
|
||||
struct TrainingDataCollectorImpl::TrainingTimings {
|
||||
@ -112,7 +117,8 @@ TrainingDataCollectorImpl::TrainingDataCollectorImpl(
|
||||
cached_result_provider_(cached_result_provider),
|
||||
training_cache_(std::make_unique<TrainingDataCache>(
|
||||
storage_service->segment_info_database())),
|
||||
default_model_manager_(storage_service->default_model_manager()) {}
|
||||
default_model_manager_(storage_service->default_model_manager()),
|
||||
time_trigger_sampling_rate_(TimeDelaySamplingRate.Get()) {}
|
||||
|
||||
TrainingDataCollectorImpl::~TrainingDataCollectorImpl() {
|
||||
histogram_signal_handler_->RemoveObserver(this);
|
||||
@ -277,6 +283,11 @@ void TrainingDataCollectorImpl::OnUserAction(const std::string& user_action,
|
||||
}
|
||||
}
|
||||
|
||||
void TrainingDataCollectorImpl::SetSamplingRateForTesting(
|
||||
uint64_t sampling_rate) {
|
||||
time_trigger_sampling_rate_ = sampling_rate;
|
||||
}
|
||||
|
||||
void TrainingDataCollectorImpl::OnUmaUpdatedReportForSegmentInfo(
|
||||
const absl::optional<ImmediateCollectionParam>& param,
|
||||
absl::optional<proto::SegmentInfo> segment) {
|
||||
@ -604,26 +615,28 @@ void TrainingDataCollectorImpl::OnGetTrainingTensorsAtDecisionTime(
|
||||
// TODO(haileywang): This is slightly inaccurate since the the delay timer is
|
||||
// only started after the input training tensors are cached.
|
||||
if (training_request.observation_delayed_task) {
|
||||
if (training_request.observation_delayed_task.value().is_zero()) {
|
||||
RecordTrainingDataCollectionEvent(
|
||||
segment_info.segment_id(),
|
||||
stats::TrainingDataCollectionEvent::kImmediateObservationPosted);
|
||||
} else {
|
||||
RecordTrainingDataCollectionEvent(
|
||||
segment_info.segment_id(),
|
||||
stats::TrainingDataCollectionEvent::kDelayedTaskPosted);
|
||||
}
|
||||
|
||||
VLOG(1) << "Observation timeout set for "
|
||||
<< proto::SegmentId_Name(segment_info.segment_id()) << " "
|
||||
<< *training_request.observation_delayed_task;
|
||||
|
||||
base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
|
||||
FROM_HERE,
|
||||
base::BindOnce(&TrainingDataCollectorImpl::OnObservationTrigger,
|
||||
weak_ptr_factory_.GetWeakPtr(), absl::nullopt,
|
||||
request_id, segment_info, base::DoNothing()),
|
||||
*training_request.observation_delayed_task);
|
||||
if (training_request.observation_delayed_task.value().is_zero()) {
|
||||
PostObservationTask(
|
||||
request_id, segment_info, *training_request.observation_delayed_task,
|
||||
stats::TrainingDataCollectionEvent::kImmediateObservationPosted);
|
||||
} else {
|
||||
// Sample time triggered data for ondemand models.
|
||||
if (IsPeriodic(segment_info) ||
|
||||
base::RandGenerator(time_trigger_sampling_rate_) == 0) {
|
||||
PostObservationTask(
|
||||
request_id, segment_info,
|
||||
*training_request.observation_delayed_task,
|
||||
stats::TrainingDataCollectionEvent::kDelayedTaskPosted);
|
||||
} else {
|
||||
RecordTrainingDataCollectionEvent(
|
||||
segment_info.segment_id(),
|
||||
stats::TrainingDataCollectionEvent::kDelayTriggerSampled);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
VLOG(1) << "Observation without timeout "
|
||||
<< proto::SegmentId_Name(segment_info.segment_id());
|
||||
@ -633,6 +646,20 @@ void TrainingDataCollectorImpl::OnGetTrainingTensorsAtDecisionTime(
|
||||
}
|
||||
}
|
||||
|
||||
void TrainingDataCollectorImpl::PostObservationTask(
|
||||
TrainingRequestId request_id,
|
||||
const proto::SegmentInfo& segment_info,
|
||||
const base::TimeDelta& delay,
|
||||
stats::TrainingDataCollectionEvent event) {
|
||||
base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
|
||||
FROM_HERE,
|
||||
base::BindOnce(&TrainingDataCollectorImpl::OnObservationTrigger,
|
||||
weak_ptr_factory_.GetWeakPtr(), absl::nullopt, request_id,
|
||||
segment_info, base::DoNothing()),
|
||||
delay);
|
||||
RecordTrainingDataCollectionEvent(segment_info.segment_id(), event);
|
||||
}
|
||||
|
||||
void TrainingDataCollectorImpl::OnObservationTrigger(
|
||||
const absl::optional<ImmediateCollectionParam>& param,
|
||||
TrainingRequestId request_id,
|
||||
|
@ -5,6 +5,7 @@
|
||||
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATA_COLLECTION_TRAINING_DATA_COLLECTOR_IMPL_H_
|
||||
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATA_COLLECTION_TRAINING_DATA_COLLECTOR_IMPL_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
@ -13,6 +14,7 @@
|
||||
#include "base/memory/raw_ptr.h"
|
||||
#include "base/memory/weak_ptr.h"
|
||||
#include "base/metrics/histogram_base.h"
|
||||
#include "base/time/time.h"
|
||||
#include "components/segmentation_platform/internal/data_collection/training_data_cache.h"
|
||||
#include "components/segmentation_platform/internal/data_collection/training_data_collector.h"
|
||||
#include "components/segmentation_platform/internal/database/cached_result_provider.h"
|
||||
@ -22,6 +24,8 @@
|
||||
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
|
||||
#include "components/segmentation_platform/internal/signals/histogram_signal_handler.h"
|
||||
#include "components/segmentation_platform/internal/signals/user_action_signal_handler.h"
|
||||
#include "components/segmentation_platform/internal/stats.h"
|
||||
#include "components/segmentation_platform/public/features.h"
|
||||
#include "components/segmentation_platform/public/model_provider.h"
|
||||
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
|
||||
#include "components/segmentation_platform/public/trigger.h"
|
||||
@ -68,6 +72,8 @@ class TrainingDataCollectorImpl : public TrainingDataCollector,
|
||||
void OnUserAction(const std::string& user_action,
|
||||
base::TimeTicks action_time) override;
|
||||
|
||||
void SetSamplingRateForTesting(uint64_t sampling_rate);
|
||||
|
||||
private:
|
||||
struct TrainingTimings;
|
||||
|
||||
@ -127,6 +133,11 @@ class TrainingDataCollectorImpl : public TrainingDataCollector,
|
||||
const ModelProvider::Request& input_tensors,
|
||||
const ModelProvider::Response& output_tensors);
|
||||
|
||||
void PostObservationTask(TrainingRequestId request_id,
|
||||
const proto::SegmentInfo& segment_info,
|
||||
const base::TimeDelta& delay,
|
||||
stats::TrainingDataCollectionEvent event);
|
||||
|
||||
// Returns whether training data can be reported through UKM. If
|
||||
// |include_output| is false, only input data will be checked to see if they
|
||||
// meet the collection requirement.
|
||||
@ -190,6 +201,8 @@ class TrainingDataCollectorImpl : public TrainingDataCollector,
|
||||
// list.
|
||||
base::flat_set<SegmentId> all_segments_for_training_;
|
||||
|
||||
uint64_t time_trigger_sampling_rate_{0};
|
||||
|
||||
base::WeakPtrFactory<TrainingDataCollectorImpl> weak_ptr_factory_{this};
|
||||
};
|
||||
|
||||
|
2
components/segmentation_platform/internal/data_collection/training_data_collector_impl_unittest.cc
2
components/segmentation_platform/internal/data_collection/training_data_collector_impl_unittest.cc
@ -191,6 +191,8 @@ class TrainingDataCollectorImplTest : public ::testing::Test {
|
||||
&histogram_signal_handler_, &user_action_signal_handler_,
|
||||
storage_service_.get(), &prefs_, &clock_,
|
||||
cached_result_provider_.get());
|
||||
|
||||
collector_->SetSamplingRateForTesting(1);
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -284,7 +284,8 @@ enum class TrainingDataCollectionEvent {
|
||||
kObservationDisallowed = 18,
|
||||
kTrainingDataMissing = 19,
|
||||
kOnDecisionTimeTypeMistmatch = 20,
|
||||
kMaxValue = kOnDecisionTimeTypeMistmatch,
|
||||
kDelayTriggerSampled = 21,
|
||||
kMaxValue = kDelayTriggerSampled,
|
||||
};
|
||||
|
||||
// Records analytics for training data collection.
|
||||
|
@ -115,4 +115,8 @@ BASE_FEATURE(kSegmentationPlatformIosModuleRanker,
|
||||
#else
|
||||
base::FEATURE_DISABLED_BY_DEFAULT);
|
||||
#endif
|
||||
|
||||
BASE_FEATURE(kSegmentationPlatformTimeDelaySampling,
|
||||
"SegmentationPlatformTimeDelaySampling",
|
||||
base::FEATURE_DISABLED_BY_DEFAULT);
|
||||
} // namespace segmentation_platform::features
|
||||
|
@ -86,6 +86,9 @@ BASE_DECLARE_FEATURE(kSegmentationPlatformTabResumptionRanker);
|
||||
|
||||
// Feature flag for enabling ios module ranker.
|
||||
BASE_DECLARE_FEATURE(kSegmentationPlatformIosModuleRanker);
|
||||
|
||||
// Feature flag for controlling sampling of training data collection.
|
||||
BASE_DECLARE_FEATURE(kSegmentationPlatformTimeDelaySampling);
|
||||
} // namespace segmentation_platform::features
|
||||
|
||||
#endif // COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_FEATURES_H_
|
||||
|
@ -96105,6 +96105,8 @@ https://www.dmtf.org/sites/default/files/standards/documents/DSP0134_2.7.1.pdf
|
||||
<int value="17" label="Disallowed for recording"/>
|
||||
<int value="18" label="Observation disallowed for recording"/>
|
||||
<int value="19" label="Training data missing"/>
|
||||
<int value="20" label="Decision type mismatch failure"/>
|
||||
<int value="21" label="Delay trigger sampled"/>
|
||||
</enum>
|
||||
|
||||
<enum name="SegmentationPlatformValidationResult">
|
||||
|
Reference in New Issue
Block a user