1
0
Fork 0

#update: Mixed模型

This commit is contained in:
killua 2024-06-21 21:31:48 +08:00
parent 55e2d7209c
commit 3f78fcb12c
3 changed files with 117 additions and 115 deletions

View File

@ -8,10 +8,12 @@ from .exception import *
from .dependency import get_logger from .dependency import get_logger
from utils.vits_utils import TextToSpeech from utils.vits_utils import TextToSpeech
from config import Config from config import Config
import aiohttp import threading
import requests
import asyncio import asyncio
import struct import struct
import base64 import base64
import time
import json import json
# ----------- 初始化vits ----------- # # ----------- 初始化vits ----------- #
@ -85,7 +87,7 @@ class MINIMAX_LLM(LLM):
def __init__(self): def __init__(self):
self.token = 0 self.token = 0
async def chat(self, assistant, prompt): def chat(self, assistant, prompt):
llm_info = json.loads(assistant.llm_info) llm_info = json.loads(assistant.llm_info)
# 临时方案,后续需要修改 # 临时方案,后续需要修改
@ -107,9 +109,8 @@ class MINIMAX_LLM(LLM):
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
async with aiohttp.ClientSession() as client: response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) #调用大模型
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #调用大模型 for chunk in response.iter_lines():
async for chunk in response.content.iter_any():
try: try:
chunk_msg = self.__parseChunk(chunk) #解析llm返回 chunk_msg = self.__parseChunk(chunk) #解析llm返回
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
@ -150,7 +151,8 @@ class VOLCENGINE_LLM(LLM):
self.token = 0 self.token = 0
self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY) self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY)
async def chat(self, assistant, prompt): def chat(self, assistant, prompt):
try:
llm_info = json.loads(assistant.llm_info) llm_info = json.loads(assistant.llm_info)
# 临时方案,后续需要修改 # 临时方案,后续需要修改
@ -182,6 +184,8 @@ class VOLCENGINE_LLM(LLM):
msg_frame['code'] = '201' msg_frame['code'] = '201'
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
yield msg_frame yield msg_frame
except GeneratorExit:
stream.response.close()
def __get_model(self, llm_info): def __get_model(self, llm_info):
if llm_info['model'] == 'doubao-4k-lite': if llm_info['model'] == 'doubao-4k-lite':
@ -206,7 +210,8 @@ class ZHIPU_LLM(LLM):
self.token = 0 self.token = 0
self.client = ZhipuAI(api_key=Config.ZHIPU_LLM.API_KEY) self.client = ZhipuAI(api_key=Config.ZHIPU_LLM.API_KEY)
async def chat(self, assistant, prompt): def chat(self, assistant, prompt):
try:
llm_info = json.loads(assistant.llm_info) llm_info = json.loads(assistant.llm_info)
# 临时方案,后续需要修改 # 临时方案,后续需要修改
@ -236,6 +241,8 @@ class ZHIPU_LLM(LLM):
msg_frame['code'] = '201' msg_frame['code'] = '201'
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
yield msg_frame yield msg_frame
except GeneratorExit:
stream.response.close()
def __parseChunk(self, llm_chunk): def __parseChunk(self, llm_chunk):
if llm_chunk.usage: if llm_chunk.usage:
@ -247,45 +254,38 @@ class ZHIPU_LLM(LLM):
class MIXED_LLM(LLM): class MIXED_LLM(LLM):
def __init__(self): def __init__(self):
self.minimax = MINIMAX_LLM() # self.minimax = MINIMAX_LLM()
self.volcengine = VOLCENGINE_LLM() self.volcengine = VOLCENGINE_LLM()
self.zhipu = ZHIPU_LLM() self.zhipu = ZHIPU_LLM()
async def chat(self, assistant, prompt): def chat_wrapper(self,chat_func, assistant, prompt, event, result_container):
minimax_result = self.minimax.chat(assistant, prompt) gen = chat_func(assistant, prompt)
volcengine_result = self.volcengine.chat(assistant, prompt) first_chunk = next(gen)
zhipu_result = self.zhipu.chat(assistant, prompt) if event.is_set():
raise GeneratorExit()
else:
result_container.append({
'first_chunk': first_chunk,
'gen': gen
})
minimax_task = asyncio.create_task(minimax_result.__anext__()) def chat(self, assistant, prompt):
volcengine_task = asyncio.create_task(volcengine_result.__anext__()) event = threading.Event()
zhipu_task = asyncio.create_task(zhipu_result.__anext__()) result_container = [] #大模型返回结果
volc_thread = threading.Thread(target=self.chat_wrapper, args=(self.volcengine.chat, assistant, prompt, event, result_container))
zhipu_thread = threading.Thread(target=self.chat_wrapper, args=(self.zhipu.chat, assistant, prompt, event, result_container))
volc_thread.start()
zhipu_thread.start()
while True:
if result_container:
yield result_container[0]['first_chunk']
for chunk in result_container[0]['gen']:
yield chunk
break
time.sleep(0.05)
volc_thread.join()
zhipu_thread.join()
done, pending = await asyncio.wait([minimax_task, volcengine_task, zhipu_task], return_when=asyncio.FIRST_COMPLETED)
first_task = done.pop()
if first_task == minimax_task:
logger.debug("使用MINIMAX模型")
yield await minimax_task
while True:
try:
yield await minimax_result.__anext__()
except StopAsyncIteration:
break
elif first_task == volcengine_task:
logger.debug("使用豆包模型")
yield await volcengine_task
while True:
try:
yield await volcengine_result.__anext__()
except StopAsyncIteration:
break
elif first_task == zhipu_task:
logger.debug("使用智谱模型")
yield await zhipu_task
while True:
try:
yield await zhipu_result.__anext__()
except StopAsyncIteration:
break
class VITS_TTS(TTS): class VITS_TTS(TTS):
def __init__(self): def __init__(self):

View File

@ -209,10 +209,11 @@ async def streaming_chat(ws: WebSocket):
llm_text = "" llm_text = ""
logger.debug("开始进行ASR识别") logger.debug("开始进行ASR识别")
while len(asr_results)==0: while len(asr_results)==0:
chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=2)) chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=1))
if assistant is None: if assistant is None:
with get_db_context() as db: #使用with语句获取数据库连接自动关闭数据库连接 with get_db_context() as db: #使用with语句获取数据库连接自动关闭数据库连接
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first() assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
logger.debug(f"接收到{assistant.name}的请求")
if assistant is None: if assistant is None:
raise SessionNotFoundError() raise SessionNotFoundError()
user_info = json.loads(assistant.user_info) user_info = json.loads(assistant.user_info)
@ -230,7 +231,7 @@ async def streaming_chat(ws: WebSocket):
start_time = time.time() start_time = time.time()
is_first_response = True is_first_response = True
async for llm_frame in llm_frames: for llm_frame in llm_frames:
if is_first_response: if is_first_response:
end_time = time.time() end_time = time.time()
logger.debug(f"第一帧返回耗时:{round(end_time-start_time,3)}s") logger.debug(f"第一帧返回耗时:{round(end_time-start_time,3)}s")

View File

@ -13,3 +13,4 @@ numba
librosa librosa
aiohttp aiohttp
'volcengine-python-sdk[ark]' 'volcengine-python-sdk[ark]'
zhipuai