Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import model back #75

Merged
merged 4 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions retrain.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/bin/bash

# Ctrl + C to exit the program
function cleanup() {
exit 1
}

trap cleanup SIGINT SIGTERM

# Time limit for training
if [ -z "$1" ]; then
echo "Please enter the time limit for training"
echo "Usage: ./retrain.sh <time_limit> [<partition_number>]"
echo "Eg: 10s, 6h"
echo "Usage: ./retrain.sh 10s"
exit 1
fi

# Set the partition number
if [ -n "$2" ]; then
export GZ_PARTITION="$2"
else
export GZ_PARTITION=150
fi

. install/setup.bash

# Rerun every $1 time
while true; do
colcon build
timeout "$1" gz sim -g &
timeout "$1" ros2 launch reinforcement_learning train.launch.py

# Get the latest folder in rl_logs
latest_folder=$(ls -lt ./rl_logs | grep '^d' | head -n 1 | awk '{print $NF}')

# Set the new paths for actor_path and critic_path
new_actor_path="rl_logs/$latest_folder/models/actor_checkpoint.pht"
new_critic_path="rl_logs/$latest_folder/models/critic_checkpoint.pht"

# Check if the new paths exist
if [ ! -e "$new_actor_path" ] || [ ! -e "$new_critic_path" ]; then
echo "Error: $new_actor_path or $new_critic_path does not exist"
exit 1
fi

# Use sed to update the actor_path and critic_path in the YAML file
sed -i "s#actor_path: '.*'#actor_path: '$new_actor_path'#g" src/reinforcement_learning/config/train.yaml
sed -i "s#critic_path: '.*'#critic_path: '$new_critic_path'#g" src/reinforcement_learning/config/train.yaml
done
4 changes: 3 additions & 1 deletion src/reinforcement_learning/config/train.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
train:
ros__parameters:
environment: 'CarTrack' # CarGoal, CarWall, CarBlock, CarTrack
track: 'track_1' # track_1, track_2, track_3 -> only applies for CarTrack
max_steps_exploration: 5000
track: 'track_2' # track_1, track_2, track_3 -> only applies for CarTrack
max_steps_training: 1000000
reward_range: 1.0
collision_range: 0.2
actor_path: 'rl_logs/23_07_17_01:50:30/models/actor_checkpoint.pht'
critic_path: 'rl_logs/23_07_17_01:50:30/models/critic_checkpoint.pht'
# gamma: 0.95
# tau: 0.005
# g: 5
Expand Down
23 changes: 20 additions & 3 deletions src/reinforcement_learning/reinforcement_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def main():
MAX_STEPS, \
STEP_LENGTH, \
REWARD_RANGE, \
COLLISION_RANGE = [param.value for param in params]
COLLISION_RANGE, \
ACTOR_PATH, \
CRITIC_PATH = [param.value for param in params]

if ACTOR_PATH != '' and CRITIC_PATH != '':
MAX_STEPS_EXPLORATION = 0

print(
f'---------------------------------------------\n'
Expand All @@ -60,6 +65,8 @@ def main():
f'Step Length: {STEP_LENGTH}\n'
f'Reward Range: {REWARD_RANGE}\n'
f'Collision Range: {COLLISION_RANGE}\n'
f'Critic Path: {CRITIC_PATH}\n'
f'Actor Path: {ACTOR_PATH}\n'
f'---------------------------------------------\n'
)

Expand All @@ -79,6 +86,12 @@ def main():
actor = Actor(observation_size=env.OBSERVATION_SIZE, num_actions=env.ACTION_NUM, learning_rate=ACTOR_LR)
critic = Critic(observation_size=env.OBSERVATION_SIZE, num_actions=env.ACTION_NUM, learning_rate=CRITIC_LR)

if ACTOR_PATH != '' and CRITIC_PATH != '':
print('Reading saved models into actor and critic')
actor.load_state_dict(torch.load(ACTOR_PATH))
critic.load_state_dict(torch.load(CRITIC_PATH))
print('Successfully Loaded models')

agent = TD3(
actor_network=actor,
critic_network=critic,
Expand Down Expand Up @@ -192,7 +205,9 @@ def get_params():
('max_steps', 100),
('step_length', 0.25),
('reward_range', 0.2),
('collision_range', 0.2)
('collision_range', 0.2),
('actor_path', ''),
('critic_path', '')
]
)

Expand All @@ -212,7 +227,9 @@ def get_params():
'max_steps',
'step_length',
'reward_range',
'collision_range'
'collision_range',
'actor_path',
'critic_path',
])


Expand Down