-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathcode_execution.py
120 lines (99 loc) · 3.81 KB
/
code_execution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import logging
import traceback
from typing import Any, List
from ...helpers.code_manager import CodeExecutionContext
from ...helpers.logger import Logger
from ..base_logic_unit import BaseLogicUnit
from ..pipeline_context import PipelineContext
from ...prompts.correct_error_prompt import CorrectErrorPrompt
class CodeExecution(BaseLogicUnit):
"""
Code Execution Stage
"""
pass
def execute(self, input: Any, **kwargs) -> Any:
"""
This method will return output according to
Implementation.
:param input: Your input data.
:param kwargs: A dictionary of keyword arguments.
- 'logger' (any): The logger for logging.
- 'config' (Config): Global configurations for the test
- 'context' (any): The execution context.
:return: The result of the execution.
"""
pipeline_context: PipelineContext = kwargs.get("context")
logger: Logger = kwargs.get("logger")
code = input
retry_count = 0
code_to_run = code
result = None
while retry_count < pipeline_context.config.max_retries:
try:
# Execute the code
code_context = CodeExecutionContext(
pipeline_context.get_intermediate_value("last_prompt_id"),
pipeline_context.get_intermediate_value("skills"),
)
result = pipeline_context.get_intermediate_value(
"code_manager"
).execute_code(
code=code_to_run,
context=code_context,
)
break
except Exception as e:
if (
not pipeline_context.config.use_error_correction_framework
or retry_count >= pipeline_context.config.max_retries - 1
):
raise e
retry_count += 1
logger.log(
f"Failed to execute code with a correction framework "
f"[retry number: {retry_count}]",
level=logging.WARNING,
)
traceback_error = traceback.format_exc()
[
code_to_run,
reasoning,
answer,
] = pipeline_context.query_exec_tracker.execute_func(
self._retry_run_code,
code,
pipeline_context,
logger,
traceback_error,
)
pipeline_context.add_intermediate_value("reasoning", reasoning)
pipeline_context.add_intermediate_value("answer", answer)
return result
def _retry_run_code(
self, code: str, context: PipelineContext, logger: Logger, e: Exception
) -> List:
"""
A method to retry the code execution with error correction framework.
Args:
code (str): A python code
context (PipelineContext) : Pipeline Context
logger (Logger) : Logger
e (Exception): An exception
dataframes
Returns (str): A python code
"""
logger.log(f"Failed with error: {e}. Retrying", logging.ERROR)
default_values = {
"engine": context.dfs[0].engine,
"code": code,
"error_returned": e,
}
error_correcting_instruction = context.get_intermediate_value("get_prompt")(
"correct_error",
default_prompt=CorrectErrorPrompt(),
default_values=default_values,
)
result = context.config.llm.generate_code(error_correcting_instruction)
if context.config.callback is not None:
context.config.callback.on_code(result[0])
return result