forked from killua/TakwayPlatform
613 lines
28 KiB
Python
613 lines
28 KiB
Python
|
from ..schemas.chat import *
|
|||
|
from ..dependencies.logger import get_logger
|
|||
|
from .controller_enum import *
|
|||
|
from ..models import UserCharacter, Session, Character, User
|
|||
|
from utils.audio_utils import VAD
|
|||
|
from fastapi import WebSocket, HTTPException, status
|
|||
|
from datetime import datetime
|
|||
|
from utils.xf_asr_utils import generate_xf_asr_url
|
|||
|
from config import get_config
|
|||
|
import uuid
|
|||
|
import json
|
|||
|
import requests
|
|||
|
import asyncio
|
|||
|
|
|||
|
# 依赖注入获取logger
|
|||
|
logger = get_logger()
|
|||
|
|
|||
|
# --------------------初始化本地ASR-----------------------
|
|||
|
from utils.stt.funasr_utils import FunAutoSpeechRecognizer
|
|||
|
|
|||
|
asr = FunAutoSpeechRecognizer()
|
|||
|
logger.info("本地ASR初始化成功")
|
|||
|
# -------------------------------------------------------
|
|||
|
|
|||
|
# --------------------初始化本地VITS----------------------
|
|||
|
from utils.tts.vits_utils import TextToSpeech
|
|||
|
|
|||
|
tts = TextToSpeech(device='cpu')
|
|||
|
logger.info("本地TTS初始化成功")
|
|||
|
# -------------------------------------------------------
|
|||
|
|
|||
|
|
|||
|
# 依赖注入获取Config
|
|||
|
Config = get_config()
|
|||
|
|
|||
|
# ----------------------工具函数-------------------------
|
|||
|
#获取session内容
|
|||
|
def get_session_content(session_id,redis,db):
|
|||
|
session_content_str = ""
|
|||
|
if redis.exists(session_id):
|
|||
|
session_content_str = redis.get(session_id)
|
|||
|
else:
|
|||
|
session_db = db.query(Session).filter(Session.id == session_id).first()
|
|||
|
if not session_db:
|
|||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
|
|||
|
session_content_str = session_db.content
|
|||
|
return json.loads(session_content_str)
|
|||
|
|
|||
|
#解析大模型流式返回内容
|
|||
|
def parseChunkDelta(chunk):
|
|||
|
decoded_data = chunk.decode('utf-8')
|
|||
|
parsed_data = json.loads(decoded_data[6:])
|
|||
|
if 'delta' in parsed_data['choices'][0]:
|
|||
|
delta_content = parsed_data['choices'][0]['delta']
|
|||
|
return delta_content['content']
|
|||
|
else:
|
|||
|
return ""
|
|||
|
|
|||
|
#断句函数
|
|||
|
def split_string_with_punctuation(current_sentence,text,is_first):
|
|||
|
result = []
|
|||
|
for char in text:
|
|||
|
current_sentence += char
|
|||
|
if is_first and char in ',.?!,。?!':
|
|||
|
result.append(current_sentence)
|
|||
|
current_sentence = ''
|
|||
|
is_first = False
|
|||
|
elif char in '。?!':
|
|||
|
result.append(current_sentence)
|
|||
|
current_sentence = ''
|
|||
|
return result, current_sentence, is_first
|
|||
|
#--------------------------------------------------------
|
|||
|
|
|||
|
# 创建新聊天
|
|||
|
async def create_chat_handler(chat: ChatCreateRequest, db, redis):
|
|||
|
# 创建新的UserCharacter记录
|
|||
|
new_chat = UserCharacter(user_id=chat.user_id, character_id=chat.character_id)
|
|||
|
|
|||
|
try:
|
|||
|
db.add(new_chat)
|
|||
|
db.commit()
|
|||
|
db.refresh(new_chat)
|
|||
|
except Exception as e:
|
|||
|
db.rollback()
|
|||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|||
|
|
|||
|
# 查询所要创建聊天的角色信息,并创建SystemPrompt
|
|||
|
db_character = db.query(Character).filter(Character.id == chat.character_id).first()
|
|||
|
db_user = db.query(User).filter(User.id == chat.user_id).first()
|
|||
|
system_prompt = f"""我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\n{"角色名称: " + db_character.name}。\n{"角色背景: " + db_character.description}\n{"角色所处环境: " + db_character.world_scenario}\n
|
|||
|
{"角色的常用问候语: " + db_character.wakeup_words}。\n你需要用简单、通俗易懂的口语化方式进行对话,在没有经过允许的情况下,你需要保持上述角色,不得擅自跳出角色设定。\n 与你聊天的对象信息如下:{db_user.persona}"""
|
|||
|
|
|||
|
|
|||
|
# 创建新的Session记录
|
|||
|
session_id = str(uuid.uuid4())
|
|||
|
user_id = chat.user_id
|
|||
|
messages = json.dumps([{"role": "system", "content": system_prompt}], ensure_ascii=False)
|
|||
|
|
|||
|
tts_info = {
|
|||
|
"language": 0,
|
|||
|
"speaker_id":db_character.voice_id,
|
|||
|
"noise_scale": 0.1,
|
|||
|
"noise_scale_w":0.668,
|
|||
|
"length_scale": 1.2
|
|||
|
}
|
|||
|
llm_info = {
|
|||
|
"model": "abab5.5-chat",
|
|||
|
"temperature": 1,
|
|||
|
"top_p": 0.9,
|
|||
|
}
|
|||
|
|
|||
|
# 将tts和llm信息转化为json字符串
|
|||
|
tts_info_str = json.dumps(tts_info, ensure_ascii=False)
|
|||
|
llm_info_str = json.dumps(llm_info, ensure_ascii=False)
|
|||
|
user_info_str = db_user.persona
|
|||
|
|
|||
|
token = 0
|
|||
|
content = {"user_id": user_id, "messages": messages, "user_info": user_info_str, "tts_info": tts_info_str,
|
|||
|
"llm_info": llm_info_str, "token": token}
|
|||
|
new_session = Session(id=session_id, user_character_id=new_chat.id, content=json.dumps(content, ensure_ascii=False),
|
|||
|
last_activity=datetime.now(), is_permanent=False)
|
|||
|
|
|||
|
# 将Session记录存入
|
|||
|
db.add(new_session)
|
|||
|
db.commit()
|
|||
|
db.refresh(new_session)
|
|||
|
redis.set(session_id, json.dumps(content, ensure_ascii=False))
|
|||
|
|
|||
|
chat_create_data = ChatCreateData(user_character_id=new_chat.id, session_id=session_id, createdAt=datetime.now().isoformat())
|
|||
|
return ChatCreateResponse(status="success", message="创建聊天成功", data=chat_create_data)
|
|||
|
|
|||
|
|
|||
|
#删除聊天
|
|||
|
async def delete_chat_handler(user_character_id, db, redis):
|
|||
|
# 查询该聊天记录
|
|||
|
user_character_record = db.query(UserCharacter).filter(UserCharacter.id == user_character_id).first()
|
|||
|
if not user_character_record:
|
|||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="UserCharacter not found")
|
|||
|
session_record = db.query(Session).filter(Session.user_character_id == user_character_id).first()
|
|||
|
try:
|
|||
|
redis.delete(session_record.id)
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"删除Redis中Session记录时发生错误: {str(e)}")
|
|||
|
try:
|
|||
|
db.delete(session_record)
|
|||
|
db.delete(user_character_record)
|
|||
|
db.commit()
|
|||
|
except Exception as e:
|
|||
|
db.rollback()
|
|||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|||
|
chat_delete_data = ChatDeleteData(deletedAt=datetime.now().isoformat())
|
|||
|
return ChatDeleteResponse(status="success", message="删除聊天成功", data=chat_delete_data)
|
|||
|
|
|||
|
|
|||
|
# 非流式聊天
|
|||
|
async def non_streaming_chat_handler(chat: ChatNonStreamRequest, db, redis):
|
|||
|
pass
|
|||
|
|
|||
|
|
|||
|
|
|||
|
|
|||
|
#---------------------------------------单次流式聊天接口---------------------------------------------
|
|||
|
#处理用户输入
|
|||
|
async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,user_input_finish_event):
|
|||
|
logger.debug("用户输入处理函数启动")
|
|||
|
is_future_done = False
|
|||
|
try:
|
|||
|
while not user_input_finish_event.is_set():
|
|||
|
sct_data_json = json.loads(await ws.receive_text())
|
|||
|
if not is_future_done:
|
|||
|
future_session_id.set_result(sct_data_json['meta_info']['session_id'])
|
|||
|
if sct_data_json['meta_info']['voice_synthesize']:
|
|||
|
future_response_type.set_result(RESPONSE_AUDIO)
|
|||
|
else:
|
|||
|
future_response_type.set_result(RESPONSE_TEXT)
|
|||
|
is_future_done = True
|
|||
|
if sct_data_json['text']:
|
|||
|
await llm_input_q.put(sct_data_json['text'])
|
|||
|
if not user_input_finish_event.is_set():
|
|||
|
user_input_finish_event.set()
|
|||
|
break
|
|||
|
if sct_data_json['meta_info']['is_end']:
|
|||
|
await user_input_q.put(sct_data_json['audio'])
|
|||
|
if not user_input_finish_event.is_set():
|
|||
|
user_input_finish_event.set()
|
|||
|
break
|
|||
|
await user_input_q.put(sct_data_json['audio'])
|
|||
|
except KeyError as ke:
|
|||
|
if sct_data_json['state'] == 1 and sct_data_json['method'] == 'heartbeat':
|
|||
|
logger.debug("收到心跳包")
|
|||
|
|
|||
|
#语音识别
|
|||
|
async def sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event):
|
|||
|
logger.debug("语音识别函数启动")
|
|||
|
current_message = ""
|
|||
|
while not (user_input_finish_event.is_set() and user_input_q.empty()):
|
|||
|
audio_data = await user_input_q.get()
|
|||
|
asr_result = asr.streaming_recognize(audio_data)
|
|||
|
current_message += ''.join(asr_result['text'])
|
|||
|
asr_result = asr.streaming_recognize(b'',is_end=True)
|
|||
|
current_message += ''.join(asr_result['text'])
|
|||
|
await llm_input_q.put(current_message)
|
|||
|
logger.debug(f"接收到用户消息: {current_message}")
|
|||
|
|
|||
|
#大模型调用
|
|||
|
async def sct_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,llm_response_finish_event):
|
|||
|
logger.debug("llm调用函数启动")
|
|||
|
session_content = get_session_content(session_id,redis,db)
|
|||
|
messages = json.loads(session_content["messages"])
|
|||
|
current_message = await llm_input_q.get()
|
|||
|
messages.append({'role': 'user', "content": current_message})
|
|||
|
payload = json.dumps({
|
|||
|
"model": llm_info["model"],
|
|||
|
"stream": True,
|
|||
|
"messages": messages,
|
|||
|
"max_tokens": 10000,
|
|||
|
"temperature": llm_info["temperature"],
|
|||
|
"top_p": llm_info["top_p"]
|
|||
|
})
|
|||
|
headers = {
|
|||
|
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
|||
|
'Content-Type': 'application/json'
|
|||
|
}
|
|||
|
response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True)
|
|||
|
if response.status_code == 200:
|
|||
|
for chunk in response.iter_lines():
|
|||
|
if chunk:
|
|||
|
chunk_data = parseChunkDelta(chunk)
|
|||
|
await llm_response_q.put(chunk_data)
|
|||
|
llm_response_finish_event.set()
|
|||
|
|
|||
|
#大模型返回断句
|
|||
|
async def sct_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_response_finish_event):
|
|||
|
logger.debug("llm返回处理函数启动")
|
|||
|
llm_response = ""
|
|||
|
current_sentence = ""
|
|||
|
is_first = True
|
|||
|
while not (llm_response_finish_event.is_set() and llm_response_q.empty()):
|
|||
|
llm_chunk = await llm_response_q.get()
|
|||
|
llm_response += llm_chunk
|
|||
|
sentences, current_sentence, is_first = split_string_with_punctuation(current_sentence, llm_chunk, is_first)
|
|||
|
for sentence in sentences:
|
|||
|
await split_result_q.put(sentence)
|
|||
|
session_content = get_session_content(session_id,redis,db)
|
|||
|
messages = json.loads(session_content["messages"])
|
|||
|
messages.append({'role': 'assistant', "content": llm_response})
|
|||
|
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
|||
|
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
|||
|
logger.debug(f"llm返回结果: {llm_response}")
|
|||
|
|
|||
|
#文本返回及语音合成
|
|||
|
async def sct_response_handler(ws,tts_info,response_type,split_result_q,llm_response_finish_event,chat_finish_event):
|
|||
|
logger.debug("返回处理函数启动")
|
|||
|
while not (llm_response_finish_event.is_set() and split_result_q.empty()):
|
|||
|
sentence = await split_result_q.get()
|
|||
|
if response_type == RESPONSE_TEXT:
|
|||
|
response_message = {"type": "text", "code":200, "msg": sentence}
|
|||
|
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
|||
|
elif response_type == RESPONSE_AUDIO:
|
|||
|
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
|
|||
|
response_message = {"type": "text", "code":200, "msg": sentence}
|
|||
|
await ws.send_bytes(audio)
|
|||
|
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
|||
|
logger.debug(f"websocket返回: {sentence}")
|
|||
|
chat_finish_event.set()
|
|||
|
|
|||
|
async def streaming_chat_temporary_handler(ws: WebSocket, db, redis):
|
|||
|
logger.debug("streaming chat temporary websocket 连接建立")
|
|||
|
user_input_q = asyncio.Queue() # 用于存储用户输入
|
|||
|
llm_input_q = asyncio.Queue() # 用于存储llm输入
|
|||
|
llm_response_q = asyncio.Queue() # 用于存储llm输出
|
|||
|
split_result_q = asyncio.Queue() # 用于存储tts输出
|
|||
|
|
|||
|
user_input_finish_event = asyncio.Event()
|
|||
|
llm_response_finish_event = asyncio.Event()
|
|||
|
chat_finish_event = asyncio.Event()
|
|||
|
future_session_id = asyncio.Future()
|
|||
|
future_response_type = asyncio.Future()
|
|||
|
asyncio.create_task(sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,user_input_finish_event))
|
|||
|
asyncio.create_task(sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event))
|
|||
|
|
|||
|
|
|||
|
session_id = await future_session_id #获取session_id
|
|||
|
response_type = await future_response_type #获取返回类型
|
|||
|
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
|||
|
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
|||
|
|
|||
|
asyncio.create_task(sct_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,llm_response_finish_event))
|
|||
|
asyncio.create_task(sct_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_response_finish_event))
|
|||
|
asyncio.create_task(sct_response_handler(ws,tts_info,response_type,split_result_q,llm_response_finish_event,chat_finish_event))
|
|||
|
|
|||
|
while not chat_finish_event.is_set():
|
|||
|
await asyncio.sleep(1)
|
|||
|
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
|
|||
|
await ws.close()
|
|||
|
logger.debug("streaming chat temporary websocket 连接断开")
|
|||
|
#---------------------------------------------------------------------------------------------------
|
|||
|
|
|||
|
|
|||
|
# 持续流式聊天
|
|||
|
async def streaming_chat_lasting_handler(ws, db, redis):
|
|||
|
print("Websocket连接成功")
|
|||
|
while True:
|
|||
|
try:
|
|||
|
print("等待接受")
|
|||
|
data = await asyncio.wait_for(ws.receive_text(), timeout=60)
|
|||
|
data_json = json.loads(data)
|
|||
|
if data_json["is_close"]:
|
|||
|
close_message = {"type": "close", "code": 200, "msg": ""}
|
|||
|
await ws.send_text(json.dumps(close_message, ensure_ascii=False))
|
|||
|
print("连接关闭")
|
|||
|
await asyncio.sleep(0.5)
|
|||
|
await ws.close()
|
|||
|
return;
|
|||
|
except asyncio.TimeoutError:
|
|||
|
print("连接超时")
|
|||
|
await ws.close()
|
|||
|
return;
|
|||
|
current_message = "" # 用于存储用户消息
|
|||
|
response_type = RESPONSE_TEXT # 用于获取返回类型
|
|||
|
session_id = ""
|
|||
|
if Config.STRAM_CHAT.ASR == "LOCAL":
|
|||
|
try:
|
|||
|
while True:
|
|||
|
if data_json["text"]: # 若文字输入不为空,则表示该输入为文字输入
|
|||
|
if data_json["meta_info"]["voice_synthesize"]:
|
|||
|
response_type = RESPONSE_AUDIO # 查看voice_synthesize判断返回类型
|
|||
|
session_id = data_json["meta_info"]["session_id"]
|
|||
|
current_message = data_json['text']
|
|||
|
break
|
|||
|
|
|||
|
if not data_json['meta_info']['is_end']: # 还在发
|
|||
|
asr_result = asr.streaming_recognize(data_json["audio"])
|
|||
|
current_message += ''.join(asr_result['text'])
|
|||
|
else: # 发完了
|
|||
|
asr_result = asr.streaming_recognize(data_json["audio"], is_end=True)
|
|||
|
session_id = data_json["meta_info"]["session_id"]
|
|||
|
current_message += ''.join(asr_result['text'])
|
|||
|
if data_json["meta_info"]["voice_synthesize"]:
|
|||
|
response_type = RESPONSE_AUDIO # 查看voice_synthesize判断返回类型
|
|||
|
break
|
|||
|
data_json = json.loads(await ws.receive_text())
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
error_info = f"接收用户消息错误: {str(e)}"
|
|||
|
error_message = {"type": "error", "code": "500", "msg": error_info}
|
|||
|
logger.error(error_info)
|
|||
|
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
|
|||
|
await ws.close()
|
|||
|
return
|
|||
|
elif Config.STRAM_CHAT.ASR == "REMOTE":
|
|||
|
error_info = f"远程ASR服务暂未开通"
|
|||
|
error_message = {"type": "error", "code": "500", "msg": error_info}
|
|||
|
logger.error(error_info)
|
|||
|
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
|
|||
|
await ws.close()
|
|||
|
return
|
|||
|
|
|||
|
print(f"接收到用户消息: {current_message}")
|
|||
|
# 查询Session记录
|
|||
|
session_content_str = ""
|
|||
|
if redis.exists(session_id):
|
|||
|
session_content_str = redis.get(session_id)
|
|||
|
else:
|
|||
|
session_db = db.query(Session).filter(Session.id == session_id).first()
|
|||
|
if not session_db:
|
|||
|
error_info = f"未找到session记录: {str(e)}"
|
|||
|
error_message = {"type": "error", "code": 500, "msg": error_info}
|
|||
|
logger.error(error_info)
|
|||
|
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
|
|||
|
await ws.close()
|
|||
|
return
|
|||
|
session_content_str = session_db.content
|
|||
|
|
|||
|
session_content = json.loads(session_content_str)
|
|||
|
llm_info = json.loads(session_content["llm_info"])
|
|||
|
tts_info = json.loads(session_content["tts_info"])
|
|||
|
user_info = json.loads(session_content["user_info"])
|
|||
|
messages = json.loads(session_content["messages"])
|
|||
|
messages.append({'role': 'user', "content": current_message})
|
|||
|
token_count = session_content["token"]
|
|||
|
|
|||
|
try:
|
|||
|
payload = json.dumps({
|
|||
|
"model": llm_info["model"],
|
|||
|
"stream": True,
|
|||
|
"messages": messages,
|
|||
|
"tool_choice": "auto",
|
|||
|
"max_tokens": 10000,
|
|||
|
"temperature": llm_info["temperature"],
|
|||
|
"top_p": llm_info["top_p"]
|
|||
|
})
|
|||
|
headers = {
|
|||
|
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
|||
|
'Content-Type': 'application/json'
|
|||
|
}
|
|||
|
response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True)
|
|||
|
except Exception as e:
|
|||
|
error_info = f"发送信息给大模型时发生错误: {str(e)}"
|
|||
|
error_message = {"type": "error", "code": 500, "msg": error_info}
|
|||
|
logger.error(error_info)
|
|||
|
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
|
|||
|
await ws.close()
|
|||
|
return
|
|||
|
|
|||
|
def split_string_with_punctuation(text):
|
|||
|
punctuations = "!?。"
|
|||
|
result = []
|
|||
|
current_sentence = ""
|
|||
|
for char in text:
|
|||
|
current_sentence += char
|
|||
|
if char in punctuations:
|
|||
|
result.append(current_sentence)
|
|||
|
current_sentence = ""
|
|||
|
# 判断最后一个字符是否为标点符号
|
|||
|
if current_sentence and current_sentence[-1] not in punctuations:
|
|||
|
# 如果最后一段不以标点符号结尾,则加入拆分数组
|
|||
|
result.append(current_sentence)
|
|||
|
return result
|
|||
|
|
|||
|
llm_response = ""
|
|||
|
response_buf = ""
|
|||
|
|
|||
|
try:
|
|||
|
if Config.STRAM_CHAT.TTS == "LOCAL":
|
|||
|
if response.status_code == 200:
|
|||
|
for chunk in response.iter_lines():
|
|||
|
if chunk:
|
|||
|
if response_type == RESPONSE_AUDIO:
|
|||
|
chunk_data = parseChunkDelta(chunk)
|
|||
|
llm_response += chunk_data
|
|||
|
response_buf += chunk_data
|
|||
|
split_buf = split_string_with_punctuation(response_buf)
|
|||
|
response_buf = ""
|
|||
|
if len(split_buf) != 0:
|
|||
|
for sentence in split_buf:
|
|||
|
sr, audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
|
|||
|
text_response = {"type": "text", "code": 200, "msg": sentence}
|
|||
|
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) # 返回文本数据
|
|||
|
await ws.send_bytes(audio) # 返回音频二进制流数据
|
|||
|
if response_type == RESPONSE_TEXT:
|
|||
|
chunk_data = parseChunkDelta(chunk)
|
|||
|
llm_response += chunk_data
|
|||
|
text_response = {"type": "text", "code": 200, "msg": chunk_data}
|
|||
|
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) # 返回文本数据
|
|||
|
|
|||
|
elif Config.STRAM_CHAT.TTS == "REMOTE":
|
|||
|
error_info = f"暂不支持远程音频合成"
|
|||
|
error_message = {"type": "error", "code": 500, "msg": error_info}
|
|||
|
logger.error(error_info)
|
|||
|
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
|
|||
|
await ws.close()
|
|||
|
return
|
|||
|
end_response = {"type": "end", "code": 200, "msg": ""}
|
|||
|
await ws.send_text(json.dumps(end_response, ensure_ascii=False)) # 单次返回结束
|
|||
|
print(f"llm消息: {llm_response}")
|
|||
|
except Exception as e:
|
|||
|
error_info = f"音频合成与向前端返回时错误: {str(e)}"
|
|||
|
error_message = {"type": "error", "code": 500, "msg": error_info}
|
|||
|
logger.error(error_info)
|
|||
|
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
|
|||
|
await ws.close()
|
|||
|
return
|
|||
|
|
|||
|
try:
|
|||
|
messages.append({'role': 'assistant', "content": llm_response})
|
|||
|
token_count += len(llm_response)
|
|||
|
session_content["messages"] = json.dumps(messages, ensure_ascii=False)
|
|||
|
session_content["token"] = token_count
|
|||
|
redis.set(session_id, json.dumps(session_content, ensure_ascii=False))
|
|||
|
except Exception as e:
|
|||
|
error_info = f"更新session时错误: {str(e)}"
|
|||
|
error_message = {"type": "error", "code": 500, "msg": error_info}
|
|||
|
logger.error(error_info)
|
|||
|
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
|
|||
|
await ws.close()
|
|||
|
return
|
|||
|
print("处理完毕")
|
|||
|
|
|||
|
|
|||
|
#--------------------------------语音通话接口--------------------------------------
|
|||
|
#音频数据生产函数
|
|||
|
async def voice_call_audio_producer(ws,audio_queue,future,shutdown_event):
|
|||
|
logger.debug("音频数据生产函数启动")
|
|||
|
is_future_done = False
|
|||
|
while not shutdown_event.is_set():
|
|||
|
voice_call_data_json = json.loads(await ws.receive_text())
|
|||
|
if not is_future_done: #在第一次循环中读取session_id
|
|||
|
future.set_result(voice_call_data_json['meta_info']['session_id'])
|
|||
|
is_future_done = True
|
|||
|
if voice_call_data_json["is_close"]:
|
|||
|
shutdown_event.set()
|
|||
|
break
|
|||
|
else:
|
|||
|
await audio_queue.put(voice_call_data_json["audio"]) #将音频数据存入audio_q
|
|||
|
|
|||
|
#音频数据消费函数
|
|||
|
async def voice_call_audio_consumer(audio_q,asr_result_q,shutdown_event):
|
|||
|
logger.debug("音频数据消费者函数启动")
|
|||
|
vad = VAD()
|
|||
|
current_message = ""
|
|||
|
vad_count = 0
|
|||
|
while not (shutdown_event.is_set() and audio_q.empty()):
|
|||
|
audio_data = await audio_q.get()
|
|||
|
if vad.is_speech(audio_data):
|
|||
|
if vad_count > 0:
|
|||
|
vad_count -= 1
|
|||
|
asr_result = asr.streaming_recognize(audio_data)
|
|||
|
current_message += ''.join(asr_result['text'])
|
|||
|
else:
|
|||
|
vad_count += 1
|
|||
|
if vad_count >= 25: #连续25帧没有语音,则认为说完了
|
|||
|
asr_result = asr.streaming_recognize(audio_data, is_end=True)
|
|||
|
if current_message:
|
|||
|
logger.debug(f"检测到静默,用户输入为:{current_message}")
|
|||
|
await asr_result_q.put(current_message)
|
|||
|
current_message = ""
|
|||
|
vad_count = 0
|
|||
|
|
|||
|
#asr结果消费以及llm返回生产函数
|
|||
|
async def voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,shutdown_event):
|
|||
|
logger.debug("asr结果消费以及llm返回生产函数启动")
|
|||
|
while not (shutdown_event.is_set() and asr_result_q.empty()):
|
|||
|
session_content = get_session_content(session_id,redis,db)
|
|||
|
messages = json.loads(session_content["messages"])
|
|||
|
current_message = await asr_result_q.get()
|
|||
|
messages.append({'role': 'user', "content": current_message})
|
|||
|
payload = json.dumps({
|
|||
|
"model": llm_info["model"],
|
|||
|
"stream": True,
|
|||
|
"messages": messages,
|
|||
|
"max_tokens":10000,
|
|||
|
"temperature": llm_info["temperature"],
|
|||
|
"top_p": llm_info["top_p"]
|
|||
|
})
|
|||
|
|
|||
|
headers = {
|
|||
|
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
|||
|
'Content-Type': 'application/json'
|
|||
|
}
|
|||
|
response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True)
|
|||
|
if response.status_code == 200:
|
|||
|
for chunk in response.iter_lines():
|
|||
|
if chunk:
|
|||
|
chunk_data =parseChunkDelta(chunk)
|
|||
|
llm_frame = {'message':chunk_data,'is_end':False}
|
|||
|
await llm_response_q.put(llm_frame)
|
|||
|
llm_frame = {'message':"",'is_end':True}
|
|||
|
await llm_response_q.put(llm_frame)
|
|||
|
|
|||
|
#llm结果返回函数
|
|||
|
async def voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,shutdown_event):
|
|||
|
logger.debug("llm结果返回函数启动")
|
|||
|
llm_response = ""
|
|||
|
current_sentence = ""
|
|||
|
is_first = True
|
|||
|
while not (shutdown_event.is_set() and llm_response_q.empty()):
|
|||
|
llm_frame = await llm_response_q.get()
|
|||
|
llm_response += llm_frame['message']
|
|||
|
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,llm_frame['message'],is_first)
|
|||
|
for sentence in sentences:
|
|||
|
await split_result_q.put(sentence)
|
|||
|
if llm_frame['is_end']:
|
|||
|
is_first = True
|
|||
|
session_content = get_session_content(session_id,redis,db)
|
|||
|
messages = json.loads(session_content["messages"])
|
|||
|
messages.append({'role': 'assistant', "content": llm_response})
|
|||
|
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
|||
|
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
|||
|
logger.debug(f"llm返回结果: {llm_response}")
|
|||
|
llm_response = ""
|
|||
|
current_sentence = ""
|
|||
|
|
|||
|
#语音合成及返回函数
|
|||
|
async def voice_call_tts_handler(ws,tts_info,split_result_q,shutdown_event):
|
|||
|
logger.debug("语音合成及返回函数启动")
|
|||
|
while not (shutdown_event.is_set() and split_result_q.empty()):
|
|||
|
sentence = await split_result_q.get()
|
|||
|
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
|
|||
|
text_response = {"type": "text", "code": 200, "msg": sentence}
|
|||
|
await ws.send_bytes(audio) #返回音频二进制流数据
|
|||
|
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
|||
|
logger.debug(f"websocket返回:{sentence}")
|
|||
|
asyncio.sleep(0.5)
|
|||
|
await ws.close()
|
|||
|
|
|||
|
|
|||
|
async def voice_call_handler(ws, db, redis):
|
|||
|
logger.debug("voice_call websocket 连接建立")
|
|||
|
audio_q = asyncio.Queue()
|
|||
|
asr_result_q = asyncio.Queue()
|
|||
|
llm_response_q = asyncio.Queue()
|
|||
|
split_result_q = asyncio.Queue()
|
|||
|
|
|||
|
shutdown_event = asyncio.Event()
|
|||
|
future = asyncio.Future()
|
|||
|
asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,shutdown_event)) #创建音频数据生产者
|
|||
|
asyncio.create_task(voice_call_audio_consumer(audio_q,asr_result_q,shutdown_event)) #创建音频数据消费者
|
|||
|
|
|||
|
#获取session内容
|
|||
|
session_id = await future #获取session_id
|
|||
|
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
|||
|
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
|||
|
|
|||
|
asyncio.create_task(voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,shutdown_event)) #创建llm处理者
|
|||
|
asyncio.create_task(voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,shutdown_event)) #创建llm断句结果
|
|||
|
asyncio.create_task(voice_call_tts_handler(ws,tts_info,split_result_q,shutdown_event)) #返回tts音频结果
|
|||
|
|
|||
|
while not shutdown_event.is_set():
|
|||
|
await asyncio.sleep(5)
|
|||
|
await ws.close()
|
|||
|
logger.debug("voice_call websocket 连接断开")
|
|||
|
#------------------------------------------------------------------------------------------
|