feat: openvoice重构,添加语音克隆功能

This commit is contained in:
killua4396 2024-05-22 17:26:29 +08:00
parent 8369090313
commit 5cf16bf03b
6 changed files with 99 additions and 32 deletions

View File

@ -4,12 +4,13 @@ from ..dependencies.summarizer import get_summarizer
from ..dependencies.asr import get_asr
from ..dependencies.tts import get_tts
from .controller_enum import *
from ..models import UserCharacter, Session, Character, User
from ..models import UserCharacter, Session, Character, User, Audio
from utils.audio_utils import VAD
from fastapi import WebSocket, HTTPException, status
from datetime import datetime
from utils.xf_asr_utils import generate_xf_asr_url
from config import get_config
import numpy as np
import uuid
import json
import asyncio
@ -100,6 +101,19 @@ def update_session_activity(session_id,db):
except Exception as e:
db.roolback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
#获取target_se
def get_emb(session_id,db):
try:
session_record = db.query(Session).filter(Session.id == session_id).first()
user_character_record = db.query(UserCharacter).filter(UserCharacter.id == session_record.user_character_id).first()
audio_record = db.query(Audio).filter(Audio.user_id == user_character_record.user_id).first()
emb_npy = np.frombuffer(audio_record.emb_data,dtype=np.int32)
emb_npy_3d = np.reshape(emb_npy,(1,256,1))
return emb_npy_3d
except Exception as e:
logger.error("未找到音频:"+str(e))
return np.array([])
#--------------------------------------------------------
# 创建新聊天
@ -283,6 +297,7 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json'
}
target_se = get_emb(session_id,db)
except Exception as e:
logger.error(f"编辑http请求时发生错误: {str(e)}")
try:
@ -299,7 +314,17 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
elif response_type == RESPONSE_AUDIO:
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
if target_se.size == 0:
audio,sr = tts._base_tts(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["length_scale"])
else:
audio,sr = tts.synthesize(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["length_scale"],
target_se=target_se)
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_bytes(audio) #返回音频数据
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
@ -455,6 +480,7 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json'
}
target_se = get_emb(session_id,db)
async with aiohttp.ClientSession() as client:
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
async for chunk in response.content.iter_any():
@ -469,7 +495,17 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
elif response_type == RESPONSE_AUDIO:
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
if target_se.size == 0:
audio,sr = tts._base_tts(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["length_scale"])
else:
audio,sr = tts.synthesize(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["length_scale"],
target_se=target_se)
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_bytes(audio)
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
@ -629,11 +665,11 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
"temperature": llm_info["temperature"],
"top_p": llm_info["top_p"]
})
headers = {
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json'
}
target_se = get_emb(session_id,db)
async with aiohttp.ClientSession() as client:
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
async for chunk in response.content.iter_any():
@ -643,7 +679,17 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
llm_response += chunk_data
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
for sentence in sentences:
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
if target_se.size == 0:
audio,sr = tts._base_tts(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["length_scale"])
else:
audio,sr = tts.synthesize(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["length_scale"],
target_se=target_se)
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
await ws.send_bytes(audio) #返回音频二进制流数据
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
@ -673,23 +719,6 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
break
voice_call_end_event.set()
#语音合成及返回函数
async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event):
logger.debug("语音合成及返回函数启动")
while not (split_finished_event.is_set() and split_result_q.empty()):
try:
sentence = await asyncio.wait_for(split_result_q.get(),timeout=3)
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
await ws.send_bytes(audio) #返回音频二进制流数据
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
logger.debug(f"websocket返回:{sentence}")
except asyncio.TimeoutError:
continue
voice_call_end_event.set()
async def voice_call_handler(ws, db, redis):
logger.debug("voice_call websocket 连接建立")
audio_q = asyncio.Queue() #音频队列

View File

@ -1,14 +1,20 @@
from ..schemas.user_schema import *
from ..dependencies.logger import get_logger
from ..dependencies.tts import get_tts
from ..models import User, Hardware, Audio
from datetime import datetime
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from pydub import AudioSegment
import numpy as np
import io
#依赖注入获取logger
logger = get_logger()
#依赖注入获取tts
tts = get_tts()
#创建用户
async def create_user_handler(user:UserCrateRequest, db: Session):
@ -156,13 +162,17 @@ async def get_hardware_handler(hardware_id, db):
#用户上传音频
async def upload_audio_handler(user_id, audio, db):
try:
audio_data = audio.file.read()
new_audio = Audio(user_id=user_id, audio_data=audio_data)
audio_data = await audio.read()
raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data
numpy_data = np.frombuffer(raw_data, dtype=np.int32)
emb_data = tts.audio2emb(numpy_data,rate=44100,vad=True).tobytes()
new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_data)
db.add(new_audio)
db.commit()
db.refresh(new_audio)
except Exception as e:
db.rollback()
print(str(e))
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat())
return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data)
@ -174,8 +184,11 @@ async def update_audio_handler(audio_id, audio_file, db):
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
if existing_audio is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
audio_data = audio_file.file.read()
audio_data = await audio_file.read()
raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data
emb_data = tts.audio2emb(np.frombuffer(raw_data, dtype=np.int32),rate=44100,vad=True).tobytes()
existing_audio.audio_data = audio_data
existing_audio.emb_data = emb_data
db.commit()
except Exception as e:
db.rollback()

Binary file not shown.

View File

@ -30,6 +30,7 @@ class ChatServiceTest:
}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200:
print("用户创建成功")
self.user_id = response.json()['data']['user_id']
else:
raise Exception("创建聊天时,用户创建失败")
@ -57,6 +58,21 @@ class ChatServiceTest:
else:
raise Exception("创建聊天时,角色创建失败")
#上传音频用于音频克隆
url = f"{self.socket}/users/audio?user_id={self.user_id}"
current_file_path = os.path.abspath(__file__)
current_file_path = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_file_path)
mp3_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
with open(mp3_file_path, 'rb') as audio_file:
files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')}
response = requests.post(url,files=files)
if response.status_code == 200:
self.audio_id = response.json()['data']['audio_id']
print("音频上传成功")
else:
raise Exception("音频上传失败")
#创建一个对话
url = f"{self.socket}/chats"
payload = json.dumps({
@ -66,6 +82,7 @@ class ChatServiceTest:
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200:
print("对话创建成功")
@ -302,6 +319,11 @@ class ChatServiceTest:
else:
raise Exception("聊天删除测试失败")
url = f"{self.socket}/users/audio/{self.audio_id}"
response = requests.request("DELETE", url)
if response.status_code != 200:
raise Exception("音频删除测试失败")
url = f"{self.socket}/users/{self.user_id}"
response = requests.request("DELETE", url)
if response.status_code != 200:
@ -312,7 +334,6 @@ class ChatServiceTest:
if response.status_code != 200:
raise Exception("角色删除测试失败")
def chat_test():
chat_service_test = ChatServiceTest()
chat_service_test.test_create_chat()

View File

@ -5,7 +5,7 @@ import os
class UserServiceTest:
def __init__(self,socket="http://127.0.0.1:8001"):
def __init__(self,socket="http://127.0.0.1:7878"):
self.socket = socket
def test_user_create(self):
@ -132,7 +132,7 @@ class UserServiceTest:
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_dir)
wav_file_path = os.path.join(tests_dir, 'assets', 'iat_mp3_8k.mp3')
wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
with open(wav_file_path, 'rb') as audio_file:
files = {'audio_file':(wav_file_path,audio_file,'audio/mpeg')}
response = requests.post(url, files=files)
@ -141,15 +141,15 @@ class UserServiceTest:
print("音频上传测试成功")
else:
raise Exception("音频上传测试失败")
def test_update_audio(self):
url = f"{self.socket}/users/audio/{self.audio_id}"
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_dir)
wav_file_path = os.path.join(tests_dir, 'assets', 'iat_mp3_8k.mp3')
wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
with open(wav_file_path, 'rb') as audio_file:
files = {'audio_file':(wav_file_path,audio_file,'audio/mpeg')}
files = {'audio_file':(wav_file_path,audio_file,'audio/wav')}
response = requests.put(url, files=files)
if response.status_code == 200:
print("音频上传测试成功")

View File

@ -260,6 +260,11 @@ class TextToSpeech:
audio: tensor
sr: 生成音频的采样速率
"""
if source_se is not None:
source_se = torch.tensor(source_se.astype(np.float32)).to(self.device)
if target_se is not None:
target_se = torch.tensor(target_se.astype(np.float32)).to(self.device)
if source_se is None:
source_se = self.source_se
if target_se is None:
@ -319,7 +324,6 @@ class TextToSpeech:
tts_sr = self.base_tts_model.hps.data.sampling_rate
converter_sr = self.tone_color_converter.hps.data.sampling_rate
audio = F.resample(audio, tts_sr, converter_sr)
print(audio.dtype)
audio, sr = self._convert_tone(audio,
source_se=source_se,
target_se=target_se,