forked from killua/TakwayPlatform
feat: openvoice重构,添加语音克隆功能
This commit is contained in:
parent
8369090313
commit
5cf16bf03b
|
@ -4,12 +4,13 @@ from ..dependencies.summarizer import get_summarizer
|
|||
from ..dependencies.asr import get_asr
|
||||
from ..dependencies.tts import get_tts
|
||||
from .controller_enum import *
|
||||
from ..models import UserCharacter, Session, Character, User
|
||||
from ..models import UserCharacter, Session, Character, User, Audio
|
||||
from utils.audio_utils import VAD
|
||||
from fastapi import WebSocket, HTTPException, status
|
||||
from datetime import datetime
|
||||
from utils.xf_asr_utils import generate_xf_asr_url
|
||||
from config import get_config
|
||||
import numpy as np
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
|
@ -100,6 +101,19 @@ def update_session_activity(session_id,db):
|
|||
except Exception as e:
|
||||
db.roolback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
#获取target_se
|
||||
def get_emb(session_id,db):
|
||||
try:
|
||||
session_record = db.query(Session).filter(Session.id == session_id).first()
|
||||
user_character_record = db.query(UserCharacter).filter(UserCharacter.id == session_record.user_character_id).first()
|
||||
audio_record = db.query(Audio).filter(Audio.user_id == user_character_record.user_id).first()
|
||||
emb_npy = np.frombuffer(audio_record.emb_data,dtype=np.int32)
|
||||
emb_npy_3d = np.reshape(emb_npy,(1,256,1))
|
||||
return emb_npy_3d
|
||||
except Exception as e:
|
||||
logger.error("未找到音频:"+str(e))
|
||||
return np.array([])
|
||||
#--------------------------------------------------------
|
||||
|
||||
# 创建新聊天
|
||||
|
@ -283,6 +297,7 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
target_se = get_emb(session_id,db)
|
||||
except Exception as e:
|
||||
logger.error(f"编辑http请求时发生错误: {str(e)}")
|
||||
try:
|
||||
|
@ -299,7 +314,17 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
|
||||
elif response_type == RESPONSE_AUDIO:
|
||||
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
|
||||
if target_se.size == 0:
|
||||
audio,sr = tts._base_tts(text=sentence,
|
||||
noise_scale=tts_info["noise_scale"],
|
||||
noise_scale_w=tts_info["noise_scale_w"],
|
||||
speed=tts_info["length_scale"])
|
||||
else:
|
||||
audio,sr = tts.synthesize(text=sentence,
|
||||
noise_scale=tts_info["noise_scale"],
|
||||
noise_scale_w=tts_info["noise_scale_w"],
|
||||
speed=tts_info["length_scale"],
|
||||
target_se=target_se)
|
||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||
await ws.send_bytes(audio) #返回音频数据
|
||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
|
||||
|
@ -455,6 +480,7 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
target_se = get_emb(session_id,db)
|
||||
async with aiohttp.ClientSession() as client:
|
||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
||||
async for chunk in response.content.iter_any():
|
||||
|
@ -469,7 +495,17 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
||||
elif response_type == RESPONSE_AUDIO:
|
||||
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
|
||||
if target_se.size == 0:
|
||||
audio,sr = tts._base_tts(text=sentence,
|
||||
noise_scale=tts_info["noise_scale"],
|
||||
noise_scale_w=tts_info["noise_scale_w"],
|
||||
speed=tts_info["length_scale"])
|
||||
else:
|
||||
audio,sr = tts.synthesize(text=sentence,
|
||||
noise_scale=tts_info["noise_scale"],
|
||||
noise_scale_w=tts_info["noise_scale_w"],
|
||||
speed=tts_info["length_scale"],
|
||||
target_se=target_se)
|
||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||
await ws.send_bytes(audio)
|
||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
||||
|
@ -629,11 +665,11 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
|||
"temperature": llm_info["temperature"],
|
||||
"top_p": llm_info["top_p"]
|
||||
})
|
||||
|
||||
headers = {
|
||||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
target_se = get_emb(session_id,db)
|
||||
async with aiohttp.ClientSession() as client:
|
||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
||||
async for chunk in response.content.iter_any():
|
||||
|
@ -643,7 +679,17 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
|||
llm_response += chunk_data
|
||||
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
|
||||
for sentence in sentences:
|
||||
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
|
||||
if target_se.size == 0:
|
||||
audio,sr = tts._base_tts(text=sentence,
|
||||
noise_scale=tts_info["noise_scale"],
|
||||
noise_scale_w=tts_info["noise_scale_w"],
|
||||
speed=tts_info["length_scale"])
|
||||
else:
|
||||
audio,sr = tts.synthesize(text=sentence,
|
||||
noise_scale=tts_info["noise_scale"],
|
||||
noise_scale_w=tts_info["noise_scale_w"],
|
||||
speed=tts_info["length_scale"],
|
||||
target_se=target_se)
|
||||
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
|
||||
await ws.send_bytes(audio) #返回音频二进制流数据
|
||||
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
||||
|
@ -673,23 +719,6 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
|||
break
|
||||
voice_call_end_event.set()
|
||||
|
||||
|
||||
#语音合成及返回函数
|
||||
async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event):
|
||||
logger.debug("语音合成及返回函数启动")
|
||||
while not (split_finished_event.is_set() and split_result_q.empty()):
|
||||
try:
|
||||
sentence = await asyncio.wait_for(split_result_q.get(),timeout=3)
|
||||
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
|
||||
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
|
||||
await ws.send_bytes(audio) #返回音频二进制流数据
|
||||
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
||||
logger.debug(f"websocket返回:{sentence}")
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
voice_call_end_event.set()
|
||||
|
||||
|
||||
async def voice_call_handler(ws, db, redis):
|
||||
logger.debug("voice_call websocket 连接建立")
|
||||
audio_q = asyncio.Queue() #音频队列
|
||||
|
|
|
@ -1,14 +1,20 @@
|
|||
from ..schemas.user_schema import *
|
||||
from ..dependencies.logger import get_logger
|
||||
from ..dependencies.tts import get_tts
|
||||
from ..models import User, Hardware, Audio
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import HTTPException, status
|
||||
from pydub import AudioSegment
|
||||
import numpy as np
|
||||
import io
|
||||
|
||||
|
||||
#依赖注入获取logger
|
||||
logger = get_logger()
|
||||
|
||||
#依赖注入获取tts
|
||||
tts = get_tts()
|
||||
|
||||
#创建用户
|
||||
async def create_user_handler(user:UserCrateRequest, db: Session):
|
||||
|
@ -156,13 +162,17 @@ async def get_hardware_handler(hardware_id, db):
|
|||
#用户上传音频
|
||||
async def upload_audio_handler(user_id, audio, db):
|
||||
try:
|
||||
audio_data = audio.file.read()
|
||||
new_audio = Audio(user_id=user_id, audio_data=audio_data)
|
||||
audio_data = await audio.read()
|
||||
raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data
|
||||
numpy_data = np.frombuffer(raw_data, dtype=np.int32)
|
||||
emb_data = tts.audio2emb(numpy_data,rate=44100,vad=True).tobytes()
|
||||
new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_data)
|
||||
db.add(new_audio)
|
||||
db.commit()
|
||||
db.refresh(new_audio)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(str(e))
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat())
|
||||
return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data)
|
||||
|
@ -174,8 +184,11 @@ async def update_audio_handler(audio_id, audio_file, db):
|
|||
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
|
||||
if existing_audio is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
|
||||
audio_data = audio_file.file.read()
|
||||
audio_data = await audio_file.read()
|
||||
raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data
|
||||
emb_data = tts.audio2emb(np.frombuffer(raw_data, dtype=np.int32),rate=44100,vad=True).tobytes()
|
||||
existing_audio.audio_data = audio_data
|
||||
existing_audio.emb_data = emb_data
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
|
|
Binary file not shown.
|
@ -30,6 +30,7 @@ class ChatServiceTest:
|
|||
}
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("用户创建成功")
|
||||
self.user_id = response.json()['data']['user_id']
|
||||
else:
|
||||
raise Exception("创建聊天时,用户创建失败")
|
||||
|
@ -57,6 +58,21 @@ class ChatServiceTest:
|
|||
else:
|
||||
raise Exception("创建聊天时,角色创建失败")
|
||||
|
||||
#上传音频用于音频克隆
|
||||
url = f"{self.socket}/users/audio?user_id={self.user_id}"
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
current_file_path = os.path.dirname(current_file_path)
|
||||
tests_dir = os.path.dirname(current_file_path)
|
||||
mp3_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
|
||||
with open(mp3_file_path, 'rb') as audio_file:
|
||||
files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')}
|
||||
response = requests.post(url,files=files)
|
||||
if response.status_code == 200:
|
||||
self.audio_id = response.json()['data']['audio_id']
|
||||
print("音频上传成功")
|
||||
else:
|
||||
raise Exception("音频上传失败")
|
||||
|
||||
#创建一个对话
|
||||
url = f"{self.socket}/chats"
|
||||
payload = json.dumps({
|
||||
|
@ -66,6 +82,7 @@ class ChatServiceTest:
|
|||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("对话创建成功")
|
||||
|
@ -302,6 +319,11 @@ class ChatServiceTest:
|
|||
else:
|
||||
raise Exception("聊天删除测试失败")
|
||||
|
||||
url = f"{self.socket}/users/audio/{self.audio_id}"
|
||||
response = requests.request("DELETE", url)
|
||||
if response.status_code != 200:
|
||||
raise Exception("音频删除测试失败")
|
||||
|
||||
url = f"{self.socket}/users/{self.user_id}"
|
||||
response = requests.request("DELETE", url)
|
||||
if response.status_code != 200:
|
||||
|
@ -312,7 +334,6 @@ class ChatServiceTest:
|
|||
if response.status_code != 200:
|
||||
raise Exception("角色删除测试失败")
|
||||
|
||||
|
||||
def chat_test():
|
||||
chat_service_test = ChatServiceTest()
|
||||
chat_service_test.test_create_chat()
|
||||
|
|
|
@ -5,7 +5,7 @@ import os
|
|||
|
||||
|
||||
class UserServiceTest:
|
||||
def __init__(self,socket="http://127.0.0.1:8001"):
|
||||
def __init__(self,socket="http://127.0.0.1:7878"):
|
||||
self.socket = socket
|
||||
|
||||
def test_user_create(self):
|
||||
|
@ -132,7 +132,7 @@ class UserServiceTest:
|
|||
current_file_path = os.path.abspath(__file__)
|
||||
current_dir = os.path.dirname(current_file_path)
|
||||
tests_dir = os.path.dirname(current_dir)
|
||||
wav_file_path = os.path.join(tests_dir, 'assets', 'iat_mp3_8k.mp3')
|
||||
wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
|
||||
with open(wav_file_path, 'rb') as audio_file:
|
||||
files = {'audio_file':(wav_file_path,audio_file,'audio/mpeg')}
|
||||
response = requests.post(url, files=files)
|
||||
|
@ -147,9 +147,9 @@ class UserServiceTest:
|
|||
current_file_path = os.path.abspath(__file__)
|
||||
current_dir = os.path.dirname(current_file_path)
|
||||
tests_dir = os.path.dirname(current_dir)
|
||||
wav_file_path = os.path.join(tests_dir, 'assets', 'iat_mp3_8k.mp3')
|
||||
wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
|
||||
with open(wav_file_path, 'rb') as audio_file:
|
||||
files = {'audio_file':(wav_file_path,audio_file,'audio/mpeg')}
|
||||
files = {'audio_file':(wav_file_path,audio_file,'audio/wav')}
|
||||
response = requests.put(url, files=files)
|
||||
if response.status_code == 200:
|
||||
print("音频上传测试成功")
|
||||
|
|
|
@ -260,6 +260,11 @@ class TextToSpeech:
|
|||
audio: tensor
|
||||
sr: 生成音频的采样速率
|
||||
"""
|
||||
if source_se is not None:
|
||||
source_se = torch.tensor(source_se.astype(np.float32)).to(self.device)
|
||||
if target_se is not None:
|
||||
target_se = torch.tensor(target_se.astype(np.float32)).to(self.device)
|
||||
|
||||
if source_se is None:
|
||||
source_se = self.source_se
|
||||
if target_se is None:
|
||||
|
@ -319,7 +324,6 @@ class TextToSpeech:
|
|||
tts_sr = self.base_tts_model.hps.data.sampling_rate
|
||||
converter_sr = self.tone_color_converter.hps.data.sampling_rate
|
||||
audio = F.resample(audio, tts_sr, converter_sr)
|
||||
print(audio.dtype)
|
||||
audio, sr = self._convert_tone(audio,
|
||||
source_se=source_se,
|
||||
target_se=target_se,
|
||||
|
|
Loading…
Reference in New Issue