Skip to content

Commit

Permalink
[RLlib] Make Connector pipeline indexed and individual connetors gett…
Browse files Browse the repository at this point in the history
…able by int or class (or classname). (#28202)
  • Loading branch information
ArturNiederfahrenhorst authored Sep 12, 2022
1 parent de820c6 commit 7f03368
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
40 changes: 39 additions & 1 deletion rllib/connectors/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import abc
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union

import gym

Expand Down Expand Up @@ -404,6 +404,44 @@ def __str__(self, indentation: int = 0):
+ [c.__str__(indentation + 4) for c in self.connectors]
)

def __getitem__(self, key: Union[str, int, type]):
"""Returns a list of connectors that fit 'key'.
If key is a number n, we return a list with the nth element of this pipeline.
If key is a Connector class or a string matching the class name of a
Connector class, we return a list of all connectors in this pipeline matching
the specified class.
Args:
key: The key to index by
Returns: The Connector at index `key`.
"""
# In case key is a class
if not isinstance(key, str):
if isinstance(key, slice):
raise NotImplementedError(
"Slicing of ConnectorPipeline is currently not supported."
)
elif isinstance(key, int):
return [self.connectors[key]]
elif isinstance(key, type):
key = key.__name__
else:
raise NotImplementedError(
"Indexing by {} is currently not supported.".format(type(key))
)

results = []
for c in self.connectors:
if c.__class__.__name__ == key:
results.append(c)

if len(results) == 0:
return []

return results


@PublicAPI(stability="alpha")
def register_connector(name: str, cls: Connector):
Expand Down
4 changes: 4 additions & 0 deletions rllib/connectors/tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def test_sanity_check(self):
self.assertEqual(m.connectors[0].__class__.__name__, "Tom")
self.assertEqual(m.connectors[1].__class__.__name__, "Mary")

self.assertEqual(m["Tom"], [m.connectors[0]])
self.assertEqual(m[0], [m.connectors[0]])
self.assertEqual(m[m.connectors[1].__class__], [m.connectors[1]])


if __name__ == "__main__":
import pytest
Expand Down

0 comments on commit 7f03368

Please sign in to comment.