diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index e2a1beeab1..aaa65409d8 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -27,6 +27,7 @@ SchedulerGroup, ) from deepsparse.subgraph_execute import SubGraphExecutor +from deepsparse.tasks import SupportedTasks from deepsparse.utils import InferenceState, PipelineState from deepsparse.utils.subgraph import SubGraph from deepsparse.utils.time import TIMER_KEY, InferenceStages, TimerManager @@ -139,7 +140,10 @@ def create(cls, task: str, **kwargs) -> "Pipeline": "Pipeline was not created for the given task. The " "provided task should be registered using the OperatorRegistry" ) - except Exception: + except Exception as e: + if SupportedTasks.is_text_generation(task): + raise e + _LOGGER.warning(f"Could not create v2 '{task}' pipeline, trying legacy") from deepsparse.legacy import Pipeline diff --git a/src/deepsparse/transformers/pipelines/text_generation/autoregressive_preprocess_operator.py b/src/deepsparse/transformers/pipelines/text_generation/autoregressive_preprocess_operator.py index df4e587df3..01d2a664b5 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/autoregressive_preprocess_operator.py +++ b/src/deepsparse/transformers/pipelines/text_generation/autoregressive_preprocess_operator.py @@ -51,7 +51,10 @@ def can_operate(self, inp: Any) -> bool: if inp.get("in_generation"): return True - if kv_cache.total_num_processed_tokens >= kv_cache.capacity: + if ( + kv_cache.total_num_processed_tokens >= kv_cache.capacity + and inp.get("in_generation") is None + ): raise RuntimeError( "Not enough kv_cache capacity to run generation. Please use a larger " "sequence_length or a shorter prompt"