From dba43836b619a0c69b9ae7c3521e8d9f04e7cb13 Mon Sep 17 00:00:00 2001 From: killua4396 <1223086337@qq.com> Date: Mon, 10 Jun 2024 02:28:16 +0800 Subject: [PATCH] =?UTF-8?q?debug:=20=E4=BF=AE=E5=A4=8D=E4=BA=86=E5=A4=9A?= =?UTF-8?q?=E4=B8=AAbug=201.=E5=AE=8C=E6=88=90=E4=BA=86=E7=AB=AF=E4=BE=A7?= =?UTF-8?q?=E5=8F=91=E9=80=81=E6=9D=82=E9=9F=B3=E8=AF=86=E5=88=AB=202.?= =?UTF-8?q?=E4=BF=AE=E5=A4=8Dminimax=E8=BF=94=E5=9B=9E=E5=B8=A7=E8=A7=A3?= =?UTF-8?q?=E6=9E=90=E9=94=99=E8=AF=AFbug=203.=E4=BF=AE=E5=A4=8D=E4=BA=86?= =?UTF-8?q?=E8=BF=94=E5=9B=9E=E8=BF=87=E6=85=A2=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 6 +++- app/concrete.py | 44 +++++++++++++++++++---------- app/dependency.py | 29 ++++++++++++++++++++ app/exception.py | 41 +++++++++++++++++++++++++++ app/public.py | 27 +----------------- config.py | 4 +-- main.py | 67 +++++++++++++++++++++++++++++---------------- test/test.py | 43 +++++++++++++++++++++++++---- utils/vits_utils.py | 20 ++++---------- 9 files changed, 193 insertions(+), 88 deletions(-) create mode 100644 app/dependency.py create mode 100644 app/exception.py diff --git a/.gitignore b/.gitignore index 2e3caf3..52f78dc 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,8 @@ __pycache__/ /utils/vits_model/config.json /utils/vits_model/G_953000.pth -takway.db \ No newline at end of file +takway.db + +/storage/** + +/app.log \ No newline at end of file diff --git a/app/concrete.py b/app/concrete.py index fb812df..6fce8cd 100644 --- a/app/concrete.py +++ b/app/concrete.py @@ -2,6 +2,8 @@ from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_ from .model import Assistant from .abstract import * from .public import * +from .exception import * +from .dependency import get_logger from utils.vits_utils import TextToSpeech from config import Config import aiohttp @@ -14,6 +16,9 @@ import json vits = TextToSpeech() # ---------------------------------- # +# ---------- 初始化logger ---------- # +logger = get_logger() +# ---------------------------------- # #------ 具体 ASR, LLM, TTS 类 ------ # @@ -31,6 +36,8 @@ class XF_ASR(ASR): self.segment_start_time = None async def stream_recognize(self, chunk): + if self.status == FIRST_FRAME and chunk['meta_info']['is_end']: #如果是第一帧,且为end,则判断为杂音 + raise SideNoiseError() if self.websocket is None: #如果websocket未建立,则建立一个新的连接 self.websocket = await xf_asr_websocket_factory() if self.segment_start_time is None: #如果是第一段,则记录开始时间 @@ -50,8 +57,7 @@ class XF_ASR(ASR): self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv())) if self.current_message == "": raise AsrResultNoneError() - await self.websocket.close() - print("语音识别结束,用户消息:", self.current_message) + asyncio.create_task(self.websocket.close()) return [{"text":self.current_message, "audio":self.audio}] current_time = asyncio.get_event_loop().time() @@ -71,8 +77,7 @@ class MINIMAX_LLM(LLM): async def chat(self, assistant, prompt, db): llm_info = json.loads(assistant.llm_info) messages = json.loads(assistant.messages) - messages.append({"role":"user","content":prompt}) - assistant.messages = json.dumps(messages) + messages.append({'role':'user','content':prompt}) payload = json.dumps({ "model": llm_info['model'], "stream": True, @@ -104,17 +109,26 @@ class MINIMAX_LLM(LLM): def __parseChunk(self, llm_chunk): - result = "" - data=json.loads(llm_chunk.decode('utf-8')[6:]) - if data["object"] == "chat.completion": #如果是结束帧 - self.token = data['usage']['total_tokens'] - raise LLMResponseEnd() - elif data['object'] == 'chat.completion.chunk': - for choice in data['choices']: - result += choice['delta']['content'] - else: - raise UnkownLLMFrame() - return result + try: + result = "" + chunk_decoded = llm_chunk.decode('utf-8') + chunks = chunk_decoded.split('\n\n') + for chunk in chunks: + if not chunk: + continue + data=json.loads(chunk[6:]) + if data["object"] == "chat.completion": #如果是结束帧 + self.token = data['usage']['total_tokens'] + raise LLMResponseEnd() + elif data['object'] == 'chat.completion.chunk': + for choice in data['choices']: + result += choice['delta']['content'] + else: + raise UnkownLLMFrame() + return result + except (json.JSONDecodeError, KeyError): + logger.error(llm_chunk) + raise AbnormalLLMFrame(f"error llm_chunk:{llm_chunk}") class VITS_TTS(TTS): def __init__(self): diff --git a/app/dependency.py b/app/dependency.py new file mode 100644 index 0000000..ee17b71 --- /dev/null +++ b/app/dependency.py @@ -0,0 +1,29 @@ +from config import Config +import logging + +#日志类 +class Logger: + def __init__(self): + self.logger = logging.getLogger(__name__) + self.logger.setLevel(Config.LOG_LEVEL) + self.logger.propagate = False + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + + if not self.logger.handlers: # 检查是否已经有处理器 + # 输出到控制台 + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + self.logger.addHandler(console_handler) + + # 输出到文件 + file_handler = logging.FileHandler('app.log') + file_handler.setFormatter(formatter) + self.logger.addHandler(file_handler) + +logger = None + +def get_logger(): + global logger + if logger is None: + logger = Logger() + return logger.logger \ No newline at end of file diff --git a/app/exception.py b/app/exception.py new file mode 100644 index 0000000..ea02356 --- /dev/null +++ b/app/exception.py @@ -0,0 +1,41 @@ +# Asr结果为空异常 +class AsrResultNoneError(Exception): + def __init__(self, message="Asr Result is None!"): + super().__init__(message) + self.message = message + +# 如果asr_results中没有结果异常 +class NoAsrResultsError(Exception): + def __init__(self, message="No Asr Results!"): + super().__init__(message) + self.message = message + +# 未知LLM返回帧 +class UnkownLLMFrame(Exception): + def __init__(self, message="Unkown LLM Frame!"): + super().__init__(message) + self.message = message + +# 异常LLM返回帧 +class AbnormalLLMFrame(Exception): + def __init__(self, message="Abnormal LLM Frame!"): + super().__init__(message) + self.message = message + +# token超出阈值异常 +class TokenOutofRangeError(Exception): + def __init__(self, message="Token Out of Range!"): + super().__init__(message) + self.message = message + +# 接收到端侧杂音异常 +class SideNoiseError(Exception): + def __init__(self, message="Side Noise!"): + super().__init__(message) + self.message = message + +# 大模型返回结束(非异常) +class LLMResponseEnd(Exception): + def __init__(self, message="LLM Response End!"): + super().__init__(message) + self.message = message diff --git a/app/public.py b/app/public.py index 0a886be..24f92d1 100644 --- a/app/public.py +++ b/app/public.py @@ -2,32 +2,7 @@ from datetime import datetime import wave import json -# -------------- 公共类 ------------ # -class AsrResultNoneError(Exception): - def __init__(self, message="Asr Result is None!"): - super().__init__(message) - self.message = message - -class NoAsrResultsError(Exception): - def __init__(self, message="No Asr Results!"): - super().__init__(message) - self.message = message - -class LLMResponseEnd(Exception): - def __init__(self, message="LLM Response End!"): - super().__init__(message) - self.message = message - -class UnkownLLMFrame(Exception): - def __init__(self, message="Unkown LLM Frame!"): - super().__init__(message) - self.message = message - -class TokenOutofRangeError(Exception): - def __init__(self, message="Token Out of Range!"): - super().__init__(message) - self.message = message - +# -------------- 公共类 ------------ # class SentenceSegmentation(): def __init__(self,): self.is_first_sentence = True diff --git a/config.py b/config.py index 99431d2..4282b57 100644 --- a/config.py +++ b/config.py @@ -3,6 +3,7 @@ class Config: ASR = "XF" #在此处选择语音识别引擎 LLM = "MINIMAX" #在此处选择大模型 TTS = "VITS" #在此处选择语音合成引擎 + LOG_LEVEL = "DEBUG" class UVICORN: HOST = '0.0.0.0' PORT = 7878 @@ -16,5 +17,4 @@ class Config: VAD_EOS = 10000 class MINIMAX_LLM: API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLph5EiLCJVc2VyTmFtZSI6IumHkSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzY4NTM2NDM3MzE1MDgwODg2IiwiUGhvbmUiOiIxMzEzNjE0NzUyNyIsIkdyb3VwSUQiOiIxNzY4NTM2NDM3MzA2NjkyMjc4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjQtMDUtMTggMTY6MTQ6MDMiLCJpc3MiOiJtaW5pbWF4In0.LypYOkJXwKV6GzDM1dcNn4L0m19o8Q_Lvmn6SkMMb9WAfDJYxEnTc5odm-L4WAWfbur_gY0cQzgoHnI14t4XSaAvqfmcdCrKYpJbKoBmMse_RogJs7KOBt658je3wES4pBUKQll6NbogQB1f93lnA9IYv4aEVldfqglbCikd54XO8E9Ptn4gX9Mp8fUn3lCpZ6_OSlmgZsQySrmt1sDHHzi3DlkdXlFSI38TQSZIa5RhFpI8WSBLIbaKl84OhaDzo7v99k9DUCzb5JGh0eZOnUT0YswbKCPeV8rZ1XUiOVQrna1uiDLvqv54aIt3vsu-LypYmnHxtZ_z4u2gt87pZg" - URL = "https://api.minimax.chat/v1/text/chatcompletion_v2" - \ No newline at end of file + URL = "https://api.minimax.chat/v1/text/chatcompletion_v2" \ No newline at end of file diff --git a/main.py b/main.py index ba5bea6..f88ab14 100644 --- a/main.py +++ b/main.py @@ -1,20 +1,26 @@ from fastapi import FastAPI, Depends, WebSocket, HTTPException from fastapi.middleware.cors import CORSMiddleware from config import Config -from app.concrete import Agent, AsrResultNoneError +from app.concrete import Agent from app.model import Assistant, User, get_db from app.schemas import * +from app.dependency import get_logger +from app.exception import AsrResultNoneError, AbnormalLLMFrame, SideNoiseError import uvicorn import uuid import json # 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ -def update_messages(messages, llm_text): +def update_messages(messages, kid_text,llm_text): messages = json.loads(messages) + messages.append({"role":"user","content":kid_text}) messages.append({"role":"assistant","content":llm_text}) return json.dumps(messages,ensure_ascii=False) # -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +# 引入logger对象 +logger = get_logger() + # 创建FastAPI实例 app = FastAPI() @@ -89,6 +95,7 @@ async def update_assistant_system_prompt(id: str,request: update_assistant_syste assistant = db.query(Assistant).filter(Assistant.id == id).first() if assistant: assistant.system_prompt = request.system_prompt + assistant.messages = json.dumps([{"role":"system","content":assistant.system_prompt}],ensure_ascii=False) db.commit() return {"code":200,"msg":"success","data":{}} else: @@ -163,11 +170,13 @@ async def update_user(id: str,request: update_user_request,db=Depends(get_db)): @app.websocket("/api/chat/streaming/temporary") async def streaming_chat(ws: WebSocket,db=Depends(get_db)): await ws.accept() - agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS) - assistant = None - asr_results = [] - llm_text = "" + logger.debug("WebSocket连接成功") try: + agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS) + assistant = None + asr_results = [] + llm_text = "" + logger.debug("开始进行ASR识别") while len(asr_results)==0: chunk = json.loads(await ws.receive_text()) if assistant is None: @@ -175,24 +184,36 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)): agent.init_recorder(assistant.user_id) chunk["audio"] = agent.user_audio_process(chunk["audio"], db) asr_results = await agent.stream_recognize(chunk, db) + kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果 + logger.debug(f"ASR识别成功,识别结果为:{kid_text}") + prompt = agent.prompt_process(asr_results, db) + agent.recorder.input_text = prompt + logger.debug("开始调用大模型") + llm_frames = await agent.chat(assistant, prompt, db) + async for llm_frame in llm_frames: + resp_msgs = agent.llm_msg_process(llm_frame, db) + for resp_msg in resp_msgs: + llm_text += resp_msg + tts_audio = agent.synthetize(assistant, resp_msg, db) + agent.tts_audio_process(tts_audio, db) + await ws.send_bytes(agent.encode(resp_msg, tts_audio)) + logger.debug(f'websocket返回:{resp_msg}') + logger.debug(f"大模型返回结束,返回结果为:{llm_text}") + await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False)) + logger.debug("结束帧发送完毕") + assistant.messages = update_messages(assistant.messages, kid_text ,llm_text) + db.commit() + logger.debug("聊天更新成功") + agent.recorder.output_text = llm_text + agent.save() + logger.debug("音频保存成功") except AsrResultNoneError: - await ws.send_text(json.dumps({"type":"close","code":201,"msg":""}, ensure_ascii=False)) - return - prompt = agent.prompt_process(asr_results, db) - agent.recorder.input_text = prompt - llm_frames = await agent.chat(assistant, prompt, db) - async for llm_frame in llm_frames: - resp_msgs = agent.llm_msg_process(llm_frame, db) - for resp_msg in resp_msgs: - llm_text += resp_msg - tts_audio = agent.synthetize(assistant, resp_msg, db) - agent.tts_audio_process(tts_audio, db) - await ws.send_bytes(agent.encode(resp_msg, tts_audio)) - await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False)) - assistant.messages = update_messages(assistant.messages, llm_text) - db.commit() - agent.recorder.output_text = llm_text - agent.save() + await ws.send_text(json.dumps({"type":"close","code":201,"msg":"asr结果为空"}, ensure_ascii=False)) + except AbnormalLLMFrame as e: + await ws.send_text(json.dumps({"type":"close","code":202,"msg":str(e)}, ensure_ascii=False)) + except SideNoiseError as e: + await ws.send_text(json.dumps({"type":"close","code":203,"msg":str(e)}, ensure_ascii=False)) + logger.debug("WebSocket连接断开") await ws.close() # -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/test/test.py b/test/test.py index b9b17aa..4659ee4 100644 --- a/test/test.py +++ b/test/test.py @@ -1,14 +1,14 @@ import json +import time import base64 from datetime import datetime -import io from websocket import create_connection data = { "text": "", "audio": "", "meta_info": { - "session_id":"a36c9bb4-e813-4f0e-9c75-18e049c60f48", + "session_id":"469f4a99-12a5-45a6-bc91-353df07423b6", "stream": True, "voice_synthesize": True, "is_end": False, @@ -48,10 +48,17 @@ def send_json(): # 发送最后一个数据块和流结束信号 send_audio_chunk(websocket, b'') # 发送空数据块表示结束 # 等待并打印接收到的数据 - print("等待接收:", datetime.now()) + wait_start_time = time.time() + print("开始等待返回数据:", datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')) + audio_bytes = b'' + is_receive_first_frame = False while True: data_ws = websocket.recv() + if not is_receive_first_frame: + wait_end_time = time.time() + print("等待时间:", wait_end_time - wait_start_time) + is_receive_first_frame = True try: message_json = json.loads(data_ws) print(message_json) # 打印接收到的消息 @@ -59,10 +66,34 @@ def send_json(): break # 如果没有接收到消息,则退出循环 except Exception as e: audio_bytes += data_ws - - print(e) + # print(e) print("接收完毕:", datetime.now()) websocket.close() +# 模拟检测到杂音时,只发一帧end +def send_one_end_frame(): + websocket = create_connection('ws://114.214.236.207:7878/api/chat/streaming/temporary') + data["meta_info"]["is_end"] = True + send_audio_chunk(websocket, b'') # 发送空数据块表示结束 + audio_bytes = b'' + is_receive_first_frame = False + while True: + data_ws = websocket.recv() + if not is_receive_first_frame: + wait_end_time = time.time() + is_receive_first_frame = True + try: + message_json = json.loads(data_ws) + print(message_json) # 打印接收到的消息 + if message_json["type"] == "close": + break # 如果没有接收到消息,则退出循环 + except Exception as e: + audio_bytes += data_ws + # print(e) + websocket.close() + # 启动事件循环 -send_json() \ No newline at end of file +# send_json() + + +send_one_end_frame() \ No newline at end of file diff --git a/utils/vits_utils.py b/utils/vits_utils.py index 1ff919c..56168f8 100644 --- a/utils/vits_utils.py +++ b/utils/vits_utils.py @@ -9,21 +9,6 @@ from .vits import utils, commons from .vits.models import SynthesizerTrn from .vits.text import text_to_sequence -def tts_model_init(model_path='./vits_model', device='cuda'): - hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json')) - # hps_ms = utils.get_hparams_from_file('vits_model/config.json') - net_g_ms = SynthesizerTrn( - len(hps_ms.symbols), - hps_ms.data.filter_length // 2 + 1, - hps_ms.train.segment_size // hps_ms.data.hop_length, - n_speakers=hps_ms.data.n_speakers, - **hps_ms.model) - net_g_ms = net_g_ms.eval().to(device) - speakers = hps_ms.speakers - utils.load_checkpoint(os.path.join(model_path, 'G_953000.pth'), net_g_ms, None) - # utils.load_checkpoint('vits_model/G_953000.pth', net_g_ms, None) - return hps_ms, net_g_ms, speakers - class TextToSpeech: def __init__(self, model_path="./utils/vits_model", @@ -36,6 +21,11 @@ class TextToSpeech: self.device = torch.device(device) self.limitation = os.getenv("SYSTEM") == "spaces" # 在huggingface spaces中限制文本和音频长度 self.hps_ms, self.net_g_ms, self.speakers = self._tts_model_init(model_path) + self._init_jieba() + + def _init_jieba(self): + text = self._preprocess_text("初始化", 0) + self._generate_audio(text, 100, 0.6, 0.668, 1.0) def _tts_model_init(self, model_path): hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))