forked from killua/TakwayPlatform
feat: 通过依赖注入来获取asr和tts对象
This commit is contained in:
parent
646c188f78
commit
fdee2e7520
|
@ -1,6 +1,8 @@
|
|||
from ..schemas.chat_schema import *
|
||||
from ..dependencies.logger import get_logger
|
||||
from ..dependencies.summarizer import get_summarizer
|
||||
from ..dependencies.asr import get_asr
|
||||
from ..dependencies.tts import get_tts
|
||||
from .controller_enum import *
|
||||
from ..models import UserCharacter, Session, Character, User
|
||||
from utils.audio_utils import VAD
|
||||
|
@ -19,21 +21,14 @@ logger = get_logger()
|
|||
# 依赖注入获取context总结服务
|
||||
summarizer = get_summarizer()
|
||||
|
||||
# --------------------初始化本地ASR-----------------------
|
||||
from utils.stt.modified_funasr import ModifiedRecognizer
|
||||
|
||||
asr = ModifiedRecognizer()
|
||||
logger.info("本地ASR初始化成功")
|
||||
# -----------------------获取ASR-------------------------
|
||||
asr = get_asr()
|
||||
# -------------------------------------------------------
|
||||
|
||||
# --------------------初始化本地VITS----------------------
|
||||
from utils.tts.vits_utils import TextToSpeech
|
||||
|
||||
tts = TextToSpeech(device='cpu')
|
||||
logger.info("本地TTS初始化成功")
|
||||
# -------------------------TTS--------------------------
|
||||
tts = get_tts()
|
||||
# -------------------------------------------------------
|
||||
|
||||
|
||||
# 依赖注入获取Config
|
||||
Config = get_config()
|
||||
|
||||
|
@ -488,7 +483,7 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||
is_first = True
|
||||
llm_response = ""
|
||||
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数,则进行文本摘要
|
||||
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于70%的最大token数,则进行文本摘要
|
||||
system_prompt = messages[0]['content']
|
||||
summary = await summarizer.summarize(messages)
|
||||
events = user_info['events']
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
from utils.stt.modified_funasr import ModifiedRecognizer
|
||||
from app.dependencies.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
#初始化全局asr对象
|
||||
asr = ModifiedRecognizer()
|
||||
logger.info("ASR初始化成功")
|
||||
|
||||
def get_asr():
|
||||
return asr
|
|
@ -0,0 +1,11 @@
|
|||
from utils.tts.openvoice_utils import TextToSpeech
|
||||
from app.dependencies.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
#初始化全局tts对象
|
||||
tts = TextToSpeech(use_tone_convert=True,device='cuda')
|
||||
logger.info("TTS初始化成功")
|
||||
|
||||
def get_tts():
|
||||
return tts
|
|
@ -21,4 +21,9 @@ apscheduler
|
|||
aiohttp
|
||||
faster_whisper
|
||||
whisper_timestamped
|
||||
<<<<<<< Updated upstream
|
||||
modelscope
|
||||
=======
|
||||
modelscope
|
||||
wavmark
|
||||
>>>>>>> Stashed changes
|
||||
|
|
Loading…
Reference in New Issue