1
0
Fork 0
TakwayDisplayPlatform/app/concrete.py

285 lines
10 KiB
Python
Raw Normal View History

2024-06-09 22:54:13 +08:00
from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
from .model import Assistant
from .abstract import *
from .public import *
from utils.vits_utils import TextToSpeech
from config import Config
import aiohttp
import asyncio
import struct
import base64
import json
# ----------- 初始化vits ----------- #
vits = TextToSpeech()
# ---------------------------------- #
#------ 具体 ASR, LLM, TTS 类 ------ #
FIRST_FRAME = 1
CONTINUE_FRAME =2
LAST_FRAME =3
class XF_ASR(ASR):
def __init__(self):
self.websocket = None
self.current_message = ""
self.audio = ""
self.status = FIRST_FRAME
self.segment_duration_threshold = 25 #超时时间为25秒
self.segment_start_time = None
async def stream_recognize(self, chunk):
if self.websocket is None: #如果websocket未建立则建立一个新的连接
self.websocket = await xf_asr_websocket_factory()
if self.segment_start_time is None: #如果是第一段,则记录开始时间
self.segment_start_time = asyncio.get_event_loop().time()
if chunk['meta_info']['is_end']:
self.status = LAST_FRAME
audio_data = chunk['audio']
self.audio += audio_data
if self.status == FIRST_FRAME: #发送第一帧
await self.websocket.send(make_first_frame(audio_data))
self.status = CONTINUE_FRAME
elif self.status == CONTINUE_FRAME: #发送中间帧
await self.websocket.send(make_continue_frame(audio_data))
elif self.status == LAST_FRAME: #发送最后一帧
await self.websocket.send(make_last_frame(audio_data))
self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv()))
if self.current_message == "":
raise AsrResultNoneError()
await self.websocket.close()
print("语音识别结束,用户消息:", self.current_message)
return [{"text":self.current_message, "audio":self.audio}]
current_time = asyncio.get_event_loop().time()
if current_time - self.segment_start_time > self.segment_duration_threshold: #超时,发送最后一帧并重置状态
await self.websocket.send(make_last_frame())
self.current_message += parse_xfasr_recv(await self.websocket.recv())
await self.websocket.close()
self.websocket = await xf_asr_websocket_factory()
self.status = FIRST_FRAME
self.segment_start_time = current_time
return []
class MINIMAX_LLM(LLM):
def __init__(self):
self.token = 0
async def chat(self, assistant, prompt, db):
llm_info = json.loads(assistant.llm_info)
messages = json.loads(assistant.messages)
messages.append({"role":"user","content":prompt})
assistant.messages = json.dumps(messages)
payload = json.dumps({
"model": llm_info['model'],
"stream": True,
"messages": messages,
"max_tokens": llm_info['max_tokens'],
"temperature": llm_info['temperature'],
"top_p": llm_info['top_p'],
})
headers = {
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json'
}
async with aiohttp.ClientSession() as client:
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response:
async for chunk in response.content.iter_any():
try:
chunk_msg = self.__parseChunk(chunk)
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame
except LLMResponseEnd:
msg_frame = {"is_end":True,"msg":""}
if self.token > llm_info['max_tokens'] * 0.8:
msg_frame['msg'] = 'max_token reached'
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
db.commit()
assistant.messages = as_query.messages
yield msg_frame
def __parseChunk(self, llm_chunk):
result = ""
data=json.loads(llm_chunk.decode('utf-8')[6:])
if data["object"] == "chat.completion": #如果是结束帧
self.token = data['usage']['total_tokens']
raise LLMResponseEnd()
elif data['object'] == 'chat.completion.chunk':
for choice in data['choices']:
result += choice['delta']['content']
else:
raise UnkownLLMFrame()
return result
class VITS_TTS(TTS):
def __init__(self):
pass
def synthetize(self, assistant, text):
tts_info = json.loads(assistant.tts_info)
return vits.synthesize(text, tts_info)
# --------------------------------- #
# ------ ASR, LLM, TTS 工厂类 ------ #
class ASRFactory:
def create_asr(self,asr_type:str) -> ASR:
if asr_type == 'XF':
return XF_ASR()
class LLMFactory:
def create_llm(self,llm_type:str) -> LLM:
if llm_type == 'MINIMAX':
return MINIMAX_LLM()
class TTSFactory:
def create_tts(self,tts_type:str) -> TTS:
if tts_type == 'VITS':
return VITS_TTS()
# --------------------------------- #
# ----------- 具体服务类 ----------- #
# 从单说话人的asr_results中取出第一个结果作为prompt
class BasicPromptService(PromptService):
def prompt_process(self, prompt:str, **kwargs):
if 'asr_result' in kwargs:
return kwargs['asr_result'][0]['text'], kwargs
else:
raise NoAsrResultsError()
class UserAudioRecordService(UserAudioService):
def user_audio_process(self, audio:str, **kwargs):
audio_decoded = base64.b64decode(audio)
kwargs['recorder'].user_audio += audio_decoded
return audio,kwargs
class TTSAudioRecordService(TTSAudioService):
def tts_audio_process(self, audio:bytes, **kwargs):
kwargs['recorder'].tts_audio += audio
return audio,kwargs
# --------------------------------- #
# ------------ 服务链类 ----------- #
class UserAudioServiceChain:
def __init__(self):
self.services = []
def add_service(self, service:UserAudioService):
self.services.append(service)
def user_audio_process(self, audio, **kwargs):
for service in self.services:
audio, kwargs = service.user_audio_process(audio, **kwargs)
return audio
class PromptServiceChain:
def __init__(self):
self.services = []
def add_service(self, service:PromptService):
self.services.append(service)
def prompt_process(self, asr_results):
kwargs = {"asr_result":asr_results}
prompt = ""
for service in self.services:
prompt, kwargs = service.prompt_process(prompt, **kwargs)
return prompt
class LLMMsgServiceChain:
def __init__(self, ):
self.services = []
self.spliter = SentenceSegmentation()
def add_service(self, service:LLMMsgService):
self.services.append(service)
def llm_msg_process(self, llm_msg):
kwargs = {}
llm_chunks = self.spliter.split(llm_msg)
for service in self.services:
llm_chunks , kwargs = service.llm_msg_process(llm_chunks, **kwargs)
return llm_chunks
class TTSAudioServiceChain:
def __init__(self):
self.services = []
def add_service(self, service:TTSAudioService):
self.services.append(service)
def tts_audio_process(self, audio, **kwargs):
for service in self.services:
audio, kwargs = service.tts_audio_process(audio, **kwargs)
return audio
# --------------------------------- #
class Agent():
def __init__(self, asr_type:str, llm_type:str, tts_type:str):
asrFactory = ASRFactory()
llmFactory = LLMFactory()
ttsFactory = TTSFactory()
self.recorder = None
self.user_audio_service_chain = UserAudioServiceChain() #创建用户音频处理服务链
self.user_audio_service_chain.add_service(UserAudioRecordService()) #添加用户音频记录服务
self.asr = asrFactory.create_asr(asr_type)
self.prompt_service_chain = PromptServiceChain()
self.prompt_service_chain.add_service(BasicPromptService())
self.llm = llmFactory.create_llm(llm_type)
self.llm_chunk_service_chain = LLMMsgServiceChain()
self.tts = ttsFactory.create_tts(tts_type)
self.tts_audio_service_chain = TTSAudioServiceChain()
self.tts_audio_service_chain.add_service(TTSAudioRecordService())
def init_recorder(self,user_id):
self.recorder = Recorder(user_id)
# 对用户输入的音频进行预处理
def user_audio_process(self, audio, db):
return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder)
# 进行流式语音识别
async def stream_recognize(self, chunk, db):
return await self.asr.stream_recognize(chunk)
# 进行Prompt加工
def prompt_process(self, asr_results, db):
return self.prompt_service_chain.prompt_process(asr_results)
# 进行大模型调用
async def chat(self, assistant ,prompt, db):
return self.llm.chat(assistant, prompt, db)
# 对大模型的返回进行处理
def llm_msg_process(self, llm_chunk, db):
return self.llm_chunk_service_chain.llm_msg_process(llm_chunk)
# 进行TTS合成
def synthetize(self, assistant, text, db):
return self.tts.synthetize(assistant, text)
# 对合成后的音频进行处理
def tts_audio_process(self, audio, db):
return self.tts_audio_service_chain.tts_audio_process(audio, recorder=self.recorder)
# 编码
def encode(self, text, audio):
text_resp = {"type":"text","code":200,"msg":text}
text_resp_bytes = json.dumps(text_resp,ensure_ascii=False).encode('utf-8')
header = struct.pack('!II',len(text_resp_bytes),len(audio))
final_resp = header + text_resp_bytes + audio
return final_resp
# 保存音频
def save(self):
self.recorder.write()