diff --git a/modelscope_agent/tools/style_repaint.py b/modelscope_agent/tools/style_repaint.py index e5be865b..c46d0381 100644 --- a/modelscope_agent/tools/style_repaint.py +++ b/modelscope_agent/tools/style_repaint.py @@ -3,23 +3,24 @@ import json import requests +from modelscope_agent.tools.base import BaseTool, register_tool from modelscope_agent.tools.localfile2url_utils.localfile2url import \ get_upload_url -from modelscope_agent.tools.tool import Tool, ToolSchema -from pydantic import ValidationError from requests.exceptions import RequestException, Timeout MAX_RETRY_TIMES = 3 WORK_DIR = os.getenv('CODE_INTERPRETER_WORK_DIR', '/tmp/ci_workspace') -class StyleRepaint(Tool): +@register_tool('style_repaint') +class StyleRepaint(BaseTool): description = '调用style_repaint api处理图片' name = 'style_repaint' parameters: list = [{ 'name': 'input.image_path', 'description': '用户上传的照片的相对路径', - 'required': True + 'required': True, + 'type': 'string' }, { 'name': 'input.style_index', 'description': '想要生成的风格化类型索引:\ @@ -28,38 +29,36 @@ class StyleRepaint(Tool): 2 二次元 \ 3 小清新 \ 4 未来科技 \ - 5 3D写实 \ + 5 国画古风 \ + 6 将军百战 \ + 7 炫彩卡通 \ + 8 清雅国风 \ + 9 喜迎新年 \ 用户输入数字指定风格', - 'required': True + 'required': True, + 'type': 'int' }] - def __init__(self, cfg={}): - self.cfg = cfg.get(self.name, {}) - # remote call - self.url = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/image-generation/generation' - self.token = self.cfg.get('token', - os.environ.get('DASHSCOPE_API_KEY', '')) - assert self.token != '', 'dashscope api token must be acquired' - + def call(self, params: str, **kwargs) -> str: + params = self._verify_args(params) + if isinstance(params, str): + return 'Parameter Error' + remote_parsed_input = self._remote_parse_input(**params) try: - all_param = { - 'name': self.name, - 'description': self.description, - 'parameters': self.parameters - } - self.tool_schema = ToolSchema(**all_param) - except ValidationError: - raise ValueError(f'Error when parsing parameters of {self.name}') - - self._str = self.tool_schema.model_dump_json() - self._function = self.parse_pydantic_model_to_openai_function( - all_param) - - def __call__(self, *args, **kwargs): - remote_parsed_input = self._remote_parse_input(*args, **kwargs) - remote_parsed_input['input']['style_index'] = int( - remote_parsed_input['input']['style_index']) + remote_parsed_input['input']['style_index'] = int( + remote_parsed_input['input']['style_index']) + except ValueError: + raise ValueError( + 'Please reselect the style index or the corresponding style introduction' + ) remote_parsed_input = json.dumps(remote_parsed_input) + url = kwargs.get( + 'url', + 'https://dashscope.aliyuncs.com/api/v1/services/aigc/image-generation/generation' + ) + self.token = kwargs.get('token', + os.environ.get('DASHSCOPE_API_KEY', '')) + assert self.token != '', 'dashscope api token must be acquired' origin_result = None retry_times = MAX_RETRY_TIMES headers = { @@ -74,17 +73,13 @@ def __call__(self, *args, **kwargs): try: response = requests.request( - 'POST', - url=self.url, - headers=headers, - data=remote_parsed_input) + 'POST', url=url, headers=headers, data=remote_parsed_input) if response.status_code != requests.codes.ok: response.raise_for_status() origin_result = json.loads(response.content.decode('utf-8')) - self.final_result = self._parse_output( - origin_result, remote=True) + self.final_result = origin_result return self.get_stylerepaint_result() except Timeout: continue @@ -129,7 +124,7 @@ def _remote_parse_input(self, *args, **kwargs): return kwargs def get_result(self): - result_data = json.loads(json.dumps(self.final_result['result'])) + result_data = json.loads(json.dumps(self.final_result)) if 'task_id' in result_data['output']: task_id = result_data['output']['task_id'] get_url = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}' @@ -145,7 +140,7 @@ def get_result(self): response.raise_for_status() origin_result = json.loads(response.content.decode('utf-8')) - get_result = self._parse_output(origin_result, remote=True) + get_result = origin_result return get_result except Timeout: continue @@ -162,15 +157,14 @@ def get_stylerepaint_result(self): try: result = self.get_result() while True: - result_data = result.get('result', {}) + result_data = result output = result_data.get('output', {}) task_status = output.get('task_status', '') if task_status == 'SUCCEEDED': print('任务已完成') # 取出result里url的部分,提高url图片展示稳定性 - output_url = self._parse_output( - result['result']['output']['results'][0]) + output_url = result['output']['results'][0]['url'] return output_url elif task_status == 'FAILED': diff --git a/tests/tools/test_style_repaint.py b/tests/tools/test_style_repaint.py new file mode 100644 index 00000000..91300e1c --- /dev/null +++ b/tests/tools/test_style_repaint.py @@ -0,0 +1,35 @@ +import os + +from modelscope_agent.agent import Agent +from modelscope_agent.tools.style_repaint import StyleRepaint + +print(os.getcwd()) + +from modelscope_agent.prompts.role_play import RolePlay # NOQA + + +def test_style_repaint(): + # 图片默认上传到ci_workspace + params = """{'input.image_path': './WechatIMG139.jpg', 'input.style_index': 0}""" + + style_repaint = StyleRepaint() + res = style_repaint.call(params) + assert (res.startswith('http')) + + +def test_style_repaint_role(): + role_template = '你扮演一个绘画家,用尽可能丰富的描述调用工具绘制各种风格的图画。' + + llm_config = {'model': 'qwen-max', 'model_server': 'dashscope'} + + # input tool args + function_list = [{'name': 'style_repaint'}] + + bot = RolePlay( + function_list=function_list, llm=llm_config, instruction=role_template) + + response = bot.run('[上传文件WechatIMG139.jpg],我想要清雅国风') + text = '' + for chunk in response: + text += chunk + print(text)