1
0
Fork 0

update:加了一点注释

This commit is contained in:
killua 2024-06-10 20:49:50 +08:00
parent dba43836b6
commit d885684533
3 changed files with 14 additions and 10 deletions

View File

@ -31,7 +31,7 @@ class PromptService(ABC):
class LLMMsgService(ABC):
@abstractmethod
def llm_msg_process(self, llm_chunk:list, **kwargs) -> list:
def llm_msg_process(self, llm_chunks:list, **kwargs) -> list:
pass
class TTSAudioService(ABC):

View File

@ -78,7 +78,7 @@ class MINIMAX_LLM(LLM):
llm_info = json.loads(assistant.llm_info)
messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt})
payload = json.dumps({
payload = json.dumps({ #整理payload
"model": llm_info['model'],
"stream": True,
"messages": messages,
@ -91,15 +91,15 @@ class MINIMAX_LLM(LLM):
'Content-Type': 'application/json'
}
async with aiohttp.ClientSession() as client:
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response:
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #调用大模型
async for chunk in response.content.iter_any():
try:
chunk_msg = self.__parseChunk(chunk)
chunk_msg = self.__parseChunk(chunk) #解析llm返回
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame
except LLMResponseEnd:
msg_frame = {"is_end":True,"msg":""}
if self.token > llm_info['max_tokens'] * 0.8:
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['msg'] = 'max_token reached'
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
@ -169,12 +169,14 @@ class BasicPromptService(PromptService):
else:
raise NoAsrResultsError()
# 记录用户音频
class UserAudioRecordService(UserAudioService):
def user_audio_process(self, audio:str, **kwargs):
audio_decoded = base64.b64decode(audio)
kwargs['recorder'].user_audio += audio_decoded
return audio,kwargs
# 记录TTS音频
class TTSAudioRecordService(TTSAudioService):
def tts_audio_process(self, audio:bytes, **kwargs):
kwargs['recorder'].tts_audio += audio
@ -214,7 +216,7 @@ class LLMMsgServiceChain:
self.services.append(service)
def llm_msg_process(self, llm_msg):
kwargs = {}
llm_chunks = self.spliter.split(llm_msg)
llm_chunks = self.spliter.split(llm_msg) #首先对llm返回进行断句
for service in self.services:
llm_chunks , kwargs = service.llm_msg_process(llm_chunks, **kwargs)
return llm_chunks
@ -247,7 +249,7 @@ class Agent():
self.llm = llmFactory.create_llm(llm_type)
self.llm_chunk_service_chain = LLMMsgServiceChain()
self.llm_msg_service_chain = LLMMsgServiceChain()
self.tts = ttsFactory.create_tts(tts_type)
@ -275,7 +277,7 @@ class Agent():
# 对大模型的返回进行处理
def llm_msg_process(self, llm_chunk, db):
return self.llm_chunk_service_chain.llm_msg_process(llm_chunk)
return self.llm_msg_service_chain.llm_msg_process(llm_chunk)
# 进行TTS合成
def synthetize(self, assistant, text, db):

View File

@ -172,7 +172,7 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
await ws.accept()
logger.debug("WebSocket连接成功")
try:
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
agent = None
assistant = None
asr_results = []
llm_text = ""
@ -182,6 +182,8 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
if assistant is None:
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
agent.init_recorder(assistant.user_id)
if not agent:
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
asr_results = await agent.stream_recognize(chunk, db)
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果