diff --git a/app/controllers/chat_controller.py b/app/controllers/chat_controller.py index 17f2a85..baf68e5 100644 --- a/app/controllers/chat_controller.py +++ b/app/controllers/chat_controller.py @@ -4,12 +4,13 @@ from ..dependencies.summarizer import get_summarizer from ..dependencies.asr import get_asr from ..dependencies.tts import get_tts from .controller_enum import * -from ..models import UserCharacter, Session, Character, User +from ..models import UserCharacter, Session, Character, User, Audio from utils.audio_utils import VAD from fastapi import WebSocket, HTTPException, status from datetime import datetime from utils.xf_asr_utils import generate_xf_asr_url from config import get_config +import numpy as np import uuid import json import asyncio @@ -100,6 +101,19 @@ def update_session_activity(session_id,db): except Exception as e: db.roolback() raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + +#获取target_se +def get_emb(session_id,db): + try: + session_record = db.query(Session).filter(Session.id == session_id).first() + user_character_record = db.query(UserCharacter).filter(UserCharacter.id == session_record.user_character_id).first() + audio_record = db.query(Audio).filter(Audio.user_id == user_character_record.user_id).first() + emb_npy = np.frombuffer(audio_record.emb_data,dtype=np.int32) + emb_npy_3d = np.reshape(emb_npy,(1,256,1)) + return emb_npy_3d + except Exception as e: + logger.error("未找到音频:"+str(e)) + return np.array([]) #-------------------------------------------------------- # 创建新聊天 @@ -283,6 +297,7 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Content-Type': 'application/json' } + target_se = get_emb(session_id,db) except Exception as e: logger.error(f"编辑http请求时发生错误: {str(e)}") try: @@ -299,7 +314,17 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis response_message = {"type": "text", "code":200, "msg": sentence} await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息 elif response_type == RESPONSE_AUDIO: - sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True) + if target_se.size == 0: + audio,sr = tts._base_tts(text=sentence, + noise_scale=tts_info["noise_scale"], + noise_scale_w=tts_info["noise_scale_w"], + speed=tts_info["length_scale"]) + else: + audio,sr = tts.synthesize(text=sentence, + noise_scale=tts_info["noise_scale"], + noise_scale_w=tts_info["noise_scale_w"], + speed=tts_info["length_scale"], + target_se=target_se) response_message = {"type": "text", "code":200, "msg": sentence} await ws.send_bytes(audio) #返回音频数据 await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息 @@ -455,6 +480,7 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Content-Type': 'application/json' } + target_se = get_emb(session_id,db) async with aiohttp.ClientSession() as client: async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求 async for chunk in response.content.iter_any(): @@ -469,7 +495,17 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis response_message = {"type": "text", "code":200, "msg": sentence} await ws.send_text(json.dumps(response_message, ensure_ascii=False)) elif response_type == RESPONSE_AUDIO: - sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True) + if target_se.size == 0: + audio,sr = tts._base_tts(text=sentence, + noise_scale=tts_info["noise_scale"], + noise_scale_w=tts_info["noise_scale_w"], + speed=tts_info["length_scale"]) + else: + audio,sr = tts.synthesize(text=sentence, + noise_scale=tts_info["noise_scale"], + noise_scale_w=tts_info["noise_scale_w"], + speed=tts_info["length_scale"], + target_se=target_se) response_message = {"type": "text", "code":200, "msg": sentence} await ws.send_bytes(audio) await ws.send_text(json.dumps(response_message, ensure_ascii=False)) @@ -629,11 +665,11 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re "temperature": llm_info["temperature"], "top_p": llm_info["top_p"] }) - headers = { 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Content-Type': 'application/json' } + target_se = get_emb(session_id,db) async with aiohttp.ClientSession() as client: async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求 async for chunk in response.content.iter_any(): @@ -643,7 +679,17 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re llm_response += chunk_data sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) for sentence in sentences: - sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True) + if target_se.size == 0: + audio,sr = tts._base_tts(text=sentence, + noise_scale=tts_info["noise_scale"], + noise_scale_w=tts_info["noise_scale_w"], + speed=tts_info["length_scale"]) + else: + audio,sr = tts.synthesize(text=sentence, + noise_scale=tts_info["noise_scale"], + noise_scale_w=tts_info["noise_scale_w"], + speed=tts_info["length_scale"], + target_se=target_se) text_response = {"type": "llm_text", "code": 200, "msg": sentence} await ws.send_bytes(audio) #返回音频二进制流数据 await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 @@ -673,23 +719,6 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re break voice_call_end_event.set() - -#语音合成及返回函数 -async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event): - logger.debug("语音合成及返回函数启动") - while not (split_finished_event.is_set() and split_result_q.empty()): - try: - sentence = await asyncio.wait_for(split_result_q.get(),timeout=3) - sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True) - text_response = {"type": "llm_text", "code": 200, "msg": sentence} - await ws.send_bytes(audio) #返回音频二进制流数据 - await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 - logger.debug(f"websocket返回:{sentence}") - except asyncio.TimeoutError: - continue - voice_call_end_event.set() - - async def voice_call_handler(ws, db, redis): logger.debug("voice_call websocket 连接建立") audio_q = asyncio.Queue() #音频队列 diff --git a/app/controllers/user_controller.py b/app/controllers/user_controller.py index 4ccd404..bb30fcb 100644 --- a/app/controllers/user_controller.py +++ b/app/controllers/user_controller.py @@ -1,14 +1,20 @@ from ..schemas.user_schema import * from ..dependencies.logger import get_logger +from ..dependencies.tts import get_tts from ..models import User, Hardware, Audio from datetime import datetime from sqlalchemy.orm import Session from fastapi import HTTPException, status +from pydub import AudioSegment +import numpy as np +import io #依赖注入获取logger logger = get_logger() +#依赖注入获取tts +tts = get_tts() #创建用户 async def create_user_handler(user:UserCrateRequest, db: Session): @@ -156,13 +162,17 @@ async def get_hardware_handler(hardware_id, db): #用户上传音频 async def upload_audio_handler(user_id, audio, db): try: - audio_data = audio.file.read() - new_audio = Audio(user_id=user_id, audio_data=audio_data) + audio_data = await audio.read() + raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data + numpy_data = np.frombuffer(raw_data, dtype=np.int32) + emb_data = tts.audio2emb(numpy_data,rate=44100,vad=True).tobytes() + new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_data) db.add(new_audio) db.commit() db.refresh(new_audio) except Exception as e: db.rollback() + print(str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat()) return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data) @@ -174,8 +184,11 @@ async def update_audio_handler(audio_id, audio_file, db): existing_audio = db.query(Audio).filter(Audio.id == audio_id).first() if existing_audio is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在") - audio_data = audio_file.file.read() + audio_data = await audio_file.read() + raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data + emb_data = tts.audio2emb(np.frombuffer(raw_data, dtype=np.int32),rate=44100,vad=True).tobytes() existing_audio.audio_data = audio_data + existing_audio.emb_data = emb_data db.commit() except Exception as e: db.rollback() diff --git a/tests/assets/demo_speaker0.mp3 b/tests/assets/demo_speaker0.mp3 new file mode 100644 index 0000000..bf1e546 Binary files /dev/null and b/tests/assets/demo_speaker0.mp3 differ diff --git a/tests/unit_test/chat_test.py b/tests/unit_test/chat_test.py index ebfbf11..7ca335d 100644 --- a/tests/unit_test/chat_test.py +++ b/tests/unit_test/chat_test.py @@ -30,6 +30,7 @@ class ChatServiceTest: } response = requests.request("POST", url, headers=headers, data=payload) if response.status_code == 200: + print("用户创建成功") self.user_id = response.json()['data']['user_id'] else: raise Exception("创建聊天时,用户创建失败") @@ -57,6 +58,21 @@ class ChatServiceTest: else: raise Exception("创建聊天时,角色创建失败") + #上传音频用于音频克隆 + url = f"{self.socket}/users/audio?user_id={self.user_id}" + current_file_path = os.path.abspath(__file__) + current_file_path = os.path.dirname(current_file_path) + tests_dir = os.path.dirname(current_file_path) + mp3_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3') + with open(mp3_file_path, 'rb') as audio_file: + files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')} + response = requests.post(url,files=files) + if response.status_code == 200: + self.audio_id = response.json()['data']['audio_id'] + print("音频上传成功") + else: + raise Exception("音频上传失败") + #创建一个对话 url = f"{self.socket}/chats" payload = json.dumps({ @@ -66,6 +82,7 @@ class ChatServiceTest: headers = { 'Content-Type': 'application/json' } + response = requests.request("POST", url, headers=headers, data=payload) if response.status_code == 200: print("对话创建成功") @@ -302,6 +319,11 @@ class ChatServiceTest: else: raise Exception("聊天删除测试失败") + url = f"{self.socket}/users/audio/{self.audio_id}" + response = requests.request("DELETE", url) + if response.status_code != 200: + raise Exception("音频删除测试失败") + url = f"{self.socket}/users/{self.user_id}" response = requests.request("DELETE", url) if response.status_code != 200: @@ -312,7 +334,6 @@ class ChatServiceTest: if response.status_code != 200: raise Exception("角色删除测试失败") - def chat_test(): chat_service_test = ChatServiceTest() chat_service_test.test_create_chat() diff --git a/tests/unit_test/user_test.py b/tests/unit_test/user_test.py index 4eaaca8..40b5821 100644 --- a/tests/unit_test/user_test.py +++ b/tests/unit_test/user_test.py @@ -5,7 +5,7 @@ import os class UserServiceTest: - def __init__(self,socket="http://127.0.0.1:8001"): + def __init__(self,socket="http://127.0.0.1:7878"): self.socket = socket def test_user_create(self): @@ -132,7 +132,7 @@ class UserServiceTest: current_file_path = os.path.abspath(__file__) current_dir = os.path.dirname(current_file_path) tests_dir = os.path.dirname(current_dir) - wav_file_path = os.path.join(tests_dir, 'assets', 'iat_mp3_8k.mp3') + wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3') with open(wav_file_path, 'rb') as audio_file: files = {'audio_file':(wav_file_path,audio_file,'audio/mpeg')} response = requests.post(url, files=files) @@ -141,15 +141,15 @@ class UserServiceTest: print("音频上传测试成功") else: raise Exception("音频上传测试失败") - + def test_update_audio(self): url = f"{self.socket}/users/audio/{self.audio_id}" current_file_path = os.path.abspath(__file__) current_dir = os.path.dirname(current_file_path) tests_dir = os.path.dirname(current_dir) - wav_file_path = os.path.join(tests_dir, 'assets', 'iat_mp3_8k.mp3') + wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3') with open(wav_file_path, 'rb') as audio_file: - files = {'audio_file':(wav_file_path,audio_file,'audio/mpeg')} + files = {'audio_file':(wav_file_path,audio_file,'audio/wav')} response = requests.put(url, files=files) if response.status_code == 200: print("音频上传测试成功") diff --git a/utils/tts/openvoice_utils.py b/utils/tts/openvoice_utils.py index c4781fe..ab0604a 100644 --- a/utils/tts/openvoice_utils.py +++ b/utils/tts/openvoice_utils.py @@ -260,6 +260,11 @@ class TextToSpeech: audio: tensor sr: 生成音频的采样速率 """ + if source_se is not None: + source_se = torch.tensor(source_se.astype(np.float32)).to(self.device) + if target_se is not None: + target_se = torch.tensor(target_se.astype(np.float32)).to(self.device) + if source_se is None: source_se = self.source_se if target_se is None: @@ -319,7 +324,6 @@ class TextToSpeech: tts_sr = self.base_tts_model.hps.data.sampling_rate converter_sr = self.tone_color_converter.hps.data.sampling_rate audio = F.resample(audio, tts_sr, converter_sr) - print(audio.dtype) audio, sr = self._convert_tone(audio, source_se=source_se, target_se=target_se,