0

[net] Refactor TCPSocketWin to allow alternative read/write impls.

TCPSocketWin uses base::win::ObjectWatcher to watch for completed
reads and writes. Under the hood, this incurs one PostTask per
signaled event [1], which comes with OS thread scheduling and task
queuing overhead. We want to test the hypothesis that using an
IO completion port to be notified of completed reads and writes
without going through PostTask improves page load time.

Steps:

1. [this CL] Make the methods of TCPSocketWin which implement
   the read/write operations pure virtual and implement them in
   TCPSocketDefaultWin. Introduce TCPSocketWin::Create which
   instantiates a TCPSocketDefaultWin. No behavior change in this CL.

2. [crrev.com/c/5627052] Introduce TcpSocketIoCompletionPortWin which
   uses an IO completion port to be notified of completed reads and
   writes, instead of an event watched by base::win::ObjectWatcher.
   Make TCPSocketWin::Create instantiate that class instead of
   TCPSocketDefaultWin when the new "cpSocketIoCompletionPortWin"
   feature is enabled.

[1] https://source.chromium.org/chromium/chromium/src/+/main:base/win/object_watcher.cc;l=87-92;drc=82d96b9965565487874b66b939ca2527a9b2783d

Bug: 40287434
Change-Id: I9d3a60cb3877138b89da03a844da413c6ec4d6e3
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5627341
Reviewed-by: Greg Thompson <grt@chromium.org>
Reviewed-by: Adam Rice <ricea@chromium.org>
Auto-Submit: Francois Pierre Doray <fdoray@chromium.org>
Reviewed-by: Matt Reynolds <mattreynolds@chromium.org>
Commit-Queue: Matt Reynolds <mattreynolds@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1352176}
This commit is contained in:
François Doray
2024-09-06 18:21:11 +00:00
committed by Chromium LUCI CQ
parent cf0e898948
commit dac6cb749d
16 changed files with 438 additions and 242 deletions

@ -490,7 +490,7 @@ void WifiDirectMedium::CreateAndConnectSocket(
}
auto ip_endpoint = net::IPEndPoint(*ip, port);
auto tcp_socket =
std::make_unique<net::TCPSocket>(nullptr, nullptr, net::NetLogSource());
net::TCPSocket::Create(nullptr, nullptr, net::NetLogSource());
tcp_socket->AdoptUnconnectedSocket(fd);
tcp_socket->Connect(
ip_endpoint,

@ -100,9 +100,10 @@ IN_PROC_BROWSER_TEST_F(TransferableSocketBrowserTest, TransferSocket) {
network_service_pending;
GetNetworkService()->BindTestInterfaceForTesting(
network_service_pending.InitWithNewPipeAndPassReceiver());
net::TCPSocket socket(nullptr, nullptr, net::NetLogSource());
socket.Open(net::AddressFamily::ADDRESS_FAMILY_IPV4);
socket.DetachFromThread();
std::unique_ptr<net::TCPSocket> socket =
net::TCPSocket::Create(nullptr, nullptr, net::NetLogSource());
socket->Open(net::AddressFamily::ADDRESS_FAMILY_IPV4);
socket->DetachFromThread();
net::IPEndPoint endpoint(net::IPAddress::IPv4Localhost(),
embedded_test_server()->port());
@ -110,7 +111,7 @@ IN_PROC_BROWSER_TEST_F(TransferableSocketBrowserTest, TransferSocket) {
content::GetIOThreadTaskRunner({})->PostTaskAndReplyWithResult(
FROM_HERE,
base::BindOnce(
&net::TCPSocket::Connect, base::Unretained(&socket), endpoint,
&net::TCPSocket::Connect, base::Unretained(socket.get()), endpoint,
base::BindLambdaForTesting([&connect_run_loop](int result) {
EXPECT_EQ(result, net::OK);
connect_run_loop.Quit();
@ -122,7 +123,7 @@ IN_PROC_BROWSER_TEST_F(TransferableSocketBrowserTest, TransferSocket) {
EXPECT_EQ(result, net::ERR_IO_PENDING);
}));
connect_run_loop.Run();
socket.DetachFromThread();
socket->DetachFromThread();
#if BUILDFLAG(IS_WIN)
// Obtain the running process id of the network service, as this is needed to
// duplicate the socket on Windows only.
@ -136,12 +137,12 @@ IN_PROC_BROWSER_TEST_F(TransferableSocketBrowserTest, TransferSocket) {
}
ASSERT_TRUE(network_process.IsValid());
network::TransferableSocket transferable(
socket.ReleaseSocketDescriptorForTesting(), network_process);
socket->ReleaseSocketDescriptorForTesting(), network_process);
#else
base::test::TestFuture<net::SocketDescriptor> socket_descriptor;
GetIOThreadTaskRunner({})->PostTaskAndReplyWithResult(
FROM_HERE, base::BindLambdaForTesting([&]() {
return socket.ReleaseSocketDescriptorForTesting();
return socket->ReleaseSocketDescriptorForTesting();
}),
socket_descriptor.GetCallback());
network::TransferableSocket transferable(socket_descriptor.Get());

@ -722,13 +722,14 @@ class NetworkServiceTestHelper::NetworkServiceTestImpl
void MakeRequestToServer(network::TransferableSocket transferred,
const net::IPEndPoint& endpoint,
MakeRequestToServerCallback callback) override {
net::TCPSocket socket(nullptr, nullptr, net::NetLogSource());
socket.AdoptConnectedSocket(transferred.TakeSocket(), endpoint);
std::unique_ptr<net::TCPSocket> socket =
net::TCPSocket::Create(nullptr, nullptr, net::NetLogSource());
socket->AdoptConnectedSocket(transferred.TakeSocket(), endpoint);
const std::string kRequest("GET / HTTP/1.0\r\n\r\n");
auto io_buffer = base::MakeRefCounted<net::StringIOBuffer>(kRequest);
int rv = socket.Write(io_buffer.get(), io_buffer->size(), base::DoNothing(),
TRAFFIC_ANNOTATION_FOR_TESTS);
int rv = socket->Write(io_buffer.get(), io_buffer->size(),
base::DoNothing(), TRAFFIC_ANNOTATION_FOR_TESTS);
// For purposes of tests, this IPC only supports sync Write calls.
DCHECK_NE(net::ERR_IO_PENDING, rv);
std::move(callback).Run(rv == static_cast<int>(kRequest.size()));

@ -99,7 +99,7 @@ void BluetoothSocketNet::ResetData() {
}
void BluetoothSocketNet::ResetTCPSocket() {
tcp_socket_.reset(new net::TCPSocket(NULL, NULL, net::NetLogSource()));
tcp_socket_ = net::TCPSocket::Create(nullptr, nullptr, net::NetLogSource());
}
void BluetoothSocketNet::SetTCPSocket(

@ -188,8 +188,8 @@ void BluetoothSocketWin::DoConnect(base::OnceClosure success_callback,
return;
}
std::unique_ptr<net::TCPSocket> scoped_socket(
new net::TCPSocket(NULL, NULL, net::NetLogSource()));
std::unique_ptr<net::TCPSocket> scoped_socket =
net::TCPSocket::Create(nullptr, nullptr, net::NetLogSource());
net::EnsureWinsockInit();
SOCKET socket_fd = socket(AF_BTH, SOCK_STREAM, BTHPROTO_RFCOMM);
SOCKADDR_BTH sa;
@ -255,8 +255,8 @@ void BluetoothSocketWin::DoListen(const BluetoothUUID& uuid,
// Note that |socket_fd| belongs to a non-TCP address family (i.e. AF_BTH),
// TCPSocket methods that involve address could not be called. So bind()
// is called on |socket_fd| directly.
std::unique_ptr<net::TCPSocket> scoped_socket(
new net::TCPSocket(NULL, NULL, net::NetLogSource()));
std::unique_ptr<net::TCPSocket> scoped_socket =
net::TCPSocket::Create(nullptr, nullptr, net::NetLogSource());
scoped_socket->AdoptUnconnectedSocket(socket_fd);
SOCKADDR_BTH sa;

@ -91,18 +91,20 @@ TEST(NetworkLibraryTest, BindToNetwork) {
NetworkChangeNotifierFactoryAndroid ncn_factory;
NetworkChangeNotifier::DisableForTest ncn_disable_for_test;
std::unique_ptr<NetworkChangeNotifier> ncn(ncn_factory.CreateInstance());
TCPSocket socket_tcp_ipv4(nullptr, nullptr, NetLogSource());
ASSERT_EQ(OK, socket_tcp_ipv4.Open(ADDRESS_FAMILY_IPV4));
TCPSocket socket_tcp_ipv6(nullptr, nullptr, NetLogSource());
ASSERT_EQ(OK, socket_tcp_ipv6.Open(ADDRESS_FAMILY_IPV6));
std::unique_ptr<TCPSocket> socket_tcp_ipv4 =
TCPSocket::Create(nullptr, nullptr, NetLogSource());
ASSERT_EQ(OK, socket_tcp_ipv4->Open(ADDRESS_FAMILY_IPV4));
std::unique_ptr<TCPSocket> socket_tcp_ipv6 =
TCPSocket::Create(nullptr, nullptr, NetLogSource());
ASSERT_EQ(OK, socket_tcp_ipv6->Open(ADDRESS_FAMILY_IPV6));
UDPSocket socket_udp_ipv4(DatagramSocket::DEFAULT_BIND, nullptr,
NetLogSource());
ASSERT_EQ(OK, socket_udp_ipv4.Open(ADDRESS_FAMILY_IPV4));
UDPSocket socket_udp_ipv6(DatagramSocket::DEFAULT_BIND, nullptr,
NetLogSource());
ASSERT_EQ(OK, socket_udp_ipv6.Open(ADDRESS_FAMILY_IPV6));
std::array sockets{socket_tcp_ipv4.SocketDescriptorForTesting(),
socket_tcp_ipv6.SocketDescriptorForTesting(),
std::array sockets{socket_tcp_ipv4->SocketDescriptorForTesting(),
socket_tcp_ipv6->SocketDescriptorForTesting(),
socket_udp_ipv4.SocketDescriptorForTesting(),
socket_udp_ipv6.SocketDescriptorForTesting()};

@ -37,15 +37,14 @@ TCPClientSocket::TCPClientSocket(
net::NetLog* net_log,
const net::NetLogSource& source,
handles::NetworkHandle network)
: TCPClientSocket(
std::make_unique<TCPSocket>(std::move(socket_performance_watcher),
net_log,
source),
addresses,
-1 /* current_address_index */,
nullptr /* bind_address */,
network_quality_estimator,
network) {}
: TCPClientSocket(TCPSocket::Create(std::move(socket_performance_watcher),
net_log,
source),
addresses,
-1 /* current_address_index */,
nullptr /* bind_address */,
network_quality_estimator,
network) {}
TCPClientSocket::TCPClientSocket(std::unique_ptr<TCPSocket> connected_socket,
const IPEndPoint& peer_address)

@ -19,9 +19,9 @@ namespace net {
TCPServerSocket::TCPServerSocket(NetLog* net_log, const NetLogSource& source)
: TCPServerSocket(
std::make_unique<TCPSocket>(nullptr /* socket_performance_watcher */,
net_log,
source)) {}
TCPSocket::Create(nullptr /* socket_performance_watcher */,
net_log,
source)) {}
TCPServerSocket::TCPServerSocket(std::unique_ptr<TCPSocket> socket)
: socket_(std::move(socket)) {}

@ -17,6 +17,7 @@
#include "base/functional/bind.h"
#include "base/lazy_instance.h"
#include "base/logging.h"
#include "base/memory/ptr_util.h"
#include "base/metrics/histogram_macros.h"
#include "base/posix/eintr_wrapper.h"
#include "base/strings/string_number_conversions.h"
@ -142,6 +143,23 @@ base::TimeDelta GetTransportRtt(SocketDescriptor fd) {
//-----------------------------------------------------------------------------
// static
std::unique_ptr<TCPSocketPosix> TCPSocketPosix::Create(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetLog* net_log,
const NetLogSource& source) {
return base::WrapUnique(new TCPSocketPosix(
std::move(socket_performance_watcher), net_log, source));
}
// static
std::unique_ptr<TCPSocketPosix> TCPSocketPosix::Create(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetLogWithSource net_log_source) {
return base::WrapUnique(new TCPSocketPosix(
std::move(socket_performance_watcher), net_log_source));
}
TCPSocketPosix::TCPSocketPosix(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetLog* net_log,
@ -540,8 +558,8 @@ int TCPSocketPosix::BuildTcpSocketPosix(
return ERR_ADDRESS_INVALID;
}
*tcp_socket = std::make_unique<TCPSocketPosix>(nullptr, net_log_.net_log(),
net_log_.source());
*tcp_socket =
TCPSocketPosix::Create(nullptr, net_log_.net_log(), net_log_.source());
(*tcp_socket)->socket_ = std::move(accept_socket_);
return OK;
}

@ -38,12 +38,11 @@ class NET_EXPORT TCPSocketPosix {
public:
// |socket_performance_watcher| is notified of the performance metrics related
// to this socket. |socket_performance_watcher| may be null.
TCPSocketPosix(
static std::unique_ptr<TCPSocketPosix> Create(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetLog* net_log,
const NetLogSource& source);
TCPSocketPosix(
static std::unique_ptr<TCPSocketPosix> Create(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetLogWithSource net_log_source);
@ -180,6 +179,14 @@ class NET_EXPORT TCPSocketPosix {
int BindToNetwork(handles::NetworkHandle network);
private:
TCPSocketPosix(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetLog* net_log,
const NetLogSource& source);
TCPSocketPosix(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetLogWithSource net_log_source);
void AcceptCompleted(std::unique_ptr<TCPSocketPosix>* tcp_socket,
IPEndPoint* address,
CompletionOnceCallback callback,

@ -114,27 +114,28 @@ const int kListenBacklog = 5;
class TCPSocketTest : public PlatformTest, public WithTaskEnvironment {
protected:
TCPSocketTest() : socket_(nullptr, nullptr, NetLogSource()) {}
TCPSocketTest()
: socket_(TCPSocket::Create(nullptr, nullptr, NetLogSource())) {}
void SetUpListenIPv4() {
ASSERT_THAT(socket_.Open(ADDRESS_FAMILY_IPV4), IsOk());
ASSERT_THAT(socket_.Bind(IPEndPoint(IPAddress::IPv4Localhost(), 0)),
ASSERT_THAT(socket_->Open(ADDRESS_FAMILY_IPV4), IsOk());
ASSERT_THAT(socket_->Bind(IPEndPoint(IPAddress::IPv4Localhost(), 0)),
IsOk());
ASSERT_THAT(socket_.Listen(kListenBacklog), IsOk());
ASSERT_THAT(socket_.GetLocalAddress(&local_address_), IsOk());
ASSERT_THAT(socket_->Listen(kListenBacklog), IsOk());
ASSERT_THAT(socket_->GetLocalAddress(&local_address_), IsOk());
}
void SetUpListenIPv6(bool* success) {
*success = false;
if (socket_.Open(ADDRESS_FAMILY_IPV6) != OK ||
socket_.Bind(IPEndPoint(IPAddress::IPv6Localhost(), 0)) != OK ||
socket_.Listen(kListenBacklog) != OK) {
if (socket_->Open(ADDRESS_FAMILY_IPV6) != OK ||
socket_->Bind(IPEndPoint(IPAddress::IPv6Localhost(), 0)) != OK ||
socket_->Listen(kListenBacklog) != OK) {
LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is "
"disabled. Skipping the test";
return;
}
ASSERT_THAT(socket_.GetLocalAddress(&local_address_), IsOk());
ASSERT_THAT(socket_->GetLocalAddress(&local_address_), IsOk());
*success = true;
}
@ -142,8 +143,8 @@ class TCPSocketTest : public PlatformTest, public WithTaskEnvironment {
TestCompletionCallback accept_callback;
std::unique_ptr<TCPSocket> accepted_socket;
IPEndPoint accepted_address;
ASSERT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
accept_callback.callback()),
ASSERT_THAT(socket_->Accept(&accepted_socket, &accepted_address,
accept_callback.callback()),
IsError(ERR_IO_PENDING));
TestCompletionCallback connect_callback;
@ -182,18 +183,19 @@ class TCPSocketTest : public PlatformTest, public WithTaskEnvironment {
should_notify_updated_rtt);
TestSocketPerformanceWatcher* watcher_ptr = watcher.get();
TCPSocket connecting_socket(std::move(watcher), nullptr, NetLogSource());
std::unique_ptr<TCPSocket> connecting_socket =
TCPSocket::Create(std::move(watcher), nullptr, NetLogSource());
int result = connecting_socket.Open(ADDRESS_FAMILY_IPV4);
int result = connecting_socket->Open(ADDRESS_FAMILY_IPV4);
ASSERT_THAT(result, IsOk());
int connect_result =
connecting_socket.Connect(local_address_, connect_callback.callback());
connecting_socket->Connect(local_address_, connect_callback.callback());
TestCompletionCallback accept_callback;
std::unique_ptr<TCPSocket> accepted_socket;
IPEndPoint accepted_address;
result = socket_.Accept(&accepted_socket, &accepted_address,
accept_callback.callback());
result = socket_->Accept(&accepted_socket, &accepted_address,
accept_callback.callback());
ASSERT_THAT(accept_callback.GetResult(result), IsOk());
ASSERT_TRUE(accepted_socket.get());
@ -220,7 +222,7 @@ class TCPSocketTest : public PlatformTest, public WithTaskEnvironment {
scoped_refptr<IOBufferWithSize> read_buffer =
base::MakeRefCounted<IOBufferWithSize>(message.size());
TestCompletionCallback read_callback;
int read_result = connecting_socket.Read(
int read_result = connecting_socket->Read(
read_buffer.get(), read_buffer->size(), read_callback.callback());
ASSERT_EQ(1, write_callback.GetResult(write_result));
@ -237,7 +239,7 @@ class TCPSocketTest : public PlatformTest, public WithTaskEnvironment {
return AddressList(local_address_);
}
TCPSocket socket_;
std::unique_ptr<TCPSocket> socket_;
IPEndPoint local_address_;
};
@ -255,8 +257,8 @@ TEST_F(TCPSocketTest, Accept) {
TestCompletionCallback accept_callback;
std::unique_ptr<TCPSocket> accepted_socket;
IPEndPoint accepted_address;
int result = socket_.Accept(&accepted_socket, &accepted_address,
accept_callback.callback());
int result = socket_->Accept(&accepted_socket, &accepted_address,
accept_callback.callback());
ASSERT_THAT(accept_callback.GetResult(result), IsOk());
EXPECT_TRUE(accepted_socket.get());
@ -275,12 +277,13 @@ TEST_F(TCPSocketTest, AcceptAsync) {
// Test AdoptConnectedSocket()
TEST_F(TCPSocketTest, AdoptConnectedSocket) {
TCPSocket accepting_socket(nullptr, nullptr, NetLogSource());
ASSERT_THAT(accepting_socket.Open(ADDRESS_FAMILY_IPV4), IsOk());
ASSERT_THAT(accepting_socket.Bind(IPEndPoint(IPAddress::IPv4Localhost(), 0)),
std::unique_ptr<TCPSocket> accepting_socket =
TCPSocket::Create(nullptr, nullptr, NetLogSource());
ASSERT_THAT(accepting_socket->Open(ADDRESS_FAMILY_IPV4), IsOk());
ASSERT_THAT(accepting_socket->Bind(IPEndPoint(IPAddress::IPv4Localhost(), 0)),
IsOk());
ASSERT_THAT(accepting_socket.GetLocalAddress(&local_address_), IsOk());
ASSERT_THAT(accepting_socket.Listen(kListenBacklog), IsOk());
ASSERT_THAT(accepting_socket->GetLocalAddress(&local_address_), IsOk());
ASSERT_THAT(accepting_socket->Listen(kListenBacklog), IsOk());
TestCompletionCallback connect_callback;
// TODO(yzshen): Switch to use TCPSocket when it supports client socket
@ -292,20 +295,20 @@ TEST_F(TCPSocketTest, AdoptConnectedSocket) {
TestCompletionCallback accept_callback;
std::unique_ptr<TCPSocket> accepted_socket;
IPEndPoint accepted_address;
int result = accepting_socket.Accept(&accepted_socket, &accepted_address,
accept_callback.callback());
int result = accepting_socket->Accept(&accepted_socket, &accepted_address,
accept_callback.callback());
ASSERT_THAT(accept_callback.GetResult(result), IsOk());
SocketDescriptor accepted_descriptor =
accepted_socket->ReleaseSocketDescriptorForTesting();
ASSERT_THAT(
socket_.AdoptConnectedSocket(accepted_descriptor, accepted_address),
socket_->AdoptConnectedSocket(accepted_descriptor, accepted_address),
IsOk());
// socket_ should now have the local address.
IPEndPoint adopted_address;
ASSERT_THAT(socket_.GetLocalAddress(&adopted_address), IsOk());
ASSERT_THAT(socket_->GetLocalAddress(&adopted_address), IsOk());
EXPECT_EQ(local_address_.address(), adopted_address.address());
EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
@ -315,15 +318,15 @@ TEST_F(TCPSocketTest, AdoptConnectedSocket) {
TEST_F(TCPSocketTest, AcceptForAdoptedUnconnectedSocket) {
SocketDescriptor existing_socket =
CreatePlatformSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
ASSERT_THAT(socket_.AdoptUnconnectedSocket(existing_socket), IsOk());
ASSERT_THAT(socket_->AdoptUnconnectedSocket(existing_socket), IsOk());
IPEndPoint address(IPAddress::IPv4Localhost(), 0);
SockaddrStorage storage;
ASSERT_TRUE(address.ToSockAddr(storage.addr, &storage.addr_len));
ASSERT_EQ(0, bind(existing_socket, storage.addr, storage.addr_len));
ASSERT_THAT(socket_.Listen(kListenBacklog), IsOk());
ASSERT_THAT(socket_.GetLocalAddress(&local_address_), IsOk());
ASSERT_THAT(socket_->Listen(kListenBacklog), IsOk());
ASSERT_THAT(socket_->GetLocalAddress(&local_address_), IsOk());
TestAcceptAsync();
}
@ -336,8 +339,8 @@ TEST_F(TCPSocketTest, Accept2Connections) {
std::unique_ptr<TCPSocket> accepted_socket;
IPEndPoint accepted_address;
ASSERT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
accept_callback.callback()),
ASSERT_THAT(socket_->Accept(&accepted_socket, &accepted_address,
accept_callback.callback()),
IsError(ERR_IO_PENDING));
TestCompletionCallback connect_callback;
@ -357,8 +360,8 @@ TEST_F(TCPSocketTest, Accept2Connections) {
std::unique_ptr<TCPSocket> accepted_socket2;
IPEndPoint accepted_address2;
int result = socket_.Accept(&accepted_socket2, &accepted_address2,
accept_callback2.callback());
int result = socket_->Accept(&accepted_socket2, &accepted_address2,
accept_callback2.callback());
ASSERT_THAT(accept_callback2.GetResult(result), IsOk());
EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
@ -387,8 +390,8 @@ TEST_F(TCPSocketTest, AcceptIPv6) {
TestCompletionCallback accept_callback;
std::unique_ptr<TCPSocket> accepted_socket;
IPEndPoint accepted_address;
int result = socket_.Accept(&accepted_socket, &accepted_address,
accept_callback.callback());
int result = socket_->Accept(&accepted_socket, &accepted_address,
accept_callback.callback());
ASSERT_THAT(accept_callback.GetResult(result), IsOk());
EXPECT_TRUE(accepted_socket.get());
@ -403,17 +406,18 @@ TEST_F(TCPSocketTest, ReadWrite) {
ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
TestCompletionCallback connect_callback;
TCPSocket connecting_socket(nullptr, nullptr, NetLogSource());
int result = connecting_socket.Open(ADDRESS_FAMILY_IPV4);
std::unique_ptr<TCPSocket> connecting_socket =
TCPSocket::Create(nullptr, nullptr, NetLogSource());
int result = connecting_socket->Open(ADDRESS_FAMILY_IPV4);
ASSERT_THAT(result, IsOk());
int connect_result =
connecting_socket.Connect(local_address_, connect_callback.callback());
connecting_socket->Connect(local_address_, connect_callback.callback());
TestCompletionCallback accept_callback;
std::unique_ptr<TCPSocket> accepted_socket;
IPEndPoint accepted_address;
result = socket_.Accept(&accepted_socket, &accepted_address,
accept_callback.callback());
result = socket_->Accept(&accepted_socket, &accepted_address,
accept_callback.callback());
ASSERT_THAT(accept_callback.GetResult(result), IsOk());
ASSERT_TRUE(accepted_socket.get());
@ -448,7 +452,7 @@ TEST_F(TCPSocketTest, ReadWrite) {
scoped_refptr<IOBufferWithSize> read_buffer =
base::MakeRefCounted<IOBufferWithSize>(message.size() - bytes_read);
TestCompletionCallback read_callback;
int read_result = connecting_socket.Read(
int read_result = connecting_socket->Read(
read_buffer.get(), read_buffer->size(), read_callback.callback());
read_result = read_callback.GetResult(read_result);
ASSERT_TRUE(read_result >= 0);
@ -471,7 +475,7 @@ TEST_F(TCPSocketTest, DestroyWithPendingRead) {
TestCompletionCallback connect_callback;
std::unique_ptr<TCPSocket> connecting_socket =
std::make_unique<TCPSocket>(nullptr, nullptr, NetLogSource());
TCPSocket::Create(nullptr, nullptr, NetLogSource());
int result = connecting_socket->Open(ADDRESS_FAMILY_IPV4);
ASSERT_THAT(result, IsOk());
int connect_result =
@ -480,8 +484,8 @@ TEST_F(TCPSocketTest, DestroyWithPendingRead) {
TestCompletionCallback accept_callback;
std::unique_ptr<TCPSocket> accepted_socket;
IPEndPoint accepted_address;
result = socket_.Accept(&accepted_socket, &accepted_address,
accept_callback.callback());
result = socket_->Accept(&accepted_socket, &accepted_address,
accept_callback.callback());
ASSERT_THAT(accept_callback.GetResult(result), IsOk());
ASSERT_TRUE(accepted_socket.get());
ASSERT_THAT(connect_callback.GetResult(connect_result), IsOk());
@ -512,7 +516,7 @@ TEST_F(TCPSocketTest, DestroyWithPendingWrite) {
TestCompletionCallback connect_callback;
std::unique_ptr<TCPSocket> connecting_socket =
std::make_unique<TCPSocket>(nullptr, nullptr, NetLogSource());
TCPSocket::Create(nullptr, nullptr, NetLogSource());
int result = connecting_socket->Open(ADDRESS_FAMILY_IPV4);
ASSERT_THAT(result, IsOk());
int connect_result =
@ -521,8 +525,8 @@ TEST_F(TCPSocketTest, DestroyWithPendingWrite) {
TestCompletionCallback accept_callback;
std::unique_ptr<TCPSocket> accepted_socket;
IPEndPoint accepted_address;
result = socket_.Accept(&accepted_socket, &accepted_address,
accept_callback.callback());
result = socket_->Accept(&accepted_socket, &accepted_address,
accept_callback.callback());
ASSERT_THAT(accept_callback.GetResult(result), IsOk());
ASSERT_TRUE(accepted_socket.get());
ASSERT_THAT(connect_callback.GetResult(connect_result), IsOk());
@ -558,7 +562,7 @@ TEST_F(TCPSocketTest, CancelPendingReadIfReady) {
// Create a connected socket.
TestCompletionCallback connect_callback;
std::unique_ptr<TCPSocket> connecting_socket =
std::make_unique<TCPSocket>(nullptr, nullptr, NetLogSource());
TCPSocket::Create(nullptr, nullptr, NetLogSource());
int result = connecting_socket->Open(ADDRESS_FAMILY_IPV4);
ASSERT_THAT(result, IsOk());
int connect_result =
@ -567,8 +571,8 @@ TEST_F(TCPSocketTest, CancelPendingReadIfReady) {
TestCompletionCallback accept_callback;
std::unique_ptr<TCPSocket> accepted_socket;
IPEndPoint accepted_address;
result = socket_.Accept(&accepted_socket, &accepted_address,
accept_callback.callback());
result = socket_->Accept(&accepted_socket, &accepted_address,
accept_callback.callback());
ASSERT_THAT(accept_callback.GetResult(result), IsOk());
ASSERT_TRUE(accepted_socket.get());
ASSERT_THAT(connect_callback.GetResult(connect_result), IsOk());
@ -617,8 +621,8 @@ TEST_F(TCPSocketTest, IsConnected) {
TestCompletionCallback accept_callback;
std::unique_ptr<TCPSocket> accepted_socket;
IPEndPoint accepted_address;
EXPECT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
accept_callback.callback()),
EXPECT_THAT(socket_->Accept(&accepted_socket, &accepted_address,
accept_callback.callback()),
IsError(ERR_IO_PENDING));
TestCompletionCallback connect_callback;
@ -696,8 +700,8 @@ TEST_F(TCPSocketTest, BeforeConnectCallback) {
TestCompletionCallback accept_callback;
std::unique_ptr<TCPSocket> accepted_socket;
IPEndPoint accepted_address;
EXPECT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
accept_callback.callback()),
EXPECT_THAT(socket_->Accept(&accepted_socket, &accepted_address,
accept_callback.callback()),
IsError(ERR_IO_PENDING));
TestCompletionCallback connect_callback;
@ -741,8 +745,8 @@ TEST_F(TCPSocketTest, BeforeConnectCallbackFails) {
TestCompletionCallback accept_callback;
std::unique_ptr<TCPSocket> accepted_socket;
IPEndPoint accepted_address;
EXPECT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
accept_callback.callback()),
EXPECT_THAT(socket_->Accept(&accepted_socket, &accepted_address,
accept_callback.callback()),
IsError(ERR_IO_PENDING));
TestCompletionCallback connect_callback;
@ -769,8 +773,8 @@ TEST_F(TCPSocketTest, SetKeepAlive) {
TestCompletionCallback accept_callback;
std::unique_ptr<TCPSocket> accepted_socket;
IPEndPoint accepted_address;
EXPECT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
accept_callback.callback()),
EXPECT_THAT(socket_->Accept(&accepted_socket, &accepted_address,
accept_callback.callback()),
IsError(ERR_IO_PENDING));
TestCompletionCallback connect_callback;
@ -801,8 +805,8 @@ TEST_F(TCPSocketTest, SetNoDelay) {
TestCompletionCallback accept_callback;
std::unique_ptr<TCPSocket> accepted_socket;
IPEndPoint accepted_address;
EXPECT_THAT(socket_.Accept(&accepted_socket, &accepted_address,
accept_callback.callback()),
EXPECT_THAT(socket_->Accept(&accepted_socket, &accepted_address,
accept_callback.callback()),
IsError(ERR_IO_PENDING));
TestCompletionCallback connect_callback;
@ -856,16 +860,16 @@ TEST_F(TCPSocketTest, Tag) {
AddressList addr_list;
ASSERT_TRUE(test_server.GetAddressList(&addr_list));
EXPECT_EQ(socket_.Open(addr_list[0].GetFamily()), OK);
EXPECT_EQ(socket_->Open(addr_list[0].GetFamily()), OK);
// Verify TCP connect packets are tagged and counted properly.
int32_t tag_val1 = 0x12345678;
uint64_t old_traffic = GetTaggedBytes(tag_val1);
SocketTag tag1(SocketTag::UNSET_UID, tag_val1);
socket_.ApplySocketTag(tag1);
socket_->ApplySocketTag(tag1);
TestCompletionCallback connect_callback;
int connect_result =
socket_.Connect(addr_list[0], connect_callback.callback());
socket_->Connect(addr_list[0], connect_callback.callback());
EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
@ -874,32 +878,32 @@ TEST_F(TCPSocketTest, Tag) {
int32_t tag_val2 = 0x87654321;
old_traffic = GetTaggedBytes(tag_val2);
SocketTag tag2(getuid(), tag_val2);
socket_.ApplySocketTag(tag2);
socket_->ApplySocketTag(tag2);
const char kRequest1[] = "GET / HTTP/1.0";
scoped_refptr<IOBuffer> write_buffer1 =
base::MakeRefCounted<StringIOBuffer>(kRequest1);
TestCompletionCallback write_callback1;
EXPECT_EQ(
socket_.Write(write_buffer1.get(), strlen(kRequest1),
write_callback1.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
socket_->Write(write_buffer1.get(), strlen(kRequest1),
write_callback1.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
static_cast<int>(strlen(kRequest1)));
EXPECT_GT(GetTaggedBytes(tag_val2), old_traffic);
// Verify socket can be retagged with a new value and the current process's
// UID.
old_traffic = GetTaggedBytes(tag_val1);
socket_.ApplySocketTag(tag1);
socket_->ApplySocketTag(tag1);
const char kRequest2[] = "\n\n";
scoped_refptr<IOBuffer> write_buffer2 =
base::MakeRefCounted<StringIOBuffer>(kRequest2);
TestCompletionCallback write_callback2;
EXPECT_EQ(
socket_.Write(write_buffer2.get(), strlen(kRequest2),
write_callback2.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
socket_->Write(write_buffer2.get(), strlen(kRequest2),
write_callback2.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
static_cast<int>(strlen(kRequest2)));
EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
socket_.Close();
socket_->Close();
}
TEST_F(TCPSocketTest, TagAfterConnect) {
@ -915,12 +919,12 @@ TEST_F(TCPSocketTest, TagAfterConnect) {
AddressList addr_list;
ASSERT_TRUE(test_server.GetAddressList(&addr_list));
EXPECT_EQ(socket_.Open(addr_list[0].GetFamily()), OK);
EXPECT_EQ(socket_->Open(addr_list[0].GetFamily()), OK);
// Connect socket.
TestCompletionCallback connect_callback;
int connect_result =
socket_.Connect(addr_list[0], connect_callback.callback());
socket_->Connect(addr_list[0], connect_callback.callback());
EXPECT_THAT(connect_callback.GetResult(connect_result), IsOk());
// Verify socket can be tagged with a new value and the current process's
@ -928,14 +932,14 @@ TEST_F(TCPSocketTest, TagAfterConnect) {
int32_t tag_val2 = 0x87654321;
uint64_t old_traffic = GetTaggedBytes(tag_val2);
SocketTag tag2(getuid(), tag_val2);
socket_.ApplySocketTag(tag2);
socket_->ApplySocketTag(tag2);
const char kRequest1[] = "GET / HTTP/1.0";
scoped_refptr<IOBuffer> write_buffer1 =
base::MakeRefCounted<StringIOBuffer>(kRequest1);
TestCompletionCallback write_callback1;
EXPECT_EQ(
socket_.Write(write_buffer1.get(), strlen(kRequest1),
write_callback1.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
socket_->Write(write_buffer1.get(), strlen(kRequest1),
write_callback1.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
static_cast<int>(strlen(kRequest1)));
EXPECT_GT(GetTaggedBytes(tag_val2), old_traffic);
@ -944,18 +948,18 @@ TEST_F(TCPSocketTest, TagAfterConnect) {
int32_t tag_val1 = 0x12345678;
old_traffic = GetTaggedBytes(tag_val1);
SocketTag tag1(SocketTag::UNSET_UID, tag_val1);
socket_.ApplySocketTag(tag1);
socket_->ApplySocketTag(tag1);
const char kRequest2[] = "\n\n";
scoped_refptr<IOBuffer> write_buffer2 =
base::MakeRefCounted<StringIOBuffer>(kRequest2);
TestCompletionCallback write_callback2;
EXPECT_EQ(
socket_.Write(write_buffer2.get(), strlen(kRequest2),
write_callback2.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
socket_->Write(write_buffer2.get(), strlen(kRequest2),
write_callback2.callback(), TRAFFIC_ANNOTATION_FOR_TESTS),
static_cast<int>(strlen(kRequest2)));
EXPECT_GT(GetTaggedBytes(tag_val1), old_traffic);
socket_.Close();
socket_->Close();
}
TEST_F(TCPSocketTest, BindToNetwork) {

@ -107,17 +107,65 @@ bool SetNonBlockingAndGetError(int fd, int* os_error) {
//-----------------------------------------------------------------------------
// This class encapsulates all the state that has to be preserved as long as
// there is a network IO operation in progress. If the owner TCPSocketWin is
// destroyed while an operation is in progress, the Core is detached and it
// lives until the operation completes and the OS doesn't reference any resource
// declared on this class anymore.
class TCPSocketWin::Core : public base::RefCounted<Core> {
class NET_EXPORT TCPSocketDefaultWin : public TCPSocketWin {
public:
explicit Core(TCPSocketWin* socket);
TCPSocketDefaultWin(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetLog* net_log,
const NetLogSource& source);
Core(const Core&) = delete;
Core& operator=(const Core&) = delete;
TCPSocketDefaultWin(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetLogWithSource net_log_source);
~TCPSocketDefaultWin() override;
// TCPSocketWin:
int Read(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) override;
int ReadIfReady(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) override;
int CancelReadIfReady() override;
int Write(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback,
const NetworkTrafficAnnotationTag& traffic_annotation) override;
protected:
// TCPSocketWin:
scoped_refptr<Core> CreateCore() override;
bool HasPendingRead() const override;
void OnClosed() override;
private:
class CoreImpl;
void RetryRead(int rv);
void DidCompleteWrite();
void DidSignalRead();
CoreImpl& GetCoreImpl();
// External callback; called when read is complete.
CompletionOnceCallback read_callback_;
// Non-null if a ReadIfReady() is to be completed asynchronously. This is an
// external callback if user used ReadIfReady() instead of Read(), but a
// wrapped callback on top of RetryRead() if Read() is used.
CompletionOnceCallback read_if_ready_callback_;
// External callback; called when write is complete.
CompletionOnceCallback write_callback_;
};
class TCPSocketDefaultWin::CoreImpl : public TCPSocketWin::Core {
public:
explicit CoreImpl(TCPSocketDefaultWin* socket);
CoreImpl(const CoreImpl&) = delete;
CoreImpl& operator=(const CoreImpl&) = delete;
// Start watching for the end of a read or write operation.
void WatchForRead();
@ -126,8 +174,10 @@ class TCPSocketWin::Core : public base::RefCounted<Core> {
// Stops watching for read.
void StopWatchingForRead();
// The TCPSocketWin is going away.
void Detach();
// TCPSocketWin::Core:
void Detach() override;
HANDLE GetConnectEvent() override;
void WatchForConnect() override;
// Event handle for monitoring connect and read events through WSAEventSelect.
HANDLE read_event_;
@ -147,36 +197,34 @@ class TCPSocketWin::Core : public base::RefCounted<Core> {
bool non_blocking_reads_initialized_ = false;
private:
friend class base::RefCounted<Core>;
class ReadDelegate : public base::win::ObjectWatcher::Delegate {
public:
explicit ReadDelegate(Core* core) : core_(core) {}
explicit ReadDelegate(CoreImpl* core) : core_(core) {}
~ReadDelegate() override = default;
// base::ObjectWatcher::Delegate methods:
void OnObjectSignaled(HANDLE object) override;
private:
const raw_ptr<Core> core_;
const raw_ptr<CoreImpl> core_;
};
class WriteDelegate : public base::win::ObjectWatcher::Delegate {
public:
explicit WriteDelegate(Core* core) : core_(core) {}
explicit WriteDelegate(CoreImpl* core) : core_(core) {}
~WriteDelegate() override = default;
// base::ObjectWatcher::Delegate methods:
void OnObjectSignaled(HANDLE object) override;
private:
const raw_ptr<Core> core_;
const raw_ptr<CoreImpl> core_;
};
~Core();
~CoreImpl() override;
// The socket that created this object.
raw_ptr<TCPSocketWin> socket_;
raw_ptr<TCPSocketDefaultWin> socket_;
// |reader_| handles the signals from |read_watcher_|.
ReadDelegate reader_;
@ -189,7 +237,10 @@ class TCPSocketWin::Core : public base::RefCounted<Core> {
base::win::ObjectWatcher write_watcher_;
};
TCPSocketWin::Core::Core(TCPSocketWin* socket)
TCPSocketWin::Core::Core() = default;
TCPSocketWin::Core::~Core() = default;
TCPSocketDefaultWin::CoreImpl::CoreImpl(TCPSocketDefaultWin* socket)
: read_event_(WSACreateEvent()),
socket_(socket),
reader_(this),
@ -198,7 +249,7 @@ TCPSocketWin::Core::Core(TCPSocketWin* socket)
write_overlapped_.hEvent = WSACreateEvent();
}
TCPSocketWin::Core::~Core() {
TCPSocketDefaultWin::CoreImpl::~CoreImpl() {
// Detach should already have been called.
DCHECK(!socket_);
@ -209,26 +260,26 @@ TCPSocketWin::Core::~Core() {
memset(&write_overlapped_, 0xaf, sizeof(write_overlapped_));
}
void TCPSocketWin::Core::WatchForRead() {
void TCPSocketDefaultWin::CoreImpl::WatchForRead() {
// Reads use WSAEventSelect, which closesocket() cancels so unlike writes,
// there's no need to increment the reference count here.
read_watcher_.StartWatchingOnce(read_event_, &reader_);
}
void TCPSocketWin::Core::WatchForWrite() {
void TCPSocketDefaultWin::CoreImpl::WatchForWrite() {
// We grab an extra reference because there is an IO operation in progress.
// Balanced in WriteDelegate::OnObjectSignaled().
AddRef();
write_watcher_.StartWatchingOnce(write_overlapped_.hEvent, &writer_);
}
void TCPSocketWin::Core::StopWatchingForRead() {
void TCPSocketDefaultWin::CoreImpl::StopWatchingForRead() {
DCHECK(!socket_->connect_callback_);
read_watcher_.StopWatching();
}
void TCPSocketWin::Core::Detach() {
void TCPSocketDefaultWin::CoreImpl::Detach() {
// Stop watching the read watcher. A read won't be signalled after the Detach
// call, since the socket has been closed, but it's possible the event was
// signalled when the socket was closed, but hasn't been handled yet, so need
@ -240,7 +291,18 @@ void TCPSocketWin::Core::Detach() {
socket_ = nullptr;
}
void TCPSocketWin::Core::ReadDelegate::OnObjectSignaled(HANDLE object) {
HANDLE TCPSocketDefaultWin::CoreImpl::GetConnectEvent() {
// `read_event_` is used to watch for connect.
return read_event_;
}
void TCPSocketDefaultWin::CoreImpl::WatchForConnect() {
// `read_event_` is used to watch for connect.
WatchForRead();
}
void TCPSocketDefaultWin::CoreImpl::ReadDelegate::OnObjectSignaled(
HANDLE object) {
DCHECK_EQ(object, core_->read_event_);
DCHECK(core_->socket_);
if (core_->socket_->connect_callback_) {
@ -250,7 +312,7 @@ void TCPSocketWin::Core::ReadDelegate::OnObjectSignaled(HANDLE object) {
}
}
void TCPSocketWin::Core::WriteDelegate::OnObjectSignaled(
void TCPSocketDefaultWin::CoreImpl::WriteDelegate::OnObjectSignaled(
HANDLE object) {
DCHECK_EQ(object, core_->write_overlapped_.hEvent);
if (core_->socket_)
@ -262,6 +324,23 @@ void TCPSocketWin::Core::WriteDelegate::OnObjectSignaled(
//-----------------------------------------------------------------------------
// static
std::unique_ptr<TCPSocketWin> TCPSocketWin::Create(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetLog* net_log,
const NetLogSource& source) {
return std::make_unique<TCPSocketDefaultWin>(
std::move(socket_performance_watcher), net_log, source);
}
// static
std::unique_ptr<TCPSocketWin> TCPSocketWin::Create(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetLogWithSource net_log_source) {
return std::make_unique<TCPSocketDefaultWin>(
std::move(socket_performance_watcher), std::move(net_log_source));
}
TCPSocketWin::TCPSocketWin(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
net::NetLog* net_log,
@ -280,14 +359,17 @@ TCPSocketWin::TCPSocketWin(
: socket_(INVALID_SOCKET),
socket_performance_watcher_(std::move(socket_performance_watcher)),
accept_event_(WSA_INVALID_EVENT),
net_log_(net_log_source) {
net_log_(std::move(net_log_source)) {
net_log_.BeginEvent(NetLogEventType::SOCKET_ALIVE);
EnsureWinsockInit();
}
TCPSocketWin::~TCPSocketWin() {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
Close();
// The subclass must call `Close`. See comment in header file.
CHECK(!core_);
net_log_.EndEvent(NetLogEventType::SOCKET_ALIVE);
}
@ -327,7 +409,7 @@ int TCPSocketWin::AdoptConnectedSocket(SocketDescriptor socket,
return result;
}
core_ = base::MakeRefCounted<Core>(this);
core_ = CreateCore();
peer_address_ = std::make_unique<IPEndPoint>(peer_address);
return OK;
@ -460,8 +542,7 @@ bool TCPSocketWin::IsConnected() const {
return false;
}
if (read_if_ready_callback_) {
// Outstanding read on a connected socket.
if (HasPendingRead()) {
return true;
}
@ -489,8 +570,7 @@ bool TCPSocketWin::IsConnectedAndIdle() const {
return false;
}
if (read_if_ready_callback_) {
// Outstanding read on a connected socket.
if (HasPendingRead()) {
return true;
}
@ -511,34 +591,36 @@ bool TCPSocketWin::IsConnectedAndIdle() const {
return true;
}
int TCPSocketWin::Read(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) {
int TCPSocketDefaultWin::Read(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
DCHECK(!core_->read_iobuffer_.get());
CoreImpl& core = GetCoreImpl();
DCHECK(!core.read_iobuffer_.get());
// base::Unretained() is safe because RetryRead() won't be called when |this|
// is gone.
int rv = ReadIfReady(
buf, buf_len,
base::BindOnce(&TCPSocketWin::RetryRead, base::Unretained(this)));
base::BindOnce(&TCPSocketDefaultWin::RetryRead, base::Unretained(this)));
if (rv != ERR_IO_PENDING)
return rv;
read_callback_ = std::move(callback);
core_->read_iobuffer_ = buf;
core_->read_buffer_length_ = buf_len;
core.read_iobuffer_ = buf;
core.read_buffer_length_ = buf_len;
return ERR_IO_PENDING;
}
int TCPSocketWin::ReadIfReady(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) {
int TCPSocketDefaultWin::ReadIfReady(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
DCHECK_NE(socket_, INVALID_SOCKET);
DCHECK(read_if_ready_callback_.is_null());
if (!core_->non_blocking_reads_initialized_) {
WSAEventSelect(socket_, core_->read_event_, FD_READ | FD_CLOSE);
core_->non_blocking_reads_initialized_ = true;
CoreImpl& core = GetCoreImpl();
if (!core.non_blocking_reads_initialized_) {
WSAEventSelect(socket_, core.read_event_, FD_READ | FD_CLOSE);
core.non_blocking_reads_initialized_ = true;
}
int rv = recv(socket_, buf->data(), buf_len, 0);
int os_error = WSAGetLastError();
@ -557,20 +639,20 @@ int TCPSocketWin::ReadIfReady(IOBuffer* buf,
}
read_if_ready_callback_ = std::move(callback);
core_->WatchForRead();
core.WatchForRead();
return ERR_IO_PENDING;
}
int TCPSocketWin::CancelReadIfReady() {
int TCPSocketDefaultWin::CancelReadIfReady() {
DCHECK(read_callback_.is_null());
DCHECK(!read_if_ready_callback_.is_null());
core_->StopWatchingForRead();
GetCoreImpl().StopWatchingForRead();
read_if_ready_callback_.Reset();
return net::OK;
}
int TCPSocketWin::Write(
int TCPSocketDefaultWin::Write(
IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback,
@ -579,18 +661,19 @@ int TCPSocketWin::Write(
DCHECK_NE(socket_, INVALID_SOCKET);
CHECK(write_callback_.is_null());
DCHECK_GT(buf_len, 0);
DCHECK(!core_->write_iobuffer_.get());
CoreImpl& core = GetCoreImpl();
DCHECK(!core.write_iobuffer_.get());
WSABUF write_buffer;
write_buffer.len = buf_len;
write_buffer.buf = buf->data();
DWORD num;
int rv = WSASend(socket_, &write_buffer, 1, &num, 0,
&core_->write_overlapped_, nullptr);
int rv = WSASend(socket_, &write_buffer, 1, &num, 0, &core.write_overlapped_,
nullptr);
int os_error = WSAGetLastError();
if (rv == 0) {
if (ResetEventIfSignaled(core_->write_overlapped_.hEvent)) {
if (ResetEventIfSignaled(core.write_overlapped_.hEvent)) {
rv = static_cast<int>(num);
if (rv > buf_len || rv < 0) {
// It seems that some winsock interceptors report that more was written
@ -612,9 +695,9 @@ int TCPSocketWin::Write(
}
}
write_callback_ = std::move(callback);
core_->write_iobuffer_ = buf;
core_->write_buffer_length_ = buf_len;
core_->WatchForWrite();
core.write_iobuffer_ = buf;
core.write_buffer_length_ = buf_len;
core.WatchForWrite();
return ERR_IO_PENDING;
}
@ -748,9 +831,9 @@ void TCPSocketWin::Close() {
// when the socket is closed. This is not the case for reads.
}
read_callback_.Reset();
read_if_ready_callback_.Reset();
write_callback_.Reset();
connect_callback_.Reset();
OnClosed();
peer_address_.reset();
connect_os_error_ = 0;
}
@ -810,8 +893,8 @@ int TCPSocketWin::AcceptInternal(std::unique_ptr<TCPSocketWin>* socket,
net_log_.EndEventWithNetErrorCode(NetLogEventType::TCP_ACCEPT, net_error);
return net_error;
}
auto tcp_socket = std::make_unique<TCPSocketWin>(nullptr, net_log_.net_log(),
net_log_.source());
auto tcp_socket =
TCPSocketWin::Create(nullptr, net_log_.net_log(), net_log_.source());
int adopt_result = tcp_socket->AdoptConnectedSocket(new_socket, ip_end_point);
if (adopt_result != OK) {
net_log_.EndEventWithNetErrorCode(NetLogEventType::TCP_ACCEPT,
@ -859,11 +942,11 @@ int TCPSocketWin::DoConnect() {
return CreateNetLogIPEndPointParams(peer_address_.get());
});
core_ = base::MakeRefCounted<Core>(this);
core_ = CreateCore();
// WSAEventSelect sets the socket to non-blocking mode as a side effect.
// Our connect() and recv() calls require that the socket be non-blocking.
WSAEventSelect(socket_, core_->read_event_, FD_CONNECT);
WSAEventSelect(socket_, core_->GetConnectEvent(), FD_CONNECT);
SockaddrStorage storage;
if (!peer_address_->ToSockAddr(storage.addr, &storage.addr_len))
@ -892,8 +975,9 @@ int TCPSocketWin::DoConnect() {
// and we don't know if it's correct.
NOTREACHED_IN_MIGRATION();
if (ResetEventIfSignaled(core_->read_event_))
if (ResetEventIfSignaled(core_->GetConnectEvent())) {
return OK;
}
} else {
int os_error = WSAGetLastError();
if (os_error != WSAEWOULDBLOCK) {
@ -905,7 +989,7 @@ int TCPSocketWin::DoConnect() {
}
}
core_->WatchForRead();
core_->WatchForConnect();
return ERR_IO_PENDING;
}
@ -947,20 +1031,21 @@ void TCPSocketWin::LogConnectEnd(int net_error) {
});
}
void TCPSocketWin::RetryRead(int rv) {
DCHECK(core_->read_iobuffer_);
void TCPSocketDefaultWin::RetryRead(int rv) {
CoreImpl& core = GetCoreImpl();
DCHECK(core.read_iobuffer_);
if (rv == OK) {
// base::Unretained() is safe because RetryRead() won't be called when
// |this| is gone.
rv = ReadIfReady(
core_->read_iobuffer_.get(), core_->read_buffer_length_,
base::BindOnce(&TCPSocketWin::RetryRead, base::Unretained(this)));
rv = ReadIfReady(core.read_iobuffer_.get(), core.read_buffer_length_,
base::BindOnce(&TCPSocketDefaultWin::RetryRead,
base::Unretained(this)));
if (rv == ERR_IO_PENDING)
return;
}
core_->read_iobuffer_ = nullptr;
core_->read_buffer_length_ = 0;
core.read_iobuffer_ = nullptr;
core.read_buffer_length_ = 0;
std::move(read_callback_).Run(rv);
}
@ -969,7 +1054,7 @@ void TCPSocketWin::DidCompleteConnect() {
int result;
WSANETWORKEVENTS events;
int rv = WSAEnumNetworkEvents(socket_, core_->read_event_, &events);
int rv = WSAEnumNetworkEvents(socket_, core_->GetConnectEvent(), &events);
int os_error = WSAGetLastError();
if (rv == SOCKET_ERROR) {
DLOG(FATAL)
@ -991,14 +1076,15 @@ void TCPSocketWin::DidCompleteConnect() {
std::move(connect_callback_).Run(result);
}
void TCPSocketWin::DidCompleteWrite() {
void TCPSocketDefaultWin::DidCompleteWrite() {
DCHECK(!write_callback_.is_null());
CoreImpl& core = GetCoreImpl();
DWORD num_bytes, flags;
BOOL ok = WSAGetOverlappedResult(socket_, &core_->write_overlapped_,
&num_bytes, FALSE, &flags);
BOOL ok = WSAGetOverlappedResult(socket_, &core.write_overlapped_, &num_bytes,
FALSE, &flags);
int os_error = WSAGetLastError();
WSAResetEvent(core_->write_overlapped_.hEvent);
WSAResetEvent(core.write_overlapped_.hEvent);
int rv;
if (!ok) {
rv = MapSystemError(os_error);
@ -1006,31 +1092,32 @@ void TCPSocketWin::DidCompleteWrite() {
os_error);
} else {
rv = static_cast<int>(num_bytes);
if (rv > core_->write_buffer_length_ || rv < 0) {
if (rv > core.write_buffer_length_ || rv < 0) {
// It seems that some winsock interceptors report that more was written
// than was available. Treat this as an error. http://crbug.com/27870
LOG(ERROR) << "Detected broken LSP: Asked to write "
<< core_->write_buffer_length_ << " bytes, but " << rv
<< core.write_buffer_length_ << " bytes, but " << rv
<< " bytes reported.";
rv = ERR_WINSOCK_UNEXPECTED_WRITTEN_BYTES;
} else {
net_log_.AddByteTransferEvent(NetLogEventType::SOCKET_BYTES_SENT,
num_bytes, core_->write_iobuffer_->data());
num_bytes, core.write_iobuffer_->data());
}
}
core_->write_iobuffer_ = nullptr;
core.write_iobuffer_ = nullptr;
DCHECK_NE(rv, ERR_IO_PENDING);
std::move(write_callback_).Run(rv);
}
void TCPSocketWin::DidSignalRead() {
void TCPSocketDefaultWin::DidSignalRead() {
DCHECK(!read_if_ready_callback_.is_null());
CoreImpl& core = GetCoreImpl();
int os_error = 0;
WSANETWORKEVENTS network_events;
int rv = WSAEnumNetworkEvents(socket_, core_->read_event_, &network_events);
int rv = WSAEnumNetworkEvents(socket_, core.read_event_, &network_events);
os_error = WSAGetLastError();
if (rv == SOCKET_ERROR) {
@ -1055,7 +1142,7 @@ void TCPSocketWin::DidSignalRead() {
} else {
// This may happen because Read() may succeed synchronously and
// consume all the received data without resetting the event object.
core_->WatchForRead();
core.WatchForRead();
return;
}
@ -1081,4 +1168,40 @@ int TCPSocketWin::BindToNetwork(handles::NetworkHandle network) {
return ERR_NOT_IMPLEMENTED;
}
TCPSocketDefaultWin::TCPSocketDefaultWin(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetLog* net_log,
const NetLogSource& source)
: TCPSocketWin(std::move(socket_performance_watcher), net_log, source) {}
TCPSocketDefaultWin::TCPSocketDefaultWin(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetLogWithSource net_log_source)
: TCPSocketWin(std::move(socket_performance_watcher),
std::move(net_log_source)) {}
TCPSocketDefaultWin::~TCPSocketDefaultWin() {
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
Close();
}
TCPSocketDefaultWin::CoreImpl& TCPSocketDefaultWin::GetCoreImpl() {
return CHECK_DEREF(static_cast<CoreImpl*>(core_.get()));
}
scoped_refptr<TCPSocketWin::Core> TCPSocketDefaultWin::CreateCore() {
return base::MakeRefCounted<CoreImpl>(this);
}
bool TCPSocketDefaultWin::HasPendingRead() const {
CHECK(!read_callback_ || read_if_ready_callback_);
return !read_if_ready_callback_.is_null();
}
void TCPSocketDefaultWin::OnClosed() {
read_callback_.Reset();
read_if_ready_callback_.Reset();
write_callback_.Reset();
}
} // namespace net

@ -11,6 +11,7 @@
#include <memory>
#include "base/check_deref.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/scoped_refptr.h"
#include "base/threading/thread_checker.h"
@ -26,6 +27,8 @@
namespace net {
class TCPSocketDefaultWin;
class AddressList;
class IOBuffer;
class IPEndPoint;
@ -35,18 +38,20 @@ class SocketTag;
class NET_EXPORT TCPSocketWin : public base::win::ObjectWatcher::Delegate {
public:
TCPSocketWin(
static std::unique_ptr<TCPSocketWin> Create(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetLog* net_log,
const NetLogSource& source);
TCPSocketWin(
static std::unique_ptr<TCPSocketWin> Create(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetLogWithSource net_log_source);
TCPSocketWin(const TCPSocketWin&) = delete;
TCPSocketWin& operator=(const TCPSocketWin&) = delete;
// IMPORTANT: All subclasses must call `Close`. The base class cannot do it
// because `Close` invokes virtual methods, but it CHECKs that the socket is
// closed.
~TCPSocketWin() override;
int Open(AddressFamily family);
@ -75,13 +80,17 @@ class NET_EXPORT TCPSocketWin : public base::win::ObjectWatcher::Delegate {
// Multiple outstanding requests are not supported.
// Full duplex mode (reading and writing at the same time) is supported.
int Read(IOBuffer* buf, int buf_len, CompletionOnceCallback callback);
int ReadIfReady(IOBuffer* buf, int buf_len, CompletionOnceCallback callback);
int CancelReadIfReady();
int Write(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback,
const NetworkTrafficAnnotationTag& traffic_annotation);
virtual int Read(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) = 0;
virtual int ReadIfReady(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) = 0;
virtual int CancelReadIfReady() = 0;
virtual int Write(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback,
const NetworkTrafficAnnotationTag& traffic_annotation) = 0;
int GetLocalAddress(IPEndPoint* address) const;
int GetPeerAddress(IPEndPoint* address) const;
@ -149,8 +158,54 @@ class NET_EXPORT TCPSocketWin : public base::win::ObjectWatcher::Delegate {
return socket_performance_watcher_.get();
}
private:
class Core;
protected:
friend class TCPSocketDefaultWin;
// Encapsulates state that must be preserved while network IO operations are
// in progress. If the owning TCPSocketWin is destroyed while an operation is
// in progress, the Core is detached and lives until the operation completes
// and the OS doesn't reference any resource owned by it.
class Core : public base::RefCounted<Core> {
public:
Core(const Core&) = delete;
Core& operator=(const Core&) = delete;
// Invoked when the socket is closed. Clears any reference from the `Core`
// to its parent socket.
virtual void Detach() = 0;
// Returns the event to use for watching the completion of a connect()
// operation.
virtual HANDLE GetConnectEvent() = 0;
// Must be invoked after initiating a connect() operation. Will invoke
// `DidCompleteConnect()` when the connect() operation is complete.
virtual void WatchForConnect() = 0;
protected:
friend class base::RefCounted<Core>;
Core();
virtual ~Core();
};
TCPSocketWin(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetLog* net_log,
const NetLogSource& source);
TCPSocketWin(
std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
NetLogWithSource net_log_source);
// Instantiates a `Core` object for this socket.
virtual scoped_refptr<Core> CreateCore() = 0;
// Whether there is a pending read operation on this socket.
virtual bool HasPendingRead() const = 0;
// Invoked when the socket is closed.
virtual void OnClosed() = 0;
// base::ObjectWatcher::Delegate implementation.
void OnObjectSignaled(HANDLE object) override;
@ -164,10 +219,7 @@ class NET_EXPORT TCPSocketWin : public base::win::ObjectWatcher::Delegate {
void LogConnectBegin(const AddressList& addresses);
void LogConnectEnd(int net_error);
void RetryRead(int rv);
void DidCompleteConnect();
void DidCompleteWrite();
void DidSignalRead();
SOCKET socket_;
@ -189,17 +241,6 @@ class NET_EXPORT TCPSocketWin : public base::win::ObjectWatcher::Delegate {
// Callback invoked when connect is complete.
CompletionOnceCallback connect_callback_;
// External callback; called when connect or read is complete.
CompletionOnceCallback read_callback_;
// Non-null if a ReadIfReady() is to be completed asynchronously. This is an
// external callback if user used ReadIfReady() instead of Read(), but a
// wrapped callback on top of RetryRead() if Read() is used.
CompletionOnceCallback read_if_ready_callback_;
// External callback; called when write is complete.
CompletionOnceCallback write_callback_;
std::unique_ptr<IPEndPoint> peer_address_;
// The OS error that a connect attempt last completed with.
int connect_os_error_ = 0;

@ -128,7 +128,7 @@ void BrokeredTcpClientSocket ::DidCompleteCreate(
// Create an unconnected TCPSocket with the socket fd that was opened in the
// browser process.
std::unique_ptr<net::TCPSocket> tcp_socket = std::make_unique<net::TCPSocket>(
std::unique_ptr<net::TCPSocket> tcp_socket = net::TCPSocket::Create(
std::move(socket_performance_watcher_), net_log_source_);
tcp_socket->AdoptUnconnectedSocket(socket.TakeSocket());

@ -29,9 +29,10 @@ TEST_F(TransferableSocketTest, MojoTraits) {
#if BUILDFLAG(IS_WIN)
net::EnsureWinsockInit();
#endif
net::TCPSocket socket(nullptr, nullptr, net::NetLogSource());
socket.Open(net::AddressFamily::ADDRESS_FAMILY_IPV4);
auto socket_desc = socket.ReleaseSocketDescriptorForTesting();
std::unique_ptr<net::TCPSocket> socket =
net::TCPSocket::Create(nullptr, nullptr, net::NetLogSource());
socket->Open(net::AddressFamily::ADDRESS_FAMILY_IPV4);
auto socket_desc = socket->ReleaseSocketDescriptorForTesting();
TransferableSocket transferable(socket_desc
#if BUILDFLAG(IS_WIN)
,

@ -27,10 +27,9 @@ TCPBoundSocket::TCPBoundSocket(
net::NetLog* net_log,
const net::NetworkTrafficAnnotationTag& traffic_annotation)
: socket_factory_(socket_factory),
socket_(std::make_unique<net::TCPSocket>(
nullptr /*socket_performance_watcher*/,
net_log,
net::NetLogSource())),
socket_(net::TCPSocket::Create(nullptr /*socket_performance_watcher*/,
net_log,
net::NetLogSource())),
traffic_annotation_(traffic_annotation) {}
TCPBoundSocket::~TCPBoundSocket() = default;