Passed FlatBuffer from Browser to Renderer process.
Stored Client-Side Phishing Detection Model (as a FlatBuffer, go/flatbuffers) in ReadOnlySharedMemory in client_side_phishing_model. Duplicated the shared memory region and passed to the components/safe_browsing/content/renderer/phishing_classifier/scorer.cc in renderer process via client_side_detection_host. Updated unit tests and browser tests. Design Doc: http://go/memory-regression-csd-android Bug: 1210696 Change-Id: I8168f521bd8bc0c68ba873a52b0fdbc488aa8678 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2912076 Commit-Queue: Rohit Bhatia <bhatiarohit@google.com> Reviewed-by: danakj <danakj@chromium.org> Reviewed-by: Daniel Rubery <drubery@chromium.org> Cr-Commit-Position: refs/heads/master@{#886554}
This commit is contained in:

committed by
Chromium LUCI CQ

parent
d4da940008
commit
26d7ea11b2
chrome
browser
safe_browsing
renderer
safe_browsing
components/safe_browsing/content
@ -12,6 +12,7 @@
|
||||
#include "base/callback_helpers.h"
|
||||
#include "base/files/file_path.h"
|
||||
#include "base/macros.h"
|
||||
#include "base/memory/read_only_shared_memory_region.h"
|
||||
#include "base/memory/ref_counted.h"
|
||||
#include "base/run_loop.h"
|
||||
#include "base/synchronization/waitable_event.h"
|
||||
@ -26,6 +27,7 @@
|
||||
#include "components/prefs/scoped_user_pref_update.h"
|
||||
#include "components/safe_browsing/content/browser/client_side_detection_service.h"
|
||||
#include "components/safe_browsing/content/browser/client_side_model_loader.h"
|
||||
#include "components/safe_browsing/content/browser/client_side_phishing_model.h"
|
||||
#include "components/safe_browsing/content/common/safe_browsing.mojom-shared.h"
|
||||
#include "components/safe_browsing/core/browser/sync/sync_utils.h"
|
||||
#include "components/safe_browsing/core/common/safe_browsing_prefs.h"
|
||||
@ -133,6 +135,8 @@ class MockClientSideDetectionService : public ClientSideDetectionService {
|
||||
MOCK_METHOD1(IsInCache, bool(const GURL&));
|
||||
MOCK_METHOD0(OverPhishingReportLimit, bool());
|
||||
MOCK_METHOD0(GetModelStr, std::string());
|
||||
MOCK_METHOD0(GetModelSharedMemoryRegion, base::ReadOnlySharedMemoryRegion());
|
||||
MOCK_METHOD0(GetModelType, CSDModelType());
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN(MockClientSideDetectionService);
|
||||
@ -195,6 +199,12 @@ class FakePhishingDetector : public mojom::PhishingDetector {
|
||||
model_ = model;
|
||||
}
|
||||
|
||||
// mojom::PhishingDetector
|
||||
void SetPhishingFlatBufferModel(base::ReadOnlySharedMemoryRegion region,
|
||||
base::File file) override {
|
||||
region_ = std::move(region);
|
||||
}
|
||||
|
||||
// mojom::PhishingDetector
|
||||
void StartPhishingDetection(
|
||||
const GURL& url,
|
||||
@ -223,10 +233,15 @@ class FakePhishingDetector : public mojom::PhishingDetector {
|
||||
|
||||
void CheckModel(const std::string& model) { EXPECT_EQ(model, model_); }
|
||||
|
||||
void CheckModel(base::ReadOnlySharedMemoryRegion region) {
|
||||
EXPECT_EQ(region.GetGUID(), region_.GetGUID());
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
phishing_detection_started_ = false;
|
||||
url_ = GURL();
|
||||
model_ = "";
|
||||
region_ = base::ReadOnlySharedMemoryRegion();
|
||||
}
|
||||
|
||||
private:
|
||||
@ -234,6 +249,7 @@ class FakePhishingDetector : public mojom::PhishingDetector {
|
||||
bool phishing_detection_started_ = false;
|
||||
GURL url_;
|
||||
std::string model_ = "";
|
||||
base::ReadOnlySharedMemoryRegion region_ = base::ReadOnlySharedMemoryRegion();
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(FakePhishingDetector);
|
||||
};
|
||||
@ -274,8 +290,7 @@ class ClientSideDetectionHostTestBase : public ChromeRenderViewHostTestHarness {
|
||||
InitTestApi();
|
||||
|
||||
// Inject service classes.
|
||||
csd_service_ =
|
||||
std::make_unique<StrictMock<MockClientSideDetectionService>>();
|
||||
csd_service_ = std::make_unique<MockClientSideDetectionService>();
|
||||
database_manager_ = new StrictMock<MockSafeBrowsingDatabaseManager>();
|
||||
ui_manager_ = new StrictMock<MockSafeBrowsingUIManager>(
|
||||
// TODO(crbug/925153): Port consumers of the SafeBrowsingService to
|
||||
@ -298,6 +313,8 @@ class ClientSideDetectionHostTestBase : public ChromeRenderViewHostTestHarness {
|
||||
std::make_unique<StrictMock<MockSafeBrowsingTokenFetcher>>();
|
||||
raw_token_fetcher_ = token_fetcher.get();
|
||||
csd_host_->set_token_fetcher_for_testing(std::move(token_fetcher));
|
||||
|
||||
testing::DefaultValue<CSDModelType>::Set(CSDModelType::kProtobuf);
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
@ -375,9 +392,9 @@ class ClientSideDetectionHostTestBase : public ChromeRenderViewHostTestHarness {
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ClientSideDetectionHost> csd_host_;
|
||||
std::unique_ptr<StrictMock<MockClientSideDetectionService>> csd_service_;
|
||||
std::unique_ptr<MockClientSideDetectionService> csd_service_;
|
||||
scoped_refptr<StrictMock<MockSafeBrowsingUIManager> > ui_manager_;
|
||||
scoped_refptr<StrictMock<MockSafeBrowsingDatabaseManager> > database_manager_;
|
||||
scoped_refptr<StrictMock<MockSafeBrowsingDatabaseManager>> database_manager_;
|
||||
FakePhishingDetector fake_phishing_detector_;
|
||||
StrictMock<MockSafeBrowsingTokenFetcher>* raw_token_fetcher_ = nullptr;
|
||||
base::SimpleTestTickClock clock_;
|
||||
@ -1138,6 +1155,20 @@ TEST_F(ClientSideDetectionHostTest, RecordsPhishingDetectionDuration) {
|
||||
.min);
|
||||
}
|
||||
|
||||
TEST_F(ClientSideDetectionHostTest, TestSendFlatBufferModelToRenderFrame) {
|
||||
base::MappedReadOnlyRegion mapped_region =
|
||||
base::ReadOnlySharedMemoryRegion::Create(10);
|
||||
EXPECT_CALL(*csd_service_, GetModelType())
|
||||
.WillRepeatedly(Return(CSDModelType::kFlatbuffer));
|
||||
EXPECT_CALL(*csd_service_, GetModelSharedMemoryRegion())
|
||||
.WillRepeatedly(
|
||||
Return(testing::ByMove(mapped_region.region.Duplicate())));
|
||||
csd_host_->SendModelToRenderFrame();
|
||||
base::RunLoop().RunUntilIdle();
|
||||
fake_phishing_detector_.CheckModel(mapped_region.region.Duplicate());
|
||||
fake_phishing_detector_.Reset();
|
||||
}
|
||||
|
||||
TEST_F(ClientSideDetectionHostTest, TestSendModelToRenderFrame) {
|
||||
EXPECT_CALL(*csd_service_, GetModelStr()).WillRepeatedly(Return("standard"));
|
||||
csd_host_->SendModelToRenderFrame();
|
||||
|
@ -249,6 +249,26 @@ TEST_F(PhishingClassifierDelegateTest, HasPhishingModel) {
|
||||
EXPECT_CALL(*classifier_, CancelPendingClassification());
|
||||
}
|
||||
|
||||
TEST_F(PhishingClassifierDelegateTest, HasFlatBufferModel) {
|
||||
ASSERT_FALSE(classifier_->is_ready());
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
flat::ClientSideModelBuilder csd_model_builder(builder);
|
||||
builder.Finish(csd_model_builder.Finish());
|
||||
std::string model_str(reinterpret_cast<char*>(builder.GetBufferPointer()),
|
||||
builder.GetSize());
|
||||
base::MappedReadOnlyRegion mapped_region =
|
||||
base::ReadOnlySharedMemoryRegion::Create(model_str.length());
|
||||
memcpy(mapped_region.mapping.memory(), model_str.data(), model_str.length());
|
||||
|
||||
delegate_->SetPhishingFlatBufferModel(mapped_region.region.Duplicate(),
|
||||
base::File());
|
||||
ASSERT_TRUE(classifier_->is_ready());
|
||||
|
||||
// The delegate will cancel pending classification on destruction.
|
||||
EXPECT_CALL(*classifier_, CancelPendingClassification());
|
||||
}
|
||||
|
||||
TEST_F(PhishingClassifierDelegateTest, HasVisualTfLiteModel) {
|
||||
ASSERT_FALSE(classifier_->is_ready());
|
||||
|
||||
|
@ -369,6 +369,22 @@ void ClientSideDetectionHost::DidFinishNavigation(
|
||||
classification_request_->Start();
|
||||
}
|
||||
|
||||
void ClientSideDetectionHost::SetPhishingModel() {
|
||||
switch (csd_service_->GetModelType()) {
|
||||
case CSDModelType::kNone:
|
||||
case CSDModelType::kProtobuf:
|
||||
phishing_detector_->SetPhishingModel(
|
||||
csd_service_->GetModelStr(),
|
||||
csd_service_->GetVisualTfLiteModel().Duplicate());
|
||||
return;
|
||||
case CSDModelType::kFlatbuffer:
|
||||
phishing_detector_->SetPhishingFlatBufferModel(
|
||||
csd_service_->GetModelSharedMemoryRegion(),
|
||||
csd_service_->GetVisualTfLiteModel().Duplicate());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void ClientSideDetectionHost::SendModelToRenderFrame() {
|
||||
DCHECK_CURRENTLY_ON(BrowserThread::UI);
|
||||
if (!web_contents() || web_contents() != tab_ || !csd_service_)
|
||||
@ -381,9 +397,7 @@ void ClientSideDetectionHost::SendModelToRenderFrame() {
|
||||
phishing_detector_.reset();
|
||||
frame->GetRemoteInterfaces()->GetInterface(
|
||||
phishing_detector_.BindNewPipeAndPassReceiver());
|
||||
phishing_detector_->SetPhishingModel(
|
||||
csd_service_->GetModelStr(),
|
||||
csd_service_->GetVisualTfLiteModel().Duplicate());
|
||||
SetPhishingModel();
|
||||
}
|
||||
}
|
||||
|
||||
@ -402,9 +416,7 @@ void ClientSideDetectionHost::RenderFrameCreated(
|
||||
phishing_detector_.reset();
|
||||
render_frame_host->GetRemoteInterfaces()->GetInterface(
|
||||
phishing_detector_.BindNewPipeAndPassReceiver());
|
||||
phishing_detector_->SetPhishingModel(
|
||||
csd_service_->GetModelStr(),
|
||||
csd_service_->GetVisualTfLiteModel().Duplicate());
|
||||
SetPhishingModel();
|
||||
}
|
||||
|
||||
void ClientSideDetectionHost::OnPhishingPreClassificationDone(
|
||||
|
@ -160,6 +160,9 @@ class ClientSideDetectionHost : public content::WebContentsObserver {
|
||||
// who are signed in and not in incognito mode.
|
||||
bool CanGetAccessToken();
|
||||
|
||||
// Set phishing model in PhishingDetector in renderers.
|
||||
void SetPhishingModel();
|
||||
|
||||
// Send the client report to CSD server.
|
||||
void SendRequest(std::unique_ptr<ClientPhishingRequest> verdict,
|
||||
const std::string& access_token);
|
||||
|
@ -435,6 +435,15 @@ std::string ClientSideDetectionService::GetModelStr() {
|
||||
return ClientSidePhishingModel::GetInstance()->GetModelStr();
|
||||
}
|
||||
|
||||
CSDModelType ClientSideDetectionService::GetModelType() {
|
||||
return ClientSidePhishingModel::GetInstance()->GetModelType();
|
||||
}
|
||||
|
||||
base::ReadOnlySharedMemoryRegion
|
||||
ClientSideDetectionService::GetModelSharedMemoryRegion() {
|
||||
return ClientSidePhishingModel::GetInstance()->GetModelSharedMemoryRegion();
|
||||
}
|
||||
|
||||
const base::File& ClientSideDetectionService::GetVisualTfLiteModel() {
|
||||
return ClientSidePhishingModel::GetInstance()->GetVisualTfLiteModel();
|
||||
}
|
||||
|
@ -24,11 +24,13 @@
|
||||
#include "base/containers/queue.h"
|
||||
#include "base/gtest_prod_util.h"
|
||||
#include "base/macros.h"
|
||||
#include "base/memory/read_only_shared_memory_region.h"
|
||||
#include "base/memory/ref_counted.h"
|
||||
#include "base/memory/weak_ptr.h"
|
||||
#include "base/time/time.h"
|
||||
#include "components/keyed_service/core/keyed_service.h"
|
||||
#include "components/prefs/pref_change_registrar.h"
|
||||
#include "components/safe_browsing/content/browser/client_side_phishing_model.h"
|
||||
#include "components/safe_browsing/core/proto/csd.pb.h"
|
||||
#include "content/public/browser/browser_thread.h"
|
||||
#include "content/public/browser/notification_observer.h"
|
||||
@ -123,10 +125,18 @@ class ClientSideDetectionService : public KeyedService {
|
||||
// Sends a model to each renderer.
|
||||
virtual void SendModelToRenderers();
|
||||
|
||||
// Returns the model string. Virtual so that mock implementation can override
|
||||
// it.
|
||||
// Returns the model string. Used only for protobuf model. Virtual so that
|
||||
// mock implementation can override it.
|
||||
virtual std::string GetModelStr();
|
||||
|
||||
// Returns the model type (protobuf or flatbuffer). Virtual so that mock
|
||||
// implementation can override it.
|
||||
virtual CSDModelType GetModelType();
|
||||
|
||||
// Returns the ReadOnlySharedMemoryRegion for the flatbuffer model. Virtual so
|
||||
// that mock implementation can override it.
|
||||
virtual base::ReadOnlySharedMemoryRegion GetModelSharedMemoryRegion();
|
||||
|
||||
// Returns the TfLite model file. Virtual so that mock implementation can
|
||||
// override it.
|
||||
virtual const base::File& GetVisualTfLiteModel();
|
||||
|
@ -10,6 +10,8 @@
|
||||
#include "base/command_line.h"
|
||||
#include "base/feature_list.h"
|
||||
#include "base/logging.h"
|
||||
#include "base/memory/read_only_shared_memory_region.h"
|
||||
#include "base/memory/shared_memory_mapping.h"
|
||||
#include "base/memory/singleton.h"
|
||||
#include "base/metrics/histogram_functions.h"
|
||||
#include "base/task/post_task.h"
|
||||
@ -80,10 +82,14 @@ base::CallbackListSubscription ClientSidePhishingModel::RegisterCallback(
|
||||
}
|
||||
|
||||
bool ClientSidePhishingModel::IsEnabled() const {
|
||||
return !model_str_.empty() || visual_tflite_model_.IsValid();
|
||||
return (model_type_ == CSDModelType::kFlatbuffer &&
|
||||
mapped_region_.IsValid()) ||
|
||||
(model_type_ == CSDModelType::kProtobuf && !model_str_.empty()) ||
|
||||
visual_tflite_model_.IsValid();
|
||||
}
|
||||
|
||||
std::string ClientSidePhishingModel::GetModelStr() const {
|
||||
DCHECK(model_type_ != CSDModelType::kFlatbuffer);
|
||||
return model_str_;
|
||||
}
|
||||
|
||||
@ -95,6 +101,11 @@ CSDModelType ClientSidePhishingModel::GetModelType() const {
|
||||
return model_type_;
|
||||
}
|
||||
|
||||
base::ReadOnlySharedMemoryRegion
|
||||
ClientSidePhishingModel::GetModelSharedMemoryRegion() const {
|
||||
return mapped_region_.region.Duplicate();
|
||||
}
|
||||
|
||||
void ClientSidePhishingModel::PopulateFromDynamicUpdate(
|
||||
const std::string& model_str,
|
||||
base::File visual_tflite_model) {
|
||||
@ -109,14 +120,21 @@ void ClientSidePhishingModel::PopulateFromDynamicUpdate(
|
||||
!model_str.empty()) {
|
||||
if (base::FeatureList::IsEnabled(kClientSideDetectionModelIsFlatBuffer)) {
|
||||
flatbuffers::Verifier verifier(
|
||||
const_cast<uint8_t*>(
|
||||
reinterpret_cast<const uint8_t*>(model_str.data())),
|
||||
reinterpret_cast<const uint8_t*>(model_str.data()),
|
||||
model_str.length());
|
||||
model_valid = flat::VerifyClientSideModelBuffer(verifier);
|
||||
if (model_valid) {
|
||||
model_type_ = CSDModelType::kFlatbuffer;
|
||||
model_version_field =
|
||||
flat::GetClientSideModel(model_str.data())->version();
|
||||
mapped_region_ =
|
||||
base::ReadOnlySharedMemoryRegion::Create(model_str.length());
|
||||
if (mapped_region_.IsValid()) {
|
||||
model_type_ = CSDModelType::kFlatbuffer;
|
||||
model_version_field =
|
||||
flat::GetClientSideModel(model_str.data())->version();
|
||||
memcpy(mapped_region_.mapping.memory(), model_str.data(),
|
||||
model_str.length());
|
||||
} else {
|
||||
model_valid = false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ClientSideModel model_proto;
|
||||
@ -124,6 +142,7 @@ void ClientSidePhishingModel::PopulateFromDynamicUpdate(
|
||||
if (model_valid) {
|
||||
model_type_ = CSDModelType::kProtobuf;
|
||||
model_version_field = model_proto.version();
|
||||
model_str_ = model_str;
|
||||
}
|
||||
}
|
||||
|
||||
@ -137,7 +156,6 @@ void ClientSidePhishingModel::PopulateFromDynamicUpdate(
|
||||
base::UmaHistogramExactLinear(
|
||||
"SBClientPhishing.ModelDynamicUpdateVersion", model_version_field,
|
||||
kMaxVersion + 1);
|
||||
model_str_ = model_str;
|
||||
}
|
||||
}
|
||||
|
||||
@ -170,6 +188,21 @@ void ClientSidePhishingModel::SetVisualTfLiteModelForTesting(base::File file) {
|
||||
visual_tflite_model_ = std::move(file);
|
||||
}
|
||||
|
||||
void ClientSidePhishingModel::SetModelTypeForTesting(CSDModelType model_type) {
|
||||
AutoLock lock(lock_);
|
||||
model_type_ = model_type;
|
||||
}
|
||||
|
||||
void ClientSidePhishingModel::ClearMappedRegionForTesting() {
|
||||
AutoLock lock(lock_);
|
||||
mapped_region_.mapping = base::WritableSharedMemoryMapping();
|
||||
mapped_region_.region = base::ReadOnlySharedMemoryRegion();
|
||||
}
|
||||
|
||||
void* ClientSidePhishingModel::GetFlatBufferMemoryAddressForTesting() {
|
||||
return mapped_region_.mapping.memory();
|
||||
}
|
||||
|
||||
void ClientSidePhishingModel::MaybeOverrideModel() {
|
||||
if (base::CommandLine::ForCurrentProcess()->HasSwitch(
|
||||
kOverrideCsdModelFlag)) {
|
||||
@ -206,18 +239,26 @@ void ClientSidePhishingModel::OnGetOverridenModelData(
|
||||
return;
|
||||
}
|
||||
model_type_ = model_type;
|
||||
model_str_ = model_data;
|
||||
break;
|
||||
}
|
||||
case CSDModelType::kFlatbuffer: {
|
||||
flatbuffers::Verifier verifier(
|
||||
const_cast<uint8_t*>(
|
||||
reinterpret_cast<const uint8_t*>(model_data.data())),
|
||||
reinterpret_cast<const uint8_t*>(model_data.data()),
|
||||
model_data.length());
|
||||
if (!flat::VerifyClientSideModelBuffer(verifier)) {
|
||||
VLOG(2)
|
||||
<< "Overriden model data is not a valid ClientSideModel flatbuffer";
|
||||
return;
|
||||
}
|
||||
mapped_region_ =
|
||||
base::ReadOnlySharedMemoryRegion::Create(model_data.length());
|
||||
if (!mapped_region_.IsValid()) {
|
||||
VLOG(2) << "Could not create shared memory region for flatbuffer";
|
||||
return;
|
||||
}
|
||||
memcpy(mapped_region_.mapping.memory(), model_data.data(),
|
||||
model_data.length());
|
||||
model_type_ = model_type;
|
||||
break;
|
||||
}
|
||||
@ -227,7 +268,6 @@ void ClientSidePhishingModel::OnGetOverridenModelData(
|
||||
}
|
||||
|
||||
VLOG(2) << "Model overriden successfully";
|
||||
model_str_ = model_data;
|
||||
|
||||
// Unretained is safe because this is a singleton.
|
||||
base::PostTask(FROM_HERE, {content::BrowserThread::UI},
|
||||
|
@ -10,8 +10,8 @@
|
||||
|
||||
#include "base/callback_list.h"
|
||||
#include "base/files/file.h"
|
||||
#include "base/memory/read_only_shared_memory_region.h"
|
||||
#include "base/synchronization/lock.h"
|
||||
#include "base/version.h"
|
||||
#include "components/safe_browsing/content/browser/client_side_model_loader.h"
|
||||
|
||||
namespace safe_browsing {
|
||||
@ -47,6 +47,9 @@ class ClientSidePhishingModel {
|
||||
// Returns the model string, as a serialized protobuf or flatbuffer.
|
||||
std::string GetModelStr() const;
|
||||
|
||||
// Returns the shared memory region for the flatbuffer.
|
||||
base::ReadOnlySharedMemoryRegion GetModelSharedMemoryRegion() const;
|
||||
|
||||
// Updates the internal model string, when one is received from a component
|
||||
// update.
|
||||
void PopulateFromDynamicUpdate(const std::string& model_str,
|
||||
@ -57,6 +60,12 @@ class ClientSidePhishingModel {
|
||||
// Overrides the model string for use in tests.
|
||||
void SetModelStrForTesting(const std::string& model_str);
|
||||
void SetVisualTfLiteModelForTesting(base::File file);
|
||||
// Overrides model type.
|
||||
void SetModelTypeForTesting(CSDModelType model_type);
|
||||
// Removes mapping.
|
||||
void ClearMappedRegionForTesting();
|
||||
// Get flatbuffer memory address.
|
||||
void* GetFlatBufferMemoryAddressForTesting();
|
||||
|
||||
// Called to check the command line and maybe override the current model.
|
||||
void MaybeOverrideModel();
|
||||
@ -76,14 +85,18 @@ class ClientSidePhishingModel {
|
||||
// lock_. Will always be notified on the UI thread.
|
||||
base::RepeatingCallbackList<void()> callbacks_;
|
||||
|
||||
// Model string (protobuf or flatbuffer). Protected by lock_.
|
||||
// Model protobuf string. Protected by lock_.
|
||||
std::string model_str_;
|
||||
|
||||
// Visual TFLite model file. Protected by lock_.
|
||||
base::File visual_tflite_model_;
|
||||
|
||||
// Model type as inferred by feature flag. Protected by lock_.
|
||||
CSDModelType model_type_;
|
||||
CSDModelType model_type_ = CSDModelType::kNone;
|
||||
|
||||
// MappedReadOnlyRegion where the flatbuffer has been copied to. Protected by
|
||||
// lock_.
|
||||
base::MappedReadOnlyRegion mapped_region_ = base::MappedReadOnlyRegion();
|
||||
|
||||
mutable base::Lock lock_;
|
||||
|
||||
|
@ -11,6 +11,8 @@
|
||||
#include "base/files/file.h"
|
||||
#include "base/files/scoped_temp_dir.h"
|
||||
#include "base/logging.h"
|
||||
#include "base/memory/read_only_shared_memory_region.h"
|
||||
#include "base/memory/shared_memory_mapping.h"
|
||||
#include "base/run_loop.h"
|
||||
#include "base/test/scoped_command_line.h"
|
||||
#include "base/test/scoped_feature_list.h"
|
||||
@ -29,6 +31,27 @@ void ResetClientSidePhishingModel() {
|
||||
ClientSidePhishingModel::GetInstance()->SetModelStrForTesting("");
|
||||
ClientSidePhishingModel::GetInstance()->SetVisualTfLiteModelForTesting(
|
||||
base::File());
|
||||
ClientSidePhishingModel::GetInstance()->ClearMappedRegionForTesting();
|
||||
ClientSidePhishingModel::GetInstance()->SetModelTypeForTesting(
|
||||
CSDModelType::kNone);
|
||||
}
|
||||
|
||||
std::string CreateFlatBufferString() {
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
flat::ClientSideModelBuilder csd_model_builder(builder);
|
||||
builder.Finish(csd_model_builder.Finish());
|
||||
return std::string(reinterpret_cast<char*>(builder.GetBufferPointer()),
|
||||
builder.GetSize());
|
||||
}
|
||||
|
||||
void GetFlatBufferStringFromMappedMemory(
|
||||
base::ReadOnlySharedMemoryRegion region,
|
||||
std::string* output) {
|
||||
ASSERT_TRUE(region.IsValid());
|
||||
base::ReadOnlySharedMemoryMapping mapping = region.Map();
|
||||
ASSERT_TRUE(mapping.IsValid());
|
||||
*output = std::string(reinterpret_cast<const char*>(mapping.memory()),
|
||||
mapping.size());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -235,12 +258,7 @@ TEST(ClientSidePhishingModelTest, CanOverrideFlatBufferWithFlag) {
|
||||
base::File::FLAG_READ |
|
||||
base::File::FLAG_WRITE);
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
flat::ClientSideModelBuilder csd_model_builder(builder);
|
||||
csd_model_builder.add_version(123);
|
||||
builder.Finish(csd_model_builder.Finish());
|
||||
const std::string file_contents(
|
||||
reinterpret_cast<char*>(builder.GetBufferPointer()), builder.GetSize());
|
||||
const std::string file_contents = CreateFlatBufferString();
|
||||
file.WriteAtCurrentPos(file_contents.data(), file_contents.size());
|
||||
|
||||
base::test::ScopedCommandLine command_line;
|
||||
@ -263,8 +281,11 @@ TEST(ClientSidePhishingModelTest, CanOverrideFlatBufferWithFlag) {
|
||||
|
||||
run_loop.Run();
|
||||
|
||||
EXPECT_EQ(ClientSidePhishingModel::GetInstance()->GetModelStr(),
|
||||
file_contents);
|
||||
std::string model_str_from_shared_mem;
|
||||
ASSERT_NO_FATAL_FAILURE(GetFlatBufferStringFromMappedMemory(
|
||||
ClientSidePhishingModel::GetInstance()->GetModelSharedMemoryRegion(),
|
||||
&model_str_from_shared_mem));
|
||||
EXPECT_EQ(model_str_from_shared_mem, file_contents);
|
||||
EXPECT_EQ(ClientSidePhishingModel::GetInstance()->GetModelType(),
|
||||
CSDModelType::kFlatbuffer);
|
||||
EXPECT_TRUE(called);
|
||||
@ -288,18 +309,82 @@ TEST(ClientSidePhishingModelTest, AcceptsValidFlatbufferIfFeatureEnabled) {
|
||||
},
|
||||
run_loop.QuitClosure(), &called));
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
flat::ClientSideModelBuilder csd_model_builder(builder);
|
||||
builder.Finish(csd_model_builder.Finish());
|
||||
const std::string model_str(
|
||||
reinterpret_cast<char*>(builder.GetBufferPointer()), builder.GetSize());
|
||||
const std::string model_str = CreateFlatBufferString();
|
||||
ClientSidePhishingModel::GetInstance()->PopulateFromDynamicUpdate(
|
||||
model_str, base::File());
|
||||
EXPECT_TRUE(ClientSidePhishingModel::GetInstance()->IsEnabled());
|
||||
EXPECT_EQ(model_str, ClientSidePhishingModel::GetInstance()->GetModelStr());
|
||||
|
||||
run_loop.Run();
|
||||
|
||||
EXPECT_TRUE(ClientSidePhishingModel::GetInstance()->IsEnabled());
|
||||
std::string model_str_from_shared_mem;
|
||||
ASSERT_NO_FATAL_FAILURE(GetFlatBufferStringFromMappedMemory(
|
||||
ClientSidePhishingModel::GetInstance()->GetModelSharedMemoryRegion(),
|
||||
&model_str_from_shared_mem));
|
||||
EXPECT_EQ(model_str, model_str_from_shared_mem);
|
||||
EXPECT_EQ(ClientSidePhishingModel::GetInstance()->GetModelType(),
|
||||
CSDModelType::kFlatbuffer);
|
||||
EXPECT_TRUE(called);
|
||||
}
|
||||
|
||||
TEST(ClientSidePhishingModelTest, FlatbufferonFollowingUpdate) {
|
||||
ResetClientSidePhishingModel();
|
||||
base::test::ScopedFeatureList feature_list;
|
||||
feature_list.InitWithFeatures(
|
||||
/*enabled_features=*/{kClientSideDetectionModelIsFlatBuffer},
|
||||
/*disabled_features=*/{});
|
||||
content::BrowserTaskEnvironment task_environment;
|
||||
base::RunLoop run_loop;
|
||||
|
||||
const std::string model_str1 = CreateFlatBufferString();
|
||||
ClientSidePhishingModel::GetInstance()->PopulateFromDynamicUpdate(
|
||||
model_str1, base::File());
|
||||
|
||||
run_loop.RunUntilIdle();
|
||||
EXPECT_TRUE(ClientSidePhishingModel::GetInstance()->IsEnabled());
|
||||
std::string model_str_from_shared_mem1;
|
||||
ASSERT_NO_FATAL_FAILURE(GetFlatBufferStringFromMappedMemory(
|
||||
ClientSidePhishingModel::GetInstance()->GetModelSharedMemoryRegion(),
|
||||
&model_str_from_shared_mem1));
|
||||
EXPECT_EQ(model_str1, model_str_from_shared_mem1);
|
||||
EXPECT_EQ(ClientSidePhishingModel::GetInstance()->GetModelType(),
|
||||
CSDModelType::kFlatbuffer);
|
||||
|
||||
// Should be able to write to memory with WritableSharedMemoryMapping field.
|
||||
void* memory_addr = ClientSidePhishingModel::GetInstance()
|
||||
->GetFlatBufferMemoryAddressForTesting();
|
||||
EXPECT_EQ(memset(memory_addr, 'G', 1), memory_addr);
|
||||
|
||||
bool called = false;
|
||||
base::CallbackListSubscription subscription =
|
||||
ClientSidePhishingModel::GetInstance()->RegisterCallback(
|
||||
base::BindRepeating(
|
||||
[](base::RepeatingClosure quit_closure, bool* called) {
|
||||
*called = true;
|
||||
std::move(quit_closure).Run();
|
||||
},
|
||||
run_loop.QuitClosure(), &called));
|
||||
|
||||
const std::string model_str2 = CreateFlatBufferString();
|
||||
ClientSidePhishingModel::GetInstance()->PopulateFromDynamicUpdate(
|
||||
model_str2, base::File());
|
||||
|
||||
run_loop.RunUntilIdle();
|
||||
EXPECT_TRUE(called);
|
||||
EXPECT_TRUE(ClientSidePhishingModel::GetInstance()->IsEnabled());
|
||||
std::string model_str_from_shared_mem2;
|
||||
ASSERT_NO_FATAL_FAILURE(GetFlatBufferStringFromMappedMemory(
|
||||
ClientSidePhishingModel::GetInstance()->GetModelSharedMemoryRegion(),
|
||||
&model_str_from_shared_mem2));
|
||||
EXPECT_EQ(model_str2, model_str_from_shared_mem2);
|
||||
EXPECT_EQ(ClientSidePhishingModel::GetInstance()->GetModelType(),
|
||||
CSDModelType::kFlatbuffer);
|
||||
|
||||
// Mapping should be undone automatically, even with a region copy lying
|
||||
// around. Death tests misbehave on Android, or the memory may be re-mapped.
|
||||
// See https://crbug.com/815537 and base/test/gtest_util.h.
|
||||
// Can remove this if flaky.
|
||||
#if defined(GTEST_HAS_DEATH_TEST) && !defined(OS_ANDROID)
|
||||
EXPECT_DEATH_IF_SUPPORTED(memset(memory_addr, 'G', 1), "");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace safe_browsing
|
||||
|
@ -4,8 +4,10 @@
|
||||
|
||||
module safe_browsing.mojom;
|
||||
|
||||
|
||||
import "components/safe_browsing/core/common/safe_browsing_url_checker.mojom";
|
||||
import "mojo/public/mojom/base/read_only_file.mojom";
|
||||
import "mojo/public/mojom/base/shared_memory.mojom";
|
||||
import "services/network/public/mojom/http_request_headers.mojom";
|
||||
import "services/network/public/mojom/fetch_api.mojom";
|
||||
import "url/mojom/url.mojom";
|
||||
@ -128,6 +130,16 @@ interface PhishingDetector {
|
||||
// to classify the appearance of pages.
|
||||
SetPhishingModel(string model, mojo_base.mojom.ReadOnlyFile? tflite_model);
|
||||
|
||||
// A classification model for client-side phishing detection. This call sends
|
||||
// the model from the browser process to the renderer process. The model is
|
||||
// sent as a safe_browsing::ClientSideModel flatbuffer string in a
|
||||
// ReadOnlySharedMemoryRegion to client-side phishing detection on the
|
||||
// renderer process. An invalid region is used to disable classification. The
|
||||
// |tflite_model| is a file handle with contents a TfLite model, which is used
|
||||
// to classify the appearance of pages.
|
||||
SetPhishingFlatBufferModel(mojo_base.mojom.ReadOnlySharedMemoryRegion region,
|
||||
mojo_base.mojom.ReadOnlyFile? tflite_model);
|
||||
|
||||
// Tells the renderer to begin phishing detection for the given toplevel URL
|
||||
// which it has started loading. Returns the serialized request proto and a
|
||||
// |result| enum to indicate failure. If the URL is phishing the request proto
|
||||
|
@ -7,6 +7,7 @@
|
||||
|
||||
#include "base/bind.h"
|
||||
#include "base/memory/ptr_util.h"
|
||||
#include "base/memory/read_only_shared_memory_region.h"
|
||||
#include "base/memory/scoped_refptr.h"
|
||||
#include "base/run_loop.h"
|
||||
#include "base/single_thread_task_runner.h"
|
||||
@ -120,6 +121,9 @@ class TestPhishingDetector : public mojom::PhishingDetector {
|
||||
|
||||
void SetPhishingModel(const std::string& model, base::File file) override {}
|
||||
|
||||
void SetPhishingFlatBufferModel(base::ReadOnlySharedMemoryRegion region,
|
||||
base::File file) override {}
|
||||
|
||||
void StartPhishingDetection(
|
||||
const GURL& url,
|
||||
StartPhishingDetectionCallback callback) override {
|
||||
|
@ -37,6 +37,7 @@ source_set("phishing_classifier") {
|
||||
"//components/safe_browsing/core:client_model_proto",
|
||||
"//components/safe_browsing/core:csd_proto",
|
||||
"//components/safe_browsing/core/common",
|
||||
"//components/safe_browsing/core/fbs:client_model",
|
||||
"//content/public/renderer",
|
||||
"//crypto",
|
||||
"//skia",
|
||||
@ -70,6 +71,7 @@ source_set("unit_tests") {
|
||||
"//components/safe_browsing/content/renderer/phishing_classifier:unit_tests_support",
|
||||
"//components/safe_browsing/core:client_model_proto",
|
||||
"//components/safe_browsing/core:csd_proto",
|
||||
"//components/safe_browsing/core/fbs:client_model",
|
||||
"//crypto",
|
||||
"//skia",
|
||||
"//testing/gmock",
|
||||
@ -97,6 +99,7 @@ if (use_libfuzzer) {
|
||||
":client_side_phishing_fuzzer_proto",
|
||||
":phishing_classifier",
|
||||
"//base:base",
|
||||
"//components/safe_browsing/core/fbs:client_model",
|
||||
"//skia",
|
||||
"//third_party/libprotobuf-mutator",
|
||||
]
|
||||
|
@ -96,6 +96,23 @@ void PhishingClassifierDelegate::SetPhishingModel(
|
||||
g_phishing_scorer.Get().reset(scorer);
|
||||
}
|
||||
|
||||
void PhishingClassifierDelegate::SetPhishingFlatBufferModel(
|
||||
base::ReadOnlySharedMemoryRegion flatbuffer_region,
|
||||
base::File tflite_visual_model) {
|
||||
safe_browsing::Scorer* scorer = nullptr;
|
||||
// An invalid region or invalid model file means we should disable
|
||||
// client-side phishing detection.
|
||||
if (flatbuffer_region.IsValid() || tflite_visual_model.IsValid()) {
|
||||
scorer = safe_browsing::Scorer::Create(std::move(flatbuffer_region),
|
||||
std::move(tflite_visual_model));
|
||||
if (!scorer)
|
||||
return;
|
||||
}
|
||||
for (auto* delegate : PhishingClassifierDelegates())
|
||||
delegate->SetPhishingScorer(scorer);
|
||||
g_phishing_scorer.Get().reset(scorer);
|
||||
}
|
||||
|
||||
// static
|
||||
PhishingClassifierDelegate* PhishingClassifierDelegate::Create(
|
||||
content::RenderFrame* render_frame,
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include <string>
|
||||
|
||||
#include "base/macros.h"
|
||||
#include "base/memory/read_only_shared_memory_region.h"
|
||||
#include "components/safe_browsing/content/common/safe_browsing.mojom.h"
|
||||
#include "content/public/renderer/render_frame_observer.h"
|
||||
#include "content/public/renderer/render_thread_observer.h"
|
||||
@ -52,6 +53,11 @@ class PhishingClassifierDelegate : public content::RenderFrameObserver,
|
||||
void SetPhishingModel(const std::string& model,
|
||||
base::File tflite_visual_model) override;
|
||||
|
||||
// mojom::PhishingDetector
|
||||
void SetPhishingFlatBufferModel(
|
||||
base::ReadOnlySharedMemoryRegion flatbuffer_region,
|
||||
base::File tflite_visual_model) override;
|
||||
|
||||
// Called by the RenderFrame once there is a phishing scorer available.
|
||||
// The scorer is passed on to the classifier.
|
||||
void SetPhishingScorer(const safe_browsing::Scorer* scorer);
|
||||
|
@ -10,6 +10,8 @@
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "base/memory/read_only_shared_memory_region.h"
|
||||
#include "base/memory/shared_memory_mapping.h"
|
||||
#include "base/metrics/histogram_macros.h"
|
||||
#include "base/strings/string_number_conversions.h"
|
||||
#include "base/strings/string_piece.h"
|
||||
@ -41,6 +43,9 @@ enum ScorerCreationStatus {
|
||||
SCORER_FAIL_MODEL_PARSE_ERROR,
|
||||
SCORER_FAIL_MODEL_MISSING_FIELDS,
|
||||
SCORER_FAIL_MAP_VISUAL_TFLITE_MODEL,
|
||||
SCORER_FAIL_FLATBUFFER_INVALID_REGION,
|
||||
SCORER_FAIL_FLATBUFFER_INVALID_MAPPING,
|
||||
SCORER_FAIL_FLATBUFFER_FAILED_VERIFY,
|
||||
SCORER_STATUS_MAX // Always add new values before this one.
|
||||
};
|
||||
|
||||
@ -255,6 +260,42 @@ Scorer* Scorer::Create(const base::StringPiece& model_str,
|
||||
return scorer.release();
|
||||
}
|
||||
|
||||
/* static */
|
||||
Scorer* Scorer::Create(base::ReadOnlySharedMemoryRegion region,
|
||||
base::File visual_tflite_model) {
|
||||
std::unique_ptr<Scorer> scorer(new Scorer());
|
||||
|
||||
if (!region.IsValid()) {
|
||||
RecordScorerCreationStatus(SCORER_FAIL_FLATBUFFER_INVALID_REGION);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
base::ReadOnlySharedMemoryMapping mapping = region.Map();
|
||||
if (!mapping.IsValid()) {
|
||||
RecordScorerCreationStatus(SCORER_FAIL_FLATBUFFER_INVALID_MAPPING);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
flatbuffers::Verifier verifier(
|
||||
reinterpret_cast<const uint8_t*>(mapping.memory()), mapping.size());
|
||||
if (!flat::VerifyClientSideModelBuffer(verifier)) {
|
||||
RecordScorerCreationStatus(SCORER_FAIL_FLATBUFFER_FAILED_VERIFY);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Only do this part if the visual model file exists
|
||||
if (visual_tflite_model.IsValid() && !scorer->visual_tflite_model_.Initialize(
|
||||
std::move(visual_tflite_model))) {
|
||||
RecordScorerCreationStatus(SCORER_FAIL_MAP_VISUAL_TFLITE_MODEL);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
RecordScorerCreationStatus(SCORER_SUCCESS);
|
||||
scorer->flatbuffer_model_ = flat::GetClientSideModel(mapping.memory());
|
||||
scorer->flatbuffer_mapping_ = std::move(mapping);
|
||||
return scorer.release();
|
||||
}
|
||||
|
||||
double Scorer::ComputeScore(const FeatureMap& features) const {
|
||||
double logodds = 0.0;
|
||||
for (int i = 0; i < model_.rule_size(); ++i) {
|
||||
|
@ -24,8 +24,10 @@
|
||||
#include "base/files/file.h"
|
||||
#include "base/files/memory_mapped_file.h"
|
||||
#include "base/macros.h"
|
||||
#include "base/memory/read_only_shared_memory_region.h"
|
||||
#include "base/memory/weak_ptr.h"
|
||||
#include "base/strings/string_piece.h"
|
||||
#include "components/safe_browsing/core/fbs/client_model_generated.h"
|
||||
#include "components/safe_browsing/core/proto/client_model.pb.h"
|
||||
#include "components/safe_browsing/core/proto/csd.pb.h"
|
||||
#include "third_party/skia/include/core/SkBitmap.h"
|
||||
@ -39,10 +41,17 @@ class Scorer {
|
||||
virtual ~Scorer();
|
||||
|
||||
// Factory method which creates a new Scorer object by parsing the given
|
||||
// model. If parsing fails this method returns NULL.
|
||||
// model. If parsing fails this method returns NULL.
|
||||
// Can use this if model_str is empty.
|
||||
static Scorer* Create(const base::StringPiece& model_str,
|
||||
base::File visual_tflite_model);
|
||||
|
||||
// Factory method which creates a new Scorer object by parsing the given
|
||||
// flatbuffer or tflite model. If parsing fails this method returns NULL.
|
||||
// Use this only if region is valid.
|
||||
static Scorer* Create(base::ReadOnlySharedMemoryRegion region,
|
||||
base::File visual_tflite_model);
|
||||
|
||||
// This method computes the probability that the given features are indicative
|
||||
// of phishing. It returns a score value that falls in the range [0.0,1.0]
|
||||
// (range is inclusive on both ends).
|
||||
@ -122,6 +131,12 @@ class Scorer {
|
||||
std::unordered_set<std::string> page_terms_;
|
||||
std::unordered_set<uint32_t> page_words_;
|
||||
|
||||
// Unowned. Points within flatbuffer_mapping_ and should not be free()d.
|
||||
// It remains valid till flatbuffer_mapping_ is valid and should be reassigned
|
||||
// if the mapping is updated.
|
||||
const flat::ClientSideModel* flatbuffer_model_;
|
||||
base::ReadOnlySharedMemoryMapping flatbuffer_mapping_;
|
||||
|
||||
base::MemoryMappedFile visual_tflite_model_;
|
||||
|
||||
base::WeakPtrFactory<Scorer> weak_ptr_factory_{this};
|
||||
|
@ -12,12 +12,14 @@
|
||||
#include "base/files/file_path.h"
|
||||
#include "base/files/scoped_temp_dir.h"
|
||||
#include "base/format_macros.h"
|
||||
#include "base/memory/read_only_shared_memory_region.h"
|
||||
#include "base/run_loop.h"
|
||||
#include "base/test/bind.h"
|
||||
#include "base/test/task_environment.h"
|
||||
#include "base/test/test_discardable_memory_allocator.h"
|
||||
#include "base/threading/thread.h"
|
||||
#include "components/safe_browsing/content/renderer/phishing_classifier/features.h"
|
||||
#include "components/safe_browsing/core/fbs/client_model_generated.h"
|
||||
#include "components/safe_browsing/core/proto/client_model.pb.h"
|
||||
#include "components/safe_browsing/core/proto/csd.pb.h"
|
||||
#include "testing/gmock/include/gmock/gmock.h"
|
||||
@ -25,6 +27,26 @@
|
||||
|
||||
namespace safe_browsing {
|
||||
|
||||
namespace {
|
||||
|
||||
std::string GetFlatBufferString() {
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
flat::ClientSideModelBuilder csd_model_builder(builder);
|
||||
builder.Finish(csd_model_builder.Finish());
|
||||
return std::string(reinterpret_cast<char*>(builder.GetBufferPointer()),
|
||||
builder.GetSize());
|
||||
}
|
||||
|
||||
base::MappedReadOnlyRegion GetMappedReadOnlyRegionWithData(std::string data) {
|
||||
base::MappedReadOnlyRegion mapped_region =
|
||||
base::ReadOnlySharedMemoryRegion::Create(data.length());
|
||||
EXPECT_TRUE(mapped_region.IsValid());
|
||||
memcpy(mapped_region.mapping.memory(), data.data(), data.length());
|
||||
return mapped_region;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class PhishingScorerTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
@ -108,6 +130,25 @@ class PhishingScorerTest : public ::testing::Test {
|
||||
base::TestDiscardableMemoryAllocator test_allocator_;
|
||||
};
|
||||
|
||||
TEST_F(PhishingScorerTest, HasValidFlatBufferModel) {
|
||||
std::unique_ptr<Scorer> scorer;
|
||||
std::string flatbuffer = GetFlatBufferString();
|
||||
base::MappedReadOnlyRegion mapped_region =
|
||||
GetMappedReadOnlyRegionWithData(flatbuffer);
|
||||
scorer.reset(Scorer::Create(mapped_region.region.Duplicate(), base::File()));
|
||||
EXPECT_TRUE(scorer.get() != nullptr);
|
||||
|
||||
// Invalid region.
|
||||
scorer.reset(
|
||||
Scorer::Create(base::ReadOnlySharedMemoryRegion(), base::File()));
|
||||
EXPECT_FALSE(scorer.get());
|
||||
|
||||
// Invalid buffer in region.
|
||||
mapped_region = GetMappedReadOnlyRegionWithData("bogus string");
|
||||
scorer.reset(Scorer::Create(mapped_region.region.Duplicate(), base::File()));
|
||||
EXPECT_FALSE(scorer.get());
|
||||
}
|
||||
|
||||
TEST_F(PhishingScorerTest, HasValidModel) {
|
||||
std::unique_ptr<Scorer> scorer;
|
||||
scorer.reset(Scorer::Create(model_.SerializeAsString(), base::File()));
|
||||
|
Reference in New Issue
Block a user