forked from killua/TakwayDisplayPlatform
#update: Mixed模型
This commit is contained in:
parent
55e2d7209c
commit
3f78fcb12c
224
app/concrete.py
224
app/concrete.py
|
@ -8,10 +8,12 @@ from .exception import *
|
||||||
from .dependency import get_logger
|
from .dependency import get_logger
|
||||||
from utils.vits_utils import TextToSpeech
|
from utils.vits_utils import TextToSpeech
|
||||||
from config import Config
|
from config import Config
|
||||||
import aiohttp
|
import threading
|
||||||
|
import requests
|
||||||
import asyncio
|
import asyncio
|
||||||
import struct
|
import struct
|
||||||
import base64
|
import base64
|
||||||
|
import time
|
||||||
import json
|
import json
|
||||||
|
|
||||||
# ----------- 初始化vits ----------- #
|
# ----------- 初始化vits ----------- #
|
||||||
|
@ -85,7 +87,7 @@ class MINIMAX_LLM(LLM):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.token = 0
|
self.token = 0
|
||||||
|
|
||||||
async def chat(self, assistant, prompt):
|
def chat(self, assistant, prompt):
|
||||||
llm_info = json.loads(assistant.llm_info)
|
llm_info = json.loads(assistant.llm_info)
|
||||||
|
|
||||||
# 临时方案,后续需要修改
|
# 临时方案,后续需要修改
|
||||||
|
@ -107,20 +109,19 @@ class MINIMAX_LLM(LLM):
|
||||||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
async with aiohttp.ClientSession() as client:
|
response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) #调用大模型
|
||||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #调用大模型
|
for chunk in response.iter_lines():
|
||||||
async for chunk in response.content.iter_any():
|
try:
|
||||||
try:
|
chunk_msg = self.__parseChunk(chunk) #解析llm返回
|
||||||
chunk_msg = self.__parseChunk(chunk) #解析llm返回
|
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
yield msg_frame
|
||||||
yield msg_frame
|
except LLMResponseEnd:
|
||||||
except LLMResponseEnd:
|
msg_frame = {"is_end":True,"code":200,"msg":""}
|
||||||
msg_frame = {"is_end":True,"code":200,"msg":""}
|
assistant.token = self.token
|
||||||
assistant.token = self.token
|
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
msg_frame['code'] = '201'
|
||||||
msg_frame['code'] = '201'
|
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||||
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
yield msg_frame
|
||||||
yield msg_frame
|
|
||||||
|
|
||||||
|
|
||||||
def __parseChunk(self, llm_chunk):
|
def __parseChunk(self, llm_chunk):
|
||||||
|
@ -150,38 +151,41 @@ class VOLCENGINE_LLM(LLM):
|
||||||
self.token = 0
|
self.token = 0
|
||||||
self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY)
|
self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY)
|
||||||
|
|
||||||
async def chat(self, assistant, prompt):
|
def chat(self, assistant, prompt):
|
||||||
llm_info = json.loads(assistant.llm_info)
|
try:
|
||||||
|
llm_info = json.loads(assistant.llm_info)
|
||||||
# 临时方案,后续需要修改
|
|
||||||
user_info = json.loads(assistant.user_info)
|
|
||||||
if user_info['llm_type'] == "MIXED":
|
|
||||||
llm_info['model'] = 'doubao-32k-lite'
|
|
||||||
|
|
||||||
model = self.__get_model(llm_info)
|
# 临时方案,后续需要修改
|
||||||
messages = json.loads(assistant.messages)
|
user_info = json.loads(assistant.user_info)
|
||||||
messages.append({'role':'user','content':prompt})
|
if user_info['llm_type'] == "MIXED":
|
||||||
stream = self.client.chat.completions.create(
|
llm_info['model'] = 'doubao-32k-lite'
|
||||||
model = model,
|
|
||||||
messages=messages,
|
model = self.__get_model(llm_info)
|
||||||
stream=True,
|
messages = json.loads(assistant.messages)
|
||||||
temperature=llm_info['temperature'],
|
messages.append({'role':'user','content':prompt})
|
||||||
top_p=llm_info['top_p'],
|
stream = self.client.chat.completions.create(
|
||||||
max_tokens=llm_info['max_tokens'],
|
model = model,
|
||||||
stream_options={'include_usage': True}
|
messages=messages,
|
||||||
)
|
stream=True,
|
||||||
for chunk in stream:
|
temperature=llm_info['temperature'],
|
||||||
try:
|
top_p=llm_info['top_p'],
|
||||||
chunk_msg = self.__parseChunk(chunk)
|
max_tokens=llm_info['max_tokens'],
|
||||||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
stream_options={'include_usage': True}
|
||||||
yield msg_frame
|
)
|
||||||
except LLMResponseEnd:
|
for chunk in stream:
|
||||||
msg_frame = {"is_end":True,"code":200,"msg":""}
|
try:
|
||||||
assistant.token = self.token
|
chunk_msg = self.__parseChunk(chunk)
|
||||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||||
msg_frame['code'] = '201'
|
yield msg_frame
|
||||||
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
except LLMResponseEnd:
|
||||||
yield msg_frame
|
msg_frame = {"is_end":True,"code":200,"msg":""}
|
||||||
|
assistant.token = self.token
|
||||||
|
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||||
|
msg_frame['code'] = '201'
|
||||||
|
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||||
|
yield msg_frame
|
||||||
|
except GeneratorExit:
|
||||||
|
stream.response.close()
|
||||||
|
|
||||||
def __get_model(self, llm_info):
|
def __get_model(self, llm_info):
|
||||||
if llm_info['model'] == 'doubao-4k-lite':
|
if llm_info['model'] == 'doubao-4k-lite':
|
||||||
|
@ -206,36 +210,39 @@ class ZHIPU_LLM(LLM):
|
||||||
self.token = 0
|
self.token = 0
|
||||||
self.client = ZhipuAI(api_key=Config.ZHIPU_LLM.API_KEY)
|
self.client = ZhipuAI(api_key=Config.ZHIPU_LLM.API_KEY)
|
||||||
|
|
||||||
async def chat(self, assistant, prompt):
|
def chat(self, assistant, prompt):
|
||||||
llm_info = json.loads(assistant.llm_info)
|
try:
|
||||||
|
llm_info = json.loads(assistant.llm_info)
|
||||||
# 临时方案,后续需要修改
|
|
||||||
user_info = json.loads(assistant.user_info)
|
|
||||||
if user_info['llm_type'] == "MIXED":
|
|
||||||
llm_info['model'] = 'glm-4-air'
|
|
||||||
|
|
||||||
messages = json.loads(assistant.messages)
|
# 临时方案,后续需要修改
|
||||||
messages.append({'role':'user','content':prompt})
|
user_info = json.loads(assistant.user_info)
|
||||||
stream = self.client.chat.completions.create(
|
if user_info['llm_type'] == "MIXED":
|
||||||
model = llm_info['model'],
|
llm_info['model'] = 'glm-4-air'
|
||||||
messages=messages,
|
|
||||||
stream=True,
|
messages = json.loads(assistant.messages)
|
||||||
temperature=llm_info['temperature'],
|
messages.append({'role':'user','content':prompt})
|
||||||
top_p=llm_info['top_p'],
|
stream = self.client.chat.completions.create(
|
||||||
max_tokens=llm_info['max_tokens']
|
model = llm_info['model'],
|
||||||
)
|
messages=messages,
|
||||||
for chunk in stream:
|
stream=True,
|
||||||
try:
|
temperature=llm_info['temperature'],
|
||||||
chunk_msg = self.__parseChunk(chunk)
|
top_p=llm_info['top_p'],
|
||||||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
max_tokens=llm_info['max_tokens']
|
||||||
yield msg_frame
|
)
|
||||||
except LLMResponseEnd:
|
for chunk in stream:
|
||||||
msg_frame = {"is_end":True,"code":200,"msg":""}
|
try:
|
||||||
assistant.token = self.token
|
chunk_msg = self.__parseChunk(chunk)
|
||||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||||
msg_frame['code'] = '201'
|
yield msg_frame
|
||||||
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
except LLMResponseEnd:
|
||||||
yield msg_frame
|
msg_frame = {"is_end":True,"code":200,"msg":""}
|
||||||
|
assistant.token = self.token
|
||||||
|
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||||
|
msg_frame['code'] = '201'
|
||||||
|
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||||
|
yield msg_frame
|
||||||
|
except GeneratorExit:
|
||||||
|
stream.response.close()
|
||||||
|
|
||||||
def __parseChunk(self, llm_chunk):
|
def __parseChunk(self, llm_chunk):
|
||||||
if llm_chunk.usage:
|
if llm_chunk.usage:
|
||||||
|
@ -247,45 +254,38 @@ class ZHIPU_LLM(LLM):
|
||||||
|
|
||||||
class MIXED_LLM(LLM):
|
class MIXED_LLM(LLM):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.minimax = MINIMAX_LLM()
|
# self.minimax = MINIMAX_LLM()
|
||||||
self.volcengine = VOLCENGINE_LLM()
|
self.volcengine = VOLCENGINE_LLM()
|
||||||
self.zhipu = ZHIPU_LLM()
|
self.zhipu = ZHIPU_LLM()
|
||||||
|
|
||||||
async def chat(self, assistant, prompt):
|
def chat_wrapper(self,chat_func, assistant, prompt, event, result_container):
|
||||||
minimax_result = self.minimax.chat(assistant, prompt)
|
gen = chat_func(assistant, prompt)
|
||||||
volcengine_result = self.volcengine.chat(assistant, prompt)
|
first_chunk = next(gen)
|
||||||
zhipu_result = self.zhipu.chat(assistant, prompt)
|
if event.is_set():
|
||||||
|
raise GeneratorExit()
|
||||||
minimax_task = asyncio.create_task(minimax_result.__anext__())
|
else:
|
||||||
volcengine_task = asyncio.create_task(volcengine_result.__anext__())
|
result_container.append({
|
||||||
zhipu_task = asyncio.create_task(zhipu_result.__anext__())
|
'first_chunk': first_chunk,
|
||||||
|
'gen': gen
|
||||||
done, pending = await asyncio.wait([minimax_task, volcengine_task, zhipu_task], return_when=asyncio.FIRST_COMPLETED)
|
})
|
||||||
first_task = done.pop()
|
|
||||||
if first_task == minimax_task:
|
def chat(self, assistant, prompt):
|
||||||
logger.debug("使用MINIMAX模型")
|
event = threading.Event()
|
||||||
yield await minimax_task
|
result_container = [] #大模型返回结果
|
||||||
while True:
|
volc_thread = threading.Thread(target=self.chat_wrapper, args=(self.volcengine.chat, assistant, prompt, event, result_container))
|
||||||
try:
|
zhipu_thread = threading.Thread(target=self.chat_wrapper, args=(self.zhipu.chat, assistant, prompt, event, result_container))
|
||||||
yield await minimax_result.__anext__()
|
volc_thread.start()
|
||||||
except StopAsyncIteration:
|
zhipu_thread.start()
|
||||||
break
|
while True:
|
||||||
elif first_task == volcengine_task:
|
if result_container:
|
||||||
logger.debug("使用豆包模型")
|
yield result_container[0]['first_chunk']
|
||||||
yield await volcengine_task
|
for chunk in result_container[0]['gen']:
|
||||||
while True:
|
yield chunk
|
||||||
try:
|
break
|
||||||
yield await volcengine_result.__anext__()
|
time.sleep(0.05)
|
||||||
except StopAsyncIteration:
|
volc_thread.join()
|
||||||
break
|
zhipu_thread.join()
|
||||||
elif first_task == zhipu_task:
|
|
||||||
logger.debug("使用智谱模型")
|
|
||||||
yield await zhipu_task
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
yield await zhipu_result.__anext__()
|
|
||||||
except StopAsyncIteration:
|
|
||||||
break
|
|
||||||
|
|
||||||
class VITS_TTS(TTS):
|
class VITS_TTS(TTS):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
5
main.py
5
main.py
|
@ -209,10 +209,11 @@ async def streaming_chat(ws: WebSocket):
|
||||||
llm_text = ""
|
llm_text = ""
|
||||||
logger.debug("开始进行ASR识别")
|
logger.debug("开始进行ASR识别")
|
||||||
while len(asr_results)==0:
|
while len(asr_results)==0:
|
||||||
chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=2))
|
chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=1))
|
||||||
if assistant is None:
|
if assistant is None:
|
||||||
with get_db_context() as db: #使用with语句获取数据库连接,自动关闭数据库连接
|
with get_db_context() as db: #使用with语句获取数据库连接,自动关闭数据库连接
|
||||||
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
|
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
|
||||||
|
logger.debug(f"接收到{assistant.name}的请求")
|
||||||
if assistant is None:
|
if assistant is None:
|
||||||
raise SessionNotFoundError()
|
raise SessionNotFoundError()
|
||||||
user_info = json.loads(assistant.user_info)
|
user_info = json.loads(assistant.user_info)
|
||||||
|
@ -230,7 +231,7 @@ async def streaming_chat(ws: WebSocket):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
is_first_response = True
|
is_first_response = True
|
||||||
|
|
||||||
async for llm_frame in llm_frames:
|
for llm_frame in llm_frames:
|
||||||
if is_first_response:
|
if is_first_response:
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.debug(f"第一帧返回耗时:{round(end_time-start_time,3)}s")
|
logger.debug(f"第一帧返回耗时:{round(end_time-start_time,3)}s")
|
||||||
|
|
|
@ -12,4 +12,5 @@ cn2an
|
||||||
numba
|
numba
|
||||||
librosa
|
librosa
|
||||||
aiohttp
|
aiohttp
|
||||||
'volcengine-python-sdk[ark]'
|
'volcengine-python-sdk[ark]'
|
||||||
|
zhipuai
|
Loading…
Reference in New Issue