debug: 修复了多个bug

1.完成了端侧发送杂音识别
2.修复minimax返回帧解析错误bug
3.修复了返回过慢的bug
This commit is contained in:
killua4396 2024-06-10 02:28:16 +08:00
parent 02272c8a8b
commit dba43836b6
9 changed files with 193 additions and 88 deletions

6
.gitignore vendored
View File

@ -3,4 +3,8 @@ __pycache__/
/utils/vits_model/config.json
/utils/vits_model/G_953000.pth
takway.db
takway.db
/storage/**
/app.log

View File

@ -2,6 +2,8 @@ from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_
from .model import Assistant
from .abstract import *
from .public import *
from .exception import *
from .dependency import get_logger
from utils.vits_utils import TextToSpeech
from config import Config
import aiohttp
@ -14,6 +16,9 @@ import json
vits = TextToSpeech()
# ---------------------------------- #
# ---------- 初始化logger ---------- #
logger = get_logger()
# ---------------------------------- #
#------ 具体 ASR, LLM, TTS 类 ------ #
@ -31,6 +36,8 @@ class XF_ASR(ASR):
self.segment_start_time = None
async def stream_recognize(self, chunk):
if self.status == FIRST_FRAME and chunk['meta_info']['is_end']: #如果是第一帧且为end则判断为杂音
raise SideNoiseError()
if self.websocket is None: #如果websocket未建立则建立一个新的连接
self.websocket = await xf_asr_websocket_factory()
if self.segment_start_time is None: #如果是第一段,则记录开始时间
@ -50,8 +57,7 @@ class XF_ASR(ASR):
self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv()))
if self.current_message == "":
raise AsrResultNoneError()
await self.websocket.close()
print("语音识别结束,用户消息:", self.current_message)
asyncio.create_task(self.websocket.close())
return [{"text":self.current_message, "audio":self.audio}]
current_time = asyncio.get_event_loop().time()
@ -71,8 +77,7 @@ class MINIMAX_LLM(LLM):
async def chat(self, assistant, prompt, db):
llm_info = json.loads(assistant.llm_info)
messages = json.loads(assistant.messages)
messages.append({"role":"user","content":prompt})
assistant.messages = json.dumps(messages)
messages.append({'role':'user','content':prompt})
payload = json.dumps({
"model": llm_info['model'],
"stream": True,
@ -104,17 +109,26 @@ class MINIMAX_LLM(LLM):
def __parseChunk(self, llm_chunk):
result = ""
data=json.loads(llm_chunk.decode('utf-8')[6:])
if data["object"] == "chat.completion": #如果是结束帧
self.token = data['usage']['total_tokens']
raise LLMResponseEnd()
elif data['object'] == 'chat.completion.chunk':
for choice in data['choices']:
result += choice['delta']['content']
else:
raise UnkownLLMFrame()
return result
try:
result = ""
chunk_decoded = llm_chunk.decode('utf-8')
chunks = chunk_decoded.split('\n\n')
for chunk in chunks:
if not chunk:
continue
data=json.loads(chunk[6:])
if data["object"] == "chat.completion": #如果是结束帧
self.token = data['usage']['total_tokens']
raise LLMResponseEnd()
elif data['object'] == 'chat.completion.chunk':
for choice in data['choices']:
result += choice['delta']['content']
else:
raise UnkownLLMFrame()
return result
except (json.JSONDecodeError, KeyError):
logger.error(llm_chunk)
raise AbnormalLLMFrame(f"error llm_chunk:{llm_chunk}")
class VITS_TTS(TTS):
def __init__(self):

29
app/dependency.py Normal file
View File

@ -0,0 +1,29 @@
from config import Config
import logging
#日志类
class Logger:
def __init__(self):
self.logger = logging.getLogger(__name__)
self.logger.setLevel(Config.LOG_LEVEL)
self.logger.propagate = False
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
if not self.logger.handlers: # 检查是否已经有处理器
# 输出到控制台
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
# 输出到文件
file_handler = logging.FileHandler('app.log')
file_handler.setFormatter(formatter)
self.logger.addHandler(file_handler)
logger = None
def get_logger():
global logger
if logger is None:
logger = Logger()
return logger.logger

41
app/exception.py Normal file
View File

@ -0,0 +1,41 @@
# Asr结果为空异常
class AsrResultNoneError(Exception):
def __init__(self, message="Asr Result is None!"):
super().__init__(message)
self.message = message
# 如果asr_results中没有结果异常
class NoAsrResultsError(Exception):
def __init__(self, message="No Asr Results!"):
super().__init__(message)
self.message = message
# 未知LLM返回帧
class UnkownLLMFrame(Exception):
def __init__(self, message="Unkown LLM Frame!"):
super().__init__(message)
self.message = message
# 异常LLM返回帧
class AbnormalLLMFrame(Exception):
def __init__(self, message="Abnormal LLM Frame!"):
super().__init__(message)
self.message = message
# token超出阈值异常
class TokenOutofRangeError(Exception):
def __init__(self, message="Token Out of Range!"):
super().__init__(message)
self.message = message
# 接收到端侧杂音异常
class SideNoiseError(Exception):
def __init__(self, message="Side Noise!"):
super().__init__(message)
self.message = message
# 大模型返回结束(非异常)
class LLMResponseEnd(Exception):
def __init__(self, message="LLM Response End!"):
super().__init__(message)
self.message = message

View File

@ -2,32 +2,7 @@ from datetime import datetime
import wave
import json
# -------------- 公共类 ------------ #
class AsrResultNoneError(Exception):
def __init__(self, message="Asr Result is None!"):
super().__init__(message)
self.message = message
class NoAsrResultsError(Exception):
def __init__(self, message="No Asr Results!"):
super().__init__(message)
self.message = message
class LLMResponseEnd(Exception):
def __init__(self, message="LLM Response End!"):
super().__init__(message)
self.message = message
class UnkownLLMFrame(Exception):
def __init__(self, message="Unkown LLM Frame!"):
super().__init__(message)
self.message = message
class TokenOutofRangeError(Exception):
def __init__(self, message="Token Out of Range!"):
super().__init__(message)
self.message = message
# -------------- 公共类 ------------ #
class SentenceSegmentation():
def __init__(self,):
self.is_first_sentence = True

View File

@ -3,6 +3,7 @@ class Config:
ASR = "XF" #在此处选择语音识别引擎
LLM = "MINIMAX" #在此处选择大模型
TTS = "VITS" #在此处选择语音合成引擎
LOG_LEVEL = "DEBUG"
class UVICORN:
HOST = '0.0.0.0'
PORT = 7878
@ -16,5 +17,4 @@ class Config:
VAD_EOS = 10000
class MINIMAX_LLM:
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLph5EiLCJVc2VyTmFtZSI6IumHkSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzY4NTM2NDM3MzE1MDgwODg2IiwiUGhvbmUiOiIxMzEzNjE0NzUyNyIsIkdyb3VwSUQiOiIxNzY4NTM2NDM3MzA2NjkyMjc4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjQtMDUtMTggMTY6MTQ6MDMiLCJpc3MiOiJtaW5pbWF4In0.LypYOkJXwKV6GzDM1dcNn4L0m19o8Q_Lvmn6SkMMb9WAfDJYxEnTc5odm-L4WAWfbur_gY0cQzgoHnI14t4XSaAvqfmcdCrKYpJbKoBmMse_RogJs7KOBt658je3wES4pBUKQll6NbogQB1f93lnA9IYv4aEVldfqglbCikd54XO8E9Ptn4gX9Mp8fUn3lCpZ6_OSlmgZsQySrmt1sDHHzi3DlkdXlFSI38TQSZIa5RhFpI8WSBLIbaKl84OhaDzo7v99k9DUCzb5JGh0eZOnUT0YswbKCPeV8rZ1XUiOVQrna1uiDLvqv54aIt3vsu-LypYmnHxtZ_z4u2gt87pZg"
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"

67
main.py
View File

@ -1,20 +1,26 @@
from fastapi import FastAPI, Depends, WebSocket, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from config import Config
from app.concrete import Agent, AsrResultNoneError
from app.concrete import Agent
from app.model import Assistant, User, get_db
from app.schemas import *
from app.dependency import get_logger
from app.exception import AsrResultNoneError, AbnormalLLMFrame, SideNoiseError
import uvicorn
import uuid
import json
# 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------
def update_messages(messages, llm_text):
def update_messages(messages, kid_text,llm_text):
messages = json.loads(messages)
messages.append({"role":"user","content":kid_text})
messages.append({"role":"assistant","content":llm_text})
return json.dumps(messages,ensure_ascii=False)
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 引入logger对象
logger = get_logger()
# 创建FastAPI实例
app = FastAPI()
@ -89,6 +95,7 @@ async def update_assistant_system_prompt(id: str,request: update_assistant_syste
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
assistant.system_prompt = request.system_prompt
assistant.messages = json.dumps([{"role":"system","content":assistant.system_prompt}],ensure_ascii=False)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
@ -163,11 +170,13 @@ async def update_user(id: str,request: update_user_request,db=Depends(get_db)):
@app.websocket("/api/chat/streaming/temporary")
async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
await ws.accept()
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
assistant = None
asr_results = []
llm_text = ""
logger.debug("WebSocket连接成功")
try:
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
assistant = None
asr_results = []
llm_text = ""
logger.debug("开始进行ASR识别")
while len(asr_results)==0:
chunk = json.loads(await ws.receive_text())
if assistant is None:
@ -175,24 +184,36 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
agent.init_recorder(assistant.user_id)
chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
asr_results = await agent.stream_recognize(chunk, db)
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
logger.debug(f"ASR识别成功识别结果为{kid_text}")
prompt = agent.prompt_process(asr_results, db)
agent.recorder.input_text = prompt
logger.debug("开始调用大模型")
llm_frames = await agent.chat(assistant, prompt, db)
async for llm_frame in llm_frames:
resp_msgs = agent.llm_msg_process(llm_frame, db)
for resp_msg in resp_msgs:
llm_text += resp_msg
tts_audio = agent.synthetize(assistant, resp_msg, db)
agent.tts_audio_process(tts_audio, db)
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
logger.debug(f'websocket返回{resp_msg}')
logger.debug(f"大模型返回结束,返回结果为:{llm_text}")
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
logger.debug("结束帧发送完毕")
assistant.messages = update_messages(assistant.messages, kid_text ,llm_text)
db.commit()
logger.debug("聊天更新成功")
agent.recorder.output_text = llm_text
agent.save()
logger.debug("音频保存成功")
except AsrResultNoneError:
await ws.send_text(json.dumps({"type":"close","code":201,"msg":""}, ensure_ascii=False))
return
prompt = agent.prompt_process(asr_results, db)
agent.recorder.input_text = prompt
llm_frames = await agent.chat(assistant, prompt, db)
async for llm_frame in llm_frames:
resp_msgs = agent.llm_msg_process(llm_frame, db)
for resp_msg in resp_msgs:
llm_text += resp_msg
tts_audio = agent.synthetize(assistant, resp_msg, db)
agent.tts_audio_process(tts_audio, db)
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
assistant.messages = update_messages(assistant.messages, llm_text)
db.commit()
agent.recorder.output_text = llm_text
agent.save()
await ws.send_text(json.dumps({"type":"close","code":201,"msg":"asr结果为空"}, ensure_ascii=False))
except AbnormalLLMFrame as e:
await ws.send_text(json.dumps({"type":"close","code":202,"msg":str(e)}, ensure_ascii=False))
except SideNoiseError as e:
await ws.send_text(json.dumps({"type":"close","code":203,"msg":str(e)}, ensure_ascii=False))
logger.debug("WebSocket连接断开")
await ws.close()
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

View File

@ -1,14 +1,14 @@
import json
import time
import base64
from datetime import datetime
import io
from websocket import create_connection
data = {
"text": "",
"audio": "",
"meta_info": {
"session_id":"a36c9bb4-e813-4f0e-9c75-18e049c60f48",
"session_id":"469f4a99-12a5-45a6-bc91-353df07423b6",
"stream": True,
"voice_synthesize": True,
"is_end": False,
@ -48,10 +48,17 @@ def send_json():
# 发送最后一个数据块和流结束信号
send_audio_chunk(websocket, b'') # 发送空数据块表示结束
# 等待并打印接收到的数据
print("等待接收:", datetime.now())
wait_start_time = time.time()
print("开始等待返回数据:", datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f'))
audio_bytes = b''
is_receive_first_frame = False
while True:
data_ws = websocket.recv()
if not is_receive_first_frame:
wait_end_time = time.time()
print("等待时间:", wait_end_time - wait_start_time)
is_receive_first_frame = True
try:
message_json = json.loads(data_ws)
print(message_json) # 打印接收到的消息
@ -59,10 +66,34 @@ def send_json():
break # 如果没有接收到消息,则退出循环
except Exception as e:
audio_bytes += data_ws
print(e)
# print(e)
print("接收完毕:", datetime.now())
websocket.close()
# 模拟检测到杂音时只发一帧end
def send_one_end_frame():
websocket = create_connection('ws://114.214.236.207:7878/api/chat/streaming/temporary')
data["meta_info"]["is_end"] = True
send_audio_chunk(websocket, b'') # 发送空数据块表示结束
audio_bytes = b''
is_receive_first_frame = False
while True:
data_ws = websocket.recv()
if not is_receive_first_frame:
wait_end_time = time.time()
is_receive_first_frame = True
try:
message_json = json.loads(data_ws)
print(message_json) # 打印接收到的消息
if message_json["type"] == "close":
break # 如果没有接收到消息,则退出循环
except Exception as e:
audio_bytes += data_ws
# print(e)
websocket.close()
# 启动事件循环
send_json()
# send_json()
send_one_end_frame()

View File

@ -9,21 +9,6 @@ from .vits import utils, commons
from .vits.models import SynthesizerTrn
from .vits.text import text_to_sequence
def tts_model_init(model_path='./vits_model', device='cuda'):
hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))
# hps_ms = utils.get_hparams_from_file('vits_model/config.json')
net_g_ms = SynthesizerTrn(
len(hps_ms.symbols),
hps_ms.data.filter_length // 2 + 1,
hps_ms.train.segment_size // hps_ms.data.hop_length,
n_speakers=hps_ms.data.n_speakers,
**hps_ms.model)
net_g_ms = net_g_ms.eval().to(device)
speakers = hps_ms.speakers
utils.load_checkpoint(os.path.join(model_path, 'G_953000.pth'), net_g_ms, None)
# utils.load_checkpoint('vits_model/G_953000.pth', net_g_ms, None)
return hps_ms, net_g_ms, speakers
class TextToSpeech:
def __init__(self,
model_path="./utils/vits_model",
@ -36,6 +21,11 @@ class TextToSpeech:
self.device = torch.device(device)
self.limitation = os.getenv("SYSTEM") == "spaces" # 在huggingface spaces中限制文本和音频长度
self.hps_ms, self.net_g_ms, self.speakers = self._tts_model_init(model_path)
self._init_jieba()
def _init_jieba(self):
text = self._preprocess_text("初始化", 0)
self._generate_audio(text, 100, 0.6, 0.668, 1.0)
def _tts_model_init(self, model_path):
hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))