update: 增加了智谱大模型,增加了修改max_tokens接口

This commit is contained in:
killua4396 2024-06-20 14:12:51 +08:00
parent 5a47440e0f
commit 6354fe49f0
5 changed files with 72 additions and 7 deletions

View File

@ -5,7 +5,7 @@ class ASR(ABC):
def __init__(self):
self.is_slience = False
@abstractmethod
async def stream_recognize(self, chunk):
async def stream_recognize(self, assistant, chunk):
pass
class LLM(ABC):

View File

@ -1,5 +1,6 @@
from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
from volcenginesdkarkruntime import Ark
from zhipuai import ZhipuAI
from .model import Assistant
from .abstract import *
from .public import *
@ -37,7 +38,7 @@ class XF_ASR(ASR):
self.segment_duration_threshold = 25 #超时时间为25秒
self.segment_start_time = None
async def stream_recognize(self, chunk):
async def stream_recognize(self, assistant, chunk):
if self.status == FIRST_FRAME and chunk['meta_info']['is_end']: #如果是第一帧且为end则判断为杂音
raise SideNoiseError()
if self.websocket is None: #如果websocket未建立则建立一个新的连接
@ -188,6 +189,50 @@ class VOLCENGINE_LLM(LLM):
raise AbnormalLLMFrame(f"error volcengine llm_chunk:{llm_chunk}")
return llm_chunk.choices[0].delta.content
class ZHIPU_LLM(LLM):
def __init__(self):
self.token = 0
self.client = ZhipuAI(api_key=Config.ZHIPU_LLM.API_KEY)
async def chat(self, assistant, prompt):
llm_info = json.loads(assistant.llm_info)
messages = json.loads(assistant.messages)
messages.append({'role':'user','content':prompt})
stream = self.client.chat.completions.create(
model = llm_info['model'],
messages=messages,
stream=True,
temperature=llm_info['temperature'],
top_p=llm_info['top_p'],
max_tokens=llm_info['max_tokens']
)
for chunk in stream:
try:
chunk_msg = self.__parseChunk(chunk)
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame
except LLMResponseEnd:
msg_frame = {"is_end":True,"code":200,"msg":""}
assistant.token = self.token
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%则重置session
msg_frame['code'] = '201'
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
yield msg_frame
def __parseChunk(self, llm_chunk):
if llm_chunk.usage:
self.token = llm_chunk.usage.total_tokens
raise LLMResponseEnd()
if not llm_chunk.choices:
raise AbnormalLLMFrame(f"error zhipu llm_chunk:{llm_chunk}")
return llm_chunk.choices[0].delta.content
class MIXED_LLM(LLM):
def __init__(self):
self.minimax = MINIMAX_LLM()
self.volcengine = VOLCENGINE_LLM()
self.zhipu = ZHIPU_LLM()
class VITS_TTS(TTS):
def __init__(self):
pass
@ -211,6 +256,8 @@ class LLMFactory:
return MINIMAX_LLM()
if llm_type == 'VOLCENGINE':
return VOLCENGINE_LLM()
if llm_type == 'ZHIPU':
return ZHIPU_LLM()
class TTSFactory:
def create_tts(self,tts_type:str) -> TTS:
@ -324,8 +371,8 @@ class Agent():
return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder)
# 进行流式语音识别
async def stream_recognize(self, chunk):
return await self.asr.stream_recognize(chunk)
async def stream_recognize(self, assistant, chunk):
return await self.asr.stream_recognize(assistant,chunk)
# 进行Prompt加工
def prompt_process(self, asr_results):

View File

@ -35,3 +35,6 @@ class update_assistant_deatil_params_request(BaseModel):
temperature :float
speaker_id:int
length_scale:float
class update_assistant_max_tokens_request(BaseModel):
max_tokens:int

View File

@ -23,3 +23,5 @@ class Config:
DOUBAO_LITE_4k = "ep-20240612075552-5c7tk"
DOUBAO_LITE_32k = "ep-20240618130753-q85dm"
DOUBAO_PRO_32k = "ep-20240618145315-pm2c6"
class ZHIPU_LLM:
API_KEY = "8e7e14bcb66e772e19825d8211b3cc76.WSqQIst0deRMfUIG"

15
main.py
View File

@ -130,6 +130,19 @@ async def update_assistant_deatil_params(id: str,request: update_assistant_deati
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 更新max_tokens
@app.put("/api/assistants/{id}/max_tokens")
async def update_assistant_max_tokens(id: str,request: update_assistant_max_tokens_request,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
llm_info = json.loads(assistant.llm_info)
llm_info['max_tokens'] = request.max_tokens
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
@ -207,7 +220,7 @@ async def streaming_chat(ws: WebSocket):
agent = Agent(asr_type=user_info['asr_type'], llm_type=user_info['llm_type'], tts_type=user_info['tts_type'])
agent.init_recorder(assistant.user_id)
chunk["audio"] = agent.user_audio_process(chunk["audio"])
asr_results = await agent.stream_recognize(chunk)
asr_results = await agent.stream_recognize(assistant, chunk)
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
prompt = agent.prompt_process(asr_results)
agent.recorder.input_text = prompt