diff --git a/src/VIRAL.py b/src/VIRAL.py index 7f4e517..ff6de1e 100644 --- a/src/VIRAL.py +++ b/src/VIRAL.py @@ -1,21 +1,26 @@ import random import signal import sys -import threading from logging import getLogger from typing import Callable, Dict, List import numpy as np -from OllamaChat import OllamaChat -from State import State - +from utils.OllamaChat import OllamaChat +from utils.State import State +from utils.Algo import Algo +from utils.Environments import Environments +from utils.CustomRewardWrapper import CustomRewardWrapper +from utils.TrainingInfoCallback import TrainingInfoCallback +from stable_baselines3 import PPO +from stable_baselines3.common.env_util import make_vec_env +import gymnasium as gym class VIRAL: def __init__( self, - learning_method: Callable, - env, + learning_algo: Algo, + env_type : Environments, objectives_metrics: List[callable] = [], model: str = "qwen2.5-coder", options: dict = {}, @@ -40,32 +45,14 @@ def __init__( """, options=options, ) - self.env = env + self.env_type : Environments = env_type + self.env = None self.objectives_metrics = objectives_metrics - self.learning_method = learning_method + self.learning_algo : Algo = learning_algo + self.learning_method = None self.memory: List[State] = [State(0)] self.logger = getLogger("VIRAL") - self.stops_threads = threading.Event() - self.lock = threading.Lock() - self.threads: list[threading.Thread] = [] - self.threads.append( - threading.Thread(target=self._learning, args=[self.memory[0]]) - ) - self.threads[0].start() - signal.signal(signal.SIGTERM, self.sigterm_handler) - signal.signal(signal.SIGINT, self.sigterm_handler) - - def sigterm_handler(self, signal, frame): - self.stops_threads.set() - for thread in self.threads: - if thread.is_alive(): - thread.join() - if len(threading.enumerate()) > 1: - for thread in threading.enumerate(): - self.logger.error(f"{thread.name},({thread.ident}) is alive") - else: - self.logger.debug("end of main thread") - sys.exit(0) + #self.training_callback = TrainingInfoCallback() def generate_reward_function( self, task_description: str, iterations: int = 1 @@ -93,11 +80,11 @@ def generate_reward_function( #repeat_last_n": 64, # combien le model regarde en arrière pour éviter de répéter les réponses (64 par défaut large pour nous) - "repeat_penalty": 1.5, # pénalité pour éviter de répéter les réponses (1.1 par défaut au mac 1.5 intéressant a modificer je pense) + #"repeat_penalty": 1.5, # pénalité pour éviter de répéter les réponses (1.1 par défaut au mac 1.5 intéressant a modificer je pense) #"stop": "stop you here" # pour stopper la génération de texte pas intéressant pour nous - "tfs_z": 1.2, #reduire l'impacte des token les moins "pertinents" (1.0 par défaut pour désactiver 2.0 max) + #"tfs_z": 1.2, #reduire l'impacte des token les moins "pertinents" (1.0 par défaut pour désactiver 2.0 max) #"top_k": 30, #reduit la probabilité de générer des non-sens (40 par défaut, 100 pour générer des réponses plus diverses, 10 pour des réponses plus "conservatrices") #"top_p": 0.95, #marche avec le top_k une forte valeur pour des texte plus diverses (0.9 par défaut) @@ -110,12 +97,12 @@ def generate_reward_function( additional_options["seed"] = random.randint(0, 1000000) for i in [1, 2]: prompt = f""" - Complete the reward function for a {self.env.spec.name} environment. + Complete the reward function for a {self.env_type.value} environment. Task Description: {task_description} Iteration {i+1}/{2} complete this sentence: def reward_func(observations:np.ndarray, terminated: bool, truncated: bool) -> float: - \"\"\"Reward function for {self.env.spec.name} + \"\"\"Reward function for {self.env_type.value} Args: observations (np.ndarray): observation on the current state @@ -134,7 +121,9 @@ def reward_func(observations:np.ndarray, terminated: bool, truncated: bool) -> f response = self.llm.print_Generator_and_return(response, i) reward_func, response = self._get_runnable_function(response) self.memory.append(State(i, reward_func, response)) - best_idx, worst_idx = self.evaluate_policy(1, 2) + + best_idx, worst_idx = self.evaluate_policy(1, 2) + worst_idx = 1 self.logger.debug(f"state to refine: {worst_idx}") ### SECOND STAGE ### for n in range(iterations - 1): @@ -156,11 +145,12 @@ def _get_runnable_function(self, response: str, error: str = None) -> Callable: response = self.llm.generate_response(stream=True) response = self.llm.print_Generator_and_return(response) try: + env = gym.make(self.env_type.value) response = self._get_code(response) reward_func = self._compile_reward_function(response) - state, _ = self.env.reset() - action = self.learning_method.output(state) - next_observation, _, terminated, truncated, _ = self.env.step(action) + state, _ = env.reset() + action = env.action_space.sample() + next_observation, _, terminated, truncated, _ = env.step(action) self._test_reward_function( reward_func, observations=next_observation, @@ -253,26 +243,20 @@ def self_refine_reward(self, idx: int) -> Callable: def _learning(self, state: State) -> None: """train a policy on an environment""" self.logger.debug(f"state {state.idx} begin is learning") + vec_env, model = self._generate_env_model(state.reward_func) - policy, perfs, sr, nb_ep = self.learning_method.train( - reward_func=state.reward_func, - save_name=f"model/{self.learning_method}_{self.env.spec.name}{state.idx}.model", - stop=self.stops_threads, - ) + training_callback = TrainingInfoCallback() + policy = model.learn(total_timesteps=60000, callback=training_callback) + metrics = training_callback.get_metrics() self.memory[state.idx].set_policy(policy) - observations, rewards, sr_test = self.test_policy(policy) + observations, rewards, sr_test = self.test_policy(vec_env, policy) + metrics["test_success_rate"] = sr_test + metrics["test_rewards"] = rewards perso_observations = [] for objective_metric in self.objectives_metrics: perso_observations.append(objective_metric(observations)) - self.memory[state.idx].set_performances( - { - "train_success_rate": sr, - "train_episodes": nb_ep, - "test_success_rate": sr_test, - "test_rewards": rewards, - "custom_metrics": perso_observations, - } - ) # TODO maybe add in to chat this state + self.memory[state.idx].set_performances(metrics) + print(f"state {state.idx} performances: {metrics}") self.logger.debug(f"state {state.idx} as finished is learning") def evaluate_policy(self, idx1: int, idx2: int) -> int: @@ -288,16 +272,9 @@ def evaluate_policy(self, idx1: int, idx2: int) -> int: """ if len(self.memory) < 2: self.logger.error("At least two reward functions are required.") - to_join: int = [] for i in [idx1, idx2]: if self.memory[i].performances is None: - self.threads.append( - threading.Thread(target=self._learning, args=[self.memory[i]]) - ) - self.threads[-1].start() - to_join.append(i) - for t in to_join: - self.threads[t].join() + self._learning(self.memory[i]) # TODO comparaison sur le success rate pour l'instant if ( self.memory[idx1].performances["test_success_rate"] @@ -309,10 +286,11 @@ def evaluate_policy(self, idx1: int, idx2: int) -> int: def test_policy( self, + env, policy, reward_func=None, nb_episodes=100, - max_t=1000, + max_t=1000 ) -> list: all_rewards = [] all_states = [] @@ -320,24 +298,21 @@ def test_policy( x_max = 0 x_min = 0 # avoid div by 0 for epi in range(1, nb_episodes + 1): - if self.stops_threads.is_set(): - break total_reward = 0 - state, _ = self.env.reset() + obs = env.reset() for i in range(1, max_t + 1): - action = policy.output(state) - next_observation, reward, terminated, truncated, _ = self.env.step( - action - ) + action = policy.predict(obs, deterministic=True)[0] + next_observation, reward, done, info = env.step(action) + #infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated + truncated = info[0]["TimeLimit.truncated"] if reward_func is not None: - reward = reward_func(next_observation, terminated, truncated) + reward = reward_func(next_observation, done, truncated) total_reward += reward state = next_observation all_states.append(state) - if terminated: - break - if truncated: - nb_success += 1 + if done: + if truncated: + nb_success += 1 break all_rewards.append(total_reward) if total_reward > x_max: @@ -348,3 +323,16 @@ def test_policy( x - x_min / x_max - x_min for x in all_rewards ] # Min-Max normalized return all_states, all_rewards, (nb_success / nb_episodes) + + def _generate_env_model(self, reward_func): + """ + Generate the environment model + """ + vec_env = make_vec_env(self.env_type.value, n_envs=1, wrapper_class=CustomRewardWrapper, wrapper_kwargs={'llm_reward_function': reward_func}) + if self.learning_algo == Algo.PPO: + model = PPO("MlpPolicy", vec_env, verbose=1) + else: + raise ValueError("The learning algorithm is not implemented.") + + return vec_env, model + \ No newline at end of file diff --git a/src/log/log.txt b/src/log/log.txt index 513fa7f..b5cc0e2 100644 --- a/src/log/log.txt +++ b/src/log/log.txt @@ -4934,3 +4934,733 @@ def reward_func(observations: np.ndarray, terminated: bool, truncated: bool) -> 12:12:12 VIRAL.py:174 WARNING Error syntax Syntax error in the generated code : unexpected indent (, line 4) +13:21:33 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:22:01 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:22:44 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:22:52 VIRAL.py:119 INFO + additional options: {'temperature': 1, 'repeat_penalty': 1.5, 'tfs_z': 1.2, 'seed': 613243} + +13:22:56 VIRAL.py:162 WARNING + Error syntax Syntax error in the generated code : inconsistent use of tabs and spaces in indentation (, line 6) + +13:23:04 VIRAL.py:162 WARNING + Error syntax Syntax error in the generated code : unterminated string literal (detected at line 1) (, line 1) + +13:23:12 VIRAL.py:162 WARNING + Error syntax Syntax error in the generated code : unterminated string literal (detected at line 1) (, line 1) + +13:23:19 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:23:21 VIRAL.py:119 INFO + additional options: {'temperature': 1, 'repeat_penalty': 1.5, 'tfs_z': 1.2, 'seed': 643271} + +13:23:25 VIRAL.py:162 WARNING + Error syntax Syntax error in the generated code : invalid syntax (, line 8) + +13:23:31 VIRAL.py:162 WARNING + Error syntax Syntax error in the generated code : unindent does not match any outer indentation level (, line 11) + +13:23:37 VIRAL.py:162 WARNING + Error syntax Syntax error in the generated code : unindent does not match any outer indentation level (, line 11) + +13:23:43 VIRAL.py:162 WARNING + Error syntax Syntax error in the generated code : unindent does not match any outer indentation level (, line 11) + +13:23:50 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:23:52 VIRAL.py:119 INFO + additional options: {'temperature': 1, 'repeat_penalty': 1.5, 'tfs_z': 1.2, 'seed': 942914} + +13:23:53 VIRAL.py:162 WARNING + Error syntax Syntax error in the generated code : invalid syntax (, line 4) + +13:24:10 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:24:12 VIRAL.py:119 INFO + additional options: {'temperature': 1, 'seed': 76352} + +13:26:44 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:26:46 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 656100} + +13:27:42 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:27:44 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 339103} + +13:27:47 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 339103} + +13:27:48 main.py:49 INFO + state 0: +reward function: + +None + + isn't trained yet + +13:27:48 main.py:49 INFO + state 1: +reward function: + +def reward_func(observations:np.ndarray, terminated: bool, truncated: bool) -> float: + if terminated or truncated: + return -1.0 + else: + return 1.0 - observations[2]**2 + + isn't trained yet + +13:27:48 main.py:49 INFO + state 2: +reward function: + +def reward_func(observations:np.ndarray, terminated: bool, truncated: bool) -> float: + if terminated or truncated: + return -1.0 + else: + return 1.0 + observations[2]**2 + + isn't trained yet + +13:27:59 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:28:01 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 734281} + +13:28:06 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 734281} + +13:28:08 main.py:49 INFO + state 0: +reward function: + +None + + isn't trained yet + +13:28:08 main.py:49 INFO + state 1: +reward function: + +import numpy as np + +def reward_func(observations:np.ndarray, terminated: bool, truncated: bool) -> float: + """Reward function for CartPole-v1 + + Args: + observations (np.ndarray): observation on the current state + terminated (bool): episode is terminated due a failure + truncated (bool): episode is truncated due a success + + Returns: + float: The reward for the current step + """ + pole_angle = observations[2] + + if terminated or truncated: + return -1.0 + else: + return 1.0 - abs(pole_angle) / np.pi * 2 + + isn't trained yet + +13:28:08 main.py:49 INFO + state 2: +reward function: + +import numpy as np + +def reward_func(observations:np.ndarray, terminated: bool, truncated: bool) -> float: + """Reward function for CartPole-v1 + + Args: + observations (np.ndarray): observation on the current state + terminated (bool): episode is terminated due a failure + truncated (bool): episode is truncated due a success + + Returns: + float: The reward for the current step + """ + pole_angle = observations[2] + + if terminated or truncated: + return -1.0 + else: + return 1.0 - abs(pole_angle) / np.pi * 2 + + isn't trained yet + +13:28:32 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:28:34 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 735279} + +13:28:38 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 735279} + +13:28:40 main.py:49 INFO + state 0: +reward function: + +None + + isn't trained yet + +13:28:40 main.py:49 INFO + state 1: +reward function: + +def reward_func(observations: np.ndarray, terminated: bool, truncated: bool) -> float: + """Reward function for CartPole-v1 + + Args: + observations (np.ndarray): observation on the current state + terminated (bool): episode is terminated due a failure + truncated (bool): episode is truncated due a success + + Returns: + float: The reward for the current step + """ + if terminated or truncated: + return -1.0 + else: + return 1.0 + + isn't trained yet + +13:28:40 main.py:49 INFO + state 2: +reward function: + +def reward_func(observations: np.ndarray, terminated: bool, truncated: bool) -> float: + """Reward function for CartPole-v1 + + Args: + observations (np.ndarray): observation on the current state + terminated (bool): episode is terminated due a failure + truncated (bool): episode is truncated due a success + + Returns: + float: The reward for the current step + """ + if terminated or truncated: + return -1.0 + else: + angle = observations[2] + return 1.0 - abs(angle / 0.418) + + isn't trained yet + +13:29:39 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:29:41 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 601536} + +13:29:45 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 601536} + +13:31:18 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:31:20 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 379124} + +13:31:25 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 379124} + +13:36:36 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:36:44 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 161778} + +13:36:49 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 161778} + +13:38:14 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:38:16 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 478769} + +13:38:21 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 478769} + +13:39:24 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:39:26 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 788195} + +13:39:30 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 788195} + +13:41:01 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:41:04 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 222331} + +13:41:08 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 222331} + +13:43:39 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:43:42 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 392125} + +13:43:46 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 392125} + +13:44:00 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:44:02 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 655337} + +13:44:06 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 655337} + +13:45:58 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:46:01 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 87728} + +13:46:06 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 87728} + +13:47:36 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:47:38 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 681457} + +13:47:41 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 681457} + +13:48:17 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:48:19 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 189860} + +13:48:23 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 189860} + +13:49:22 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:49:24 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 973783} + +13:49:29 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 973783} + +13:53:34 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:53:36 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 606951} + +13:53:42 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 606951} + +13:54:10 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:54:12 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 743526} + +13:54:17 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 743526} + +13:57:01 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:57:03 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 973923} + +13:57:08 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 973923} + +13:59:21 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +13:59:23 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 234892} + +13:59:29 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 234892} + +14:03:19 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +14:03:21 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 285825} + +14:03:26 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 285825} + +14:19:52 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +14:20:00 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 523047} + +14:20:05 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 523047} + +14:21:07 main.py:49 INFO + state 0: +reward function: + +None + + isn't trained yet + +14:21:07 main.py:49 INFO + state 1: +reward function: + +import numpy as np + +def reward_func(observations:np.ndarray, terminated: bool, truncated: bool) -> float: + """Reward function for CartPole-v1 + + Args: + observations (np.ndarray): observation on the current state + terminated (bool): episode is terminated due a failure + truncated (bool): episode is truncated due a success + + Returns: + float: The reward for the current step + """ + angle = observations[2] + if terminated or truncated: + return -1.0 # Failure or success, negative reward + else: + return np.cos(angle) # Positive reward based on pole's vertical alignment + + Performances: + +{'train_success_rate': 100.0, 'train_episodes': 266, 'custom_metrics': {}, 'test_success_rate': 0.0, 'test_rewards': [array([6.9661956], dtype=float32), array([7.9666786], dtype=float32), array([7.9743767], dtype=float32), array([6.9526296], dtype=float32), array([7.9711695], dtype=float32), array([6.9662633], dtype=float32), array([6.9566326], dtype=float32), array([6.9468813], dtype=float32), array([7.968299], dtype=float32), array([7.9710064], dtype=float32), array([7.9740477], dtype=float32), array([7.9554615], dtype=float32), array([7.9686165], dtype=float32), array([6.9505167], dtype=float32), array([7.975523], dtype=float32), array([7.974062], dtype=float32), array([7.974058], dtype=float32), array([7.975336], dtype=float32), array([6.9584856], dtype=float32), array([7.9685783], dtype=float32), array([6.9728503], dtype=float32), array([6.9664545], dtype=float32), array([7.973238], dtype=float32), array([7.9752445], dtype=float32), array([6.97116], dtype=float32), array([8.958213], dtype=float32), array([6.9739437], dtype=float32), array([7.9682894], dtype=float32), array([7.970299], dtype=float32), array([8.957311], dtype=float32), array([7.962881], dtype=float32), array([6.966643], dtype=float32), array([6.9740434], dtype=float32), array([7.970749], dtype=float32), array([7.9729576], dtype=float32), array([7.963971], dtype=float32), array([7.9719152], dtype=float32), array([7.973875], dtype=float32), array([7.960619], dtype=float32), array([7.95782], dtype=float32), array([7.97324], dtype=float32), array([7.969879], dtype=float32), array([6.9488845], dtype=float32), array([6.9505363], dtype=float32), array([6.9723186], dtype=float32), array([6.9681463], dtype=float32), array([6.951812], dtype=float32), array([6.973339], dtype=float32), array([7.969619], dtype=float32), array([6.969921], dtype=float32), array([5.965583], dtype=float32), array([7.9729977], dtype=float32), array([7.967022], dtype=float32), array([7.9752455], dtype=float32), array([6.9745774], dtype=float32), array([6.946839], dtype=float32), array([7.9725657], dtype=float32), array([6.968704], dtype=float32), array([7.971817], dtype=float32), array([7.955346], dtype=float32), array([7.9672003], dtype=float32), array([6.9571753], dtype=float32), array([7.9757013], dtype=float32), array([7.9652205], dtype=float32), array([6.959957], dtype=float32), array([6.958987], dtype=float32), array([6.973798], dtype=float32), array([7.965373], dtype=float32), array([6.9544606], dtype=float32), array([6.9696803], dtype=float32), array([6.9539676], dtype=float32), array([7.9676437], dtype=float32), array([7.961423], dtype=float32), array([7.9618893], dtype=float32), array([7.9730186], dtype=float32), array([6.969637], dtype=float32), array([7.9639864], dtype=float32), array([6.9628677], dtype=float32), array([7.961005], dtype=float32), array([6.9554634], dtype=float32), array([7.969965], dtype=float32), array([7.966872], dtype=float32), array([7.9683123], dtype=float32), array([7.9744987], dtype=float32), array([7.974103], dtype=float32), array([6.960107], dtype=float32), array([7.960619], dtype=float32), array([6.9713655], dtype=float32), array([7.97474], dtype=float32), array([7.973501], dtype=float32), array([7.963004], dtype=float32), array([6.9714313], dtype=float32), array([7.974412], dtype=float32), array([6.94937], dtype=float32), array([6.966172], dtype=float32), array([6.9685435], dtype=float32), array([7.9758053], dtype=float32), array([6.9507504], dtype=float32), array([7.9687405], dtype=float32), array([6.9664655], dtype=float32)]} + + Policy: + +14:21:07 main.py:49 INFO + state 2: +reward function: + +import numpy as np + +def reward_func(observations:np.ndarray, terminated: bool, truncated: bool) -> float: + """Reward function for CartPole-v1 + + Args: + observations (np.ndarray): observation on the current state + terminated (bool): episode is terminated due a failure + truncated (bool): episode is truncated due a success + + Returns: + float: The reward for the current step + """ + angle = observations[2] + if terminated or truncated: + return -1.0 # Failure or success, negative reward + else: + return np.cos(angle) + 0.5 * (angle**2 < 0.2)**2 # Positive reward based on pole's vertical alignment and stability + + Performances: + +{'train_success_rate': 100.0, 'train_episodes': 251, 'custom_metrics': {}, 'test_success_rate': 0.0, 'test_rewards': [array([10.95259], dtype=float32), array([13.957707], dtype=float32), array([10.969561], dtype=float32), array([12.471957], dtype=float32), array([12.472703], dtype=float32), array([12.467831], dtype=float32), array([10.973705], dtype=float32), array([12.460483], dtype=float32), array([12.457328], dtype=float32), array([12.461605], dtype=float32), array([10.963558], dtype=float32), array([12.475749], dtype=float32), array([9.464667], dtype=float32), array([12.466751], dtype=float32), array([12.455445], dtype=float32), array([10.945679], dtype=float32), array([12.470938], dtype=float32), array([12.471609], dtype=float32), array([12.470207], dtype=float32), array([10.972839], dtype=float32), array([12.467936], dtype=float32), array([10.968296], dtype=float32), array([10.961361], dtype=float32), array([12.467699], dtype=float32), array([12.462386], dtype=float32), array([12.470747], dtype=float32), array([12.452126], dtype=float32), array([12.469376], dtype=float32), array([12.464769], dtype=float32), array([13.958049], dtype=float32), array([12.472413], dtype=float32), array([10.967122], dtype=float32), array([10.956684], dtype=float32), array([12.460769], dtype=float32), array([12.473208], dtype=float32), array([10.974245], dtype=float32), array([12.466009], dtype=float32), array([12.462774], dtype=float32), array([10.968747], dtype=float32), array([12.463215], dtype=float32), array([10.948158], dtype=float32), array([10.965016], dtype=float32), array([13.956003], dtype=float32), array([12.453936], dtype=float32), array([12.471634], dtype=float32), array([12.467441], dtype=float32), array([12.47491], dtype=float32), array([12.465915], dtype=float32), array([10.955706], dtype=float32), array([12.472738], dtype=float32), array([10.9681015], dtype=float32), array([12.466104], dtype=float32), array([10.968498], dtype=float32), array([12.463783], dtype=float32), array([12.461851], dtype=float32), array([10.961075], dtype=float32), array([12.470657], dtype=float32), array([12.466281], dtype=float32), array([12.461803], dtype=float32), array([12.459101], dtype=float32), array([10.9648285], dtype=float32), array([12.475041], dtype=float32), array([12.465155], dtype=float32), array([12.455122], dtype=float32), array([12.473621], dtype=float32), array([10.954266], dtype=float32), array([12.472059], dtype=float32), array([12.467389], dtype=float32), array([12.461479], dtype=float32), array([12.465909], dtype=float32), array([12.473216], dtype=float32), array([10.971273], dtype=float32), array([12.455473], dtype=float32), array([12.473352], dtype=float32), array([12.472754], dtype=float32), array([12.475703], dtype=float32), array([12.471912], dtype=float32), array([12.46884], dtype=float32), array([12.464479], dtype=float32), array([10.964634], dtype=float32), array([12.456501], dtype=float32), array([12.468685], dtype=float32), array([10.966262], dtype=float32), array([12.475623], dtype=float32), array([13.954869], dtype=float32), array([10.971606], dtype=float32), array([12.47558], dtype=float32), array([12.469751], dtype=float32), array([12.475481], dtype=float32), array([12.467912], dtype=float32), array([10.948302], dtype=float32), array([12.471239], dtype=float32), array([12.461482], dtype=float32), array([13.954246], dtype=float32), array([12.453873], dtype=float32), array([10.966632], dtype=float32), array([12.463488], dtype=float32), array([12.4592085], dtype=float32), array([12.473618], dtype=float32), array([10.966131], dtype=float32)]} + + Policy: + +14:31:47 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +14:31:54 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 614476} + +14:32:00 VIRAL.py:120 INFO + additional options: {'temperature': 1, 'seed': 614476} + diff --git a/src/main.py b/src/main.py index 8eab88c..ffdaeb5 100644 --- a/src/main.py +++ b/src/main.py @@ -4,14 +4,16 @@ import gymnasium as gym from log.log_config import init_logger -from ObjectivesMetrics import objective_metric_CartPole +from utils.ObjectivesMetrics import objective_metric_CartPole from RLAlgo.DirectSearch import DirectSearch from RLAlgo.Reinforce import Reinforce from VIRAL import VIRAL -from CustomRewardWrapper import CustomRewardWrapper +from utils.CustomRewardWrapper import CustomRewardWrapper from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_vec_env +from utils.Environments import Environments +from utils.Algo import Algo def parse_logger(): parser = argparse.ArgumentParser() @@ -29,13 +31,9 @@ def parse_logger(): if __name__ == "__main__": - parse_logger() - vec_env = make_vec_env("CartPole-v1", n_envs=4) - learning_method = PPO("MlpPolicy", vec_env, verbose=1) - learning_method.learn(total_timesteps=25000) + logger = parse_logger() - objectives_metrics = [objective_metric_CartPole] - viral = VIRAL("PPO", "CartPole-v1", objectives_metrics) + viral = VIRAL(Algo.PPO, Environments.CARTPOLE) res = viral.generate_reward_function( task_description="""Balance a pole on a cart, Num Observation Min Max diff --git a/src/test.py b/src/test.py index 39fe3d8..5f5fb05 100644 --- a/src/test.py +++ b/src/test.py @@ -2,7 +2,7 @@ import gymnasium as gym import numpy as np from stable_baselines3.common.callbacks import BaseCallback -from CustomRewardWrapper import CustomRewardWrapper +from utils.CustomRewardWrapper import CustomRewardWrapper from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_vec_env class TrainingInfoCallback(BaseCallback): @@ -111,24 +111,4 @@ def reward_func(observations:np.ndarray, terminated: bool, truncated: bool) -> f print(f"Nombre total d'épisodes : {metrics['total_episodes']}") print(f"Épisodes terminés (terminated) : {metrics['terminated_count']}") print(f"Épisodes tronqués (truncated) : {metrics['truncated_count']}") -# Exemple de visualisation des métriques -import matplotlib.pyplot as plt - -plt.figure(figsize=(12, 5)) - -# Graphique des récompenses par épisode -plt.subplot(1, 2, 1) -plt.plot(metrics['episode_rewards']) -plt.title('Récompenses par Épisode') -plt.xlabel("Numéro d'Épisode") -plt.ylabel('Récompense') - -# Graphique de la moyenne mobile des récompenses -plt.subplot(1, 2, 2) -plt.plot(metrics['mean_rewards_10_episodes']) -plt.title('Moyenne Mobile des Récompenses (10 derniers épisodes)') -plt.xlabel("Groupe d'Épisodes") -plt.ylabel('Récompense Moyenne') - -plt.tight_layout() -plt.show() \ No newline at end of file +# Exemple de visualisation des métriques \ No newline at end of file diff --git a/src/utils/Algo.py b/src/utils/Algo.py new file mode 100644 index 0000000..a6b3c4a --- /dev/null +++ b/src/utils/Algo.py @@ -0,0 +1,6 @@ +from enum import Enum + +class Algo(Enum): + PPO = "PPO" + REINFORCE = "REINFORCE" + \ No newline at end of file diff --git a/src/CustomRewardWrapper.py b/src/utils/CustomRewardWrapper.py similarity index 100% rename from src/CustomRewardWrapper.py rename to src/utils/CustomRewardWrapper.py diff --git a/src/utils/Environments.py b/src/utils/Environments.py new file mode 100644 index 0000000..f435c6c --- /dev/null +++ b/src/utils/Environments.py @@ -0,0 +1,5 @@ +from enum import Enum + +class Environments(Enum): + CARTPOLE = "CartPole-v1" + LUNAR_LANDER = "LunarLander-v3" \ No newline at end of file diff --git a/src/ObjectivesMetrics.py b/src/utils/ObjectivesMetrics.py similarity index 100% rename from src/ObjectivesMetrics.py rename to src/utils/ObjectivesMetrics.py diff --git a/src/OllamaChat.py b/src/utils/OllamaChat.py similarity index 100% rename from src/OllamaChat.py rename to src/utils/OllamaChat.py diff --git a/src/State.py b/src/utils/State.py similarity index 100% rename from src/State.py rename to src/utils/State.py diff --git a/src/utils/TrainingInfoCallback.py b/src/utils/TrainingInfoCallback.py new file mode 100644 index 0000000..3b462d1 --- /dev/null +++ b/src/utils/TrainingInfoCallback.py @@ -0,0 +1,81 @@ +import numpy as np +from stable_baselines3.common.callbacks import BaseCallback + +class TrainingInfoCallback(BaseCallback): + def __init__(self, verbose=0): + super().__init__(verbose) + self.training_metrics = { + 'timesteps': [], + 'episode_rewards': [], + 'episode_lengths': [], + 'mean_rewards': [], + 'mean_lengths': [], + 'terminated_count': 0, + 'truncated_count': 0, + 'total_episodes': 0 + } + + self.current_episode_reward = 0 + self.current_episode_length = 0 + self.custom_metrics = {} + + def _on_step(self) -> bool: + for reward, done, truncated in zip( + self.locals['rewards'], + self.locals['dones'], + self.locals.get('truncateds', [False] * len(self.locals['rewards'])) + ): + self.current_episode_reward += reward + self.current_episode_length += 1 + + # Vérifier si l'épisode est terminé ou tronqué + if done or truncated: + self.training_metrics['total_episodes'] += 1 + + if done: + self.training_metrics['terminated_count'] += 1 + + if truncated: + self.training_metrics['truncated_count'] += 1 + + # Stocker les métriques de l'épisode + self.training_metrics['episode_rewards'].append(self.current_episode_reward) + self.training_metrics['episode_lengths'].append(self.current_episode_length) + + # Calculer et stocker les moyennes glissantes + if len(self.training_metrics['episode_rewards']) > 10: + mean_reward = np.mean(self.training_metrics['episode_rewards'][-10:]) + mean_length = np.mean(self.training_metrics['episode_lengths'][-10:]) + + self.training_metrics['mean_rewards'].append(mean_reward) + self.training_metrics['mean_lengths'].append(mean_length) + + # Réinitialiser pour le prochain épisode + self.current_episode_reward = 0 + self.current_episode_length = 0 + + # Stocker le nombre de timesteps + self.training_metrics['timesteps'].append(self.num_timesteps) + + return True + + def _on_training_end(self) -> None: + # Calculer le taux de succès à l'entraînement + try: + train_success_rate = (self.training_metrics['terminated_count'] / + self.training_metrics['total_episodes']) * 100 + except ZeroDivisionError: + train_success_rate = 0 + + # Préparer les métriques finales + self.results = { + "train_success_rate": train_success_rate, + "train_episodes": self.training_metrics['total_episodes'], + "custom_metrics": self.custom_metrics + } + + def get_metrics(self): + """ + Méthode pour récupérer les métriques dans le format spécifié + """ + return self.results if hasattr(self, 'results') else {} diff --git a/src/utils.py b/src/utils/utils.py similarity index 100% rename from src/utils.py rename to src/utils/utils.py