From d885684533efc9a0ce64aec06cbf16a5a46c5d25 Mon Sep 17 00:00:00 2001 From: killua <1223086337@qq.com> Date: Mon, 10 Jun 2024 20:49:50 +0800 Subject: [PATCH] =?UTF-8?q?update:=E5=8A=A0=E4=BA=86=E4=B8=80=E7=82=B9?= =?UTF-8?q?=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/abstract.py | 2 +- app/concrete.py | 18 ++++++++++-------- main.py | 4 +++- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/app/abstract.py b/app/abstract.py index 8f64115..13c8eef 100644 --- a/app/abstract.py +++ b/app/abstract.py @@ -31,7 +31,7 @@ class PromptService(ABC): class LLMMsgService(ABC): @abstractmethod - def llm_msg_process(self, llm_chunk:list, **kwargs) -> list: + def llm_msg_process(self, llm_chunks:list, **kwargs) -> list: pass class TTSAudioService(ABC): diff --git a/app/concrete.py b/app/concrete.py index 6fce8cd..3ac654c 100644 --- a/app/concrete.py +++ b/app/concrete.py @@ -78,7 +78,7 @@ class MINIMAX_LLM(LLM): llm_info = json.loads(assistant.llm_info) messages = json.loads(assistant.messages) messages.append({'role':'user','content':prompt}) - payload = json.dumps({ + payload = json.dumps({ #整理payload "model": llm_info['model'], "stream": True, "messages": messages, @@ -91,15 +91,15 @@ class MINIMAX_LLM(LLM): 'Content-Type': 'application/json' } async with aiohttp.ClientSession() as client: - async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: + async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #调用大模型 async for chunk in response.content.iter_any(): try: - chunk_msg = self.__parseChunk(chunk) + chunk_msg = self.__parseChunk(chunk) #解析llm返回 msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} yield msg_frame except LLMResponseEnd: msg_frame = {"is_end":True,"msg":""} - if self.token > llm_info['max_tokens'] * 0.8: + if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session msg_frame['msg'] = 'max_token reached' as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first() as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) @@ -169,12 +169,14 @@ class BasicPromptService(PromptService): else: raise NoAsrResultsError() +# 记录用户音频 class UserAudioRecordService(UserAudioService): def user_audio_process(self, audio:str, **kwargs): audio_decoded = base64.b64decode(audio) kwargs['recorder'].user_audio += audio_decoded return audio,kwargs - + +# 记录TTS音频 class TTSAudioRecordService(TTSAudioService): def tts_audio_process(self, audio:bytes, **kwargs): kwargs['recorder'].tts_audio += audio @@ -214,7 +216,7 @@ class LLMMsgServiceChain: self.services.append(service) def llm_msg_process(self, llm_msg): kwargs = {} - llm_chunks = self.spliter.split(llm_msg) + llm_chunks = self.spliter.split(llm_msg) #首先对llm返回进行断句 for service in self.services: llm_chunks , kwargs = service.llm_msg_process(llm_chunks, **kwargs) return llm_chunks @@ -247,7 +249,7 @@ class Agent(): self.llm = llmFactory.create_llm(llm_type) - self.llm_chunk_service_chain = LLMMsgServiceChain() + self.llm_msg_service_chain = LLMMsgServiceChain() self.tts = ttsFactory.create_tts(tts_type) @@ -275,7 +277,7 @@ class Agent(): # 对大模型的返回进行处理 def llm_msg_process(self, llm_chunk, db): - return self.llm_chunk_service_chain.llm_msg_process(llm_chunk) + return self.llm_msg_service_chain.llm_msg_process(llm_chunk) # 进行TTS合成 def synthetize(self, assistant, text, db): diff --git a/main.py b/main.py index f88ab14..54cd84f 100644 --- a/main.py +++ b/main.py @@ -172,7 +172,7 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)): await ws.accept() logger.debug("WebSocket连接成功") try: - agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS) + agent = None assistant = None asr_results = [] llm_text = "" @@ -182,6 +182,8 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)): if assistant is None: assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first() agent.init_recorder(assistant.user_id) + if not agent: + agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS) chunk["audio"] = agent.user_audio_process(chunk["audio"], db) asr_results = await agent.stream_recognize(chunk, db) kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果