Skip to content

Commit

Permalink
some fix and test
Browse files Browse the repository at this point in the history
  • Loading branch information
valentincuzin committed Jan 4, 2025
1 parent e94efeb commit 0ba76a9
Show file tree
Hide file tree
Showing 8 changed files with 1,100 additions and 2,255 deletions.
2 changes: 1 addition & 1 deletion src/LLM/GenCode.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_runnable_function(self, error: str = None) -> Callable:
action = env.envs[0].action_space.sample()
obs, _, dones, infos = env.step([action])
infos[0]["terminated"] = False
is_success, is_failure = self.success_func(obs[0], infos[0])
is_success, is_failure = self.success_func(env.envs[0], infos[0])
self.test_reward_function(
reward_func, observations=obs[0], is_success=0,
is_failure=0
Expand Down
2 changes: 1 addition & 1 deletion src/PolicyTrainer/CustomRewardWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def step(self, action):
is_success = 0
is_failure = 0
if terminated or truncated:
is_success, is_failure = self.success_func(observation, info)
is_success, is_failure = self.success_func(self.env, info)
reward = self.llm_reward_function(observation, is_success, is_failure)
else:
reward = original_reward
Expand Down
3 changes: 3 additions & 0 deletions src/PolicyTrainer/PolicyTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ def test_policy(
obs, reward, term, trunc, info = env.step(actions)
episode_rewards += reward
done = trunc or term

if done:
info["TimeLimit.truncated"] = trunc
info["terminated"] = term
is_success, _ = self.success_func(env, info)
if is_success:
nb_success += 1
Expand Down
4 changes: 2 additions & 2 deletions src/VIRAL.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from logging import getLogger

from Environments import EnvType
from LLM.GenCode import GenCode
from LLM.OllamaChat import OllamaChat
from PolicyTrainer.PolicyTrainer import PolicyTrainer
from State.State import State
from LLM.GenCode import GenCode
from PolicyTrainer.PolicyTrainer import PolicyTrainer


class VIRAL:
Expand Down
2,242 changes: 0 additions & 2,242 deletions src/log/CartPole-v1_log.csv

This file was deleted.

1,072 changes: 1,072 additions & 0 deletions src/log/CartPole_v1_log.csv

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def main():
memory.
"""
parse_logger()
env_type = LunarLander(Algo.PPO)
env_type = CartPole(Algo.PPO)
model = 'qwen2.5-coder'
human_feedback = False
LoggerCSV(env_type, model)
viral = VIRAL(
env_type=env_type, model=model, hf=human_feedback, training_time=500_000, numenvs=3, options=additional_options)
viral.generate_context(Prompt.LUNAR_LANDER)
viral.generate_reward_function(n_init=1, n_refine=3)
env_type=env_type, model=model, hf=human_feedback, training_time=30_000, numenvs=3, options=additional_options)
viral.generate_context(Prompt.CARTPOLE)
viral.generate_reward_function(n_init=1, n_refine=0)
for state in viral.memory:
viral.logger.info(state)

Expand Down
22 changes: 17 additions & 5 deletions src/main3.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,27 @@ def main():
memory.
"""
parse_logger()
env_type = LunarLander(Algo.PPO)
env_type = CartPole(Algo.PPO)
model = 'qwen2.5-coder'
human_feedback = True
LoggerCSV(env_type, model)
viral = VIRAL(
env_type=env_type, model=model, hf=human_feedback, training_time=2_000, numenvs=1, options=additional_options)
viral.test_reward_func("""def reward_function(observations, is_success, is_failure):
# Your reward calculation logic here
return 0.0271 # Placeholder return value""")
env_type=env_type, model=model, hf=human_feedback, training_time=50_000, numenvs=2, options=additional_options)
viral.test_reward_func("""def reward_func(observations:np.ndarray, is_success:bool, is_failure:bool) -> float:
x, x_dot, theta, theta_dot = observations
if is_success:
return 10.0
elif is_failure:
return -10.0
else:
# Reward based on how close to vertical the pole is and how stable it is
proximity_to_vertical = np.cos(theta)
stability_factor = np.exp(-abs(theta_dot))
reward = proximity_to_vertical * stability_factor
return reward""")
for state in viral.memory:
viral.logger.info(state)

Expand Down

0 comments on commit 0ba76a9

Please sign in to comment.