Skip to content

Commit

Permalink
Merge pull request #106 from UoA-CARES/fix-termination-condition
Browse files Browse the repository at this point in the history
  • Loading branch information
retinfai authored Aug 24, 2023
2 parents 78d470a + 15fae90 commit a55248d
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 12 deletions.
8 changes: 5 additions & 3 deletions src/environments/environments/CarTrackEnvironment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ class CarTrackEnvironment(F1tenthEnvironment):
def __init__(self,
car_name,
reward_range=1,
max_steps=50,
max_steps=500,
collision_range=0.2,
step_length=0.5,
track='track_1',
observation_mode='full'
observation_mode='full',
):
super().__init__('car_track', car_name, max_steps, step_length)

Expand Down Expand Up @@ -117,8 +117,10 @@ def step(self, action):

def is_terminated(self, state):
return has_collided(state[8:], self.COLLISION_RANGE) \
or has_flipped_over(state[2:6])
or has_flipped_over(state[2:6]) or \
self.goal_number >= len(self.all_goals)


def generate_goal(self, number):
print("Goal", number, "spawned")
return self.all_goals[number % len(self.all_goals)]
Expand Down
55 changes: 55 additions & 0 deletions src/environments/environments/waypoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from collections import namedtuple

Waypoint = namedtuple('Waypoint', ['x', 'y', 'Y'])

# Waypoint(),
waypoints = {
'track_1': [
Waypoint(-14.11, 3.71, -1.81),
Waypoint(-15.75, -1.71, -1.96),
Waypoint(-16.71, -5.48, -1.35),
Waypoint(-7.53, -5.43, -0.27),
Waypoint(-0.22, -8.55, -0.87),
Waypoint(1.9, -12.05, -0.66),
Waypoint(5.5, -12.38, 0.79),
Waypoint(7.32, -7.3, 0.97),
Waypoint(9.62, -1.45, 1.23),
Waypoint(10.76, 4.08, 1.64),
Waypoint(8.48, 9.77, 2.07),
Waypoint(4.93, 13.27, -3.06),
Waypoint(1.5, 9.96, -1.9),
Waypoint(-0.89, 5.67, -2.94),
Waypoint(-5.52, 10.83, 2.02),
Waypoint(-6.67, 15.67, 2.38),
Waypoint(-10.84, 13.99, -1.93),
Waypoint(12.81, 8.14, -1.93),
],
'budapest': [
Waypoint(0, 0, -0.66),
Waypoint(7.27, -5.88, -0.66),
Waypoint(14.19, -7.85, 0.74),
Waypoint(13.44, -0.49, 2.45),
Waypoint(7.74, 4.32, 2.04),
Waypoint(15.5, 4.11, -0.9),
Waypoint(21.58, -0.55, 0.76),
Waypoint(31.42, 9.76, 0.76),
Waypoint(38.02, 17.33, 1.26),
Waypoint(34.74, 26.22, 2.16),
Waypoint(34.83, 39.9, 1.7),
Waypoint(25.33, 46.92, 1.7),
Waypoint(58.7, 58.7, 2.17),
Waypoint(14.53, 68.36, 2.59),
Waypoint(7.71, 64.15, -1.48),
Waypoint(7.8, 50.07, -2.17),
Waypoint(-3.65, 30.22, -1.95),
Waypoint(-2.49, 22.02, -1.2),
Waypoint(-1.34, 17.12, -2.05),
Waypoint(-9.16, 18.16, 2.49),
Waypoint(-19.28, 26.05, 2.82),
Waypoint(-28.16, 27.8, 3.11),
Waypoint(-29.03, 23.69, -0.79),
Waypoint(-20.17, 16.62, -0.75),
Waypoint(-7.26, 5.93, -0.75),
]

}
4 changes: 2 additions & 2 deletions src/environments/worlds/budapest_track.sdf
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
<!-- <mesh filename="package://environments/meshes/track.dae" scale="0.001 0.001 0.001"/> -->
<mesh>
<uri>model://src/environments/meshes/budapest_track.stl</uri>
<scale>0.01 0.01 0.01</scale>
<scale>0.007 0.007 0.007</scale>
</mesh>
</geometry>
<material>
Expand All @@ -101,7 +101,7 @@
<!-- <mesh filename="package://environments/meshes/track.dae" scale="0.001 0.001 0.001" /> -->
<mesh>
<uri>model://src/environments/meshes/budapest_track.stl</uri>
<scale>0.01 0.01 0.01</scale>
<scale>0.007 0.007 0.007</scale>
</mesh>
</geometry>
</collision>
Expand Down
2 changes: 1 addition & 1 deletion src/reinforcement_learning/config/sanity_check.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
sanity_check:
ros__parameters:
environment: 'CarTrack' # CarGoal, CarWall, CarBlock, CarTrack
track: 'track_1' # track_1, track_2, track_3 -> only applies for CarTrack
track: 'budapest_track' # track_1, track_2, track_3 -> only applies for CarTrack
max_steps_exploration: 5000
max_steps_training: 1000000
reward_range: 1.0
Expand Down
5 changes: 3 additions & 2 deletions src/reinforcement_learning/config/train.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
train:
ros__parameters:
environment: 'CarBeat' # CarGoal, CarWall, CarBlock, CarTrack, CarBeat
environment: 'CarTrack' # CarGoal, CarWall, CarBlock, CarTrack, CarBeat
track: 'track_2' # track_1, track_2, track_3 -> only applies for CarTrack
algorithm: 'TD3'
max_steps_exploration: 1000
max_steps_training: 1000000
reward_range: 3.0
reward_range: 3.0
collision_range: 0.2
observation_mode: 'lidar_only'
evaluate_every_n_steps: 2000
evaluate_for_m_episodes: 3
# actor_path & critic_path must exist, it can't be commented
# actor_path: 'rl_logs/23_08_02_17:59:13/models/actor_checkpoint.pht'
# critic_path: 'rl_logs/23_08_02_17:59:13/models/critic_checkpoint.pht'
Expand Down
58 changes: 54 additions & 4 deletions src/reinforcement_learning/reinforcement_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def main():
global MAX_STEPS_PER_BATCH
global G
global BATCH_SIZE
global EVALUATE_EVERY_N_STEPS
global EVALUATE_FOR_M_EPISODES

ENVIRONMENT, \
ALGORITHM, \
Expand All @@ -47,7 +49,9 @@ def main():
ACTOR_PATH, \
CRITIC_PATH, \
MAX_STEPS_PER_BATCH, \
OBSERVATION_MODE = [param.value for param in params]
OBSERVATION_MODE, \
EVALUATE_EVERY_N_STEPS, \
EVALUATE_FOR_M_EPISODES = [param.value for param in params]

if ACTOR_PATH != '' and CRITIC_PATH != '':
MAX_STEPS_EXPLORATION = 0
Expand All @@ -73,7 +77,9 @@ def main():
f'Critic Path: {CRITIC_PATH}\n'
f'Actor Path: {ACTOR_PATH}\n'
f'Max Steps per Batch: {MAX_STEPS_PER_BATCH}\n'
f'Observation Mode: {OBSERVATION_MODE}'
f'Observation Mode: {OBSERVATION_MODE}\n'
f'Evaluate Every N Steps: {EVALUATE_EVERY_N_STEPS}\n'
f'Evaluate For M Episodes: {EVALUATE_FOR_M_EPISODES}\n'
f'---------------------------------------------\n'
)

Expand Down Expand Up @@ -130,7 +136,9 @@ def main():
'reward_range': REWARD_RANGE,
'collision_range': COLLISION_RANGE,
'max_steps_per_batch': MAX_STEPS_PER_BATCH,
'observation_mode': OBSERVATION_MODE
'observation_mode': OBSERVATION_MODE,
'evaluate_every_n_steps': EVALUATE_EVERY_N_STEPS,
'evaluate_for_m_episodes': EVALUATE_FOR_M_EPISODES
}

if (ENVIRONMENT == 'CarTrack'):
Expand Down Expand Up @@ -179,12 +187,18 @@ def train(env, agent, record: Record):
experiences = (experiences['state'], experiences['action'], experiences['reward'], experiences['next_state'], experiences['done'])
agent.train_policy(experiences)

evaluation_reward = None

if total_step_counter % EVALUATE_EVERY_N_STEPS == 0:
evaluation_reward = evaluate_policy(env, agent, EVALUATE_FOR_M_EPISODES)

record.log(
out=done or truncated,
Step=total_step_counter,
Episode=episode_num,
Step_Reward=reward,
Episode_Reward=episode_reward if (done or truncated) else None,
Evaluation_Reward=evaluation_reward
)

if done or truncated:
Expand Down Expand Up @@ -258,6 +272,38 @@ def train_ppo(env, agent, record):
episode_timesteps = 0
episode_num += 1

def evaluate_policy(env, agent, num_episodes):

episode_reward_history = []

print('Beginning Evaluation----------------------------')

for ep in range(num_episodes):
state, _ = env.reset()

episode_timesteps = 0
episode_reward = 0

truncated = False
terminated = False

while not truncated and not terminated:

action = agent.select_action_from_policy(state)
next_state, reward, terminated, truncated, _ = env.step(action)

episode_reward += reward
state = next_state

print(f'Evaluation Episode {ep + 1} Completed with a Reward of {episode_reward}')
episode_reward_history.append(episode_reward)

avg_reward = sum(episode_reward_history) / len(episode_reward_history)

print(f'Evaluation Completed: Avg Reward over {num_episodes} Episodes is {avg_reward} ----------------------------')

return avg_reward

def get_params():
'''
This function fetches the hyperparameters passed in through the launch files
Expand Down Expand Up @@ -287,7 +333,9 @@ def get_params():
('actor_path', ''),
('critic_path', ''),
('max_steps_per_batch', 5000),
('observation_mode', 'no_position')
('observation_mode', 'no_position'),
('evaluate_every_n_steps', 2000),
('evaluate_for_m_episodes', 5),
]
)

Expand All @@ -313,6 +361,8 @@ def get_params():
'critic_path',
'max_steps_per_batch',
'observation_mode',
'evaluate_every_n_steps',
'evaluate_for_m_episodes'
])


Expand Down

0 comments on commit a55248d

Please sign in to comment.