Skip to content

Commit

Permalink
add Base.:(==) and Base.hash for AbstractEnv and test nash_conv on Ku…
Browse files Browse the repository at this point in the history
…hnPokerEnv (#348)

* add Base.:(==) and Base.hash for AbstractEnv

* add nash_conv test for KuhnPokerEnv with TabularRandomPolicy and get_optimal_kuhn_policy

* add get_optimal_kuhn_policy for KuhnPokerEnv

* supplement Base.hash

* supplement and move Base.:(==) and Base.hash to RLBase/src/interface.jl

* add messages about Base.:(==) for AbstractEnv

* update test

* modify Base.:(==)

Co-authored-by: peter <51195500031@stu.ecnu.edu.cn>
  • Loading branch information
peterchen96 and peter authored Jul 8, 2021
1 parent a98ecfd commit 0b6fc51
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/ReinforcementLearningBase/src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,19 @@ Make an independent copy of `env`,
@api copy(env::AbstractEnv) = deepcopy(env)
@api copyto!(dest::AbstractEnv, src::AbstractEnv)

# checking the state of all players in env is enough?
"""
Base.:(==)(env1::T, env2::T) where T<:AbstractEnv
!!! warning
Only check the state of all players in the env.
"""
function Base.:(==)(env1::T, env2::T) where T<:AbstractEnv
len = length(players(env1))
len == length(players(env2)) &&
all(state(env1, player) == state(env2, player) for player in players(env1))
end
Base.hash(env::AbstractEnv, h::UInt) = hash([state(env, player) for player in players(env)], h)

@api nameof(env::AbstractEnv) = nameof(typeof(env))

"""
Expand Down
7 changes: 7 additions & 0 deletions src/ReinforcementLearningZoo/test/cfr/nash_conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,11 @@
env = OpenSpielEnv("kuhn_poker(players=4)")
p = TabularRandomPolicy()
@test RLZoo.nash_conv(p, env) 3.4760416666666663

env = KuhnPokerEnv()
p = TabularRandomPolicy()
@test RLZoo.nash_conv(p, env) == 11 / 12

p = get_optimal_kuhn_policy(env)
@test RLZoo.nash_conv(p, env) == 0.0
end
21 changes: 21 additions & 0 deletions src/ReinforcementLearningZoo/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Random
using StableRNGs
using OpenSpiel

# used for OpenSpielEnv("kuhn_poker")
function get_optimal_kuhn_policy= 0.2)
TabularRandomPolicy(
table = Dict(
Expand All @@ -28,6 +29,26 @@ function get_optimal_kuhn_policy(α = 0.2)
)
end

# used for julia version KuhnPokerGame
function get_optimal_kuhn_policy(env::KuhnPokerEnv; α = 0.2)
TabularRandomPolicy(
table = Dict(
(:J,) => [1 - α, α],
(:J, :pass, :bet) => [1.0, 0.0],
(:Q,) => [1.0, 0.0],
(:Q, :pass, :bet) => [2.0 / 3.0 - α, 1.0 / 3.0 + α],
(:K,) => [1 - 3 * α, 3 * α],
(:K, :pass, :bet) => [0.0, 1.0],
(:J, :pass) => [2.0 / 3.0, 1.0 / 3.0],
(:J, :bet) => [1.0, 0.0],
(:Q, :pass) => [1.0, 0.0],
(:Q, :bet) => [2.0 / 3.0, 1.0 / 3.0],
(:K, :pass) => [0.0, 1.0],
(:K, :bet) => [0.0, 1.0],
),
)
end

@testset "ReinforcementLearningZoo.jl" begin
include("cfr/cfr.jl")
end

0 comments on commit 0b6fc51

Please sign in to comment.