forked from killua/TakwayDisplayPlatform
#update: 添加mixed模型方案
This commit is contained in:
parent
6354fe49f0
commit
55e2d7209c
|
@ -87,6 +87,12 @@ class MINIMAX_LLM(LLM):
|
|||
|
||||
async def chat(self, assistant, prompt):
|
||||
llm_info = json.loads(assistant.llm_info)
|
||||
|
||||
# 临时方案,后续需要修改
|
||||
user_info = json.loads(assistant.user_info)
|
||||
if user_info['llm_type'] == "MIXED":
|
||||
llm_info['model'] = 'abab6.5s-chat'
|
||||
|
||||
messages = json.loads(assistant.messages)
|
||||
messages.append({'role':'user','content':prompt})
|
||||
payload = json.dumps({ #整理payload
|
||||
|
@ -146,6 +152,12 @@ class VOLCENGINE_LLM(LLM):
|
|||
|
||||
async def chat(self, assistant, prompt):
|
||||
llm_info = json.loads(assistant.llm_info)
|
||||
|
||||
# 临时方案,后续需要修改
|
||||
user_info = json.loads(assistant.user_info)
|
||||
if user_info['llm_type'] == "MIXED":
|
||||
llm_info['model'] = 'doubao-32k-lite'
|
||||
|
||||
model = self.__get_model(llm_info)
|
||||
messages = json.loads(assistant.messages)
|
||||
messages.append({'role':'user','content':prompt})
|
||||
|
@ -196,6 +208,12 @@ class ZHIPU_LLM(LLM):
|
|||
|
||||
async def chat(self, assistant, prompt):
|
||||
llm_info = json.loads(assistant.llm_info)
|
||||
|
||||
# 临时方案,后续需要修改
|
||||
user_info = json.loads(assistant.user_info)
|
||||
if user_info['llm_type'] == "MIXED":
|
||||
llm_info['model'] = 'glm-4-air'
|
||||
|
||||
messages = json.loads(assistant.messages)
|
||||
messages.append({'role':'user','content':prompt})
|
||||
stream = self.client.chat.completions.create(
|
||||
|
@ -233,6 +251,42 @@ class MIXED_LLM(LLM):
|
|||
self.volcengine = VOLCENGINE_LLM()
|
||||
self.zhipu = ZHIPU_LLM()
|
||||
|
||||
async def chat(self, assistant, prompt):
|
||||
minimax_result = self.minimax.chat(assistant, prompt)
|
||||
volcengine_result = self.volcengine.chat(assistant, prompt)
|
||||
zhipu_result = self.zhipu.chat(assistant, prompt)
|
||||
|
||||
minimax_task = asyncio.create_task(minimax_result.__anext__())
|
||||
volcengine_task = asyncio.create_task(volcengine_result.__anext__())
|
||||
zhipu_task = asyncio.create_task(zhipu_result.__anext__())
|
||||
|
||||
done, pending = await asyncio.wait([minimax_task, volcengine_task, zhipu_task], return_when=asyncio.FIRST_COMPLETED)
|
||||
first_task = done.pop()
|
||||
if first_task == minimax_task:
|
||||
logger.debug("使用MINIMAX模型")
|
||||
yield await minimax_task
|
||||
while True:
|
||||
try:
|
||||
yield await minimax_result.__anext__()
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
elif first_task == volcengine_task:
|
||||
logger.debug("使用豆包模型")
|
||||
yield await volcengine_task
|
||||
while True:
|
||||
try:
|
||||
yield await volcengine_result.__anext__()
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
elif first_task == zhipu_task:
|
||||
logger.debug("使用智谱模型")
|
||||
yield await zhipu_task
|
||||
while True:
|
||||
try:
|
||||
yield await zhipu_result.__anext__()
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
class VITS_TTS(TTS):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
@ -258,6 +312,8 @@ class LLMFactory:
|
|||
return VOLCENGINE_LLM()
|
||||
if llm_type == 'ZHIPU':
|
||||
return ZHIPU_LLM()
|
||||
if llm_type == 'MIXED':
|
||||
return MIXED_LLM()
|
||||
|
||||
class TTSFactory:
|
||||
def create_tts(self,tts_type:str) -> TTS:
|
||||
|
|
11
main.py
11
main.py
|
@ -226,14 +226,19 @@ async def streaming_chat(ws: WebSocket):
|
|||
agent.recorder.input_text = prompt
|
||||
logger.debug("开始调用大模型")
|
||||
llm_frames = await agent.chat(assistant, prompt)
|
||||
|
||||
start_time = time.time()
|
||||
is_first_response = True
|
||||
|
||||
async for llm_frame in llm_frames:
|
||||
if is_first_response:
|
||||
end_time = time.time()
|
||||
logger.debug(f"第一帧返回耗时:{round(end_time-start_time,3)}s")
|
||||
is_first_response = False
|
||||
resp_msgs = agent.llm_msg_process(llm_frame)
|
||||
for resp_msg in resp_msgs:
|
||||
llm_text += resp_msg
|
||||
tts_start_time = time.time()
|
||||
tts_audio = agent.synthetize(assistant, resp_msg)
|
||||
tts_end_time = time.time()
|
||||
logger.debug(f"TTS生成音频耗时:{tts_end_time-tts_start_time}s")
|
||||
agent.tts_audio_process(tts_audio)
|
||||
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
|
||||
logger.debug(f'websocket返回:{resp_msg}')
|
||||
|
|
Loading…
Reference in New Issue