forked from killua/TakwayPlatform
debug: 修复语音克隆输出错误的bug
This commit is contained in:
parent
3b9cc44e4c
commit
e2f3decfae
|
@ -15,6 +15,7 @@ import uuid
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import io
|
||||||
|
|
||||||
# 依赖注入获取logger
|
# 依赖注入获取logger
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
@ -107,10 +108,10 @@ def get_emb(session_id,db):
|
||||||
try:
|
try:
|
||||||
session_record = db.query(Session).filter(Session.id == session_id).first()
|
session_record = db.query(Session).filter(Session.id == session_id).first()
|
||||||
user_character_record = db.query(UserCharacter).filter(UserCharacter.id == session_record.user_character_id).first()
|
user_character_record = db.query(UserCharacter).filter(UserCharacter.id == session_record.user_character_id).first()
|
||||||
audio_record = db.query(Audio).filter(Audio.user_id == user_character_record.user_id).first()
|
user_record = db.query(User).filter(User.id == user_character_record.user_id).first()
|
||||||
emb_npy = np.frombuffer(audio_record.emb_data,dtype=np.int32)
|
audio_record = db.query(Audio).filter(Audio.id == user_record.selected_audio_id).first()
|
||||||
emb_npy_3d = np.reshape(emb_npy,(1,256,1))
|
emb_npy = np.load(io.BytesIO(audio_record.emb_data))
|
||||||
return emb_npy_3d
|
return emb_npy
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("未找到音频:"+str(e))
|
logger.error("未找到音频:"+str(e))
|
||||||
return np.array([])
|
return np.array([])
|
||||||
|
@ -149,7 +150,7 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
|
||||||
"noise_scale": 0.1,
|
"noise_scale": 0.1,
|
||||||
"noise_scale_w":0.668,
|
"noise_scale_w":0.668,
|
||||||
"length_scale": 1.2,
|
"length_scale": 1.2,
|
||||||
"speend":1
|
"speed":1
|
||||||
}
|
}
|
||||||
llm_info = {
|
llm_info = {
|
||||||
"model": "abab5.5-chat",
|
"model": "abab5.5-chat",
|
||||||
|
@ -243,7 +244,7 @@ async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f
|
||||||
logger.error(f"用户输入处理函数发生错误: {str(e)}")
|
logger.error(f"用户输入处理函数发生错误: {str(e)}")
|
||||||
|
|
||||||
#语音识别
|
#语音识别
|
||||||
async def sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event):
|
async def sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event):
|
||||||
logger.debug("语音识别函数启动")
|
logger.debug("语音识别函数启动")
|
||||||
is_signup = False
|
is_signup = False
|
||||||
audio = ""
|
audio = ""
|
||||||
|
@ -259,6 +260,9 @@ async def sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_
|
||||||
current_message += ''.join(asr_result['text'])
|
current_message += ''.join(asr_result['text'])
|
||||||
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
|
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
|
||||||
current_message += ''.join(asr_result['text'])
|
current_message += ''.join(asr_result['text'])
|
||||||
|
if current_message == "":
|
||||||
|
await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False))
|
||||||
|
return
|
||||||
current_message = asr.punctuation_correction(current_message)
|
current_message = asr.punctuation_correction(current_message)
|
||||||
emotion_dict = asr.emtion_recognition(audio) #情感辨识
|
emotion_dict = asr.emtion_recognition(audio) #情感辨识
|
||||||
if not isinstance(emotion_dict, str):
|
if not isinstance(emotion_dict, str):
|
||||||
|
@ -366,7 +370,7 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis):
|
||||||
session_id = await future_session_id #获取session_id
|
session_id = await future_session_id #获取session_id
|
||||||
update_session_activity(session_id,db)
|
update_session_activity(session_id,db)
|
||||||
response_type = await future_response_type #获取返回类型
|
response_type = await future_response_type #获取返回类型
|
||||||
asyncio.create_task(sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event))
|
asyncio.create_task(sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event))
|
||||||
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
||||||
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
||||||
|
|
||||||
|
|
|
@ -163,19 +163,17 @@ async def get_hardware_handler(hardware_id, db):
|
||||||
async def upload_audio_handler(user_id, audio, db):
|
async def upload_audio_handler(user_id, audio, db):
|
||||||
try:
|
try:
|
||||||
audio_data = await audio.read()
|
audio_data = await audio.read()
|
||||||
raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data
|
emb_data = tts.audio2emb(np.frombuffer(AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data, dtype=np.int32),rate=44100,vad=True)
|
||||||
numpy_data = np.frombuffer(raw_data, dtype=np.int32)
|
out = io.BytesIO()
|
||||||
emb_data = tts.audio2emb(numpy_data,rate=44100,vad=True).tobytes()
|
np.save(out, emb_data)
|
||||||
new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_data) #创建音频
|
out.seek(0)
|
||||||
db.flush()
|
emb_binary = out.read()
|
||||||
existing_user = db.query(User).filter(User.id == user_id).first()
|
new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_binary) #创建音频
|
||||||
existing_user.selected_audio_id = new_audio.id #绑定音频到用户
|
|
||||||
db.add(new_audio)
|
db.add(new_audio)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(new_audio)
|
db.refresh(new_audio)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
print(str(e))
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat())
|
audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat())
|
||||||
return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data)
|
return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data)
|
||||||
|
|
|
@ -63,7 +63,7 @@ class ChatServiceTest:
|
||||||
current_file_path = os.path.abspath(__file__)
|
current_file_path = os.path.abspath(__file__)
|
||||||
current_file_path = os.path.dirname(current_file_path)
|
current_file_path = os.path.dirname(current_file_path)
|
||||||
tests_dir = os.path.dirname(current_file_path)
|
tests_dir = os.path.dirname(current_file_path)
|
||||||
mp3_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
|
mp3_file_path = os.path.join(tests_dir, 'assets', 'BarbieDollsVoice.mp3')
|
||||||
with open(mp3_file_path, 'rb') as audio_file:
|
with open(mp3_file_path, 'rb') as audio_file:
|
||||||
files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')}
|
files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')}
|
||||||
response = requests.post(url,files=files)
|
response = requests.post(url,files=files)
|
||||||
|
@ -73,6 +73,22 @@ class ChatServiceTest:
|
||||||
else:
|
else:
|
||||||
raise Exception("音频上传失败")
|
raise Exception("音频上传失败")
|
||||||
|
|
||||||
|
#绑定音频
|
||||||
|
url = f"{self.socket}/users/audio/bind"
|
||||||
|
payload = json.dumps({
|
||||||
|
"user_id":self.user_id,
|
||||||
|
"audio_id":self.audio_id
|
||||||
|
})
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
response = requests.request("POST", url, headers=headers, data=payload)
|
||||||
|
if response.status_code == 200:
|
||||||
|
print("音频绑定测试成功")
|
||||||
|
else:
|
||||||
|
raise Exception("音频绑定测试失败")
|
||||||
|
|
||||||
|
|
||||||
#创建一个对话
|
#创建一个对话
|
||||||
url = f"{self.socket}/chats"
|
url = f"{self.socket}/chats"
|
||||||
payload = json.dumps({
|
payload = json.dumps({
|
||||||
|
@ -339,10 +355,10 @@ def chat_test():
|
||||||
chat_service_test.test_create_chat()
|
chat_service_test.test_create_chat()
|
||||||
chat_service_test.test_session_id_query()
|
chat_service_test.test_session_id_query()
|
||||||
chat_service_test.test_session_content_query()
|
chat_service_test.test_session_content_query()
|
||||||
chat_service_test.test_session_update()
|
# chat_service_test.test_session_update()
|
||||||
asyncio.run(chat_service_test.test_chat_temporary())
|
asyncio.run(chat_service_test.test_chat_temporary())
|
||||||
asyncio.run(chat_service_test.test_chat_lasting())
|
# asyncio.run(chat_service_test.test_chat_lasting())
|
||||||
asyncio.run(chat_service_test.test_voice_call())
|
# asyncio.run(chat_service_test.test_voice_call())
|
||||||
chat_service_test.test_chat_delete()
|
chat_service_test.test_chat_delete()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue