debug: 修复了多个bug
1.完成了端侧发送杂音识别 2.修复minimax返回帧解析错误bug 3.修复了返回过慢的bug
This commit is contained in:
parent
02272c8a8b
commit
dba43836b6
|
@ -3,4 +3,8 @@ __pycache__/
|
|||
/utils/vits_model/config.json
|
||||
/utils/vits_model/G_953000.pth
|
||||
|
||||
takway.db
|
||||
takway.db
|
||||
|
||||
/storage/**
|
||||
|
||||
/app.log
|
|
@ -2,6 +2,8 @@ from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_
|
|||
from .model import Assistant
|
||||
from .abstract import *
|
||||
from .public import *
|
||||
from .exception import *
|
||||
from .dependency import get_logger
|
||||
from utils.vits_utils import TextToSpeech
|
||||
from config import Config
|
||||
import aiohttp
|
||||
|
@ -14,6 +16,9 @@ import json
|
|||
vits = TextToSpeech()
|
||||
# ---------------------------------- #
|
||||
|
||||
# ---------- 初始化logger ---------- #
|
||||
logger = get_logger()
|
||||
# ---------------------------------- #
|
||||
|
||||
|
||||
#------ 具体 ASR, LLM, TTS 类 ------ #
|
||||
|
@ -31,6 +36,8 @@ class XF_ASR(ASR):
|
|||
self.segment_start_time = None
|
||||
|
||||
async def stream_recognize(self, chunk):
|
||||
if self.status == FIRST_FRAME and chunk['meta_info']['is_end']: #如果是第一帧,且为end,则判断为杂音
|
||||
raise SideNoiseError()
|
||||
if self.websocket is None: #如果websocket未建立,则建立一个新的连接
|
||||
self.websocket = await xf_asr_websocket_factory()
|
||||
if self.segment_start_time is None: #如果是第一段,则记录开始时间
|
||||
|
@ -50,8 +57,7 @@ class XF_ASR(ASR):
|
|||
self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv()))
|
||||
if self.current_message == "":
|
||||
raise AsrResultNoneError()
|
||||
await self.websocket.close()
|
||||
print("语音识别结束,用户消息:", self.current_message)
|
||||
asyncio.create_task(self.websocket.close())
|
||||
return [{"text":self.current_message, "audio":self.audio}]
|
||||
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
|
@ -71,8 +77,7 @@ class MINIMAX_LLM(LLM):
|
|||
async def chat(self, assistant, prompt, db):
|
||||
llm_info = json.loads(assistant.llm_info)
|
||||
messages = json.loads(assistant.messages)
|
||||
messages.append({"role":"user","content":prompt})
|
||||
assistant.messages = json.dumps(messages)
|
||||
messages.append({'role':'user','content':prompt})
|
||||
payload = json.dumps({
|
||||
"model": llm_info['model'],
|
||||
"stream": True,
|
||||
|
@ -104,17 +109,26 @@ class MINIMAX_LLM(LLM):
|
|||
|
||||
|
||||
def __parseChunk(self, llm_chunk):
|
||||
result = ""
|
||||
data=json.loads(llm_chunk.decode('utf-8')[6:])
|
||||
if data["object"] == "chat.completion": #如果是结束帧
|
||||
self.token = data['usage']['total_tokens']
|
||||
raise LLMResponseEnd()
|
||||
elif data['object'] == 'chat.completion.chunk':
|
||||
for choice in data['choices']:
|
||||
result += choice['delta']['content']
|
||||
else:
|
||||
raise UnkownLLMFrame()
|
||||
return result
|
||||
try:
|
||||
result = ""
|
||||
chunk_decoded = llm_chunk.decode('utf-8')
|
||||
chunks = chunk_decoded.split('\n\n')
|
||||
for chunk in chunks:
|
||||
if not chunk:
|
||||
continue
|
||||
data=json.loads(chunk[6:])
|
||||
if data["object"] == "chat.completion": #如果是结束帧
|
||||
self.token = data['usage']['total_tokens']
|
||||
raise LLMResponseEnd()
|
||||
elif data['object'] == 'chat.completion.chunk':
|
||||
for choice in data['choices']:
|
||||
result += choice['delta']['content']
|
||||
else:
|
||||
raise UnkownLLMFrame()
|
||||
return result
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
logger.error(llm_chunk)
|
||||
raise AbnormalLLMFrame(f"error llm_chunk:{llm_chunk}")
|
||||
|
||||
class VITS_TTS(TTS):
|
||||
def __init__(self):
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
from config import Config
|
||||
import logging
|
||||
|
||||
#日志类
|
||||
class Logger:
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger.setLevel(Config.LOG_LEVEL)
|
||||
self.logger.propagate = False
|
||||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
if not self.logger.handlers: # 检查是否已经有处理器
|
||||
# 输出到控制台
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
self.logger.addHandler(console_handler)
|
||||
|
||||
# 输出到文件
|
||||
file_handler = logging.FileHandler('app.log')
|
||||
file_handler.setFormatter(formatter)
|
||||
self.logger.addHandler(file_handler)
|
||||
|
||||
logger = None
|
||||
|
||||
def get_logger():
|
||||
global logger
|
||||
if logger is None:
|
||||
logger = Logger()
|
||||
return logger.logger
|
|
@ -0,0 +1,41 @@
|
|||
# Asr结果为空异常
|
||||
class AsrResultNoneError(Exception):
|
||||
def __init__(self, message="Asr Result is None!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
# 如果asr_results中没有结果异常
|
||||
class NoAsrResultsError(Exception):
|
||||
def __init__(self, message="No Asr Results!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
# 未知LLM返回帧
|
||||
class UnkownLLMFrame(Exception):
|
||||
def __init__(self, message="Unkown LLM Frame!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
# 异常LLM返回帧
|
||||
class AbnormalLLMFrame(Exception):
|
||||
def __init__(self, message="Abnormal LLM Frame!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
# token超出阈值异常
|
||||
class TokenOutofRangeError(Exception):
|
||||
def __init__(self, message="Token Out of Range!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
# 接收到端侧杂音异常
|
||||
class SideNoiseError(Exception):
|
||||
def __init__(self, message="Side Noise!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
# 大模型返回结束(非异常)
|
||||
class LLMResponseEnd(Exception):
|
||||
def __init__(self, message="LLM Response End!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
|
@ -2,32 +2,7 @@ from datetime import datetime
|
|||
import wave
|
||||
import json
|
||||
|
||||
# -------------- 公共类 ------------ #
|
||||
class AsrResultNoneError(Exception):
|
||||
def __init__(self, message="Asr Result is None!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
class NoAsrResultsError(Exception):
|
||||
def __init__(self, message="No Asr Results!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
class LLMResponseEnd(Exception):
|
||||
def __init__(self, message="LLM Response End!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
class UnkownLLMFrame(Exception):
|
||||
def __init__(self, message="Unkown LLM Frame!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
class TokenOutofRangeError(Exception):
|
||||
def __init__(self, message="Token Out of Range!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
# -------------- 公共类 ------------ #
|
||||
class SentenceSegmentation():
|
||||
def __init__(self,):
|
||||
self.is_first_sentence = True
|
||||
|
|
|
@ -3,6 +3,7 @@ class Config:
|
|||
ASR = "XF" #在此处选择语音识别引擎
|
||||
LLM = "MINIMAX" #在此处选择大模型
|
||||
TTS = "VITS" #在此处选择语音合成引擎
|
||||
LOG_LEVEL = "DEBUG"
|
||||
class UVICORN:
|
||||
HOST = '0.0.0.0'
|
||||
PORT = 7878
|
||||
|
@ -16,5 +17,4 @@ class Config:
|
|||
VAD_EOS = 10000
|
||||
class MINIMAX_LLM:
|
||||
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLph5EiLCJVc2VyTmFtZSI6IumHkSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzY4NTM2NDM3MzE1MDgwODg2IiwiUGhvbmUiOiIxMzEzNjE0NzUyNyIsIkdyb3VwSUQiOiIxNzY4NTM2NDM3MzA2NjkyMjc4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjQtMDUtMTggMTY6MTQ6MDMiLCJpc3MiOiJtaW5pbWF4In0.LypYOkJXwKV6GzDM1dcNn4L0m19o8Q_Lvmn6SkMMb9WAfDJYxEnTc5odm-L4WAWfbur_gY0cQzgoHnI14t4XSaAvqfmcdCrKYpJbKoBmMse_RogJs7KOBt658je3wES4pBUKQll6NbogQB1f93lnA9IYv4aEVldfqglbCikd54XO8E9Ptn4gX9Mp8fUn3lCpZ6_OSlmgZsQySrmt1sDHHzi3DlkdXlFSI38TQSZIa5RhFpI8WSBLIbaKl84OhaDzo7v99k9DUCzb5JGh0eZOnUT0YswbKCPeV8rZ1XUiOVQrna1uiDLvqv54aIt3vsu-LypYmnHxtZ_z4u2gt87pZg"
|
||||
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
||||
|
||||
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
67
main.py
67
main.py
|
@ -1,20 +1,26 @@
|
|||
from fastapi import FastAPI, Depends, WebSocket, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from config import Config
|
||||
from app.concrete import Agent, AsrResultNoneError
|
||||
from app.concrete import Agent
|
||||
from app.model import Assistant, User, get_db
|
||||
from app.schemas import *
|
||||
from app.dependency import get_logger
|
||||
from app.exception import AsrResultNoneError, AbnormalLLMFrame, SideNoiseError
|
||||
import uvicorn
|
||||
import uuid
|
||||
import json
|
||||
|
||||
# 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
def update_messages(messages, llm_text):
|
||||
def update_messages(messages, kid_text,llm_text):
|
||||
messages = json.loads(messages)
|
||||
messages.append({"role":"user","content":kid_text})
|
||||
messages.append({"role":"assistant","content":llm_text})
|
||||
return json.dumps(messages,ensure_ascii=False)
|
||||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
# 引入logger对象
|
||||
logger = get_logger()
|
||||
|
||||
# 创建FastAPI实例
|
||||
app = FastAPI()
|
||||
|
||||
|
@ -89,6 +95,7 @@ async def update_assistant_system_prompt(id: str,request: update_assistant_syste
|
|||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if assistant:
|
||||
assistant.system_prompt = request.system_prompt
|
||||
assistant.messages = json.dumps([{"role":"system","content":assistant.system_prompt}],ensure_ascii=False)
|
||||
db.commit()
|
||||
return {"code":200,"msg":"success","data":{}}
|
||||
else:
|
||||
|
@ -163,11 +170,13 @@ async def update_user(id: str,request: update_user_request,db=Depends(get_db)):
|
|||
@app.websocket("/api/chat/streaming/temporary")
|
||||
async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
|
||||
await ws.accept()
|
||||
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
|
||||
assistant = None
|
||||
asr_results = []
|
||||
llm_text = ""
|
||||
logger.debug("WebSocket连接成功")
|
||||
try:
|
||||
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
|
||||
assistant = None
|
||||
asr_results = []
|
||||
llm_text = ""
|
||||
logger.debug("开始进行ASR识别")
|
||||
while len(asr_results)==0:
|
||||
chunk = json.loads(await ws.receive_text())
|
||||
if assistant is None:
|
||||
|
@ -175,24 +184,36 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
|
|||
agent.init_recorder(assistant.user_id)
|
||||
chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
|
||||
asr_results = await agent.stream_recognize(chunk, db)
|
||||
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
|
||||
logger.debug(f"ASR识别成功,识别结果为:{kid_text}")
|
||||
prompt = agent.prompt_process(asr_results, db)
|
||||
agent.recorder.input_text = prompt
|
||||
logger.debug("开始调用大模型")
|
||||
llm_frames = await agent.chat(assistant, prompt, db)
|
||||
async for llm_frame in llm_frames:
|
||||
resp_msgs = agent.llm_msg_process(llm_frame, db)
|
||||
for resp_msg in resp_msgs:
|
||||
llm_text += resp_msg
|
||||
tts_audio = agent.synthetize(assistant, resp_msg, db)
|
||||
agent.tts_audio_process(tts_audio, db)
|
||||
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
|
||||
logger.debug(f'websocket返回:{resp_msg}')
|
||||
logger.debug(f"大模型返回结束,返回结果为:{llm_text}")
|
||||
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||
logger.debug("结束帧发送完毕")
|
||||
assistant.messages = update_messages(assistant.messages, kid_text ,llm_text)
|
||||
db.commit()
|
||||
logger.debug("聊天更新成功")
|
||||
agent.recorder.output_text = llm_text
|
||||
agent.save()
|
||||
logger.debug("音频保存成功")
|
||||
except AsrResultNoneError:
|
||||
await ws.send_text(json.dumps({"type":"close","code":201,"msg":""}, ensure_ascii=False))
|
||||
return
|
||||
prompt = agent.prompt_process(asr_results, db)
|
||||
agent.recorder.input_text = prompt
|
||||
llm_frames = await agent.chat(assistant, prompt, db)
|
||||
async for llm_frame in llm_frames:
|
||||
resp_msgs = agent.llm_msg_process(llm_frame, db)
|
||||
for resp_msg in resp_msgs:
|
||||
llm_text += resp_msg
|
||||
tts_audio = agent.synthetize(assistant, resp_msg, db)
|
||||
agent.tts_audio_process(tts_audio, db)
|
||||
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
|
||||
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||
assistant.messages = update_messages(assistant.messages, llm_text)
|
||||
db.commit()
|
||||
agent.recorder.output_text = llm_text
|
||||
agent.save()
|
||||
await ws.send_text(json.dumps({"type":"close","code":201,"msg":"asr结果为空"}, ensure_ascii=False))
|
||||
except AbnormalLLMFrame as e:
|
||||
await ws.send_text(json.dumps({"type":"close","code":202,"msg":str(e)}, ensure_ascii=False))
|
||||
except SideNoiseError as e:
|
||||
await ws.send_text(json.dumps({"type":"close","code":203,"msg":str(e)}, ensure_ascii=False))
|
||||
logger.debug("WebSocket连接断开")
|
||||
await ws.close()
|
||||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
|
43
test/test.py
43
test/test.py
|
@ -1,14 +1,14 @@
|
|||
import json
|
||||
import time
|
||||
import base64
|
||||
from datetime import datetime
|
||||
import io
|
||||
from websocket import create_connection
|
||||
|
||||
data = {
|
||||
"text": "",
|
||||
"audio": "",
|
||||
"meta_info": {
|
||||
"session_id":"a36c9bb4-e813-4f0e-9c75-18e049c60f48",
|
||||
"session_id":"469f4a99-12a5-45a6-bc91-353df07423b6",
|
||||
"stream": True,
|
||||
"voice_synthesize": True,
|
||||
"is_end": False,
|
||||
|
@ -48,10 +48,17 @@ def send_json():
|
|||
# 发送最后一个数据块和流结束信号
|
||||
send_audio_chunk(websocket, b'') # 发送空数据块表示结束
|
||||
# 等待并打印接收到的数据
|
||||
print("等待接收:", datetime.now())
|
||||
wait_start_time = time.time()
|
||||
print("开始等待返回数据:", datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f'))
|
||||
|
||||
audio_bytes = b''
|
||||
is_receive_first_frame = False
|
||||
while True:
|
||||
data_ws = websocket.recv()
|
||||
if not is_receive_first_frame:
|
||||
wait_end_time = time.time()
|
||||
print("等待时间:", wait_end_time - wait_start_time)
|
||||
is_receive_first_frame = True
|
||||
try:
|
||||
message_json = json.loads(data_ws)
|
||||
print(message_json) # 打印接收到的消息
|
||||
|
@ -59,10 +66,34 @@ def send_json():
|
|||
break # 如果没有接收到消息,则退出循环
|
||||
except Exception as e:
|
||||
audio_bytes += data_ws
|
||||
|
||||
print(e)
|
||||
# print(e)
|
||||
print("接收完毕:", datetime.now())
|
||||
websocket.close()
|
||||
|
||||
# 模拟检测到杂音时,只发一帧end
|
||||
def send_one_end_frame():
|
||||
websocket = create_connection('ws://114.214.236.207:7878/api/chat/streaming/temporary')
|
||||
data["meta_info"]["is_end"] = True
|
||||
send_audio_chunk(websocket, b'') # 发送空数据块表示结束
|
||||
audio_bytes = b''
|
||||
is_receive_first_frame = False
|
||||
while True:
|
||||
data_ws = websocket.recv()
|
||||
if not is_receive_first_frame:
|
||||
wait_end_time = time.time()
|
||||
is_receive_first_frame = True
|
||||
try:
|
||||
message_json = json.loads(data_ws)
|
||||
print(message_json) # 打印接收到的消息
|
||||
if message_json["type"] == "close":
|
||||
break # 如果没有接收到消息,则退出循环
|
||||
except Exception as e:
|
||||
audio_bytes += data_ws
|
||||
# print(e)
|
||||
websocket.close()
|
||||
|
||||
# 启动事件循环
|
||||
send_json()
|
||||
# send_json()
|
||||
|
||||
|
||||
send_one_end_frame()
|
|
@ -9,21 +9,6 @@ from .vits import utils, commons
|
|||
from .vits.models import SynthesizerTrn
|
||||
from .vits.text import text_to_sequence
|
||||
|
||||
def tts_model_init(model_path='./vits_model', device='cuda'):
|
||||
hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))
|
||||
# hps_ms = utils.get_hparams_from_file('vits_model/config.json')
|
||||
net_g_ms = SynthesizerTrn(
|
||||
len(hps_ms.symbols),
|
||||
hps_ms.data.filter_length // 2 + 1,
|
||||
hps_ms.train.segment_size // hps_ms.data.hop_length,
|
||||
n_speakers=hps_ms.data.n_speakers,
|
||||
**hps_ms.model)
|
||||
net_g_ms = net_g_ms.eval().to(device)
|
||||
speakers = hps_ms.speakers
|
||||
utils.load_checkpoint(os.path.join(model_path, 'G_953000.pth'), net_g_ms, None)
|
||||
# utils.load_checkpoint('vits_model/G_953000.pth', net_g_ms, None)
|
||||
return hps_ms, net_g_ms, speakers
|
||||
|
||||
class TextToSpeech:
|
||||
def __init__(self,
|
||||
model_path="./utils/vits_model",
|
||||
|
@ -36,6 +21,11 @@ class TextToSpeech:
|
|||
self.device = torch.device(device)
|
||||
self.limitation = os.getenv("SYSTEM") == "spaces" # 在huggingface spaces中限制文本和音频长度
|
||||
self.hps_ms, self.net_g_ms, self.speakers = self._tts_model_init(model_path)
|
||||
self._init_jieba()
|
||||
|
||||
def _init_jieba(self):
|
||||
text = self._preprocess_text("初始化", 0)
|
||||
self._generate_audio(text, 100, 0.6, 0.668, 1.0)
|
||||
|
||||
def _tts_model_init(self, model_path):
|
||||
hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))
|
||||
|
|
Loading…
Reference in New Issue