Skip to content

Commit

Permalink
Improve WebSocket handshake validation (crystal-lang#5327)
Browse files Browse the repository at this point in the history
  • Loading branch information
straight-shoota authored and chris-huxtable committed Jun 6, 2018
1 parent 4309343 commit 7504580
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 21 deletions.
80 changes: 72 additions & 8 deletions spec/std/http/server/handlers/websocket_handler_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ describe HTTP::WebSocketHandler do
io = IO::Memory.new

headers = HTTP::Headers{
"Upgrade" => "WS",
"Connection" => "Upgrade",
"Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==",
"Upgrade" => "WS",
"Connection" => "Upgrade",
"Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version" => "13",
}
request = HTTP::Request.new("GET", "/", headers: headers)
response = HTTP::Server::Response.new(io)
Expand All @@ -47,6 +48,7 @@ describe HTTP::WebSocketHandler do
"Upgrade" => "websocket",
"Connection" => {{connection}},
"Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version" => "13",
}
request = HTTP::Request.new("GET", "/", headers: headers)
response = HTTP::Server::Response.new(io)
Expand All @@ -63,16 +65,17 @@ describe HTTP::WebSocketHandler do

response.close

io.to_s.should eq("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-Websocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n")
io.to_s.should eq("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n")
end
{% end %}

it "gives upgrade response for case-insensitive 'WebSocket' upgrade request" do
io = IO::Memory.new
headers = HTTP::Headers{
"Upgrade" => "WebSocket",
"Connection" => "Upgrade",
"Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==",
"Upgrade" => "WebSocket",
"Connection" => "Upgrade",
"Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version" => "13",
}
request = HTTP::Request.new("GET", "/", headers: headers)
response = HTTP::Server::Response.new(io)
Expand All @@ -89,6 +92,67 @@ describe HTTP::WebSocketHandler do

response.close

io.to_s.should eq("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-Websocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n")
io.to_s.should eq("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n")
end

it "returns bad request if Sec-WebSocket-Key is missing" do
io = IO::Memory.new

headers = HTTP::Headers{
"Upgrade" => "websocket",
"Connection" => "Upgrade",
"Sec-WebSocket-Version" => "13",
}
request = HTTP::Request.new("GET", "/", headers: headers)
response = HTTP::Server::Response.new(io)
context = HTTP::Server::Context.new(request, response)

handler = HTTP::WebSocketHandler.new { }
handler.call context

response.close

io.to_s.should eq("HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n")
end

it "returns upgrade required if Sec-WebSocket-Version is missing" do
io = IO::Memory.new

headers = HTTP::Headers{
"Upgrade" => "websocket",
"Connection" => "Upgrade",
"Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==",
}
request = HTTP::Request.new("GET", "/", headers: headers)
response = HTTP::Server::Response.new(io)
context = HTTP::Server::Context.new(request, response)

handler = HTTP::WebSocketHandler.new { }
handler.call context

response.close

io.to_s.should eq("HTTP/1.1 426 Upgrade Required\r\nSec-WebSocket-Version: 13\r\nContent-Length: 0\r\n\r\n")
end

it "returns upgrade required if Sec-WebSocket-Version is invalid" do
io = IO::Memory.new

headers = HTTP::Headers{
"Upgrade" => "websocket",
"Connection" => "Upgrade",
"Sec-WebSocket-Key" => "dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version" => "12",
}
request = HTTP::Request.new("GET", "/", headers: headers)
response = HTTP::Server::Response.new(io)
context = HTTP::Server::Context.new(request, response)

handler = HTTP::WebSocketHandler.new { }
handler.call context

response.close

io.to_s.should eq("HTTP/1.1 426 Upgrade Required\r\nSec-WebSocket-Version: 13\r\nContent-Length: 0\r\n\r\n")
end
end
55 changes: 55 additions & 0 deletions spec/std/http/web_socket_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,61 @@ describe HTTP::WebSocket do
ws2.run
end

it "handshake fails if server does not switch protocols" do
port_chan = Channel(Int32).new
spawn do
http_ref = nil
http_server = http_ref = HTTP::Server.new(0) do |context|
context.response.status_code = 200
end

http_server.bind
port_chan.send(http_server.port)
http_server.listen

http_ref.not_nil!.close
end

listen_port = port_chan.receive

expect_raises(Socket::Error, "Handshake got denied. Status code was 200.") do
HTTP::WebSocket::Protocol.new("127.0.0.1", port: listen_port, path: "/")
end
end

it "handshake fails if server does not verify Sec-WebSocket-Key" do
port_chan = Channel(Int32).new
spawn do
http_ref = nil
has_been_called = false

http_server = http_ref = HTTP::Server.new(0) do |context|
response = context.response
response.status_code = 101
response.headers["Upgrade"] = "websocket"
response.headers["Connection"] = "Upgrade"
if has_been_called
response.headers["Sec-WebSocket-Accept"] = "foobar"
http_ref.not_nil!.close
else
has_been_called = true
end
end

http_server.bind
port_chan.send(http_server.port)
http_server.listen
end

listen_port = port_chan.receive

2.times do
expect_raises(Socket::Error, "Handshake got denied. Server did not verify WebSocket challenge.") do
HTTP::WebSocket::Protocol.new("127.0.0.1", port: listen_port, path: "/")
end
end
end

typeof(HTTP::WebSocket.new(URI.parse("ws://localhost")))
typeof(HTTP::WebSocket.new("localhost", "/"))
typeof(HTTP::WebSocket.new("ws://localhost"))
Expand Down
26 changes: 17 additions & 9 deletions src/http/server/handlers/websocket_handler.cr
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,28 @@ class HTTP::WebSocketHandler

def call(context)
if websocket_upgrade_request? context.request
key = context.request.headers["Sec-Websocket-Key"]
response = context.response

version = context.request.headers["Sec-WebSocket-Version"]?
unless version == WebSocket::Protocol::VERSION
response.status_code = 426
response.headers["Sec-WebSocket-Version"] = WebSocket::Protocol::VERSION
return
end

accept_code =
{% if flag?(:without_openssl) %}
Digest::SHA1.base64digest("#{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
{% else %}
Base64.strict_encode(OpenSSL::SHA1.hash("#{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
{% end %}
key = context.request.headers["Sec-WebSocket-Key"]?

unless key
response.status_code = 400
return
end

accept_code = WebSocket::Protocol.key_challenge(key)

response = context.response
response.status_code = 101
response.headers["Upgrade"] = "websocket"
response.headers["Connection"] = "Upgrade"
response.headers["Sec-Websocket-Accept"] = accept_code
response.headers["Sec-WebSocket-Accept"] = accept_code
response.upgrade do |io|
ws_session = WebSocket.new(io)
@proc.call(ws_session, context)
Expand Down
23 changes: 19 additions & 4 deletions src/http/web_socket/protocol.cr
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class HTTP::WebSocket::Protocol
end

MASK_BIT = 128_u8
VERSION = 13
VERSION = "13"

record PacketInfo,
opcode : Opcode,
Expand Down Expand Up @@ -257,19 +257,26 @@ class HTTP::WebSocket::Protocol
end
{% end %}

random_key = Base64.strict_encode(StaticArray(UInt8, 16).new { rand(256).to_u8 })

headers["Host"] = "#{host}:#{port}"
headers["Connection"] = "Upgrade"
headers["Upgrade"] = "websocket"
headers["Sec-WebSocket-Version"] = VERSION.to_s
headers["Sec-WebSocket-Key"] = Base64.strict_encode(StaticArray(UInt8, 16).new { rand(256).to_u8 })
headers["Sec-WebSocket-Version"] = VERSION
headers["Sec-WebSocket-Key"] = random_key

path = "/" if path.empty?
handshake = HTTP::Request.new("GET", path, headers)
handshake.to_io(socket)
socket.flush
handshake_response = HTTP::Client::Response.from_io(socket)
unless handshake_response.status_code == 101
raise Socket::Error.new("Handshake got denied. Status code was #{handshake_response.status_code}")
raise Socket::Error.new("Handshake got denied. Status code was #{handshake_response.status_code}.")
end

challenge_response = Protocol.key_challenge(random_key)
unless handshake_response.headers["Sec-WebSocket-Accept"]? == challenge_response
raise Socket::Error.new("Handshake got denied. Server did not verify WebSocket challenge.")
end

new(socket, masked: true)
Expand All @@ -285,4 +292,12 @@ class HTTP::WebSocket::Protocol

raise ArgumentError.new("No host or path specified which are required.")
end

def self.key_challenge(key)
{% if flag?(:without_openssl) %}
Digest::SHA1.base64digest("#{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
{% else %}
Base64.strict_encode(OpenSSL::SHA1.hash("#{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
{% end %}
end
end

0 comments on commit 7504580

Please sign in to comment.