1
0
Fork 0
TakwayDisplayPlatform/app/concrete.py

408 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
from volcenginesdkarkruntime import Ark
from zhipuai import ZhipuAI
from .model import Assistant
from .abstract import *
from .public import *
from .exception import *
from .dependency import get_logger
from utils.vits_utils import TextToSpeech
from config import Config
import aiohttp
import asyncio
import struct
import base64
import json
# ----------- 初始化vits ----------- #
vits = TextToSpeech()
# ---------------------------------- #
# ---------- 初始化logger ---------- #
logger = get_logger()
# ---------------------------------- #
#------ 具体 ASR, LLM, TTS 类 ------ #
FIRST_FRAME = 1
CONTINUE_FRAME =2
LAST_FRAME =3
class XF_ASR(ASR):
def __init__(self):
super().__init__()
self.websocket = None
self.current_message = ""
self.audio = ""
self.status = FIRST_FRAME
self.segment_duration_threshold = 25 #超时时间为25秒
self.segment_start_time = None
async def stream_recognize(self, assistant, chunk):
if self.status == FIRST_FRAME and chunk['meta_info']['is_end']: #如果是第一帧且为end则判断为杂音
raise SideNoiseError()
if self.websocket is None: #如果websocket未建立则建立一个新的连接
self.websocket = await xf_asr_websocket_factory()
if self.segment_start_time is None: #如果是第一段,则记录开始时间
self.segment_start_time = asyncio.get_event_loop().time()
if chunk['meta_info']['is_end']:
self.status = LAST_FRAME
audio_data = chunk['audio']
self.audio += audio_data
if self.status == FIRST_FRAME: #发送第一帧
await self.websocket.send(make_first_frame(audio_data))
self.status = CONTINUE_FRAME
elif self.status == CONTINUE_FRAME: #发送中间帧
await self.websocket.send(make_continue_frame(audio_data))
elif self.status == LAST_FRAME: #发送最后一帧
await self.websocket.send(make_last_frame(audio_data))
logger.debug("发送完毕")
self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv()))
if self.current_message == "":
raise AsrResultNoneError()
if self.current_message in [""]:
raise SideNoiseError()
if "闭嘴" in self.current_message:
self.is_slience = True
asyncio.create_task(self.websocket.close())
raise EnterSlienceMode()
asyncio.create_task(self.websocket.close())
logger.debug(f"ASR结果: {self.current_message}")
return [{"text":self.current_message, "audio":self.audio}]
current_time = asyncio.get_event_loop().time()
if current_time - self.segment_start_time > self.segment_duration_threshold: #超时,发送最后一帧并重置状态
await self.websocket.send(make_last_frame())
self.current_message += parse_xfasr_recv(await self.websocket.recv())
await self.websocket.close()
self.websocket = await xf_asr_websocket_factory()
self.status = FIRST_FRAME
self.segment_start_time = current_time
return []
class MINIMAX_LLM(LLM):
def __init__(self):
self.token = 0
async def chat(self, assistant, prompt):
llm_info = json.loads(assistant.llm_info)
messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt})
payload = json.dumps({ #整理payload
"model": llm_info['model'],
"stream": True,
"messages": messages,
"max_tokens": llm_info['max_tokens'],
"temperature": llm_info['temperature'],
"top_p": llm_info['top_p'],
})
headers = {
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json'
}
async with aiohttp.ClientSession() as client:
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #调用大模型
async for chunk in response.content.iter_any():
try:
chunk_msg = self.__parseChunk(chunk) #解析llm返回
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame
except LLMResponseEnd:
msg_frame = {"is_end":True,"code":200,"msg":""}
assistant.token = self.token
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['code'] = '201'
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
yield msg_frame
def __parseChunk(self, llm_chunk):
try:
result = ""
chunk_decoded = llm_chunk.decode('utf-8')
chunks = chunk_decoded.split('\n\n')
for chunk in chunks:
if not chunk:
continue
data=json.loads(chunk[6:])
if data["object"] == "chat.completion": #如果是结束帧
self.token = data['usage']['total_tokens']
raise LLMResponseEnd()
elif data['object'] == 'chat.completion.chunk':
for choice in data['choices']:
result += choice['delta']['content']
else:
raise UnkownLLMFrame()
return result
except (json.JSONDecodeError, KeyError):
logger.error(llm_chunk)
raise AbnormalLLMFrame(f"error llm_chunk:{llm_chunk}")
class VOLCENGINE_LLM(LLM):
def __init__(self):
self.token = 0
self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY)
async def chat(self, assistant, prompt):
llm_info = json.loads(assistant.llm_info)
model = self.__get_model(llm_info)
messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt})
stream = self.client.chat.completions.create(
model = model,
messages=messages,
stream=True,
temperature=llm_info['temperature'],
top_p=llm_info['top_p'],
max_tokens=llm_info['max_tokens'],
stream_options={'include_usage': True}
)
for chunk in stream:
try:
chunk_msg = self.__parseChunk(chunk)
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame
except LLMResponseEnd:
msg_frame = {"is_end":True,"code":200,"msg":""}
assistant.token = self.token
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['code'] = '201'
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
yield msg_frame
def __get_model(self, llm_info):
if llm_info['model'] == 'doubao-4k-lite':
return Config.VOLCENGINE_LLM.DOUBAO_LITE_4k
elif llm_info['model'] == 'doubao-32k-lite':
return Config.VOLCENGINE_LLM.DOUBAO_LITE_32k
elif llm_info['model'] == 'doubao-32k-pro':
return Config.VOLCENGINE_LLM.DOUBAO_PRO_32k
else:
raise UnknownVolcEngineModelError()
def __parseChunk(self, llm_chunk):
if llm_chunk.usage:
self.token = llm_chunk.usage.total_tokens
raise LLMResponseEnd()
if not llm_chunk.choices:
raise AbnormalLLMFrame(f"error volcengine llm_chunk:{llm_chunk}")
return llm_chunk.choices[0].delta.content
class ZHIPU_LLM(LLM):
def __init__(self):
self.token = 0
self.client = ZhipuAI(api_key=Config.ZHIPU_LLM.API_KEY)
async def chat(self, assistant, prompt):
llm_info = json.loads(assistant.llm_info)
messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt})
stream = self.client.chat.completions.create(
model = llm_info['model'],
messages=messages,
stream=True,
temperature=llm_info['temperature'],
top_p=llm_info['top_p'],
max_tokens=llm_info['max_tokens']
)
for chunk in stream:
try:
chunk_msg = self.__parseChunk(chunk)
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame
except LLMResponseEnd:
msg_frame = {"is_end":True,"code":200,"msg":""}
assistant.token = self.token
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['code'] = '201'
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
yield msg_frame
def __parseChunk(self, llm_chunk):
if llm_chunk.usage:
self.token = llm_chunk.usage.total_tokens
raise LLMResponseEnd()
if not llm_chunk.choices:
raise AbnormalLLMFrame(f"error zhipu llm_chunk:{llm_chunk}")
return llm_chunk.choices[0].delta.content
class MIXED_LLM(LLM):
def __init__(self):
self.minimax = MINIMAX_LLM()
self.volcengine = VOLCENGINE_LLM()
self.zhipu = ZHIPU_LLM()
class VITS_TTS(TTS):
def __init__(self):
pass
def synthetize(self, assistant, text):
tts_info = json.loads(assistant.tts_info)
return vits.synthesize(text, tts_info)
# --------------------------------- #
# ------ ASR, LLM, TTS 工厂类 ------ #
class ASRFactory:
def create_asr(self,asr_type:str) -> ASR:
if asr_type == 'XF':
return XF_ASR()
class LLMFactory:
def create_llm(self,llm_type:str) -> LLM:
if llm_type == 'MINIMAX':
return MINIMAX_LLM()
if llm_type == 'VOLCENGINE':
return VOLCENGINE_LLM()
if llm_type == 'ZHIPU':
return ZHIPU_LLM()
class TTSFactory:
def create_tts(self,tts_type:str) -> TTS:
if tts_type == 'VITS':
return VITS_TTS()
# --------------------------------- #
# ----------- 具体服务类 ----------- #
# 从单说话人的asr_results中取出第一个结果作为prompt
class BasicPromptService(PromptService):
def prompt_process(self, prompt:str, **kwargs):
if 'asr_result' in kwargs:
return kwargs['asr_result'][0]['text'], kwargs
else:
raise NoAsrResultsError()
# 记录用户音频
class UserAudioRecordService(UserAudioService):
def user_audio_process(self, audio:str, **kwargs):
audio_decoded = base64.b64decode(audio)
kwargs['recorder'].user_audio += audio_decoded
return audio,kwargs
# 记录TTS音频
class TTSAudioRecordService(TTSAudioService):
def tts_audio_process(self, audio:bytes, **kwargs):
kwargs['recorder'].tts_audio += audio
return audio,kwargs
# --------------------------------- #
# ------------ 服务链类 ----------- #
class UserAudioServiceChain:
def __init__(self):
self.services = []
def add_service(self, service:UserAudioService):
self.services.append(service)
def user_audio_process(self, audio, **kwargs):
for service in self.services:
audio, kwargs = service.user_audio_process(audio, **kwargs)
return audio
class PromptServiceChain:
def __init__(self):
self.services = []
def add_service(self, service:PromptService):
self.services.append(service)
def prompt_process(self, asr_results):
kwargs = {"asr_result":asr_results}
prompt = ""
for service in self.services:
prompt, kwargs = service.prompt_process(prompt, **kwargs)
return prompt
class LLMMsgServiceChain:
def __init__(self, ):
self.services = []
self.spliter = SentenceSegmentation()
def add_service(self, service:LLMMsgService):
self.services.append(service)
def llm_msg_process(self, llm_msg):
kwargs = {}
llm_chunks = self.spliter.split(llm_msg) #首先对llm返回进行断句
for service in self.services:
llm_chunks , kwargs = service.llm_msg_process(llm_chunks, **kwargs)
return llm_chunks
class TTSAudioServiceChain:
def __init__(self):
self.services = []
def add_service(self, service:TTSAudioService):
self.services.append(service)
def tts_audio_process(self, audio, **kwargs):
for service in self.services:
audio, kwargs = service.tts_audio_process(audio, **kwargs)
return audio
# --------------------------------- #
class Agent():
def __init__(self, asr_type:str, llm_type:str, tts_type:str):
asrFactory = ASRFactory()
llmFactory = LLMFactory()
ttsFactory = TTSFactory()
self.recorder = None
self.user_audio_service_chain = UserAudioServiceChain() #创建用户音频处理服务链
self.user_audio_service_chain.add_service(UserAudioRecordService()) #添加用户音频记录服务
self.asr = asrFactory.create_asr(asr_type)
self.prompt_service_chain = PromptServiceChain()
self.prompt_service_chain.add_service(BasicPromptService())
self.llm = llmFactory.create_llm(llm_type)
self.llm_msg_service_chain = LLMMsgServiceChain()
self.tts = ttsFactory.create_tts(tts_type)
self.tts_audio_service_chain = TTSAudioServiceChain()
self.tts_audio_service_chain.add_service(TTSAudioRecordService())
def init_recorder(self,user_id):
self.recorder = Recorder(user_id)
# 对用户输入的音频进行预处理
def user_audio_process(self, audio):
return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder)
# 进行流式语音识别
async def stream_recognize(self, assistant, chunk):
return await self.asr.stream_recognize(assistant,chunk)
# 进行Prompt加工
def prompt_process(self, asr_results):
return self.prompt_service_chain.prompt_process(asr_results)
# 进行大模型调用
async def chat(self, assistant ,prompt):
return self.llm.chat(assistant, prompt)
# 对大模型的返回进行处理
def llm_msg_process(self, llm_chunk):
return self.llm_msg_service_chain.llm_msg_process(llm_chunk)
# 进行TTS合成
def synthetize(self, assistant, text):
return self.tts.synthetize(assistant, text)
# 对合成后的音频进行处理
def tts_audio_process(self, audio):
return self.tts_audio_service_chain.tts_audio_process(audio, recorder=self.recorder)
# 编码
def encode(self, text, audio):
text_resp = {"type":"text","code":200,"msg":text}
text_resp_bytes = json.dumps(text_resp,ensure_ascii=False).encode('utf-8')
header = struct.pack('!II',len(text_resp_bytes),len(audio))
final_resp = header + text_resp_bytes + audio
return final_resp
# 保存音频
def save(self):
self.recorder.write()