forked from killua/TakwayDisplayPlatform
285 lines
10 KiB
Python
285 lines
10 KiB
Python
|
from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
|
|||
|
from .model import Assistant
|
|||
|
from .abstract import *
|
|||
|
from .public import *
|
|||
|
from utils.vits_utils import TextToSpeech
|
|||
|
from config import Config
|
|||
|
import aiohttp
|
|||
|
import asyncio
|
|||
|
import struct
|
|||
|
import base64
|
|||
|
import json
|
|||
|
|
|||
|
# ----------- 初始化vits ----------- #
|
|||
|
vits = TextToSpeech()
|
|||
|
# ---------------------------------- #
|
|||
|
|
|||
|
|
|||
|
|
|||
|
#------ 具体 ASR, LLM, TTS 类 ------ #
|
|||
|
FIRST_FRAME = 1
|
|||
|
CONTINUE_FRAME =2
|
|||
|
LAST_FRAME =3
|
|||
|
|
|||
|
class XF_ASR(ASR):
|
|||
|
def __init__(self):
|
|||
|
self.websocket = None
|
|||
|
self.current_message = ""
|
|||
|
self.audio = ""
|
|||
|
self.status = FIRST_FRAME
|
|||
|
self.segment_duration_threshold = 25 #超时时间为25秒
|
|||
|
self.segment_start_time = None
|
|||
|
|
|||
|
async def stream_recognize(self, chunk):
|
|||
|
if self.websocket is None: #如果websocket未建立,则建立一个新的连接
|
|||
|
self.websocket = await xf_asr_websocket_factory()
|
|||
|
if self.segment_start_time is None: #如果是第一段,则记录开始时间
|
|||
|
self.segment_start_time = asyncio.get_event_loop().time()
|
|||
|
if chunk['meta_info']['is_end']:
|
|||
|
self.status = LAST_FRAME
|
|||
|
audio_data = chunk['audio']
|
|||
|
self.audio += audio_data
|
|||
|
|
|||
|
if self.status == FIRST_FRAME: #发送第一帧
|
|||
|
await self.websocket.send(make_first_frame(audio_data))
|
|||
|
self.status = CONTINUE_FRAME
|
|||
|
elif self.status == CONTINUE_FRAME: #发送中间帧
|
|||
|
await self.websocket.send(make_continue_frame(audio_data))
|
|||
|
elif self.status == LAST_FRAME: #发送最后一帧
|
|||
|
await self.websocket.send(make_last_frame(audio_data))
|
|||
|
self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv()))
|
|||
|
if self.current_message == "":
|
|||
|
raise AsrResultNoneError()
|
|||
|
await self.websocket.close()
|
|||
|
print("语音识别结束,用户消息:", self.current_message)
|
|||
|
return [{"text":self.current_message, "audio":self.audio}]
|
|||
|
|
|||
|
current_time = asyncio.get_event_loop().time()
|
|||
|
if current_time - self.segment_start_time > self.segment_duration_threshold: #超时,发送最后一帧并重置状态
|
|||
|
await self.websocket.send(make_last_frame())
|
|||
|
self.current_message += parse_xfasr_recv(await self.websocket.recv())
|
|||
|
await self.websocket.close()
|
|||
|
self.websocket = await xf_asr_websocket_factory()
|
|||
|
self.status = FIRST_FRAME
|
|||
|
self.segment_start_time = current_time
|
|||
|
return []
|
|||
|
|
|||
|
class MINIMAX_LLM(LLM):
|
|||
|
def __init__(self):
|
|||
|
self.token = 0
|
|||
|
|
|||
|
async def chat(self, assistant, prompt, db):
|
|||
|
llm_info = json.loads(assistant.llm_info)
|
|||
|
messages = json.loads(assistant.messages)
|
|||
|
messages.append({"role":"user","content":prompt})
|
|||
|
assistant.messages = json.dumps(messages)
|
|||
|
payload = json.dumps({
|
|||
|
"model": llm_info['model'],
|
|||
|
"stream": True,
|
|||
|
"messages": messages,
|
|||
|
"max_tokens": llm_info['max_tokens'],
|
|||
|
"temperature": llm_info['temperature'],
|
|||
|
"top_p": llm_info['top_p'],
|
|||
|
})
|
|||
|
headers = {
|
|||
|
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
|||
|
'Content-Type': 'application/json'
|
|||
|
}
|
|||
|
async with aiohttp.ClientSession() as client:
|
|||
|
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response:
|
|||
|
async for chunk in response.content.iter_any():
|
|||
|
try:
|
|||
|
chunk_msg = self.__parseChunk(chunk)
|
|||
|
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
|||
|
yield msg_frame
|
|||
|
except LLMResponseEnd:
|
|||
|
msg_frame = {"is_end":True,"msg":""}
|
|||
|
if self.token > llm_info['max_tokens'] * 0.8:
|
|||
|
msg_frame['msg'] = 'max_token reached'
|
|||
|
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
|
|||
|
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
|||
|
db.commit()
|
|||
|
assistant.messages = as_query.messages
|
|||
|
yield msg_frame
|
|||
|
|
|||
|
|
|||
|
def __parseChunk(self, llm_chunk):
|
|||
|
result = ""
|
|||
|
data=json.loads(llm_chunk.decode('utf-8')[6:])
|
|||
|
if data["object"] == "chat.completion": #如果是结束帧
|
|||
|
self.token = data['usage']['total_tokens']
|
|||
|
raise LLMResponseEnd()
|
|||
|
elif data['object'] == 'chat.completion.chunk':
|
|||
|
for choice in data['choices']:
|
|||
|
result += choice['delta']['content']
|
|||
|
else:
|
|||
|
raise UnkownLLMFrame()
|
|||
|
return result
|
|||
|
|
|||
|
class VITS_TTS(TTS):
|
|||
|
def __init__(self):
|
|||
|
pass
|
|||
|
|
|||
|
def synthetize(self, assistant, text):
|
|||
|
tts_info = json.loads(assistant.tts_info)
|
|||
|
return vits.synthesize(text, tts_info)
|
|||
|
# --------------------------------- #
|
|||
|
|
|||
|
|
|||
|
|
|||
|
# ------ ASR, LLM, TTS 工厂类 ------ #
|
|||
|
class ASRFactory:
|
|||
|
def create_asr(self,asr_type:str) -> ASR:
|
|||
|
if asr_type == 'XF':
|
|||
|
return XF_ASR()
|
|||
|
|
|||
|
class LLMFactory:
|
|||
|
def create_llm(self,llm_type:str) -> LLM:
|
|||
|
if llm_type == 'MINIMAX':
|
|||
|
return MINIMAX_LLM()
|
|||
|
|
|||
|
class TTSFactory:
|
|||
|
def create_tts(self,tts_type:str) -> TTS:
|
|||
|
if tts_type == 'VITS':
|
|||
|
return VITS_TTS()
|
|||
|
# --------------------------------- #
|
|||
|
|
|||
|
|
|||
|
|
|||
|
# ----------- 具体服务类 ----------- #
|
|||
|
# 从单说话人的asr_results中取出第一个结果作为prompt
|
|||
|
class BasicPromptService(PromptService):
|
|||
|
def prompt_process(self, prompt:str, **kwargs):
|
|||
|
if 'asr_result' in kwargs:
|
|||
|
return kwargs['asr_result'][0]['text'], kwargs
|
|||
|
else:
|
|||
|
raise NoAsrResultsError()
|
|||
|
|
|||
|
class UserAudioRecordService(UserAudioService):
|
|||
|
def user_audio_process(self, audio:str, **kwargs):
|
|||
|
audio_decoded = base64.b64decode(audio)
|
|||
|
kwargs['recorder'].user_audio += audio_decoded
|
|||
|
return audio,kwargs
|
|||
|
|
|||
|
class TTSAudioRecordService(TTSAudioService):
|
|||
|
def tts_audio_process(self, audio:bytes, **kwargs):
|
|||
|
kwargs['recorder'].tts_audio += audio
|
|||
|
return audio,kwargs
|
|||
|
# --------------------------------- #
|
|||
|
|
|||
|
|
|||
|
|
|||
|
# ------------ 服务链类 ----------- #
|
|||
|
class UserAudioServiceChain:
|
|||
|
def __init__(self):
|
|||
|
self.services = []
|
|||
|
def add_service(self, service:UserAudioService):
|
|||
|
self.services.append(service)
|
|||
|
def user_audio_process(self, audio, **kwargs):
|
|||
|
for service in self.services:
|
|||
|
audio, kwargs = service.user_audio_process(audio, **kwargs)
|
|||
|
return audio
|
|||
|
|
|||
|
class PromptServiceChain:
|
|||
|
def __init__(self):
|
|||
|
self.services = []
|
|||
|
def add_service(self, service:PromptService):
|
|||
|
self.services.append(service)
|
|||
|
def prompt_process(self, asr_results):
|
|||
|
kwargs = {"asr_result":asr_results}
|
|||
|
prompt = ""
|
|||
|
for service in self.services:
|
|||
|
prompt, kwargs = service.prompt_process(prompt, **kwargs)
|
|||
|
return prompt
|
|||
|
|
|||
|
class LLMMsgServiceChain:
|
|||
|
def __init__(self, ):
|
|||
|
self.services = []
|
|||
|
self.spliter = SentenceSegmentation()
|
|||
|
def add_service(self, service:LLMMsgService):
|
|||
|
self.services.append(service)
|
|||
|
def llm_msg_process(self, llm_msg):
|
|||
|
kwargs = {}
|
|||
|
llm_chunks = self.spliter.split(llm_msg)
|
|||
|
for service in self.services:
|
|||
|
llm_chunks , kwargs = service.llm_msg_process(llm_chunks, **kwargs)
|
|||
|
return llm_chunks
|
|||
|
|
|||
|
class TTSAudioServiceChain:
|
|||
|
def __init__(self):
|
|||
|
self.services = []
|
|||
|
def add_service(self, service:TTSAudioService):
|
|||
|
self.services.append(service)
|
|||
|
def tts_audio_process(self, audio, **kwargs):
|
|||
|
for service in self.services:
|
|||
|
audio, kwargs = service.tts_audio_process(audio, **kwargs)
|
|||
|
return audio
|
|||
|
# --------------------------------- #
|
|||
|
|
|||
|
class Agent():
|
|||
|
def __init__(self, asr_type:str, llm_type:str, tts_type:str):
|
|||
|
asrFactory = ASRFactory()
|
|||
|
llmFactory = LLMFactory()
|
|||
|
ttsFactory = TTSFactory()
|
|||
|
self.recorder = None
|
|||
|
|
|||
|
self.user_audio_service_chain = UserAudioServiceChain() #创建用户音频处理服务链
|
|||
|
self.user_audio_service_chain.add_service(UserAudioRecordService()) #添加用户音频记录服务
|
|||
|
|
|||
|
self.asr = asrFactory.create_asr(asr_type)
|
|||
|
|
|||
|
self.prompt_service_chain = PromptServiceChain()
|
|||
|
self.prompt_service_chain.add_service(BasicPromptService())
|
|||
|
|
|||
|
self.llm = llmFactory.create_llm(llm_type)
|
|||
|
|
|||
|
self.llm_chunk_service_chain = LLMMsgServiceChain()
|
|||
|
|
|||
|
self.tts = ttsFactory.create_tts(tts_type)
|
|||
|
|
|||
|
self.tts_audio_service_chain = TTSAudioServiceChain()
|
|||
|
self.tts_audio_service_chain.add_service(TTSAudioRecordService())
|
|||
|
|
|||
|
def init_recorder(self,user_id):
|
|||
|
self.recorder = Recorder(user_id)
|
|||
|
|
|||
|
# 对用户输入的音频进行预处理
|
|||
|
def user_audio_process(self, audio, db):
|
|||
|
return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder)
|
|||
|
|
|||
|
# 进行流式语音识别
|
|||
|
async def stream_recognize(self, chunk, db):
|
|||
|
return await self.asr.stream_recognize(chunk)
|
|||
|
|
|||
|
# 进行Prompt加工
|
|||
|
def prompt_process(self, asr_results, db):
|
|||
|
return self.prompt_service_chain.prompt_process(asr_results)
|
|||
|
|
|||
|
# 进行大模型调用
|
|||
|
async def chat(self, assistant ,prompt, db):
|
|||
|
return self.llm.chat(assistant, prompt, db)
|
|||
|
|
|||
|
# 对大模型的返回进行处理
|
|||
|
def llm_msg_process(self, llm_chunk, db):
|
|||
|
return self.llm_chunk_service_chain.llm_msg_process(llm_chunk)
|
|||
|
|
|||
|
# 进行TTS合成
|
|||
|
def synthetize(self, assistant, text, db):
|
|||
|
return self.tts.synthetize(assistant, text)
|
|||
|
|
|||
|
# 对合成后的音频进行处理
|
|||
|
def tts_audio_process(self, audio, db):
|
|||
|
return self.tts_audio_service_chain.tts_audio_process(audio, recorder=self.recorder)
|
|||
|
|
|||
|
# 编码
|
|||
|
def encode(self, text, audio):
|
|||
|
text_resp = {"type":"text","code":200,"msg":text}
|
|||
|
text_resp_bytes = json.dumps(text_resp,ensure_ascii=False).encode('utf-8')
|
|||
|
header = struct.pack('!II',len(text_resp_bytes),len(audio))
|
|||
|
final_resp = header + text_resp_bytes + audio
|
|||
|
return final_resp
|
|||
|
|
|||
|
# 保存音频
|
|||
|
def save(self):
|
|||
|
self.recorder.write()
|
|||
|
|