diff --git a/llmkira/middleware/chain_box/__init__.py b/llmkira/middleware/chain_box/__init__.py index 00009c1a3..e157d4a42 100644 --- a/llmkira/middleware/chain_box/__init__.py +++ b/llmkira/middleware/chain_box/__init__.py @@ -22,7 +22,7 @@ def __init__(self, uid: str = None): @classmethod def _prefix(cls, uuid: str) -> str: """auth:{auth_schema_version}:uuid""" - return f"auth:v1:{uuid}" + return f"auth:v2:{uuid}" @classmethod def from_form(cls, platform: str, user_id: str): @@ -52,6 +52,7 @@ async def read_auth(self, uuid: str) -> Optional[Chain]: logger.debug(f"Get Auth Data {_cache} {_cache}") chain = Chain.from_redis(_cache) if chain.is_expire: + logger.debug(f"Auth Expire {chain.uuid}") return None if chain.creator_uid != self.uid: logger.debug(f"Not User {self.uid} Created Auth") @@ -66,7 +67,7 @@ def __init__(self, uid: str): def _prefix(self) -> str: """chain:{auth_schema_version}:uuid""" - return f"chain:v1:{self.uid}" + return f"chain:v2:{self.uid}" async def add_task(self, chain: Chain) -> str: """ @@ -87,6 +88,8 @@ async def get_task(self) -> Optional[Chain]: :return Optional[Chain] """ cache = global_cache_runtime.get_redis() + # FIXME 优化为获取对应空间的数据 + # FIXME signas 应该添加 uuid redis_raw = await cache.lpop_data( key=self._prefix() ) @@ -94,5 +97,6 @@ async def get_task(self) -> Optional[Chain]: return None chain = Chain.from_redis(redis_raw) if chain.is_expire: + logger.debug(f"Chain Expire {chain.uuid}") return None return chain diff --git a/llmkira/receiver/function.py b/llmkira/receiver/function.py index 07b1838e5..9898ad2f1 100644 --- a/llmkira/receiver/function.py +++ b/llmkira/receiver/function.py @@ -10,13 +10,14 @@ import os from aio_pika.abc import AbstractIncomingMessage +from loguru import logger + from llmkira.middleware.chain_box import Chain, AuthReloader, ChainReloader from llmkira.middleware.env_virtual import EnvManager from llmkira.schema import RawMessage from llmkira.sdk.func_calling.register import ToolRegister from llmkira.sdk.schema import TaskBatch from llmkira.task import Task, TaskHeader -from loguru import logger class ChainFunc(object): @@ -28,7 +29,7 @@ async def reply_user(*, receiver: TaskHeader.Location ): """ - 通知用户 + 包装一下发送消息 :param platform: Default should be `task.receiver.platform` :param task: 任务 header :param text: 文本 str @@ -72,7 +73,7 @@ async def auth_chain(self, meta.verify_uuid = meta.get_verify_uuid(task_batch=task_batch) # 注册本地部署点 task_id = await AuthReloader(uid=_task_forward.receiver.uid).save_auth( - chain=Chain( + chain=Chain.create( uuid=meta.verify_uuid, creator_uid=_task_forward.receiver.uid, channel=__receiver__, @@ -81,7 +82,8 @@ async def auth_chain(self, receiver=_task_forward.receiver, task_meta=meta, message=[] - ) + ), + expire=60 * 60 * 2, ) ) await self.reply_user(platform=_task_forward.receiver.platform, @@ -202,15 +204,12 @@ async def run_pending_task(task: TaskHeader, pending_task: TaskBatch): task_batch=pending_task ) return logger.info(f"[Resign Auth] \n--auth-require {pending_task.get_batch_name()} require.") - # 获取环境变量 + # Get Env _env_dict = await EnvManager.from_uid(uid=task.receiver.uid).get_env_list(name_list=_tool_obj.env_list) assert isinstance(_env_dict, dict), "unexpected env dict? it should be dict..." - # 运行函数, 传递模型的信息,以及上一条的结果的openai raw信息 - - # 必须提前注册,否则无法获取到结果 + # Resign Chain if (task.task_meta.resign_next_step or task.task_meta.is_complete( num_end=1)) and not _tool_obj.repeatable: - # 路由 logger.debug(f"Function {pending_task.get_batch_name()} need resign chain") await chain_func.resign_chain( task=task, @@ -218,7 +217,7 @@ async def run_pending_task(task: TaskHeader, pending_task: TaskBatch): repeatable=_tool_obj.repeatable, deploy_child=_tool_obj.deploy_child, ) - # 运行函数 + # 运行函数, 传递模型的信息,以及上一条的结果的openai raw信息 run_result = await _tool_obj.load(task=task, receiver=task.receiver, arg=_arg, @@ -226,6 +225,7 @@ async def run_pending_task(task: TaskHeader, pending_task: TaskBatch): pending_task=pending_task, refer_llm_result=task.task_meta.llm_result ) + # 更新任务状态 await task.task_meta.complete_task(task_batch=pending_task, run_result=run_result) return run_result