update: 增加了智谱大模型,增加了修改max_tokens接口
This commit is contained in:
parent
5a47440e0f
commit
6354fe49f0
|
@ -5,7 +5,7 @@ class ASR(ABC):
|
|||
def __init__(self):
|
||||
self.is_slience = False
|
||||
@abstractmethod
|
||||
async def stream_recognize(self, chunk):
|
||||
async def stream_recognize(self, assistant, chunk):
|
||||
pass
|
||||
|
||||
class LLM(ABC):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
|
||||
from volcenginesdkarkruntime import Ark
|
||||
from zhipuai import ZhipuAI
|
||||
from .model import Assistant
|
||||
from .abstract import *
|
||||
from .public import *
|
||||
|
@ -37,7 +38,7 @@ class XF_ASR(ASR):
|
|||
self.segment_duration_threshold = 25 #超时时间为25秒
|
||||
self.segment_start_time = None
|
||||
|
||||
async def stream_recognize(self, chunk):
|
||||
async def stream_recognize(self, assistant, chunk):
|
||||
if self.status == FIRST_FRAME and chunk['meta_info']['is_end']: #如果是第一帧,且为end,则判断为杂音
|
||||
raise SideNoiseError()
|
||||
if self.websocket is None: #如果websocket未建立,则建立一个新的连接
|
||||
|
@ -188,6 +189,50 @@ class VOLCENGINE_LLM(LLM):
|
|||
raise AbnormalLLMFrame(f"error volcengine llm_chunk:{llm_chunk}")
|
||||
return llm_chunk.choices[0].delta.content
|
||||
|
||||
class ZHIPU_LLM(LLM):
|
||||
def __init__(self):
|
||||
self.token = 0
|
||||
self.client = ZhipuAI(api_key=Config.ZHIPU_LLM.API_KEY)
|
||||
|
||||
async def chat(self, assistant, prompt):
|
||||
llm_info = json.loads(assistant.llm_info)
|
||||
messages = json.loads(assistant.messages)
|
||||
messages.append({'role':'user','content':prompt})
|
||||
stream = self.client.chat.completions.create(
|
||||
model = llm_info['model'],
|
||||
messages=messages,
|
||||
stream=True,
|
||||
temperature=llm_info['temperature'],
|
||||
top_p=llm_info['top_p'],
|
||||
max_tokens=llm_info['max_tokens']
|
||||
)
|
||||
for chunk in stream:
|
||||
try:
|
||||
chunk_msg = self.__parseChunk(chunk)
|
||||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||
yield msg_frame
|
||||
except LLMResponseEnd:
|
||||
msg_frame = {"is_end":True,"code":200,"msg":""}
|
||||
assistant.token = self.token
|
||||
if self.token > llm_info['max_tokens'] * 0.8: #如果token超过80%,则重置session
|
||||
msg_frame['code'] = '201'
|
||||
assistant.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||
yield msg_frame
|
||||
|
||||
def __parseChunk(self, llm_chunk):
|
||||
if llm_chunk.usage:
|
||||
self.token = llm_chunk.usage.total_tokens
|
||||
raise LLMResponseEnd()
|
||||
if not llm_chunk.choices:
|
||||
raise AbnormalLLMFrame(f"error zhipu llm_chunk:{llm_chunk}")
|
||||
return llm_chunk.choices[0].delta.content
|
||||
|
||||
class MIXED_LLM(LLM):
|
||||
def __init__(self):
|
||||
self.minimax = MINIMAX_LLM()
|
||||
self.volcengine = VOLCENGINE_LLM()
|
||||
self.zhipu = ZHIPU_LLM()
|
||||
|
||||
class VITS_TTS(TTS):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
@ -211,6 +256,8 @@ class LLMFactory:
|
|||
return MINIMAX_LLM()
|
||||
if llm_type == 'VOLCENGINE':
|
||||
return VOLCENGINE_LLM()
|
||||
if llm_type == 'ZHIPU':
|
||||
return ZHIPU_LLM()
|
||||
|
||||
class TTSFactory:
|
||||
def create_tts(self,tts_type:str) -> TTS:
|
||||
|
@ -324,8 +371,8 @@ class Agent():
|
|||
return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder)
|
||||
|
||||
# 进行流式语音识别
|
||||
async def stream_recognize(self, chunk):
|
||||
return await self.asr.stream_recognize(chunk)
|
||||
async def stream_recognize(self, assistant, chunk):
|
||||
return await self.asr.stream_recognize(assistant,chunk)
|
||||
|
||||
# 进行Prompt加工
|
||||
def prompt_process(self, asr_results):
|
||||
|
|
|
@ -34,4 +34,7 @@ class update_assistant_deatil_params_request(BaseModel):
|
|||
model :str
|
||||
temperature :float
|
||||
speaker_id:int
|
||||
length_scale:float
|
||||
length_scale:float
|
||||
|
||||
class update_assistant_max_tokens_request(BaseModel):
|
||||
max_tokens:int
|
|
@ -22,4 +22,6 @@ class Config:
|
|||
API_KEY = "a1bf964c-5c12-4d2b-ad97-85893e14d55d"
|
||||
DOUBAO_LITE_4k = "ep-20240612075552-5c7tk"
|
||||
DOUBAO_LITE_32k = "ep-20240618130753-q85dm"
|
||||
DOUBAO_PRO_32k = "ep-20240618145315-pm2c6"
|
||||
DOUBAO_PRO_32k = "ep-20240618145315-pm2c6"
|
||||
class ZHIPU_LLM:
|
||||
API_KEY = "8e7e14bcb66e772e19825d8211b3cc76.WSqQIst0deRMfUIG"
|
15
main.py
15
main.py
|
@ -130,6 +130,19 @@ async def update_assistant_deatil_params(id: str,request: update_assistant_deati
|
|||
return {"code":200,"msg":"success","data":{}}
|
||||
else:
|
||||
return {"code":404,'msg':"assistant not found","data":{}}
|
||||
|
||||
# 更新max_tokens
|
||||
@app.put("/api/assistants/{id}/max_tokens")
|
||||
async def update_assistant_max_tokens(id: str,request: update_assistant_max_tokens_request,db=Depends(get_db)):
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if assistant:
|
||||
llm_info = json.loads(assistant.llm_info)
|
||||
llm_info['max_tokens'] = request.max_tokens
|
||||
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
|
||||
db.commit()
|
||||
return {"code":200,"msg":"success","data":{}}
|
||||
else:
|
||||
return {"code":404,'msg':"assistant not found","data":{}}
|
||||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
@ -207,7 +220,7 @@ async def streaming_chat(ws: WebSocket):
|
|||
agent = Agent(asr_type=user_info['asr_type'], llm_type=user_info['llm_type'], tts_type=user_info['tts_type'])
|
||||
agent.init_recorder(assistant.user_id)
|
||||
chunk["audio"] = agent.user_audio_process(chunk["audio"])
|
||||
asr_results = await agent.stream_recognize(chunk)
|
||||
asr_results = await agent.stream_recognize(assistant, chunk)
|
||||
kid_text = asr_results[0]['text'] #asr结果的[0]默认为孩子(主要用户)的asr结果
|
||||
prompt = agent.prompt_process(asr_results)
|
||||
agent.recorder.input_text = prompt
|
||||
|
|
Loading…
Reference in New Issue