from ..schemas.chat import * from ..dependencies.logger import get_logger from .controller_enum import * from ..models import UserCharacter, Session, Character, User from utils.audio_utils import VAD from fastapi import WebSocket, HTTPException, status from datetime import datetime from utils.xf_asr_utils import generate_xf_asr_url from config import get_config import uuid import json import requests import asyncio # 依赖注入获取logger logger = get_logger() # --------------------初始化本地ASR----------------------- from utils.stt.funasr_utils import FunAutoSpeechRecognizer asr = FunAutoSpeechRecognizer() logger.info("本地ASR初始化成功") # ------------------------------------------------------- # --------------------初始化本地VITS---------------------- from utils.tts.vits_utils import TextToSpeech tts = TextToSpeech(device='cpu') logger.info("本地TTS初始化成功") # ------------------------------------------------------- # 依赖注入获取Config Config = get_config() # ----------------------工具函数------------------------- #获取session内容 def get_session_content(session_id,redis,db): session_content_str = "" if redis.exists(session_id): session_content_str = redis.get(session_id) else: session_db = db.query(Session).filter(Session.id == session_id).first() if not session_db: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found") session_content_str = session_db.content return json.loads(session_content_str) #解析大模型流式返回内容 def parseChunkDelta(chunk): decoded_data = chunk.decode('utf-8') parsed_data = json.loads(decoded_data[6:]) if 'delta' in parsed_data['choices'][0]: delta_content = parsed_data['choices'][0]['delta'] return delta_content['content'] else: return "" #断句函数 def split_string_with_punctuation(current_sentence,text,is_first): result = [] for char in text: current_sentence += char if is_first and char in ',.?!,。?!': result.append(current_sentence) current_sentence = '' is_first = False elif char in '。?!': result.append(current_sentence) current_sentence = '' return result, current_sentence, is_first #-------------------------------------------------------- # 创建新聊天 async def create_chat_handler(chat: ChatCreateRequest, db, redis): # 创建新的UserCharacter记录 new_chat = UserCharacter(user_id=chat.user_id, character_id=chat.character_id) try: db.add(new_chat) db.commit() db.refresh(new_chat) except Exception as e: db.rollback() raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) # 查询所要创建聊天的角色信息,并创建SystemPrompt db_character = db.query(Character).filter(Character.id == chat.character_id).first() db_user = db.query(User).filter(User.id == chat.user_id).first() system_prompt = f"""我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\n{"角色名称: " + db_character.name}。\n{"角色背景: " + db_character.description}\n{"角色所处环境: " + db_character.world_scenario}\n {"角色的常用问候语: " + db_character.wakeup_words}。\n你需要用简单、通俗易懂的口语化方式进行对话,在没有经过允许的情况下,你需要保持上述角色,不得擅自跳出角色设定。\n 与你聊天的对象信息如下:{db_user.persona}""" # 创建新的Session记录 session_id = str(uuid.uuid4()) user_id = chat.user_id messages = json.dumps([{"role": "system", "content": system_prompt}], ensure_ascii=False) tts_info = { "language": 0, "speaker_id":db_character.voice_id, "noise_scale": 0.1, "noise_scale_w":0.668, "length_scale": 1.2 } llm_info = { "model": "abab5.5-chat", "temperature": 1, "top_p": 0.9, } # 将tts和llm信息转化为json字符串 tts_info_str = json.dumps(tts_info, ensure_ascii=False) llm_info_str = json.dumps(llm_info, ensure_ascii=False) user_info_str = db_user.persona token = 0 content = {"user_id": user_id, "messages": messages, "user_info": user_info_str, "tts_info": tts_info_str, "llm_info": llm_info_str, "token": token} new_session = Session(id=session_id, user_character_id=new_chat.id, content=json.dumps(content, ensure_ascii=False), last_activity=datetime.now(), is_permanent=False) # 将Session记录存入 db.add(new_session) db.commit() db.refresh(new_session) redis.set(session_id, json.dumps(content, ensure_ascii=False)) chat_create_data = ChatCreateData(user_character_id=new_chat.id, session_id=session_id, createdAt=datetime.now().isoformat()) return ChatCreateResponse(status="success", message="创建聊天成功", data=chat_create_data) #删除聊天 async def delete_chat_handler(user_character_id, db, redis): # 查询该聊天记录 user_character_record = db.query(UserCharacter).filter(UserCharacter.id == user_character_id).first() if not user_character_record: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="UserCharacter not found") session_record = db.query(Session).filter(Session.user_character_id == user_character_id).first() try: redis.delete(session_record.id) except Exception as e: logger.error(f"删除Redis中Session记录时发生错误: {str(e)}") try: db.delete(session_record) db.delete(user_character_record) db.commit() except Exception as e: db.rollback() raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) chat_delete_data = ChatDeleteData(deletedAt=datetime.now().isoformat()) return ChatDeleteResponse(status="success", message="删除聊天成功", data=chat_delete_data) # 非流式聊天 async def non_streaming_chat_handler(chat: ChatNonStreamRequest, db, redis): pass #---------------------------------------单次流式聊天接口--------------------------------------------- #处理用户输入 async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,user_input_finish_event): logger.debug("用户输入处理函数启动") is_future_done = False try: while not user_input_finish_event.is_set(): sct_data_json = json.loads(await ws.receive_text()) if not is_future_done: future_session_id.set_result(sct_data_json['meta_info']['session_id']) if sct_data_json['meta_info']['voice_synthesize']: future_response_type.set_result(RESPONSE_AUDIO) else: future_response_type.set_result(RESPONSE_TEXT) is_future_done = True if sct_data_json['text']: await llm_input_q.put(sct_data_json['text']) if not user_input_finish_event.is_set(): user_input_finish_event.set() break if sct_data_json['meta_info']['is_end']: await user_input_q.put(sct_data_json['audio']) if not user_input_finish_event.is_set(): user_input_finish_event.set() break await user_input_q.put(sct_data_json['audio']) except KeyError as ke: if sct_data_json['state'] == 1 and sct_data_json['method'] == 'heartbeat': logger.debug("收到心跳包") #语音识别 async def sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event): logger.debug("语音识别函数启动") current_message = "" while not (user_input_finish_event.is_set() and user_input_q.empty()): audio_data = await user_input_q.get() asr_result = asr.streaming_recognize(audio_data) current_message += ''.join(asr_result['text']) asr_result = asr.streaming_recognize(b'',is_end=True) current_message += ''.join(asr_result['text']) await llm_input_q.put(current_message) logger.debug(f"接收到用户消息: {current_message}") #大模型调用 async def sct_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,llm_response_finish_event): logger.debug("llm调用函数启动") session_content = get_session_content(session_id,redis,db) messages = json.loads(session_content["messages"]) current_message = await llm_input_q.get() messages.append({'role': 'user', "content": current_message}) payload = json.dumps({ "model": llm_info["model"], "stream": True, "messages": messages, "max_tokens": 10000, "temperature": llm_info["temperature"], "top_p": llm_info["top_p"] }) headers = { 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Content-Type': 'application/json' } response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) if response.status_code == 200: for chunk in response.iter_lines(): if chunk: chunk_data = parseChunkDelta(chunk) await llm_response_q.put(chunk_data) llm_response_finish_event.set() #大模型返回断句 async def sct_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_response_finish_event): logger.debug("llm返回处理函数启动") llm_response = "" current_sentence = "" is_first = True while not (llm_response_finish_event.is_set() and llm_response_q.empty()): llm_chunk = await llm_response_q.get() llm_response += llm_chunk sentences, current_sentence, is_first = split_string_with_punctuation(current_sentence, llm_chunk, is_first) for sentence in sentences: await split_result_q.put(sentence) session_content = get_session_content(session_id,redis,db) messages = json.loads(session_content["messages"]) messages.append({'role': 'assistant', "content": llm_response}) session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session logger.debug(f"llm返回结果: {llm_response}") #文本返回及语音合成 async def sct_response_handler(ws,tts_info,response_type,split_result_q,llm_response_finish_event,chat_finish_event): logger.debug("返回处理函数启动") while not (llm_response_finish_event.is_set() and split_result_q.empty()): sentence = await split_result_q.get() if response_type == RESPONSE_TEXT: response_message = {"type": "text", "code":200, "msg": sentence} await ws.send_text(json.dumps(response_message, ensure_ascii=False)) elif response_type == RESPONSE_AUDIO: sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True) response_message = {"type": "text", "code":200, "msg": sentence} await ws.send_bytes(audio) await ws.send_text(json.dumps(response_message, ensure_ascii=False)) logger.debug(f"websocket返回: {sentence}") chat_finish_event.set() async def streaming_chat_temporary_handler(ws: WebSocket, db, redis): logger.debug("streaming chat temporary websocket 连接建立") user_input_q = asyncio.Queue() # 用于存储用户输入 llm_input_q = asyncio.Queue() # 用于存储llm输入 llm_response_q = asyncio.Queue() # 用于存储llm输出 split_result_q = asyncio.Queue() # 用于存储tts输出 user_input_finish_event = asyncio.Event() llm_response_finish_event = asyncio.Event() chat_finish_event = asyncio.Event() future_session_id = asyncio.Future() future_response_type = asyncio.Future() asyncio.create_task(sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,user_input_finish_event)) asyncio.create_task(sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event)) session_id = await future_session_id #获取session_id response_type = await future_response_type #获取返回类型 tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) asyncio.create_task(sct_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,llm_response_finish_event)) asyncio.create_task(sct_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_response_finish_event)) asyncio.create_task(sct_response_handler(ws,tts_info,response_type,split_result_q,llm_response_finish_event,chat_finish_event)) while not chat_finish_event.is_set(): await asyncio.sleep(1) await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False)) await ws.close() logger.debug("streaming chat temporary websocket 连接断开") #--------------------------------------------------------------------------------------------------- # 持续流式聊天 async def streaming_chat_lasting_handler(ws, db, redis): print("Websocket连接成功") while True: try: print("等待接受") data = await asyncio.wait_for(ws.receive_text(), timeout=60) data_json = json.loads(data) if data_json["is_close"]: close_message = {"type": "close", "code": 200, "msg": ""} await ws.send_text(json.dumps(close_message, ensure_ascii=False)) print("连接关闭") await asyncio.sleep(0.5) await ws.close() return; except asyncio.TimeoutError: print("连接超时") await ws.close() return; current_message = "" # 用于存储用户消息 response_type = RESPONSE_TEXT # 用于获取返回类型 session_id = "" if Config.STRAM_CHAT.ASR == "LOCAL": try: while True: if data_json["text"]: # 若文字输入不为空,则表示该输入为文字输入 if data_json["meta_info"]["voice_synthesize"]: response_type = RESPONSE_AUDIO # 查看voice_synthesize判断返回类型 session_id = data_json["meta_info"]["session_id"] current_message = data_json['text'] break if not data_json['meta_info']['is_end']: # 还在发 asr_result = asr.streaming_recognize(data_json["audio"]) current_message += ''.join(asr_result['text']) else: # 发完了 asr_result = asr.streaming_recognize(data_json["audio"], is_end=True) session_id = data_json["meta_info"]["session_id"] current_message += ''.join(asr_result['text']) if data_json["meta_info"]["voice_synthesize"]: response_type = RESPONSE_AUDIO # 查看voice_synthesize判断返回类型 break data_json = json.loads(await ws.receive_text()) except Exception as e: error_info = f"接收用户消息错误: {str(e)}" error_message = {"type": "error", "code": "500", "msg": error_info} logger.error(error_info) await ws.send_text(json.dumps(error_message, ensure_ascii=False)) await ws.close() return elif Config.STRAM_CHAT.ASR == "REMOTE": error_info = f"远程ASR服务暂未开通" error_message = {"type": "error", "code": "500", "msg": error_info} logger.error(error_info) await ws.send_text(json.dumps(error_message, ensure_ascii=False)) await ws.close() return print(f"接收到用户消息: {current_message}") # 查询Session记录 session_content_str = "" if redis.exists(session_id): session_content_str = redis.get(session_id) else: session_db = db.query(Session).filter(Session.id == session_id).first() if not session_db: error_info = f"未找到session记录: {str(e)}" error_message = {"type": "error", "code": 500, "msg": error_info} logger.error(error_info) await ws.send_text(json.dumps(error_message, ensure_ascii=False)) await ws.close() return session_content_str = session_db.content session_content = json.loads(session_content_str) llm_info = json.loads(session_content["llm_info"]) tts_info = json.loads(session_content["tts_info"]) user_info = json.loads(session_content["user_info"]) messages = json.loads(session_content["messages"]) messages.append({'role': 'user', "content": current_message}) token_count = session_content["token"] try: payload = json.dumps({ "model": llm_info["model"], "stream": True, "messages": messages, "tool_choice": "auto", "max_tokens": 10000, "temperature": llm_info["temperature"], "top_p": llm_info["top_p"] }) headers = { 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Content-Type': 'application/json' } response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) except Exception as e: error_info = f"发送信息给大模型时发生错误: {str(e)}" error_message = {"type": "error", "code": 500, "msg": error_info} logger.error(error_info) await ws.send_text(json.dumps(error_message, ensure_ascii=False)) await ws.close() return def split_string_with_punctuation(text): punctuations = "!?。" result = [] current_sentence = "" for char in text: current_sentence += char if char in punctuations: result.append(current_sentence) current_sentence = "" # 判断最后一个字符是否为标点符号 if current_sentence and current_sentence[-1] not in punctuations: # 如果最后一段不以标点符号结尾,则加入拆分数组 result.append(current_sentence) return result llm_response = "" response_buf = "" try: if Config.STRAM_CHAT.TTS == "LOCAL": if response.status_code == 200: for chunk in response.iter_lines(): if chunk: if response_type == RESPONSE_AUDIO: chunk_data = parseChunkDelta(chunk) llm_response += chunk_data response_buf += chunk_data split_buf = split_string_with_punctuation(response_buf) response_buf = "" if len(split_buf) != 0: for sentence in split_buf: sr, audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True) text_response = {"type": "text", "code": 200, "msg": sentence} await ws.send_text(json.dumps(text_response, ensure_ascii=False)) # 返回文本数据 await ws.send_bytes(audio) # 返回音频二进制流数据 if response_type == RESPONSE_TEXT: chunk_data = parseChunkDelta(chunk) llm_response += chunk_data text_response = {"type": "text", "code": 200, "msg": chunk_data} await ws.send_text(json.dumps(text_response, ensure_ascii=False)) # 返回文本数据 elif Config.STRAM_CHAT.TTS == "REMOTE": error_info = f"暂不支持远程音频合成" error_message = {"type": "error", "code": 500, "msg": error_info} logger.error(error_info) await ws.send_text(json.dumps(error_message, ensure_ascii=False)) await ws.close() return end_response = {"type": "end", "code": 200, "msg": ""} await ws.send_text(json.dumps(end_response, ensure_ascii=False)) # 单次返回结束 print(f"llm消息: {llm_response}") except Exception as e: error_info = f"音频合成与向前端返回时错误: {str(e)}" error_message = {"type": "error", "code": 500, "msg": error_info} logger.error(error_info) await ws.send_text(json.dumps(error_message, ensure_ascii=False)) await ws.close() return try: messages.append({'role': 'assistant', "content": llm_response}) token_count += len(llm_response) session_content["messages"] = json.dumps(messages, ensure_ascii=False) session_content["token"] = token_count redis.set(session_id, json.dumps(session_content, ensure_ascii=False)) except Exception as e: error_info = f"更新session时错误: {str(e)}" error_message = {"type": "error", "code": 500, "msg": error_info} logger.error(error_info) await ws.send_text(json.dumps(error_message, ensure_ascii=False)) await ws.close() return print("处理完毕") #--------------------------------语音通话接口-------------------------------------- #音频数据生产函数 async def voice_call_audio_producer(ws,audio_queue,future,shutdown_event): logger.debug("音频数据生产函数启动") is_future_done = False while not shutdown_event.is_set(): voice_call_data_json = json.loads(await ws.receive_text()) if not is_future_done: #在第一次循环中读取session_id future.set_result(voice_call_data_json['meta_info']['session_id']) is_future_done = True if voice_call_data_json["is_close"]: shutdown_event.set() break else: await audio_queue.put(voice_call_data_json["audio"]) #将音频数据存入audio_q #音频数据消费函数 async def voice_call_audio_consumer(audio_q,asr_result_q,shutdown_event): logger.debug("音频数据消费者函数启动") vad = VAD() current_message = "" vad_count = 0 while not (shutdown_event.is_set() and audio_q.empty()): audio_data = await audio_q.get() if vad.is_speech(audio_data): if vad_count > 0: vad_count -= 1 asr_result = asr.streaming_recognize(audio_data) current_message += ''.join(asr_result['text']) else: vad_count += 1 if vad_count >= 25: #连续25帧没有语音,则认为说完了 asr_result = asr.streaming_recognize(audio_data, is_end=True) if current_message: logger.debug(f"检测到静默,用户输入为:{current_message}") await asr_result_q.put(current_message) current_message = "" vad_count = 0 #asr结果消费以及llm返回生产函数 async def voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,shutdown_event): logger.debug("asr结果消费以及llm返回生产函数启动") while not (shutdown_event.is_set() and asr_result_q.empty()): session_content = get_session_content(session_id,redis,db) messages = json.loads(session_content["messages"]) current_message = await asr_result_q.get() messages.append({'role': 'user', "content": current_message}) payload = json.dumps({ "model": llm_info["model"], "stream": True, "messages": messages, "max_tokens":10000, "temperature": llm_info["temperature"], "top_p": llm_info["top_p"] }) headers = { 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Content-Type': 'application/json' } response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) if response.status_code == 200: for chunk in response.iter_lines(): if chunk: chunk_data =parseChunkDelta(chunk) llm_frame = {'message':chunk_data,'is_end':False} await llm_response_q.put(llm_frame) llm_frame = {'message':"",'is_end':True} await llm_response_q.put(llm_frame) #llm结果返回函数 async def voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,shutdown_event): logger.debug("llm结果返回函数启动") llm_response = "" current_sentence = "" is_first = True while not (shutdown_event.is_set() and llm_response_q.empty()): llm_frame = await llm_response_q.get() llm_response += llm_frame['message'] sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,llm_frame['message'],is_first) for sentence in sentences: await split_result_q.put(sentence) if llm_frame['is_end']: is_first = True session_content = get_session_content(session_id,redis,db) messages = json.loads(session_content["messages"]) messages.append({'role': 'assistant', "content": llm_response}) session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session logger.debug(f"llm返回结果: {llm_response}") llm_response = "" current_sentence = "" #语音合成及返回函数 async def voice_call_tts_handler(ws,tts_info,split_result_q,shutdown_event): logger.debug("语音合成及返回函数启动") while not (shutdown_event.is_set() and split_result_q.empty()): sentence = await split_result_q.get() sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True) text_response = {"type": "text", "code": 200, "msg": sentence} await ws.send_bytes(audio) #返回音频二进制流数据 await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 logger.debug(f"websocket返回:{sentence}") asyncio.sleep(0.5) await ws.close() async def voice_call_handler(ws, db, redis): logger.debug("voice_call websocket 连接建立") audio_q = asyncio.Queue() asr_result_q = asyncio.Queue() llm_response_q = asyncio.Queue() split_result_q = asyncio.Queue() shutdown_event = asyncio.Event() future = asyncio.Future() asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,shutdown_event)) #创建音频数据生产者 asyncio.create_task(voice_call_audio_consumer(audio_q,asr_result_q,shutdown_event)) #创建音频数据消费者 #获取session内容 session_id = await future #获取session_id tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) asyncio.create_task(voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,shutdown_event)) #创建llm处理者 asyncio.create_task(voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,shutdown_event)) #创建llm断句结果 asyncio.create_task(voice_call_tts_handler(ws,tts_info,split_result_q,shutdown_event)) #返回tts音频结果 while not shutdown_event.is_set(): await asyncio.sleep(5) await ws.close() logger.debug("voice_call websocket 连接断开") #------------------------------------------------------------------------------------------