1
0
Fork 0

feat: openvoice重构,添加语音克隆功能

This commit is contained in:
killua4396 2024-05-22 17:26:29 +08:00
parent 8369090313
commit 5cf16bf03b
6 changed files with 99 additions and 32 deletions

View File

@ -4,12 +4,13 @@ from ..dependencies.summarizer import get_summarizer
from ..dependencies.asr import get_asr from ..dependencies.asr import get_asr
from ..dependencies.tts import get_tts from ..dependencies.tts import get_tts
from .controller_enum import * from .controller_enum import *
from ..models import UserCharacter, Session, Character, User from ..models import UserCharacter, Session, Character, User, Audio
from utils.audio_utils import VAD from utils.audio_utils import VAD
from fastapi import WebSocket, HTTPException, status from fastapi import WebSocket, HTTPException, status
from datetime import datetime from datetime import datetime
from utils.xf_asr_utils import generate_xf_asr_url from utils.xf_asr_utils import generate_xf_asr_url
from config import get_config from config import get_config
import numpy as np
import uuid import uuid
import json import json
import asyncio import asyncio
@ -100,6 +101,19 @@ def update_session_activity(session_id,db):
except Exception as e: except Exception as e:
db.roolback() db.roolback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
#获取target_se
def get_emb(session_id,db):
try:
session_record = db.query(Session).filter(Session.id == session_id).first()
user_character_record = db.query(UserCharacter).filter(UserCharacter.id == session_record.user_character_id).first()
audio_record = db.query(Audio).filter(Audio.user_id == user_character_record.user_id).first()
emb_npy = np.frombuffer(audio_record.emb_data,dtype=np.int32)
emb_npy_3d = np.reshape(emb_npy,(1,256,1))
return emb_npy_3d
except Exception as e:
logger.error("未找到音频:"+str(e))
return np.array([])
#-------------------------------------------------------- #--------------------------------------------------------
# 创建新聊天 # 创建新聊天
@ -283,6 +297,7 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
target_se = get_emb(session_id,db)
except Exception as e: except Exception as e:
logger.error(f"编辑http请求时发生错误: {str(e)}") logger.error(f"编辑http请求时发生错误: {str(e)}")
try: try:
@ -299,7 +314,17 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
response_message = {"type": "text", "code":200, "msg": sentence} response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息 await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
elif response_type == RESPONSE_AUDIO: elif response_type == RESPONSE_AUDIO:
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True) if target_se.size == 0:
audio,sr = tts._base_tts(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["length_scale"])
else:
audio,sr = tts.synthesize(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["length_scale"],
target_se=target_se)
response_message = {"type": "text", "code":200, "msg": sentence} response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_bytes(audio) #返回音频数据 await ws.send_bytes(audio) #返回音频数据
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息 await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
@ -455,6 +480,7 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
target_se = get_emb(session_id,db)
async with aiohttp.ClientSession() as client: async with aiohttp.ClientSession() as client:
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求 async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
async for chunk in response.content.iter_any(): async for chunk in response.content.iter_any():
@ -469,7 +495,17 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
response_message = {"type": "text", "code":200, "msg": sentence} response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) await ws.send_text(json.dumps(response_message, ensure_ascii=False))
elif response_type == RESPONSE_AUDIO: elif response_type == RESPONSE_AUDIO:
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True) if target_se.size == 0:
audio,sr = tts._base_tts(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["length_scale"])
else:
audio,sr = tts.synthesize(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["length_scale"],
target_se=target_se)
response_message = {"type": "text", "code":200, "msg": sentence} response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_bytes(audio) await ws.send_bytes(audio)
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) await ws.send_text(json.dumps(response_message, ensure_ascii=False))
@ -629,11 +665,11 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
"temperature": llm_info["temperature"], "temperature": llm_info["temperature"],
"top_p": llm_info["top_p"] "top_p": llm_info["top_p"]
}) })
headers = { headers = {
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
target_se = get_emb(session_id,db)
async with aiohttp.ClientSession() as client: async with aiohttp.ClientSession() as client:
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求 async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
async for chunk in response.content.iter_any(): async for chunk in response.content.iter_any():
@ -643,7 +679,17 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
llm_response += chunk_data llm_response += chunk_data
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
for sentence in sentences: for sentence in sentences:
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True) if target_se.size == 0:
audio,sr = tts._base_tts(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["length_scale"])
else:
audio,sr = tts.synthesize(text=sentence,
noise_scale=tts_info["noise_scale"],
noise_scale_w=tts_info["noise_scale_w"],
speed=tts_info["length_scale"],
target_se=target_se)
text_response = {"type": "llm_text", "code": 200, "msg": sentence} text_response = {"type": "llm_text", "code": 200, "msg": sentence}
await ws.send_bytes(audio) #返回音频二进制流数据 await ws.send_bytes(audio) #返回音频二进制流数据
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
@ -673,23 +719,6 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
break break
voice_call_end_event.set() voice_call_end_event.set()
#语音合成及返回函数
async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event):
logger.debug("语音合成及返回函数启动")
while not (split_finished_event.is_set() and split_result_q.empty()):
try:
sentence = await asyncio.wait_for(split_result_q.get(),timeout=3)
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
await ws.send_bytes(audio) #返回音频二进制流数据
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
logger.debug(f"websocket返回:{sentence}")
except asyncio.TimeoutError:
continue
voice_call_end_event.set()
async def voice_call_handler(ws, db, redis): async def voice_call_handler(ws, db, redis):
logger.debug("voice_call websocket 连接建立") logger.debug("voice_call websocket 连接建立")
audio_q = asyncio.Queue() #音频队列 audio_q = asyncio.Queue() #音频队列

View File

@ -1,14 +1,20 @@
from ..schemas.user_schema import * from ..schemas.user_schema import *
from ..dependencies.logger import get_logger from ..dependencies.logger import get_logger
from ..dependencies.tts import get_tts
from ..models import User, Hardware, Audio from ..models import User, Hardware, Audio
from datetime import datetime from datetime import datetime
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from fastapi import HTTPException, status from fastapi import HTTPException, status
from pydub import AudioSegment
import numpy as np
import io
#依赖注入获取logger #依赖注入获取logger
logger = get_logger() logger = get_logger()
#依赖注入获取tts
tts = get_tts()
#创建用户 #创建用户
async def create_user_handler(user:UserCrateRequest, db: Session): async def create_user_handler(user:UserCrateRequest, db: Session):
@ -156,13 +162,17 @@ async def get_hardware_handler(hardware_id, db):
#用户上传音频 #用户上传音频
async def upload_audio_handler(user_id, audio, db): async def upload_audio_handler(user_id, audio, db):
try: try:
audio_data = audio.file.read() audio_data = await audio.read()
new_audio = Audio(user_id=user_id, audio_data=audio_data) raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data
numpy_data = np.frombuffer(raw_data, dtype=np.int32)
emb_data = tts.audio2emb(numpy_data,rate=44100,vad=True).tobytes()
new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_data)
db.add(new_audio) db.add(new_audio)
db.commit() db.commit()
db.refresh(new_audio) db.refresh(new_audio)
except Exception as e: except Exception as e:
db.rollback() db.rollback()
print(str(e))
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat()) audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat())
return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data) return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data)
@ -174,8 +184,11 @@ async def update_audio_handler(audio_id, audio_file, db):
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first() existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
if existing_audio is None: if existing_audio is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
audio_data = audio_file.file.read() audio_data = await audio_file.read()
raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data
emb_data = tts.audio2emb(np.frombuffer(raw_data, dtype=np.int32),rate=44100,vad=True).tobytes()
existing_audio.audio_data = audio_data existing_audio.audio_data = audio_data
existing_audio.emb_data = emb_data
db.commit() db.commit()
except Exception as e: except Exception as e:
db.rollback() db.rollback()

Binary file not shown.

View File

@ -30,6 +30,7 @@ class ChatServiceTest:
} }
response = requests.request("POST", url, headers=headers, data=payload) response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200: if response.status_code == 200:
print("用户创建成功")
self.user_id = response.json()['data']['user_id'] self.user_id = response.json()['data']['user_id']
else: else:
raise Exception("创建聊天时,用户创建失败") raise Exception("创建聊天时,用户创建失败")
@ -57,6 +58,21 @@ class ChatServiceTest:
else: else:
raise Exception("创建聊天时,角色创建失败") raise Exception("创建聊天时,角色创建失败")
#上传音频用于音频克隆
url = f"{self.socket}/users/audio?user_id={self.user_id}"
current_file_path = os.path.abspath(__file__)
current_file_path = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_file_path)
mp3_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
with open(mp3_file_path, 'rb') as audio_file:
files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')}
response = requests.post(url,files=files)
if response.status_code == 200:
self.audio_id = response.json()['data']['audio_id']
print("音频上传成功")
else:
raise Exception("音频上传失败")
#创建一个对话 #创建一个对话
url = f"{self.socket}/chats" url = f"{self.socket}/chats"
payload = json.dumps({ payload = json.dumps({
@ -66,6 +82,7 @@ class ChatServiceTest:
headers = { headers = {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
response = requests.request("POST", url, headers=headers, data=payload) response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200: if response.status_code == 200:
print("对话创建成功") print("对话创建成功")
@ -302,6 +319,11 @@ class ChatServiceTest:
else: else:
raise Exception("聊天删除测试失败") raise Exception("聊天删除测试失败")
url = f"{self.socket}/users/audio/{self.audio_id}"
response = requests.request("DELETE", url)
if response.status_code != 200:
raise Exception("音频删除测试失败")
url = f"{self.socket}/users/{self.user_id}" url = f"{self.socket}/users/{self.user_id}"
response = requests.request("DELETE", url) response = requests.request("DELETE", url)
if response.status_code != 200: if response.status_code != 200:
@ -312,7 +334,6 @@ class ChatServiceTest:
if response.status_code != 200: if response.status_code != 200:
raise Exception("角色删除测试失败") raise Exception("角色删除测试失败")
def chat_test(): def chat_test():
chat_service_test = ChatServiceTest() chat_service_test = ChatServiceTest()
chat_service_test.test_create_chat() chat_service_test.test_create_chat()

View File

@ -5,7 +5,7 @@ import os
class UserServiceTest: class UserServiceTest:
def __init__(self,socket="http://127.0.0.1:8001"): def __init__(self,socket="http://127.0.0.1:7878"):
self.socket = socket self.socket = socket
def test_user_create(self): def test_user_create(self):
@ -132,7 +132,7 @@ class UserServiceTest:
current_file_path = os.path.abspath(__file__) current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path) current_dir = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_dir) tests_dir = os.path.dirname(current_dir)
wav_file_path = os.path.join(tests_dir, 'assets', 'iat_mp3_8k.mp3') wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
with open(wav_file_path, 'rb') as audio_file: with open(wav_file_path, 'rb') as audio_file:
files = {'audio_file':(wav_file_path,audio_file,'audio/mpeg')} files = {'audio_file':(wav_file_path,audio_file,'audio/mpeg')}
response = requests.post(url, files=files) response = requests.post(url, files=files)
@ -141,15 +141,15 @@ class UserServiceTest:
print("音频上传测试成功") print("音频上传测试成功")
else: else:
raise Exception("音频上传测试失败") raise Exception("音频上传测试失败")
def test_update_audio(self): def test_update_audio(self):
url = f"{self.socket}/users/audio/{self.audio_id}" url = f"{self.socket}/users/audio/{self.audio_id}"
current_file_path = os.path.abspath(__file__) current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path) current_dir = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_dir) tests_dir = os.path.dirname(current_dir)
wav_file_path = os.path.join(tests_dir, 'assets', 'iat_mp3_8k.mp3') wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
with open(wav_file_path, 'rb') as audio_file: with open(wav_file_path, 'rb') as audio_file:
files = {'audio_file':(wav_file_path,audio_file,'audio/mpeg')} files = {'audio_file':(wav_file_path,audio_file,'audio/wav')}
response = requests.put(url, files=files) response = requests.put(url, files=files)
if response.status_code == 200: if response.status_code == 200:
print("音频上传测试成功") print("音频上传测试成功")

View File

@ -260,6 +260,11 @@ class TextToSpeech:
audio: tensor audio: tensor
sr: 生成音频的采样速率 sr: 生成音频的采样速率
""" """
if source_se is not None:
source_se = torch.tensor(source_se.astype(np.float32)).to(self.device)
if target_se is not None:
target_se = torch.tensor(target_se.astype(np.float32)).to(self.device)
if source_se is None: if source_se is None:
source_se = self.source_se source_se = self.source_se
if target_se is None: if target_se is None:
@ -319,7 +324,6 @@ class TextToSpeech:
tts_sr = self.base_tts_model.hps.data.sampling_rate tts_sr = self.base_tts_model.hps.data.sampling_rate
converter_sr = self.tone_color_converter.hps.data.sampling_rate converter_sr = self.tone_color_converter.hps.data.sampling_rate
audio = F.resample(audio, tts_sr, converter_sr) audio = F.resample(audio, tts_sr, converter_sr)
print(audio.dtype)
audio, sr = self._convert_tone(audio, audio, sr = self._convert_tone(audio,
source_se=source_se, source_se=source_se,
target_se=target_se, target_se=target_se,