
This is a preparation to fix each files. This CL has no behavior changes. This patch was fully automated using script: https://paste.googleplex.com/5614491201175552 See internal doc about it: https://docs.google.com/document/d/1erdcokeh6rfBqs_h0drHqSLtbDbB61j7j3O2Pz8NH78/edit?resourcekey=0-hNe6w1hYAYyVXGEpWI7HVA&tab=t.0 Bug: 40285824 Change-Id: Ibfa3bf8ef4117b3c6c73d64c1295368ecdf16c18 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5759698 Reviewed-by: Joe Downing <joedow@chromium.org> Commit-Queue: Joe Downing <joedow@chromium.org> Auto-Submit: Arthur Sonzogni <arthursonzogni@chromium.org> Cr-Commit-Position: refs/heads/main@{#1344112}
491 lines
15 KiB
C++
491 lines
15 KiB
C++
// Copyright 2012 The Chromium Authors
|
|
// Use of this source code is governed by a BSD-style license that can be
|
|
// found in the LICENSE file.
|
|
|
|
#ifdef UNSAFE_BUFFERS_BUILD
|
|
// TODO(crbug.com/40285824): Remove this and convert code to safer constructs.
|
|
#pragma allow_unsafe_buffers
|
|
#endif
|
|
|
|
#include "remoting/protocol/channel_multiplexer.h"
|
|
|
|
#include <stddef.h>
|
|
#include <string.h>
|
|
|
|
#include <utility>
|
|
|
|
#include "base/functional/bind.h"
|
|
#include "base/functional/callback.h"
|
|
#include "base/functional/callback_helpers.h"
|
|
#include "base/location.h"
|
|
#include "base/memory/raw_ptr.h"
|
|
#include "base/sequence_checker.h"
|
|
#include "base/task/single_thread_task_runner.h"
|
|
#include "net/base/net_errors.h"
|
|
#include "remoting/protocol/message_serialization.h"
|
|
#include "remoting/protocol/p2p_stream_socket.h"
|
|
|
|
namespace remoting::protocol {
|
|
|
|
namespace {
|
|
const int kChannelIdUnknown = -1;
|
|
const int kMaxPacketSize = 1024;
|
|
|
|
class PendingPacket {
|
|
public:
|
|
PendingPacket(std::unique_ptr<MultiplexPacket> packet)
|
|
: packet(std::move(packet)) {}
|
|
|
|
PendingPacket(const PendingPacket&) = delete;
|
|
PendingPacket& operator=(const PendingPacket&) = delete;
|
|
|
|
~PendingPacket() = default;
|
|
|
|
bool is_empty() { return pos >= packet->data().size(); }
|
|
|
|
int Read(char* buffer, size_t size) {
|
|
size = std::min(size, packet->data().size() - pos);
|
|
memcpy(buffer, packet->data().data() + pos, size);
|
|
pos += size;
|
|
return size;
|
|
}
|
|
|
|
private:
|
|
std::unique_ptr<MultiplexPacket> packet;
|
|
size_t pos = 0U;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
const char ChannelMultiplexer::kMuxChannelName[] = "mux";
|
|
|
|
struct ChannelMultiplexer::PendingChannel {
|
|
PendingChannel(const std::string& name, ChannelCreatedCallback callback)
|
|
: name(name), callback(std::move(callback)) {}
|
|
std::string name;
|
|
ChannelCreatedCallback callback;
|
|
};
|
|
|
|
class ChannelMultiplexer::MuxChannel {
|
|
public:
|
|
MuxChannel(ChannelMultiplexer* multiplexer,
|
|
const std::string& name,
|
|
int send_id);
|
|
|
|
MuxChannel(const MuxChannel&) = delete;
|
|
MuxChannel& operator=(const MuxChannel&) = delete;
|
|
|
|
~MuxChannel();
|
|
|
|
const std::string& name() { return name_; }
|
|
int receive_id() { return receive_id_; }
|
|
void set_receive_id(int id) { receive_id_ = id; }
|
|
|
|
// Called by ChannelMultiplexer.
|
|
std::unique_ptr<P2PStreamSocket> CreateSocket();
|
|
void OnIncomingPacket(std::unique_ptr<MultiplexPacket> packet);
|
|
void OnBaseChannelError(int error);
|
|
|
|
// Called by MuxSocket.
|
|
void OnSocketDestroyed();
|
|
void DoWrite(std::unique_ptr<MultiplexPacket> packet,
|
|
base::OnceClosure done_task,
|
|
const net::NetworkTrafficAnnotationTag& traffic_annotation);
|
|
int DoRead(const scoped_refptr<net::IOBuffer>& buffer, int buffer_len);
|
|
|
|
private:
|
|
raw_ptr<ChannelMultiplexer> multiplexer_;
|
|
std::string name_;
|
|
int send_id_;
|
|
bool id_sent_;
|
|
int receive_id_;
|
|
raw_ptr<MuxSocket> socket_;
|
|
std::list<std::unique_ptr<PendingPacket>> pending_packets_;
|
|
};
|
|
|
|
class ChannelMultiplexer::MuxSocket : public P2PStreamSocket {
|
|
public:
|
|
MuxSocket(MuxChannel* channel);
|
|
|
|
MuxSocket(const MuxSocket&) = delete;
|
|
MuxSocket& operator=(const MuxSocket&) = delete;
|
|
|
|
~MuxSocket() override;
|
|
|
|
void OnWriteComplete();
|
|
void OnBaseChannelError(int error);
|
|
void OnPacketReceived();
|
|
|
|
// P2PStreamSocket interface.
|
|
int Read(const scoped_refptr<net::IOBuffer>& buffer,
|
|
int buffer_len,
|
|
net::CompletionOnceCallback callback) override;
|
|
int Write(
|
|
const scoped_refptr<net::IOBuffer>& buffer,
|
|
int buffer_len,
|
|
net::CompletionOnceCallback callback,
|
|
const net::NetworkTrafficAnnotationTag& traffic_annotation) override;
|
|
|
|
private:
|
|
raw_ptr<MuxChannel> channel_;
|
|
|
|
int base_channel_error_ = net::OK;
|
|
|
|
net::CompletionOnceCallback read_callback_;
|
|
scoped_refptr<net::IOBuffer> read_buffer_;
|
|
int read_buffer_size_;
|
|
|
|
bool write_pending_;
|
|
int write_result_;
|
|
net::CompletionOnceCallback write_callback_;
|
|
|
|
SEQUENCE_CHECKER(sequence_checker_);
|
|
|
|
base::WeakPtrFactory<MuxSocket> weak_factory_{this};
|
|
};
|
|
|
|
ChannelMultiplexer::MuxChannel::MuxChannel(ChannelMultiplexer* multiplexer,
|
|
const std::string& name,
|
|
int send_id)
|
|
: multiplexer_(multiplexer),
|
|
name_(name),
|
|
send_id_(send_id),
|
|
id_sent_(false),
|
|
receive_id_(kChannelIdUnknown),
|
|
socket_(nullptr) {}
|
|
|
|
ChannelMultiplexer::MuxChannel::~MuxChannel() {
|
|
// Socket must be destroyed before the channel.
|
|
DCHECK(!socket_);
|
|
}
|
|
|
|
std::unique_ptr<P2PStreamSocket>
|
|
ChannelMultiplexer::MuxChannel::CreateSocket() {
|
|
DCHECK(!socket_); // Can't create more than one socket per channel.
|
|
std::unique_ptr<MuxSocket> result(new MuxSocket(this));
|
|
socket_ = result.get();
|
|
return std::move(result);
|
|
}
|
|
|
|
void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
|
|
std::unique_ptr<MultiplexPacket> packet) {
|
|
DCHECK_EQ(packet->channel_id(), receive_id_);
|
|
if (packet->data().size() > 0) {
|
|
pending_packets_.push_back(
|
|
std::make_unique<PendingPacket>(std::move(packet)));
|
|
if (socket_) {
|
|
// Notify the socket that we have more data.
|
|
socket_->OnPacketReceived();
|
|
}
|
|
}
|
|
}
|
|
|
|
void ChannelMultiplexer::MuxChannel::OnBaseChannelError(int error) {
|
|
if (socket_) {
|
|
socket_->OnBaseChannelError(error);
|
|
}
|
|
}
|
|
|
|
void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
|
|
DCHECK(socket_);
|
|
socket_ = nullptr;
|
|
}
|
|
|
|
void ChannelMultiplexer::MuxChannel::DoWrite(
|
|
std::unique_ptr<MultiplexPacket> packet,
|
|
base::OnceClosure done_task,
|
|
const net::NetworkTrafficAnnotationTag& traffic_annotation) {
|
|
packet->set_channel_id(send_id_);
|
|
if (!id_sent_) {
|
|
packet->set_channel_name(name_);
|
|
id_sent_ = true;
|
|
}
|
|
multiplexer_->DoWrite(std::move(packet), std::move(done_task),
|
|
traffic_annotation);
|
|
}
|
|
|
|
int ChannelMultiplexer::MuxChannel::DoRead(
|
|
const scoped_refptr<net::IOBuffer>& buffer,
|
|
int buffer_len) {
|
|
int pos = 0;
|
|
while (buffer_len > 0 && !pending_packets_.empty()) {
|
|
DCHECK(!pending_packets_.front()->is_empty());
|
|
int result =
|
|
pending_packets_.front()->Read(buffer->data() + pos, buffer_len);
|
|
DCHECK_LE(result, buffer_len);
|
|
pos += result;
|
|
buffer_len -= pos;
|
|
if (pending_packets_.front()->is_empty()) {
|
|
pending_packets_.pop_front();
|
|
}
|
|
}
|
|
return pos;
|
|
}
|
|
|
|
ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel)
|
|
: channel_(channel),
|
|
read_buffer_size_(0),
|
|
write_pending_(false),
|
|
write_result_(0) {}
|
|
|
|
ChannelMultiplexer::MuxSocket::~MuxSocket() {
|
|
channel_->OnSocketDestroyed();
|
|
}
|
|
|
|
int ChannelMultiplexer::MuxSocket::Read(
|
|
const scoped_refptr<net::IOBuffer>& buffer,
|
|
int buffer_len,
|
|
net::CompletionOnceCallback callback) {
|
|
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
|
|
DCHECK(read_callback_.is_null());
|
|
|
|
if (base_channel_error_ != net::OK) {
|
|
return base_channel_error_;
|
|
}
|
|
|
|
int result = channel_->DoRead(buffer, buffer_len);
|
|
if (result == 0) {
|
|
read_buffer_ = buffer;
|
|
read_buffer_size_ = buffer_len;
|
|
read_callback_ = std::move(callback);
|
|
return net::ERR_IO_PENDING;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
int ChannelMultiplexer::MuxSocket::Write(
|
|
const scoped_refptr<net::IOBuffer>& buffer,
|
|
int buffer_len,
|
|
net::CompletionOnceCallback callback,
|
|
const net::NetworkTrafficAnnotationTag& traffic_annotation) {
|
|
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
|
|
DCHECK(write_callback_.is_null());
|
|
|
|
if (base_channel_error_ != net::OK) {
|
|
return base_channel_error_;
|
|
}
|
|
|
|
std::unique_ptr<MultiplexPacket> packet(new MultiplexPacket());
|
|
size_t size = std::min(kMaxPacketSize, buffer_len);
|
|
packet->mutable_data()->assign(buffer->data(), size);
|
|
|
|
write_pending_ = true;
|
|
channel_->DoWrite(
|
|
std::move(packet),
|
|
base::BindOnce(&ChannelMultiplexer::MuxSocket::OnWriteComplete,
|
|
weak_factory_.GetWeakPtr()),
|
|
traffic_annotation);
|
|
|
|
// OnWriteComplete() might be called above synchronously.
|
|
if (write_pending_) {
|
|
DCHECK(write_callback_.is_null());
|
|
write_callback_ = std::move(callback);
|
|
write_result_ = size;
|
|
return net::ERR_IO_PENDING;
|
|
}
|
|
|
|
return size;
|
|
}
|
|
|
|
void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
|
|
write_pending_ = false;
|
|
if (!write_callback_.is_null()) {
|
|
std::move(write_callback_).Run(write_result_);
|
|
}
|
|
}
|
|
|
|
void ChannelMultiplexer::MuxSocket::OnBaseChannelError(int error) {
|
|
base_channel_error_ = error;
|
|
|
|
// Here only one of the read and write callbacks is called if both of them are
|
|
// pending. Ideally both of them should be called in that case, but that would
|
|
// require the second one to be called asynchronously which would complicate
|
|
// this code. Channels handle read and write errors the same way (see
|
|
// ChannelDispatcherBase::OnReadWriteFailed) so calling only one of the
|
|
// callbacks is enough.
|
|
|
|
if (!read_callback_.is_null()) {
|
|
std::move(read_callback_).Run(error);
|
|
return;
|
|
}
|
|
|
|
if (!write_callback_.is_null()) {
|
|
std::move(write_callback_).Run(error);
|
|
}
|
|
}
|
|
|
|
void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
|
|
if (!read_callback_.is_null()) {
|
|
int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
|
|
read_buffer_ = nullptr;
|
|
DCHECK_GT(result, 0);
|
|
std::move(read_callback_).Run(result);
|
|
}
|
|
}
|
|
|
|
ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory,
|
|
const std::string& base_channel_name)
|
|
: base_channel_factory_(factory),
|
|
base_channel_name_(base_channel_name),
|
|
next_channel_id_(0) {}
|
|
|
|
ChannelMultiplexer::~ChannelMultiplexer() {
|
|
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
|
|
DCHECK(pending_channels_.empty());
|
|
|
|
// Cancel creation of the base channel if it hasn't finished.
|
|
if (base_channel_factory_) {
|
|
base_channel_factory_->CancelChannelCreation(base_channel_name_);
|
|
}
|
|
}
|
|
|
|
void ChannelMultiplexer::CreateChannel(const std::string& name,
|
|
ChannelCreatedCallback callback) {
|
|
if (base_channel_.get()) {
|
|
// Already have |base_channel_|. Create new multiplexed channel
|
|
// synchronously.
|
|
std::move(callback).Run(GetOrCreateChannel(name)->CreateSocket());
|
|
} else if (!base_channel_.get() && !base_channel_factory_) {
|
|
// Fail synchronously if we failed to create |base_channel_|.
|
|
std::move(callback).Run(nullptr);
|
|
} else {
|
|
// Still waiting for the |base_channel_|.
|
|
pending_channels_.emplace_back(name, std::move(callback));
|
|
|
|
// If this is the first multiplexed channel then create the base channel.
|
|
if (pending_channels_.size() == 1U) {
|
|
base_channel_factory_->CreateChannel(
|
|
base_channel_name_,
|
|
base::BindOnce(&ChannelMultiplexer::OnBaseChannelReady,
|
|
base::Unretained(this)));
|
|
}
|
|
}
|
|
}
|
|
|
|
void ChannelMultiplexer::CancelChannelCreation(const std::string& name) {
|
|
for (auto it = pending_channels_.begin(); it != pending_channels_.end();
|
|
++it) {
|
|
if (it->name == name) {
|
|
pending_channels_.erase(it);
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
void ChannelMultiplexer::OnBaseChannelReady(
|
|
std::unique_ptr<P2PStreamSocket> socket) {
|
|
base_channel_factory_ = nullptr;
|
|
base_channel_ = std::move(socket);
|
|
|
|
if (base_channel_.get()) {
|
|
// Initialize reader and writer.
|
|
reader_.StartReading(
|
|
base_channel_.get(),
|
|
base::BindRepeating(&ChannelMultiplexer::OnIncomingPacket,
|
|
base::Unretained(this)),
|
|
base::BindOnce(&ChannelMultiplexer::OnBaseChannelError,
|
|
base::Unretained(this)));
|
|
writer_.Start(base::BindRepeating(&P2PStreamSocket::Write,
|
|
base::Unretained(base_channel_.get())),
|
|
base::BindOnce(&ChannelMultiplexer::OnBaseChannelError,
|
|
base::Unretained(this)));
|
|
}
|
|
|
|
DoCreatePendingChannels();
|
|
}
|
|
|
|
void ChannelMultiplexer::DoCreatePendingChannels() {
|
|
if (pending_channels_.empty()) {
|
|
return;
|
|
}
|
|
|
|
// Every time this function is called it connects a single channel and posts a
|
|
// separate task to connect other channels. This is necessary because the
|
|
// callback may destroy the multiplexer or somehow else modify
|
|
// |pending_channels_| list (e.g. call CancelChannelCreation()).
|
|
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
|
|
FROM_HERE, base::BindOnce(&ChannelMultiplexer::DoCreatePendingChannels,
|
|
weak_factory_.GetWeakPtr()));
|
|
|
|
PendingChannel c = std::move(pending_channels_.front());
|
|
pending_channels_.erase(pending_channels_.begin());
|
|
std::unique_ptr<P2PStreamSocket> socket;
|
|
if (base_channel_.get()) {
|
|
socket = GetOrCreateChannel(c.name)->CreateSocket();
|
|
}
|
|
std::move(c.callback).Run(std::move(socket));
|
|
}
|
|
|
|
ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
|
|
const std::string& name) {
|
|
std::unique_ptr<MuxChannel>& channel = channels_[name];
|
|
if (!channel) {
|
|
// Create a new channel if we haven't found existing one.
|
|
channel = std::make_unique<MuxChannel>(this, name, next_channel_id_);
|
|
++next_channel_id_;
|
|
}
|
|
|
|
return channel.get();
|
|
}
|
|
|
|
void ChannelMultiplexer::OnBaseChannelError(int error) {
|
|
for (auto it = channels_.begin(); it != channels_.end(); ++it) {
|
|
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
|
|
FROM_HERE,
|
|
base::BindOnce(&ChannelMultiplexer::NotifyBaseChannelError,
|
|
weak_factory_.GetWeakPtr(), it->second->name(), error));
|
|
}
|
|
}
|
|
|
|
void ChannelMultiplexer::NotifyBaseChannelError(const std::string& name,
|
|
int error) {
|
|
auto it = channels_.find(name);
|
|
if (it != channels_.end()) {
|
|
it->second->OnBaseChannelError(error);
|
|
}
|
|
}
|
|
|
|
void ChannelMultiplexer::OnIncomingPacket(
|
|
std::unique_ptr<CompoundBuffer> buffer) {
|
|
std::unique_ptr<MultiplexPacket> packet =
|
|
ParseMessage<MultiplexPacket>(buffer.get());
|
|
if (!packet) {
|
|
return;
|
|
}
|
|
|
|
DCHECK(packet->has_channel_id());
|
|
if (!packet->has_channel_id()) {
|
|
LOG(ERROR) << "Received packet without channel_id.";
|
|
return;
|
|
}
|
|
|
|
int receive_id = packet->channel_id();
|
|
MuxChannel* channel = nullptr;
|
|
auto it = channels_by_receive_id_.find(receive_id);
|
|
if (it != channels_by_receive_id_.end()) {
|
|
channel = it->second;
|
|
} else {
|
|
// This is a new |channel_id| we haven't seen before. Look it up by name.
|
|
if (!packet->has_channel_name()) {
|
|
LOG(ERROR) << "Received packet with unknown channel_id and "
|
|
"without channel_name.";
|
|
return;
|
|
}
|
|
channel = GetOrCreateChannel(packet->channel_name());
|
|
channel->set_receive_id(receive_id);
|
|
channels_by_receive_id_[receive_id] = channel;
|
|
}
|
|
|
|
channel->OnIncomingPacket(std::move(packet));
|
|
}
|
|
|
|
void ChannelMultiplexer::DoWrite(
|
|
std::unique_ptr<MultiplexPacket> packet,
|
|
base::OnceClosure done_task,
|
|
const net::NetworkTrafficAnnotationTag& traffic_annotation) {
|
|
writer_.Write(SerializeAndFrameMessage(*packet), std::move(done_task),
|
|
traffic_annotation);
|
|
}
|
|
|
|
} // namespace remoting::protocol
|