Skip to content

Commit

Permalink
inspector: check Host header for local connections
Browse files Browse the repository at this point in the history
PR-URL: https://github.com/nodejs-private/node-private/pull/102/
Reviewed-By: Ben Noordhuis <info@bnoordhuis.nl>
Reviewed-By: Сковорода Никита Андреевич <chalkerx@gmail.com>
  • Loading branch information
eugeneo authored and MylesBorins committed Mar 28, 2018
1 parent fad5dcc commit fc1a610
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 38 deletions.
120 changes: 90 additions & 30 deletions src/inspector_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

#include "openssl/sha.h" // Sha-1 hash

#include <map>
#include <string.h>
#include <vector>

#define ACCEPT_KEY_LENGTH base64_encoded_size(20)
#define BUFFER_GROWTH_CHUNK_SIZE 1024
Expand Down Expand Up @@ -63,7 +63,7 @@ class ProtocolHandler {
virtual void Write(const std::vector<char> data) = 0;
virtual void CancelHandshake() = 0;

std::string GetHost();
std::string GetHost() const;

InspectorSocket* inspector() {
return inspector_;
Expand Down Expand Up @@ -160,6 +160,48 @@ static void generate_accept_string(const std::string& client_key,
node::base64_encode(hash, sizeof(hash), *buffer, sizeof(*buffer));
}

static bool IsOneOf(const std::string& host,
const std::vector<std::string>& hosts) {
for (const std::string& candidate : hosts) {
if (node::StringEqualNoCase(host.data(), candidate.data()))
return true;
}
return false;
}

static std::string TrimPort(const std::string& host) {
size_t last_colon_pos = host.rfind(":");
if (last_colon_pos == std::string::npos)
return host;
size_t bracket = host.rfind("]");
if (bracket == std::string::npos || last_colon_pos > bracket)
return host.substr(0, last_colon_pos);
return host;
}

static bool IsIPAddress(const std::string& host) {
if (host.length() >= 4 && host.front() == '[' && host.back() == ']')
return true;
int quads = 0;
for (char c : host) {
if (c == '.')
quads++;
else if (!isdigit(c))
return false;
}
return quads == 3;
}

// This is a value coming from the interface, it can only be IPv4 or IPv6
// address string.
static bool IsIPv4Localhost(const std::string& host) {
std::string v6_tunnel_prefix = "::ffff:";
if (host.substr(0, v6_tunnel_prefix.length()) == v6_tunnel_prefix)
return IsIPv4Localhost(host.substr(v6_tunnel_prefix.length()));
std::string localhost_net = "127.";
return host.substr(0, localhost_net.length()) == localhost_net;
}

// Constants for hybi-10 frame format.

typedef int OpCode;
Expand Down Expand Up @@ -298,7 +340,6 @@ static ws_decode_result decode_frame_hybi17(const std::vector<char>& buffer,
return closed ? FRAME_CLOSE : FRAME_OK;
}


// WS protocol
class WsHandler : public ProtocolHandler {
public:
Expand Down Expand Up @@ -400,17 +441,16 @@ class WsHandler : public ProtocolHandler {
// HTTP protocol
class HttpEvent {
public:
HttpEvent(const std::string& path, bool upgrade,
bool isGET, const std::string& ws_key) : path(path),
upgrade(upgrade),
isGET(isGET),
ws_key(ws_key) { }
HttpEvent(const std::string& path, bool upgrade, bool isGET,
const std::string& ws_key, const std::string& host)
: path(path), upgrade(upgrade), isGET(isGET), ws_key(ws_key),
host(host) { }

std::string path;
bool upgrade;
bool isGET;
std::string ws_key;
std::string current_header_;
std::string host;
};

class HttpHandler : public ProtocolHandler {
Expand Down Expand Up @@ -472,18 +512,17 @@ class HttpHandler : public ProtocolHandler {
std::vector<HttpEvent> events;
std::swap(events, events_);
for (const HttpEvent& event : events) {
bool shouldContinue = event.isGET && !event.upgrade;
if (!event.isGET) {
if (!IsAllowedHost(event.host) || !event.isGET) {
CancelHandshake();
return;
} else if (!event.upgrade) {
delegate()->OnHttpGet(event.path);
} else if (event.ws_key.empty()) {
CancelHandshake();
return;
} else {
delegate()->OnSocketUpgrade(event.path, event.ws_key);
}
if (!shouldContinue)
return;
}
}

Expand All @@ -504,16 +543,9 @@ class HttpHandler : public ProtocolHandler {
}

static int OnHeaderValue(http_parser* parser, const char* at, size_t length) {
static const char SEC_WEBSOCKET_KEY_HEADER[] = "Sec-WebSocket-Key";
HttpHandler* handler = From(parser);
handler->parsing_value_ = true;
if (handler->current_header_.size() ==
sizeof(SEC_WEBSOCKET_KEY_HEADER) - 1 &&
node::StringEqualNoCaseN(handler->current_header_.data(),
SEC_WEBSOCKET_KEY_HEADER,
sizeof(SEC_WEBSOCKET_KEY_HEADER) - 1)) {
handler->ws_key_.append(at, length);
}
handler->headers_[handler->current_header_].append(at, length);
return 0;
}

Expand All @@ -540,23 +572,53 @@ class HttpHandler : public ProtocolHandler {
static int OnMessageComplete(http_parser* parser) {
// Event needs to be fired after the parser is done.
HttpHandler* handler = From(parser);
handler->events_.push_back(HttpEvent(handler->path_, parser->upgrade,
parser->method == HTTP_GET,
handler->ws_key_));
handler->events_.push_back(
HttpEvent(handler->path_, parser->upgrade, parser->method == HTTP_GET,
handler->HeaderValue("Sec-WebSocket-Key"),
handler->HeaderValue("Host")));
handler->path_ = "";
handler->ws_key_ = "";
handler->parsing_value_ = false;
handler->headers_.clear();
handler->current_header_ = "";

return 0;
}

std::string HeaderValue(const std::string& header) const {
bool header_found = false;
std::string value;
for (const auto& header_value : headers_) {
if (node::StringEqualNoCaseN(header_value.first.data(), header.data(),
header.length())) {
if (header_found)
return "";
value = header_value.second;
header_found = true;
}
}
return value;
}

bool IsAllowedHost(const std::string& host_with_port) const {
std::string host = TrimPort(host_with_port);
if (host.empty())
return false;
if (IsIPAddress(host))
return true;
std::string socket_host = GetHost();
if (IsIPv4Localhost(socket_host)) {
return IsOneOf(host, { "localhost" });
} else if (socket_host == "::1") {
return IsOneOf(host, { "localhost", "localhost6" });
}
return true;
}

bool parsing_value_;
http_parser parser_;
http_parser_settings parser_settings;
std::vector<HttpEvent> events_;
std::string current_header_;
std::string ws_key_;
std::map<std::string, std::string> headers_;
std::string path_;
};

Expand All @@ -579,7 +641,7 @@ InspectorSocket::Delegate* ProtocolHandler::delegate() {
return tcp_->delegate();
}

std::string ProtocolHandler::GetHost() {
std::string ProtocolHandler::GetHost() const {
char ip[INET6_ADDRSTRLEN];
sockaddr_storage addr;
int len = sizeof(addr);
Expand Down Expand Up @@ -622,8 +684,6 @@ TcpHolder::Pointer TcpHolder::Accept(
if (err == 0) {
return { result, DisconnectAndDispose };
} else {
fprintf(stderr, "[%s:%d@%s]\n", __FILE__, __LINE__, __FUNCTION__);

delete result;
return { nullptr, nullptr };
}
Expand Down
47 changes: 39 additions & 8 deletions test/cctest/test_inspector_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ struct read_expects {
};

static const char HANDSHAKE_REQ[] = "GET /ws/path HTTP/1.1\r\n"
"Host: localhost:9222\r\n"
"Host: localhost:9229\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Key: aaa==\r\n"
Expand Down Expand Up @@ -504,7 +504,7 @@ TEST_F(InspectorSocketTest, ExtraTextBeforeRequest) {

TEST_F(InspectorSocketTest, RequestWithoutKey) {
const char BROKEN_REQUEST[] = "GET / HTTP/1.1\r\n"
"Host: localhost:9222\r\n"
"Host: localhost:9229\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Version: 13\r\n\r\n";
Expand Down Expand Up @@ -619,24 +619,23 @@ TEST_F(InspectorSocketTest, ReportsHttpGet) {
delegate->SetDelegate(ReportsHttpGet_handshake);

const char GET_REQ[] = "GET /some/path HTTP/1.1\r\n"
"Host: localhost:9222\r\n"
"Host: localhost:9229\r\n"
"Sec-WebSocket-Key: aaa==\r\n"
"Sec-WebSocket-Version: 13\r\n\r\n";
send_in_chunks(GET_REQ, sizeof(GET_REQ) - 1);

expect_nothing_on_client();

const char WRITE_REQUEST[] = "GET /respond/withtext HTTP/1.1\r\n"
"Host: localhost:9222\r\n\r\n";
"Host: localhost:9229\r\n\r\n";
send_in_chunks(WRITE_REQUEST, sizeof(WRITE_REQUEST) - 1);

expect_on_client(TEST_SUCCESS, sizeof(TEST_SUCCESS) - 1);
const char GET_REQS[] = "GET /some/path2 HTTP/1.1\r\n"
"Host: localhost:9222\r\n"
"Host: localhost:9229\r\n"
"Sec-WebSocket-Key: aaa==\r\n"
"Sec-WebSocket-Version: 13\r\n\r\n"
"GET /close HTTP/1.1\r\n"
"Host: localhost:9222\r\n"
"Host: localhost:9229\r\n"
"Sec-WebSocket-Key: aaa==\r\n"
"Sec-WebSocket-Version: 13\r\n\r\n";
send_in_chunks(GET_REQS, sizeof(GET_REQS) - 1);
Expand Down Expand Up @@ -696,7 +695,7 @@ static void GetThenHandshake_handshake(enum inspector_handshake_event state,
TEST_F(InspectorSocketTest, GetThenHandshake) {
delegate->SetDelegate(GetThenHandshake_handshake);
const char WRITE_REQUEST[] = "GET /respond/withtext HTTP/1.1\r\n"
"Host: localhost:9222\r\n\r\n";
"Host: localhost:9229\r\n\r\n";
send_in_chunks(WRITE_REQUEST, sizeof(WRITE_REQUEST) - 1);

expect_on_client(TEST_SUCCESS, sizeof(TEST_SUCCESS) - 1);
Expand Down Expand Up @@ -826,4 +825,36 @@ TEST_F(InspectorSocketTest, NoCloseResponseFromClient) {
delegate->WaitForDispose();
}

static bool delegate_called = false;

void shouldnt_be_called(enum inspector_handshake_event state,
const std::string& path, bool* cont) {
delegate_called = true;
}

void expect_failure_no_delegate(const std::string& request) {
delegate->SetDelegate(shouldnt_be_called);
delegate_called = false;
send_in_chunks(request.c_str(), request.length());
expect_handshake_failure();
SPIN_WHILE(delegate != nullptr);
ASSERT_FALSE(delegate_called);
}

TEST_F(InspectorSocketTest, HostCheckedForGET) {
const char GET_REQUEST[] = "GET /respond/withtext HTTP/1.1\r\n"
"Host: notlocalhost:9229\r\n\r\n";
expect_failure_no_delegate(GET_REQUEST);
}

TEST_F(InspectorSocketTest, HostCheckedForUPGRADE) {
const char UPGRADE_REQUEST[] = "GET /ws/path HTTP/1.1\r\n"
"Host: nonlocalhost:9229\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Key: aaa==\r\n"
"Sec-WebSocket-Version: 13\r\n\r\n";
expect_failure_no_delegate(UPGRADE_REQUEST);
}

} // anonymous namespace

0 comments on commit fc1a610

Please sign in to comment.