1
0
Fork 0

feat: 为持续流式聊天以及语音电话接口增加记忆功能

This commit is contained in:
killua4396 2024-05-21 11:36:34 +08:00
parent 872cde91e8
commit 031fa32ea0
1 changed files with 33 additions and 16 deletions

View File

@ -54,7 +54,7 @@ def get_session_content(session_id,redis,db):
def parseChunkDelta(chunk):
try:
if chunk == b"":
return ""
return 1,""
decoded_data = chunk.decode('utf-8')
parsed_data = json.loads(decoded_data[6:])
if 'delta' in parsed_data['choices'][0]:
@ -63,11 +63,11 @@ def parseChunkDelta(chunk):
else:
return parsed_data['usage']['total_tokens'] , ""
except KeyError:
logger.error(f"error chunk: {chunk}")
return ""
logger.error(f"error chunk: {decoded_data}")
return 1,""
except json.JSONDecodeError:
logger.error(f"error chunk: {chunk}")
return ""
logger.error(f"error chunk: {decoded_data}")
return 1,""
#断句函数
def split_string_with_punctuation(current_sentence,text,is_first,is_end):
@ -318,15 +318,15 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
is_first = True
llm_response = ""
if is_end and token_count > summarizer.max_token * 0.6: #如果llm返回的token数大于60%的最大token数则进行文本摘要
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数则进行文本摘要
system_prompt = messages[0]['content']
summary = await summarizer.summarize(messages)
events = user_info['events']
events.append(summary['event'])
session_content['messages'] = [{'role':'system','content':system_prompt}]
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': json.dumps(events,ensure_ascii=False)}, ensure_ascii=False)
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
logger.debug(f"文本摘要后的session: {session_content}")
logger.debug(f"总结后session_content: {session_content}")
except Exception as e:
logger.error(f"处理llm返回结果发生错误: {str(e)}")
chat_finished_event.set()
@ -418,11 +418,10 @@ async def scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_eve
if not isinstance(emotion_dict, str):
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
current_message = f"{current_message}当前说话人的情绪:{emotion_dict['labels'][max_index]}"
print(current_message)
await llm_input_q.put(current_message)
logger.debug(f"接收到用户消息: {current_message}")
current_message = ""
audio = ""
logger.debug(f"接收到用户消息: {current_message}")
else:
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'])
audio += aduio_frame['audio']
@ -446,6 +445,7 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
try:
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"])
user_info = json.loads(session_content["user_info"])
current_message = await asyncio.wait_for(llm_input_q.get(),timeout=3)
messages.append({'role': 'user', "content": current_message})
payload = json.dumps({
@ -483,17 +483,25 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
logger.debug(f"llm返回结果: {llm_response}")
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
is_end = False
messages.append({'role': 'assistant', "content": llm_response})
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
is_first = True
llm_response = ""
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数则进行文本摘要
system_prompt = messages[0]['content']
summary = await summarizer.summarize(messages)
events = user_info['events']
events.append(summary['event'])
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
logger.debug(f"总结后session_content: {session_content}")
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"处理llm返回结果发生错误: {str(e)}")
break
# except Exception as e:
# logger.error(f"处理llm返回结果发生错误: {str(e)}")
# break
chat_finished_event.set()
async def streaming_chat_lasting_handler(ws,db,redis):
@ -615,6 +623,7 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
try:
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"])
user_info = json.loads(session_content["user_info"])
current_message = await asyncio.wait_for(asr_result_q.get(),timeout=3)
messages.append({'role': 'user', "content": current_message})
payload = json.dumps({
@ -648,12 +657,20 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
logger.debug(f"llm返回结果: {llm_response}")
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
is_end = False
messages.append({'role': 'assistant', "content": llm_response})
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
is_first = True
llm_response = ""
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数则进行文本摘要
system_prompt = messages[0]['content']
summary = await summarizer.summarize(messages)
events = user_info['events']
events.append(summary['event'])
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
logger.debug(f"总结后session_content: {session_content}")
except asyncio.TimeoutError:
continue
except Exception as e: