Skip to content

Commit

Permalink
removed unsafe_arrays and uviews
Browse files Browse the repository at this point in the history
  • Loading branch information
klowrey committed Sep 16, 2021
1 parent 3144694 commit 1bf2191
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 149 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Shapes = "175de200-b73b-11e9-28b7-9b5b306cec37"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnsafeArrays = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6"

[compat]
Distances = "0.8"
Expand All @@ -26,7 +25,6 @@ MuJoCo = "0.3"
Reexport = "0.2"
Shapes = "0.2"
StaticArrays = "0.12"
UnsafeArrays = "1"
julia = "1.3"

[extras]
Expand Down
1 change: 0 additions & 1 deletion src/LyceumMuJoCo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ using LinearAlgebra
using Random

# 3rd party
using UnsafeArrays
using StaticArrays
using Distributions
using Reexport
Expand Down
29 changes: 9 additions & 20 deletions src/dmc/cartpole_swingup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,19 @@ end


@inline getsim(env::CartpoleSwingup) = env.sim


@inline obsspace(env::CartpoleSwingup) = env.obsspace

@inline function getobs!(obs, env::CartpoleSwingup)
@boundscheck checkaxes(obsspace(env), obs)

@uviews obs begin
sobs = obsspace(env)(obs)
sobs.pos.cart = env.sim.dn.qpos[:slider]
sobs.pos.pole_zz = env.sim.dn.xmat[:z, :z, :pole_1]
sobs.pos.pole_xz = env.sim.dn.xmat[:x, :z, :pole_1]
copyto!(sobs.vel, env.sim.d.qvel)
end
sobs = obsspace(env)(obs)
sobs.pos.cart = env.sim.dn.qpos[:slider]
sobs.pos.pole_zz = env.sim.dn.xmat[:z, :z, :pole_1]
sobs.pos.pole_xz = env.sim.dn.xmat[:x, :z, :pole_1]
copyto!(sobs.vel, env.sim.d.qvel)
obs
end


@inline function getreward(state, action, obs, env::CartpoleSwingup)
@boundscheck begin
checkaxes(statespace(env), state)
Expand All @@ -97,22 +92,18 @@ end
mean(upright) * small_control * small_velocity * centered
end


@inline function geteval(state, action, obs, env::CartpoleSwingup)
@boundscheck begin
checkaxes(obsspace(env), obs)
end
obsspace(env)(obs).pos.pole_zz
end


function reset!(env::CartpoleSwingup)
reset_nofwd!(env.sim)

qpos = env.sim.dn.qpos
@uviews qpos begin
qpos[:hinge_1] = pi
end
qpos[:hinge_1] = pi

forward!(env.sim)

Expand All @@ -123,15 +114,13 @@ function randreset!(rng::Random.AbstractRNG, env::CartpoleSwingup)
reset_nofwd!(env.sim)

qpos = env.sim.dn.qpos
@uviews qpos begin
qpos[:slider] = 0.01 * randn(rng)
qpos[:hinge_1] = pi + 0.01*randn(rng)
end
qpos[:slider] = 0.01 * randn(rng)
qpos[:hinge_1] = pi + 0.01*randn(rng)

randn!(rng, env.sim.d.qvel)
env.sim.d.qvel .*= 0.01

forward!(env.sim)

env
end
end
76 changes: 30 additions & 46 deletions src/gym/hopper-v2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,37 +41,30 @@ function tconstruct(::Type{HopperV2}, n::Integer)
end

@inline getsim(env::HopperV2) = env.sim


@inline statespace(env::HopperV2) = env.statespace

function getstate!(state, env::HopperV2)
checkaxes(statespace(env), state)
@uviews state begin
shaped = statespace(env)(state)
getstate!(shaped.simstate, env.sim)
shaped.last_torso_x = env.last_torso_x
end
shaped = statespace(env)(state)
getstate!(shaped.simstate, env.sim)
shaped.last_torso_x = env.last_torso_x
state
end

function setstate!(env::HopperV2, state)
checkaxes(statespace(env), state)
@uviews state begin
shaped = statespace(env)(state)
setstate!(env.sim, shaped.simstate)
env.last_torso_x = shaped.last_torso_x
end
shaped = statespace(env)(state)
setstate!(env.sim, shaped.simstate)
env.last_torso_x = shaped.last_torso_x
env
end


@inline obsspace(env::HopperV2) = env.obsspace

function getobs!(obs, env::HopperV2)
checkaxes(obsspace(env), obs)
qpos = env.sim.d.qpos
@views @uviews qpos obs begin
@views begin
shaped = obsspace(env)(obs)
copyto!(shaped.cropped_qpos, qpos[2:end])
copyto!(shaped.qvel, env.sim.d.qvel)
Expand All @@ -80,28 +73,22 @@ function getobs!(obs, env::HopperV2)
obs
end


function getreward(state, action, ::Any, env::HopperV2)
checkaxes(statespace(env), state)
checkaxes(actionspace(env), action)
@uviews state begin
shapedstate = statespace(env)(state)
alive_bonus = 1.0
reward = (_torso_x(shapedstate, env) - shapedstate.last_torso_x) / timestep(env)
reward += alive_bonus
reward -= 1e-3 * sum(x->x^2, action)
reward
end
shapedstate = statespace(env)(state)
alive_bonus = 1.0
reward = (_torso_x(shapedstate, env) - shapedstate.last_torso_x) / timestep(env)
reward += alive_bonus
reward -= 1e-3 * sum(x->x^2, action)
reward
end

function geteval(state, ::Any, ::Any, env::HopperV2)
checkaxes(statespace(env), state)
@uviews state begin
_torso_x(statespace(env)(state), env)
end
_torso_x(statespace(env)(state), env)
end


function reset!(env::HopperV2)
reset!(env.sim)
env.last_torso_x = _torso_x(env)
Expand All @@ -117,7 +104,6 @@ function randreset!(rng::AbstractRNG, env::HopperV2)
env
end


function step!(env::HopperV2)
env.last_torso_x = _torso_x(env)
step!(env.sim)
Expand All @@ -126,26 +112,24 @@ end

function isdone(state, ::Any, ::Any, env::HopperV2)
checkaxes(statespace(env), state)
@uviews state begin
shapedstate = statespace(env)(state)
torso_x = _torso_x(shapedstate, env)
height = _torso_height(shapedstate, env)
torso_ang = _torso_ang(shapedstate, env)
qpos = shapedstate.simstate.qpos
qvel = shapedstate.simstate.qvel

done = !(
all(isfinite, state)
&& all(x->abs(x) < 100, uview(qpos, 3:length(qpos)))
&& all(x->abs(x) < 100, uview(qvel))
&& height > 0.7
&& abs(torso_ang) < 0.2
)
done
end
shapedstate = statespace(env)(state)
torso_x = _torso_x(shapedstate, env)
height = _torso_height(shapedstate, env)
torso_ang = _torso_ang(shapedstate, env)
qpos = shapedstate.simstate.qpos
qvel = shapedstate.simstate.qvel

done = !(
all(isfinite, state)
&& all(x->abs(x) < 100, view(qpos, 3:length(qpos)))
&& all(x->abs(x) < 100, qvel)
&& height > 0.7
&& abs(torso_ang) < 0.2
)
done
end

@inline _torso_x(shapedstate::ShapedView, ::HopperV2) = shapedstate.simstate.qpos[1]
@inline _torso_x(env::HopperV2) = env.sim.d.qpos[1]
@inline _torso_height(shapedstate::ShapedView, ::HopperV2) = shapedstate.simstate.qpos[2]
@inline _torso_ang(shapedstate::ShapedView, ::HopperV2) = shapedstate.simstate.qpos[3]
@inline _torso_ang(shapedstate::ShapedView, ::HopperV2) = shapedstate.simstate.qpos[3]
38 changes: 13 additions & 25 deletions src/gym/swimmer-v2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,61 +42,49 @@ function tconstruct(::Type{SwimmerV2}, n::Integer)
end

@inline getsim(env::SwimmerV2) = env.sim


@inline statespace(env::SwimmerV2) = env.statespace

function getstate!(state, env::SwimmerV2)
checkaxes(statespace(env), state)
@uviews state begin
shaped = statespace(env)(state)
getstate!(shaped.simstate, env.sim)
shaped.last_torso_x = env.last_torso_x
end
shaped = statespace(env)(state)
getstate!(shaped.simstate, env.sim)
shaped.last_torso_x = env.last_torso_x
state
end

function setstate!(env::SwimmerV2, state)
checkaxes(statespace(env), state)
@uviews state begin
shaped = statespace(env)(state)
setstate!(env.sim, shaped.simstate)
env.last_torso_x = shaped.last_torso_x
end
shaped = statespace(env)(state)
setstate!(env.sim, shaped.simstate)
env.last_torso_x = shaped.last_torso_x
env
end


@inline obsspace(env::SwimmerV2) = env.obsspace

function getobs!(obs, env::SwimmerV2)
checkaxes(obsspace(env), obs)
qpos = env.sim.d.qpos
@views @uviews obs qpos begin
@views begin
shaped = obsspace(env)(obs)
copyto!(shaped.qpos_cropped, qpos[3:end])
copyto!(shaped.qvel, env.sim.d.qvel)
end
obs
end


function getreward(state, action, ::Any, env::SwimmerV2)
checkaxes(statespace(env), state)
checkaxes(actionspace(env), action)
@uviews state begin
shapedstate = statespace(env)(state)
reward_fwd = (_torso_x(shapedstate, env) - shapedstate.last_torso_x) / timestep(env)
reward_ctrl = -1e-4 * sum(x->x^2, action)
reward_fwd + reward_ctrl
end
shapedstate = statespace(env)(state)
reward_fwd = (_torso_x(shapedstate, env) - shapedstate.last_torso_x) / timestep(env)
reward_ctrl = -1e-4 * sum(x->x^2, action)
reward_fwd + reward_ctrl
end

function geteval(state, action, obs, env::SwimmerV2)
checkaxes(statespace(env), state)
@uviews state begin
_torso_x(statespace(env)(state), env)
end
_torso_x(statespace(env)(state), env)
end


Expand All @@ -122,4 +110,4 @@ function step!(env::SwimmerV2)
end

@inline _torso_x(shapedstate::ShapedView, ::SwimmerV2) = shapedstate.simstate.qpos[1]
@inline _torso_x(env::SwimmerV2) = env.sim.d.qpos[1]
@inline _torso_x(env::SwimmerV2) = env.sim.d.qpos[1]
Loading

0 comments on commit 1bf2191

Please sign in to comment.