Skip to content

Commit

Permalink
update the logger, add parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Cuzin-Rambaud committed Dec 21, 2024
1 parent c1b61f6 commit b4fa444
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 234 deletions.
5 changes: 3 additions & 2 deletions src/PolicyTrainer/PolicyTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@


class PolicyTrainer:
def __init__(self, memory: list[State], env_type: EnvType):
def __init__(self, memory: list[State], env_type: EnvType, timeout: int):
self.logger = getLogger("VIRAL")
self.memory = memory
self.timeout = timeout
self.algo = env_type.algo
self.env_name = str(env_type)
self.success_func = env_type.success_func
Expand All @@ -46,7 +47,7 @@ def _learning(self, state: State, queue: Queue = None) -> None:
self.logger.debug(f"state {state.idx} begin is learning with reward function: {state.reward_func_str}")
vec_env, model, numvenv = self._generate_env_model(state.reward_func)
training_callback = TrainingInfoCallback()
policy = model.learn(total_timesteps=25000, callback=training_callback)
policy = model.learn(total_timesteps=self.timeout, callback=training_callback)
policy.save(f"model/policy{state.idx}.model")
metrics = training_callback.get_metrics()
self.logger.debug(f"{state.idx} TRAINING METRICS: {metrics}")
Expand Down
13 changes: 7 additions & 6 deletions src/VIRAL.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class VIRAL:
def __init__(
self,
env_type: EnvType,
model: str = "qwen2.5-coder",
model: str,
training_time : int = 25000,
options: dict = {},
):
"""
Expand Down Expand Up @@ -45,11 +46,11 @@ def __init__(
self.logger.info(f"additional options: {options}")
self.memory: list[State] = [State(0)]
self.policy_trainer: PolicyTrainer = PolicyTrainer(
self.memory, self.env_type
self.memory, self.env_type, timeout=training_time
)

def generate_reward_function(
self, task_description: str, iterations: int = 1
self, task_description: str, n_init: int = 2, n_refine: int = 1
) -> list[State]:
"""
Generate and iteratively improve a reward function using a Language Model (LLM).
Expand Down Expand Up @@ -99,10 +100,10 @@ def generate_reward_function(
- Logging at various stages for debugging and tracking
"""
### INIT STAGE ###
for i in [1, 2]:
for i in range(1, n_init+1): # TODO make it work for 4_init
prompt = f"""
Complete the reward function for a {self.env_type} environment.
Task Description: {task_description} Iteration {i+1}/{2}
Task Description: {task_description} Iteration {i}/{n_init+1}
complete this sentence:
def reward_func(observations:np.ndarray, terminated: bool, truncated: bool) -> float:
Expand All @@ -126,7 +127,7 @@ def reward_func(observations:np.ndarray, terminated: bool, truncated: bool) -> f

best_idx, worst_idx = self.policy_trainer.evaluate_policy(1, 2)
### SECOND STAGE ###
for n in range(iterations - 1):
for _ in range(n_refine - 1):
self.logger.debug(f"state to refine: {best_idx}")
new_idx = self.self_refine_reward(best_idx)
best_idx, worst_idx = self.policy_trainer.evaluate_policy(best_idx, new_idx)
Expand Down
66 changes: 32 additions & 34 deletions src/analyse.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -18,7 +18,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand All @@ -42,69 +42,60 @@
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>env</th>\n",
" <th>llm</th>\n",
" <th>reward_function</th>\n",
" <th>src</th>\n",
" <th>SR</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>CartPole-v1</td>\n",
" <td>qwen2.5-coder</td>\n",
" <td>NaN</td>\n",
" <td>[0]</td>\n",
" <td>0.85</td>\n",
" <td>0.66</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>CartPole-v1</td>\n",
" <td>qwen2.5-coder</td>\n",
" <td>def reward_func(observations: np.ndarray, term...</td>\n",
" <td>[2]</td>\n",
" <td>0.98</td>\n",
" <td>import numpy as np\\n\\ndef reward_func(observat...</td>\n",
" <td>0.82</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>CartPole-v1</td>\n",
" <td>qwen2.5-coder</td>\n",
" <td>def reward_func(observations: np.ndarray, term...</td>\n",
" <td>[1]</td>\n",
" <td>0.99</td>\n",
" <td>import numpy as np\\n\\ndef reward_func(observat...</td>\n",
" <td>0.82</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>CartPole-v1</td>\n",
" <td>qwen2.5-coder</td>\n",
" <td>def reward_func(observations: np.ndarray, term...</td>\n",
" <td>[1, 3]</td>\n",
" <td>0.79</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>qwen2.5-coder</td>\n",
" <td>def reward_func(observations: np.ndarray, term...</td>\n",
" <td>[1, 4]</td>\n",
" <td>0.00</td>\n",
" <td>import numpy as np\\n\\ndef reward_func(observat...</td>\n",
" <td>0.10</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" llm reward_function src \\\n",
"0 qwen2.5-coder NaN [0] \n",
"1 qwen2.5-coder def reward_func(observations: np.ndarray, term... [2] \n",
"2 qwen2.5-coder def reward_func(observations: np.ndarray, term... [1] \n",
"3 qwen2.5-coder def reward_func(observations: np.ndarray, term... [1, 3] \n",
"4 qwen2.5-coder def reward_func(observations: np.ndarray, term... [1, 4] \n",
" env llm \\\n",
"0 CartPole-v1 qwen2.5-coder \n",
"1 CartPole-v1 qwen2.5-coder \n",
"2 CartPole-v1 qwen2.5-coder \n",
"3 CartPole-v1 qwen2.5-coder \n",
"\n",
" SR \n",
"0 0.85 \n",
"1 0.98 \n",
"2 0.99 \n",
"3 0.79 \n",
"4 0.00 "
" reward_function SR \n",
"0 NaN 0.66 \n",
"1 import numpy as np\\n\\ndef reward_func(observat... 0.82 \n",
"2 import numpy as np\\n\\ndef reward_func(observat... 0.82 \n",
"3 import numpy as np\\n\\ndef reward_func(observat... 0.10 "
]
},
"execution_count": 16,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -113,6 +104,13 @@
"data = pd.read_csv('log/CartPole-v1_log.csv', delimiter=';')\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Loading

0 comments on commit b4fa444

Please sign in to comment.