Skip to content

Commit

Permalink
Wrap sending message in ChainFunc's notify_user method
Browse files Browse the repository at this point in the history
- Wrap the sending of a message in the notify_user method of the ChainFunc class.
  • Loading branch information
sudoskys committed Nov 26, 2023
1 parent 83a5b67 commit 90332f0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
8 changes: 6 additions & 2 deletions llmkira/middleware/chain_box/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, uid: str = None):
@classmethod
def _prefix(cls, uuid: str) -> str:
"""auth:{auth_schema_version}:uuid"""
return f"auth:v1:{uuid}"
return f"auth:v2:{uuid}"

@classmethod
def from_form(cls, platform: str, user_id: str):
Expand Down Expand Up @@ -52,6 +52,7 @@ async def read_auth(self, uuid: str) -> Optional[Chain]:
logger.debug(f"Get Auth Data {_cache} {_cache}")
chain = Chain.from_redis(_cache)
if chain.is_expire:
logger.debug(f"Auth Expire {chain.uuid}")
return None
if chain.creator_uid != self.uid:
logger.debug(f"Not User {self.uid} Created Auth")
Expand All @@ -66,7 +67,7 @@ def __init__(self, uid: str):

def _prefix(self) -> str:
"""chain:{auth_schema_version}:uuid"""
return f"chain:v1:{self.uid}"
return f"chain:v2:{self.uid}"

async def add_task(self, chain: Chain) -> str:
"""
Expand All @@ -87,12 +88,15 @@ async def get_task(self) -> Optional[Chain]:
:return Optional[Chain]
"""
cache = global_cache_runtime.get_redis()
# FIXME 优化为获取对应空间的数据
# FIXME signas 应该添加 uuid
redis_raw = await cache.lpop_data(
key=self._prefix()
)
if not redis_raw:
return None
chain = Chain.from_redis(redis_raw)
if chain.is_expire:
logger.debug(f"Chain Expire {chain.uuid}")
return None
return chain
20 changes: 10 additions & 10 deletions llmkira/receiver/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import os

from aio_pika.abc import AbstractIncomingMessage
from loguru import logger

from llmkira.middleware.chain_box import Chain, AuthReloader, ChainReloader
from llmkira.middleware.env_virtual import EnvManager
from llmkira.schema import RawMessage
from llmkira.sdk.func_calling.register import ToolRegister
from llmkira.sdk.schema import TaskBatch
from llmkira.task import Task, TaskHeader
from loguru import logger


class ChainFunc(object):
Expand All @@ -28,7 +29,7 @@ async def reply_user(*,
receiver: TaskHeader.Location
):
"""
通知用户
包装一下发送消息
:param platform: Default should be `task.receiver.platform`
:param task: 任务 header
:param text: 文本 str
Expand Down Expand Up @@ -72,7 +73,7 @@ async def auth_chain(self,
meta.verify_uuid = meta.get_verify_uuid(task_batch=task_batch)
# 注册本地部署点
task_id = await AuthReloader(uid=_task_forward.receiver.uid).save_auth(
chain=Chain(
chain=Chain.create(
uuid=meta.verify_uuid,
creator_uid=_task_forward.receiver.uid,
channel=__receiver__,
Expand All @@ -81,7 +82,8 @@ async def auth_chain(self,
receiver=_task_forward.receiver,
task_meta=meta,
message=[]
)
),
expire=60 * 60 * 2,
)
)
await self.reply_user(platform=_task_forward.receiver.platform,
Expand Down Expand Up @@ -202,30 +204,28 @@ async def run_pending_task(task: TaskHeader, pending_task: TaskBatch):
task_batch=pending_task
)
return logger.info(f"[Resign Auth] \n--auth-require {pending_task.get_batch_name()} require.")
# 获取环境变量
# Get Env
_env_dict = await EnvManager.from_uid(uid=task.receiver.uid).get_env_list(name_list=_tool_obj.env_list)
assert isinstance(_env_dict, dict), "unexpected env dict? it should be dict..."
# 运行函数, 传递模型的信息,以及上一条的结果的openai raw信息

# 必须提前注册,否则无法获取到结果
# Resign Chain
if (task.task_meta.resign_next_step or task.task_meta.is_complete(
num_end=1)) and not _tool_obj.repeatable:
# 路由
logger.debug(f"Function {pending_task.get_batch_name()} need resign chain")
await chain_func.resign_chain(
task=task,
parent_func=pending_task.get_batch_name(),
repeatable=_tool_obj.repeatable,
deploy_child=_tool_obj.deploy_child,
)
# 运行函数
# 运行函数, 传递模型的信息,以及上一条的结果的openai raw信息
run_result = await _tool_obj.load(task=task,
receiver=task.receiver,
arg=_arg,
env=_env_dict,
pending_task=pending_task,
refer_llm_result=task.task_meta.llm_result
)
# 更新任务状态
await task.task_meta.complete_task(task_batch=pending_task, run_result=run_result)
return run_result

Expand Down

0 comments on commit 90332f0

Please sign in to comment.