Skip to content

Commit

Permalink
merge bitcoin#25426: add new method Sock::GetSockName() that wraps ge…
Browse files Browse the repository at this point in the history
…tsockname() and use it in GetBindAddress()
  • Loading branch information
kwvg committed Jun 11, 2024
1 parent 6b159f1 commit be19868
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,13 +423,13 @@ bool CConnman::CheckIncomingNonce(uint64_t nonce)
}

/** Get the bind address for a socket as CAddress */
static CAddress GetBindAddress(SOCKET sock)
static CAddress GetBindAddress(const Sock& sock)
{
CAddress addr_bind;
struct sockaddr_storage sockaddr_bind;
socklen_t sockaddr_bind_len = sizeof(sockaddr_bind);
if (sock != INVALID_SOCKET) {
if (!getsockname(sock, (struct sockaddr*)&sockaddr_bind, &sockaddr_bind_len)) {
if (sock.Get() != INVALID_SOCKET) {
if (!sock.GetSockName((struct sockaddr*)&sockaddr_bind, &sockaddr_bind_len)) {
addr_bind.SetSockAddr((const struct sockaddr*)&sockaddr_bind);
} else {
LogPrint(BCLog::NET, "Warning: getsockname failed\n");
Expand Down Expand Up @@ -572,7 +572,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo
NodeId id = GetNewNodeId();
uint64_t nonce = GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(id).Finalize();
if (!addr_bind.IsValid()) {
addr_bind = GetBindAddress(sock->Get());
addr_bind = GetBindAddress(*sock);
}
CNode* pnode = new CNode(id,
nLocalServices,
Expand Down Expand Up @@ -1248,7 +1248,7 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket, CMasternodeSy
addr = CAddress{MaybeFlipIPv6toCJDNS(addr), NODE_NONE};
}

const CAddress addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(sock->Get())), NODE_NONE};
const CAddress addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(*sock)), NODE_NONE};

NetPermissionFlags permissionFlags = NetPermissionFlags::None;
hListenSocket.AddSocketPermissionFlags(permissionFlags);
Expand Down
14 changes: 14 additions & 0 deletions src/test/fuzz/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,20 @@ int FuzzedSock::SetSockOpt(int, int, const void*, socklen_t) const
return 0;
}

int FuzzedSock::GetSockName(sockaddr* name, socklen_t* name_len) const
{
constexpr std::array getsockname_errnos{
ECONNRESET,
ENOBUFS,
};
if (m_fuzzed_data_provider.ConsumeBool()) {
SetFuzzedErrNo(m_fuzzed_data_provider, getsockname_errnos);
return -1;
}
*name_len = m_fuzzed_data_provider.ConsumeData(name, *name_len);
return 0;
}

bool FuzzedSock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
{
constexpr std::array wait_errnos{
Expand Down
2 changes: 2 additions & 0 deletions src/test/fuzz/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class FuzzedSock : public Sock

int SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const override;

int GetSockName(sockaddr* name, socklen_t* name_len) const override;

bool Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred = nullptr) const override;

bool IsConnected(std::string& errmsg) const override;
Expand Down
6 changes: 6 additions & 0 deletions src/test/util/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ class StaticContentsSock : public Sock

int SetSockOpt(int, int, const void*, socklen_t) const override { return 0; }

int GetSockName(sockaddr* name, socklen_t* name_len) const override
{
std::memset(name, 0x0, *name_len);
return 0;
}

bool Wait(std::chrono::milliseconds timeout,
Event requested,
Event* occurred = nullptr) const override
Expand Down
5 changes: 5 additions & 0 deletions src/util/sock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt
return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len);
}

int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const
{
return getsockname(m_socket, name, name_len);
}

bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
{
#ifdef USE_POLL
Expand Down
7 changes: 7 additions & 0 deletions src/util/sock.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ class Sock
const void* opt_val,
socklen_t opt_len) const;

/**
* getsockname(2) wrapper. Equivalent to
* `getsockname(this->Get(), name, name_len)`. Code that uses this
* wrapper can be unit tested if this method is overridden by a mock Sock implementation.
*/
[[nodiscard]] virtual int GetSockName(sockaddr* name, socklen_t* name_len) const;

using Event = uint8_t;

/**
Expand Down

0 comments on commit be19868

Please sign in to comment.