-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor(Pipelines) : Smart Data Frame Pipeline #735
Changes from 13 commits
1c06347
385db0d
3d30afb
51e4142
55f9700
44f4bbe
bdeeb9f
1cb977b
09386c4
69fa4be
d5e9e03
d6a6ed6
9b209ce
3f3fc61
327b8e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import Any | ||
from ...helpers.logger import Logger | ||
from ..base_logic_unit import BaseLogicUnit | ||
from ..pipeline_context import PipelineContext | ||
|
||
|
||
class CacheLookup(BaseLogicUnit): | ||
""" | ||
Cache Lookup of Code Stage | ||
""" | ||
|
||
pass | ||
|
||
def execute(self, input: Any, **kwargs) -> Any: | ||
""" | ||
This method will return output according to | ||
Implementation. | ||
|
||
:param input: Your input data. | ||
:param kwargs: A dictionary of keyword arguments. | ||
- 'logger' (any): The logger for logging. | ||
- 'config' (Config): Global configurations for the test | ||
- 'context' (any): The execution context. | ||
|
||
:return: The result of the execution. | ||
""" | ||
pipeline_context: PipelineContext = kwargs.get("context") | ||
logger: Logger = kwargs.get("logger") | ||
if ( | ||
pipeline_context.config.enable_cache | ||
and pipeline_context.cache | ||
and pipeline_context.cache.get( | ||
pipeline_context.cache.get_cache_key(pipeline_context) | ||
) | ||
): | ||
logger.log("Using cached response") | ||
code = pipeline_context.query_exec_tracker.execute_func( | ||
pipeline_context.cache.get, | ||
pipeline_context.cache.get_cache_key(pipeline_context), | ||
tag="cache_hit", | ||
) | ||
pipeline_context.add_intermediate_value("is_present_in_cache", True) | ||
return code |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from typing import Any | ||
from ..base_logic_unit import BaseLogicUnit | ||
from ..pipeline_context import PipelineContext | ||
|
||
|
||
class CachePopulation(BaseLogicUnit): | ||
""" | ||
Cache Population Stage | ||
""" | ||
|
||
pass | ||
|
||
def execute(self, input: Any, **kwargs) -> Any: | ||
""" | ||
This method will return output according to | ||
Implementation. | ||
:param input: Your input data. | ||
:param kwargs: A dictionary of keyword arguments. | ||
- 'logger' (any): The logger for logging. | ||
- 'config' (Config): Global configurations for the test | ||
- 'context' (any): The execution context. | ||
:return: The result of the execution. | ||
""" | ||
pipeline_context: PipelineContext = kwargs.get("context") | ||
|
||
code = input | ||
|
||
if pipeline_context.config.enable_cache and pipeline_context.cache: | ||
pipeline_context.cache.set( | ||
pipeline_context.cache.get_cache_key(pipeline_context), code | ||
) | ||
|
||
if pipeline_context.config.callback is not None: | ||
pipeline_context.config.callback.on_code(code) | ||
|
||
return code |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import logging | ||
import traceback | ||
from typing import Any, List | ||
from ...helpers.code_manager import CodeExecutionContext | ||
from ...helpers.logger import Logger | ||
from ..base_logic_unit import BaseLogicUnit | ||
from ..pipeline_context import PipelineContext | ||
from ...prompts.correct_error_prompt import CorrectErrorPrompt | ||
|
||
|
||
class CodeExecution(BaseLogicUnit): | ||
""" | ||
Code Execution Stage | ||
""" | ||
|
||
pass | ||
|
||
def execute(self, input: Any, **kwargs) -> Any: | ||
""" | ||
This method will return output according to | ||
Implementation. | ||
:param input: Your input data. | ||
:param kwargs: A dictionary of keyword arguments. | ||
- 'logger' (any): The logger for logging. | ||
- 'config' (Config): Global configurations for the test | ||
- 'context' (any): The execution context. | ||
:return: The result of the execution. | ||
""" | ||
pipeline_context: PipelineContext = kwargs.get("context") | ||
logger: Logger = kwargs.get("logger") | ||
|
||
code = input | ||
retry_count = 0 | ||
code_to_run = code | ||
result = None | ||
while retry_count < pipeline_context.config.max_retries: | ||
try: | ||
# Execute the code | ||
code_context = CodeExecutionContext( | ||
pipeline_context.get_intermediate_value("last_prompt_id"), | ||
pipeline_context.get_intermediate_value("skills"), | ||
) | ||
result = pipeline_context.get_intermediate_value( | ||
"code_manager" | ||
).execute_code( | ||
code=code_to_run, | ||
context=code_context, | ||
) | ||
|
||
break | ||
|
||
except Exception as e: | ||
if ( | ||
not pipeline_context.config.use_error_correction_framework | ||
or retry_count >= pipeline_context.config.max_retries - 1 | ||
): | ||
raise e | ||
|
||
retry_count += 1 | ||
|
||
logger.log( | ||
f"Failed to execute code with a correction framework " | ||
f"[retry number: {retry_count}]", | ||
level=logging.WARNING, | ||
) | ||
|
||
traceback_error = traceback.format_exc() | ||
[ | ||
code_to_run, | ||
reasoning, | ||
answer, | ||
] = pipeline_context.query_exec_tracker.execute_func( | ||
self._retry_run_code, | ||
code, | ||
pipeline_context, | ||
logger, | ||
traceback_error, | ||
) | ||
|
||
pipeline_context.add_intermediate_value("reasoning", reasoning) | ||
pipeline_context.add_intermediate_value("answer", answer) | ||
|
||
return result | ||
|
||
def _retry_run_code( | ||
self, code: str, context: PipelineContext, logger: Logger, e: Exception | ||
) -> List: | ||
""" | ||
A method to retry the code execution with error correction framework. | ||
Args: | ||
code (str): A python code | ||
context (PipelineContext) : Pipeline Context | ||
logger (Logger) : Logger | ||
e (Exception): An exception | ||
dataframes | ||
Returns (str): A python code | ||
""" | ||
|
||
logger.log(f"Failed with error: {e}. Retrying", logging.ERROR) | ||
|
||
default_values = { | ||
"engine": context.dfs[0].engine, | ||
"code": code, | ||
"error_returned": e, | ||
} | ||
error_correcting_instruction = context.get_intermediate_value("get_prompt")( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @milind-sinaptik why not move get prompt function here in that function and set_vars that are necessary for CorrectErrorPrompt. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried moving this but there are some test which start failing as move this function. Putting aside this for now. |
||
"correct_error", | ||
default_prompt=CorrectErrorPrompt(), | ||
default_values=default_values, | ||
) | ||
|
||
result = context.config.llm.generate_code(error_correcting_instruction) | ||
if context.config.callback is not None: | ||
context.config.callback.on_code(result[0]) | ||
|
||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@milind-sinaptik add log if skip for the debugging purpose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done