diff --git a/spec/amqproxy_spec.cr b/spec/amqproxy_spec.cr index c72dfb0..0b5c215 100644 --- a/spec/amqproxy_spec.cr +++ b/spec/amqproxy_spec.cr @@ -8,11 +8,12 @@ describe AMQProxy::Server do Fiber.yield 10.times do AMQP::Client.start("amqp://localhost:5673") do |conn| - conn.channel + ch = conn.channel + ch.basic_publish "foobar", "amq.fanout", "" s.client_connections.should eq 1 s.upstream_connections.should eq 1 end - sleep 0.1 + # sleep 0.1 end s.client_connections.should eq 0 s.upstream_connections.should eq 1 @@ -64,8 +65,9 @@ describe AMQProxy::Server do begin spawn { s.listen("127.0.0.1", 5673) } Fiber.yield - AMQP::Client.start("amqp://localhost:5673?channel_max=#{UInt16::MAX}") do |conn| - conn.channel_max.should eq UInt16::MAX + max = 4000 + AMQP::Client.start("amqp://localhost:5673?channel_max=#{max}") do |conn| + conn.channel_max.should eq max conn.channel_max.times do conn.channel end diff --git a/src/amqproxy/channel_pool.cr b/src/amqproxy/channel_pool.cr index 03e3edc..44da5e0 100644 --- a/src/amqproxy/channel_pool.cr +++ b/src/amqproxy/channel_pool.cr @@ -5,7 +5,6 @@ require "./upstream" module AMQProxy class ChannelPool - getter size = 0 @tls_ctx : OpenSSL::SSL::Context::Client? @log : Logger @lock = Mutex.new @@ -15,10 +14,10 @@ module AMQProxy def initialize(@host : String, @port : Int32, tls : Bool, @log, @idle_connection_timeout : Int32) @tls_ctx = OpenSSL::SSL::Context::Client.new if tls - @upstream_channel_channel = Hash(Credentials, Channel(Tuple(Upstream, UInt16))).new do |h, k| + @upstream_channel_channel = Hash(Credentials, Channel(Tuple(Upstream, UInt16))).new do |h, credentials| chan = Channel(Tuple(Upstream, UInt16)).new(128) - spawn pool_loop(k, chan) - h[k] = chan + spawn pool_loop(credentials, chan) + h[credentials] = chan end end @@ -27,12 +26,14 @@ module AMQProxy loop do upstream = Upstream.new(@host, @port, @tls_ctx, @log, credentials) @upstreams[credentials] << upstream - loop do - channel = upstream.open_channel - upstream_channel_channel.send({upstream, channel}) - rescue Upstream::ChannelMaxReached - break + spawn(name: "upstream read loop #{@host}:#{@port}") do + begin + upstream.read_loop # blocks until upstream closes connection + ensure + @upstreams[credentials].delete(upstream) + end end + upstream.channel_loop(upstream_channel_channel) rescue ex next end @@ -59,7 +60,7 @@ module AMQProxy end end end - @size = 0 + @upstreams.clear end end @@ -72,14 +73,12 @@ module AMQProxy q.size.times do u = q.shift if u.last_used < max_connection_age - @size -= 1 begin u.close "Pooled connection closed due to inactivity" rescue ex @log.error "Problem closing upstream: #{ex.inspect}" end elsif u.closed? - @size -= 1 @log.error "Removing closed upstream connection from pool" else q.push u diff --git a/src/amqproxy/client.cr b/src/amqproxy/client.cr index 9255f2f..6d3255c 100644 --- a/src/amqproxy/client.cr +++ b/src/amqproxy/client.cr @@ -40,7 +40,6 @@ module AMQProxy when AMQ::Protocol::Frame::Connection::Close close_all_upstream_channels write AMQ::Protocol::Frame::Connection::CloseOk.new - return when AMQ::Protocol::Frame::Connection::CloseOk return when AMQ::Protocol::Frame::Channel::Open @@ -54,6 +53,8 @@ module AMQProxy upstream.unassign_channel(upstream_channel) end write AMQ::Protocol::Frame::Channel::CloseOk.new(frame.channel) + when AMQ::Protocol::Frame::Channel::CloseOk + # noop when frame.channel.zero? raise "unexpected connection frame: #{frame}" else @@ -68,12 +69,11 @@ module AMQProxy end end end - rescue ex : IO::EOFError + rescue ex : IO::Error raise Error.new("Client disconnected", ex) unless socket.closed? - rescue ex - raise ReadError.new "Client read error", ex ensure - close_socket + @outgoing_frames.close + close_all_upstream_channels end private def write_loop(socket = @socket) @@ -86,7 +86,9 @@ module AMQProxy rescue ex : IO::Error raise ex unless socket.closed? ensure - close_socket + @outgoing_frames.close + socket.close rescue nil + close_all_upstream_channels end # Send frame to client, channel id should already be remapped by the caller @@ -98,15 +100,11 @@ module AMQProxy write AMQ::Protocol::Frame::Channel::Close.new(id, 500_u16, "UPSTREAM_DISCONNECTED", 0_u16, 0_u16) end - private def close_socket - @outgoing_frames.close - socket.close rescue nil - close_all_upstream_channels - end - private def close_all_upstream_channels @channel_map.each_value do |upstream, upstream_channel| - upstream.close_channel(upstream_channel) + upstream.unassign_channel(upstream_channel) + rescue Upstream::WriteError + next # Nothing to do end @channel_map.clear end @@ -126,10 +124,6 @@ module AMQProxy 0_u16, 0_u16) end - def close_socket - @socket.close rescue nil - end - def negotiate(socket = @socket) proto = uninitialized UInt8[8] socket.read_fully(proto.to_slice) diff --git a/src/amqproxy/upstream.cr b/src/amqproxy/upstream.cr index e5a87c8..55c02f8 100644 --- a/src/amqproxy/upstream.cr +++ b/src/amqproxy/upstream.cr @@ -31,16 +31,30 @@ module AMQProxy tcp_socket end @channel_max = start(credentials) - spawn read_loop, name: "upstream read loop #{@host}:#{@port}" + end + + def channel_loop(upstream_channel_channel) : Nil + loop do + upstream_channel_channel.send({self, open_channel}) + rescue Upstream::ChannelMaxReached + break + end end def open_channel : UInt16 @channels_lock.synchronize do 1_u16.upto(@channel_max) do |i| - next if @channels.has_key? i - @channels[i] = nil - send AMQ::Protocol::Frame::Channel::Open.new(i) - return i + if @channels.has_key?(i) + if @channels[i].nil? + return i # reuse + else + next # in use + end + else + @channels[i] = nil + send AMQ::Protocol::Frame::Channel::Open.new(i) + return i + end end raise ChannelMaxReached.new end @@ -54,30 +68,26 @@ module AMQProxy def unassign_channel(channel : UInt16) @channels_lock.synchronize do - if @unsafe_channels.includes? channel - close_channel(channel) + if @unsafe_channels.delete channel + send AMQ::Protocol::Frame::Channel::Close.new(channel, 0u16, "", 0u16, 0u16) + @channels.delete channel else @channels[channel] = nil end end end - def close_channel(channel : UInt16) - @channels_lock.synchronize do - send AMQ::Protocol::Frame::Channel::Close.new(channel, 0u16, "", 0u16, 0u16) - @channels.delete channel - @unsafe_channels.delete channel - end - end - # Frames from upstream (to client) - private def read_loop(socket = @socket) + def read_loop(socket = @socket) loop do case frame = AMQ::Protocol::Frame.from_io(socket, IO::ByteFormat::NetworkEndian) when AMQ::Protocol::Frame::Heartbeat then send frame when AMQ::Protocol::Frame::Connection::Close close_all_client_channels - send AMQ::Protocol::Frame::Connection::CloseOk.new + begin + send AMQ::Protocol::Frame::Connection::CloseOk.new + rescue WriteError + end return when AMQ::Protocol::Frame::Connection::CloseOk then return when AMQ::Protocol::Frame::Channel::OpenOk