debug: 修复语音克隆输出错误的bug

This commit is contained in:
killua4396 2024-05-23 15:19:21 +08:00
parent 3b9cc44e4c
commit e2f3decfae
3 changed files with 37 additions and 19 deletions

View File

@ -15,6 +15,7 @@ import uuid
import json import json
import asyncio import asyncio
import aiohttp import aiohttp
import io
# 依赖注入获取logger # 依赖注入获取logger
logger = get_logger() logger = get_logger()
@ -107,10 +108,10 @@ def get_emb(session_id,db):
try: try:
session_record = db.query(Session).filter(Session.id == session_id).first() session_record = db.query(Session).filter(Session.id == session_id).first()
user_character_record = db.query(UserCharacter).filter(UserCharacter.id == session_record.user_character_id).first() user_character_record = db.query(UserCharacter).filter(UserCharacter.id == session_record.user_character_id).first()
audio_record = db.query(Audio).filter(Audio.user_id == user_character_record.user_id).first() user_record = db.query(User).filter(User.id == user_character_record.user_id).first()
emb_npy = np.frombuffer(audio_record.emb_data,dtype=np.int32) audio_record = db.query(Audio).filter(Audio.id == user_record.selected_audio_id).first()
emb_npy_3d = np.reshape(emb_npy,(1,256,1)) emb_npy = np.load(io.BytesIO(audio_record.emb_data))
return emb_npy_3d return emb_npy
except Exception as e: except Exception as e:
logger.error("未找到音频:"+str(e)) logger.error("未找到音频:"+str(e))
return np.array([]) return np.array([])
@ -149,7 +150,7 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
"noise_scale": 0.1, "noise_scale": 0.1,
"noise_scale_w":0.668, "noise_scale_w":0.668,
"length_scale": 1.2, "length_scale": 1.2,
"speend":1 "speed":1
} }
llm_info = { llm_info = {
"model": "abab5.5-chat", "model": "abab5.5-chat",
@ -243,7 +244,7 @@ async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f
logger.error(f"用户输入处理函数发生错误: {str(e)}") logger.error(f"用户输入处理函数发生错误: {str(e)}")
#语音识别 #语音识别
async def sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event): async def sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event):
logger.debug("语音识别函数启动") logger.debug("语音识别函数启动")
is_signup = False is_signup = False
audio = "" audio = ""
@ -259,6 +260,9 @@ async def sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_
current_message += ''.join(asr_result['text']) current_message += ''.join(asr_result['text'])
asr_result = asr.streaming_recognize(session_id,b'',is_end=True) asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
current_message += ''.join(asr_result['text']) current_message += ''.join(asr_result['text'])
if current_message == "":
await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False))
return
current_message = asr.punctuation_correction(current_message) current_message = asr.punctuation_correction(current_message)
emotion_dict = asr.emtion_recognition(audio) #情感辨识 emotion_dict = asr.emtion_recognition(audio) #情感辨识
if not isinstance(emotion_dict, str): if not isinstance(emotion_dict, str):
@ -366,7 +370,7 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis):
session_id = await future_session_id #获取session_id session_id = await future_session_id #获取session_id
update_session_activity(session_id,db) update_session_activity(session_id,db)
response_type = await future_response_type #获取返回类型 response_type = await future_response_type #获取返回类型
asyncio.create_task(sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event)) asyncio.create_task(sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event))
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])

View File

@ -163,19 +163,17 @@ async def get_hardware_handler(hardware_id, db):
async def upload_audio_handler(user_id, audio, db): async def upload_audio_handler(user_id, audio, db):
try: try:
audio_data = await audio.read() audio_data = await audio.read()
raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data emb_data = tts.audio2emb(np.frombuffer(AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data, dtype=np.int32),rate=44100,vad=True)
numpy_data = np.frombuffer(raw_data, dtype=np.int32) out = io.BytesIO()
emb_data = tts.audio2emb(numpy_data,rate=44100,vad=True).tobytes() np.save(out, emb_data)
new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_data) #创建音频 out.seek(0)
db.flush() emb_binary = out.read()
existing_user = db.query(User).filter(User.id == user_id).first() new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_binary) #创建音频
existing_user.selected_audio_id = new_audio.id #绑定音频到用户
db.add(new_audio) db.add(new_audio)
db.commit() db.commit()
db.refresh(new_audio) db.refresh(new_audio)
except Exception as e: except Exception as e:
db.rollback() db.rollback()
print(str(e))
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat()) audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat())
return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data) return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data)

View File

@ -63,7 +63,7 @@ class ChatServiceTest:
current_file_path = os.path.abspath(__file__) current_file_path = os.path.abspath(__file__)
current_file_path = os.path.dirname(current_file_path) current_file_path = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_file_path) tests_dir = os.path.dirname(current_file_path)
mp3_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3') mp3_file_path = os.path.join(tests_dir, 'assets', 'BarbieDollsVoice.mp3')
with open(mp3_file_path, 'rb') as audio_file: with open(mp3_file_path, 'rb') as audio_file:
files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')} files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')}
response = requests.post(url,files=files) response = requests.post(url,files=files)
@ -73,6 +73,22 @@ class ChatServiceTest:
else: else:
raise Exception("音频上传失败") raise Exception("音频上传失败")
#绑定音频
url = f"{self.socket}/users/audio/bind"
payload = json.dumps({
"user_id":self.user_id,
"audio_id":self.audio_id
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200:
print("音频绑定测试成功")
else:
raise Exception("音频绑定测试失败")
#创建一个对话 #创建一个对话
url = f"{self.socket}/chats" url = f"{self.socket}/chats"
payload = json.dumps({ payload = json.dumps({
@ -339,10 +355,10 @@ def chat_test():
chat_service_test.test_create_chat() chat_service_test.test_create_chat()
chat_service_test.test_session_id_query() chat_service_test.test_session_id_query()
chat_service_test.test_session_content_query() chat_service_test.test_session_content_query()
chat_service_test.test_session_update() # chat_service_test.test_session_update()
asyncio.run(chat_service_test.test_chat_temporary()) asyncio.run(chat_service_test.test_chat_temporary())
asyncio.run(chat_service_test.test_chat_lasting()) # asyncio.run(chat_service_test.test_chat_lasting())
asyncio.run(chat_service_test.test_voice_call()) # asyncio.run(chat_service_test.test_voice_call())
chat_service_test.test_chat_delete() chat_service_test.test_chat_delete()