diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 228c5fa4..bde90c4f 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -265,7 +265,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if "final_observation" in infos: for idx, final_obs in enumerate(infos["final_observation"]): if final_obs is not None: - real_next_obs[idx] = np.concatenate([v for v in final_obs.values()], axis=-1) + real_next_obs[idx] = np.concatenate( + [v for k, v in final_obs.items() if k in cfg.algo.mlp_keys.encoder], axis=-1 + ) step_data["dones"] = dones[np.newaxis] step_data["actions"] = actions[np.newaxis] diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index daf98a47..08dd9eab 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -203,7 +203,9 @@ def player( if "final_observation" in infos: for idx, final_obs in enumerate(infos["final_observation"]): if final_obs is not None: - real_next_obs[idx] = np.concatenate([v for v in final_obs.values()], axis=-1) + real_next_obs[idx] = np.concatenate( + [v for k, v in final_obs.items() if k in cfg.algo.mlp_keys.encoder], axis=-1 + ) step_data["dones"] = dones[np.newaxis] step_data["actions"] = actions[np.newaxis]