Skip to content

Commit

Permalink
Merge pull request #75 from UoA-CARES/Import-model-back
Browse files Browse the repository at this point in the history
Import model back
  • Loading branch information
retinfai authored Jul 24, 2023
2 parents 767cbc3 + e611cbe commit 83ef44c
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 4 deletions.
50 changes: 50 additions & 0 deletions retrain.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/bin/bash

# Ctrl + C to exit the program
function cleanup() {
exit 1
}

trap cleanup SIGINT SIGTERM

# Time limit for training
if [ -z "$1" ]; then
echo "Please enter the time limit for training"
echo "Usage: ./retrain.sh <time_limit> [<partition_number>]"
echo "Eg: 10s, 6h"
echo "Usage: ./retrain.sh 10s"
exit 1
fi

# Set the partition number
if [ -n "$2" ]; then
export GZ_PARTITION="$2"
else
export GZ_PARTITION=150
fi

. install/setup.bash

# Rerun every $1 time
while true; do
colcon build
timeout "$1" gz sim -g &
timeout "$1" ros2 launch reinforcement_learning train.launch.py

# Get the latest folder in rl_logs
latest_folder=$(ls -lt ./rl_logs | grep '^d' | head -n 1 | awk '{print $NF}')

# Set the new paths for actor_path and critic_path
new_actor_path="rl_logs/$latest_folder/models/actor_checkpoint.pht"
new_critic_path="rl_logs/$latest_folder/models/critic_checkpoint.pht"

# Check if the new paths exist
if [ ! -e "$new_actor_path" ] || [ ! -e "$new_critic_path" ]; then
echo "Error: $new_actor_path or $new_critic_path does not exist"
exit 1
fi

# Use sed to update the actor_path and critic_path in the YAML file
sed -i "s#actor_path: '.*'#actor_path: '$new_actor_path'#g" src/reinforcement_learning/config/train.yaml
sed -i "s#critic_path: '.*'#critic_path: '$new_critic_path'#g" src/reinforcement_learning/config/train.yaml
done
4 changes: 3 additions & 1 deletion src/reinforcement_learning/config/train.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
train:
ros__parameters:
environment: 'CarTrack' # CarGoal, CarWall, CarBlock, CarTrack
track: 'track_1' # track_1, track_2, track_3 -> only applies for CarTrack
max_steps_exploration: 5000
track: 'track_2' # track_1, track_2, track_3 -> only applies for CarTrack
max_steps_training: 1000000
reward_range: 1.0
collision_range: 0.2
actor_path: 'rl_logs/23_07_17_01:50:30/models/actor_checkpoint.pht'
critic_path: 'rl_logs/23_07_17_01:50:30/models/critic_checkpoint.pht'
# gamma: 0.95
# tau: 0.005
# g: 5
Expand Down
23 changes: 20 additions & 3 deletions src/reinforcement_learning/reinforcement_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def main():
MAX_STEPS, \
STEP_LENGTH, \
REWARD_RANGE, \
COLLISION_RANGE = [param.value for param in params]
COLLISION_RANGE, \
ACTOR_PATH, \
CRITIC_PATH = [param.value for param in params]

if ACTOR_PATH != '' and CRITIC_PATH != '':
MAX_STEPS_EXPLORATION = 0

print(
f'---------------------------------------------\n'
Expand All @@ -60,6 +65,8 @@ def main():
f'Step Length: {STEP_LENGTH}\n'
f'Reward Range: {REWARD_RANGE}\n'
f'Collision Range: {COLLISION_RANGE}\n'
f'Critic Path: {CRITIC_PATH}\n'
f'Actor Path: {ACTOR_PATH}\n'
f'---------------------------------------------\n'
)

Expand All @@ -79,6 +86,12 @@ def main():
actor = Actor(observation_size=env.OBSERVATION_SIZE, num_actions=env.ACTION_NUM, learning_rate=ACTOR_LR)
critic = Critic(observation_size=env.OBSERVATION_SIZE, num_actions=env.ACTION_NUM, learning_rate=CRITIC_LR)

if ACTOR_PATH != '' and CRITIC_PATH != '':
print('Reading saved models into actor and critic')
actor.load_state_dict(torch.load(ACTOR_PATH))
critic.load_state_dict(torch.load(CRITIC_PATH))
print('Successfully Loaded models')

agent = TD3(
actor_network=actor,
critic_network=critic,
Expand Down Expand Up @@ -192,7 +205,9 @@ def get_params():
('max_steps', 100),
('step_length', 0.25),
('reward_range', 0.2),
('collision_range', 0.2)
('collision_range', 0.2),
('actor_path', ''),
('critic_path', '')
]
)

Expand All @@ -212,7 +227,9 @@ def get_params():
'max_steps',
'step_length',
'reward_range',
'collision_range'
'collision_range',
'actor_path',
'critic_path',
])


Expand Down

0 comments on commit 83ef44c

Please sign in to comment.