from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv from .model import Assistant from .abstract import * from .public import * from utils.vits_utils import TextToSpeech from config import Config import aiohttp import asyncio import struct import base64 import json # ----------- 初始化vits ----------- # vits = TextToSpeech() # ---------------------------------- # #------ 具体 ASR, LLM, TTS 类 ------ # FIRST_FRAME = 1 CONTINUE_FRAME =2 LAST_FRAME =3 class XF_ASR(ASR): def __init__(self): self.websocket = None self.current_message = "" self.audio = "" self.status = FIRST_FRAME self.segment_duration_threshold = 25 #超时时间为25秒 self.segment_start_time = None async def stream_recognize(self, chunk): if self.websocket is None: #如果websocket未建立,则建立一个新的连接 self.websocket = await xf_asr_websocket_factory() if self.segment_start_time is None: #如果是第一段,则记录开始时间 self.segment_start_time = asyncio.get_event_loop().time() if chunk['meta_info']['is_end']: self.status = LAST_FRAME audio_data = chunk['audio'] self.audio += audio_data if self.status == FIRST_FRAME: #发送第一帧 await self.websocket.send(make_first_frame(audio_data)) self.status = CONTINUE_FRAME elif self.status == CONTINUE_FRAME: #发送中间帧 await self.websocket.send(make_continue_frame(audio_data)) elif self.status == LAST_FRAME: #发送最后一帧 await self.websocket.send(make_last_frame(audio_data)) self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv())) if self.current_message == "": raise AsrResultNoneError() await self.websocket.close() print("语音识别结束,用户消息:", self.current_message) return [{"text":self.current_message, "audio":self.audio}] current_time = asyncio.get_event_loop().time() if current_time - self.segment_start_time > self.segment_duration_threshold: #超时,发送最后一帧并重置状态 await self.websocket.send(make_last_frame()) self.current_message += parse_xfasr_recv(await self.websocket.recv()) await self.websocket.close() self.websocket = await xf_asr_websocket_factory() self.status = FIRST_FRAME self.segment_start_time = current_time return [] class MINIMAX_LLM(LLM): def __init__(self): self.token = 0 async def chat(self, assistant, prompt, db): llm_info = json.loads(assistant.llm_info) messages = json.loads(assistant.messages) messages.append({"role":"user","content":prompt}) assistant.messages = json.dumps(messages) payload = json.dumps({ "model": llm_info['model'], "stream": True, "messages": messages, "max_tokens": llm_info['max_tokens'], "temperature": llm_info['temperature'], "top_p": llm_info['top_p'], }) headers = { 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Content-Type': 'application/json' } async with aiohttp.ClientSession() as client: async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: async for chunk in response.content.iter_any(): try: chunk_msg = self.__parseChunk(chunk) msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} yield msg_frame except LLMResponseEnd: msg_frame = {"is_end":True,"msg":""} if self.token > llm_info['max_tokens'] * 0.8: msg_frame['msg'] = 'max_token reached' as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first() as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) db.commit() assistant.messages = as_query.messages yield msg_frame def __parseChunk(self, llm_chunk): result = "" data=json.loads(llm_chunk.decode('utf-8')[6:]) if data["object"] == "chat.completion": #如果是结束帧 self.token = data['usage']['total_tokens'] raise LLMResponseEnd() elif data['object'] == 'chat.completion.chunk': for choice in data['choices']: result += choice['delta']['content'] else: raise UnkownLLMFrame() return result class VITS_TTS(TTS): def __init__(self): pass def synthetize(self, assistant, text): tts_info = json.loads(assistant.tts_info) return vits.synthesize(text, tts_info) # --------------------------------- # # ------ ASR, LLM, TTS 工厂类 ------ # class ASRFactory: def create_asr(self,asr_type:str) -> ASR: if asr_type == 'XF': return XF_ASR() class LLMFactory: def create_llm(self,llm_type:str) -> LLM: if llm_type == 'MINIMAX': return MINIMAX_LLM() class TTSFactory: def create_tts(self,tts_type:str) -> TTS: if tts_type == 'VITS': return VITS_TTS() # --------------------------------- # # ----------- 具体服务类 ----------- # # 从单说话人的asr_results中取出第一个结果作为prompt class BasicPromptService(PromptService): def prompt_process(self, prompt:str, **kwargs): if 'asr_result' in kwargs: return kwargs['asr_result'][0]['text'], kwargs else: raise NoAsrResultsError() class UserAudioRecordService(UserAudioService): def user_audio_process(self, audio:str, **kwargs): audio_decoded = base64.b64decode(audio) kwargs['recorder'].user_audio += audio_decoded return audio,kwargs class TTSAudioRecordService(TTSAudioService): def tts_audio_process(self, audio:bytes, **kwargs): kwargs['recorder'].tts_audio += audio return audio,kwargs # --------------------------------- # # ------------ 服务链类 ----------- # class UserAudioServiceChain: def __init__(self): self.services = [] def add_service(self, service:UserAudioService): self.services.append(service) def user_audio_process(self, audio, **kwargs): for service in self.services: audio, kwargs = service.user_audio_process(audio, **kwargs) return audio class PromptServiceChain: def __init__(self): self.services = [] def add_service(self, service:PromptService): self.services.append(service) def prompt_process(self, asr_results): kwargs = {"asr_result":asr_results} prompt = "" for service in self.services: prompt, kwargs = service.prompt_process(prompt, **kwargs) return prompt class LLMMsgServiceChain: def __init__(self, ): self.services = [] self.spliter = SentenceSegmentation() def add_service(self, service:LLMMsgService): self.services.append(service) def llm_msg_process(self, llm_msg): kwargs = {} llm_chunks = self.spliter.split(llm_msg) for service in self.services: llm_chunks , kwargs = service.llm_msg_process(llm_chunks, **kwargs) return llm_chunks class TTSAudioServiceChain: def __init__(self): self.services = [] def add_service(self, service:TTSAudioService): self.services.append(service) def tts_audio_process(self, audio, **kwargs): for service in self.services: audio, kwargs = service.tts_audio_process(audio, **kwargs) return audio # --------------------------------- # class Agent(): def __init__(self, asr_type:str, llm_type:str, tts_type:str): asrFactory = ASRFactory() llmFactory = LLMFactory() ttsFactory = TTSFactory() self.recorder = None self.user_audio_service_chain = UserAudioServiceChain() #创建用户音频处理服务链 self.user_audio_service_chain.add_service(UserAudioRecordService()) #添加用户音频记录服务 self.asr = asrFactory.create_asr(asr_type) self.prompt_service_chain = PromptServiceChain() self.prompt_service_chain.add_service(BasicPromptService()) self.llm = llmFactory.create_llm(llm_type) self.llm_chunk_service_chain = LLMMsgServiceChain() self.tts = ttsFactory.create_tts(tts_type) self.tts_audio_service_chain = TTSAudioServiceChain() self.tts_audio_service_chain.add_service(TTSAudioRecordService()) def init_recorder(self,user_id): self.recorder = Recorder(user_id) # 对用户输入的音频进行预处理 def user_audio_process(self, audio, db): return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder) # 进行流式语音识别 async def stream_recognize(self, chunk, db): return await self.asr.stream_recognize(chunk) # 进行Prompt加工 def prompt_process(self, asr_results, db): return self.prompt_service_chain.prompt_process(asr_results) # 进行大模型调用 async def chat(self, assistant ,prompt, db): return self.llm.chat(assistant, prompt, db) # 对大模型的返回进行处理 def llm_msg_process(self, llm_chunk, db): return self.llm_chunk_service_chain.llm_msg_process(llm_chunk) # 进行TTS合成 def synthetize(self, assistant, text, db): return self.tts.synthetize(assistant, text) # 对合成后的音频进行处理 def tts_audio_process(self, audio, db): return self.tts_audio_service_chain.tts_audio_process(audio, recorder=self.recorder) # 编码 def encode(self, text, audio): text_resp = {"type":"text","code":200,"msg":text} text_resp_bytes = json.dumps(text_resp,ensure_ascii=False).encode('utf-8') header = struct.pack('!II',len(text_resp_bytes),len(audio)) final_resp = header + text_resp_bytes + audio return final_resp # 保存音频 def save(self): self.recorder.write()