From 6354fe49f03f5c27eb187af6f86b295cb33b61a8 Mon Sep 17 00:00:00 2001 From: killua4396 <1223086337@qq.com> Date: Thu, 20 Jun 2024 14:12:51 +0800 Subject: [PATCH] =?UTF-8?q?update:=20=E5=A2=9E=E5=8A=A0=E4=BA=86=E6=99=BA?= =?UTF-8?q?=E8=B0=B1=E5=A4=A7=E6=A8=A1=E5=9E=8B=EF=BC=8C=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E4=BA=86=E4=BF=AE=E6=94=B9max=5Ftokens=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/abstract.py | 2 +- app/concrete.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++--- app/schemas.py | 5 ++++- config.py | 4 +++- main.py | 15 +++++++++++++- 5 files changed, 72 insertions(+), 7 deletions(-) diff --git a/app/abstract.py b/app/abstract.py index 13947b7..5c6f121 100644 --- a/app/abstract.py +++ b/app/abstract.py @@ -5,7 +5,7 @@ class ASR(ABC): def __init__(self): self.is_slience = False @abstractmethod - async def stream_recognize(self, chunk): + async def stream_recognize(self, assistant, chunk): pass class LLM(ABC): diff --git a/app/concrete.py b/app/concrete.py index 3c19988..eeb45b7 100644 --- a/app/concrete.py +++ b/app/concrete.py @@ -1,5 +1,6 @@ from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv from volcenginesdkarkruntime import Ark +from zhipuai import ZhipuAI from .model import Assistant from .abstract import * from .public import * @@ -37,7 +38,7 @@ class XF_ASR(ASR): self.segment_duration_threshold = 25 #超时时间为25秒 self.segment_start_time = None - async def stream_recognize(self, chunk): + async def stream_recognize(self, assistant, chunk): if self.status == FIRST_FRAME and chunk['meta_info']['is_end']: #如果是第一帧,且为end,则判断为杂音 raise SideNoiseError() if self.websocket is None: #如果websocket未建立,则建立一个新的连接 @@ -188,6 +189,50 @@ class VOLCENGINE_LLM(LLM): raise AbnormalLLMFrame(f"error volcengine llm_chunk:{llm_chunk}") return llm_chunk.choices[0].delta.content +class ZHIPU_LLM(LLM): + def __init__(self): + self.token = 0 + self.client = ZhipuAI(api_key=Config.ZHIPU_LLM.API_KEY) + + async def chat(self, assistant, prompt): + llm_info = json.loads(assistant.llm_info) + messages = json.loads(assistant.messages) + messages.append({'role':'user','content':prompt}) + stream = self.client.chat.completions.create( + model = llm_info['model'], + messages=messages, + stream=True, + temperature=llm_info['temperature'], + top_p=llm_info['top_p'], + max_tokens=llm_info['max_tokens'] + ) + for chunk in stream: + try: + chunk_msg = self.__parseChunk(chunk) + msg_frame = {"is_end":False,"code":200,"msg":chunk_msg} + yield msg_frame + except LLMResponseEnd: + msg_frame = {"is_end":True,"code":200,"msg":""} + assistant.token = self.token + if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session + msg_frame['code'] = '201' + assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}]) + yield msg_frame + + def __parseChunk(self, llm_chunk): + if llm_chunk.usage: + self.token = llm_chunk.usage.total_tokens + raise LLMResponseEnd() + if not llm_chunk.choices: + raise AbnormalLLMFrame(f"error zhipu llm_chunk:{llm_chunk}") + return llm_chunk.choices[0].delta.content + +class MIXED_LLM(LLM): + def __init__(self): + self.minimax = MINIMAX_LLM() + self.volcengine = VOLCENGINE_LLM() + self.zhipu = ZHIPU_LLM() + class VITS_TTS(TTS): def __init__(self): pass @@ -211,6 +256,8 @@ class LLMFactory: return MINIMAX_LLM() if llm_type == 'VOLCENGINE': return VOLCENGINE_LLM() + if llm_type == 'ZHIPU': + return ZHIPU_LLM() class TTSFactory: def create_tts(self,tts_type:str) -> TTS: @@ -324,8 +371,8 @@ class Agent(): return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder) # 进行流式语音识别 - async def stream_recognize(self, chunk): - return await self.asr.stream_recognize(chunk) + async def stream_recognize(self, assistant, chunk): + return await self.asr.stream_recognize(assistant,chunk) # 进行Prompt加工 def prompt_process(self, asr_results): diff --git a/app/schemas.py b/app/schemas.py index ffb2aea..d597836 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -34,4 +34,7 @@ class update_assistant_deatil_params_request(BaseModel): model :str temperature :float speaker_id:int - length_scale:float \ No newline at end of file + length_scale:float + +class update_assistant_max_tokens_request(BaseModel): + max_tokens:int \ No newline at end of file diff --git a/config.py b/config.py index a31ce8d..f85c2e8 100644 --- a/config.py +++ b/config.py @@ -22,4 +22,6 @@ class Config: API_KEY = "a1bf964c-5c12-4d2b-ad97-85893e14d55d" DOUBAO_LITE_4k = "ep-20240612075552-5c7tk" DOUBAO_LITE_32k = "ep-20240618130753-q85dm" - DOUBAO_PRO_32k = "ep-20240618145315-pm2c6" \ No newline at end of file + DOUBAO_PRO_32k = "ep-20240618145315-pm2c6" + class ZHIPU_LLM: + API_KEY = "8e7e14bcb66e772e19825d8211b3cc76.WSqQIst0deRMfUIG" \ No newline at end of file diff --git a/main.py b/main.py index dd97d21..1571f9a 100644 --- a/main.py +++ b/main.py @@ -130,6 +130,19 @@ async def update_assistant_deatil_params(id: str,request: update_assistant_deati return {"code":200,"msg":"success","data":{}} else: return {"code":404,'msg':"assistant not found","data":{}} + +# 更新max_tokens +@app.put("/api/assistants/{id}/max_tokens") +async def update_assistant_max_tokens(id: str,request: update_assistant_max_tokens_request,db=Depends(get_db)): + assistant = db.query(Assistant).filter(Assistant.id == id).first() + if assistant: + llm_info = json.loads(assistant.llm_info) + llm_info['max_tokens'] = request.max_tokens + assistant.llm_info = json.dumps(llm_info, ensure_ascii=False) + db.commit() + return {"code":200,"msg":"success","data":{}} + else: + return {"code":404,'msg':"assistant not found","data":{}} # -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- @@ -207,7 +220,7 @@ async def streaming_chat(ws: WebSocket): agent = Agent(asr_type=user_info['asr_type'], llm_type=user_info['llm_type'], tts_type=user_info['tts_type']) agent.init_recorder(assistant.user_id) chunk["audio"] = agent.user_audio_process(chunk["audio"]) - asr_results = await agent.stream_recognize(chunk) + asr_results = await agent.stream_recognize(assistant, chunk) kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果 prompt = agent.prompt_process(asr_results) agent.recorder.input_text = prompt