#update: 添加mixed模型方案

This commit is contained in:
killua4396 2024-06-20 16:06:03 +08:00
parent 6354fe49f0
commit 55e2d7209c
2 changed files with 64 additions and 3 deletions

View File

@ -87,6 +87,12 @@ class MINIMAX_LLM(LLM):
async def chat(self, assistant, prompt):
llm_info = json.loads(assistant.llm_info)
# 临时方案,后续需要修改
user_info = json.loads(assistant.user_info)
if user_info['llm_type'] == "MIXED":
llm_info['model'] = 'abab6.5s-chat'
messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt})
payload = json.dumps({ #整理payload
@ -146,6 +152,12 @@ class VOLCENGINE_LLM(LLM):
async def chat(self, assistant, prompt):
llm_info = json.loads(assistant.llm_info)
# 临时方案,后续需要修改
user_info = json.loads(assistant.user_info)
if user_info['llm_type'] == "MIXED":
llm_info['model'] = 'doubao-32k-lite'
model = self.__get_model(llm_info)
messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt})
@ -196,6 +208,12 @@ class ZHIPU_LLM(LLM):
async def chat(self, assistant, prompt):
llm_info = json.loads(assistant.llm_info)
# 临时方案,后续需要修改
user_info = json.loads(assistant.user_info)
if user_info['llm_type'] == "MIXED":
llm_info['model'] = 'glm-4-air'
messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt})
stream = self.client.chat.completions.create(
@ -232,6 +250,42 @@ class MIXED_LLM(LLM):
self.minimax = MINIMAX_LLM()
self.volcengine = VOLCENGINE_LLM()
self.zhipu = ZHIPU_LLM()
async def chat(self, assistant, prompt):
minimax_result = self.minimax.chat(assistant, prompt)
volcengine_result = self.volcengine.chat(assistant, prompt)
zhipu_result = self.zhipu.chat(assistant, prompt)
minimax_task = asyncio.create_task(minimax_result.__anext__())
volcengine_task = asyncio.create_task(volcengine_result.__anext__())
zhipu_task = asyncio.create_task(zhipu_result.__anext__())
done, pending = await asyncio.wait([minimax_task, volcengine_task, zhipu_task], return_when=asyncio.FIRST_COMPLETED)
first_task = done.pop()
if first_task == minimax_task:
logger.debug("使用MINIMAX模型")
yield await minimax_task
while True:
try:
yield await minimax_result.__anext__()
except StopAsyncIteration:
break
elif first_task == volcengine_task:
logger.debug("使用豆包模型")
yield await volcengine_task
while True:
try:
yield await volcengine_result.__anext__()
except StopAsyncIteration:
break
elif first_task == zhipu_task:
logger.debug("使用智谱模型")
yield await zhipu_task
while True:
try:
yield await zhipu_result.__anext__()
except StopAsyncIteration:
break
class VITS_TTS(TTS):
def __init__(self):
@ -258,6 +312,8 @@ class LLMFactory:
return VOLCENGINE_LLM()
if llm_type == 'ZHIPU':
return ZHIPU_LLM()
if llm_type == 'MIXED':
return MIXED_LLM()
class TTSFactory:
def create_tts(self,tts_type:str) -> TTS:

11
main.py
View File

@ -226,14 +226,19 @@ async def streaming_chat(ws: WebSocket):
agent.recorder.input_text = prompt
logger.debug("开始调用大模型")
llm_frames = await agent.chat(assistant, prompt)
start_time = time.time()
is_first_response = True
async for llm_frame in llm_frames:
if is_first_response:
end_time = time.time()
logger.debug(f"第一帧返回耗时:{round(end_time-start_time,3)}s")
is_first_response = False
resp_msgs = agent.llm_msg_process(llm_frame)
for resp_msg in resp_msgs:
llm_text += resp_msg
tts_start_time = time.time()
tts_audio = agent.synthetize(assistant, resp_msg)
tts_end_time = time.time()
logger.debug(f"TTS生成音频耗时{tts_end_time-tts_start_time}s")
agent.tts_audio_process(tts_audio)
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
logger.debug(f'websocket返回{resp_msg}')