From 55e2d7209cae2cd57e2810e88bd8b28f07a5b981 Mon Sep 17 00:00:00 2001 From: killua4396 <1223086337@qq.com> Date: Thu, 20 Jun 2024 16:06:03 +0800 Subject: [PATCH] =?UTF-8?q?#update:=20=E6=B7=BB=E5=8A=A0mixed=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E6=96=B9=E6=A1=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/concrete.py | 56 +++++++++++++++++++++++++++++++++++++++++++++++++ main.py | 11 +++++++--- 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/app/concrete.py b/app/concrete.py index eeb45b7..5cc4041 100644 --- a/app/concrete.py +++ b/app/concrete.py @@ -87,6 +87,12 @@ class MINIMAX_LLM(LLM): async def chat(self, assistant, prompt): llm_info = json.loads(assistant.llm_info) + + # 临时方案,后续需要修改 + user_info = json.loads(assistant.user_info) + if user_info['llm_type'] == "MIXED": + llm_info['model'] = 'abab6.5s-chat' + messages = json.loads(assistant.messages) messages.append({'role':'user','content':prompt}) payload = json.dumps({ #整理payload @@ -146,6 +152,12 @@ class VOLCENGINE_LLM(LLM): async def chat(self, assistant, prompt): llm_info = json.loads(assistant.llm_info) + + # 临时方案,后续需要修改 + user_info = json.loads(assistant.user_info) + if user_info['llm_type'] == "MIXED": + llm_info['model'] = 'doubao-32k-lite' + model = self.__get_model(llm_info) messages = json.loads(assistant.messages) messages.append({'role':'user','content':prompt}) @@ -196,6 +208,12 @@ class ZHIPU_LLM(LLM): async def chat(self, assistant, prompt): llm_info = json.loads(assistant.llm_info) + + # 临时方案,后续需要修改 + user_info = json.loads(assistant.user_info) + if user_info['llm_type'] == "MIXED": + llm_info['model'] = 'glm-4-air' + messages = json.loads(assistant.messages) messages.append({'role':'user','content':prompt}) stream = self.client.chat.completions.create( @@ -232,6 +250,42 @@ class MIXED_LLM(LLM): self.minimax = MINIMAX_LLM() self.volcengine = VOLCENGINE_LLM() self.zhipu = ZHIPU_LLM() + + async def chat(self, assistant, prompt): + minimax_result = self.minimax.chat(assistant, prompt) + volcengine_result = self.volcengine.chat(assistant, prompt) + zhipu_result = self.zhipu.chat(assistant, prompt) + + minimax_task = asyncio.create_task(minimax_result.__anext__()) + volcengine_task = asyncio.create_task(volcengine_result.__anext__()) + zhipu_task = asyncio.create_task(zhipu_result.__anext__()) + + done, pending = await asyncio.wait([minimax_task, volcengine_task, zhipu_task], return_when=asyncio.FIRST_COMPLETED) + first_task = done.pop() + if first_task == minimax_task: + logger.debug("使用MINIMAX模型") + yield await minimax_task + while True: + try: + yield await minimax_result.__anext__() + except StopAsyncIteration: + break + elif first_task == volcengine_task: + logger.debug("使用豆包模型") + yield await volcengine_task + while True: + try: + yield await volcengine_result.__anext__() + except StopAsyncIteration: + break + elif first_task == zhipu_task: + logger.debug("使用智谱模型") + yield await zhipu_task + while True: + try: + yield await zhipu_result.__anext__() + except StopAsyncIteration: + break class VITS_TTS(TTS): def __init__(self): @@ -258,6 +312,8 @@ class LLMFactory: return VOLCENGINE_LLM() if llm_type == 'ZHIPU': return ZHIPU_LLM() + if llm_type == 'MIXED': + return MIXED_LLM() class TTSFactory: def create_tts(self,tts_type:str) -> TTS: diff --git a/main.py b/main.py index 1571f9a..b01ad48 100644 --- a/main.py +++ b/main.py @@ -226,14 +226,19 @@ async def streaming_chat(ws: WebSocket): agent.recorder.input_text = prompt logger.debug("开始调用大模型") llm_frames = await agent.chat(assistant, prompt) + + start_time = time.time() + is_first_response = True + async for llm_frame in llm_frames: + if is_first_response: + end_time = time.time() + logger.debug(f"第一帧返回耗时:{round(end_time-start_time,3)}s") + is_first_response = False resp_msgs = agent.llm_msg_process(llm_frame) for resp_msg in resp_msgs: llm_text += resp_msg - tts_start_time = time.time() tts_audio = agent.synthetize(assistant, resp_msg) - tts_end_time = time.time() - logger.debug(f"TTS生成音频耗时:{tts_end_time-tts_start_time}s") agent.tts_audio_process(tts_audio) await ws.send_bytes(agent.encode(resp_msg, tts_audio)) logger.debug(f'websocket返回:{resp_msg}')