1
0
Fork 0
TakwayDisplayPlatform/app/concrete.py

363 lines
14 KiB
Python
Raw Normal View History

2024-06-09 22:54:13 +08:00
from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
2024-06-12 17:18:47 +08:00
from volcenginesdkarkruntime import Ark
2024-06-09 22:54:13 +08:00
from .model import Assistant
from .abstract import *
from .public import *
from .exception import *
from .dependency import get_logger
2024-06-09 22:54:13 +08:00
from utils.vits_utils import TextToSpeech
from config import Config
import aiohttp
import asyncio
import struct
import base64
import json
# ----------- 初始化vits ----------- #
vits = TextToSpeech()
# ---------------------------------- #
# ---------- 初始化logger ---------- #
logger = get_logger()
# ---------------------------------- #
2024-06-09 22:54:13 +08:00
#------ 具体 ASR, LLM, TTS 类 ------ #
FIRST_FRAME = 1
CONTINUE_FRAME =2
LAST_FRAME =3
class XF_ASR(ASR):
def __init__(self):
super().__init__()
2024-06-09 22:54:13 +08:00
self.websocket = None
self.current_message = ""
self.audio = ""
self.status = FIRST_FRAME
self.segment_duration_threshold = 25 #超时时间为25秒
self.segment_start_time = None
async def stream_recognize(self, chunk):
if self.status == FIRST_FRAME and chunk['meta_info']['is_end']: #如果是第一帧且为end则判断为杂音
raise SideNoiseError()
2024-06-09 22:54:13 +08:00
if self.websocket is None: #如果websocket未建立则建立一个新的连接
self.websocket = await xf_asr_websocket_factory()
if self.segment_start_time is None: #如果是第一段,则记录开始时间
self.segment_start_time = asyncio.get_event_loop().time()
if chunk['meta_info']['is_end']:
self.status = LAST_FRAME
audio_data = chunk['audio']
self.audio += audio_data
if self.status == FIRST_FRAME: #发送第一帧
await self.websocket.send(make_first_frame(audio_data))
self.status = CONTINUE_FRAME
elif self.status == CONTINUE_FRAME: #发送中间帧
await self.websocket.send(make_continue_frame(audio_data))
elif self.status == LAST_FRAME: #发送最后一帧
await self.websocket.send(make_last_frame(audio_data))
self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv()))
if self.current_message == "":
raise AsrResultNoneError()
if "进入沉默模式" in self.current_message:
self.is_slience = True
asyncio.create_task(self.websocket.close())
raise EnterSlienceMode()
if "退出沉默模式" in self.current_message:
self.is_slience = False
self.current_message = "已退出沉默模式"
if self.is_slience:
asyncio.create_task(self.websocket.close())
raise SlienceMode()
asyncio.create_task(self.websocket.close())
2024-06-09 22:54:13 +08:00
return [{"text":self.current_message, "audio":self.audio}]
current_time = asyncio.get_event_loop().time()
if current_time - self.segment_start_time > self.segment_duration_threshold: #超时,发送最后一帧并重置状态
await self.websocket.send(make_last_frame())
self.current_message += parse_xfasr_recv(await self.websocket.recv())
await self.websocket.close()
self.websocket = await xf_asr_websocket_factory()
self.status = FIRST_FRAME
self.segment_start_time = current_time
return []
class MINIMAX_LLM(LLM):
def __init__(self):
self.token = 0
async def chat(self, assistant, prompt, db):
llm_info = json.loads(assistant.llm_info)
messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt})
2024-06-10 20:49:50 +08:00
payload = json.dumps({ #整理payload
2024-06-09 22:54:13 +08:00
"model": llm_info['model'],
"stream": True,
"messages": messages,
"max_tokens": llm_info['max_tokens'],
"temperature": llm_info['temperature'],
"top_p": llm_info['top_p'],
})
headers = {
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json'
}
async with aiohttp.ClientSession() as client:
2024-06-10 20:49:50 +08:00
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #调用大模型
2024-06-09 22:54:13 +08:00
async for chunk in response.content.iter_any():
try:
2024-06-10 20:49:50 +08:00
chunk_msg = self.__parseChunk(chunk) #解析llm返回
2024-06-09 22:54:13 +08:00
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame
except LLMResponseEnd:
2024-06-12 17:18:47 +08:00
msg_frame = {"is_end":True,"code":200,"msg":""}
2024-06-10 20:49:50 +08:00
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
2024-06-12 17:18:47 +08:00
msg_frame['code'] = '201'
2024-06-09 22:54:13 +08:00
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
db.commit()
assistant.messages = as_query.messages
yield msg_frame
def __parseChunk(self, llm_chunk):
try:
result = ""
chunk_decoded = llm_chunk.decode('utf-8')
chunks = chunk_decoded.split('\n\n')
for chunk in chunks:
if not chunk:
continue
data=json.loads(chunk[6:])
if data["object"] == "chat.completion": #如果是结束帧
self.token = data['usage']['total_tokens']
raise LLMResponseEnd()
elif data['object'] == 'chat.completion.chunk':
for choice in data['choices']:
result += choice['delta']['content']
else:
raise UnkownLLMFrame()
return result
except (json.JSONDecodeError, KeyError):
logger.error(llm_chunk)
raise AbnormalLLMFrame(f"error llm_chunk:{llm_chunk}")
2024-06-09 22:54:13 +08:00
2024-06-12 17:18:47 +08:00
class VOLCENGINE_LLM(LLM):
def __init__(self):
self.token = 0
self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY)
async def chat(self, assistant, prompt, db):
llm_info = json.loads(assistant.llm_info)
model = self.__get_model(llm_info)
messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt})
stream = self.client.chat.completions.create(
model = model,
messages=messages,
stream=True,
temperature=llm_info['temperature'],
top_p=llm_info['top_p'],
max_tokens=llm_info['max_tokens'],
2024-06-12 17:18:47 +08:00
stream_options={'include_usage': True}
)
for chunk in stream:
try:
chunk_msg = self.__parseChunk(chunk)
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame
except LLMResponseEnd:
msg_frame = {"is_end":True,"code":20-0,"msg":""}
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['code'] = '201'
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
db.commit()
assistant.messages = as_query.messages
yield msg_frame
def __get_model(self, llm_info):
if llm_info['model'] == 'doubao-4k-lite':
return Config.VOLCENGINE_LLM.DOUBAO_LITE_4k
else:
raise UnknownVolcEngineModelError()
def __parseChunk(self, llm_chunk):
if llm_chunk.usage:
self.token = llm_chunk.usage.total_tokens
raise LLMResponseEnd()
if not llm_chunk.choices:
raise AbnormalLLMFrame(f"error volcengine llm_chunk:{llm_chunk}")
return llm_chunk.choices[0].delta.content
2024-06-09 22:54:13 +08:00
class VITS_TTS(TTS):
def __init__(self):
pass
def synthetize(self, assistant, text):
tts_info = json.loads(assistant.tts_info)
return vits.synthesize(text, tts_info)
# --------------------------------- #
# ------ ASR, LLM, TTS 工厂类 ------ #
class ASRFactory:
def create_asr(self,asr_type:str) -> ASR:
if asr_type == 'XF':
return XF_ASR()
class LLMFactory:
def create_llm(self,llm_type:str) -> LLM:
if llm_type == 'MINIMAX':
return MINIMAX_LLM()
2024-06-12 17:18:47 +08:00
if llm_type == 'VOLCENGINE':
return VOLCENGINE_LLM()
2024-06-09 22:54:13 +08:00
class TTSFactory:
def create_tts(self,tts_type:str) -> TTS:
if tts_type == 'VITS':
return VITS_TTS()
# --------------------------------- #
# ----------- 具体服务类 ----------- #
# 从单说话人的asr_results中取出第一个结果作为prompt
class BasicPromptService(PromptService):
def prompt_process(self, prompt:str, **kwargs):
if 'asr_result' in kwargs:
return kwargs['asr_result'][0]['text'], kwargs
else:
raise NoAsrResultsError()
2024-06-10 20:49:50 +08:00
# 记录用户音频
2024-06-09 22:54:13 +08:00
class UserAudioRecordService(UserAudioService):
def user_audio_process(self, audio:str, **kwargs):
audio_decoded = base64.b64decode(audio)
kwargs['recorder'].user_audio += audio_decoded
return audio,kwargs
2024-06-10 20:49:50 +08:00
# 记录TTS音频
2024-06-09 22:54:13 +08:00
class TTSAudioRecordService(TTSAudioService):
def tts_audio_process(self, audio:bytes, **kwargs):
kwargs['recorder'].tts_audio += audio
return audio,kwargs
# --------------------------------- #
# ------------ 服务链类 ----------- #
class UserAudioServiceChain:
def __init__(self):
self.services = []
def add_service(self, service:UserAudioService):
self.services.append(service)
def user_audio_process(self, audio, **kwargs):
for service in self.services:
audio, kwargs = service.user_audio_process(audio, **kwargs)
return audio
class PromptServiceChain:
def __init__(self):
self.services = []
def add_service(self, service:PromptService):
self.services.append(service)
def prompt_process(self, asr_results):
kwargs = {"asr_result":asr_results}
prompt = ""
for service in self.services:
prompt, kwargs = service.prompt_process(prompt, **kwargs)
return prompt
class LLMMsgServiceChain:
def __init__(self, ):
self.services = []
self.spliter = SentenceSegmentation()
def add_service(self, service:LLMMsgService):
self.services.append(service)
def llm_msg_process(self, llm_msg):
kwargs = {}
2024-06-10 20:49:50 +08:00
llm_chunks = self.spliter.split(llm_msg) #首先对llm返回进行断句
2024-06-09 22:54:13 +08:00
for service in self.services:
llm_chunks , kwargs = service.llm_msg_process(llm_chunks, **kwargs)
return llm_chunks
class TTSAudioServiceChain:
def __init__(self):
self.services = []
def add_service(self, service:TTSAudioService):
self.services.append(service)
def tts_audio_process(self, audio, **kwargs):
for service in self.services:
audio, kwargs = service.tts_audio_process(audio, **kwargs)
return audio
# --------------------------------- #
class Agent():
def __init__(self, asr_type:str, llm_type:str, tts_type:str):
asrFactory = ASRFactory()
llmFactory = LLMFactory()
ttsFactory = TTSFactory()
self.recorder = None
self.user_audio_service_chain = UserAudioServiceChain() #创建用户音频处理服务链
self.user_audio_service_chain.add_service(UserAudioRecordService()) #添加用户音频记录服务
self.asr = asrFactory.create_asr(asr_type)
self.prompt_service_chain = PromptServiceChain()
self.prompt_service_chain.add_service(BasicPromptService())
self.llm = llmFactory.create_llm(llm_type)
2024-06-10 20:49:50 +08:00
self.llm_msg_service_chain = LLMMsgServiceChain()
2024-06-09 22:54:13 +08:00
self.tts = ttsFactory.create_tts(tts_type)
self.tts_audio_service_chain = TTSAudioServiceChain()
self.tts_audio_service_chain.add_service(TTSAudioRecordService())
def init_recorder(self,user_id):
self.recorder = Recorder(user_id)
# 对用户输入的音频进行预处理
def user_audio_process(self, audio, db):
return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder)
# 进行流式语音识别
async def stream_recognize(self, chunk, db):
return await self.asr.stream_recognize(chunk)
# 进行Prompt加工
def prompt_process(self, asr_results, db):
return self.prompt_service_chain.prompt_process(asr_results)
# 进行大模型调用
async def chat(self, assistant ,prompt, db):
return self.llm.chat(assistant, prompt, db)
# 对大模型的返回进行处理
def llm_msg_process(self, llm_chunk, db):
2024-06-10 20:49:50 +08:00
return self.llm_msg_service_chain.llm_msg_process(llm_chunk)
2024-06-09 22:54:13 +08:00
# 进行TTS合成
def synthetize(self, assistant, text, db):
return self.tts.synthetize(assistant, text)
# 对合成后的音频进行处理
def tts_audio_process(self, audio, db):
return self.tts_audio_service_chain.tts_audio_process(audio, recorder=self.recorder)
# 编码
def encode(self, text, audio):
text_resp = {"type":"text","code":200,"msg":text}
text_resp_bytes = json.dumps(text_resp,ensure_ascii=False).encode('utf-8')
header = struct.pack('!II',len(text_resp_bytes),len(audio))
final_resp = header + text_resp_bytes + audio
return final_resp
# 保存音频
def save(self):
self.recorder.write()