#update: 添加mixed模型方案
This commit is contained in:
parent
6354fe49f0
commit
55e2d7209c
|
@ -87,6 +87,12 @@ class MINIMAX_LLM(LLM):
|
||||||
|
|
||||||
async def chat(self, assistant, prompt):
|
async def chat(self, assistant, prompt):
|
||||||
llm_info = json.loads(assistant.llm_info)
|
llm_info = json.loads(assistant.llm_info)
|
||||||
|
|
||||||
|
# 临时方案,后续需要修改
|
||||||
|
user_info = json.loads(assistant.user_info)
|
||||||
|
if user_info['llm_type'] == "MIXED":
|
||||||
|
llm_info['model'] = 'abab6.5s-chat'
|
||||||
|
|
||||||
messages = json.loads(assistant.messages)
|
messages = json.loads(assistant.messages)
|
||||||
messages.append({'role':'user','content':prompt})
|
messages.append({'role':'user','content':prompt})
|
||||||
payload = json.dumps({ #整理payload
|
payload = json.dumps({ #整理payload
|
||||||
|
@ -146,6 +152,12 @@ class VOLCENGINE_LLM(LLM):
|
||||||
|
|
||||||
async def chat(self, assistant, prompt):
|
async def chat(self, assistant, prompt):
|
||||||
llm_info = json.loads(assistant.llm_info)
|
llm_info = json.loads(assistant.llm_info)
|
||||||
|
|
||||||
|
# 临时方案,后续需要修改
|
||||||
|
user_info = json.loads(assistant.user_info)
|
||||||
|
if user_info['llm_type'] == "MIXED":
|
||||||
|
llm_info['model'] = 'doubao-32k-lite'
|
||||||
|
|
||||||
model = self.__get_model(llm_info)
|
model = self.__get_model(llm_info)
|
||||||
messages = json.loads(assistant.messages)
|
messages = json.loads(assistant.messages)
|
||||||
messages.append({'role':'user','content':prompt})
|
messages.append({'role':'user','content':prompt})
|
||||||
|
@ -196,6 +208,12 @@ class ZHIPU_LLM(LLM):
|
||||||
|
|
||||||
async def chat(self, assistant, prompt):
|
async def chat(self, assistant, prompt):
|
||||||
llm_info = json.loads(assistant.llm_info)
|
llm_info = json.loads(assistant.llm_info)
|
||||||
|
|
||||||
|
# 临时方案,后续需要修改
|
||||||
|
user_info = json.loads(assistant.user_info)
|
||||||
|
if user_info['llm_type'] == "MIXED":
|
||||||
|
llm_info['model'] = 'glm-4-air'
|
||||||
|
|
||||||
messages = json.loads(assistant.messages)
|
messages = json.loads(assistant.messages)
|
||||||
messages.append({'role':'user','content':prompt})
|
messages.append({'role':'user','content':prompt})
|
||||||
stream = self.client.chat.completions.create(
|
stream = self.client.chat.completions.create(
|
||||||
|
@ -232,6 +250,42 @@ class MIXED_LLM(LLM):
|
||||||
self.minimax = MINIMAX_LLM()
|
self.minimax = MINIMAX_LLM()
|
||||||
self.volcengine = VOLCENGINE_LLM()
|
self.volcengine = VOLCENGINE_LLM()
|
||||||
self.zhipu = ZHIPU_LLM()
|
self.zhipu = ZHIPU_LLM()
|
||||||
|
|
||||||
|
async def chat(self, assistant, prompt):
|
||||||
|
minimax_result = self.minimax.chat(assistant, prompt)
|
||||||
|
volcengine_result = self.volcengine.chat(assistant, prompt)
|
||||||
|
zhipu_result = self.zhipu.chat(assistant, prompt)
|
||||||
|
|
||||||
|
minimax_task = asyncio.create_task(minimax_result.__anext__())
|
||||||
|
volcengine_task = asyncio.create_task(volcengine_result.__anext__())
|
||||||
|
zhipu_task = asyncio.create_task(zhipu_result.__anext__())
|
||||||
|
|
||||||
|
done, pending = await asyncio.wait([minimax_task, volcengine_task, zhipu_task], return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
first_task = done.pop()
|
||||||
|
if first_task == minimax_task:
|
||||||
|
logger.debug("使用MINIMAX模型")
|
||||||
|
yield await minimax_task
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
yield await minimax_result.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
elif first_task == volcengine_task:
|
||||||
|
logger.debug("使用豆包模型")
|
||||||
|
yield await volcengine_task
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
yield await volcengine_result.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
elif first_task == zhipu_task:
|
||||||
|
logger.debug("使用智谱模型")
|
||||||
|
yield await zhipu_task
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
yield await zhipu_result.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
class VITS_TTS(TTS):
|
class VITS_TTS(TTS):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -258,6 +312,8 @@ class LLMFactory:
|
||||||
return VOLCENGINE_LLM()
|
return VOLCENGINE_LLM()
|
||||||
if llm_type == 'ZHIPU':
|
if llm_type == 'ZHIPU':
|
||||||
return ZHIPU_LLM()
|
return ZHIPU_LLM()
|
||||||
|
if llm_type == 'MIXED':
|
||||||
|
return MIXED_LLM()
|
||||||
|
|
||||||
class TTSFactory:
|
class TTSFactory:
|
||||||
def create_tts(self,tts_type:str) -> TTS:
|
def create_tts(self,tts_type:str) -> TTS:
|
||||||
|
|
11
main.py
11
main.py
|
@ -226,14 +226,19 @@ async def streaming_chat(ws: WebSocket):
|
||||||
agent.recorder.input_text = prompt
|
agent.recorder.input_text = prompt
|
||||||
logger.debug("开始调用大模型")
|
logger.debug("开始调用大模型")
|
||||||
llm_frames = await agent.chat(assistant, prompt)
|
llm_frames = await agent.chat(assistant, prompt)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
is_first_response = True
|
||||||
|
|
||||||
async for llm_frame in llm_frames:
|
async for llm_frame in llm_frames:
|
||||||
|
if is_first_response:
|
||||||
|
end_time = time.time()
|
||||||
|
logger.debug(f"第一帧返回耗时:{round(end_time-start_time,3)}s")
|
||||||
|
is_first_response = False
|
||||||
resp_msgs = agent.llm_msg_process(llm_frame)
|
resp_msgs = agent.llm_msg_process(llm_frame)
|
||||||
for resp_msg in resp_msgs:
|
for resp_msg in resp_msgs:
|
||||||
llm_text += resp_msg
|
llm_text += resp_msg
|
||||||
tts_start_time = time.time()
|
|
||||||
tts_audio = agent.synthetize(assistant, resp_msg)
|
tts_audio = agent.synthetize(assistant, resp_msg)
|
||||||
tts_end_time = time.time()
|
|
||||||
logger.debug(f"TTS生成音频耗时:{tts_end_time-tts_start_time}s")
|
|
||||||
agent.tts_audio_process(tts_audio)
|
agent.tts_audio_process(tts_audio)
|
||||||
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
|
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
|
||||||
logger.debug(f'websocket返回:{resp_msg}')
|
logger.debug(f'websocket返回:{resp_msg}')
|
||||||
|
|
Loading…
Reference in New Issue