forked from killua/TakwayPlatform
feat: 为持续流式聊天以及语音电话接口增加记忆功能
This commit is contained in:
parent
872cde91e8
commit
031fa32ea0
|
@ -54,7 +54,7 @@ def get_session_content(session_id,redis,db):
|
||||||
def parseChunkDelta(chunk):
|
def parseChunkDelta(chunk):
|
||||||
try:
|
try:
|
||||||
if chunk == b"":
|
if chunk == b"":
|
||||||
return ""
|
return 1,""
|
||||||
decoded_data = chunk.decode('utf-8')
|
decoded_data = chunk.decode('utf-8')
|
||||||
parsed_data = json.loads(decoded_data[6:])
|
parsed_data = json.loads(decoded_data[6:])
|
||||||
if 'delta' in parsed_data['choices'][0]:
|
if 'delta' in parsed_data['choices'][0]:
|
||||||
|
@ -63,11 +63,11 @@ def parseChunkDelta(chunk):
|
||||||
else:
|
else:
|
||||||
return parsed_data['usage']['total_tokens'] , ""
|
return parsed_data['usage']['total_tokens'] , ""
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.error(f"error chunk: {chunk}")
|
logger.error(f"error chunk: {decoded_data}")
|
||||||
return ""
|
return 1,""
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.error(f"error chunk: {chunk}")
|
logger.error(f"error chunk: {decoded_data}")
|
||||||
return ""
|
return 1,""
|
||||||
|
|
||||||
#断句函数
|
#断句函数
|
||||||
def split_string_with_punctuation(current_sentence,text,is_first,is_end):
|
def split_string_with_punctuation(current_sentence,text,is_first,is_end):
|
||||||
|
@ -318,15 +318,15 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||||
is_first = True
|
is_first = True
|
||||||
llm_response = ""
|
llm_response = ""
|
||||||
if is_end and token_count > summarizer.max_token * 0.6: #如果llm返回的token数大于60%的最大token数,则进行文本摘要
|
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数,则进行文本摘要
|
||||||
system_prompt = messages[0]['content']
|
system_prompt = messages[0]['content']
|
||||||
summary = await summarizer.summarize(messages)
|
summary = await summarizer.summarize(messages)
|
||||||
events = user_info['events']
|
events = user_info['events']
|
||||||
events.append(summary['event'])
|
events.append(summary['event'])
|
||||||
session_content['messages'] = [{'role':'system','content':system_prompt}]
|
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
|
||||||
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': json.dumps(events,ensure_ascii=False)}, ensure_ascii=False)
|
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
|
||||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
|
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
|
||||||
logger.debug(f"文本摘要后的session: {session_content}")
|
logger.debug(f"总结后session_content: {session_content}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理llm返回结果发生错误: {str(e)}")
|
logger.error(f"处理llm返回结果发生错误: {str(e)}")
|
||||||
chat_finished_event.set()
|
chat_finished_event.set()
|
||||||
|
@ -418,11 +418,10 @@ async def scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_eve
|
||||||
if not isinstance(emotion_dict, str):
|
if not isinstance(emotion_dict, str):
|
||||||
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
||||||
current_message = f"{current_message}当前说话人的情绪:{emotion_dict['labels'][max_index]}"
|
current_message = f"{current_message}当前说话人的情绪:{emotion_dict['labels'][max_index]}"
|
||||||
print(current_message)
|
|
||||||
await llm_input_q.put(current_message)
|
await llm_input_q.put(current_message)
|
||||||
|
logger.debug(f"接收到用户消息: {current_message}")
|
||||||
current_message = ""
|
current_message = ""
|
||||||
audio = ""
|
audio = ""
|
||||||
logger.debug(f"接收到用户消息: {current_message}")
|
|
||||||
else:
|
else:
|
||||||
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'])
|
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'])
|
||||||
audio += aduio_frame['audio']
|
audio += aduio_frame['audio']
|
||||||
|
@ -446,6 +445,7 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
try:
|
try:
|
||||||
session_content = get_session_content(session_id,redis,db)
|
session_content = get_session_content(session_id,redis,db)
|
||||||
messages = json.loads(session_content["messages"])
|
messages = json.loads(session_content["messages"])
|
||||||
|
user_info = json.loads(session_content["user_info"])
|
||||||
current_message = await asyncio.wait_for(llm_input_q.get(),timeout=3)
|
current_message = await asyncio.wait_for(llm_input_q.get(),timeout=3)
|
||||||
messages.append({'role': 'user', "content": current_message})
|
messages.append({'role': 'user', "content": current_message})
|
||||||
payload = json.dumps({
|
payload = json.dumps({
|
||||||
|
@ -483,17 +483,25 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
logger.debug(f"llm返回结果: {llm_response}")
|
logger.debug(f"llm返回结果: {llm_response}")
|
||||||
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||||
is_end = False
|
is_end = False
|
||||||
|
|
||||||
messages.append({'role': 'assistant', "content": llm_response})
|
messages.append({'role': 'assistant', "content": llm_response})
|
||||||
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
||||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||||
is_first = True
|
is_first = True
|
||||||
llm_response = ""
|
llm_response = ""
|
||||||
|
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数,则进行文本摘要
|
||||||
|
system_prompt = messages[0]['content']
|
||||||
|
summary = await summarizer.summarize(messages)
|
||||||
|
events = user_info['events']
|
||||||
|
events.append(summary['event'])
|
||||||
|
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
|
||||||
|
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
|
||||||
|
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
|
||||||
|
logger.debug(f"总结后session_content: {session_content}")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"处理llm返回结果发生错误: {str(e)}")
|
# logger.error(f"处理llm返回结果发生错误: {str(e)}")
|
||||||
break
|
# break
|
||||||
chat_finished_event.set()
|
chat_finished_event.set()
|
||||||
|
|
||||||
async def streaming_chat_lasting_handler(ws,db,redis):
|
async def streaming_chat_lasting_handler(ws,db,redis):
|
||||||
|
@ -615,6 +623,7 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
||||||
try:
|
try:
|
||||||
session_content = get_session_content(session_id,redis,db)
|
session_content = get_session_content(session_id,redis,db)
|
||||||
messages = json.loads(session_content["messages"])
|
messages = json.loads(session_content["messages"])
|
||||||
|
user_info = json.loads(session_content["user_info"])
|
||||||
current_message = await asyncio.wait_for(asr_result_q.get(),timeout=3)
|
current_message = await asyncio.wait_for(asr_result_q.get(),timeout=3)
|
||||||
messages.append({'role': 'user', "content": current_message})
|
messages.append({'role': 'user', "content": current_message})
|
||||||
payload = json.dumps({
|
payload = json.dumps({
|
||||||
|
@ -648,12 +657,20 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
||||||
logger.debug(f"llm返回结果: {llm_response}")
|
logger.debug(f"llm返回结果: {llm_response}")
|
||||||
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||||
is_end = False
|
is_end = False
|
||||||
|
|
||||||
messages.append({'role': 'assistant', "content": llm_response})
|
messages.append({'role': 'assistant', "content": llm_response})
|
||||||
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
||||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||||
is_first = True
|
is_first = True
|
||||||
llm_response = ""
|
llm_response = ""
|
||||||
|
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数,则进行文本摘要
|
||||||
|
system_prompt = messages[0]['content']
|
||||||
|
summary = await summarizer.summarize(messages)
|
||||||
|
events = user_info['events']
|
||||||
|
events.append(summary['event'])
|
||||||
|
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
|
||||||
|
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
|
||||||
|
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
|
||||||
|
logger.debug(f"总结后session_content: {session_content}")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
Loading…
Reference in New Issue