Skip to content

Commit

Permalink
Move the seed in the init
Browse files Browse the repository at this point in the history
  • Loading branch information
ekomlenovic committed Dec 15, 2024
1 parent a3e0ce5 commit e2ccbe8
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 15 deletions.
7 changes: 4 additions & 3 deletions src/VIRAL.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def __init__(
model (str): Language model for reward generation
learning_method (str): Reinforcement learning method
"""
if (options.get("seed") is None):
options["seed"] = random.randint(0, 1000000)

self.llm = OllamaChat(
model=model,
system_prompt="""
Expand All @@ -52,7 +55,7 @@ def __init__(
self.learning_method = None
self.memory: List[State] = [State(0)]
self.logger = getLogger("VIRAL")
self._learning(self.memory[0])
#self._learning(self.memory[0])
#self.training_callback = TrainingInfoCallback()

def generate_reward_function(
Expand Down Expand Up @@ -94,8 +97,6 @@ def generate_reward_function(
#"seed": 42, # a utiliser pour la reproductibilité des résultats (important si publication)
}
### INIT STAGE ###
if (additional_options.get("seed") is None):
additional_options["seed"] = random.randint(0, 1000000)
for i in [1, 2]:
prompt = f"""
Complete the reward function for a {self.env_type.value} environment.
Expand Down
66 changes: 66 additions & 0 deletions src/log/log.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5901,3 +5901,69 @@ None
- STOP immediately your completion after the last return


23:11:17 OllamaChat.py:32 INFO
System:
You are an expert in Reinforcement Learning specialized in designing reward functions.
Strict criteria:
- Complete ONLY the reward function code
- Use Python format
- Give no additional explanations
- Focus on the Gymnasium environment
- Take into the observation of the state, the terminated and truncated boolean
- STOP immediately your completion after the last return


23:12:06 OllamaChat.py:32 INFO
System:
You are an expert in Reinforcement Learning specialized in designing reward functions.
Strict criteria:
- Complete ONLY the reward function code
- Use Python format
- Give no additional explanations
- Focus on the Gymnasium environment
- Take into the observation of the state, the terminated and truncated boolean
- STOP immediately your completion after the last return


23:12:14 VIRAL.py:121 INFO
additional options: {'temperature': 1, 'seed': 815094}

23:12:18 VIRAL.py:121 INFO
additional options: {'temperature': 1, 'seed': 815094}

23:16:37 OllamaChat.py:32 INFO
System:
You are an expert in Reinforcement Learning specialized in designing reward functions.
Strict criteria:
- Complete ONLY the reward function code
- Use Python format
- Give no additional explanations
- Focus on the Gymnasium environment
- Take into the observation of the state, the terminated and truncated boolean
- STOP immediately your completion after the last return


23:16:40 VIRAL.py:122 INFO
additional options: {'temperature': 1}

23:17:09 OllamaChat.py:32 INFO
System:
You are an expert in Reinforcement Learning specialized in designing reward functions.
Strict criteria:
- Complete ONLY the reward function code
- Use Python format
- Give no additional explanations
- Focus on the Gymnasium environment
- Take into the observation of the state, the terminated and truncated boolean
- STOP immediately your completion after the last return
, Options: {'seed': 635239}

23:17:11 VIRAL.py:122 INFO
additional options: {'temperature': 1}

23:17:14 VIRAL.py:166 WARNING
Error syntax Syntax error in the generated code : 'return' outside function (<string>, line 20)

23:17:20 VIRAL.py:122 INFO
additional options: {'temperature': 1}

22 changes: 11 additions & 11 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ def parse_logger():
logger = parse_logger()

viral = VIRAL(Algo.PPO, Environments.CARTPOLE)
# res = viral.generate_reward_function(
# task_description="""Balance a pole on a cart,
# Num Observation Min Max
# 0 Cart Position -4.8 4.8
# 1 Cart Velocity -Inf Inf
# 2 Pole Angle ~ -0.418 rad (-24°) ~ 0.418 rad (24°)
# 3 Pole Angular Velocity -Inf Inf
# Since the goal is to keep the pole upright for as long as possible.
# """,
# iterations=1,
# )
res = viral.generate_reward_function(
task_description="""Balance a pole on a cart,
Num Observation Min Max
0 Cart Position -4.8 4.8
1 Cart Velocity -Inf Inf
2 Pole Angle ~ -0.418 rad (-24°) ~ 0.418 rad (24°)
3 Pole Angular Velocity -Inf Inf
Since the goal is to keep the pole upright for as long as possible.
""",
iterations=1,
)
for state in viral.memory:
logger.info(state)
2 changes: 1 addition & 1 deletion src/utils/OllamaChat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
self.logger = getLogger("VIRAL")

if system_prompt:
self.logger.info(f"System: {system_prompt}")
self.logger.info(f"System: {system_prompt}, Options: {self.options}")
self.add_message(system_prompt, role="system")

def add_message(self, content: str, role: str = "user", **kwargs) -> None:
Expand Down

0 comments on commit e2ccbe8

Please sign in to comment.