debug: 修复了多个bug

1.完成了端侧发送杂音识别
2.修复minimax返回帧解析错误bug
3.修复了返回过慢的bug
This commit is contained in:
killua4396 2024-06-10 02:28:16 +08:00
parent 02272c8a8b
commit dba43836b6
9 changed files with 193 additions and 88 deletions

6
.gitignore vendored
View File

@ -3,4 +3,8 @@ __pycache__/
/utils/vits_model/config.json /utils/vits_model/config.json
/utils/vits_model/G_953000.pth /utils/vits_model/G_953000.pth
takway.db takway.db
/storage/**
/app.log

View File

@ -2,6 +2,8 @@ from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_
from .model import Assistant from .model import Assistant
from .abstract import * from .abstract import *
from .public import * from .public import *
from .exception import *
from .dependency import get_logger
from utils.vits_utils import TextToSpeech from utils.vits_utils import TextToSpeech
from config import Config from config import Config
import aiohttp import aiohttp
@ -14,6 +16,9 @@ import json
vits = TextToSpeech() vits = TextToSpeech()
# ---------------------------------- # # ---------------------------------- #
# ---------- 初始化logger ---------- #
logger = get_logger()
# ---------------------------------- #
#------ 具体 ASR, LLM, TTS 类 ------ # #------ 具体 ASR, LLM, TTS 类 ------ #
@ -31,6 +36,8 @@ class XF_ASR(ASR):
self.segment_start_time = None self.segment_start_time = None
async def stream_recognize(self, chunk): async def stream_recognize(self, chunk):
if self.status == FIRST_FRAME and chunk['meta_info']['is_end']: #如果是第一帧且为end则判断为杂音
raise SideNoiseError()
if self.websocket is None: #如果websocket未建立则建立一个新的连接 if self.websocket is None: #如果websocket未建立则建立一个新的连接
self.websocket = await xf_asr_websocket_factory() self.websocket = await xf_asr_websocket_factory()
if self.segment_start_time is None: #如果是第一段,则记录开始时间 if self.segment_start_time is None: #如果是第一段,则记录开始时间
@ -50,8 +57,7 @@ class XF_ASR(ASR):
self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv())) self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv()))
if self.current_message == "": if self.current_message == "":
raise AsrResultNoneError() raise AsrResultNoneError()
await self.websocket.close() asyncio.create_task(self.websocket.close())
print("语音识别结束,用户消息:", self.current_message)
return [{"text":self.current_message, "audio":self.audio}] return [{"text":self.current_message, "audio":self.audio}]
current_time = asyncio.get_event_loop().time() current_time = asyncio.get_event_loop().time()
@ -71,8 +77,7 @@ class MINIMAX_LLM(LLM):
async def chat(self, assistant, prompt, db): async def chat(self, assistant, prompt, db):
llm_info = json.loads(assistant.llm_info) llm_info = json.loads(assistant.llm_info)
messages = json.loads(assistant.messages) messages = json.loads(assistant.messages)
messages.append({"role":"user","content":prompt}) messages.append({'role':'user','content':prompt})
assistant.messages = json.dumps(messages)
payload = json.dumps({ payload = json.dumps({
"model": llm_info['model'], "model": llm_info['model'],
"stream": True, "stream": True,
@ -104,17 +109,26 @@ class MINIMAX_LLM(LLM):
def __parseChunk(self, llm_chunk): def __parseChunk(self, llm_chunk):
result = "" try:
data=json.loads(llm_chunk.decode('utf-8')[6:]) result = ""
if data["object"] == "chat.completion": #如果是结束帧 chunk_decoded = llm_chunk.decode('utf-8')
self.token = data['usage']['total_tokens'] chunks = chunk_decoded.split('\n\n')
raise LLMResponseEnd() for chunk in chunks:
elif data['object'] == 'chat.completion.chunk': if not chunk:
for choice in data['choices']: continue
result += choice['delta']['content'] data=json.loads(chunk[6:])
else: if data["object"] == "chat.completion": #如果是结束帧
raise UnkownLLMFrame() self.token = data['usage']['total_tokens']
return result raise LLMResponseEnd()
elif data['object'] == 'chat.completion.chunk':
for choice in data['choices']:
result += choice['delta']['content']
else:
raise UnkownLLMFrame()
return result
except (json.JSONDecodeError, KeyError):
logger.error(llm_chunk)
raise AbnormalLLMFrame(f"error llm_chunk:{llm_chunk}")
class VITS_TTS(TTS): class VITS_TTS(TTS):
def __init__(self): def __init__(self):

29
app/dependency.py Normal file
View File

@ -0,0 +1,29 @@
from config import Config
import logging
#日志类
class Logger:
def __init__(self):
self.logger = logging.getLogger(__name__)
self.logger.setLevel(Config.LOG_LEVEL)
self.logger.propagate = False
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
if not self.logger.handlers: # 检查是否已经有处理器
# 输出到控制台
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
# 输出到文件
file_handler = logging.FileHandler('app.log')
file_handler.setFormatter(formatter)
self.logger.addHandler(file_handler)
logger = None
def get_logger():
global logger
if logger is None:
logger = Logger()
return logger.logger

41
app/exception.py Normal file
View File

@ -0,0 +1,41 @@
# Asr结果为空异常
class AsrResultNoneError(Exception):
def __init__(self, message="Asr Result is None!"):
super().__init__(message)
self.message = message
# 如果asr_results中没有结果异常
class NoAsrResultsError(Exception):
def __init__(self, message="No Asr Results!"):
super().__init__(message)
self.message = message
# 未知LLM返回帧
class UnkownLLMFrame(Exception):
def __init__(self, message="Unkown LLM Frame!"):
super().__init__(message)
self.message = message
# 异常LLM返回帧
class AbnormalLLMFrame(Exception):
def __init__(self, message="Abnormal LLM Frame!"):
super().__init__(message)
self.message = message
# token超出阈值异常
class TokenOutofRangeError(Exception):
def __init__(self, message="Token Out of Range!"):
super().__init__(message)
self.message = message
# 接收到端侧杂音异常
class SideNoiseError(Exception):
def __init__(self, message="Side Noise!"):
super().__init__(message)
self.message = message
# 大模型返回结束(非异常)
class LLMResponseEnd(Exception):
def __init__(self, message="LLM Response End!"):
super().__init__(message)
self.message = message

View File

@ -2,32 +2,7 @@ from datetime import datetime
import wave import wave
import json import json
# -------------- 公共类 ------------ # # -------------- 公共类 ------------ #
class AsrResultNoneError(Exception):
def __init__(self, message="Asr Result is None!"):
super().__init__(message)
self.message = message
class NoAsrResultsError(Exception):
def __init__(self, message="No Asr Results!"):
super().__init__(message)
self.message = message
class LLMResponseEnd(Exception):
def __init__(self, message="LLM Response End!"):
super().__init__(message)
self.message = message
class UnkownLLMFrame(Exception):
def __init__(self, message="Unkown LLM Frame!"):
super().__init__(message)
self.message = message
class TokenOutofRangeError(Exception):
def __init__(self, message="Token Out of Range!"):
super().__init__(message)
self.message = message
class SentenceSegmentation(): class SentenceSegmentation():
def __init__(self,): def __init__(self,):
self.is_first_sentence = True self.is_first_sentence = True

View File

@ -3,6 +3,7 @@ class Config:
ASR = "XF" #在此处选择语音识别引擎 ASR = "XF" #在此处选择语音识别引擎
LLM = "MINIMAX" #在此处选择大模型 LLM = "MINIMAX" #在此处选择大模型
TTS = "VITS" #在此处选择语音合成引擎 TTS = "VITS" #在此处选择语音合成引擎
LOG_LEVEL = "DEBUG"
class UVICORN: class UVICORN:
HOST = '0.0.0.0' HOST = '0.0.0.0'
PORT = 7878 PORT = 7878
@ -16,5 +17,4 @@ class Config:
VAD_EOS = 10000 VAD_EOS = 10000
class MINIMAX_LLM: class MINIMAX_LLM:
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLph5EiLCJVc2VyTmFtZSI6IumHkSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzY4NTM2NDM3MzE1MDgwODg2IiwiUGhvbmUiOiIxMzEzNjE0NzUyNyIsIkdyb3VwSUQiOiIxNzY4NTM2NDM3MzA2NjkyMjc4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjQtMDUtMTggMTY6MTQ6MDMiLCJpc3MiOiJtaW5pbWF4In0.LypYOkJXwKV6GzDM1dcNn4L0m19o8Q_Lvmn6SkMMb9WAfDJYxEnTc5odm-L4WAWfbur_gY0cQzgoHnI14t4XSaAvqfmcdCrKYpJbKoBmMse_RogJs7KOBt658je3wES4pBUKQll6NbogQB1f93lnA9IYv4aEVldfqglbCikd54XO8E9Ptn4gX9Mp8fUn3lCpZ6_OSlmgZsQySrmt1sDHHzi3DlkdXlFSI38TQSZIa5RhFpI8WSBLIbaKl84OhaDzo7v99k9DUCzb5JGh0eZOnUT0YswbKCPeV8rZ1XUiOVQrna1uiDLvqv54aIt3vsu-LypYmnHxtZ_z4u2gt87pZg" API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLph5EiLCJVc2VyTmFtZSI6IumHkSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzY4NTM2NDM3MzE1MDgwODg2IiwiUGhvbmUiOiIxMzEzNjE0NzUyNyIsIkdyb3VwSUQiOiIxNzY4NTM2NDM3MzA2NjkyMjc4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjQtMDUtMTggMTY6MTQ6MDMiLCJpc3MiOiJtaW5pbWF4In0.LypYOkJXwKV6GzDM1dcNn4L0m19o8Q_Lvmn6SkMMb9WAfDJYxEnTc5odm-L4WAWfbur_gY0cQzgoHnI14t4XSaAvqfmcdCrKYpJbKoBmMse_RogJs7KOBt658je3wES4pBUKQll6NbogQB1f93lnA9IYv4aEVldfqglbCikd54XO8E9Ptn4gX9Mp8fUn3lCpZ6_OSlmgZsQySrmt1sDHHzi3DlkdXlFSI38TQSZIa5RhFpI8WSBLIbaKl84OhaDzo7v99k9DUCzb5JGh0eZOnUT0YswbKCPeV8rZ1XUiOVQrna1uiDLvqv54aIt3vsu-LypYmnHxtZ_z4u2gt87pZg"
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2" URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"

67
main.py
View File

@ -1,20 +1,26 @@
from fastapi import FastAPI, Depends, WebSocket, HTTPException from fastapi import FastAPI, Depends, WebSocket, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from config import Config from config import Config
from app.concrete import Agent, AsrResultNoneError from app.concrete import Agent
from app.model import Assistant, User, get_db from app.model import Assistant, User, get_db
from app.schemas import * from app.schemas import *
from app.dependency import get_logger
from app.exception import AsrResultNoneError, AbnormalLLMFrame, SideNoiseError
import uvicorn import uvicorn
import uuid import uuid
import json import json
# 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ # 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------
def update_messages(messages, llm_text): def update_messages(messages, kid_text,llm_text):
messages = json.loads(messages) messages = json.loads(messages)
messages.append({"role":"user","content":kid_text})
messages.append({"role":"assistant","content":llm_text}) messages.append({"role":"assistant","content":llm_text})
return json.dumps(messages,ensure_ascii=False) return json.dumps(messages,ensure_ascii=False)
# -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 引入logger对象
logger = get_logger()
# 创建FastAPI实例 # 创建FastAPI实例
app = FastAPI() app = FastAPI()
@ -89,6 +95,7 @@ async def update_assistant_system_prompt(id: str,request: update_assistant_syste
assistant = db.query(Assistant).filter(Assistant.id == id).first() assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant: if assistant:
assistant.system_prompt = request.system_prompt assistant.system_prompt = request.system_prompt
assistant.messages = json.dumps([{"role":"system","content":assistant.system_prompt}],ensure_ascii=False)
db.commit() db.commit()
return {"code":200,"msg":"success","data":{}} return {"code":200,"msg":"success","data":{}}
else: else:
@ -163,11 +170,13 @@ async def update_user(id: str,request: update_user_request,db=Depends(get_db)):
@app.websocket("/api/chat/streaming/temporary") @app.websocket("/api/chat/streaming/temporary")
async def streaming_chat(ws: WebSocket,db=Depends(get_db)): async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
await ws.accept() await ws.accept()
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS) logger.debug("WebSocket连接成功")
assistant = None
asr_results = []
llm_text = ""
try: try:
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
assistant = None
asr_results = []
llm_text = ""
logger.debug("开始进行ASR识别")
while len(asr_results)==0: while len(asr_results)==0:
chunk = json.loads(await ws.receive_text()) chunk = json.loads(await ws.receive_text())
if assistant is None: if assistant is None:
@ -175,24 +184,36 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
agent.init_recorder(assistant.user_id) agent.init_recorder(assistant.user_id)
chunk["audio"] = agent.user_audio_process(chunk["audio"], db) chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
asr_results = await agent.stream_recognize(chunk, db) asr_results = await agent.stream_recognize(chunk, db)
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
logger.debug(f"ASR识别成功识别结果为{kid_text}")
prompt = agent.prompt_process(asr_results, db)
agent.recorder.input_text = prompt
logger.debug("开始调用大模型")
llm_frames = await agent.chat(assistant, prompt, db)
async for llm_frame in llm_frames:
resp_msgs = agent.llm_msg_process(llm_frame, db)
for resp_msg in resp_msgs:
llm_text += resp_msg
tts_audio = agent.synthetize(assistant, resp_msg, db)
agent.tts_audio_process(tts_audio, db)
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
logger.debug(f'websocket返回{resp_msg}')
logger.debug(f"大模型返回结束,返回结果为:{llm_text}")
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
logger.debug("结束帧发送完毕")
assistant.messages = update_messages(assistant.messages, kid_text ,llm_text)
db.commit()
logger.debug("聊天更新成功")
agent.recorder.output_text = llm_text
agent.save()
logger.debug("音频保存成功")
except AsrResultNoneError: except AsrResultNoneError:
await ws.send_text(json.dumps({"type":"close","code":201,"msg":""}, ensure_ascii=False)) await ws.send_text(json.dumps({"type":"close","code":201,"msg":"asr结果为空"}, ensure_ascii=False))
return except AbnormalLLMFrame as e:
prompt = agent.prompt_process(asr_results, db) await ws.send_text(json.dumps({"type":"close","code":202,"msg":str(e)}, ensure_ascii=False))
agent.recorder.input_text = prompt except SideNoiseError as e:
llm_frames = await agent.chat(assistant, prompt, db) await ws.send_text(json.dumps({"type":"close","code":203,"msg":str(e)}, ensure_ascii=False))
async for llm_frame in llm_frames: logger.debug("WebSocket连接断开")
resp_msgs = agent.llm_msg_process(llm_frame, db)
for resp_msg in resp_msgs:
llm_text += resp_msg
tts_audio = agent.synthetize(assistant, resp_msg, db)
agent.tts_audio_process(tts_audio, db)
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
assistant.messages = update_messages(assistant.messages, llm_text)
db.commit()
agent.recorder.output_text = llm_text
agent.save()
await ws.close() await ws.close()
# -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

View File

@ -1,14 +1,14 @@
import json import json
import time
import base64 import base64
from datetime import datetime from datetime import datetime
import io
from websocket import create_connection from websocket import create_connection
data = { data = {
"text": "", "text": "",
"audio": "", "audio": "",
"meta_info": { "meta_info": {
"session_id":"a36c9bb4-e813-4f0e-9c75-18e049c60f48", "session_id":"469f4a99-12a5-45a6-bc91-353df07423b6",
"stream": True, "stream": True,
"voice_synthesize": True, "voice_synthesize": True,
"is_end": False, "is_end": False,
@ -48,10 +48,17 @@ def send_json():
# 发送最后一个数据块和流结束信号 # 发送最后一个数据块和流结束信号
send_audio_chunk(websocket, b'') # 发送空数据块表示结束 send_audio_chunk(websocket, b'') # 发送空数据块表示结束
# 等待并打印接收到的数据 # 等待并打印接收到的数据
print("等待接收:", datetime.now()) wait_start_time = time.time()
print("开始等待返回数据:", datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f'))
audio_bytes = b'' audio_bytes = b''
is_receive_first_frame = False
while True: while True:
data_ws = websocket.recv() data_ws = websocket.recv()
if not is_receive_first_frame:
wait_end_time = time.time()
print("等待时间:", wait_end_time - wait_start_time)
is_receive_first_frame = True
try: try:
message_json = json.loads(data_ws) message_json = json.loads(data_ws)
print(message_json) # 打印接收到的消息 print(message_json) # 打印接收到的消息
@ -59,10 +66,34 @@ def send_json():
break # 如果没有接收到消息,则退出循环 break # 如果没有接收到消息,则退出循环
except Exception as e: except Exception as e:
audio_bytes += data_ws audio_bytes += data_ws
# print(e)
print(e)
print("接收完毕:", datetime.now()) print("接收完毕:", datetime.now())
websocket.close() websocket.close()
# 模拟检测到杂音时只发一帧end
def send_one_end_frame():
websocket = create_connection('ws://114.214.236.207:7878/api/chat/streaming/temporary')
data["meta_info"]["is_end"] = True
send_audio_chunk(websocket, b'') # 发送空数据块表示结束
audio_bytes = b''
is_receive_first_frame = False
while True:
data_ws = websocket.recv()
if not is_receive_first_frame:
wait_end_time = time.time()
is_receive_first_frame = True
try:
message_json = json.loads(data_ws)
print(message_json) # 打印接收到的消息
if message_json["type"] == "close":
break # 如果没有接收到消息,则退出循环
except Exception as e:
audio_bytes += data_ws
# print(e)
websocket.close()
# 启动事件循环 # 启动事件循环
send_json() # send_json()
send_one_end_frame()

View File

@ -9,21 +9,6 @@ from .vits import utils, commons
from .vits.models import SynthesizerTrn from .vits.models import SynthesizerTrn
from .vits.text import text_to_sequence from .vits.text import text_to_sequence
def tts_model_init(model_path='./vits_model', device='cuda'):
hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))
# hps_ms = utils.get_hparams_from_file('vits_model/config.json')
net_g_ms = SynthesizerTrn(
len(hps_ms.symbols),
hps_ms.data.filter_length // 2 + 1,
hps_ms.train.segment_size // hps_ms.data.hop_length,
n_speakers=hps_ms.data.n_speakers,
**hps_ms.model)
net_g_ms = net_g_ms.eval().to(device)
speakers = hps_ms.speakers
utils.load_checkpoint(os.path.join(model_path, 'G_953000.pth'), net_g_ms, None)
# utils.load_checkpoint('vits_model/G_953000.pth', net_g_ms, None)
return hps_ms, net_g_ms, speakers
class TextToSpeech: class TextToSpeech:
def __init__(self, def __init__(self,
model_path="./utils/vits_model", model_path="./utils/vits_model",
@ -36,6 +21,11 @@ class TextToSpeech:
self.device = torch.device(device) self.device = torch.device(device)
self.limitation = os.getenv("SYSTEM") == "spaces" # 在huggingface spaces中限制文本和音频长度 self.limitation = os.getenv("SYSTEM") == "spaces" # 在huggingface spaces中限制文本和音频长度
self.hps_ms, self.net_g_ms, self.speakers = self._tts_model_init(model_path) self.hps_ms, self.net_g_ms, self.speakers = self._tts_model_init(model_path)
self._init_jieba()
def _init_jieba(self):
text = self._preprocess_text("初始化", 0)
self._generate_audio(text, 100, 0.6, 0.668, 1.0)
def _tts_model_init(self, model_path): def _tts_model_init(self, model_path):
hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json')) hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))