diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index 254ce9559c6..10a60a0a145 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -12,6 +12,7 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. +import collections import hashlib import inspect import json @@ -20,6 +21,7 @@ from typing import ( Any, ClassVar, + Counter, Dict, List, Optional, @@ -135,6 +137,7 @@ def __new__( ) cls.CONFIG_PARAMETER_NAME = arg cls.CONFIG_CLASS = arg_type + elif issubclass(arg_type, StepContext): if cls.CONTEXT_PARAMETER_NAME is not None: raise StepInterfaceError( @@ -163,14 +166,19 @@ def __new__( # tfx requires them to be unique # TODO [ENG-155]: Can we prefix inputs and outputs to avoid this # restriction? - shared_input_output_keys = set(cls.INPUT_SIGNATURE).intersection( - set(cls.OUTPUT_SIGNATURE) - ) - if shared_input_output_keys: + counter: Counter[str] = collections.Counter() + counter.update(list(cls.INPUT_SIGNATURE)) + counter.update(list(cls.OUTPUT_SIGNATURE)) + if cls.CONFIG_CLASS: + counter.update(list(cls.CONFIG_CLASS.__fields__.keys())) + + shared_keys = {k for k in counter.elements() if counter[k] > 1} + if shared_keys: raise StepInterfaceError( - f"There is an overlap in the input and output names of " - f"step '{name}': {shared_input_output_keys}. Please make " - f"sure that your input and output names are distinct." + f"The following keys are overlapping in the input, output and " + f"config parameter names of step '{name}': {shared_keys}. " + f"Please make sure that your input, output and config " + f"parameter names are unique." ) return cls