2024-06-09 22:54:13 +08:00
|
|
|
|
from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
|
2024-06-12 17:18:47 +08:00
|
|
|
|
from volcenginesdkarkruntime import Ark
|
2024-06-09 22:54:13 +08:00
|
|
|
|
from .model import Assistant
|
|
|
|
|
from .abstract import *
|
|
|
|
|
from .public import *
|
2024-06-10 02:28:16 +08:00
|
|
|
|
from .exception import *
|
|
|
|
|
from .dependency import get_logger
|
2024-06-09 22:54:13 +08:00
|
|
|
|
from utils.vits_utils import TextToSpeech
|
|
|
|
|
from config import Config
|
|
|
|
|
import aiohttp
|
|
|
|
|
import asyncio
|
|
|
|
|
import struct
|
|
|
|
|
import base64
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
|
|
# ----------- 初始化vits ----------- #
|
|
|
|
|
vits = TextToSpeech()
|
|
|
|
|
# ---------------------------------- #
|
|
|
|
|
|
2024-06-10 02:28:16 +08:00
|
|
|
|
# ---------- 初始化logger ---------- #
|
|
|
|
|
logger = get_logger()
|
|
|
|
|
# ---------------------------------- #
|
2024-06-09 22:54:13 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#------ 具体 ASR, LLM, TTS 类 ------ #
|
|
|
|
|
FIRST_FRAME = 1
|
|
|
|
|
CONTINUE_FRAME =2
|
|
|
|
|
LAST_FRAME =3
|
|
|
|
|
|
|
|
|
|
class XF_ASR(ASR):
|
|
|
|
|
def __init__(self):
|
2024-06-12 22:56:39 +08:00
|
|
|
|
super().__init__()
|
2024-06-09 22:54:13 +08:00
|
|
|
|
self.websocket = None
|
|
|
|
|
self.current_message = ""
|
|
|
|
|
self.audio = ""
|
|
|
|
|
self.status = FIRST_FRAME
|
|
|
|
|
self.segment_duration_threshold = 25 #超时时间为25秒
|
|
|
|
|
self.segment_start_time = None
|
|
|
|
|
|
|
|
|
|
async def stream_recognize(self, chunk):
|
2024-06-10 02:28:16 +08:00
|
|
|
|
if self.status == FIRST_FRAME and chunk['meta_info']['is_end']: #如果是第一帧,且为end,则判断为杂音
|
|
|
|
|
raise SideNoiseError()
|
2024-06-09 22:54:13 +08:00
|
|
|
|
if self.websocket is None: #如果websocket未建立,则建立一个新的连接
|
|
|
|
|
self.websocket = await xf_asr_websocket_factory()
|
|
|
|
|
if self.segment_start_time is None: #如果是第一段,则记录开始时间
|
|
|
|
|
self.segment_start_time = asyncio.get_event_loop().time()
|
|
|
|
|
if chunk['meta_info']['is_end']:
|
|
|
|
|
self.status = LAST_FRAME
|
|
|
|
|
audio_data = chunk['audio']
|
|
|
|
|
self.audio += audio_data
|
|
|
|
|
|
|
|
|
|
if self.status == FIRST_FRAME: #发送第一帧
|
|
|
|
|
await self.websocket.send(make_first_frame(audio_data))
|
|
|
|
|
self.status = CONTINUE_FRAME
|
|
|
|
|
elif self.status == CONTINUE_FRAME: #发送中间帧
|
|
|
|
|
await self.websocket.send(make_continue_frame(audio_data))
|
|
|
|
|
elif self.status == LAST_FRAME: #发送最后一帧
|
|
|
|
|
await self.websocket.send(make_last_frame(audio_data))
|
2024-06-20 10:34:26 +08:00
|
|
|
|
logger.debug("发送完毕")
|
2024-06-09 22:54:13 +08:00
|
|
|
|
self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv()))
|
|
|
|
|
if self.current_message == "":
|
|
|
|
|
raise AsrResultNoneError()
|
2024-06-20 10:34:26 +08:00
|
|
|
|
if self.current_message in ["啊"]:
|
|
|
|
|
raise SideNoiseError()
|
|
|
|
|
if "闭嘴" in self.current_message:
|
2024-06-12 22:56:39 +08:00
|
|
|
|
self.is_slience = True
|
|
|
|
|
asyncio.create_task(self.websocket.close())
|
|
|
|
|
raise EnterSlienceMode()
|
2024-06-10 02:28:16 +08:00
|
|
|
|
asyncio.create_task(self.websocket.close())
|
2024-06-20 10:34:26 +08:00
|
|
|
|
logger.debug(f"ASR结果: {self.current_message}")
|
2024-06-09 22:54:13 +08:00
|
|
|
|
return [{"text":self.current_message, "audio":self.audio}]
|
|
|
|
|
|
|
|
|
|
current_time = asyncio.get_event_loop().time()
|
|
|
|
|
if current_time - self.segment_start_time > self.segment_duration_threshold: #超时,发送最后一帧并重置状态
|
|
|
|
|
await self.websocket.send(make_last_frame())
|
|
|
|
|
self.current_message += parse_xfasr_recv(await self.websocket.recv())
|
|
|
|
|
await self.websocket.close()
|
|
|
|
|
self.websocket = await xf_asr_websocket_factory()
|
|
|
|
|
self.status = FIRST_FRAME
|
|
|
|
|
self.segment_start_time = current_time
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
class MINIMAX_LLM(LLM):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.token = 0
|
|
|
|
|
|
2024-06-20 10:34:26 +08:00
|
|
|
|
async def chat(self, assistant, prompt):
|
2024-06-09 22:54:13 +08:00
|
|
|
|
llm_info = json.loads(assistant.llm_info)
|
|
|
|
|
messages = json.loads(assistant.messages)
|
2024-06-10 02:28:16 +08:00
|
|
|
|
messages.append({'role':'user','content':prompt})
|
2024-06-10 20:49:50 +08:00
|
|
|
|
payload = json.dumps({ #整理payload
|
2024-06-09 22:54:13 +08:00
|
|
|
|
"model": llm_info['model'],
|
|
|
|
|
"stream": True,
|
|
|
|
|
"messages": messages,
|
|
|
|
|
"max_tokens": llm_info['max_tokens'],
|
|
|
|
|
"temperature": llm_info['temperature'],
|
|
|
|
|
"top_p": llm_info['top_p'],
|
|
|
|
|
})
|
|
|
|
|
headers = {
|
|
|
|
|
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
|
|
|
|
'Content-Type': 'application/json'
|
|
|
|
|
}
|
|
|
|
|
async with aiohttp.ClientSession() as client:
|
2024-06-10 20:49:50 +08:00
|
|
|
|
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #调用大模型
|
2024-06-09 22:54:13 +08:00
|
|
|
|
async for chunk in response.content.iter_any():
|
|
|
|
|
try:
|
2024-06-10 20:49:50 +08:00
|
|
|
|
chunk_msg = self.__parseChunk(chunk) #解析llm返回
|
2024-06-09 22:54:13 +08:00
|
|
|
|
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
|
|
|
|
yield msg_frame
|
|
|
|
|
except LLMResponseEnd:
|
2024-06-12 17:18:47 +08:00
|
|
|
|
msg_frame = {"is_end":True,"code":200,"msg":""}
|
2024-06-20 10:34:26 +08:00
|
|
|
|
assistant.token = self.token
|
2024-06-10 20:49:50 +08:00
|
|
|
|
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
2024-06-12 17:18:47 +08:00
|
|
|
|
msg_frame['code'] = '201'
|
2024-06-20 10:34:26 +08:00
|
|
|
|
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
2024-06-09 22:54:13 +08:00
|
|
|
|
yield msg_frame
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __parseChunk(self, llm_chunk):
|
2024-06-10 02:28:16 +08:00
|
|
|
|
try:
|
|
|
|
|
result = ""
|
|
|
|
|
chunk_decoded = llm_chunk.decode('utf-8')
|
|
|
|
|
chunks = chunk_decoded.split('\n\n')
|
|
|
|
|
for chunk in chunks:
|
|
|
|
|
if not chunk:
|
|
|
|
|
continue
|
|
|
|
|
data=json.loads(chunk[6:])
|
|
|
|
|
if data["object"] == "chat.completion": #如果是结束帧
|
|
|
|
|
self.token = data['usage']['total_tokens']
|
|
|
|
|
raise LLMResponseEnd()
|
|
|
|
|
elif data['object'] == 'chat.completion.chunk':
|
|
|
|
|
for choice in data['choices']:
|
|
|
|
|
result += choice['delta']['content']
|
|
|
|
|
else:
|
|
|
|
|
raise UnkownLLMFrame()
|
|
|
|
|
return result
|
|
|
|
|
except (json.JSONDecodeError, KeyError):
|
|
|
|
|
logger.error(llm_chunk)
|
|
|
|
|
raise AbnormalLLMFrame(f"error llm_chunk:{llm_chunk}")
|
2024-06-09 22:54:13 +08:00
|
|
|
|
|
2024-06-12 17:18:47 +08:00
|
|
|
|
class VOLCENGINE_LLM(LLM):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.token = 0
|
|
|
|
|
self.client = Ark(api_key=Config.VOLCENGINE_LLM.API_KEY)
|
|
|
|
|
|
2024-06-20 10:34:26 +08:00
|
|
|
|
async def chat(self, assistant, prompt):
|
2024-06-12 17:18:47 +08:00
|
|
|
|
llm_info = json.loads(assistant.llm_info)
|
|
|
|
|
model = self.__get_model(llm_info)
|
|
|
|
|
messages = json.loads(assistant.messages)
|
|
|
|
|
messages.append({'role':'user','content':prompt})
|
|
|
|
|
stream = self.client.chat.completions.create(
|
|
|
|
|
model = model,
|
|
|
|
|
messages=messages,
|
|
|
|
|
stream=True,
|
2024-06-12 22:56:39 +08:00
|
|
|
|
temperature=llm_info['temperature'],
|
|
|
|
|
top_p=llm_info['top_p'],
|
|
|
|
|
max_tokens=llm_info['max_tokens'],
|
2024-06-12 17:18:47 +08:00
|
|
|
|
stream_options={'include_usage': True}
|
|
|
|
|
)
|
|
|
|
|
for chunk in stream:
|
|
|
|
|
try:
|
|
|
|
|
chunk_msg = self.__parseChunk(chunk)
|
|
|
|
|
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
|
|
|
|
yield msg_frame
|
|
|
|
|
except LLMResponseEnd:
|
2024-06-20 10:34:26 +08:00
|
|
|
|
msg_frame = {"is_end":True,"code":200,"msg":""}
|
|
|
|
|
assistant.token = self.token
|
2024-06-12 17:18:47 +08:00
|
|
|
|
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
|
|
|
|
msg_frame['code'] = '201'
|
2024-06-20 10:34:26 +08:00
|
|
|
|
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
2024-06-12 17:18:47 +08:00
|
|
|
|
yield msg_frame
|
|
|
|
|
|
|
|
|
|
def __get_model(self, llm_info):
|
|
|
|
|
if llm_info['model'] == 'doubao-4k-lite':
|
|
|
|
|
return Config.VOLCENGINE_LLM.DOUBAO_LITE_4k
|
2024-06-20 10:34:26 +08:00
|
|
|
|
elif llm_info['model'] == 'doubao-32k-lite':
|
|
|
|
|
return Config.VOLCENGINE_LLM.DOUBAO_LITE_32k
|
|
|
|
|
elif llm_info['model'] == 'doubao-32k-pro':
|
|
|
|
|
return Config.VOLCENGINE_LLM.DOUBAO_PRO_32k
|
2024-06-12 17:18:47 +08:00
|
|
|
|
else:
|
|
|
|
|
raise UnknownVolcEngineModelError()
|
|
|
|
|
|
|
|
|
|
def __parseChunk(self, llm_chunk):
|
|
|
|
|
if llm_chunk.usage:
|
|
|
|
|
self.token = llm_chunk.usage.total_tokens
|
|
|
|
|
raise LLMResponseEnd()
|
|
|
|
|
if not llm_chunk.choices:
|
|
|
|
|
raise AbnormalLLMFrame(f"error volcengine llm_chunk:{llm_chunk}")
|
|
|
|
|
return llm_chunk.choices[0].delta.content
|
|
|
|
|
|
2024-06-09 22:54:13 +08:00
|
|
|
|
class VITS_TTS(TTS):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def synthetize(self, assistant, text):
|
|
|
|
|
tts_info = json.loads(assistant.tts_info)
|
|
|
|
|
return vits.synthesize(text, tts_info)
|
|
|
|
|
# --------------------------------- #
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ------ ASR, LLM, TTS 工厂类 ------ #
|
|
|
|
|
class ASRFactory:
|
|
|
|
|
def create_asr(self,asr_type:str) -> ASR:
|
|
|
|
|
if asr_type == 'XF':
|
|
|
|
|
return XF_ASR()
|
|
|
|
|
|
|
|
|
|
class LLMFactory:
|
|
|
|
|
def create_llm(self,llm_type:str) -> LLM:
|
|
|
|
|
if llm_type == 'MINIMAX':
|
|
|
|
|
return MINIMAX_LLM()
|
2024-06-12 17:18:47 +08:00
|
|
|
|
if llm_type == 'VOLCENGINE':
|
|
|
|
|
return VOLCENGINE_LLM()
|
2024-06-09 22:54:13 +08:00
|
|
|
|
|
|
|
|
|
class TTSFactory:
|
|
|
|
|
def create_tts(self,tts_type:str) -> TTS:
|
|
|
|
|
if tts_type == 'VITS':
|
|
|
|
|
return VITS_TTS()
|
|
|
|
|
# --------------------------------- #
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ----------- 具体服务类 ----------- #
|
|
|
|
|
# 从单说话人的asr_results中取出第一个结果作为prompt
|
|
|
|
|
class BasicPromptService(PromptService):
|
|
|
|
|
def prompt_process(self, prompt:str, **kwargs):
|
|
|
|
|
if 'asr_result' in kwargs:
|
|
|
|
|
return kwargs['asr_result'][0]['text'], kwargs
|
|
|
|
|
else:
|
|
|
|
|
raise NoAsrResultsError()
|
|
|
|
|
|
2024-06-10 20:49:50 +08:00
|
|
|
|
# 记录用户音频
|
2024-06-09 22:54:13 +08:00
|
|
|
|
class UserAudioRecordService(UserAudioService):
|
|
|
|
|
def user_audio_process(self, audio:str, **kwargs):
|
|
|
|
|
audio_decoded = base64.b64decode(audio)
|
|
|
|
|
kwargs['recorder'].user_audio += audio_decoded
|
|
|
|
|
return audio,kwargs
|
2024-06-10 20:49:50 +08:00
|
|
|
|
|
|
|
|
|
# 记录TTS音频
|
2024-06-09 22:54:13 +08:00
|
|
|
|
class TTSAudioRecordService(TTSAudioService):
|
|
|
|
|
def tts_audio_process(self, audio:bytes, **kwargs):
|
|
|
|
|
kwargs['recorder'].tts_audio += audio
|
|
|
|
|
return audio,kwargs
|
|
|
|
|
# --------------------------------- #
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ------------ 服务链类 ----------- #
|
|
|
|
|
class UserAudioServiceChain:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.services = []
|
|
|
|
|
def add_service(self, service:UserAudioService):
|
|
|
|
|
self.services.append(service)
|
|
|
|
|
def user_audio_process(self, audio, **kwargs):
|
|
|
|
|
for service in self.services:
|
|
|
|
|
audio, kwargs = service.user_audio_process(audio, **kwargs)
|
|
|
|
|
return audio
|
|
|
|
|
|
|
|
|
|
class PromptServiceChain:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.services = []
|
|
|
|
|
def add_service(self, service:PromptService):
|
|
|
|
|
self.services.append(service)
|
|
|
|
|
def prompt_process(self, asr_results):
|
|
|
|
|
kwargs = {"asr_result":asr_results}
|
|
|
|
|
prompt = ""
|
|
|
|
|
for service in self.services:
|
|
|
|
|
prompt, kwargs = service.prompt_process(prompt, **kwargs)
|
|
|
|
|
return prompt
|
|
|
|
|
|
|
|
|
|
class LLMMsgServiceChain:
|
|
|
|
|
def __init__(self, ):
|
|
|
|
|
self.services = []
|
|
|
|
|
self.spliter = SentenceSegmentation()
|
|
|
|
|
def add_service(self, service:LLMMsgService):
|
|
|
|
|
self.services.append(service)
|
|
|
|
|
def llm_msg_process(self, llm_msg):
|
|
|
|
|
kwargs = {}
|
2024-06-10 20:49:50 +08:00
|
|
|
|
llm_chunks = self.spliter.split(llm_msg) #首先对llm返回进行断句
|
2024-06-09 22:54:13 +08:00
|
|
|
|
for service in self.services:
|
|
|
|
|
llm_chunks , kwargs = service.llm_msg_process(llm_chunks, **kwargs)
|
|
|
|
|
return llm_chunks
|
|
|
|
|
|
|
|
|
|
class TTSAudioServiceChain:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.services = []
|
|
|
|
|
def add_service(self, service:TTSAudioService):
|
|
|
|
|
self.services.append(service)
|
|
|
|
|
def tts_audio_process(self, audio, **kwargs):
|
|
|
|
|
for service in self.services:
|
|
|
|
|
audio, kwargs = service.tts_audio_process(audio, **kwargs)
|
|
|
|
|
return audio
|
|
|
|
|
# --------------------------------- #
|
|
|
|
|
|
|
|
|
|
class Agent():
|
|
|
|
|
def __init__(self, asr_type:str, llm_type:str, tts_type:str):
|
|
|
|
|
asrFactory = ASRFactory()
|
|
|
|
|
llmFactory = LLMFactory()
|
|
|
|
|
ttsFactory = TTSFactory()
|
|
|
|
|
self.recorder = None
|
|
|
|
|
|
|
|
|
|
self.user_audio_service_chain = UserAudioServiceChain() #创建用户音频处理服务链
|
|
|
|
|
self.user_audio_service_chain.add_service(UserAudioRecordService()) #添加用户音频记录服务
|
|
|
|
|
|
|
|
|
|
self.asr = asrFactory.create_asr(asr_type)
|
|
|
|
|
|
|
|
|
|
self.prompt_service_chain = PromptServiceChain()
|
|
|
|
|
self.prompt_service_chain.add_service(BasicPromptService())
|
|
|
|
|
|
|
|
|
|
self.llm = llmFactory.create_llm(llm_type)
|
|
|
|
|
|
2024-06-10 20:49:50 +08:00
|
|
|
|
self.llm_msg_service_chain = LLMMsgServiceChain()
|
2024-06-09 22:54:13 +08:00
|
|
|
|
|
|
|
|
|
self.tts = ttsFactory.create_tts(tts_type)
|
|
|
|
|
|
|
|
|
|
self.tts_audio_service_chain = TTSAudioServiceChain()
|
|
|
|
|
self.tts_audio_service_chain.add_service(TTSAudioRecordService())
|
|
|
|
|
|
|
|
|
|
def init_recorder(self,user_id):
|
|
|
|
|
self.recorder = Recorder(user_id)
|
|
|
|
|
|
|
|
|
|
# 对用户输入的音频进行预处理
|
2024-06-20 10:34:26 +08:00
|
|
|
|
def user_audio_process(self, audio):
|
2024-06-09 22:54:13 +08:00
|
|
|
|
return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder)
|
|
|
|
|
|
|
|
|
|
# 进行流式语音识别
|
2024-06-20 10:34:26 +08:00
|
|
|
|
async def stream_recognize(self, chunk):
|
2024-06-09 22:54:13 +08:00
|
|
|
|
return await self.asr.stream_recognize(chunk)
|
|
|
|
|
|
|
|
|
|
# 进行Prompt加工
|
2024-06-20 10:34:26 +08:00
|
|
|
|
def prompt_process(self, asr_results):
|
2024-06-09 22:54:13 +08:00
|
|
|
|
return self.prompt_service_chain.prompt_process(asr_results)
|
|
|
|
|
|
|
|
|
|
# 进行大模型调用
|
2024-06-20 10:34:26 +08:00
|
|
|
|
async def chat(self, assistant ,prompt):
|
|
|
|
|
return self.llm.chat(assistant, prompt)
|
2024-06-09 22:54:13 +08:00
|
|
|
|
|
|
|
|
|
# 对大模型的返回进行处理
|
2024-06-20 10:34:26 +08:00
|
|
|
|
def llm_msg_process(self, llm_chunk):
|
2024-06-10 20:49:50 +08:00
|
|
|
|
return self.llm_msg_service_chain.llm_msg_process(llm_chunk)
|
2024-06-09 22:54:13 +08:00
|
|
|
|
|
|
|
|
|
# 进行TTS合成
|
2024-06-20 10:34:26 +08:00
|
|
|
|
def synthetize(self, assistant, text):
|
2024-06-09 22:54:13 +08:00
|
|
|
|
return self.tts.synthetize(assistant, text)
|
|
|
|
|
|
|
|
|
|
# 对合成后的音频进行处理
|
2024-06-20 10:34:26 +08:00
|
|
|
|
def tts_audio_process(self, audio):
|
2024-06-09 22:54:13 +08:00
|
|
|
|
return self.tts_audio_service_chain.tts_audio_process(audio, recorder=self.recorder)
|
|
|
|
|
|
|
|
|
|
# 编码
|
|
|
|
|
def encode(self, text, audio):
|
|
|
|
|
text_resp = {"type":"text","code":200,"msg":text}
|
|
|
|
|
text_resp_bytes = json.dumps(text_resp,ensure_ascii=False).encode('utf-8')
|
|
|
|
|
header = struct.pack('!II',len(text_resp_bytes),len(audio))
|
|
|
|
|
final_resp = header + text_resp_bytes + audio
|
|
|
|
|
return final_resp
|
|
|
|
|
|
|
|
|
|
# 保存音频
|
|
|
|
|
def save(self):
|
|
|
|
|
self.recorder.write()
|
|
|
|
|
|