From 8fad5696446073e42ea41489cff8fb6bd8ca868f Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 31 Jan 2024 19:19:48 +0000 Subject: [PATCH 1/2] only check capacity condition durin prefill; already have check in generation --- .../text_generation/autoregressive_preprocess_operator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/deepsparse/transformers/pipelines/text_generation/autoregressive_preprocess_operator.py b/src/deepsparse/transformers/pipelines/text_generation/autoregressive_preprocess_operator.py index df4e587df3..01d2a664b5 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/autoregressive_preprocess_operator.py +++ b/src/deepsparse/transformers/pipelines/text_generation/autoregressive_preprocess_operator.py @@ -51,7 +51,10 @@ def can_operate(self, inp: Any) -> bool: if inp.get("in_generation"): return True - if kv_cache.total_num_processed_tokens >= kv_cache.capacity: + if ( + kv_cache.total_num_processed_tokens >= kv_cache.capacity + and inp.get("in_generation") is None + ): raise RuntimeError( "Not enough kv_cache capacity to run generation. Please use a larger " "sequence_length or a shorter prompt" From 55793164fe7d12244f5aabb82d0deb5c2c5686a4 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 31 Jan 2024 21:03:36 +0000 Subject: [PATCH 2/2] dont try v1 if running text gen; just raise error --- src/deepsparse/pipeline.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index e2a1beeab1..aaa65409d8 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -27,6 +27,7 @@ SchedulerGroup, ) from deepsparse.subgraph_execute import SubGraphExecutor +from deepsparse.tasks import SupportedTasks from deepsparse.utils import InferenceState, PipelineState from deepsparse.utils.subgraph import SubGraph from deepsparse.utils.time import TIMER_KEY, InferenceStages, TimerManager @@ -139,7 +140,10 @@ def create(cls, task: str, **kwargs) -> "Pipeline": "Pipeline was not created for the given task. The " "provided task should be registered using the OperatorRegistry" ) - except Exception: + except Exception as e: + if SupportedTasks.is_text_generation(task): + raise e + _LOGGER.warning(f"Could not create v2 '{task}' pipeline, trying legacy") from deepsparse.legacy import Pipeline