Skip to content

Commit

Permalink
merge bitcoin#24357: make setsockopt() and SetSocketNoDelay() mockabl…
Browse files Browse the repository at this point in the history
…e/testable
  • Loading branch information
kwvg committed Jun 11, 2024
1 parent 9c751ef commit 6b159f1
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 20 deletions.
21 changes: 17 additions & 4 deletions src/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,11 @@ void CConnman::CreateNodeFromAcceptedSocket(std::unique_ptr<Sock>&& sock,

// According to the internet TCP_NODELAY is not carried into accepted sockets
// on all platforms. Set it again here just to be sure.
SetSocketNoDelay(sock->Get());
const int on{1};
if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) {
LogPrint(BCLog::NET, "connection from %s: unable to set TCP_NODELAY, continuing anyway\n",
addr.ToString());
}

// Don't accept connections from banned peers.
bool banned = m_banman && m_banman->IsBanned(addr);
Expand Down Expand Up @@ -3219,17 +3223,26 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError,

// Allow binding if the port is still in TIME_WAIT state after
// the program was closed and restarted.
setsockopt(sock->Get(), SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int));
if (sock->SetSockOpt(SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int)) == SOCKET_ERROR) {
strError = strprintf(Untranslated("Error setting SO_REUSEADDR on socket: %s, continuing anyway"), NetworkErrorString(WSAGetLastError()));
LogPrintf("%s\n", strError.original);
}

// some systems don't have IPV6_V6ONLY but are always v6only; others do have the option
// and enable it by default or not. Try to enable it, if possible.
if (addrBind.IsIPv6()) {
#ifdef IPV6_V6ONLY
setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int));
if (sock->SetSockOpt(IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int)) == SOCKET_ERROR) {
strError = strprintf(Untranslated("Error setting IPV6_V6ONLY on socket: %s, continuing anyway"), NetworkErrorString(WSAGetLastError()));
LogPrintf("%s\n", strError.original);
}
#endif
#ifdef WIN32
int nProtLevel = PROTECTION_LEVEL_UNRESTRICTED;
setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int));
if (sock->SetSockOpt(IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int)) == SOCKET_ERROR) {
strError = strprintf(Untranslated("Error setting IPV6_PROTECTION_LEVEL on socket: %s, continuing anyway"), NetworkErrorString(WSAGetLastError()));
LogPrintf("%s\n", strError.original);
}
#endif
}

Expand Down
27 changes: 13 additions & 14 deletions src/netbase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,10 +498,11 @@ std::unique_ptr<Sock> CreateSockTCP(const CService& address_family)
return nullptr;
}

auto sock = std::make_unique<Sock>(hSocket);

// Ensure that waiting for I/O on this socket won't result in undefined
// behavior.
if (!IsSelectableSocket(hSocket)) {
CloseSocket(hSocket);
if (!IsSelectableSocket(sock->Get())) {
LogPrintf("Cannot create connection: non-selectable socket created (fd >= FD_SETSIZE ?)\n");
return nullptr;
}
Expand All @@ -510,19 +511,24 @@ std::unique_ptr<Sock> CreateSockTCP(const CService& address_family)
int set = 1;
// Set the no-sigpipe option on the socket for BSD systems, other UNIXes
// should use the MSG_NOSIGNAL flag for every send.
setsockopt(hSocket, SOL_SOCKET, SO_NOSIGPIPE, (void*)&set, sizeof(int));
if (sock->SetSockOpt(SOL_SOCKET, SO_NOSIGPIPE, (void*)&set, sizeof(int)) == SOCKET_ERROR) {
LogPrintf("Error setting SO_NOSIGPIPE on socket: %s, continuing anyway\n",
NetworkErrorString(WSAGetLastError()));
}
#endif

// Set the no-delay option (disable Nagle's algorithm) on the TCP socket.
SetSocketNoDelay(hSocket);
const int on{1};
if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) {
LogPrint(BCLog::NET, "Unable to set TCP_NODELAY on a newly created socket, continuing anyway\n");
}

// Set the non-blocking option on the socket.
if (!SetSocketNonBlocking(hSocket)) {
CloseSocket(hSocket);
if (!SetSocketNonBlocking(sock->Get())) {
LogPrintf("Error setting socket to non-blocking: %s\n", NetworkErrorString(WSAGetLastError()));
return nullptr;
}
return std::make_unique<Sock>(hSocket);
return sock;
}

std::function<std::unique_ptr<Sock>(const CService&)> CreateSock = CreateSockTCP;
Expand Down Expand Up @@ -729,13 +735,6 @@ bool SetSocketNonBlocking(const SOCKET& hSocket)
return true;
}

bool SetSocketNoDelay(const SOCKET& hSocket)
{
int set = 1;
int rc = setsockopt(hSocket, IPPROTO_TCP, TCP_NODELAY, (const char*)&set, sizeof(int));
return rc == 0;
}

void InterruptSocks5(bool interrupt)
{
interruptSocks5Recv = interrupt;
Expand Down
2 changes: 0 additions & 2 deletions src/netbase.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,6 @@ bool ConnectThroughProxy(const Proxy& proxy, const std::string& strDest, uint16_

/** Enable non-blocking mode for a socket */
bool SetSocketNonBlocking(const SOCKET& hSocket);
/** Set the TCP_NODELAY flag on a socket */
bool SetSocketNoDelay(const SOCKET& hSocket);
void InterruptSocks5(bool interrupt);

/**
Expand Down
13 changes: 13 additions & 0 deletions src/test/fuzz/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,19 @@ int FuzzedSock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* op
return 0;
}

int FuzzedSock::SetSockOpt(int, int, const void*, socklen_t) const
{
constexpr std::array setsockopt_errnos{
ENOMEM,
ENOBUFS,
};
if (m_fuzzed_data_provider.ConsumeBool()) {
SetFuzzedErrNo(m_fuzzed_data_provider, setsockopt_errnos);
return -1;
}
return 0;
}

bool FuzzedSock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
{
constexpr std::array wait_errnos{
Expand Down
2 changes: 2 additions & 0 deletions src/test/fuzz/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class FuzzedSock : public Sock

int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override;

int SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const override;

bool Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred = nullptr) const override;

bool IsConnected(std::string& errmsg) const override;
Expand Down
2 changes: 2 additions & 0 deletions src/test/util/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ class StaticContentsSock : public Sock
return 0;
}

int SetSockOpt(int, int, const void*, socklen_t) const override { return 0; }

bool Wait(std::chrono::milliseconds timeout,
Event requested,
Event* occurred = nullptr) const override
Expand Down
5 changes: 5 additions & 0 deletions src/util/sock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len)
return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);
}

int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const
{
return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len);
}

bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
{
#ifdef USE_POLL
Expand Down
10 changes: 10 additions & 0 deletions src/util/sock.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,16 @@ class Sock
void* opt_val,
socklen_t* opt_len) const;

/**
* setsockopt(2) wrapper. Equivalent to
* `setsockopt(this->Get(), level, opt_name, opt_val, opt_len)`. Code that uses this
* wrapper can be unit tested if this method is overridden by a mock Sock implementation.
*/
[[nodiscard]] virtual int SetSockOpt(int level,
int opt_name,
const void* opt_val,
socklen_t opt_len) const;

using Event = uint8_t;

/**
Expand Down

0 comments on commit 6b159f1

Please sign in to comment.