diff --git a/app/concrete.py b/app/concrete.py index 5cc4041..032c3a0 100644 --- a/app/concrete.py +++ b/app/concrete.py @@ -8,10 +8,12 @@ from .exception import * from .dependency import get_logger from utils.vits_utils import TextToSpeech from config import Config -import aiohttp +import threading +import requests import asyncio import struct import base64 +import time import json # ----------- 初始化vits ----------- # @@ -85,7 +87,7 @@ class MINIMAX_LLM(LLM): def __init__(self): self.token = 0 - async def chat(self, assistant, prompt): + def chat(self, assistant, prompt): llm_info = json.loads(assistant.llm_info) # 临时方案,后续需要修改 @@ -107,20 +109,19 @@ class MINIMAX_LLM(LLM): 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Content-Type': 'application/json' } - async with aiohttp.ClientSession() as client: - async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #调用大模型 - async for chunk in response.content.iter_any(): - try: - chunk_msg = self.__parseChunk(chunk) #解析llm返回 - msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} - yield msg_frame - except LLMResponseEnd: - msg_frame = {"is_end":True,"code":200,"msg":""} - assistant.token = self.token - if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session - msg_frame['code'] = '201' - assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) - yield msg_frame + response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) #调用大模型 + for chunk in response.iter_lines(): + try: + chunk_msg = self.__parseChunk(chunk) #解析llm返回 + msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} + yield msg_frame + except LLMResponseEnd: + msg_frame = {"is_end":True,"code":200,"msg":""} + assistant.token = self.token + if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session + msg_frame['code'] = '201' + assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) + yield msg_frame def __parseChunk(self, llm_chunk): @@ -150,38 +151,41 @@ class VOLCENGINE_LLM(LLM): self.token = 0 self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY) - async def chat(self, assistant, prompt): - llm_info = json.loads(assistant.llm_info) - - # 临时方案,后续需要修改 - user_info = json.loads(assistant.user_info) - if user_info['llm_type'] == "MIXED": - llm_info['model'] = 'doubao-32k-lite' + def chat(self, assistant, prompt): + try: + llm_info = json.loads(assistant.llm_info) - model = self.__get_model(llm_info) - messages = json.loads(assistant.messages) - messages.append({'role':'user','content':prompt}) - stream = self.client.chat.completions.create( - model = model, - messages=messages, - stream=True, - temperature=llm_info['temperature'], - top_p=llm_info['top_p'], - max_tokens=llm_info['max_tokens'], - stream_options={'include_usage': True} - ) - for chunk in stream: - try: - chunk_msg = self.__parseChunk(chunk) - msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} - yield msg_frame - except LLMResponseEnd: - msg_frame = {"is_end":True,"code":200,"msg":""} - assistant.token = self.token - if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session - msg_frame['code'] = '201' - assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) - yield msg_frame + # 临时方案,后续需要修改 + user_info = json.loads(assistant.user_info) + if user_info['llm_type'] == "MIXED": + llm_info['model'] = 'doubao-32k-lite' + + model = self.__get_model(llm_info) + messages = json.loads(assistant.messages) + messages.append({'role':'user','content':prompt}) + stream = self.client.chat.completions.create( + model = model, + messages=messages, + stream=True, + temperature=llm_info['temperature'], + top_p=llm_info['top_p'], + max_tokens=llm_info['max_tokens'], + stream_options={'include_usage': True} + ) + for chunk in stream: + try: + chunk_msg = self.__parseChunk(chunk) + msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} + yield msg_frame + except LLMResponseEnd: + msg_frame = {"is_end":True,"code":200,"msg":""} + assistant.token = self.token + if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session + msg_frame['code'] = '201' + assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) + yield msg_frame + except GeneratorExit: + stream.response.close() def __get_model(self, llm_info): if llm_info['model'] == 'doubao-4k-lite': @@ -206,36 +210,39 @@ class ZHIPU_LLM(LLM): self.token = 0 self.client = ZhipuAI(api_key=Config.ZHIPU_LLM.API_KEY) - async def chat(self, assistant, prompt): - llm_info = json.loads(assistant.llm_info) - - # 临时方案,后续需要修改 - user_info = json.loads(assistant.user_info) - if user_info['llm_type'] == "MIXED": - llm_info['model'] = 'glm-4-air' + def chat(self, assistant, prompt): + try: + llm_info = json.loads(assistant.llm_info) - messages = json.loads(assistant.messages) - messages.append({'role':'user','content':prompt}) - stream = self.client.chat.completions.create( - model = llm_info['model'], - messages=messages, - stream=True, - temperature=llm_info['temperature'], - top_p=llm_info['top_p'], - max_tokens=llm_info['max_tokens'] - ) - for chunk in stream: - try: - chunk_msg = self.__parseChunk(chunk) - msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} - yield msg_frame - except LLMResponseEnd: - msg_frame = {"is_end":True,"code":200,"msg":""} - assistant.token = self.token - if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session - msg_frame['code'] = '201' - assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) - yield msg_frame + # 临时方案,后续需要修改 + user_info = json.loads(assistant.user_info) + if user_info['llm_type'] == "MIXED": + llm_info['model'] = 'glm-4-air' + + messages = json.loads(assistant.messages) + messages.append({'role':'user','content':prompt}) + stream = self.client.chat.completions.create( + model = llm_info['model'], + messages=messages, + stream=True, + temperature=llm_info['temperature'], + top_p=llm_info['top_p'], + max_tokens=llm_info['max_tokens'] + ) + for chunk in stream: + try: + chunk_msg = self.__parseChunk(chunk) + msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} + yield msg_frame + except LLMResponseEnd: + msg_frame = {"is_end":True,"code":200,"msg":""} + assistant.token = self.token + if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session + msg_frame['code'] = '201' + assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) + yield msg_frame + except GeneratorExit: + stream.response.close() def __parseChunk(self, llm_chunk): if llm_chunk.usage: @@ -247,45 +254,38 @@ class ZHIPU_LLM(LLM): class MIXED_LLM(LLM): def __init__(self): - self.minimax = MINIMAX_LLM() + # self.minimax = MINIMAX_LLM() self.volcengine = VOLCENGINE_LLM() self.zhipu = ZHIPU_LLM() - async def chat(self, assistant, prompt): - minimax_result = self.minimax.chat(assistant, prompt) - volcengine_result = self.volcengine.chat(assistant, prompt) - zhipu_result = self.zhipu.chat(assistant, prompt) - - minimax_task = asyncio.create_task(minimax_result.__anext__()) - volcengine_task = asyncio.create_task(volcengine_result.__anext__()) - zhipu_task = asyncio.create_task(zhipu_result.__anext__()) - - done, pending = await asyncio.wait([minimax_task, volcengine_task, zhipu_task], return_when=asyncio.FIRST_COMPLETED) - first_task = done.pop() - if first_task == minimax_task: - logger.debug("使用MINIMAX模型") - yield await minimax_task - while True: - try: - yield await minimax_result.__anext__() - except StopAsyncIteration: - break - elif first_task == volcengine_task: - logger.debug("使用豆包模型") - yield await volcengine_task - while True: - try: - yield await volcengine_result.__anext__() - except StopAsyncIteration: - break - elif first_task == zhipu_task: - logger.debug("使用智谱模型") - yield await zhipu_task - while True: - try: - yield await zhipu_result.__anext__() - except StopAsyncIteration: - break + def chat_wrapper(self,chat_func, assistant, prompt, event, result_container): + gen = chat_func(assistant, prompt) + first_chunk = next(gen) + if event.is_set(): + raise GeneratorExit() + else: + result_container.append({ + 'first_chunk': first_chunk, + 'gen': gen + }) + + def chat(self, assistant, prompt): + event = threading.Event() + result_container = [] #大模型返回结果 + volc_thread = threading.Thread(target=self.chat_wrapper, args=(self.volcengine.chat, assistant, prompt, event, result_container)) + zhipu_thread = threading.Thread(target=self.chat_wrapper, args=(self.zhipu.chat, assistant, prompt, event, result_container)) + volc_thread.start() + zhipu_thread.start() + while True: + if result_container: + yield result_container[0]['first_chunk'] + for chunk in result_container[0]['gen']: + yield chunk + break + time.sleep(0.05) + volc_thread.join() + zhipu_thread.join() + class VITS_TTS(TTS): def __init__(self): diff --git a/main.py b/main.py index b01ad48..93a3746 100644 --- a/main.py +++ b/main.py @@ -209,10 +209,11 @@ async def streaming_chat(ws: WebSocket): llm_text = "" logger.debug("开始进行ASR识别") while len(asr_results)==0: - chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=2)) + chunk = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=1)) if assistant is None: with get_db_context() as db: #使用with语句获取数据库连接,自动关闭数据库连接 assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first() + logger.debug(f"接收到{assistant.name}的请求") if assistant is None: raise SessionNotFoundError() user_info = json.loads(assistant.user_info) @@ -230,7 +231,7 @@ async def streaming_chat(ws: WebSocket): start_time = time.time() is_first_response = True - async for llm_frame in llm_frames: + for llm_frame in llm_frames: if is_first_response: end_time = time.time() logger.debug(f"第一帧返回耗时:{round(end_time-start_time,3)}s") diff --git a/requirements.txt b/requirements.txt index e6c81a2..db2e993 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ cn2an numba librosa aiohttp -'volcengine-python-sdk[ark]' \ No newline at end of file +'volcengine-python-sdk[ark]' +zhipuai \ No newline at end of file