1
0
Fork 0

feat: 通过依赖注入来获取asr和tts对象

This commit is contained in:
killua4396 2024-05-22 15:26:01 +08:00
parent 646c188f78
commit fdee2e7520
4 changed files with 35 additions and 13 deletions

View File

@ -1,6 +1,8 @@
from ..schemas.chat_schema import * from ..schemas.chat_schema import *
from ..dependencies.logger import get_logger from ..dependencies.logger import get_logger
from ..dependencies.summarizer import get_summarizer from ..dependencies.summarizer import get_summarizer
from ..dependencies.asr import get_asr
from ..dependencies.tts import get_tts
from .controller_enum import * from .controller_enum import *
from ..models import UserCharacter, Session, Character, User from ..models import UserCharacter, Session, Character, User
from utils.audio_utils import VAD from utils.audio_utils import VAD
@ -19,21 +21,14 @@ logger = get_logger()
# 依赖注入获取context总结服务 # 依赖注入获取context总结服务
summarizer = get_summarizer() summarizer = get_summarizer()
# --------------------初始化本地ASR----------------------- # -----------------------获取ASR-------------------------
from utils.stt.modified_funasr import ModifiedRecognizer asr = get_asr()
asr = ModifiedRecognizer()
logger.info("本地ASR初始化成功")
# ------------------------------------------------------- # -------------------------------------------------------
# --------------------初始化本地VITS---------------------- # -------------------------TTS--------------------------
from utils.tts.vits_utils import TextToSpeech tts = get_tts()
tts = TextToSpeech(device='cpu')
logger.info("本地TTS初始化成功")
# ------------------------------------------------------- # -------------------------------------------------------
# 依赖注入获取Config # 依赖注入获取Config
Config = get_config() Config = get_config()
@ -488,7 +483,7 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
is_first = True is_first = True
llm_response = "" llm_response = ""
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数则进行文本摘要 if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于70%的最大token数则进行文本摘要
system_prompt = messages[0]['content'] system_prompt = messages[0]['content']
summary = await summarizer.summarize(messages) summary = await summarizer.summarize(messages)
events = user_info['events'] events = user_info['events']

11
app/dependencies/asr.py Normal file
View File

@ -0,0 +1,11 @@
from utils.stt.modified_funasr import ModifiedRecognizer
from app.dependencies.logger import get_logger
logger = get_logger()
#初始化全局asr对象
asr = ModifiedRecognizer()
logger.info("ASR初始化成功")
def get_asr():
return asr

11
app/dependencies/tts.py Normal file
View File

@ -0,0 +1,11 @@
from utils.tts.openvoice_utils import TextToSpeech
from app.dependencies.logger import get_logger
logger = get_logger()
#初始化全局tts对象
tts = TextToSpeech(use_tone_convert=True,device='cuda')
logger.info("TTS初始化成功")
def get_tts():
return tts

View File

@ -21,4 +21,9 @@ apscheduler
aiohttp aiohttp
faster_whisper faster_whisper
whisper_timestamped whisper_timestamped
modelscope <<<<<<< Updated upstream
modelscope
=======
modelscope
wavmark
>>>>>>> Stashed changes