forked from killua/TakwayPlatform
feat: 为单次流式聊天接口增加记忆功能
This commit is contained in:
parent
0593bf0482
commit
872cde91e8
|
@ -1,5 +1,6 @@
|
||||||
from ..schemas.chat_schema import *
|
from ..schemas.chat_schema import *
|
||||||
from ..dependencies.logger import get_logger
|
from ..dependencies.logger import get_logger
|
||||||
|
from ..dependencies.summarizer import get_summarizer
|
||||||
from .controller_enum import *
|
from .controller_enum import *
|
||||||
from ..models import UserCharacter, Session, Character, User
|
from ..models import UserCharacter, Session, Character, User
|
||||||
from utils.audio_utils import VAD
|
from utils.audio_utils import VAD
|
||||||
|
@ -15,6 +16,9 @@ import aiohttp
|
||||||
# 依赖注入获取logger
|
# 依赖注入获取logger
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
# 依赖注入获取context总结服务
|
||||||
|
summarizer = get_summarizer()
|
||||||
|
|
||||||
# --------------------初始化本地ASR-----------------------
|
# --------------------初始化本地ASR-----------------------
|
||||||
from utils.stt.modified_funasr import ModifiedRecognizer
|
from utils.stt.modified_funasr import ModifiedRecognizer
|
||||||
|
|
||||||
|
@ -55,9 +59,9 @@ def parseChunkDelta(chunk):
|
||||||
parsed_data = json.loads(decoded_data[6:])
|
parsed_data = json.loads(decoded_data[6:])
|
||||||
if 'delta' in parsed_data['choices'][0]:
|
if 'delta' in parsed_data['choices'][0]:
|
||||||
delta_content = parsed_data['choices'][0]['delta']
|
delta_content = parsed_data['choices'][0]['delta']
|
||||||
return delta_content['content']
|
return -1, delta_content['content']
|
||||||
else:
|
else:
|
||||||
return "end"
|
return parsed_data['usage']['total_tokens'] , ""
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.error(f"error chunk: {chunk}")
|
logger.error(f"error chunk: {chunk}")
|
||||||
return ""
|
return ""
|
||||||
|
@ -142,11 +146,15 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
|
||||||
"temperature": 1,
|
"temperature": 1,
|
||||||
"top_p": 0.9,
|
"top_p": 0.9,
|
||||||
}
|
}
|
||||||
|
user_info = {
|
||||||
|
"character":"",
|
||||||
|
"events":[]
|
||||||
|
}
|
||||||
|
|
||||||
# 将tts和llm信息转化为json字符串
|
# 将tts和llm信息转化为json字符串
|
||||||
tts_info_str = json.dumps(tts_info, ensure_ascii=False)
|
tts_info_str = json.dumps(tts_info, ensure_ascii=False)
|
||||||
llm_info_str = json.dumps(llm_info, ensure_ascii=False)
|
llm_info_str = json.dumps(llm_info, ensure_ascii=False)
|
||||||
user_info_str = db_user.persona
|
user_info_str = json.dumps(user_info, ensure_ascii=False)
|
||||||
|
|
||||||
token = 0
|
token = 0
|
||||||
content = {"user_id": user_id, "messages": messages, "user_info": user_info_str, "tts_info": tts_info_str,
|
content = {"user_id": user_id, "messages": messages, "user_info": user_info_str, "tts_info": tts_info_str,
|
||||||
|
@ -262,13 +270,16 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
is_first = True
|
is_first = True
|
||||||
is_end = False
|
is_end = False
|
||||||
session_content = get_session_content(session_id,redis,db)
|
session_content = get_session_content(session_id,redis,db)
|
||||||
|
user_info = json.loads(session_content["user_info"])
|
||||||
messages = json.loads(session_content["messages"])
|
messages = json.loads(session_content["messages"])
|
||||||
current_message = await llm_input_q.get()
|
current_message = await llm_input_q.get()
|
||||||
messages.append({'role': 'user', "content": current_message})
|
messages.append({'role': 'user', "content": current_message})
|
||||||
|
messages_send = messages #创造一个message副本,在其中最后一条数据前面添加用户信息
|
||||||
|
messages_send[-1]['content'] = f"用户性格:{user_info['character']}\n事件摘要:{user_info['events']}" + messages_send[-1]['content']
|
||||||
payload = json.dumps({
|
payload = json.dumps({
|
||||||
"model": llm_info["model"],
|
"model": llm_info["model"],
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"messages": messages,
|
"messages": messages_send,
|
||||||
"max_tokens": 10000,
|
"max_tokens": 10000,
|
||||||
"temperature": llm_info["temperature"],
|
"temperature": llm_info["temperature"],
|
||||||
"top_p": llm_info["top_p"]
|
"top_p": llm_info["top_p"]
|
||||||
|
@ -283,8 +294,8 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
async with aiohttp.ClientSession() as client:
|
async with aiohttp.ClientSession() as client:
|
||||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
||||||
async for chunk in response.content.iter_any():
|
async for chunk in response.content.iter_any():
|
||||||
chunk_data = parseChunkDelta(chunk)
|
token_count, chunk_data = parseChunkDelta(chunk)
|
||||||
is_end = chunk_data == "end"
|
is_end = token_count >0
|
||||||
if not is_end:
|
if not is_end:
|
||||||
llm_response += chunk_data
|
llm_response += chunk_data
|
||||||
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) #断句
|
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) #断句
|
||||||
|
@ -307,6 +318,15 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||||
is_first = True
|
is_first = True
|
||||||
llm_response = ""
|
llm_response = ""
|
||||||
|
if is_end and token_count > summarizer.max_token * 0.6: #如果llm返回的token数大于60%的最大token数,则进行文本摘要
|
||||||
|
system_prompt = messages[0]['content']
|
||||||
|
summary = await summarizer.summarize(messages)
|
||||||
|
events = user_info['events']
|
||||||
|
events.append(summary['event'])
|
||||||
|
session_content['messages'] = [{'role':'system','content':system_prompt}]
|
||||||
|
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': json.dumps(events,ensure_ascii=False)}, ensure_ascii=False)
|
||||||
|
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
|
||||||
|
logger.debug(f"文本摘要后的session: {session_content}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理llm返回结果发生错误: {str(e)}")
|
logger.error(f"处理llm返回结果发生错误: {str(e)}")
|
||||||
chat_finished_event.set()
|
chat_finished_event.set()
|
||||||
|
@ -443,8 +463,8 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
async with aiohttp.ClientSession() as client:
|
async with aiohttp.ClientSession() as client:
|
||||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
||||||
async for chunk in response.content.iter_any():
|
async for chunk in response.content.iter_any():
|
||||||
chunk_data = parseChunkDelta(chunk)
|
token_count, chunk_data = parseChunkDelta(chunk)
|
||||||
is_end = chunk_data == "end"
|
is_end = token_count >0
|
||||||
if not is_end:
|
if not is_end:
|
||||||
llm_response += chunk_data
|
llm_response += chunk_data
|
||||||
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
|
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
|
||||||
|
@ -613,8 +633,8 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
||||||
async with aiohttp.ClientSession() as client:
|
async with aiohttp.ClientSession() as client:
|
||||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
||||||
async for chunk in response.content.iter_any():
|
async for chunk in response.content.iter_any():
|
||||||
chunk_data = parseChunkDelta(chunk)
|
token_count, chunk_data = parseChunkDelta(chunk)
|
||||||
is_end = chunk_data == "end"
|
is_end = token_count >0
|
||||||
if not is_end:
|
if not is_end:
|
||||||
llm_response += chunk_data
|
llm_response += chunk_data
|
||||||
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
|
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
import aiohttp
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
# 依赖注入获取Config
|
||||||
|
Config = get_config()
|
||||||
|
|
||||||
|
class Summarizer:
|
||||||
|
def __init__(self):
|
||||||
|
self.system_prompt = """你是一台对话总结机器,你的职责是整理用户与玩具之间的对话,最终提炼出对话中发生的事件,以及用户性格\n\n你的输出必须为一个json,里面有两个字段,一个是event,一个是character,将你总结出的事件写入event,将你总结出的用户性格写入character\nevent和character均为字符串\n返回示例:{"event":"在幼儿园看葫芦娃,老鹰抓小鸡","character":"活泼可爱"}"""
|
||||||
|
self.model = "abab5.5-chat"
|
||||||
|
self.max_token = 10000
|
||||||
|
self.temperature = 0.9
|
||||||
|
self.top_p = 1
|
||||||
|
|
||||||
|
async def summarize(self,messages):
|
||||||
|
context = ""
|
||||||
|
for message in messages:
|
||||||
|
if message['role'] == 'user':
|
||||||
|
context += "用户:"+ message['content'] + '\n'
|
||||||
|
elif message['role'] == 'assistant':
|
||||||
|
context += '玩具:'+ message['content'] + '\n'
|
||||||
|
payload = json.dumps({
|
||||||
|
"model":self.model,
|
||||||
|
"messages":[
|
||||||
|
{
|
||||||
|
"role":"system",
|
||||||
|
"content":self.system_prompt
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role":"user",
|
||||||
|
"content":context
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens":self.max_token,
|
||||||
|
"top_p":self.top_p
|
||||||
|
})
|
||||||
|
headers = {
|
||||||
|
'Authorization': f'Bearer {Config.MINIMAX_LLM.API_KEY}',
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(Config.MINIMAX_LLM.URL, data=payload, headers=headers) as response:
|
||||||
|
content = json.loads(json.loads(await response.text())['choices'][0]['message']['content'])
|
||||||
|
try:
|
||||||
|
summary = {
|
||||||
|
'event':datetime.now().strftime("%Y-%m-%d")+":"+content['event'],
|
||||||
|
'character':content['character']
|
||||||
|
}
|
||||||
|
except TypeError:
|
||||||
|
summary = {
|
||||||
|
'event':datetime.now().strftime("%Y-%m-%d")+":"+content['event'][0],
|
||||||
|
'character':""
|
||||||
|
}
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def get_summarizer():
|
||||||
|
summarizer = Summarizer()
|
||||||
|
return summarizer
|
|
@ -101,7 +101,7 @@ class ChatServiceTest:
|
||||||
payload = json.dumps({
|
payload = json.dumps({
|
||||||
"user_id": self.user_id,
|
"user_id": self.user_id,
|
||||||
"messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话,在没有经过允许的情况下,你需要保持上述角色,不得擅自跳出角色设定。\\n\"}]",
|
"messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话,在没有经过允许的情况下,你需要保持上述角色,不得擅自跳出角色设定。\\n\"}]",
|
||||||
"user_info": "{}",
|
"user_info": "{\"character\": \"\", \"events\": [] }",
|
||||||
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2}",
|
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2}",
|
||||||
"llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}",
|
"llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}",
|
||||||
"token": 0}
|
"token": 0}
|
||||||
|
|
Loading…
Reference in New Issue