From 1098b1118503e2772b25e1bb4b9b93b3700a9adf Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 31 Aug 2022 14:27:40 +0200 Subject: [PATCH 1/3] initial Signed-off-by: Artur Niederfahrenhorst --- rllib/connectors/connector.py | 40 +++++++++++++++++++++++- rllib/connectors/tests/test_connector.py | 4 +++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/rllib/connectors/connector.py b/rllib/connectors/connector.py index b438fceec6f7e..364b4703e53a0 100644 --- a/rllib/connectors/connector.py +++ b/rllib/connectors/connector.py @@ -3,7 +3,7 @@ import abc import logging -from typing import TYPE_CHECKING, Any, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union import gym @@ -404,6 +404,44 @@ def __str__(self, indentation: int = 0): + [c.__str__(indentation + 4) for c in self.connectors] ) + def __getitem__(self, key: Union[str, int, type]): + """Returns a list of connectors that fit 'key'. + + If key is a number n, we return a list with the nth element of this pipeline. + If key is a Connector class or a string matching the class name of a + Connector class, we return a list of all connectors in this pipeline matching + the specified class. + + Args: + key: The key to index by + + Returns: The Connector at index `key`. + """ + # In case key is a class + if not isinstance(key, str): + if isinstance(key, slice): + raise NotImplementedError( + "Slicing of ConnectorPipeline is not " "supported." + ) + elif isinstance(key, int): + return [self.connectors[key]] + elif isinstance(key, type): + key = key.__name__ + else: + raise NotImplementedError( + "Indexing by {} not supported.".format(type(key)) + ) + + results = [] + for c in self.connectors: + if c.__class__.__name__ == key: + results.append(c) + + if len(results) == 0: + raise IndexError + + return results + @PublicAPI(stability="alpha") def register_connector(name: str, cls: Connector): diff --git a/rllib/connectors/tests/test_connector.py b/rllib/connectors/tests/test_connector.py index 0aea0fad5f4e6..2116eac8bad5f 100644 --- a/rllib/connectors/tests/test_connector.py +++ b/rllib/connectors/tests/test_connector.py @@ -51,6 +51,10 @@ def test_sanity_check(self): self.assertEqual(m.connectors[0].__class__.__name__, "Tom") self.assertEqual(m.connectors[1].__class__.__name__, "Mary") + self.assertEqual(m["Tom"], [m.connectors[0]]) + self.assertEqual(m[0], [m.connectors[0]]) + self.assertEqual(m[m.connectors[1].__class__], [m.connectors[1]]) + if __name__ == "__main__": import pytest From 29ca5b02d09ec44e2624da17a3976fc72d1caeae Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 31 Aug 2022 22:45:55 +0200 Subject: [PATCH 2/3] jun's nit Signed-off-by: Artur Niederfahrenhorst --- rllib/connectors/connector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rllib/connectors/connector.py b/rllib/connectors/connector.py index 364b4703e53a0..1dd8d0e4ee6ba 100644 --- a/rllib/connectors/connector.py +++ b/rllib/connectors/connector.py @@ -421,7 +421,7 @@ def __getitem__(self, key: Union[str, int, type]): if not isinstance(key, str): if isinstance(key, slice): raise NotImplementedError( - "Slicing of ConnectorPipeline is not " "supported." + "Slicing of ConnectorPipeline is currently not supported." ) elif isinstance(key, int): return [self.connectors[key]] @@ -429,7 +429,7 @@ def __getitem__(self, key: Union[str, int, type]): key = key.__name__ else: raise NotImplementedError( - "Indexing by {} not supported.".format(type(key)) + "Indexing by {} is currently not supported.".format(type(key)) ) results = [] @@ -438,7 +438,7 @@ def __getitem__(self, key: Union[str, int, type]): results.append(c) if len(results) == 0: - raise IndexError + raise [] return results From 946dd17de7218a42b6377bab0900eed5a6365d72 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Thu, 8 Sep 2022 12:49:44 +0200 Subject: [PATCH 3/3] return [] instead of raising it Signed-off-by: Artur Niederfahrenhorst --- rllib/connectors/connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/connectors/connector.py b/rllib/connectors/connector.py index 1dd8d0e4ee6ba..62390ba7b6c92 100644 --- a/rllib/connectors/connector.py +++ b/rllib/connectors/connector.py @@ -438,7 +438,7 @@ def __getitem__(self, key: Union[str, int, type]): results.append(c) if len(results) == 0: - raise [] + return [] return results