0

Passed FlatBuffer from Browser to Renderer process.

Stored Client-Side Phishing Detection Model (as a FlatBuffer, go/flatbuffers) in ReadOnlySharedMemory in client_side_phishing_model.

Duplicated the shared memory region and passed to the components/safe_browsing/content/renderer/phishing_classifier/scorer.cc in renderer process via
client_side_detection_host.

Updated unit tests and browser tests.

Design Doc: http://go/memory-regression-csd-android

Bug: 1210696
Change-Id: I8168f521bd8bc0c68ba873a52b0fdbc488aa8678
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2912076
Commit-Queue: Rohit Bhatia <bhatiarohit@google.com>
Reviewed-by: danakj <danakj@chromium.org>
Reviewed-by: Daniel Rubery <drubery@chromium.org>
Cr-Commit-Position: refs/heads/master@{#886554}
This commit is contained in:
Rohit Bhatia
2021-05-26 01:33:03 +00:00
committed by Chromium LUCI CQ
parent d4da940008
commit 26d7ea11b2
17 changed files with 404 additions and 42 deletions

@ -12,6 +12,7 @@
#include "base/callback_helpers.h"
#include "base/files/file_path.h"
#include "base/macros.h"
#include "base/memory/read_only_shared_memory_region.h"
#include "base/memory/ref_counted.h"
#include "base/run_loop.h"
#include "base/synchronization/waitable_event.h"
@ -26,6 +27,7 @@
#include "components/prefs/scoped_user_pref_update.h"
#include "components/safe_browsing/content/browser/client_side_detection_service.h"
#include "components/safe_browsing/content/browser/client_side_model_loader.h"
#include "components/safe_browsing/content/browser/client_side_phishing_model.h"
#include "components/safe_browsing/content/common/safe_browsing.mojom-shared.h"
#include "components/safe_browsing/core/browser/sync/sync_utils.h"
#include "components/safe_browsing/core/common/safe_browsing_prefs.h"
@ -133,6 +135,8 @@ class MockClientSideDetectionService : public ClientSideDetectionService {
MOCK_METHOD1(IsInCache, bool(const GURL&));
MOCK_METHOD0(OverPhishingReportLimit, bool());
MOCK_METHOD0(GetModelStr, std::string());
MOCK_METHOD0(GetModelSharedMemoryRegion, base::ReadOnlySharedMemoryRegion());
MOCK_METHOD0(GetModelType, CSDModelType());
private:
DISALLOW_COPY_AND_ASSIGN(MockClientSideDetectionService);
@ -195,6 +199,12 @@ class FakePhishingDetector : public mojom::PhishingDetector {
model_ = model;
}
// mojom::PhishingDetector
void SetPhishingFlatBufferModel(base::ReadOnlySharedMemoryRegion region,
base::File file) override {
region_ = std::move(region);
}
// mojom::PhishingDetector
void StartPhishingDetection(
const GURL& url,
@ -223,10 +233,15 @@ class FakePhishingDetector : public mojom::PhishingDetector {
void CheckModel(const std::string& model) { EXPECT_EQ(model, model_); }
void CheckModel(base::ReadOnlySharedMemoryRegion region) {
EXPECT_EQ(region.GetGUID(), region_.GetGUID());
}
void Reset() {
phishing_detection_started_ = false;
url_ = GURL();
model_ = "";
region_ = base::ReadOnlySharedMemoryRegion();
}
private:
@ -234,6 +249,7 @@ class FakePhishingDetector : public mojom::PhishingDetector {
bool phishing_detection_started_ = false;
GURL url_;
std::string model_ = "";
base::ReadOnlySharedMemoryRegion region_ = base::ReadOnlySharedMemoryRegion();
DISALLOW_COPY_AND_ASSIGN(FakePhishingDetector);
};
@ -274,8 +290,7 @@ class ClientSideDetectionHostTestBase : public ChromeRenderViewHostTestHarness {
InitTestApi();
// Inject service classes.
csd_service_ =
std::make_unique<StrictMock<MockClientSideDetectionService>>();
csd_service_ = std::make_unique<MockClientSideDetectionService>();
database_manager_ = new StrictMock<MockSafeBrowsingDatabaseManager>();
ui_manager_ = new StrictMock<MockSafeBrowsingUIManager>(
// TODO(crbug/925153): Port consumers of the SafeBrowsingService to
@ -298,6 +313,8 @@ class ClientSideDetectionHostTestBase : public ChromeRenderViewHostTestHarness {
std::make_unique<StrictMock<MockSafeBrowsingTokenFetcher>>();
raw_token_fetcher_ = token_fetcher.get();
csd_host_->set_token_fetcher_for_testing(std::move(token_fetcher));
testing::DefaultValue<CSDModelType>::Set(CSDModelType::kProtobuf);
}
void TearDown() override {
@ -375,9 +392,9 @@ class ClientSideDetectionHostTestBase : public ChromeRenderViewHostTestHarness {
protected:
std::unique_ptr<ClientSideDetectionHost> csd_host_;
std::unique_ptr<StrictMock<MockClientSideDetectionService>> csd_service_;
std::unique_ptr<MockClientSideDetectionService> csd_service_;
scoped_refptr<StrictMock<MockSafeBrowsingUIManager> > ui_manager_;
scoped_refptr<StrictMock<MockSafeBrowsingDatabaseManager> > database_manager_;
scoped_refptr<StrictMock<MockSafeBrowsingDatabaseManager>> database_manager_;
FakePhishingDetector fake_phishing_detector_;
StrictMock<MockSafeBrowsingTokenFetcher>* raw_token_fetcher_ = nullptr;
base::SimpleTestTickClock clock_;
@ -1138,6 +1155,20 @@ TEST_F(ClientSideDetectionHostTest, RecordsPhishingDetectionDuration) {
.min);
}
TEST_F(ClientSideDetectionHostTest, TestSendFlatBufferModelToRenderFrame) {
base::MappedReadOnlyRegion mapped_region =
base::ReadOnlySharedMemoryRegion::Create(10);
EXPECT_CALL(*csd_service_, GetModelType())
.WillRepeatedly(Return(CSDModelType::kFlatbuffer));
EXPECT_CALL(*csd_service_, GetModelSharedMemoryRegion())
.WillRepeatedly(
Return(testing::ByMove(mapped_region.region.Duplicate())));
csd_host_->SendModelToRenderFrame();
base::RunLoop().RunUntilIdle();
fake_phishing_detector_.CheckModel(mapped_region.region.Duplicate());
fake_phishing_detector_.Reset();
}
TEST_F(ClientSideDetectionHostTest, TestSendModelToRenderFrame) {
EXPECT_CALL(*csd_service_, GetModelStr()).WillRepeatedly(Return("standard"));
csd_host_->SendModelToRenderFrame();

@ -249,6 +249,26 @@ TEST_F(PhishingClassifierDelegateTest, HasPhishingModel) {
EXPECT_CALL(*classifier_, CancelPendingClassification());
}
TEST_F(PhishingClassifierDelegateTest, HasFlatBufferModel) {
ASSERT_FALSE(classifier_->is_ready());
flatbuffers::FlatBufferBuilder builder(1024);
flat::ClientSideModelBuilder csd_model_builder(builder);
builder.Finish(csd_model_builder.Finish());
std::string model_str(reinterpret_cast<char*>(builder.GetBufferPointer()),
builder.GetSize());
base::MappedReadOnlyRegion mapped_region =
base::ReadOnlySharedMemoryRegion::Create(model_str.length());
memcpy(mapped_region.mapping.memory(), model_str.data(), model_str.length());
delegate_->SetPhishingFlatBufferModel(mapped_region.region.Duplicate(),
base::File());
ASSERT_TRUE(classifier_->is_ready());
// The delegate will cancel pending classification on destruction.
EXPECT_CALL(*classifier_, CancelPendingClassification());
}
TEST_F(PhishingClassifierDelegateTest, HasVisualTfLiteModel) {
ASSERT_FALSE(classifier_->is_ready());

@ -369,6 +369,22 @@ void ClientSideDetectionHost::DidFinishNavigation(
classification_request_->Start();
}
void ClientSideDetectionHost::SetPhishingModel() {
switch (csd_service_->GetModelType()) {
case CSDModelType::kNone:
case CSDModelType::kProtobuf:
phishing_detector_->SetPhishingModel(
csd_service_->GetModelStr(),
csd_service_->GetVisualTfLiteModel().Duplicate());
return;
case CSDModelType::kFlatbuffer:
phishing_detector_->SetPhishingFlatBufferModel(
csd_service_->GetModelSharedMemoryRegion(),
csd_service_->GetVisualTfLiteModel().Duplicate());
return;
}
}
void ClientSideDetectionHost::SendModelToRenderFrame() {
DCHECK_CURRENTLY_ON(BrowserThread::UI);
if (!web_contents() || web_contents() != tab_ || !csd_service_)
@ -381,9 +397,7 @@ void ClientSideDetectionHost::SendModelToRenderFrame() {
phishing_detector_.reset();
frame->GetRemoteInterfaces()->GetInterface(
phishing_detector_.BindNewPipeAndPassReceiver());
phishing_detector_->SetPhishingModel(
csd_service_->GetModelStr(),
csd_service_->GetVisualTfLiteModel().Duplicate());
SetPhishingModel();
}
}
@ -402,9 +416,7 @@ void ClientSideDetectionHost::RenderFrameCreated(
phishing_detector_.reset();
render_frame_host->GetRemoteInterfaces()->GetInterface(
phishing_detector_.BindNewPipeAndPassReceiver());
phishing_detector_->SetPhishingModel(
csd_service_->GetModelStr(),
csd_service_->GetVisualTfLiteModel().Duplicate());
SetPhishingModel();
}
void ClientSideDetectionHost::OnPhishingPreClassificationDone(

@ -160,6 +160,9 @@ class ClientSideDetectionHost : public content::WebContentsObserver {
// who are signed in and not in incognito mode.
bool CanGetAccessToken();
// Set phishing model in PhishingDetector in renderers.
void SetPhishingModel();
// Send the client report to CSD server.
void SendRequest(std::unique_ptr<ClientPhishingRequest> verdict,
const std::string& access_token);

@ -435,6 +435,15 @@ std::string ClientSideDetectionService::GetModelStr() {
return ClientSidePhishingModel::GetInstance()->GetModelStr();
}
CSDModelType ClientSideDetectionService::GetModelType() {
return ClientSidePhishingModel::GetInstance()->GetModelType();
}
base::ReadOnlySharedMemoryRegion
ClientSideDetectionService::GetModelSharedMemoryRegion() {
return ClientSidePhishingModel::GetInstance()->GetModelSharedMemoryRegion();
}
const base::File& ClientSideDetectionService::GetVisualTfLiteModel() {
return ClientSidePhishingModel::GetInstance()->GetVisualTfLiteModel();
}

@ -24,11 +24,13 @@
#include "base/containers/queue.h"
#include "base/gtest_prod_util.h"
#include "base/macros.h"
#include "base/memory/read_only_shared_memory_region.h"
#include "base/memory/ref_counted.h"
#include "base/memory/weak_ptr.h"
#include "base/time/time.h"
#include "components/keyed_service/core/keyed_service.h"
#include "components/prefs/pref_change_registrar.h"
#include "components/safe_browsing/content/browser/client_side_phishing_model.h"
#include "components/safe_browsing/core/proto/csd.pb.h"
#include "content/public/browser/browser_thread.h"
#include "content/public/browser/notification_observer.h"
@ -123,10 +125,18 @@ class ClientSideDetectionService : public KeyedService {
// Sends a model to each renderer.
virtual void SendModelToRenderers();
// Returns the model string. Virtual so that mock implementation can override
// it.
// Returns the model string. Used only for protobuf model. Virtual so that
// mock implementation can override it.
virtual std::string GetModelStr();
// Returns the model type (protobuf or flatbuffer). Virtual so that mock
// implementation can override it.
virtual CSDModelType GetModelType();
// Returns the ReadOnlySharedMemoryRegion for the flatbuffer model. Virtual so
// that mock implementation can override it.
virtual base::ReadOnlySharedMemoryRegion GetModelSharedMemoryRegion();
// Returns the TfLite model file. Virtual so that mock implementation can
// override it.
virtual const base::File& GetVisualTfLiteModel();

@ -10,6 +10,8 @@
#include "base/command_line.h"
#include "base/feature_list.h"
#include "base/logging.h"
#include "base/memory/read_only_shared_memory_region.h"
#include "base/memory/shared_memory_mapping.h"
#include "base/memory/singleton.h"
#include "base/metrics/histogram_functions.h"
#include "base/task/post_task.h"
@ -80,10 +82,14 @@ base::CallbackListSubscription ClientSidePhishingModel::RegisterCallback(
}
bool ClientSidePhishingModel::IsEnabled() const {
return !model_str_.empty() || visual_tflite_model_.IsValid();
return (model_type_ == CSDModelType::kFlatbuffer &&
mapped_region_.IsValid()) ||
(model_type_ == CSDModelType::kProtobuf && !model_str_.empty()) ||
visual_tflite_model_.IsValid();
}
std::string ClientSidePhishingModel::GetModelStr() const {
DCHECK(model_type_ != CSDModelType::kFlatbuffer);
return model_str_;
}
@ -95,6 +101,11 @@ CSDModelType ClientSidePhishingModel::GetModelType() const {
return model_type_;
}
base::ReadOnlySharedMemoryRegion
ClientSidePhishingModel::GetModelSharedMemoryRegion() const {
return mapped_region_.region.Duplicate();
}
void ClientSidePhishingModel::PopulateFromDynamicUpdate(
const std::string& model_str,
base::File visual_tflite_model) {
@ -109,14 +120,21 @@ void ClientSidePhishingModel::PopulateFromDynamicUpdate(
!model_str.empty()) {
if (base::FeatureList::IsEnabled(kClientSideDetectionModelIsFlatBuffer)) {
flatbuffers::Verifier verifier(
const_cast<uint8_t*>(
reinterpret_cast<const uint8_t*>(model_str.data())),
reinterpret_cast<const uint8_t*>(model_str.data()),
model_str.length());
model_valid = flat::VerifyClientSideModelBuffer(verifier);
if (model_valid) {
model_type_ = CSDModelType::kFlatbuffer;
model_version_field =
flat::GetClientSideModel(model_str.data())->version();
mapped_region_ =
base::ReadOnlySharedMemoryRegion::Create(model_str.length());
if (mapped_region_.IsValid()) {
model_type_ = CSDModelType::kFlatbuffer;
model_version_field =
flat::GetClientSideModel(model_str.data())->version();
memcpy(mapped_region_.mapping.memory(), model_str.data(),
model_str.length());
} else {
model_valid = false;
}
}
} else {
ClientSideModel model_proto;
@ -124,6 +142,7 @@ void ClientSidePhishingModel::PopulateFromDynamicUpdate(
if (model_valid) {
model_type_ = CSDModelType::kProtobuf;
model_version_field = model_proto.version();
model_str_ = model_str;
}
}
@ -137,7 +156,6 @@ void ClientSidePhishingModel::PopulateFromDynamicUpdate(
base::UmaHistogramExactLinear(
"SBClientPhishing.ModelDynamicUpdateVersion", model_version_field,
kMaxVersion + 1);
model_str_ = model_str;
}
}
@ -170,6 +188,21 @@ void ClientSidePhishingModel::SetVisualTfLiteModelForTesting(base::File file) {
visual_tflite_model_ = std::move(file);
}
void ClientSidePhishingModel::SetModelTypeForTesting(CSDModelType model_type) {
AutoLock lock(lock_);
model_type_ = model_type;
}
void ClientSidePhishingModel::ClearMappedRegionForTesting() {
AutoLock lock(lock_);
mapped_region_.mapping = base::WritableSharedMemoryMapping();
mapped_region_.region = base::ReadOnlySharedMemoryRegion();
}
void* ClientSidePhishingModel::GetFlatBufferMemoryAddressForTesting() {
return mapped_region_.mapping.memory();
}
void ClientSidePhishingModel::MaybeOverrideModel() {
if (base::CommandLine::ForCurrentProcess()->HasSwitch(
kOverrideCsdModelFlag)) {
@ -206,18 +239,26 @@ void ClientSidePhishingModel::OnGetOverridenModelData(
return;
}
model_type_ = model_type;
model_str_ = model_data;
break;
}
case CSDModelType::kFlatbuffer: {
flatbuffers::Verifier verifier(
const_cast<uint8_t*>(
reinterpret_cast<const uint8_t*>(model_data.data())),
reinterpret_cast<const uint8_t*>(model_data.data()),
model_data.length());
if (!flat::VerifyClientSideModelBuffer(verifier)) {
VLOG(2)
<< "Overriden model data is not a valid ClientSideModel flatbuffer";
return;
}
mapped_region_ =
base::ReadOnlySharedMemoryRegion::Create(model_data.length());
if (!mapped_region_.IsValid()) {
VLOG(2) << "Could not create shared memory region for flatbuffer";
return;
}
memcpy(mapped_region_.mapping.memory(), model_data.data(),
model_data.length());
model_type_ = model_type;
break;
}
@ -227,7 +268,6 @@ void ClientSidePhishingModel::OnGetOverridenModelData(
}
VLOG(2) << "Model overriden successfully";
model_str_ = model_data;
// Unretained is safe because this is a singleton.
base::PostTask(FROM_HERE, {content::BrowserThread::UI},

@ -10,8 +10,8 @@
#include "base/callback_list.h"
#include "base/files/file.h"
#include "base/memory/read_only_shared_memory_region.h"
#include "base/synchronization/lock.h"
#include "base/version.h"
#include "components/safe_browsing/content/browser/client_side_model_loader.h"
namespace safe_browsing {
@ -47,6 +47,9 @@ class ClientSidePhishingModel {
// Returns the model string, as a serialized protobuf or flatbuffer.
std::string GetModelStr() const;
// Returns the shared memory region for the flatbuffer.
base::ReadOnlySharedMemoryRegion GetModelSharedMemoryRegion() const;
// Updates the internal model string, when one is received from a component
// update.
void PopulateFromDynamicUpdate(const std::string& model_str,
@ -57,6 +60,12 @@ class ClientSidePhishingModel {
// Overrides the model string for use in tests.
void SetModelStrForTesting(const std::string& model_str);
void SetVisualTfLiteModelForTesting(base::File file);
// Overrides model type.
void SetModelTypeForTesting(CSDModelType model_type);
// Removes mapping.
void ClearMappedRegionForTesting();
// Get flatbuffer memory address.
void* GetFlatBufferMemoryAddressForTesting();
// Called to check the command line and maybe override the current model.
void MaybeOverrideModel();
@ -76,14 +85,18 @@ class ClientSidePhishingModel {
// lock_. Will always be notified on the UI thread.
base::RepeatingCallbackList<void()> callbacks_;
// Model string (protobuf or flatbuffer). Protected by lock_.
// Model protobuf string. Protected by lock_.
std::string model_str_;
// Visual TFLite model file. Protected by lock_.
base::File visual_tflite_model_;
// Model type as inferred by feature flag. Protected by lock_.
CSDModelType model_type_;
CSDModelType model_type_ = CSDModelType::kNone;
// MappedReadOnlyRegion where the flatbuffer has been copied to. Protected by
// lock_.
base::MappedReadOnlyRegion mapped_region_ = base::MappedReadOnlyRegion();
mutable base::Lock lock_;

@ -11,6 +11,8 @@
#include "base/files/file.h"
#include "base/files/scoped_temp_dir.h"
#include "base/logging.h"
#include "base/memory/read_only_shared_memory_region.h"
#include "base/memory/shared_memory_mapping.h"
#include "base/run_loop.h"
#include "base/test/scoped_command_line.h"
#include "base/test/scoped_feature_list.h"
@ -29,6 +31,27 @@ void ResetClientSidePhishingModel() {
ClientSidePhishingModel::GetInstance()->SetModelStrForTesting("");
ClientSidePhishingModel::GetInstance()->SetVisualTfLiteModelForTesting(
base::File());
ClientSidePhishingModel::GetInstance()->ClearMappedRegionForTesting();
ClientSidePhishingModel::GetInstance()->SetModelTypeForTesting(
CSDModelType::kNone);
}
std::string CreateFlatBufferString() {
flatbuffers::FlatBufferBuilder builder(1024);
flat::ClientSideModelBuilder csd_model_builder(builder);
builder.Finish(csd_model_builder.Finish());
return std::string(reinterpret_cast<char*>(builder.GetBufferPointer()),
builder.GetSize());
}
void GetFlatBufferStringFromMappedMemory(
base::ReadOnlySharedMemoryRegion region,
std::string* output) {
ASSERT_TRUE(region.IsValid());
base::ReadOnlySharedMemoryMapping mapping = region.Map();
ASSERT_TRUE(mapping.IsValid());
*output = std::string(reinterpret_cast<const char*>(mapping.memory()),
mapping.size());
}
} // namespace
@ -235,12 +258,7 @@ TEST(ClientSidePhishingModelTest, CanOverrideFlatBufferWithFlag) {
base::File::FLAG_READ |
base::File::FLAG_WRITE);
flatbuffers::FlatBufferBuilder builder(1024);
flat::ClientSideModelBuilder csd_model_builder(builder);
csd_model_builder.add_version(123);
builder.Finish(csd_model_builder.Finish());
const std::string file_contents(
reinterpret_cast<char*>(builder.GetBufferPointer()), builder.GetSize());
const std::string file_contents = CreateFlatBufferString();
file.WriteAtCurrentPos(file_contents.data(), file_contents.size());
base::test::ScopedCommandLine command_line;
@ -263,8 +281,11 @@ TEST(ClientSidePhishingModelTest, CanOverrideFlatBufferWithFlag) {
run_loop.Run();
EXPECT_EQ(ClientSidePhishingModel::GetInstance()->GetModelStr(),
file_contents);
std::string model_str_from_shared_mem;
ASSERT_NO_FATAL_FAILURE(GetFlatBufferStringFromMappedMemory(
ClientSidePhishingModel::GetInstance()->GetModelSharedMemoryRegion(),
&model_str_from_shared_mem));
EXPECT_EQ(model_str_from_shared_mem, file_contents);
EXPECT_EQ(ClientSidePhishingModel::GetInstance()->GetModelType(),
CSDModelType::kFlatbuffer);
EXPECT_TRUE(called);
@ -288,18 +309,82 @@ TEST(ClientSidePhishingModelTest, AcceptsValidFlatbufferIfFeatureEnabled) {
},
run_loop.QuitClosure(), &called));
flatbuffers::FlatBufferBuilder builder(1024);
flat::ClientSideModelBuilder csd_model_builder(builder);
builder.Finish(csd_model_builder.Finish());
const std::string model_str(
reinterpret_cast<char*>(builder.GetBufferPointer()), builder.GetSize());
const std::string model_str = CreateFlatBufferString();
ClientSidePhishingModel::GetInstance()->PopulateFromDynamicUpdate(
model_str, base::File());
EXPECT_TRUE(ClientSidePhishingModel::GetInstance()->IsEnabled());
EXPECT_EQ(model_str, ClientSidePhishingModel::GetInstance()->GetModelStr());
run_loop.Run();
EXPECT_TRUE(ClientSidePhishingModel::GetInstance()->IsEnabled());
std::string model_str_from_shared_mem;
ASSERT_NO_FATAL_FAILURE(GetFlatBufferStringFromMappedMemory(
ClientSidePhishingModel::GetInstance()->GetModelSharedMemoryRegion(),
&model_str_from_shared_mem));
EXPECT_EQ(model_str, model_str_from_shared_mem);
EXPECT_EQ(ClientSidePhishingModel::GetInstance()->GetModelType(),
CSDModelType::kFlatbuffer);
EXPECT_TRUE(called);
}
TEST(ClientSidePhishingModelTest, FlatbufferonFollowingUpdate) {
ResetClientSidePhishingModel();
base::test::ScopedFeatureList feature_list;
feature_list.InitWithFeatures(
/*enabled_features=*/{kClientSideDetectionModelIsFlatBuffer},
/*disabled_features=*/{});
content::BrowserTaskEnvironment task_environment;
base::RunLoop run_loop;
const std::string model_str1 = CreateFlatBufferString();
ClientSidePhishingModel::GetInstance()->PopulateFromDynamicUpdate(
model_str1, base::File());
run_loop.RunUntilIdle();
EXPECT_TRUE(ClientSidePhishingModel::GetInstance()->IsEnabled());
std::string model_str_from_shared_mem1;
ASSERT_NO_FATAL_FAILURE(GetFlatBufferStringFromMappedMemory(
ClientSidePhishingModel::GetInstance()->GetModelSharedMemoryRegion(),
&model_str_from_shared_mem1));
EXPECT_EQ(model_str1, model_str_from_shared_mem1);
EXPECT_EQ(ClientSidePhishingModel::GetInstance()->GetModelType(),
CSDModelType::kFlatbuffer);
// Should be able to write to memory with WritableSharedMemoryMapping field.
void* memory_addr = ClientSidePhishingModel::GetInstance()
->GetFlatBufferMemoryAddressForTesting();
EXPECT_EQ(memset(memory_addr, 'G', 1), memory_addr);
bool called = false;
base::CallbackListSubscription subscription =
ClientSidePhishingModel::GetInstance()->RegisterCallback(
base::BindRepeating(
[](base::RepeatingClosure quit_closure, bool* called) {
*called = true;
std::move(quit_closure).Run();
},
run_loop.QuitClosure(), &called));
const std::string model_str2 = CreateFlatBufferString();
ClientSidePhishingModel::GetInstance()->PopulateFromDynamicUpdate(
model_str2, base::File());
run_loop.RunUntilIdle();
EXPECT_TRUE(called);
EXPECT_TRUE(ClientSidePhishingModel::GetInstance()->IsEnabled());
std::string model_str_from_shared_mem2;
ASSERT_NO_FATAL_FAILURE(GetFlatBufferStringFromMappedMemory(
ClientSidePhishingModel::GetInstance()->GetModelSharedMemoryRegion(),
&model_str_from_shared_mem2));
EXPECT_EQ(model_str2, model_str_from_shared_mem2);
EXPECT_EQ(ClientSidePhishingModel::GetInstance()->GetModelType(),
CSDModelType::kFlatbuffer);
// Mapping should be undone automatically, even with a region copy lying
// around. Death tests misbehave on Android, or the memory may be re-mapped.
// See https://crbug.com/815537 and base/test/gtest_util.h.
// Can remove this if flaky.
#if defined(GTEST_HAS_DEATH_TEST) && !defined(OS_ANDROID)
EXPECT_DEATH_IF_SUPPORTED(memset(memory_addr, 'G', 1), "");
#endif
}
} // namespace safe_browsing

@ -4,8 +4,10 @@
module safe_browsing.mojom;
import "components/safe_browsing/core/common/safe_browsing_url_checker.mojom";
import "mojo/public/mojom/base/read_only_file.mojom";
import "mojo/public/mojom/base/shared_memory.mojom";
import "services/network/public/mojom/http_request_headers.mojom";
import "services/network/public/mojom/fetch_api.mojom";
import "url/mojom/url.mojom";
@ -128,6 +130,16 @@ interface PhishingDetector {
// to classify the appearance of pages.
SetPhishingModel(string model, mojo_base.mojom.ReadOnlyFile? tflite_model);
// A classification model for client-side phishing detection. This call sends
// the model from the browser process to the renderer process. The model is
// sent as a safe_browsing::ClientSideModel flatbuffer string in a
// ReadOnlySharedMemoryRegion to client-side phishing detection on the
// renderer process. An invalid region is used to disable classification. The
// |tflite_model| is a file handle with contents a TfLite model, which is used
// to classify the appearance of pages.
SetPhishingFlatBufferModel(mojo_base.mojom.ReadOnlySharedMemoryRegion region,
mojo_base.mojom.ReadOnlyFile? tflite_model);
// Tells the renderer to begin phishing detection for the given toplevel URL
// which it has started loading. Returns the serialized request proto and a
// |result| enum to indicate failure. If the URL is phishing the request proto

@ -7,6 +7,7 @@
#include "base/bind.h"
#include "base/memory/ptr_util.h"
#include "base/memory/read_only_shared_memory_region.h"
#include "base/memory/scoped_refptr.h"
#include "base/run_loop.h"
#include "base/single_thread_task_runner.h"
@ -120,6 +121,9 @@ class TestPhishingDetector : public mojom::PhishingDetector {
void SetPhishingModel(const std::string& model, base::File file) override {}
void SetPhishingFlatBufferModel(base::ReadOnlySharedMemoryRegion region,
base::File file) override {}
void StartPhishingDetection(
const GURL& url,
StartPhishingDetectionCallback callback) override {

@ -37,6 +37,7 @@ source_set("phishing_classifier") {
"//components/safe_browsing/core:client_model_proto",
"//components/safe_browsing/core:csd_proto",
"//components/safe_browsing/core/common",
"//components/safe_browsing/core/fbs:client_model",
"//content/public/renderer",
"//crypto",
"//skia",
@ -70,6 +71,7 @@ source_set("unit_tests") {
"//components/safe_browsing/content/renderer/phishing_classifier:unit_tests_support",
"//components/safe_browsing/core:client_model_proto",
"//components/safe_browsing/core:csd_proto",
"//components/safe_browsing/core/fbs:client_model",
"//crypto",
"//skia",
"//testing/gmock",
@ -97,6 +99,7 @@ if (use_libfuzzer) {
":client_side_phishing_fuzzer_proto",
":phishing_classifier",
"//base:base",
"//components/safe_browsing/core/fbs:client_model",
"//skia",
"//third_party/libprotobuf-mutator",
]

@ -96,6 +96,23 @@ void PhishingClassifierDelegate::SetPhishingModel(
g_phishing_scorer.Get().reset(scorer);
}
void PhishingClassifierDelegate::SetPhishingFlatBufferModel(
base::ReadOnlySharedMemoryRegion flatbuffer_region,
base::File tflite_visual_model) {
safe_browsing::Scorer* scorer = nullptr;
// An invalid region or invalid model file means we should disable
// client-side phishing detection.
if (flatbuffer_region.IsValid() || tflite_visual_model.IsValid()) {
scorer = safe_browsing::Scorer::Create(std::move(flatbuffer_region),
std::move(tflite_visual_model));
if (!scorer)
return;
}
for (auto* delegate : PhishingClassifierDelegates())
delegate->SetPhishingScorer(scorer);
g_phishing_scorer.Get().reset(scorer);
}
// static
PhishingClassifierDelegate* PhishingClassifierDelegate::Create(
content::RenderFrame* render_frame,

@ -11,6 +11,7 @@
#include <string>
#include "base/macros.h"
#include "base/memory/read_only_shared_memory_region.h"
#include "components/safe_browsing/content/common/safe_browsing.mojom.h"
#include "content/public/renderer/render_frame_observer.h"
#include "content/public/renderer/render_thread_observer.h"
@ -52,6 +53,11 @@ class PhishingClassifierDelegate : public content::RenderFrameObserver,
void SetPhishingModel(const std::string& model,
base::File tflite_visual_model) override;
// mojom::PhishingDetector
void SetPhishingFlatBufferModel(
base::ReadOnlySharedMemoryRegion flatbuffer_region,
base::File tflite_visual_model) override;
// Called by the RenderFrame once there is a phishing scorer available.
// The scorer is passed on to the classifier.
void SetPhishingScorer(const safe_browsing::Scorer* scorer);

@ -10,6 +10,8 @@
#include <unordered_map>
#include <unordered_set>
#include "base/memory/read_only_shared_memory_region.h"
#include "base/memory/shared_memory_mapping.h"
#include "base/metrics/histogram_macros.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/string_piece.h"
@ -41,6 +43,9 @@ enum ScorerCreationStatus {
SCORER_FAIL_MODEL_PARSE_ERROR,
SCORER_FAIL_MODEL_MISSING_FIELDS,
SCORER_FAIL_MAP_VISUAL_TFLITE_MODEL,
SCORER_FAIL_FLATBUFFER_INVALID_REGION,
SCORER_FAIL_FLATBUFFER_INVALID_MAPPING,
SCORER_FAIL_FLATBUFFER_FAILED_VERIFY,
SCORER_STATUS_MAX // Always add new values before this one.
};
@ -255,6 +260,42 @@ Scorer* Scorer::Create(const base::StringPiece& model_str,
return scorer.release();
}
/* static */
Scorer* Scorer::Create(base::ReadOnlySharedMemoryRegion region,
base::File visual_tflite_model) {
std::unique_ptr<Scorer> scorer(new Scorer());
if (!region.IsValid()) {
RecordScorerCreationStatus(SCORER_FAIL_FLATBUFFER_INVALID_REGION);
return nullptr;
}
base::ReadOnlySharedMemoryMapping mapping = region.Map();
if (!mapping.IsValid()) {
RecordScorerCreationStatus(SCORER_FAIL_FLATBUFFER_INVALID_MAPPING);
return nullptr;
}
flatbuffers::Verifier verifier(
reinterpret_cast<const uint8_t*>(mapping.memory()), mapping.size());
if (!flat::VerifyClientSideModelBuffer(verifier)) {
RecordScorerCreationStatus(SCORER_FAIL_FLATBUFFER_FAILED_VERIFY);
return nullptr;
}
// Only do this part if the visual model file exists
if (visual_tflite_model.IsValid() && !scorer->visual_tflite_model_.Initialize(
std::move(visual_tflite_model))) {
RecordScorerCreationStatus(SCORER_FAIL_MAP_VISUAL_TFLITE_MODEL);
return nullptr;
}
RecordScorerCreationStatus(SCORER_SUCCESS);
scorer->flatbuffer_model_ = flat::GetClientSideModel(mapping.memory());
scorer->flatbuffer_mapping_ = std::move(mapping);
return scorer.release();
}
double Scorer::ComputeScore(const FeatureMap& features) const {
double logodds = 0.0;
for (int i = 0; i < model_.rule_size(); ++i) {

@ -24,8 +24,10 @@
#include "base/files/file.h"
#include "base/files/memory_mapped_file.h"
#include "base/macros.h"
#include "base/memory/read_only_shared_memory_region.h"
#include "base/memory/weak_ptr.h"
#include "base/strings/string_piece.h"
#include "components/safe_browsing/core/fbs/client_model_generated.h"
#include "components/safe_browsing/core/proto/client_model.pb.h"
#include "components/safe_browsing/core/proto/csd.pb.h"
#include "third_party/skia/include/core/SkBitmap.h"
@ -39,10 +41,17 @@ class Scorer {
virtual ~Scorer();
// Factory method which creates a new Scorer object by parsing the given
// model. If parsing fails this method returns NULL.
// model. If parsing fails this method returns NULL.
// Can use this if model_str is empty.
static Scorer* Create(const base::StringPiece& model_str,
base::File visual_tflite_model);
// Factory method which creates a new Scorer object by parsing the given
// flatbuffer or tflite model. If parsing fails this method returns NULL.
// Use this only if region is valid.
static Scorer* Create(base::ReadOnlySharedMemoryRegion region,
base::File visual_tflite_model);
// This method computes the probability that the given features are indicative
// of phishing. It returns a score value that falls in the range [0.0,1.0]
// (range is inclusive on both ends).
@ -122,6 +131,12 @@ class Scorer {
std::unordered_set<std::string> page_terms_;
std::unordered_set<uint32_t> page_words_;
// Unowned. Points within flatbuffer_mapping_ and should not be free()d.
// It remains valid till flatbuffer_mapping_ is valid and should be reassigned
// if the mapping is updated.
const flat::ClientSideModel* flatbuffer_model_;
base::ReadOnlySharedMemoryMapping flatbuffer_mapping_;
base::MemoryMappedFile visual_tflite_model_;
base::WeakPtrFactory<Scorer> weak_ptr_factory_{this};

@ -12,12 +12,14 @@
#include "base/files/file_path.h"
#include "base/files/scoped_temp_dir.h"
#include "base/format_macros.h"
#include "base/memory/read_only_shared_memory_region.h"
#include "base/run_loop.h"
#include "base/test/bind.h"
#include "base/test/task_environment.h"
#include "base/test/test_discardable_memory_allocator.h"
#include "base/threading/thread.h"
#include "components/safe_browsing/content/renderer/phishing_classifier/features.h"
#include "components/safe_browsing/core/fbs/client_model_generated.h"
#include "components/safe_browsing/core/proto/client_model.pb.h"
#include "components/safe_browsing/core/proto/csd.pb.h"
#include "testing/gmock/include/gmock/gmock.h"
@ -25,6 +27,26 @@
namespace safe_browsing {
namespace {
std::string GetFlatBufferString() {
flatbuffers::FlatBufferBuilder builder(1024);
flat::ClientSideModelBuilder csd_model_builder(builder);
builder.Finish(csd_model_builder.Finish());
return std::string(reinterpret_cast<char*>(builder.GetBufferPointer()),
builder.GetSize());
}
base::MappedReadOnlyRegion GetMappedReadOnlyRegionWithData(std::string data) {
base::MappedReadOnlyRegion mapped_region =
base::ReadOnlySharedMemoryRegion::Create(data.length());
EXPECT_TRUE(mapped_region.IsValid());
memcpy(mapped_region.mapping.memory(), data.data(), data.length());
return mapped_region;
}
} // namespace
class PhishingScorerTest : public ::testing::Test {
protected:
void SetUp() override {
@ -108,6 +130,25 @@ class PhishingScorerTest : public ::testing::Test {
base::TestDiscardableMemoryAllocator test_allocator_;
};
TEST_F(PhishingScorerTest, HasValidFlatBufferModel) {
std::unique_ptr<Scorer> scorer;
std::string flatbuffer = GetFlatBufferString();
base::MappedReadOnlyRegion mapped_region =
GetMappedReadOnlyRegionWithData(flatbuffer);
scorer.reset(Scorer::Create(mapped_region.region.Duplicate(), base::File()));
EXPECT_TRUE(scorer.get() != nullptr);
// Invalid region.
scorer.reset(
Scorer::Create(base::ReadOnlySharedMemoryRegion(), base::File()));
EXPECT_FALSE(scorer.get());
// Invalid buffer in region.
mapped_region = GetMappedReadOnlyRegionWithData("bogus string");
scorer.reset(Scorer::Create(mapped_region.region.Duplicate(), base::File()));
EXPECT_FALSE(scorer.get());
}
TEST_F(PhishingScorerTest, HasValidModel) {
std::unique_ptr<Scorer> scorer;
scorer.reset(Scorer::Create(model_.SerializeAsString(), base::File()));