forked from killua/TakwayDisplayPlatform
update:加了一点注释
This commit is contained in:
parent
dba43836b6
commit
d885684533
|
@ -31,7 +31,7 @@ class PromptService(ABC):
|
||||||
|
|
||||||
class LLMMsgService(ABC):
|
class LLMMsgService(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def llm_msg_process(self, llm_chunk:list, **kwargs) -> list:
|
def llm_msg_process(self, llm_chunks:list, **kwargs) -> list:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class TTSAudioService(ABC):
|
class TTSAudioService(ABC):
|
||||||
|
|
|
@ -78,7 +78,7 @@ class MINIMAX_LLM(LLM):
|
||||||
llm_info = json.loads(assistant.llm_info)
|
llm_info = json.loads(assistant.llm_info)
|
||||||
messages = json.loads(assistant.messages)
|
messages = json.loads(assistant.messages)
|
||||||
messages.append({'role':'user','content':prompt})
|
messages.append({'role':'user','content':prompt})
|
||||||
payload = json.dumps({
|
payload = json.dumps({ #整理payload
|
||||||
"model": llm_info['model'],
|
"model": llm_info['model'],
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
@ -91,15 +91,15 @@ class MINIMAX_LLM(LLM):
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
async with aiohttp.ClientSession() as client:
|
async with aiohttp.ClientSession() as client:
|
||||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response:
|
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #调用大模型
|
||||||
async for chunk in response.content.iter_any():
|
async for chunk in response.content.iter_any():
|
||||||
try:
|
try:
|
||||||
chunk_msg = self.__parseChunk(chunk)
|
chunk_msg = self.__parseChunk(chunk) #解析llm返回
|
||||||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||||
yield msg_frame
|
yield msg_frame
|
||||||
except LLMResponseEnd:
|
except LLMResponseEnd:
|
||||||
msg_frame = {"is_end":True,"msg":""}
|
msg_frame = {"is_end":True,"msg":""}
|
||||||
if self.token > llm_info['max_tokens'] * 0.8:
|
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||||
msg_frame['msg'] = 'max_token reached'
|
msg_frame['msg'] = 'max_token reached'
|
||||||
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
|
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
|
||||||
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||||
|
@ -169,12 +169,14 @@ class BasicPromptService(PromptService):
|
||||||
else:
|
else:
|
||||||
raise NoAsrResultsError()
|
raise NoAsrResultsError()
|
||||||
|
|
||||||
|
# 记录用户音频
|
||||||
class UserAudioRecordService(UserAudioService):
|
class UserAudioRecordService(UserAudioService):
|
||||||
def user_audio_process(self, audio:str, **kwargs):
|
def user_audio_process(self, audio:str, **kwargs):
|
||||||
audio_decoded = base64.b64decode(audio)
|
audio_decoded = base64.b64decode(audio)
|
||||||
kwargs['recorder'].user_audio += audio_decoded
|
kwargs['recorder'].user_audio += audio_decoded
|
||||||
return audio,kwargs
|
return audio,kwargs
|
||||||
|
|
||||||
|
# 记录TTS音频
|
||||||
class TTSAudioRecordService(TTSAudioService):
|
class TTSAudioRecordService(TTSAudioService):
|
||||||
def tts_audio_process(self, audio:bytes, **kwargs):
|
def tts_audio_process(self, audio:bytes, **kwargs):
|
||||||
kwargs['recorder'].tts_audio += audio
|
kwargs['recorder'].tts_audio += audio
|
||||||
|
@ -214,7 +216,7 @@ class LLMMsgServiceChain:
|
||||||
self.services.append(service)
|
self.services.append(service)
|
||||||
def llm_msg_process(self, llm_msg):
|
def llm_msg_process(self, llm_msg):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
llm_chunks = self.spliter.split(llm_msg)
|
llm_chunks = self.spliter.split(llm_msg) #首先对llm返回进行断句
|
||||||
for service in self.services:
|
for service in self.services:
|
||||||
llm_chunks , kwargs = service.llm_msg_process(llm_chunks, **kwargs)
|
llm_chunks , kwargs = service.llm_msg_process(llm_chunks, **kwargs)
|
||||||
return llm_chunks
|
return llm_chunks
|
||||||
|
@ -247,7 +249,7 @@ class Agent():
|
||||||
|
|
||||||
self.llm = llmFactory.create_llm(llm_type)
|
self.llm = llmFactory.create_llm(llm_type)
|
||||||
|
|
||||||
self.llm_chunk_service_chain = LLMMsgServiceChain()
|
self.llm_msg_service_chain = LLMMsgServiceChain()
|
||||||
|
|
||||||
self.tts = ttsFactory.create_tts(tts_type)
|
self.tts = ttsFactory.create_tts(tts_type)
|
||||||
|
|
||||||
|
@ -275,7 +277,7 @@ class Agent():
|
||||||
|
|
||||||
# 对大模型的返回进行处理
|
# 对大模型的返回进行处理
|
||||||
def llm_msg_process(self, llm_chunk, db):
|
def llm_msg_process(self, llm_chunk, db):
|
||||||
return self.llm_chunk_service_chain.llm_msg_process(llm_chunk)
|
return self.llm_msg_service_chain.llm_msg_process(llm_chunk)
|
||||||
|
|
||||||
# 进行TTS合成
|
# 进行TTS合成
|
||||||
def synthetize(self, assistant, text, db):
|
def synthetize(self, assistant, text, db):
|
||||||
|
|
4
main.py
4
main.py
|
@ -172,7 +172,7 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
|
||||||
await ws.accept()
|
await ws.accept()
|
||||||
logger.debug("WebSocket连接成功")
|
logger.debug("WebSocket连接成功")
|
||||||
try:
|
try:
|
||||||
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
|
agent = None
|
||||||
assistant = None
|
assistant = None
|
||||||
asr_results = []
|
asr_results = []
|
||||||
llm_text = ""
|
llm_text = ""
|
||||||
|
@ -182,6 +182,8 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
|
||||||
if assistant is None:
|
if assistant is None:
|
||||||
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
|
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
|
||||||
agent.init_recorder(assistant.user_id)
|
agent.init_recorder(assistant.user_id)
|
||||||
|
if not agent:
|
||||||
|
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
|
||||||
chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
|
chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
|
||||||
asr_results = await agent.stream_recognize(chunk, db)
|
asr_results = await agent.stream_recognize(chunk, db)
|
||||||
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
|
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
|
||||||
|
|
Loading…
Reference in New Issue