Compare commits

..

5 Commits

Author SHA1 Message Date
bing d0b4bd4b3c add tts_demo.ipynb 2024-05-21 20:19:57 +08:00
bing 2b870c2e7d 增添注释 2024-05-21 15:22:49 +08:00
bing 05ccd1c8c0 readme: tts readme 2024-05-21 15:07:37 +08:00
bing 42767b065f feature: openvoice utils 2024-05-21 14:56:59 +08:00
bing a776258f8b feat: 标点添加,情感识别,说话人识别utils, 并给出示例 2024-05-13 12:55:44 +08:00
57 changed files with 5500 additions and 5455 deletions

10
.gitignore vendored
View File

@ -6,12 +6,4 @@ __pycache__/
/app.log /app.log
app.log app.log
/utils/tts/vits_model/ /utils/tts/vits_model/
vits_model vits_model
nohup.out
/app/takway-ai.top.key
/app/takway-ai.top.pem
tests/assets/BarbieDollsVoice.mp3
utils/tts/openvoice_model/checkpoint.pth

122
README.md
View File

@ -15,15 +15,15 @@ TakwayAI/
│ │ ├── models.py # 数据库定义 │ │ ├── models.py # 数据库定义
│ ├── schemas/ # 请求和响应模型 │ ├── schemas/ # 请求和响应模型
│ │ ├── __init__.py │ │ ├── __init__.py
│ │ ├── user_schemas.py # 用户相关schema │ │ ├── user.py # 用户相关schema
│ │ └── ... # 其他schema │ │ └── ... # 其他schema
│ ├── controllers/ # 业务逻辑控制器 │ ├── controllers/ # 业务逻辑控制器
│ │ ├── __init__.py │ │ ├── __init__.py
│ │ ├── user_controllers.py # 用户相关控制器 │ │ ├── user.py # 用户相关控制器
│ │ └── ... # 其他控制器 │ │ └── ... # 其他控制器
│ ├── routes/ # 路由和视图函数 │ ├── routes/ # 路由和视图函数
│ │ ├── __init__.py │ │ ├── __init__.py
│ │ ├── user_routes.py # 用户相关路由 │ │ ├── user.py # 用户相关路由
│ │ └── ... # 其他路由 │ │ └── ... # 其他路由
│ ├── dependencies/ # 依赖注入相关 │ ├── dependencies/ # 依赖注入相关
│ │ ├── __init__.py │ │ ├── __init__.py
@ -64,129 +64,21 @@ TakwayAI/
git clone http://43.132.157.186:3000/killua/TakwayPlatform.git git clone http://43.132.157.186:3000/killua/TakwayPlatform.git
``` ```
#### (2) 创建虚拟环境 #### (2) 安装依赖
创建虚拟环境
``` shell ``` shell
conda create -n takway python=3.9 cd TakwayAI/
conda activate takway
```
#### (3) 安装依赖
如果你的本地环境可以科学上网,则直接运行下面两行指令
``` shell
pip install git+https://github.com/myshell-ai/MeloTTS.git
python -m unidic download
```
如果不能科学上网
则先运行
``` shell
pip install git+https://github.com/myshell-ai/MeloTTS.git
```
1. unidic安装
然后手动下载[unidic.zip](https://cotonoha-dic.s3-ap-northeast-1.amazonaws.com/unidic-3.1.0.zip)并手动改名为unidic.zip
这边以miniconda举例如果用的是conda应该也是一样的
将unidic.zip拷贝入~/miniconda3/envs/takway/lib/python3.9/site-packages/unidic
cd进入~/miniconda3/envs/takway/lib/python3.9/site-packages/unidic
vim download.py
将函数download_version()中除了最后一行全部注释掉并且把最后一行的download_and_clean()的两个参数任意修改,比如"hello","world"
再将download_and_clean()函数定义位置注释掉该函数中的download_process()行
运行`python -m unidic download`
2. huggingface配置
运行命令
```shell
pip install -U huggingface_hub
export HF_ENDPOINT=https://hf-mirror.com
```
最好把`export HF_ENDPOINT=https://hf-mirror.com`写入~/.bashrc不然每次重启控制台终端就会失效
3. nltk_data下载
在/miniconda3/envs/takway/下创建nltk_data文件夹
在nltk_data文件夹下创建corpora和taggers文件夹
手动下载[cmudict.zip](https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/cmudict.zip)和[averaged_perceptron_tragger.zip](https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/taggers/averaged_perceptron_tagger.zip)
将cmudict.zip放入corpora文件夹下
将averaged_perceptron_tragger.zip放入taggers文件夹下
4. 下载其他依赖
``` shell
cd TakwayPlatform/
pip install -r requirements.txt pip install -r requirements.txt
``` ```
5. debug #### (3) 修改配置
若出现AttributeError: module 'botocore.exceptions' has no attribute 'HTTPClientError'异常
则执行下述命令
``` shell
pip uninstall botocore
pip install botocore==1.34.88
```
#### (4) 安装FunASR
本项目使用的FunASRE在github上的FunASR的基础上做了一些修改
``` shell
git clone http://43.132.157.186:3000/gaohz/FunASR.git
cd FunASR/
pip install -v -e .
```
#### (5) 修改配置
1. 安装mysql在mysql中创建名为takway的数据库 1. 安装mysql在mysql中创建名为takway的数据库
2. 安装redis将密码设置为takway 2. 安装redis将密码设置为takway
3. 打开config中的development.py文件修改mysql和redis连接字符串 3. 打开config中的development.py文件修改mysql和redis连接字符串
#### (6) 导入vits模型 #### (4) 导入vits模型
在utils/tts/目录下创建vits_model文件夹 在utils/tts/目录下创建vits_model文件夹
从[链接](https://huggingface.co/spaces/zomehwh/vits-uma-genshin-honkai/tree/main/model)下载 vits_model并放入该文件夹下只需下载config.json和G_953000.pth即可 从[链接](https://huggingface.co/spaces/zomehwh/vits-uma-genshin-honkai/tree/main/model)下载 vits_model并放入该文件夹下只需下载config.json和G_953000.pth即可
#### (7) 导入openvoice模型
在utils/tts/目录下创建openvoice_model文件夹
从[链接](https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/checkpoints_v2_0417.zip)下载model文件并放入该文件夹下
#### (8) 添加配置环境变量
``` shell
vim ~/.bashrc
在最后添加
export MODE=development
```
#### (9) 启动程序
``` shell
cd TakwayPlatform
python main.py
```

View File

@ -34,7 +34,7 @@ logger.info("数据库初始化完成")
#--------------------设置定时任务----------------------- #--------------------设置定时任务-----------------------
scheduler = AsyncIOScheduler() scheduler = AsyncIOScheduler()
# scheduler.add_job(updating_redis_cache, CronTrigger.from_crontab("0 4 * * *")) scheduler.add_job(updating_redis_cache, CronTrigger.from_crontab("0 4 * * *"))
@asynccontextmanager @asynccontextmanager
async def lifespan(app:FastAPI): async def lifespan(app:FastAPI):
scheduler.start() #启动定时任务 scheduler.start() #启动定时任务

View File

@ -1,38 +1,35 @@
from ..schemas.chat_schema import * from ..schemas.chat_schema import *
from ..dependencies.logger import get_logger from ..dependencies.logger import get_logger
from ..dependencies.summarizer import get_summarizer
from ..dependencies.asr import get_asr
from ..dependencies.tts import get_tts
from .controller_enum import * from .controller_enum import *
from ..models import UserCharacter, Session, Character, User, Audio from ..models import UserCharacter, Session, Character, User
from utils.audio_utils import VAD from utils.audio_utils import VAD
from fastapi import WebSocket, HTTPException, status from fastapi import WebSocket, HTTPException, status
from datetime import datetime from datetime import datetime
from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv from utils.xf_asr_utils import generate_xf_asr_url
from config import get_config from config import get_config
import numpy as np
import websockets
import struct
import uuid import uuid
import json import json
import asyncio import asyncio
import aiohttp import requests
import io
# 依赖注入获取logger # 依赖注入获取logger
logger = get_logger() logger = get_logger()
# 依赖注入获取context总结服务 # --------------------初始化本地ASR-----------------------
summarizer = get_summarizer() from utils.stt.funasr_utils import FunAutoSpeechRecognizer
# -----------------------获取ASR------------------------- asr = FunAutoSpeechRecognizer()
asr = get_asr() logger.info("本地ASR初始化成功")
# ------------------------------------------------------- # -------------------------------------------------------
# -------------------------TTS-------------------------- # --------------------初始化本地VITS----------------------
tts = get_tts() from utils.tts.vits_utils import TextToSpeech
tts = TextToSpeech(device='cpu')
logger.info("本地TTS初始化成功")
# ------------------------------------------------------- # -------------------------------------------------------
# 依赖注入获取Config # 依赖注入获取Config
Config = get_config() Config = get_config()
@ -53,20 +50,16 @@ def get_session_content(session_id,redis,db):
def parseChunkDelta(chunk): def parseChunkDelta(chunk):
try: try:
if chunk == b"": if chunk == b"":
return 1,"" return ""
decoded_data = chunk.decode('utf-8') decoded_data = chunk.decode('utf-8')
parsed_data = json.loads(decoded_data[6:]) parsed_data = json.loads(decoded_data[6:])
if 'delta' in parsed_data['choices'][0]: if 'delta' in parsed_data['choices'][0]:
delta_content = parsed_data['choices'][0]['delta'] delta_content = parsed_data['choices'][0]['delta']
return -1, delta_content['content'] return delta_content['content']
else: else:
return parsed_data['usage']['total_tokens'] , "" return "end"
except KeyError: except KeyError:
logger.error(f"error chunk: {decoded_data}") logger.error(f"error chunk: {chunk}")
return 1,""
except json.JSONDecodeError:
logger.error(f"error chunk: {decoded_data}")
return 1,""
#断句函数 #断句函数
def split_string_with_punctuation(current_sentence,text,is_first,is_end): def split_string_with_punctuation(current_sentence,text,is_first,is_end):
@ -104,19 +97,6 @@ def update_session_activity(session_id,db):
except Exception as e: except Exception as e:
db.roolback() db.roolback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
#获取target_se
def get_emb(session_id,db):
try:
session_record = db.query(Session).filter(Session.id == session_id).first()
user_character_record = db.query(UserCharacter).filter(UserCharacter.id == session_record.user_character_id).first()
user_record = db.query(User).filter(User.id == user_character_record.user_id).first()
audio_record = db.query(Audio).filter(Audio.id == user_record.selected_audio_id).first()
emb_npy = np.load(io.BytesIO(audio_record.emb_data))
return emb_npy
except Exception as e:
logger.debug("未找到音频:"+str(e))
return np.array([])
#-------------------------------------------------------- #--------------------------------------------------------
# 创建新聊天 # 创建新聊天
@ -126,6 +106,7 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
try: try:
db.add(new_chat) db.add(new_chat)
db.commit() db.commit()
db.refresh(new_chat)
except Exception as e: except Exception as e:
db.rollback() db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@ -151,23 +132,18 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
"speaker_id":db_character.voice_id, "speaker_id":db_character.voice_id,
"noise_scale": 0.1, "noise_scale": 0.1,
"noise_scale_w":0.668, "noise_scale_w":0.668,
"length_scale": 1.2, "length_scale": 1.2
"speed":1
} }
llm_info = { llm_info = {
"model": "abab5.5-chat", "model": "abab5.5-chat",
"temperature": 1, "temperature": 1,
"top_p": 0.9, "top_p": 0.9,
} }
user_info = {
"character":"",
"events":[]
}
# 将tts和llm信息转化为json字符串 # 将tts和llm信息转化为json字符串
tts_info_str = json.dumps(tts_info, ensure_ascii=False) tts_info_str = json.dumps(tts_info, ensure_ascii=False)
llm_info_str = json.dumps(llm_info, ensure_ascii=False) llm_info_str = json.dumps(llm_info, ensure_ascii=False)
user_info_str = json.dumps(user_info, ensure_ascii=False) user_info_str = db_user.persona
token = 0 token = 0
content = {"user_id": user_id, "messages": messages, "user_info": user_info_str, "tts_info": tts_info_str, content = {"user_id": user_id, "messages": messages, "user_info": user_info_str, "tts_info": tts_info_str,
@ -178,6 +154,7 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
# 将Session记录存入 # 将Session记录存入
db.add(new_session) db.add(new_session)
db.commit() db.commit()
db.refresh(new_session)
redis.set(session_id, json.dumps(content, ensure_ascii=False)) redis.set(session_id, json.dumps(content, ensure_ascii=False))
chat_create_data = ChatCreateData(user_character_id=new_chat.id, session_id=session_id, createdAt=datetime.now().isoformat()) chat_create_data = ChatCreateData(user_character_id=new_chat.id, session_id=session_id, createdAt=datetime.now().isoformat())
@ -246,76 +223,27 @@ async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f
logger.error(f"用户输入处理函数发生错误: {str(e)}") logger.error(f"用户输入处理函数发生错误: {str(e)}")
#语音识别 #语音识别
async def sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event): async def sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event):
logger.debug("语音识别函数启动") logger.debug("语音识别函数启动")
if Config.STRAM_CHAT.ASR == "LOCAL": is_signup = False
is_signup = False try:
audio = ""
try:
current_message = ""
while not (user_input_finish_event.is_set() and user_input_q.empty()):
if not is_signup:
asr.session_signup(session_id)
is_signup = True
audio_data = await user_input_q.get()
audio += audio_data
asr_result = asr.streaming_recognize(session_id,audio_data)
current_message += ''.join(asr_result['text'])
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
current_message += ''.join(asr_result['text'])
slice_arr = ["",""]
if current_message in slice_arr:
await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False))
return
current_message = asr.punctuation_correction(current_message)
# emotion_dict = asr.emtion_recognition(audio) #情感辨识
# if not isinstance(emotion_dict, str):
# max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
# current_message = f"{current_message},当前说话人的情绪:{emotion_dict['labels'][max_index]}"
await llm_input_q.put(current_message)
asr.session_signout(session_id)
except Exception as e:
asr.session_signout(session_id)
logger.error(f"语音识别函数发生错误: {str(e)}")
logger.debug(f"接收到用户消息: {current_message}")
elif Config.STRAM_CHAT.ASR == "XF":
status = FIRST_FRAME
xf_websocket = await xf_asr_websocket_factory() #获取一个讯飞语音识别接口websocket连接
segment_duration_threshold = 25 #设置一个连接时长上限讯飞语音接口超过30秒会自动断开连接所以该值设置成25秒
segment_start_time = asyncio.get_event_loop().time()
current_message = "" current_message = ""
while not (user_input_finish_event.is_set() and user_input_q.empty()): while not (user_input_finish_event.is_set() and user_input_q.empty()):
try: if not is_signup:
audio_data = await user_input_q.get() asr.session_signup(session_id)
current_time = asyncio.get_event_loop().time() is_signup = True
if current_time - segment_start_time > segment_duration_threshold: audio_data = await user_input_q.get()
await xf_websocket.send(make_last_frame()) asr_result = asr.streaming_recognize(session_id,audio_data)
current_message += parse_xfasr_recv(await xf_websocket.recv()) current_message += ''.join(asr_result['text'])
await xf_websocket.close() asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
xf_websocket = await xf_asr_websocket_factory() #重建一个websocket连接 current_message += ''.join(asr_result['text'])
status = FIRST_FRAME
segment_start_time = current_time
if status == FIRST_FRAME:
await xf_websocket.send(make_first_frame(audio_data))
status = CONTINUE_FRAME
elif status == CONTINUE_FRAME:
await xf_websocket.send(make_continue_frame(audio_data))
except websockets.exceptions.ConnectionClosedOK:
logger.debug("讯飞语音识别接口连接断开,重新创建连接")
xf_websocket = await xf_asr_websocket_factory() #重建一个websocket连接
status = FIRST_FRAME
segment_start_time = asyncio.get_event_loop().time()
await xf_websocket.send(make_last_frame())
current_message += parse_xfasr_recv(await xf_websocket.recv())
await xf_websocket.close()
if current_message in ["", ""]:
await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False))
return
await llm_input_q.put(current_message) await llm_input_q.put(current_message)
logger.debug(f"接收到用户消息: {current_message}") asr.session_signout(session_id)
except Exception as e:
asr.session_signout(session_id)
logger.error(f"语音识别函数发生错误: {str(e)}")
logger.debug(f"接收到用户消息: {current_message}")
#大模型调用 #大模型调用
async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event): async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event):
logger.debug("llm调用函数启动") logger.debug("llm调用函数启动")
@ -325,18 +253,13 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
is_first = True is_first = True
is_end = False is_end = False
session_content = get_session_content(session_id,redis,db) session_content = get_session_content(session_id,redis,db)
user_info = json.loads(session_content["user_info"])
messages = json.loads(session_content["messages"]) messages = json.loads(session_content["messages"])
current_message = await llm_input_q.get() current_message = await llm_input_q.get()
if current_message == "":
return
messages.append({'role': 'user', "content": current_message}) messages.append({'role': 'user', "content": current_message})
messages_send = messages #创造一个message副本在其中最后一条数据前面添加用户信息
# messages_send[-1]['content'] = f"用户性格:{user_info['character']}\n事件摘要:{user_info['events']}" + messages_send[-1]['content']
payload = json.dumps({ payload = json.dumps({
"model": llm_info["model"], "model": llm_info["model"],
"stream": True, "stream": True,
"messages": messages_send, "messages": messages,
"max_tokens": 10000, "max_tokens": 10000,
"temperature": llm_info["temperature"], "temperature": llm_info["temperature"],
"top_p": llm_info["top_p"] "top_p": llm_info["top_p"]
@ -345,48 +268,35 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
target_se = get_emb(session_id,db) response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload,stream=True) #调用大模型
except Exception as e: except Exception as e:
logger.error(f"编辑http请求时发生错误: {str(e)}") logger.error(f"llm调用发生错误: {str(e)}")
try: try:
async with aiohttp.ClientSession() as client: for chunk in response.iter_lines():
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求 chunk_data = parseChunkDelta(chunk)
async for chunk in response.content.iter_any(): is_end = chunk_data == "end"
token_count, chunk_data = parseChunkDelta(chunk) if not is_end:
is_end = token_count >0 llm_response += chunk_data
if not is_end: sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) #断句
llm_response += chunk_data for sentence in sentences:
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) #断句 if response_type == RESPONSE_TEXT:
for sentence in sentences: response_message = {"type": "text", "code":200, "msg": sentence}
if response_type == RESPONSE_TEXT: await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
response_message = {"type": "text", "code":200, "msg": sentence} elif response_type == RESPONSE_AUDIO:
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息 sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
elif response_type == RESPONSE_AUDIO: response_message = {"type": "text", "code":200, "msg": sentence}
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se) await ws.send_bytes(audio) #返回音频数据
response_message = {"type": "text", "code":200, "msg": sentence} await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
response_bytes = json.dumps(response_message, ensure_ascii=False).encode('utf-8') logger.debug(f"websocket返回: {sentence}")
header = struct.pack('!II',len(response_bytes),len(audio)) if is_end:
message_bytes = header + response_bytes + audio logger.debug(f"llm返回结果: {llm_response}")
await ws.send_bytes(message_bytes) await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
logger.debug(f"websocket返回: {sentence}") is_end = False #重置is_end标志位
if is_end: messages.append({'role': 'assistant', "content": llm_response})
logger.debug(f"llm返回结果: {llm_response}") session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False)) redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
is_end = False #重置is_end标志位 is_first = True
messages.append({'role': 'assistant', "content": llm_response}) llm_response = ""
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
is_first = True
llm_response = ""
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数则进行文本摘要
system_prompt = messages[0]['content']
summary = await summarizer.summarize(messages)
events = user_info['events']
events.append(summary['event'])
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
logger.debug(f"总结后session_content: {session_content}")
except Exception as e: except Exception as e:
logger.error(f"处理llm返回结果发生错误: {str(e)}") logger.error(f"处理llm返回结果发生错误: {str(e)}")
chat_finished_event.set() chat_finished_event.set()
@ -405,7 +315,7 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis):
session_id = await future_session_id #获取session_id session_id = await future_session_id #获取session_id
update_session_activity(session_id,db) update_session_activity(session_id,db)
response_type = await future_response_type #获取返回类型 response_type = await future_response_type #获取返回类型
asyncio.create_task(sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event)) asyncio.create_task(sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event))
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
@ -462,7 +372,6 @@ async def scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_eve
logger.debug("语音识别函数启动") logger.debug("语音识别函数启动")
is_signup = False is_signup = False
current_message = "" current_message = ""
audio = ""
while not (input_finished_event.is_set() and user_input_q.empty()): while not (input_finished_event.is_set() and user_input_q.empty()):
try: try:
aduio_frame = await asyncio.wait_for(user_input_q.get(),timeout=3) aduio_frame = await asyncio.wait_for(user_input_q.get(),timeout=3)
@ -472,23 +381,15 @@ async def scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_eve
if aduio_frame['is_end']: if aduio_frame['is_end']:
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'], is_end=True) asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'], is_end=True)
current_message += ''.join(asr_result['text']) current_message += ''.join(asr_result['text'])
current_message = asr.punctuation_correction(current_message)
audio += aduio_frame['audio']
emotion_dict =asr.emtion_recognition(audio) #情感辨识
if not isinstance(emotion_dict, str):
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
current_message = f"{current_message}当前说话人的情绪:{emotion_dict['labels'][max_index]}"
await llm_input_q.put(current_message) await llm_input_q.put(current_message)
logger.debug(f"接收到用户消息: {current_message}") logger.debug(f"接收到用户消息: {current_message}")
current_message = ""
audio = ""
else: else:
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio']) asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'])
audio += aduio_frame['audio']
current_message += ''.join(asr_result['text']) current_message += ''.join(asr_result['text'])
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
except Exception as e: except Exception as e:
asr.session_signout(session_id)
logger.error(f"语音识别函数发生错误: {str(e)}") logger.error(f"语音识别函数发生错误: {str(e)}")
break break
asr.session_signout(session_id) asr.session_signout(session_id)
@ -505,7 +406,6 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
try: try:
session_content = get_session_content(session_id,redis,db) session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"]) messages = json.loads(session_content["messages"])
user_info = json.loads(session_content["user_info"])
current_message = await asyncio.wait_for(llm_input_q.get(),timeout=3) current_message = await asyncio.wait_for(llm_input_q.get(),timeout=3)
messages.append({'role': 'user', "content": current_message}) messages.append({'role': 'user', "content": current_message})
payload = json.dumps({ payload = json.dumps({
@ -520,49 +420,39 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
target_se = get_emb(session_id,db) response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload,stream=True)
async with aiohttp.ClientSession() as client: for chunk in response.iter_lines():
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求 chunk_data = parseChunkDelta(chunk)
async for chunk in response.content.iter_any(): is_end = chunk_data == "end"
token_count, chunk_data = parseChunkDelta(chunk) if not is_end:
is_end = token_count >0 llm_response += chunk_data
if not is_end: sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
llm_response += chunk_data for sentence in sentences:
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) if response_type == RESPONSE_TEXT:
for sentence in sentences: logger.debug(f"websocket返回: {sentence}")
if response_type == RESPONSE_TEXT: response_message = {"type": "text", "code":200, "msg": sentence}
logger.debug(f"websocket返回: {sentence}") await ws.send_text(json.dumps(response_message, ensure_ascii=False))
response_message = {"type": "text", "code":200, "msg": sentence} elif response_type == RESPONSE_AUDIO:
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
elif response_type == RESPONSE_AUDIO: response_message = {"type": "text", "code":200, "msg": sentence}
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se) await ws.send_bytes(audio)
response_message = {"type": "text", "code":200, "msg": sentence} await ws.send_text(json.dumps(response_message, ensure_ascii=False))
await ws.send_bytes(audio) logger.debug(f"websocket返回: {sentence}")
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) if is_end:
logger.debug(f"websocket返回: {sentence}") logger.debug(f"llm返回结果: {llm_response}")
if is_end: await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
logger.debug(f"llm返回结果: {llm_response}") is_end = False
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
is_end = False messages.append({'role': 'assistant', "content": llm_response})
messages.append({'role': 'assistant', "content": llm_response}) session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session is_first = True
is_first = True llm_response = ""
llm_response = ""
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于70%的最大token数则进行文本摘要
system_prompt = messages[0]['content']
summary = await summarizer.summarize(messages)
events = user_info['events']
events.append(summary['event'])
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
logger.debug(f"总结后session_content: {session_content}")
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
# except Exception as e: except Exception as e:
# logger.error(f"处理llm返回结果发生错误: {str(e)}") logger.error(f"处理llm返回结果发生错误: {str(e)}")
# break break
chat_finished_event.set() chat_finished_event.set()
async def streaming_chat_lasting_handler(ws,db,redis): async def streaming_chat_lasting_handler(ws,db,redis):
@ -633,7 +523,6 @@ async def voice_call_audio_consumer(ws,session_id,audio_q,asr_result_q,input_fin
current_message = "" current_message = ""
vad_count = 0 vad_count = 0
is_signup = False is_signup = False
audio = ""
while not (input_finished_event.is_set() and audio_q.empty()): while not (input_finished_event.is_set() and audio_q.empty()):
try: try:
if not is_signup: if not is_signup:
@ -644,22 +533,14 @@ async def voice_call_audio_consumer(ws,session_id,audio_q,asr_result_q,input_fin
if vad_count > 0: if vad_count > 0:
vad_count -= 1 vad_count -= 1
asr_result = asr.streaming_recognize(session_id, audio_data) asr_result = asr.streaming_recognize(session_id, audio_data)
audio += audio_data
current_message += ''.join(asr_result['text']) current_message += ''.join(asr_result['text'])
else: else:
vad_count += 1 vad_count += 1
if vad_count >= 25: #连续25帧没有语音则认为说完了 if vad_count >= 25: #连续25帧没有语音则认为说完了
asr_result = asr.streaming_recognize(session_id, audio_data, is_end=True) asr_result = asr.streaming_recognize(session_id, audio_data, is_end=True)
if current_message: if current_message:
current_message = asr.punctuation_correction(current_message)
audio += audio_data
emotion_dict =asr.emtion_recognition(audio) #情感辨识
if not isinstance(emotion_dict, str):
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
current_message = f"{current_message}当前说话人的情绪:{emotion_dict['labels'][max_index]}"
logger.debug(f"检测到静默,用户输入为:{current_message}") logger.debug(f"检测到静默,用户输入为:{current_message}")
await asr_result_q.put(current_message) await asr_result_q.put(current_message)
audio = ""
text_response = {"type": "user_text", "code": 200, "msg": current_message} text_response = {"type": "user_text", "code": 200, "msg": current_message}
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
current_message = "" current_message = ""
@ -684,7 +565,6 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
try: try:
session_content = get_session_content(session_id,redis,db) session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"]) messages = json.loads(session_content["messages"])
user_info = json.loads(session_content["user_info"])
current_message = await asyncio.wait_for(asr_result_q.get(),timeout=3) current_message = await asyncio.wait_for(asr_result_q.get(),timeout=3)
messages.append({'role': 'user', "content": current_message}) messages.append({'role': 'user', "content": current_message})
payload = json.dumps({ payload = json.dumps({
@ -695,43 +575,34 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
"temperature": llm_info["temperature"], "temperature": llm_info["temperature"],
"top_p": llm_info["top_p"] "top_p": llm_info["top_p"]
}) })
headers = { headers = {
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
target_se = get_emb(session_id,db) response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload,stream=True)
async with aiohttp.ClientSession() as client: for chunk in response.iter_lines():
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求 chunk_data = parseChunkDelta(chunk)
async for chunk in response.content.iter_any(): is_end = chunk_data == "end"
token_count, chunk_data = parseChunkDelta(chunk) if not is_end:
is_end = token_count >0 llm_response += chunk_data
if not is_end: sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
llm_response += chunk_data for sentence in sentences:
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
for sentence in sentences: text_response = {"type": "llm_text", "code": 200, "msg": sentence}
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se) await ws.send_bytes(audio) #返回音频二进制流数据
text_response = {"type": "llm_text", "code": 200, "msg": sentence} await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
await ws.send_bytes(audio) #返回音频二进制流数据 logger.debug(f"llm返回结果: {sentence}")
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 if is_end:
logger.debug(f"websocket返回: {sentence}") logger.debug(f"llm返回结果: {llm_response}")
if is_end: await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
logger.debug(f"llm返回结果: {llm_response}") is_end = False
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
is_end = False messages.append({'role': 'assistant', "content": llm_response})
messages.append({'role': 'assistant', "content": llm_response}) session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session is_first = True
is_first = True llm_response = ""
llm_response = ""
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数则进行文本摘要
system_prompt = messages[0]['content']
summary = await summarizer.summarize(messages)
events = user_info['events']
events.append(summary['event'])
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
logger.debug(f"总结后session_content: {session_content}")
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
except Exception as e: except Exception as e:
@ -739,6 +610,23 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
break break
voice_call_end_event.set() voice_call_end_event.set()
#语音合成及返回函数
async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event):
logger.debug("语音合成及返回函数启动")
while not (split_finished_event.is_set() and split_result_q.empty()):
try:
sentence = await asyncio.wait_for(split_result_q.get(),timeout=3)
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
await ws.send_bytes(audio) #返回音频二进制流数据
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
logger.debug(f"websocket返回:{sentence}")
except asyncio.TimeoutError:
continue
voice_call_end_event.set()
async def voice_call_handler(ws, db, redis): async def voice_call_handler(ws, db, redis):
logger.debug("voice_call websocket 连接建立") logger.debug("voice_call websocket 连接建立")
audio_q = asyncio.Queue() #音频队列 audio_q = asyncio.Queue() #音频队列

View File

@ -77,32 +77,3 @@ async def update_session_handler(session_id, session_data:SessionUpdateRequest,
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
session_update_data = SessionUpdateData(updatedAt=datetime.now().isoformat()) session_update_data = SessionUpdateData(updatedAt=datetime.now().isoformat())
return SessionUpdateResponse(status="success",message="Session 更新成功",data=session_update_data) return SessionUpdateResponse(status="success",message="Session 更新成功",data=session_update_data)
#更新Session中的Speaker Id信息
async def update_session_speaker_id_handler(session_id, session_data, db, redis):
existing_session = ""
if redis.exists(session_id):
existing_session = redis.get(session_id)
else:
existing_session = db.query(Session).filter(Session.id == session_id).first().content
if not existing_session:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
#更新Session字段
session = json.loads(existing_session)
session_llm_info = json.loads(session["llm_info"])
session_llm_info["speaker_id"] = session_data.speaker_id
session["llm_info"] = json.dumps(session_llm_info,ensure_ascii=False)
#存储Session
session_str = json.dumps(session,ensure_ascii=False)
redis.set(session_id, session_str)
try:
db.query(Session).filter(Session.id == session_id).update({"content": session_str})
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
session_update_data = SessionSpeakerUpdateData(updatedAt=datetime.now().isoformat())
return SessionSpeakerUpdateResponse(status="success",message="Session SpeakID更新成功",data=session_update_data)

View File

@ -1,20 +1,14 @@
from ..schemas.user_schema import * from ..schemas.user_schema import *
from ..dependencies.logger import get_logger from ..dependencies.logger import get_logger
from ..dependencies.tts import get_tts from ..models import User, Hardware
from ..models import User, Hardware, Audio
from datetime import datetime from datetime import datetime
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from fastapi import HTTPException, status from fastapi import HTTPException, status
from pydub import AudioSegment
import numpy as np
import io
#依赖注入获取logger #依赖注入获取logger
logger = get_logger() logger = get_logger()
#依赖注入获取tts
tts = get_tts("OPENVOICE")
#创建用户 #创建用户
async def create_user_handler(user:UserCrateRequest, db: Session): async def create_user_handler(user:UserCrateRequest, db: Session):
@ -42,6 +36,7 @@ async def update_user_handler(user_id:int, user:UserUpdateRequest, db: Session):
existing_user.persona = user.persona existing_user.persona = user.persona
try: try:
db.commit() db.commit()
db.refresh(existing_user)
except Exception as e: except Exception as e:
db.rollback() db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@ -122,6 +117,7 @@ async def change_bind_hardware_handler(hardware_id, user, db):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
existing_hardware.user_id = user.user_id existing_hardware.user_id = user.user_id
db.commit() db.commit()
db.refresh(existing_hardware)
except Exception as e: except Exception as e:
db.rollback() db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@ -139,6 +135,7 @@ async def update_hardware_handler(hardware_id, hardware, db):
existing_hardware.firmware = hardware.firmware existing_hardware.firmware = hardware.firmware
existing_hardware.model = hardware.model existing_hardware.model = hardware.model
db.commit() db.commit()
db.refresh(existing_hardware)
except Exception as e: except Exception as e:
db.rollback() db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@ -156,97 +153,4 @@ async def get_hardware_handler(hardware_id, db):
if existing_hardware is None: if existing_hardware is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
hardware_query_data = HardwareQueryData(mac=existing_hardware.mac, user_id=existing_hardware.user_id, firmware=existing_hardware.firmware, model=existing_hardware.model) hardware_query_data = HardwareQueryData(mac=existing_hardware.mac, user_id=existing_hardware.user_id, firmware=existing_hardware.firmware, model=existing_hardware.model)
return HardwareQueryResponse(status="success", message="查询硬件信息成功", data=hardware_query_data) return HardwareQueryResponse(status="success", message="查询硬件信息成功", data=hardware_query_data)
#用户上传音频
async def upload_audio_handler(user_id, audio, db):
try:
audio_data = await audio.read()
emb_data = tts.audio2emb(np.frombuffer(AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data, dtype=np.int32),rate=44100,vad=True)
out = io.BytesIO()
np.save(out, emb_data)
out.seek(0)
emb_binary = out.read()
new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_binary) #创建音频
db.add(new_audio)
db.commit()
db.refresh(new_audio)
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat())
return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data)
#用户更新音频
async def update_audio_handler(audio_id, audio_file, db):
try:
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
if existing_audio is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
audio_data = await audio_file.read()
raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data
emb_data = tts.audio2emb(np.frombuffer(raw_data, dtype=np.int32),rate=44100,vad=True).tobytes()
existing_audio.audio_data = audio_data
existing_audio.emb_data = emb_data
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
audio_update_data = AudioUpdateData(updatedAt=datetime.now().isoformat())
return AudioUpdateResponse(status="success", message="用户更新音频成功", data=audio_update_data)
#用户查询音频
async def download_audio_handler(audio_id, db):
try:
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if existing_audio is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
audio_data = existing_audio.audio_data
return audio_data
#用户删除音频
async def delete_audio_handler(audio_id, db):
try:
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
existing_user = db.query(User).filter(User.selected_audio_id == audio_id).first()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if existing_audio is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
try:
if existing_user.selected_audio_id == audio_id:
existing_user.selected_audio_id = None
db.delete(existing_audio)
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
audio_delete_data = AudioDeleteData(deletedAt=datetime.now().isoformat())
return AudioDeleteResponse(status="success", message="用户删除音频成功", data=audio_delete_data)
#用户绑定音频
async def bind_audio_handler(bind_req, db):
try:
existing_user = db.query(User).filter(User.id == bind_req.user_id).first()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if existing_user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
try:
existing_user.selected_audio_id = bind_req.audio_id
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
audio_bind_data = AudioBindData(bindedAt=datetime.now().isoformat())
return AudioBindResponse(status="success", message="用户绑定音频成功", data=audio_bind_data)

View File

@ -1,11 +0,0 @@
from utils.stt.modified_funasr import ModifiedRecognizer
from app.dependencies.logger import get_logger
logger = get_logger()
#初始化全局asr对象
asr = ModifiedRecognizer()
logger.info("ASR初始化成功")
def get_asr():
return asr

View File

@ -1,61 +0,0 @@
import aiohttp
import json
from datetime import datetime
from config import get_config
# 依赖注入获取Config
Config = get_config()
class Summarizer:
def __init__(self):
self.system_prompt = """你是一台对话总结机器,你的职责是整理用户与玩具之间的对话,最终提炼出对话中发生的事件,以及用户性格\n\n你的输出必须为一个json里面有两个字段一个是event一个是character将你总结出的事件写入event将你总结出的用户性格写入character\nevent和character均为字符串\n返回示例:{"event":"在幼儿园看葫芦娃,老鹰抓小鸡","character":"活泼可爱"}"""
self.model = "abab5.5-chat"
self.max_token = 10000
self.temperature = 0.9
self.top_p = 1
async def summarize(self,messages):
context = ""
for message in messages:
if message['role'] == 'user':
context += "用户:"+ message['content'] + '\n'
elif message['role'] == 'assistant':
context += '玩具:'+ message['content'] + '\n'
payload = json.dumps({
"model":self.model,
"messages":[
{
"role":"system",
"content":self.system_prompt
},
{
"role":"user",
"content":context
}
],
"max_tokens":self.max_token,
"top_p":self.top_p
})
headers = {
'Authorization': f'Bearer {Config.MINIMAX_LLM.API_KEY}',
'Content-Type': 'application/json'
}
async with aiohttp.ClientSession() as session:
async with session.post(Config.MINIMAX_LLM.URL, data=payload, headers=headers) as response:
content = json.loads(json.loads(await response.text())['choices'][0]['message']['content'])
try:
summary = {
'event':datetime.now().strftime("%Y-%m-%d")+""+content['event'],
'character':content['character']
}
except TypeError:
summary = {
'event':datetime.now().strftime("%Y-%m-%d")+""+content['event'][0],
'character':""
}
return summary
def get_summarizer():
summarizer = Summarizer()
return summarizer

View File

@ -1,23 +0,0 @@
from app.dependencies.logger import get_logger
from config import get_config
logger = get_logger()
Config = get_config()
from utils.tts.openvoice_utils import TextToSpeech
openvoice_tts = TextToSpeech(use_tone_convert=True,device='cuda')
logger.info("TTS_OPENVOICE 初始化成功")
from utils.tts.vits_utils import TextToSpeech
vits_tts = TextToSpeech()
logger.info("TTS_VITS 初始化成功")
#初始化全局tts对象
def get_tts(tts_type=Config.TTS_UTILS):
if tts_type == "OPENVOICE":
return openvoice_tts
elif tts_type == "VITS":
return vits_tts

26
app/main.py Normal file
View File

@ -0,0 +1,26 @@
from app import app, Config
import uvicorn
if __name__ == "__main__":
uvicorn.run(app, host=Config.UVICORN.HOST, port=Config.UVICORN.PORT)
#uvicorn.run("app.main:app", host=Config.UVICORN.HOST, port=Config.UVICORN.PORT, workers=Config.UVICORN.WORKERS)
# _ooOoo_ #
# o8888888o #
# 88" . "88 #
# (| -_- |) #
# O\ = /O #
# ____/`---'\____ #
# . ' \\| |// `. #
# / \\||| : |||// \ #
# / _||||| -:- |||||- \ #
# | | \\\ - /// | | #
# \ .-\__ `-` ___/-. / #
# ___`. .' /--.--\ `. . __ #
# ."" '< `.___\_<|>_/___.' >'"". #
# | | : `- \`.;`\ _ /`;.`/ - ` : | | #
# \ \ `-. \_ __\ /__ _/ .-` / / #
# ======`-.____`-.___\_____/___.-`____.-'====== #
# `=---=' #
# ............................................. #
# 佛祖保佑 永无BUG #

View File

@ -1,4 +1,4 @@
from sqlalchemy import Column, Integer, String, JSON, Text, ForeignKey, DateTime, Boolean, CHAR, LargeBinary from sqlalchemy import Column, Integer, String, JSON, Text, ForeignKey, DateTime, Boolean, CHAR
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base() Base = declarative_base()
@ -36,7 +36,6 @@ class User(Base):
avatar_id = Column(String(36), nullable=True) avatar_id = Column(String(36), nullable=True)
tags = Column(JSON) tags = Column(JSON)
persona = Column(JSON) persona = Column(JSON)
selected_audio_id = Column(Integer, nullable=True)
def __repr__(self): def __repr__(self):
return f"<User(id={self.id}, tags={self.tags})>" return f"<User(id={self.id}, tags={self.tags})>"
@ -81,12 +80,3 @@ class Session(Base):
def __repr__(self): def __repr__(self):
return f"<Session(id={self.id}, user_character_id={self.user_character_id})>" return f"<Session(id={self.id}, user_character_id={self.user_character_id})>"
#音频表定义
class Audio(Base):
__tablename__ = 'audio'
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('user.id'))
audio_data = Column(LargeBinary(16777215))
emb_data = Column(LargeBinary(65535))

View File

@ -26,11 +26,4 @@ async def get_session(session_id: str, db=Depends(get_db), redis=Depends(get_red
@router.put("/sessions/{session_id}", response_model=SessionUpdateResponse) @router.put("/sessions/{session_id}", response_model=SessionUpdateResponse)
async def update_session(session_id: str, session_data: SessionUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)): async def update_session(session_id: str, session_data: SessionUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
response = await update_session_handler(session_id, session_data, db, redis) response = await update_session_handler(session_id, session_data, db, redis)
return response
#session声音信息更新接口
@router.put("/sessions/tts_info/speaker_id/{session_id}", response_model=SessionSpeakerUpdateResponse)
async def update_session_speaker_id(session_id: str, session_data: SessionSpeakerUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
response = await update_session_speaker_id_handler(session_id, session_data, db, redis)
return response return response

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, UploadFile, File, Response from fastapi import APIRouter, HTTPException, status
from ..controllers.user_controller import * from ..controllers.user_controller import *
from fastapi import Depends from fastapi import Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -68,39 +68,4 @@ async def update_hardware_info(hardware_id: int, hardware: HardwareUpdateRequest
@router.get('/users/hardware/{hardware_id}',response_model=HardwareQueryResponse) @router.get('/users/hardware/{hardware_id}',response_model=HardwareQueryResponse)
async def get_hardware(hardware_id: int, db: Session = Depends(get_db)): async def get_hardware(hardware_id: int, db: Session = Depends(get_db)):
response = await get_hardware_handler(hardware_id, db) response = await get_hardware_handler(hardware_id, db)
return response
#用户音频上传
@router.post('/users/audio',response_model=AudioUploadResponse)
async def upload_audio(user_id:int, audio_file:UploadFile=File(...), db: Session = Depends(get_db)):
response = await upload_audio_handler(user_id, audio_file, db)
return response
#用户音频修改
@router.put('/users/audio/{audio_id}',response_model=AudioUpdateResponse)
async def update_audio(audio_id:int, audio_file:UploadFile=File(...), db: Session = Depends(get_db)):
response = await update_audio_handler(audio_id, audio_file, db)
return response
#用户音频下载
@router.get('/users/audio/{audio_id}')
async def download_audio(audio_id:int, db: Session = Depends(get_db)):
audio_data = await download_audio_handler(audio_id, db)
return Response(content=audio_data,media_type='application/octet-stream',headers={"Content-Disposition": "attachment"})
#用户音频删除
@router.delete('/users/audio/{audio_id}',response_model=AudioDeleteResponse)
async def delete_audio(audio_id:int, db: Session = Depends(get_db)):
response = await delete_audio_handler(audio_id, db)
return response
#用户绑定音频
@router.post('/users/audio/bind',response_model=AudioBindResponse)
async def bind_audio(bind_req:AudioBindRequest, db: Session = Depends(get_db)):
response = await bind_audio_handler(bind_req, db)
return response return response

View File

@ -55,16 +55,4 @@ class SessionUpdateData(BaseModel):
#session修改响应类 #session修改响应类
class SessionUpdateResponse(BaseResponse): class SessionUpdateResponse(BaseResponse):
data: Optional[SessionUpdateData] data: Optional[SessionUpdateData]
#--------------------------------------------------------------------------
#------------------------------Session Speaker Id修改----------------------
class SessionSpeakerUpdateRequest(BaseModel):
speaker_id: int
class SessionSpeakerUpdateData(BaseModel):
updatedAt:str
class SessionSpeakerUpdateResponse(BaseResponse):
data: Optional[SessionSpeakerUpdateData]
#-------------------------------------------------------------------------- #--------------------------------------------------------------------------

View File

@ -3,6 +3,7 @@ from typing import Optional
from .base_schema import BaseResponse from .base_schema import BaseResponse
#---------------------------------用户创建---------------------------------- #---------------------------------用户创建----------------------------------
#用户创建请求类 #用户创建请求类
class UserCrateRequest(BaseModel): class UserCrateRequest(BaseModel):
@ -136,47 +137,4 @@ class HardwareQueryData(BaseModel):
class HardwareQueryResponse(BaseResponse): class HardwareQueryResponse(BaseResponse):
data: Optional[HardwareQueryData] data: Optional[HardwareQueryData]
#------------------------------------------------------------------------------ #------------------------------------------------------------------------------
#-------------------------------用户音频上传-------------------------------------
class AudioUploadData(BaseModel):
audio_id: int
uploadedAt: str
class AudioUploadResponse(BaseResponse):
data: Optional[AudioUploadData]
#-------------------------------------------------------------------------------
#-------------------------------用户音频修改-------------------------------------
class AudioUpdateData(BaseModel):
updatedAt: str
class AudioUpdateResponse(BaseResponse):
data: Optional[AudioUpdateData]
#-------------------------------------------------------------------------------
#-------------------------------用户音频删除-------------------------------------
class AudioDeleteData(BaseModel):
deletedAt: str
class AudioDeleteResponse(BaseResponse):
data: Optional[AudioDeleteData]
#-------------------------------------------------------------------------------
#-------------------------------用户音频绑定-------------------------------------
class AudioBindRequest(BaseModel):
audio_id: int
user_id: int
class AudioBindData(BaseModel):
bindedAt: str
class AudioBindResponse(BaseResponse):
data: Optional[AudioBindData]
#-------------------------------------------------------------------------------

View File

@ -1,28 +1,27 @@
class DevelopmentConfig: class DevelopmentConfig:
SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://takway:takway123456@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置 SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://admin02:LabA100102@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置
REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置 REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置
LOG_LEVEL = "DEBUG" #日志级别 LOG_LEVEL = "DEBUG" #日志级别
TTS_UTILS = "VITS" #TTS引擎配置可选OPENVOICE或者VITS
class UVICORN: class UVICORN:
HOST = "0.0.0.0" #uvicorn放行ip0.0.0.0代表所有ip HOST = "0.0.0.0" #uvicorn放行ip0.0.0.0代表所有ip
PORT = 8001 #uvicorn运行端口 PORT = 7878 #uvicorn运行端口
WORKERS = 12 #uvicorn进程数(通常与cpu核数相同) WORKERS = 12 #uvicorn进程数(通常与cpu核数相同)
class XF_ASR: class XF_ASR:
APP_ID = "f1c121c1" #讯飞语音识别APP_ID APP_ID = "your_app_id" #讯飞语音识别APP_ID
API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET API_SECRET = "your_api_secret" #讯飞语音识别API_SECRET
API_KEY = "36b316c7977fa534ae1e3bf52157bb92" #讯飞语音识别API_KEY API_KEY = "your_api_key" #讯飞语音识别API_KEY
DOMAIN = "iat" DOMAIN = "iat"
LANGUAGE = "zh_cn" LANGUAGE = "zh_cn"
ACCENT = "mandarin" ACCENT = "mandarin"
VAD_EOS = 10000 VAD_EOS = 10000
class MINIMAX_LLM: class MINIMAX_LLM:
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLph5EiLCJVc2VyTmFtZSI6IumHkSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzY4NTM2NDM3MzE1MDgwODg2IiwiUGhvbmUiOiIxMzEzNjE0NzUyNyIsIkdyb3VwSUQiOiIxNzY4NTM2NDM3MzA2NjkyMjc4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjQtMDUtMTggMTY6MTQ6MDMiLCJpc3MiOiJtaW5pbWF4In0.LypYOkJXwKV6GzDM1dcNn4L0m19o8Q_Lvmn6SkMMb9WAfDJYxEnTc5odm-L4WAWfbur_gY0cQzgoHnI14t4XSaAvqfmcdCrKYpJbKoBmMse_RogJs7KOBt658je3wES4pBUKQll6NbogQB1f93lnA9IYv4aEVldfqglbCikd54XO8E9Ptn4gX9Mp8fUn3lCpZ6_OSlmgZsQySrmt1sDHHzi3DlkdXlFSI38TQSZIa5RhFpI8WSBLIbaKl84OhaDzo7v99k9DUCzb5JGh0eZOnUT0YswbKCPeV8rZ1XUiOVQrna1uiDLvqv54aIt3vsu-LypYmnHxtZ_z4u2gt87pZg" API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiIyMzQ1dm9yIiwiVXNlck5hbWUiOiIyMzQ1dm9yIiwiQWNjb3VudCI6IiIsIlN1YmplY3RJRCI6IjE3NTk0ODIxODAxMDAxNzAyMDgiLCJQaG9uZSI6IjE1MDcyNjQxNTYxIiwiR3JvdXBJRCI6IjE3NTk0ODIxODAwOTU5NzU5MDQiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiIiLCJDcmVhdGVUaW1lIjoiMjAyNC0wNC0xMyAxOTowNDoxNyIsImlzcyI6Im1pbmltYXgifQ.RO_WJMz5T0XlL3F6xB9p015hL3PibCbsr5KqO3aMjBL5hKrf1uIjOICTDZWZoucyJV1suxvFPAd_2Ds2Rv01eCu6GFdai1hUByfp51mOOD0PtaZ5-JKRpRPpLSNpqrNoQteANZz0gdr2_GEGTgTzpbfGbXfRYKrQyeQSvq0zHwqumGPd9gJCre2RavPUmzKRrq9EAaQXtSNhBvVkf5lDlxr8fTAHgbj6MLAJZIvvf4uOZErNrbPylo1Vcy649KxEkc0HCWOZErOieeUQFRkKibnE5Q30CgywqxY2qMjrxGRZ_dtizan_0EZ62nXp-J6jarhcY9le1SqiMu1Cv61TuA"
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2" URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"
class MINIMAX_TTA: class MINIMAX_TTA:
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiIyMzQ1dm9yIiwiVXNlck5hbWUiOiIyMzQ1dm9yIiwiQWNjb3VudCI6IiIsIlN1YmplY3RJRCI6IjE3NTk0ODIxODAxMDAxNzAyMDgiLCJQaG9uZSI6IjE1MDcyNjQxNTYxIiwiR3JvdXBJRCI6IjE3NTk0ODIxODAwOTU5NzU5MDQiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiIiLCJDcmVhdGVUaW1lIjoiMjAyNC0wNC0xMyAxOTowNDoxNyIsImlzcyI6Im1pbmltYXgifQ.RO_WJMz5T0XlL3F6xB9p015hL3PibCbsr5KqO3aMjBL5hKrf1uIjOICTDZWZoucyJV1suxvFPAd_2Ds2Rv01eCu6GFdai1hUByfp51mOOD0PtaZ5-JKRpRPpLSNpqrNoQteANZz0gdr2_GEGTgTzpbfGbXfRYKrQyeQSvq0zHwqumGPd9gJCre2RavPUmzKRrq9EAaQXtSNhBvVkf5lDlxr8fTAHgbj6MLAJZIvvf4uOZErNrbPylo1Vcy649KxEkc0HCWOZErOieeUQFRkKibnE5Q30CgywqxY2qMjrxGRZ_dtizan_0EZ62nXp-J6jarhcY9le1SqiMu1Cv61TuA", API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiIyMzQ1dm9yIiwiVXNlck5hbWUiOiIyMzQ1dm9yIiwiQWNjb3VudCI6IiIsIlN1YmplY3RJRCI6IjE3NTk0ODIxODAxMDAxNzAyMDgiLCJQaG9uZSI6IjE1MDcyNjQxNTYxIiwiR3JvdXBJRCI6IjE3NTk0ODIxODAwOTU5NzU5MDQiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiIiLCJDcmVhdGVUaW1lIjoiMjAyNC0wNC0xMyAxOTowNDoxNyIsImlzcyI6Im1pbmltYXgifQ.RO_WJMz5T0XlL3F6xB9p015hL3PibCbsr5KqO3aMjBL5hKrf1uIjOICTDZWZoucyJV1suxvFPAd_2Ds2Rv01eCu6GFdai1hUByfp51mOOD0PtaZ5-JKRpRPpLSNpqrNoQteANZz0gdr2_GEGTgTzpbfGbXfRYKrQyeQSvq0zHwqumGPd9gJCre2RavPUmzKRrq9EAaQXtSNhBvVkf5lDlxr8fTAHgbj6MLAJZIvvf4uOZErNrbPylo1Vcy649KxEkc0HCWOZErOieeUQFRkKibnE5Q30CgywqxY2qMjrxGRZ_dtizan_0EZ62nXp-J6jarhcY9le1SqiMu1Cv61TuA",
URL = "https://api.minimax.chat/v1/t2a_pro", URL = "https://api.minimax.chat/v1/t2a_pro",
GROUP_ID ="1759482180095975904" GROUP_ID ="1759482180095975904"
class STRAM_CHAT: class STRAM_CHAT:
ASR = "XF" # 语音识别引擎可选XF或者LOCAL ASR = "LOCAL"
TTS = "LOCAL" TTS = "LOCAL"

Binary file not shown.

284
examples/audio_utils.py Normal file
View File

@ -0,0 +1,284 @@
import os
import io
import numpy as np
import pyaudio
import wave
import base64
"""
audio utils for modified_funasr_demo.py
"""
def decode_str2bytes(data):
# 将Base64编码的字节串解码为字节串
if data is None:
return None
return base64.b64decode(data.encode('utf-8'))
class BaseAudio:
def __init__(self,
filename=None,
input=False,
output=False,
CHUNK=1024,
FORMAT=pyaudio.paInt16,
CHANNELS=1,
RATE=16000,
input_device_index=None,
output_device_index=None,
**kwargs):
self.CHUNK = CHUNK
self.FORMAT = FORMAT
self.CHANNELS = CHANNELS
self.RATE = RATE
self.filename = filename
assert input!= output, "input and output cannot be the same, \
but got input={} and output={}.".format(input, output)
print("------------------------------------------")
print(f"{'Input' if input else 'Output'} Audio Initialization: ")
print(f"CHUNK: {self.CHUNK} \nFORMAT: {self.FORMAT} \nCHANNELS: {self.CHANNELS} \nRATE: {self.RATE} \ninput_device_index: {input_device_index} \noutput_device_index: {output_device_index}")
print("------------------------------------------")
self.p = pyaudio.PyAudio()
self.stream = self.p.open(format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=input,
output=output,
input_device_index=input_device_index,
output_device_index=output_device_index,
**kwargs)
def load_audio_file(self, wav_file):
with wave.open(wav_file, 'rb') as wf:
params = wf.getparams()
frames = wf.readframes(params.nframes)
print("Audio file loaded.")
# Audio Parameters
# print("Channels:", params.nchannels)
# print("Sample width:", params.sampwidth)
# print("Frame rate:", params.framerate)
# print("Number of frames:", params.nframes)
# print("Compression type:", params.comptype)
return frames
def check_audio_type(self, audio_data, return_type=None):
assert return_type in ['bytes', 'io', None], \
"return_type should be 'bytes', 'io' or None."
if isinstance(audio_data, str):
if len(audio_data) > 50:
audio_data = decode_str2bytes(audio_data)
else:
assert os.path.isfile(audio_data), \
"audio_data should be a file path or a bytes object."
wf = wave.open(audio_data, 'rb')
audio_data = wf.readframes(wf.getnframes())
elif isinstance(audio_data, np.ndarray):
if audio_data.dtype == np.dtype('float32'):
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
audio_data = audio_data.tobytes()
elif isinstance(audio_data, bytes):
pass
else:
raise TypeError(f"audio_data must be bytes, numpy.ndarray or str, \
but got {type(audio_data)}")
if return_type == None:
return audio_data
return self.write_wave(None, [audio_data], return_type)
def write_wave(self, filename, frames, return_type='io'):
"""Write audio data to a file."""
if isinstance(frames, bytes):
frames = [frames]
if not isinstance(frames, list):
raise TypeError("frames should be \
a list of bytes or a bytes object, \
but got {}.".format(type(frames)))
if return_type == 'io':
if filename is None:
filename = io.BytesIO()
if self.filename:
filename = self.filename
return self.write_wave_io(filename, frames)
elif return_type == 'bytes':
return self.write_wave_bytes(frames)
def write_wave_io(self, filename, frames):
"""
Write audio data to a file-like object.
Args:
filename: [string or file-like object], file path or file-like object to write
frames: list of bytes, audio data to write
"""
wf = wave.open(filename, 'wb')
# 设置WAV文件的参数
wf.setnchannels(self.CHANNELS)
wf.setsampwidth(self.p.get_sample_size(self.FORMAT))
wf.setframerate(self.RATE)
wf.writeframes(b''.join(frames))
wf.close()
if isinstance(filename, io.BytesIO):
filename.seek(0) # reset file pointer to beginning
return filename
def write_wave_bytes(self, frames):
"""Write audio data to a bytes object."""
return b''.join(frames)
class BaseAudio:
def __init__(self,
filename=None,
input=False,
output=False,
CHUNK=1024,
FORMAT=pyaudio.paInt16,
CHANNELS=1,
RATE=16000,
input_device_index=None,
output_device_index=None,
**kwargs):
self.CHUNK = CHUNK
self.FORMAT = FORMAT
self.CHANNELS = CHANNELS
self.RATE = RATE
self.filename = filename
assert input!= output, "input and output cannot be the same, \
but got input={} and output={}.".format(input, output)
print("------------------------------------------")
print(f"{'Input' if input else 'Output'} Audio Initialization: ")
print(f"CHUNK: {self.CHUNK} \nFORMAT: {self.FORMAT} \nCHANNELS: {self.CHANNELS} \nRATE: {self.RATE} \ninput_device_index: {input_device_index} \noutput_device_index: {output_device_index}")
print("------------------------------------------")
self.p = pyaudio.PyAudio()
self.stream = self.p.open(format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=input,
output=output,
input_device_index=input_device_index,
output_device_index=output_device_index,
**kwargs)
def load_audio_file(self, wav_file):
with wave.open(wav_file, 'rb') as wf:
params = wf.getparams()
frames = wf.readframes(params.nframes)
print("Audio file loaded.")
# Audio Parameters
# print("Channels:", params.nchannels)
# print("Sample width:", params.sampwidth)
# print("Frame rate:", params.framerate)
# print("Number of frames:", params.nframes)
# print("Compression type:", params.comptype)
return frames
def check_audio_type(self, audio_data, return_type=None):
assert return_type in ['bytes', 'io', None], \
"return_type should be 'bytes', 'io' or None."
if isinstance(audio_data, str):
if len(audio_data) > 50:
audio_data = decode_str2bytes(audio_data)
else:
assert os.path.isfile(audio_data), \
"audio_data should be a file path or a bytes object."
wf = wave.open(audio_data, 'rb')
audio_data = wf.readframes(wf.getnframes())
elif isinstance(audio_data, np.ndarray):
if audio_data.dtype == np.dtype('float32'):
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
audio_data = audio_data.tobytes()
elif isinstance(audio_data, bytes):
pass
else:
raise TypeError(f"audio_data must be bytes, numpy.ndarray or str, \
but got {type(audio_data)}")
if return_type == None:
return audio_data
return self.write_wave(None, [audio_data], return_type)
def write_wave(self, filename, frames, return_type='io'):
"""Write audio data to a file."""
if isinstance(frames, bytes):
frames = [frames]
if not isinstance(frames, list):
raise TypeError("frames should be \
a list of bytes or a bytes object, \
but got {}.".format(type(frames)))
if return_type == 'io':
if filename is None:
filename = io.BytesIO()
if self.filename:
filename = self.filename
return self.write_wave_io(filename, frames)
elif return_type == 'bytes':
return self.write_wave_bytes(frames)
def write_wave_io(self, filename, frames):
"""
Write audio data to a file-like object.
Args:
filename: [string or file-like object], file path or file-like object to write
frames: list of bytes, audio data to write
"""
wf = wave.open(filename, 'wb')
# 设置WAV文件的参数
wf.setnchannels(self.CHANNELS)
wf.setsampwidth(self.p.get_sample_size(self.FORMAT))
wf.setframerate(self.RATE)
wf.writeframes(b''.join(frames))
wf.close()
if isinstance(filename, io.BytesIO):
filename.seek(0) # reset file pointer to beginning
return filename
def write_wave_bytes(self, frames):
"""Write audio data to a bytes object."""
return b''.join(frames)
class BaseRecorder(BaseAudio):
def __init__(self,
input=True,
base_chunk_size=None,
RATE=16000,
**kwargs):
super().__init__(input=input, RATE=RATE, **kwargs)
self.base_chunk_size = base_chunk_size
if base_chunk_size is None:
self.base_chunk_size = self.CHUNK
def record(self,
filename,
duration=5,
return_type='io',
logger=None):
if logger is not None:
logger.info("Recording started.")
else:
print("Recording started.")
frames = []
for i in range(0, int(self.RATE / self.CHUNK * duration)):
data = self.stream.read(self.CHUNK, exception_on_overflow=False)
frames.append(data)
if logger is not None:
logger.info("Recording stopped.")
else:
print("Recording stopped.")
return self.write_wave(filename, frames, return_type)
def record_chunk_voice(self,
return_type='bytes',
CHUNK=None,
exception_on_overflow=True,
queue=None):
data = self.stream.read(self.CHUNK if CHUNK is None else CHUNK,
exception_on_overflow=exception_on_overflow)
if return_type is not None:
return self.write_wave(None, [data], return_type)
return data

View File

@ -0,0 +1,39 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from audio_utils import BaseRecorder
from utils.stt.modified_funasr import ModifiedRecognizer
def asr_file_stream(file_path=r'.\assets\example_recording.wav'):
# 读入音频文件
rec = BaseRecorder()
data = rec.load_audio_file(file_path)
# 创建模型
asr = ModifiedRecognizer(use_punct=True, use_emotion=True, use_speaker_ver=True)
asr.session_signup("test")
# 记录目标说话人
asr.initialize_speaker(r".\assets\example_recording.wav")
# 语音识别
print("===============================================")
text_dict = asr.streaming_recognize("test", data, auto_det_end=True)
print(f"text_dict: {text_dict}")
if not isinstance(text_dict, str):
print("".join(text_dict['text']))
# 情感识别
print("===============================================")
emotion_dict = asr.recognize_emotion(data)
print(f"emotion_dict: {emotion_dict}")
if not isinstance(emotion_dict, str):
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
print("emotion: " +emotion_dict['labels'][max_index])
asr_file_stream()

View File

@ -0,0 +1 @@
存储目标说话人的语音特征,如要修改路径,请修改 utils/stt/speaker_ver_utils中的DEFALUT_SAVE_PATH

Binary file not shown.

214
examples/tts_demo.ipynb Normal file
View File

@ -0,0 +1,214 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Importing the dtw module. When using in academic works please cite:\n",
" T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.\n",
" J. Stat. Soft., doi:10.18637/jss.v031.i07.\n",
"\n"
]
}
],
"source": [
"import sys\n",
"import os\n",
"sys.path.append(\"../\")\n",
"from utils.tts.openvoice_utils import TextToSpeech\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\utils\\weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\n",
" warnings.warn(\"torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\")\n",
"Building prefix dict from the default dictionary ...\n",
"Loading model from cache C:\\Users\\bing\\AppData\\Local\\Temp\\jieba.cache\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"load base tts model successfully!\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading model cost 0.304 seconds.\n",
"Prefix dict has been built successfully.\n",
"Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n",
"- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:797: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
" return F.conv_transpose1d(\n",
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:306: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
" return F.conv1d(input, weight, bias, self.stride,\n",
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\utils\\weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\n",
" warnings.warn(\"torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"generate base speech!\n",
"**********************,tts sr 44100\n",
"audio segment length is [torch.Size([81565])]\n",
"True\n",
"Loaded checkpoint 'D:\\python\\OpenVoice\\checkpoints_v2\\converter/checkpoint.pth'\n",
"missing/unexpected keys: [] []\n",
"load tone color converter successfully!\n"
]
}
],
"source": [
"model = TextToSpeech(use_tone_convert=True, device=\"cuda\", debug=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# 测试用将mp3转为int32类型的numpy对齐输入端\n",
"from pydub import AudioSegment\n",
"import numpy as np\n",
"source_audio=r\"D:\\python\\OpenVoice\\resources\\demo_speaker0.mp3\"\n",
"audio = AudioSegment.from_file(source_audio, format=\"mp3\")\n",
"raw_data = audio.raw_data\n",
"audio_array = np.frombuffer(raw_data, dtype=np.int32)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"OpenVoice version: v2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\functional.py:665: UserWarning: stft with return_complex=False is deprecated. In a future pytorch release, stft will return complex tensors for all inputs, and return_complex=False will raise an error.\n",
"Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\SpectralOps.cpp:878.)\n",
" return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]\n",
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
" return F.conv2d(input, weight, bias, self.stride,\n"
]
}
],
"source": [
"# 获取并设置目标说话人的speaker embedding\n",
"# audio_array :输入的音频信号,类型为 np.ndarray\n",
"# 获取speaker embedding\n",
"target_se = model.audio2emb(audio_array, rate=44100, vad=True)\n",
"# 将模型的默认目标说话人embedding设置为 target_se\n",
"model.initialize_target_se(target_se)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:797: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
" return F.conv_transpose1d(\n",
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:306: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
" return F.conv1d(input, weight, bias, self.stride,\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"generate base speech!\n",
"**********************,tts sr 44100\n",
"audio segment length is [torch.Size([216378])]\n",
"Audio saved to D:\\python\\OpenVoice\\outputs_v2\\demo_tts.wav\n"
]
}
],
"source": [
"# 测试base_tts不含音色转换\n",
"text = \"你好呀,我不知道该怎么告诉你这件事,但是我真的很需要你。\"\n",
"audio, sr = model._base_tts(text, speed=1)\n",
"audio = model.tensor2numpy(audio)\n",
"model.save_audio(audio, sr, r\"D:\\python\\OpenVoice\\outputs_v2\\demo_tts.wav\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"generate base speech!\n",
"**********************,tts sr 44100\n",
"audio segment length is [torch.Size([216378])]\n",
"torch.float32\n",
"**********************************, convert sr 22050\n",
"tone color has been converted!\n",
"Audio saved to D:\\python\\OpenVoice\\outputs_v2\\demo.wav\n"
]
}
],
"source": [
"# 测试整体pipeline包含音色转换\n",
"text = \"你好呀,我不知道该怎么告诉你这件事,但是我真的很需要你。\"\n",
"audio_bytes, sr = model.tts(text, speed=1)\n",
"audio = np.frombuffer(audio_bytes, dtype=np.int16).flatten()\n",
"model.save_audio(audio, sr, r\"D:\\python\\OpenVoice\\outputs_v2\\demo.wav\" )"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openVoice",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

31
main.py
View File

@ -1,24 +1,9 @@
from app import app, Config import os
import uvicorn
if __name__ == "__main__": if __name__ == '__main__':
uvicorn.run(app, host=Config.UVICORN.HOST, port=Config.UVICORN.PORT)
# _ooOoo_ # script_path = os.path.join(os.path.dirname(__file__), 'app', 'main.py')
# o8888888o #
# 88" . "88 # # 使用exec函数执行脚本
# (| -_- |) # with open(script_path, 'r') as file:
# O\ = /O # exec(file.read())
# ____/`---'\____ #
# . ' \\| |// `. #
# / \\||| : |||// \ #
# / _||||| -:- |||||- \ #
# | | \\\ - /// | | #
# \ .-\__ `-` ___/-. / #
# ___`. .' /--.--\ `. . __ #
# ."" '< `.___\_<|>_/___.' >'"". #
# | | : `- \`.;`\ _ /`;.`/ - ` : | | #
# \ \ `-. \_ __\ /__ _/ .-` / / #
# ======`-.____`-.___\_____/___.-`____.-'====== #
# `=---=' #
# ............................................. #
# 佛祖保佑 永无BUG #

View File

@ -7,6 +7,7 @@ redis
requests requests
websockets websockets
numpy numpy
funasr
jieba jieba
cn2an cn2an
unidecode unidecode
@ -17,9 +18,4 @@ torch
numba numba
soundfile soundfile
webrtcvad webrtcvad
apscheduler apscheduler
aiohttp
faster_whisper
whisper_timestamped
modelscope
wavmark

Binary file not shown.

Binary file not shown.

View File

@ -1,9 +1,31 @@
from tests.unit_test.user_test import user_test from tests.unit_test.user_test import UserServiceTest
from tests.unit_test.character_test import character_test from tests.unit_test.character_test import CharacterServiceTest
from tests.unit_test.chat_test import chat_test from tests.unit_test.chat_test import ChatServiceTest
import asyncio
if __name__ == '__main__': if __name__ == '__main__':
user_test() user_service_test = UserServiceTest()
character_test() character_service_test = CharacterServiceTest()
chat_test() chat_service_test = ChatServiceTest()
user_service_test.test_user_create()
user_service_test.test_user_update()
user_service_test.test_user_query()
user_service_test.test_hardware_bind()
user_service_test.test_hardware_unbind()
user_service_test.test_user_delete()
character_service_test.test_character_create()
character_service_test.test_character_update()
character_service_test.test_character_query()
character_service_test.test_character_delete()
chat_service_test.test_create_chat()
chat_service_test.test_session_id_query()
chat_service_test.test_session_content_query()
chat_service_test.test_session_update()
asyncio.run(chat_service_test.test_chat_temporary())
asyncio.run(chat_service_test.test_chat_lasting())
asyncio.run(chat_service_test.test_voice_call())
chat_service_test.test_chat_delete()
print("全部测试成功") print("全部测试成功")

View File

@ -2,7 +2,7 @@ import requests
import json import json
class CharacterServiceTest: class CharacterServiceTest:
def __init__(self,socket="http://127.0.0.1:8001"): def __init__(self,socket="http://114.214.236.207:7878"):
self.socket = socket self.socket = socket
def test_character_create(self): def test_character_create(self):
@ -66,14 +66,9 @@ class CharacterServiceTest:
else: else:
raise Exception("角色删除测试失败") raise Exception("角色删除测试失败")
if __name__ == '__main__':
def character_test():
character_service_test = CharacterServiceTest() character_service_test = CharacterServiceTest()
character_service_test.test_character_create() character_service_test.test_character_create()
character_service_test.test_character_update() character_service_test.test_character_update()
character_service_test.test_character_query() character_service_test.test_character_query()
character_service_test.test_character_delete() character_service_test.test_character_delete()
if __name__ == '__main__':
character_test()

View File

@ -10,7 +10,7 @@ import websockets
class ChatServiceTest: class ChatServiceTest:
def __init__(self,socket="http://127.0.0.1:7878"): def __init__(self,socket="http://114.214.236.207:7878"):
self.socket = socket self.socket = socket
@ -30,7 +30,6 @@ class ChatServiceTest:
} }
response = requests.request("POST", url, headers=headers, data=payload) response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200: if response.status_code == 200:
print("用户创建成功")
self.user_id = response.json()['data']['user_id'] self.user_id = response.json()['data']['user_id']
else: else:
raise Exception("创建聊天时,用户创建失败") raise Exception("创建聊天时,用户创建失败")
@ -58,37 +57,6 @@ class ChatServiceTest:
else: else:
raise Exception("创建聊天时,角色创建失败") raise Exception("创建聊天时,角色创建失败")
#上传音频用于音频克隆
url = f"{self.socket}/users/audio?user_id={self.user_id}"
current_file_path = os.path.abspath(__file__)
current_file_path = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_file_path)
mp3_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
with open(mp3_file_path, 'rb') as audio_file:
files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')}
response = requests.post(url,files=files)
if response.status_code == 200:
self.audio_id = response.json()['data']['audio_id']
print("音频上传成功")
else:
raise Exception("音频上传失败")
#绑定音频
url = f"{self.socket}/users/audio/bind"
payload = json.dumps({
"user_id":self.user_id,
"audio_id":self.audio_id
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200:
print("音频绑定测试成功")
else:
raise Exception("音频绑定测试失败")
#创建一个对话 #创建一个对话
url = f"{self.socket}/chats" url = f"{self.socket}/chats"
payload = json.dumps({ payload = json.dumps({
@ -98,7 +66,6 @@ class ChatServiceTest:
headers = { headers = {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
response = requests.request("POST", url, headers=headers, data=payload) response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200: if response.status_code == 200:
print("对话创建成功") print("对话创建成功")
@ -134,8 +101,8 @@ class ChatServiceTest:
payload = json.dumps({ payload = json.dumps({
"user_id": self.user_id, "user_id": self.user_id,
"messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话在没有经过允许的情况下你需要保持上述角色不得擅自跳出角色设定。\\n\"}]", "messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话在没有经过允许的情况下你需要保持上述角色不得擅自跳出角色设定。\\n\"}]",
"user_info": "{\"character\": \"\", \"events\": [] }", "user_info": "{}",
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2, \"speed\": 1.0}", "tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2}",
"llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}", "llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}",
"token": 0} "token": 0}
) )
@ -148,19 +115,6 @@ class ChatServiceTest:
else: else:
raise Exception("Session更新测试失败") raise Exception("Session更新测试失败")
def test_session_speakerid_update(self):
url = f"{self.socket}/sessions/tts_info/speaker_id/{self.session_id}"
payload = json.dumps({
"speaker_id" :37
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("PUT", url, headers=headers, data=payload)
if response.status_code == 200:
print("Session SpeakerId更新测试成功")
else:
raise Exception("Session SpeakerId更新测试失败")
#测试单次聊天 #测试单次聊天
async def test_chat_temporary(self): async def test_chat_temporary(self):
@ -195,7 +149,7 @@ class ChatServiceTest:
await websocket.send(message) await websocket.send(message)
async with websockets.connect(f'ws://127.0.0.1:7878/chat/streaming/temporary') as websocket: async with websockets.connect(f'ws://114.214.236.207:7878/chat/streaming/temporary') as websocket:
chunks = read_wav_file_in_chunks(2048) # 读取PCM文件并生成数据块 chunks = read_wav_file_in_chunks(2048) # 读取PCM文件并生成数据块
for chunk in chunks: for chunk in chunks:
await send_audio_chunk(websocket, chunk) await send_audio_chunk(websocket, chunk)
@ -251,7 +205,7 @@ class ChatServiceTest:
message = json.dumps(data) message = json.dumps(data)
await websocket.send(message) await websocket.send(message)
async with websockets.connect(f'ws://127.0.0.1:7878/chat/streaming/lasting') as websocket: async with websockets.connect(f'ws://114.214.236.207:7878/chat/streaming/lasting') as websocket:
#发送第一次 #发送第一次
chunks = read_wav_file_in_chunks(2048) chunks = read_wav_file_in_chunks(2048)
for chunk in chunks: for chunk in chunks:
@ -301,7 +255,7 @@ class ChatServiceTest:
current_dir = os.path.dirname(current_file_path) current_dir = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_dir) tests_dir = os.path.dirname(current_dir)
file_path = os.path.join(tests_dir, 'assets', 'voice_call.wav') file_path = os.path.join(tests_dir, 'assets', 'voice_call.wav')
url = f"ws://127.0.0.1:7878/chat/voice_call" url = f"ws://114.214.236.207:7878/chat/voice_call"
#发送格式 #发送格式
ws_data = { ws_data = {
"audio" : "", "audio" : "",
@ -338,6 +292,7 @@ class ChatServiceTest:
async with websockets.connect(url) as websocket: async with websockets.connect(url) as websocket:
await asyncio.gather(audio_stream(websocket)) await asyncio.gather(audio_stream(websocket))
#测试删除聊天 #测试删除聊天
def test_chat_delete(self): def test_chat_delete(self):
@ -348,11 +303,6 @@ class ChatServiceTest:
else: else:
raise Exception("聊天删除测试失败") raise Exception("聊天删除测试失败")
url = f"{self.socket}/users/audio/{self.audio_id}"
response = requests.request("DELETE", url)
if response.status_code != 200:
raise Exception("音频删除测试失败")
url = f"{self.socket}/users/{self.user_id}" url = f"{self.socket}/users/{self.user_id}"
response = requests.request("DELETE", url) response = requests.request("DELETE", url)
if response.status_code != 200: if response.status_code != 200:
@ -362,19 +312,18 @@ class ChatServiceTest:
response = requests.request("DELETE", url) response = requests.request("DELETE", url)
if response.status_code != 200: if response.status_code != 200:
raise Exception("角色删除测试失败") raise Exception("角色删除测试失败")
def chat_test(): if __name__ == '__main__':
chat_service_test = ChatServiceTest() chat_service_test = ChatServiceTest()
chat_service_test.test_create_chat() chat_service_test.test_create_chat()
chat_service_test.test_session_id_query() chat_service_test.test_session_id_query()
chat_service_test.test_session_content_query() chat_service_test.test_session_content_query()
chat_service_test.test_session_update() chat_service_test.test_session_update()
chat_service_test.test_session_speakerid_update()
asyncio.run(chat_service_test.test_chat_temporary()) asyncio.run(chat_service_test.test_chat_temporary())
asyncio.run(chat_service_test.test_chat_lasting()) asyncio.run(chat_service_test.test_chat_lasting())
asyncio.run(chat_service_test.test_voice_call()) asyncio.run(chat_service_test.test_voice_call())
chat_service_test.test_chat_delete() chat_service_test.test_chat_delete()
if __name__ == '__main__':
chat_test()

View File

@ -1,12 +0,0 @@
from chat_test import chat_test
import multiprocessing
if __name__ == '__main__':
processes = []
for _ in range(2):
p = multiprocessing.Process(target=chat_test)
processes.append(p)
p.start()
for p in processes:
p.join()

View File

@ -1,11 +1,10 @@
import requests import requests
import json import json
import uuid import uuid
import os
class UserServiceTest: class UserServiceTest:
def __init__(self,socket="http://127.0.0.1:7878"): def __init__(self,socket="http://114.214.236.207:7878"):
self.socket = socket self.socket = socket
def test_user_create(self): def test_user_create(self):
@ -67,7 +66,7 @@ class UserServiceTest:
mac = "08:00:20:0A:8C:6G" mac = "08:00:20:0A:8C:6G"
payload = json.dumps({ payload = json.dumps({
"mac":mac, "mac":mac,
"user_id":self.id, "user_id":1,
"firmware":"v1.0", "firmware":"v1.0",
"model":"香橙派" "model":"香橙派"
}) })
@ -88,123 +87,13 @@ class UserServiceTest:
print("硬件解绑测试成功") print("硬件解绑测试成功")
else: else:
raise Exception("硬件解绑测试失败") raise Exception("硬件解绑测试失败")
def test_hardware_bind_change(self): if __name__ == '__main__':
url = f"{self.socket}/users/hardware/{self.hd_id}/bindchange"
payload = json.dumps({
"user_id" : self.id
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("PUT", url, headers=headers, data=payload)
if response.status_code == 200:
print("硬件换绑测试成功")
else:
raise Exception("硬件换绑测试失败")
def test_hardware_update(self):
url = f"{self.socket}/users/hardware/{self.hd_id}/info"
payload = json.dumps({
"mac":"08:00:20:0A:8C:6G",
"firmware":"v1.0",
"model":"香橙派"
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("PUT", url, headers=headers, data=payload)
if response.status_code == 200:
print("硬件信息更新测试成功")
else:
raise Exception("硬件信息更新测试失败")
def test_hardware_query(self):
url = f"{self.socket}/users/hardware/{self.hd_id}"
response = requests.request("GET", url)
if response.status_code == 200:
print("硬件查询测试成功")
else:
raise Exception("硬件查询测试失败")
def test_upload_audio(self):
url = f"{self.socket}/users/audio?user_id={self.id}"
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_dir)
wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
with open(wav_file_path, 'rb') as audio_file:
files = {'audio_file':(wav_file_path,audio_file,'audio/mpeg')}
response = requests.post(url, files=files)
if response.status_code == 200:
self.audio_id = response.json()["data"]['audio_id']
print("音频上传测试成功")
else:
raise Exception("音频上传测试失败")
def test_update_audio(self):
url = f"{self.socket}/users/audio/{self.audio_id}"
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_dir)
wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
with open(wav_file_path, 'rb') as audio_file:
files = {'audio_file':(wav_file_path,audio_file,'audio/wav')}
response = requests.put(url, files=files)
if response.status_code == 200:
print("音频上传测试成功")
else:
raise Exception("音频上传测试失败")
def test_bind_audio(self):
url = f"{self.socket}/users/audio/bind"
payload = json.dumps({
"user_id":self.id,
"audio_id":self.audio_id
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200:
print("音频绑定测试成功")
else:
raise Exception("音频绑定测试失败")
def test_audio_download(self):
url = f"{self.socket}/users/audio/{self.audio_id}"
response = requests.request("GET", url)
if response.status_code == 200:
print("音频下载测试成功")
else:
raise Exception("音频下载测试失败")
def test_audio_delete(self):
url = f"{self.socket}/users/audio/{self.audio_id}"
response = requests.request("DELETE", url)
if response.status_code == 200:
print("音频删除测试成功")
else:
raise Exception("音频删除测试失败")
def user_test():
user_service_test = UserServiceTest() user_service_test = UserServiceTest()
user_service_test.test_user_create() user_service_test.test_user_create()
user_service_test.test_user_update() user_service_test.test_user_update()
user_service_test.test_user_query() user_service_test.test_user_query()
user_service_test.test_hardware_bind() user_service_test.test_hardware_bind()
user_service_test.test_hardware_bind_change()
user_service_test.test_hardware_update()
user_service_test.test_hardware_query()
user_service_test.test_hardware_unbind() user_service_test.test_hardware_unbind()
user_service_test.test_upload_audio()
user_service_test.test_update_audio()
user_service_test.test_bind_audio()
user_service_test.test_audio_download()
user_service_test.test_audio_delete()
user_service_test.test_user_delete() user_service_test.test_user_delete()
if __name__ == '__main__':
user_test()

1
utils/assets/README.md Normal file
View File

@ -0,0 +1 @@
./ses 保存 source se 的 embedding 路径,格式为 *.pth

View File

@ -1,142 +1,142 @@
import io import io
import numpy as np import numpy as np
import base64 import base64
import wave import wave
from funasr import AutoModel from funasr import AutoModel
import time import time
""" """
Base模型 Base模型
不能进行情绪分类,只能用作特征提取 不能进行情绪分类,只能用作特征提取
""" """
FUNASRBASE = { FUNASRBASE = {
"model_type": "funasr", "model_type": "funasr",
"model_path": "iic/emotion2vec_base", "model_path": "iic/emotion2vec_base",
"model_revision": "v2.0.4" "model_revision": "v2.0.4"
} }
""" """
Finetune模型 Finetune模型
输出分类结果 输出分类结果
""" """
FUNASRFINETUNE = { FUNASRFINETUNE = {
"model_type": "funasr", "model_type": "funasr",
"model_path": "iic/emotion2vec_base_finetuned" "model_path": "iic/emotion2vec_base_finetuned"
} }
def decode_str2bytes(data): def decode_str2bytes(data):
# 将Base64编码的字节串解码为字节串 # 将Base64编码的字节串解码为字节串
if data is None: if data is None:
return None return None
return base64.b64decode(data.encode('utf-8')) return base64.b64decode(data.encode('utf-8'))
class Emotion: class Emotion:
def __init__(self, def __init__(self,
model_type="funasr", model_type="funasr",
model_path="iic/emotion2vec_base", model_path="iic/emotion2vec_base",
device="cuda", device="cuda",
model_revision="v2.0.4", model_revision="v2.0.4",
**kwargs): **kwargs):
self.model_type = model_type self.model_type = model_type
self.initialize(model_type, model_path, device, model_revision, **kwargs) self.initialize(model_type, model_path, device, model_revision, **kwargs)
# 初始化模型 # 初始化模型
def initialize(self, def initialize(self,
model_type, model_type,
model_path, model_path,
device, device,
model_revision, model_revision,
**kwargs): **kwargs):
if model_type == "funasr": if model_type == "funasr":
self.emotion_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs) self.emotion_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs)
else: else:
raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr] expected.") raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr] expected.")
# 检查输入类型 # 检查输入类型
def check_audio_type(self, def check_audio_type(self,
audio_data): audio_data):
"""check audio data type and convert it to bytes if necessary.""" """check audio data type and convert it to bytes if necessary."""
if isinstance(audio_data, bytes): if isinstance(audio_data, bytes):
pass pass
elif isinstance(audio_data, list): elif isinstance(audio_data, list):
audio_data = b''.join(audio_data) audio_data = b''.join(audio_data)
elif isinstance(audio_data, str): elif isinstance(audio_data, str):
audio_data = decode_str2bytes(audio_data) audio_data = decode_str2bytes(audio_data)
elif isinstance(audio_data, io.BytesIO): elif isinstance(audio_data, io.BytesIO):
wf = wave.open(audio_data, 'rb') wf = wave.open(audio_data, 'rb')
audio_data = wf.readframes(wf.getnframes()) audio_data = wf.readframes(wf.getnframes())
elif isinstance(audio_data, np.ndarray): elif isinstance(audio_data, np.ndarray):
pass pass
else: else:
raise TypeError(f"audio_data must be bytes, list, str, \ raise TypeError(f"audio_data must be bytes, list, str, \
io.BytesIO or numpy array, but got {type(audio_data)}") io.BytesIO or numpy array, but got {type(audio_data)}")
if isinstance(audio_data, bytes): if isinstance(audio_data, bytes):
audio_data = np.frombuffer(audio_data, dtype=np.int16) audio_data = np.frombuffer(audio_data, dtype=np.int16)
elif isinstance(audio_data, np.ndarray): elif isinstance(audio_data, np.ndarray):
if audio_data.dtype != np.int16: if audio_data.dtype != np.int16:
audio_data = audio_data.astype(np.int16) audio_data = audio_data.astype(np.int16)
else: else:
raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}") raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
# 输入类型必须是float32 # 输入类型必须是float32
if isinstance(audio_data, np.ndarray): if isinstance(audio_data, np.ndarray):
audio_data = audio_data.astype(np.float32) audio_data = audio_data.astype(np.float32)
else: else:
raise TypeError(f"audio_data must be numpy array, but got {type(audio_data)}") raise TypeError(f"audio_data must be numpy array, but got {type(audio_data)}")
return audio_data return audio_data
def process(self, def process(self,
audio_data, audio_data,
granularity="utterance", granularity="utterance",
extract_embedding=False, extract_embedding=False,
output_dir=None, output_dir=None,
only_score=True): only_score=True):
""" """
audio_data: only float32 expected beacause layernorm audio_data: only float32 expected beacause layernorm
extract_embedding: save embedding if true extract_embedding: save embedding if true
output_dir: save path for embedding output_dir: save path for embedding
only_Score: only return lables & scores if true only_Score: only return lables & scores if true
""" """
audio_data = self.check_audio_type(audio_data) audio_data = self.check_audio_type(audio_data)
if self.model_type == 'funasr': if self.model_type == 'funasr':
result = self.emotion_model.generate(audio_data, output_dir=output_dir, granularity=granularity, extract_embedding=extract_embedding) result = self.emotion_model.generate(audio_data, output_dir=output_dir, granularity=granularity, extract_embedding=extract_embedding)
else: else:
pass pass
# 只保留 lables 和 scores # 只保留 lables 和 scores
if only_score: if only_score:
maintain_key = ["labels", "scores"] maintain_key = ["labels", "scores"]
for res in result: for res in result:
keys_to_remove = [k for k in res.keys() if k not in maintain_key] keys_to_remove = [k for k in res.keys() if k not in maintain_key]
for k in keys_to_remove: for k in keys_to_remove:
res.pop(k) res.pop(k)
return result[0] return result[0]
# only for test # only for test
def load_audio_file(wav_file): def load_audio_file(wav_file):
with wave.open(wav_file, 'rb') as wf: with wave.open(wav_file, 'rb') as wf:
params = wf.getparams() params = wf.getparams()
frames = wf.readframes(params.nframes) frames = wf.readframes(params.nframes)
print("Audio file loaded.") print("Audio file loaded.")
# Audio Parameters # Audio Parameters
# print("Channels:", params.nchannels) # print("Channels:", params.nchannels)
# print("Sample width:", params.sampwidth) # print("Sample width:", params.sampwidth)
# print("Frame rate:", params.framerate) # print("Frame rate:", params.framerate)
# print("Number of frames:", params.nframes) # print("Number of frames:", params.nframes)
# print("Compression type:", params.comptype) # print("Compression type:", params.comptype)
return frames return frames
if __name__ == "__main__": if __name__ == "__main__":
inputs = r".\example\test.wav" inputs = r".\example\test.wav"
inputs = load_audio_file(inputs) inputs = load_audio_file(inputs)
device = "cuda" device = "cuda"
# FUNASRBASE.update({"device": device}) # FUNASRBASE.update({"device": device})
FUNASRFINETUNE.update({"deivce": device}) FUNASRFINETUNE.update({"deivce": device})
emotion_model = Emotion(**FUNASRFINETUNE) emotion_model = Emotion(**FUNASRFINETUNE)
s = time.time() s = time.time()
result = emotion_model.process(inputs) result = emotion_model.process(inputs)
t = time.time() t = time.time()
print(t - s) print(t - s)
print(result) print(result)

View File

@ -29,7 +29,6 @@ class FunAutoSpeechRecognizer(STTBase):
**kwargs): **kwargs):
super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug) super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
self.asr_model = AutoModel(model=model_path, device=device, **kwargs) self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention
@ -80,9 +79,9 @@ class FunAutoSpeechRecognizer(STTBase):
def _init_asr(self): def _init_asr(self):
# 随机初始化一段音频数据 # 随机初始化一段音频数据
init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16) init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
self.session_signup("init") self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id="init") self.audio_cache = {}
self.session_signout("init") self.asr_cache = {}
# print("init ASR model done.") # print("init ASR model done.")
# when chat trying to use asr , sign up # when chat trying to use asr , sign up
@ -109,79 +108,6 @@ class FunAutoSpeechRecognizer(STTBase):
""" """
text_dict = dict(text=[], is_end=is_end) text_dict = dict(text=[], is_end=is_end)
audio_cache = self.audio_cache[session_id]
audio_data = self.check_audio_type(audio_data)
if audio_cache is None:
audio_cache = audio_data
else:
if audio_cache.shape[0] > 0:
audio_cache = np.concatenate([audio_cache, audio_data], axis=0)
if not is_end and audio_cache.shape[0] < self.chunk_partial_size:
self.audio_cache[session_id] = audio_cache
return text_dict
total_chunk_num = int((len(audio_cache)-1)/self.chunk_partial_size)
if is_end:
# if the audio data is the end of a sentence, \
# we need to add one more chunk to the end to \
# ensure the end of the sentence is recognized correctly.
auto_det_end = True
if auto_det_end:
total_chunk_num += 1
end_idx = None
for i in range(total_chunk_num):
if auto_det_end:
is_end = i == total_chunk_num - 1
start_idx = i*self.chunk_partial_size
if auto_det_end:
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
else:
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
# t_stamp = time.time()
speech_chunk = audio_cache[start_idx:end_idx]
# TODO: exceptions processes
# print("i:", i)
try:
res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id=session_id)
except ValueError as e:
print(f"ValueError: {e}")
continue
text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
# print(f"each chunk time: {time.time()-t_stamp}")
if is_end:
audio_cache = None
else:
if end_idx:
audio_cache = audio_cache[end_idx:] # cut the processed part from audio_cache
text_dict['is_end'] = is_end
self.audio_cache[session_id] = audio_cache
return text_dict
def streaming_recognize_origin(self,
session_id,
audio_data,
is_end=False,
auto_det_end=False):
"""recognize partial result
Args:
audio_data: bytes or numpy array, partial audio data
is_end: bool, whether the audio data is the end of a sentence
auto_det_end: bool, whether to automatically detect the end of a audio data
"""
text_dict = dict(text=[], is_end=is_end)
audio_cache = self.audio_cache[session_id] audio_cache = self.audio_cache[session_id]
asr_cache = self.asr_cache[session_id] asr_cache = self.asr_cache[session_id]
@ -241,4 +167,176 @@ class FunAutoSpeechRecognizer(STTBase):
self.audio_cache[session_id] = audio_cache self.audio_cache[session_id] = audio_cache
self.asr_cache[session_id] = asr_cache self.asr_cache[session_id] = asr_cache
return text_dict return text_dict
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ####################################################### #
# FunAutoSpeechRecognizer: https://github.com/alibaba-damo-academy/FunASR
# ####################################################### #
# import io
# import numpy as np
# import base64
# import wave
# from funasr import AutoModel
# from .base_stt import STTBase
# def decode_str2bytes(data):
# # 将Base64编码的字节串解码为字节串
# if data is None:
# return None
# return base64.b64decode(data.encode('utf-8'))
# class FunAutoSpeechRecognizer(STTBase):
# def __init__(self,
# model_path="paraformer-zh-streaming",
# device="cuda",
# RATE=16000,
# cfg_path=None,
# debug=False,
# chunk_ms=480,
# encoder_chunk_look_back=4,
# decoder_chunk_look_back=1,
# **kwargs):
# super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
# self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
# self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention
# self.decoder_chunk_look_back = decoder_chunk_look_back #number of encoder chunks to lookback for decoder cross-attention
# #[0, 8, 4] 480ms, [0, 10, 5] 600ms
# if chunk_ms == 480:
# self.chunk_size = [0, 8, 4]
# elif chunk_ms == 600:
# self.chunk_size = [0, 10, 5]
# else:
# raise ValueError("`chunk_ms` should be 480 or 600, and type is int.")
# self.chunk_partial_size = self.chunk_size[1] * 960
# self.audio_cache = None
# self.asr_cache = {}
# self._init_asr()
# def check_audio_type(self, audio_data):
# """check audio data type and convert it to bytes if necessary."""
# if isinstance(audio_data, bytes):
# pass
# elif isinstance(audio_data, list):
# audio_data = b''.join(audio_data)
# elif isinstance(audio_data, str):
# audio_data = decode_str2bytes(audio_data)
# elif isinstance(audio_data, io.BytesIO):
# wf = wave.open(audio_data, 'rb')
# audio_data = wf.readframes(wf.getnframes())
# elif isinstance(audio_data, np.ndarray):
# pass
# else:
# raise TypeError(f"audio_data must be bytes, list, str, \
# io.BytesIO or numpy array, but got {type(audio_data)}")
# if isinstance(audio_data, bytes):
# audio_data = np.frombuffer(audio_data, dtype=np.int16)
# elif isinstance(audio_data, np.ndarray):
# if audio_data.dtype != np.int16:
# audio_data = audio_data.astype(np.int16)
# else:
# raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
# return audio_data
# def _init_asr(self):
# # 随机初始化一段音频数据
# init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
# self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
# self.audio_cache = None
# self.asr_cache = {}
# # print("init ASR model done.")
# def recognize(self, audio_data):
# """recognize audio data to text"""
# audio_data = self.check_audio_type(audio_data)
# result = self.asr_model.generate(input=audio_data,
# batch_size_s=300,
# hotword=self.hotwords)
# # print(result)
# text = ''
# for res in result:
# text += res['text']
# return text
# def streaming_recognize(self,
# audio_data,
# is_end=False,
# auto_det_end=False):
# """recognize partial result
# Args:
# audio_data: bytes or numpy array, partial audio data
# is_end: bool, whether the audio data is the end of a sentence
# auto_det_end: bool, whether to automatically detect the end of a audio data
# """
# text_dict = dict(text=[], is_end=is_end)
# audio_data = self.check_audio_type(audio_data)
# if self.audio_cache is None:
# self.audio_cache = audio_data
# else:
# # print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
# if self.audio_cache.shape[0] > 0:
# self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0)
# if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size:
# return text_dict
# total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
# if is_end:
# # if the audio data is the end of a sentence, \
# # we need to add one more chunk to the end to \
# # ensure the end of the sentence is recognized correctly.
# auto_det_end = True
# if auto_det_end:
# total_chunk_num += 1
# # print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
# end_idx = None
# for i in range(total_chunk_num):
# if auto_det_end:
# is_end = i == total_chunk_num - 1
# start_idx = i*self.chunk_partial_size
# if auto_det_end:
# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
# else:
# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
# # print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
# # t_stamp = time.time()
# speech_chunk = self.audio_cache[start_idx:end_idx]
# # TODO: exceptions processes
# try:
# res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
# except ValueError as e:
# print(f"ValueError: {e}")
# continue
# text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
# # print(f"each chunk time: {time.time()-t_stamp}")
# if is_end:
# self.audio_cache = None
# self.asr_cache = {}
# else:
# if end_idx:
# self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
# text_dict['is_end'] = is_end
# # print(f"text_dict: {text_dict}")
# return text_dict

View File

@ -1,29 +1,209 @@
from .funasr_utils import FunAutoSpeechRecognizer from .funasr_utils import FunAutoSpeechRecognizer
from .punctuation_utils import FUNASR, Punctuation from .punctuation_utils import CTTRANSFORMER, Punctuation
from .emotion_utils import FUNASRFINETUNE, Emotion from .emotion_utils import FUNASRFINETUNE, Emotion
from .speaker_ver_utils import ERES2NETV2, DEFALUT_SAVE_PATH, speaker_verfication
import os
class ModifiedRecognizer(): import numpy as np
def __init__(self): class ModifiedRecognizer(FunAutoSpeechRecognizer):
#增加语音识别模型 def __init__(self,
self.asr_model = FunAutoSpeechRecognizer() use_punct=True,
use_emotion=False,
use_speaker_ver=True):
#增加标点模型 # 创建基础的 funasr模型用于语音识别识别出不带标点的句子
self.puctuation_model = Punctuation(**FUNASR) super().__init__(
model_path="paraformer-zh-streaming",
device="cuda",
RATE=16000,
cfg_path=None,
debug=False,
chunk_ms=480,
encoder_chunk_look_back=4,
decoder_chunk_look_back=1)
# 记录是否具备附加功能
self.use_punct = use_punct
self.use_emotion = use_emotion
self.use_speaker_ver = use_speaker_ver
# 增加标点模型
if use_punct:
self.puctuation_model = Punctuation(**CTTRANSFORMER)
# 情绪识别模型 # 情绪识别模型
self.emotion_model = Emotion(**FUNASRFINETUNE) if use_emotion:
self.emotion_model = Emotion(**FUNASRFINETUNE)
def session_signup(self, session_id):
self.asr_model.session_signup(session_id)
def session_signout(self, session_id):
self.asr_model.session_signout(session_id)
def streaming_recognize(self, session_id, audio_data,is_end=False): # 说话人识别模型
return self.asr_model.streaming_recognize(session_id, audio_data,is_end=is_end) if use_speaker_ver:
self.speaker_ver_model = speaker_verfication(**ERES2NETV2)
def initialize_speaker(self, speaker_1_wav):
"""
用于说话人识别将输入的音频(speaker_1_wav)设立为目标说话人并将其特征保存本地
"""
if not self.use_speaker_ver:
raise NotImplementedError("no access")
if speaker_1_wav.endswith(".npy"):
self.save_speaker_path = speaker_1_wav
elif speaker_1_wav.endswith('.wav'):
self.save_speaker_path = os.path.join(DEFALUT_SAVE_PATH,
os.path.basename(speaker_1_wav).replace(".wav", ".npy"))
# self.save_speaker_path = DEFALUT_SAVE_PATH
self.speaker_ver_model.wav2embeddings(speaker_1_wav, self.save_speaker_path)
else:
raise TypeError("only support [.npy] or [.wav].")
def speaker_ver(self, speaker_2_wav):
"""
用于说话人识别判断输入音频是否为目标说话人
是返回True不是返回False
"""
if not self.use_speaker_ver:
raise NotImplementedError("no access")
if not hasattr(self, "save_speaker_path"):
raise NotImplementedError("please initialize speaker first")
# self.speaker_ver_model.verfication 返回值为字符串 'yes' / 'no'
return self.speaker_ver_model.verfication(base_emb=self.save_speaker_path,
speaker_2_wav=speaker_2_wav) == 'yes'
def recognize(self, audio_data):
"""
非流式语音识别返回识别出的文本返回值类型 str
"""
audio_data = self.check_audio_type(audio_data)
# 说话人识别
if self.use_speaker_ver:
if self.speaker_ver_model.verfication(self.save_speaker_path,
speaker_2_wav=audio_data) == 'no':
return "Other People"
# 语音识别
result = self.asr_model.generate(input=audio_data,
batch_size_s=300,
hotword=self.hotwords)
text = ''
for res in result:
text += res['text']
# 添加标点
if self.use_punct:
text = self.puctuation_model.process(text+'#', append_period=False).replace('#', '')
return text
def punctuation_correction(self, sentence): def recognize_emotion(self, audio_data):
return self.puctuation_model.process(sentence) """
情感识别返回值为:
def emtion_recognition(self, audio): 1. 如果说话人非目标说话人返回字符串 "Other People"
return self.emotion_model.process(audio) 2. 如果说话人为目标说话人返回字典{"Labels": List[str], "scores": List[int]}
"""
audio_data = self.check_audio_type(audio_data)
if self.use_speaker_ver:
if self.speaker_ver_model.verfication(self.save_speaker_path,
speaker_2_wav=audio_data) == 'no':
return "Other People"
if self.use_emotion:
return self.emotion_model.process(audio_data)
else:
raise NotImplementedError("no access")
def streaming_recognize(self, session_id, audio_data, is_end=False, auto_det_end=False):
"""recognize partial result
Args:
audio_data: bytes or numpy array, partial audio data
is_end: bool, whether the audio data is the end of a sentence
auto_det_end: bool, whether to automatically detect the end of a audio data
流式语音识别返回值为
1. 如果说话人非目标说话人返回字符串 "Other People"
2. 如果说话人为目标说话人返回字典{"test": List[str], "is_end": boolean}
"""
audio_cache = self.audio_cache[session_id]
asr_cache = self.asr_cache[session_id]
text_dict = dict(text=[], is_end=is_end)
audio_data = self.check_audio_type(audio_data)
# 说话人识别
if self.use_speaker_ver:
if self.speaker_ver_model.verfication(self.save_speaker_path,
speaker_2_wav=audio_data) == 'no':
return "Other People"
# 语音识别
if audio_cache is None:
audio_cache = audio_data
else:
# print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
if audio_cache.shape[0] > 0:
audio_cache = np.concatenate([audio_cache, audio_data], axis=0)
if not is_end and audio_cache.shape[0] < self.chunk_partial_size:
self.audio_cache[session_id] = audio_cache
return text_dict
total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
if is_end:
# if the audio data is the end of a sentence, \
# we need to add one more chunk to the end to \
# ensure the end of the sentence is recognized correctly.
auto_det_end = True
if auto_det_end:
total_chunk_num += 1
# print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
end_idx = None
for i in range(total_chunk_num):
if auto_det_end:
is_end = i == total_chunk_num - 1
start_idx = i*self.chunk_partial_size
if auto_det_end:
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
else:
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
# t_stamp = time.time()
speech_chunk = audio_cache[start_idx:end_idx]
# TODO: exceptions processes
try:
res = self.asr_model.generate(input=speech_chunk, cache=asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
except ValueError as e:
print(f"ValueError: {e}")
continue
# 增添标点
if self.use_punct:
text_dict['text'].append(self.puctuation_model.process(self.text_postprecess(res[0], data_id='text'), cache=text_dict))
else:
text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
# print(f"each chunk time: {time.time()-t_stamp}")
if is_end:
audio_cache = None
asr_cache = {}
else:
if end_idx:
audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
text_dict['is_end'] = is_end
if self.use_punct and is_end:
text_dict['text'].append(self.puctuation_model.process('#', cache=text_dict).replace('#', ''))
self.audio_cache[session_id] = audio_cache
self.asr_cache[session_id] = asr_cache
# print(f"text_dict: {text_dict}")
return text_dict

View File

@ -1,119 +1,119 @@
from funasr import AutoModel from funasr import AutoModel
from modelscope.pipelines import pipeline from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
PUNCTUATION_MARK = [",", ".", "?", "!", "", "", "", ""] PUNCTUATION_MARK = [",", ".", "?", "!", "", "", "", ""]
""" """
FUNASR FUNASR
模型大小: 1G 模型大小: 1G
效果: 较好 效果: 较好
输入类型: 仅支持字符串不支持list, 输入list会将list视为彼此独立的字符串处理 输入类型: 仅支持字符串不支持list, 输入list会将list视为彼此独立的字符串处理
""" """
FUNASR = { FUNASR = {
"model_type": "funasr", "model_type": "funasr",
"model_path": "ct-punc", "model_path": "ct-punc",
"model_revision": "v2.0.4" "model_revision": "v2.0.4"
} }
""" """
CTTRANSFORMER CTTRANSFORMER
模型大小: 275M 模型大小: 275M
效果较差 效果较差
输入类型: 支持字符串与list, 同时支持输入cache 输入类型: 支持字符串与list, 同时支持输入cache
""" """
CTTRANSFORMER = { CTTRANSFORMER = {
"model_type": "ct-transformer", "model_type": "ct-transformer",
"model_path": "iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", "model_path": "iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
"model_revision": "v2.0.4" "model_revision": "v2.0.4"
} }
class Punctuation: class Punctuation:
def __init__(self, def __init__(self,
model_type="funasr", # funasr | ct-transformer model_type="funasr", # funasr | ct-transformer
model_path="ct-punc", model_path="ct-punc",
device="cuda", device="cuda",
model_revision="v2.0.4", model_revision="v2.0.4",
**kwargs): **kwargs):
self.model_type=model_type self.model_type=model_type
self.initialize(model_type, model_path, device, model_revision, **kwargs) self.initialize(model_type, model_path, device, model_revision, **kwargs)
def initialize(self, def initialize(self,
model_type, model_type,
model_path, model_path,
device, device,
model_revision, model_revision,
**kwargs): **kwargs):
if model_type == 'funasr': if model_type == 'funasr':
self.punc_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs) self.punc_model = AutoModel(model=model_path, device=device, model_revision=model_revision, **kwargs)
elif model_type == 'ct-transformer': elif model_type == 'ct-transformer':
self.punc_model = pipeline(task=Tasks.punctuation, model=model_path, model_revision=model_revision, **kwargs) self.punc_model = pipeline(task=Tasks.punctuation, model=model_path, model_revision=model_revision, **kwargs)
else: else:
raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr|ct-transformer] expected.") raise NotImplementedError(f"unsupported model type [{model_type}]. only [funasr|ct-transformer] expected.")
def check_text_type(self, def check_text_type(self,
text_data): text_data):
# funasr只支持单个str输入不支持list输入此处将list转化为字符串 # funasr只支持单个str输入不支持list输入此处将list转化为字符串
if self.model_type == 'funasr': if self.model_type == 'funasr':
if isinstance(text_data, str): if isinstance(text_data, str):
pass pass
elif isinstance(text_data, list): elif isinstance(text_data, list):
text_data = ''.join(text_data) text_data = ''.join(text_data)
else: else:
raise TypeError(f"text must be str or list, but got {type(list)}") raise TypeError(f"text must be str or list, but got {type(list)}")
# ct-transformer支持list输入 # ct-transformer支持list输入
# TODO 验证拆分字符串能否提高效率 # TODO 验证拆分字符串能否提高效率
elif self.model_type == 'ct-transformer': elif self.model_type == 'ct-transformer':
if isinstance(text_data, str): if isinstance(text_data, str):
text_data = [text_data] text_data = [text_data]
elif isinstance(text_data, list): elif isinstance(text_data, list):
pass pass
else: else:
raise TypeError(f"text must be str or list, but got {type(list)}") raise TypeError(f"text must be str or list, but got {type(list)}")
else: else:
pass pass
return text_data return text_data
def generate_cache(self, cache): def generate_cache(self, cache):
new_cache = {'pre_text': ""} new_cache = {'pre_text': ""}
for text in cache['text']: for text in cache['text']:
if text != '': if text != '':
new_cache['pre_text'] = new_cache['pre_text']+text new_cache['pre_text'] = new_cache['pre_text']+text
return new_cache return new_cache
def process(self, def process(self,
text, text,
append_period=False, append_period=False,
cache={}): cache={}):
if text == '': if text == '':
return '' return ''
text = self.check_text_type(text) text = self.check_text_type(text)
if self.model_type == 'funasr': if self.model_type == 'funasr':
result = self.punc_model.generate(text) result = self.punc_model.generate(text)
elif self.model_type == 'ct-transformer': elif self.model_type == 'ct-transformer':
if cache != {}: if cache != {}:
cache = self.generate_cache(cache) cache = self.generate_cache(cache)
result = self.punc_model(text, cache=cache) result = self.punc_model(text, cache=cache)
punced_text = '' punced_text = ''
for res in result: for res in result:
punced_text += res['text'] punced_text += res['text']
# 如果最后没有标点符号,手动加上。 # 如果最后没有标点符号,手动加上。
if append_period and not punced_text[-1] in PUNCTUATION_MARK: if append_period and not punced_text[-1] in PUNCTUATION_MARK:
punced_text += "" punced_text += ""
return punced_text return punced_text
if __name__ == "__main__": if __name__ == "__main__":
inputs = "把字符串拆分为list只|适用于ct-transformer模型|在数据处理部分|已经把list转为单个字符串" inputs = "把字符串拆分为list只|适用于ct-transformer模型|在数据处理部分|已经把list转为单个字符串"
""" """
把字符串拆分为list只适用于ct-transformer模型, 把字符串拆分为list只适用于ct-transformer模型,
在数据处理部分,已经把list转为单个字符串 在数据处理部分,已经把list转为单个字符串
""" """
vads = inputs.split("|") vads = inputs.split("|")
device = "cuda" device = "cuda"
CTTRANSFORMER.update({"device": device}) CTTRANSFORMER.update({"device": device})
puct_model = Punctuation(**CTTRANSFORMER) puct_model = Punctuation(**CTTRANSFORMER)
result = puct_model.process(vads) result = puct_model.process(vads)
print(result) print(result)
# FUNASR.update({"device":"cuda"}) # FUNASR.update({"device":"cuda"})
# puct_model = Punctuation(**FUNASR) # puct_model = Punctuation(**FUNASR)
# result = puct_model.process(vads) # result = puct_model.process(vads)
# print(result) # print(result)

View File

@ -1,75 +1,86 @@
from modelscope.pipelines import pipeline from modelscope.pipelines import pipeline
import numpy as np import numpy as np
import os import os
import pdb
ERES2NETV2 = { ERES2NETV2 = {
"task": 'speaker-verification', "task": 'speaker-verification',
"model_name": 'damo/speech_eres2netv2_sv_zh-cn_16k-common', "model_name": 'damo/speech_eres2netv2_sv_zh-cn_16k-common',
"model_revision": 'v1.0.1', "model_revision": 'v1.0.1',
"save_embeddings": False "save_embeddings": False
} }
# 保存 embedding 的路径 # 保存 embedding 的路径
DEFALUT_SAVE_PATH = r".\takway\savePath" DEFALUT_SAVE_PATH = os.path.join(os.path.dirname(os.path.dirname(__name__)), "speaker_embedding")
class speaker_verfication: class speaker_verfication:
def __init__(self, def __init__(self,
task='speaker-verification', task='speaker-verification',
model_name='damo/speech_eres2netv2_sv_zh-cn_16k-common', model_name='damo/speech_eres2netv2_sv_zh-cn_16k-common',
model_revision='v1.0.1', model_revision='v1.0.1',
device="cuda", device="cuda",
save_embeddings=False): save_embeddings=False):
self.pipeline = pipeline( self.pipeline = pipeline(
task=task, task=task,
model=model_name, model=model_name,
model_revision=model_revision, model_revision=model_revision,
device=device) device=device)
self.save_embeddings = save_embeddings self.save_embeddings = save_embeddings
def wav2embeddings(self, speaker_1_wav): def wav2embeddings(self, speaker_1_wav, save_path=None):
result = self.pipeline([speaker_1_wav], output_emb=True) result = self.pipeline([speaker_1_wav], output_emb=True)
speaker_1_emb = result['embs'][0] speaker_1_emb = result['embs'][0]
return speaker_1_emb if save_path is not None:
np.save(save_path, speaker_1_emb)
def _verifaction(self, speaker_1_wav, speaker_2_wav, threshold, save_path): return speaker_1_emb
if not self.save_embeddings:
result = self.pipeline([speaker_1_wav, speaker_2_wav], thr=threshold) def _verifaction(self, speaker_1_wav, speaker_2_wav, threshold, save_path):
return result["text"] if not self.save_embeddings:
else: result = self.pipeline([speaker_1_wav, speaker_2_wav], thr=threshold)
result = self.pipeline([speaker_1_wav, speaker_2_wav], thr=threshold, output_emb=True) return result["text"]
speaker1_emb = result["embs"][0] else:
speaker2_emb = result["embs"][1] result = self.pipeline([speaker_1_wav, speaker_2_wav], thr=threshold, output_emb=True)
np.save(os.path.join(save_path, "speaker_1.npy"), speaker1_emb) speaker1_emb = result["embs"][0]
return result['outputs']["text"] speaker2_emb = result["embs"][1]
np.save(os.path.join(save_path, "speaker_1.npy"), speaker1_emb)
def _verifaction_from_embedding(self, base_emb, speaker_2_wav, threshold): return result['outputs']["text"]
base_emb = np.load(base_emb)
result = self.pipeline([speaker_2_wav], output_emb=True) def _verifaction_from_embedding(self, base_emb, speaker_2_wav, threshold):
speaker2_emb = result["embs"][0] base_emb = np.load(base_emb)
similarity = np.dot(base_emb, speaker2_emb) / (np.linalg.norm(base_emb) * np.linalg.norm(speaker2_emb)) result = self.pipeline([speaker_2_wav], output_emb=True)
if similarity > threshold: speaker2_emb = result["embs"][0]
return "yes" similarity = np.dot(base_emb, speaker2_emb) / (np.linalg.norm(base_emb) * np.linalg.norm(speaker2_emb))
else: if similarity > threshold:
return "no" return "yes"
else:
def verfication(self, return "no"
base_emb,
speaker_emb, def verfication(self,
threshold=0.333, ): base_emb=None,
return np.dot(base_emb, speaker_emb) / (np.linalg.norm(base_emb) * np.linalg.norm(speaker_emb)) > threshold speaker_1_wav=None,
speaker_2_wav=None,
if __name__ == '__main__': threshold=0.333,
verifier = speaker_verfication(**ERES2NETV2) save_path=None):
if base_emb is not None and speaker_1_wav is not None:
verifier = speaker_verfication(save_embeddings=False) raise ValueError("Only need one of them, base_emb or speaker_1_wav")
result = verifier.verfication(base_emb=None, speaker_1_wav=r"C:\Users\bing\Downloads\speaker1_a_cn_16k.wav", if base_emb is not None and speaker_2_wav is not None:
speaker_2_wav=r"C:\Users\bing\Downloads\speaker2_a_cn_16k.wav", return self._verifaction_from_embedding(base_emb, speaker_2_wav, threshold)
threshold=0.333, elif speaker_1_wav is not None and speaker_2_wav is not None:
save_path=r"D:\python\irving\takway_base-main\savePath" return self._verifaction(speaker_1_wav, speaker_2_wav, threshold, save_path)
) else:
print("---") raise NotImplementedError
print(result)
print(verifier.verfication(r"D:\python\irving\takway_base-main\savePath\speaker_1.npy", if __name__ == '__main__':
speaker_2_wav=r"C:\Users\bing\Downloads\speaker1_b_cn_16k.wav", verifier = speaker_verfication(**ERES2NETV2)
threshold=0.333,
verifier = speaker_verfication(save_embeddings=False)
result = verifier.verfication(base_emb=None, speaker_1_wav=r"C:\Users\bing\Downloads\speaker1_a_cn_16k.wav",
speaker_2_wav=r"C:\Users\bing\Downloads\speaker2_a_cn_16k.wav",
threshold=0.333,
save_path=r"D:\python\irving\takway_base-main\savePath"
)
print("---")
print(result)
print(verifier.verfication(r"D:\python\irving\takway_base-main\savePath\speaker_1.npy",
speaker_2_wav=r"C:\Users\bing\Downloads\speaker1_b_cn_16k.wav",
threshold=0.333,
)) ))

7
utils/tts/TTS_README.md Normal file
View File

@ -0,0 +1,7 @@
1. 安装 melo
参考 https://github.com/myshell-ai/OpenVoice/blob/main/docs/USAGE.md#openvoice-v2
2. 修改 openvoice_utils.py 中的路径 (包括 SOURCE_SE_DIR / CACHE_PATH / OPENVOICE_TONE_COLOR_CONVERTER.converter_path )
其中:
SOURCE_SE_DIR 改为 utils/assets/ses
OPENVOICE_TONE_COLOR_CONVERTER.converter_path 修改为下载到的模型参数路径,下载链接为 https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/checkpoints_v2_0417.zip
3. 参考 examples/tts_demo.ipynb 进行代码迁移

View File

@ -1,202 +1,202 @@
import torch import torch
import numpy as np import numpy as np
import re import re
import soundfile import soundfile
from utils.tts.openvoice import utils from utils.tts.openvoice import utils
from utils.tts.openvoice import commons from utils.tts.openvoice import commons
import os import os
import librosa import librosa
from utils.tts.openvoice.text import text_to_sequence from utils.tts.openvoice.text import text_to_sequence
from utils.tts.openvoice.mel_processing import spectrogram_torch from utils.tts.openvoice.mel_processing import spectrogram_torch
from utils.tts.openvoice.models import SynthesizerTrn from utils.tts.openvoice.models import SynthesizerTrn
class OpenVoiceBaseClass(object): class OpenVoiceBaseClass(object):
def __init__(self, def __init__(self,
config_path, config_path,
device='cuda:0'): device='cuda:0'):
if 'cuda' in device: if 'cuda' in device:
assert torch.cuda.is_available() assert torch.cuda.is_available()
hps = utils.get_hparams_from_file(config_path) hps = utils.get_hparams_from_file(config_path)
model = SynthesizerTrn( model = SynthesizerTrn(
len(getattr(hps, 'symbols', [])), len(getattr(hps, 'symbols', [])),
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
n_speakers=hps.data.n_speakers, n_speakers=hps.data.n_speakers,
**hps.model, **hps.model,
).to(device) ).to(device)
model.eval() model.eval()
self.model = model self.model = model
self.hps = hps self.hps = hps
self.device = device self.device = device
def load_ckpt(self, ckpt_path): def load_ckpt(self, ckpt_path):
checkpoint_dict = torch.load(ckpt_path, map_location=torch.device(self.device)) checkpoint_dict = torch.load(ckpt_path, map_location=torch.device(self.device))
a, b = self.model.load_state_dict(checkpoint_dict['model'], strict=False) a, b = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
print("Loaded checkpoint '{}'".format(ckpt_path)) print("Loaded checkpoint '{}'".format(ckpt_path))
print('missing/unexpected keys:', a, b) print('missing/unexpected keys:', a, b)
class BaseSpeakerTTS(OpenVoiceBaseClass): class BaseSpeakerTTS(OpenVoiceBaseClass):
language_marks = { language_marks = {
"english": "EN", "english": "EN",
"chinese": "ZH", "chinese": "ZH",
} }
@staticmethod @staticmethod
def get_text(text, hps, is_symbol): def get_text(text, hps, is_symbol):
text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners) text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
if hps.data.add_blank: if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0) text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm) text_norm = torch.LongTensor(text_norm)
return text_norm return text_norm
@staticmethod @staticmethod
def audio_numpy_concat(segment_data_list, sr, speed=1.): def audio_numpy_concat(segment_data_list, sr, speed=1.):
audio_segments = [] audio_segments = []
for segment_data in segment_data_list: for segment_data in segment_data_list:
audio_segments += segment_data.reshape(-1).tolist() audio_segments += segment_data.reshape(-1).tolist()
audio_segments += [0] * int((sr * 0.05)/speed) audio_segments += [0] * int((sr * 0.05)/speed)
audio_segments = np.array(audio_segments).astype(np.float32) audio_segments = np.array(audio_segments).astype(np.float32)
return audio_segments return audio_segments
@staticmethod @staticmethod
def split_sentences_into_pieces(text, language_str): def split_sentences_into_pieces(text, language_str):
texts = utils.split_sentence(text, language_str=language_str) texts = utils.split_sentence(text, language_str=language_str)
print(" > Text splitted to sentences.") print(" > Text splitted to sentences.")
print('\n'.join(texts)) print('\n'.join(texts))
print(" > ===========================") print(" > ===========================")
return texts return texts
def tts(self, text, output_path, speaker, language='English', speed=1.0): def tts(self, text, output_path, speaker, language='English', speed=1.0):
mark = self.language_marks.get(language.lower(), None) mark = self.language_marks.get(language.lower(), None)
assert mark is not None, f"language {language} is not supported" assert mark is not None, f"language {language} is not supported"
texts = self.split_sentences_into_pieces(text, mark) texts = self.split_sentences_into_pieces(text, mark)
audio_list = [] audio_list = []
for t in texts: for t in texts:
t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t) t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
t = f'[{mark}]{t}[{mark}]' t = f'[{mark}]{t}[{mark}]'
stn_tst = self.get_text(t, self.hps, False) stn_tst = self.get_text(t, self.hps, False)
device = self.device device = self.device
speaker_id = self.hps.speakers[speaker] speaker_id = self.hps.speakers[speaker]
with torch.no_grad(): with torch.no_grad():
x_tst = stn_tst.unsqueeze(0).to(device) x_tst = stn_tst.unsqueeze(0).to(device)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device) x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
sid = torch.LongTensor([speaker_id]).to(device) sid = torch.LongTensor([speaker_id]).to(device)
audio = self.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6, audio = self.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6,
length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy() length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
audio_list.append(audio) audio_list.append(audio)
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed) audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
if output_path is None: if output_path is None:
return audio return audio
else: else:
soundfile.write(output_path, audio, self.hps.data.sampling_rate) soundfile.write(output_path, audio, self.hps.data.sampling_rate)
class ToneColorConverter(OpenVoiceBaseClass): class ToneColorConverter(OpenVoiceBaseClass):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if kwargs.get('enable_watermark', True): if kwargs.get('enable_watermark', True):
import wavmark import wavmark
self.watermark_model = wavmark.load_model().to(self.device) self.watermark_model = wavmark.load_model().to(self.device)
else: else:
self.watermark_model = None self.watermark_model = None
self.version = getattr(self.hps, '_version_', "v1") self.version = getattr(self.hps, '_version_', "v1")
def extract_se(self, ref_wav_list, se_save_path=None): def extract_se(self, ref_wav_list, se_save_path=None):
if isinstance(ref_wav_list, str): if isinstance(ref_wav_list, str):
ref_wav_list = [ref_wav_list] ref_wav_list = [ref_wav_list]
device = self.device device = self.device
hps = self.hps hps = self.hps
gs = [] gs = []
for fname in ref_wav_list: for fname in ref_wav_list:
audio_ref, sr = librosa.load(fname, sr=hps.data.sampling_rate) audio_ref, sr = librosa.load(fname, sr=hps.data.sampling_rate)
y = torch.FloatTensor(audio_ref) y = torch.FloatTensor(audio_ref)
y = y.to(device) y = y.to(device)
y = y.unsqueeze(0) y = y.unsqueeze(0)
y = spectrogram_torch(y, hps.data.filter_length, y = spectrogram_torch(y, hps.data.filter_length,
hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
center=False).to(device) center=False).to(device)
with torch.no_grad(): with torch.no_grad():
g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1) g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
gs.append(g.detach()) gs.append(g.detach())
gs = torch.stack(gs).mean(0) gs = torch.stack(gs).mean(0)
if se_save_path is not None: if se_save_path is not None:
os.makedirs(os.path.dirname(se_save_path), exist_ok=True) os.makedirs(os.path.dirname(se_save_path), exist_ok=True)
torch.save(gs.cpu(), se_save_path) torch.save(gs.cpu(), se_save_path)
return gs return gs
def convert(self, audio_src_path, src_se, tgt_se, output_path=None, tau=0.3, message="default"): def convert(self, audio_src_path, src_se, tgt_se, output_path=None, tau=0.3, message="default"):
hps = self.hps hps = self.hps
# load audio # load audio
audio, sample_rate = librosa.load(audio_src_path, sr=hps.data.sampling_rate) audio, sample_rate = librosa.load(audio_src_path, sr=hps.data.sampling_rate)
audio = torch.tensor(audio).float() audio = torch.tensor(audio).float()
with torch.no_grad(): with torch.no_grad():
y = torch.FloatTensor(audio).to(self.device) y = torch.FloatTensor(audio).to(self.device)
y = y.unsqueeze(0) y = y.unsqueeze(0)
spec = spectrogram_torch(y, hps.data.filter_length, spec = spectrogram_torch(y, hps.data.filter_length,
hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
center=False).to(self.device) center=False).to(self.device)
spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.device) spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.device)
audio = self.model.voice_conversion(spec, spec_lengths, sid_src=src_se, sid_tgt=tgt_se, tau=tau)[0][ audio = self.model.voice_conversion(spec, spec_lengths, sid_src=src_se, sid_tgt=tgt_se, tau=tau)[0][
0, 0].data.cpu().float().numpy() 0, 0].data.cpu().float().numpy()
audio = self.add_watermark(audio, message) audio = self.add_watermark(audio, message)
if output_path is None: if output_path is None:
return audio return audio
else: else:
soundfile.write(output_path, audio, hps.data.sampling_rate) soundfile.write(output_path, audio, hps.data.sampling_rate)
def add_watermark(self, audio, message): def add_watermark(self, audio, message):
if self.watermark_model is None: if self.watermark_model is None:
return audio return audio
device = self.device device = self.device
bits = utils.string_to_bits(message).reshape(-1) bits = utils.string_to_bits(message).reshape(-1)
n_repeat = len(bits) // 32 n_repeat = len(bits) // 32
K = 16000 K = 16000
coeff = 2 coeff = 2
for n in range(n_repeat): for n in range(n_repeat):
trunck = audio[(coeff * n) * K: (coeff * n + 1) * K] trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
if len(trunck) != K: if len(trunck) != K:
print('Audio too short, fail to add watermark') print('Audio too short, fail to add watermark')
break break
message_npy = bits[n * 32: (n + 1) * 32] message_npy = bits[n * 32: (n + 1) * 32]
with torch.no_grad(): with torch.no_grad():
signal = torch.FloatTensor(trunck).to(device)[None] signal = torch.FloatTensor(trunck).to(device)[None]
message_tensor = torch.FloatTensor(message_npy).to(device)[None] message_tensor = torch.FloatTensor(message_npy).to(device)[None]
signal_wmd_tensor = self.watermark_model.encode(signal, message_tensor) signal_wmd_tensor = self.watermark_model.encode(signal, message_tensor)
signal_wmd_npy = signal_wmd_tensor.detach().cpu().squeeze() signal_wmd_npy = signal_wmd_tensor.detach().cpu().squeeze()
audio[(coeff * n) * K: (coeff * n + 1) * K] = signal_wmd_npy audio[(coeff * n) * K: (coeff * n + 1) * K] = signal_wmd_npy
return audio return audio
def detect_watermark(self, audio, n_repeat): def detect_watermark(self, audio, n_repeat):
bits = [] bits = []
K = 16000 K = 16000
coeff = 2 coeff = 2
for n in range(n_repeat): for n in range(n_repeat):
trunck = audio[(coeff * n) * K: (coeff * n + 1) * K] trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
if len(trunck) != K: if len(trunck) != K:
print('Audio too short, fail to detect watermark') print('Audio too short, fail to detect watermark')
return 'Fail' return 'Fail'
with torch.no_grad(): with torch.no_grad():
signal = torch.FloatTensor(trunck).to(self.device).unsqueeze(0) signal = torch.FloatTensor(trunck).to(self.device).unsqueeze(0)
message_decoded_npy = (self.watermark_model.decode(signal) >= 0.5).int().detach().cpu().numpy().squeeze() message_decoded_npy = (self.watermark_model.decode(signal) >= 0.5).int().detach().cpu().numpy().squeeze()
bits.append(message_decoded_npy) bits.append(message_decoded_npy)
bits = np.stack(bits).reshape(-1, 8) bits = np.stack(bits).reshape(-1, 8)
message = utils.bits_to_string(bits) message = utils.bits_to_string(bits)
return message return message

View File

@ -1,465 +1,465 @@
import math import math
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from utils.tts.openvoice import commons from utils.tts.openvoice import commons
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5): def __init__(self, channels, eps=1e-5):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.eps = eps self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels)) self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels)) self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x): def forward(self, x):
x = x.transpose(1, -1) x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1) return x.transpose(1, -1)
@torch.jit.script @torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0] n_channels_int = n_channels[0]
in_act = input_a + input_b in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :]) t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act acts = t_act * s_act
return acts return acts
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__( def __init__(
self, self,
hidden_channels, hidden_channels,
filter_channels, filter_channels,
n_heads, n_heads,
n_layers, n_layers,
kernel_size=1, kernel_size=1,
p_dropout=0.0, p_dropout=0.0,
window_size=4, window_size=4,
isflow=True, isflow=True,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.filter_channels = filter_channels self.filter_channels = filter_channels
self.n_heads = n_heads self.n_heads = n_heads
self.n_layers = n_layers self.n_layers = n_layers
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.window_size = window_size self.window_size = window_size
# if isflow: # if isflow:
# cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1) # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
# self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1) # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
# self.cond_layer = weight_norm(cond_layer, name='weight') # self.cond_layer = weight_norm(cond_layer, name='weight')
# self.gin_channels = 256 # self.gin_channels = 256
self.cond_layer_idx = self.n_layers self.cond_layer_idx = self.n_layers
if "gin_channels" in kwargs: if "gin_channels" in kwargs:
self.gin_channels = kwargs["gin_channels"] self.gin_channels = kwargs["gin_channels"]
if self.gin_channels != 0: if self.gin_channels != 0:
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels) self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
# vits2 says 3rd block, so idx is 2 by default # vits2 says 3rd block, so idx is 2 by default
self.cond_layer_idx = ( self.cond_layer_idx = (
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2 kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
) )
# logging.debug(self.gin_channels, self.cond_layer_idx) # logging.debug(self.gin_channels, self.cond_layer_idx)
assert ( assert (
self.cond_layer_idx < self.n_layers self.cond_layer_idx < self.n_layers
), "cond_layer_idx should be less than n_layers" ), "cond_layer_idx should be less than n_layers"
self.drop = nn.Dropout(p_dropout) self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList() self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList() self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList() self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList() self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers): for i in range(self.n_layers):
self.attn_layers.append( self.attn_layers.append(
MultiHeadAttention( MultiHeadAttention(
hidden_channels, hidden_channels,
hidden_channels, hidden_channels,
n_heads, n_heads,
p_dropout=p_dropout, p_dropout=p_dropout,
window_size=window_size, window_size=window_size,
) )
) )
self.norm_layers_1.append(LayerNorm(hidden_channels)) self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append( self.ffn_layers.append(
FFN( FFN(
hidden_channels, hidden_channels,
hidden_channels, hidden_channels,
filter_channels, filter_channels,
kernel_size, kernel_size,
p_dropout=p_dropout, p_dropout=p_dropout,
) )
) )
self.norm_layers_2.append(LayerNorm(hidden_channels)) self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, g=None): def forward(self, x, x_mask, g=None):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask x = x * x_mask
for i in range(self.n_layers): for i in range(self.n_layers):
if i == self.cond_layer_idx and g is not None: if i == self.cond_layer_idx and g is not None:
g = self.spk_emb_linear(g.transpose(1, 2)) g = self.spk_emb_linear(g.transpose(1, 2))
g = g.transpose(1, 2) g = g.transpose(1, 2)
x = x + g x = x + g
x = x * x_mask x = x * x_mask
y = self.attn_layers[i](x, x, attn_mask) y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y) y = self.drop(y)
x = self.norm_layers_1[i](x + y) x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask) y = self.ffn_layers[i](x, x_mask)
y = self.drop(y) y = self.drop(y)
x = self.norm_layers_2[i](x + y) x = self.norm_layers_2[i](x + y)
x = x * x_mask x = x * x_mask
return x return x
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(
self, self,
hidden_channels, hidden_channels,
filter_channels, filter_channels,
n_heads, n_heads,
n_layers, n_layers,
kernel_size=1, kernel_size=1,
p_dropout=0.0, p_dropout=0.0,
proximal_bias=False, proximal_bias=False,
proximal_init=True, proximal_init=True,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.filter_channels = filter_channels self.filter_channels = filter_channels
self.n_heads = n_heads self.n_heads = n_heads
self.n_layers = n_layers self.n_layers = n_layers
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.proximal_bias = proximal_bias self.proximal_bias = proximal_bias
self.proximal_init = proximal_init self.proximal_init = proximal_init
self.drop = nn.Dropout(p_dropout) self.drop = nn.Dropout(p_dropout)
self.self_attn_layers = nn.ModuleList() self.self_attn_layers = nn.ModuleList()
self.norm_layers_0 = nn.ModuleList() self.norm_layers_0 = nn.ModuleList()
self.encdec_attn_layers = nn.ModuleList() self.encdec_attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList() self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList() self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList() self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers): for i in range(self.n_layers):
self.self_attn_layers.append( self.self_attn_layers.append(
MultiHeadAttention( MultiHeadAttention(
hidden_channels, hidden_channels,
hidden_channels, hidden_channels,
n_heads, n_heads,
p_dropout=p_dropout, p_dropout=p_dropout,
proximal_bias=proximal_bias, proximal_bias=proximal_bias,
proximal_init=proximal_init, proximal_init=proximal_init,
) )
) )
self.norm_layers_0.append(LayerNorm(hidden_channels)) self.norm_layers_0.append(LayerNorm(hidden_channels))
self.encdec_attn_layers.append( self.encdec_attn_layers.append(
MultiHeadAttention( MultiHeadAttention(
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
) )
) )
self.norm_layers_1.append(LayerNorm(hidden_channels)) self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append( self.ffn_layers.append(
FFN( FFN(
hidden_channels, hidden_channels,
hidden_channels, hidden_channels,
filter_channels, filter_channels,
kernel_size, kernel_size,
p_dropout=p_dropout, p_dropout=p_dropout,
causal=True, causal=True,
) )
) )
self.norm_layers_2.append(LayerNorm(hidden_channels)) self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, h, h_mask): def forward(self, x, x_mask, h, h_mask):
""" """
x: decoder input x: decoder input
h: encoder output h: encoder output
""" """
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype device=x.device, dtype=x.dtype
) )
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask x = x * x_mask
for i in range(self.n_layers): for i in range(self.n_layers):
y = self.self_attn_layers[i](x, x, self_attn_mask) y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y) y = self.drop(y)
x = self.norm_layers_0[i](x + y) x = self.norm_layers_0[i](x + y)
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
y = self.drop(y) y = self.drop(y)
x = self.norm_layers_1[i](x + y) x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask) y = self.ffn_layers[i](x, x_mask)
y = self.drop(y) y = self.drop(y)
x = self.norm_layers_2[i](x + y) x = self.norm_layers_2[i](x + y)
x = x * x_mask x = x * x_mask
return x return x
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
def __init__( def __init__(
self, self,
channels, channels,
out_channels, out_channels,
n_heads, n_heads,
p_dropout=0.0, p_dropout=0.0,
window_size=None, window_size=None,
heads_share=True, heads_share=True,
block_length=None, block_length=None,
proximal_bias=False, proximal_bias=False,
proximal_init=False, proximal_init=False,
): ):
super().__init__() super().__init__()
assert channels % n_heads == 0 assert channels % n_heads == 0
self.channels = channels self.channels = channels
self.out_channels = out_channels self.out_channels = out_channels
self.n_heads = n_heads self.n_heads = n_heads
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.window_size = window_size self.window_size = window_size
self.heads_share = heads_share self.heads_share = heads_share
self.block_length = block_length self.block_length = block_length
self.proximal_bias = proximal_bias self.proximal_bias = proximal_bias
self.proximal_init = proximal_init self.proximal_init = proximal_init
self.attn = None self.attn = None
self.k_channels = channels // n_heads self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1) self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1) self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1) self.conv_v = nn.Conv1d(channels, channels, 1)
self.conv_o = nn.Conv1d(channels, out_channels, 1) self.conv_o = nn.Conv1d(channels, out_channels, 1)
self.drop = nn.Dropout(p_dropout) self.drop = nn.Dropout(p_dropout)
if window_size is not None: if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5 rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter( self.emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev * rel_stddev
) )
self.emb_rel_v = nn.Parameter( self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev * rel_stddev
) )
nn.init.xavier_uniform_(self.conv_q.weight) nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight) nn.init.xavier_uniform_(self.conv_k.weight)
nn.init.xavier_uniform_(self.conv_v.weight) nn.init.xavier_uniform_(self.conv_v.weight)
if proximal_init: if proximal_init:
with torch.no_grad(): with torch.no_grad():
self.conv_k.weight.copy_(self.conv_q.weight) self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias) self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask=None): def forward(self, x, c, attn_mask=None):
q = self.conv_q(x) q = self.conv_q(x)
k = self.conv_k(c) k = self.conv_k(c)
v = self.conv_v(c) v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask) x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x) x = self.conv_o(x)
return x return x
def attention(self, query, key, value, mask=None): def attention(self, query, key, value, mask=None):
# reshape [b, d, t] -> [b, n_h, t, d_k] # reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, t_t = (*key.size(), query.size(2)) b, d, t_s, t_t = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None: if self.window_size is not None:
assert ( assert (
t_s == t_t t_s == t_t
), "Relative attention is only available for self-attention." ), "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys( rel_logits = self._matmul_with_relative_keys(
query / math.sqrt(self.k_channels), key_relative_embeddings query / math.sqrt(self.k_channels), key_relative_embeddings
) )
scores_local = self._relative_position_to_absolute_position(rel_logits) scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local scores = scores + scores_local
if self.proximal_bias: if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention." assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to( scores = scores + self._attention_bias_proximal(t_s).to(
device=scores.device, dtype=scores.dtype device=scores.device, dtype=scores.dtype
) )
if mask is not None: if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4) scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None: if self.block_length is not None:
assert ( assert (
t_s == t_t t_s == t_t
), "Local attention is only available for self-attention." ), "Local attention is only available for self-attention."
block_mask = ( block_mask = (
torch.ones_like(scores) torch.ones_like(scores)
.triu(-self.block_length) .triu(-self.block_length)
.tril(self.block_length) .tril(self.block_length)
) )
scores = scores.masked_fill(block_mask == 0, -1e4) scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn) p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value) output = torch.matmul(p_attn, value)
if self.window_size is not None: if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn) relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings( value_relative_embeddings = self._get_relative_embeddings(
self.emb_rel_v, t_s self.emb_rel_v, t_s
) )
output = output + self._matmul_with_relative_values( output = output + self._matmul_with_relative_values(
relative_weights, value_relative_embeddings relative_weights, value_relative_embeddings
) )
output = ( output = (
output.transpose(2, 3).contiguous().view(b, d, t_t) output.transpose(2, 3).contiguous().view(b, d, t_t)
) # [b, n_h, t_t, d_k] -> [b, d, t_t] ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn return output, p_attn
def _matmul_with_relative_values(self, x, y): def _matmul_with_relative_values(self, x, y):
""" """
x: [b, h, l, m] x: [b, h, l, m]
y: [h or 1, m, d] y: [h or 1, m, d]
ret: [b, h, l, d] ret: [b, h, l, d]
""" """
ret = torch.matmul(x, y.unsqueeze(0)) ret = torch.matmul(x, y.unsqueeze(0))
return ret return ret
def _matmul_with_relative_keys(self, x, y): def _matmul_with_relative_keys(self, x, y):
""" """
x: [b, h, l, d] x: [b, h, l, d]
y: [h or 1, m, d] y: [h or 1, m, d]
ret: [b, h, l, m] ret: [b, h, l, m]
""" """
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret return ret
def _get_relative_embeddings(self, relative_embeddings, length): def _get_relative_embeddings(self, relative_embeddings, length):
2 * self.window_size + 1 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops. # Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0) pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0) slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1 slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0: if pad_length > 0:
padded_relative_embeddings = F.pad( padded_relative_embeddings = F.pad(
relative_embeddings, relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
) )
else: else:
padded_relative_embeddings = relative_embeddings padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[ used_relative_embeddings = padded_relative_embeddings[
:, slice_start_position:slice_end_position :, slice_start_position:slice_end_position
] ]
return used_relative_embeddings return used_relative_embeddings
def _relative_position_to_absolute_position(self, x): def _relative_position_to_absolute_position(self, x):
""" """
x: [b, h, l, 2*l-1] x: [b, h, l, 2*l-1]
ret: [b, h, l, l] ret: [b, h, l, l]
""" """
batch, heads, length, _ = x.size() batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing. # Concat columns of pad to shift from relative to absolute indexing.
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
# Concat extra elements so to add up to shape (len+1, 2*len-1). # Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length]) x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad( x_flat = F.pad(
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
) )
# Reshape and slice out the padded elements. # Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
:, :, :length, length - 1 : :, :, :length, length - 1 :
] ]
return x_final return x_final
def _absolute_position_to_relative_position(self, x): def _absolute_position_to_relative_position(self, x):
""" """
x: [b, h, l, l] x: [b, h, l, l]
ret: [b, h, l, 2*l-1] ret: [b, h, l, 2*l-1]
""" """
batch, heads, length, _ = x.size() batch, heads, length, _ = x.size()
# pad along column # pad along column
x = F.pad( x = F.pad(
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
) )
x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape # add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final return x_final
def _attention_bias_proximal(self, length): def _attention_bias_proximal(self, length):
"""Bias for self-attention to encourage attention to close positions. """Bias for self-attention to encourage attention to close positions.
Args: Args:
length: an integer scalar. length: an integer scalar.
Returns: Returns:
a Tensor with shape [1, 1, length, length] a Tensor with shape [1, 1, length, length]
""" """
r = torch.arange(length, dtype=torch.float32) r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module): class FFN(nn.Module):
def __init__( def __init__(
self, self,
in_channels, in_channels,
out_channels, out_channels,
filter_channels, filter_channels,
kernel_size, kernel_size,
p_dropout=0.0, p_dropout=0.0,
activation=None, activation=None,
causal=False, causal=False,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.filter_channels = filter_channels self.filter_channels = filter_channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.activation = activation self.activation = activation
self.causal = causal self.causal = causal
if causal: if causal:
self.padding = self._causal_padding self.padding = self._causal_padding
else: else:
self.padding = self._same_padding self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
self.drop = nn.Dropout(p_dropout) self.drop = nn.Dropout(p_dropout)
def forward(self, x, x_mask): def forward(self, x, x_mask):
x = self.conv_1(self.padding(x * x_mask)) x = self.conv_1(self.padding(x * x_mask))
if self.activation == "gelu": if self.activation == "gelu":
x = x * torch.sigmoid(1.702 * x) x = x * torch.sigmoid(1.702 * x)
else: else:
x = torch.relu(x) x = torch.relu(x)
x = self.drop(x) x = self.drop(x)
x = self.conv_2(self.padding(x * x_mask)) x = self.conv_2(self.padding(x * x_mask))
return x * x_mask return x * x_mask
def _causal_padding(self, x): def _causal_padding(self, x):
if self.kernel_size == 1: if self.kernel_size == 1:
return x return x
pad_l = self.kernel_size - 1 pad_l = self.kernel_size - 1
pad_r = 0 pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]] padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding)) x = F.pad(x, commons.convert_pad_shape(padding))
return x return x
def _same_padding(self, x): def _same_padding(self, x):
if self.kernel_size == 1: if self.kernel_size == 1:
return x return x
pad_l = (self.kernel_size - 1) // 2 pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2 pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]] padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding)) x = F.pad(x, commons.convert_pad_shape(padding))
return x return x

View File

@ -1,160 +1,160 @@
import math import math
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
def init_weights(m, mean=0.0, std=0.01): def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__ classname = m.__class__.__name__
if classname.find("Conv") != -1: if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std) m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1): def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2) return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape): def convert_pad_shape(pad_shape):
layer = pad_shape[::-1] layer = pad_shape[::-1]
pad_shape = [item for sublist in layer for item in sublist] pad_shape = [item for sublist in layer for item in sublist]
return pad_shape return pad_shape
def intersperse(lst, item): def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1) result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst result[1::2] = lst
return result return result
def kl_divergence(m_p, logs_p, m_q, logs_q): def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)""" """KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5 kl = (logs_q - logs_p) - 0.5
kl += ( kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
) )
return kl return kl
def rand_gumbel(shape): def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows.""" """Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples)) return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x): def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g return g
def slice_segments(x, ids_str, segment_size=4): def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size]) ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)): for i in range(x.size(0)):
idx_str = ids_str[i] idx_str = ids_str[i]
idx_end = idx_str + segment_size idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end] ret[i] = x[i, :, idx_str:idx_end]
return ret return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4): def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size() b, d, t = x.size()
if x_lengths is None: if x_lengths is None:
x_lengths = t x_lengths = t
ids_str_max = x_lengths - segment_size + 1 ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size) ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str return ret, ids_str
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float) position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2 num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1 num_timescales - 1
) )
inv_timescales = min_timescale * torch.exp( inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
) )
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2]) signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length) signal = signal.view(1, channels, length)
return signal return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size() b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device) return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size() b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length): def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask return mask
@torch.jit.script @torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0] n_channels_int = n_channels[0]
in_act = input_a + input_b in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :]) t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act acts = t_act * s_act
return acts return acts
def convert_pad_shape(pad_shape): def convert_pad_shape(pad_shape):
layer = pad_shape[::-1] layer = pad_shape[::-1]
pad_shape = [item for sublist in layer for item in sublist] pad_shape = [item for sublist in layer for item in sublist]
return pad_shape return pad_shape
def shift_1d(x): def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x return x
def sequence_mask(length, max_length=None): def sequence_mask(length, max_length=None):
if max_length is None: if max_length is None:
max_length = length.max() max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device) x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1) return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask): def generate_path(duration, mask):
""" """
duration: [b, 1, t_x] duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x] mask: [b, 1, t_y, t_x]
""" """
b, _, t_y, t_x = mask.shape b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1) cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x) cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y) path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2, 3) * mask path = path.unsqueeze(1).transpose(2, 3) * mask
return path return path
def clip_grad_value_(parameters, clip_value, norm_type=2): def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
parameters = [parameters] parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters)) parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type) norm_type = float(norm_type)
if clip_value is not None: if clip_value is not None:
clip_value = float(clip_value) clip_value = float(clip_value)
total_norm = 0 total_norm = 0
for p in parameters: for p in parameters:
param_norm = p.grad.data.norm(norm_type) param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type total_norm += param_norm.item() ** norm_type
if clip_value is not None: if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value) p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1.0 / norm_type) total_norm = total_norm ** (1.0 / norm_type)
return total_norm return total_norm

View File

@ -1,183 +1,183 @@
import torch import torch
import torch.utils.data import torch.utils.data
from librosa.filters import mel as librosa_mel_fn from librosa.filters import mel as librosa_mel_fn
MAX_WAV_VALUE = 32768.0 MAX_WAV_VALUE = 32768.0
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
""" """
PARAMS PARAMS
------ ------
C: compression factor C: compression factor
""" """
return torch.log(torch.clamp(x, min=clip_val) * C) return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1): def dynamic_range_decompression_torch(x, C=1):
""" """
PARAMS PARAMS
------ ------
C: compression factor used to compress C: compression factor used to compress
""" """
return torch.exp(x) / C return torch.exp(x) / C
def spectral_normalize_torch(magnitudes): def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes) output = dynamic_range_compression_torch(magnitudes)
return output return output
def spectral_de_normalize_torch(magnitudes): def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes) output = dynamic_range_decompression_torch(magnitudes)
return output return output
mel_basis = {} mel_basis = {}
hann_window = {} hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.1: if torch.min(y) < -1.1:
print("min value is ", torch.min(y)) print("min value is ", torch.min(y))
if torch.max(y) > 1.1: if torch.max(y) > 1.1:
print("max value is ", torch.max(y)) print("max value is ", torch.max(y))
global hann_window global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device) dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_size) + "_" + dtype_device wnsize_dtype_device = str(win_size) + "_" + dtype_device
if wnsize_dtype_device not in hann_window: if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device dtype=y.dtype, device=y.device
) )
y = torch.nn.functional.pad( y = torch.nn.functional.pad(
y.unsqueeze(1), y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect", mode="reflect",
) )
y = y.squeeze(1) y = y.squeeze(1)
spec = torch.stft( spec = torch.stft(
y, y,
n_fft, n_fft,
hop_length=hop_size, hop_length=hop_size,
win_length=win_size, win_length=win_size,
window=hann_window[wnsize_dtype_device], window=hann_window[wnsize_dtype_device],
center=center, center=center,
pad_mode="reflect", pad_mode="reflect",
normalized=False, normalized=False,
onesided=True, onesided=True,
return_complex=False, return_complex=False,
) )
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec return spec
def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False): def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False):
# if torch.min(y) < -1.: # if torch.min(y) < -1.:
# print('min value is ', torch.min(y)) # print('min value is ', torch.min(y))
# if torch.max(y) > 1.: # if torch.max(y) > 1.:
# print('max value is ', torch.max(y)) # print('max value is ', torch.max(y))
global hann_window global hann_window
dtype_device = str(y.dtype) + '_' + str(y.device) dtype_device = str(y.dtype) + '_' + str(y.device)
wnsize_dtype_device = str(win_size) + '_' + dtype_device wnsize_dtype_device = str(win_size) + '_' + dtype_device
if wnsize_dtype_device not in hann_window: if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
# ******************** original ************************# # ******************** original ************************#
# y = y.squeeze(1) # y = y.squeeze(1)
# spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], # spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
# center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) # center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
# ******************** ConvSTFT ************************# # ******************** ConvSTFT ************************#
freq_cutoff = n_fft // 2 + 1 freq_cutoff = n_fft // 2 + 1
fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft))) fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft)))
forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1]) forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1])
forward_basis = forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float() forward_basis = forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float()
import torch.nn.functional as F import torch.nn.functional as F
# if center: # if center:
# signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1) # signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1)
assert center is False assert center is False
forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride = hop_size) forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride = hop_size)
spec2 = torch.stack([forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim = -1) spec2 = torch.stack([forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim = -1)
# ******************** Verification ************************# # ******************** Verification ************************#
spec1 = torch.stft(y.squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], spec1 = torch.stft(y.squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
assert torch.allclose(spec1, spec2, atol=1e-4) assert torch.allclose(spec1, spec2, atol=1e-4)
spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6) spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6)
return spec return spec
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
global mel_basis global mel_basis
dtype_device = str(spec.dtype) + "_" + str(spec.device) dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis: if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=spec.dtype, device=spec.device dtype=spec.dtype, device=spec.device
) )
spec = torch.matmul(mel_basis[fmax_dtype_device], spec) spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec) spec = spectral_normalize_torch(spec)
return spec return spec
def mel_spectrogram_torch( def mel_spectrogram_torch(
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
): ):
if torch.min(y) < -1.0: if torch.min(y) < -1.0:
print("min value is ", torch.min(y)) print("min value is ", torch.min(y))
if torch.max(y) > 1.0: if torch.max(y) > 1.0:
print("max value is ", torch.max(y)) print("max value is ", torch.max(y))
global mel_basis, hann_window global mel_basis, hann_window
dtype_device = str(y.dtype) + "_" + str(y.device) dtype_device = str(y.dtype) + "_" + str(y.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device fmax_dtype_device = str(fmax) + "_" + dtype_device
wnsize_dtype_device = str(win_size) + "_" + dtype_device wnsize_dtype_device = str(win_size) + "_" + dtype_device
if fmax_dtype_device not in mel_basis: if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=y.dtype, device=y.device dtype=y.dtype, device=y.device
) )
if wnsize_dtype_device not in hann_window: if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device dtype=y.dtype, device=y.device
) )
y = torch.nn.functional.pad( y = torch.nn.functional.pad(
y.unsqueeze(1), y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect", mode="reflect",
) )
y = y.squeeze(1) y = y.squeeze(1)
spec = torch.stft( spec = torch.stft(
y, y,
n_fft, n_fft,
hop_length=hop_size, hop_length=hop_size,
win_length=win_size, win_length=win_size,
window=hann_window[wnsize_dtype_device], window=hann_window[wnsize_dtype_device],
center=center, center=center,
pad_mode="reflect", pad_mode="reflect",
normalized=False, normalized=False,
onesided=True, onesided=True,
return_complex=False, return_complex=False,
) )
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec) spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec) spec = spectral_normalize_torch(spec)
return spec return spec

View File

@ -1,499 +1,499 @@
import math import math
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from utils.tts.openvoice import commons from utils.tts.openvoice import commons
from utils.tts.openvoice import modules from utils.tts.openvoice import modules
from utils.tts.openvoice import attentions from utils.tts.openvoice import attentions
from torch.nn import Conv1d, ConvTranspose1d, Conv2d from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from utils.tts.openvoice.commons import init_weights, get_padding from utils.tts.openvoice.commons import init_weights, get_padding
class TextEncoder(nn.Module): class TextEncoder(nn.Module):
def __init__(self, def __init__(self,
n_vocab, n_vocab,
out_channels, out_channels,
hidden_channels, hidden_channels,
filter_channels, filter_channels,
n_heads, n_heads,
n_layers, n_layers,
kernel_size, kernel_size,
p_dropout): p_dropout):
super().__init__() super().__init__()
self.n_vocab = n_vocab self.n_vocab = n_vocab
self.out_channels = out_channels self.out_channels = out_channels
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.filter_channels = filter_channels self.filter_channels = filter_channels
self.n_heads = n_heads self.n_heads = n_heads
self.n_layers = n_layers self.n_layers = n_layers
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.emb = nn.Embedding(n_vocab, hidden_channels) self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
self.encoder = attentions.Encoder( self.encoder = attentions.Encoder(
hidden_channels, hidden_channels,
filter_channels, filter_channels,
n_heads, n_heads,
n_layers, n_layers,
kernel_size, kernel_size,
p_dropout) p_dropout)
self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1) self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths): def forward(self, x, x_lengths):
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
x = torch.transpose(x, 1, -1) # [b, h, t] x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.encoder(x * x_mask, x_mask) x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1) m, logs = torch.split(stats, self.out_channels, dim=1)
return x, m, logs, x_mask return x, m, logs, x_mask
class DurationPredictor(nn.Module): class DurationPredictor(nn.Module):
def __init__( def __init__(
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
): ):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.filter_channels = filter_channels self.filter_channels = filter_channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout) self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d( self.conv_1 = nn.Conv1d(
in_channels, filter_channels, kernel_size, padding=kernel_size // 2 in_channels, filter_channels, kernel_size, padding=kernel_size // 2
) )
self.norm_1 = modules.LayerNorm(filter_channels) self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d( self.conv_2 = nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
) )
self.norm_2 = modules.LayerNorm(filter_channels) self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1) self.proj = nn.Conv1d(filter_channels, 1, 1)
if gin_channels != 0: if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, in_channels, 1) self.cond = nn.Conv1d(gin_channels, in_channels, 1)
def forward(self, x, x_mask, g=None): def forward(self, x, x_mask, g=None):
x = torch.detach(x) x = torch.detach(x)
if g is not None: if g is not None:
g = torch.detach(g) g = torch.detach(g)
x = x + self.cond(g) x = x + self.cond(g)
x = self.conv_1(x * x_mask) x = self.conv_1(x * x_mask)
x = torch.relu(x) x = torch.relu(x)
x = self.norm_1(x) x = self.norm_1(x)
x = self.drop(x) x = self.drop(x)
x = self.conv_2(x * x_mask) x = self.conv_2(x * x_mask)
x = torch.relu(x) x = torch.relu(x)
x = self.norm_2(x) x = self.norm_2(x)
x = self.drop(x) x = self.drop(x)
x = self.proj(x * x_mask) x = self.proj(x * x_mask)
return x * x_mask return x * x_mask
class StochasticDurationPredictor(nn.Module): class StochasticDurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
super().__init__() super().__init__()
filter_channels = in_channels # it needs to be removed from future version. filter_channels = in_channels # it needs to be removed from future version.
self.in_channels = in_channels self.in_channels = in_channels
self.filter_channels = filter_channels self.filter_channels = filter_channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.p_dropout = p_dropout self.p_dropout = p_dropout
self.n_flows = n_flows self.n_flows = n_flows
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.log_flow = modules.Log() self.log_flow = modules.Log()
self.flows = nn.ModuleList() self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2)) self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows): for i in range(n_flows):
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.flows.append(modules.Flip()) self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1) self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.post_flows = nn.ModuleList() self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2)) self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4): for i in range(4):
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.post_flows.append(modules.Flip()) self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1) self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1) self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
if gin_channels != 0: if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1) self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
x = torch.detach(x) x = torch.detach(x)
x = self.pre(x) x = self.pre(x)
if g is not None: if g is not None:
g = torch.detach(g) g = torch.detach(g)
x = x + self.cond(g) x = x + self.cond(g)
x = self.convs(x, x_mask) x = self.convs(x, x_mask)
x = self.proj(x) * x_mask x = self.proj(x) * x_mask
if not reverse: if not reverse:
flows = self.flows flows = self.flows
assert w is not None assert w is not None
logdet_tot_q = 0 logdet_tot_q = 0
h_w = self.post_pre(w) h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask) h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask h_w = self.post_proj(h_w) * x_mask
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = e_q z_q = e_q
for flow in self.post_flows: for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q logdet_tot_q += logdet_q
z_u, z1 = torch.split(z_q, [1, 1], 1) z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask z0 = (w - u) * x_mask
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
logdet_tot = 0 logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask) z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet logdet_tot += logdet
z = torch.cat([z0, z1], 1) z = torch.cat([z0, z1], 1)
for flow in flows: for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse) z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet logdet_tot = logdet_tot + logdet
nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
return nll + logq # [b] return nll + logq # [b]
else: else:
flows = list(reversed(self.flows)) flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows: for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse) z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1) z0, z1 = torch.split(z, [1, 1], 1)
logw = z0 logw = z0
return logw return logw
class PosteriorEncoder(nn.Module): class PosteriorEncoder(nn.Module):
def __init__( def __init__(
self, self,
in_channels, in_channels,
out_channels, out_channels,
hidden_channels, hidden_channels,
kernel_size, kernel_size,
dilation_rate, dilation_rate,
n_layers, n_layers,
gin_channels=0, gin_channels=0,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dilation_rate = dilation_rate self.dilation_rate = dilation_rate
self.n_layers = n_layers self.n_layers = n_layers
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1) self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN( self.enc = modules.WN(
hidden_channels, hidden_channels,
kernel_size, kernel_size,
dilation_rate, dilation_rate,
n_layers, n_layers,
gin_channels=gin_channels, gin_channels=gin_channels,
) )
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None, tau=1.0): def forward(self, x, x_lengths, g=None, tau=1.0):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype x.dtype
) )
x = self.pre(x) * x_mask x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g) x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1) m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask
return z, m, logs, x_mask return z, m, logs, x_mask
class Generator(torch.nn.Module): class Generator(torch.nn.Module):
def __init__( def __init__(
self, self,
initial_channel, initial_channel,
resblock, resblock,
resblock_kernel_sizes, resblock_kernel_sizes,
resblock_dilation_sizes, resblock_dilation_sizes,
upsample_rates, upsample_rates,
upsample_initial_channel, upsample_initial_channel,
upsample_kernel_sizes, upsample_kernel_sizes,
gin_channels=0, gin_channels=0,
): ):
super(Generator, self).__init__() super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes) self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates) self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d( self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3 initial_channel, upsample_initial_channel, 7, 1, padding=3
) )
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
self.ups = nn.ModuleList() self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append( self.ups.append(
weight_norm( weight_norm(
ConvTranspose1d( ConvTranspose1d(
upsample_initial_channel // (2**i), upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)), upsample_initial_channel // (2 ** (i + 1)),
k, k,
u, u,
padding=(k - u) // 2, padding=(k - u) // 2,
) )
) )
) )
self.resblocks = nn.ModuleList() self.resblocks = nn.ModuleList()
for i in range(len(self.ups)): for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1)) ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate( for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes) zip(resblock_kernel_sizes, resblock_dilation_sizes)
): ):
self.resblocks.append(resblock(ch, k, d)) self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.ups.apply(init_weights) self.ups.apply(init_weights)
if gin_channels != 0: if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g=None): def forward(self, x, g=None):
x = self.conv_pre(x) x = self.conv_pre(x)
if g is not None: if g is not None:
x = x + self.cond(g) x = x + self.cond(g)
for i in range(self.num_upsamples): for i in range(self.num_upsamples):
x = F.leaky_relu(x, modules.LRELU_SLOPE) x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = self.ups[i](x) x = self.ups[i](x)
xs = None xs = None
for j in range(self.num_kernels): for j in range(self.num_kernels):
if xs is None: if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x) xs = self.resblocks[i * self.num_kernels + j](x)
else: else:
xs += self.resblocks[i * self.num_kernels + j](x) xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels x = xs / self.num_kernels
x = F.leaky_relu(x) x = F.leaky_relu(x)
x = self.conv_post(x) x = self.conv_post(x)
x = torch.tanh(x) x = torch.tanh(x)
return x return x
def remove_weight_norm(self): def remove_weight_norm(self):
print("Removing weight norm...") print("Removing weight norm...")
for layer in self.ups: for layer in self.ups:
remove_weight_norm(layer) remove_weight_norm(layer)
for layer in self.resblocks: for layer in self.resblocks:
layer.remove_weight_norm() layer.remove_weight_norm()
class ReferenceEncoder(nn.Module): class ReferenceEncoder(nn.Module):
""" """
inputs --- [N, Ty/r, n_mels*r] mels inputs --- [N, Ty/r, n_mels*r] mels
outputs --- [N, ref_enc_gru_size] outputs --- [N, ref_enc_gru_size]
""" """
def __init__(self, spec_channels, gin_channels=0, layernorm=True): def __init__(self, spec_channels, gin_channels=0, layernorm=True):
super().__init__() super().__init__()
self.spec_channels = spec_channels self.spec_channels = spec_channels
ref_enc_filters = [32, 32, 64, 64, 128, 128] ref_enc_filters = [32, 32, 64, 64, 128, 128]
K = len(ref_enc_filters) K = len(ref_enc_filters)
filters = [1] + ref_enc_filters filters = [1] + ref_enc_filters
convs = [ convs = [
weight_norm( weight_norm(
nn.Conv2d( nn.Conv2d(
in_channels=filters[i], in_channels=filters[i],
out_channels=filters[i + 1], out_channels=filters[i + 1],
kernel_size=(3, 3), kernel_size=(3, 3),
stride=(2, 2), stride=(2, 2),
padding=(1, 1), padding=(1, 1),
) )
) )
for i in range(K) for i in range(K)
] ]
self.convs = nn.ModuleList(convs) self.convs = nn.ModuleList(convs)
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
self.gru = nn.GRU( self.gru = nn.GRU(
input_size=ref_enc_filters[-1] * out_channels, input_size=ref_enc_filters[-1] * out_channels,
hidden_size=256 // 2, hidden_size=256 // 2,
batch_first=True, batch_first=True,
) )
self.proj = nn.Linear(128, gin_channels) self.proj = nn.Linear(128, gin_channels)
if layernorm: if layernorm:
self.layernorm = nn.LayerNorm(self.spec_channels) self.layernorm = nn.LayerNorm(self.spec_channels)
else: else:
self.layernorm = None self.layernorm = None
def forward(self, inputs, mask=None): def forward(self, inputs, mask=None):
N = inputs.size(0) N = inputs.size(0)
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
if self.layernorm is not None: if self.layernorm is not None:
out = self.layernorm(out) out = self.layernorm(out)
for conv in self.convs: for conv in self.convs:
out = conv(out) out = conv(out)
# out = wn(out) # out = wn(out)
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
T = out.size(1) T = out.size(1)
N = out.size(0) N = out.size(0)
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
self.gru.flatten_parameters() self.gru.flatten_parameters()
memory, out = self.gru(out) # out --- [1, N, 128] memory, out = self.gru(out) # out --- [1, N, 128]
return self.proj(out.squeeze(0)) return self.proj(out.squeeze(0))
def calculate_channels(self, L, kernel_size, stride, pad, n_convs): def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
for i in range(n_convs): for i in range(n_convs):
L = (L - kernel_size + 2 * pad) // stride + 1 L = (L - kernel_size + 2 * pad) // stride + 1
return L return L
class ResidualCouplingBlock(nn.Module): class ResidualCouplingBlock(nn.Module):
def __init__(self, def __init__(self,
channels, channels,
hidden_channels, hidden_channels,
kernel_size, kernel_size,
dilation_rate, dilation_rate,
n_layers, n_layers,
n_flows=4, n_flows=4,
gin_channels=0): gin_channels=0):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dilation_rate = dilation_rate self.dilation_rate = dilation_rate
self.n_layers = n_layers self.n_layers = n_layers
self.n_flows = n_flows self.n_flows = n_flows
self.gin_channels = gin_channels self.gin_channels = gin_channels
self.flows = nn.ModuleList() self.flows = nn.ModuleList()
for i in range(n_flows): for i in range(n_flows):
self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
self.flows.append(modules.Flip()) self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=False): def forward(self, x, x_mask, g=None, reverse=False):
if not reverse: if not reverse:
for flow in self.flows: for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse) x, _ = flow(x, x_mask, g=g, reverse=reverse)
else: else:
for flow in reversed(self.flows): for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, reverse=reverse) x = flow(x, x_mask, g=g, reverse=reverse)
return x return x
class SynthesizerTrn(nn.Module): class SynthesizerTrn(nn.Module):
""" """
Synthesizer for Training Synthesizer for Training
""" """
def __init__( def __init__(
self, self,
n_vocab, n_vocab,
spec_channels, spec_channels,
inter_channels, inter_channels,
hidden_channels, hidden_channels,
filter_channels, filter_channels,
n_heads, n_heads,
n_layers, n_layers,
kernel_size, kernel_size,
p_dropout, p_dropout,
resblock, resblock,
resblock_kernel_sizes, resblock_kernel_sizes,
resblock_dilation_sizes, resblock_dilation_sizes,
upsample_rates, upsample_rates,
upsample_initial_channel, upsample_initial_channel,
upsample_kernel_sizes, upsample_kernel_sizes,
n_speakers=256, n_speakers=256,
gin_channels=256, gin_channels=256,
zero_g=False, zero_g=False,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
self.dec = Generator( self.dec = Generator(
inter_channels, inter_channels,
resblock, resblock,
resblock_kernel_sizes, resblock_kernel_sizes,
resblock_dilation_sizes, resblock_dilation_sizes,
upsample_rates, upsample_rates,
upsample_initial_channel, upsample_initial_channel,
upsample_kernel_sizes, upsample_kernel_sizes,
gin_channels=gin_channels, gin_channels=gin_channels,
) )
self.enc_q = PosteriorEncoder( self.enc_q = PosteriorEncoder(
spec_channels, spec_channels,
inter_channels, inter_channels,
hidden_channels, hidden_channels,
5, 5,
1, 1,
16, 16,
gin_channels=gin_channels, gin_channels=gin_channels,
) )
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
self.n_speakers = n_speakers self.n_speakers = n_speakers
if n_speakers == 0: if n_speakers == 0:
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels) self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
else: else:
self.enc_p = TextEncoder(n_vocab, self.enc_p = TextEncoder(n_vocab,
inter_channels, inter_channels,
hidden_channels, hidden_channels,
filter_channels, filter_channels,
n_heads, n_heads,
n_layers, n_layers,
kernel_size, kernel_size,
p_dropout) p_dropout)
self.sdp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) self.sdp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
self.emb_g = nn.Embedding(n_speakers, gin_channels) self.emb_g = nn.Embedding(n_speakers, gin_channels)
self.zero_g = zero_g self.zero_g = zero_g
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., sdp_ratio=0.2, max_len=None): def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., sdp_ratio=0.2, max_len=None):
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
if self.n_speakers > 0: if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else: else:
g = None g = None
logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * sdp_ratio \ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * sdp_ratio \
+ self.dp(x, x_mask, g=g) * (1 - sdp_ratio) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
w = torch.exp(logw) * x_mask * length_scale w = torch.exp(logw) * x_mask * length_scale
w_ceil = torch.ceil(w) w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = commons.generate_path(w_ceil, attn_mask) attn = commons.generate_path(w_ceil, attn_mask)
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=g, reverse=True) z = self.flow(z_p, y_mask, g=g, reverse=True)
o = self.dec((z * y_mask)[:,:,:max_len], g=g) o = self.dec((z * y_mask)[:,:,:max_len], g=g)
return o, attn, y_mask, (z, z_p, m_p, logs_p) return o, attn, y_mask, (z, z_p, m_p, logs_p)
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0): def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
g_src = sid_src g_src = sid_src
g_tgt = sid_tgt g_tgt = sid_tgt
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=tau) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=tau)
z_p = self.flow(z, y_mask, g=g_src) z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt)) o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt))
return o_hat, y_mask, (z, z_p, z_hat) return o_hat, y_mask, (z, z_p, z_hat)

File diff suppressed because it is too large Load Diff

View File

@ -1,275 +1,275 @@
import os import os
import torch import torch
import argparse import argparse
import gradio as gr import gradio as gr
from zipfile import ZipFile from zipfile import ZipFile
import langid import langid
from openvoice import se_extractor from openvoice import se_extractor
from openvoice.api import BaseSpeakerTTS, ToneColorConverter from openvoice.api import BaseSpeakerTTS, ToneColorConverter
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--share", action='store_true', default=False, help="make link public") parser.add_argument("--share", action='store_true', default=False, help="make link public")
args = parser.parse_args() args = parser.parse_args()
en_ckpt_base = 'checkpoints/base_speakers/EN' en_ckpt_base = 'checkpoints/base_speakers/EN'
zh_ckpt_base = 'checkpoints/base_speakers/ZH' zh_ckpt_base = 'checkpoints/base_speakers/ZH'
ckpt_converter = 'checkpoints/converter' ckpt_converter = 'checkpoints/converter'
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
output_dir = 'outputs' output_dir = 'outputs'
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
# load models # load models
en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device) en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth') en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device) zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth') zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth')
tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device) tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth') tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
# load speaker embeddings # load speaker embeddings
en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device) en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device)
en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device) en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device)
zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device) zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device)
# This online demo mainly supports English and Chinese # This online demo mainly supports English and Chinese
supported_languages = ['zh', 'en'] supported_languages = ['zh', 'en']
def predict(prompt, style, audio_file_pth, agree): def predict(prompt, style, audio_file_pth, agree):
# initialize a empty info # initialize a empty info
text_hint = '' text_hint = ''
# agree with the terms # agree with the terms
if agree == False: if agree == False:
text_hint += '[ERROR] Please accept the Terms & Condition!\n' text_hint += '[ERROR] Please accept the Terms & Condition!\n'
gr.Warning("Please accept the Terms & Condition!") gr.Warning("Please accept the Terms & Condition!")
return ( return (
text_hint, text_hint,
None, None,
None, None,
) )
# first detect the input language # first detect the input language
language_predicted = langid.classify(prompt)[0].strip() language_predicted = langid.classify(prompt)[0].strip()
print(f"Detected language:{language_predicted}") print(f"Detected language:{language_predicted}")
if language_predicted not in supported_languages: if language_predicted not in supported_languages:
text_hint += f"[ERROR] The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}\n" text_hint += f"[ERROR] The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}\n"
gr.Warning( gr.Warning(
f"The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}" f"The detected language {language_predicted} for your input text is not in our Supported Languages: {supported_languages}"
) )
return ( return (
text_hint, text_hint,
None, None,
None, None,
) )
if language_predicted == "zh": if language_predicted == "zh":
tts_model = zh_base_speaker_tts tts_model = zh_base_speaker_tts
source_se = zh_source_se source_se = zh_source_se
language = 'Chinese' language = 'Chinese'
if style not in ['default']: if style not in ['default']:
text_hint += f"[ERROR] The style {style} is not supported for Chinese, which should be in ['default']\n" text_hint += f"[ERROR] The style {style} is not supported for Chinese, which should be in ['default']\n"
gr.Warning(f"The style {style} is not supported for Chinese, which should be in ['default']") gr.Warning(f"The style {style} is not supported for Chinese, which should be in ['default']")
return ( return (
text_hint, text_hint,
None, None,
None, None,
) )
else: else:
tts_model = en_base_speaker_tts tts_model = en_base_speaker_tts
if style == 'default': if style == 'default':
source_se = en_source_default_se source_se = en_source_default_se
else: else:
source_se = en_source_style_se source_se = en_source_style_se
language = 'English' language = 'English'
if style not in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']: if style not in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']:
text_hint += f"[ERROR] The style {style} is not supported for English, which should be in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']\n" text_hint += f"[ERROR] The style {style} is not supported for English, which should be in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']\n"
gr.Warning(f"The style {style} is not supported for English, which should be in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']") gr.Warning(f"The style {style} is not supported for English, which should be in ['default', 'whispering', 'shouting', 'excited', 'cheerful', 'terrified', 'angry', 'sad', 'friendly']")
return ( return (
text_hint, text_hint,
None, None,
None, None,
) )
speaker_wav = audio_file_pth speaker_wav = audio_file_pth
if len(prompt) < 2: if len(prompt) < 2:
text_hint += f"[ERROR] Please give a longer prompt text \n" text_hint += f"[ERROR] Please give a longer prompt text \n"
gr.Warning("Please give a longer prompt text") gr.Warning("Please give a longer prompt text")
return ( return (
text_hint, text_hint,
None, None,
None, None,
) )
if len(prompt) > 200: if len(prompt) > 200:
text_hint += f"[ERROR] Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo and try for your usage \n" text_hint += f"[ERROR] Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo and try for your usage \n"
gr.Warning( gr.Warning(
"Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo for your usage" "Text length limited to 200 characters for this demo, please try shorter text. You can clone our open-source repo for your usage"
) )
return ( return (
text_hint, text_hint,
None, None,
None, None,
) )
# note diffusion_conditioning not used on hifigan (default mode), it will be empty but need to pass it to model.inference # note diffusion_conditioning not used on hifigan (default mode), it will be empty but need to pass it to model.inference
try: try:
target_se, audio_name = se_extractor.get_se(speaker_wav, tone_color_converter, target_dir='processed', vad=True) target_se, audio_name = se_extractor.get_se(speaker_wav, tone_color_converter, target_dir='processed', vad=True)
except Exception as e: except Exception as e:
text_hint += f"[ERROR] Get target tone color error {str(e)} \n" text_hint += f"[ERROR] Get target tone color error {str(e)} \n"
gr.Warning( gr.Warning(
"[ERROR] Get target tone color error {str(e)} \n" "[ERROR] Get target tone color error {str(e)} \n"
) )
return ( return (
text_hint, text_hint,
None, None,
None, None,
) )
src_path = f'{output_dir}/tmp.wav' src_path = f'{output_dir}/tmp.wav'
tts_model.tts(prompt, src_path, speaker=style, language=language) tts_model.tts(prompt, src_path, speaker=style, language=language)
save_path = f'{output_dir}/output.wav' save_path = f'{output_dir}/output.wav'
# Run the tone color converter # Run the tone color converter
encode_message = "@MyShell" encode_message = "@MyShell"
tone_color_converter.convert( tone_color_converter.convert(
audio_src_path=src_path, audio_src_path=src_path,
src_se=source_se, src_se=source_se,
tgt_se=target_se, tgt_se=target_se,
output_path=save_path, output_path=save_path,
message=encode_message) message=encode_message)
text_hint += f'''Get response successfully \n''' text_hint += f'''Get response successfully \n'''
return ( return (
text_hint, text_hint,
save_path, save_path,
speaker_wav, speaker_wav,
) )
title = "MyShell OpenVoice" title = "MyShell OpenVoice"
description = """ description = """
We introduce OpenVoice, a versatile instant voice cloning approach that requires only a short audio clip from the reference speaker to replicate their voice and generate speech in multiple languages. OpenVoice enables granular control over voice styles, including emotion, accent, rhythm, pauses, and intonation, in addition to replicating the tone color of the reference speaker. OpenVoice also achieves zero-shot cross-lingual voice cloning for languages not included in the massive-speaker training set. We introduce OpenVoice, a versatile instant voice cloning approach that requires only a short audio clip from the reference speaker to replicate their voice and generate speech in multiple languages. OpenVoice enables granular control over voice styles, including emotion, accent, rhythm, pauses, and intonation, in addition to replicating the tone color of the reference speaker. OpenVoice also achieves zero-shot cross-lingual voice cloning for languages not included in the massive-speaker training set.
""" """
markdown_table = """ markdown_table = """
<div align="center" style="margin-bottom: 10px;"> <div align="center" style="margin-bottom: 10px;">
| | | | | | | |
| :-----------: | :-----------: | :-----------: | | :-----------: | :-----------: | :-----------: |
| **OpenSource Repo** | **Project Page** | **Join the Community** | | **OpenSource Repo** | **Project Page** | **Join the Community** |
| <div style='text-align: center;'><a style="display:inline-block,align:center" href='https://github.com/myshell-ai/OpenVoice'><img src='https://img.shields.io/github/stars/myshell-ai/OpenVoice?style=social' /></a></div> | [OpenVoice](https://research.myshell.ai/open-voice) | [![Discord](https://img.shields.io/discord/1122227993805336617?color=%239B59B6&label=%20Discord%20)](https://discord.gg/myshell) | | <div style='text-align: center;'><a style="display:inline-block,align:center" href='https://github.com/myshell-ai/OpenVoice'><img src='https://img.shields.io/github/stars/myshell-ai/OpenVoice?style=social' /></a></div> | [OpenVoice](https://research.myshell.ai/open-voice) | [![Discord](https://img.shields.io/discord/1122227993805336617?color=%239B59B6&label=%20Discord%20)](https://discord.gg/myshell) |
</div> </div>
""" """
markdown_table_v2 = """ markdown_table_v2 = """
<div align="center" style="margin-bottom: 2px;"> <div align="center" style="margin-bottom: 2px;">
| | | | | | | | | |
| :-----------: | :-----------: | :-----------: | :-----------: | | :-----------: | :-----------: | :-----------: | :-----------: |
| **OpenSource Repo** | <div style='text-align: center;'><a style="display:inline-block,align:center" href='https://github.com/myshell-ai/OpenVoice'><img src='https://img.shields.io/github/stars/myshell-ai/OpenVoice?style=social' /></a></div> | **Project Page** | [OpenVoice](https://research.myshell.ai/open-voice) | | **OpenSource Repo** | <div style='text-align: center;'><a style="display:inline-block,align:center" href='https://github.com/myshell-ai/OpenVoice'><img src='https://img.shields.io/github/stars/myshell-ai/OpenVoice?style=social' /></a></div> | **Project Page** | [OpenVoice](https://research.myshell.ai/open-voice) |
| | | | | |
| :-----------: | :-----------: | | :-----------: | :-----------: |
**Join the Community** | [![Discord](https://img.shields.io/discord/1122227993805336617?color=%239B59B6&label=%20Discord%20)](https://discord.gg/myshell) | **Join the Community** | [![Discord](https://img.shields.io/discord/1122227993805336617?color=%239B59B6&label=%20Discord%20)](https://discord.gg/myshell) |
</div> </div>
""" """
content = """ content = """
<div> <div>
<strong>If the generated voice does not sound like the reference voice, please refer to <a href='https://github.com/myshell-ai/OpenVoice/blob/main/docs/QA.md'>this QnA</a>.</strong> <strong>For multi-lingual & cross-lingual examples, please refer to <a href='https://github.com/myshell-ai/OpenVoice/blob/main/demo_part2.ipynb'>this jupyter notebook</a>.</strong> <strong>If the generated voice does not sound like the reference voice, please refer to <a href='https://github.com/myshell-ai/OpenVoice/blob/main/docs/QA.md'>this QnA</a>.</strong> <strong>For multi-lingual & cross-lingual examples, please refer to <a href='https://github.com/myshell-ai/OpenVoice/blob/main/demo_part2.ipynb'>this jupyter notebook</a>.</strong>
This online demo mainly supports <strong>English</strong>. The <em>default</em> style also supports <strong>Chinese</strong>. But OpenVoice can adapt to any other language as long as a base speaker is provided. This online demo mainly supports <strong>English</strong>. The <em>default</em> style also supports <strong>Chinese</strong>. But OpenVoice can adapt to any other language as long as a base speaker is provided.
</div> </div>
""" """
wrapped_markdown_content = f"<div style='border: 1px solid #000; padding: 10px;'>{content}</div>" wrapped_markdown_content = f"<div style='border: 1px solid #000; padding: 10px;'>{content}</div>"
examples = [ examples = [
[ [
"今天天气真好,我们一起出去吃饭吧。", "今天天气真好,我们一起出去吃饭吧。",
'default', 'default',
"resources/demo_speaker1.mp3", "resources/demo_speaker1.mp3",
True, True,
],[ ],[
"This audio is generated by open voice with a half-performance model.", "This audio is generated by open voice with a half-performance model.",
'whispering', 'whispering',
"resources/demo_speaker2.mp3", "resources/demo_speaker2.mp3",
True, True,
], ],
[ [
"He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.", "He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.",
'sad', 'sad',
"resources/demo_speaker0.mp3", "resources/demo_speaker0.mp3",
True, True,
], ],
] ]
with gr.Blocks(analytics_enabled=False) as demo: with gr.Blocks(analytics_enabled=False) as demo:
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
gr.Markdown( gr.Markdown(
""" """
## <img src="https://huggingface.co/spaces/myshell-ai/OpenVoice/raw/main/logo.jpg" height="40"/> ## <img src="https://huggingface.co/spaces/myshell-ai/OpenVoice/raw/main/logo.jpg" height="40"/>
""" """
) )
with gr.Row(): with gr.Row():
gr.Markdown(markdown_table_v2) gr.Markdown(markdown_table_v2)
with gr.Row(): with gr.Row():
gr.Markdown(description) gr.Markdown(description)
with gr.Column(): with gr.Column():
gr.Video('https://github.com/myshell-ai/OpenVoice/assets/40556743/3cba936f-82bf-476c-9e52-09f0f417bb2f', autoplay=True) gr.Video('https://github.com/myshell-ai/OpenVoice/assets/40556743/3cba936f-82bf-476c-9e52-09f0f417bb2f', autoplay=True)
with gr.Row(): with gr.Row():
gr.HTML(wrapped_markdown_content) gr.HTML(wrapped_markdown_content)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
input_text_gr = gr.Textbox( input_text_gr = gr.Textbox(
label="Text Prompt", label="Text Prompt",
info="One or two sentences at a time is better. Up to 200 text characters.", info="One or two sentences at a time is better. Up to 200 text characters.",
value="He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.", value="He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered, flour-fattened sauce.",
) )
style_gr = gr.Dropdown( style_gr = gr.Dropdown(
label="Style", label="Style",
info="Select a style of output audio for the synthesised speech. (Chinese only support 'default' now)", info="Select a style of output audio for the synthesised speech. (Chinese only support 'default' now)",
choices=['default', 'whispering', 'cheerful', 'terrified', 'angry', 'sad', 'friendly'], choices=['default', 'whispering', 'cheerful', 'terrified', 'angry', 'sad', 'friendly'],
max_choices=1, max_choices=1,
value="default", value="default",
) )
ref_gr = gr.Audio( ref_gr = gr.Audio(
label="Reference Audio", label="Reference Audio",
info="Click on the ✎ button to upload your own target speaker audio", info="Click on the ✎ button to upload your own target speaker audio",
type="filepath", type="filepath",
value="resources/demo_speaker2.mp3", value="resources/demo_speaker2.mp3",
) )
tos_gr = gr.Checkbox( tos_gr = gr.Checkbox(
label="Agree", label="Agree",
value=False, value=False,
info="I agree to the terms of the cc-by-nc-4.0 license-: https://github.com/myshell-ai/OpenVoice/blob/main/LICENSE", info="I agree to the terms of the cc-by-nc-4.0 license-: https://github.com/myshell-ai/OpenVoice/blob/main/LICENSE",
) )
tts_button = gr.Button("Send", elem_id="send-btn", visible=True) tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
with gr.Column(): with gr.Column():
out_text_gr = gr.Text(label="Info") out_text_gr = gr.Text(label="Info")
audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True) audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
ref_audio_gr = gr.Audio(label="Reference Audio Used") ref_audio_gr = gr.Audio(label="Reference Audio Used")
gr.Examples(examples, gr.Examples(examples,
label="Examples", label="Examples",
inputs=[input_text_gr, style_gr, ref_gr, tos_gr], inputs=[input_text_gr, style_gr, ref_gr, tos_gr],
outputs=[out_text_gr, audio_gr, ref_audio_gr], outputs=[out_text_gr, audio_gr, ref_audio_gr],
fn=predict, fn=predict,
cache_examples=False,) cache_examples=False,)
tts_button.click(predict, [input_text_gr, style_gr, ref_gr, tos_gr], outputs=[out_text_gr, audio_gr, ref_audio_gr]) tts_button.click(predict, [input_text_gr, style_gr, ref_gr, tos_gr], outputs=[out_text_gr, audio_gr, ref_audio_gr])
demo.queue() demo.queue()
demo.launch(debug=True, show_api=True, share=args.share) demo.launch(debug=True, show_api=True, share=args.share)

View File

@ -1,153 +1,152 @@
import os import os
import glob import glob
import torch import torch
import hashlib import hashlib
import librosa import librosa
import base64 import base64
from glob import glob from glob import glob
import numpy as np import numpy as np
from pydub import AudioSegment from pydub import AudioSegment
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
import hashlib import hashlib
import base64 import base64
import librosa import librosa
from whisper_timestamped.transcribe import get_audio_tensor, get_vad_segments from whisper_timestamped.transcribe import get_audio_tensor, get_vad_segments
model_size = "medium" model_size = "medium"
# Run on GPU with FP16 # Run on GPU with FP16
model = None model = None
def split_audio_whisper(audio_path, audio_name, target_dir='processed'): def split_audio_whisper(audio_path, audio_name, target_dir='processed'):
global model global model
if model is None: if model is None:
model = WhisperModel(model_size, device="cuda", compute_type="float16") model = WhisperModel(model_size, device="cuda", compute_type="float16")
audio = AudioSegment.from_file(audio_path) audio = AudioSegment.from_file(audio_path)
max_len = len(audio) max_len = len(audio)
target_folder = os.path.join(target_dir, audio_name) target_folder = os.path.join(target_dir, audio_name)
segments, info = model.transcribe(audio_path, beam_size=5, word_timestamps=True) segments, info = model.transcribe(audio_path, beam_size=5, word_timestamps=True)
segments = list(segments) segments = list(segments)
# create directory # create directory
os.makedirs(target_folder, exist_ok=True) os.makedirs(target_folder, exist_ok=True)
wavs_folder = os.path.join(target_folder, 'wavs') wavs_folder = os.path.join(target_folder, 'wavs')
os.makedirs(wavs_folder, exist_ok=True) os.makedirs(wavs_folder, exist_ok=True)
# segments # segments
s_ind = 0 s_ind = 0
start_time = None start_time = None
for k, w in enumerate(segments): for k, w in enumerate(segments):
# process with the time # process with the time
if k == 0: if k == 0:
start_time = max(0, w.start) start_time = max(0, w.start)
end_time = w.end end_time = w.end
# calculate confidence # calculate confidence
if len(w.words) > 0: if len(w.words) > 0:
confidence = sum([s.probability for s in w.words]) / len(w.words) confidence = sum([s.probability for s in w.words]) / len(w.words)
else: else:
confidence = 0. confidence = 0.
# clean text # clean text
text = w.text.replace('...', '') text = w.text.replace('...', '')
# left 0.08s for each audios # left 0.08s for each audios
audio_seg = audio[int( start_time * 1000) : min(max_len, int(end_time * 1000) + 80)] audio_seg = audio[int( start_time * 1000) : min(max_len, int(end_time * 1000) + 80)]
# segment file name # segment file name
fname = f"{audio_name}_seg{s_ind}.wav" fname = f"{audio_name}_seg{s_ind}.wav"
# filter out the segment shorter than 1.5s and longer than 20s # filter out the segment shorter than 1.5s and longer than 20s
save = audio_seg.duration_seconds > 1.5 and \ save = audio_seg.duration_seconds > 1.5 and \
audio_seg.duration_seconds < 20. and \ audio_seg.duration_seconds < 20. and \
len(text) >= 2 and len(text) < 200 len(text) >= 2 and len(text) < 200
if save: if save:
output_file = os.path.join(wavs_folder, fname) output_file = os.path.join(wavs_folder, fname)
audio_seg.export(output_file, format='wav') audio_seg.export(output_file, format='wav')
if k < len(segments) - 1: if k < len(segments) - 1:
start_time = max(0, segments[k+1].start - 0.08) start_time = max(0, segments[k+1].start - 0.08)
s_ind = s_ind + 1 s_ind = s_ind + 1
return wavs_folder return wavs_folder
def split_audio_vad(audio_path, audio_name, target_dir, split_seconds=10.0): def split_audio_vad(audio_path, audio_name, target_dir, split_seconds=10.0):
SAMPLE_RATE = 16000 SAMPLE_RATE = 16000
audio_vad = get_audio_tensor(audio_path) audio_vad = get_audio_tensor(audio_path)
segments = get_vad_segments( segments = get_vad_segments(
audio_vad, audio_vad,
output_sample=True, output_sample=True,
min_speech_duration=0.1, min_speech_duration=0.1,
min_silence_duration=1, min_silence_duration=1,
method="silero", method="silero",
) )
segments = [(seg["start"], seg["end"]) for seg in segments] segments = [(seg["start"], seg["end"]) for seg in segments]
segments = [(float(s) / SAMPLE_RATE, float(e) / SAMPLE_RATE) for s,e in segments] segments = [(float(s) / SAMPLE_RATE, float(e) / SAMPLE_RATE) for s,e in segments]
print(segments) print(segments)
audio_active = AudioSegment.silent(duration=0) audio_active = AudioSegment.silent(duration=0)
audio = AudioSegment.from_file(audio_path) audio = AudioSegment.from_file(audio_path)
for start_time, end_time in segments: for start_time, end_time in segments:
audio_active += audio[int( start_time * 1000) : int(end_time * 1000)] audio_active += audio[int( start_time * 1000) : int(end_time * 1000)]
audio_dur = audio_active.duration_seconds audio_dur = audio_active.duration_seconds
print(f'after vad: dur = {audio_dur}') print(f'after vad: dur = {audio_dur}')
target_folder = os.path.join(target_dir, audio_name) target_folder = os.path.join(target_dir, audio_name)
wavs_folder = os.path.join(target_folder, 'wavs') wavs_folder = os.path.join(target_folder, 'wavs')
os.makedirs(wavs_folder, exist_ok=True) os.makedirs(wavs_folder, exist_ok=True)
start_time = 0. start_time = 0.
count = 0 count = 0
num_splits = int(np.round(audio_dur / split_seconds)) num_splits = int(np.round(audio_dur / split_seconds))
assert num_splits > 0, 'input audio is too short' assert num_splits > 0, 'input audio is too short'
interval = audio_dur / num_splits interval = audio_dur / num_splits
for i in range(num_splits): for i in range(num_splits):
end_time = min(start_time + interval, audio_dur) end_time = min(start_time + interval, audio_dur)
if i == num_splits - 1: if i == num_splits - 1:
end_time = audio_dur end_time = audio_dur
output_file = f"{wavs_folder}/{audio_name}_seg{count}.wav" output_file = f"{wavs_folder}/{audio_name}_seg{count}.wav"
audio_seg = audio_active[int(start_time * 1000): int(end_time * 1000)] audio_seg = audio_active[int(start_time * 1000): int(end_time * 1000)]
audio_seg.export(output_file, format='wav') audio_seg.export(output_file, format='wav')
start_time = end_time start_time = end_time
count += 1 count += 1
return wavs_folder return wavs_folder
def hash_numpy_array(audio_path): def hash_numpy_array(audio_path):
array, _ = librosa.load(audio_path, sr=None, mono=True) array, _ = librosa.load(audio_path, sr=None, mono=True)
# Convert the array to bytes # Convert the array to bytes
array_bytes = array.tobytes() array_bytes = array.tobytes()
# Calculate the hash of the array bytes # Calculate the hash of the array bytes
hash_object = hashlib.sha256(array_bytes) hash_object = hashlib.sha256(array_bytes)
hash_value = hash_object.digest() hash_value = hash_object.digest()
# Convert the hash value to base64 # Convert the hash value to base64
base64_value = base64.b64encode(hash_value) base64_value = base64.b64encode(hash_value)
return base64_value.decode('utf-8')[:16].replace('/', '_^') return base64_value.decode('utf-8')[:16].replace('/', '_^')
def get_se(audio_path, vc_model, target_dir='processed', vad=True): def get_se(audio_path, vc_model, target_dir='processed', vad=True):
device = vc_model.device device = vc_model.device
version = vc_model.version version = vc_model.version
print("OpenVoice version:", version) print("OpenVoice version:", version)
audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{version}_{hash_numpy_array(audio_path)}" audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{version}_{hash_numpy_array(audio_path)}"
se_path = os.path.join(target_dir, audio_name, 'se.pth') se_path = os.path.join(target_dir, audio_name, 'se.pth')
# if os.path.isfile(se_path): # if os.path.isfile(se_path):
# se = torch.load(se_path).to(device) # se = torch.load(se_path).to(device)
# return se, audio_name # return se, audio_name
# if os.path.isdir(audio_path): # if os.path.isdir(audio_path):
# wavs_folder = audio_path # wavs_folder = audio_path
os.environ["TOKENIZERS_PARALLELISM"] = "false" if vad:
if vad: wavs_folder = split_audio_vad(audio_path, target_dir=target_dir, audio_name=audio_name)
wavs_folder = split_audio_vad(audio_path, target_dir=target_dir, audio_name=audio_name) else:
else: wavs_folder = split_audio_whisper(audio_path, target_dir=target_dir, audio_name=audio_name)
wavs_folder = split_audio_whisper(audio_path, target_dir=target_dir, audio_name=audio_name)
audio_segs = glob(f'{wavs_folder}/*.wav')
audio_segs = glob(f'{wavs_folder}/*.wav') if len(audio_segs) == 0:
if len(audio_segs) == 0: raise NotImplementedError('No audio segments found!')
raise NotImplementedError('No audio segments found!')
return vc_model.extract_se(audio_segs, se_save_path=se_path), audio_name
return vc_model.extract_se(audio_segs, se_save_path=se_path), audio_name

View File

@ -1,79 +1,79 @@
""" from https://github.com/keithito/tacotron """ """ from https://github.com/keithito/tacotron """
from utils.tts.openvoice.text import cleaners from utils.tts.openvoice.text import cleaners
from utils.tts.openvoice.text.symbols import symbols from utils.tts.openvoice.text.symbols import symbols
# Mappings from symbol to numeric ID and vice versa: # Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)} _symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)} _id_to_symbol = {i: s for i, s in enumerate(symbols)}
def text_to_sequence(text, symbols, cleaner_names): def text_to_sequence(text, symbols, cleaner_names):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args: Args:
text: string to convert to a sequence text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through cleaner_names: names of the cleaner functions to run the text through
Returns: Returns:
List of integers corresponding to the symbols in the text List of integers corresponding to the symbols in the text
''' '''
sequence = [] sequence = []
symbol_to_id = {s: i for i, s in enumerate(symbols)} symbol_to_id = {s: i for i, s in enumerate(symbols)}
clean_text = _clean_text(text, cleaner_names) clean_text = _clean_text(text, cleaner_names)
print(clean_text) print(clean_text)
print(f" length:{len(clean_text)}") print(f" length:{len(clean_text)}")
for symbol in clean_text: for symbol in clean_text:
if symbol not in symbol_to_id.keys(): if symbol not in symbol_to_id.keys():
continue continue
symbol_id = symbol_to_id[symbol] symbol_id = symbol_to_id[symbol]
sequence += [symbol_id] sequence += [symbol_id]
print(f" length:{len(sequence)}") print(f" length:{len(sequence)}")
return sequence return sequence
def cleaned_text_to_sequence(cleaned_text, symbols): def cleaned_text_to_sequence(cleaned_text, symbols):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args: Args:
text: string to convert to a sequence text: string to convert to a sequence
Returns: Returns:
List of integers corresponding to the symbols in the text List of integers corresponding to the symbols in the text
''' '''
symbol_to_id = {s: i for i, s in enumerate(symbols)} symbol_to_id = {s: i for i, s in enumerate(symbols)}
sequence = [symbol_to_id[symbol] for symbol in cleaned_text if symbol in symbol_to_id.keys()] sequence = [symbol_to_id[symbol] for symbol in cleaned_text if symbol in symbol_to_id.keys()]
return sequence return sequence
from utils.tts.openvoice.text.symbols import language_tone_start_map from utils.tts.openvoice.text.symbols import language_tone_start_map
def cleaned_text_to_sequence_vits2(cleaned_text, tones, language, symbols, languages): def cleaned_text_to_sequence_vits2(cleaned_text, tones, language, symbols, languages):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text. """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args: Args:
text: string to convert to a sequence text: string to convert to a sequence
Returns: Returns:
List of integers corresponding to the symbols in the text List of integers corresponding to the symbols in the text
""" """
symbol_to_id = {s: i for i, s in enumerate(symbols)} symbol_to_id = {s: i for i, s in enumerate(symbols)}
language_id_map = {s: i for i, s in enumerate(languages)} language_id_map = {s: i for i, s in enumerate(languages)}
phones = [symbol_to_id[symbol] for symbol in cleaned_text] phones = [symbol_to_id[symbol] for symbol in cleaned_text]
tone_start = language_tone_start_map[language] tone_start = language_tone_start_map[language]
tones = [i + tone_start for i in tones] tones = [i + tone_start for i in tones]
lang_id = language_id_map[language] lang_id = language_id_map[language]
lang_ids = [lang_id for i in phones] lang_ids = [lang_id for i in phones]
return phones, tones, lang_ids return phones, tones, lang_ids
def sequence_to_text(sequence): def sequence_to_text(sequence):
'''Converts a sequence of IDs back to a string''' '''Converts a sequence of IDs back to a string'''
result = '' result = ''
for symbol_id in sequence: for symbol_id in sequence:
s = _id_to_symbol[symbol_id] s = _id_to_symbol[symbol_id]
result += s result += s
return result return result
def _clean_text(text, cleaner_names): def _clean_text(text, cleaner_names):
for name in cleaner_names: for name in cleaner_names:
cleaner = getattr(cleaners, name) cleaner = getattr(cleaners, name)
if not cleaner: if not cleaner:
raise Exception('Unknown cleaner: %s' % name) raise Exception('Unknown cleaner: %s' % name)
text = cleaner(text) text = cleaner(text)
return text return text

View File

@ -1,16 +1,16 @@
import re import re
from utils.tts.openvoice.text.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2 from utils.tts.openvoice.text.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2
from utils.tts.openvoice.text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2 from utils.tts.openvoice.text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2
def cjke_cleaners2(text): def cjke_cleaners2(text):
text = re.sub(r'\[ZH\](.*?)\[ZH\]', text = re.sub(r'\[ZH\](.*?)\[ZH\]',
lambda x: chinese_to_ipa(x.group(1))+' ', text) lambda x: chinese_to_ipa(x.group(1))+' ', text)
text = re.sub(r'\[JA\](.*?)\[JA\]', text = re.sub(r'\[JA\](.*?)\[JA\]',
lambda x: japanese_to_ipa2(x.group(1))+' ', text) lambda x: japanese_to_ipa2(x.group(1))+' ', text)
text = re.sub(r'\[KO\](.*?)\[KO\]', text = re.sub(r'\[KO\](.*?)\[KO\]',
lambda x: korean_to_ipa(x.group(1))+' ', text) lambda x: korean_to_ipa(x.group(1))+' ', text)
text = re.sub(r'\[EN\](.*?)\[EN\]', text = re.sub(r'\[EN\](.*?)\[EN\]',
lambda x: english_to_ipa2(x.group(1))+' ', text) lambda x: english_to_ipa2(x.group(1))+' ', text)
text = re.sub(r'\s+$', '', text) text = re.sub(r'\s+$', '', text)
text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
return text return text

View File

@ -1,188 +1,188 @@
""" from https://github.com/keithito/tacotron """ """ from https://github.com/keithito/tacotron """
''' '''
Cleaners are transformations that run over the input text at both training and eval time. Cleaners are transformations that run over the input text at both training and eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
hyperparameter. Some cleaners are English-specific. You'll typically want to use: hyperparameter. Some cleaners are English-specific. You'll typically want to use:
1. "english_cleaners" for English text 1. "english_cleaners" for English text
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
the Unidecode library (https://pypi.python.org/pypi/Unidecode) the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
the symbols in symbols.py to match your data). the symbols in symbols.py to match your data).
''' '''
# Regular expression matching whitespace: # Regular expression matching whitespace:
import re import re
import inflect import inflect
from unidecode import unidecode from unidecode import unidecode
import eng_to_ipa as ipa import eng_to_ipa as ipa
_inflect = inflect.engine() _inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+') _number_re = re.compile(r'[0-9]+')
# List of (regular expression, replacement) pairs for abbreviations: # List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'), ('mrs', 'misess'),
('mr', 'mister'), ('mr', 'mister'),
('dr', 'doctor'), ('dr', 'doctor'),
('st', 'saint'), ('st', 'saint'),
('co', 'company'), ('co', 'company'),
('jr', 'junior'), ('jr', 'junior'),
('maj', 'major'), ('maj', 'major'),
('gen', 'general'), ('gen', 'general'),
('drs', 'doctors'), ('drs', 'doctors'),
('rev', 'reverend'), ('rev', 'reverend'),
('lt', 'lieutenant'), ('lt', 'lieutenant'),
('hon', 'honorable'), ('hon', 'honorable'),
('sgt', 'sergeant'), ('sgt', 'sergeant'),
('capt', 'captain'), ('capt', 'captain'),
('esq', 'esquire'), ('esq', 'esquire'),
('ltd', 'limited'), ('ltd', 'limited'),
('col', 'colonel'), ('col', 'colonel'),
('ft', 'fort'), ('ft', 'fort'),
]] ]]
# List of (ipa, lazy ipa) pairs: # List of (ipa, lazy ipa) pairs:
_lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
('r', 'ɹ'), ('r', 'ɹ'),
('æ', 'e'), ('æ', 'e'),
('ɑ', 'a'), ('ɑ', 'a'),
('ɔ', 'o'), ('ɔ', 'o'),
('ð', 'z'), ('ð', 'z'),
('θ', 's'), ('θ', 's'),
('ɛ', 'e'), ('ɛ', 'e'),
('ɪ', 'i'), ('ɪ', 'i'),
('ʊ', 'u'), ('ʊ', 'u'),
('ʒ', 'ʥ'), ('ʒ', 'ʥ'),
('ʤ', 'ʥ'), ('ʤ', 'ʥ'),
('ˈ', ''), ('ˈ', ''),
]] ]]
# List of (ipa, lazy ipa2) pairs: # List of (ipa, lazy ipa2) pairs:
_lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
('r', 'ɹ'), ('r', 'ɹ'),
('ð', 'z'), ('ð', 'z'),
('θ', 's'), ('θ', 's'),
('ʒ', 'ʑ'), ('ʒ', 'ʑ'),
('ʤ', ''), ('ʤ', ''),
('ˈ', ''), ('ˈ', ''),
]] ]]
# List of (ipa, ipa2) pairs # List of (ipa, ipa2) pairs
_ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
('r', 'ɹ'), ('r', 'ɹ'),
('ʤ', ''), ('ʤ', ''),
('ʧ', '') ('ʧ', '')
]] ]]
def expand_abbreviations(text): def expand_abbreviations(text):
for regex, replacement in _abbreviations: for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
return text return text
def collapse_whitespace(text): def collapse_whitespace(text):
return re.sub(r'\s+', ' ', text) return re.sub(r'\s+', ' ', text)
def _remove_commas(m): def _remove_commas(m):
return m.group(1).replace(',', '') return m.group(1).replace(',', '')
def _expand_decimal_point(m): def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ') return m.group(1).replace('.', ' point ')
def _expand_dollars(m): def _expand_dollars(m):
match = m.group(1) match = m.group(1)
parts = match.split('.') parts = match.split('.')
if len(parts) > 2: if len(parts) > 2:
return match + ' dollars' # Unexpected format return match + ' dollars' # Unexpected format
dollars = int(parts[0]) if parts[0] else 0 dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents: if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars' dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents' cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
elif dollars: elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars' dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit) return '%s %s' % (dollars, dollar_unit)
elif cents: elif cents:
cent_unit = 'cent' if cents == 1 else 'cents' cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit) return '%s %s' % (cents, cent_unit)
else: else:
return 'zero dollars' return 'zero dollars'
def _expand_ordinal(m): def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0)) return _inflect.number_to_words(m.group(0))
def _expand_number(m): def _expand_number(m):
num = int(m.group(0)) num = int(m.group(0))
if num > 1000 and num < 3000: if num > 1000 and num < 3000:
if num == 2000: if num == 2000:
return 'two thousand' return 'two thousand'
elif num > 2000 and num < 2010: elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100) return 'two thousand ' + _inflect.number_to_words(num % 100)
elif num % 100 == 0: elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred' return _inflect.number_to_words(num // 100) + ' hundred'
else: else:
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
else: else:
return _inflect.number_to_words(num, andword='') return _inflect.number_to_words(num, andword='')
def normalize_numbers(text): def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text) text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text) text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text) text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text) text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text) text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text) text = re.sub(_number_re, _expand_number, text)
return text return text
def mark_dark_l(text): def mark_dark_l(text):
return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text) return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
def english_to_ipa(text): def english_to_ipa(text):
text = unidecode(text).lower() text = unidecode(text).lower()
text = expand_abbreviations(text) text = expand_abbreviations(text)
text = normalize_numbers(text) text = normalize_numbers(text)
phonemes = ipa.convert(text) phonemes = ipa.convert(text)
phonemes = collapse_whitespace(phonemes) phonemes = collapse_whitespace(phonemes)
return phonemes return phonemes
def english_to_lazy_ipa(text): def english_to_lazy_ipa(text):
text = english_to_ipa(text) text = english_to_ipa(text)
for regex, replacement in _lazy_ipa: for regex, replacement in _lazy_ipa:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
return text return text
def english_to_ipa2(text): def english_to_ipa2(text):
text = english_to_ipa(text) text = english_to_ipa(text)
text = mark_dark_l(text) text = mark_dark_l(text)
for regex, replacement in _ipa_to_ipa2: for regex, replacement in _ipa_to_ipa2:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
return text.replace('...', '') return text.replace('...', '')
def english_to_lazy_ipa2(text): def english_to_lazy_ipa2(text):
text = english_to_ipa(text) text = english_to_ipa(text)
for regex, replacement in _lazy_ipa2: for regex, replacement in _lazy_ipa2:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
return text return text

View File

@ -1,326 +1,326 @@
import os import os
import sys import sys
import re import re
from pypinyin import lazy_pinyin, BOPOMOFO from pypinyin import lazy_pinyin, BOPOMOFO
import jieba import jieba
import cn2an import cn2an
import logging import logging
# List of (Latin alphabet, bopomofo) pairs: # List of (Latin alphabet, bopomofo) pairs:
_latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
('a', 'ㄟˉ'), ('a', 'ㄟˉ'),
('b', 'ㄅㄧˋ'), ('b', 'ㄅㄧˋ'),
('c', 'ㄙㄧˉ'), ('c', 'ㄙㄧˉ'),
('d', 'ㄉㄧˋ'), ('d', 'ㄉㄧˋ'),
('e', 'ㄧˋ'), ('e', 'ㄧˋ'),
('f', 'ㄝˊㄈㄨˋ'), ('f', 'ㄝˊㄈㄨˋ'),
('g', 'ㄐㄧˋ'), ('g', 'ㄐㄧˋ'),
('h', 'ㄝˇㄑㄩˋ'), ('h', 'ㄝˇㄑㄩˋ'),
('i', 'ㄞˋ'), ('i', 'ㄞˋ'),
('j', 'ㄐㄟˋ'), ('j', 'ㄐㄟˋ'),
('k', 'ㄎㄟˋ'), ('k', 'ㄎㄟˋ'),
('l', 'ㄝˊㄛˋ'), ('l', 'ㄝˊㄛˋ'),
('m', 'ㄝˊㄇㄨˋ'), ('m', 'ㄝˊㄇㄨˋ'),
('n', 'ㄣˉ'), ('n', 'ㄣˉ'),
('o', 'ㄡˉ'), ('o', 'ㄡˉ'),
('p', 'ㄆㄧˉ'), ('p', 'ㄆㄧˉ'),
('q', 'ㄎㄧㄡˉ'), ('q', 'ㄎㄧㄡˉ'),
('r', 'ㄚˋ'), ('r', 'ㄚˋ'),
('s', 'ㄝˊㄙˋ'), ('s', 'ㄝˊㄙˋ'),
('t', 'ㄊㄧˋ'), ('t', 'ㄊㄧˋ'),
('u', 'ㄧㄡˉ'), ('u', 'ㄧㄡˉ'),
('v', 'ㄨㄧˉ'), ('v', 'ㄨㄧˉ'),
('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'), ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
('x', 'ㄝˉㄎㄨˋㄙˋ'), ('x', 'ㄝˉㄎㄨˋㄙˋ'),
('y', 'ㄨㄞˋ'), ('y', 'ㄨㄞˋ'),
('z', 'ㄗㄟˋ') ('z', 'ㄗㄟˋ')
]] ]]
# List of (bopomofo, romaji) pairs: # List of (bopomofo, romaji) pairs:
_bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [ _bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [
('ㄅㄛ', 'p⁼wo'), ('ㄅㄛ', 'p⁼wo'),
('ㄆㄛ', 'pʰwo'), ('ㄆㄛ', 'pʰwo'),
('ㄇㄛ', 'mwo'), ('ㄇㄛ', 'mwo'),
('ㄈㄛ', 'fwo'), ('ㄈㄛ', 'fwo'),
('', 'p⁼'), ('', 'p⁼'),
('', ''), ('', ''),
('', 'm'), ('', 'm'),
('', 'f'), ('', 'f'),
('', 't⁼'), ('', 't⁼'),
('', ''), ('', ''),
('', 'n'), ('', 'n'),
('', 'l'), ('', 'l'),
('', 'k⁼'), ('', 'k⁼'),
('', ''), ('', ''),
('', 'h'), ('', 'h'),
('', 'ʧ⁼'), ('', 'ʧ⁼'),
('', 'ʧʰ'), ('', 'ʧʰ'),
('', 'ʃ'), ('', 'ʃ'),
('', 'ʦ`⁼'), ('', 'ʦ`⁼'),
('', 'ʦ`ʰ'), ('', 'ʦ`ʰ'),
('', 's`'), ('', 's`'),
('', 'ɹ`'), ('', 'ɹ`'),
('', 'ʦ⁼'), ('', 'ʦ⁼'),
('', 'ʦʰ'), ('', 'ʦʰ'),
('', 's'), ('', 's'),
('', 'a'), ('', 'a'),
('', 'o'), ('', 'o'),
('', 'ə'), ('', 'ə'),
('', 'e'), ('', 'e'),
('', 'ai'), ('', 'ai'),
('', 'ei'), ('', 'ei'),
('', 'au'), ('', 'au'),
('', 'ou'), ('', 'ou'),
('ㄧㄢ', 'yeNN'), ('ㄧㄢ', 'yeNN'),
('', 'aNN'), ('', 'aNN'),
('ㄧㄣ', 'iNN'), ('ㄧㄣ', 'iNN'),
('', 'əNN'), ('', 'əNN'),
('', 'aNg'), ('', 'aNg'),
('ㄧㄥ', 'iNg'), ('ㄧㄥ', 'iNg'),
('ㄨㄥ', 'uNg'), ('ㄨㄥ', 'uNg'),
('ㄩㄥ', 'yuNg'), ('ㄩㄥ', 'yuNg'),
('', 'əNg'), ('', 'əNg'),
('', 'əɻ'), ('', 'əɻ'),
('', 'i'), ('', 'i'),
('', 'u'), ('', 'u'),
('', 'ɥ'), ('', 'ɥ'),
('ˉ', ''), ('ˉ', ''),
('ˊ', ''), ('ˊ', ''),
('ˇ', '↓↑'), ('ˇ', '↓↑'),
('ˋ', ''), ('ˋ', ''),
('˙', ''), ('˙', ''),
('', ','), ('', ','),
('', '.'), ('', '.'),
('', '!'), ('', '!'),
('', '?'), ('', '?'),
('', '-') ('', '-')
]] ]]
# List of (romaji, ipa) pairs: # List of (romaji, ipa) pairs:
_romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ _romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
('ʃy', 'ʃ'), ('ʃy', 'ʃ'),
('ʧʰy', 'ʧʰ'), ('ʧʰy', 'ʧʰ'),
('ʧ⁼y', 'ʧ⁼'), ('ʧ⁼y', 'ʧ⁼'),
('NN', 'n'), ('NN', 'n'),
('Ng', 'ŋ'), ('Ng', 'ŋ'),
('y', 'j'), ('y', 'j'),
('h', 'x') ('h', 'x')
]] ]]
# List of (bopomofo, ipa) pairs: # List of (bopomofo, ipa) pairs:
_bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ _bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
('ㄅㄛ', 'p⁼wo'), ('ㄅㄛ', 'p⁼wo'),
('ㄆㄛ', 'pʰwo'), ('ㄆㄛ', 'pʰwo'),
('ㄇㄛ', 'mwo'), ('ㄇㄛ', 'mwo'),
('ㄈㄛ', 'fwo'), ('ㄈㄛ', 'fwo'),
('', 'p⁼'), ('', 'p⁼'),
('', ''), ('', ''),
('', 'm'), ('', 'm'),
('', 'f'), ('', 'f'),
('', 't⁼'), ('', 't⁼'),
('', ''), ('', ''),
('', 'n'), ('', 'n'),
('', 'l'), ('', 'l'),
('', 'k⁼'), ('', 'k⁼'),
('', ''), ('', ''),
('', 'x'), ('', 'x'),
('', 'tʃ⁼'), ('', 'tʃ⁼'),
('', 'tʃʰ'), ('', 'tʃʰ'),
('', 'ʃ'), ('', 'ʃ'),
('', 'ts`⁼'), ('', 'ts`⁼'),
('', 'ts`ʰ'), ('', 'ts`ʰ'),
('', 's`'), ('', 's`'),
('', 'ɹ`'), ('', 'ɹ`'),
('', 'ts⁼'), ('', 'ts⁼'),
('', 'tsʰ'), ('', 'tsʰ'),
('', 's'), ('', 's'),
('', 'a'), ('', 'a'),
('', 'o'), ('', 'o'),
('', 'ə'), ('', 'ə'),
('', 'ɛ'), ('', 'ɛ'),
('', 'aɪ'), ('', 'aɪ'),
('', 'eɪ'), ('', 'eɪ'),
('', 'ɑʊ'), ('', 'ɑʊ'),
('', ''), ('', ''),
('ㄧㄢ', 'jɛn'), ('ㄧㄢ', 'jɛn'),
('ㄩㄢ', 'ɥæn'), ('ㄩㄢ', 'ɥæn'),
('', 'an'), ('', 'an'),
('ㄧㄣ', 'in'), ('ㄧㄣ', 'in'),
('ㄩㄣ', 'ɥn'), ('ㄩㄣ', 'ɥn'),
('', 'ən'), ('', 'ən'),
('', 'ɑŋ'), ('', 'ɑŋ'),
('ㄧㄥ', ''), ('ㄧㄥ', ''),
('ㄨㄥ', 'ʊŋ'), ('ㄨㄥ', 'ʊŋ'),
('ㄩㄥ', 'jʊŋ'), ('ㄩㄥ', 'jʊŋ'),
('', 'əŋ'), ('', 'əŋ'),
('', 'əɻ'), ('', 'əɻ'),
('', 'i'), ('', 'i'),
('', 'u'), ('', 'u'),
('', 'ɥ'), ('', 'ɥ'),
('ˉ', ''), ('ˉ', ''),
('ˊ', ''), ('ˊ', ''),
('ˇ', '↓↑'), ('ˇ', '↓↑'),
('ˋ', ''), ('ˋ', ''),
('˙', ''), ('˙', ''),
('', ','), ('', ','),
('', '.'), ('', '.'),
('', '!'), ('', '!'),
('', '?'), ('', '?'),
('', '-') ('', '-')
]] ]]
# List of (bopomofo, ipa2) pairs: # List of (bopomofo, ipa2) pairs:
_bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ _bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
('ㄅㄛ', 'pwo'), ('ㄅㄛ', 'pwo'),
('ㄆㄛ', 'pʰwo'), ('ㄆㄛ', 'pʰwo'),
('ㄇㄛ', 'mwo'), ('ㄇㄛ', 'mwo'),
('ㄈㄛ', 'fwo'), ('ㄈㄛ', 'fwo'),
('', 'p'), ('', 'p'),
('', ''), ('', ''),
('', 'm'), ('', 'm'),
('', 'f'), ('', 'f'),
('', 't'), ('', 't'),
('', ''), ('', ''),
('', 'n'), ('', 'n'),
('', 'l'), ('', 'l'),
('', 'k'), ('', 'k'),
('', ''), ('', ''),
('', 'h'), ('', 'h'),
('', ''), ('', ''),
('', 'tɕʰ'), ('', 'tɕʰ'),
('', 'ɕ'), ('', 'ɕ'),
('', ''), ('', ''),
('', 'tʂʰ'), ('', 'tʂʰ'),
('', 'ʂ'), ('', 'ʂ'),
('', 'ɻ'), ('', 'ɻ'),
('', 'ts'), ('', 'ts'),
('', 'tsʰ'), ('', 'tsʰ'),
('', 's'), ('', 's'),
('', 'a'), ('', 'a'),
('', 'o'), ('', 'o'),
('', 'ɤ'), ('', 'ɤ'),
('', 'ɛ'), ('', 'ɛ'),
('', 'aɪ'), ('', 'aɪ'),
('', 'eɪ'), ('', 'eɪ'),
('', 'ɑʊ'), ('', 'ɑʊ'),
('', ''), ('', ''),
('ㄧㄢ', 'jɛn'), ('ㄧㄢ', 'jɛn'),
('ㄩㄢ', 'yæn'), ('ㄩㄢ', 'yæn'),
('', 'an'), ('', 'an'),
('ㄧㄣ', 'in'), ('ㄧㄣ', 'in'),
('ㄩㄣ', 'yn'), ('ㄩㄣ', 'yn'),
('', 'ən'), ('', 'ən'),
('', 'ɑŋ'), ('', 'ɑŋ'),
('ㄧㄥ', ''), ('ㄧㄥ', ''),
('ㄨㄥ', 'ʊŋ'), ('ㄨㄥ', 'ʊŋ'),
('ㄩㄥ', 'jʊŋ'), ('ㄩㄥ', 'jʊŋ'),
('', 'ɤŋ'), ('', 'ɤŋ'),
('', 'əɻ'), ('', 'əɻ'),
('', 'i'), ('', 'i'),
('', 'u'), ('', 'u'),
('', 'y'), ('', 'y'),
('ˉ', '˥'), ('ˉ', '˥'),
('ˊ', '˧˥'), ('ˊ', '˧˥'),
('ˇ', '˨˩˦'), ('ˇ', '˨˩˦'),
('ˋ', '˥˩'), ('ˋ', '˥˩'),
('˙', ''), ('˙', ''),
('', ','), ('', ','),
('', '.'), ('', '.'),
('', '!'), ('', '!'),
('', '?'), ('', '?'),
('', '-') ('', '-')
]] ]]
def number_to_chinese(text): def number_to_chinese(text):
numbers = re.findall(r'\d+(?:\.?\d+)?', text) numbers = re.findall(r'\d+(?:\.?\d+)?', text)
for number in numbers: for number in numbers:
text = text.replace(number, cn2an.an2cn(number), 1) text = text.replace(number, cn2an.an2cn(number), 1)
return text return text
def chinese_to_bopomofo(text): def chinese_to_bopomofo(text):
text = text.replace('', '').replace('', '').replace('', '') text = text.replace('', '').replace('', '').replace('', '')
words = jieba.lcut(text, cut_all=False) words = jieba.lcut(text, cut_all=False)
text = '' text = ''
for word in words: for word in words:
bopomofos = lazy_pinyin(word, BOPOMOFO) bopomofos = lazy_pinyin(word, BOPOMOFO)
if not re.search('[\u4e00-\u9fff]', word): if not re.search('[\u4e00-\u9fff]', word):
text += word text += word
continue continue
for i in range(len(bopomofos)): for i in range(len(bopomofos)):
bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\', bopomofos[i]) bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\', bopomofos[i])
if text != '': if text != '':
text += ' ' text += ' '
text += ''.join(bopomofos) text += ''.join(bopomofos)
return text return text
def latin_to_bopomofo(text): def latin_to_bopomofo(text):
for regex, replacement in _latin_to_bopomofo: for regex, replacement in _latin_to_bopomofo:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
return text return text
def bopomofo_to_romaji(text): def bopomofo_to_romaji(text):
for regex, replacement in _bopomofo_to_romaji: for regex, replacement in _bopomofo_to_romaji:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
return text return text
def bopomofo_to_ipa(text): def bopomofo_to_ipa(text):
for regex, replacement in _bopomofo_to_ipa: for regex, replacement in _bopomofo_to_ipa:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
return text return text
def bopomofo_to_ipa2(text): def bopomofo_to_ipa2(text):
for regex, replacement in _bopomofo_to_ipa2: for regex, replacement in _bopomofo_to_ipa2:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
return text return text
def chinese_to_romaji(text): def chinese_to_romaji(text):
text = number_to_chinese(text) text = number_to_chinese(text)
text = chinese_to_bopomofo(text) text = chinese_to_bopomofo(text)
text = latin_to_bopomofo(text) text = latin_to_bopomofo(text)
text = bopomofo_to_romaji(text) text = bopomofo_to_romaji(text)
text = re.sub('i([aoe])', r'y\1', text) text = re.sub('i([aoe])', r'y\1', text)
text = re.sub('u([aoəe])', r'w\1', text) text = re.sub('u([aoəe])', r'w\1', text)
text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\\2', text) text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\\2', text)
return text return text
def chinese_to_lazy_ipa(text): def chinese_to_lazy_ipa(text):
text = chinese_to_romaji(text) text = chinese_to_romaji(text)
for regex, replacement in _romaji_to_ipa: for regex, replacement in _romaji_to_ipa:
text = re.sub(regex, replacement, text) text = re.sub(regex, replacement, text)
return text return text
def chinese_to_ipa(text): def chinese_to_ipa(text):
text = number_to_chinese(text) text = number_to_chinese(text)
text = chinese_to_bopomofo(text) text = chinese_to_bopomofo(text)
text = latin_to_bopomofo(text) text = latin_to_bopomofo(text)
text = bopomofo_to_ipa(text) text = bopomofo_to_ipa(text)
text = re.sub('i([aoe])', r'j\1', text) text = re.sub('i([aoe])', r'j\1', text)
text = re.sub('u([aoəe])', r'w\1', text) text = re.sub('u([aoəe])', r'w\1', text)
text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\\2', text) text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\\2', text)
return text return text
def chinese_to_ipa2(text): def chinese_to_ipa2(text):
text = number_to_chinese(text) text = number_to_chinese(text)
text = chinese_to_bopomofo(text) text = chinese_to_bopomofo(text)
text = latin_to_bopomofo(text) text = latin_to_bopomofo(text)
text = bopomofo_to_ipa2(text) text = bopomofo_to_ipa2(text)
text = re.sub(r'i([aoe])', r'j\1', text) text = re.sub(r'i([aoe])', r'j\1', text)
text = re.sub(r'u([aoəe])', r'w\1', text) text = re.sub(r'u([aoəe])', r'w\1', text)
text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\\2', text) text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\\2', text)
text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text) text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text)
return text return text

View File

@ -1,88 +1,88 @@
''' '''
Defines the set of symbols used in text input to the model. Defines the set of symbols used in text input to the model.
''' '''
# japanese_cleaners # japanese_cleaners
# _pad = '_' # _pad = '_'
# _punctuation = ',.!?-' # _punctuation = ',.!?-'
# _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ ' # _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
'''# japanese_cleaners2 '''# japanese_cleaners2
_pad = '_' _pad = '_'
_punctuation = ',.!?-~…' _punctuation = ',.!?-~…'
_letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ ' _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
''' '''
'''# korean_cleaners '''# korean_cleaners
_pad = '_' _pad = '_'
_punctuation = ',.!?…~' _punctuation = ',.!?…~'
_letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ ' _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
''' '''
'''# chinese_cleaners '''# chinese_cleaners
_pad = '_' _pad = '_'
_punctuation = ',。!?—…' _punctuation = ',。!?—…'
_letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ ' _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
''' '''
# # zh_ja_mixture_cleaners # # zh_ja_mixture_cleaners
# _pad = '_' # _pad = '_'
# _punctuation = ',.!?-~…' # _punctuation = ',.!?-~…'
# _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ ' # _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
'''# sanskrit_cleaners '''# sanskrit_cleaners
_pad = '_' _pad = '_'
_punctuation = '' _punctuation = ''
_letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ ' _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ '
''' '''
'''# cjks_cleaners '''# cjks_cleaners
_pad = '_' _pad = '_'
_punctuation = ',.!?-~…' _punctuation = ',.!?-~…'
_letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ ' _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ '
''' '''
'''# thai_cleaners '''# thai_cleaners
_pad = '_' _pad = '_'
_punctuation = '.!? ' _punctuation = '.!? '
_letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์' _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์'
''' '''
# # cjke_cleaners2 # # cjke_cleaners2
_pad = '_' _pad = '_'
_punctuation = ',.!?-~…' _punctuation = ',.!?-~…'
_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ' _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
'''# shanghainese_cleaners '''# shanghainese_cleaners
_pad = '_' _pad = '_'
_punctuation = ',.!?…' _punctuation = ',.!?…'
_letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 ' _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 '
''' '''
'''# chinese_dialect_cleaners '''# chinese_dialect_cleaners
_pad = '_' _pad = '_'
_punctuation = ',.!?~…─' _punctuation = ',.!?~…─'
_letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ ' _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ '
''' '''
# Export all symbols: # Export all symbols:
symbols = [_pad] + list(_punctuation) + list(_letters) symbols = [_pad] + list(_punctuation) + list(_letters)
# Special symbol ids # Special symbol ids
SPACE_ID = symbols.index(" ") SPACE_ID = symbols.index(" ")
num_ja_tones = 1 num_ja_tones = 1
num_kr_tones = 1 num_kr_tones = 1
num_zh_tones = 6 num_zh_tones = 6
num_en_tones = 4 num_en_tones = 4
language_tone_start_map = { language_tone_start_map = {
"ZH": 0, "ZH": 0,
"JP": num_zh_tones, "JP": num_zh_tones,
"EN": num_zh_tones + num_ja_tones, "EN": num_zh_tones + num_ja_tones,
'KR': num_zh_tones + num_ja_tones + num_en_tones, 'KR': num_zh_tones + num_ja_tones + num_en_tones,
} }

View File

@ -1,209 +1,209 @@
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
import numpy as np import numpy as np
DEFAULT_MIN_BIN_WIDTH = 1e-3 DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3 DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3 DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform( def piecewise_rational_quadratic_transform(
inputs, inputs,
unnormalized_widths, unnormalized_widths,
unnormalized_heights, unnormalized_heights,
unnormalized_derivatives, unnormalized_derivatives,
inverse=False, inverse=False,
tails=None, tails=None,
tail_bound=1.0, tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH, min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, min_derivative=DEFAULT_MIN_DERIVATIVE,
): ):
if tails is None: if tails is None:
spline_fn = rational_quadratic_spline spline_fn = rational_quadratic_spline
spline_kwargs = {} spline_kwargs = {}
else: else:
spline_fn = unconstrained_rational_quadratic_spline spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {"tails": tails, "tail_bound": tail_bound} spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn( outputs, logabsdet = spline_fn(
inputs=inputs, inputs=inputs,
unnormalized_widths=unnormalized_widths, unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights, unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives, unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse, inverse=inverse,
min_bin_width=min_bin_width, min_bin_width=min_bin_width,
min_bin_height=min_bin_height, min_bin_height=min_bin_height,
min_derivative=min_derivative, min_derivative=min_derivative,
**spline_kwargs **spline_kwargs
) )
return outputs, logabsdet return outputs, logabsdet
def searchsorted(bin_locations, inputs, eps=1e-6): def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps bin_locations[..., -1] += eps
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
def unconstrained_rational_quadratic_spline( def unconstrained_rational_quadratic_spline(
inputs, inputs,
unnormalized_widths, unnormalized_widths,
unnormalized_heights, unnormalized_heights,
unnormalized_derivatives, unnormalized_derivatives,
inverse=False, inverse=False,
tails="linear", tails="linear",
tail_bound=1.0, tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH, min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, min_derivative=DEFAULT_MIN_DERIVATIVE,
): ):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs) outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs) logabsdet = torch.zeros_like(inputs)
if tails == "linear": if tails == "linear":
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1) constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask] outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0 logabsdet[outside_interval_mask] = 0
else: else:
raise RuntimeError("{} tails are not implemented.".format(tails)) raise RuntimeError("{} tails are not implemented.".format(tails))
( (
outputs[inside_interval_mask], outputs[inside_interval_mask],
logabsdet[inside_interval_mask], logabsdet[inside_interval_mask],
) = rational_quadratic_spline( ) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask], inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :], unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :], unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse, inverse=inverse,
left=-tail_bound, left=-tail_bound,
right=tail_bound, right=tail_bound,
bottom=-tail_bound, bottom=-tail_bound,
top=tail_bound, top=tail_bound,
min_bin_width=min_bin_width, min_bin_width=min_bin_width,
min_bin_height=min_bin_height, min_bin_height=min_bin_height,
min_derivative=min_derivative, min_derivative=min_derivative,
) )
return outputs, logabsdet return outputs, logabsdet
def rational_quadratic_spline( def rational_quadratic_spline(
inputs, inputs,
unnormalized_widths, unnormalized_widths,
unnormalized_heights, unnormalized_heights,
unnormalized_derivatives, unnormalized_derivatives,
inverse=False, inverse=False,
left=0.0, left=0.0,
right=1.0, right=1.0,
bottom=0.0, bottom=0.0,
top=1.0, top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH, min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT, min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, min_derivative=DEFAULT_MIN_DERIVATIVE,
): ):
if torch.min(inputs) < left or torch.max(inputs) > right: if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain") raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1] num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0: if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins") raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0: if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins") raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, dim=-1) widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1) cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
cumwidths = (right - left) * cumwidths + left cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left cumwidths[..., 0] = left
cumwidths[..., -1] = right cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1] widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives) derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, dim=-1) heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1) cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
cumheights = (top - bottom) * cumheights + bottom cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom cumheights[..., 0] = bottom
cumheights[..., -1] = top cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1] heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse: if inverse:
bin_idx = searchsorted(cumheights, inputs)[..., None] bin_idx = searchsorted(cumheights, inputs)[..., None]
else: else:
bin_idx = searchsorted(cumwidths, inputs)[..., None] bin_idx = searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_bin_widths = widths.gather(-1, bin_idx)[..., 0] input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
delta = heights / widths delta = heights / widths
input_delta = delta.gather(-1, bin_idx)[..., 0] input_delta = delta.gather(-1, bin_idx)[..., 0]
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
input_heights = heights.gather(-1, bin_idx)[..., 0] input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse: if inverse:
a = (inputs - input_cumheights) * ( a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives) ) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * ( b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta input_derivatives + input_derivatives_plus_one - 2 * input_delta
) )
c = -input_delta * (inputs - input_cumheights) c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all() assert (discriminant >= 0).all()
root = (2 * c) / (-b - torch.sqrt(discriminant)) root = (2 * c) / (-b - torch.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root) theta_one_minus_theta = root * (1 - root)
denominator = input_delta + ( denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta * theta_one_minus_theta
) )
derivative_numerator = input_delta.pow(2) * ( derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2) input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta + 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2) + input_derivatives * (1 - root).pow(2)
) )
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet return outputs, -logabsdet
else: else:
theta = (inputs - input_cumwidths) / input_bin_widths theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta) theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * ( numerator = input_heights * (
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
) )
denominator = input_delta + ( denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta * theta_one_minus_theta
) )
outputs = input_cumheights + numerator / denominator outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * ( derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2) input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta + 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2) + input_derivatives * (1 - theta).pow(2)
) )
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet return outputs, logabsdet

View File

@ -1,194 +1,194 @@
import re import re
import json import json
import numpy as np import numpy as np
def get_hparams_from_file(config_path): def get_hparams_from_file(config_path):
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, "r", encoding="utf-8") as f:
data = f.read() data = f.read()
config = json.loads(data) config = json.loads(data)
hparams = HParams(**config) hparams = HParams(**config)
return hparams return hparams
class HParams: class HParams:
def __init__(self, **kwargs): def __init__(self, **kwargs):
for k, v in kwargs.items(): for k, v in kwargs.items():
if type(v) == dict: if type(v) == dict:
v = HParams(**v) v = HParams(**v)
self[k] = v self[k] = v
def keys(self): def keys(self):
return self.__dict__.keys() return self.__dict__.keys()
def items(self): def items(self):
return self.__dict__.items() return self.__dict__.items()
def values(self): def values(self):
return self.__dict__.values() return self.__dict__.values()
def __len__(self): def __len__(self):
return len(self.__dict__) return len(self.__dict__)
def __getitem__(self, key): def __getitem__(self, key):
return getattr(self, key) return getattr(self, key)
def __setitem__(self, key, value): def __setitem__(self, key, value):
return setattr(self, key, value) return setattr(self, key, value)
def __contains__(self, key): def __contains__(self, key):
return key in self.__dict__ return key in self.__dict__
def __repr__(self): def __repr__(self):
return self.__dict__.__repr__() return self.__dict__.__repr__()
def string_to_bits(string, pad_len=8): def string_to_bits(string, pad_len=8):
# Convert each character to its ASCII value # Convert each character to its ASCII value
ascii_values = [ord(char) for char in string] ascii_values = [ord(char) for char in string]
# Convert ASCII values to binary representation # Convert ASCII values to binary representation
binary_values = [bin(value)[2:].zfill(8) for value in ascii_values] binary_values = [bin(value)[2:].zfill(8) for value in ascii_values]
# Convert binary strings to integer arrays # Convert binary strings to integer arrays
bit_arrays = [[int(bit) for bit in binary] for binary in binary_values] bit_arrays = [[int(bit) for bit in binary] for binary in binary_values]
# Convert list of arrays to NumPy array # Convert list of arrays to NumPy array
numpy_array = np.array(bit_arrays) numpy_array = np.array(bit_arrays)
numpy_array_full = np.zeros((pad_len, 8), dtype=numpy_array.dtype) numpy_array_full = np.zeros((pad_len, 8), dtype=numpy_array.dtype)
numpy_array_full[:, 2] = 1 numpy_array_full[:, 2] = 1
max_len = min(pad_len, len(numpy_array)) max_len = min(pad_len, len(numpy_array))
numpy_array_full[:max_len] = numpy_array[:max_len] numpy_array_full[:max_len] = numpy_array[:max_len]
return numpy_array_full return numpy_array_full
def bits_to_string(bits_array): def bits_to_string(bits_array):
# Convert each row of the array to a binary string # Convert each row of the array to a binary string
binary_values = [''.join(str(bit) for bit in row) for row in bits_array] binary_values = [''.join(str(bit) for bit in row) for row in bits_array]
# Convert binary strings to ASCII values # Convert binary strings to ASCII values
ascii_values = [int(binary, 2) for binary in binary_values] ascii_values = [int(binary, 2) for binary in binary_values]
# Convert ASCII values to characters # Convert ASCII values to characters
output_string = ''.join(chr(value) for value in ascii_values) output_string = ''.join(chr(value) for value in ascii_values)
return output_string return output_string
def split_sentence(text, min_len=10, language_str='[EN]'): def split_sentence(text, min_len=10, language_str='[EN]'):
if language_str in ['EN']: if language_str in ['EN']:
sentences = split_sentences_latin(text, min_len=min_len) sentences = split_sentences_latin(text, min_len=min_len)
else: else:
sentences = split_sentences_zh(text, min_len=min_len) sentences = split_sentences_zh(text, min_len=min_len)
return sentences return sentences
def split_sentences_latin(text, min_len=10): def split_sentences_latin(text, min_len=10):
"""Split Long sentences into list of short ones """Split Long sentences into list of short ones
Args: Args:
str: Input sentences. str: Input sentences.
Returns: Returns:
List[str]: list of output sentences. List[str]: list of output sentences.
""" """
# deal with dirty sentences # deal with dirty sentences
text = re.sub('[。!?;]', '.', text) text = re.sub('[。!?;]', '.', text)
text = re.sub('[]', ',', text) text = re.sub('[]', ',', text)
text = re.sub('[“”]', '"', text) text = re.sub('[“”]', '"', text)
text = re.sub('[]', "'", text) text = re.sub('[]', "'", text)
text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text) text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
text = re.sub('[\n\t ]+', ' ', text) text = re.sub('[\n\t ]+', ' ', text)
text = re.sub('([,.!?;])', r'\1 $#!', text) text = re.sub('([,.!?;])', r'\1 $#!', text)
# split # split
sentences = [s.strip() for s in text.split('$#!')] sentences = [s.strip() for s in text.split('$#!')]
if len(sentences[-1]) == 0: del sentences[-1] if len(sentences[-1]) == 0: del sentences[-1]
new_sentences = [] new_sentences = []
new_sent = [] new_sent = []
count_len = 0 count_len = 0
for ind, sent in enumerate(sentences): for ind, sent in enumerate(sentences):
# print(sent) # print(sent)
new_sent.append(sent) new_sent.append(sent)
count_len += len(sent.split(" ")) count_len += len(sent.split(" "))
if count_len > min_len or ind == len(sentences) - 1: if count_len > min_len or ind == len(sentences) - 1:
count_len = 0 count_len = 0
new_sentences.append(' '.join(new_sent)) new_sentences.append(' '.join(new_sent))
new_sent = [] new_sent = []
return merge_short_sentences_latin(new_sentences) return merge_short_sentences_latin(new_sentences)
def merge_short_sentences_latin(sens): def merge_short_sentences_latin(sens):
"""Avoid short sentences by merging them with the following sentence. """Avoid short sentences by merging them with the following sentence.
Args: Args:
List[str]: list of input sentences. List[str]: list of input sentences.
Returns: Returns:
List[str]: list of output sentences. List[str]: list of output sentences.
""" """
sens_out = [] sens_out = []
for s in sens: for s in sens:
# If the previous sentence is too short, merge them with # If the previous sentence is too short, merge them with
# the current sentence. # the current sentence.
if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2: if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
sens_out[-1] = sens_out[-1] + " " + s sens_out[-1] = sens_out[-1] + " " + s
else: else:
sens_out.append(s) sens_out.append(s)
try: try:
if len(sens_out[-1].split(" ")) <= 2: if len(sens_out[-1].split(" ")) <= 2:
sens_out[-2] = sens_out[-2] + " " + sens_out[-1] sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
sens_out.pop(-1) sens_out.pop(-1)
except: except:
pass pass
return sens_out return sens_out
def split_sentences_zh(text, min_len=10): def split_sentences_zh(text, min_len=10):
text = re.sub('[。!?;]', '.', text) text = re.sub('[。!?;]', '.', text)
text = re.sub('[]', ',', text) text = re.sub('[]', ',', text)
# 将文本中的换行符、空格和制表符替换为空格 # 将文本中的换行符、空格和制表符替换为空格
text = re.sub('[\n\t ]+', ' ', text) text = re.sub('[\n\t ]+', ' ', text)
# 在标点符号后添加一个空格 # 在标点符号后添加一个空格
text = re.sub('([,.!?;])', r'\1 $#!', text) text = re.sub('([,.!?;])', r'\1 $#!', text)
# 分隔句子并去除前后空格 # 分隔句子并去除前后空格
# sentences = [s.strip() for s in re.split('(。|||)', text)] # sentences = [s.strip() for s in re.split('(。|||)', text)]
sentences = [s.strip() for s in text.split('$#!')] sentences = [s.strip() for s in text.split('$#!')]
if len(sentences[-1]) == 0: del sentences[-1] if len(sentences[-1]) == 0: del sentences[-1]
new_sentences = [] new_sentences = []
new_sent = [] new_sent = []
count_len = 0 count_len = 0
for ind, sent in enumerate(sentences): for ind, sent in enumerate(sentences):
new_sent.append(sent) new_sent.append(sent)
count_len += len(sent) count_len += len(sent)
if count_len > min_len or ind == len(sentences) - 1: if count_len > min_len or ind == len(sentences) - 1:
count_len = 0 count_len = 0
new_sentences.append(' '.join(new_sent)) new_sentences.append(' '.join(new_sent))
new_sent = [] new_sent = []
return merge_short_sentences_zh(new_sentences) return merge_short_sentences_zh(new_sentences)
def merge_short_sentences_zh(sens): def merge_short_sentences_zh(sens):
# return sens # return sens
"""Avoid short sentences by merging them with the following sentence. """Avoid short sentences by merging them with the following sentence.
Args: Args:
List[str]: list of input sentences. List[str]: list of input sentences.
Returns: Returns:
List[str]: list of output sentences. List[str]: list of output sentences.
""" """
sens_out = [] sens_out = []
for s in sens: for s in sens:
# If the previous sentense is too short, merge them with # If the previous sentense is too short, merge them with
# the current sentence. # the current sentence.
if len(sens_out) > 0 and len(sens_out[-1]) <= 2: if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
sens_out[-1] = sens_out[-1] + " " + s sens_out[-1] = sens_out[-1] + " " + s
else: else:
sens_out.append(s) sens_out.append(s)
try: try:
if len(sens_out[-1]) <= 2: if len(sens_out[-1]) <= 2:
sens_out[-2] = sens_out[-2] + " " + sens_out[-1] sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
sens_out.pop(-1) sens_out.pop(-1)
except: except:
pass pass
return sens_out return sens_out

View File

@ -1,57 +0,0 @@
{
"_version_": "v2",
"data": {
"sampling_rate": 22050,
"filter_length": 1024,
"hop_length": 256,
"win_length": 1024,
"n_speakers": 0
},
"model": {
"zero_g": true,
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0.1,
"resblock": "1",
"resblock_kernel_sizes": [
3,
7,
11
],
"resblock_dilation_sizes": [
[
1,
3,
5
],
[
1,
3,
5
],
[
1,
3,
5
]
],
"upsample_rates": [
8,
8,
2,
2
],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [
16,
16,
4,
4
],
"gin_channels": 256
}
}

View File

@ -1,346 +1,352 @@
import os import os
import re import re
from glob import glob from glob import glob
from tqdm.auto import tqdm import hashlib
import soundfile as sf from tqdm.auto import tqdm
import numpy as np import soundfile as sf
import torch import numpy as np
from typing import Optional, Union import torch
# melo from typing import Optional, Union
from melo.api import TTS # melo
from melo.utils import get_text_for_tts_infer from melo.api import TTS
# openvoice from melo.utils import get_text_for_tts_infer
from .openvoice import se_extractor # openvoice
from .openvoice.api import ToneColorConverter from .openvoice import se_extractor
from .openvoice.mel_processing import spectrogram_torch from .openvoice.api import ToneColorConverter
# torchaudio from .openvoice.mel_processing import spectrogram_torch
import torchaudio.functional as F # torchaudio
import torchaudio.functional as F
# 存储 BASE SPEAKER 的 embedding(source_se) 的路径 # 存储 BASE SPEAKER 的 embedding(source_se) 的路径
current_file_path = os.path.abspath(__file__) SOURCE_SE_DIR = r"D:\python\OpenVoice\checkpoints_v2\base_speakers\ses"
utils_dir = os.path.dirname(os.path.dirname(current_file_path))
# 存储缓存文件的路径
SOURCE_SE_DIR = os.path.join(utils_dir,'assets','ses') CACHE_PATH = r"D:\python\OpenVoice\processed"
# 存储缓存文件的路径 OPENVOICE_BASE_TTS={
CACHE_PATH = r"/tmp/openvoice_cache" "model_type": "open_voice_base_tts",
# 转换的语言
OPENVOICE_BASE_TTS={ "language": "ZH",
"model_type": "open_voice_base_tts", }
# 转换的语言
"language": "ZH", OPENVOICE_TONE_COLOR_CONVERTER={
} "model_type": "open_voice_converter",
# 模型参数路径
converter_path = os.path.join(os.path.dirname(current_file_path),'openvoice_model') "converter_path": r"D:\python\OpenVoice\checkpoints_v2\converter",
OPENVOICE_TONE_COLOR_CONVERTER={ }
"model_type": "open_voice_converter",
# 模型参数路径 class TextToSpeech:
"converter_path": converter_path, def __init__(self,
} use_tone_convert=True,
device="cuda",
class TextToSpeech: debug:bool=False,
def __init__(self, ):
use_tone_convert=True, self.debug = debug
device="cuda", self.device = device
debug:bool=False, self.use_tone_convert = use_tone_convert
): # 默认的源说话人 se
self.debug = debug self.source_se = None
self.device = device # 默认的目标说话人 se
self.use_tone_convert = use_tone_convert self.target_se = None
# 默认的源说话人 se
self.source_se = None self.initialize_base_tts(**OPENVOICE_BASE_TTS)
# 默认的目标说话人 se if self.debug:
self.target_se = None print("use tone converter is", self.use_tone_convert)
if self.use_tone_convert:
self.initialize_base_tts(**OPENVOICE_BASE_TTS) self.initialize_tone_color_converter(**OPENVOICE_TONE_COLOR_CONVERTER)
if self.debug: self.initialize_source_se()
print("use tone converter is", self.use_tone_convert)
if self.use_tone_convert:
self.initialize_tone_color_converter(**OPENVOICE_TONE_COLOR_CONVERTER) def initialize_tone_color_converter(self, **kwargs):
self.initialize_source_se() """
初始化 tone color converter
"""
def initialize_tone_color_converter(self, **kwargs): model_type = kwargs.pop('model_type')
""" self.tone_color_converter_model_type = model_type
初始化 tone color converter if model_type == 'open_voice_converter':
""" # 加载模型
model_type = kwargs.pop('model_type') converter_path = kwargs.pop('converter_path')
self.tone_color_converter_model_type = model_type self.tone_color_converter = ToneColorConverter(f'{converter_path}/config.json', self.device)
if model_type == 'open_voice_converter': self.tone_color_converter.load_ckpt(f'{converter_path}/checkpoint.pth')
# 加载模型 if self.debug:
converter_path = kwargs.pop('converter_path') print("load tone color converter successfully!")
self.tone_color_converter = ToneColorConverter(f'{converter_path}/config.json', self.device) else:
self.tone_color_converter.load_ckpt(f'{converter_path}/checkpoint.pth') raise NotImplementedError(f"only [open_voice_converter] model type expected, but get [{model_type}]. ")
if self.debug:
print("load tone color converter successfully!") def initialize_base_tts(self, **kwargs):
else: """
raise NotImplementedError(f"only [open_voice_converter] model type expected, but get [{model_type}]. ") 初始化 base tts model
"""
def initialize_base_tts(self, **kwargs): model_type = kwargs.pop('model_type')
""" self.base_tts_model_type = model_type
初始化 base tts model if model_type == "open_voice_base_tts":
""" language = kwargs.pop('language')
model_type = kwargs.pop('model_type') self.base_tts_model = TTS(language=language, device=self.device)
self.base_tts_model_type = model_type speaker_ids = self.base_tts_model.hps.data.spk2id
if model_type == "open_voice_base_tts": flag = False
language = kwargs.pop('language') for speaker_key in speaker_ids.keys():
self.base_tts_model = TTS(language=language, device=self.device) if flag:
speaker_ids = self.base_tts_model.hps.data.spk2id Warning(f'loaded model has more than one speaker, only the first speaker is used. The input speaker ids are {speaker_ids}')
flag = False break
for speaker_key in speaker_ids.keys(): self.speaker_id = speaker_ids[speaker_key]
if flag: self.speaker_key = speaker_key.lower().replace('_', '-')
Warning(f'loaded model has more than one speaker, only the first speaker is used. The input speaker ids are {speaker_ids}') flag=True
break if self.debug:
self.speaker_id = speaker_ids[speaker_key] print("load base tts model successfully!")
self.speaker_key = speaker_key.lower().replace('_', '-') # 第一次使用tts时会加载bert模型
flag=True self._base_tts("初始化bert模型。")
if self.debug: else:
print("load base tts model successfully!") raise NotImplementedError(f"only [open_voice_base_tts] model type expected, but get [{model_type}]. ")
# 第一次使用tts时会加载bert模型
self._base_tts("初始化bert模型。") def initialize_source_se(self):
else: """
raise NotImplementedError(f"only [open_voice_base_tts] model type expected, but get [{model_type}]. ") 初始化source se
"""
def initialize_source_se(self): if self.source_se is not None:
""" Warning("replace source speaker embedding with new source speaker embedding!")
初始化source se self.source_se = torch.load(os.path.join(SOURCE_SE_DIR, f"{self.speaker_key}.pth"), map_location=self.device)
"""
if self.source_se is not None: def initialize_target_se(self, se: Union[np.ndarray, torch.Tensor]):
Warning("replace source speaker embedding with new source speaker embedding!") """
self.source_se = torch.load(os.path.join(SOURCE_SE_DIR, f"{self.speaker_key}.pth"), map_location=self.device) 设置 target se
param:
def initialize_target_se(self, se: Union[np.ndarray, torch.Tensor]): se: 输入的se类型可以为np.ndarray或torch.Tensor
""" """
设置 target se if self.target_se is not None:
param: Warning("replace target source speaker embedding with new target speaker embedding!")
se: 输入的se类型可以为np.ndarray或torch.Tensor if isinstance(se, np.ndarray):
""" self.target_se = torch.tensor(se.astype(np.float32)).to(self.device)
if self.target_se is not None: elif isinstance(se, torch.Tensor):
Warning("replace target source speaker embedding with new target speaker embedding!") self.target_se = se.float().to(self.device)
if isinstance(se, np.ndarray):
self.target_se = torch.tensor(se.astype(np.float32)).to(self.device) def audio2numpy(self, audio_data: Union[bytes, np.ndarray]):
elif isinstance(se, torch.Tensor): """
self.target_se = se.float().to(self.device) 将字节流的audio转为numpy类型也可以传入numpy类型
return: np.float32
#语音转numpy """
def audio2numpy(self, audio_data: Union[bytes, np.ndarray]): # TODO 是否归一化判断
""" if isinstance(audio_data, bytes):
将字节流的audio转为numpy类型也可以传入numpy类型 audio_data = np.frombuffer(audio_data, dtype=np.int16).flatten().astype(np.float32) / 32768.0
return: np.float32 elif isinstance(audio_data, np.ndarray):
""" if audio_data.dtype != np.float32:
# TODO 是否归一化判断 audio_data = audio_data.astype(np.int16).flatten().astype(np.float32) / 32768.0
if isinstance(audio_data, bytes): else:
audio_data = np.frombuffer(audio_data, dtype=np.int16).flatten().astype(np.float32) / 32768.0 raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
elif isinstance(audio_data, np.ndarray): return audio_data
if audio_data.dtype != np.float32:
audio_data = audio_data.astype(np.int16).flatten().astype(np.float32) / 32768.0 def audio2emb(self, audio_data: Union[bytes, np.ndarray], rate=44100, vad=True):
else: """
raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}") 将输入的字节流/numpy类型的audio转为speaker embedding
return audio_data param:
audio_data: 输入的音频字节
def audio2emb(self, audio_data: Union[bytes, np.ndarray], rate=44100, vad=True): rate: 输入音频的采样率
""" vad: 是否使用vad模型
将输入的字节流/numpy类型的audio转为speaker embedding return: np.ndarray
param: """
audio_data: 输入的音频字节 audio_data = self.audio2numpy(audio_data)
rate: 输入音频的采样率
vad: 是否使用vad模型 from scipy.io import wavfile
return: np.ndarray audio_path = os.path.join(CACHE_PATH, "tmp.wav")
""" wavfile.write(audio_path, rate=rate, data=audio_data)
audio_data = self.audio2numpy(audio_data)
if not os.path.exists(CACHE_PATH): se, _ = se_extractor.get_se(audio_path, self.tone_color_converter, target_dir=CACHE_PATH, vad=False)
os.makedirs(CACHE_PATH) # device = self.tone_color_converter.device
# version = self.tone_color_converter.version
from scipy.io import wavfile # if self.debug:
audio_path = os.path.join(CACHE_PATH, "tmp.wav") # print("OpenVoice version:", version)
wavfile.write(audio_path, rate=rate, data=audio_data)
# audio_name = f"tmp_{version}_{hashlib.sha256(audio_data.tobytes()).hexdigest()[:16].replace('/','_^')}"
se, _ = se_extractor.get_se(audio_path, self.tone_color_converter, target_dir=CACHE_PATH, vad=False)
return se.cpu().detach().numpy()
# if vad:
def tensor2numpy(self, audio_data: torch.Tensor): # wavs_folder = se_extractor.split_audio_vad(audio_path, target_dir=CACHE_PATH, audio_name=audio_name)
""" # else:
tensor类型转numpy # wavs_folder = se_extractor.split_audio_whisper(audio_data, target_dir=CACHE_PATH, audio_name=audio_name)
"""
return audio_data.cpu().detach().float().numpy() # audio_segs = glob(f'{wavs_folder}/*.wav')
# if len(audio_segs) == 0:
def numpy2bytes(self, audio_data): # raise NotImplementedError('No audio segments found!')
if isinstance(audio_data, np.ndarray): # # se, _ = se_extractor.get_se(audio_data, self.tone_color_converter, CACHE_PATH, vad=False)
if audio_data.dtype == np.dtype('float32'): # se = self.tone_color_converter.extract_se(audio_segs)
audio_data = np.int16(audio_data * np.iinfo(np.int16).max) return se.cpu().detach().numpy()
audio_data = audio_data.tobytes()
return audio_data def tensor2numpy(self, audio_data: torch.Tensor):
else: """
raise TypeError("audio_data must be a numpy array") tensor类型转numpy
"""
def _base_tts(self, return audio_data.cpu().detach().float().numpy()
text: str,
sdp_ratio=0.2, def numpy2bytes(self, audio_data: np.ndarray):
noise_scale=0.6, """
noise_scale_w=0.8, numpy类型转bytes
speed=1.0, """
quite=True): return (audio_data*32768.0).astype(np.int32).tobytes()
"""
base语音合成 def _base_tts(self,
param: text: str,
text: 要合成的文本 sdp_ratio=0.2,
sdp_ratio: SDP在合成时的占比, 理论上此比率越高, 合成的语音语调方差越大. noise_scale=0.6,
noise_scale: 样本噪声张量的噪声标度 noise_scale_w=0.8,
noise_scale_w: 推理中随机持续时间预测器的噪声标度 speed=1.0,
speed: 说话语速 quite=True):
quite: 是否显示进度条 """
return: base语音合成
audio: tensor param:
sr: 生成音频的采样速率 text: 要合成的文本
""" sdp_ratio: SDP在合成时的占比, 理论上此比率越高, 合成的语音语调方差越大.
speaker_id = self.speaker_id noise_scale: 样本噪声张量的噪声标度
if self.base_tts_model_type != "open_voice_base_tts": noise_scale_w: 推理中随机持续时间预测器的噪声标度
raise NotImplementedError("only [open_voice_base_tts] model type expected.") speed: 说话语速
language = self.base_tts_model.language quite: 是否显示进度条
texts = self.base_tts_model.split_sentences_into_pieces(text, language, quite) return:
audio_list = [] audio: tensor
if quite: sr: 生成音频的采样速率
tx = texts """
else: speaker_id = self.speaker_id
tx = tqdm(texts) if self.base_tts_model_type != "open_voice_base_tts":
for t in tx: raise NotImplementedError("only [open_voice_base_tts] model type expected.")
if language in ['EN', 'ZH_MIX_EN']: language = self.base_tts_model.language
t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t) texts = self.base_tts_model.split_sentences_into_pieces(text, language, quite)
device = self.base_tts_model.device audio_list = []
bert, ja_bert, phones, tones, lang_ids = get_text_for_tts_infer(t, language, self.base_tts_model.hps, device, self.base_tts_model.symbol_to_id) if quite:
with torch.no_grad(): tx = texts
x_tst = phones.to(device).unsqueeze(0) else:
tones = tones.to(device).unsqueeze(0) tx = tqdm(texts)
lang_ids = lang_ids.to(device).unsqueeze(0) for t in tx:
bert = bert.to(device).unsqueeze(0) if language in ['EN', 'ZH_MIX_EN']:
ja_bert = ja_bert.to(device).unsqueeze(0) t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) device = self.base_tts_model.device
del phones bert, ja_bert, phones, tones, lang_ids = get_text_for_tts_infer(t, language, self.base_tts_model.hps, device, self.base_tts_model.symbol_to_id)
speakers = torch.LongTensor([speaker_id]).to(device) with torch.no_grad():
audio = self.base_tts_model.model.infer( x_tst = phones.to(device).unsqueeze(0)
x_tst, tones = tones.to(device).unsqueeze(0)
x_tst_lengths, lang_ids = lang_ids.to(device).unsqueeze(0)
speakers, bert = bert.to(device).unsqueeze(0)
tones, ja_bert = ja_bert.to(device).unsqueeze(0)
lang_ids, x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
bert, del phones
ja_bert, speakers = torch.LongTensor([speaker_id]).to(device)
sdp_ratio=sdp_ratio, audio = self.base_tts_model.model.infer(
noise_scale=noise_scale, x_tst,
noise_scale_w = noise_scale_w, x_tst_lengths,
length_scale = 1. / speed, speakers,
)[0][0, 0].data tones,
del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers lang_ids,
audio_list.append(audio) bert,
torch.cuda.empty_cache() ja_bert,
audio_segments = [] sdp_ratio=sdp_ratio,
sr = self.base_tts_model.hps.data.sampling_rate noise_scale=noise_scale,
for segment_data in audio_list: noise_scale_w = noise_scale_w,
audio_segments.append(segment_data.reshape(-1).contiguous()) length_scale = 1. / speed,
audio_segments.append(torch.tensor([0]*int((sr * 0.05) / speed), dtype=segment_data.dtype, device=segment_data.device)) )[0][0, 0].data
audio_segments = torch.cat(audio_segments, dim=-1) del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
if self.debug: audio_list.append(audio)
print("generate base speech!") torch.cuda.empty_cache()
print("**********************,tts sr",sr) audio_segments = []
print(f"audio segment length is [{audio_segments.shape}]") sr = self.base_tts_model.hps.data.sampling_rate
return audio_segments, sr for segment_data in audio_list:
audio_segments.append(segment_data.reshape(-1).contiguous())
def _convert_tone(self, audio_segments.append(torch.tensor([0]*int((sr * 0.05) / speed), dtype=segment_data.dtype, device=segment_data.device))
audio_data: torch.Tensor, audio_segments = torch.cat(audio_segments, dim=-1)
source_se: Optional[np.ndarray]=None, if self.debug:
target_se: Optional[np.ndarray]=None, print("generate base speech!")
tau :float=0.3, print("**********************,tts sr",sr)
message :str="default"): print(f"audio segment length is [{audio_segments.shape}]")
""" return audio_segments, sr
音色转换
param: def _convert_tone(self,
audio_data: _base_tts输出的音频数据 audio_data: torch.Tensor,
source_se: 如果为None, 则使用self.source_se source_se: Optional[np.ndarray]=None,
target_se: 如果为None, 则使用self.target_se target_se: Optional[np.ndarray]=None,
tau: tau :float=0.3,
message: 水印信息 TODO message :str="default"):
return: """
audio: tensor 音色转换
sr: 生成音频的采样速率 param:
""" audio_data: _base_tts输出的音频数据
if source_se is not None: source_se: 如果为None, 则使用self.source_se
source_se = torch.tensor(source_se.astype(np.float32)).to(self.device) target_se: 如果为None, 则使用self.target_se
if target_se is not None: tau:
target_se = torch.tensor(target_se.astype(np.float32)).to(self.device) message: 水印信息 TODO
return:
if source_se is None: audio: tensor
source_se = self.source_se sr: 生成音频的采样速率
if target_se is None: """
target_se = self.target_se if source_se is None:
source_se = self.source_se
hps = self.tone_color_converter.hps if target_se is None:
sr = hps.data.sampling_rate target_se = self.target_se
if self.debug:
print("**********************************, convert sr", sr) hps = self.tone_color_converter.hps
audio_data = audio_data.float() sr = hps.data.sampling_rate
if self.debug:
with torch.no_grad(): print("**********************************, convert sr", sr)
y = audio_data.to(self.tone_color_converter.device) audio_data = audio_data.float()
y = y.unsqueeze(0)
spec = spectrogram_torch(y, hps.data.filter_length, with torch.no_grad():
sr, hps.data.hop_length, hps.data.win_length, y = audio_data.to(self.tone_color_converter.device)
center=False).to(self.tone_color_converter.device) y = y.unsqueeze(0)
spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.tone_color_converter.device) spec = spectrogram_torch(y, hps.data.filter_length,
audio = self.tone_color_converter.model.voice_conversion(spec, spec_lengths, sid_src=source_se, sid_tgt=target_se, tau=tau)[0][ sr, hps.data.hop_length, hps.data.win_length,
0, 0].data center=False).to(self.tone_color_converter.device)
# audio = self.tone_color_converter.add_watermark(audio, message) spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.tone_color_converter.device)
if self.debug: audio = self.tone_color_converter.model.voice_conversion(spec, spec_lengths, sid_src=source_se, sid_tgt=target_se, tau=tau)[0][
print("tone color has been converted!") 0, 0].data
return audio, sr # audio = self.tone_color_converter.add_watermark(audio, message)
if self.debug:
def synthesize(self, print("tone color has been converted!")
text: str, return audio, sr
tts_info,
source_se: Optional[np.ndarray]=None, def tts(self,
target_se: Optional[np.ndarray]=None, text: str,
sdp_ratio=0.2, sdp_ratio=0.2,
quite=True, noise_scale=0.6,
tau :float=0.3, noise_scale_w=0.8,
message :str="default"): speed=1.0,
""" quite=True,
整体pipeline
_base_tts() source_se: Optional[np.ndarray]=None,
_convert_tone() target_se: Optional[np.ndarray]=None,
tensor2numpy() tau :float=0.3,
numpy2bytes() message :str="default"):
param: """
见_base_tts和_convert_tone 整体pipeline
return: _base_tts()
audio: 字节流音频数据 _convert_tone()
sr: 音频数据的采样率 tensor2numpy()
""" numpy2bytes()
audio, sr = self._base_tts(text, param:
sdp_ratio=sdp_ratio, 见_base_tts和_convert_tone
noise_scale=tts_info['noise_scale'], return:
noise_scale_w=tts_info['noise_scale_w'], audio: 字节流音频数据
speed=tts_info['speed'], sr: 音频数据的采样率
quite=quite) """
if self.use_tone_convert and target_se.size>0: audio, sr = self._base_tts(text,
tts_sr = self.base_tts_model.hps.data.sampling_rate sdp_ratio=sdp_ratio,
converter_sr = self.tone_color_converter.hps.data.sampling_rate noise_scale=noise_scale,
audio = F.resample(audio, tts_sr, converter_sr) noise_scale_w=noise_scale_w,
audio, sr = self._convert_tone(audio, speed=speed,
source_se=source_se, quite=quite)
target_se=target_se, if self.use_tone_convert:
tau=tau, tts_sr = self.base_tts_model.hps.data.sampling_rate
message=message) converter_sr = self.tone_color_converter.hps.data.sampling_rate
audio = self.tensor2numpy(audio) audio = F.resample(audio, tts_sr, converter_sr)
audio = self.numpy2bytes(audio) print(audio.dtype)
return audio, sr audio, sr = self._convert_tone(audio,
source_se=source_se,
def save_audio(self, audio, sample_rate, save_path): target_se=target_se,
""" tau=tau,
将numpy类型的音频数据保存至本地 message=message)
param: audio = self.tensor2numpy(audio)
audio: numpy类型的音频数据 audio = self.numpy2bytes(audio)
sample_rate: 数据采样率 return audio, sr
save_path: 保存路径
""" def save_audio(self, audio, sample_rate, save_path):
sf.write(save_path, audio, sample_rate) """
print(f"Audio saved to {save_path}") 将numpy类型的音频数据保存至本地
param:
audio: numpy类型的音频数据
sample_rate: 数据采样率
save_path: 保存路径
"""
sf.write(save_path, audio, sample_rate)
print(f"Audio saved to {save_path}")

View File

@ -2,7 +2,6 @@ import os
import numpy as np import numpy as np
import torch import torch
from torch import LongTensor from torch import LongTensor
from typing import Optional
import soundfile as sf import soundfile as sf
# vits # vits
from .vits import utils, commons from .vits import utils, commons
@ -80,19 +79,19 @@ class TextToSpeech:
print(f"Synthesis time: {time.time() - start_time} s") print(f"Synthesis time: {time.time() - start_time} s")
return audio return audio
def synthesize(self, text, tts_info,target_se: Optional[np.ndarray]=None, save_audio=False, return_bytes=True): def synthesize(self, text, language, speaker_id, noise_scale, noise_scale_w, length_scale, save_audio=False, return_bytes=False):
if not len(text): if not len(text):
return "输入文本不能为空!", None return "输入文本不能为空!", None
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "") text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
if len(text) > 100 and self.limitation: if len(text) > 100 and self.limitation:
return f"输入文字过长!{len(text)}>100", None return f"输入文字过长!{len(text)}>100", None
text = self._preprocess_text(text, tts_info['language']) text = self._preprocess_text(text, language)
audio = self._generate_audio(text, tts_info['speaker_id'], tts_info['noise_scale'], tts_info['noise_scale_w'], tts_info['length_scale']) audio = self._generate_audio(text, speaker_id, noise_scale, noise_scale_w, length_scale)
if self.debug or save_audio: if self.debug or save_audio:
self.save_audio(audio, self.RATE, 'output_file.wav') self.save_audio(audio, self.RATE, 'output_file.wav')
if return_bytes: if return_bytes:
audio = self.convert_numpy_to_bytes(audio) audio = self.convert_numpy_to_bytes(audio)
return audio, self.RATE return self.RATE, audio
def convert_numpy_to_bytes(self, audio_data): def convert_numpy_to_bytes(self, audio_data):
if isinstance(audio_data, np.ndarray): if isinstance(audio_data, np.ndarray):

View File

@ -1,9 +1,7 @@
import websockets
import datetime import datetime
import hashlib import hashlib
import base64 import base64
import hmac import hmac
import json
from urllib.parse import urlencode from urllib.parse import urlencode
from wsgiref.handlers import format_date_time from wsgiref.handlers import format_date_time
from datetime import datetime from datetime import datetime
@ -37,33 +35,3 @@ def generate_xf_asr_url():
} }
url = url + '?' + urlencode(v) url = url + '?' + urlencode(v)
return url return url
def make_first_frame(buf):
first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vad_eos":10000},
"data":{"status":0,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
return json.dumps(first_frame)
def make_continue_frame(buf):
continue_frame = {"data":{"status":1,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
return json.dumps(continue_frame)
def make_last_frame(buf):
last_frame = {"data":{"status":2,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
return json.dumps(last_frame)
def parse_xfasr_recv(message):
code = message['code']
if code!=0:
raise Exception("讯飞ASR错误码"+str(code))
else:
data = message['data']['result']['ws']
result = ""
for i in data:
for w in i['cw']:
result += w['w']
return result
async def xf_asr_websocket_factory():
url = generate_xf_asr_url()
return await websockets.connect(url)