1
0
Fork 0

#update: Mixed模型

This commit is contained in:
killua 2024-06-21 21:31:48 +08:00
parent 55e2d7209c
commit 3f78fcb12c
3 changed files with 117 additions and 115 deletions

View File

@ -8,10 +8,12 @@ from .exception import *
from .dependency import get_logger
from utils.vits_utils import TextToSpeech
from config import Config
import aiohttp
import threading
import requests
import asyncio
import struct
import base64
import time
import json
# ----------- 初始化vits ----------- #
@ -85,7 +87,7 @@ class MINIMAX_LLM(LLM):
def __init__(self):
self.token = 0
async def chat(self, assistant, prompt):
def chat(self, assistant, prompt):
llm_info = json.loads(assistant.llm_info)
# 临时方案,后续需要修改
@ -107,20 +109,19 @@ class MINIMAX_LLM(LLM):
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json'
}
async with aiohttp.ClientSession() as client:
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #调用大模型
async for chunk in response.content.iter_any():
try:
chunk_msg = self.__parseChunk(chunk) #解析llm返回
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame
except LLMResponseEnd:
msg_frame = {"is_end":True,"code":200,"msg":""}
assistant.token = self.token
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['code'] = '201'
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
yield msg_frame
response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) #调用大模型
for chunk in response.iter_lines():
try:
chunk_msg = self.__parseChunk(chunk) #解析llm返回
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame
except LLMResponseEnd:
msg_frame = {"is_end":True,"code":200,"msg":""}
assistant.token = self.token
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['code'] = '201'
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
yield msg_frame
def __parseChunk(self, llm_chunk):
@ -150,38 +151,41 @@ class VOLCENGINE_LLM(LLM):
self.token = 0
self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY)
async def chat(self, assistant, prompt):
llm_info = json.loads(assistant.llm_info)
# 临时方案,后续需要修改
user_info = json.loads(assistant.user_info)
if user_info['llm_type'] == "MIXED":
llm_info['model'] = 'doubao-32k-lite'
def chat(self, assistant, prompt):
try:
llm_info = json.loads(assistant.llm_info)
model = self.__get_model(llm_info)
messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt})
stream = self.client.chat.completions.create(
model = model,
messages=messages,
stream=True,
temperature=llm_info['temperature'],
top_p=llm_info['top_p'],
max_tokens=llm_info['max_tokens'],
stream_options={'include_usage': True}
)
for chunk in stream:
try:
chunk_msg = self.__parseChunk(chunk)
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame
except LLMResponseEnd:
msg_frame = {"is_end":True,"code":200,"msg":""}
assistant.token = self.token
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['code'] = '201'
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
yield msg_frame
# 临时方案,后续需要修改
user_info = json.loads(assistant.user_info)
if user_info['llm_type'] == "MIXED":
llm_info['model'] = 'doubao-32k-lite'
model = self.__get_model(llm_info)
messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt})
stream = self.client.chat.completions.create(
model = model,
messages=messages,
stream=True,
temperature=llm_info['temperature'],
top_p=llm_info['top_p'],
max_tokens=llm_info['max_tokens'],
stream_options={'include_usage': True}
)
for chunk in stream:
try:
chunk_msg = self.__parseChunk(chunk)
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame
except LLMResponseEnd:
msg_frame = {"is_end":True,"code":200,"msg":""}
assistant.token = self.token
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['code'] = '201'
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
yield msg_frame
except GeneratorExit:
stream.response.close()
def __get_model(self, llm_info):
if llm_info['model'] == 'doubao-4k-lite':
@ -206,36 +210,39 @@ class ZHIPU_LLM(LLM):
self.token = 0
self.client = ZhipuAI(api_key=Config.ZHIPU_LLM.API_KEY)
async def chat(self, assistant, prompt):
llm_info = json.loads(assistant.llm_info)
# 临时方案,后续需要修改
user_info = json.loads(assistant.user_info)
if user_info['llm_type'] == "MIXED":
llm_info['model'] = 'glm-4-air'
def chat(self, assistant, prompt):
try:
llm_info = json.loads(assistant.llm_info)
messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt})
stream = self.client.chat.completions.create(
model = llm_info['model'],
messages=messages,
stream=True,
temperature=llm_info['temperature'],
top_p=llm_info['top_p'],
max_tokens=llm_info['max_tokens']
)
for chunk in stream:
try:
chunk_msg = self.__parseChunk(chunk)
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame
except LLMResponseEnd:
msg_frame = {"is_end":True,"code":200,"msg":""}
assistant.token = self.token
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['code'] = '201'
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
yield msg_frame
# 临时方案,后续需要修改
user_info = json.loads(assistant.user_info)
if user_info['llm_type'] == "MIXED":
llm_info['model'] = 'glm-4-air'
messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt})
stream = self.client.chat.completions.create(
model = llm_info['model'],
messages=messages,
stream=True,
temperature=llm_info['temperature'],
top_p=llm_info['top_p'],
max_tokens=llm_info['max_tokens']
)
for chunk in stream:
try:
chunk_msg = self.__parseChunk(chunk)
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame
except LLMResponseEnd:
msg_frame = {"is_end":True,"code":200,"msg":""}
assistant.token = self.token
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['code'] = '201'
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
yield msg_frame
except GeneratorExit:
stream.response.close()
def __parseChunk(self, llm_chunk):
if llm_chunk.usage:
@ -247,45 +254,38 @@ class ZHIPU_LLM(LLM):
class MIXED_LLM(LLM):
def __init__(self):
self.minimax = MINIMAX_LLM()
# self.minimax = MINIMAX_LLM()
self.volcengine = VOLCENGINE_LLM()
self.zhipu = ZHIPU_LLM()
async def chat(self, assistant, prompt):
minimax_result = self.minimax.chat(assistant, prompt)
volcengine_result = self.volcengine.chat(assistant, prompt)
zhipu_result = self.zhipu.chat(assistant, prompt)
minimax_task = asyncio.create_task(minimax_result.__anext__())
volcengine_task = asyncio.create_task(volcengine_result.__anext__())
zhipu_task = asyncio.create_task(zhipu_result.__anext__())
done, pending = await asyncio.wait([minimax_task, volcengine_task, zhipu_task], return_when=asyncio.FIRST_COMPLETED)
first_task = done.pop()
if first_task == minimax_task:
logger.debug("使用MINIMAX模型")
yield await minimax_task
while True:
try:
yield await minimax_result.__anext__()
except StopAsyncIteration:
break
elif first_task == volcengine_task:
logger.debug("使用豆包模型")
yield await volcengine_task
while True:
try:
yield await volcengine_result.__anext__()
except StopAsyncIteration:
break
elif first_task == zhipu_task:
logger.debug("使用智谱模型")
yield await zhipu_task
while True:
try:
yield await zhipu_result.__anext__()
except StopAsyncIteration:
break
def chat_wrapper(self,chat_func, assistant, prompt, event, result_container):
gen = chat_func(assistant, prompt)
first_chunk = next(gen)
if event.is_set():
raise GeneratorExit()
else:
result_container.append({
'first_chunk': first_chunk,
'gen': gen
})
def chat(self, assistant, prompt):
event = threading.Event()
result_container = [] #大模型返回结果
volc_thread = threading.Thread(target=self.chat_wrapper, args=(self.volcengine.chat, assistant, prompt, event, result_container))
zhipu_thread = threading.Thread(target=self.chat_wrapper, args=(self.zhipu.chat, assistant, prompt, event, result_container))
volc_thread.start()
zhipu_thread.start()
while True:
if result_container:
yield result_container[0]['first_chunk']
for chunk in result_container[0]['gen']:
yield chunk
break
time.sleep(0.05)
volc_thread.join()
zhipu_thread.join()
class VITS_TTS(TTS):
def __init__(self):

View File

@ -209,10 +209,11 @@ async def streaming_chat(ws: WebSocket):
llm_text = ""
logger.debug("开始进行ASR识别")
while len(asr_results)==0:
chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=2))
chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=1))
if assistant is None:
with get_db_context() as db: #使用with语句获取数据库连接自动关闭数据库连接
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
logger.debug(f"接收到{assistant.name}的请求")
if assistant is None:
raise SessionNotFoundError()
user_info = json.loads(assistant.user_info)
@ -230,7 +231,7 @@ async def streaming_chat(ws: WebSocket):
start_time = time.time()
is_first_response = True
async for llm_frame in llm_frames:
for llm_frame in llm_frames:
if is_first_response:
end_time = time.time()
logger.debug(f"第一帧返回耗时:{round(end_time-start_time,3)}s")

View File

@ -12,4 +12,5 @@ cn2an
numba
librosa
aiohttp
'volcengine-python-sdk[ark]'
'volcengine-python-sdk[ark]'
zhipuai