diff --git a/app/abstract.py b/app/abstract.py index 13c8eef..e622f70 100644 --- a/app/abstract.py +++ b/app/abstract.py @@ -2,6 +2,8 @@ from abc import ABC, abstractmethod #------ 抽象 ASR, LLM, TTS 类 ------ # class ASR(ABC): + def __init__(self): + self.is_slience = False @abstractmethod async def stream_recognize(self, chunk): pass diff --git a/app/concrete.py b/app/concrete.py index b218811..8263b38 100644 --- a/app/concrete.py +++ b/app/concrete.py @@ -29,6 +29,7 @@ LAST_FRAME =3 class XF_ASR(ASR): def __init__(self): + super().__init__() self.websocket = None self.current_message = "" self.audio = "" @@ -58,6 +59,16 @@ class XF_ASR(ASR): self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv())) if self.current_message == "": raise AsrResultNoneError() + if "进入沉默模式" in self.current_message: + self.is_slience = True + asyncio.create_task(self.websocket.close()) + raise EnterSlienceMode() + if "退出沉默模式" in self.current_message: + self.is_slience = False + self.current_message = "已退出沉默模式" + if self.is_slience: + asyncio.create_task(self.websocket.close()) + raise SlienceMode() asyncio.create_task(self.websocket.close()) return [{"text":self.current_message, "audio":self.audio}] @@ -145,6 +156,9 @@ class VOLCENGINE_LLM(LLM): model = model, messages=messages, stream=True, + temperature=llm_info['temperature'], + top_p=llm_info['top_p'], + max_tokens=llm_info['max_tokens'], stream_options={'include_usage': True} ) for chunk in stream: diff --git a/app/exception.py b/app/exception.py index dd0e83b..f5fdbcd 100644 --- a/app/exception.py +++ b/app/exception.py @@ -51,3 +51,15 @@ class LLMResponseEnd(Exception): def __init__(self, message="LLM Response End!"): super().__init__(message) self.message = message + +# 进入静音模式(非异常) +class EnterSlienceMode(Exception): + def __init__(self, message="Enter Slience Mode!"): + super().__init__(message) + self.message = message + +# 处于静音模式(非异常) +class SlienceMode(Exception): + def __init__(self, message="Slience Mode!"): + super().__init__(message) + self.message = message diff --git a/main.py b/main.py index 200136b..4962d82 100644 --- a/main.py +++ b/main.py @@ -221,6 +221,12 @@ async def streaming_chat(ws: WebSocket,db=Depends(get_db)): agent.recorder.output_text = llm_text agent.save() logger.debug("音频保存成功") + except EnterSlienceMode: + tts_audio = agent.synthetize(assistant, "已进入沉默模式", db) + await ws.send_bytes(agent.encode("已进入沉默模式", tts_audio)) + await ws.send_text(json.dumps({"type":"info","code":201,"msg":"进入沉默模式"}, ensure_ascii=False)) + except SlienceMode: + await ws.send_text(json.dumps({"type":"info","code":201,"msg":"处于沉默模式"}, ensure_ascii=False)) except AsrResultNoneError: await ws.send_text(json.dumps({"type":"error","code":501,"msg":"asr结果为空"}, ensure_ascii=False)) except AbnormalLLMFrame as e: