forked from killua/TakwayPlatform
feat: 为持续流式聊天以及语音电话接口增加记忆功能
This commit is contained in:
parent
872cde91e8
commit
031fa32ea0
|
@ -54,7 +54,7 @@ def get_session_content(session_id,redis,db):
|
|||
def parseChunkDelta(chunk):
|
||||
try:
|
||||
if chunk == b"":
|
||||
return ""
|
||||
return 1,""
|
||||
decoded_data = chunk.decode('utf-8')
|
||||
parsed_data = json.loads(decoded_data[6:])
|
||||
if 'delta' in parsed_data['choices'][0]:
|
||||
|
@ -63,11 +63,11 @@ def parseChunkDelta(chunk):
|
|||
else:
|
||||
return parsed_data['usage']['total_tokens'] , ""
|
||||
except KeyError:
|
||||
logger.error(f"error chunk: {chunk}")
|
||||
return ""
|
||||
logger.error(f"error chunk: {decoded_data}")
|
||||
return 1,""
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"error chunk: {chunk}")
|
||||
return ""
|
||||
logger.error(f"error chunk: {decoded_data}")
|
||||
return 1,""
|
||||
|
||||
#断句函数
|
||||
def split_string_with_punctuation(current_sentence,text,is_first,is_end):
|
||||
|
@ -318,15 +318,15 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||
is_first = True
|
||||
llm_response = ""
|
||||
if is_end and token_count > summarizer.max_token * 0.6: #如果llm返回的token数大于60%的最大token数,则进行文本摘要
|
||||
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数,则进行文本摘要
|
||||
system_prompt = messages[0]['content']
|
||||
summary = await summarizer.summarize(messages)
|
||||
events = user_info['events']
|
||||
events.append(summary['event'])
|
||||
session_content['messages'] = [{'role':'system','content':system_prompt}]
|
||||
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': json.dumps(events,ensure_ascii=False)}, ensure_ascii=False)
|
||||
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
|
||||
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
|
||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
|
||||
logger.debug(f"文本摘要后的session: {session_content}")
|
||||
logger.debug(f"总结后session_content: {session_content}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理llm返回结果发生错误: {str(e)}")
|
||||
chat_finished_event.set()
|
||||
|
@ -418,11 +418,10 @@ async def scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_eve
|
|||
if not isinstance(emotion_dict, str):
|
||||
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
||||
current_message = f"{current_message}当前说话人的情绪:{emotion_dict['labels'][max_index]}"
|
||||
print(current_message)
|
||||
await llm_input_q.put(current_message)
|
||||
logger.debug(f"接收到用户消息: {current_message}")
|
||||
current_message = ""
|
||||
audio = ""
|
||||
logger.debug(f"接收到用户消息: {current_message}")
|
||||
else:
|
||||
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'])
|
||||
audio += aduio_frame['audio']
|
||||
|
@ -446,6 +445,7 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
try:
|
||||
session_content = get_session_content(session_id,redis,db)
|
||||
messages = json.loads(session_content["messages"])
|
||||
user_info = json.loads(session_content["user_info"])
|
||||
current_message = await asyncio.wait_for(llm_input_q.get(),timeout=3)
|
||||
messages.append({'role': 'user', "content": current_message})
|
||||
payload = json.dumps({
|
||||
|
@ -483,17 +483,25 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
logger.debug(f"llm返回结果: {llm_response}")
|
||||
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||
is_end = False
|
||||
|
||||
messages.append({'role': 'assistant', "content": llm_response})
|
||||
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||
is_first = True
|
||||
llm_response = ""
|
||||
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数,则进行文本摘要
|
||||
system_prompt = messages[0]['content']
|
||||
summary = await summarizer.summarize(messages)
|
||||
events = user_info['events']
|
||||
events.append(summary['event'])
|
||||
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
|
||||
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
|
||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
|
||||
logger.debug(f"总结后session_content: {session_content}")
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"处理llm返回结果发生错误: {str(e)}")
|
||||
break
|
||||
# except Exception as e:
|
||||
# logger.error(f"处理llm返回结果发生错误: {str(e)}")
|
||||
# break
|
||||
chat_finished_event.set()
|
||||
|
||||
async def streaming_chat_lasting_handler(ws,db,redis):
|
||||
|
@ -615,6 +623,7 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
|||
try:
|
||||
session_content = get_session_content(session_id,redis,db)
|
||||
messages = json.loads(session_content["messages"])
|
||||
user_info = json.loads(session_content["user_info"])
|
||||
current_message = await asyncio.wait_for(asr_result_q.get(),timeout=3)
|
||||
messages.append({'role': 'user', "content": current_message})
|
||||
payload = json.dumps({
|
||||
|
@ -648,12 +657,20 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
|||
logger.debug(f"llm返回结果: {llm_response}")
|
||||
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||
is_end = False
|
||||
|
||||
messages.append({'role': 'assistant', "content": llm_response})
|
||||
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||
is_first = True
|
||||
llm_response = ""
|
||||
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数,则进行文本摘要
|
||||
system_prompt = messages[0]['content']
|
||||
summary = await summarizer.summarize(messages)
|
||||
events = user_info['events']
|
||||
events.append(summary['event'])
|
||||
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
|
||||
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
|
||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
|
||||
logger.debug(f"总结后session_content: {session_content}")
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception as e:
|
||||
|
|
Loading…
Reference in New Issue