diff --git a/spec/std/concurrent/select_spec.cr b/spec/std/concurrent/select_spec.cr index cd598f1b0e77..233e1a0528db 100644 --- a/spec/std/concurrent/select_spec.cr +++ b/spec/std/concurrent/select_spec.cr @@ -6,7 +6,7 @@ private def yield_to(fiber) end describe "select" do - it "select many receviers" do + it "select many receivers" do ch1 = Channel(Int32).new ch2 = Channel(Int32).new res = [] of Int32 @@ -72,6 +72,35 @@ describe "select" do res.should eq (0...10).to_a end + it "select else clause and cancel other clauses" do + ch1 = Channel::Buffered(Int32).new(1) + ch2 = Channel::Buffered(Int32).new(1) + + select + when ch1.receive + got = 1 + when ch2.receive + got = 2 + else + got = -1 + end + + got.should eq(-1) + + spawn do + ch1.send(1) + ch2.send(2) + ch1.close + ch2.close + end + + ch1.receive.should eq(1) + ch1.receive?.should be_nil + + ch2.receive.should eq(2) + ch2.receive?.should be_nil + end + it "select should work with send which started before receive, fixed #3862" do ch1 = Channel(Int32).new ch2 = Channel(Int32).new @@ -100,4 +129,23 @@ describe "select" do sleep x.should eq 1 end + + it "won't enqueue a dead/running fiber, fixed #3900" do + ch = Channel::Buffered(Int32).new(1) + + spawn do + ch.send(1) + + select + when ch.send(1) + when ch.send(2) + end + + ch.close + end + + ch.receive.should eq(1) + ch.receive.should eq(1) + ch.receive?.should be_nil + end end diff --git a/src/channel.cr b/src/channel.cr index f283d3971b0b..1411e0c76866 100644 --- a/src/channel.cr +++ b/src/channel.cr @@ -2,10 +2,12 @@ require "fiber" abstract class Channel(T) module SelectAction + getter? canceled = false + getter? waiting = false abstract def ready? abstract def execute - abstract def wait - abstract def unwait + abstract def wait : Bool + abstract def unwait(fiber : Fiber) end class ClosedError < Exception @@ -61,16 +63,16 @@ abstract class Channel(T) @receivers << Fiber.current end - protected def unwait_for_receive - @receivers.delete Fiber.current + protected def unwait_for_receive(fiber) + @receivers.delete fiber end protected def wait_for_send @senders << Fiber.current end - protected def unwait_for_send - @senders.delete Fiber.current + protected def unwait_for_send(fiber) + @senders.delete fiber end protected def raise_if_closed @@ -94,26 +96,63 @@ abstract class Channel(T) nil end + # :nodoc: def self.select(*ops : SelectAction) self.select ops end + # :nodoc: + # + # Executes all operations inside its own fiber to wait in. Postpones the fiber + # execution so the *fibers* array will always be filled with all fibers, so + # any ready operation can cancel all other fibers ASAP. def self.select(ops : Tuple | Array, has_else = false) - loop do - ops.each_with_index do |op, index| - if op.ready? - result = op.execute - return index, result + main = Fiber.current + fibers = Array(Fiber).new(ops.size) + + waiting = 0 + index = -1 + value = nil + + ops.each_with_index do |op, i| + fibers << Fiber.new(name: i.to_s) do + loop do + break if op.canceled? + + if op.ready? + # cancel other fibers before executing the op, which could switch + # the current context: + cancel_select_actions(ops, fibers, i) + index, value = i, op.execute + Crystal::Scheduler.enqueue(main) + break + end + + if has_else && (waiting += 1) == ops.size + cancel_select_actions(ops, fibers, i) + index = ops.size + Crystal::Scheduler.enqueue(main) + break + end + + op.wait + Crystal::Scheduler.reschedule end end + end - if has_else - return ops.size, nil - end + Crystal::Scheduler.enqueue(fibers) + Crystal::Scheduler.reschedule - ops.each &.wait - Crystal::Scheduler.reschedule - ops.each &.unwait + {index, value} + end + + private def self.cancel_select_actions(ops, fibers, running_index) + ops.each_with_index do |op, i| + next if i == running_index + fiber = fibers[i] + op.unwait(fiber) + Crystal::Scheduler.enqueue(fiber) if op.waiting? end end @@ -128,7 +167,7 @@ abstract class Channel(T) end # :nodoc: - struct ReceiveAction(C) + class ReceiveAction(C) include SelectAction def initialize(@channel : C) @@ -144,15 +183,17 @@ abstract class Channel(T) def wait @channel.wait_for_receive + @waiting = true end - def unwait - @channel.unwait_for_receive + def unwait(fiber) + @canceled = true + @channel.unwait_for_receive(fiber) end end # :nodoc: - struct SendAction(C, T) + class SendAction(C, T) include SelectAction def initialize(@channel : C, @value : T) @@ -168,10 +209,12 @@ abstract class Channel(T) def wait @channel.wait_for_send + @waiting = true end - def unwait - @channel.unwait_for_send + def unwait(fiber) + @canceled = true + @channel.unwait_for_send(fiber) end end end