Skip to content

Commit

Permalink
Stable training
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Suarez committed Feb 8, 2025
1 parent b655669 commit 392010f
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 8 deletions.
8 changes: 5 additions & 3 deletions clean_pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import random
import psutil
import time

from threading import Thread
from collections import defaultdict, deque

Expand Down Expand Up @@ -178,7 +177,10 @@ def train(data):
experience.num_minibatches, config.minibatch_size).to(config.device)
with torch.no_grad():
for mb in range(experience.num_minibatches):
adversarial_reward[mb] = -torch.log(1 - data.policy.policy.discriminate(state[mb]).squeeze())
disc_logits = data.policy.policy.discriminate(state[mb]).squeeze()
prob = 1 / (1 + torch.exp(-disc_logits))
adversarial_reward[mb] = -torch.log(torch.maximum(
1 - prob, torch.tensor(0.0001, device=config.device)))

# TODO: Nans in adversarial reward and gae
adversarial_reward_np = adversarial_reward.cpu().numpy().ravel()
Expand Down Expand Up @@ -436,7 +438,7 @@ def __init__(self, batch_size, bptt_horizon, minibatch_size, obs_shape, obs_dtyp
obs_device = device if not pin else 'cpu'
self.obs=torch.zeros(batch_size, *obs_shape, dtype=obs_dtype,
pin_memory=pin, device=device if not pin else 'cpu')
self.demo=torch.zeros(batch_size, 576, dtype=obs_dtype,
self.demo=torch.zeros(batch_size, 358, dtype=obs_dtype,
pin_memory=pin, device=device if not pin else 'cpu')
self.state=torch.zeros(batch_size, 358, dtype=obs_dtype,
pin_memory=pin, device=device if not pin else 'cpu')
Expand Down
72 changes: 68 additions & 4 deletions pufferlib/environments/morph/humanoid_phc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,9 +1202,19 @@ def _compute_observations(self, env_ids=None):
if env_ids is None:
env_ids = self.all_env_ids

self.state = self._compute_humanoid_obs(env_ids)
self.demo = self._compute_task_obs(env_ids)
obs = torch.cat([self.state, self.demo], dim=-1)
# This is the normalized state of the humanoid
state = self._compute_humanoid_obs(env_ids)

# This is the difference of state with the demo, but it is
# called "state" in the paper.
imitation = self._compute_task_obs(env_ids)

# Possible the original paper only uses imitation
obs = torch.cat([state, imitation], dim=-1)

# This is the normalized vector with position, rotation, velocity, and
# angular velocity for the simulated humanoid and the demo data
self.state, self.demo = self._compute_state_obs(env_ids)

if self.add_obs_noise and not self.flag_test:
obs = obs + torch.randn_like(obs) * 0.1
Expand Down Expand Up @@ -1245,6 +1255,60 @@ def _compute_humanoid_obs(self, env_ids=None):
self._has_limb_weight_obs, # Constant: False
)

def _compute_state_obs(self, env_ids=None):
if env_ids is None:
env_ids = slice(None)

body_pos = self._rigid_body_pos[env_ids]#[..., self._track_bodies_id]
body_rot = self._rigid_body_rot[env_ids]#[..., self._track_bodies_id]
body_vel = self._rigid_body_vel[env_ids]#[..., self._track_bodies_id]
body_ang_vel = self._rigid_body_ang_vel[env_ids]#[..., self._track_bodies_id]

sim_obs = compute_humanoid_observations_smpl_max(
body_pos,
body_rot,
body_vel,
body_ang_vel,
None,
None,
self._local_root_obs, # Constant: True
self._root_height_obs, # Constant: True
self._has_upright_start, # Constant: True
self._has_shape_obs, # Constant: False
self._has_limb_weight_obs, # Constant: False
)

motion_times = (
(self.progress_buf[env_ids] + 1) * self.dt
+ self._motion_start_times[env_ids]
+ self._motion_start_times_offset[env_ids]
) # Next frame, so +1

motion_res = self._get_state_from_motionlib_cache(
self._sampled_motion_ids[env_ids], motion_times, self._global_offset[env_ids]
) # pass in the env_ids such that the motion is in synced.

demo_pos = motion_res["rg_pos"]#[..., self._track_bodies_id]
demo_rot = motion_res["rb_rot"]#[..., self._track_bodies_id]
demo_vel = motion_res["body_vel"]#[..., self._track_bodies_id]
demo_ang_vel = motion_res["body_ang_vel"]#[..., self._track_bodies_id]

demo_obs = compute_humanoid_observations_smpl_max(
demo_pos,
demo_rot,
demo_vel,
demo_ang_vel,
None,
None,
True, # Constant: True
self._root_height_obs, # Constant: True
self._has_upright_start, # Constant: True
self._has_shape_obs, # Constant: False
self._has_limb_weight_obs, # Constant: False
)

return sim_obs, demo_obs

def _compute_task_obs(self, env_ids=None, save_buffer=True):
if env_ids is None:
env_ids = self.all_env_ids
Expand Down Expand Up @@ -1698,7 +1762,7 @@ def remove_base_rot(quat):
return quat_mul(quat, base_rot.repeat(shape, 1))


@torch.jit.script
#@torch.jit.script
def compute_humanoid_observations_smpl_max(
body_pos,
body_rot,
Expand Down
5 changes: 4 additions & 1 deletion pufferlib/environments/morph/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def __init__(self, env, input_dim, action_dim, demo_dim, hidden):
requires_grad=False,
)
nn.init.constant_(self.sigma, -2.9)
#self.mu = pufferlib.pytorch.layer_init(
# nn.Linear(hidden, action_dim), std=0.01)
#self.sigma = nn.Parameter(torch.zeros(1, action_dim))

### Separate Critic
self.critic_mlp = nn.Sequential(
Expand Down Expand Up @@ -71,7 +74,7 @@ def encode_observations(self, obs):

def decode_actions(self, hidden, lookup=None):
mu = self.mu(hidden)
std = torch.exp(self.sigma)
std = torch.exp(self.sigma).expand_as(mu)
probs = torch.distributions.Normal(mu, std)
value = self.value(hidden)
return probs, value
Expand Down

0 comments on commit 392010f

Please sign in to comment.