from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv from volcenginesdkarkruntime import Ark from .model import Assistant from .abstract import * from .public import * from .exception import * from .dependency import get_logger from utils.vits_utils import TextToSpeech from config import Config import aiohttp import asyncio import struct import base64 import json # ----------- 初始化vits ----------- # vits = TextToSpeech() # ---------------------------------- # # ---------- 初始化logger ---------- # logger = get_logger() # ---------------------------------- # #------ 具体 ASR, LLM, TTS 类 ------ # FIRST_FRAME = 1 CONTINUE_FRAME =2 LAST_FRAME =3 class XF_ASR(ASR): def __init__(self): self.websocket = None self.current_message = "" self.audio = "" self.status = FIRST_FRAME self.segment_duration_threshold = 25 #超时时间为25秒 self.segment_start_time = None async def stream_recognize(self, chunk): if self.status == FIRST_FRAME and chunk['meta_info']['is_end']: #如果是第一帧,且为end,则判断为杂音 raise SideNoiseError() if self.websocket is None: #如果websocket未建立,则建立一个新的连接 self.websocket = await xf_asr_websocket_factory() if self.segment_start_time is None: #如果是第一段,则记录开始时间 self.segment_start_time = asyncio.get_event_loop().time() if chunk['meta_info']['is_end']: self.status = LAST_FRAME audio_data = chunk['audio'] self.audio += audio_data if self.status == FIRST_FRAME: #发送第一帧 await self.websocket.send(make_first_frame(audio_data)) self.status = CONTINUE_FRAME elif self.status == CONTINUE_FRAME: #发送中间帧 await self.websocket.send(make_continue_frame(audio_data)) elif self.status == LAST_FRAME: #发送最后一帧 await self.websocket.send(make_last_frame(audio_data)) self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv())) if self.current_message == "": raise AsrResultNoneError() asyncio.create_task(self.websocket.close()) return [{"text":self.current_message, "audio":self.audio}] current_time = asyncio.get_event_loop().time() if current_time - self.segment_start_time > self.segment_duration_threshold: #超时,发送最后一帧并重置状态 await self.websocket.send(make_last_frame()) self.current_message += parse_xfasr_recv(await self.websocket.recv()) await self.websocket.close() self.websocket = await xf_asr_websocket_factory() self.status = FIRST_FRAME self.segment_start_time = current_time return [] class MINIMAX_LLM(LLM): def __init__(self): self.token = 0 async def chat(self, assistant, prompt, db): llm_info = json.loads(assistant.llm_info) messages = json.loads(assistant.messages) messages.append({'role':'user','content':prompt}) payload = json.dumps({ #整理payload "model": llm_info['model'], "stream": True, "messages": messages, "max_tokens": llm_info['max_tokens'], "temperature": llm_info['temperature'], "top_p": llm_info['top_p'], }) headers = { 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Content-Type': 'application/json' } async with aiohttp.ClientSession() as client: async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #调用大模型 async for chunk in response.content.iter_any(): try: chunk_msg = self.__parseChunk(chunk) #解析llm返回 msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} yield msg_frame except LLMResponseEnd: msg_frame = {"is_end":True,"code":200,"msg":""} if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session msg_frame['code'] = '201' as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first() as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) db.commit() assistant.messages = as_query.messages yield msg_frame def __parseChunk(self, llm_chunk): try: result = "" chunk_decoded = llm_chunk.decode('utf-8') chunks = chunk_decoded.split('\n\n') for chunk in chunks: if not chunk: continue data=json.loads(chunk[6:]) if data["object"] == "chat.completion": #如果是结束帧 self.token = data['usage']['total_tokens'] raise LLMResponseEnd() elif data['object'] == 'chat.completion.chunk': for choice in data['choices']: result += choice['delta']['content'] else: raise UnkownLLMFrame() return result except (json.JSONDecodeError, KeyError): logger.error(llm_chunk) raise AbnormalLLMFrame(f"error llm_chunk:{llm_chunk}") class VOLCENGINE_LLM(LLM): def __init__(self): self.token = 0 self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY) async def chat(self, assistant, prompt, db): llm_info = json.loads(assistant.llm_info) model = self.__get_model(llm_info) messages = json.loads(assistant.messages) messages.append({'role':'user','content':prompt}) stream = self.client.chat.completions.create( model = model, messages=messages, stream=True, stream_options={'include_usage': True} ) for chunk in stream: try: chunk_msg = self.__parseChunk(chunk) msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} yield msg_frame except LLMResponseEnd: msg_frame = {"is_end":True,"code":20-0,"msg":""} if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session msg_frame['code'] = '201' as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first() as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) db.commit() assistant.messages = as_query.messages yield msg_frame def __get_model(self, llm_info): if llm_info['model'] == 'doubao-4k-lite': return Config.VOLCENGINE_LLM.DOUBAO_LITE_4k else: raise UnknownVolcEngineModelError() def __parseChunk(self, llm_chunk): if llm_chunk.usage: self.token = llm_chunk.usage.total_tokens raise LLMResponseEnd() if not llm_chunk.choices: raise AbnormalLLMFrame(f"error volcengine llm_chunk:{llm_chunk}") return llm_chunk.choices[0].delta.content class VITS_TTS(TTS): def __init__(self): pass def synthetize(self, assistant, text): tts_info = json.loads(assistant.tts_info) return vits.synthesize(text, tts_info) # --------------------------------- # # ------ ASR, LLM, TTS 工厂类 ------ # class ASRFactory: def create_asr(self,asr_type:str) -> ASR: if asr_type == 'XF': return XF_ASR() class LLMFactory: def create_llm(self,llm_type:str) -> LLM: if llm_type == 'MINIMAX': return MINIMAX_LLM() if llm_type == 'VOLCENGINE': return VOLCENGINE_LLM() class TTSFactory: def create_tts(self,tts_type:str) -> TTS: if tts_type == 'VITS': return VITS_TTS() # --------------------------------- # # ----------- 具体服务类 ----------- # # 从单说话人的asr_results中取出第一个结果作为prompt class BasicPromptService(PromptService): def prompt_process(self, prompt:str, **kwargs): if 'asr_result' in kwargs: return kwargs['asr_result'][0]['text'], kwargs else: raise NoAsrResultsError() # 记录用户音频 class UserAudioRecordService(UserAudioService): def user_audio_process(self, audio:str, **kwargs): audio_decoded = base64.b64decode(audio) kwargs['recorder'].user_audio += audio_decoded return audio,kwargs # 记录TTS音频 class TTSAudioRecordService(TTSAudioService): def tts_audio_process(self, audio:bytes, **kwargs): kwargs['recorder'].tts_audio += audio return audio,kwargs # --------------------------------- # # ------------ 服务链类 ----------- # class UserAudioServiceChain: def __init__(self): self.services = [] def add_service(self, service:UserAudioService): self.services.append(service) def user_audio_process(self, audio, **kwargs): for service in self.services: audio, kwargs = service.user_audio_process(audio, **kwargs) return audio class PromptServiceChain: def __init__(self): self.services = [] def add_service(self, service:PromptService): self.services.append(service) def prompt_process(self, asr_results): kwargs = {"asr_result":asr_results} prompt = "" for service in self.services: prompt, kwargs = service.prompt_process(prompt, **kwargs) return prompt class LLMMsgServiceChain: def __init__(self, ): self.services = [] self.spliter = SentenceSegmentation() def add_service(self, service:LLMMsgService): self.services.append(service) def llm_msg_process(self, llm_msg): kwargs = {} llm_chunks = self.spliter.split(llm_msg) #首先对llm返回进行断句 for service in self.services: llm_chunks , kwargs = service.llm_msg_process(llm_chunks, **kwargs) return llm_chunks class TTSAudioServiceChain: def __init__(self): self.services = [] def add_service(self, service:TTSAudioService): self.services.append(service) def tts_audio_process(self, audio, **kwargs): for service in self.services: audio, kwargs = service.tts_audio_process(audio, **kwargs) return audio # --------------------------------- # class Agent(): def __init__(self, asr_type:str, llm_type:str, tts_type:str): asrFactory = ASRFactory() llmFactory = LLMFactory() ttsFactory = TTSFactory() self.recorder = None self.user_audio_service_chain = UserAudioServiceChain() #创建用户音频处理服务链 self.user_audio_service_chain.add_service(UserAudioRecordService()) #添加用户音频记录服务 self.asr = asrFactory.create_asr(asr_type) self.prompt_service_chain = PromptServiceChain() self.prompt_service_chain.add_service(BasicPromptService()) self.llm = llmFactory.create_llm(llm_type) self.llm_msg_service_chain = LLMMsgServiceChain() self.tts = ttsFactory.create_tts(tts_type) self.tts_audio_service_chain = TTSAudioServiceChain() self.tts_audio_service_chain.add_service(TTSAudioRecordService()) def init_recorder(self,user_id): self.recorder = Recorder(user_id) # 对用户输入的音频进行预处理 def user_audio_process(self, audio, db): return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder) # 进行流式语音识别 async def stream_recognize(self, chunk, db): return await self.asr.stream_recognize(chunk) # 进行Prompt加工 def prompt_process(self, asr_results, db): return self.prompt_service_chain.prompt_process(asr_results) # 进行大模型调用 async def chat(self, assistant ,prompt, db): return self.llm.chat(assistant, prompt, db) # 对大模型的返回进行处理 def llm_msg_process(self, llm_chunk, db): return self.llm_msg_service_chain.llm_msg_process(llm_chunk) # 进行TTS合成 def synthetize(self, assistant, text, db): return self.tts.synthetize(assistant, text) # 对合成后的音频进行处理 def tts_audio_process(self, audio, db): return self.tts_audio_service_chain.tts_audio_process(audio, recorder=self.recorder) # 编码 def encode(self, text, audio): text_resp = {"type":"text","code":200,"msg":text} text_resp_bytes = json.dumps(text_resp,ensure_ascii=False).encode('utf-8') header = struct.pack('!II',len(text_resp_bytes),len(audio)) final_resp = header + text_resp_bytes + audio return final_resp # 保存音频 def save(self): self.recorder.write()