1
0
Fork 0
TakwayPlatform/app/controllers/chat.py

629 lines
29 KiB
Python
Raw Normal View History

2024-05-01 17:18:30 +08:00
from ..schemas.chat import *
from ..dependencies.logger import get_logger
from .controller_enum import *
from ..models import UserCharacter, Session, Character, User
from utils.audio_utils import VAD
from fastapi import WebSocket, HTTPException, status
from datetime import datetime
from utils.xf_asr_utils import generate_xf_asr_url
from config import get_config
import uuid
import json
import asyncio
import requests
2024-05-01 17:18:30 +08:00
# 依赖注入获取logger
logger = get_logger()
# --------------------初始化本地ASR-----------------------
from utils.stt.funasr_utils import FunAutoSpeechRecognizer
asr = FunAutoSpeechRecognizer()
logger.info("本地ASR初始化成功")
# -------------------------------------------------------
# --------------------初始化本地VITS----------------------
from utils.tts.vits_utils import TextToSpeech
tts = TextToSpeech(device='cpu')
logger.info("本地TTS初始化成功")
# -------------------------------------------------------
# 依赖注入获取Config
Config = get_config()
# ----------------------工具函数-------------------------
#获取session内容
def get_session_content(session_id,redis,db):
session_content_str = ""
if redis.exists(session_id):
session_content_str = redis.get(session_id)
else:
session_db = db.query(Session).filter(Session.id == session_id).first()
if not session_db:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
session_content_str = session_db.content
return json.loads(session_content_str)
#解析大模型流式返回内容
def parseChunkDelta(chunk):
try:
if chunk == b"":
return ""
decoded_data = chunk.decode('utf-8')
parsed_data = json.loads(decoded_data[6:])
if 'delta' in parsed_data['choices'][0]:
delta_content = parsed_data['choices'][0]['delta']
return delta_content['content']
else:
return "end"
except KeyError:
logger.error(f"error chunk: {chunk}")
2024-05-01 17:18:30 +08:00
#断句函数
def split_string_with_punctuation(current_sentence,text,is_first,is_end):
try:
result = []
if is_end:
if current_sentence:
result.append(current_sentence)
2024-05-01 17:18:30 +08:00
current_sentence = ''
return result, current_sentence, is_first
for char in text:
current_sentence += char
if is_first and char in ',.?!,。?!':
result.append(current_sentence)
current_sentence = ''
is_first = False
elif char in '。?!':
result.append(current_sentence)
current_sentence = ''
return result, current_sentence, is_first
except Exception as e:
logger.error(f"断句时出现错误: {str(e)}")
2024-05-02 10:27:21 +08:00
#vad预处理
2024-05-02 10:27:21 +08:00
def vad_preprocess(audio):
if len(audio)<1280:
return ('A'*1280)
return audio[:1280],audio[1280:]
2024-05-01 17:18:30 +08:00
#--------------------------------------------------------
# 创建新聊天
async def create_chat_handler(chat: ChatCreateRequest, db, redis):
# 创建新的UserCharacter记录
new_chat = UserCharacter(user_id=chat.user_id, character_id=chat.character_id)
try:
db.add(new_chat)
db.commit()
db.refresh(new_chat)
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
# 查询所要创建聊天的角色信息并创建SystemPrompt
db_character = db.query(Character).filter(Character.id == chat.character_id).first()
db_user = db.query(User).filter(User.id == chat.user_id).first()
system_prompt = f"""我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\n{"角色名称: " + db_character.name}\n{"角色背景: " + db_character.description}\n{"角色所处环境: " + db_character.world_scenario}\n
{"角色的常用问候语: " + db_character.wakeup_words}\n你需要用简单通俗易懂的口语化方式进行对话在没有经过允许的情况下你需要保持上述角色不得擅自跳出角色设定\n 与你聊天的对象信息如下{db_user.persona}"""
# 创建新的Session记录
session_id = str(uuid.uuid4())
user_id = chat.user_id
messages = json.dumps([{"role": "system", "content": system_prompt}], ensure_ascii=False)
tts_info = {
"language": 0,
"speaker_id":db_character.voice_id,
"noise_scale": 0.1,
"noise_scale_w":0.668,
"length_scale": 1.2
}
llm_info = {
"model": "abab5.5-chat",
"temperature": 1,
"top_p": 0.9,
}
# 将tts和llm信息转化为json字符串
tts_info_str = json.dumps(tts_info, ensure_ascii=False)
llm_info_str = json.dumps(llm_info, ensure_ascii=False)
user_info_str = db_user.persona
token = 0
content = {"user_id": user_id, "messages": messages, "user_info": user_info_str, "tts_info": tts_info_str,
"llm_info": llm_info_str, "token": token}
new_session = Session(id=session_id, user_character_id=new_chat.id, content=json.dumps(content, ensure_ascii=False),
last_activity=datetime.now(), is_permanent=False)
# 将Session记录存入
db.add(new_session)
db.commit()
db.refresh(new_session)
redis.set(session_id, json.dumps(content, ensure_ascii=False))
chat_create_data = ChatCreateData(user_character_id=new_chat.id, session_id=session_id, createdAt=datetime.now().isoformat())
return ChatCreateResponse(status="success", message="创建聊天成功", data=chat_create_data)
#删除聊天
async def delete_chat_handler(user_character_id, db, redis):
# 查询该聊天记录
user_character_record = db.query(UserCharacter).filter(UserCharacter.id == user_character_id).first()
if not user_character_record:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="UserCharacter not found")
session_record = db.query(Session).filter(Session.user_character_id == user_character_id).first()
try:
redis.delete(session_record.id)
except Exception as e:
logger.error(f"删除Redis中Session记录时发生错误: {str(e)}")
try:
db.delete(session_record)
db.delete(user_character_record)
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
chat_delete_data = ChatDeleteData(deletedAt=datetime.now().isoformat())
return ChatDeleteResponse(status="success", message="删除聊天成功", data=chat_delete_data)
# 非流式聊天
async def non_streaming_chat_handler(chat: ChatNonStreamRequest, db, redis):
pass
#---------------------------------------单次流式聊天接口---------------------------------------------
#处理用户输入
async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,user_input_finish_event):
logger.debug("用户输入处理函数启动")
is_future_done = False
try:
while not user_input_finish_event.is_set():
sct_data_json = json.loads(await ws.receive_text())
if not is_future_done:
future_session_id.set_result(sct_data_json['meta_info']['session_id'])
if sct_data_json['meta_info']['voice_synthesize']:
future_response_type.set_result(RESPONSE_AUDIO)
else:
future_response_type.set_result(RESPONSE_TEXT)
is_future_done = True
if sct_data_json['text']:
await llm_input_q.put(sct_data_json['text'])
if not user_input_finish_event.is_set():
user_input_finish_event.set()
break
if sct_data_json['meta_info']['is_end']:
await user_input_q.put(sct_data_json['audio'])
if not user_input_finish_event.is_set():
user_input_finish_event.set()
break
await user_input_q.put(sct_data_json['audio'])
except KeyError as ke:
if 'state' in sct_data_json and 'method' in sct_data_json:
2024-05-01 17:18:30 +08:00
logger.debug("收到心跳包")
except Exception as e:
logger.error(f"用户输入处理函数发生错误: {str(e)}")
2024-05-01 17:18:30 +08:00
#语音识别
async def sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event):
logger.debug("语音识别函数启动")
try:
current_message = ""
while not (user_input_finish_event.is_set() and user_input_q.empty()):
audio_data = await user_input_q.get()
asr_result = asr.streaming_recognize(audio_data)
current_message += ''.join(asr_result['text'])
asr_result = asr.streaming_recognize(b'',is_end=True)
2024-05-01 17:18:30 +08:00
current_message += ''.join(asr_result['text'])
await llm_input_q.put(current_message)
except Exception as e:
logger.error(f"语音识别函数发生错误: {str(e)}")
2024-05-01 17:18:30 +08:00
logger.debug(f"接收到用户消息: {current_message}")
#大模型调用
async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event):
2024-05-01 17:18:30 +08:00
logger.debug("llm调用函数启动")
try:
llm_response = ""
current_sentence = ""
is_first = True
is_end = False
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"])
current_message = await llm_input_q.get()
messages.append({'role': 'user', "content": current_message})
payload = json.dumps({
"model": llm_info["model"],
"stream": True,
"messages": messages,
"max_tokens": 10000,
"temperature": llm_info["temperature"],
"top_p": llm_info["top_p"]
})
headers = {
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json'
}
response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload,stream=True) #调用大模型
except Exception as e:
logger.error(f"llm调用发生错误: {str(e)}")
try:
for chunk in response.iter_lines():
chunk_data = parseChunkDelta(chunk)
is_end = chunk_data == "end"
if not is_end:
llm_response += chunk_data
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) #断句
for sentence in sentences:
if response_type == RESPONSE_TEXT:
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
elif response_type == RESPONSE_AUDIO:
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_bytes(audio) #返回音频数据
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
logger.debug(f"websocket返回: {sentence}")
if is_end:
logger.debug(f"llm返回结果: {llm_response}")
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
is_end = False #重置is_end标志位
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"])
messages.append({'role': 'assistant', "content": llm_response})
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
is_first = True
llm_response = ""
except Exception as e:
logger.error(f"处理llm返回结果发生错误: {str(e)}")
chat_finished_event.set()
2024-05-01 17:18:30 +08:00
async def streaming_chat_temporary_handler(ws: WebSocket, db, redis):
logger.debug("streaming chat temporary websocket 连接建立")
user_input_q = asyncio.Queue() # 用于存储用户输入
llm_input_q = asyncio.Queue() # 用于存储llm输入
user_input_finish_event = asyncio.Event()
chat_finished_event = asyncio.Event()
2024-05-01 17:18:30 +08:00
future_session_id = asyncio.Future()
future_response_type = asyncio.Future()
asyncio.create_task(sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,user_input_finish_event))
asyncio.create_task(sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event))
session_id = await future_session_id #获取session_id
response_type = await future_response_type #获取返回类型
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
asyncio.create_task(sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event))
2024-05-01 17:18:30 +08:00
while not chat_finished_event.is_set():
2024-05-01 17:18:30 +08:00
await asyncio.sleep(1)
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
await ws.close()
logger.debug("streaming chat temporary websocket 连接断开")
#---------------------------------------------------------------------------------------------------
#------------------------------------------持续流式聊天----------------------------------------------
#处理用户输入
async def scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,input_finished_event):
logger.debug("用户输入处理函数启动")
is_future_done = False
while not input_finished_event.is_set():
try:
scl_data_json = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=3))
if scl_data_json['is_close']:
input_finished_event.set()
break
if not is_future_done:
future_session_id.set_result(scl_data_json['meta_info']['session_id'])
if scl_data_json['meta_info']['voice_synthesize']:
future_response_type.set_result(RESPONSE_AUDIO)
else:
future_response_type.set_result(RESPONSE_TEXT)
is_future_done = True
if scl_data_json['text']:
await llm_input_q.put(scl_data_json['text'])
if scl_data_json['meta_info']['is_end']:
user_input_frame = {"audio": scl_data_json['audio'], "is_end": True}
await user_input_q.put(user_input_frame)
user_input_frame = {"audio": scl_data_json['audio'], "is_end": False}
await user_input_q.put(user_input_frame)
except KeyError as ke:
if 'state' in scl_data_json and 'method' in scl_data_json:
logger.debug("收到心跳包")
continue
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"用户输入处理函数发生错误: {str(e)}")
break
#语音识别
async def scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event):
logger.debug("语音识别函数启动")
current_message = ""
while not (input_finished_event.is_set() and user_input_q.empty()):
try:
aduio_frame = await asyncio.wait_for(user_input_q.get(),timeout=3)
if aduio_frame['is_end']:
asr_result = asr.streaming_recognize(aduio_frame['audio'], is_end=True)
current_message += ''.join(asr_result['text'])
await llm_input_q.put(current_message)
logger.debug(f"接收到用户消息: {current_message}")
else:
asr_result = asr.streaming_recognize(aduio_frame['audio'])
current_message += ''.join(asr_result['text'])
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"语音识别函数发生错误: {str(e)}")
break
asr_finished_event.set()
#大模型调用
async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,asr_finished_event,chat_finished_event):
logger.debug("llm调用函数启动")
llm_response = ""
current_sentence = ""
is_first = True
is_end = False
while not (asr_finished_event.is_set() and llm_input_q.empty()):
try:
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"])
current_message = await asyncio.wait_for(llm_input_q.get(),timeout=3)
messages.append({'role': 'user', "content": current_message})
payload = json.dumps({
"model": llm_info["model"],
"stream": True,
"messages": messages,
"max_tokens": 10000,
"temperature": llm_info["temperature"],
"top_p": llm_info["top_p"]
})
headers = {
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json'
}
response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload,stream=True)
for chunk in response.iter_lines():
chunk_data = parseChunkDelta(chunk)
is_end = chunk_data == "end"
if not is_end:
llm_response += chunk_data
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
for sentence in sentences:
if response_type == RESPONSE_TEXT:
logger.debug(f"websocket返回: {sentence}")
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
elif response_type == RESPONSE_AUDIO:
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_bytes(audio)
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
logger.debug(f"websocket返回: {sentence}")
if is_end:
logger.debug(f"llm返回结果: {llm_response}")
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
is_end = False
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"])
messages.append({'role': 'assistant', "content": llm_response})
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
is_first = True
llm_response = ""
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"处理llm返回结果发生错误: {str(e)}")
break
chat_finished_event.set()
async def streaming_chat_lasting_handler(ws,db,redis):
logger.debug("streaming chat lasting websocket 连接建立")
user_input_q = asyncio.Queue() # 用于存储用户输入
llm_input_q = asyncio.Queue() # 用于存储llm输入
input_finished_event = asyncio.Event()
asr_finished_event = asyncio.Event()
chat_finished_event = asyncio.Event()
future_session_id = asyncio.Future()
future_response_type = asyncio.Future()
asyncio.create_task(scl_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,input_finished_event))
asyncio.create_task(scl_asr_handler(user_input_q,llm_input_q,input_finished_event,asr_finished_event))
session_id = await future_session_id #获取session_id
response_type = await future_response_type #获取返回类型
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
asyncio.create_task(scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,asr_finished_event,chat_finished_event))
while not chat_finished_event.is_set():
await asyncio.sleep(3)
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
await ws.close()
logger.debug("streaming chat lasting websocket 连接断开")
#---------------------------------------------------------------------------------------------------
2024-05-01 17:18:30 +08:00
#--------------------------------语音通话接口--------------------------------------
#音频数据生产函数
async def voice_call_audio_producer(ws,audio_q,future,input_finished_event):
2024-05-01 17:18:30 +08:00
logger.debug("音频数据生产函数启动")
is_future_done = False
2024-05-02 10:27:21 +08:00
audio_data = ""
while not input_finished_event.is_set():
try:
voice_call_data_json = json.loads(await asyncio.wait_for(ws.receive_text(),timeout=3))
2024-05-02 10:27:21 +08:00
if not is_future_done: #在第一次循环中读取session_id
future.set_result(voice_call_data_json['meta_info']['session_id'])
is_future_done = True
if voice_call_data_json["is_close"]:
input_finished_event.set()
break
else:
audio_data += voice_call_data_json["audio"]
while len(audio_data) > 1280:
vad_frame,audio_data = vad_preprocess(audio_data)
await audio_q.put(vad_frame) #将音频数据存入audio_q
except KeyError as ke:
if 'state' in voice_call_data_json and 'method' in voice_call_data_json:
logger.info(f"收到心跳包")
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"音频数据生产函数发生错误: {str(e)}")
break
2024-05-01 17:18:30 +08:00
#音频数据消费函数
async def voice_call_audio_consumer(ws,audio_q,asr_result_q,input_finished_event,asr_finished_event):
2024-05-01 17:18:30 +08:00
logger.debug("音频数据消费者函数启动")
vad = VAD()
current_message = ""
vad_count = 0
while not (input_finished_event.is_set() and audio_q.empty()):
try:
audio_data = await asyncio.wait_for(audio_q.get(),timeout=3)
if vad.is_speech(audio_data):
if vad_count > 0:
vad_count -= 1
asr_result = asr.streaming_recognize(audio_data)
current_message += ''.join(asr_result['text'])
else:
vad_count += 1
if vad_count >= 25: #连续25帧没有语音则认为说完了
asr_result = asr.streaming_recognize(audio_data, is_end=True)
if current_message:
logger.debug(f"检测到静默,用户输入为:{current_message}")
await asr_result_q.put(current_message)
text_response = {"type": "user_text", "code": 200, "msg": current_message}
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
current_message = ""
vad_count = 0
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"音频数据消费者函数发生错误: {str(e)}")
break
asr_finished_event.set()
2024-05-01 17:18:30 +08:00
#asr结果消费以及llm返回生产函数
async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_result_q,asr_finished_event,voice_call_end_event):
2024-05-01 17:18:30 +08:00
logger.debug("asr结果消费以及llm返回生产函数启动")
llm_response = ""
current_sentence = ""
is_first = True
is_end = False
while not (asr_finished_event.is_set() and asr_result_q.empty()):
try:
2024-05-01 17:18:30 +08:00
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"])
current_message = await asyncio.wait_for(asr_result_q.get(),timeout=3)
messages.append({'role': 'user', "content": current_message})
payload = json.dumps({
"model": llm_info["model"],
"stream": True,
"messages": messages,
"max_tokens":10000,
"temperature": llm_info["temperature"],
"top_p": llm_info["top_p"]
})
headers = {
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json'
}
response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload,stream=True)
for chunk in response.iter_lines():
chunk_data = parseChunkDelta(chunk)
is_end = chunk_data == "end"
if not is_end:
llm_response += chunk_data
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
for sentence in sentences:
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
await ws.send_bytes(audio) #返回音频二进制流数据
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
logger.debug(f"llm返回结果: {sentence}")
if is_end:
logger.debug(f"llm返回结果: {llm_response}")
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
is_end = False
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"])
messages.append({'role': 'assistant', "content": llm_response})
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
is_first = True
llm_response = ""
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"处理llm返回结果发生错误: {str(e)}")
break
voice_call_end_event.set()
2024-05-01 17:18:30 +08:00
#语音合成及返回函数
async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event):
2024-05-01 17:18:30 +08:00
logger.debug("语音合成及返回函数启动")
while not (split_finished_event.is_set() and split_result_q.empty()):
try:
sentence = await asyncio.wait_for(split_result_q.get(),timeout=3)
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
await ws.send_bytes(audio) #返回音频二进制流数据
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
logger.debug(f"websocket返回:{sentence}")
except asyncio.TimeoutError:
continue
voice_call_end_event.set()
2024-05-01 17:18:30 +08:00
async def voice_call_handler(ws, db, redis):
logger.debug("voice_call websocket 连接建立")
audio_q = asyncio.Queue() #音频队列
asr_result_q = asyncio.Queue() #语音识别结果队列
input_finished_event = asyncio.Event() #用户输入结束事件
asr_finished_event = asyncio.Event() #语音识别结束事件
voice_call_end_event = asyncio.Event() #语音电话终止事件
future = asyncio.Future() #用于获取传输的session_id
asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,input_finished_event)) #创建音频数据生产者
asyncio.create_task(voice_call_audio_consumer(ws,audio_q,asr_result_q,input_finished_event,asr_finished_event)) #创建音频数据消费者
2024-05-01 17:18:30 +08:00
#获取session内容
session_id = await future #获取session_id
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
asyncio.create_task(voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_result_q,asr_finished_event,voice_call_end_event)) #创建llm处理者
while not voice_call_end_event.is_set():
await asyncio.sleep(3)
2024-05-01 17:18:30 +08:00
await ws.close()
logger.debug("voice_call websocket 连接断开")
#------------------------------------------------------------------------------------------