Skip to content

Commit

Permalink
(v3.3.1) - Sinergym observation normalization improved (#408)
Browse files Browse the repository at this point in the history
* Update NormalizeObservation wrapper paramters in sinergym/utils/wrappers.py

* Activate and deactivate automatic calibration update in NormalizeObservation in sinergym/utils/wrappers.py

* Added getters and setters for mean and variation in automatic normalization calibration in NormalizeObservation

* API reference updated with new functionality

* Added documentation for these features in wrappers section

* Tests created for observation normalization and new feaures

* Updated train and load model scripts

* Added note in traning and loading model notebook.

* Extra: Added 2ZoneDatacenter issue in documentation (buildings section)

* Updated Sinergym version from 3.3.0 to 3.3.1
  • Loading branch information
AlejandroCN7 authored Apr 22, 2024
1 parent 4dc3380 commit 9dc5a90
Show file tree
Hide file tree
Showing 10 changed files with 211 additions and 15 deletions.
5 changes: 5 additions & 0 deletions docs/source/pages/buildings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ The default actuators, output and meters variables are:
| Facility Total HVAC Electricity Demand Rate | HVAC_electricity_demand_rate | Whole Building |
+-----------------------------------------------+-----------------------------------------------+-----------------------------------+

.. warning:: Since the update to EnergyPlus version 23.1.0, it appears that temperature setpoints are not correctly
applied in the East zone. The issue is currently under investigation. In the meantime, the default
reward functions only apply to the control of the West zone to maintain result consistency. For more
information about this issue, visit `#395 <https://github.com/ugr-sail/sinergym/issues/395>`__.

**************************
Small Datacenter
**************************
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
sinergym.utils.wrappers.NormalizeObservation
sinergym.utils.wrappers.NormalizeObservation
============================================

.. currentmodule:: sinergym.utils.wrappers
Expand All @@ -16,12 +16,16 @@ sinergym.utils.wrappers.NormalizeObservation
.. autosummary::

~NormalizeObservation.__init__
~NormalizeObservation.activate_update
~NormalizeObservation.class_name
~NormalizeObservation.close
~NormalizeObservation.deactivate_update
~NormalizeObservation.get_wrapper_attr
~NormalizeObservation.normalize
~NormalizeObservation.render
~NormalizeObservation.reset
~NormalizeObservation.set_mean
~NormalizeObservation.set_var
~NormalizeObservation.step
~NormalizeObservation.wrapper_spec

Expand All @@ -35,12 +39,14 @@ sinergym.utils.wrappers.NormalizeObservation

~NormalizeObservation.action_space
~NormalizeObservation.logger
~NormalizeObservation.mean
~NormalizeObservation.metadata
~NormalizeObservation.np_random
~NormalizeObservation.observation_space
~NormalizeObservation.render_mode
~NormalizeObservation.reward_range
~NormalizeObservation.spec
~NormalizeObservation.unwrapped
~NormalizeObservation.var


15 changes: 15 additions & 0 deletions docs/source/pages/wrappers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,21 @@ It's based on the
Initially, it may not be precise and the values might often be out of range, so use this wrapper
with caution.

However, *Sinergym* enhances its functionality with some additional features:

- It includes the last unnormalized observation as an attribute, which is very useful for logging.

- It provides access to the means and variations used for normalization calibration, addressing the low-level
issues found in the original wrapper.

- Similarly, these calibration values can be set via a method. Refer to the :ref:`API reference` for more information.

- The automatic calibration can be enabled or disabled as you interact with the environment, allowing the
calibration to remain static instead of adaptive.

These functionalities are crucial when evaluating models trained using this wrapper.
For more details, visit `#407 <https://github.com/ugr-sail/sinergym/issues/407>`__.

***********************
LoggerWrapper
***********************
Expand Down
4 changes: 3 additions & 1 deletion examples/drl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53825,7 +53825,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll create the Gym environment, but it's **important to wrap the environment with the same wrappers used during training**. We can use the evaluation experiment name to rename the environment."
"We'll create the Gym environment, but it's **important to wrap the environment with the same wrappers used during training**. We can use the evaluation experiment name to rename the environment.\n",
"\n",
"**Note**: If you are loading a pre-trained model and using the observation space normalization wrapper, you should save the means and variations calibrated during the training process for a fair evaluation. Check the documentation on the wrapper and the training and model loading scripts for more information."
]
},
{
Expand Down
34 changes: 33 additions & 1 deletion scripts/eval/load_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sinergym.utils.constants import *
from sinergym.utils.rewards import *
from sinergym.utils.wrappers import *
from sinergym.utils.common import is_wrapped

# ---------------------------------------------------------------------------- #
# Parameters #
Expand Down Expand Up @@ -101,7 +102,7 @@
# parse str parameters to sinergym Callable or Objects if it is
# required
if isinstance(value, str):
if 'sinergym.' in value:
if '.' in value:
parameters[name] = eval(value)
env = wrapper_class(env=env, ** parameters)

Expand Down Expand Up @@ -176,6 +177,13 @@
sum(rewards))
env.close()

# Save normalization calibration if exists
if is_wrapped(env, NormalizeObservation) and conf.get('wandb'):
wandb.config.mean = env.get_wrapper_attr('mean')
wandb.config.var = env.get_wrapper_attr('var')
wandb.config.automatic_update = env.get_wrapper_attr(
'automatic_update')

# ---------------------------------------------------------------------------- #
# Wandb Artifacts #
# ---------------------------------------------------------------------------- #
Expand Down Expand Up @@ -219,6 +227,30 @@

except Exception as err:
print("Error in process detected")

# Save normalization calibration if exists
if is_wrapped(env, NormalizeObservation) and conf.get('wandb'):
wandb.config.mean = env.get_wrapper_attr('mean')
wandb.config.var = env.get_wrapper_attr('var')
wandb.config.automatic_update = env.get_wrapper_attr(
'automatic_update')

# Save current wandb artifacts state
if conf.get('wandb'):
if conf['wandb'].get('evaluation_registry'):
artifact = wandb.Artifact(
name=conf['wandb']['evaluation_registry']['artifact_name'],
type=conf['wandb']['evaluation_registry']['artifact_type'])
artifact.add_dir(
env.get_wrapper_attr('workspace_path'),
name='evaluation_output/')

run.log_artifact(artifact)

# wandb has finished
run.finish()

# Auto delete
if conf.get('cloud'):
if conf['cloud'].get('auto_delete'):
print('Deleting remote container')
Expand Down
5 changes: 3 additions & 2 deletions scripts/eval/load_agent_example.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
"model": "alex_ugr/sinergym/training:latest",
"wrappers": {
"NormalizeAction": {},
"NormalizeObservation": {},
"NormalizeReward": {},
"NormalizeObservation": {"mean": null,
"var": null,
"automatic_update": false},
"LoggerWrapper": {
"logger_class": "sinergym.utils.logger.CSVLogger",
"flag": true
Expand Down
38 changes: 37 additions & 1 deletion scripts/train/train_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sinergym.utils.logger import WandBOutputFormat
from sinergym.utils.rewards import *
from sinergym.utils.wrappers import *
from sinergym.utils.common import is_wrapped

# ---------------------------------------------------------------------------- #
# Function to process configuration #
Expand Down Expand Up @@ -172,7 +173,7 @@ def process_algorithm_parameters(alg_params: dict):
# parse str parameters to sinergym Callable or Objects if it is
# required
if isinstance(value, str):
if 'sinergym.' in value:
if '.' in value:
parameters[name] = eval(value)
env = wrapper_class(env=env, ** parameters)
if eval_env is not None:
Expand Down Expand Up @@ -321,6 +322,12 @@ def process_algorithm_parameters(alg_params: dict):
callback=callback,
log_interval=conf['algorithm']['log_interval'])
model.save(env.get_wrapper_attr('workspace_path') + '/model')
# Save normalization calibration if exists
if is_wrapped(env, NormalizeObservation) and conf.get('wandb'):
wandb.config.mean = env.get_wrapper_attr('mean')
wandb.config.var = env.get_wrapper_attr('var')
wandb.config.automatic_update = env.get_wrapper_attr(
'automatic_update')

# If the algorithm doesn't reset or close the environment, this script will do it in
# order to correctly log all the simulation data (Energyplus + Sinergym
Expand Down Expand Up @@ -378,6 +385,35 @@ def process_algorithm_parameters(alg_params: dict):
# If there is some error in the code, delete remote container if exists
except Exception as err:
print("Error in process detected")

# Current model state save
model.save(env.get_wrapper_attr('workspace_path') + '/model')

# Save normalization calibration if exists
if is_wrapped(env, NormalizeObservation) and conf.get('wandb'):
wandb.config.mean = env.get_wrapper_attr('mean')
wandb.config.var = env.get_wrapper_attr('var')
wandb.config.automatic_update = env.get_wrapper_attr(
'automatic_update')

# Save current wandb artifacts state
if conf.get('wandb'):
artifact = wandb.Artifact(
name=conf['wandb']['artifact_name'],
type=conf['wandb']['artifact_type'])
artifact.add_dir(
env.get_wrapper_attr('workspace_path'),
name='training_output/')
if conf.get('evaluation'):
artifact.add_dir(
eval_env.get_wrapper_attr('workspace_path'),
name='evaluation_output/')
run.log_artifact(artifact)

# wandb has finished
run.finish()

# Auto delete
if conf.get('cloud'):
if conf['cloud'].get('auto_delete'):
print('Deleting remote container')
Expand Down
72 changes: 64 additions & 8 deletions sinergym/utils/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,33 @@ class NormalizeObservation(gym.Wrapper, gym.utils.RecordConstructorArgs):

def __init__(self,
env: EplusEnv,
epsilon: float = 1e-8):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
automatic_update: bool = True,
epsilon: float = 1e-8,
mean: np.float64 = None,
var: np.float64 = None):
"""Initializes the NormalizationWrapper. Mean and var values can be None andbeing updated during interaction with environment.
Args:
env (Env): The environment to apply the wrapper
epsilon (float): A stability parameter that is used when scaling the observations. Defaults to 1e-8
env (EplusEnv): The environment to apply the wrapper.
automatic_update (bool, optional): Whether or not to update the mean and variance values automatically. Defaults to True.
epsilon (float, optional): A stability parameter used when scaling the observations. Defaults to 1e-8.
mean (np.float64, optional): The mean value used for normalization. Defaults to None.
var (np.float64, optional): The variance value used for normalization. Defaults to None.
"""
gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
gym.utils.RecordConstructorArgs.__init__(
self, epsilon=epsilon, mean=mean, var=var)
gym.Wrapper.__init__(self, env)
self.num_envs = 1
self.is_vector_env = False
self.automatic_update = automatic_update

self.unwrapped_observation = None
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.obs_rms.mean = mean if mean is not None else self.obs_rms.mean
self.obs_rms.var = var if var is not None else self.obs_rms.var
self.epsilon = epsilon

self.logger.info('wrapper initialized.')
self.logger.info('Wrapper initialized.')

def step(self, action):
"""Steps through the environment and normalizes the observation."""
Expand All @@ -101,9 +111,55 @@ def reset(self, **kwargs):

return self.normalize(np.array([obs]))[0], info

def deactivate_update(self):
"""
Deactivates the automatic update of the normalization wrapper.
After calling this method, the normalization wrapper will not update its calibration automatically.
"""
self.automatic_update = False

def activate_update(self):
"""
Activates the automatic update of the normalization wrapper.
After calling this method, the normalization wrapper will update its calibration automatically.
"""
self.automatic_update = True

@property
def mean(self):
"""Returns the mean value of the observations."""
return self.obs_rms.mean

@property
def var(self):
"""Returns the variance value of the observations."""
return self.obs_rms.var

def set_mean(self, mean: np.float64):
"""Sets the mean value of the observations."""
try:
assert len(mean) == self.observation_space.shape[0]
except AssertionError as err:
self.logger.error(
'Mean values must have the same shape than environment observation space.')
raise err
self.obs_rms.mean = mean

def set_var(self, var: np.float64):
"""Sets the variance value of the observations."""
try:
assert len(var) == self.observation_space.shape[0]
except AssertionError as err:
self.logger.error(
'Variance values must have the same shape than environment observation space.')
raise err
self.obs_rms.var = var

def normalize(self, obs):
"""Normalizes the observation using the running mean and variance of the observations."""
self.obs_rms.update(obs)
"""Normalizes the observation using the running mean and variance of the observations.
If automatic_update is enabled, the running mean and variance will be updated too."""
if self.automatic_update:
self.obs_rms.update(obs)
return (obs - self.obs_rms.mean) / \
np.sqrt(self.obs_rms.var + self.epsilon)

Expand Down
2 changes: 1 addition & 1 deletion sinergym/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.3.0
3.3.1
43 changes: 43 additions & 0 deletions tests/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,49 @@ def test_discretize_wrapper(env_wrapper_discretize):
assert not hasattr(original_env, 'action_mapping')


def test_normalize_observation_wrapper(env_wrapper_normalization):

# Spaces
env = env_wrapper_normalization
assert not env.is_discrete
assert hasattr(env, 'unwrapped_observation')

# Normalization calibration
assert hasattr(env, 'mean')
old_mean = env.get_wrapper_attr('mean').copy()
assert hasattr(env, 'var')
old_var = env.get_wrapper_attr('var').copy()
assert len(env.get_wrapper_attr('mean')) == env.observation_space.shape[0]
assert len(env.get_wrapper_attr('var')) == env.observation_space.shape[0]

# reset
obs, _ = env.reset()

# Spaces
assert (obs != env.get_wrapper_attr('unwrapped_observation')).any()
assert env.observation_space.contains(
env.get_wrapper_attr('unwrapped_observation'))

# Calibration
assert (old_mean != env.get_wrapper_attr('mean')).any()
assert (old_var != env.get_wrapper_attr('var')).any()
old_mean = env.get_wrapper_attr('mean').copy()
old_var = env.get_wrapper_attr('var').copy()
env.get_wrapper_attr('deactivate_update')()
a = env.action_space.sample()
env.step(a)
assert (old_mean == env.get_wrapper_attr('mean')).all()
assert (old_var == env.get_wrapper_attr('var')).all()
env.get_wrapper_attr('activate_update')()
env.step(a)
assert (old_mean != env.get_wrapper_attr('mean')).any()
assert (old_var != env.get_wrapper_attr('var')).any()
env.get_wrapper_attr('set_mean')(old_mean)
env.get_wrapper_attr('set_var')(old_var)
assert (old_mean == env.get_wrapper_attr('mean')).all()
assert (old_var == env.get_wrapper_attr('var')).all()


def test_normalize_action_wrapper(env_normalize_action_wrapper):

env = env_normalize_action_wrapper
Expand Down

0 comments on commit 9dc5a90

Please sign in to comment.