Skip to content

Commit

Permalink
update the doc
Browse files Browse the repository at this point in the history
  • Loading branch information
valentincuzin committed Dec 23, 2024
1 parent da56194 commit 9542550
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 75 deletions.
2 changes: 2 additions & 0 deletions src/LLM/OllamaChat.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def generate_simple_response(self,
Generate a simple response without historic.
Args:
prompt (str): user prompt
sys_prompt (str, optional): system prompt
stream (bool, optional): Stream response in real-time
additional_options (dict, optional): Temporary generation options
Expand Down
17 changes: 16 additions & 1 deletion src/PolicyTrainer/CustomRewardWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,27 @@


class CustomRewardWrapper(gym.Wrapper):
def __init__(self, env, success_func: Callable = None, llm_reward_function: Callable = None):
def __init__(self, env: gym.Env, success_func: Callable = None, llm_reward_function: Callable = None):
"""init the custom reward wrapper
Args:
env (gym.Env): the current environment
success_func (Callable, optional): this function should return True if success. Defaults to None.
llm_reward_function (Callable, optional): the generated reward function. Defaults to None.
"""
super().__init__(env)
self.success_func = success_func
self.llm_reward_function = llm_reward_function

def step(self, action):
"""override the step function to integrate our reward function
Args:
action (): the action to realise
Returns:
tuple: observation, reward, terminated, truncated, info
"""
observation, original_reward, terminated, truncated, info = self.env.step(
action
)
Expand Down
57 changes: 48 additions & 9 deletions src/PolicyTrainer/PolicyTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from time import sleep
from queue import Empty

from stable_baselines3.common.vec_env.base_vec_env import VecEnv

from PolicyTrainer.CustomRewardWrapper import CustomRewardWrapper
from PolicyTrainer.TrainingInfoCallback import TrainingInfoCallback
from State.State import State
Expand All @@ -24,6 +26,13 @@

class PolicyTrainer:
def __init__(self, memory: list[State], env_type: EnvType, timeout: int):
"""initialise the policy trainer
Args:
memory (list[State]):
env_type (EnvType): parameter of the env
timeout (int): for the model.learn()
"""
self.logger = getLogger("VIRAL")
self.memory = memory
self.timeout = timeout
Expand All @@ -48,7 +57,12 @@ def __init__(self, memory: list[State], env_type: EnvType, timeout: int):
self._learning(self.memory[0])

def _learning(self, state: State, queue: Queue = None) -> None:
"""train a policy on an environment"""
"""train a policy on an environment
Args:
state (State):
queue (Queue, optional): handle modification to return. Defaults to None.
"""
self.logger.debug(
f"state {state.idx} begin is learning with reward function: {state.reward_func_str}"
)
Expand All @@ -71,7 +85,7 @@ def _learning(self, state: State, queue: Queue = None) -> None:
)

def evaluate_policy(self, idx1: int, idx2: int) -> int:
"""
""" TODO to be change, i think evaluate if the policy give is better than the original
Evaluate policy performance for multiple reward functions
Args:
Expand Down Expand Up @@ -134,11 +148,22 @@ def evaluate_policy(self, idx1: int, idx2: int) -> int:

def test_policy(
self,
env,
env: VecEnv,
policy,
numvenv,
nb_episodes=10,
numvenv: int,
nb_episodes: int = 100,
) -> float:
"""test a policy already train
Args:
env (VecEnv): envs
policy (): can be PPO or other RLAlgo
numvenv (int): number of env in the vec
nb_episodes (int, optional): . Defaults to 100.
Returns:
float: _description_
"""
all_rewards = []
nb_success = 0

Expand All @@ -163,7 +188,13 @@ def test_policy(
success_rate = nb_success / nb_episodes
return success_rate

def test_policy_hf(self, policy_path, nb_episodes = 100):
def test_policy_hf(self, policy_path: str, nb_episodes: int = 100):
"""visualise a policy
Args:
policy_path (str): the path of the policy to load
nb_episodes (int, optional): . Defaults to 100.
"""
env = make(self.env_name, render_mode='human')
if self.algo == Algo.PPO:
policy = PPO.load(policy_path)
Expand All @@ -175,9 +206,17 @@ def test_policy_hf(self, policy_path, nb_episodes = 100):
obs, _, term, trunc, _ = env.step(actions)
done = term or trunc

def _generate_env_model(self, reward_func):
"""
Generate the environment model
def _generate_env_model(self, reward_func) -> tuple[VecEnv, PPO, int]:
"""Generate the environment model
Args:
reward_func (Callable): the generated reward function
Raises:
ValueError: if algo not implemented
Returns:
tuple[VecEnv, PPO, int]: the envs, the model, the number of envs
"""
numenvs = 2
# SubprocVecEnv sauf on utilisera cuda derrière
Expand Down
13 changes: 7 additions & 6 deletions src/PolicyTrainer/TrainingInfoCallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ def __init__(self):
self.num_envs = None

def _on_training_start(self):
"""Initialisation au début de l'entraînement"""
"""Init at the begin of the training"""
self.num_envs = self.training_env.num_envs
self.current_episode_rewards = np.zeros(self.num_envs)
self.current_episode_lengths = np.zeros(self.num_envs, dtype=int)

def _on_step(self) -> bool:
"""Méthode appelée à chaque étape de l'entraînement."""
"""call every steps"""
rewards = self.locals["rewards"]
dones = self.locals["dones"]

Expand All @@ -43,7 +43,7 @@ def _on_step(self) -> bool:
return True

def _on_training_end(self) -> None:
"""Méthode appelée à la fin de l'entraînement."""
"""call at the end of the training"""
rewards = self.training_metrics["episode_rewards"]
rewards /= np.linalg.norm(rewards)
lengths = self.training_metrics["episode_lengths"]
Expand All @@ -55,8 +55,9 @@ def _on_training_end(self) -> None:
}

def get_metrics(self):
"""
Méthode pour récupérer les métriques calculées.
:return: Dictionnaire des métriques
"""for get metrics
Returns:
dict: contain metrics harvested
"""
return self.custom_metrics
121 changes: 67 additions & 54 deletions src/State/State.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,65 +5,72 @@
logger = getLogger("VIRAL")

class State:
"""Represents a state in the reward function generation and evaluation process.
This class encapsulates the key components of a reward function's lifecycle,
tracking its index, implementation, policy, and performance metrics.
Attributes:
idx (int): Unique identifier for the state.
reward_func (Callable, optional): The compiled reward function.
reward_func_str (str, optional): String representation of the reward function.
policy (object, optional): The policy associated with the reward function.
performances (dict, optional): Performance metrics of the reward function.
Key Characteristics:
- Tracks the evolution of reward functions
- Provides a snapshot of a specific iteration
- Allows for dynamic updating of policy and performance
Initialization Constraints:
- Initial state (idx=0) cannot have a reward function
- Non-initial states must have both a reward function and its string representation
Methods:
- set_policy(policy): Update the associated policy
- set_performances(performances): Update performance metrics
- __repr__(): Provide a human-readable string representation of the state
Example:
# Creating a new state for a reward function
state = State(
idx=1,
reward_func=my_reward_func,
reward_func_str="def reward_func(...):",
policy=None,
perfomances=None
)
# Updating state with training results
state.set_policy(trained_policy)
state.set_performances({
'success_rate': 0.75,
'average_reward': 10.5
})
Notes:
- Designed for tracking reward function iterations
- Provides flexibility in managing function states
- Supports logging and debugging of reward function generation process
"""
def __init__(
self,
idx,
idx: int,
reward_func: Callable = None,
reward_func_str: str = None,
policy=None,
perfomances: dict = None,
):
"""
Represents a state in the reward function generation and evaluation process.
This class encapsulates the key components of a reward function's lifecycle,
tracking its index, implementation, policy, and performance metrics.
Attributes:
idx (int): Unique identifier for the state.
reward_func (Callable, optional): The compiled reward function.
reward_func_str (str, optional): String representation of the reward function.
policy (object, optional): The policy associated with the reward function.
performances (dict, optional): Performance metrics of the reward function.
Key Characteristics:
- Tracks the evolution of reward functions
- Provides a snapshot of a specific iteration
- Allows for dynamic updating of policy and performance
Initialization Constraints:
- Initial state (idx=0) cannot have a reward function
- Non-initial states must have both a reward function and its string representation
Methods:
- set_policy(policy): Update the associated policy
- set_performances(performances): Update performance metrics
- __repr__(): Provide a human-readable string representation of the state
Example:
# Creating a new state for a reward function
state = State(
idx=1,
reward_func=my_reward_func,
reward_func_str="def reward_func(...):",
policy=None,
perfomances=None
)
# Updating state with training results
state.set_policy(trained_policy)
state.set_performances({
'success_rate': 0.75,
'average_reward': 10.5
})
Notes:
- Designed for tracking reward function iterations
- Provides flexibility in managing function states
- Supports logging and debugging of reward function generation process
"""init a new state
Args:
idx (int): the index of the memory
reward_func (Callable, optional): . Defaults to None.
reward_func_str (str, optional): for printing the reward function. Defaults to None.
policy (_type_, optional): . Defaults to None.
perfomances (dict, optional): . Defaults to None.
"""
self.idx = idx
self.src: list = [self.idx]
if self.idx == 0 and (reward_func is not None or reward_func_str is not None):
logger.error("the inital state don't take reward function")
elif self.idx != 0 and (reward_func is None or reward_func_str is None):
Expand All @@ -74,14 +81,20 @@ def __init__(
self.logger_csv = getLoggerCSV()
self.performances = perfomances

def set_src(self, state):
self.src = state.src.copy()
self.src.append(self.idx)

def set_policy(self, policy):
"""set the current policy
Args:
policy ():
"""
self.policy = policy

def set_performances(self, performances: dict):
"""set performance after the test
Args:
performances (dict):
"""
self.performances = performances
self.logger_csv.to_csv(self)

Expand Down
21 changes: 19 additions & 2 deletions src/VIRAL.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ def __init__(
"""
Initialize VIRAL architecture for dynamic reward function generation
Args:
env_type (EnvType): refer to parameter of an gym Env
model (str): Language model for reward generation
learning_method (str): Reinforcement learning method
hf (bool, optional): active the human feedback
training_time (int, optional): timeout for model.learn()
options (dict, optional): options for the llm
"""
if options.get("seed") is None:
options["seed"] = random.randint(0, 1000000)
Expand Down Expand Up @@ -51,6 +54,11 @@ def __init__(
)

def generate_context(self, prompt_info: dict):
"""Generate more contexte for Step back prompting
Args:
prompt_info (dict): contain a task, and observation space
"""
prompt = f"{prompt_info}\nDescribe which observation can achive the goal."
sys_prompt = (
f"You're a physics expert, specializing in {self.env_type} motion analysis.\n"
Expand Down Expand Up @@ -207,7 +215,16 @@ def self_refine_reward(self, idx: int) -> Callable:

return len(self.memory) - 1

def human_feedback(self, prompt: str, idx: int) -> Callable:
def human_feedback(self, prompt: str, idx: int) -> str:
"""implement human feedback
Args:
prompt (str): user prompt
idx (int): state.idx to refine
Returns:
str: return the modified prompt
"""
self.logger.info(self.memory[idx])
visualise = input("do you need to visualise policy ?\ny/n:")
if visualise.lower() in ["y", "yes"]:
Expand Down
Loading

0 comments on commit 9542550

Please sign in to comment.