Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
maik97 committed Sep 17, 2021
1 parent 2aaa4f3 commit b6a52f3
Showing 1 changed file with 29 additions and 8 deletions.
37 changes: 29 additions & 8 deletions _example_agents/ppo_single_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ def __init__(
env,
epochs=10,
batch_size=64,
learning_rate=3e-45,
learning_rate=3e-4,
clipnorm=0.5,
entropy_factor = 0.0,
hidden_units = 256,
hidden_activation = 'relu',
kernel_initializer: str = 'glorot_uniform',
entropy_factor=0.0,
hidden_units=64,
hidden_activation='relu',
kernel_initializer='glorot_uniform',
logger=None,
approximate_contin=False,
):
Expand All @@ -46,6 +46,8 @@ def __init__(

self.reward_rmstd = RunningMeanStd()

kernel_initializer = tf.keras.initializers.Orthogonal()

input_layer = Input(env.observation_space.shape)
hidden_layer = Dense(hidden_units, activation=hidden_activation, kernel_initializer=kernel_initializer)(input_layer)
hidden_layer = Dense(hidden_units, activation=hidden_activation, kernel_initializer=kernel_initializer)(hidden_layer)
Expand Down Expand Up @@ -126,8 +128,9 @@ def learn(self):

sum_loss = a_loss + 0.5 * c_loss

grad = tape.gradient(sum_loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(grad, self.model.trainable_variables))
if not tf.math.is_nan(sum_loss):
grad = tape.gradient(sum_loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(grad, self.model.trainable_variables))

#self.optimizer.minimize(sum_loss, self.model.trainable_variables, tape=tape)

Expand All @@ -141,6 +144,13 @@ def learn(self):

self.memory.clear()

def save_model(self, path='test'):
self.model.save_weights(path)
#self.model.save(path)

def load_model(self, path='test'):
self.model.load_weights(path)


def train_ppo():

Expand All @@ -151,7 +161,18 @@ def train_ppo():
agent = PPO(env, logger=StatusPrinter('test'))

trainer = Trainer(env, agent)
trainer.n_step_train(5_000_000, train_on_test=False)
trainer.n_step_train(5_000, train_on_test=False)
trainer.agent.save_model()
env.close()
del trainer
del agent
del env

env = gym.make("LunarLanderContinuous-v2")
agent = PPO(env, logger=StatusPrinter('test'))
trainer = Trainer(env, agent)
trainer.agent.load_model()
trainer.n_step_train(5_000, train_on_test=False)
trainer.test(100)

if __name__ == "__main__":
Expand Down

0 comments on commit b6a52f3

Please sign in to comment.