1
0
Fork 0

Compare commits

..

32 Commits

Author SHA1 Message Date
killua4396 5975d4687f debug: 增加了讯飞语音识别接口的timeout机制,防止语音识别未完成时连接断开 2024-06-05 14:57:18 +08:00
killua 2378fee258 update: 增加杂音屏蔽词,暂时停止记忆、情感识别功能 2024-06-05 10:32:40 +08:00
killua4396 09ffeb6ab6 feat: 增加讯飞asr功能 2024-06-04 18:05:23 +08:00
killua4396 30fdb9c6bd feat: 增加修改speaker_id接口 2024-05-29 18:41:35 +08:00
killua4396 ce033dca2b feat: 后端封装tts,可以通过配置文件切换openvoice和vits 2024-05-24 15:08:55 +08:00
killua4396 e2f3decfae debug: 修复语音克隆输出错误的bug 2024-05-23 15:19:21 +08:00
killua 3b9cc44e4c update: 更新部署文档 2024-05-23 11:48:05 +08:00
killua 0756725cd1 uptdate: 更新部署文档和依赖文档 2024-05-23 10:44:14 +08:00
killua4396 b3e0f26937 update: 在tts_info中加入了speed字段 2024-05-23 10:05:10 +08:00
killua4396 387c277c28 feat: 用户表添加selected_audio_id字段,添加用户音频绑定接口 2024-05-23 09:57:53 +08:00
Killua777 773b48471a 更新部署文档 2024-05-22 17:31:33 +08:00
killua4396 e64c4b0839 update: 暂时取消redis持久化定时任务 2024-05-22 17:27:06 +08:00
killua4396 5cf16bf03b feat: openvoice重构,添加语音克隆功能 2024-05-22 17:26:29 +08:00
killua4396 8369090313 feat: 增加openvoice库 2024-05-22 15:28:23 +08:00
killua4396 fdee2e7520 feat: 通过依赖注入来获取asr和tts对象 2024-05-22 15:26:01 +08:00
Killua777 646c188f78 update: 修改部署文档与依赖文档 2024-05-22 10:04:07 +08:00
killua4396 5b3205e1b7 update: 在用户测试中添加音频除非是 2024-05-21 11:37:23 +08:00
killua4396 031fa32ea0 feat: 为持续流式聊天以及语音电话接口增加记忆功能 2024-05-21 11:36:34 +08:00
killua4396 872cde91e8 feat: 为单次流式聊天接口增加记忆功能 2024-05-20 21:38:44 +08:00
killua 0593bf0482 update: 将app启动位置至根目录下 2024-05-18 20:39:50 +08:00
Killua777 88bcfd03a6 update: 更新md部署文档 2024-05-18 17:58:12 +08:00
killua4396 54d13fba87 feat: 增加情感检测与标点识别 2024-05-18 17:10:09 +08:00
killua4396 017997a33e update: 更换minimax apikey 2024-05-18 16:16:34 +08:00
killua4396 f426878f9e update: 修改requests为aiohttp,实现异步http请求 2024-05-18 11:09:44 +08:00
killua b26f3192bc update: 更新了gitignore文档,忽略nohup.out 2024-05-17 15:59:57 +08:00
killua 0df6ecf394 update: 更新了默认的配置参数
update: 更新了测试代码中的默认ip与端口
2024-05-16 21:18:18 +08:00
killua4396 ff88b94778 update: 更新requirement 2024-05-16 15:31:00 +08:00
killua4396 4a256fa506 feat: 增加了用户音频的增删查改 2024-05-16 13:28:47 +08:00
gaohz 4322b03418 [update] 更新funasr parformer_streaming cache管理,增加seesion_id字段 2024-05-15 23:01:17 +08:00
gaohz a943a281d5 [update] 更新funasr parformer_streaming cache管理,增加seesion_id字段 2024-05-15 22:38:12 +08:00
IrvingGao d930e71410 [update] 更新funasr parformer_streaming cache管理,增加seesion_id字段 2024-05-15 21:56:10 +08:00
killua4396 bd893a4599 update: 更新程序,增加了并发测试程序 2024-05-13 15:14:54 +08:00
57 changed files with 5454 additions and 5499 deletions

8
.gitignore vendored
View File

@ -7,3 +7,11 @@ __pycache__/
app.log
/utils/tts/vits_model/
vits_model
nohup.out
/app/takway-ai.top.key
/app/takway-ai.top.pem
tests/assets/BarbieDollsVoice.mp3
utils/tts/openvoice_model/checkpoint.pth

122
README.md
View File

@ -15,15 +15,15 @@ TakwayAI/
│ │ ├── models.py # 数据库定义
│ ├── schemas/ # 请求和响应模型
│ │ ├── __init__.py
│ │ ├── user.py # 用户相关schema
│ │ ├── user_schemas.py # 用户相关schema
│ │ └── ... # 其他schema
│ ├── controllers/ # 业务逻辑控制器
│ │ ├── __init__.py
│ │ ├── user.py # 用户相关控制器
│ │ ├── user_controllers.py # 用户相关控制器
│ │ └── ... # 其他控制器
│ ├── routes/ # 路由和视图函数
│ │ ├── __init__.py
│ │ ├── user.py # 用户相关路由
│ │ ├── user_routes.py # 用户相关路由
│ │ └── ... # 其他路由
│ ├── dependencies/ # 依赖注入相关
│ │ ├── __init__.py
@ -64,21 +64,129 @@ TakwayAI/
git clone http://43.132.157.186:3000/killua/TakwayPlatform.git
```
#### (2) 安装依赖
#### (2) 创建虚拟环境
创建虚拟环境
``` shell
cd TakwayAI/
conda create -n takway python=3.9
conda activate takway
```
#### (3) 安装依赖
如果你的本地环境可以科学上网,则直接运行下面两行指令
``` shell
pip install git+https://github.com/myshell-ai/MeloTTS.git
python -m unidic download
```
如果不能科学上网
则先运行
``` shell
pip install git+https://github.com/myshell-ai/MeloTTS.git
```
1. unidic安装
然后手动下载[unidic.zip](https://cotonoha-dic.s3-ap-northeast-1.amazonaws.com/unidic-3.1.0.zip)并手动改名为unidic.zip
这边以miniconda举例如果用的是conda应该也是一样的
将unidic.zip拷贝入~/miniconda3/envs/takway/lib/python3.9/site-packages/unidic
cd进入~/miniconda3/envs/takway/lib/python3.9/site-packages/unidic
vim download.py
将函数download_version()中除了最后一行全部注释掉并且把最后一行的download_and_clean()的两个参数任意修改,比如"hello","world"
再将download_and_clean()函数定义位置注释掉该函数中的download_process()行
运行`python -m unidic download`
2. huggingface配置
运行命令
```shell
pip install -U huggingface_hub
export HF_ENDPOINT=https://hf-mirror.com
```
最好把`export HF_ENDPOINT=https://hf-mirror.com`写入~/.bashrc不然每次重启控制台终端就会失效
3. nltk_data下载
在/miniconda3/envs/takway/下创建nltk_data文件夹
在nltk_data文件夹下创建corpora和taggers文件夹
手动下载[cmudict.zip](https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/cmudict.zip)和[averaged_perceptron_tragger.zip](https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/taggers/averaged_perceptron_tagger.zip)
将cmudict.zip放入corpora文件夹下
将averaged_perceptron_tragger.zip放入taggers文件夹下
4. 下载其他依赖
``` shell
cd TakwayPlatform/
pip install -r requirements.txt
```
#### (3) 修改配置
5. debug
若出现AttributeError: module 'botocore.exceptions' has no attribute 'HTTPClientError'异常
则执行下述命令
``` shell
pip uninstall botocore
pip install botocore==1.34.88
```
#### (4) 安装FunASR
本项目使用的FunASRE在github上的FunASR的基础上做了一些修改
``` shell
git clone http://43.132.157.186:3000/gaohz/FunASR.git
cd FunASR/
pip install -v -e .
```
#### (5) 修改配置
1. 安装mysql在mysql中创建名为takway的数据库
2. 安装redis将密码设置为takway
3. 打开config中的development.py文件修改mysql和redis连接字符串
#### (4) 导入vits模型
#### (6) 导入vits模型
在utils/tts/目录下创建vits_model文件夹
从[链接](https://huggingface.co/spaces/zomehwh/vits-uma-genshin-honkai/tree/main/model)下载 vits_model并放入该文件夹下只需下载config.json和G_953000.pth即可
#### (7) 导入openvoice模型
在utils/tts/目录下创建openvoice_model文件夹
从[链接](https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/checkpoints_v2_0417.zip)下载model文件并放入该文件夹下
#### (8) 添加配置环境变量
``` shell
vim ~/.bashrc
在最后添加
export MODE=development
```
#### (9) 启动程序
``` shell
cd TakwayPlatform
python main.py
```

View File

@ -34,7 +34,7 @@ logger.info("数据库初始化完成")
#--------------------设置定时任务-----------------------
scheduler = AsyncIOScheduler()
scheduler.add_job(updating_redis_cache, CronTrigger.from_crontab("0 4 * * *"))
# scheduler.add_job(updating_redis_cache, CronTrigger.from_crontab("0 4 * * *"))
@asynccontextmanager
async def lifespan(app:FastAPI):
scheduler.start() #启动定时任务

View File

@ -1,35 +1,38 @@
from ..schemas.chat_schema import *
from ..dependencies.logger import get_logger
from ..dependencies.summarizer import get_summarizer
from ..dependencies.asr import get_asr
from ..dependencies.tts import get_tts
from .controller_enum import *
from ..models import UserCharacter, Session, Character, User
from ..models import UserCharacter, Session, Character, User, Audio
from utils.audio_utils import VAD
from fastapi import WebSocket, HTTPException, status
from datetime import datetime
from utils.xf_asr_utils import generate_xf_asr_url
from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
from config import get_config
import numpy as np
import websockets
import struct
import uuid
import json
import asyncio
import requests
import aiohttp
import io
# 依赖注入获取logger
logger = get_logger()
# --------------------初始化本地ASR-----------------------
from utils.stt.funasr_utils import FunAutoSpeechRecognizer
# 依赖注入获取context总结服务
summarizer = get_summarizer()
asr = FunAutoSpeechRecognizer()
logger.info("本地ASR初始化成功")
# -----------------------获取ASR-------------------------
asr = get_asr()
# -------------------------------------------------------
# --------------------初始化本地VITS----------------------
from utils.tts.vits_utils import TextToSpeech
tts = TextToSpeech(device='cpu')
logger.info("本地TTS初始化成功")
# -------------------------TTS--------------------------
tts = get_tts()
# -------------------------------------------------------
# 依赖注入获取Config
Config = get_config()
@ -50,16 +53,20 @@ def get_session_content(session_id,redis,db):
def parseChunkDelta(chunk):
try:
if chunk == b"":
return ""
return 1,""
decoded_data = chunk.decode('utf-8')
parsed_data = json.loads(decoded_data[6:])
if 'delta' in parsed_data['choices'][0]:
delta_content = parsed_data['choices'][0]['delta']
return delta_content['content']
return -1, delta_content['content']
else:
return "end"
return parsed_data['usage']['total_tokens'] , ""
except KeyError:
logger.error(f"error chunk: {chunk}")
logger.error(f"error chunk: {decoded_data}")
return 1,""
except json.JSONDecodeError:
logger.error(f"error chunk: {decoded_data}")
return 1,""
#断句函数
def split_string_with_punctuation(current_sentence,text,is_first,is_end):
@ -97,6 +104,19 @@ def update_session_activity(session_id,db):
except Exception as e:
db.roolback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
#获取target_se
def get_emb(session_id,db):
try:
session_record = db.query(Session).filter(Session.id == session_id).first()
user_character_record = db.query(UserCharacter).filter(UserCharacter.id == session_record.user_character_id).first()
user_record = db.query(User).filter(User.id == user_character_record.user_id).first()
audio_record = db.query(Audio).filter(Audio.id == user_record.selected_audio_id).first()
emb_npy = np.load(io.BytesIO(audio_record.emb_data))
return emb_npy
except Exception as e:
logger.debug("未找到音频:"+str(e))
return np.array([])
#--------------------------------------------------------
# 创建新聊天
@ -106,7 +126,6 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
try:
db.add(new_chat)
db.commit()
db.refresh(new_chat)
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@ -132,18 +151,23 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
"speaker_id":db_character.voice_id,
"noise_scale": 0.1,
"noise_scale_w":0.668,
"length_scale": 1.2
"length_scale": 1.2,
"speed":1
}
llm_info = {
"model": "abab5.5-chat",
"temperature": 1,
"top_p": 0.9,
}
user_info = {
"character":"",
"events":[]
}
# 将tts和llm信息转化为json字符串
tts_info_str = json.dumps(tts_info, ensure_ascii=False)
llm_info_str = json.dumps(llm_info, ensure_ascii=False)
user_info_str = db_user.persona
user_info_str = json.dumps(user_info, ensure_ascii=False)
token = 0
content = {"user_id": user_id, "messages": messages, "user_info": user_info_str, "tts_info": tts_info_str,
@ -154,7 +178,6 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
# 将Session记录存入
db.add(new_session)
db.commit()
db.refresh(new_session)
redis.set(session_id, json.dumps(content, ensure_ascii=False))
chat_create_data = ChatCreateData(user_character_id=new_chat.id, session_id=session_id, createdAt=datetime.now().isoformat())
@ -223,26 +246,75 @@ async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f
logger.error(f"用户输入处理函数发生错误: {str(e)}")
#语音识别
async def sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event):
async def sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event):
logger.debug("语音识别函数启动")
is_signup = False
try:
if Config.STRAM_CHAT.ASR == "LOCAL":
is_signup = False
audio = ""
try:
current_message = ""
while not (user_input_finish_event.is_set() and user_input_q.empty()):
if not is_signup:
asr.session_signup(session_id)
is_signup = True
audio_data = await user_input_q.get()
audio += audio_data
asr_result = asr.streaming_recognize(session_id,audio_data)
current_message += ''.join(asr_result['text'])
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
current_message += ''.join(asr_result['text'])
slice_arr = ["",""]
if current_message in slice_arr:
await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False))
return
current_message = asr.punctuation_correction(current_message)
# emotion_dict = asr.emtion_recognition(audio) #情感辨识
# if not isinstance(emotion_dict, str):
# max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
# current_message = f"{current_message},当前说话人的情绪:{emotion_dict['labels'][max_index]}"
await llm_input_q.put(current_message)
asr.session_signout(session_id)
except Exception as e:
asr.session_signout(session_id)
logger.error(f"语音识别函数发生错误: {str(e)}")
logger.debug(f"接收到用户消息: {current_message}")
elif Config.STRAM_CHAT.ASR == "XF":
status = FIRST_FRAME
xf_websocket = await xf_asr_websocket_factory() #获取一个讯飞语音识别接口websocket连接
segment_duration_threshold = 25 #设置一个连接时长上限讯飞语音接口超过30秒会自动断开连接所以该值设置成25秒
segment_start_time = asyncio.get_event_loop().time()
current_message = ""
while not (user_input_finish_event.is_set() and user_input_q.empty()):
if not is_signup:
asr.session_signup(session_id)
is_signup = True
audio_data = await user_input_q.get()
asr_result = asr.streaming_recognize(session_id,audio_data)
current_message += ''.join(asr_result['text'])
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
current_message += ''.join(asr_result['text'])
try:
audio_data = await user_input_q.get()
current_time = asyncio.get_event_loop().time()
if current_time - segment_start_time > segment_duration_threshold:
await xf_websocket.send(make_last_frame())
current_message += parse_xfasr_recv(await xf_websocket.recv())
await xf_websocket.close()
xf_websocket = await xf_asr_websocket_factory() #重建一个websocket连接
status = FIRST_FRAME
segment_start_time = current_time
if status == FIRST_FRAME:
await xf_websocket.send(make_first_frame(audio_data))
status = CONTINUE_FRAME
elif status == CONTINUE_FRAME:
await xf_websocket.send(make_continue_frame(audio_data))
except websockets.exceptions.ConnectionClosedOK:
logger.debug("讯飞语音识别接口连接断开,重新创建连接")
xf_websocket = await xf_asr_websocket_factory() #重建一个websocket连接
status = FIRST_FRAME
segment_start_time = asyncio.get_event_loop().time()
await xf_websocket.send(make_last_frame())
current_message += parse_xfasr_recv(await xf_websocket.recv())
await xf_websocket.close()
if current_message in ["", ""]:
await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False))
return
await llm_input_q.put(current_message)
asr.session_signout(session_id)
except Exception as e:
asr.session_signout(session_id)
logger.error(f"语音识别函数发生错误: {str(e)}")
logger.debug(f"接收到用户消息: {current_message}")
logger.debug(f"接收到用户消息: {current_message}")
#大模型调用
async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event):
@ -253,13 +325,18 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
is_first = True
is_end = False
session_content = get_session_content(session_id,redis,db)
user_info = json.loads(session_content["user_info"])
messages = json.loads(session_content["messages"])
current_message = await llm_input_q.get()
if current_message == "":
return
messages.append({'role': 'user', "content": current_message})
messages_send = messages #创造一个message副本在其中最后一条数据前面添加用户信息
# messages_send[-1]['content'] = f"用户性格:{user_info['character']}\n事件摘要:{user_info['events']}" + messages_send[-1]['content']
payload = json.dumps({
"model": llm_info["model"],
"stream": True,
"messages": messages,
"messages": messages_send,
"max_tokens": 10000,
"temperature": llm_info["temperature"],
"top_p": llm_info["top_p"]
@ -268,35 +345,48 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json'
}
response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload,stream=True) #调用大模型
target_se = get_emb(session_id,db)
except Exception as e:
logger.error(f"llm调用发生错误: {str(e)}")
logger.error(f"编辑http请求时发生错误: {str(e)}")
try:
for chunk in response.iter_lines():
chunk_data = parseChunkDelta(chunk)
is_end = chunk_data == "end"
if not is_end:
llm_response += chunk_data
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) #断句
for sentence in sentences:
if response_type == RESPONSE_TEXT:
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
elif response_type == RESPONSE_AUDIO:
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_bytes(audio) #返回音频数据
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
logger.debug(f"websocket返回: {sentence}")
if is_end:
logger.debug(f"llm返回结果: {llm_response}")
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
is_end = False #重置is_end标志位
messages.append({'role': 'assistant', "content": llm_response})
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
is_first = True
llm_response = ""
async with aiohttp.ClientSession() as client:
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
async for chunk in response.content.iter_any():
token_count, chunk_data = parseChunkDelta(chunk)
is_end = token_count >0
if not is_end:
llm_response += chunk_data
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) #断句
for sentence in sentences:
if response_type == RESPONSE_TEXT:
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
elif response_type == RESPONSE_AUDIO:
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
response_message = {"type": "text", "code":200, "msg": sentence}
response_bytes = json.dumps(response_message, ensure_ascii=False).encode('utf-8')
header = struct.pack('!II',len(response_bytes),len(audio))
message_bytes = header + response_bytes + audio
await ws.send_bytes(message_bytes)
logger.debug(f"websocket返回: {sentence}")
if is_end:
logger.debug(f"llm返回结果: {llm_response}")
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
is_end = False #重置is_end标志位
messages.append({'role': 'assistant', "content": llm_response})
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
is_first = True
llm_response = ""
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数则进行文本摘要
system_prompt = messages[0]['content']
summary = await summarizer.summarize(messages)
events = user_info['events']
events.append(summary['event'])
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
logger.debug(f"总结后session_content: {session_content}")
except Exception as e:
logger.error(f"处理llm返回结果发生错误: {str(e)}")
chat_finished_event.set()
@ -315,7 +405,7 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis):
session_id = await future_session_id #获取session_id
update_session_activity(session_id,db)
response_type = await future_response_type #获取返回类型
asyncio.create_task(sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event))
asyncio.create_task(sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event))
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
@ -372,6 +462,7 @@ async def scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_eve
logger.debug("语音识别函数启动")
is_signup = False
current_message = ""
audio = ""
while not (input_finished_event.is_set() and user_input_q.empty()):
try:
aduio_frame = await asyncio.wait_for(user_input_q.get(),timeout=3)
@ -381,15 +472,23 @@ async def scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_eve
if aduio_frame['is_end']:
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'], is_end=True)
current_message += ''.join(asr_result['text'])
current_message = asr.punctuation_correction(current_message)
audio += aduio_frame['audio']
emotion_dict =asr.emtion_recognition(audio) #情感辨识
if not isinstance(emotion_dict, str):
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
current_message = f"{current_message}当前说话人的情绪:{emotion_dict['labels'][max_index]}"
await llm_input_q.put(current_message)
logger.debug(f"接收到用户消息: {current_message}")
current_message = ""
audio = ""
else:
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'])
audio += aduio_frame['audio']
current_message += ''.join(asr_result['text'])
except asyncio.TimeoutError:
continue
except Exception as e:
asr.session_signout(session_id)
logger.error(f"语音识别函数发生错误: {str(e)}")
break
asr.session_signout(session_id)
@ -406,6 +505,7 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
try:
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"])
user_info = json.loads(session_content["user_info"])
current_message = await asyncio.wait_for(llm_input_q.get(),timeout=3)
messages.append({'role': 'user', "content": current_message})
payload = json.dumps({
@ -420,39 +520,49 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json'
}
response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload,stream=True)
for chunk in response.iter_lines():
chunk_data = parseChunkDelta(chunk)
is_end = chunk_data == "end"
if not is_end:
llm_response += chunk_data
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
for sentence in sentences:
if response_type == RESPONSE_TEXT:
logger.debug(f"websocket返回: {sentence}")
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
elif response_type == RESPONSE_AUDIO:
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_bytes(audio)
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
logger.debug(f"websocket返回: {sentence}")
if is_end:
logger.debug(f"llm返回结果: {llm_response}")
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
is_end = False
messages.append({'role': 'assistant', "content": llm_response})
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
is_first = True
llm_response = ""
target_se = get_emb(session_id,db)
async with aiohttp.ClientSession() as client:
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
async for chunk in response.content.iter_any():
token_count, chunk_data = parseChunkDelta(chunk)
is_end = token_count >0
if not is_end:
llm_response += chunk_data
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
for sentence in sentences:
if response_type == RESPONSE_TEXT:
logger.debug(f"websocket返回: {sentence}")
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
elif response_type == RESPONSE_AUDIO:
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_bytes(audio)
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
logger.debug(f"websocket返回: {sentence}")
if is_end:
logger.debug(f"llm返回结果: {llm_response}")
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
is_end = False
messages.append({'role': 'assistant', "content": llm_response})
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
is_first = True
llm_response = ""
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于70%的最大token数则进行文本摘要
system_prompt = messages[0]['content']
summary = await summarizer.summarize(messages)
events = user_info['events']
events.append(summary['event'])
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
logger.debug(f"总结后session_content: {session_content}")
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"处理llm返回结果发生错误: {str(e)}")
break
# except Exception as e:
# logger.error(f"处理llm返回结果发生错误: {str(e)}")
# break
chat_finished_event.set()
async def streaming_chat_lasting_handler(ws,db,redis):
@ -523,6 +633,7 @@ async def voice_call_audio_consumer(ws,session_id,audio_q,asr_result_q,input_fin
current_message = ""
vad_count = 0
is_signup = False
audio = ""
while not (input_finished_event.is_set() and audio_q.empty()):
try:
if not is_signup:
@ -533,14 +644,22 @@ async def voice_call_audio_consumer(ws,session_id,audio_q,asr_result_q,input_fin
if vad_count > 0:
vad_count -= 1
asr_result = asr.streaming_recognize(session_id, audio_data)
audio += audio_data
current_message += ''.join(asr_result['text'])
else:
vad_count += 1
if vad_count >= 25: #连续25帧没有语音则认为说完了
asr_result = asr.streaming_recognize(session_id, audio_data, is_end=True)
if current_message:
current_message = asr.punctuation_correction(current_message)
audio += audio_data
emotion_dict =asr.emtion_recognition(audio) #情感辨识
if not isinstance(emotion_dict, str):
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
current_message = f"{current_message}当前说话人的情绪:{emotion_dict['labels'][max_index]}"
logger.debug(f"检测到静默,用户输入为:{current_message}")
await asr_result_q.put(current_message)
audio = ""
text_response = {"type": "user_text", "code": 200, "msg": current_message}
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
current_message = ""
@ -565,6 +684,7 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
try:
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"])
user_info = json.loads(session_content["user_info"])
current_message = await asyncio.wait_for(asr_result_q.get(),timeout=3)
messages.append({'role': 'user', "content": current_message})
payload = json.dumps({
@ -575,34 +695,43 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
"temperature": llm_info["temperature"],
"top_p": llm_info["top_p"]
})
headers = {
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json'
}
response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload,stream=True)
for chunk in response.iter_lines():
chunk_data = parseChunkDelta(chunk)
is_end = chunk_data == "end"
if not is_end:
llm_response += chunk_data
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
for sentence in sentences:
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
await ws.send_bytes(audio) #返回音频二进制流数据
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
logger.debug(f"llm返回结果: {sentence}")
if is_end:
logger.debug(f"llm返回结果: {llm_response}")
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
is_end = False
messages.append({'role': 'assistant', "content": llm_response})
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
is_first = True
llm_response = ""
target_se = get_emb(session_id,db)
async with aiohttp.ClientSession() as client:
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
async for chunk in response.content.iter_any():
token_count, chunk_data = parseChunkDelta(chunk)
is_end = token_count >0
if not is_end:
llm_response += chunk_data
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
for sentence in sentences:
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
await ws.send_bytes(audio) #返回音频二进制流数据
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
logger.debug(f"websocket返回: {sentence}")
if is_end:
logger.debug(f"llm返回结果: {llm_response}")
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
is_end = False
messages.append({'role': 'assistant', "content": llm_response})
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
is_first = True
llm_response = ""
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数则进行文本摘要
system_prompt = messages[0]['content']
summary = await summarizer.summarize(messages)
events = user_info['events']
events.append(summary['event'])
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
logger.debug(f"总结后session_content: {session_content}")
except asyncio.TimeoutError:
continue
except Exception as e:
@ -610,23 +739,6 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
break
voice_call_end_event.set()
#语音合成及返回函数
async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event):
logger.debug("语音合成及返回函数启动")
while not (split_finished_event.is_set() and split_result_q.empty()):
try:
sentence = await asyncio.wait_for(split_result_q.get(),timeout=3)
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
await ws.send_bytes(audio) #返回音频二进制流数据
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
logger.debug(f"websocket返回:{sentence}")
except asyncio.TimeoutError:
continue
voice_call_end_event.set()
async def voice_call_handler(ws, db, redis):
logger.debug("voice_call websocket 连接建立")
audio_q = asyncio.Queue() #音频队列

View File

@ -77,3 +77,32 @@ async def update_session_handler(session_id, session_data:SessionUpdateRequest,
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
session_update_data = SessionUpdateData(updatedAt=datetime.now().isoformat())
return SessionUpdateResponse(status="success",message="Session 更新成功",data=session_update_data)
#更新Session中的Speaker Id信息
async def update_session_speaker_id_handler(session_id, session_data, db, redis):
existing_session = ""
if redis.exists(session_id):
existing_session = redis.get(session_id)
else:
existing_session = db.query(Session).filter(Session.id == session_id).first().content
if not existing_session:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
#更新Session字段
session = json.loads(existing_session)
session_llm_info = json.loads(session["llm_info"])
session_llm_info["speaker_id"] = session_data.speaker_id
session["llm_info"] = json.dumps(session_llm_info,ensure_ascii=False)
#存储Session
session_str = json.dumps(session,ensure_ascii=False)
redis.set(session_id, session_str)
try:
db.query(Session).filter(Session.id == session_id).update({"content": session_str})
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
session_update_data = SessionSpeakerUpdateData(updatedAt=datetime.now().isoformat())
return SessionSpeakerUpdateResponse(status="success",message="Session SpeakID更新成功",data=session_update_data)

View File

@ -1,14 +1,20 @@
from ..schemas.user_schema import *
from ..dependencies.logger import get_logger
from ..models import User, Hardware
from ..dependencies.tts import get_tts
from ..models import User, Hardware, Audio
from datetime import datetime
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from pydub import AudioSegment
import numpy as np
import io
#依赖注入获取logger
logger = get_logger()
#依赖注入获取tts
tts = get_tts("OPENVOICE")
#创建用户
async def create_user_handler(user:UserCrateRequest, db: Session):
@ -36,7 +42,6 @@ async def update_user_handler(user_id:int, user:UserUpdateRequest, db: Session):
existing_user.persona = user.persona
try:
db.commit()
db.refresh(existing_user)
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@ -117,7 +122,6 @@ async def change_bind_hardware_handler(hardware_id, user, db):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
existing_hardware.user_id = user.user_id
db.commit()
db.refresh(existing_hardware)
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@ -135,7 +139,6 @@ async def update_hardware_handler(hardware_id, hardware, db):
existing_hardware.firmware = hardware.firmware
existing_hardware.model = hardware.model
db.commit()
db.refresh(existing_hardware)
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@ -154,3 +157,96 @@ async def get_hardware_handler(hardware_id, db):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
hardware_query_data = HardwareQueryData(mac=existing_hardware.mac, user_id=existing_hardware.user_id, firmware=existing_hardware.firmware, model=existing_hardware.model)
return HardwareQueryResponse(status="success", message="查询硬件信息成功", data=hardware_query_data)
#用户上传音频
async def upload_audio_handler(user_id, audio, db):
try:
audio_data = await audio.read()
emb_data = tts.audio2emb(np.frombuffer(AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data, dtype=np.int32),rate=44100,vad=True)
out = io.BytesIO()
np.save(out, emb_data)
out.seek(0)
emb_binary = out.read()
new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_binary) #创建音频
db.add(new_audio)
db.commit()
db.refresh(new_audio)
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat())
return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data)
#用户更新音频
async def update_audio_handler(audio_id, audio_file, db):
try:
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
if existing_audio is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
audio_data = await audio_file.read()
raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data
emb_data = tts.audio2emb(np.frombuffer(raw_data, dtype=np.int32),rate=44100,vad=True).tobytes()
existing_audio.audio_data = audio_data
existing_audio.emb_data = emb_data
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
audio_update_data = AudioUpdateData(updatedAt=datetime.now().isoformat())
return AudioUpdateResponse(status="success", message="用户更新音频成功", data=audio_update_data)
#用户查询音频
async def download_audio_handler(audio_id, db):
try:
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if existing_audio is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
audio_data = existing_audio.audio_data
return audio_data
#用户删除音频
async def delete_audio_handler(audio_id, db):
try:
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
existing_user = db.query(User).filter(User.selected_audio_id == audio_id).first()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if existing_audio is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
try:
if existing_user.selected_audio_id == audio_id:
existing_user.selected_audio_id = None
db.delete(existing_audio)
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
audio_delete_data = AudioDeleteData(deletedAt=datetime.now().isoformat())
return AudioDeleteResponse(status="success", message="用户删除音频成功", data=audio_delete_data)
#用户绑定音频
async def bind_audio_handler(bind_req, db):
try:
existing_user = db.query(User).filter(User.id == bind_req.user_id).first()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if existing_user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
try:
existing_user.selected_audio_id = bind_req.audio_id
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
audio_bind_data = AudioBindData(bindedAt=datetime.now().isoformat())
return AudioBindResponse(status="success", message="用户绑定音频成功", data=audio_bind_data)

11
app/dependencies/asr.py Normal file
View File

@ -0,0 +1,11 @@
from utils.stt.modified_funasr import ModifiedRecognizer
from app.dependencies.logger import get_logger
logger = get_logger()
#初始化全局asr对象
asr = ModifiedRecognizer()
logger.info("ASR初始化成功")
def get_asr():
return asr

View File

@ -0,0 +1,61 @@
import aiohttp
import json
from datetime import datetime
from config import get_config
# 依赖注入获取Config
Config = get_config()
class Summarizer:
def __init__(self):
self.system_prompt = """你是一台对话总结机器,你的职责是整理用户与玩具之间的对话,最终提炼出对话中发生的事件,以及用户性格\n\n你的输出必须为一个json里面有两个字段一个是event一个是character将你总结出的事件写入event将你总结出的用户性格写入character\nevent和character均为字符串\n返回示例:{"event":"在幼儿园看葫芦娃,老鹰抓小鸡","character":"活泼可爱"}"""
self.model = "abab5.5-chat"
self.max_token = 10000
self.temperature = 0.9
self.top_p = 1
async def summarize(self,messages):
context = ""
for message in messages:
if message['role'] == 'user':
context += "用户:"+ message['content'] + '\n'
elif message['role'] == 'assistant':
context += '玩具:'+ message['content'] + '\n'
payload = json.dumps({
"model":self.model,
"messages":[
{
"role":"system",
"content":self.system_prompt
},
{
"role":"user",
"content":context
}
],
"max_tokens":self.max_token,
"top_p":self.top_p
})
headers = {
'Authorization': f'Bearer {Config.MINIMAX_LLM.API_KEY}',
'Content-Type': 'application/json'
}
async with aiohttp.ClientSession() as session:
async with session.post(Config.MINIMAX_LLM.URL, data=payload, headers=headers) as response:
content = json.loads(json.loads(await response.text())['choices'][0]['message']['content'])
try:
summary = {
'event':datetime.now().strftime("%Y-%m-%d")+""+content['event'],
'character':content['character']
}
except TypeError:
summary = {
'event':datetime.now().strftime("%Y-%m-%d")+""+content['event'][0],
'character':""
}
return summary
def get_summarizer():
summarizer = Summarizer()
return summarizer

23
app/dependencies/tts.py Normal file
View File

@ -0,0 +1,23 @@
from app.dependencies.logger import get_logger
from config import get_config
logger = get_logger()
Config = get_config()
from utils.tts.openvoice_utils import TextToSpeech
openvoice_tts = TextToSpeech(use_tone_convert=True,device='cuda')
logger.info("TTS_OPENVOICE 初始化成功")
from utils.tts.vits_utils import TextToSpeech
vits_tts = TextToSpeech()
logger.info("TTS_VITS 初始化成功")
#初始化全局tts对象
def get_tts(tts_type=Config.TTS_UTILS):
if tts_type == "OPENVOICE":
return openvoice_tts
elif tts_type == "VITS":
return vits_tts

View File

@ -1,26 +0,0 @@
from app import app, Config
import uvicorn
if __name__ == "__main__":
uvicorn.run(app, host=Config.UVICORN.HOST, port=Config.UVICORN.PORT)
#uvicorn.run("app.main:app", host=Config.UVICORN.HOST, port=Config.UVICORN.PORT, workers=Config.UVICORN.WORKERS)
# _ooOoo_ #
# o8888888o #
# 88" . "88 #
# (| -_- |) #
# O\ = /O #
# ____/`---'\____ #
# . ' \\| |// `. #
# / \\||| : |||// \ #
# / _||||| -:- |||||- \ #
# | | \\\ - /// | | #
# \ .-\__ `-` ___/-. / #
# ___`. .' /--.--\ `. . __ #
# ."" '< `.___\_<|>_/___.' >'"". #
# | | : `- \`.;`\ _ /`;.`/ - ` : | | #
# \ \ `-. \_ __\ /__ _/ .-` / / #
# ======`-.____`-.___\_____/___.-`____.-'====== #
# `=---=' #
# ............................................. #
# 佛祖保佑 永无BUG #

View File

@ -1,4 +1,4 @@
from sqlalchemy import Column, Integer, String, JSON, Text, ForeignKey, DateTime, Boolean, CHAR
from sqlalchemy import Column, Integer, String, JSON, Text, ForeignKey, DateTime, Boolean, CHAR, LargeBinary
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
@ -36,6 +36,7 @@ class User(Base):
avatar_id = Column(String(36), nullable=True)
tags = Column(JSON)
persona = Column(JSON)
selected_audio_id = Column(Integer, nullable=True)
def __repr__(self):
return f"<User(id={self.id}, tags={self.tags})>"
@ -80,3 +81,12 @@ class Session(Base):
def __repr__(self):
return f"<Session(id={self.id}, user_character_id={self.user_character_id})>"
#音频表定义
class Audio(Base):
__tablename__ = 'audio'
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('user.id'))
audio_data = Column(LargeBinary(16777215))
emb_data = Column(LargeBinary(65535))

View File

@ -27,3 +27,10 @@ async def get_session(session_id: str, db=Depends(get_db), redis=Depends(get_red
async def update_session(session_id: str, session_data: SessionUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
response = await update_session_handler(session_id, session_data, db, redis)
return response
#session声音信息更新接口
@router.put("/sessions/tts_info/speaker_id/{session_id}", response_model=SessionSpeakerUpdateResponse)
async def update_session_speaker_id(session_id: str, session_data: SessionSpeakerUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
response = await update_session_speaker_id_handler(session_id, session_data, db, redis)
return response

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, status
from fastapi import APIRouter, UploadFile, File, Response
from ..controllers.user_controller import *
from fastapi import Depends
from sqlalchemy.orm import Session
@ -69,3 +69,38 @@ async def update_hardware_info(hardware_id: int, hardware: HardwareUpdateRequest
async def get_hardware(hardware_id: int, db: Session = Depends(get_db)):
response = await get_hardware_handler(hardware_id, db)
return response
#用户音频上传
@router.post('/users/audio',response_model=AudioUploadResponse)
async def upload_audio(user_id:int, audio_file:UploadFile=File(...), db: Session = Depends(get_db)):
response = await upload_audio_handler(user_id, audio_file, db)
return response
#用户音频修改
@router.put('/users/audio/{audio_id}',response_model=AudioUpdateResponse)
async def update_audio(audio_id:int, audio_file:UploadFile=File(...), db: Session = Depends(get_db)):
response = await update_audio_handler(audio_id, audio_file, db)
return response
#用户音频下载
@router.get('/users/audio/{audio_id}')
async def download_audio(audio_id:int, db: Session = Depends(get_db)):
audio_data = await download_audio_handler(audio_id, db)
return Response(content=audio_data,media_type='application/octet-stream',headers={"Content-Disposition": "attachment"})
#用户音频删除
@router.delete('/users/audio/{audio_id}',response_model=AudioDeleteResponse)
async def delete_audio(audio_id:int, db: Session = Depends(get_db)):
response = await delete_audio_handler(audio_id, db)
return response
#用户绑定音频
@router.post('/users/audio/bind',response_model=AudioBindResponse)
async def bind_audio(bind_req:AudioBindRequest, db: Session = Depends(get_db)):
response = await bind_audio_handler(bind_req, db)
return response

View File

@ -56,3 +56,15 @@ class SessionUpdateData(BaseModel):
class SessionUpdateResponse(BaseResponse):
data: Optional[SessionUpdateData]
#--------------------------------------------------------------------------
#------------------------------Session Speaker Id修改----------------------
class SessionSpeakerUpdateRequest(BaseModel):
speaker_id: int
class SessionSpeakerUpdateData(BaseModel):
updatedAt:str
class SessionSpeakerUpdateResponse(BaseResponse):
data: Optional[SessionSpeakerUpdateData]
#--------------------------------------------------------------------------

View File

@ -3,7 +3,6 @@ from typing import Optional
from .base_schema import BaseResponse
#---------------------------------用户创建----------------------------------
#用户创建请求类
class UserCrateRequest(BaseModel):
@ -138,3 +137,46 @@ class HardwareQueryData(BaseModel):
class HardwareQueryResponse(BaseResponse):
data: Optional[HardwareQueryData]
#------------------------------------------------------------------------------
#-------------------------------用户音频上传-------------------------------------
class AudioUploadData(BaseModel):
audio_id: int
uploadedAt: str
class AudioUploadResponse(BaseResponse):
data: Optional[AudioUploadData]
#-------------------------------------------------------------------------------
#-------------------------------用户音频修改-------------------------------------
class AudioUpdateData(BaseModel):
updatedAt: str
class AudioUpdateResponse(BaseResponse):
data: Optional[AudioUpdateData]
#-------------------------------------------------------------------------------
#-------------------------------用户音频删除-------------------------------------
class AudioDeleteData(BaseModel):
deletedAt: str
class AudioDeleteResponse(BaseResponse):
data: Optional[AudioDeleteData]
#-------------------------------------------------------------------------------
#-------------------------------用户音频绑定-------------------------------------
class AudioBindRequest(BaseModel):
audio_id: int
user_id: int
class AudioBindData(BaseModel):
bindedAt: str
class AudioBindResponse(BaseResponse):
data: Optional[AudioBindData]
#-------------------------------------------------------------------------------

View File

@ -1,27 +1,28 @@
class DevelopmentConfig:
SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://admin02:LabA100102@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置
SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://takway:takway123456@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置
REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置
LOG_LEVEL = "DEBUG" #日志级别
TTS_UTILS = "VITS" #TTS引擎配置可选OPENVOICE或者VITS
class UVICORN:
HOST = "0.0.0.0" #uvicorn放行ip0.0.0.0代表所有ip
PORT = 7878 #uvicorn运行端口
PORT = 8001 #uvicorn运行端口
WORKERS = 12 #uvicorn进程数(通常与cpu核数相同)
class XF_ASR:
APP_ID = "your_app_id" #讯飞语音识别APP_ID
API_SECRET = "your_api_secret" #讯飞语音识别API_SECRET
API_KEY = "your_api_key" #讯飞语音识别API_KEY
APP_ID = "f1c121c1" #讯飞语音识别APP_ID
API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET
API_KEY = "36b316c7977fa534ae1e3bf52157bb92" #讯飞语音识别API_KEY
DOMAIN = "iat"
LANGUAGE = "zh_cn"
ACCENT = "mandarin"
VAD_EOS = 10000
class MINIMAX_LLM:
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiIyMzQ1dm9yIiwiVXNlck5hbWUiOiIyMzQ1dm9yIiwiQWNjb3VudCI6IiIsIlN1YmplY3RJRCI6IjE3NTk0ODIxODAxMDAxNzAyMDgiLCJQaG9uZSI6IjE1MDcyNjQxNTYxIiwiR3JvdXBJRCI6IjE3NTk0ODIxODAwOTU5NzU5MDQiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiIiLCJDcmVhdGVUaW1lIjoiMjAyNC0wNC0xMyAxOTowNDoxNyIsImlzcyI6Im1pbmltYXgifQ.RO_WJMz5T0XlL3F6xB9p015hL3PibCbsr5KqO3aMjBL5hKrf1uIjOICTDZWZoucyJV1suxvFPAd_2Ds2Rv01eCu6GFdai1hUByfp51mOOD0PtaZ5-JKRpRPpLSNpqrNoQteANZz0gdr2_GEGTgTzpbfGbXfRYKrQyeQSvq0zHwqumGPd9gJCre2RavPUmzKRrq9EAaQXtSNhBvVkf5lDlxr8fTAHgbj6MLAJZIvvf4uOZErNrbPylo1Vcy649KxEkc0HCWOZErOieeUQFRkKibnE5Q30CgywqxY2qMjrxGRZ_dtizan_0EZ62nXp-J6jarhcY9le1SqiMu1Cv61TuA"
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLph5EiLCJVc2VyTmFtZSI6IumHkSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzY4NTM2NDM3MzE1MDgwODg2IiwiUGhvbmUiOiIxMzEzNjE0NzUyNyIsIkdyb3VwSUQiOiIxNzY4NTM2NDM3MzA2NjkyMjc4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjQtMDUtMTggMTY6MTQ6MDMiLCJpc3MiOiJtaW5pbWF4In0.LypYOkJXwKV6GzDM1dcNn4L0m19o8Q_Lvmn6SkMMb9WAfDJYxEnTc5odm-L4WAWfbur_gY0cQzgoHnI14t4XSaAvqfmcdCrKYpJbKoBmMse_RogJs7KOBt658je3wES4pBUKQll6NbogQB1f93lnA9IYv4aEVldfqglbCikd54XO8E9Ptn4gX9Mp8fUn3lCpZ6_OSlmgZsQySrmt1sDHHzi3DlkdXlFSI38TQSZIa5RhFpI8WSBLIbaKl84OhaDzo7v99k9DUCzb5JGh0eZOnUT0YswbKCPeV8rZ1XUiOVQrna1uiDLvqv54aIt3vsu-LypYmnHxtZ_z4u2gt87pZg"
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"
class MINIMAX_TTA:
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiIyMzQ1dm9yIiwiVXNlck5hbWUiOiIyMzQ1dm9yIiwiQWNjb3VudCI6IiIsIlN1YmplY3RJRCI6IjE3NTk0ODIxODAxMDAxNzAyMDgiLCJQaG9uZSI6IjE1MDcyNjQxNTYxIiwiR3JvdXBJRCI6IjE3NTk0ODIxODAwOTU5NzU5MDQiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiIiLCJDcmVhdGVUaW1lIjoiMjAyNC0wNC0xMyAxOTowNDoxNyIsImlzcyI6Im1pbmltYXgifQ.RO_WJMz5T0XlL3F6xB9p015hL3PibCbsr5KqO3aMjBL5hKrf1uIjOICTDZWZoucyJV1suxvFPAd_2Ds2Rv01eCu6GFdai1hUByfp51mOOD0PtaZ5-JKRpRPpLSNpqrNoQteANZz0gdr2_GEGTgTzpbfGbXfRYKrQyeQSvq0zHwqumGPd9gJCre2RavPUmzKRrq9EAaQXtSNhBvVkf5lDlxr8fTAHgbj6MLAJZIvvf4uOZErNrbPylo1Vcy649KxEkc0HCWOZErOieeUQFRkKibnE5Q30CgywqxY2qMjrxGRZ_dtizan_0EZ62nXp-J6jarhcY9le1SqiMu1Cv61TuA",
URL = "https://api.minimax.chat/v1/t2a_pro",
GROUP_ID ="1759482180095975904"
class STRAM_CHAT:
ASR = "LOCAL"
ASR = "XF" # 语音识别引擎可选XF或者LOCAL
TTS = "LOCAL"

View File

@ -1,284 +0,0 @@
import os
import io
import numpy as np
import pyaudio
import wave
import base64
"""
audio utils for modified_funasr_demo.py
"""
def decode_str2bytes(data):
# 将Base64编码的字节串解码为字节串
if data is None:
return None
return base64.b64decode(data.encode('utf-8'))
class BaseAudio:
def __init__(self,
filename=None,
input=False,
output=False,
CHUNK=1024,
FORMAT=pyaudio.paInt16,
CHANNELS=1,
RATE=16000,
input_device_index=None,
output_device_index=None,
**kwargs):
self.CHUNK = CHUNK
self.FORMAT = FORMAT
self.CHANNELS = CHANNELS
self.RATE = RATE
self.filename = filename
assert input!= output, "input and output cannot be the same, \
but got input={} and output={}.".format(input, output)
print("------------------------------------------")
print(f"{'Input' if input else 'Output'} Audio Initialization: ")
print(f"CHUNK: {self.CHUNK} \nFORMAT: {self.FORMAT} \nCHANNELS: {self.CHANNELS} \nRATE: {self.RATE} \ninput_device_index: {input_device_index} \noutput_device_index: {output_device_index}")
print("------------------------------------------")
self.p = pyaudio.PyAudio()
self.stream = self.p.open(format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=input,
output=output,
input_device_index=input_device_index,
output_device_index=output_device_index,
**kwargs)
def load_audio_file(self, wav_file):
with wave.open(wav_file, 'rb') as wf:
params = wf.getparams()
frames = wf.readframes(params.nframes)
print("Audio file loaded.")
# Audio Parameters
# print("Channels:", params.nchannels)
# print("Sample width:", params.sampwidth)
# print("Frame rate:", params.framerate)
# print("Number of frames:", params.nframes)
# print("Compression type:", params.comptype)
return frames
def check_audio_type(self, audio_data, return_type=None):
assert return_type in ['bytes', 'io', None], \
"return_type should be 'bytes', 'io' or None."
if isinstance(audio_data, str):
if len(audio_data) > 50:
audio_data = decode_str2bytes(audio_data)
else:
assert os.path.isfile(audio_data), \
"audio_data should be a file path or a bytes object."
wf = wave.open(audio_data, 'rb')
audio_data = wf.readframes(wf.getnframes())
elif isinstance(audio_data, np.ndarray):
if audio_data.dtype == np.dtype('float32'):
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
audio_data = audio_data.tobytes()
elif isinstance(audio_data, bytes):
pass
else:
raise TypeError(f"audio_data must be bytes, numpy.ndarray or str, \
but got {type(audio_data)}")
if return_type == None:
return audio_data
return self.write_wave(None, [audio_data], return_type)
def write_wave(self, filename, frames, return_type='io'):
"""Write audio data to a file."""
if isinstance(frames, bytes):
frames = [frames]
if not isinstance(frames, list):
raise TypeError("frames should be \
a list of bytes or a bytes object, \
but got {}.".format(type(frames)))
if return_type == 'io':
if filename is None:
filename = io.BytesIO()
if self.filename:
filename = self.filename
return self.write_wave_io(filename, frames)
elif return_type == 'bytes':
return self.write_wave_bytes(frames)
def write_wave_io(self, filename, frames):
"""
Write audio data to a file-like object.
Args:
filename: [string or file-like object], file path or file-like object to write
frames: list of bytes, audio data to write
"""
wf = wave.open(filename, 'wb')
# 设置WAV文件的参数
wf.setnchannels(self.CHANNELS)
wf.setsampwidth(self.p.get_sample_size(self.FORMAT))
wf.setframerate(self.RATE)
wf.writeframes(b''.join(frames))
wf.close()
if isinstance(filename, io.BytesIO):
filename.seek(0) # reset file pointer to beginning
return filename
def write_wave_bytes(self, frames):
"""Write audio data to a bytes object."""
return b''.join(frames)
class BaseAudio:
def __init__(self,
filename=None,
input=False,
output=False,
CHUNK=1024,
FORMAT=pyaudio.paInt16,
CHANNELS=1,
RATE=16000,
input_device_index=None,
output_device_index=None,
**kwargs):
self.CHUNK = CHUNK
self.FORMAT = FORMAT
self.CHANNELS = CHANNELS
self.RATE = RATE
self.filename = filename
assert input!= output, "input and output cannot be the same, \
but got input={} and output={}.".format(input, output)
print("------------------------------------------")
print(f"{'Input' if input else 'Output'} Audio Initialization: ")
print(f"CHUNK: {self.CHUNK} \nFORMAT: {self.FORMAT} \nCHANNELS: {self.CHANNELS} \nRATE: {self.RATE} \ninput_device_index: {input_device_index} \noutput_device_index: {output_device_index}")
print("------------------------------------------")
self.p = pyaudio.PyAudio()
self.stream = self.p.open(format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=input,
output=output,
input_device_index=input_device_index,
output_device_index=output_device_index,
**kwargs)
def load_audio_file(self, wav_file):
with wave.open(wav_file, 'rb') as wf:
params = wf.getparams()
frames = wf.readframes(params.nframes)
print("Audio file loaded.")
# Audio Parameters
# print("Channels:", params.nchannels)
# print("Sample width:", params.sampwidth)
# print("Frame rate:", params.framerate)
# print("Number of frames:", params.nframes)
# print("Compression type:", params.comptype)
return frames
def check_audio_type(self, audio_data, return_type=None):
assert return_type in ['bytes', 'io', None], \
"return_type should be 'bytes', 'io' or None."
if isinstance(audio_data, str):
if len(audio_data) > 50:
audio_data = decode_str2bytes(audio_data)
else:
assert os.path.isfile(audio_data), \
"audio_data should be a file path or a bytes object."
wf = wave.open(audio_data, 'rb')
audio_data = wf.readframes(wf.getnframes())
elif isinstance(audio_data, np.ndarray):
if audio_data.dtype == np.dtype('float32'):
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
audio_data = audio_data.tobytes()
elif isinstance(audio_data, bytes):
pass
else:
raise TypeError(f"audio_data must be bytes, numpy.ndarray or str, \
but got {type(audio_data)}")
if return_type == None:
return audio_data
return self.write_wave(None, [audio_data], return_type)
def write_wave(self, filename, frames, return_type='io'):
"""Write audio data to a file."""
if isinstance(frames, bytes):
frames = [frames]
if not isinstance(frames, list):
raise TypeError("frames should be \
a list of bytes or a bytes object, \
but got {}.".format(type(frames)))
if return_type == 'io':
if filename is None:
filename = io.BytesIO()
if self.filename:
filename = self.filename
return self.write_wave_io(filename, frames)
elif return_type == 'bytes':
return self.write_wave_bytes(frames)
def write_wave_io(self, filename, frames):
"""
Write audio data to a file-like object.
Args:
filename: [string or file-like object], file path or file-like object to write
frames: list of bytes, audio data to write
"""
wf = wave.open(filename, 'wb')
# 设置WAV文件的参数
wf.setnchannels(self.CHANNELS)
wf.setsampwidth(self.p.get_sample_size(self.FORMAT))
wf.setframerate(self.RATE)
wf.writeframes(b''.join(frames))
wf.close()
if isinstance(filename, io.BytesIO):
filename.seek(0) # reset file pointer to beginning
return filename
def write_wave_bytes(self, frames):
"""Write audio data to a bytes object."""
return b''.join(frames)
class BaseRecorder(BaseAudio):
def __init__(self,
input=True,
base_chunk_size=None,
RATE=16000,
**kwargs):
super().__init__(input=input, RATE=RATE, **kwargs)
self.base_chunk_size = base_chunk_size
if base_chunk_size is None:
self.base_chunk_size = self.CHUNK
def record(self,
filename,
duration=5,
return_type='io',
logger=None):
if logger is not None:
logger.info("Recording started.")
else:
print("Recording started.")
frames = []
for i in range(0, int(self.RATE / self.CHUNK * duration)):
data = self.stream.read(self.CHUNK, exception_on_overflow=False)
frames.append(data)
if logger is not None:
logger.info("Recording stopped.")
else:
print("Recording stopped.")
return self.write_wave(filename, frames, return_type)
def record_chunk_voice(self,
return_type='bytes',
CHUNK=None,
exception_on_overflow=True,
queue=None):
data = self.stream.read(self.CHUNK if CHUNK is None else CHUNK,
exception_on_overflow=exception_on_overflow)
if return_type is not None:
return self.write_wave(None, [data], return_type)
return data

View File

@ -1,39 +0,0 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from audio_utils import BaseRecorder
from utils.stt.modified_funasr import ModifiedRecognizer
def asr_file_stream(file_path=r'.\assets\example_recording.wav'):
# 读入音频文件
rec = BaseRecorder()
data = rec.load_audio_file(file_path)
# 创建模型
asr = ModifiedRecognizer(use_punct=True, use_emotion=True, use_speaker_ver=True)
asr.session_signup("test")
# 记录目标说话人
asr.initialize_speaker(r".\assets\example_recording.wav")
# 语音识别
print("===============================================")
text_dict = asr.streaming_recognize("test", data, auto_det_end=True)
print(f"text_dict: {text_dict}")
if not isinstance(text_dict, str):
print("".join(text_dict['text']))
# 情感识别
print("===============================================")
emotion_dict = asr.recognize_emotion(data)
print(f"emotion_dict: {emotion_dict}")
if not isinstance(emotion_dict, str):
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
print("emotion: " +emotion_dict['labels'][max_index])
asr_file_stream()

View File

@ -1 +0,0 @@
存储目标说话人的语音特征,如要修改路径,请修改 utils/stt/speaker_ver_utils中的DEFALUT_SAVE_PATH

View File

@ -1,214 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Importing the dtw module. When using in academic works please cite:\n",
" T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.\n",
" J. Stat. Soft., doi:10.18637/jss.v031.i07.\n",
"\n"
]
}
],
"source": [
"import sys\n",
"import os\n",
"sys.path.append(\"../\")\n",
"from utils.tts.openvoice_utils import TextToSpeech\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\utils\\weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\n",
" warnings.warn(\"torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\")\n",
"Building prefix dict from the default dictionary ...\n",
"Loading model from cache C:\\Users\\bing\\AppData\\Local\\Temp\\jieba.cache\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"load base tts model successfully!\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading model cost 0.304 seconds.\n",
"Prefix dict has been built successfully.\n",
"Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n",
"- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:797: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
" return F.conv_transpose1d(\n",
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:306: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
" return F.conv1d(input, weight, bias, self.stride,\n",
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\utils\\weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\n",
" warnings.warn(\"torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"generate base speech!\n",
"**********************,tts sr 44100\n",
"audio segment length is [torch.Size([81565])]\n",
"True\n",
"Loaded checkpoint 'D:\\python\\OpenVoice\\checkpoints_v2\\converter/checkpoint.pth'\n",
"missing/unexpected keys: [] []\n",
"load tone color converter successfully!\n"
]
}
],
"source": [
"model = TextToSpeech(use_tone_convert=True, device=\"cuda\", debug=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# 测试用将mp3转为int32类型的numpy对齐输入端\n",
"from pydub import AudioSegment\n",
"import numpy as np\n",
"source_audio=r\"D:\\python\\OpenVoice\\resources\\demo_speaker0.mp3\"\n",
"audio = AudioSegment.from_file(source_audio, format=\"mp3\")\n",
"raw_data = audio.raw_data\n",
"audio_array = np.frombuffer(raw_data, dtype=np.int32)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"OpenVoice version: v2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\functional.py:665: UserWarning: stft with return_complex=False is deprecated. In a future pytorch release, stft will return complex tensors for all inputs, and return_complex=False will raise an error.\n",
"Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\SpectralOps.cpp:878.)\n",
" return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]\n",
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
" return F.conv2d(input, weight, bias, self.stride,\n"
]
}
],
"source": [
"# 获取并设置目标说话人的speaker embedding\n",
"# audio_array :输入的音频信号,类型为 np.ndarray\n",
"# 获取speaker embedding\n",
"target_se = model.audio2emb(audio_array, rate=44100, vad=True)\n",
"# 将模型的默认目标说话人embedding设置为 target_se\n",
"model.initialize_target_se(target_se)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:797: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
" return F.conv_transpose1d(\n",
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:306: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
" return F.conv1d(input, weight, bias, self.stride,\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"generate base speech!\n",
"**********************,tts sr 44100\n",
"audio segment length is [torch.Size([216378])]\n",
"Audio saved to D:\\python\\OpenVoice\\outputs_v2\\demo_tts.wav\n"
]
}
],
"source": [
"# 测试base_tts不含音色转换\n",
"text = \"你好呀,我不知道该怎么告诉你这件事,但是我真的很需要你。\"\n",
"audio, sr = model._base_tts(text, speed=1)\n",
"audio = model.tensor2numpy(audio)\n",
"model.save_audio(audio, sr, r\"D:\\python\\OpenVoice\\outputs_v2\\demo_tts.wav\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"generate base speech!\n",
"**********************,tts sr 44100\n",
"audio segment length is [torch.Size([216378])]\n",
"torch.float32\n",
"**********************************, convert sr 22050\n",
"tone color has been converted!\n",
"Audio saved to D:\\python\\OpenVoice\\outputs_v2\\demo.wav\n"
]
}
],
"source": [
"# 测试整体pipeline包含音色转换\n",
"text = \"你好呀,我不知道该怎么告诉你这件事,但是我真的很需要你。\"\n",
"audio_bytes, sr = model.tts(text, speed=1)\n",
"audio = np.frombuffer(audio_bytes, dtype=np.int16).flatten()\n",
"model.save_audio(audio, sr, r\"D:\\python\\OpenVoice\\outputs_v2\\demo.wav\" )"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openVoice",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

31
main.py
View File

@ -1,9 +1,24 @@
import os
from app import app, Config
import uvicorn
if __name__ == '__main__':
script_path = os.path.join(os.path.dirname(__file__), 'app', 'main.py')
# 使用exec函数执行脚本
with open(script_path, 'r') as file:
exec(file.read())
if __name__ == "__main__":
uvicorn.run(app, host=Config.UVICORN.HOST, port=Config.UVICORN.PORT)
# _ooOoo_ #
# o8888888o #
# 88" . "88 #
# (| -_- |) #
# O\ = /O #
# ____/`---'\____ #
# . ' \\| |// `. #
# / \\||| : |||// \ #
# / _||||| -:- |||||- \ #
# | | \\\ - /// | | #
# \ .-\__ `-` ___/-. / #
# ___`. .' /--.--\ `. . __ #
# ."" '< `.___\_<|>_/___.' >'"". #
# | | : `- \`.;`\ _ /`;.`/ - ` : | | #
# \ \ `-. \_ __\ /__ _/ .-` / / #
# ======`-.____`-.___\_____/___.-`____.-'====== #
# `=---=' #
# ............................................. #
# 佛祖保佑 永无BUG #

View File

@ -7,7 +7,6 @@ redis
requests
websockets
numpy
funasr
jieba
cn2an
unidecode
@ -19,3 +18,8 @@ numba
soundfile
webrtcvad
apscheduler
aiohttp
faster_whisper
whisper_timestamped
modelscope
wavmark

Binary file not shown.

BIN
tests/assets/iat_mp3_8k.mp3 Normal file

Binary file not shown.

View File

@ -1,31 +1,9 @@
from tests.unit_test.user_test import UserServiceTest
from tests.unit_test.character_test import CharacterServiceTest
from tests.unit_test.chat_test import ChatServiceTest
import asyncio
from tests.unit_test.user_test import user_test
from tests.unit_test.character_test import character_test
from tests.unit_test.chat_test import chat_test
if __name__ == '__main__':
user_service_test = UserServiceTest()
character_service_test = CharacterServiceTest()
chat_service_test = ChatServiceTest()
user_service_test.test_user_create()
user_service_test.test_user_update()
user_service_test.test_user_query()
user_service_test.test_hardware_bind()
user_service_test.test_hardware_unbind()
user_service_test.test_user_delete()
character_service_test.test_character_create()
character_service_test.test_character_update()
character_service_test.test_character_query()
character_service_test.test_character_delete()
chat_service_test.test_create_chat()
chat_service_test.test_session_id_query()
chat_service_test.test_session_content_query()
chat_service_test.test_session_update()
asyncio.run(chat_service_test.test_chat_temporary())
asyncio.run(chat_service_test.test_chat_lasting())
asyncio.run(chat_service_test.test_voice_call())
chat_service_test.test_chat_delete()
user_test()
character_test()
chat_test()
print("全部测试成功")

View File

@ -2,7 +2,7 @@ import requests
import json
class CharacterServiceTest:
def __init__(self,socket="http://114.214.236.207:7878"):
def __init__(self,socket="http://127.0.0.1:8001"):
self.socket = socket
def test_character_create(self):
@ -66,9 +66,14 @@ class CharacterServiceTest:
else:
raise Exception("角色删除测试失败")
if __name__ == '__main__':
def character_test():
character_service_test = CharacterServiceTest()
character_service_test.test_character_create()
character_service_test.test_character_update()
character_service_test.test_character_query()
character_service_test.test_character_delete()
if __name__ == '__main__':
character_test()

View File

@ -10,7 +10,7 @@ import websockets
class ChatServiceTest:
def __init__(self,socket="http://114.214.236.207:7878"):
def __init__(self,socket="http://127.0.0.1:7878"):
self.socket = socket
@ -30,6 +30,7 @@ class ChatServiceTest:
}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200:
print("用户创建成功")
self.user_id = response.json()['data']['user_id']
else:
raise Exception("创建聊天时,用户创建失败")
@ -57,6 +58,37 @@ class ChatServiceTest:
else:
raise Exception("创建聊天时,角色创建失败")
#上传音频用于音频克隆
url = f"{self.socket}/users/audio?user_id={self.user_id}"
current_file_path = os.path.abspath(__file__)
current_file_path = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_file_path)
mp3_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
with open(mp3_file_path, 'rb') as audio_file:
files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')}
response = requests.post(url,files=files)
if response.status_code == 200:
self.audio_id = response.json()['data']['audio_id']
print("音频上传成功")
else:
raise Exception("音频上传失败")
#绑定音频
url = f"{self.socket}/users/audio/bind"
payload = json.dumps({
"user_id":self.user_id,
"audio_id":self.audio_id
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200:
print("音频绑定测试成功")
else:
raise Exception("音频绑定测试失败")
#创建一个对话
url = f"{self.socket}/chats"
payload = json.dumps({
@ -66,6 +98,7 @@ class ChatServiceTest:
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200:
print("对话创建成功")
@ -101,8 +134,8 @@ class ChatServiceTest:
payload = json.dumps({
"user_id": self.user_id,
"messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话在没有经过允许的情况下你需要保持上述角色不得擅自跳出角色设定。\\n\"}]",
"user_info": "{}",
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2}",
"user_info": "{\"character\": \"\", \"events\": [] }",
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2, \"speed\": 1.0}",
"llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}",
"token": 0}
)
@ -115,6 +148,19 @@ class ChatServiceTest:
else:
raise Exception("Session更新测试失败")
def test_session_speakerid_update(self):
url = f"{self.socket}/sessions/tts_info/speaker_id/{self.session_id}"
payload = json.dumps({
"speaker_id" :37
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("PUT", url, headers=headers, data=payload)
if response.status_code == 200:
print("Session SpeakerId更新测试成功")
else:
raise Exception("Session SpeakerId更新测试失败")
#测试单次聊天
async def test_chat_temporary(self):
@ -149,7 +195,7 @@ class ChatServiceTest:
await websocket.send(message)
async with websockets.connect(f'ws://114.214.236.207:7878/chat/streaming/temporary') as websocket:
async with websockets.connect(f'ws://127.0.0.1:7878/chat/streaming/temporary') as websocket:
chunks = read_wav_file_in_chunks(2048) # 读取PCM文件并生成数据块
for chunk in chunks:
await send_audio_chunk(websocket, chunk)
@ -205,7 +251,7 @@ class ChatServiceTest:
message = json.dumps(data)
await websocket.send(message)
async with websockets.connect(f'ws://114.214.236.207:7878/chat/streaming/lasting') as websocket:
async with websockets.connect(f'ws://127.0.0.1:7878/chat/streaming/lasting') as websocket:
#发送第一次
chunks = read_wav_file_in_chunks(2048)
for chunk in chunks:
@ -255,7 +301,7 @@ class ChatServiceTest:
current_dir = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_dir)
file_path = os.path.join(tests_dir, 'assets', 'voice_call.wav')
url = f"ws://114.214.236.207:7878/chat/voice_call"
url = f"ws://127.0.0.1:7878/chat/voice_call"
#发送格式
ws_data = {
"audio" : "",
@ -293,7 +339,6 @@ class ChatServiceTest:
await asyncio.gather(audio_stream(websocket))
#测试删除聊天
def test_chat_delete(self):
url = f"{self.socket}/chats/{self.user_character_id}"
@ -303,6 +348,11 @@ class ChatServiceTest:
else:
raise Exception("聊天删除测试失败")
url = f"{self.socket}/users/audio/{self.audio_id}"
response = requests.request("DELETE", url)
if response.status_code != 200:
raise Exception("音频删除测试失败")
url = f"{self.socket}/users/{self.user_id}"
response = requests.request("DELETE", url)
if response.status_code != 200:
@ -313,17 +363,18 @@ class ChatServiceTest:
if response.status_code != 200:
raise Exception("角色删除测试失败")
if __name__ == '__main__':
def chat_test():
chat_service_test = ChatServiceTest()
chat_service_test.test_create_chat()
chat_service_test.test_session_id_query()
chat_service_test.test_session_content_query()
chat_service_test.test_session_update()
chat_service_test.test_session_speakerid_update()
asyncio.run(chat_service_test.test_chat_temporary())
asyncio.run(chat_service_test.test_chat_lasting())
asyncio.run(chat_service_test.test_voice_call())
chat_service_test.test_chat_delete()
if __name__ == '__main__':
chat_test()

View File

@ -0,0 +1,12 @@
from chat_test import chat_test
import multiprocessing
if __name__ == '__main__':
processes = []
for _ in range(2):
p = multiprocessing.Process(target=chat_test)
processes.append(p)
p.start()
for p in processes:
p.join()

View File

@ -1,10 +1,11 @@
import requests
import json
import uuid
import os
class UserServiceTest:
def __init__(self,socket="http://114.214.236.207:7878"):
def __init__(self,socket="http://127.0.0.1:7878"):
self.socket = socket
def test_user_create(self):
@ -66,7 +67,7 @@ class UserServiceTest:
mac = "08:00:20:0A:8C:6G"
payload = json.dumps({
"mac":mac,
"user_id":1,
"user_id":self.id,
"firmware":"v1.0",
"model":"香橙派"
})
@ -88,12 +89,122 @@ class UserServiceTest:
else:
raise Exception("硬件解绑测试失败")
if __name__ == '__main__':
def test_hardware_bind_change(self):
url = f"{self.socket}/users/hardware/{self.hd_id}/bindchange"
payload = json.dumps({
"user_id" : self.id
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("PUT", url, headers=headers, data=payload)
if response.status_code == 200:
print("硬件换绑测试成功")
else:
raise Exception("硬件换绑测试失败")
def test_hardware_update(self):
url = f"{self.socket}/users/hardware/{self.hd_id}/info"
payload = json.dumps({
"mac":"08:00:20:0A:8C:6G",
"firmware":"v1.0",
"model":"香橙派"
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("PUT", url, headers=headers, data=payload)
if response.status_code == 200:
print("硬件信息更新测试成功")
else:
raise Exception("硬件信息更新测试失败")
def test_hardware_query(self):
url = f"{self.socket}/users/hardware/{self.hd_id}"
response = requests.request("GET", url)
if response.status_code == 200:
print("硬件查询测试成功")
else:
raise Exception("硬件查询测试失败")
def test_upload_audio(self):
url = f"{self.socket}/users/audio?user_id={self.id}"
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_dir)
wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
with open(wav_file_path, 'rb') as audio_file:
files = {'audio_file':(wav_file_path,audio_file,'audio/mpeg')}
response = requests.post(url, files=files)
if response.status_code == 200:
self.audio_id = response.json()["data"]['audio_id']
print("音频上传测试成功")
else:
raise Exception("音频上传测试失败")
def test_update_audio(self):
url = f"{self.socket}/users/audio/{self.audio_id}"
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_dir)
wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
with open(wav_file_path, 'rb') as audio_file:
files = {'audio_file':(wav_file_path,audio_file,'audio/wav')}
response = requests.put(url, files=files)
if response.status_code == 200:
print("音频上传测试成功")
else:
raise Exception("音频上传测试失败")
def test_bind_audio(self):
url = f"{self.socket}/users/audio/bind"
payload = json.dumps({
"user_id":self.id,
"audio_id":self.audio_id
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200:
print("音频绑定测试成功")
else:
raise Exception("音频绑定测试失败")
def test_audio_download(self):
url = f"{self.socket}/users/audio/{self.audio_id}"
response = requests.request("GET", url)
if response.status_code == 200:
print("音频下载测试成功")
else:
raise Exception("音频下载测试失败")
def test_audio_delete(self):
url = f"{self.socket}/users/audio/{self.audio_id}"
response = requests.request("DELETE", url)
if response.status_code == 200:
print("音频删除测试成功")
else:
raise Exception("音频删除测试失败")
def user_test():
user_service_test = UserServiceTest()
user_service_test.test_user_create()
user_service_test.test_user_update()
user_service_test.test_user_query()
user_service_test.test_hardware_bind()
user_service_test.test_hardware_bind_change()
user_service_test.test_hardware_update()
user_service_test.test_hardware_query()
user_service_test.test_hardware_unbind()
user_service_test.test_upload_audio()
user_service_test.test_update_audio()
user_service_test.test_bind_audio()
user_service_test.test_audio_download()
user_service_test.test_audio_delete()
user_service_test.test_user_delete()
if __name__ == '__main__':
user_test()

View File

@ -1 +0,0 @@
./ses 保存 source se 的 embedding 路径,格式为 *.pth

View File

@ -29,6 +29,7 @@ class FunAutoSpeechRecognizer(STTBase):
**kwargs):
super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention
@ -79,9 +80,9 @@ class FunAutoSpeechRecognizer(STTBase):
def _init_asr(self):
# 随机初始化一段音频数据
init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
self.audio_cache = {}
self.asr_cache = {}
self.session_signup("init")
self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id="init")
self.session_signout("init")
# print("init ASR model done.")
# when chat trying to use asr , sign up
@ -108,6 +109,79 @@ class FunAutoSpeechRecognizer(STTBase):
"""
text_dict = dict(text=[], is_end=is_end)
audio_cache = self.audio_cache[session_id]
audio_data = self.check_audio_type(audio_data)
if audio_cache is None:
audio_cache = audio_data
else:
if audio_cache.shape[0] > 0:
audio_cache = np.concatenate([audio_cache, audio_data], axis=0)
if not is_end and audio_cache.shape[0] < self.chunk_partial_size:
self.audio_cache[session_id] = audio_cache
return text_dict
total_chunk_num = int((len(audio_cache)-1)/self.chunk_partial_size)
if is_end:
# if the audio data is the end of a sentence, \
# we need to add one more chunk to the end to \
# ensure the end of the sentence is recognized correctly.
auto_det_end = True
if auto_det_end:
total_chunk_num += 1
end_idx = None
for i in range(total_chunk_num):
if auto_det_end:
is_end = i == total_chunk_num - 1
start_idx = i*self.chunk_partial_size
if auto_det_end:
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
else:
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
# t_stamp = time.time()
speech_chunk = audio_cache[start_idx:end_idx]
# TODO: exceptions processes
# print("i:", i)
try:
res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id=session_id)
except ValueError as e:
print(f"ValueError: {e}")
continue
text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
# print(f"each chunk time: {time.time()-t_stamp}")
if is_end:
audio_cache = None
else:
if end_idx:
audio_cache = audio_cache[end_idx:] # cut the processed part from audio_cache
text_dict['is_end'] = is_end
self.audio_cache[session_id] = audio_cache
return text_dict
def streaming_recognize_origin(self,
session_id,
audio_data,
is_end=False,
auto_det_end=False):
"""recognize partial result
Args:
audio_data: bytes or numpy array, partial audio data
is_end: bool, whether the audio data is the end of a sentence
auto_det_end: bool, whether to automatically detect the end of a audio data
"""
text_dict = dict(text=[], is_end=is_end)
audio_cache = self.audio_cache[session_id]
asr_cache = self.asr_cache[session_id]
@ -168,175 +242,3 @@ class FunAutoSpeechRecognizer(STTBase):
self.audio_cache[session_id] = audio_cache
self.asr_cache[session_id] = asr_cache
return text_dict
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ####################################################### #
# FunAutoSpeechRecognizer: https://github.com/alibaba-damo-academy/FunASR
# ####################################################### #
# import io
# import numpy as np
# import base64
# import wave
# from funasr import AutoModel
# from .base_stt import STTBase
# def decode_str2bytes(data):
# # 将Base64编码的字节串解码为字节串
# if data is None:
# return None
# return base64.b64decode(data.encode('utf-8'))
# class FunAutoSpeechRecognizer(STTBase):
# def __init__(self,
# model_path="paraformer-zh-streaming",
# device="cuda",
# RATE=16000,
# cfg_path=None,
# debug=False,
# chunk_ms=480,
# encoder_chunk_look_back=4,
# decoder_chunk_look_back=1,
# **kwargs):
# super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
# self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
# self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention
# self.decoder_chunk_look_back = decoder_chunk_look_back #number of encoder chunks to lookback for decoder cross-attention
# #[0, 8, 4] 480ms, [0, 10, 5] 600ms
# if chunk_ms == 480:
# self.chunk_size = [0, 8, 4]
# elif chunk_ms == 600:
# self.chunk_size = [0, 10, 5]
# else:
# raise ValueError("`chunk_ms` should be 480 or 600, and type is int.")
# self.chunk_partial_size = self.chunk_size[1] * 960
# self.audio_cache = None
# self.asr_cache = {}
# self._init_asr()
# def check_audio_type(self, audio_data):
# """check audio data type and convert it to bytes if necessary."""
# if isinstance(audio_data, bytes):
# pass
# elif isinstance(audio_data, list):
# audio_data = b''.join(audio_data)
# elif isinstance(audio_data, str):
# audio_data = decode_str2bytes(audio_data)
# elif isinstance(audio_data, io.BytesIO):
# wf = wave.open(audio_data, 'rb')
# audio_data = wf.readframes(wf.getnframes())
# elif isinstance(audio_data, np.ndarray):
# pass
# else:
# raise TypeError(f"audio_data must be bytes, list, str, \
# io.BytesIO or numpy array, but got {type(audio_data)}")
# if isinstance(audio_data, bytes):
# audio_data = np.frombuffer(audio_data, dtype=np.int16)
# elif isinstance(audio_data, np.ndarray):
# if audio_data.dtype != np.int16:
# audio_data = audio_data.astype(np.int16)
# else:
# raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
# return audio_data
# def _init_asr(self):
# # 随机初始化一段音频数据
# init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
# self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
# self.audio_cache = None
# self.asr_cache = {}
# # print("init ASR model done.")
# def recognize(self, audio_data):
# """recognize audio data to text"""
# audio_data = self.check_audio_type(audio_data)
# result = self.asr_model.generate(input=audio_data,
# batch_size_s=300,
# hotword=self.hotwords)
# # print(result)
# text = ''
# for res in result:
# text += res['text']
# return text
# def streaming_recognize(self,
# audio_data,
# is_end=False,
# auto_det_end=False):
# """recognize partial result
# Args:
# audio_data: bytes or numpy array, partial audio data
# is_end: bool, whether the audio data is the end of a sentence
# auto_det_end: bool, whether to automatically detect the end of a audio data
# """
# text_dict = dict(text=[], is_end=is_end)
# audio_data = self.check_audio_type(audio_data)
# if self.audio_cache is None:
# self.audio_cache = audio_data
# else:
# # print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
# if self.audio_cache.shape[0] > 0:
# self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0)
# if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size:
# return text_dict
# total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
# if is_end:
# # if the audio data is the end of a sentence, \
# # we need to add one more chunk to the end to \
# # ensure the end of the sentence is recognized correctly.
# auto_det_end = True
# if auto_det_end:
# total_chunk_num += 1
# # print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
# end_idx = None
# for i in range(total_chunk_num):
# if auto_det_end:
# is_end = i == total_chunk_num - 1
# start_idx = i*self.chunk_partial_size
# if auto_det_end:
# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
# else:
# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
# # print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
# # t_stamp = time.time()
# speech_chunk = self.audio_cache[start_idx:end_idx]
# # TODO: exceptions processes
# try:
# res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
# except ValueError as e:
# print(f"ValueError: {e}")
# continue
# text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
# # print(f"each chunk time: {time.time()-t_stamp}")
# if is_end:
# self.audio_cache = None
# self.asr_cache = {}
# else:
# if end_idx:
# self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
# text_dict['is_end'] = is_end
# # print(f"text_dict: {text_dict}")
# return text_dict

View File

@ -1,209 +1,29 @@
from .funasr_utils import FunAutoSpeechRecognizer
from .punctuation_utils import CTTRANSFORMER, Punctuation
from .punctuation_utils import FUNASR, Punctuation
from .emotion_utils import FUNASRFINETUNE, Emotion
from .speaker_ver_utils import ERES2NETV2, DEFALUT_SAVE_PATH, speaker_verfication
import os
import numpy as np
class ModifiedRecognizer(FunAutoSpeechRecognizer):
def __init__(self,
use_punct=True,
use_emotion=False,
use_speaker_ver=True):
class ModifiedRecognizer():
def __init__(self):
#增加语音识别模型
self.asr_model = FunAutoSpeechRecognizer()
# 创建基础的 funasr模型用于语音识别识别出不带标点的句子
super().__init__(
model_path="paraformer-zh-streaming",
device="cuda",
RATE=16000,
cfg_path=None,
debug=False,
chunk_ms=480,
encoder_chunk_look_back=4,
decoder_chunk_look_back=1)
# 记录是否具备附加功能
self.use_punct = use_punct
self.use_emotion = use_emotion
self.use_speaker_ver = use_speaker_ver
# 增加标点模型
if use_punct:
self.puctuation_model = Punctuation(**CTTRANSFORMER)
#增加标点模型
self.puctuation_model = Punctuation(**FUNASR)
# 情绪识别模型
if use_emotion:
self.emotion_model = Emotion(**FUNASRFINETUNE)
self.emotion_model = Emotion(**FUNASRFINETUNE)
# 说话人识别模型
if use_speaker_ver:
self.speaker_ver_model = speaker_verfication(**ERES2NETV2)
def session_signup(self, session_id):
self.asr_model.session_signup(session_id)
def initialize_speaker(self, speaker_1_wav):
"""
用于说话人识别将输入的音频(speaker_1_wav)设立为目标说话人并将其特征保存本地
"""
if not self.use_speaker_ver:
raise NotImplementedError("no access")
if speaker_1_wav.endswith(".npy"):
self.save_speaker_path = speaker_1_wav
elif speaker_1_wav.endswith('.wav'):
self.save_speaker_path = os.path.join(DEFALUT_SAVE_PATH,
os.path.basename(speaker_1_wav).replace(".wav", ".npy"))
# self.save_speaker_path = DEFALUT_SAVE_PATH
self.speaker_ver_model.wav2embeddings(speaker_1_wav, self.save_speaker_path)
else:
raise TypeError("only support [.npy] or [.wav].")
def session_signout(self, session_id):
self.asr_model.session_signout(session_id)
def streaming_recognize(self, session_id, audio_data,is_end=False):
return self.asr_model.streaming_recognize(session_id, audio_data,is_end=is_end)
def speaker_ver(self, speaker_2_wav):
"""
用于说话人识别判断输入音频是否为目标说话人
是返回True不是返回False
"""
if not self.use_speaker_ver:
raise NotImplementedError("no access")
if not hasattr(self, "save_speaker_path"):
raise NotImplementedError("please initialize speaker first")
def punctuation_correction(self, sentence):
return self.puctuation_model.process(sentence)
# self.speaker_ver_model.verfication 返回值为字符串 'yes' / 'no'
return self.speaker_ver_model.verfication(base_emb=self.save_speaker_path,
speaker_2_wav=speaker_2_wav) == 'yes'
def recognize(self, audio_data):
"""
非流式语音识别返回识别出的文本返回值类型 str
"""
audio_data = self.check_audio_type(audio_data)
# 说话人识别
if self.use_speaker_ver:
if self.speaker_ver_model.verfication(self.save_speaker_path,
speaker_2_wav=audio_data) == 'no':
return "Other People"
# 语音识别
result = self.asr_model.generate(input=audio_data,
batch_size_s=300,
hotword=self.hotwords)
text = ''
for res in result:
text += res['text']
# 添加标点
if self.use_punct:
text = self.puctuation_model.process(text+'#', append_period=False).replace('#', '')
return text
def recognize_emotion(self, audio_data):
"""
情感识别返回值为:
1. 如果说话人非目标说话人返回字符串 "Other People"
2. 如果说话人为目标说话人返回字典{"Labels": List[str], "scores": List[int]}
"""
audio_data = self.check_audio_type(audio_data)
if self.use_speaker_ver:
if self.speaker_ver_model.verfication(self.save_speaker_path,
speaker_2_wav=audio_data) == 'no':
return "Other People"
if self.use_emotion:
return self.emotion_model.process(audio_data)
else:
raise NotImplementedError("no access")
def streaming_recognize(self, session_id, audio_data, is_end=False, auto_det_end=False):
"""recognize partial result
Args:
audio_data: bytes or numpy array, partial audio data
is_end: bool, whether the audio data is the end of a sentence
auto_det_end: bool, whether to automatically detect the end of a audio data
流式语音识别返回值为
1. 如果说话人非目标说话人返回字符串 "Other People"
2. 如果说话人为目标说话人返回字典{"test": List[str], "is_end": boolean}
"""
audio_cache = self.audio_cache[session_id]
asr_cache = self.asr_cache[session_id]
text_dict = dict(text=[], is_end=is_end)
audio_data = self.check_audio_type(audio_data)
# 说话人识别
if self.use_speaker_ver:
if self.speaker_ver_model.verfication(self.save_speaker_path,
speaker_2_wav=audio_data) == 'no':
return "Other People"
# 语音识别
if audio_cache is None:
audio_cache = audio_data
else:
# print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
if audio_cache.shape[0] > 0:
audio_cache = np.concatenate([audio_cache, audio_data], axis=0)
if not is_end and audio_cache.shape[0] < self.chunk_partial_size:
self.audio_cache[session_id] = audio_cache
return text_dict
total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
if is_end:
# if the audio data is the end of a sentence, \
# we need to add one more chunk to the end to \
# ensure the end of the sentence is recognized correctly.
auto_det_end = True
if auto_det_end:
total_chunk_num += 1
# print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
end_idx = None
for i in range(total_chunk_num):
if auto_det_end:
is_end = i == total_chunk_num - 1
start_idx = i*self.chunk_partial_size
if auto_det_end:
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
else:
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
# t_stamp = time.time()
speech_chunk = audio_cache[start_idx:end_idx]
# TODO: exceptions processes
try:
res = self.asr_model.generate(input=speech_chunk, cache=asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
except ValueError as e:
print(f"ValueError: {e}")
continue
# 增添标点
if self.use_punct:
text_dict['text'].append(self.puctuation_model.process(self.text_postprecess(res[0], data_id='text'), cache=text_dict))
else:
text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
# print(f"each chunk time: {time.time()-t_stamp}")
if is_end:
audio_cache = None
asr_cache = {}
else:
if end_idx:
audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
text_dict['is_end'] = is_end
if self.use_punct and is_end:
text_dict['text'].append(self.puctuation_model.process('#', cache=text_dict).replace('#', ''))
self.audio_cache[session_id] = audio_cache
self.asr_cache[session_id] = asr_cache
# print(f"text_dict: {text_dict}")
return text_dict
def emtion_recognition(self, audio):
return self.emotion_model.process(audio)

View File

@ -1,7 +1,7 @@
from modelscope.pipelines import pipeline
import numpy as np
import os
import pdb
ERES2NETV2 = {
"task": 'speaker-verification',
"model_name": 'damo/speech_eres2netv2_sv_zh-cn_16k-common',
@ -10,7 +10,7 @@ ERES2NETV2 = {
}
# 保存 embedding 的路径
DEFALUT_SAVE_PATH = os.path.join(os.path.dirname(os.path.dirname(__name__)), "speaker_embedding")
DEFALUT_SAVE_PATH = r".\takway\savePath"
class speaker_verfication:
def __init__(self,
@ -26,11 +26,9 @@ class speaker_verfication:
device=device)
self.save_embeddings = save_embeddings
def wav2embeddings(self, speaker_1_wav, save_path=None):
def wav2embeddings(self, speaker_1_wav):
result = self.pipeline([speaker_1_wav], output_emb=True)
speaker_1_emb = result['embs'][0]
if save_path is not None:
np.save(save_path, speaker_1_emb)
return speaker_1_emb
def _verifaction(self, speaker_1_wav, speaker_2_wav, threshold, save_path):
@ -55,19 +53,10 @@ class speaker_verfication:
return "no"
def verfication(self,
base_emb=None,
speaker_1_wav=None,
speaker_2_wav=None,
threshold=0.333,
save_path=None):
if base_emb is not None and speaker_1_wav is not None:
raise ValueError("Only need one of them, base_emb or speaker_1_wav")
if base_emb is not None and speaker_2_wav is not None:
return self._verifaction_from_embedding(base_emb, speaker_2_wav, threshold)
elif speaker_1_wav is not None and speaker_2_wav is not None:
return self._verifaction(speaker_1_wav, speaker_2_wav, threshold, save_path)
else:
raise NotImplementedError
base_emb,
speaker_emb,
threshold=0.333, ):
return np.dot(base_emb, speaker_emb) / (np.linalg.norm(base_emb) * np.linalg.norm(speaker_emb)) > threshold
if __name__ == '__main__':
verifier = speaker_verfication(**ERES2NETV2)

View File

@ -1,7 +0,0 @@
1. 安装 melo
参考 https://github.com/myshell-ai/OpenVoice/blob/main/docs/USAGE.md#openvoice-v2
2. 修改 openvoice_utils.py 中的路径 (包括 SOURCE_SE_DIR / CACHE_PATH / OPENVOICE_TONE_COLOR_CONVERTER.converter_path )
其中:
SOURCE_SE_DIR 改为 utils/assets/ses
OPENVOICE_TONE_COLOR_CONVERTER.converter_path 修改为下载到的模型参数路径,下载链接为 https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/checkpoints_v2_0417.zip
3. 参考 examples/tts_demo.ipynb 进行代码迁移

View File

@ -140,6 +140,7 @@ def get_se(audio_path, vc_model, target_dir='processed', vad=True):
# if os.path.isdir(audio_path):
# wavs_folder = audio_path
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if vad:
wavs_folder = split_audio_vad(audio_path, target_dir=target_dir, audio_name=audio_name)
else:

View File

@ -0,0 +1,57 @@
{
"_version_": "v2",
"data": {
"sampling_rate": 22050,
"filter_length": 1024,
"hop_length": 256,
"win_length": 1024,
"n_speakers": 0
},
"model": {
"zero_g": true,
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0.1,
"resblock": "1",
"resblock_kernel_sizes": [
3,
7,
11
],
"resblock_dilation_sizes": [
[
1,
3,
5
],
[
1,
3,
5
],
[
1,
3,
5
]
],
"upsample_rates": [
8,
8,
2,
2
],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [
16,
16,
4,
4
],
"gin_channels": 256
}
}

View File

@ -1,7 +1,6 @@
import os
import re
from glob import glob
import hashlib
from tqdm.auto import tqdm
import soundfile as sf
import numpy as np
@ -16,11 +15,15 @@ from .openvoice.api import ToneColorConverter
from .openvoice.mel_processing import spectrogram_torch
# torchaudio
import torchaudio.functional as F
# 存储 BASE SPEAKER 的 embedding(source_se) 的路径
SOURCE_SE_DIR = r"D:\python\OpenVoice\checkpoints_v2\base_speakers\ses"
current_file_path = os.path.abspath(__file__)
utils_dir = os.path.dirname(os.path.dirname(current_file_path))
SOURCE_SE_DIR = os.path.join(utils_dir,'assets','ses')
# 存储缓存文件的路径
CACHE_PATH = r"D:\python\OpenVoice\processed"
CACHE_PATH = r"/tmp/openvoice_cache"
OPENVOICE_BASE_TTS={
"model_type": "open_voice_base_tts",
@ -28,10 +31,11 @@ OPENVOICE_BASE_TTS={
"language": "ZH",
}
converter_path = os.path.join(os.path.dirname(current_file_path),'openvoice_model')
OPENVOICE_TONE_COLOR_CONVERTER={
"model_type": "open_voice_converter",
# 模型参数路径
"converter_path": r"D:\python\OpenVoice\checkpoints_v2\converter",
"converter_path": converter_path,
}
class TextToSpeech:
@ -118,6 +122,7 @@ class TextToSpeech:
elif isinstance(se, torch.Tensor):
self.target_se = se.float().to(self.device)
#语音转numpy
def audio2numpy(self, audio_data: Union[bytes, np.ndarray]):
"""
将字节流的audio转为numpy类型也可以传入numpy类型
@ -143,30 +148,14 @@ class TextToSpeech:
return: np.ndarray
"""
audio_data = self.audio2numpy(audio_data)
if not os.path.exists(CACHE_PATH):
os.makedirs(CACHE_PATH)
from scipy.io import wavfile
audio_path = os.path.join(CACHE_PATH, "tmp.wav")
wavfile.write(audio_path, rate=rate, data=audio_data)
se, _ = se_extractor.get_se(audio_path, self.tone_color_converter, target_dir=CACHE_PATH, vad=False)
# device = self.tone_color_converter.device
# version = self.tone_color_converter.version
# if self.debug:
# print("OpenVoice version:", version)
# audio_name = f"tmp_{version}_{hashlib.sha256(audio_data.tobytes()).hexdigest()[:16].replace('/','_^')}"
# if vad:
# wavs_folder = se_extractor.split_audio_vad(audio_path, target_dir=CACHE_PATH, audio_name=audio_name)
# else:
# wavs_folder = se_extractor.split_audio_whisper(audio_data, target_dir=CACHE_PATH, audio_name=audio_name)
# audio_segs = glob(f'{wavs_folder}/*.wav')
# if len(audio_segs) == 0:
# raise NotImplementedError('No audio segments found!')
# # se, _ = se_extractor.get_se(audio_data, self.tone_color_converter, CACHE_PATH, vad=False)
# se = self.tone_color_converter.extract_se(audio_segs)
return se.cpu().detach().numpy()
def tensor2numpy(self, audio_data: torch.Tensor):
@ -175,11 +164,14 @@ class TextToSpeech:
"""
return audio_data.cpu().detach().float().numpy()
def numpy2bytes(self, audio_data: np.ndarray):
"""
numpy类型转bytes
"""
return (audio_data*32768.0).astype(np.int32).tobytes()
def numpy2bytes(self, audio_data):
if isinstance(audio_data, np.ndarray):
if audio_data.dtype == np.dtype('float32'):
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
audio_data = audio_data.tobytes()
return audio_data
else:
raise TypeError("audio_data must be a numpy array")
def _base_tts(self,
text: str,
@ -271,6 +263,11 @@ class TextToSpeech:
audio: tensor
sr: 生成音频的采样速率
"""
if source_se is not None:
source_se = torch.tensor(source_se.astype(np.float32)).to(self.device)
if target_se is not None:
target_se = torch.tensor(target_se.astype(np.float32)).to(self.device)
if source_se is None:
source_se = self.source_se
if target_se is None:
@ -296,16 +293,13 @@ class TextToSpeech:
print("tone color has been converted!")
return audio, sr
def tts(self,
def synthesize(self,
text: str,
sdp_ratio=0.2,
noise_scale=0.6,
noise_scale_w=0.8,
speed=1.0,
quite=True,
tts_info,
source_se: Optional[np.ndarray]=None,
target_se: Optional[np.ndarray]=None,
sdp_ratio=0.2,
quite=True,
tau :float=0.3,
message :str="default"):
"""
@ -322,15 +316,14 @@ class TextToSpeech:
"""
audio, sr = self._base_tts(text,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
speed=speed,
noise_scale=tts_info['noise_scale'],
noise_scale_w=tts_info['noise_scale_w'],
speed=tts_info['speed'],
quite=quite)
if self.use_tone_convert:
if self.use_tone_convert and target_se.size>0:
tts_sr = self.base_tts_model.hps.data.sampling_rate
converter_sr = self.tone_color_converter.hps.data.sampling_rate
audio = F.resample(audio, tts_sr, converter_sr)
print(audio.dtype)
audio, sr = self._convert_tone(audio,
source_se=source_se,
target_se=target_se,
@ -350,3 +343,4 @@ class TextToSpeech:
"""
sf.write(save_path, audio, sample_rate)
print(f"Audio saved to {save_path}")

View File

@ -2,6 +2,7 @@ import os
import numpy as np
import torch
from torch import LongTensor
from typing import Optional
import soundfile as sf
# vits
from .vits import utils, commons
@ -79,19 +80,19 @@ class TextToSpeech:
print(f"Synthesis time: {time.time() - start_time} s")
return audio
def synthesize(self, text, language, speaker_id, noise_scale, noise_scale_w, length_scale, save_audio=False, return_bytes=False):
def synthesize(self, text, tts_info,target_se: Optional[np.ndarray]=None, save_audio=False, return_bytes=True):
if not len(text):
return "输入文本不能为空!", None
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
if len(text) > 100 and self.limitation:
return f"输入文字过长!{len(text)}>100", None
text = self._preprocess_text(text, language)
audio = self._generate_audio(text, speaker_id, noise_scale, noise_scale_w, length_scale)
text = self._preprocess_text(text, tts_info['language'])
audio = self._generate_audio(text, tts_info['speaker_id'], tts_info['noise_scale'], tts_info['noise_scale_w'], tts_info['length_scale'])
if self.debug or save_audio:
self.save_audio(audio, self.RATE, 'output_file.wav')
if return_bytes:
audio = self.convert_numpy_to_bytes(audio)
return self.RATE, audio
return audio, self.RATE
def convert_numpy_to_bytes(self, audio_data):
if isinstance(audio_data, np.ndarray):

View File

@ -1,7 +1,9 @@
import websockets
import datetime
import hashlib
import base64
import hmac
import json
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
from datetime import datetime
@ -35,3 +37,33 @@ def generate_xf_asr_url():
}
url = url + '?' + urlencode(v)
return url
def make_first_frame(buf):
first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vad_eos":10000},
"data":{"status":0,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
return json.dumps(first_frame)
def make_continue_frame(buf):
continue_frame = {"data":{"status":1,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
return json.dumps(continue_frame)
def make_last_frame(buf):
last_frame = {"data":{"status":2,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
return json.dumps(last_frame)
def parse_xfasr_recv(message):
code = message['code']
if code!=0:
raise Exception("讯飞ASR错误码"+str(code))
else:
data = message['data']['result']['ws']
result = ""
for i in data:
for w in i['cw']:
result += w['w']
return result
async def xf_asr_websocket_factory():
url = generate_xf_asr_url()
return await websockets.connect(url)