0
Files
src/remoting/protocol/session_authz_authenticator_unittest.cc
Peter Boström e72e8c8dfb Use NOTREACHED_IN_MIGRATION() in remoting/
This was generated by replacing "  NOTREACHED()" with
"  NOTREACHED_IN_MIGRATION()" and running git cl format.

This prepares for making NOTREACHED() [[noreturn]] alongside
NotReachedIsFatal migration of existing inventory.

Bug: 40580068
Change-Id: I57cfd1b2a4e6c91529ab221cf7a7da622be26913
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5540125
Commit-Queue: Lei Zhang <thestig@chromium.org>
Owners-Override: Lei Zhang <thestig@chromium.org>
Commit-Queue: Peter Boström <pbos@chromium.org>
Auto-Submit: Peter Boström <pbos@chromium.org>
Reviewed-by: Lei Zhang <thestig@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1300942}
2024-05-14 22:40:32 +00:00

402 lines
15 KiB
C++

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "remoting/protocol/session_authz_authenticator.h"
#include <memory>
#include <tuple>
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/functional/callback_forward.h"
#include "base/memory/raw_ptr.h"
#include "base/notreached.h"
#include "base/run_loop.h"
#include "base/test/gmock_callback_support.h"
#include "base/test/gmock_move_support.h"
#include "base/test/mock_callback.h"
#include "base/time/time.h"
#include "net/http/http_status_code.h"
#include "remoting/base/mock_session_authz_service_client.h"
#include "remoting/base/protobuf_http_status.h"
#include "remoting/base/rsa_key_pair.h"
#include "remoting/base/session_authz_service_client.h"
#include "remoting/proto/session_authz_service.h"
#include "remoting/protocol/authenticator.h"
#include "remoting/protocol/authenticator_test_base.h"
#include "remoting/protocol/connection_tester.h"
#include "remoting/protocol/credentials_type.h"
#include "remoting/protocol/spake2_authenticator.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace remoting::protocol {
namespace {
using testing::_;
using testing::ByMove;
using testing::Return;
constexpr char kFakeHostToken[] = "fake_host_token";
constexpr char kFakeSessionId[] = "fake_session_id";
constexpr char kFakeSessionToken[] = "fake_session_token";
constexpr char kFakeSharedSecret[] = "fake_shared_secret";
constexpr char kFakeSessionReauthToken[] = "fake_session_reauth_token";
constexpr base::TimeDelta kFakeSessionReauthTokenLifetime = base::Minutes(5);
constexpr int kMessageSize = 100;
constexpr int kMessages = 1;
auto RespondGenerateHostToken() {
auto response = std::make_unique<internal::GenerateHostTokenResponseStruct>();
response->host_token = kFakeHostToken;
response->session_id = kFakeSessionId;
return base::test::RunOnceCallback<0>(ProtobufHttpStatus::OK(),
std::move(response));
}
auto RespondVerifySessionToken(
const std::string& session_id = kFakeSessionId,
const std::string& shared_secret = kFakeSharedSecret) {
auto response =
std::make_unique<internal::VerifySessionTokenResponseStruct>();
response->session_id = session_id;
response->shared_secret = shared_secret;
response->session_reauth_token = kFakeSessionReauthToken;
response->session_reauth_token_lifetime = kFakeSessionReauthTokenLifetime;
return base::test::RunOnceCallback<1>(ProtobufHttpStatus::OK(),
std::move(response));
}
class FakeClientAuthenticator : public Authenticator {
public:
explicit FakeClientAuthenticator(const CreateBaseAuthenticatorCallback&
create_base_authenticator_callback);
~FakeClientAuthenticator() override;
// Authenticator implementation.
CredentialsType credentials_type() const override;
const Authenticator& implementing_authenticator() const override;
State state() const override;
bool started() const override;
RejectionReason rejection_reason() const override;
void ProcessMessage(const jingle_xmpp::XmlElement* message,
base::OnceClosure resume_callback) override;
std::unique_ptr<jingle_xmpp::XmlElement> GetNextMessage() override;
const std::string& GetAuthKey() const override;
std::unique_ptr<ChannelAuthenticator> CreateChannelAuthenticator()
const override;
const std::string& host_token() const { return host_token_; }
// Used to simulate an invalid message from the underlying authenticator.
void set_underlying_authenticator_message_suppressed(bool suppressed) {
underlying_authenticator_message_suppressed_ = suppressed;
}
private:
enum class SessionAuthzState {
WAITING_FOR_HOST_TOKEN,
READY_TO_SEND_SESSION_TOKEN,
AUTHORIZED,
FAILED,
};
SessionAuthzState session_authz_state_ =
SessionAuthzState::WAITING_FOR_HOST_TOKEN;
RejectionReason session_authz_rejection_reason_;
CreateBaseAuthenticatorCallback create_base_authenticator_callback_;
std::unique_ptr<Authenticator> underlying_;
std::string host_token_;
std::unique_ptr<jingle_xmpp::XmlElement> message_;
base::OnceClosure resume_callback_;
bool underlying_authenticator_message_suppressed_ = false;
};
CredentialsType FakeClientAuthenticator::credentials_type() const {
return CredentialsType::CORP_SESSION_AUTHZ;
}
const Authenticator& FakeClientAuthenticator::implementing_authenticator()
const {
return *this;
}
Authenticator::State FakeClientAuthenticator::state() const {
switch (session_authz_state_) {
case SessionAuthzState::READY_TO_SEND_SESSION_TOKEN:
return MESSAGE_READY;
case SessionAuthzState::WAITING_FOR_HOST_TOKEN:
return WAITING_MESSAGE;
case SessionAuthzState::AUTHORIZED:
return underlying_->state();
case SessionAuthzState::FAILED:
return REJECTED;
}
}
bool FakeClientAuthenticator::started() const {
return session_authz_state_ != SessionAuthzState::WAITING_FOR_HOST_TOKEN;
}
FakeClientAuthenticator::FakeClientAuthenticator(
const CreateBaseAuthenticatorCallback& create_base_authenticator_callback)
: create_base_authenticator_callback_(create_base_authenticator_callback) {}
FakeClientAuthenticator::~FakeClientAuthenticator() = default;
Authenticator::RejectionReason FakeClientAuthenticator::rejection_reason()
const {
return session_authz_state_ == SessionAuthzState::FAILED
? session_authz_rejection_reason_
: underlying_->rejection_reason();
}
void FakeClientAuthenticator::ProcessMessage(
const jingle_xmpp::XmlElement* message,
base::OnceClosure resume_callback) {
switch (session_authz_state_) {
case SessionAuthzState::WAITING_FOR_HOST_TOKEN:
host_token_ =
message->TextNamed(SessionAuthzAuthenticator::kHostTokenTag);
ASSERT_FALSE(host_token_.empty());
session_authz_state_ = SessionAuthzState::READY_TO_SEND_SESSION_TOKEN;
underlying_ = create_base_authenticator_callback_.Run(kFakeSharedSecret,
MESSAGE_READY);
std::move(resume_callback).Run();
return;
case SessionAuthzState::AUTHORIZED:
underlying_->ProcessMessage(message, std::move(resume_callback));
return;
default:
NOTREACHED_IN_MIGRATION();
}
}
std::unique_ptr<jingle_xmpp::XmlElement>
FakeClientAuthenticator::GetNextMessage() {
EXPECT_EQ(state(), MESSAGE_READY);
std::unique_ptr<jingle_xmpp::XmlElement> message;
if (underlying_ && underlying_->state() == MESSAGE_READY) {
if (underlying_authenticator_message_suppressed_) {
underlying_->GetNextMessage();
} else {
message = underlying_->GetNextMessage();
}
}
if (!message) {
message = CreateEmptyAuthenticatorMessage();
}
if (session_authz_state_ == SessionAuthzState::READY_TO_SEND_SESSION_TOKEN) {
jingle_xmpp::XmlElement* session_token_element =
new jingle_xmpp::XmlElement(
SessionAuthzAuthenticator::kSessionTokenTag);
session_token_element->SetBodyText(kFakeSessionToken);
message->AddElement(session_token_element);
session_authz_state_ = SessionAuthzState::AUTHORIZED;
}
return message;
}
const std::string& FakeClientAuthenticator::GetAuthKey() const {
EXPECT_EQ(state(), ACCEPTED);
return underlying_->GetAuthKey();
}
std::unique_ptr<ChannelAuthenticator>
FakeClientAuthenticator::CreateChannelAuthenticator() const {
EXPECT_EQ(state(), ACCEPTED);
return underlying_->CreateChannelAuthenticator();
}
} // namespace
class SessionAuthzAuthenticatorTest : public AuthenticatorTestBase {
public:
SessionAuthzAuthenticatorTest();
~SessionAuthzAuthenticatorTest() override;
protected:
void SetUp() override;
// Caller must add an expectation for
// mock_service_client_->GenerateHostToken() and run the callback, otherwise
// this method will not return.
void StartAuthExchange();
raw_ptr<MockSessionAuthzServiceClient> mock_service_client_;
raw_ptr<SessionAuthzAuthenticator> host_authenticator_;
raw_ptr<FakeClientAuthenticator> client_authenticator_;
};
SessionAuthzAuthenticatorTest::SessionAuthzAuthenticatorTest() = default;
SessionAuthzAuthenticatorTest::~SessionAuthzAuthenticatorTest() = default;
void SessionAuthzAuthenticatorTest::SetUp() {
AuthenticatorTestBase::SetUp();
auto mock_service_client = std::make_unique<MockSessionAuthzServiceClient>();
mock_service_client_ = mock_service_client.get();
auto host_authenticator = std::make_unique<SessionAuthzAuthenticator>(
CredentialsType::CORP_SESSION_AUTHZ, std::move(mock_service_client),
base::BindRepeating(&Spake2Authenticator::CreateForHost, kHostId,
kClientId, host_cert_, key_pair_));
host_authenticator_ = host_authenticator.get();
host_ = std::move(host_authenticator);
auto client_authenticator =
std::make_unique<FakeClientAuthenticator>(base::BindRepeating(
&Spake2Authenticator::CreateForClient, kClientId, kHostId));
client_authenticator_ = client_authenticator.get();
client_ = std::move(client_authenticator);
}
void SessionAuthzAuthenticatorTest::StartAuthExchange() {
base::RunLoop run_loop;
host_authenticator_->Start(run_loop.QuitClosure());
run_loop.Run();
RunHostInitiatedAuthExchange();
}
TEST_F(SessionAuthzAuthenticatorTest, SuccessfulAuth) {
EXPECT_CALL(*mock_service_client_, GenerateHostToken(_))
.WillOnce(RespondGenerateHostToken());
internal::VerifySessionTokenRequestStruct
expected_verify_session_token_request;
expected_verify_session_token_request.session_token = kFakeSessionToken;
EXPECT_CALL(*mock_service_client_,
VerifySessionToken(expected_verify_session_token_request, _))
.WillOnce(RespondVerifySessionToken());
StartAuthExchange();
ASSERT_EQ(host_->state(), Authenticator::ACCEPTED);
ASSERT_EQ(client_->state(), Authenticator::ACCEPTED);
ASSERT_EQ(client_authenticator_->host_token(), kFakeHostToken);
// Verify that authenticated channels can be created after authentication.
client_auth_ = client_->CreateChannelAuthenticator();
host_auth_ = host_->CreateChannelAuthenticator();
RunChannelAuth(false);
StreamConnectionTester tester(host_socket_.get(), client_socket_.get(),
kMessageSize, kMessages);
base::RunLoop run_loop;
tester.Start(run_loop.QuitClosure());
run_loop.Run();
tester.CheckResults();
}
TEST_F(SessionAuthzAuthenticatorTest, GenerateHostToken_RpcError_Rejected) {
base::MockCallback<base::OnceClosure> mock_resume_callback;
SessionAuthzServiceClient::GenerateHostTokenCallback
generate_host_token_callback;
EXPECT_CALL(*mock_service_client_, GenerateHostToken(_))
.WillOnce(MoveArg<0>(&generate_host_token_callback));
EXPECT_CALL(mock_resume_callback, Run()).Times(0);
host_authenticator_->Start(mock_resume_callback.Get());
ASSERT_EQ(host_->state(), Authenticator::PROCESSING_MESSAGE);
EXPECT_CALL(mock_resume_callback, Run()).Times(1);
std::move(generate_host_token_callback)
.Run(ProtobufHttpStatus(ProtobufHttpStatus::Code::PERMISSION_DENIED,
"Permission denied"),
nullptr);
ASSERT_EQ(host_->state(), Authenticator::REJECTED);
ASSERT_EQ(host_->rejection_reason(),
Authenticator::RejectionReason::AUTHZ_POLICY_CHECK_FAILED);
}
TEST_F(SessionAuthzAuthenticatorTest, VerifySessionToken_RpcError_Rejected) {
SessionAuthzServiceClient::VerifySessionTokenCallback
verify_session_token_callback;
EXPECT_CALL(*mock_service_client_, GenerateHostToken(_))
.WillOnce(RespondGenerateHostToken());
EXPECT_CALL(*mock_service_client_, VerifySessionToken(_, _))
.WillOnce(MoveArg<1>(&verify_session_token_callback));
StartAuthExchange();
ASSERT_EQ(host_->state(), Authenticator::PROCESSING_MESSAGE);
std::move(verify_session_token_callback)
.Run(ProtobufHttpStatus(ProtobufHttpStatus::Code::PERMISSION_DENIED,
"Permission denied"),
nullptr);
ASSERT_EQ(host_->state(), Authenticator::REJECTED);
ASSERT_EQ(host_->rejection_reason(),
Authenticator::RejectionReason::AUTHZ_POLICY_CHECK_FAILED);
}
TEST_F(SessionAuthzAuthenticatorTest,
VerifySessionToken_SessionIdMismatch_Rejected) {
EXPECT_CALL(*mock_service_client_, GenerateHostToken(_))
.WillOnce(RespondGenerateHostToken());
EXPECT_CALL(*mock_service_client_, VerifySessionToken(_, _))
.WillOnce(RespondVerifySessionToken("mismatched_session_id"));
StartAuthExchange();
ASSERT_EQ(host_->state(), Authenticator::REJECTED);
ASSERT_EQ(host_->rejection_reason(),
Authenticator::RejectionReason::INVALID_ACCOUNT_ID);
}
TEST_F(SessionAuthzAuthenticatorTest,
UnderlyingAuthenticator_InvalidIncomingMessage_Rejected) {
client_authenticator_->set_underlying_authenticator_message_suppressed(true);
EXPECT_CALL(*mock_service_client_, GenerateHostToken(_))
.WillOnce(RespondGenerateHostToken());
EXPECT_CALL(*mock_service_client_, VerifySessionToken(_, _))
.WillOnce(RespondVerifySessionToken());
StartAuthExchange();
ASSERT_EQ(host_->state(), Authenticator::REJECTED);
ASSERT_EQ(host_->rejection_reason(),
Authenticator::RejectionReason::PROTOCOL_ERROR);
}
TEST_F(SessionAuthzAuthenticatorTest,
UnderlyingAuthenticator_MismatchedSharedSecret_NotAccepted) {
EXPECT_CALL(*mock_service_client_, GenerateHostToken(_))
.WillOnce(RespondGenerateHostToken());
EXPECT_CALL(*mock_service_client_, VerifySessionToken(_, _))
.WillOnce(RespondVerifySessionToken(kFakeSessionId,
"mismatched_shared_secret"));
StartAuthExchange();
// With the mismatched shared secret, the underlying authenticator will just
// believe the client has not sent enough data to complete the key exchange.
ASSERT_NE(host_->state(), Authenticator::ACCEPTED);
}
TEST_F(SessionAuthzAuthenticatorTest, ReauthorizationFailed_Rejected) {
EXPECT_CALL(*mock_service_client_, GenerateHostToken(_))
.WillOnce(RespondGenerateHostToken());
EXPECT_CALL(*mock_service_client_, VerifySessionToken(_, _))
.WillOnce(RespondVerifySessionToken());
base::MockCallback<base::RepeatingClosure>
state_change_after_accepted_callback;
host_->set_state_change_after_accepted_callback(
state_change_after_accepted_callback.Get());
EXPECT_CALL(state_change_after_accepted_callback, Run()).Times(0);
StartAuthExchange();
ASSERT_EQ(host_->state(), Authenticator::ACCEPTED);
ASSERT_EQ(client_->state(), Authenticator::ACCEPTED);
internal::ReauthorizeHostRequestStruct request;
request.session_id = kFakeSessionId;
request.session_reauth_token = kFakeSessionReauthToken;
EXPECT_CALL(*mock_service_client_, ReauthorizeHost(request, _))
.WillOnce(base::test::RunOnceCallback<1>(
ProtobufHttpStatus(net::HttpStatusCode::HTTP_FORBIDDEN), nullptr));
EXPECT_CALL(state_change_after_accepted_callback, Run());
task_environment_.FastForwardBy(kFakeSessionReauthTokenLifetime);
ASSERT_EQ(host_->state(), Authenticator::REJECTED);
ASSERT_EQ(host_->rejection_reason(),
Authenticator::RejectionReason::REAUTHZ_POLICY_CHECK_FAILED);
}
} // namespace remoting::protocol