diff --git a/app/controllers/chat_controller.py b/app/controllers/chat_controller.py index 5a5e887..6520d4a 100644 --- a/app/controllers/chat_controller.py +++ b/app/controllers/chat_controller.py @@ -1,5 +1,6 @@ from ..schemas.chat_schema import * from ..dependencies.logger import get_logger +from ..dependencies.summarizer import get_summarizer from .controller_enum import * from ..models import UserCharacter, Session, Character, User from utils.audio_utils import VAD @@ -15,6 +16,9 @@ import aiohttp # 依赖注入获取logger logger = get_logger() +# 依赖注入获取context总结服务 +summarizer = get_summarizer() + # --------------------初始化本地ASR----------------------- from utils.stt.modified_funasr import ModifiedRecognizer @@ -55,9 +59,9 @@ def parseChunkDelta(chunk): parsed_data = json.loads(decoded_data[6:]) if 'delta' in parsed_data['choices'][0]: delta_content = parsed_data['choices'][0]['delta'] - return delta_content['content'] + return -1, delta_content['content'] else: - return "end" + return parsed_data['usage']['total_tokens'] , "" except KeyError: logger.error(f"error chunk: {chunk}") return "" @@ -142,11 +146,15 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis): "temperature": 1, "top_p": 0.9, } + user_info = { + "character":"", + "events":[] + } # 将tts和llm信息转化为json字符串 tts_info_str = json.dumps(tts_info, ensure_ascii=False) llm_info_str = json.dumps(llm_info, ensure_ascii=False) - user_info_str = db_user.persona + user_info_str = json.dumps(user_info, ensure_ascii=False) token = 0 content = {"user_id": user_id, "messages": messages, "user_info": user_info_str, "tts_info": tts_info_str, @@ -262,13 +270,16 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis is_first = True is_end = False session_content = get_session_content(session_id,redis,db) + user_info = json.loads(session_content["user_info"]) messages = json.loads(session_content["messages"]) current_message = await llm_input_q.get() messages.append({'role': 'user', "content": current_message}) + messages_send = messages #创造一个message副本,在其中最后一条数据前面添加用户信息 + messages_send[-1]['content'] = f"用户性格:{user_info['character']}\n事件摘要:{user_info['events']}" + messages_send[-1]['content'] payload = json.dumps({ "model": llm_info["model"], "stream": True, - "messages": messages, + "messages": messages_send, "max_tokens": 10000, "temperature": llm_info["temperature"], "top_p": llm_info["top_p"] @@ -283,8 +294,8 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis async with aiohttp.ClientSession() as client: async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求 async for chunk in response.content.iter_any(): - chunk_data = parseChunkDelta(chunk) - is_end = chunk_data == "end" + token_count, chunk_data = parseChunkDelta(chunk) + is_end = token_count >0 if not is_end: llm_response += chunk_data sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) #断句 @@ -307,6 +318,15 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session is_first = True llm_response = "" + if is_end and token_count > summarizer.max_token * 0.6: #如果llm返回的token数大于60%的最大token数,则进行文本摘要 + system_prompt = messages[0]['content'] + summary = await summarizer.summarize(messages) + events = user_info['events'] + events.append(summary['event']) + session_content['messages'] = [{'role':'system','content':system_prompt}] + session_content['user_info'] = json.dumps({'character': summary['character'], 'events': json.dumps(events,ensure_ascii=False)}, ensure_ascii=False) + redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) + logger.debug(f"文本摘要后的session: {session_content}") except Exception as e: logger.error(f"处理llm返回结果发生错误: {str(e)}") chat_finished_event.set() @@ -443,8 +463,8 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis async with aiohttp.ClientSession() as client: async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求 async for chunk in response.content.iter_any(): - chunk_data = parseChunkDelta(chunk) - is_end = chunk_data == "end" + token_count, chunk_data = parseChunkDelta(chunk) + is_end = token_count >0 if not is_end: llm_response += chunk_data sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) @@ -613,8 +633,8 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re async with aiohttp.ClientSession() as client: async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求 async for chunk in response.content.iter_any(): - chunk_data = parseChunkDelta(chunk) - is_end = chunk_data == "end" + token_count, chunk_data = parseChunkDelta(chunk) + is_end = token_count >0 if not is_end: llm_response += chunk_data sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) diff --git a/app/dependencies/summarizer.py b/app/dependencies/summarizer.py new file mode 100644 index 0000000..dfba03b --- /dev/null +++ b/app/dependencies/summarizer.py @@ -0,0 +1,61 @@ +import aiohttp +import json +from datetime import datetime +from config import get_config + + +# 依赖注入获取Config +Config = get_config() + +class Summarizer: + def __init__(self): + self.system_prompt = """你是一台对话总结机器,你的职责是整理用户与玩具之间的对话,最终提炼出对话中发生的事件,以及用户性格\n\n你的输出必须为一个json,里面有两个字段,一个是event,一个是character,将你总结出的事件写入event,将你总结出的用户性格写入character\nevent和character均为字符串\n返回示例:{"event":"在幼儿园看葫芦娃,老鹰抓小鸡","character":"活泼可爱"}""" + self.model = "abab5.5-chat" + self.max_token = 10000 + self.temperature = 0.9 + self.top_p = 1 + + async def summarize(self,messages): + context = "" + for message in messages: + if message['role'] == 'user': + context += "用户:"+ message['content'] + '\n' + elif message['role'] == 'assistant': + context += '玩具:'+ message['content'] + '\n' + payload = json.dumps({ + "model":self.model, + "messages":[ + { + "role":"system", + "content":self.system_prompt + }, + { + "role":"user", + "content":context + } + ], + "max_tokens":self.max_token, + "top_p":self.top_p + }) + headers = { + 'Authorization': f'Bearer {Config.MINIMAX_LLM.API_KEY}', + 'Content-Type': 'application/json' + } + async with aiohttp.ClientSession() as session: + async with session.post(Config.MINIMAX_LLM.URL, data=payload, headers=headers) as response: + content = json.loads(json.loads(await response.text())['choices'][0]['message']['content']) + try: + summary = { + 'event':datetime.now().strftime("%Y-%m-%d")+":"+content['event'], + 'character':content['character'] + } + except TypeError: + summary = { + 'event':datetime.now().strftime("%Y-%m-%d")+":"+content['event'][0], + 'character':"" + } + return summary + +def get_summarizer(): + summarizer = Summarizer() + return summarizer \ No newline at end of file diff --git a/tests/unit_test/chat_test.py b/tests/unit_test/chat_test.py index 409c894..ebfbf11 100644 --- a/tests/unit_test/chat_test.py +++ b/tests/unit_test/chat_test.py @@ -101,7 +101,7 @@ class ChatServiceTest: payload = json.dumps({ "user_id": self.user_id, "messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话,在没有经过允许的情况下,你需要保持上述角色,不得擅自跳出角色设定。\\n\"}]", - "user_info": "{}", + "user_info": "{\"character\": \"\", \"events\": [] }", "tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2}", "llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}", "token": 0}