forked from killua/TakwayDisplayPlatform
#update: Mixed模型
This commit is contained in:
parent
55e2d7209c
commit
3f78fcb12c
224
app/concrete.py
224
app/concrete.py
|
@ -8,10 +8,12 @@ from .exception import *
|
|||
from .dependency import get_logger
|
||||
from utils.vits_utils import TextToSpeech
|
||||
from config import Config
|
||||
import aiohttp
|
||||
import threading
|
||||
import requests
|
||||
import asyncio
|
||||
import struct
|
||||
import base64
|
||||
import time
|
||||
import json
|
||||
|
||||
# ----------- 初始化vits ----------- #
|
||||
|
@ -85,7 +87,7 @@ class MINIMAX_LLM(LLM):
|
|||
def __init__(self):
|
||||
self.token = 0
|
||||
|
||||
async def chat(self, assistant, prompt):
|
||||
def chat(self, assistant, prompt):
|
||||
llm_info = json.loads(assistant.llm_info)
|
||||
|
||||
# 临时方案,后续需要修改
|
||||
|
@ -107,20 +109,19 @@ class MINIMAX_LLM(LLM):
|
|||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
async with aiohttp.ClientSession() as client:
|
||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #调用大模型
|
||||
async for chunk in response.content.iter_any():
|
||||
try:
|
||||
chunk_msg = self.__parseChunk(chunk) #解析llm返回
|
||||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||
yield msg_frame
|
||||
except LLMResponseEnd:
|
||||
msg_frame = {"is_end":True,"code":200,"msg":""}
|
||||
assistant.token = self.token
|
||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||
msg_frame['code'] = '201'
|
||||
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||
yield msg_frame
|
||||
response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) #调用大模型
|
||||
for chunk in response.iter_lines():
|
||||
try:
|
||||
chunk_msg = self.__parseChunk(chunk) #解析llm返回
|
||||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||
yield msg_frame
|
||||
except LLMResponseEnd:
|
||||
msg_frame = {"is_end":True,"code":200,"msg":""}
|
||||
assistant.token = self.token
|
||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||
msg_frame['code'] = '201'
|
||||
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||
yield msg_frame
|
||||
|
||||
|
||||
def __parseChunk(self, llm_chunk):
|
||||
|
@ -150,38 +151,41 @@ class VOLCENGINE_LLM(LLM):
|
|||
self.token = 0
|
||||
self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY)
|
||||
|
||||
async def chat(self, assistant, prompt):
|
||||
llm_info = json.loads(assistant.llm_info)
|
||||
|
||||
# 临时方案,后续需要修改
|
||||
user_info = json.loads(assistant.user_info)
|
||||
if user_info['llm_type'] == "MIXED":
|
||||
llm_info['model'] = 'doubao-32k-lite'
|
||||
def chat(self, assistant, prompt):
|
||||
try:
|
||||
llm_info = json.loads(assistant.llm_info)
|
||||
|
||||
model = self.__get_model(llm_info)
|
||||
messages = json.loads(assistant.messages)
|
||||
messages.append({'role':'user','content':prompt})
|
||||
stream = self.client.chat.completions.create(
|
||||
model = model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
temperature=llm_info['temperature'],
|
||||
top_p=llm_info['top_p'],
|
||||
max_tokens=llm_info['max_tokens'],
|
||||
stream_options={'include_usage': True}
|
||||
)
|
||||
for chunk in stream:
|
||||
try:
|
||||
chunk_msg = self.__parseChunk(chunk)
|
||||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||
yield msg_frame
|
||||
except LLMResponseEnd:
|
||||
msg_frame = {"is_end":True,"code":200,"msg":""}
|
||||
assistant.token = self.token
|
||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||
msg_frame['code'] = '201'
|
||||
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||
yield msg_frame
|
||||
# 临时方案,后续需要修改
|
||||
user_info = json.loads(assistant.user_info)
|
||||
if user_info['llm_type'] == "MIXED":
|
||||
llm_info['model'] = 'doubao-32k-lite'
|
||||
|
||||
model = self.__get_model(llm_info)
|
||||
messages = json.loads(assistant.messages)
|
||||
messages.append({'role':'user','content':prompt})
|
||||
stream = self.client.chat.completions.create(
|
||||
model = model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
temperature=llm_info['temperature'],
|
||||
top_p=llm_info['top_p'],
|
||||
max_tokens=llm_info['max_tokens'],
|
||||
stream_options={'include_usage': True}
|
||||
)
|
||||
for chunk in stream:
|
||||
try:
|
||||
chunk_msg = self.__parseChunk(chunk)
|
||||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||
yield msg_frame
|
||||
except LLMResponseEnd:
|
||||
msg_frame = {"is_end":True,"code":200,"msg":""}
|
||||
assistant.token = self.token
|
||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||
msg_frame['code'] = '201'
|
||||
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||
yield msg_frame
|
||||
except GeneratorExit:
|
||||
stream.response.close()
|
||||
|
||||
def __get_model(self, llm_info):
|
||||
if llm_info['model'] == 'doubao-4k-lite':
|
||||
|
@ -206,36 +210,39 @@ class ZHIPU_LLM(LLM):
|
|||
self.token = 0
|
||||
self.client = ZhipuAI(api_key=Config.ZHIPU_LLM.API_KEY)
|
||||
|
||||
async def chat(self, assistant, prompt):
|
||||
llm_info = json.loads(assistant.llm_info)
|
||||
|
||||
# 临时方案,后续需要修改
|
||||
user_info = json.loads(assistant.user_info)
|
||||
if user_info['llm_type'] == "MIXED":
|
||||
llm_info['model'] = 'glm-4-air'
|
||||
def chat(self, assistant, prompt):
|
||||
try:
|
||||
llm_info = json.loads(assistant.llm_info)
|
||||
|
||||
messages = json.loads(assistant.messages)
|
||||
messages.append({'role':'user','content':prompt})
|
||||
stream = self.client.chat.completions.create(
|
||||
model = llm_info['model'],
|
||||
messages=messages,
|
||||
stream=True,
|
||||
temperature=llm_info['temperature'],
|
||||
top_p=llm_info['top_p'],
|
||||
max_tokens=llm_info['max_tokens']
|
||||
)
|
||||
for chunk in stream:
|
||||
try:
|
||||
chunk_msg = self.__parseChunk(chunk)
|
||||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||
yield msg_frame
|
||||
except LLMResponseEnd:
|
||||
msg_frame = {"is_end":True,"code":200,"msg":""}
|
||||
assistant.token = self.token
|
||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||
msg_frame['code'] = '201'
|
||||
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||
yield msg_frame
|
||||
# 临时方案,后续需要修改
|
||||
user_info = json.loads(assistant.user_info)
|
||||
if user_info['llm_type'] == "MIXED":
|
||||
llm_info['model'] = 'glm-4-air'
|
||||
|
||||
messages = json.loads(assistant.messages)
|
||||
messages.append({'role':'user','content':prompt})
|
||||
stream = self.client.chat.completions.create(
|
||||
model = llm_info['model'],
|
||||
messages=messages,
|
||||
stream=True,
|
||||
temperature=llm_info['temperature'],
|
||||
top_p=llm_info['top_p'],
|
||||
max_tokens=llm_info['max_tokens']
|
||||
)
|
||||
for chunk in stream:
|
||||
try:
|
||||
chunk_msg = self.__parseChunk(chunk)
|
||||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||
yield msg_frame
|
||||
except LLMResponseEnd:
|
||||
msg_frame = {"is_end":True,"code":200,"msg":""}
|
||||
assistant.token = self.token
|
||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||
msg_frame['code'] = '201'
|
||||
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||
yield msg_frame
|
||||
except GeneratorExit:
|
||||
stream.response.close()
|
||||
|
||||
def __parseChunk(self, llm_chunk):
|
||||
if llm_chunk.usage:
|
||||
|
@ -247,45 +254,38 @@ class ZHIPU_LLM(LLM):
|
|||
|
||||
class MIXED_LLM(LLM):
|
||||
def __init__(self):
|
||||
self.minimax = MINIMAX_LLM()
|
||||
# self.minimax = MINIMAX_LLM()
|
||||
self.volcengine = VOLCENGINE_LLM()
|
||||
self.zhipu = ZHIPU_LLM()
|
||||
|
||||
async def chat(self, assistant, prompt):
|
||||
minimax_result = self.minimax.chat(assistant, prompt)
|
||||
volcengine_result = self.volcengine.chat(assistant, prompt)
|
||||
zhipu_result = self.zhipu.chat(assistant, prompt)
|
||||
|
||||
minimax_task = asyncio.create_task(minimax_result.__anext__())
|
||||
volcengine_task = asyncio.create_task(volcengine_result.__anext__())
|
||||
zhipu_task = asyncio.create_task(zhipu_result.__anext__())
|
||||
|
||||
done, pending = await asyncio.wait([minimax_task, volcengine_task, zhipu_task], return_when=asyncio.FIRST_COMPLETED)
|
||||
first_task = done.pop()
|
||||
if first_task == minimax_task:
|
||||
logger.debug("使用MINIMAX模型")
|
||||
yield await minimax_task
|
||||
while True:
|
||||
try:
|
||||
yield await minimax_result.__anext__()
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
elif first_task == volcengine_task:
|
||||
logger.debug("使用豆包模型")
|
||||
yield await volcengine_task
|
||||
while True:
|
||||
try:
|
||||
yield await volcengine_result.__anext__()
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
elif first_task == zhipu_task:
|
||||
logger.debug("使用智谱模型")
|
||||
yield await zhipu_task
|
||||
while True:
|
||||
try:
|
||||
yield await zhipu_result.__anext__()
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
def chat_wrapper(self,chat_func, assistant, prompt, event, result_container):
|
||||
gen = chat_func(assistant, prompt)
|
||||
first_chunk = next(gen)
|
||||
if event.is_set():
|
||||
raise GeneratorExit()
|
||||
else:
|
||||
result_container.append({
|
||||
'first_chunk': first_chunk,
|
||||
'gen': gen
|
||||
})
|
||||
|
||||
def chat(self, assistant, prompt):
|
||||
event = threading.Event()
|
||||
result_container = [] #大模型返回结果
|
||||
volc_thread = threading.Thread(target=self.chat_wrapper, args=(self.volcengine.chat, assistant, prompt, event, result_container))
|
||||
zhipu_thread = threading.Thread(target=self.chat_wrapper, args=(self.zhipu.chat, assistant, prompt, event, result_container))
|
||||
volc_thread.start()
|
||||
zhipu_thread.start()
|
||||
while True:
|
||||
if result_container:
|
||||
yield result_container[0]['first_chunk']
|
||||
for chunk in result_container[0]['gen']:
|
||||
yield chunk
|
||||
break
|
||||
time.sleep(0.05)
|
||||
volc_thread.join()
|
||||
zhipu_thread.join()
|
||||
|
||||
|
||||
class VITS_TTS(TTS):
|
||||
def __init__(self):
|
||||
|
|
5
main.py
5
main.py
|
@ -209,10 +209,11 @@ async def streaming_chat(ws: WebSocket):
|
|||
llm_text = ""
|
||||
logger.debug("开始进行ASR识别")
|
||||
while len(asr_results)==0:
|
||||
chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=2))
|
||||
chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=1))
|
||||
if assistant is None:
|
||||
with get_db_context() as db: #使用with语句获取数据库连接,自动关闭数据库连接
|
||||
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
|
||||
logger.debug(f"接收到{assistant.name}的请求")
|
||||
if assistant is None:
|
||||
raise SessionNotFoundError()
|
||||
user_info = json.loads(assistant.user_info)
|
||||
|
@ -230,7 +231,7 @@ async def streaming_chat(ws: WebSocket):
|
|||
start_time = time.time()
|
||||
is_first_response = True
|
||||
|
||||
async for llm_frame in llm_frames:
|
||||
for llm_frame in llm_frames:
|
||||
if is_first_response:
|
||||
end_time = time.time()
|
||||
logger.debug(f"第一帧返回耗时:{round(end_time-start_time,3)}s")
|
||||
|
|
|
@ -12,4 +12,5 @@ cn2an
|
|||
numba
|
||||
librosa
|
||||
aiohttp
|
||||
'volcengine-python-sdk[ark]'
|
||||
'volcengine-python-sdk[ark]'
|
||||
zhipuai
|
Loading…
Reference in New Issue