Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor the http executor node #5212

Merged
merged 15 commits into from
Jun 24, 2024
87 changes: 35 additions & 52 deletions api/core/helper/ssrf_proxy.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,48 @@
"""
Proxy requests to avoid SSRF
"""

import os

from httpx import get as _get
from httpx import head as _head
from httpx import options as _options
from httpx import patch as _patch
from httpx import post as _post
from httpx import put as _put
from requests import delete as _delete
import httpx

SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '')
SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '')
SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '')

requests_proxies = {
'http': SSRF_PROXY_HTTP_URL,
'https': SSRF_PROXY_HTTPS_URL
} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None

httpx_proxies = {
proxies = {
'http://': SSRF_PROXY_HTTP_URL,
'https://': SSRF_PROXY_HTTPS_URL
} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None

def get(url, *args, **kwargs):
return _get(url=url, *args, proxies=httpx_proxies, **kwargs)

def post(url, *args, **kwargs):
return _post(url=url, *args, proxies=httpx_proxies, **kwargs)

def put(url, *args, **kwargs):
return _put(url=url, *args, proxies=httpx_proxies, **kwargs)

def patch(url, *args, **kwargs):
return _patch(url=url, *args, proxies=httpx_proxies, **kwargs)

def delete(url, *args, **kwargs):
if 'follow_redirects' in kwargs:
if kwargs['follow_redirects']:
kwargs['allow_redirects'] = kwargs['follow_redirects']
kwargs.pop('follow_redirects')
if 'timeout' in kwargs:
timeout = kwargs['timeout']
if timeout is None:
kwargs.pop('timeout')
elif isinstance(timeout, tuple):
# check length of tuple
if len(timeout) == 2:
kwargs['timeout'] = timeout
elif len(timeout) == 1:
kwargs['timeout'] = timeout[0]
elif len(timeout) > 2:
kwargs['timeout'] = (timeout[0], timeout[1])
else:
kwargs['timeout'] = (timeout, timeout)
return _delete(url=url, *args, proxies=requests_proxies, **kwargs)

def head(url, *args, **kwargs):
return _head(url=url, *args, proxies=httpx_proxies, **kwargs)

def options(url, *args, **kwargs):
return _options(url=url, *args, proxies=httpx_proxies, **kwargs)

def make_request(method, url, **kwargs):
if SSRF_PROXY_ALL_URL:
return httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs)
elif proxies:
return httpx.request(method=method, url=url, proxies=proxies, **kwargs)
else:
return httpx.request(method=method, url=url, **kwargs)


def get(url, **kwargs):
return make_request('GET', url, **kwargs)


def post(url, **kwargs):
return make_request('POST', url, **kwargs)


def put(url, **kwargs):
return make_request('PUT', url, **kwargs)


def patch(url, **kwargs):
return make_request('PATCH', url, **kwargs)


def delete(url, **kwargs):
return make_request('DELETE', url, **kwargs)


def head(url, **kwargs):
return make_request('HEAD', url, **kwargs)
116 changes: 37 additions & 79 deletions api/core/tools/tool/api_tool.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import json
from json import dumps
from os import getenv
from typing import Any, Union
from typing import Any
from urllib.parse import urlencode

import httpx
import requests

import core.helper.ssrf_proxy as ssrf_proxy
from core.tools.entities.tool_bundle import ApiToolBundle
Expand All @@ -18,12 +16,14 @@
int(getenv('API_TOOL_DEFAULT_READ_TIMEOUT', '60'))
)


class ApiTool(Tool):
api_bundle: ApiToolBundle

"""
Api tool
"""

def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
"""
fork a new tool with meta data
Expand All @@ -38,16 +38,17 @@ def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
api_bundle=self.api_bundle.model_copy() if self.api_bundle else None,
runtime=Tool.Runtime(**runtime)
)

def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str:

def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any],
format_only: bool = False) -> str:
"""
validate the credentials for Api tool
"""
# assemble validate request and request parameters
headers = self.assembling_request(parameters)

if format_only:
return
return ''

response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters)
# validate response
Expand All @@ -68,12 +69,12 @@ def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:

if 'api_key_header' in credentials:
api_key_header = credentials['api_key_header']

if 'api_key_value' not in credentials:
raise ToolProviderCredentialValidationError('Missing api_key_value')
elif not isinstance(credentials['api_key_value'], str):
raise ToolProviderCredentialValidationError('api_key_value must be a string')

if 'api_key_header_prefix' in credentials:
api_key_header_prefix = credentials['api_key_header_prefix']
if api_key_header_prefix == 'basic' and credentials['api_key_value']:
Expand All @@ -82,20 +83,20 @@ def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}'
elif api_key_header_prefix == 'custom':
pass

headers[api_key_header] = credentials['api_key_value']

needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required]
for parameter in needed_parameters:
if parameter.required and parameter.name not in parameters:
raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")

if parameter.default is not None and parameter.name not in parameters:
parameters[parameter.name] = parameter.default

return headers

def validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> str:
def validate_and_parse_response(self, response: httpx.Response) -> str:
"""
validate the response
"""
Expand All @@ -112,23 +113,20 @@ def validate_and_parse_response(self, response: Union[httpx.Response, requests.R
return json.dumps(response)
except Exception as e:
return response.text
elif isinstance(response, requests.Response):
if not response.ok:
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
if not response.content:
return 'Empty response from the tool, please check your parameters and try again.'
try:
response = response.json()
try:
return json.dumps(response, ensure_ascii=False)
except Exception as e:
return json.dumps(response)
except Exception as e:
return response.text
else:
raise ValueError(f'Invalid response type {type(response)}')

def do_http_request(self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]) -> httpx.Response:

@staticmethod
def get_parameter_value(parameter, parameters):
if parameter['name'] in parameters:
return parameters[parameter['name']]
elif parameter.get('required', False):
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
else:
return (parameter.get('schema', {}) or {}).get('default', '')

def do_http_request(self, url: str, method: str, headers: dict[str, Any],
parameters: dict[str, Any]) -> httpx.Response:
"""
do http request depending on api bundle
"""
Expand All @@ -141,44 +139,17 @@ def do_http_request(self, url: str, method: str, headers: dict[str, Any], parame

# check parameters
for parameter in self.api_bundle.openapi.get('parameters', []):
value = self.get_parameter_value(parameter, parameters)
if parameter['in'] == 'path':
value = ''
if parameter['name'] in parameters:
value = parameters[parameter['name']]
elif parameter['required']:
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
else:
value = (parameter.get('schema', {}) or {}).get('default', '')
path_params[parameter['name']] = value

elif parameter['in'] == 'query':
value = ''
if parameter['name'] in parameters:
value = parameters[parameter['name']]
elif parameter.get('required', False):
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
else:
value = (parameter.get('schema', {}) or {}).get('default', '')
params[parameter['name']] = value

elif parameter['in'] == 'cookie':
value = ''
if parameter['name'] in parameters:
value = parameters[parameter['name']]
elif parameter.get('required', False):
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
else:
value = (parameter.get('schema', {}) or {}).get('default', '')
cookies[parameter['name']] = value

elif parameter['in'] == 'header':
value = ''
if parameter['name'] in parameters:
value = parameters[parameter['name']]
elif parameter.get('required', False):
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
else:
value = (parameter.get('schema', {}) or {}).get('default', '')
headers[parameter['name']] = value

# check if there is a request body and handle it
Expand All @@ -203,41 +174,29 @@ def do_http_request(self, url: str, method: str, headers: dict[str, Any], parame
else:
body[name] = None
break

# replace path parameters
for name, value in path_params.items():
url = url.replace(f'{{{name}}}', f'{value}')

# parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored
if 'Content-Type' in headers:
if headers['Content-Type'] == 'application/json':
body = dumps(body)
body = json.dumps(body)
elif headers['Content-Type'] == 'application/x-www-form-urlencoded':
body = urlencode(body)
else:
body = body

# do http request
if method == 'get':
response = ssrf_proxy.get(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'post':
response = ssrf_proxy.post(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'put':
response = ssrf_proxy.put(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'delete':
response = ssrf_proxy.delete(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, allow_redirects=True)
elif method == 'patch':
response = ssrf_proxy.patch(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'head':
response = ssrf_proxy.head(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'options':
response = ssrf_proxy.options(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)

if method in ('get', 'head', 'post', 'put', 'delete', 'patch'):
response = getattr(ssrf_proxy, method)(url, params=params, headers=headers, cookies=cookies, data=body,
timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
return response
else:
raise ValueError(f'Invalid http method {method}')

return response

def _convert_body_property_any_of(self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10) -> Any:
raise ValueError(f'Invalid http method {self.method}')

def _convert_body_property_any_of(self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]],
max_recursive=10) -> Any:
if max_recursive <= 0:
raise Exception("Max recursion depth reached")
for option in any_of or []:
Expand Down Expand Up @@ -322,4 +281,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe

# assemble invoke message
return self.create_text_message(response)

Loading