From e2ccbe8228dae27c199840ff1cfbe0e38f4fd4c0 Mon Sep 17 00:00:00 2001 From: Emilien Date: Sun, 15 Dec 2024 23:17:39 +0100 Subject: [PATCH] Move the seed in the init --- src/VIRAL.py | 7 +++-- src/log/log.txt | 66 +++++++++++++++++++++++++++++++++++++++++ src/main.py | 22 +++++++------- src/utils/OllamaChat.py | 2 +- 4 files changed, 82 insertions(+), 15 deletions(-) diff --git a/src/VIRAL.py b/src/VIRAL.py index 70a0537..a296cb8 100644 --- a/src/VIRAL.py +++ b/src/VIRAL.py @@ -31,6 +31,9 @@ def __init__( model (str): Language model for reward generation learning_method (str): Reinforcement learning method """ + if (options.get("seed") is None): + options["seed"] = random.randint(0, 1000000) + self.llm = OllamaChat( model=model, system_prompt=""" @@ -52,7 +55,7 @@ def __init__( self.learning_method = None self.memory: List[State] = [State(0)] self.logger = getLogger("VIRAL") - self._learning(self.memory[0]) + #self._learning(self.memory[0]) #self.training_callback = TrainingInfoCallback() def generate_reward_function( @@ -94,8 +97,6 @@ def generate_reward_function( #"seed": 42, # a utiliser pour la reproductibilité des résultats (important si publication) } ### INIT STAGE ### - if (additional_options.get("seed") is None): - additional_options["seed"] = random.randint(0, 1000000) for i in [1, 2]: prompt = f""" Complete the reward function for a {self.env_type.value} environment. diff --git a/src/log/log.txt b/src/log/log.txt index cad650b..80afee9 100644 --- a/src/log/log.txt +++ b/src/log/log.txt @@ -5901,3 +5901,69 @@ None - STOP immediately your completion after the last return +23:11:17 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +23:12:06 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +23:12:14 VIRAL.py:121 INFO + additional options: {'temperature': 1, 'seed': 815094} + +23:12:18 VIRAL.py:121 INFO + additional options: {'temperature': 1, 'seed': 815094} + +23:16:37 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + + +23:16:40 VIRAL.py:122 INFO + additional options: {'temperature': 1} + +23:17:09 OllamaChat.py:32 INFO + System: + You are an expert in Reinforcement Learning specialized in designing reward functions. + Strict criteria: + - Complete ONLY the reward function code + - Use Python format + - Give no additional explanations + - Focus on the Gymnasium environment + - Take into the observation of the state, the terminated and truncated boolean + - STOP immediately your completion after the last return + , Options: {'seed': 635239} + +23:17:11 VIRAL.py:122 INFO + additional options: {'temperature': 1} + +23:17:14 VIRAL.py:166 WARNING + Error syntax Syntax error in the generated code : 'return' outside function (, line 20) + +23:17:20 VIRAL.py:122 INFO + additional options: {'temperature': 1} + diff --git a/src/main.py b/src/main.py index 13fda39..220e3ef 100644 --- a/src/main.py +++ b/src/main.py @@ -33,16 +33,16 @@ def parse_logger(): logger = parse_logger() viral = VIRAL(Algo.PPO, Environments.CARTPOLE) - # res = viral.generate_reward_function( - # task_description="""Balance a pole on a cart, - # Num Observation Min Max - # 0 Cart Position -4.8 4.8 - # 1 Cart Velocity -Inf Inf - # 2 Pole Angle ~ -0.418 rad (-24°) ~ 0.418 rad (24°) - # 3 Pole Angular Velocity -Inf Inf - # Since the goal is to keep the pole upright for as long as possible. - # """, - # iterations=1, - # ) + res = viral.generate_reward_function( + task_description="""Balance a pole on a cart, + Num Observation Min Max + 0 Cart Position -4.8 4.8 + 1 Cart Velocity -Inf Inf + 2 Pole Angle ~ -0.418 rad (-24°) ~ 0.418 rad (24°) + 3 Pole Angular Velocity -Inf Inf + Since the goal is to keep the pole upright for as long as possible. + """, + iterations=1, + ) for state in viral.memory: logger.info(state) diff --git a/src/utils/OllamaChat.py b/src/utils/OllamaChat.py index 6a5f49c..9c13072 100644 --- a/src/utils/OllamaChat.py +++ b/src/utils/OllamaChat.py @@ -29,7 +29,7 @@ def __init__( self.logger = getLogger("VIRAL") if system_prompt: - self.logger.info(f"System: {system_prompt}") + self.logger.info(f"System: {system_prompt}, Options: {self.options}") self.add_message(system_prompt, role="system") def add_message(self, content: str, role: str = "user", **kwargs) -> None: