forked from killua/TakwayDisplayPlatform
285 lines
10 KiB
Python
285 lines
10 KiB
Python
from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
|
||
from .model import Assistant
|
||
from .abstract import *
|
||
from .public import *
|
||
from utils.vits_utils import TextToSpeech
|
||
from config import Config
|
||
import aiohttp
|
||
import asyncio
|
||
import struct
|
||
import base64
|
||
import json
|
||
|
||
# ----------- 初始化vits ----------- #
|
||
vits = TextToSpeech()
|
||
# ---------------------------------- #
|
||
|
||
|
||
|
||
#------ 具体 ASR, LLM, TTS 类 ------ #
|
||
FIRST_FRAME = 1
|
||
CONTINUE_FRAME =2
|
||
LAST_FRAME =3
|
||
|
||
class XF_ASR(ASR):
|
||
def __init__(self):
|
||
self.websocket = None
|
||
self.current_message = ""
|
||
self.audio = ""
|
||
self.status = FIRST_FRAME
|
||
self.segment_duration_threshold = 25 #超时时间为25秒
|
||
self.segment_start_time = None
|
||
|
||
async def stream_recognize(self, chunk):
|
||
if self.websocket is None: #如果websocket未建立,则建立一个新的连接
|
||
self.websocket = await xf_asr_websocket_factory()
|
||
if self.segment_start_time is None: #如果是第一段,则记录开始时间
|
||
self.segment_start_time = asyncio.get_event_loop().time()
|
||
if chunk['meta_info']['is_end']:
|
||
self.status = LAST_FRAME
|
||
audio_data = chunk['audio']
|
||
self.audio += audio_data
|
||
|
||
if self.status == FIRST_FRAME: #发送第一帧
|
||
await self.websocket.send(make_first_frame(audio_data))
|
||
self.status = CONTINUE_FRAME
|
||
elif self.status == CONTINUE_FRAME: #发送中间帧
|
||
await self.websocket.send(make_continue_frame(audio_data))
|
||
elif self.status == LAST_FRAME: #发送最后一帧
|
||
await self.websocket.send(make_last_frame(audio_data))
|
||
self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv()))
|
||
if self.current_message == "":
|
||
raise AsrResultNoneError()
|
||
await self.websocket.close()
|
||
print("语音识别结束,用户消息:", self.current_message)
|
||
return [{"text":self.current_message, "audio":self.audio}]
|
||
|
||
current_time = asyncio.get_event_loop().time()
|
||
if current_time - self.segment_start_time > self.segment_duration_threshold: #超时,发送最后一帧并重置状态
|
||
await self.websocket.send(make_last_frame())
|
||
self.current_message += parse_xfasr_recv(await self.websocket.recv())
|
||
await self.websocket.close()
|
||
self.websocket = await xf_asr_websocket_factory()
|
||
self.status = FIRST_FRAME
|
||
self.segment_start_time = current_time
|
||
return []
|
||
|
||
class MINIMAX_LLM(LLM):
|
||
def __init__(self):
|
||
self.token = 0
|
||
|
||
async def chat(self, assistant, prompt, db):
|
||
llm_info = json.loads(assistant.llm_info)
|
||
messages = json.loads(assistant.messages)
|
||
messages.append({"role":"user","content":prompt})
|
||
assistant.messages = json.dumps(messages)
|
||
payload = json.dumps({
|
||
"model": llm_info['model'],
|
||
"stream": True,
|
||
"messages": messages,
|
||
"max_tokens": llm_info['max_tokens'],
|
||
"temperature": llm_info['temperature'],
|
||
"top_p": llm_info['top_p'],
|
||
})
|
||
headers = {
|
||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||
'Content-Type': 'application/json'
|
||
}
|
||
async with aiohttp.ClientSession() as client:
|
||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response:
|
||
async for chunk in response.content.iter_any():
|
||
try:
|
||
chunk_msg = self.__parseChunk(chunk)
|
||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||
yield msg_frame
|
||
except LLMResponseEnd:
|
||
msg_frame = {"is_end":True,"msg":""}
|
||
if self.token > llm_info['max_tokens'] * 0.8:
|
||
msg_frame['msg'] = 'max_token reached'
|
||
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
|
||
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||
db.commit()
|
||
assistant.messages = as_query.messages
|
||
yield msg_frame
|
||
|
||
|
||
def __parseChunk(self, llm_chunk):
|
||
result = ""
|
||
data=json.loads(llm_chunk.decode('utf-8')[6:])
|
||
if data["object"] == "chat.completion": #如果是结束帧
|
||
self.token = data['usage']['total_tokens']
|
||
raise LLMResponseEnd()
|
||
elif data['object'] == 'chat.completion.chunk':
|
||
for choice in data['choices']:
|
||
result += choice['delta']['content']
|
||
else:
|
||
raise UnkownLLMFrame()
|
||
return result
|
||
|
||
class VITS_TTS(TTS):
|
||
def __init__(self):
|
||
pass
|
||
|
||
def synthetize(self, assistant, text):
|
||
tts_info = json.loads(assistant.tts_info)
|
||
return vits.synthesize(text, tts_info)
|
||
# --------------------------------- #
|
||
|
||
|
||
|
||
# ------ ASR, LLM, TTS 工厂类 ------ #
|
||
class ASRFactory:
|
||
def create_asr(self,asr_type:str) -> ASR:
|
||
if asr_type == 'XF':
|
||
return XF_ASR()
|
||
|
||
class LLMFactory:
|
||
def create_llm(self,llm_type:str) -> LLM:
|
||
if llm_type == 'MINIMAX':
|
||
return MINIMAX_LLM()
|
||
|
||
class TTSFactory:
|
||
def create_tts(self,tts_type:str) -> TTS:
|
||
if tts_type == 'VITS':
|
||
return VITS_TTS()
|
||
# --------------------------------- #
|
||
|
||
|
||
|
||
# ----------- 具体服务类 ----------- #
|
||
# 从单说话人的asr_results中取出第一个结果作为prompt
|
||
class BasicPromptService(PromptService):
|
||
def prompt_process(self, prompt:str, **kwargs):
|
||
if 'asr_result' in kwargs:
|
||
return kwargs['asr_result'][0]['text'], kwargs
|
||
else:
|
||
raise NoAsrResultsError()
|
||
|
||
class UserAudioRecordService(UserAudioService):
|
||
def user_audio_process(self, audio:str, **kwargs):
|
||
audio_decoded = base64.b64decode(audio)
|
||
kwargs['recorder'].user_audio += audio_decoded
|
||
return audio,kwargs
|
||
|
||
class TTSAudioRecordService(TTSAudioService):
|
||
def tts_audio_process(self, audio:bytes, **kwargs):
|
||
kwargs['recorder'].tts_audio += audio
|
||
return audio,kwargs
|
||
# --------------------------------- #
|
||
|
||
|
||
|
||
# ------------ 服务链类 ----------- #
|
||
class UserAudioServiceChain:
|
||
def __init__(self):
|
||
self.services = []
|
||
def add_service(self, service:UserAudioService):
|
||
self.services.append(service)
|
||
def user_audio_process(self, audio, **kwargs):
|
||
for service in self.services:
|
||
audio, kwargs = service.user_audio_process(audio, **kwargs)
|
||
return audio
|
||
|
||
class PromptServiceChain:
|
||
def __init__(self):
|
||
self.services = []
|
||
def add_service(self, service:PromptService):
|
||
self.services.append(service)
|
||
def prompt_process(self, asr_results):
|
||
kwargs = {"asr_result":asr_results}
|
||
prompt = ""
|
||
for service in self.services:
|
||
prompt, kwargs = service.prompt_process(prompt, **kwargs)
|
||
return prompt
|
||
|
||
class LLMMsgServiceChain:
|
||
def __init__(self, ):
|
||
self.services = []
|
||
self.spliter = SentenceSegmentation()
|
||
def add_service(self, service:LLMMsgService):
|
||
self.services.append(service)
|
||
def llm_msg_process(self, llm_msg):
|
||
kwargs = {}
|
||
llm_chunks = self.spliter.split(llm_msg)
|
||
for service in self.services:
|
||
llm_chunks , kwargs = service.llm_msg_process(llm_chunks, **kwargs)
|
||
return llm_chunks
|
||
|
||
class TTSAudioServiceChain:
|
||
def __init__(self):
|
||
self.services = []
|
||
def add_service(self, service:TTSAudioService):
|
||
self.services.append(service)
|
||
def tts_audio_process(self, audio, **kwargs):
|
||
for service in self.services:
|
||
audio, kwargs = service.tts_audio_process(audio, **kwargs)
|
||
return audio
|
||
# --------------------------------- #
|
||
|
||
class Agent():
|
||
def __init__(self, asr_type:str, llm_type:str, tts_type:str):
|
||
asrFactory = ASRFactory()
|
||
llmFactory = LLMFactory()
|
||
ttsFactory = TTSFactory()
|
||
self.recorder = None
|
||
|
||
self.user_audio_service_chain = UserAudioServiceChain() #创建用户音频处理服务链
|
||
self.user_audio_service_chain.add_service(UserAudioRecordService()) #添加用户音频记录服务
|
||
|
||
self.asr = asrFactory.create_asr(asr_type)
|
||
|
||
self.prompt_service_chain = PromptServiceChain()
|
||
self.prompt_service_chain.add_service(BasicPromptService())
|
||
|
||
self.llm = llmFactory.create_llm(llm_type)
|
||
|
||
self.llm_chunk_service_chain = LLMMsgServiceChain()
|
||
|
||
self.tts = ttsFactory.create_tts(tts_type)
|
||
|
||
self.tts_audio_service_chain = TTSAudioServiceChain()
|
||
self.tts_audio_service_chain.add_service(TTSAudioRecordService())
|
||
|
||
def init_recorder(self,user_id):
|
||
self.recorder = Recorder(user_id)
|
||
|
||
# 对用户输入的音频进行预处理
|
||
def user_audio_process(self, audio, db):
|
||
return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder)
|
||
|
||
# 进行流式语音识别
|
||
async def stream_recognize(self, chunk, db):
|
||
return await self.asr.stream_recognize(chunk)
|
||
|
||
# 进行Prompt加工
|
||
def prompt_process(self, asr_results, db):
|
||
return self.prompt_service_chain.prompt_process(asr_results)
|
||
|
||
# 进行大模型调用
|
||
async def chat(self, assistant ,prompt, db):
|
||
return self.llm.chat(assistant, prompt, db)
|
||
|
||
# 对大模型的返回进行处理
|
||
def llm_msg_process(self, llm_chunk, db):
|
||
return self.llm_chunk_service_chain.llm_msg_process(llm_chunk)
|
||
|
||
# 进行TTS合成
|
||
def synthetize(self, assistant, text, db):
|
||
return self.tts.synthetize(assistant, text)
|
||
|
||
# 对合成后的音频进行处理
|
||
def tts_audio_process(self, audio, db):
|
||
return self.tts_audio_service_chain.tts_audio_process(audio, recorder=self.recorder)
|
||
|
||
# 编码
|
||
def encode(self, text, audio):
|
||
text_resp = {"type":"text","code":200,"msg":text}
|
||
text_resp_bytes = json.dumps(text_resp,ensure_ascii=False).encode('utf-8')
|
||
header = struct.pack('!II',len(text_resp_bytes),len(audio))
|
||
final_resp = header + text_resp_bytes + audio
|
||
return final_resp
|
||
|
||
# 保存音频
|
||
def save(self):
|
||
self.recorder.write()
|
||
|