feat: 为单次流式聊天接口增加记忆功能

This commit is contained in:
killua4396 2024-05-20 21:38:44 +08:00
parent 0593bf0482
commit 872cde91e8
3 changed files with 92 additions and 11 deletions

View File

@ -1,5 +1,6 @@
from ..schemas.chat_schema import * from ..schemas.chat_schema import *
from ..dependencies.logger import get_logger from ..dependencies.logger import get_logger
from ..dependencies.summarizer import get_summarizer
from .controller_enum import * from .controller_enum import *
from ..models import UserCharacter, Session, Character, User from ..models import UserCharacter, Session, Character, User
from utils.audio_utils import VAD from utils.audio_utils import VAD
@ -15,6 +16,9 @@ import aiohttp
# 依赖注入获取logger # 依赖注入获取logger
logger = get_logger() logger = get_logger()
# 依赖注入获取context总结服务
summarizer = get_summarizer()
# --------------------初始化本地ASR----------------------- # --------------------初始化本地ASR-----------------------
from utils.stt.modified_funasr import ModifiedRecognizer from utils.stt.modified_funasr import ModifiedRecognizer
@ -55,9 +59,9 @@ def parseChunkDelta(chunk):
parsed_data = json.loads(decoded_data[6:]) parsed_data = json.loads(decoded_data[6:])
if 'delta' in parsed_data['choices'][0]: if 'delta' in parsed_data['choices'][0]:
delta_content = parsed_data['choices'][0]['delta'] delta_content = parsed_data['choices'][0]['delta']
return delta_content['content'] return -1, delta_content['content']
else: else:
return "end" return parsed_data['usage']['total_tokens'] , ""
except KeyError: except KeyError:
logger.error(f"error chunk: {chunk}") logger.error(f"error chunk: {chunk}")
return "" return ""
@ -142,11 +146,15 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
"temperature": 1, "temperature": 1,
"top_p": 0.9, "top_p": 0.9,
} }
user_info = {
"character":"",
"events":[]
}
# 将tts和llm信息转化为json字符串 # 将tts和llm信息转化为json字符串
tts_info_str = json.dumps(tts_info, ensure_ascii=False) tts_info_str = json.dumps(tts_info, ensure_ascii=False)
llm_info_str = json.dumps(llm_info, ensure_ascii=False) llm_info_str = json.dumps(llm_info, ensure_ascii=False)
user_info_str = db_user.persona user_info_str = json.dumps(user_info, ensure_ascii=False)
token = 0 token = 0
content = {"user_id": user_id, "messages": messages, "user_info": user_info_str, "tts_info": tts_info_str, content = {"user_id": user_id, "messages": messages, "user_info": user_info_str, "tts_info": tts_info_str,
@ -262,13 +270,16 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
is_first = True is_first = True
is_end = False is_end = False
session_content = get_session_content(session_id,redis,db) session_content = get_session_content(session_id,redis,db)
user_info = json.loads(session_content["user_info"])
messages = json.loads(session_content["messages"]) messages = json.loads(session_content["messages"])
current_message = await llm_input_q.get() current_message = await llm_input_q.get()
messages.append({'role': 'user', "content": current_message}) messages.append({'role': 'user', "content": current_message})
messages_send = messages #创造一个message副本在其中最后一条数据前面添加用户信息
messages_send[-1]['content'] = f"用户性格:{user_info['character']}\n事件摘要:{user_info['events']}" + messages_send[-1]['content']
payload = json.dumps({ payload = json.dumps({
"model": llm_info["model"], "model": llm_info["model"],
"stream": True, "stream": True,
"messages": messages, "messages": messages_send,
"max_tokens": 10000, "max_tokens": 10000,
"temperature": llm_info["temperature"], "temperature": llm_info["temperature"],
"top_p": llm_info["top_p"] "top_p": llm_info["top_p"]
@ -283,8 +294,8 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
async with aiohttp.ClientSession() as client: async with aiohttp.ClientSession() as client:
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求 async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
async for chunk in response.content.iter_any(): async for chunk in response.content.iter_any():
chunk_data = parseChunkDelta(chunk) token_count, chunk_data = parseChunkDelta(chunk)
is_end = chunk_data == "end" is_end = token_count >0
if not is_end: if not is_end:
llm_response += chunk_data llm_response += chunk_data
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) #断句 sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) #断句
@ -307,6 +318,15 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
is_first = True is_first = True
llm_response = "" llm_response = ""
if is_end and token_count > summarizer.max_token * 0.6: #如果llm返回的token数大于60%的最大token数则进行文本摘要
system_prompt = messages[0]['content']
summary = await summarizer.summarize(messages)
events = user_info['events']
events.append(summary['event'])
session_content['messages'] = [{'role':'system','content':system_prompt}]
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': json.dumps(events,ensure_ascii=False)}, ensure_ascii=False)
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
logger.debug(f"文本摘要后的session: {session_content}")
except Exception as e: except Exception as e:
logger.error(f"处理llm返回结果发生错误: {str(e)}") logger.error(f"处理llm返回结果发生错误: {str(e)}")
chat_finished_event.set() chat_finished_event.set()
@ -443,8 +463,8 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
async with aiohttp.ClientSession() as client: async with aiohttp.ClientSession() as client:
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求 async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
async for chunk in response.content.iter_any(): async for chunk in response.content.iter_any():
chunk_data = parseChunkDelta(chunk) token_count, chunk_data = parseChunkDelta(chunk)
is_end = chunk_data == "end" is_end = token_count >0
if not is_end: if not is_end:
llm_response += chunk_data llm_response += chunk_data
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
@ -613,8 +633,8 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
async with aiohttp.ClientSession() as client: async with aiohttp.ClientSession() as client:
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求 async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
async for chunk in response.content.iter_any(): async for chunk in response.content.iter_any():
chunk_data = parseChunkDelta(chunk) token_count, chunk_data = parseChunkDelta(chunk)
is_end = chunk_data == "end" is_end = token_count >0
if not is_end: if not is_end:
llm_response += chunk_data llm_response += chunk_data
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)

View File

@ -0,0 +1,61 @@
import aiohttp
import json
from datetime import datetime
from config import get_config
# 依赖注入获取Config
Config = get_config()
class Summarizer:
def __init__(self):
self.system_prompt = """你是一台对话总结机器,你的职责是整理用户与玩具之间的对话,最终提炼出对话中发生的事件,以及用户性格\n\n你的输出必须为一个json里面有两个字段一个是event一个是character将你总结出的事件写入event将你总结出的用户性格写入character\nevent和character均为字符串\n返回示例:{"event":"在幼儿园看葫芦娃,老鹰抓小鸡","character":"活泼可爱"}"""
self.model = "abab5.5-chat"
self.max_token = 10000
self.temperature = 0.9
self.top_p = 1
async def summarize(self,messages):
context = ""
for message in messages:
if message['role'] == 'user':
context += "用户:"+ message['content'] + '\n'
elif message['role'] == 'assistant':
context += '玩具:'+ message['content'] + '\n'
payload = json.dumps({
"model":self.model,
"messages":[
{
"role":"system",
"content":self.system_prompt
},
{
"role":"user",
"content":context
}
],
"max_tokens":self.max_token,
"top_p":self.top_p
})
headers = {
'Authorization': f'Bearer {Config.MINIMAX_LLM.API_KEY}',
'Content-Type': 'application/json'
}
async with aiohttp.ClientSession() as session:
async with session.post(Config.MINIMAX_LLM.URL, data=payload, headers=headers) as response:
content = json.loads(json.loads(await response.text())['choices'][0]['message']['content'])
try:
summary = {
'event':datetime.now().strftime("%Y-%m-%d")+""+content['event'],
'character':content['character']
}
except TypeError:
summary = {
'event':datetime.now().strftime("%Y-%m-%d")+""+content['event'][0],
'character':""
}
return summary
def get_summarizer():
summarizer = Summarizer()
return summarizer

View File

@ -101,7 +101,7 @@ class ChatServiceTest:
payload = json.dumps({ payload = json.dumps({
"user_id": self.user_id, "user_id": self.user_id,
"messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话在没有经过允许的情况下你需要保持上述角色不得擅自跳出角色设定。\\n\"}]", "messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话在没有经过允许的情况下你需要保持上述角色不得擅自跳出角色设定。\\n\"}]",
"user_info": "{}", "user_info": "{\"character\": \"\", \"events\": [] }",
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2}", "tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2}",
"llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}", "llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}",
"token": 0} "token": 0}