[TDC] Add sampling to time triggered data collection
Change-Id: Ia43152d129a7a11887191ecf7bb6a2be61ee5c7c Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4757485 Reviewed-by: Siddhartha S <ssid@chromium.org> Commit-Queue: Hailey Wang <haileywang@google.com> Cr-Commit-Position: refs/heads/main@{#1181741}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
9f51e807c3
commit
46865afffb
components/segmentation_platform
internal
data_collection
training_data_collector_impl.cctraining_data_collector_impl.htraining_data_collector_impl_unittest.cc
stats.hpublic
tools/metrics/histograms
@@ -8,9 +8,11 @@
|
|||||||
#include "base/containers/contains.h"
|
#include "base/containers/contains.h"
|
||||||
#include "base/functional/callback_helpers.h"
|
#include "base/functional/callback_helpers.h"
|
||||||
#include "base/logging.h"
|
#include "base/logging.h"
|
||||||
|
#include "base/metrics/field_trial_params.h"
|
||||||
#include "base/metrics/metrics_hashes.h"
|
#include "base/metrics/metrics_hashes.h"
|
||||||
#include "base/metrics/user_metrics.h"
|
#include "base/metrics/user_metrics.h"
|
||||||
#include "base/notreached.h"
|
#include "base/notreached.h"
|
||||||
|
#include "base/rand_util.h"
|
||||||
#include "base/task/single_thread_task_runner.h"
|
#include "base/task/single_thread_task_runner.h"
|
||||||
#include "base/time/clock.h"
|
#include "base/time/clock.h"
|
||||||
#include "base/time/time.h"
|
#include "base/time/time.h"
|
||||||
@@ -25,7 +27,6 @@
|
|||||||
#include "components/segmentation_platform/internal/platform_options.h"
|
#include "components/segmentation_platform/internal/platform_options.h"
|
||||||
#include "components/segmentation_platform/internal/segmentation_ukm_helper.h"
|
#include "components/segmentation_platform/internal/segmentation_ukm_helper.h"
|
||||||
#include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
|
#include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
|
||||||
#include "components/segmentation_platform/internal/stats.h"
|
|
||||||
#include "components/segmentation_platform/public/config.h"
|
#include "components/segmentation_platform/public/config.h"
|
||||||
#include "components/segmentation_platform/public/local_state_helper.h"
|
#include "components/segmentation_platform/public/local_state_helper.h"
|
||||||
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
|
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
|
||||||
@@ -84,6 +85,10 @@ bool IsPeriodic(const proto::SegmentInfo& info) {
|
|||||||
return type == proto::TrainingOutputs::TriggerConfig::PERIODIC;
|
return type == proto::TrainingOutputs::TriggerConfig::PERIODIC;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constexpr base::FeatureParam<int> TimeDelaySamplingRate{
|
||||||
|
&features::kSegmentationPlatformTimeDelaySampling,
|
||||||
|
/*name=*/"SamplingRate", /*default_value=*/20};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
struct TrainingDataCollectorImpl::TrainingTimings {
|
struct TrainingDataCollectorImpl::TrainingTimings {
|
||||||
@@ -112,7 +117,8 @@ TrainingDataCollectorImpl::TrainingDataCollectorImpl(
|
|||||||
cached_result_provider_(cached_result_provider),
|
cached_result_provider_(cached_result_provider),
|
||||||
training_cache_(std::make_unique<TrainingDataCache>(
|
training_cache_(std::make_unique<TrainingDataCache>(
|
||||||
storage_service->segment_info_database())),
|
storage_service->segment_info_database())),
|
||||||
default_model_manager_(storage_service->default_model_manager()) {}
|
default_model_manager_(storage_service->default_model_manager()),
|
||||||
|
time_trigger_sampling_rate_(TimeDelaySamplingRate.Get()) {}
|
||||||
|
|
||||||
TrainingDataCollectorImpl::~TrainingDataCollectorImpl() {
|
TrainingDataCollectorImpl::~TrainingDataCollectorImpl() {
|
||||||
histogram_signal_handler_->RemoveObserver(this);
|
histogram_signal_handler_->RemoveObserver(this);
|
||||||
@@ -277,6 +283,11 @@ void TrainingDataCollectorImpl::OnUserAction(const std::string& user_action,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TrainingDataCollectorImpl::SetSamplingRateForTesting(
|
||||||
|
uint64_t sampling_rate) {
|
||||||
|
time_trigger_sampling_rate_ = sampling_rate;
|
||||||
|
}
|
||||||
|
|
||||||
void TrainingDataCollectorImpl::OnUmaUpdatedReportForSegmentInfo(
|
void TrainingDataCollectorImpl::OnUmaUpdatedReportForSegmentInfo(
|
||||||
const absl::optional<ImmediateCollectionParam>& param,
|
const absl::optional<ImmediateCollectionParam>& param,
|
||||||
absl::optional<proto::SegmentInfo> segment) {
|
absl::optional<proto::SegmentInfo> segment) {
|
||||||
@@ -604,26 +615,28 @@ void TrainingDataCollectorImpl::OnGetTrainingTensorsAtDecisionTime(
|
|||||||
// TODO(haileywang): This is slightly inaccurate since the the delay timer is
|
// TODO(haileywang): This is slightly inaccurate since the the delay timer is
|
||||||
// only started after the input training tensors are cached.
|
// only started after the input training tensors are cached.
|
||||||
if (training_request.observation_delayed_task) {
|
if (training_request.observation_delayed_task) {
|
||||||
if (training_request.observation_delayed_task.value().is_zero()) {
|
|
||||||
RecordTrainingDataCollectionEvent(
|
|
||||||
segment_info.segment_id(),
|
|
||||||
stats::TrainingDataCollectionEvent::kImmediateObservationPosted);
|
|
||||||
} else {
|
|
||||||
RecordTrainingDataCollectionEvent(
|
|
||||||
segment_info.segment_id(),
|
|
||||||
stats::TrainingDataCollectionEvent::kDelayedTaskPosted);
|
|
||||||
}
|
|
||||||
|
|
||||||
VLOG(1) << "Observation timeout set for "
|
VLOG(1) << "Observation timeout set for "
|
||||||
<< proto::SegmentId_Name(segment_info.segment_id()) << " "
|
<< proto::SegmentId_Name(segment_info.segment_id()) << " "
|
||||||
<< *training_request.observation_delayed_task;
|
<< *training_request.observation_delayed_task;
|
||||||
|
|
||||||
base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
|
if (training_request.observation_delayed_task.value().is_zero()) {
|
||||||
FROM_HERE,
|
PostObservationTask(
|
||||||
base::BindOnce(&TrainingDataCollectorImpl::OnObservationTrigger,
|
request_id, segment_info, *training_request.observation_delayed_task,
|
||||||
weak_ptr_factory_.GetWeakPtr(), absl::nullopt,
|
stats::TrainingDataCollectionEvent::kImmediateObservationPosted);
|
||||||
request_id, segment_info, base::DoNothing()),
|
} else {
|
||||||
*training_request.observation_delayed_task);
|
// Sample time triggered data for ondemand models.
|
||||||
|
if (IsPeriodic(segment_info) ||
|
||||||
|
base::RandGenerator(time_trigger_sampling_rate_) == 0) {
|
||||||
|
PostObservationTask(
|
||||||
|
request_id, segment_info,
|
||||||
|
*training_request.observation_delayed_task,
|
||||||
|
stats::TrainingDataCollectionEvent::kDelayedTaskPosted);
|
||||||
|
} else {
|
||||||
|
RecordTrainingDataCollectionEvent(
|
||||||
|
segment_info.segment_id(),
|
||||||
|
stats::TrainingDataCollectionEvent::kDelayTriggerSampled);
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
VLOG(1) << "Observation without timeout "
|
VLOG(1) << "Observation without timeout "
|
||||||
<< proto::SegmentId_Name(segment_info.segment_id());
|
<< proto::SegmentId_Name(segment_info.segment_id());
|
||||||
@@ -633,6 +646,20 @@ void TrainingDataCollectorImpl::OnGetTrainingTensorsAtDecisionTime(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TrainingDataCollectorImpl::PostObservationTask(
|
||||||
|
TrainingRequestId request_id,
|
||||||
|
const proto::SegmentInfo& segment_info,
|
||||||
|
const base::TimeDelta& delay,
|
||||||
|
stats::TrainingDataCollectionEvent event) {
|
||||||
|
base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
|
||||||
|
FROM_HERE,
|
||||||
|
base::BindOnce(&TrainingDataCollectorImpl::OnObservationTrigger,
|
||||||
|
weak_ptr_factory_.GetWeakPtr(), absl::nullopt, request_id,
|
||||||
|
segment_info, base::DoNothing()),
|
||||||
|
delay);
|
||||||
|
RecordTrainingDataCollectionEvent(segment_info.segment_id(), event);
|
||||||
|
}
|
||||||
|
|
||||||
void TrainingDataCollectorImpl::OnObservationTrigger(
|
void TrainingDataCollectorImpl::OnObservationTrigger(
|
||||||
const absl::optional<ImmediateCollectionParam>& param,
|
const absl::optional<ImmediateCollectionParam>& param,
|
||||||
TrainingRequestId request_id,
|
TrainingRequestId request_id,
|
||||||
|
@@ -5,6 +5,7 @@
|
|||||||
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATA_COLLECTION_TRAINING_DATA_COLLECTOR_IMPL_H_
|
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATA_COLLECTION_TRAINING_DATA_COLLECTOR_IMPL_H_
|
||||||
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATA_COLLECTION_TRAINING_DATA_COLLECTOR_IMPL_H_
|
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATA_COLLECTION_TRAINING_DATA_COLLECTOR_IMPL_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@@ -13,6 +14,7 @@
|
|||||||
#include "base/memory/raw_ptr.h"
|
#include "base/memory/raw_ptr.h"
|
||||||
#include "base/memory/weak_ptr.h"
|
#include "base/memory/weak_ptr.h"
|
||||||
#include "base/metrics/histogram_base.h"
|
#include "base/metrics/histogram_base.h"
|
||||||
|
#include "base/time/time.h"
|
||||||
#include "components/segmentation_platform/internal/data_collection/training_data_cache.h"
|
#include "components/segmentation_platform/internal/data_collection/training_data_cache.h"
|
||||||
#include "components/segmentation_platform/internal/data_collection/training_data_collector.h"
|
#include "components/segmentation_platform/internal/data_collection/training_data_collector.h"
|
||||||
#include "components/segmentation_platform/internal/database/cached_result_provider.h"
|
#include "components/segmentation_platform/internal/database/cached_result_provider.h"
|
||||||
@@ -22,6 +24,8 @@
|
|||||||
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
|
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
|
||||||
#include "components/segmentation_platform/internal/signals/histogram_signal_handler.h"
|
#include "components/segmentation_platform/internal/signals/histogram_signal_handler.h"
|
||||||
#include "components/segmentation_platform/internal/signals/user_action_signal_handler.h"
|
#include "components/segmentation_platform/internal/signals/user_action_signal_handler.h"
|
||||||
|
#include "components/segmentation_platform/internal/stats.h"
|
||||||
|
#include "components/segmentation_platform/public/features.h"
|
||||||
#include "components/segmentation_platform/public/model_provider.h"
|
#include "components/segmentation_platform/public/model_provider.h"
|
||||||
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
|
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
|
||||||
#include "components/segmentation_platform/public/trigger.h"
|
#include "components/segmentation_platform/public/trigger.h"
|
||||||
@@ -68,6 +72,8 @@ class TrainingDataCollectorImpl : public TrainingDataCollector,
|
|||||||
void OnUserAction(const std::string& user_action,
|
void OnUserAction(const std::string& user_action,
|
||||||
base::TimeTicks action_time) override;
|
base::TimeTicks action_time) override;
|
||||||
|
|
||||||
|
void SetSamplingRateForTesting(uint64_t sampling_rate);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct TrainingTimings;
|
struct TrainingTimings;
|
||||||
|
|
||||||
@@ -127,6 +133,11 @@ class TrainingDataCollectorImpl : public TrainingDataCollector,
|
|||||||
const ModelProvider::Request& input_tensors,
|
const ModelProvider::Request& input_tensors,
|
||||||
const ModelProvider::Response& output_tensors);
|
const ModelProvider::Response& output_tensors);
|
||||||
|
|
||||||
|
void PostObservationTask(TrainingRequestId request_id,
|
||||||
|
const proto::SegmentInfo& segment_info,
|
||||||
|
const base::TimeDelta& delay,
|
||||||
|
stats::TrainingDataCollectionEvent event);
|
||||||
|
|
||||||
// Returns whether training data can be reported through UKM. If
|
// Returns whether training data can be reported through UKM. If
|
||||||
// |include_output| is false, only input data will be checked to see if they
|
// |include_output| is false, only input data will be checked to see if they
|
||||||
// meet the collection requirement.
|
// meet the collection requirement.
|
||||||
@@ -190,6 +201,8 @@ class TrainingDataCollectorImpl : public TrainingDataCollector,
|
|||||||
// list.
|
// list.
|
||||||
base::flat_set<SegmentId> all_segments_for_training_;
|
base::flat_set<SegmentId> all_segments_for_training_;
|
||||||
|
|
||||||
|
uint64_t time_trigger_sampling_rate_{0};
|
||||||
|
|
||||||
base::WeakPtrFactory<TrainingDataCollectorImpl> weak_ptr_factory_{this};
|
base::WeakPtrFactory<TrainingDataCollectorImpl> weak_ptr_factory_{this};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
2
components/segmentation_platform/internal/data_collection/training_data_collector_impl_unittest.cc
2
components/segmentation_platform/internal/data_collection/training_data_collector_impl_unittest.cc
@@ -191,6 +191,8 @@ class TrainingDataCollectorImplTest : public ::testing::Test {
|
|||||||
&histogram_signal_handler_, &user_action_signal_handler_,
|
&histogram_signal_handler_, &user_action_signal_handler_,
|
||||||
storage_service_.get(), &prefs_, &clock_,
|
storage_service_.get(), &prefs_, &clock_,
|
||||||
cached_result_provider_.get());
|
cached_result_provider_.get());
|
||||||
|
|
||||||
|
collector_->SetSamplingRateForTesting(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@@ -284,7 +284,8 @@ enum class TrainingDataCollectionEvent {
|
|||||||
kObservationDisallowed = 18,
|
kObservationDisallowed = 18,
|
||||||
kTrainingDataMissing = 19,
|
kTrainingDataMissing = 19,
|
||||||
kOnDecisionTimeTypeMistmatch = 20,
|
kOnDecisionTimeTypeMistmatch = 20,
|
||||||
kMaxValue = kOnDecisionTimeTypeMistmatch,
|
kDelayTriggerSampled = 21,
|
||||||
|
kMaxValue = kDelayTriggerSampled,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Records analytics for training data collection.
|
// Records analytics for training data collection.
|
||||||
|
@@ -115,4 +115,8 @@ BASE_FEATURE(kSegmentationPlatformIosModuleRanker,
|
|||||||
#else
|
#else
|
||||||
base::FEATURE_DISABLED_BY_DEFAULT);
|
base::FEATURE_DISABLED_BY_DEFAULT);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
BASE_FEATURE(kSegmentationPlatformTimeDelaySampling,
|
||||||
|
"SegmentationPlatformTimeDelaySampling",
|
||||||
|
base::FEATURE_DISABLED_BY_DEFAULT);
|
||||||
} // namespace segmentation_platform::features
|
} // namespace segmentation_platform::features
|
||||||
|
@@ -86,6 +86,9 @@ BASE_DECLARE_FEATURE(kSegmentationPlatformTabResumptionRanker);
|
|||||||
|
|
||||||
// Feature flag for enabling ios module ranker.
|
// Feature flag for enabling ios module ranker.
|
||||||
BASE_DECLARE_FEATURE(kSegmentationPlatformIosModuleRanker);
|
BASE_DECLARE_FEATURE(kSegmentationPlatformIosModuleRanker);
|
||||||
|
|
||||||
|
// Feature flag for controlling sampling of training data collection.
|
||||||
|
BASE_DECLARE_FEATURE(kSegmentationPlatformTimeDelaySampling);
|
||||||
} // namespace segmentation_platform::features
|
} // namespace segmentation_platform::features
|
||||||
|
|
||||||
#endif // COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_FEATURES_H_
|
#endif // COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_FEATURES_H_
|
||||||
|
@@ -96105,6 +96105,8 @@ https://www.dmtf.org/sites/default/files/standards/documents/DSP0134_2.7.1.pdf
|
|||||||
<int value="17" label="Disallowed for recording"/>
|
<int value="17" label="Disallowed for recording"/>
|
||||||
<int value="18" label="Observation disallowed for recording"/>
|
<int value="18" label="Observation disallowed for recording"/>
|
||||||
<int value="19" label="Training data missing"/>
|
<int value="19" label="Training data missing"/>
|
||||||
|
<int value="20" label="Decision type mismatch failure"/>
|
||||||
|
<int value="21" label="Delay trigger sampled"/>
|
||||||
</enum>
|
</enum>
|
||||||
|
|
||||||
<enum name="SegmentationPlatformValidationResult">
|
<enum name="SegmentationPlatformValidationResult">
|
||||||
|
Reference in New Issue
Block a user