Skip to content

Commit

Permalink
fix ppo
Browse files Browse the repository at this point in the history
  • Loading branch information
retinfai committed Sep 8, 2023
1 parent 9581281 commit e491ce6
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/reinforcement_learning/reinforcement_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def main():
global BATCH_SIZE
global EVALUATE_EVERY_N_STEPS
global EVALUATE_FOR_M_EPISODES
global ALGORITHM

ENVIRONMENT, \
ALGORITHM, \
Expand Down Expand Up @@ -309,7 +310,11 @@ def evaluate_policy(env, agent, num_episodes):

while not truncated and not terminated:

action = agent.select_action_from_policy(state, evaluation=True)
if ALGORITHM == 'PPO':
action = agent.select_action_from_policy(state)
else:
action = agent.select_action_from_policy(state, evaluation=True)

action = hlp.denormalize(action, env.MAX_ACTIONS, env.MIN_ACTIONS)
next_state, reward, terminated, truncated, _ = env.step(action)

Expand Down

0 comments on commit e491ce6

Please sign in to comment.