Skip to content

Commit

Permalink
add armhand environment (#19)
Browse files Browse the repository at this point in the history
add armhand environment
  • Loading branch information
colinxs authored Jan 20, 2020
1 parent ec49eb3 commit d9e0c5f
Show file tree
Hide file tree
Showing 44 changed files with 799 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/LyceumMuJoCo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ import LyceumBase: statespace,

tconstruct

using LyceumBase.Tools: perturb!
using LyceumBase.Tools: perturb!, SPoint3D

export # AbstractMuJoCoEnvironment interface (an addition to AbstractEnvironment's interface)
AbstractMuJoCoEnvironment,
Expand Down Expand Up @@ -138,6 +138,7 @@ end
####

include("suite/pointmass.jl")
include("suite/armhand/armhandpickup.jl")

include("gym/swimmer-v2.jl")
include("gym/hopper-v2.jl")
Expand Down
648 changes: 648 additions & 0 deletions src/suite/armhand/armhand.xml

Large diffs are not rendered by default.

146 changes: 146 additions & 0 deletions src/suite/armhand/armhandpickup.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
struct ArmHandPickup{S<:MJSim,O<:MultiShape} <: AbstractMuJoCoEnvironment
sim::S
obsspace::O
ball::Int
palm::Int
thumb::Int
index::Int
middle::Int
ring::Int
pinky::Int
goal::SVector{3,Float64}

function ArmHandPickup(sim::MJSim)
m = sim.m
mn, dn = sim.mn, sim.dn

obsspace = MultiShape(
d_thumb = ScalarShape(Float64),
d_index = ScalarShape(Float64),
d_middle = ScalarShape(Float64),
d_ring = ScalarShape(Float64),
d_pinky = ScalarShape(Float64),
a_thumb = ScalarShape(Float64),
a_index = ScalarShape(Float64),
a_middle = ScalarShape(Float64),
a_ring = ScalarShape(Float64),
a_pinky = ScalarShape(Float64),
a_close = ScalarShape(Float64),
handball = ScalarShape(Float64),
ballgoal = ScalarShape(Float64),
ball = VectorShape(Float64, 3),
palm = VectorShape(Float64, 3),
)

ball = jl_name2id(sim.m, MuJoCo.MJCore.mjOBJ_SITE, "ball")
palm = jl_name2id(sim.m, MuJoCo.MJCore.mjOBJ_SITE, "palm")

thumb = jl_name2id(sim.m, MuJoCo.MJCore.mjOBJ_SITE, "thumb_IMU")
index = jl_name2id(sim.m, MuJoCo.MJCore.mjOBJ_SITE, "index_IMU")
middle = jl_name2id(sim.m, MuJoCo.MJCore.mjOBJ_SITE, "middle_IMU")
ring = jl_name2id(sim.m, MuJoCo.MJCore.mjOBJ_SITE, "ring_IMU")
pinky = jl_name2id(sim.m, MuJoCo.MJCore.mjOBJ_SITE, "pinky_IMU")

goal = SA_F64[0.0, 0.2, 0.5]

new{typeof(sim),typeof(obsspace)}(
sim,
obsspace,
ball,
palm,
thumb,
index,
middle,
ring,
pinky,
goal,
)
end
end

ArmHandPickup() = first(tconstruct(ArmHandPickup, 1))

function tconstruct(::Type{ArmHandPickup}, n::Integer)
modelpath = joinpath(@__DIR__, "armhand.xml")
Tuple(ArmHandPickup(s) for s in tconstruct(MJSim, n, modelpath, skip = 3))
end


@inline getsim(env::ArmHandPickup) = env.sim


@inline obsspace(env::ArmHandPickup) = env.obsspace

@propagate_inbounds function getobs!(obs, env::ArmHandPickup)
@boundscheck checkaxes(obsspace(env), obs)

m, d = env.sim.m, env.sim.d
sx = d.site_xpos
dmin = 0.5

ball = SPoint3D(sx, env.ball)
palm = SPoint3D(sx, env.palm)
thumb = SPoint3D(sx, env.thumb)
index = SPoint3D(sx, env.index)
middle = SPoint3D(sx, env.middle)
ring = SPoint3D(sx, env.ring)
pinky = SPoint3D(sx, env.pinky)
goal = ball - env.goal

@uviews obs @inbounds begin
shaped = obsspace(env)(obs)

shaped.ball .= ball
shaped.palm .= palm .- ball

shaped.handball = _sitedist(palm, ball, dmin)
shaped.ballgoal = _sitedist(ball, goal, dmin)

shaped.d_thumb = _sitedist(thumb, ball, dmin)
shaped.d_index = _sitedist(index, ball, dmin)
shaped.d_middle = _sitedist(middle, ball, dmin)
shaped.d_ring = _sitedist(ring, ball, dmin)
shaped.d_pinky = _sitedist(pinky, ball, dmin)
shaped.a_thumb = cosine_dist(goal, thumb)
shaped.a_index = cosine_dist(goal, index)
shaped.a_middle = cosine_dist(goal, middle)
shaped.a_ring = cosine_dist(goal, ring)
shaped.a_pinky = cosine_dist(goal, pinky)
shaped.a_close = cosine_dist(middle, thumb)
end

obs
end


@propagate_inbounds function getreward(state, action, obs, env::ArmHandPickup)
@boundscheck checkaxes(obsspace(env), obs)

os = obsspace(env)(obs)
handball = os.handball / 0.5
ballgoal = os.ballgoal / 0.5

reward = -handball
if handball < 0.06
reward = 2.0 - 2 * ballgoal
reward -= 0.1 * (os.d_thumb + os.d_index + os.d_middle + os.d_ring + os.d_pinky)
end
reward
end

@propagate_inbounds function geteval(state, action, obs, env::ArmHandPickup)
@boundscheck checkaxes(obsspace(env), obs)
obsspace(env)(obs).ball[3]
end


@propagate_inbounds function randreset!(rng::Random.AbstractRNG, env::ArmHandPickup)
fastreset_nofwd!(env.sim)
env.sim.d.qpos[1] = rand(rng, Uniform(-0.15, 0.15))
env.sim.d.qpos[2] = rand(rng, Uniform(-0.1, 0.1))
forward!(env.sim)
env
end


@inline _sitedist(s1, s2, dmin) = min(euclidean(s1, s2), dmin)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added src/suite/armhand/meshes/mesh/hybrid/index0.stl
Binary file not shown.
Binary file added src/suite/armhand/meshes/mesh/hybrid/index1.stl
Binary file not shown.
Binary file added src/suite/armhand/meshes/mesh/hybrid/index2.stl
Binary file not shown.
Binary file added src/suite/armhand/meshes/mesh/hybrid/index3.stl
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added src/suite/armhand/meshes/mesh/hybrid/middle0.stl
Binary file not shown.
Binary file added src/suite/armhand/meshes/mesh/hybrid/middle1.stl
Binary file not shown.
Binary file added src/suite/armhand/meshes/mesh/hybrid/middle2.stl
Binary file not shown.
Binary file added src/suite/armhand/meshes/mesh/hybrid/middle3.stl
Binary file not shown.
Binary file added src/suite/armhand/meshes/mesh/hybrid/palm.stl
Binary file not shown.
Binary file added src/suite/armhand/meshes/mesh/hybrid/pinky0.stl
Binary file not shown.
Binary file added src/suite/armhand/meshes/mesh/hybrid/pinky1.stl
Binary file not shown.
Binary file added src/suite/armhand/meshes/mesh/hybrid/pinky2.stl
Diff not rendered.
Binary file added src/suite/armhand/meshes/mesh/hybrid/pinky3.stl
Diff not rendered.
Binary file added src/suite/armhand/meshes/mesh/hybrid/ring0.stl
Diff not rendered.
Binary file added src/suite/armhand/meshes/mesh/hybrid/ring1.stl
Diff not rendered.
Binary file added src/suite/armhand/meshes/mesh/hybrid/ring2.stl
Diff not rendered.
Binary file added src/suite/armhand/meshes/mesh/hybrid/ring3.stl
Diff not rendered.
Binary file added src/suite/armhand/meshes/mesh/hybrid/thumb0.stl
Diff not rendered.
Binary file added src/suite/armhand/meshes/mesh/hybrid/thumb1.stl
Diff not rendered.
Binary file added src/suite/armhand/meshes/mesh/hybrid/thumb2.stl
Diff not rendered.
Binary file added src/suite/armhand/meshes/mesh/hybrid/thumb3.stl
Diff not rendered.
Binary file added src/suite/armhand/meshes/mesh/hybrid/wristx.stl
Diff not rendered.
Binary file added src/suite/armhand/meshes/mesh/hybrid/wristy.stl
Diff not rendered.
Binary file added src/suite/armhand/meshes/mesh/hybrid/wristz.stl
Diff not rendered.
Binary file added src/suite/armhand/meshes/texture/marble.png
Binary file added src/suite/armhand/meshes/texture/skin.png
5 changes: 3 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ using Test, Random, Shapes, Pkg, LinearAlgebra, BenchmarkTools
using MuJoCo: TESTMODELXML

const LYCEUM_SUITE = [
(LyceumMuJoCo.PointMass, (), ())
(LyceumMuJoCo.PointMass, (), ()),
(LyceumMuJoCo.ArmHandPickup, (), ()),
]

const GYM = [
Expand Down Expand Up @@ -32,4 +33,4 @@ end
@testset "DMC" begin test_group(DMC) end
end

end
end

0 comments on commit d9e0c5f

Please sign in to comment.