feat: openvoice重构,添加语音克隆功能
This commit is contained in:
parent
8369090313
commit
5cf16bf03b
|
@ -4,12 +4,13 @@ from ..dependencies.summarizer import get_summarizer
|
||||||
from ..dependencies.asr import get_asr
|
from ..dependencies.asr import get_asr
|
||||||
from ..dependencies.tts import get_tts
|
from ..dependencies.tts import get_tts
|
||||||
from .controller_enum import *
|
from .controller_enum import *
|
||||||
from ..models import UserCharacter, Session, Character, User
|
from ..models import UserCharacter, Session, Character, User, Audio
|
||||||
from utils.audio_utils import VAD
|
from utils.audio_utils import VAD
|
||||||
from fastapi import WebSocket, HTTPException, status
|
from fastapi import WebSocket, HTTPException, status
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from utils.xf_asr_utils import generate_xf_asr_url
|
from utils.xf_asr_utils import generate_xf_asr_url
|
||||||
from config import get_config
|
from config import get_config
|
||||||
|
import numpy as np
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -100,6 +101,19 @@ def update_session_activity(session_id,db):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.roolback()
|
db.roolback()
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
|
||||||
|
#获取target_se
|
||||||
|
def get_emb(session_id,db):
|
||||||
|
try:
|
||||||
|
session_record = db.query(Session).filter(Session.id == session_id).first()
|
||||||
|
user_character_record = db.query(UserCharacter).filter(UserCharacter.id == session_record.user_character_id).first()
|
||||||
|
audio_record = db.query(Audio).filter(Audio.user_id == user_character_record.user_id).first()
|
||||||
|
emb_npy = np.frombuffer(audio_record.emb_data,dtype=np.int32)
|
||||||
|
emb_npy_3d = np.reshape(emb_npy,(1,256,1))
|
||||||
|
return emb_npy_3d
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("未找到音频:"+str(e))
|
||||||
|
return np.array([])
|
||||||
#--------------------------------------------------------
|
#--------------------------------------------------------
|
||||||
|
|
||||||
# 创建新聊天
|
# 创建新聊天
|
||||||
|
@ -283,6 +297,7 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
|
target_se = get_emb(session_id,db)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"编辑http请求时发生错误: {str(e)}")
|
logger.error(f"编辑http请求时发生错误: {str(e)}")
|
||||||
try:
|
try:
|
||||||
|
@ -299,7 +314,17 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
|
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
|
||||||
elif response_type == RESPONSE_AUDIO:
|
elif response_type == RESPONSE_AUDIO:
|
||||||
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
|
if target_se.size == 0:
|
||||||
|
audio,sr = tts._base_tts(text=sentence,
|
||||||
|
noise_scale=tts_info["noise_scale"],
|
||||||
|
noise_scale_w=tts_info["noise_scale_w"],
|
||||||
|
speed=tts_info["length_scale"])
|
||||||
|
else:
|
||||||
|
audio,sr = tts.synthesize(text=sentence,
|
||||||
|
noise_scale=tts_info["noise_scale"],
|
||||||
|
noise_scale_w=tts_info["noise_scale_w"],
|
||||||
|
speed=tts_info["length_scale"],
|
||||||
|
target_se=target_se)
|
||||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||||
await ws.send_bytes(audio) #返回音频数据
|
await ws.send_bytes(audio) #返回音频数据
|
||||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
|
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
|
||||||
|
@ -455,6 +480,7 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
|
target_se = get_emb(session_id,db)
|
||||||
async with aiohttp.ClientSession() as client:
|
async with aiohttp.ClientSession() as client:
|
||||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
||||||
async for chunk in response.content.iter_any():
|
async for chunk in response.content.iter_any():
|
||||||
|
@ -469,7 +495,17 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
||||||
elif response_type == RESPONSE_AUDIO:
|
elif response_type == RESPONSE_AUDIO:
|
||||||
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
|
if target_se.size == 0:
|
||||||
|
audio,sr = tts._base_tts(text=sentence,
|
||||||
|
noise_scale=tts_info["noise_scale"],
|
||||||
|
noise_scale_w=tts_info["noise_scale_w"],
|
||||||
|
speed=tts_info["length_scale"])
|
||||||
|
else:
|
||||||
|
audio,sr = tts.synthesize(text=sentence,
|
||||||
|
noise_scale=tts_info["noise_scale"],
|
||||||
|
noise_scale_w=tts_info["noise_scale_w"],
|
||||||
|
speed=tts_info["length_scale"],
|
||||||
|
target_se=target_se)
|
||||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||||
await ws.send_bytes(audio)
|
await ws.send_bytes(audio)
|
||||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
||||||
|
@ -629,11 +665,11 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
||||||
"temperature": llm_info["temperature"],
|
"temperature": llm_info["temperature"],
|
||||||
"top_p": llm_info["top_p"]
|
"top_p": llm_info["top_p"]
|
||||||
})
|
})
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
|
target_se = get_emb(session_id,db)
|
||||||
async with aiohttp.ClientSession() as client:
|
async with aiohttp.ClientSession() as client:
|
||||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
||||||
async for chunk in response.content.iter_any():
|
async for chunk in response.content.iter_any():
|
||||||
|
@ -643,7 +679,17 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
||||||
llm_response += chunk_data
|
llm_response += chunk_data
|
||||||
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
|
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
|
||||||
for sentence in sentences:
|
for sentence in sentences:
|
||||||
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
|
if target_se.size == 0:
|
||||||
|
audio,sr = tts._base_tts(text=sentence,
|
||||||
|
noise_scale=tts_info["noise_scale"],
|
||||||
|
noise_scale_w=tts_info["noise_scale_w"],
|
||||||
|
speed=tts_info["length_scale"])
|
||||||
|
else:
|
||||||
|
audio,sr = tts.synthesize(text=sentence,
|
||||||
|
noise_scale=tts_info["noise_scale"],
|
||||||
|
noise_scale_w=tts_info["noise_scale_w"],
|
||||||
|
speed=tts_info["length_scale"],
|
||||||
|
target_se=target_se)
|
||||||
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
|
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
|
||||||
await ws.send_bytes(audio) #返回音频二进制流数据
|
await ws.send_bytes(audio) #返回音频二进制流数据
|
||||||
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
||||||
|
@ -673,23 +719,6 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
||||||
break
|
break
|
||||||
voice_call_end_event.set()
|
voice_call_end_event.set()
|
||||||
|
|
||||||
|
|
||||||
#语音合成及返回函数
|
|
||||||
async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event):
|
|
||||||
logger.debug("语音合成及返回函数启动")
|
|
||||||
while not (split_finished_event.is_set() and split_result_q.empty()):
|
|
||||||
try:
|
|
||||||
sentence = await asyncio.wait_for(split_result_q.get(),timeout=3)
|
|
||||||
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
|
|
||||||
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
|
|
||||||
await ws.send_bytes(audio) #返回音频二进制流数据
|
|
||||||
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
|
||||||
logger.debug(f"websocket返回:{sentence}")
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
continue
|
|
||||||
voice_call_end_event.set()
|
|
||||||
|
|
||||||
|
|
||||||
async def voice_call_handler(ws, db, redis):
|
async def voice_call_handler(ws, db, redis):
|
||||||
logger.debug("voice_call websocket 连接建立")
|
logger.debug("voice_call websocket 连接建立")
|
||||||
audio_q = asyncio.Queue() #音频队列
|
audio_q = asyncio.Queue() #音频队列
|
||||||
|
|
|
@ -1,14 +1,20 @@
|
||||||
from ..schemas.user_schema import *
|
from ..schemas.user_schema import *
|
||||||
from ..dependencies.logger import get_logger
|
from ..dependencies.logger import get_logger
|
||||||
|
from ..dependencies.tts import get_tts
|
||||||
from ..models import User, Hardware, Audio
|
from ..models import User, Hardware, Audio
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
|
from pydub import AudioSegment
|
||||||
|
import numpy as np
|
||||||
|
import io
|
||||||
|
|
||||||
|
|
||||||
#依赖注入获取logger
|
#依赖注入获取logger
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
#依赖注入获取tts
|
||||||
|
tts = get_tts()
|
||||||
|
|
||||||
#创建用户
|
#创建用户
|
||||||
async def create_user_handler(user:UserCrateRequest, db: Session):
|
async def create_user_handler(user:UserCrateRequest, db: Session):
|
||||||
|
@ -156,13 +162,17 @@ async def get_hardware_handler(hardware_id, db):
|
||||||
#用户上传音频
|
#用户上传音频
|
||||||
async def upload_audio_handler(user_id, audio, db):
|
async def upload_audio_handler(user_id, audio, db):
|
||||||
try:
|
try:
|
||||||
audio_data = audio.file.read()
|
audio_data = await audio.read()
|
||||||
new_audio = Audio(user_id=user_id, audio_data=audio_data)
|
raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data
|
||||||
|
numpy_data = np.frombuffer(raw_data, dtype=np.int32)
|
||||||
|
emb_data = tts.audio2emb(numpy_data,rate=44100,vad=True).tobytes()
|
||||||
|
new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_data)
|
||||||
db.add(new_audio)
|
db.add(new_audio)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(new_audio)
|
db.refresh(new_audio)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
|
print(str(e))
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat())
|
audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat())
|
||||||
return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data)
|
return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data)
|
||||||
|
@ -174,8 +184,11 @@ async def update_audio_handler(audio_id, audio_file, db):
|
||||||
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
|
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
|
||||||
if existing_audio is None:
|
if existing_audio is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
|
||||||
audio_data = audio_file.file.read()
|
audio_data = await audio_file.read()
|
||||||
|
raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data
|
||||||
|
emb_data = tts.audio2emb(np.frombuffer(raw_data, dtype=np.int32),rate=44100,vad=True).tobytes()
|
||||||
existing_audio.audio_data = audio_data
|
existing_audio.audio_data = audio_data
|
||||||
|
existing_audio.emb_data = emb_data
|
||||||
db.commit()
|
db.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
|
|
Binary file not shown.
|
@ -30,6 +30,7 @@ class ChatServiceTest:
|
||||||
}
|
}
|
||||||
response = requests.request("POST", url, headers=headers, data=payload)
|
response = requests.request("POST", url, headers=headers, data=payload)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
|
print("用户创建成功")
|
||||||
self.user_id = response.json()['data']['user_id']
|
self.user_id = response.json()['data']['user_id']
|
||||||
else:
|
else:
|
||||||
raise Exception("创建聊天时,用户创建失败")
|
raise Exception("创建聊天时,用户创建失败")
|
||||||
|
@ -57,6 +58,21 @@ class ChatServiceTest:
|
||||||
else:
|
else:
|
||||||
raise Exception("创建聊天时,角色创建失败")
|
raise Exception("创建聊天时,角色创建失败")
|
||||||
|
|
||||||
|
#上传音频用于音频克隆
|
||||||
|
url = f"{self.socket}/users/audio?user_id={self.user_id}"
|
||||||
|
current_file_path = os.path.abspath(__file__)
|
||||||
|
current_file_path = os.path.dirname(current_file_path)
|
||||||
|
tests_dir = os.path.dirname(current_file_path)
|
||||||
|
mp3_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
|
||||||
|
with open(mp3_file_path, 'rb') as audio_file:
|
||||||
|
files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')}
|
||||||
|
response = requests.post(url,files=files)
|
||||||
|
if response.status_code == 200:
|
||||||
|
self.audio_id = response.json()['data']['audio_id']
|
||||||
|
print("音频上传成功")
|
||||||
|
else:
|
||||||
|
raise Exception("音频上传失败")
|
||||||
|
|
||||||
#创建一个对话
|
#创建一个对话
|
||||||
url = f"{self.socket}/chats"
|
url = f"{self.socket}/chats"
|
||||||
payload = json.dumps({
|
payload = json.dumps({
|
||||||
|
@ -66,6 +82,7 @@ class ChatServiceTest:
|
||||||
headers = {
|
headers = {
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.request("POST", url, headers=headers, data=payload)
|
response = requests.request("POST", url, headers=headers, data=payload)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
print("对话创建成功")
|
print("对话创建成功")
|
||||||
|
@ -302,6 +319,11 @@ class ChatServiceTest:
|
||||||
else:
|
else:
|
||||||
raise Exception("聊天删除测试失败")
|
raise Exception("聊天删除测试失败")
|
||||||
|
|
||||||
|
url = f"{self.socket}/users/audio/{self.audio_id}"
|
||||||
|
response = requests.request("DELETE", url)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception("音频删除测试失败")
|
||||||
|
|
||||||
url = f"{self.socket}/users/{self.user_id}"
|
url = f"{self.socket}/users/{self.user_id}"
|
||||||
response = requests.request("DELETE", url)
|
response = requests.request("DELETE", url)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
@ -312,7 +334,6 @@ class ChatServiceTest:
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise Exception("角色删除测试失败")
|
raise Exception("角色删除测试失败")
|
||||||
|
|
||||||
|
|
||||||
def chat_test():
|
def chat_test():
|
||||||
chat_service_test = ChatServiceTest()
|
chat_service_test = ChatServiceTest()
|
||||||
chat_service_test.test_create_chat()
|
chat_service_test.test_create_chat()
|
||||||
|
|
|
@ -5,7 +5,7 @@ import os
|
||||||
|
|
||||||
|
|
||||||
class UserServiceTest:
|
class UserServiceTest:
|
||||||
def __init__(self,socket="http://127.0.0.1:8001"):
|
def __init__(self,socket="http://127.0.0.1:7878"):
|
||||||
self.socket = socket
|
self.socket = socket
|
||||||
|
|
||||||
def test_user_create(self):
|
def test_user_create(self):
|
||||||
|
@ -132,7 +132,7 @@ class UserServiceTest:
|
||||||
current_file_path = os.path.abspath(__file__)
|
current_file_path = os.path.abspath(__file__)
|
||||||
current_dir = os.path.dirname(current_file_path)
|
current_dir = os.path.dirname(current_file_path)
|
||||||
tests_dir = os.path.dirname(current_dir)
|
tests_dir = os.path.dirname(current_dir)
|
||||||
wav_file_path = os.path.join(tests_dir, 'assets', 'iat_mp3_8k.mp3')
|
wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
|
||||||
with open(wav_file_path, 'rb') as audio_file:
|
with open(wav_file_path, 'rb') as audio_file:
|
||||||
files = {'audio_file':(wav_file_path,audio_file,'audio/mpeg')}
|
files = {'audio_file':(wav_file_path,audio_file,'audio/mpeg')}
|
||||||
response = requests.post(url, files=files)
|
response = requests.post(url, files=files)
|
||||||
|
@ -147,9 +147,9 @@ class UserServiceTest:
|
||||||
current_file_path = os.path.abspath(__file__)
|
current_file_path = os.path.abspath(__file__)
|
||||||
current_dir = os.path.dirname(current_file_path)
|
current_dir = os.path.dirname(current_file_path)
|
||||||
tests_dir = os.path.dirname(current_dir)
|
tests_dir = os.path.dirname(current_dir)
|
||||||
wav_file_path = os.path.join(tests_dir, 'assets', 'iat_mp3_8k.mp3')
|
wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
|
||||||
with open(wav_file_path, 'rb') as audio_file:
|
with open(wav_file_path, 'rb') as audio_file:
|
||||||
files = {'audio_file':(wav_file_path,audio_file,'audio/mpeg')}
|
files = {'audio_file':(wav_file_path,audio_file,'audio/wav')}
|
||||||
response = requests.put(url, files=files)
|
response = requests.put(url, files=files)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
print("音频上传测试成功")
|
print("音频上传测试成功")
|
||||||
|
|
|
@ -260,6 +260,11 @@ class TextToSpeech:
|
||||||
audio: tensor
|
audio: tensor
|
||||||
sr: 生成音频的采样速率
|
sr: 生成音频的采样速率
|
||||||
"""
|
"""
|
||||||
|
if source_se is not None:
|
||||||
|
source_se = torch.tensor(source_se.astype(np.float32)).to(self.device)
|
||||||
|
if target_se is not None:
|
||||||
|
target_se = torch.tensor(target_se.astype(np.float32)).to(self.device)
|
||||||
|
|
||||||
if source_se is None:
|
if source_se is None:
|
||||||
source_se = self.source_se
|
source_se = self.source_se
|
||||||
if target_se is None:
|
if target_se is None:
|
||||||
|
@ -319,7 +324,6 @@ class TextToSpeech:
|
||||||
tts_sr = self.base_tts_model.hps.data.sampling_rate
|
tts_sr = self.base_tts_model.hps.data.sampling_rate
|
||||||
converter_sr = self.tone_color_converter.hps.data.sampling_rate
|
converter_sr = self.tone_color_converter.hps.data.sampling_rate
|
||||||
audio = F.resample(audio, tts_sr, converter_sr)
|
audio = F.resample(audio, tts_sr, converter_sr)
|
||||||
print(audio.dtype)
|
|
||||||
audio, sr = self._convert_tone(audio,
|
audio, sr = self._convert_tone(audio,
|
||||||
source_se=source_se,
|
source_se=source_se,
|
||||||
target_se=target_se,
|
target_se=target_se,
|
||||||
|
|
Loading…
Reference in New Issue