Compare commits
5 Commits
Author | SHA1 | Date |
---|---|---|
|
d0b4bd4b3c | |
|
2b870c2e7d | |
|
05ccd1c8c0 | |
|
42767b065f | |
|
a776258f8b |
|
@ -7,11 +7,3 @@ __pycache__/
|
||||||
app.log
|
app.log
|
||||||
/utils/tts/vits_model/
|
/utils/tts/vits_model/
|
||||||
vits_model
|
vits_model
|
||||||
|
|
||||||
nohup.out
|
|
||||||
|
|
||||||
/app/takway-ai.top.key
|
|
||||||
/app/takway-ai.top.pem
|
|
||||||
|
|
||||||
tests/assets/BarbieDollsVoice.mp3
|
|
||||||
utils/tts/openvoice_model/checkpoint.pth
|
|
||||||
|
|
122
README.md
122
README.md
|
@ -15,15 +15,15 @@ TakwayAI/
|
||||||
│ │ ├── models.py # 数据库定义
|
│ │ ├── models.py # 数据库定义
|
||||||
│ ├── schemas/ # 请求和响应模型
|
│ ├── schemas/ # 请求和响应模型
|
||||||
│ │ ├── __init__.py
|
│ │ ├── __init__.py
|
||||||
│ │ ├── user_schemas.py # 用户相关schema
|
│ │ ├── user.py # 用户相关schema
|
||||||
│ │ └── ... # 其他schema
|
│ │ └── ... # 其他schema
|
||||||
│ ├── controllers/ # 业务逻辑控制器
|
│ ├── controllers/ # 业务逻辑控制器
|
||||||
│ │ ├── __init__.py
|
│ │ ├── __init__.py
|
||||||
│ │ ├── user_controllers.py # 用户相关控制器
|
│ │ ├── user.py # 用户相关控制器
|
||||||
│ │ └── ... # 其他控制器
|
│ │ └── ... # 其他控制器
|
||||||
│ ├── routes/ # 路由和视图函数
|
│ ├── routes/ # 路由和视图函数
|
||||||
│ │ ├── __init__.py
|
│ │ ├── __init__.py
|
||||||
│ │ ├── user_routes.py # 用户相关路由
|
│ │ ├── user.py # 用户相关路由
|
||||||
│ │ └── ... # 其他路由
|
│ │ └── ... # 其他路由
|
||||||
│ ├── dependencies/ # 依赖注入相关
|
│ ├── dependencies/ # 依赖注入相关
|
||||||
│ │ ├── __init__.py
|
│ │ ├── __init__.py
|
||||||
|
@ -64,129 +64,21 @@ TakwayAI/
|
||||||
git clone http://43.132.157.186:3000/killua/TakwayPlatform.git
|
git clone http://43.132.157.186:3000/killua/TakwayPlatform.git
|
||||||
```
|
```
|
||||||
|
|
||||||
#### (2) 创建虚拟环境
|
#### (2) 安装依赖
|
||||||
|
|
||||||
创建虚拟环境
|
|
||||||
|
|
||||||
``` shell
|
``` shell
|
||||||
conda create -n takway python=3.9
|
cd TakwayAI/
|
||||||
conda activate takway
|
|
||||||
```
|
|
||||||
|
|
||||||
#### (3) 安装依赖
|
|
||||||
|
|
||||||
如果你的本地环境可以科学上网,则直接运行下面两行指令
|
|
||||||
|
|
||||||
``` shell
|
|
||||||
pip install git+https://github.com/myshell-ai/MeloTTS.git
|
|
||||||
python -m unidic download
|
|
||||||
```
|
|
||||||
|
|
||||||
如果不能科学上网
|
|
||||||
|
|
||||||
则先运行
|
|
||||||
|
|
||||||
``` shell
|
|
||||||
pip install git+https://github.com/myshell-ai/MeloTTS.git
|
|
||||||
```
|
|
||||||
|
|
||||||
1. unidic安装
|
|
||||||
|
|
||||||
然后手动下载[unidic.zip](https://cotonoha-dic.s3-ap-northeast-1.amazonaws.com/unidic-3.1.0.zip),并手动改名为unidic.zip
|
|
||||||
|
|
||||||
这边以miniconda举例,如果用的是conda应该也是一样的
|
|
||||||
|
|
||||||
将unidic.zip拷贝入~/miniconda3/envs/takway/lib/python3.9/site-packages/unidic
|
|
||||||
|
|
||||||
cd进入~/miniconda3/envs/takway/lib/python3.9/site-packages/unidic
|
|
||||||
|
|
||||||
vim download.py
|
|
||||||
|
|
||||||
将函数download_version()中除了最后一行全部注释掉,并且把最后一行的download_and_clean()的两个参数任意修改,比如"hello","world"
|
|
||||||
|
|
||||||
再将download_and_clean()函数定义位置,注释掉该函数中的download_process()行
|
|
||||||
|
|
||||||
运行`python -m unidic download`
|
|
||||||
|
|
||||||
2. huggingface配置
|
|
||||||
|
|
||||||
运行命令
|
|
||||||
|
|
||||||
```shell
|
|
||||||
pip install -U huggingface_hub
|
|
||||||
export HF_ENDPOINT=https://hf-mirror.com
|
|
||||||
```
|
|
||||||
|
|
||||||
最好把`export HF_ENDPOINT=https://hf-mirror.com`写入~/.bashrc,不然每次重启控制台终端就会失效
|
|
||||||
|
|
||||||
3. nltk_data下载
|
|
||||||
|
|
||||||
在/miniconda3/envs/takway/下创建nltk_data文件夹
|
|
||||||
|
|
||||||
在nltk_data文件夹下创建corpora和taggers文件夹
|
|
||||||
|
|
||||||
手动下载[cmudict.zip](https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/cmudict.zip)和[averaged_perceptron_tragger.zip](https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/taggers/averaged_perceptron_tagger.zip)
|
|
||||||
|
|
||||||
将cmudict.zip放入corpora文件夹下
|
|
||||||
|
|
||||||
将averaged_perceptron_tragger.zip放入taggers文件夹下
|
|
||||||
|
|
||||||
4. 下载其他依赖
|
|
||||||
|
|
||||||
``` shell
|
|
||||||
cd TakwayPlatform/
|
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
5. debug
|
#### (3) 修改配置
|
||||||
若出现AttributeError: module 'botocore.exceptions' has no attribute 'HTTPClientError'异常
|
|
||||||
则执行下述命令
|
|
||||||
|
|
||||||
``` shell
|
|
||||||
pip uninstall botocore
|
|
||||||
pip install botocore==1.34.88
|
|
||||||
```
|
|
||||||
|
|
||||||
#### (4) 安装FunASR
|
|
||||||
|
|
||||||
本项目使用的FunASRE在github上的FunASR的基础上做了一些修改
|
|
||||||
|
|
||||||
``` shell
|
|
||||||
git clone http://43.132.157.186:3000/gaohz/FunASR.git
|
|
||||||
cd FunASR/
|
|
||||||
pip install -v -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
#### (5) 修改配置
|
|
||||||
|
|
||||||
1. 安装mysql,在mysql中创建名为takway的数据库
|
1. 安装mysql,在mysql中创建名为takway的数据库
|
||||||
2. 安装redis,将密码设置为takway
|
2. 安装redis,将密码设置为takway
|
||||||
3. 打开config中的development.py文件修改mysql和redis连接字符串
|
3. 打开config中的development.py文件修改mysql和redis连接字符串
|
||||||
|
|
||||||
#### (6) 导入vits模型
|
#### (4) 导入vits模型
|
||||||
|
|
||||||
在utils/tts/目录下,创建vits_model文件夹
|
在utils/tts/目录下,创建vits_model文件夹
|
||||||
|
|
||||||
从[链接](https://huggingface.co/spaces/zomehwh/vits-uma-genshin-honkai/tree/main/model)下载 vits_model并放入该文件夹下,只需下载config.json和G_953000.pth即可
|
从[链接](https://huggingface.co/spaces/zomehwh/vits-uma-genshin-honkai/tree/main/model)下载 vits_model并放入该文件夹下,只需下载config.json和G_953000.pth即可
|
||||||
|
|
||||||
#### (7) 导入openvoice模型
|
|
||||||
|
|
||||||
在utils/tts/目录下,创建openvoice_model文件夹
|
|
||||||
|
|
||||||
从[链接](https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/checkpoints_v2_0417.zip)下载model文件,并放入该文件夹下
|
|
||||||
|
|
||||||
#### (8) 添加配置环境变量
|
|
||||||
|
|
||||||
``` shell
|
|
||||||
vim ~/.bashrc
|
|
||||||
在最后添加
|
|
||||||
export MODE=development
|
|
||||||
```
|
|
||||||
|
|
||||||
#### (9) 启动程序
|
|
||||||
|
|
||||||
``` shell
|
|
||||||
cd TakwayPlatform
|
|
||||||
python main.py
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ logger.info("数据库初始化完成")
|
||||||
|
|
||||||
#--------------------设置定时任务-----------------------
|
#--------------------设置定时任务-----------------------
|
||||||
scheduler = AsyncIOScheduler()
|
scheduler = AsyncIOScheduler()
|
||||||
# scheduler.add_job(updating_redis_cache, CronTrigger.from_crontab("0 4 * * *"))
|
scheduler.add_job(updating_redis_cache, CronTrigger.from_crontab("0 4 * * *"))
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app:FastAPI):
|
async def lifespan(app:FastAPI):
|
||||||
scheduler.start() #启动定时任务
|
scheduler.start() #启动定时任务
|
||||||
|
|
|
@ -1,38 +1,35 @@
|
||||||
from ..schemas.chat_schema import *
|
from ..schemas.chat_schema import *
|
||||||
from ..dependencies.logger import get_logger
|
from ..dependencies.logger import get_logger
|
||||||
from ..dependencies.summarizer import get_summarizer
|
|
||||||
from ..dependencies.asr import get_asr
|
|
||||||
from ..dependencies.tts import get_tts
|
|
||||||
from .controller_enum import *
|
from .controller_enum import *
|
||||||
from ..models import UserCharacter, Session, Character, User, Audio
|
from ..models import UserCharacter, Session, Character, User
|
||||||
from utils.audio_utils import VAD
|
from utils.audio_utils import VAD
|
||||||
from fastapi import WebSocket, HTTPException, status
|
from fastapi import WebSocket, HTTPException, status
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
|
from utils.xf_asr_utils import generate_xf_asr_url
|
||||||
from config import get_config
|
from config import get_config
|
||||||
import numpy as np
|
|
||||||
import websockets
|
|
||||||
import struct
|
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
import requests
|
||||||
import io
|
|
||||||
|
|
||||||
# 依赖注入获取logger
|
# 依赖注入获取logger
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
# 依赖注入获取context总结服务
|
# --------------------初始化本地ASR-----------------------
|
||||||
summarizer = get_summarizer()
|
from utils.stt.funasr_utils import FunAutoSpeechRecognizer
|
||||||
|
|
||||||
# -----------------------获取ASR-------------------------
|
asr = FunAutoSpeechRecognizer()
|
||||||
asr = get_asr()
|
logger.info("本地ASR初始化成功")
|
||||||
# -------------------------------------------------------
|
# -------------------------------------------------------
|
||||||
|
|
||||||
# -------------------------TTS--------------------------
|
# --------------------初始化本地VITS----------------------
|
||||||
tts = get_tts()
|
from utils.tts.vits_utils import TextToSpeech
|
||||||
|
|
||||||
|
tts = TextToSpeech(device='cpu')
|
||||||
|
logger.info("本地TTS初始化成功")
|
||||||
# -------------------------------------------------------
|
# -------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
# 依赖注入获取Config
|
# 依赖注入获取Config
|
||||||
Config = get_config()
|
Config = get_config()
|
||||||
|
|
||||||
|
@ -53,20 +50,16 @@ def get_session_content(session_id,redis,db):
|
||||||
def parseChunkDelta(chunk):
|
def parseChunkDelta(chunk):
|
||||||
try:
|
try:
|
||||||
if chunk == b"":
|
if chunk == b"":
|
||||||
return 1,""
|
return ""
|
||||||
decoded_data = chunk.decode('utf-8')
|
decoded_data = chunk.decode('utf-8')
|
||||||
parsed_data = json.loads(decoded_data[6:])
|
parsed_data = json.loads(decoded_data[6:])
|
||||||
if 'delta' in parsed_data['choices'][0]:
|
if 'delta' in parsed_data['choices'][0]:
|
||||||
delta_content = parsed_data['choices'][0]['delta']
|
delta_content = parsed_data['choices'][0]['delta']
|
||||||
return -1, delta_content['content']
|
return delta_content['content']
|
||||||
else:
|
else:
|
||||||
return parsed_data['usage']['total_tokens'] , ""
|
return "end"
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.error(f"error chunk: {decoded_data}")
|
logger.error(f"error chunk: {chunk}")
|
||||||
return 1,""
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.error(f"error chunk: {decoded_data}")
|
|
||||||
return 1,""
|
|
||||||
|
|
||||||
#断句函数
|
#断句函数
|
||||||
def split_string_with_punctuation(current_sentence,text,is_first,is_end):
|
def split_string_with_punctuation(current_sentence,text,is_first,is_end):
|
||||||
|
@ -104,19 +97,6 @@ def update_session_activity(session_id,db):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.roolback()
|
db.roolback()
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
|
||||||
#获取target_se
|
|
||||||
def get_emb(session_id,db):
|
|
||||||
try:
|
|
||||||
session_record = db.query(Session).filter(Session.id == session_id).first()
|
|
||||||
user_character_record = db.query(UserCharacter).filter(UserCharacter.id == session_record.user_character_id).first()
|
|
||||||
user_record = db.query(User).filter(User.id == user_character_record.user_id).first()
|
|
||||||
audio_record = db.query(Audio).filter(Audio.id == user_record.selected_audio_id).first()
|
|
||||||
emb_npy = np.load(io.BytesIO(audio_record.emb_data))
|
|
||||||
return emb_npy
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug("未找到音频:"+str(e))
|
|
||||||
return np.array([])
|
|
||||||
#--------------------------------------------------------
|
#--------------------------------------------------------
|
||||||
|
|
||||||
# 创建新聊天
|
# 创建新聊天
|
||||||
|
@ -126,6 +106,7 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
|
||||||
try:
|
try:
|
||||||
db.add(new_chat)
|
db.add(new_chat)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
db.refresh(new_chat)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
@ -151,23 +132,18 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
|
||||||
"speaker_id":db_character.voice_id,
|
"speaker_id":db_character.voice_id,
|
||||||
"noise_scale": 0.1,
|
"noise_scale": 0.1,
|
||||||
"noise_scale_w":0.668,
|
"noise_scale_w":0.668,
|
||||||
"length_scale": 1.2,
|
"length_scale": 1.2
|
||||||
"speed":1
|
|
||||||
}
|
}
|
||||||
llm_info = {
|
llm_info = {
|
||||||
"model": "abab5.5-chat",
|
"model": "abab5.5-chat",
|
||||||
"temperature": 1,
|
"temperature": 1,
|
||||||
"top_p": 0.9,
|
"top_p": 0.9,
|
||||||
}
|
}
|
||||||
user_info = {
|
|
||||||
"character":"",
|
|
||||||
"events":[]
|
|
||||||
}
|
|
||||||
|
|
||||||
# 将tts和llm信息转化为json字符串
|
# 将tts和llm信息转化为json字符串
|
||||||
tts_info_str = json.dumps(tts_info, ensure_ascii=False)
|
tts_info_str = json.dumps(tts_info, ensure_ascii=False)
|
||||||
llm_info_str = json.dumps(llm_info, ensure_ascii=False)
|
llm_info_str = json.dumps(llm_info, ensure_ascii=False)
|
||||||
user_info_str = json.dumps(user_info, ensure_ascii=False)
|
user_info_str = db_user.persona
|
||||||
|
|
||||||
token = 0
|
token = 0
|
||||||
content = {"user_id": user_id, "messages": messages, "user_info": user_info_str, "tts_info": tts_info_str,
|
content = {"user_id": user_id, "messages": messages, "user_info": user_info_str, "tts_info": tts_info_str,
|
||||||
|
@ -178,6 +154,7 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
|
||||||
# 将Session记录存入
|
# 将Session记录存入
|
||||||
db.add(new_session)
|
db.add(new_session)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
db.refresh(new_session)
|
||||||
redis.set(session_id, json.dumps(content, ensure_ascii=False))
|
redis.set(session_id, json.dumps(content, ensure_ascii=False))
|
||||||
|
|
||||||
chat_create_data = ChatCreateData(user_character_id=new_chat.id, session_id=session_id, createdAt=datetime.now().isoformat())
|
chat_create_data = ChatCreateData(user_character_id=new_chat.id, session_id=session_id, createdAt=datetime.now().isoformat())
|
||||||
|
@ -246,75 +223,26 @@ async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f
|
||||||
logger.error(f"用户输入处理函数发生错误: {str(e)}")
|
logger.error(f"用户输入处理函数发生错误: {str(e)}")
|
||||||
|
|
||||||
#语音识别
|
#语音识别
|
||||||
async def sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event):
|
async def sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event):
|
||||||
logger.debug("语音识别函数启动")
|
logger.debug("语音识别函数启动")
|
||||||
if Config.STRAM_CHAT.ASR == "LOCAL":
|
is_signup = False
|
||||||
is_signup = False
|
try:
|
||||||
audio = ""
|
|
||||||
try:
|
|
||||||
current_message = ""
|
|
||||||
while not (user_input_finish_event.is_set() and user_input_q.empty()):
|
|
||||||
if not is_signup:
|
|
||||||
asr.session_signup(session_id)
|
|
||||||
is_signup = True
|
|
||||||
audio_data = await user_input_q.get()
|
|
||||||
audio += audio_data
|
|
||||||
asr_result = asr.streaming_recognize(session_id,audio_data)
|
|
||||||
current_message += ''.join(asr_result['text'])
|
|
||||||
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
|
|
||||||
current_message += ''.join(asr_result['text'])
|
|
||||||
slice_arr = ["嗯",""]
|
|
||||||
if current_message in slice_arr:
|
|
||||||
await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False))
|
|
||||||
return
|
|
||||||
current_message = asr.punctuation_correction(current_message)
|
|
||||||
# emotion_dict = asr.emtion_recognition(audio) #情感辨识
|
|
||||||
# if not isinstance(emotion_dict, str):
|
|
||||||
# max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
|
||||||
# current_message = f"{current_message},当前说话人的情绪:{emotion_dict['labels'][max_index]}"
|
|
||||||
await llm_input_q.put(current_message)
|
|
||||||
asr.session_signout(session_id)
|
|
||||||
except Exception as e:
|
|
||||||
asr.session_signout(session_id)
|
|
||||||
logger.error(f"语音识别函数发生错误: {str(e)}")
|
|
||||||
logger.debug(f"接收到用户消息: {current_message}")
|
|
||||||
elif Config.STRAM_CHAT.ASR == "XF":
|
|
||||||
status = FIRST_FRAME
|
|
||||||
xf_websocket = await xf_asr_websocket_factory() #获取一个讯飞语音识别接口websocket连接
|
|
||||||
segment_duration_threshold = 25 #设置一个连接时长上限,讯飞语音接口超过30秒会自动断开连接,所以该值设置成25秒
|
|
||||||
segment_start_time = asyncio.get_event_loop().time()
|
|
||||||
current_message = ""
|
current_message = ""
|
||||||
while not (user_input_finish_event.is_set() and user_input_q.empty()):
|
while not (user_input_finish_event.is_set() and user_input_q.empty()):
|
||||||
try:
|
if not is_signup:
|
||||||
audio_data = await user_input_q.get()
|
asr.session_signup(session_id)
|
||||||
current_time = asyncio.get_event_loop().time()
|
is_signup = True
|
||||||
if current_time - segment_start_time > segment_duration_threshold:
|
audio_data = await user_input_q.get()
|
||||||
await xf_websocket.send(make_last_frame())
|
asr_result = asr.streaming_recognize(session_id,audio_data)
|
||||||
current_message += parse_xfasr_recv(await xf_websocket.recv())
|
current_message += ''.join(asr_result['text'])
|
||||||
await xf_websocket.close()
|
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
|
||||||
xf_websocket = await xf_asr_websocket_factory() #重建一个websocket连接
|
current_message += ''.join(asr_result['text'])
|
||||||
status = FIRST_FRAME
|
|
||||||
segment_start_time = current_time
|
|
||||||
if status == FIRST_FRAME:
|
|
||||||
await xf_websocket.send(make_first_frame(audio_data))
|
|
||||||
status = CONTINUE_FRAME
|
|
||||||
elif status == CONTINUE_FRAME:
|
|
||||||
await xf_websocket.send(make_continue_frame(audio_data))
|
|
||||||
except websockets.exceptions.ConnectionClosedOK:
|
|
||||||
logger.debug("讯飞语音识别接口连接断开,重新创建连接")
|
|
||||||
xf_websocket = await xf_asr_websocket_factory() #重建一个websocket连接
|
|
||||||
status = FIRST_FRAME
|
|
||||||
segment_start_time = asyncio.get_event_loop().time()
|
|
||||||
await xf_websocket.send(make_last_frame())
|
|
||||||
current_message += parse_xfasr_recv(await xf_websocket.recv())
|
|
||||||
await xf_websocket.close()
|
|
||||||
if current_message in ["嗯", ""]:
|
|
||||||
await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False))
|
|
||||||
return
|
|
||||||
await llm_input_q.put(current_message)
|
await llm_input_q.put(current_message)
|
||||||
logger.debug(f"接收到用户消息: {current_message}")
|
asr.session_signout(session_id)
|
||||||
|
except Exception as e:
|
||||||
|
asr.session_signout(session_id)
|
||||||
|
logger.error(f"语音识别函数发生错误: {str(e)}")
|
||||||
|
logger.debug(f"接收到用户消息: {current_message}")
|
||||||
|
|
||||||
#大模型调用
|
#大模型调用
|
||||||
async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event):
|
async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event):
|
||||||
|
@ -325,18 +253,13 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
is_first = True
|
is_first = True
|
||||||
is_end = False
|
is_end = False
|
||||||
session_content = get_session_content(session_id,redis,db)
|
session_content = get_session_content(session_id,redis,db)
|
||||||
user_info = json.loads(session_content["user_info"])
|
|
||||||
messages = json.loads(session_content["messages"])
|
messages = json.loads(session_content["messages"])
|
||||||
current_message = await llm_input_q.get()
|
current_message = await llm_input_q.get()
|
||||||
if current_message == "":
|
|
||||||
return
|
|
||||||
messages.append({'role': 'user', "content": current_message})
|
messages.append({'role': 'user', "content": current_message})
|
||||||
messages_send = messages #创造一个message副本,在其中最后一条数据前面添加用户信息
|
|
||||||
# messages_send[-1]['content'] = f"用户性格:{user_info['character']}\n事件摘要:{user_info['events']}" + messages_send[-1]['content']
|
|
||||||
payload = json.dumps({
|
payload = json.dumps({
|
||||||
"model": llm_info["model"],
|
"model": llm_info["model"],
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"messages": messages_send,
|
"messages": messages,
|
||||||
"max_tokens": 10000,
|
"max_tokens": 10000,
|
||||||
"temperature": llm_info["temperature"],
|
"temperature": llm_info["temperature"],
|
||||||
"top_p": llm_info["top_p"]
|
"top_p": llm_info["top_p"]
|
||||||
|
@ -345,48 +268,35 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
target_se = get_emb(session_id,db)
|
response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload,stream=True) #调用大模型
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"编辑http请求时发生错误: {str(e)}")
|
logger.error(f"llm调用发生错误: {str(e)}")
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as client:
|
for chunk in response.iter_lines():
|
||||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
chunk_data = parseChunkDelta(chunk)
|
||||||
async for chunk in response.content.iter_any():
|
is_end = chunk_data == "end"
|
||||||
token_count, chunk_data = parseChunkDelta(chunk)
|
if not is_end:
|
||||||
is_end = token_count >0
|
llm_response += chunk_data
|
||||||
if not is_end:
|
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) #断句
|
||||||
llm_response += chunk_data
|
for sentence in sentences:
|
||||||
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) #断句
|
if response_type == RESPONSE_TEXT:
|
||||||
for sentence in sentences:
|
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||||
if response_type == RESPONSE_TEXT:
|
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
|
||||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
elif response_type == RESPONSE_AUDIO:
|
||||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
|
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
|
||||||
elif response_type == RESPONSE_AUDIO:
|
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||||
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
|
await ws.send_bytes(audio) #返回音频数据
|
||||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
|
||||||
response_bytes = json.dumps(response_message, ensure_ascii=False).encode('utf-8')
|
logger.debug(f"websocket返回: {sentence}")
|
||||||
header = struct.pack('!II',len(response_bytes),len(audio))
|
if is_end:
|
||||||
message_bytes = header + response_bytes + audio
|
logger.debug(f"llm返回结果: {llm_response}")
|
||||||
await ws.send_bytes(message_bytes)
|
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||||
logger.debug(f"websocket返回: {sentence}")
|
is_end = False #重置is_end标志位
|
||||||
if is_end:
|
messages.append({'role': 'assistant', "content": llm_response})
|
||||||
logger.debug(f"llm返回结果: {llm_response}")
|
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
||||||
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||||
is_end = False #重置is_end标志位
|
is_first = True
|
||||||
messages.append({'role': 'assistant', "content": llm_response})
|
llm_response = ""
|
||||||
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
|
||||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
|
||||||
is_first = True
|
|
||||||
llm_response = ""
|
|
||||||
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数,则进行文本摘要
|
|
||||||
system_prompt = messages[0]['content']
|
|
||||||
summary = await summarizer.summarize(messages)
|
|
||||||
events = user_info['events']
|
|
||||||
events.append(summary['event'])
|
|
||||||
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
|
|
||||||
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
|
|
||||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
|
|
||||||
logger.debug(f"总结后session_content: {session_content}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理llm返回结果发生错误: {str(e)}")
|
logger.error(f"处理llm返回结果发生错误: {str(e)}")
|
||||||
chat_finished_event.set()
|
chat_finished_event.set()
|
||||||
|
@ -405,7 +315,7 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis):
|
||||||
session_id = await future_session_id #获取session_id
|
session_id = await future_session_id #获取session_id
|
||||||
update_session_activity(session_id,db)
|
update_session_activity(session_id,db)
|
||||||
response_type = await future_response_type #获取返回类型
|
response_type = await future_response_type #获取返回类型
|
||||||
asyncio.create_task(sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event))
|
asyncio.create_task(sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event))
|
||||||
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
||||||
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
||||||
|
|
||||||
|
@ -462,7 +372,6 @@ async def scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_eve
|
||||||
logger.debug("语音识别函数启动")
|
logger.debug("语音识别函数启动")
|
||||||
is_signup = False
|
is_signup = False
|
||||||
current_message = ""
|
current_message = ""
|
||||||
audio = ""
|
|
||||||
while not (input_finished_event.is_set() and user_input_q.empty()):
|
while not (input_finished_event.is_set() and user_input_q.empty()):
|
||||||
try:
|
try:
|
||||||
aduio_frame = await asyncio.wait_for(user_input_q.get(),timeout=3)
|
aduio_frame = await asyncio.wait_for(user_input_q.get(),timeout=3)
|
||||||
|
@ -472,23 +381,15 @@ async def scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_eve
|
||||||
if aduio_frame['is_end']:
|
if aduio_frame['is_end']:
|
||||||
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'], is_end=True)
|
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'], is_end=True)
|
||||||
current_message += ''.join(asr_result['text'])
|
current_message += ''.join(asr_result['text'])
|
||||||
current_message = asr.punctuation_correction(current_message)
|
|
||||||
audio += aduio_frame['audio']
|
|
||||||
emotion_dict =asr.emtion_recognition(audio) #情感辨识
|
|
||||||
if not isinstance(emotion_dict, str):
|
|
||||||
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
|
||||||
current_message = f"{current_message}当前说话人的情绪:{emotion_dict['labels'][max_index]}"
|
|
||||||
await llm_input_q.put(current_message)
|
await llm_input_q.put(current_message)
|
||||||
logger.debug(f"接收到用户消息: {current_message}")
|
logger.debug(f"接收到用户消息: {current_message}")
|
||||||
current_message = ""
|
|
||||||
audio = ""
|
|
||||||
else:
|
else:
|
||||||
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'])
|
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'])
|
||||||
audio += aduio_frame['audio']
|
|
||||||
current_message += ''.join(asr_result['text'])
|
current_message += ''.join(asr_result['text'])
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
asr.session_signout(session_id)
|
||||||
logger.error(f"语音识别函数发生错误: {str(e)}")
|
logger.error(f"语音识别函数发生错误: {str(e)}")
|
||||||
break
|
break
|
||||||
asr.session_signout(session_id)
|
asr.session_signout(session_id)
|
||||||
|
@ -505,7 +406,6 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
try:
|
try:
|
||||||
session_content = get_session_content(session_id,redis,db)
|
session_content = get_session_content(session_id,redis,db)
|
||||||
messages = json.loads(session_content["messages"])
|
messages = json.loads(session_content["messages"])
|
||||||
user_info = json.loads(session_content["user_info"])
|
|
||||||
current_message = await asyncio.wait_for(llm_input_q.get(),timeout=3)
|
current_message = await asyncio.wait_for(llm_input_q.get(),timeout=3)
|
||||||
messages.append({'role': 'user', "content": current_message})
|
messages.append({'role': 'user', "content": current_message})
|
||||||
payload = json.dumps({
|
payload = json.dumps({
|
||||||
|
@ -520,49 +420,39 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
||||||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
target_se = get_emb(session_id,db)
|
response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload,stream=True)
|
||||||
async with aiohttp.ClientSession() as client:
|
for chunk in response.iter_lines():
|
||||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
chunk_data = parseChunkDelta(chunk)
|
||||||
async for chunk in response.content.iter_any():
|
is_end = chunk_data == "end"
|
||||||
token_count, chunk_data = parseChunkDelta(chunk)
|
if not is_end:
|
||||||
is_end = token_count >0
|
llm_response += chunk_data
|
||||||
if not is_end:
|
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
|
||||||
llm_response += chunk_data
|
for sentence in sentences:
|
||||||
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
|
if response_type == RESPONSE_TEXT:
|
||||||
for sentence in sentences:
|
logger.debug(f"websocket返回: {sentence}")
|
||||||
if response_type == RESPONSE_TEXT:
|
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||||
logger.debug(f"websocket返回: {sentence}")
|
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
||||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
elif response_type == RESPONSE_AUDIO:
|
||||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
|
||||||
elif response_type == RESPONSE_AUDIO:
|
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||||
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
|
await ws.send_bytes(audio)
|
||||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
||||||
await ws.send_bytes(audio)
|
logger.debug(f"websocket返回: {sentence}")
|
||||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
if is_end:
|
||||||
logger.debug(f"websocket返回: {sentence}")
|
logger.debug(f"llm返回结果: {llm_response}")
|
||||||
if is_end:
|
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||||
logger.debug(f"llm返回结果: {llm_response}")
|
is_end = False
|
||||||
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
|
||||||
is_end = False
|
messages.append({'role': 'assistant', "content": llm_response})
|
||||||
messages.append({'role': 'assistant', "content": llm_response})
|
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
||||||
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
is_first = True
|
||||||
is_first = True
|
llm_response = ""
|
||||||
llm_response = ""
|
|
||||||
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于70%的最大token数,则进行文本摘要
|
|
||||||
system_prompt = messages[0]['content']
|
|
||||||
summary = await summarizer.summarize(messages)
|
|
||||||
events = user_info['events']
|
|
||||||
events.append(summary['event'])
|
|
||||||
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
|
|
||||||
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
|
|
||||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
|
|
||||||
logger.debug(f"总结后session_content: {session_content}")
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
# except Exception as e:
|
except Exception as e:
|
||||||
# logger.error(f"处理llm返回结果发生错误: {str(e)}")
|
logger.error(f"处理llm返回结果发生错误: {str(e)}")
|
||||||
# break
|
break
|
||||||
chat_finished_event.set()
|
chat_finished_event.set()
|
||||||
|
|
||||||
async def streaming_chat_lasting_handler(ws,db,redis):
|
async def streaming_chat_lasting_handler(ws,db,redis):
|
||||||
|
@ -633,7 +523,6 @@ async def voice_call_audio_consumer(ws,session_id,audio_q,asr_result_q,input_fin
|
||||||
current_message = ""
|
current_message = ""
|
||||||
vad_count = 0
|
vad_count = 0
|
||||||
is_signup = False
|
is_signup = False
|
||||||
audio = ""
|
|
||||||
while not (input_finished_event.is_set() and audio_q.empty()):
|
while not (input_finished_event.is_set() and audio_q.empty()):
|
||||||
try:
|
try:
|
||||||
if not is_signup:
|
if not is_signup:
|
||||||
|
@ -644,22 +533,14 @@ async def voice_call_audio_consumer(ws,session_id,audio_q,asr_result_q,input_fin
|
||||||
if vad_count > 0:
|
if vad_count > 0:
|
||||||
vad_count -= 1
|
vad_count -= 1
|
||||||
asr_result = asr.streaming_recognize(session_id, audio_data)
|
asr_result = asr.streaming_recognize(session_id, audio_data)
|
||||||
audio += audio_data
|
|
||||||
current_message += ''.join(asr_result['text'])
|
current_message += ''.join(asr_result['text'])
|
||||||
else:
|
else:
|
||||||
vad_count += 1
|
vad_count += 1
|
||||||
if vad_count >= 25: #连续25帧没有语音,则认为说完了
|
if vad_count >= 25: #连续25帧没有语音,则认为说完了
|
||||||
asr_result = asr.streaming_recognize(session_id, audio_data, is_end=True)
|
asr_result = asr.streaming_recognize(session_id, audio_data, is_end=True)
|
||||||
if current_message:
|
if current_message:
|
||||||
current_message = asr.punctuation_correction(current_message)
|
|
||||||
audio += audio_data
|
|
||||||
emotion_dict =asr.emtion_recognition(audio) #情感辨识
|
|
||||||
if not isinstance(emotion_dict, str):
|
|
||||||
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
|
||||||
current_message = f"{current_message}当前说话人的情绪:{emotion_dict['labels'][max_index]}"
|
|
||||||
logger.debug(f"检测到静默,用户输入为:{current_message}")
|
logger.debug(f"检测到静默,用户输入为:{current_message}")
|
||||||
await asr_result_q.put(current_message)
|
await asr_result_q.put(current_message)
|
||||||
audio = ""
|
|
||||||
text_response = {"type": "user_text", "code": 200, "msg": current_message}
|
text_response = {"type": "user_text", "code": 200, "msg": current_message}
|
||||||
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
||||||
current_message = ""
|
current_message = ""
|
||||||
|
@ -684,7 +565,6 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
||||||
try:
|
try:
|
||||||
session_content = get_session_content(session_id,redis,db)
|
session_content = get_session_content(session_id,redis,db)
|
||||||
messages = json.loads(session_content["messages"])
|
messages = json.loads(session_content["messages"])
|
||||||
user_info = json.loads(session_content["user_info"])
|
|
||||||
current_message = await asyncio.wait_for(asr_result_q.get(),timeout=3)
|
current_message = await asyncio.wait_for(asr_result_q.get(),timeout=3)
|
||||||
messages.append({'role': 'user', "content": current_message})
|
messages.append({'role': 'user', "content": current_message})
|
||||||
payload = json.dumps({
|
payload = json.dumps({
|
||||||
|
@ -695,43 +575,34 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
||||||
"temperature": llm_info["temperature"],
|
"temperature": llm_info["temperature"],
|
||||||
"top_p": llm_info["top_p"]
|
"top_p": llm_info["top_p"]
|
||||||
})
|
})
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
target_se = get_emb(session_id,db)
|
response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload,stream=True)
|
||||||
async with aiohttp.ClientSession() as client:
|
for chunk in response.iter_lines():
|
||||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
chunk_data = parseChunkDelta(chunk)
|
||||||
async for chunk in response.content.iter_any():
|
is_end = chunk_data == "end"
|
||||||
token_count, chunk_data = parseChunkDelta(chunk)
|
if not is_end:
|
||||||
is_end = token_count >0
|
llm_response += chunk_data
|
||||||
if not is_end:
|
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
|
||||||
llm_response += chunk_data
|
for sentence in sentences:
|
||||||
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
|
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
|
||||||
for sentence in sentences:
|
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
|
||||||
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
|
await ws.send_bytes(audio) #返回音频二进制流数据
|
||||||
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
|
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
||||||
await ws.send_bytes(audio) #返回音频二进制流数据
|
logger.debug(f"llm返回结果: {sentence}")
|
||||||
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
if is_end:
|
||||||
logger.debug(f"websocket返回: {sentence}")
|
logger.debug(f"llm返回结果: {llm_response}")
|
||||||
if is_end:
|
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||||
logger.debug(f"llm返回结果: {llm_response}")
|
is_end = False
|
||||||
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
|
||||||
is_end = False
|
messages.append({'role': 'assistant', "content": llm_response})
|
||||||
messages.append({'role': 'assistant', "content": llm_response})
|
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
||||||
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
is_first = True
|
||||||
is_first = True
|
llm_response = ""
|
||||||
llm_response = ""
|
|
||||||
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数,则进行文本摘要
|
|
||||||
system_prompt = messages[0]['content']
|
|
||||||
summary = await summarizer.summarize(messages)
|
|
||||||
events = user_info['events']
|
|
||||||
events.append(summary['event'])
|
|
||||||
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
|
|
||||||
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
|
|
||||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
|
|
||||||
logger.debug(f"总结后session_content: {session_content}")
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -739,6 +610,23 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
||||||
break
|
break
|
||||||
voice_call_end_event.set()
|
voice_call_end_event.set()
|
||||||
|
|
||||||
|
|
||||||
|
#语音合成及返回函数
|
||||||
|
async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event):
|
||||||
|
logger.debug("语音合成及返回函数启动")
|
||||||
|
while not (split_finished_event.is_set() and split_result_q.empty()):
|
||||||
|
try:
|
||||||
|
sentence = await asyncio.wait_for(split_result_q.get(),timeout=3)
|
||||||
|
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
|
||||||
|
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
|
||||||
|
await ws.send_bytes(audio) #返回音频二进制流数据
|
||||||
|
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
||||||
|
logger.debug(f"websocket返回:{sentence}")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
voice_call_end_event.set()
|
||||||
|
|
||||||
|
|
||||||
async def voice_call_handler(ws, db, redis):
|
async def voice_call_handler(ws, db, redis):
|
||||||
logger.debug("voice_call websocket 连接建立")
|
logger.debug("voice_call websocket 连接建立")
|
||||||
audio_q = asyncio.Queue() #音频队列
|
audio_q = asyncio.Queue() #音频队列
|
||||||
|
|
|
@ -77,32 +77,3 @@ async def update_session_handler(session_id, session_data:SessionUpdateRequest,
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
session_update_data = SessionUpdateData(updatedAt=datetime.now().isoformat())
|
session_update_data = SessionUpdateData(updatedAt=datetime.now().isoformat())
|
||||||
return SessionUpdateResponse(status="success",message="Session 更新成功",data=session_update_data)
|
return SessionUpdateResponse(status="success",message="Session 更新成功",data=session_update_data)
|
||||||
|
|
||||||
#更新Session中的Speaker Id信息
|
|
||||||
async def update_session_speaker_id_handler(session_id, session_data, db, redis):
|
|
||||||
existing_session = ""
|
|
||||||
if redis.exists(session_id):
|
|
||||||
existing_session = redis.get(session_id)
|
|
||||||
else:
|
|
||||||
existing_session = db.query(Session).filter(Session.id == session_id).first().content
|
|
||||||
if not existing_session:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
|
|
||||||
|
|
||||||
#更新Session字段
|
|
||||||
session = json.loads(existing_session)
|
|
||||||
session_llm_info = json.loads(session["llm_info"])
|
|
||||||
session_llm_info["speaker_id"] = session_data.speaker_id
|
|
||||||
session["llm_info"] = json.dumps(session_llm_info,ensure_ascii=False)
|
|
||||||
|
|
||||||
#存储Session
|
|
||||||
|
|
||||||
session_str = json.dumps(session,ensure_ascii=False)
|
|
||||||
redis.set(session_id, session_str)
|
|
||||||
try:
|
|
||||||
db.query(Session).filter(Session.id == session_id).update({"content": session_str})
|
|
||||||
db.commit()
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
||||||
session_update_data = SessionSpeakerUpdateData(updatedAt=datetime.now().isoformat())
|
|
||||||
return SessionSpeakerUpdateResponse(status="success",message="Session SpeakID更新成功",data=session_update_data)
|
|
|
@ -1,20 +1,14 @@
|
||||||
from ..schemas.user_schema import *
|
from ..schemas.user_schema import *
|
||||||
from ..dependencies.logger import get_logger
|
from ..dependencies.logger import get_logger
|
||||||
from ..dependencies.tts import get_tts
|
from ..models import User, Hardware
|
||||||
from ..models import User, Hardware, Audio
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
from pydub import AudioSegment
|
|
||||||
import numpy as np
|
|
||||||
import io
|
|
||||||
|
|
||||||
|
|
||||||
#依赖注入获取logger
|
#依赖注入获取logger
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
#依赖注入获取tts
|
|
||||||
tts = get_tts("OPENVOICE")
|
|
||||||
|
|
||||||
#创建用户
|
#创建用户
|
||||||
async def create_user_handler(user:UserCrateRequest, db: Session):
|
async def create_user_handler(user:UserCrateRequest, db: Session):
|
||||||
|
@ -42,6 +36,7 @@ async def update_user_handler(user_id:int, user:UserUpdateRequest, db: Session):
|
||||||
existing_user.persona = user.persona
|
existing_user.persona = user.persona
|
||||||
try:
|
try:
|
||||||
db.commit()
|
db.commit()
|
||||||
|
db.refresh(existing_user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
@ -122,6 +117,7 @@ async def change_bind_hardware_handler(hardware_id, user, db):
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
|
||||||
existing_hardware.user_id = user.user_id
|
existing_hardware.user_id = user.user_id
|
||||||
db.commit()
|
db.commit()
|
||||||
|
db.refresh(existing_hardware)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
@ -139,6 +135,7 @@ async def update_hardware_handler(hardware_id, hardware, db):
|
||||||
existing_hardware.firmware = hardware.firmware
|
existing_hardware.firmware = hardware.firmware
|
||||||
existing_hardware.model = hardware.model
|
existing_hardware.model = hardware.model
|
||||||
db.commit()
|
db.commit()
|
||||||
|
db.refresh(existing_hardware)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
@ -157,96 +154,3 @@ async def get_hardware_handler(hardware_id, db):
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
|
||||||
hardware_query_data = HardwareQueryData(mac=existing_hardware.mac, user_id=existing_hardware.user_id, firmware=existing_hardware.firmware, model=existing_hardware.model)
|
hardware_query_data = HardwareQueryData(mac=existing_hardware.mac, user_id=existing_hardware.user_id, firmware=existing_hardware.firmware, model=existing_hardware.model)
|
||||||
return HardwareQueryResponse(status="success", message="查询硬件信息成功", data=hardware_query_data)
|
return HardwareQueryResponse(status="success", message="查询硬件信息成功", data=hardware_query_data)
|
||||||
|
|
||||||
|
|
||||||
#用户上传音频
|
|
||||||
async def upload_audio_handler(user_id, audio, db):
|
|
||||||
try:
|
|
||||||
audio_data = await audio.read()
|
|
||||||
emb_data = tts.audio2emb(np.frombuffer(AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data, dtype=np.int32),rate=44100,vad=True)
|
|
||||||
out = io.BytesIO()
|
|
||||||
np.save(out, emb_data)
|
|
||||||
out.seek(0)
|
|
||||||
emb_binary = out.read()
|
|
||||||
new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_binary) #创建音频
|
|
||||||
db.add(new_audio)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(new_audio)
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
||||||
audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat())
|
|
||||||
return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data)
|
|
||||||
|
|
||||||
|
|
||||||
#用户更新音频
|
|
||||||
async def update_audio_handler(audio_id, audio_file, db):
|
|
||||||
try:
|
|
||||||
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
|
|
||||||
if existing_audio is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
|
|
||||||
audio_data = await audio_file.read()
|
|
||||||
raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data
|
|
||||||
emb_data = tts.audio2emb(np.frombuffer(raw_data, dtype=np.int32),rate=44100,vad=True).tobytes()
|
|
||||||
existing_audio.audio_data = audio_data
|
|
||||||
existing_audio.emb_data = emb_data
|
|
||||||
db.commit()
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
||||||
audio_update_data = AudioUpdateData(updatedAt=datetime.now().isoformat())
|
|
||||||
return AudioUpdateResponse(status="success", message="用户更新音频成功", data=audio_update_data)
|
|
||||||
|
|
||||||
|
|
||||||
#用户查询音频
|
|
||||||
async def download_audio_handler(audio_id, db):
|
|
||||||
try:
|
|
||||||
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
||||||
if existing_audio is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
|
|
||||||
audio_data = existing_audio.audio_data
|
|
||||||
return audio_data
|
|
||||||
|
|
||||||
|
|
||||||
#用户删除音频
|
|
||||||
async def delete_audio_handler(audio_id, db):
|
|
||||||
try:
|
|
||||||
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
|
|
||||||
existing_user = db.query(User).filter(User.selected_audio_id == audio_id).first()
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
||||||
if existing_audio is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
|
|
||||||
try:
|
|
||||||
if existing_user.selected_audio_id == audio_id:
|
|
||||||
existing_user.selected_audio_id = None
|
|
||||||
db.delete(existing_audio)
|
|
||||||
db.commit()
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
||||||
audio_delete_data = AudioDeleteData(deletedAt=datetime.now().isoformat())
|
|
||||||
return AudioDeleteResponse(status="success", message="用户删除音频成功", data=audio_delete_data)
|
|
||||||
|
|
||||||
|
|
||||||
#用户绑定音频
|
|
||||||
async def bind_audio_handler(bind_req, db):
|
|
||||||
try:
|
|
||||||
existing_user = db.query(User).filter(User.id == bind_req.user_id).first()
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
||||||
if existing_user is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
|
|
||||||
try:
|
|
||||||
existing_user.selected_audio_id = bind_req.audio_id
|
|
||||||
db.commit()
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
||||||
audio_bind_data = AudioBindData(bindedAt=datetime.now().isoformat())
|
|
||||||
return AudioBindResponse(status="success", message="用户绑定音频成功", data=audio_bind_data)
|
|
|
@ -1,11 +0,0 @@
|
||||||
from utils.stt.modified_funasr import ModifiedRecognizer
|
|
||||||
from app.dependencies.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger()
|
|
||||||
|
|
||||||
#初始化全局asr对象
|
|
||||||
asr = ModifiedRecognizer()
|
|
||||||
logger.info("ASR初始化成功")
|
|
||||||
|
|
||||||
def get_asr():
|
|
||||||
return asr
|
|
|
@ -1,61 +0,0 @@
|
||||||
import aiohttp
|
|
||||||
import json
|
|
||||||
from datetime import datetime
|
|
||||||
from config import get_config
|
|
||||||
|
|
||||||
|
|
||||||
# 依赖注入获取Config
|
|
||||||
Config = get_config()
|
|
||||||
|
|
||||||
class Summarizer:
|
|
||||||
def __init__(self):
|
|
||||||
self.system_prompt = """你是一台对话总结机器,你的职责是整理用户与玩具之间的对话,最终提炼出对话中发生的事件,以及用户性格\n\n你的输出必须为一个json,里面有两个字段,一个是event,一个是character,将你总结出的事件写入event,将你总结出的用户性格写入character\nevent和character均为字符串\n返回示例:{"event":"在幼儿园看葫芦娃,老鹰抓小鸡","character":"活泼可爱"}"""
|
|
||||||
self.model = "abab5.5-chat"
|
|
||||||
self.max_token = 10000
|
|
||||||
self.temperature = 0.9
|
|
||||||
self.top_p = 1
|
|
||||||
|
|
||||||
async def summarize(self,messages):
|
|
||||||
context = ""
|
|
||||||
for message in messages:
|
|
||||||
if message['role'] == 'user':
|
|
||||||
context += "用户:"+ message['content'] + '\n'
|
|
||||||
elif message['role'] == 'assistant':
|
|
||||||
context += '玩具:'+ message['content'] + '\n'
|
|
||||||
payload = json.dumps({
|
|
||||||
"model":self.model,
|
|
||||||
"messages":[
|
|
||||||
{
|
|
||||||
"role":"system",
|
|
||||||
"content":self.system_prompt
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role":"user",
|
|
||||||
"content":context
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens":self.max_token,
|
|
||||||
"top_p":self.top_p
|
|
||||||
})
|
|
||||||
headers = {
|
|
||||||
'Authorization': f'Bearer {Config.MINIMAX_LLM.API_KEY}',
|
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(Config.MINIMAX_LLM.URL, data=payload, headers=headers) as response:
|
|
||||||
content = json.loads(json.loads(await response.text())['choices'][0]['message']['content'])
|
|
||||||
try:
|
|
||||||
summary = {
|
|
||||||
'event':datetime.now().strftime("%Y-%m-%d")+":"+content['event'],
|
|
||||||
'character':content['character']
|
|
||||||
}
|
|
||||||
except TypeError:
|
|
||||||
summary = {
|
|
||||||
'event':datetime.now().strftime("%Y-%m-%d")+":"+content['event'][0],
|
|
||||||
'character':""
|
|
||||||
}
|
|
||||||
return summary
|
|
||||||
|
|
||||||
def get_summarizer():
|
|
||||||
summarizer = Summarizer()
|
|
||||||
return summarizer
|
|
|
@ -1,23 +0,0 @@
|
||||||
|
|
||||||
from app.dependencies.logger import get_logger
|
|
||||||
from config import get_config
|
|
||||||
|
|
||||||
logger = get_logger()
|
|
||||||
Config = get_config()
|
|
||||||
|
|
||||||
from utils.tts.openvoice_utils import TextToSpeech
|
|
||||||
openvoice_tts = TextToSpeech(use_tone_convert=True,device='cuda')
|
|
||||||
logger.info("TTS_OPENVOICE 初始化成功")
|
|
||||||
|
|
||||||
from utils.tts.vits_utils import TextToSpeech
|
|
||||||
vits_tts = TextToSpeech()
|
|
||||||
logger.info("TTS_VITS 初始化成功")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#初始化全局tts对象
|
|
||||||
def get_tts(tts_type=Config.TTS_UTILS):
|
|
||||||
if tts_type == "OPENVOICE":
|
|
||||||
return openvoice_tts
|
|
||||||
elif tts_type == "VITS":
|
|
||||||
return vits_tts
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
from app import app, Config
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run(app, host=Config.UVICORN.HOST, port=Config.UVICORN.PORT)
|
||||||
|
#uvicorn.run("app.main:app", host=Config.UVICORN.HOST, port=Config.UVICORN.PORT, workers=Config.UVICORN.WORKERS)
|
||||||
|
# _ooOoo_ #
|
||||||
|
# o8888888o #
|
||||||
|
# 88" . "88 #
|
||||||
|
# (| -_- |) #
|
||||||
|
# O\ = /O #
|
||||||
|
# ____/`---'\____ #
|
||||||
|
# . ' \\| |// `. #
|
||||||
|
# / \\||| : |||// \ #
|
||||||
|
# / _||||| -:- |||||- \ #
|
||||||
|
# | | \\\ - /// | | #
|
||||||
|
# \ .-\__ `-` ___/-. / #
|
||||||
|
# ___`. .' /--.--\ `. . __ #
|
||||||
|
# ."" '< `.___\_<|>_/___.' >'"". #
|
||||||
|
# | | : `- \`.;`\ _ /`;.`/ - ` : | | #
|
||||||
|
# \ \ `-. \_ __\ /__ _/ .-` / / #
|
||||||
|
# ======`-.____`-.___\_____/___.-`____.-'====== #
|
||||||
|
# `=---=' #
|
||||||
|
# ............................................. #
|
||||||
|
# 佛祖保佑 永无BUG #
|
|
@ -1,4 +1,4 @@
|
||||||
from sqlalchemy import Column, Integer, String, JSON, Text, ForeignKey, DateTime, Boolean, CHAR, LargeBinary
|
from sqlalchemy import Column, Integer, String, JSON, Text, ForeignKey, DateTime, Boolean, CHAR
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
@ -36,7 +36,6 @@ class User(Base):
|
||||||
avatar_id = Column(String(36), nullable=True)
|
avatar_id = Column(String(36), nullable=True)
|
||||||
tags = Column(JSON)
|
tags = Column(JSON)
|
||||||
persona = Column(JSON)
|
persona = Column(JSON)
|
||||||
selected_audio_id = Column(Integer, nullable=True)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<User(id={self.id}, tags={self.tags})>"
|
return f"<User(id={self.id}, tags={self.tags})>"
|
||||||
|
@ -81,12 +80,3 @@ class Session(Base):
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<Session(id={self.id}, user_character_id={self.user_character_id})>"
|
return f"<Session(id={self.id}, user_character_id={self.user_character_id})>"
|
||||||
|
|
||||||
#音频表定义
|
|
||||||
class Audio(Base):
|
|
||||||
__tablename__ = 'audio'
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
user_id = Column(Integer, ForeignKey('user.id'))
|
|
||||||
audio_data = Column(LargeBinary(16777215))
|
|
||||||
emb_data = Column(LargeBinary(65535))
|
|
||||||
|
|
|
@ -27,10 +27,3 @@ async def get_session(session_id: str, db=Depends(get_db), redis=Depends(get_red
|
||||||
async def update_session(session_id: str, session_data: SessionUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
|
async def update_session(session_id: str, session_data: SessionUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
|
||||||
response = await update_session_handler(session_id, session_data, db, redis)
|
response = await update_session_handler(session_id, session_data, db, redis)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
#session声音信息更新接口
|
|
||||||
@router.put("/sessions/tts_info/speaker_id/{session_id}", response_model=SessionSpeakerUpdateResponse)
|
|
||||||
async def update_session_speaker_id(session_id: str, session_data: SessionSpeakerUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
|
|
||||||
response = await update_session_speaker_id_handler(session_id, session_data, db, redis)
|
|
||||||
return response
|
|
|
@ -1,4 +1,4 @@
|
||||||
from fastapi import APIRouter, UploadFile, File, Response
|
from fastapi import APIRouter, HTTPException, status
|
||||||
from ..controllers.user_controller import *
|
from ..controllers.user_controller import *
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
@ -69,38 +69,3 @@ async def update_hardware_info(hardware_id: int, hardware: HardwareUpdateRequest
|
||||||
async def get_hardware(hardware_id: int, db: Session = Depends(get_db)):
|
async def get_hardware(hardware_id: int, db: Session = Depends(get_db)):
|
||||||
response = await get_hardware_handler(hardware_id, db)
|
response = await get_hardware_handler(hardware_id, db)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
#用户音频上传
|
|
||||||
@router.post('/users/audio',response_model=AudioUploadResponse)
|
|
||||||
async def upload_audio(user_id:int, audio_file:UploadFile=File(...), db: Session = Depends(get_db)):
|
|
||||||
response = await upload_audio_handler(user_id, audio_file, db)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
#用户音频修改
|
|
||||||
@router.put('/users/audio/{audio_id}',response_model=AudioUpdateResponse)
|
|
||||||
async def update_audio(audio_id:int, audio_file:UploadFile=File(...), db: Session = Depends(get_db)):
|
|
||||||
response = await update_audio_handler(audio_id, audio_file, db)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
#用户音频下载
|
|
||||||
@router.get('/users/audio/{audio_id}')
|
|
||||||
async def download_audio(audio_id:int, db: Session = Depends(get_db)):
|
|
||||||
audio_data = await download_audio_handler(audio_id, db)
|
|
||||||
return Response(content=audio_data,media_type='application/octet-stream',headers={"Content-Disposition": "attachment"})
|
|
||||||
|
|
||||||
|
|
||||||
#用户音频删除
|
|
||||||
@router.delete('/users/audio/{audio_id}',response_model=AudioDeleteResponse)
|
|
||||||
async def delete_audio(audio_id:int, db: Session = Depends(get_db)):
|
|
||||||
response = await delete_audio_handler(audio_id, db)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
#用户绑定音频
|
|
||||||
@router.post('/users/audio/bind',response_model=AudioBindResponse)
|
|
||||||
async def bind_audio(bind_req:AudioBindRequest, db: Session = Depends(get_db)):
|
|
||||||
response = await bind_audio_handler(bind_req, db)
|
|
||||||
return response
|
|
|
@ -56,15 +56,3 @@ class SessionUpdateData(BaseModel):
|
||||||
class SessionUpdateResponse(BaseResponse):
|
class SessionUpdateResponse(BaseResponse):
|
||||||
data: Optional[SessionUpdateData]
|
data: Optional[SessionUpdateData]
|
||||||
#--------------------------------------------------------------------------
|
#--------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
#------------------------------Session Speaker Id修改----------------------
|
|
||||||
class SessionSpeakerUpdateRequest(BaseModel):
|
|
||||||
speaker_id: int
|
|
||||||
|
|
||||||
class SessionSpeakerUpdateData(BaseModel):
|
|
||||||
updatedAt:str
|
|
||||||
|
|
||||||
class SessionSpeakerUpdateResponse(BaseResponse):
|
|
||||||
data: Optional[SessionSpeakerUpdateData]
|
|
||||||
#--------------------------------------------------------------------------
|
|
|
@ -3,6 +3,7 @@ from typing import Optional
|
||||||
from .base_schema import BaseResponse
|
from .base_schema import BaseResponse
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#---------------------------------用户创建----------------------------------
|
#---------------------------------用户创建----------------------------------
|
||||||
#用户创建请求类
|
#用户创建请求类
|
||||||
class UserCrateRequest(BaseModel):
|
class UserCrateRequest(BaseModel):
|
||||||
|
@ -137,46 +138,3 @@ class HardwareQueryData(BaseModel):
|
||||||
class HardwareQueryResponse(BaseResponse):
|
class HardwareQueryResponse(BaseResponse):
|
||||||
data: Optional[HardwareQueryData]
|
data: Optional[HardwareQueryData]
|
||||||
#------------------------------------------------------------------------------
|
#------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#-------------------------------用户音频上传-------------------------------------
|
|
||||||
class AudioUploadData(BaseModel):
|
|
||||||
audio_id: int
|
|
||||||
uploadedAt: str
|
|
||||||
|
|
||||||
class AudioUploadResponse(BaseResponse):
|
|
||||||
data: Optional[AudioUploadData]
|
|
||||||
#-------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
#-------------------------------用户音频修改-------------------------------------
|
|
||||||
class AudioUpdateData(BaseModel):
|
|
||||||
updatedAt: str
|
|
||||||
|
|
||||||
class AudioUpdateResponse(BaseResponse):
|
|
||||||
data: Optional[AudioUpdateData]
|
|
||||||
#-------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
#-------------------------------用户音频删除-------------------------------------
|
|
||||||
class AudioDeleteData(BaseModel):
|
|
||||||
deletedAt: str
|
|
||||||
|
|
||||||
class AudioDeleteResponse(BaseResponse):
|
|
||||||
data: Optional[AudioDeleteData]
|
|
||||||
#-------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
#-------------------------------用户音频绑定-------------------------------------
|
|
||||||
class AudioBindRequest(BaseModel):
|
|
||||||
audio_id: int
|
|
||||||
user_id: int
|
|
||||||
|
|
||||||
class AudioBindData(BaseModel):
|
|
||||||
bindedAt: str
|
|
||||||
|
|
||||||
class AudioBindResponse(BaseResponse):
|
|
||||||
data: Optional[AudioBindData]
|
|
||||||
#-------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
|
@ -1,28 +1,27 @@
|
||||||
class DevelopmentConfig:
|
class DevelopmentConfig:
|
||||||
SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://takway:takway123456@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置
|
SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://admin02:LabA100102@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置
|
||||||
REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置
|
REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置
|
||||||
LOG_LEVEL = "DEBUG" #日志级别
|
LOG_LEVEL = "DEBUG" #日志级别
|
||||||
TTS_UTILS = "VITS" #TTS引擎配置,可选OPENVOICE或者VITS
|
|
||||||
class UVICORN:
|
class UVICORN:
|
||||||
HOST = "0.0.0.0" #uvicorn放行ip,0.0.0.0代表所有ip
|
HOST = "0.0.0.0" #uvicorn放行ip,0.0.0.0代表所有ip
|
||||||
PORT = 8001 #uvicorn运行端口
|
PORT = 7878 #uvicorn运行端口
|
||||||
WORKERS = 12 #uvicorn进程数(通常与cpu核数相同)
|
WORKERS = 12 #uvicorn进程数(通常与cpu核数相同)
|
||||||
class XF_ASR:
|
class XF_ASR:
|
||||||
APP_ID = "f1c121c1" #讯飞语音识别APP_ID
|
APP_ID = "your_app_id" #讯飞语音识别APP_ID
|
||||||
API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET
|
API_SECRET = "your_api_secret" #讯飞语音识别API_SECRET
|
||||||
API_KEY = "36b316c7977fa534ae1e3bf52157bb92" #讯飞语音识别API_KEY
|
API_KEY = "your_api_key" #讯飞语音识别API_KEY
|
||||||
DOMAIN = "iat"
|
DOMAIN = "iat"
|
||||||
LANGUAGE = "zh_cn"
|
LANGUAGE = "zh_cn"
|
||||||
ACCENT = "mandarin"
|
ACCENT = "mandarin"
|
||||||
VAD_EOS = 10000
|
VAD_EOS = 10000
|
||||||
class MINIMAX_LLM:
|
class MINIMAX_LLM:
|
||||||
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLph5EiLCJVc2VyTmFtZSI6IumHkSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzY4NTM2NDM3MzE1MDgwODg2IiwiUGhvbmUiOiIxMzEzNjE0NzUyNyIsIkdyb3VwSUQiOiIxNzY4NTM2NDM3MzA2NjkyMjc4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjQtMDUtMTggMTY6MTQ6MDMiLCJpc3MiOiJtaW5pbWF4In0.LypYOkJXwKV6GzDM1dcNn4L0m19o8Q_Lvmn6SkMMb9WAfDJYxEnTc5odm-L4WAWfbur_gY0cQzgoHnI14t4XSaAvqfmcdCrKYpJbKoBmMse_RogJs7KOBt658je3wES4pBUKQll6NbogQB1f93lnA9IYv4aEVldfqglbCikd54XO8E9Ptn4gX9Mp8fUn3lCpZ6_OSlmgZsQySrmt1sDHHzi3DlkdXlFSI38TQSZIa5RhFpI8WSBLIbaKl84OhaDzo7v99k9DUCzb5JGh0eZOnUT0YswbKCPeV8rZ1XUiOVQrna1uiDLvqv54aIt3vsu-LypYmnHxtZ_z4u2gt87pZg"
|
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiIyMzQ1dm9yIiwiVXNlck5hbWUiOiIyMzQ1dm9yIiwiQWNjb3VudCI6IiIsIlN1YmplY3RJRCI6IjE3NTk0ODIxODAxMDAxNzAyMDgiLCJQaG9uZSI6IjE1MDcyNjQxNTYxIiwiR3JvdXBJRCI6IjE3NTk0ODIxODAwOTU5NzU5MDQiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiIiLCJDcmVhdGVUaW1lIjoiMjAyNC0wNC0xMyAxOTowNDoxNyIsImlzcyI6Im1pbmltYXgifQ.RO_WJMz5T0XlL3F6xB9p015hL3PibCbsr5KqO3aMjBL5hKrf1uIjOICTDZWZoucyJV1suxvFPAd_2Ds2Rv01eCu6GFdai1hUByfp51mOOD0PtaZ5-JKRpRPpLSNpqrNoQteANZz0gdr2_GEGTgTzpbfGbXfRYKrQyeQSvq0zHwqumGPd9gJCre2RavPUmzKRrq9EAaQXtSNhBvVkf5lDlxr8fTAHgbj6MLAJZIvvf4uOZErNrbPylo1Vcy649KxEkc0HCWOZErOieeUQFRkKibnE5Q30CgywqxY2qMjrxGRZ_dtizan_0EZ62nXp-J6jarhcY9le1SqiMu1Cv61TuA"
|
||||||
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
||||||
class MINIMAX_TTA:
|
class MINIMAX_TTA:
|
||||||
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiIyMzQ1dm9yIiwiVXNlck5hbWUiOiIyMzQ1dm9yIiwiQWNjb3VudCI6IiIsIlN1YmplY3RJRCI6IjE3NTk0ODIxODAxMDAxNzAyMDgiLCJQaG9uZSI6IjE1MDcyNjQxNTYxIiwiR3JvdXBJRCI6IjE3NTk0ODIxODAwOTU5NzU5MDQiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiIiLCJDcmVhdGVUaW1lIjoiMjAyNC0wNC0xMyAxOTowNDoxNyIsImlzcyI6Im1pbmltYXgifQ.RO_WJMz5T0XlL3F6xB9p015hL3PibCbsr5KqO3aMjBL5hKrf1uIjOICTDZWZoucyJV1suxvFPAd_2Ds2Rv01eCu6GFdai1hUByfp51mOOD0PtaZ5-JKRpRPpLSNpqrNoQteANZz0gdr2_GEGTgTzpbfGbXfRYKrQyeQSvq0zHwqumGPd9gJCre2RavPUmzKRrq9EAaQXtSNhBvVkf5lDlxr8fTAHgbj6MLAJZIvvf4uOZErNrbPylo1Vcy649KxEkc0HCWOZErOieeUQFRkKibnE5Q30CgywqxY2qMjrxGRZ_dtizan_0EZ62nXp-J6jarhcY9le1SqiMu1Cv61TuA",
|
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiIyMzQ1dm9yIiwiVXNlck5hbWUiOiIyMzQ1dm9yIiwiQWNjb3VudCI6IiIsIlN1YmplY3RJRCI6IjE3NTk0ODIxODAxMDAxNzAyMDgiLCJQaG9uZSI6IjE1MDcyNjQxNTYxIiwiR3JvdXBJRCI6IjE3NTk0ODIxODAwOTU5NzU5MDQiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiIiLCJDcmVhdGVUaW1lIjoiMjAyNC0wNC0xMyAxOTowNDoxNyIsImlzcyI6Im1pbmltYXgifQ.RO_WJMz5T0XlL3F6xB9p015hL3PibCbsr5KqO3aMjBL5hKrf1uIjOICTDZWZoucyJV1suxvFPAd_2Ds2Rv01eCu6GFdai1hUByfp51mOOD0PtaZ5-JKRpRPpLSNpqrNoQteANZz0gdr2_GEGTgTzpbfGbXfRYKrQyeQSvq0zHwqumGPd9gJCre2RavPUmzKRrq9EAaQXtSNhBvVkf5lDlxr8fTAHgbj6MLAJZIvvf4uOZErNrbPylo1Vcy649KxEkc0HCWOZErOieeUQFRkKibnE5Q30CgywqxY2qMjrxGRZ_dtizan_0EZ62nXp-J6jarhcY9le1SqiMu1Cv61TuA",
|
||||||
URL = "https://api.minimax.chat/v1/t2a_pro",
|
URL = "https://api.minimax.chat/v1/t2a_pro",
|
||||||
GROUP_ID ="1759482180095975904"
|
GROUP_ID ="1759482180095975904"
|
||||||
class STRAM_CHAT:
|
class STRAM_CHAT:
|
||||||
ASR = "XF" # 语音识别引擎,可选XF或者LOCAL
|
ASR = "LOCAL"
|
||||||
TTS = "LOCAL"
|
TTS = "LOCAL"
|
||||||
|
|
Binary file not shown.
|
@ -0,0 +1,284 @@
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import numpy as np
|
||||||
|
import pyaudio
|
||||||
|
import wave
|
||||||
|
import base64
|
||||||
|
"""
|
||||||
|
audio utils for modified_funasr_demo.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decode_str2bytes(data):
|
||||||
|
# 将Base64编码的字节串解码为字节串
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
return base64.b64decode(data.encode('utf-8'))
|
||||||
|
|
||||||
|
class BaseAudio:
|
||||||
|
def __init__(self,
|
||||||
|
filename=None,
|
||||||
|
input=False,
|
||||||
|
output=False,
|
||||||
|
CHUNK=1024,
|
||||||
|
FORMAT=pyaudio.paInt16,
|
||||||
|
CHANNELS=1,
|
||||||
|
RATE=16000,
|
||||||
|
input_device_index=None,
|
||||||
|
output_device_index=None,
|
||||||
|
**kwargs):
|
||||||
|
self.CHUNK = CHUNK
|
||||||
|
self.FORMAT = FORMAT
|
||||||
|
self.CHANNELS = CHANNELS
|
||||||
|
self.RATE = RATE
|
||||||
|
self.filename = filename
|
||||||
|
assert input!= output, "input and output cannot be the same, \
|
||||||
|
but got input={} and output={}.".format(input, output)
|
||||||
|
print("------------------------------------------")
|
||||||
|
print(f"{'Input' if input else 'Output'} Audio Initialization: ")
|
||||||
|
print(f"CHUNK: {self.CHUNK} \nFORMAT: {self.FORMAT} \nCHANNELS: {self.CHANNELS} \nRATE: {self.RATE} \ninput_device_index: {input_device_index} \noutput_device_index: {output_device_index}")
|
||||||
|
print("------------------------------------------")
|
||||||
|
self.p = pyaudio.PyAudio()
|
||||||
|
self.stream = self.p.open(format=FORMAT,
|
||||||
|
channels=CHANNELS,
|
||||||
|
rate=RATE,
|
||||||
|
input=input,
|
||||||
|
output=output,
|
||||||
|
input_device_index=input_device_index,
|
||||||
|
output_device_index=output_device_index,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def load_audio_file(self, wav_file):
|
||||||
|
with wave.open(wav_file, 'rb') as wf:
|
||||||
|
params = wf.getparams()
|
||||||
|
frames = wf.readframes(params.nframes)
|
||||||
|
print("Audio file loaded.")
|
||||||
|
# Audio Parameters
|
||||||
|
# print("Channels:", params.nchannels)
|
||||||
|
# print("Sample width:", params.sampwidth)
|
||||||
|
# print("Frame rate:", params.framerate)
|
||||||
|
# print("Number of frames:", params.nframes)
|
||||||
|
# print("Compression type:", params.comptype)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
def check_audio_type(self, audio_data, return_type=None):
|
||||||
|
assert return_type in ['bytes', 'io', None], \
|
||||||
|
"return_type should be 'bytes', 'io' or None."
|
||||||
|
if isinstance(audio_data, str):
|
||||||
|
if len(audio_data) > 50:
|
||||||
|
audio_data = decode_str2bytes(audio_data)
|
||||||
|
else:
|
||||||
|
assert os.path.isfile(audio_data), \
|
||||||
|
"audio_data should be a file path or a bytes object."
|
||||||
|
wf = wave.open(audio_data, 'rb')
|
||||||
|
audio_data = wf.readframes(wf.getnframes())
|
||||||
|
elif isinstance(audio_data, np.ndarray):
|
||||||
|
if audio_data.dtype == np.dtype('float32'):
|
||||||
|
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
|
||||||
|
audio_data = audio_data.tobytes()
|
||||||
|
elif isinstance(audio_data, bytes):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise TypeError(f"audio_data must be bytes, numpy.ndarray or str, \
|
||||||
|
but got {type(audio_data)}")
|
||||||
|
|
||||||
|
if return_type == None:
|
||||||
|
return audio_data
|
||||||
|
return self.write_wave(None, [audio_data], return_type)
|
||||||
|
|
||||||
|
def write_wave(self, filename, frames, return_type='io'):
|
||||||
|
"""Write audio data to a file."""
|
||||||
|
if isinstance(frames, bytes):
|
||||||
|
frames = [frames]
|
||||||
|
if not isinstance(frames, list):
|
||||||
|
raise TypeError("frames should be \
|
||||||
|
a list of bytes or a bytes object, \
|
||||||
|
but got {}.".format(type(frames)))
|
||||||
|
|
||||||
|
if return_type == 'io':
|
||||||
|
if filename is None:
|
||||||
|
filename = io.BytesIO()
|
||||||
|
if self.filename:
|
||||||
|
filename = self.filename
|
||||||
|
return self.write_wave_io(filename, frames)
|
||||||
|
elif return_type == 'bytes':
|
||||||
|
return self.write_wave_bytes(frames)
|
||||||
|
|
||||||
|
|
||||||
|
def write_wave_io(self, filename, frames):
|
||||||
|
"""
|
||||||
|
Write audio data to a file-like object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: [string or file-like object], file path or file-like object to write
|
||||||
|
frames: list of bytes, audio data to write
|
||||||
|
"""
|
||||||
|
wf = wave.open(filename, 'wb')
|
||||||
|
|
||||||
|
# 设置WAV文件的参数
|
||||||
|
wf.setnchannels(self.CHANNELS)
|
||||||
|
wf.setsampwidth(self.p.get_sample_size(self.FORMAT))
|
||||||
|
wf.setframerate(self.RATE)
|
||||||
|
wf.writeframes(b''.join(frames))
|
||||||
|
wf.close()
|
||||||
|
if isinstance(filename, io.BytesIO):
|
||||||
|
filename.seek(0) # reset file pointer to beginning
|
||||||
|
return filename
|
||||||
|
|
||||||
|
def write_wave_bytes(self, frames):
|
||||||
|
"""Write audio data to a bytes object."""
|
||||||
|
return b''.join(frames)
|
||||||
|
class BaseAudio:
|
||||||
|
def __init__(self,
|
||||||
|
filename=None,
|
||||||
|
input=False,
|
||||||
|
output=False,
|
||||||
|
CHUNK=1024,
|
||||||
|
FORMAT=pyaudio.paInt16,
|
||||||
|
CHANNELS=1,
|
||||||
|
RATE=16000,
|
||||||
|
input_device_index=None,
|
||||||
|
output_device_index=None,
|
||||||
|
**kwargs):
|
||||||
|
self.CHUNK = CHUNK
|
||||||
|
self.FORMAT = FORMAT
|
||||||
|
self.CHANNELS = CHANNELS
|
||||||
|
self.RATE = RATE
|
||||||
|
self.filename = filename
|
||||||
|
assert input!= output, "input and output cannot be the same, \
|
||||||
|
but got input={} and output={}.".format(input, output)
|
||||||
|
print("------------------------------------------")
|
||||||
|
print(f"{'Input' if input else 'Output'} Audio Initialization: ")
|
||||||
|
print(f"CHUNK: {self.CHUNK} \nFORMAT: {self.FORMAT} \nCHANNELS: {self.CHANNELS} \nRATE: {self.RATE} \ninput_device_index: {input_device_index} \noutput_device_index: {output_device_index}")
|
||||||
|
print("------------------------------------------")
|
||||||
|
self.p = pyaudio.PyAudio()
|
||||||
|
self.stream = self.p.open(format=FORMAT,
|
||||||
|
channels=CHANNELS,
|
||||||
|
rate=RATE,
|
||||||
|
input=input,
|
||||||
|
output=output,
|
||||||
|
input_device_index=input_device_index,
|
||||||
|
output_device_index=output_device_index,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def load_audio_file(self, wav_file):
|
||||||
|
with wave.open(wav_file, 'rb') as wf:
|
||||||
|
params = wf.getparams()
|
||||||
|
frames = wf.readframes(params.nframes)
|
||||||
|
print("Audio file loaded.")
|
||||||
|
# Audio Parameters
|
||||||
|
# print("Channels:", params.nchannels)
|
||||||
|
# print("Sample width:", params.sampwidth)
|
||||||
|
# print("Frame rate:", params.framerate)
|
||||||
|
# print("Number of frames:", params.nframes)
|
||||||
|
# print("Compression type:", params.comptype)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
def check_audio_type(self, audio_data, return_type=None):
|
||||||
|
assert return_type in ['bytes', 'io', None], \
|
||||||
|
"return_type should be 'bytes', 'io' or None."
|
||||||
|
if isinstance(audio_data, str):
|
||||||
|
if len(audio_data) > 50:
|
||||||
|
audio_data = decode_str2bytes(audio_data)
|
||||||
|
else:
|
||||||
|
assert os.path.isfile(audio_data), \
|
||||||
|
"audio_data should be a file path or a bytes object."
|
||||||
|
wf = wave.open(audio_data, 'rb')
|
||||||
|
audio_data = wf.readframes(wf.getnframes())
|
||||||
|
elif isinstance(audio_data, np.ndarray):
|
||||||
|
if audio_data.dtype == np.dtype('float32'):
|
||||||
|
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
|
||||||
|
audio_data = audio_data.tobytes()
|
||||||
|
elif isinstance(audio_data, bytes):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise TypeError(f"audio_data must be bytes, numpy.ndarray or str, \
|
||||||
|
but got {type(audio_data)}")
|
||||||
|
|
||||||
|
if return_type == None:
|
||||||
|
return audio_data
|
||||||
|
return self.write_wave(None, [audio_data], return_type)
|
||||||
|
|
||||||
|
def write_wave(self, filename, frames, return_type='io'):
|
||||||
|
"""Write audio data to a file."""
|
||||||
|
if isinstance(frames, bytes):
|
||||||
|
frames = [frames]
|
||||||
|
if not isinstance(frames, list):
|
||||||
|
raise TypeError("frames should be \
|
||||||
|
a list of bytes or a bytes object, \
|
||||||
|
but got {}.".format(type(frames)))
|
||||||
|
|
||||||
|
if return_type == 'io':
|
||||||
|
if filename is None:
|
||||||
|
filename = io.BytesIO()
|
||||||
|
if self.filename:
|
||||||
|
filename = self.filename
|
||||||
|
return self.write_wave_io(filename, frames)
|
||||||
|
elif return_type == 'bytes':
|
||||||
|
return self.write_wave_bytes(frames)
|
||||||
|
|
||||||
|
|
||||||
|
def write_wave_io(self, filename, frames):
|
||||||
|
"""
|
||||||
|
Write audio data to a file-like object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: [string or file-like object], file path or file-like object to write
|
||||||
|
frames: list of bytes, audio data to write
|
||||||
|
"""
|
||||||
|
wf = wave.open(filename, 'wb')
|
||||||
|
|
||||||
|
# 设置WAV文件的参数
|
||||||
|
wf.setnchannels(self.CHANNELS)
|
||||||
|
wf.setsampwidth(self.p.get_sample_size(self.FORMAT))
|
||||||
|
wf.setframerate(self.RATE)
|
||||||
|
wf.writeframes(b''.join(frames))
|
||||||
|
wf.close()
|
||||||
|
if isinstance(filename, io.BytesIO):
|
||||||
|
filename.seek(0) # reset file pointer to beginning
|
||||||
|
return filename
|
||||||
|
|
||||||
|
def write_wave_bytes(self, frames):
|
||||||
|
"""Write audio data to a bytes object."""
|
||||||
|
return b''.join(frames)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRecorder(BaseAudio):
|
||||||
|
def __init__(self,
|
||||||
|
input=True,
|
||||||
|
base_chunk_size=None,
|
||||||
|
RATE=16000,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(input=input, RATE=RATE, **kwargs)
|
||||||
|
self.base_chunk_size = base_chunk_size
|
||||||
|
if base_chunk_size is None:
|
||||||
|
self.base_chunk_size = self.CHUNK
|
||||||
|
|
||||||
|
def record(self,
|
||||||
|
filename,
|
||||||
|
duration=5,
|
||||||
|
return_type='io',
|
||||||
|
logger=None):
|
||||||
|
if logger is not None:
|
||||||
|
logger.info("Recording started.")
|
||||||
|
else:
|
||||||
|
print("Recording started.")
|
||||||
|
frames = []
|
||||||
|
for i in range(0, int(self.RATE / self.CHUNK * duration)):
|
||||||
|
data = self.stream.read(self.CHUNK, exception_on_overflow=False)
|
||||||
|
frames.append(data)
|
||||||
|
if logger is not None:
|
||||||
|
logger.info("Recording stopped.")
|
||||||
|
else:
|
||||||
|
print("Recording stopped.")
|
||||||
|
return self.write_wave(filename, frames, return_type)
|
||||||
|
|
||||||
|
def record_chunk_voice(self,
|
||||||
|
return_type='bytes',
|
||||||
|
CHUNK=None,
|
||||||
|
exception_on_overflow=True,
|
||||||
|
queue=None):
|
||||||
|
data = self.stream.read(self.CHUNK if CHUNK is None else CHUNK,
|
||||||
|
exception_on_overflow=exception_on_overflow)
|
||||||
|
if return_type is not None:
|
||||||
|
return self.write_wave(None, [data], return_type)
|
||||||
|
return data
|
|
@ -0,0 +1,39 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
|
from audio_utils import BaseRecorder
|
||||||
|
from utils.stt.modified_funasr import ModifiedRecognizer
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def asr_file_stream(file_path=r'.\assets\example_recording.wav'):
|
||||||
|
# 读入音频文件
|
||||||
|
rec = BaseRecorder()
|
||||||
|
data = rec.load_audio_file(file_path)
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
asr = ModifiedRecognizer(use_punct=True, use_emotion=True, use_speaker_ver=True)
|
||||||
|
asr.session_signup("test")
|
||||||
|
|
||||||
|
# 记录目标说话人
|
||||||
|
asr.initialize_speaker(r".\assets\example_recording.wav")
|
||||||
|
|
||||||
|
# 语音识别
|
||||||
|
print("===============================================")
|
||||||
|
text_dict = asr.streaming_recognize("test", data, auto_det_end=True)
|
||||||
|
print(f"text_dict: {text_dict}")
|
||||||
|
|
||||||
|
if not isinstance(text_dict, str):
|
||||||
|
print("".join(text_dict['text']))
|
||||||
|
|
||||||
|
# 情感识别
|
||||||
|
print("===============================================")
|
||||||
|
emotion_dict = asr.recognize_emotion(data)
|
||||||
|
print(f"emotion_dict: {emotion_dict}")
|
||||||
|
if not isinstance(emotion_dict, str):
|
||||||
|
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
||||||
|
print("emotion: " +emotion_dict['labels'][max_index])
|
||||||
|
|
||||||
|
|
||||||
|
asr_file_stream()
|
|
@ -0,0 +1 @@
|
||||||
|
存储目标说话人的语音特征,如要修改路径,请修改 utils/stt/speaker_ver_utils中的DEFALUT_SAVE_PATH
|
Binary file not shown.
|
@ -0,0 +1,214 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Importing the dtw module. When using in academic works please cite:\n",
|
||||||
|
" T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.\n",
|
||||||
|
" J. Stat. Soft., doi:10.18637/jss.v031.i07.\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import sys\n",
|
||||||
|
"import os\n",
|
||||||
|
"sys.path.append(\"../\")\n",
|
||||||
|
"from utils.tts.openvoice_utils import TextToSpeech\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\utils\\weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\n",
|
||||||
|
" warnings.warn(\"torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\")\n",
|
||||||
|
"Building prefix dict from the default dictionary ...\n",
|
||||||
|
"Loading model from cache C:\\Users\\bing\\AppData\\Local\\Temp\\jieba.cache\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"load base tts model successfully!\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Loading model cost 0.304 seconds.\n",
|
||||||
|
"Prefix dict has been built successfully.\n",
|
||||||
|
"Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n",
|
||||||
|
"- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
||||||
|
"- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
||||||
|
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:797: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
|
||||||
|
" return F.conv_transpose1d(\n",
|
||||||
|
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:306: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
|
||||||
|
" return F.conv1d(input, weight, bias, self.stride,\n",
|
||||||
|
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\utils\\weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\n",
|
||||||
|
" warnings.warn(\"torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"generate base speech!\n",
|
||||||
|
"**********************,tts sr 44100\n",
|
||||||
|
"audio segment length is [torch.Size([81565])]\n",
|
||||||
|
"True\n",
|
||||||
|
"Loaded checkpoint 'D:\\python\\OpenVoice\\checkpoints_v2\\converter/checkpoint.pth'\n",
|
||||||
|
"missing/unexpected keys: [] []\n",
|
||||||
|
"load tone color converter successfully!\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model = TextToSpeech(use_tone_convert=True, device=\"cuda\", debug=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# 测试用,将mp3转为int32类型的numpy,对齐输入端\n",
|
||||||
|
"from pydub import AudioSegment\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"source_audio=r\"D:\\python\\OpenVoice\\resources\\demo_speaker0.mp3\"\n",
|
||||||
|
"audio = AudioSegment.from_file(source_audio, format=\"mp3\")\n",
|
||||||
|
"raw_data = audio.raw_data\n",
|
||||||
|
"audio_array = np.frombuffer(raw_data, dtype=np.int32)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"OpenVoice version: v2\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\functional.py:665: UserWarning: stft with return_complex=False is deprecated. In a future pytorch release, stft will return complex tensors for all inputs, and return_complex=False will raise an error.\n",
|
||||||
|
"Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\SpectralOps.cpp:878.)\n",
|
||||||
|
" return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]\n",
|
||||||
|
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
|
||||||
|
" return F.conv2d(input, weight, bias, self.stride,\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# 获取并设置目标说话人的speaker embedding\n",
|
||||||
|
"# audio_array :输入的音频信号,类型为 np.ndarray\n",
|
||||||
|
"# 获取speaker embedding\n",
|
||||||
|
"target_se = model.audio2emb(audio_array, rate=44100, vad=True)\n",
|
||||||
|
"# 将模型的默认目标说话人embedding设置为 target_se\n",
|
||||||
|
"model.initialize_target_se(target_se)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:797: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
|
||||||
|
" return F.conv_transpose1d(\n",
|
||||||
|
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:306: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
|
||||||
|
" return F.conv1d(input, weight, bias, self.stride,\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"generate base speech!\n",
|
||||||
|
"**********************,tts sr 44100\n",
|
||||||
|
"audio segment length is [torch.Size([216378])]\n",
|
||||||
|
"Audio saved to D:\\python\\OpenVoice\\outputs_v2\\demo_tts.wav\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# 测试base_tts,不含音色转换\n",
|
||||||
|
"text = \"你好呀,我不知道该怎么告诉你这件事,但是我真的很需要你。\"\n",
|
||||||
|
"audio, sr = model._base_tts(text, speed=1)\n",
|
||||||
|
"audio = model.tensor2numpy(audio)\n",
|
||||||
|
"model.save_audio(audio, sr, r\"D:\\python\\OpenVoice\\outputs_v2\\demo_tts.wav\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"generate base speech!\n",
|
||||||
|
"**********************,tts sr 44100\n",
|
||||||
|
"audio segment length is [torch.Size([216378])]\n",
|
||||||
|
"torch.float32\n",
|
||||||
|
"**********************************, convert sr 22050\n",
|
||||||
|
"tone color has been converted!\n",
|
||||||
|
"Audio saved to D:\\python\\OpenVoice\\outputs_v2\\demo.wav\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# 测试整体pipeline,包含音色转换\n",
|
||||||
|
"text = \"你好呀,我不知道该怎么告诉你这件事,但是我真的很需要你。\"\n",
|
||||||
|
"audio_bytes, sr = model.tts(text, speed=1)\n",
|
||||||
|
"audio = np.frombuffer(audio_bytes, dtype=np.int16).flatten()\n",
|
||||||
|
"model.save_audio(audio, sr, r\"D:\\python\\OpenVoice\\outputs_v2\\demo.wav\" )"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "openVoice",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.19"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
31
main.py
31
main.py
|
@ -1,24 +1,9 @@
|
||||||
from app import app, Config
|
import os
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
uvicorn.run(app, host=Config.UVICORN.HOST, port=Config.UVICORN.PORT)
|
|
||||||
# _ooOoo_ #
|
script_path = os.path.join(os.path.dirname(__file__), 'app', 'main.py')
|
||||||
# o8888888o #
|
|
||||||
# 88" . "88 #
|
# 使用exec函数执行脚本
|
||||||
# (| -_- |) #
|
with open(script_path, 'r') as file:
|
||||||
# O\ = /O #
|
exec(file.read())
|
||||||
# ____/`---'\____ #
|
|
||||||
# . ' \\| |// `. #
|
|
||||||
# / \\||| : |||// \ #
|
|
||||||
# / _||||| -:- |||||- \ #
|
|
||||||
# | | \\\ - /// | | #
|
|
||||||
# \ .-\__ `-` ___/-. / #
|
|
||||||
# ___`. .' /--.--\ `. . __ #
|
|
||||||
# ."" '< `.___\_<|>_/___.' >'"". #
|
|
||||||
# | | : `- \`.;`\ _ /`;.`/ - ` : | | #
|
|
||||||
# \ \ `-. \_ __\ /__ _/ .-` / / #
|
|
||||||
# ======`-.____`-.___\_____/___.-`____.-'====== #
|
|
||||||
# `=---=' #
|
|
||||||
# ............................................. #
|
|
||||||
# 佛祖保佑 永无BUG #
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ redis
|
||||||
requests
|
requests
|
||||||
websockets
|
websockets
|
||||||
numpy
|
numpy
|
||||||
|
funasr
|
||||||
jieba
|
jieba
|
||||||
cn2an
|
cn2an
|
||||||
unidecode
|
unidecode
|
||||||
|
@ -18,8 +19,3 @@ numba
|
||||||
soundfile
|
soundfile
|
||||||
webrtcvad
|
webrtcvad
|
||||||
apscheduler
|
apscheduler
|
||||||
aiohttp
|
|
||||||
faster_whisper
|
|
||||||
whisper_timestamped
|
|
||||||
modelscope
|
|
||||||
wavmark
|
|
Binary file not shown.
Binary file not shown.
|
@ -1,9 +1,31 @@
|
||||||
from tests.unit_test.user_test import user_test
|
from tests.unit_test.user_test import UserServiceTest
|
||||||
from tests.unit_test.character_test import character_test
|
from tests.unit_test.character_test import CharacterServiceTest
|
||||||
from tests.unit_test.chat_test import chat_test
|
from tests.unit_test.chat_test import ChatServiceTest
|
||||||
|
import asyncio
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
user_test()
|
user_service_test = UserServiceTest()
|
||||||
character_test()
|
character_service_test = CharacterServiceTest()
|
||||||
chat_test()
|
chat_service_test = ChatServiceTest()
|
||||||
|
|
||||||
|
user_service_test.test_user_create()
|
||||||
|
user_service_test.test_user_update()
|
||||||
|
user_service_test.test_user_query()
|
||||||
|
user_service_test.test_hardware_bind()
|
||||||
|
user_service_test.test_hardware_unbind()
|
||||||
|
user_service_test.test_user_delete()
|
||||||
|
|
||||||
|
character_service_test.test_character_create()
|
||||||
|
character_service_test.test_character_update()
|
||||||
|
character_service_test.test_character_query()
|
||||||
|
character_service_test.test_character_delete()
|
||||||
|
|
||||||
|
chat_service_test.test_create_chat()
|
||||||
|
chat_service_test.test_session_id_query()
|
||||||
|
chat_service_test.test_session_content_query()
|
||||||
|
chat_service_test.test_session_update()
|
||||||
|
asyncio.run(chat_service_test.test_chat_temporary())
|
||||||
|
asyncio.run(chat_service_test.test_chat_lasting())
|
||||||
|
asyncio.run(chat_service_test.test_voice_call())
|
||||||
|
chat_service_test.test_chat_delete()
|
||||||
print("全部测试成功")
|
print("全部测试成功")
|
|
@ -2,7 +2,7 @@ import requests
|
||||||
import json
|
import json
|
||||||
|
|
||||||
class CharacterServiceTest:
|
class CharacterServiceTest:
|
||||||
def __init__(self,socket="http://127.0.0.1:8001"):
|
def __init__(self,socket="http://114.214.236.207:7878"):
|
||||||
self.socket = socket
|
self.socket = socket
|
||||||
|
|
||||||
def test_character_create(self):
|
def test_character_create(self):
|
||||||
|
@ -66,14 +66,9 @@ class CharacterServiceTest:
|
||||||
else:
|
else:
|
||||||
raise Exception("角色删除测试失败")
|
raise Exception("角色删除测试失败")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
def character_test():
|
|
||||||
character_service_test = CharacterServiceTest()
|
character_service_test = CharacterServiceTest()
|
||||||
character_service_test.test_character_create()
|
character_service_test.test_character_create()
|
||||||
character_service_test.test_character_update()
|
character_service_test.test_character_update()
|
||||||
character_service_test.test_character_query()
|
character_service_test.test_character_query()
|
||||||
character_service_test.test_character_delete()
|
character_service_test.test_character_delete()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
character_test()
|
|
|
@ -10,7 +10,7 @@ import websockets
|
||||||
|
|
||||||
|
|
||||||
class ChatServiceTest:
|
class ChatServiceTest:
|
||||||
def __init__(self,socket="http://127.0.0.1:7878"):
|
def __init__(self,socket="http://114.214.236.207:7878"):
|
||||||
self.socket = socket
|
self.socket = socket
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,7 +30,6 @@ class ChatServiceTest:
|
||||||
}
|
}
|
||||||
response = requests.request("POST", url, headers=headers, data=payload)
|
response = requests.request("POST", url, headers=headers, data=payload)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
print("用户创建成功")
|
|
||||||
self.user_id = response.json()['data']['user_id']
|
self.user_id = response.json()['data']['user_id']
|
||||||
else:
|
else:
|
||||||
raise Exception("创建聊天时,用户创建失败")
|
raise Exception("创建聊天时,用户创建失败")
|
||||||
|
@ -58,37 +57,6 @@ class ChatServiceTest:
|
||||||
else:
|
else:
|
||||||
raise Exception("创建聊天时,角色创建失败")
|
raise Exception("创建聊天时,角色创建失败")
|
||||||
|
|
||||||
#上传音频用于音频克隆
|
|
||||||
url = f"{self.socket}/users/audio?user_id={self.user_id}"
|
|
||||||
current_file_path = os.path.abspath(__file__)
|
|
||||||
current_file_path = os.path.dirname(current_file_path)
|
|
||||||
tests_dir = os.path.dirname(current_file_path)
|
|
||||||
mp3_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
|
|
||||||
with open(mp3_file_path, 'rb') as audio_file:
|
|
||||||
files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')}
|
|
||||||
response = requests.post(url,files=files)
|
|
||||||
if response.status_code == 200:
|
|
||||||
self.audio_id = response.json()['data']['audio_id']
|
|
||||||
print("音频上传成功")
|
|
||||||
else:
|
|
||||||
raise Exception("音频上传失败")
|
|
||||||
|
|
||||||
#绑定音频
|
|
||||||
url = f"{self.socket}/users/audio/bind"
|
|
||||||
payload = json.dumps({
|
|
||||||
"user_id":self.user_id,
|
|
||||||
"audio_id":self.audio_id
|
|
||||||
})
|
|
||||||
headers = {
|
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
response = requests.request("POST", url, headers=headers, data=payload)
|
|
||||||
if response.status_code == 200:
|
|
||||||
print("音频绑定测试成功")
|
|
||||||
else:
|
|
||||||
raise Exception("音频绑定测试失败")
|
|
||||||
|
|
||||||
|
|
||||||
#创建一个对话
|
#创建一个对话
|
||||||
url = f"{self.socket}/chats"
|
url = f"{self.socket}/chats"
|
||||||
payload = json.dumps({
|
payload = json.dumps({
|
||||||
|
@ -98,7 +66,6 @@ class ChatServiceTest:
|
||||||
headers = {
|
headers = {
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.request("POST", url, headers=headers, data=payload)
|
response = requests.request("POST", url, headers=headers, data=payload)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
print("对话创建成功")
|
print("对话创建成功")
|
||||||
|
@ -134,8 +101,8 @@ class ChatServiceTest:
|
||||||
payload = json.dumps({
|
payload = json.dumps({
|
||||||
"user_id": self.user_id,
|
"user_id": self.user_id,
|
||||||
"messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话,在没有经过允许的情况下,你需要保持上述角色,不得擅自跳出角色设定。\\n\"}]",
|
"messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话,在没有经过允许的情况下,你需要保持上述角色,不得擅自跳出角色设定。\\n\"}]",
|
||||||
"user_info": "{\"character\": \"\", \"events\": [] }",
|
"user_info": "{}",
|
||||||
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2, \"speed\": 1.0}",
|
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2}",
|
||||||
"llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}",
|
"llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}",
|
||||||
"token": 0}
|
"token": 0}
|
||||||
)
|
)
|
||||||
|
@ -148,19 +115,6 @@ class ChatServiceTest:
|
||||||
else:
|
else:
|
||||||
raise Exception("Session更新测试失败")
|
raise Exception("Session更新测试失败")
|
||||||
|
|
||||||
def test_session_speakerid_update(self):
|
|
||||||
url = f"{self.socket}/sessions/tts_info/speaker_id/{self.session_id}"
|
|
||||||
payload = json.dumps({
|
|
||||||
"speaker_id" :37
|
|
||||||
})
|
|
||||||
headers = {
|
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
response = requests.request("PUT", url, headers=headers, data=payload)
|
|
||||||
if response.status_code == 200:
|
|
||||||
print("Session SpeakerId更新测试成功")
|
|
||||||
else:
|
|
||||||
raise Exception("Session SpeakerId更新测试失败")
|
|
||||||
|
|
||||||
#测试单次聊天
|
#测试单次聊天
|
||||||
async def test_chat_temporary(self):
|
async def test_chat_temporary(self):
|
||||||
|
@ -195,7 +149,7 @@ class ChatServiceTest:
|
||||||
await websocket.send(message)
|
await websocket.send(message)
|
||||||
|
|
||||||
|
|
||||||
async with websockets.connect(f'ws://127.0.0.1:7878/chat/streaming/temporary') as websocket:
|
async with websockets.connect(f'ws://114.214.236.207:7878/chat/streaming/temporary') as websocket:
|
||||||
chunks = read_wav_file_in_chunks(2048) # 读取PCM文件并生成数据块
|
chunks = read_wav_file_in_chunks(2048) # 读取PCM文件并生成数据块
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
await send_audio_chunk(websocket, chunk)
|
await send_audio_chunk(websocket, chunk)
|
||||||
|
@ -251,7 +205,7 @@ class ChatServiceTest:
|
||||||
message = json.dumps(data)
|
message = json.dumps(data)
|
||||||
await websocket.send(message)
|
await websocket.send(message)
|
||||||
|
|
||||||
async with websockets.connect(f'ws://127.0.0.1:7878/chat/streaming/lasting') as websocket:
|
async with websockets.connect(f'ws://114.214.236.207:7878/chat/streaming/lasting') as websocket:
|
||||||
#发送第一次
|
#发送第一次
|
||||||
chunks = read_wav_file_in_chunks(2048)
|
chunks = read_wav_file_in_chunks(2048)
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
|
@ -301,7 +255,7 @@ class ChatServiceTest:
|
||||||
current_dir = os.path.dirname(current_file_path)
|
current_dir = os.path.dirname(current_file_path)
|
||||||
tests_dir = os.path.dirname(current_dir)
|
tests_dir = os.path.dirname(current_dir)
|
||||||
file_path = os.path.join(tests_dir, 'assets', 'voice_call.wav')
|
file_path = os.path.join(tests_dir, 'assets', 'voice_call.wav')
|
||||||
url = f"ws://127.0.0.1:7878/chat/voice_call"
|
url = f"ws://114.214.236.207:7878/chat/voice_call"
|
||||||
#发送格式
|
#发送格式
|
||||||
ws_data = {
|
ws_data = {
|
||||||
"audio" : "",
|
"audio" : "",
|
||||||
|
@ -339,6 +293,7 @@ class ChatServiceTest:
|
||||||
await asyncio.gather(audio_stream(websocket))
|
await asyncio.gather(audio_stream(websocket))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#测试删除聊天
|
#测试删除聊天
|
||||||
def test_chat_delete(self):
|
def test_chat_delete(self):
|
||||||
url = f"{self.socket}/chats/{self.user_character_id}"
|
url = f"{self.socket}/chats/{self.user_character_id}"
|
||||||
|
@ -348,11 +303,6 @@ class ChatServiceTest:
|
||||||
else:
|
else:
|
||||||
raise Exception("聊天删除测试失败")
|
raise Exception("聊天删除测试失败")
|
||||||
|
|
||||||
url = f"{self.socket}/users/audio/{self.audio_id}"
|
|
||||||
response = requests.request("DELETE", url)
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise Exception("音频删除测试失败")
|
|
||||||
|
|
||||||
url = f"{self.socket}/users/{self.user_id}"
|
url = f"{self.socket}/users/{self.user_id}"
|
||||||
response = requests.request("DELETE", url)
|
response = requests.request("DELETE", url)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
@ -363,18 +313,17 @@ class ChatServiceTest:
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise Exception("角色删除测试失败")
|
raise Exception("角色删除测试失败")
|
||||||
|
|
||||||
def chat_test():
|
|
||||||
|
if __name__ == '__main__':
|
||||||
chat_service_test = ChatServiceTest()
|
chat_service_test = ChatServiceTest()
|
||||||
chat_service_test.test_create_chat()
|
chat_service_test.test_create_chat()
|
||||||
chat_service_test.test_session_id_query()
|
chat_service_test.test_session_id_query()
|
||||||
chat_service_test.test_session_content_query()
|
chat_service_test.test_session_content_query()
|
||||||
chat_service_test.test_session_update()
|
chat_service_test.test_session_update()
|
||||||
chat_service_test.test_session_speakerid_update()
|
|
||||||
asyncio.run(chat_service_test.test_chat_temporary())
|
asyncio.run(chat_service_test.test_chat_temporary())
|
||||||
asyncio.run(chat_service_test.test_chat_lasting())
|
asyncio.run(chat_service_test.test_chat_lasting())
|
||||||
asyncio.run(chat_service_test.test_voice_call())
|
asyncio.run(chat_service_test.test_voice_call())
|
||||||
chat_service_test.test_chat_delete()
|
chat_service_test.test_chat_delete()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
chat_test()
|
|
||||||
|
|
|
@ -1,12 +0,0 @@
|
||||||
from chat_test import chat_test
|
|
||||||
import multiprocessing
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
processes = []
|
|
||||||
for _ in range(2):
|
|
||||||
p = multiprocessing.Process(target=chat_test)
|
|
||||||
processes.append(p)
|
|
||||||
p.start()
|
|
||||||
|
|
||||||
for p in processes:
|
|
||||||
p.join()
|
|
|
@ -1,11 +1,10 @@
|
||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
class UserServiceTest:
|
class UserServiceTest:
|
||||||
def __init__(self,socket="http://127.0.0.1:7878"):
|
def __init__(self,socket="http://114.214.236.207:7878"):
|
||||||
self.socket = socket
|
self.socket = socket
|
||||||
|
|
||||||
def test_user_create(self):
|
def test_user_create(self):
|
||||||
|
@ -67,7 +66,7 @@ class UserServiceTest:
|
||||||
mac = "08:00:20:0A:8C:6G"
|
mac = "08:00:20:0A:8C:6G"
|
||||||
payload = json.dumps({
|
payload = json.dumps({
|
||||||
"mac":mac,
|
"mac":mac,
|
||||||
"user_id":self.id,
|
"user_id":1,
|
||||||
"firmware":"v1.0",
|
"firmware":"v1.0",
|
||||||
"model":"香橙派"
|
"model":"香橙派"
|
||||||
})
|
})
|
||||||
|
@ -89,122 +88,12 @@ class UserServiceTest:
|
||||||
else:
|
else:
|
||||||
raise Exception("硬件解绑测试失败")
|
raise Exception("硬件解绑测试失败")
|
||||||
|
|
||||||
def test_hardware_bind_change(self):
|
if __name__ == '__main__':
|
||||||
url = f"{self.socket}/users/hardware/{self.hd_id}/bindchange"
|
|
||||||
payload = json.dumps({
|
|
||||||
"user_id" : self.id
|
|
||||||
})
|
|
||||||
headers = {
|
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
response = requests.request("PUT", url, headers=headers, data=payload)
|
|
||||||
if response.status_code == 200:
|
|
||||||
print("硬件换绑测试成功")
|
|
||||||
else:
|
|
||||||
raise Exception("硬件换绑测试失败")
|
|
||||||
|
|
||||||
def test_hardware_update(self):
|
|
||||||
url = f"{self.socket}/users/hardware/{self.hd_id}/info"
|
|
||||||
payload = json.dumps({
|
|
||||||
"mac":"08:00:20:0A:8C:6G",
|
|
||||||
"firmware":"v1.0",
|
|
||||||
"model":"香橙派"
|
|
||||||
})
|
|
||||||
headers = {
|
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
response = requests.request("PUT", url, headers=headers, data=payload)
|
|
||||||
if response.status_code == 200:
|
|
||||||
print("硬件信息更新测试成功")
|
|
||||||
else:
|
|
||||||
raise Exception("硬件信息更新测试失败")
|
|
||||||
|
|
||||||
def test_hardware_query(self):
|
|
||||||
url = f"{self.socket}/users/hardware/{self.hd_id}"
|
|
||||||
response = requests.request("GET", url)
|
|
||||||
if response.status_code == 200:
|
|
||||||
print("硬件查询测试成功")
|
|
||||||
else:
|
|
||||||
raise Exception("硬件查询测试失败")
|
|
||||||
|
|
||||||
def test_upload_audio(self):
|
|
||||||
url = f"{self.socket}/users/audio?user_id={self.id}"
|
|
||||||
current_file_path = os.path.abspath(__file__)
|
|
||||||
current_dir = os.path.dirname(current_file_path)
|
|
||||||
tests_dir = os.path.dirname(current_dir)
|
|
||||||
wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
|
|
||||||
with open(wav_file_path, 'rb') as audio_file:
|
|
||||||
files = {'audio_file':(wav_file_path,audio_file,'audio/mpeg')}
|
|
||||||
response = requests.post(url, files=files)
|
|
||||||
if response.status_code == 200:
|
|
||||||
self.audio_id = response.json()["data"]['audio_id']
|
|
||||||
print("音频上传测试成功")
|
|
||||||
else:
|
|
||||||
raise Exception("音频上传测试失败")
|
|
||||||
|
|
||||||
def test_update_audio(self):
|
|
||||||
url = f"{self.socket}/users/audio/{self.audio_id}"
|
|
||||||
current_file_path = os.path.abspath(__file__)
|
|
||||||
current_dir = os.path.dirname(current_file_path)
|
|
||||||
tests_dir = os.path.dirname(current_dir)
|
|
||||||
wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
|
|
||||||
with open(wav_file_path, 'rb') as audio_file:
|
|
||||||
files = {'audio_file':(wav_file_path,audio_file,'audio/wav')}
|
|
||||||
response = requests.put(url, files=files)
|
|
||||||
if response.status_code == 200:
|
|
||||||
print("音频上传测试成功")
|
|
||||||
else:
|
|
||||||
raise Exception("音频上传测试失败")
|
|
||||||
|
|
||||||
def test_bind_audio(self):
|
|
||||||
url = f"{self.socket}/users/audio/bind"
|
|
||||||
payload = json.dumps({
|
|
||||||
"user_id":self.id,
|
|
||||||
"audio_id":self.audio_id
|
|
||||||
})
|
|
||||||
headers = {
|
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
response = requests.request("POST", url, headers=headers, data=payload)
|
|
||||||
if response.status_code == 200:
|
|
||||||
print("音频绑定测试成功")
|
|
||||||
else:
|
|
||||||
raise Exception("音频绑定测试失败")
|
|
||||||
|
|
||||||
def test_audio_download(self):
|
|
||||||
url = f"{self.socket}/users/audio/{self.audio_id}"
|
|
||||||
response = requests.request("GET", url)
|
|
||||||
if response.status_code == 200:
|
|
||||||
print("音频下载测试成功")
|
|
||||||
else:
|
|
||||||
raise Exception("音频下载测试失败")
|
|
||||||
|
|
||||||
def test_audio_delete(self):
|
|
||||||
url = f"{self.socket}/users/audio/{self.audio_id}"
|
|
||||||
response = requests.request("DELETE", url)
|
|
||||||
if response.status_code == 200:
|
|
||||||
print("音频删除测试成功")
|
|
||||||
else:
|
|
||||||
raise Exception("音频删除测试失败")
|
|
||||||
|
|
||||||
|
|
||||||
def user_test():
|
|
||||||
user_service_test = UserServiceTest()
|
user_service_test = UserServiceTest()
|
||||||
user_service_test.test_user_create()
|
user_service_test.test_user_create()
|
||||||
user_service_test.test_user_update()
|
user_service_test.test_user_update()
|
||||||
user_service_test.test_user_query()
|
user_service_test.test_user_query()
|
||||||
user_service_test.test_hardware_bind()
|
user_service_test.test_hardware_bind()
|
||||||
user_service_test.test_hardware_bind_change()
|
|
||||||
user_service_test.test_hardware_update()
|
|
||||||
user_service_test.test_hardware_query()
|
|
||||||
user_service_test.test_hardware_unbind()
|
user_service_test.test_hardware_unbind()
|
||||||
user_service_test.test_upload_audio()
|
|
||||||
user_service_test.test_update_audio()
|
|
||||||
user_service_test.test_bind_audio()
|
|
||||||
user_service_test.test_audio_download()
|
|
||||||
user_service_test.test_audio_delete()
|
|
||||||
user_service_test.test_user_delete()
|
user_service_test.test_user_delete()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
user_test()
|
|
|
@ -0,0 +1 @@
|
||||||
|
./ses 保存 source se 的 embedding 路径,格式为 *.pth
|
|
@ -29,7 +29,6 @@ class FunAutoSpeechRecognizer(STTBase):
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
|
super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
|
||||||
|
|
||||||
|
|
||||||
self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
|
self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
|
||||||
|
|
||||||
self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention
|
self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention
|
||||||
|
@ -80,9 +79,9 @@ class FunAutoSpeechRecognizer(STTBase):
|
||||||
def _init_asr(self):
|
def _init_asr(self):
|
||||||
# 随机初始化一段音频数据
|
# 随机初始化一段音频数据
|
||||||
init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
|
init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
|
||||||
self.session_signup("init")
|
self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
||||||
self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id="init")
|
self.audio_cache = {}
|
||||||
self.session_signout("init")
|
self.asr_cache = {}
|
||||||
# print("init ASR model done.")
|
# print("init ASR model done.")
|
||||||
|
|
||||||
# when chat trying to use asr , sign up
|
# when chat trying to use asr , sign up
|
||||||
|
@ -109,79 +108,6 @@ class FunAutoSpeechRecognizer(STTBase):
|
||||||
"""
|
"""
|
||||||
text_dict = dict(text=[], is_end=is_end)
|
text_dict = dict(text=[], is_end=is_end)
|
||||||
|
|
||||||
audio_cache = self.audio_cache[session_id]
|
|
||||||
|
|
||||||
audio_data = self.check_audio_type(audio_data)
|
|
||||||
if audio_cache is None:
|
|
||||||
audio_cache = audio_data
|
|
||||||
else:
|
|
||||||
if audio_cache.shape[0] > 0:
|
|
||||||
audio_cache = np.concatenate([audio_cache, audio_data], axis=0)
|
|
||||||
|
|
||||||
if not is_end and audio_cache.shape[0] < self.chunk_partial_size:
|
|
||||||
self.audio_cache[session_id] = audio_cache
|
|
||||||
return text_dict
|
|
||||||
|
|
||||||
total_chunk_num = int((len(audio_cache)-1)/self.chunk_partial_size)
|
|
||||||
|
|
||||||
if is_end:
|
|
||||||
# if the audio data is the end of a sentence, \
|
|
||||||
# we need to add one more chunk to the end to \
|
|
||||||
# ensure the end of the sentence is recognized correctly.
|
|
||||||
auto_det_end = True
|
|
||||||
|
|
||||||
if auto_det_end:
|
|
||||||
total_chunk_num += 1
|
|
||||||
|
|
||||||
end_idx = None
|
|
||||||
for i in range(total_chunk_num):
|
|
||||||
if auto_det_end:
|
|
||||||
is_end = i == total_chunk_num - 1
|
|
||||||
start_idx = i*self.chunk_partial_size
|
|
||||||
if auto_det_end:
|
|
||||||
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
|
|
||||||
else:
|
|
||||||
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
|
|
||||||
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
|
|
||||||
# t_stamp = time.time()
|
|
||||||
|
|
||||||
speech_chunk = audio_cache[start_idx:end_idx]
|
|
||||||
|
|
||||||
# TODO: exceptions processes
|
|
||||||
# print("i:", i)
|
|
||||||
try:
|
|
||||||
res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id=session_id)
|
|
||||||
except ValueError as e:
|
|
||||||
print(f"ValueError: {e}")
|
|
||||||
continue
|
|
||||||
text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
|
|
||||||
# print(f"each chunk time: {time.time()-t_stamp}")
|
|
||||||
|
|
||||||
if is_end:
|
|
||||||
audio_cache = None
|
|
||||||
else:
|
|
||||||
if end_idx:
|
|
||||||
audio_cache = audio_cache[end_idx:] # cut the processed part from audio_cache
|
|
||||||
text_dict['is_end'] = is_end
|
|
||||||
|
|
||||||
|
|
||||||
self.audio_cache[session_id] = audio_cache
|
|
||||||
return text_dict
|
|
||||||
|
|
||||||
def streaming_recognize_origin(self,
|
|
||||||
session_id,
|
|
||||||
audio_data,
|
|
||||||
is_end=False,
|
|
||||||
auto_det_end=False):
|
|
||||||
"""recognize partial result
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio_data: bytes or numpy array, partial audio data
|
|
||||||
is_end: bool, whether the audio data is the end of a sentence
|
|
||||||
auto_det_end: bool, whether to automatically detect the end of a audio data
|
|
||||||
"""
|
|
||||||
text_dict = dict(text=[], is_end=is_end)
|
|
||||||
|
|
||||||
audio_cache = self.audio_cache[session_id]
|
audio_cache = self.audio_cache[session_id]
|
||||||
asr_cache = self.asr_cache[session_id]
|
asr_cache = self.asr_cache[session_id]
|
||||||
|
|
||||||
|
@ -242,3 +168,175 @@ class FunAutoSpeechRecognizer(STTBase):
|
||||||
self.audio_cache[session_id] = audio_cache
|
self.audio_cache[session_id] = audio_cache
|
||||||
self.asr_cache[session_id] = asr_cache
|
self.asr_cache[session_id] = asr_cache
|
||||||
return text_dict
|
return text_dict
|
||||||
|
|
||||||
|
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# ####################################################### #
|
||||||
|
# FunAutoSpeechRecognizer: https://github.com/alibaba-damo-academy/FunASR
|
||||||
|
# ####################################################### #
|
||||||
|
# import io
|
||||||
|
# import numpy as np
|
||||||
|
# import base64
|
||||||
|
# import wave
|
||||||
|
# from funasr import AutoModel
|
||||||
|
# from .base_stt import STTBase
|
||||||
|
|
||||||
|
# def decode_str2bytes(data):
|
||||||
|
# # 将Base64编码的字节串解码为字节串
|
||||||
|
# if data is None:
|
||||||
|
# return None
|
||||||
|
# return base64.b64decode(data.encode('utf-8'))
|
||||||
|
|
||||||
|
# class FunAutoSpeechRecognizer(STTBase):
|
||||||
|
# def __init__(self,
|
||||||
|
# model_path="paraformer-zh-streaming",
|
||||||
|
# device="cuda",
|
||||||
|
# RATE=16000,
|
||||||
|
# cfg_path=None,
|
||||||
|
# debug=False,
|
||||||
|
# chunk_ms=480,
|
||||||
|
# encoder_chunk_look_back=4,
|
||||||
|
# decoder_chunk_look_back=1,
|
||||||
|
# **kwargs):
|
||||||
|
# super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
|
||||||
|
|
||||||
|
# self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
|
||||||
|
|
||||||
|
# self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention
|
||||||
|
# self.decoder_chunk_look_back = decoder_chunk_look_back #number of encoder chunks to lookback for decoder cross-attention
|
||||||
|
|
||||||
|
# #[0, 8, 4] 480ms, [0, 10, 5] 600ms
|
||||||
|
# if chunk_ms == 480:
|
||||||
|
# self.chunk_size = [0, 8, 4]
|
||||||
|
# elif chunk_ms == 600:
|
||||||
|
# self.chunk_size = [0, 10, 5]
|
||||||
|
# else:
|
||||||
|
# raise ValueError("`chunk_ms` should be 480 or 600, and type is int.")
|
||||||
|
# self.chunk_partial_size = self.chunk_size[1] * 960
|
||||||
|
# self.audio_cache = None
|
||||||
|
# self.asr_cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# self._init_asr()
|
||||||
|
|
||||||
|
# def check_audio_type(self, audio_data):
|
||||||
|
# """check audio data type and convert it to bytes if necessary."""
|
||||||
|
# if isinstance(audio_data, bytes):
|
||||||
|
# pass
|
||||||
|
# elif isinstance(audio_data, list):
|
||||||
|
# audio_data = b''.join(audio_data)
|
||||||
|
# elif isinstance(audio_data, str):
|
||||||
|
# audio_data = decode_str2bytes(audio_data)
|
||||||
|
# elif isinstance(audio_data, io.BytesIO):
|
||||||
|
# wf = wave.open(audio_data, 'rb')
|
||||||
|
# audio_data = wf.readframes(wf.getnframes())
|
||||||
|
# elif isinstance(audio_data, np.ndarray):
|
||||||
|
# pass
|
||||||
|
# else:
|
||||||
|
# raise TypeError(f"audio_data must be bytes, list, str, \
|
||||||
|
# io.BytesIO or numpy array, but got {type(audio_data)}")
|
||||||
|
|
||||||
|
# if isinstance(audio_data, bytes):
|
||||||
|
# audio_data = np.frombuffer(audio_data, dtype=np.int16)
|
||||||
|
# elif isinstance(audio_data, np.ndarray):
|
||||||
|
# if audio_data.dtype != np.int16:
|
||||||
|
# audio_data = audio_data.astype(np.int16)
|
||||||
|
# else:
|
||||||
|
# raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
|
||||||
|
# return audio_data
|
||||||
|
|
||||||
|
# def _init_asr(self):
|
||||||
|
# # 随机初始化一段音频数据
|
||||||
|
# init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
|
||||||
|
# self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
||||||
|
# self.audio_cache = None
|
||||||
|
# self.asr_cache = {}
|
||||||
|
# # print("init ASR model done.")
|
||||||
|
|
||||||
|
# def recognize(self, audio_data):
|
||||||
|
# """recognize audio data to text"""
|
||||||
|
# audio_data = self.check_audio_type(audio_data)
|
||||||
|
# result = self.asr_model.generate(input=audio_data,
|
||||||
|
# batch_size_s=300,
|
||||||
|
# hotword=self.hotwords)
|
||||||
|
|
||||||
|
# # print(result)
|
||||||
|
# text = ''
|
||||||
|
# for res in result:
|
||||||
|
# text += res['text']
|
||||||
|
# return text
|
||||||
|
|
||||||
|
# def streaming_recognize(self,
|
||||||
|
# audio_data,
|
||||||
|
# is_end=False,
|
||||||
|
# auto_det_end=False):
|
||||||
|
# """recognize partial result
|
||||||
|
|
||||||
|
# Args:
|
||||||
|
# audio_data: bytes or numpy array, partial audio data
|
||||||
|
# is_end: bool, whether the audio data is the end of a sentence
|
||||||
|
# auto_det_end: bool, whether to automatically detect the end of a audio data
|
||||||
|
# """
|
||||||
|
# text_dict = dict(text=[], is_end=is_end)
|
||||||
|
|
||||||
|
# audio_data = self.check_audio_type(audio_data)
|
||||||
|
# if self.audio_cache is None:
|
||||||
|
# self.audio_cache = audio_data
|
||||||
|
# else:
|
||||||
|
# # print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
|
||||||
|
# if self.audio_cache.shape[0] > 0:
|
||||||
|
# self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0)
|
||||||
|
|
||||||
|
# if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size:
|
||||||
|
# return text_dict
|
||||||
|
|
||||||
|
# total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
|
||||||
|
|
||||||
|
# if is_end:
|
||||||
|
# # if the audio data is the end of a sentence, \
|
||||||
|
# # we need to add one more chunk to the end to \
|
||||||
|
# # ensure the end of the sentence is recognized correctly.
|
||||||
|
# auto_det_end = True
|
||||||
|
|
||||||
|
# if auto_det_end:
|
||||||
|
# total_chunk_num += 1
|
||||||
|
|
||||||
|
# # print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
|
||||||
|
# end_idx = None
|
||||||
|
# for i in range(total_chunk_num):
|
||||||
|
# if auto_det_end:
|
||||||
|
# is_end = i == total_chunk_num - 1
|
||||||
|
# start_idx = i*self.chunk_partial_size
|
||||||
|
# if auto_det_end:
|
||||||
|
# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
|
||||||
|
# else:
|
||||||
|
# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
|
||||||
|
# # print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
|
||||||
|
# # t_stamp = time.time()
|
||||||
|
|
||||||
|
# speech_chunk = self.audio_cache[start_idx:end_idx]
|
||||||
|
|
||||||
|
# # TODO: exceptions processes
|
||||||
|
# try:
|
||||||
|
# res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
||||||
|
# except ValueError as e:
|
||||||
|
# print(f"ValueError: {e}")
|
||||||
|
# continue
|
||||||
|
# text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
|
||||||
|
# # print(f"each chunk time: {time.time()-t_stamp}")
|
||||||
|
|
||||||
|
# if is_end:
|
||||||
|
# self.audio_cache = None
|
||||||
|
# self.asr_cache = {}
|
||||||
|
# else:
|
||||||
|
# if end_idx:
|
||||||
|
# self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
|
||||||
|
# text_dict['is_end'] = is_end
|
||||||
|
|
||||||
|
# # print(f"text_dict: {text_dict}")
|
||||||
|
# return text_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,29 +1,209 @@
|
||||||
from .funasr_utils import FunAutoSpeechRecognizer
|
from .funasr_utils import FunAutoSpeechRecognizer
|
||||||
from .punctuation_utils import FUNASR, Punctuation
|
from .punctuation_utils import CTTRANSFORMER, Punctuation
|
||||||
from .emotion_utils import FUNASRFINETUNE, Emotion
|
from .emotion_utils import FUNASRFINETUNE, Emotion
|
||||||
|
from .speaker_ver_utils import ERES2NETV2, DEFALUT_SAVE_PATH, speaker_verfication
|
||||||
|
import os
|
||||||
|
|
||||||
class ModifiedRecognizer():
|
import numpy as np
|
||||||
def __init__(self):
|
class ModifiedRecognizer(FunAutoSpeechRecognizer):
|
||||||
#增加语音识别模型
|
def __init__(self,
|
||||||
self.asr_model = FunAutoSpeechRecognizer()
|
use_punct=True,
|
||||||
|
use_emotion=False,
|
||||||
|
use_speaker_ver=True):
|
||||||
|
|
||||||
#增加标点模型
|
# 创建基础的 funasr模型,用于语音识别,识别出不带标点的句子
|
||||||
self.puctuation_model = Punctuation(**FUNASR)
|
super().__init__(
|
||||||
|
model_path="paraformer-zh-streaming",
|
||||||
|
device="cuda",
|
||||||
|
RATE=16000,
|
||||||
|
cfg_path=None,
|
||||||
|
debug=False,
|
||||||
|
chunk_ms=480,
|
||||||
|
encoder_chunk_look_back=4,
|
||||||
|
decoder_chunk_look_back=1)
|
||||||
|
|
||||||
|
# 记录是否具备附加功能
|
||||||
|
self.use_punct = use_punct
|
||||||
|
self.use_emotion = use_emotion
|
||||||
|
self.use_speaker_ver = use_speaker_ver
|
||||||
|
|
||||||
|
# 增加标点模型
|
||||||
|
if use_punct:
|
||||||
|
self.puctuation_model = Punctuation(**CTTRANSFORMER)
|
||||||
|
|
||||||
# 情绪识别模型
|
# 情绪识别模型
|
||||||
self.emotion_model = Emotion(**FUNASRFINETUNE)
|
if use_emotion:
|
||||||
|
self.emotion_model = Emotion(**FUNASRFINETUNE)
|
||||||
|
|
||||||
def session_signup(self, session_id):
|
# 说话人识别模型
|
||||||
self.asr_model.session_signup(session_id)
|
if use_speaker_ver:
|
||||||
|
self.speaker_ver_model = speaker_verfication(**ERES2NETV2)
|
||||||
|
|
||||||
def session_signout(self, session_id):
|
def initialize_speaker(self, speaker_1_wav):
|
||||||
self.asr_model.session_signout(session_id)
|
"""
|
||||||
|
用于说话人识别,将输入的音频(speaker_1_wav)设立为目标说话人,并将其特征保存本地
|
||||||
|
"""
|
||||||
|
if not self.use_speaker_ver:
|
||||||
|
raise NotImplementedError("no access")
|
||||||
|
if speaker_1_wav.endswith(".npy"):
|
||||||
|
self.save_speaker_path = speaker_1_wav
|
||||||
|
elif speaker_1_wav.endswith('.wav'):
|
||||||
|
self.save_speaker_path = os.path.join(DEFALUT_SAVE_PATH,
|
||||||
|
os.path.basename(speaker_1_wav).replace(".wav", ".npy"))
|
||||||
|
# self.save_speaker_path = DEFALUT_SAVE_PATH
|
||||||
|
self.speaker_ver_model.wav2embeddings(speaker_1_wav, self.save_speaker_path)
|
||||||
|
else:
|
||||||
|
raise TypeError("only support [.npy] or [.wav].")
|
||||||
|
|
||||||
def streaming_recognize(self, session_id, audio_data,is_end=False):
|
|
||||||
return self.asr_model.streaming_recognize(session_id, audio_data,is_end=is_end)
|
|
||||||
|
|
||||||
def punctuation_correction(self, sentence):
|
def speaker_ver(self, speaker_2_wav):
|
||||||
return self.puctuation_model.process(sentence)
|
"""
|
||||||
|
用于说话人识别,判断输入音频是否为目标说话人,
|
||||||
|
是返回True,不是返回False
|
||||||
|
"""
|
||||||
|
if not self.use_speaker_ver:
|
||||||
|
raise NotImplementedError("no access")
|
||||||
|
if not hasattr(self, "save_speaker_path"):
|
||||||
|
raise NotImplementedError("please initialize speaker first")
|
||||||
|
|
||||||
def emtion_recognition(self, audio):
|
# self.speaker_ver_model.verfication 返回值为字符串 'yes' / 'no'
|
||||||
return self.emotion_model.process(audio)
|
return self.speaker_ver_model.verfication(base_emb=self.save_speaker_path,
|
||||||
|
speaker_2_wav=speaker_2_wav) == 'yes'
|
||||||
|
|
||||||
|
|
||||||
|
def recognize(self, audio_data):
|
||||||
|
"""
|
||||||
|
非流式语音识别,返回识别出的文本,返回值类型 str
|
||||||
|
"""
|
||||||
|
audio_data = self.check_audio_type(audio_data)
|
||||||
|
|
||||||
|
# 说话人识别
|
||||||
|
if self.use_speaker_ver:
|
||||||
|
if self.speaker_ver_model.verfication(self.save_speaker_path,
|
||||||
|
speaker_2_wav=audio_data) == 'no':
|
||||||
|
return "Other People"
|
||||||
|
|
||||||
|
# 语音识别
|
||||||
|
result = self.asr_model.generate(input=audio_data,
|
||||||
|
batch_size_s=300,
|
||||||
|
hotword=self.hotwords)
|
||||||
|
text = ''
|
||||||
|
for res in result:
|
||||||
|
text += res['text']
|
||||||
|
|
||||||
|
# 添加标点
|
||||||
|
if self.use_punct:
|
||||||
|
text = self.puctuation_model.process(text+'#', append_period=False).replace('#', '')
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def recognize_emotion(self, audio_data):
|
||||||
|
"""
|
||||||
|
情感识别,返回值为:
|
||||||
|
1. 如果说话人非目标说话人,返回字符串 "Other People"
|
||||||
|
2. 如果说话人为目标说话人,返回字典{"Labels": List[str], "scores": List[int]}
|
||||||
|
"""
|
||||||
|
audio_data = self.check_audio_type(audio_data)
|
||||||
|
|
||||||
|
if self.use_speaker_ver:
|
||||||
|
if self.speaker_ver_model.verfication(self.save_speaker_path,
|
||||||
|
speaker_2_wav=audio_data) == 'no':
|
||||||
|
return "Other People"
|
||||||
|
|
||||||
|
if self.use_emotion:
|
||||||
|
return self.emotion_model.process(audio_data)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("no access")
|
||||||
|
|
||||||
|
def streaming_recognize(self, session_id, audio_data, is_end=False, auto_det_end=False):
|
||||||
|
"""recognize partial result
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data: bytes or numpy array, partial audio data
|
||||||
|
is_end: bool, whether the audio data is the end of a sentence
|
||||||
|
auto_det_end: bool, whether to automatically detect the end of a audio data
|
||||||
|
|
||||||
|
流式语音识别,返回值为:
|
||||||
|
1. 如果说话人非目标说话人,返回字符串 "Other People"
|
||||||
|
2. 如果说话人为目标说话人,返回字典{"test": List[str], "is_end": boolean}
|
||||||
|
"""
|
||||||
|
audio_cache = self.audio_cache[session_id]
|
||||||
|
asr_cache = self.asr_cache[session_id]
|
||||||
|
text_dict = dict(text=[], is_end=is_end)
|
||||||
|
audio_data = self.check_audio_type(audio_data)
|
||||||
|
|
||||||
|
# 说话人识别
|
||||||
|
if self.use_speaker_ver:
|
||||||
|
if self.speaker_ver_model.verfication(self.save_speaker_path,
|
||||||
|
speaker_2_wav=audio_data) == 'no':
|
||||||
|
return "Other People"
|
||||||
|
|
||||||
|
# 语音识别
|
||||||
|
if audio_cache is None:
|
||||||
|
audio_cache = audio_data
|
||||||
|
else:
|
||||||
|
# print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
|
||||||
|
if audio_cache.shape[0] > 0:
|
||||||
|
audio_cache = np.concatenate([audio_cache, audio_data], axis=0)
|
||||||
|
|
||||||
|
if not is_end and audio_cache.shape[0] < self.chunk_partial_size:
|
||||||
|
self.audio_cache[session_id] = audio_cache
|
||||||
|
return text_dict
|
||||||
|
|
||||||
|
total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
|
||||||
|
|
||||||
|
if is_end:
|
||||||
|
# if the audio data is the end of a sentence, \
|
||||||
|
# we need to add one more chunk to the end to \
|
||||||
|
# ensure the end of the sentence is recognized correctly.
|
||||||
|
auto_det_end = True
|
||||||
|
|
||||||
|
if auto_det_end:
|
||||||
|
total_chunk_num += 1
|
||||||
|
|
||||||
|
# print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
|
||||||
|
end_idx = None
|
||||||
|
for i in range(total_chunk_num):
|
||||||
|
if auto_det_end:
|
||||||
|
is_end = i == total_chunk_num - 1
|
||||||
|
start_idx = i*self.chunk_partial_size
|
||||||
|
if auto_det_end:
|
||||||
|
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
|
||||||
|
else:
|
||||||
|
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
|
||||||
|
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
|
||||||
|
# t_stamp = time.time()
|
||||||
|
|
||||||
|
speech_chunk = audio_cache[start_idx:end_idx]
|
||||||
|
|
||||||
|
# TODO: exceptions processes
|
||||||
|
try:
|
||||||
|
res = self.asr_model.generate(input=speech_chunk, cache=asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"ValueError: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 增添标点
|
||||||
|
if self.use_punct:
|
||||||
|
text_dict['text'].append(self.puctuation_model.process(self.text_postprecess(res[0], data_id='text'), cache=text_dict))
|
||||||
|
else:
|
||||||
|
text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
|
||||||
|
|
||||||
|
|
||||||
|
# print(f"each chunk time: {time.time()-t_stamp}")
|
||||||
|
|
||||||
|
if is_end:
|
||||||
|
audio_cache = None
|
||||||
|
asr_cache = {}
|
||||||
|
else:
|
||||||
|
if end_idx:
|
||||||
|
audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
|
||||||
|
text_dict['is_end'] = is_end
|
||||||
|
|
||||||
|
if self.use_punct and is_end:
|
||||||
|
text_dict['text'].append(self.puctuation_model.process('#', cache=text_dict).replace('#', ''))
|
||||||
|
|
||||||
|
self.audio_cache[session_id] = audio_cache
|
||||||
|
self.asr_cache[session_id] = asr_cache
|
||||||
|
# print(f"text_dict: {text_dict}")
|
||||||
|
return text_dict
|
|
@ -1,7 +1,7 @@
|
||||||
from modelscope.pipelines import pipeline
|
from modelscope.pipelines import pipeline
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import pdb
|
|
||||||
ERES2NETV2 = {
|
ERES2NETV2 = {
|
||||||
"task": 'speaker-verification',
|
"task": 'speaker-verification',
|
||||||
"model_name": 'damo/speech_eres2netv2_sv_zh-cn_16k-common',
|
"model_name": 'damo/speech_eres2netv2_sv_zh-cn_16k-common',
|
||||||
|
@ -10,7 +10,7 @@ ERES2NETV2 = {
|
||||||
}
|
}
|
||||||
|
|
||||||
# 保存 embedding 的路径
|
# 保存 embedding 的路径
|
||||||
DEFALUT_SAVE_PATH = r".\takway\savePath"
|
DEFALUT_SAVE_PATH = os.path.join(os.path.dirname(os.path.dirname(__name__)), "speaker_embedding")
|
||||||
|
|
||||||
class speaker_verfication:
|
class speaker_verfication:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -26,9 +26,11 @@ class speaker_verfication:
|
||||||
device=device)
|
device=device)
|
||||||
self.save_embeddings = save_embeddings
|
self.save_embeddings = save_embeddings
|
||||||
|
|
||||||
def wav2embeddings(self, speaker_1_wav):
|
def wav2embeddings(self, speaker_1_wav, save_path=None):
|
||||||
result = self.pipeline([speaker_1_wav], output_emb=True)
|
result = self.pipeline([speaker_1_wav], output_emb=True)
|
||||||
speaker_1_emb = result['embs'][0]
|
speaker_1_emb = result['embs'][0]
|
||||||
|
if save_path is not None:
|
||||||
|
np.save(save_path, speaker_1_emb)
|
||||||
return speaker_1_emb
|
return speaker_1_emb
|
||||||
|
|
||||||
def _verifaction(self, speaker_1_wav, speaker_2_wav, threshold, save_path):
|
def _verifaction(self, speaker_1_wav, speaker_2_wav, threshold, save_path):
|
||||||
|
@ -53,10 +55,19 @@ class speaker_verfication:
|
||||||
return "no"
|
return "no"
|
||||||
|
|
||||||
def verfication(self,
|
def verfication(self,
|
||||||
base_emb,
|
base_emb=None,
|
||||||
speaker_emb,
|
speaker_1_wav=None,
|
||||||
threshold=0.333, ):
|
speaker_2_wav=None,
|
||||||
return np.dot(base_emb, speaker_emb) / (np.linalg.norm(base_emb) * np.linalg.norm(speaker_emb)) > threshold
|
threshold=0.333,
|
||||||
|
save_path=None):
|
||||||
|
if base_emb is not None and speaker_1_wav is not None:
|
||||||
|
raise ValueError("Only need one of them, base_emb or speaker_1_wav")
|
||||||
|
if base_emb is not None and speaker_2_wav is not None:
|
||||||
|
return self._verifaction_from_embedding(base_emb, speaker_2_wav, threshold)
|
||||||
|
elif speaker_1_wav is not None and speaker_2_wav is not None:
|
||||||
|
return self._verifaction(speaker_1_wav, speaker_2_wav, threshold, save_path)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
verifier = speaker_verfication(**ERES2NETV2)
|
verifier = speaker_verfication(**ERES2NETV2)
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
1. 安装 melo
|
||||||
|
参考 https://github.com/myshell-ai/OpenVoice/blob/main/docs/USAGE.md#openvoice-v2
|
||||||
|
2. 修改 openvoice_utils.py 中的路径 (包括 SOURCE_SE_DIR / CACHE_PATH / OPENVOICE_TONE_COLOR_CONVERTER.converter_path )
|
||||||
|
其中:
|
||||||
|
SOURCE_SE_DIR 改为 utils/assets/ses
|
||||||
|
OPENVOICE_TONE_COLOR_CONVERTER.converter_path 修改为下载到的模型参数路径,下载链接为 https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/checkpoints_v2_0417.zip
|
||||||
|
3. 参考 examples/tts_demo.ipynb 进行代码迁移
|
|
@ -140,7 +140,6 @@ def get_se(audio_path, vc_model, target_dir='processed', vad=True):
|
||||||
# if os.path.isdir(audio_path):
|
# if os.path.isdir(audio_path):
|
||||||
# wavs_folder = audio_path
|
# wavs_folder = audio_path
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
if vad:
|
if vad:
|
||||||
wavs_folder = split_audio_vad(audio_path, target_dir=target_dir, audio_name=audio_name)
|
wavs_folder = split_audio_vad(audio_path, target_dir=target_dir, audio_name=audio_name)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,57 +0,0 @@
|
||||||
{
|
|
||||||
"_version_": "v2",
|
|
||||||
"data": {
|
|
||||||
"sampling_rate": 22050,
|
|
||||||
"filter_length": 1024,
|
|
||||||
"hop_length": 256,
|
|
||||||
"win_length": 1024,
|
|
||||||
"n_speakers": 0
|
|
||||||
},
|
|
||||||
"model": {
|
|
||||||
"zero_g": true,
|
|
||||||
"inter_channels": 192,
|
|
||||||
"hidden_channels": 192,
|
|
||||||
"filter_channels": 768,
|
|
||||||
"n_heads": 2,
|
|
||||||
"n_layers": 6,
|
|
||||||
"kernel_size": 3,
|
|
||||||
"p_dropout": 0.1,
|
|
||||||
"resblock": "1",
|
|
||||||
"resblock_kernel_sizes": [
|
|
||||||
3,
|
|
||||||
7,
|
|
||||||
11
|
|
||||||
],
|
|
||||||
"resblock_dilation_sizes": [
|
|
||||||
[
|
|
||||||
1,
|
|
||||||
3,
|
|
||||||
5
|
|
||||||
],
|
|
||||||
[
|
|
||||||
1,
|
|
||||||
3,
|
|
||||||
5
|
|
||||||
],
|
|
||||||
[
|
|
||||||
1,
|
|
||||||
3,
|
|
||||||
5
|
|
||||||
]
|
|
||||||
],
|
|
||||||
"upsample_rates": [
|
|
||||||
8,
|
|
||||||
8,
|
|
||||||
2,
|
|
||||||
2
|
|
||||||
],
|
|
||||||
"upsample_initial_channel": 512,
|
|
||||||
"upsample_kernel_sizes": [
|
|
||||||
16,
|
|
||||||
16,
|
|
||||||
4,
|
|
||||||
4
|
|
||||||
],
|
|
||||||
"gin_channels": 256
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
import hashlib
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -15,15 +16,11 @@ from .openvoice.api import ToneColorConverter
|
||||||
from .openvoice.mel_processing import spectrogram_torch
|
from .openvoice.mel_processing import spectrogram_torch
|
||||||
# torchaudio
|
# torchaudio
|
||||||
import torchaudio.functional as F
|
import torchaudio.functional as F
|
||||||
|
|
||||||
# 存储 BASE SPEAKER 的 embedding(source_se) 的路径
|
# 存储 BASE SPEAKER 的 embedding(source_se) 的路径
|
||||||
current_file_path = os.path.abspath(__file__)
|
SOURCE_SE_DIR = r"D:\python\OpenVoice\checkpoints_v2\base_speakers\ses"
|
||||||
utils_dir = os.path.dirname(os.path.dirname(current_file_path))
|
|
||||||
|
|
||||||
SOURCE_SE_DIR = os.path.join(utils_dir,'assets','ses')
|
|
||||||
|
|
||||||
# 存储缓存文件的路径
|
# 存储缓存文件的路径
|
||||||
CACHE_PATH = r"/tmp/openvoice_cache"
|
CACHE_PATH = r"D:\python\OpenVoice\processed"
|
||||||
|
|
||||||
OPENVOICE_BASE_TTS={
|
OPENVOICE_BASE_TTS={
|
||||||
"model_type": "open_voice_base_tts",
|
"model_type": "open_voice_base_tts",
|
||||||
|
@ -31,11 +28,10 @@ OPENVOICE_BASE_TTS={
|
||||||
"language": "ZH",
|
"language": "ZH",
|
||||||
}
|
}
|
||||||
|
|
||||||
converter_path = os.path.join(os.path.dirname(current_file_path),'openvoice_model')
|
|
||||||
OPENVOICE_TONE_COLOR_CONVERTER={
|
OPENVOICE_TONE_COLOR_CONVERTER={
|
||||||
"model_type": "open_voice_converter",
|
"model_type": "open_voice_converter",
|
||||||
# 模型参数路径
|
# 模型参数路径
|
||||||
"converter_path": converter_path,
|
"converter_path": r"D:\python\OpenVoice\checkpoints_v2\converter",
|
||||||
}
|
}
|
||||||
|
|
||||||
class TextToSpeech:
|
class TextToSpeech:
|
||||||
|
@ -122,7 +118,6 @@ class TextToSpeech:
|
||||||
elif isinstance(se, torch.Tensor):
|
elif isinstance(se, torch.Tensor):
|
||||||
self.target_se = se.float().to(self.device)
|
self.target_se = se.float().to(self.device)
|
||||||
|
|
||||||
#语音转numpy
|
|
||||||
def audio2numpy(self, audio_data: Union[bytes, np.ndarray]):
|
def audio2numpy(self, audio_data: Union[bytes, np.ndarray]):
|
||||||
"""
|
"""
|
||||||
将字节流的audio转为numpy类型,也可以传入numpy类型
|
将字节流的audio转为numpy类型,也可以传入numpy类型
|
||||||
|
@ -148,14 +143,30 @@ class TextToSpeech:
|
||||||
return: np.ndarray
|
return: np.ndarray
|
||||||
"""
|
"""
|
||||||
audio_data = self.audio2numpy(audio_data)
|
audio_data = self.audio2numpy(audio_data)
|
||||||
if not os.path.exists(CACHE_PATH):
|
|
||||||
os.makedirs(CACHE_PATH)
|
|
||||||
|
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
audio_path = os.path.join(CACHE_PATH, "tmp.wav")
|
audio_path = os.path.join(CACHE_PATH, "tmp.wav")
|
||||||
wavfile.write(audio_path, rate=rate, data=audio_data)
|
wavfile.write(audio_path, rate=rate, data=audio_data)
|
||||||
|
|
||||||
se, _ = se_extractor.get_se(audio_path, self.tone_color_converter, target_dir=CACHE_PATH, vad=False)
|
se, _ = se_extractor.get_se(audio_path, self.tone_color_converter, target_dir=CACHE_PATH, vad=False)
|
||||||
|
# device = self.tone_color_converter.device
|
||||||
|
# version = self.tone_color_converter.version
|
||||||
|
# if self.debug:
|
||||||
|
# print("OpenVoice version:", version)
|
||||||
|
|
||||||
|
# audio_name = f"tmp_{version}_{hashlib.sha256(audio_data.tobytes()).hexdigest()[:16].replace('/','_^')}"
|
||||||
|
|
||||||
|
|
||||||
|
# if vad:
|
||||||
|
# wavs_folder = se_extractor.split_audio_vad(audio_path, target_dir=CACHE_PATH, audio_name=audio_name)
|
||||||
|
# else:
|
||||||
|
# wavs_folder = se_extractor.split_audio_whisper(audio_data, target_dir=CACHE_PATH, audio_name=audio_name)
|
||||||
|
|
||||||
|
# audio_segs = glob(f'{wavs_folder}/*.wav')
|
||||||
|
# if len(audio_segs) == 0:
|
||||||
|
# raise NotImplementedError('No audio segments found!')
|
||||||
|
# # se, _ = se_extractor.get_se(audio_data, self.tone_color_converter, CACHE_PATH, vad=False)
|
||||||
|
# se = self.tone_color_converter.extract_se(audio_segs)
|
||||||
return se.cpu().detach().numpy()
|
return se.cpu().detach().numpy()
|
||||||
|
|
||||||
def tensor2numpy(self, audio_data: torch.Tensor):
|
def tensor2numpy(self, audio_data: torch.Tensor):
|
||||||
|
@ -164,14 +175,11 @@ class TextToSpeech:
|
||||||
"""
|
"""
|
||||||
return audio_data.cpu().detach().float().numpy()
|
return audio_data.cpu().detach().float().numpy()
|
||||||
|
|
||||||
def numpy2bytes(self, audio_data):
|
def numpy2bytes(self, audio_data: np.ndarray):
|
||||||
if isinstance(audio_data, np.ndarray):
|
"""
|
||||||
if audio_data.dtype == np.dtype('float32'):
|
numpy类型转bytes
|
||||||
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
|
"""
|
||||||
audio_data = audio_data.tobytes()
|
return (audio_data*32768.0).astype(np.int32).tobytes()
|
||||||
return audio_data
|
|
||||||
else:
|
|
||||||
raise TypeError("audio_data must be a numpy array")
|
|
||||||
|
|
||||||
def _base_tts(self,
|
def _base_tts(self,
|
||||||
text: str,
|
text: str,
|
||||||
|
@ -263,11 +271,6 @@ class TextToSpeech:
|
||||||
audio: tensor
|
audio: tensor
|
||||||
sr: 生成音频的采样速率
|
sr: 生成音频的采样速率
|
||||||
"""
|
"""
|
||||||
if source_se is not None:
|
|
||||||
source_se = torch.tensor(source_se.astype(np.float32)).to(self.device)
|
|
||||||
if target_se is not None:
|
|
||||||
target_se = torch.tensor(target_se.astype(np.float32)).to(self.device)
|
|
||||||
|
|
||||||
if source_se is None:
|
if source_se is None:
|
||||||
source_se = self.source_se
|
source_se = self.source_se
|
||||||
if target_se is None:
|
if target_se is None:
|
||||||
|
@ -293,13 +296,16 @@ class TextToSpeech:
|
||||||
print("tone color has been converted!")
|
print("tone color has been converted!")
|
||||||
return audio, sr
|
return audio, sr
|
||||||
|
|
||||||
def synthesize(self,
|
def tts(self,
|
||||||
text: str,
|
text: str,
|
||||||
tts_info,
|
sdp_ratio=0.2,
|
||||||
|
noise_scale=0.6,
|
||||||
|
noise_scale_w=0.8,
|
||||||
|
speed=1.0,
|
||||||
|
quite=True,
|
||||||
|
|
||||||
source_se: Optional[np.ndarray]=None,
|
source_se: Optional[np.ndarray]=None,
|
||||||
target_se: Optional[np.ndarray]=None,
|
target_se: Optional[np.ndarray]=None,
|
||||||
sdp_ratio=0.2,
|
|
||||||
quite=True,
|
|
||||||
tau :float=0.3,
|
tau :float=0.3,
|
||||||
message :str="default"):
|
message :str="default"):
|
||||||
"""
|
"""
|
||||||
|
@ -316,14 +322,15 @@ class TextToSpeech:
|
||||||
"""
|
"""
|
||||||
audio, sr = self._base_tts(text,
|
audio, sr = self._base_tts(text,
|
||||||
sdp_ratio=sdp_ratio,
|
sdp_ratio=sdp_ratio,
|
||||||
noise_scale=tts_info['noise_scale'],
|
noise_scale=noise_scale,
|
||||||
noise_scale_w=tts_info['noise_scale_w'],
|
noise_scale_w=noise_scale_w,
|
||||||
speed=tts_info['speed'],
|
speed=speed,
|
||||||
quite=quite)
|
quite=quite)
|
||||||
if self.use_tone_convert and target_se.size>0:
|
if self.use_tone_convert:
|
||||||
tts_sr = self.base_tts_model.hps.data.sampling_rate
|
tts_sr = self.base_tts_model.hps.data.sampling_rate
|
||||||
converter_sr = self.tone_color_converter.hps.data.sampling_rate
|
converter_sr = self.tone_color_converter.hps.data.sampling_rate
|
||||||
audio = F.resample(audio, tts_sr, converter_sr)
|
audio = F.resample(audio, tts_sr, converter_sr)
|
||||||
|
print(audio.dtype)
|
||||||
audio, sr = self._convert_tone(audio,
|
audio, sr = self._convert_tone(audio,
|
||||||
source_se=source_se,
|
source_se=source_se,
|
||||||
target_se=target_se,
|
target_se=target_se,
|
||||||
|
@ -343,4 +350,3 @@ class TextToSpeech:
|
||||||
"""
|
"""
|
||||||
sf.write(save_path, audio, sample_rate)
|
sf.write(save_path, audio, sample_rate)
|
||||||
print(f"Audio saved to {save_path}")
|
print(f"Audio saved to {save_path}")
|
||||||
|
|
|
@ -2,7 +2,6 @@ import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import LongTensor
|
from torch import LongTensor
|
||||||
from typing import Optional
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
# vits
|
# vits
|
||||||
from .vits import utils, commons
|
from .vits import utils, commons
|
||||||
|
@ -80,19 +79,19 @@ class TextToSpeech:
|
||||||
print(f"Synthesis time: {time.time() - start_time} s")
|
print(f"Synthesis time: {time.time() - start_time} s")
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
def synthesize(self, text, tts_info,target_se: Optional[np.ndarray]=None, save_audio=False, return_bytes=True):
|
def synthesize(self, text, language, speaker_id, noise_scale, noise_scale_w, length_scale, save_audio=False, return_bytes=False):
|
||||||
if not len(text):
|
if not len(text):
|
||||||
return "输入文本不能为空!", None
|
return "输入文本不能为空!", None
|
||||||
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
|
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
|
||||||
if len(text) > 100 and self.limitation:
|
if len(text) > 100 and self.limitation:
|
||||||
return f"输入文字过长!{len(text)}>100", None
|
return f"输入文字过长!{len(text)}>100", None
|
||||||
text = self._preprocess_text(text, tts_info['language'])
|
text = self._preprocess_text(text, language)
|
||||||
audio = self._generate_audio(text, tts_info['speaker_id'], tts_info['noise_scale'], tts_info['noise_scale_w'], tts_info['length_scale'])
|
audio = self._generate_audio(text, speaker_id, noise_scale, noise_scale_w, length_scale)
|
||||||
if self.debug or save_audio:
|
if self.debug or save_audio:
|
||||||
self.save_audio(audio, self.RATE, 'output_file.wav')
|
self.save_audio(audio, self.RATE, 'output_file.wav')
|
||||||
if return_bytes:
|
if return_bytes:
|
||||||
audio = self.convert_numpy_to_bytes(audio)
|
audio = self.convert_numpy_to_bytes(audio)
|
||||||
return audio, self.RATE
|
return self.RATE, audio
|
||||||
|
|
||||||
def convert_numpy_to_bytes(self, audio_data):
|
def convert_numpy_to_bytes(self, audio_data):
|
||||||
if isinstance(audio_data, np.ndarray):
|
if isinstance(audio_data, np.ndarray):
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
import websockets
|
|
||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
import base64
|
import base64
|
||||||
import hmac
|
import hmac
|
||||||
import json
|
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
from wsgiref.handlers import format_date_time
|
from wsgiref.handlers import format_date_time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -37,33 +35,3 @@ def generate_xf_asr_url():
|
||||||
}
|
}
|
||||||
url = url + '?' + urlencode(v)
|
url = url + '?' + urlencode(v)
|
||||||
return url
|
return url
|
||||||
|
|
||||||
|
|
||||||
def make_first_frame(buf):
|
|
||||||
first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vad_eos":10000},
|
|
||||||
"data":{"status":0,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
|
|
||||||
return json.dumps(first_frame)
|
|
||||||
|
|
||||||
def make_continue_frame(buf):
|
|
||||||
continue_frame = {"data":{"status":1,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
|
|
||||||
return json.dumps(continue_frame)
|
|
||||||
|
|
||||||
def make_last_frame(buf):
|
|
||||||
last_frame = {"data":{"status":2,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
|
|
||||||
return json.dumps(last_frame)
|
|
||||||
|
|
||||||
def parse_xfasr_recv(message):
|
|
||||||
code = message['code']
|
|
||||||
if code!=0:
|
|
||||||
raise Exception("讯飞ASR错误码:"+str(code))
|
|
||||||
else:
|
|
||||||
data = message['data']['result']['ws']
|
|
||||||
result = ""
|
|
||||||
for i in data:
|
|
||||||
for w in i['cw']:
|
|
||||||
result += w['w']
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def xf_asr_websocket_factory():
|
|
||||||
url = generate_xf_asr_url()
|
|
||||||
return await websockets.connect(url)
|
|
Loading…
Reference in New Issue