Skip to content

Commit

Permalink
Make VIRAL work and make utils folder
Browse files Browse the repository at this point in the history
On branch main
Your branch is up to date with 'origin/main'.

Changes to be committed:
	modified:   src/VIRAL.py
	modified:   src/log/log.txt
	modified:   src/main.py
	modified:   src/test.py
	new file:   src/utils/Algo.py
	renamed:    src/CustomRewardWrapper.py -> src/utils/CustomRewardWrapper.py
	new file:   src/utils/Environments.py
	renamed:    src/ObjectivesMetrics.py -> src/utils/ObjectivesMetrics.py
	renamed:    src/OllamaChat.py -> src/utils/OllamaChat.py
	renamed:    src/State.py -> src/utils/State.py
	new file:   src/utils/TrainingInfoCallback.py
	renamed:    src/utils.py -> src/utils/utils.py
  • Loading branch information
ekomlenovic committed Dec 15, 2024
1 parent 3094ad2 commit 526b3b2
Show file tree
Hide file tree
Showing 12 changed files with 891 additions and 103 deletions.
134 changes: 61 additions & 73 deletions src/VIRAL.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
import random
import signal
import sys
import threading
from logging import getLogger
from typing import Callable, Dict, List

import numpy as np

from OllamaChat import OllamaChat
from State import State

from utils.OllamaChat import OllamaChat
from utils.State import State
from utils.Algo import Algo
from utils.Environments import Environments
from utils.CustomRewardWrapper import CustomRewardWrapper
from utils.TrainingInfoCallback import TrainingInfoCallback
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
import gymnasium as gym

class VIRAL:
def __init__(
self,
learning_method: Callable,
env,
learning_algo: Algo,
env_type : Environments,
objectives_metrics: List[callable] = [],
model: str = "qwen2.5-coder",
options: dict = {},
Expand All @@ -40,32 +45,14 @@ def __init__(
""",
options=options,
)
self.env = env
self.env_type : Environments = env_type
self.env = None
self.objectives_metrics = objectives_metrics
self.learning_method = learning_method
self.learning_algo : Algo = learning_algo
self.learning_method = None
self.memory: List[State] = [State(0)]
self.logger = getLogger("VIRAL")
self.stops_threads = threading.Event()
self.lock = threading.Lock()
self.threads: list[threading.Thread] = []
self.threads.append(
threading.Thread(target=self._learning, args=[self.memory[0]])
)
self.threads[0].start()
signal.signal(signal.SIGTERM, self.sigterm_handler)
signal.signal(signal.SIGINT, self.sigterm_handler)

def sigterm_handler(self, signal, frame):
self.stops_threads.set()
for thread in self.threads:
if thread.is_alive():
thread.join()
if len(threading.enumerate()) > 1:
for thread in threading.enumerate():
self.logger.error(f"{thread.name},({thread.ident}) is alive")
else:
self.logger.debug("end of main thread")
sys.exit(0)
#self.training_callback = TrainingInfoCallback()

def generate_reward_function(
self, task_description: str, iterations: int = 1
Expand Down Expand Up @@ -93,11 +80,11 @@ def generate_reward_function(

#repeat_last_n": 64, # combien le model regarde en arrière pour éviter de répéter les réponses (64 par défaut large pour nous)

"repeat_penalty": 1.5, # pénalité pour éviter de répéter les réponses (1.1 par défaut au mac 1.5 intéressant a modificer je pense)
#"repeat_penalty": 1.5, # pénalité pour éviter de répéter les réponses (1.1 par défaut au mac 1.5 intéressant a modificer je pense)

#"stop": "stop you here" # pour stopper la génération de texte pas intéressant pour nous

"tfs_z": 1.2, #reduire l'impacte des token les moins "pertinents" (1.0 par défaut pour désactiver 2.0 max)
#"tfs_z": 1.2, #reduire l'impacte des token les moins "pertinents" (1.0 par défaut pour désactiver 2.0 max)

#"top_k": 30, #reduit la probabilité de générer des non-sens (40 par défaut, 100 pour générer des réponses plus diverses, 10 pour des réponses plus "conservatrices")
#"top_p": 0.95, #marche avec le top_k une forte valeur pour des texte plus diverses (0.9 par défaut)
Expand All @@ -110,12 +97,12 @@ def generate_reward_function(
additional_options["seed"] = random.randint(0, 1000000)
for i in [1, 2]:
prompt = f"""
Complete the reward function for a {self.env.spec.name} environment.
Complete the reward function for a {self.env_type.value} environment.
Task Description: {task_description} Iteration {i+1}/{2}
complete this sentence:
def reward_func(observations:np.ndarray, terminated: bool, truncated: bool) -> float:
\"\"\"Reward function for {self.env.spec.name}
\"\"\"Reward function for {self.env_type.value}
Args:
observations (np.ndarray): observation on the current state
Expand All @@ -134,7 +121,9 @@ def reward_func(observations:np.ndarray, terminated: bool, truncated: bool) -> f
response = self.llm.print_Generator_and_return(response, i)
reward_func, response = self._get_runnable_function(response)
self.memory.append(State(i, reward_func, response))
best_idx, worst_idx = self.evaluate_policy(1, 2)

best_idx, worst_idx = self.evaluate_policy(1, 2)
worst_idx = 1
self.logger.debug(f"state to refine: {worst_idx}")
### SECOND STAGE ###
for n in range(iterations - 1):
Expand All @@ -156,11 +145,12 @@ def _get_runnable_function(self, response: str, error: str = None) -> Callable:
response = self.llm.generate_response(stream=True)
response = self.llm.print_Generator_and_return(response)
try:
env = gym.make(self.env_type.value)
response = self._get_code(response)
reward_func = self._compile_reward_function(response)
state, _ = self.env.reset()
action = self.learning_method.output(state)
next_observation, _, terminated, truncated, _ = self.env.step(action)
state, _ = env.reset()
action = env.action_space.sample()
next_observation, _, terminated, truncated, _ = env.step(action)
self._test_reward_function(
reward_func,
observations=next_observation,
Expand Down Expand Up @@ -253,26 +243,20 @@ def self_refine_reward(self, idx: int) -> Callable:
def _learning(self, state: State) -> None:
"""train a policy on an environment"""
self.logger.debug(f"state {state.idx} begin is learning")
vec_env, model = self._generate_env_model(state.reward_func)

policy, perfs, sr, nb_ep = self.learning_method.train(
reward_func=state.reward_func,
save_name=f"model/{self.learning_method}_{self.env.spec.name}{state.idx}.model",
stop=self.stops_threads,
)
training_callback = TrainingInfoCallback()
policy = model.learn(total_timesteps=60000, callback=training_callback)
metrics = training_callback.get_metrics()
self.memory[state.idx].set_policy(policy)
observations, rewards, sr_test = self.test_policy(policy)
observations, rewards, sr_test = self.test_policy(vec_env, policy)
metrics["test_success_rate"] = sr_test
metrics["test_rewards"] = rewards
perso_observations = []
for objective_metric in self.objectives_metrics:
perso_observations.append(objective_metric(observations))
self.memory[state.idx].set_performances(
{
"train_success_rate": sr,
"train_episodes": nb_ep,
"test_success_rate": sr_test,
"test_rewards": rewards,
"custom_metrics": perso_observations,
}
) # TODO maybe add in to chat this state
self.memory[state.idx].set_performances(metrics)
print(f"state {state.idx} performances: {metrics}")
self.logger.debug(f"state {state.idx} as finished is learning")

def evaluate_policy(self, idx1: int, idx2: int) -> int:
Expand All @@ -288,16 +272,9 @@ def evaluate_policy(self, idx1: int, idx2: int) -> int:
"""
if len(self.memory) < 2:
self.logger.error("At least two reward functions are required.")
to_join: int = []
for i in [idx1, idx2]:
if self.memory[i].performances is None:
self.threads.append(
threading.Thread(target=self._learning, args=[self.memory[i]])
)
self.threads[-1].start()
to_join.append(i)
for t in to_join:
self.threads[t].join()
self._learning(self.memory[i])
# TODO comparaison sur le success rate pour l'instant
if (
self.memory[idx1].performances["test_success_rate"]
Expand All @@ -309,35 +286,33 @@ def evaluate_policy(self, idx1: int, idx2: int) -> int:

def test_policy(
self,
env,
policy,
reward_func=None,
nb_episodes=100,
max_t=1000,
max_t=1000
) -> list:
all_rewards = []
all_states = []
nb_success = 0
x_max = 0
x_min = 0 # avoid div by 0
for epi in range(1, nb_episodes + 1):
if self.stops_threads.is_set():
break
total_reward = 0
state, _ = self.env.reset()
obs = env.reset()
for i in range(1, max_t + 1):
action = policy.output(state)
next_observation, reward, terminated, truncated, _ = self.env.step(
action
)
action = policy.predict(obs, deterministic=True)[0]
next_observation, reward, done, info = env.step(action)
#infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated
truncated = info[0]["TimeLimit.truncated"]
if reward_func is not None:
reward = reward_func(next_observation, terminated, truncated)
reward = reward_func(next_observation, done, truncated)
total_reward += reward
state = next_observation
all_states.append(state)
if terminated:
break
if truncated:
nb_success += 1
if done:
if truncated:
nb_success += 1
break
all_rewards.append(total_reward)
if total_reward > x_max:
Expand All @@ -348,3 +323,16 @@ def test_policy(
x - x_min / x_max - x_min for x in all_rewards
] # Min-Max normalized
return all_states, all_rewards, (nb_success / nb_episodes)

def _generate_env_model(self, reward_func):
"""
Generate the environment model
"""
vec_env = make_vec_env(self.env_type.value, n_envs=1, wrapper_class=CustomRewardWrapper, wrapper_kwargs={'llm_reward_function': reward_func})
if self.learning_algo == Algo.PPO:
model = PPO("MlpPolicy", vec_env, verbose=1)
else:
raise ValueError("The learning algorithm is not implemented.")

return vec_env, model

Loading

0 comments on commit 526b3b2

Please sign in to comment.