forked from killua/TakwayDisplayPlatform
update:加了一点注释
This commit is contained in:
parent
dba43836b6
commit
d885684533
|
@ -31,7 +31,7 @@ class PromptService(ABC):
|
|||
|
||||
class LLMMsgService(ABC):
|
||||
@abstractmethod
|
||||
def llm_msg_process(self, llm_chunk:list, **kwargs) -> list:
|
||||
def llm_msg_process(self, llm_chunks:list, **kwargs) -> list:
|
||||
pass
|
||||
|
||||
class TTSAudioService(ABC):
|
||||
|
|
|
@ -78,7 +78,7 @@ class MINIMAX_LLM(LLM):
|
|||
llm_info = json.loads(assistant.llm_info)
|
||||
messages = json.loads(assistant.messages)
|
||||
messages.append({'role':'user','content':prompt})
|
||||
payload = json.dumps({
|
||||
payload = json.dumps({ #整理payload
|
||||
"model": llm_info['model'],
|
||||
"stream": True,
|
||||
"messages": messages,
|
||||
|
@ -91,15 +91,15 @@ class MINIMAX_LLM(LLM):
|
|||
'Content-Type': 'application/json'
|
||||
}
|
||||
async with aiohttp.ClientSession() as client:
|
||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response:
|
||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #调用大模型
|
||||
async for chunk in response.content.iter_any():
|
||||
try:
|
||||
chunk_msg = self.__parseChunk(chunk)
|
||||
chunk_msg = self.__parseChunk(chunk) #解析llm返回
|
||||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||
yield msg_frame
|
||||
except LLMResponseEnd:
|
||||
msg_frame = {"is_end":True,"msg":""}
|
||||
if self.token > llm_info['max_tokens'] * 0.8:
|
||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||
msg_frame['msg'] = 'max_token reached'
|
||||
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
|
||||
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||
|
@ -169,12 +169,14 @@ class BasicPromptService(PromptService):
|
|||
else:
|
||||
raise NoAsrResultsError()
|
||||
|
||||
# 记录用户音频
|
||||
class UserAudioRecordService(UserAudioService):
|
||||
def user_audio_process(self, audio:str, **kwargs):
|
||||
audio_decoded = base64.b64decode(audio)
|
||||
kwargs['recorder'].user_audio += audio_decoded
|
||||
return audio,kwargs
|
||||
|
||||
|
||||
# 记录TTS音频
|
||||
class TTSAudioRecordService(TTSAudioService):
|
||||
def tts_audio_process(self, audio:bytes, **kwargs):
|
||||
kwargs['recorder'].tts_audio += audio
|
||||
|
@ -214,7 +216,7 @@ class LLMMsgServiceChain:
|
|||
self.services.append(service)
|
||||
def llm_msg_process(self, llm_msg):
|
||||
kwargs = {}
|
||||
llm_chunks = self.spliter.split(llm_msg)
|
||||
llm_chunks = self.spliter.split(llm_msg) #首先对llm返回进行断句
|
||||
for service in self.services:
|
||||
llm_chunks , kwargs = service.llm_msg_process(llm_chunks, **kwargs)
|
||||
return llm_chunks
|
||||
|
@ -247,7 +249,7 @@ class Agent():
|
|||
|
||||
self.llm = llmFactory.create_llm(llm_type)
|
||||
|
||||
self.llm_chunk_service_chain = LLMMsgServiceChain()
|
||||
self.llm_msg_service_chain = LLMMsgServiceChain()
|
||||
|
||||
self.tts = ttsFactory.create_tts(tts_type)
|
||||
|
||||
|
@ -275,7 +277,7 @@ class Agent():
|
|||
|
||||
# 对大模型的返回进行处理
|
||||
def llm_msg_process(self, llm_chunk, db):
|
||||
return self.llm_chunk_service_chain.llm_msg_process(llm_chunk)
|
||||
return self.llm_msg_service_chain.llm_msg_process(llm_chunk)
|
||||
|
||||
# 进行TTS合成
|
||||
def synthetize(self, assistant, text, db):
|
||||
|
|
4
main.py
4
main.py
|
@ -172,7 +172,7 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
|
|||
await ws.accept()
|
||||
logger.debug("WebSocket连接成功")
|
||||
try:
|
||||
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
|
||||
agent = None
|
||||
assistant = None
|
||||
asr_results = []
|
||||
llm_text = ""
|
||||
|
@ -182,6 +182,8 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
|
|||
if assistant is None:
|
||||
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
|
||||
agent.init_recorder(assistant.user_id)
|
||||
if not agent:
|
||||
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
|
||||
chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
|
||||
asr_results = await agent.stream_recognize(chunk, db)
|
||||
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
|
||||
|
|
Loading…
Reference in New Issue