update:加了一点注释

This commit is contained in:
killua 2024-06-10 20:49:50 +08:00
parent dba43836b6
commit d885684533
3 changed files with 14 additions and 10 deletions

View File

@ -31,7 +31,7 @@ class PromptService(ABC):
class LLMMsgService(ABC): class LLMMsgService(ABC):
@abstractmethod @abstractmethod
def llm_msg_process(self, llm_chunk:list, **kwargs) -> list: def llm_msg_process(self, llm_chunks:list, **kwargs) -> list:
pass pass
class TTSAudioService(ABC): class TTSAudioService(ABC):

View File

@ -78,7 +78,7 @@ class MINIMAX_LLM(LLM):
llm_info = json.loads(assistant.llm_info) llm_info = json.loads(assistant.llm_info)
messages = json.loads(assistant.messages) messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt}) messages.append({'role':'user','content':prompt})
payload = json.dumps({ payload = json.dumps({ #整理payload
"model": llm_info['model'], "model": llm_info['model'],
"stream": True, "stream": True,
"messages": messages, "messages": messages,
@ -91,15 +91,15 @@ class MINIMAX_LLM(LLM):
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
async with aiohttp.ClientSession() as client: async with aiohttp.ClientSession() as client:
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #调用大模型
async for chunk in response.content.iter_any(): async for chunk in response.content.iter_any():
try: try:
chunk_msg = self.__parseChunk(chunk) chunk_msg = self.__parseChunk(chunk) #解析llm返回
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame yield msg_frame
except LLMResponseEnd: except LLMResponseEnd:
msg_frame = {"is_end":True,"msg":""} msg_frame = {"is_end":True,"msg":""}
if self.token > llm_info['max_tokens'] * 0.8: if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['msg'] = 'max_token reached' msg_frame['msg'] = 'max_token reached'
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first() as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
@ -169,12 +169,14 @@ class BasicPromptService(PromptService):
else: else:
raise NoAsrResultsError() raise NoAsrResultsError()
# 记录用户音频
class UserAudioRecordService(UserAudioService): class UserAudioRecordService(UserAudioService):
def user_audio_process(self, audio:str, **kwargs): def user_audio_process(self, audio:str, **kwargs):
audio_decoded = base64.b64decode(audio) audio_decoded = base64.b64decode(audio)
kwargs['recorder'].user_audio += audio_decoded kwargs['recorder'].user_audio += audio_decoded
return audio,kwargs return audio,kwargs
# 记录TTS音频
class TTSAudioRecordService(TTSAudioService): class TTSAudioRecordService(TTSAudioService):
def tts_audio_process(self, audio:bytes, **kwargs): def tts_audio_process(self, audio:bytes, **kwargs):
kwargs['recorder'].tts_audio += audio kwargs['recorder'].tts_audio += audio
@ -214,7 +216,7 @@ class LLMMsgServiceChain:
self.services.append(service) self.services.append(service)
def llm_msg_process(self, llm_msg): def llm_msg_process(self, llm_msg):
kwargs = {} kwargs = {}
llm_chunks = self.spliter.split(llm_msg) llm_chunks = self.spliter.split(llm_msg) #首先对llm返回进行断句
for service in self.services: for service in self.services:
llm_chunks , kwargs = service.llm_msg_process(llm_chunks, **kwargs) llm_chunks , kwargs = service.llm_msg_process(llm_chunks, **kwargs)
return llm_chunks return llm_chunks
@ -247,7 +249,7 @@ class Agent():
self.llm = llmFactory.create_llm(llm_type) self.llm = llmFactory.create_llm(llm_type)
self.llm_chunk_service_chain = LLMMsgServiceChain() self.llm_msg_service_chain = LLMMsgServiceChain()
self.tts = ttsFactory.create_tts(tts_type) self.tts = ttsFactory.create_tts(tts_type)
@ -275,7 +277,7 @@ class Agent():
# 对大模型的返回进行处理 # 对大模型的返回进行处理
def llm_msg_process(self, llm_chunk, db): def llm_msg_process(self, llm_chunk, db):
return self.llm_chunk_service_chain.llm_msg_process(llm_chunk) return self.llm_msg_service_chain.llm_msg_process(llm_chunk)
# 进行TTS合成 # 进行TTS合成
def synthetize(self, assistant, text, db): def synthetize(self, assistant, text, db):

View File

@ -172,7 +172,7 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
await ws.accept() await ws.accept()
logger.debug("WebSocket连接成功") logger.debug("WebSocket连接成功")
try: try:
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS) agent = None
assistant = None assistant = None
asr_results = [] asr_results = []
llm_text = "" llm_text = ""
@ -182,6 +182,8 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
if assistant is None: if assistant is None:
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first() assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
agent.init_recorder(assistant.user_id) agent.init_recorder(assistant.user_id)
if not agent:
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
chunk["audio"] = agent.user_audio_process(chunk["audio"], db) chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
asr_results = await agent.stream_recognize(chunk, db) asr_results = await agent.stream_recognize(chunk, db)
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果 kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果