diff --git a/src/reinforcement_learning/config/test.yaml b/src/reinforcement_learning/config/test.yaml index 9597434a..06dc60e2 100644 --- a/src/reinforcement_learning/config/test.yaml +++ b/src/reinforcement_learning/config/test.yaml @@ -1,12 +1,12 @@ test: ros__parameters: - environment: 'CarBeat' + environment: 'CarTrack' track: 'multi_track' # track_1, track_2, track_3, multi_track, multi_track_testing -> only applies for CarTrack - max_steps_evaluation: 1000000 + number_eval_episodes: 100 max_steps: 3000 - actor_path: results/23_09_08_06:28:51/models/actor_checkpoint.pht - critic_path: results/23_09_08_06:28:51/models/critic_checkpoint.pht - algorithm: 'TD3' + actor_path: training_logs/SAC-CarTrack-23_10_25_10:41:10/models/SAC-checkpoint-15000_actor.pht + critic_path: training_logs/SAC-CarTrack-23_10_25_10:41:10/models/SAC-checkpoint-15000_critic.pht + algorithm: 'SAC' step_length: 0.1 reward_range: 3.0 collision_range: 0.2 diff --git a/src/reinforcement_learning/launch/test.launch.py b/src/reinforcement_learning/launch/test.launch.py index 2c605683..64c1cc66 100644 --- a/src/reinforcement_learning/launch/test.launch.py +++ b/src/reinforcement_learning/launch/test.launch.py @@ -32,6 +32,9 @@ def generate_launch_description(): os.path.join(pkg_environments, f'{env_launch[env]}.launch.py')), launch_arguments={ 'track': TextSubstitution(text=str(config['test']['ros__parameters']['track'])), + 'car_name': TextSubstitution(text=str(config['test']['ros__parameters']['car_name']) if 'car_name' in config['test']['ros__parameters'] else 'f1tenth'), + 'car_one': TextSubstitution(text=str(config['test']['ros__parameters']['car_name']) if 'car_name' in config['test']['ros__parameters'] else 'f1tenth'), + 'car_two': TextSubstitution(text=str(config['test']['ros__parameters']['ftg_car_name']) if 'ftg_car_name' in config['test']['ros__parameters'] else 'ftg_car'), }.items() #TODO: this doesn't do anything ) diff --git a/src/reinforcement_learning/reinforcement_learning/parse_args.py b/src/reinforcement_learning/reinforcement_learning/parse_args.py index e16ca118..893b52ce 100644 --- a/src/reinforcement_learning/reinforcement_learning/parse_args.py +++ b/src/reinforcement_learning/reinforcement_learning/parse_args.py @@ -7,11 +7,13 @@ def parse_args(): param_node = __declare_params() - env_params = __get_env_params(param_node) - algorithm_params = __get_algorithm_params(param_node) - network_params = __get_network_params(param_node) + env_params, rest_env = __get_env_params(param_node) + algorithm_params, rest_alg = __get_algorithm_params(param_node) + network_params, rest_params = __get_network_params(param_node) - return env_params, algorithm_params, network_params + rest = {**rest_env, **rest_alg, **rest_params} + + return env_params, algorithm_params, network_params, rest def __declare_params(): @@ -92,7 +94,12 @@ def __get_env_params(param_node: Node): case _: raise Exception(f'Environment {params_dict["environment"]} not implemented') - return config + # Collect all the parameters that were not used into a python dictionary + rest = set(params_dict.keys()).difference(set(config.dict().keys())) + rest = {key: params_dict[key] for key in rest} + + param_node.get_logger().info(f'Rest: {rest}') + return config, rest def __get_algorithm_params(param_node: Node): params = param_node.get_parameters([ @@ -114,7 +121,11 @@ def __get_algorithm_params(param_node: Node): config = cfg.TrainingConfig(**params_dict) - return config + rest = set(params_dict.keys()).difference(set(config.dict().keys())) + rest = {key: params_dict[key] for key in rest} + + param_node.get_logger().info(f'Rest: {rest}') + return config, rest def __get_network_params(param_node: Node): params = param_node.get_parameters([ @@ -144,5 +155,9 @@ def __get_network_params(param_node: Node): case _: raise Exception(f'Algorithm {params_dict["algorithm"]} not implemented') - return config + rest = set(params_dict.keys()).difference(set(config.dict().keys())) + param_node.get_logger().info(f'Rest: {rest}') + rest = {key: params_dict[key] for key in rest} + + return config, rest diff --git a/src/reinforcement_learning/reinforcement_learning/test.py b/src/reinforcement_learning/reinforcement_learning/test.py index 03377fef..2b885545 100644 --- a/src/reinforcement_learning/reinforcement_learning/test.py +++ b/src/reinforcement_learning/reinforcement_learning/test.py @@ -16,7 +16,7 @@ def main(): rclpy.init() - env_config, algorithm_config, network_config = parse_args() + env_config, algorithm_config, network_config, rest = parse_args() print( f'Environment Config: ------------------------------------- \n' @@ -30,26 +30,14 @@ def main(): env_factory = EnvironmentFactory() network_factory = NetworkFactory() - match network_config['algorithm']: - case 'PPO': - config = cfg.PPOConfig(**network_config) - case 'DDPG': - config = cfg.DDPGConfig(**network_config) - case 'SAC': - config = cfg.SACConfig(**network_config) - case 'TD3': - config = cfg.TD3Config(**network_config) - case _: - raise Exception(f'Algorithm {network_config["algorithm"]} not implemented') - env = env_factory.create(env_config['environment'], env_config) - agent = network_factory.create_network(env.OBSERVATION_SIZE, env.ACTION_NUM, config=config) + agent = network_factory.create_network(env.OBSERVATION_SIZE, env.ACTION_NUM, config=network_config) # Load models if both paths are provided - if network_config['actor_path'] and network_config['critic_path']: + if rest['actor_path'] and rest['critic_path']: print('Reading saved models into actor and critic') - agent.actor_net.load_state_dict(torch.load(network_config['actor_path'])) - agent.critic_net.load_state_dict(torch.load(network_config['critic_path'])) + agent.actor_net.load_state_dict(torch.load(rest['actor_path'])) + agent.critic_net.load_state_dict(torch.load(rest['critic_path'])) print('Successfully Loaded models') else: raise Exception('Both actor and critic paths must be provided') @@ -58,7 +46,7 @@ def main(): case 'PPO': ppo_evaluate(env, agent, algorithm_config) case _: - off_policy_evaluate(env, agent, algorithm_config) + off_policy_evaluate(env, agent, algorithm_config['number_eval_episodes']) if __name__ == '__main__': main() diff --git a/src/reinforcement_learning/reinforcement_learning/train.py b/src/reinforcement_learning/reinforcement_learning/train.py index 150febc5..008c0f77 100644 --- a/src/reinforcement_learning/reinforcement_learning/train.py +++ b/src/reinforcement_learning/reinforcement_learning/train.py @@ -19,7 +19,7 @@ def main(): rclpy.init() - env_config, algorithm_config, network_config = parse_args() + env_config, algorithm_config, network_config, _ = parse_args() # Set Seeds torch.manual_seed(algorithm_config['seed'])