init
This commit is contained in:
commit
4c255ebd59
|
@ -0,0 +1 @@
|
|||
from . import *
|
|
@ -0,0 +1,41 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
#------ 抽象 ASR, LLM, TTS 类 ------ #
|
||||
class ASR(ABC):
|
||||
@abstractmethod
|
||||
async def stream_recognize(self, chunk):
|
||||
pass
|
||||
|
||||
class LLM(ABC):
|
||||
@abstractmethod
|
||||
def chat(self, assistant, prompt, db):
|
||||
pass
|
||||
|
||||
class TTS(ABC):
|
||||
@abstractmethod
|
||||
def synthetize(self, assistant, text):
|
||||
pass
|
||||
# --------------------------------- #
|
||||
|
||||
|
||||
# ----------- 抽象服务类 ----------- #
|
||||
class UserAudioService(ABC):
|
||||
@abstractmethod
|
||||
def user_audio_process(self, audio:str, **kwargs) -> str:
|
||||
pass
|
||||
|
||||
class PromptService(ABC):
|
||||
@abstractmethod
|
||||
def prompt_process(self, prompt:str, **kwargs) -> str:
|
||||
pass
|
||||
|
||||
class LLMMsgService(ABC):
|
||||
@abstractmethod
|
||||
def llm_msg_process(self, llm_chunk:list, **kwargs) -> list:
|
||||
pass
|
||||
|
||||
class TTSAudioService(ABC):
|
||||
@abstractmethod
|
||||
def tts_audio_process(self, audio:bytes, **kwargs) -> bytes:
|
||||
pass
|
||||
# --------------------------------- #
|
|
@ -0,0 +1,285 @@
|
|||
from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
|
||||
from .model import Assistant
|
||||
from .abstract import *
|
||||
from .public import *
|
||||
from utils.vits_utils import TextToSpeech
|
||||
from config import Config
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import struct
|
||||
import base64
|
||||
import json
|
||||
|
||||
# ----------- 初始化vits ----------- #
|
||||
vits = TextToSpeech()
|
||||
# ---------------------------------- #
|
||||
|
||||
|
||||
|
||||
#------ 具体 ASR, LLM, TTS 类 ------ #
|
||||
FIRST_FRAME = 1
|
||||
CONTINUE_FRAME =2
|
||||
LAST_FRAME =3
|
||||
|
||||
class XF_ASR(ASR):
|
||||
def __init__(self):
|
||||
self.websocket = None
|
||||
self.current_message = ""
|
||||
self.audio = ""
|
||||
self.status = FIRST_FRAME
|
||||
self.segment_duration_threshold = 25 #超时时间为25秒
|
||||
self.segment_start_time = None
|
||||
|
||||
async def stream_recognize(self, chunk):
|
||||
if self.websocket is None: #如果websocket未建立,则建立一个新的连接
|
||||
self.websocket = await xf_asr_websocket_factory()
|
||||
if self.segment_start_time is None: #如果是第一段,则记录开始时间
|
||||
self.segment_start_time = asyncio.get_event_loop().time()
|
||||
if chunk['meta_info']['is_end']:
|
||||
self.status = LAST_FRAME
|
||||
audio_data = chunk['audio']
|
||||
self.audio += audio_data
|
||||
|
||||
if self.status == FIRST_FRAME: #发送第一帧
|
||||
await self.websocket.send(make_first_frame(audio_data))
|
||||
self.status = CONTINUE_FRAME
|
||||
elif self.status == CONTINUE_FRAME: #发送中间帧
|
||||
await self.websocket.send(make_continue_frame(audio_data))
|
||||
elif self.status == LAST_FRAME: #发送最后一帧
|
||||
await self.websocket.send(make_last_frame(audio_data))
|
||||
self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv()))
|
||||
if self.current_message == "":
|
||||
raise AsrResultNoneError()
|
||||
await self.websocket.close()
|
||||
print("语音识别结束,用户消息:", self.current_message)
|
||||
return [{"text":self.current_message, "audio":self.audio}]
|
||||
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
if current_time - self.segment_start_time > self.segment_duration_threshold: #超时,发送最后一帧并重置状态
|
||||
await self.websocket.send(make_last_frame())
|
||||
self.current_message += parse_xfasr_recv(await self.websocket.recv())
|
||||
await self.websocket.close()
|
||||
self.websocket = await xf_asr_websocket_factory()
|
||||
self.status = FIRST_FRAME
|
||||
self.segment_start_time = current_time
|
||||
return []
|
||||
|
||||
class MINIMAX_LLM(LLM):
|
||||
def __init__(self):
|
||||
self.token = 0
|
||||
|
||||
async def chat(self, assistant, prompt, db):
|
||||
llm_info = json.loads(assistant.llm_info)
|
||||
messages = json.loads(assistant.messages)
|
||||
messages.append({"role":"user","content":prompt})
|
||||
assistant.messages = json.dumps(messages)
|
||||
payload = json.dumps({
|
||||
"model": llm_info['model'],
|
||||
"stream": True,
|
||||
"messages": messages,
|
||||
"max_tokens": llm_info['max_tokens'],
|
||||
"temperature": llm_info['temperature'],
|
||||
"top_p": llm_info['top_p'],
|
||||
})
|
||||
headers = {
|
||||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
async with aiohttp.ClientSession() as client:
|
||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response:
|
||||
async for chunk in response.content.iter_any():
|
||||
try:
|
||||
chunk_msg = self.__parseChunk(chunk)
|
||||
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
|
||||
yield msg_frame
|
||||
except LLMResponseEnd:
|
||||
msg_frame = {"is_end":True,"msg":""}
|
||||
if self.token > llm_info['max_tokens'] * 0.8:
|
||||
msg_frame['msg'] = 'max_token reached'
|
||||
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
|
||||
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
|
||||
db.commit()
|
||||
assistant.messages = as_query.messages
|
||||
yield msg_frame
|
||||
|
||||
|
||||
def __parseChunk(self, llm_chunk):
|
||||
result = ""
|
||||
data=json.loads(llm_chunk.decode('utf-8')[6:])
|
||||
if data["object"] == "chat.completion": #如果是结束帧
|
||||
self.token = data['usage']['total_tokens']
|
||||
raise LLMResponseEnd()
|
||||
elif data['object'] == 'chat.completion.chunk':
|
||||
for choice in data['choices']:
|
||||
result += choice['delta']['content']
|
||||
else:
|
||||
raise UnkownLLMFrame()
|
||||
return result
|
||||
|
||||
class VITS_TTS(TTS):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def synthetize(self, assistant, text):
|
||||
tts_info = json.loads(assistant.tts_info)
|
||||
return vits.synthesize(text, tts_info)
|
||||
# --------------------------------- #
|
||||
|
||||
|
||||
|
||||
# ------ ASR, LLM, TTS 工厂类 ------ #
|
||||
class ASRFactory:
|
||||
def create_asr(self,asr_type:str) -> ASR:
|
||||
if asr_type == 'XF':
|
||||
return XF_ASR()
|
||||
|
||||
class LLMFactory:
|
||||
def create_llm(self,llm_type:str) -> LLM:
|
||||
if llm_type == 'MINIMAX':
|
||||
return MINIMAX_LLM()
|
||||
|
||||
class TTSFactory:
|
||||
def create_tts(self,tts_type:str) -> TTS:
|
||||
if tts_type == 'VITS':
|
||||
return VITS_TTS()
|
||||
# --------------------------------- #
|
||||
|
||||
|
||||
|
||||
# ----------- 具体服务类 ----------- #
|
||||
# 从单说话人的asr_results中取出第一个结果作为prompt
|
||||
class BasicPromptService(PromptService):
|
||||
def prompt_process(self, prompt:str, **kwargs):
|
||||
if 'asr_result' in kwargs:
|
||||
return kwargs['asr_result'][0]['text'], kwargs
|
||||
else:
|
||||
raise NoAsrResultsError()
|
||||
|
||||
class UserAudioRecordService(UserAudioService):
|
||||
def user_audio_process(self, audio:str, **kwargs):
|
||||
audio_decoded = base64.b64decode(audio)
|
||||
kwargs['recorder'].user_audio += audio_decoded
|
||||
return audio,kwargs
|
||||
|
||||
class TTSAudioRecordService(TTSAudioService):
|
||||
def tts_audio_process(self, audio:bytes, **kwargs):
|
||||
kwargs['recorder'].tts_audio += audio
|
||||
return audio,kwargs
|
||||
# --------------------------------- #
|
||||
|
||||
|
||||
|
||||
# ------------ 服务链类 ----------- #
|
||||
class UserAudioServiceChain:
|
||||
def __init__(self):
|
||||
self.services = []
|
||||
def add_service(self, service:UserAudioService):
|
||||
self.services.append(service)
|
||||
def user_audio_process(self, audio, **kwargs):
|
||||
for service in self.services:
|
||||
audio, kwargs = service.user_audio_process(audio, **kwargs)
|
||||
return audio
|
||||
|
||||
class PromptServiceChain:
|
||||
def __init__(self):
|
||||
self.services = []
|
||||
def add_service(self, service:PromptService):
|
||||
self.services.append(service)
|
||||
def prompt_process(self, asr_results):
|
||||
kwargs = {"asr_result":asr_results}
|
||||
prompt = ""
|
||||
for service in self.services:
|
||||
prompt, kwargs = service.prompt_process(prompt, **kwargs)
|
||||
return prompt
|
||||
|
||||
class LLMMsgServiceChain:
|
||||
def __init__(self, ):
|
||||
self.services = []
|
||||
self.spliter = SentenceSegmentation()
|
||||
def add_service(self, service:LLMMsgService):
|
||||
self.services.append(service)
|
||||
def llm_msg_process(self, llm_msg):
|
||||
kwargs = {}
|
||||
llm_chunks = self.spliter.split(llm_msg)
|
||||
for service in self.services:
|
||||
llm_chunks , kwargs = service.llm_msg_process(llm_chunks, **kwargs)
|
||||
return llm_chunks
|
||||
|
||||
class TTSAudioServiceChain:
|
||||
def __init__(self):
|
||||
self.services = []
|
||||
def add_service(self, service:TTSAudioService):
|
||||
self.services.append(service)
|
||||
def tts_audio_process(self, audio, **kwargs):
|
||||
for service in self.services:
|
||||
audio, kwargs = service.tts_audio_process(audio, **kwargs)
|
||||
return audio
|
||||
# --------------------------------- #
|
||||
|
||||
class Agent():
|
||||
def __init__(self, asr_type:str, llm_type:str, tts_type:str):
|
||||
asrFactory = ASRFactory()
|
||||
llmFactory = LLMFactory()
|
||||
ttsFactory = TTSFactory()
|
||||
self.recorder = None
|
||||
|
||||
self.user_audio_service_chain = UserAudioServiceChain() #创建用户音频处理服务链
|
||||
self.user_audio_service_chain.add_service(UserAudioRecordService()) #添加用户音频记录服务
|
||||
|
||||
self.asr = asrFactory.create_asr(asr_type)
|
||||
|
||||
self.prompt_service_chain = PromptServiceChain()
|
||||
self.prompt_service_chain.add_service(BasicPromptService())
|
||||
|
||||
self.llm = llmFactory.create_llm(llm_type)
|
||||
|
||||
self.llm_chunk_service_chain = LLMMsgServiceChain()
|
||||
|
||||
self.tts = ttsFactory.create_tts(tts_type)
|
||||
|
||||
self.tts_audio_service_chain = TTSAudioServiceChain()
|
||||
self.tts_audio_service_chain.add_service(TTSAudioRecordService())
|
||||
|
||||
def init_recorder(self,user_id):
|
||||
self.recorder = Recorder(user_id)
|
||||
|
||||
# 对用户输入的音频进行预处理
|
||||
def user_audio_process(self, audio, db):
|
||||
return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder)
|
||||
|
||||
# 进行流式语音识别
|
||||
async def stream_recognize(self, chunk, db):
|
||||
return await self.asr.stream_recognize(chunk)
|
||||
|
||||
# 进行Prompt加工
|
||||
def prompt_process(self, asr_results, db):
|
||||
return self.prompt_service_chain.prompt_process(asr_results)
|
||||
|
||||
# 进行大模型调用
|
||||
async def chat(self, assistant ,prompt, db):
|
||||
return self.llm.chat(assistant, prompt, db)
|
||||
|
||||
# 对大模型的返回进行处理
|
||||
def llm_msg_process(self, llm_chunk, db):
|
||||
return self.llm_chunk_service_chain.llm_msg_process(llm_chunk)
|
||||
|
||||
# 进行TTS合成
|
||||
def synthetize(self, assistant, text, db):
|
||||
return self.tts.synthetize(assistant, text)
|
||||
|
||||
# 对合成后的音频进行处理
|
||||
def tts_audio_process(self, audio, db):
|
||||
return self.tts_audio_service_chain.tts_audio_process(audio, recorder=self.recorder)
|
||||
|
||||
# 编码
|
||||
def encode(self, text, audio):
|
||||
text_resp = {"type":"text","code":200,"msg":text}
|
||||
text_resp_bytes = json.dumps(text_resp,ensure_ascii=False).encode('utf-8')
|
||||
header = struct.pack('!II',len(text_resp_bytes),len(audio))
|
||||
final_resp = header + text_resp_bytes + audio
|
||||
return final_resp
|
||||
|
||||
# 保存音频
|
||||
def save(self):
|
||||
self.recorder.write()
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
from sqlalchemy import create_engine, Column, Integer, String, CHAR
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from config import Config
|
||||
|
||||
engine = create_engine(Config.SQLITE_URL, connect_args={"check_same_thread": False})
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base = declarative_base()
|
||||
|
||||
MESSAGE_LENGTH_LIMIT = 2**31-1
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "user"
|
||||
id = Column(CHAR(36), primary_key=True)
|
||||
name = Column(String(32))
|
||||
email = Column(String(64))
|
||||
password = Column(String(128))
|
||||
|
||||
class Assistant(Base):
|
||||
__tablename__ = "assistant"
|
||||
id = Column(CHAR(36), primary_key=True)
|
||||
user_id = Column(CHAR(36))
|
||||
name = Column(String(32))
|
||||
system_prompt = Column(String(512))
|
||||
messages = Column(String(MESSAGE_LENGTH_LIMIT))
|
||||
user_info = Column(String(256))
|
||||
llm_info = Column(String(256))
|
||||
tts_info = Column(String(256))
|
||||
token = Column(Integer)
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
from datetime import datetime
|
||||
import wave
|
||||
import json
|
||||
|
||||
# -------------- 公共类 ------------ #
|
||||
class AsrResultNoneError(Exception):
|
||||
def __init__(self, message="Asr Result is None!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
class NoAsrResultsError(Exception):
|
||||
def __init__(self, message="No Asr Results!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
class LLMResponseEnd(Exception):
|
||||
def __init__(self, message="LLM Response End!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
class UnkownLLMFrame(Exception):
|
||||
def __init__(self, message="Unkown LLM Frame!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
class TokenOutofRangeError(Exception):
|
||||
def __init__(self, message="Token Out of Range!"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
class SentenceSegmentation():
|
||||
def __init__(self,):
|
||||
self.is_first_sentence = True
|
||||
self.cache = ""
|
||||
|
||||
def __sentenceSegmentation(self, llm_frame):
|
||||
results = []
|
||||
if llm_frame['is_end']:
|
||||
if self.cache:
|
||||
results.append(self.cache)
|
||||
self.cache = ""
|
||||
return results
|
||||
for char in llm_frame['msg']:
|
||||
self.cache += char
|
||||
if self.is_first_sentence and char in ',.?!,。?!':
|
||||
results.append(self.cache)
|
||||
self.cache = ""
|
||||
self.is_first_sentence = False
|
||||
elif char in '。?!':
|
||||
results.append(self.cache)
|
||||
self.cache = ""
|
||||
return results
|
||||
|
||||
def split(self, llm_chunk):
|
||||
return self.__sentenceSegmentation(llm_chunk)
|
||||
|
||||
class Recorder:
|
||||
def __init__(self, user_id):
|
||||
self.input_wav_path = 'storage/wav/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'i.wav'
|
||||
self.output_wav_path = 'storage/wav/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'o.wav'
|
||||
self.out_put_text_path = 'storage/record/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'o.txt'
|
||||
self.input_sr = 16000
|
||||
self.output_sr = 22050
|
||||
self.user_audio = b''
|
||||
self.tts_audio = b''
|
||||
self.input_text = ""
|
||||
self.output_text = ""
|
||||
|
||||
def write(self):
|
||||
record = {"input_wav":self.input_wav_path,"input_text":self.input_text,"input_sr":self.input_sr,"output_wav":self.output_wav_path,"output_text":self.output_text,"output_sr":self.output_sr}
|
||||
with wave.open(self.input_wav_path, 'wb') as wav_file:
|
||||
wav_file.setparams((1, 2, self.input_sr, 0, 'NONE', 'not compressed'))
|
||||
wav_file.writeframes(self.user_audio)
|
||||
with wave.open(self.output_wav_path, 'wb') as wav_file:
|
||||
wav_file.setparams((1, 2, self.output_sr, 0, 'NONE', 'not compressed'))
|
||||
wav_file.writeframes(self.tts_audio)
|
||||
with open(self.out_put_text_path, 'w', encoding='utf-8') as file:
|
||||
file.write(json.dumps(record, ensure_ascii=False))
|
||||
# ---------------------------------- #
|
|
@ -0,0 +1,36 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
class create_assistant_request(BaseModel):
|
||||
user_id: str
|
||||
name: str
|
||||
system_prompt : str
|
||||
user_info : str
|
||||
llm_info : str
|
||||
tts_info : str
|
||||
|
||||
class update_assistant_request(BaseModel):
|
||||
name: str
|
||||
system_prompt : str
|
||||
messages: str
|
||||
user_info : str
|
||||
llm_info : str
|
||||
tts_info : str
|
||||
|
||||
class create_user_request(BaseModel):
|
||||
name: str
|
||||
email: str
|
||||
password: str
|
||||
|
||||
class update_user_request(BaseModel):
|
||||
name: str
|
||||
email: str
|
||||
password: str
|
||||
|
||||
class update_assistant_system_prompt_request(BaseModel):
|
||||
system_prompt:str
|
||||
|
||||
class update_assistant_deatil_params_request(BaseModel):
|
||||
model :str
|
||||
temperature :float
|
||||
speaker_id:int
|
||||
length_scale:float
|
|
@ -0,0 +1,20 @@
|
|||
class Config:
|
||||
SQLITE_URL = 'sqlite:///takway.db'
|
||||
ASR = "XF" #在此处选择语音识别引擎
|
||||
LLM = "MINIMAX" #在此处选择大模型
|
||||
TTS = "VITS" #在此处选择语音合成引擎
|
||||
class UVICORN:
|
||||
HOST = '0.0.0.0'
|
||||
PORT = 7878
|
||||
class XF_ASR:
|
||||
APP_ID = "f1c121c1" #讯飞语音识别APP_ID
|
||||
API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET
|
||||
API_KEY = "36b316c7977fa534ae1e3bf52157bb92" #讯飞语音识别API_KEY
|
||||
DOMAIN = "iat"
|
||||
LANGUAGE = "zh_cn"
|
||||
ACCENT = "mandarin"
|
||||
VAD_EOS = 10000
|
||||
class MINIMAX_LLM:
|
||||
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLph5EiLCJVc2VyTmFtZSI6IumHkSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzY4NTM2NDM3MzE1MDgwODg2IiwiUGhvbmUiOiIxMzEzNjE0NzUyNyIsIkdyb3VwSUQiOiIxNzY4NTM2NDM3MzA2NjkyMjc4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjQtMDUtMTggMTY6MTQ6MDMiLCJpc3MiOiJtaW5pbWF4In0.LypYOkJXwKV6GzDM1dcNn4L0m19o8Q_Lvmn6SkMMb9WAfDJYxEnTc5odm-L4WAWfbur_gY0cQzgoHnI14t4XSaAvqfmcdCrKYpJbKoBmMse_RogJs7KOBt658je3wES4pBUKQll6NbogQB1f93lnA9IYv4aEVldfqglbCikd54XO8E9Ptn4gX9Mp8fUn3lCpZ6_OSlmgZsQySrmt1sDHHzi3DlkdXlFSI38TQSZIa5RhFpI8WSBLIbaKl84OhaDzo7v99k9DUCzb5JGh0eZOnUT0YswbKCPeV8rZ1XUiOVQrna1uiDLvqv54aIt3vsu-LypYmnHxtZ_z4u2gt87pZg"
|
||||
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
||||
|
|
@ -0,0 +1,208 @@
|
|||
from fastapi import FastAPI, Depends, WebSocket, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from config import Config
|
||||
from app.concrete import Agent, AsrResultNoneError
|
||||
from app.model import Assistant, User, get_db
|
||||
from app.schemas import *
|
||||
import uvicorn
|
||||
import uuid
|
||||
import json
|
||||
|
||||
# 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
def update_messages(messages, llm_text):
|
||||
messages = json.loads(messages)
|
||||
messages.append({"role":"assistant","content":llm_text})
|
||||
return json.dumps(messages,ensure_ascii=False)
|
||||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
# 创建FastAPI实例
|
||||
app = FastAPI()
|
||||
|
||||
# 增删查改 assiatant----------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
# 创建一个assistant
|
||||
@app.post("/api/assistants")
|
||||
async def create_assistant(request: create_assistant_request,db=Depends(get_db)):
|
||||
id = str(uuid.uuid4())
|
||||
messages = json.dumps([{"role":"system","content":request.system_prompt}],ensure_ascii=False)
|
||||
assistant = Assistant(id=id,user_id=request.user_id, name=request.name, system_prompt=request.system_prompt, messages=messages,
|
||||
user_info=request.user_info, llm_info=request.llm_info, tts_info=request.tts_info, token=0)
|
||||
db.add(assistant)
|
||||
db.commit()
|
||||
return {"code":200,"msg":"success","data":{"id":id}}
|
||||
|
||||
# 删除一个assistant
|
||||
@app.delete("/api/assistants/{id}")
|
||||
async def delete_assistant(id: str,db=Depends(get_db)):
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if assistant:
|
||||
db.delete(assistant)
|
||||
db.commit()
|
||||
return {"code":200,"msg":"success","data":{}}
|
||||
else:
|
||||
return {"code":404,'msg':"assistant not found","data":{}}
|
||||
|
||||
# 更新一个assistant
|
||||
@app.put("/api/assistants/{id}")
|
||||
async def update_assistant(id: str,request: update_assistant_request,db=Depends(get_db)):
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if assistant:
|
||||
assistant.name = request.name
|
||||
assistant.system_prompt = request.system_prompt
|
||||
assistant.messages = request.messages
|
||||
assistant.user_info = request.user_info
|
||||
assistant.llm_info = request.llm_info
|
||||
assistant.tts_info = request.tts_info
|
||||
db.commit()
|
||||
return {"code":200,"msg":"success","data":{}}
|
||||
else:
|
||||
return {"code":404,'msg':"assistant not found","data":{}}
|
||||
|
||||
# 获取一个assistant
|
||||
@app.get("/api/assistants/{id}")
|
||||
async def get_assistant(id: str,db=Depends(get_db)):
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if assistant:
|
||||
return {"code":200,"msg":"success","data":assistant}
|
||||
else:
|
||||
return {"code":404,'msg':"assistant not found","data":{}}
|
||||
|
||||
# 获取所有的assistant名称和id
|
||||
@app.get("/api/assistants")
|
||||
async def get_all_assistants_name_id(db=Depends(get_db)):
|
||||
assistants = db.query(Assistant.id, Assistant.name).all()
|
||||
return {"code":200,"msg":"success","data":[{"id": assistant.id, "name": assistant.name} for assistant in assistants]}
|
||||
|
||||
# 重置一个assistant的消息
|
||||
@app.post("/api/assistants/{id}/reset_msg")
|
||||
async def reset_assistant_msg(id: str,db=Depends(get_db)):
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if assistant:
|
||||
assistant.messages = json.dumps([{"role":"system","content":assistant.system_prompt}],ensure_ascii=False)
|
||||
db.commit()
|
||||
return {"code":200,"msg":"success","data":{}}
|
||||
else:
|
||||
return {"code":404,'msg':"assistant not found","data":{}}
|
||||
|
||||
# 修改一个assistant的system_prompt
|
||||
@app.put("/api/assistants/{id}/system_prompt")
|
||||
async def update_assistant_system_prompt(id: str,request: update_assistant_system_prompt_request,db=Depends(get_db)):
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if assistant:
|
||||
assistant.system_prompt = request.system_prompt
|
||||
db.commit()
|
||||
return {"code":200,"msg":"success","data":{}}
|
||||
else:
|
||||
return {"code":404,'msg':"assistant not found","data":{}}
|
||||
|
||||
# 更新具体参数
|
||||
@app.put("/api/assistants/{id}/deatil_params")
|
||||
async def update_assistant_deatil_params(id: str,request: update_assistant_deatil_params_request,db=Depends(get_db)):
|
||||
assistant = db.query(Assistant).filter(Assistant.id == id).first()
|
||||
if assistant:
|
||||
llm_info = json.loads(assistant.llm_info)
|
||||
tts_info = json.loads(assistant.tts_info)
|
||||
llm_info['model'] = request.model
|
||||
llm_info['temperature'] = request.temperature
|
||||
tts_info['speaker_id'] = request.speaker_id
|
||||
tts_info['length_scale'] = request.length_scale
|
||||
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
|
||||
assistant.tts_info = json.dumps(tts_info, ensure_ascii=False)
|
||||
db.commit()
|
||||
return {"code":200,"msg":"success","data":{}}
|
||||
else:
|
||||
return {"code":404,'msg':"assistant not found","data":{}}
|
||||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
# 用户增删改查接口 ----------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
# 添加用户
|
||||
@app.post("/api/users")
|
||||
async def create_user(request: create_user_request,db=Depends(get_db)):
|
||||
id = str(uuid.uuid4())
|
||||
user = User(id=id, name=request.name, email=request.email, password=request.password)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
return {"code":200,"msg":"success","data":{"id":id}}
|
||||
|
||||
# 删除用户
|
||||
@app.delete("/api/users/{id}")
|
||||
async def delete_user(id: str,db=Depends(get_db)):
|
||||
user = db.query(User).filter(User.id == id).first()
|
||||
if user:
|
||||
db.delete(user)
|
||||
db.commit()
|
||||
return {"code":200,"msg":"success","data":{}}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="user not found")
|
||||
|
||||
# 获取用户
|
||||
@app.get("/api/users/{id}")
|
||||
async def get_user(id: str,db=Depends(get_db)):
|
||||
user = db.query(User).filter(User.id == id).first()
|
||||
if user:
|
||||
return {"code":200,"msg":"success","data":user}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="user not found")
|
||||
|
||||
# 更新用户
|
||||
@app.put("/api/users/{id}")
|
||||
async def update_user(id: str,request: update_user_request,db=Depends(get_db)):
|
||||
user = db.query(User).filter(User.id == id).first()
|
||||
if user:
|
||||
user.name = request.name
|
||||
user.email = request.email
|
||||
user.password = request.password
|
||||
db.commit()
|
||||
return {"code":200,"msg":"success","data":{}}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="user not found")
|
||||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
# 流式聊天websokct接口 ------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
@app.websocket("/api/chat/streaming/temporary")
|
||||
async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
|
||||
await ws.accept()
|
||||
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
|
||||
assistant = None
|
||||
asr_results = []
|
||||
llm_text = ""
|
||||
try:
|
||||
while len(asr_results)==0:
|
||||
chunk = json.loads(await ws.receive_text())
|
||||
if assistant is None:
|
||||
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
|
||||
agent.init_recorder(assistant.user_id)
|
||||
chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
|
||||
asr_results = await agent.stream_recognize(chunk, db)
|
||||
except AsrResultNoneError:
|
||||
await ws.send_text(json.dumps({"type":"close","code":201,"msg":""}, ensure_ascii=False))
|
||||
return
|
||||
prompt = agent.prompt_process(asr_results, db)
|
||||
agent.recorder.input_text = prompt
|
||||
llm_frames = await agent.chat(assistant, prompt, db)
|
||||
async for llm_frame in llm_frames:
|
||||
resp_msgs = agent.llm_msg_process(llm_frame, db)
|
||||
for resp_msg in resp_msgs:
|
||||
llm_text += resp_msg
|
||||
tts_audio = agent.synthetize(assistant, resp_msg, db)
|
||||
agent.tts_audio_process(tts_audio, db)
|
||||
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
|
||||
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||
assistant.messages = update_messages(assistant.messages, llm_text)
|
||||
db.commit()
|
||||
agent.recorder.output_text = llm_text
|
||||
agent.save()
|
||||
await ws.close()
|
||||
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 允许所有源,也可以指定特定源
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # 允许所有方法
|
||||
allow_headers=["*"], # 允许所有头
|
||||
)
|
||||
|
||||
# 启动服务
|
||||
uvicorn.run(app, host=Config.UVICORN.HOST, port=Config.UVICORN.PORT)
|
Binary file not shown.
|
@ -0,0 +1,68 @@
|
|||
import json
|
||||
import base64
|
||||
from datetime import datetime
|
||||
import io
|
||||
from websocket import create_connection
|
||||
|
||||
data = {
|
||||
"text": "",
|
||||
"audio": "",
|
||||
"meta_info": {
|
||||
"session_id":"a36c9bb4-e813-4f0e-9c75-18e049c60f48",
|
||||
"stream": True,
|
||||
"voice_synthesize": True,
|
||||
"is_end": False,
|
||||
"encoding": "raw"
|
||||
}
|
||||
}
|
||||
|
||||
def read_pcm_file_in_chunks(chunk_size):
|
||||
with open('example_recording.wav', 'rb') as pcm_file:
|
||||
while True:
|
||||
data = pcm_file.read(chunk_size)
|
||||
if not data:
|
||||
break
|
||||
yield data
|
||||
|
||||
def send_audio_chunk(websocket, chunk):
|
||||
# 将PCM数据进行Base64编码
|
||||
encoded_data = base64.b64encode(chunk).decode('utf-8')
|
||||
# 更新data字典中的"audio"键的值为Base64编码后的音频数据
|
||||
data["audio"] = encoded_data
|
||||
# 将JSON数据对象转换为JSON字符串
|
||||
message = json.dumps(data)
|
||||
# 发送JSON字符串到WebSocket接口
|
||||
websocket.send(message)
|
||||
|
||||
|
||||
def send_json():
|
||||
websocket = create_connection('ws://114.214.236.207:7878/api/chat/streaming/temporary')
|
||||
chunks = read_pcm_file_in_chunks(1024) # 读取PCM文件并生成数据块
|
||||
for chunk in chunks:
|
||||
send_audio_chunk(websocket, chunk)
|
||||
# print("发送数据块:", len(chunk))
|
||||
import time; time.sleep(0.01)
|
||||
# threading.Event().wait(0.01) # 等待0.01秒
|
||||
# 设置data字典中的"is_end"键为True,表示音频流结束
|
||||
data["meta_info"]["is_end"] = True
|
||||
# 发送最后一个数据块和流结束信号
|
||||
send_audio_chunk(websocket, b'') # 发送空数据块表示结束
|
||||
# 等待并打印接收到的数据
|
||||
print("等待接收:", datetime.now())
|
||||
audio_bytes = b''
|
||||
while True:
|
||||
data_ws = websocket.recv()
|
||||
try:
|
||||
message_json = json.loads(data_ws)
|
||||
print(message_json) # 打印接收到的消息
|
||||
if message_json["type"] == "close":
|
||||
break # 如果没有接收到消息,则退出循环
|
||||
except Exception as e:
|
||||
audio_bytes += data_ws
|
||||
|
||||
print(e)
|
||||
print("接收完毕:", datetime.now())
|
||||
websocket.close()
|
||||
|
||||
# 启动事件循环
|
||||
send_json()
|
|
@ -0,0 +1 @@
|
|||
from . import *
|
|
@ -0,0 +1,2 @@
|
|||
from .text import *
|
||||
from .monotonic_align import *
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,302 @@
|
|||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
# import commons
|
||||
# from modules import LayerNorm
|
||||
from utils.vits import commons
|
||||
from utils.vits.modules import LayerNorm
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.window_size = window_size
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.attn_layers = nn.ModuleList()
|
||||
self.norm_layers_1 = nn.ModuleList()
|
||||
self.ffn_layers = nn.ModuleList()
|
||||
self.norm_layers_2 = nn.ModuleList()
|
||||
for i in range(self.n_layers):
|
||||
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
x = x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
y = self.attn_layers[i](x, x, attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
|
||||
y = self.ffn_layers[i](x, x_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.proximal_bias = proximal_bias
|
||||
self.proximal_init = proximal_init
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.self_attn_layers = nn.ModuleList()
|
||||
self.norm_layers_0 = nn.ModuleList()
|
||||
self.encdec_attn_layers = nn.ModuleList()
|
||||
self.norm_layers_1 = nn.ModuleList()
|
||||
self.ffn_layers = nn.ModuleList()
|
||||
self.norm_layers_2 = nn.ModuleList()
|
||||
for i in range(self.n_layers):
|
||||
self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
|
||||
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
||||
self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
def forward(self, x, x_mask, h, h_mask):
|
||||
"""
|
||||
x: decoder input
|
||||
h: encoder output
|
||||
"""
|
||||
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
x = x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_0[i](x + y)
|
||||
|
||||
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
|
||||
y = self.ffn_layers[i](x, x_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
|
||||
super().__init__()
|
||||
assert channels % n_heads == 0
|
||||
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels
|
||||
self.n_heads = n_heads
|
||||
self.p_dropout = p_dropout
|
||||
self.window_size = window_size
|
||||
self.heads_share = heads_share
|
||||
self.block_length = block_length
|
||||
self.proximal_bias = proximal_bias
|
||||
self.proximal_init = proximal_init
|
||||
self.attn = None
|
||||
|
||||
self.k_channels = channels // n_heads
|
||||
self.conv_q = nn.Conv1d(channels, channels, 1)
|
||||
self.conv_k = nn.Conv1d(channels, channels, 1)
|
||||
self.conv_v = nn.Conv1d(channels, channels, 1)
|
||||
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
if window_size is not None:
|
||||
n_heads_rel = 1 if heads_share else n_heads
|
||||
rel_stddev = self.k_channels**-0.5
|
||||
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
|
||||
nn.init.xavier_uniform_(self.conv_q.weight)
|
||||
nn.init.xavier_uniform_(self.conv_k.weight)
|
||||
nn.init.xavier_uniform_(self.conv_v.weight)
|
||||
if proximal_init:
|
||||
with torch.no_grad():
|
||||
self.conv_k.weight.copy_(self.conv_q.weight)
|
||||
self.conv_k.bias.copy_(self.conv_q.bias)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
|
||||
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
||||
|
||||
x = self.conv_o(x)
|
||||
return x
|
||||
|
||||
def attention(self, query, key, value, mask=None):
|
||||
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
||||
b, d, t_s, t_t = (*key.size(), query.size(2))
|
||||
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
||||
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
||||
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
||||
|
||||
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
||||
if self.window_size is not None:
|
||||
assert t_s == t_t, "Relative attention is only available for self-attention."
|
||||
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
||||
rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
|
||||
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
||||
scores = scores + scores_local
|
||||
if self.proximal_bias:
|
||||
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
||||
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, -1e4)
|
||||
if self.block_length is not None:
|
||||
assert t_s == t_t, "Local attention is only available for self-attention."
|
||||
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
|
||||
scores = scores.masked_fill(block_mask == 0, -1e4)
|
||||
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
||||
p_attn = self.drop(p_attn)
|
||||
output = torch.matmul(p_attn, value)
|
||||
if self.window_size is not None:
|
||||
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
||||
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
||||
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
||||
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
||||
return output, p_attn
|
||||
|
||||
def _matmul_with_relative_values(self, x, y):
|
||||
"""
|
||||
x: [b, h, l, m]
|
||||
y: [h or 1, m, d]
|
||||
ret: [b, h, l, d]
|
||||
"""
|
||||
ret = torch.matmul(x, y.unsqueeze(0))
|
||||
return ret
|
||||
|
||||
def _matmul_with_relative_keys(self, x, y):
|
||||
"""
|
||||
x: [b, h, l, d]
|
||||
y: [h or 1, m, d]
|
||||
ret: [b, h, l, m]
|
||||
"""
|
||||
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
||||
return ret
|
||||
|
||||
def _get_relative_embeddings(self, relative_embeddings, length):
|
||||
max_relative_position = 2 * self.window_size + 1
|
||||
# Pad first before slice to avoid using cond ops.
|
||||
pad_length = max(length - (self.window_size + 1), 0)
|
||||
slice_start_position = max((self.window_size + 1) - length, 0)
|
||||
slice_end_position = slice_start_position + 2 * length - 1
|
||||
if pad_length > 0:
|
||||
padded_relative_embeddings = F.pad(
|
||||
relative_embeddings,
|
||||
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
|
||||
else:
|
||||
padded_relative_embeddings = relative_embeddings
|
||||
used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
|
||||
return used_relative_embeddings
|
||||
|
||||
def _relative_position_to_absolute_position(self, x):
|
||||
"""
|
||||
x: [b, h, l, 2*l-1]
|
||||
ret: [b, h, l, l]
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# Concat columns of pad to shift from relative to absolute indexing.
|
||||
x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
|
||||
|
||||
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
||||
x_flat = x.view([batch, heads, length * 2 * length])
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
|
||||
|
||||
# Reshape and slice out the padded elements.
|
||||
x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
|
||||
return x_final
|
||||
|
||||
def _absolute_position_to_relative_position(self, x):
|
||||
"""
|
||||
x: [b, h, l, l]
|
||||
ret: [b, h, l, 2*l-1]
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# padd along column
|
||||
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
|
||||
x_flat = x.view([batch, heads, length**2 + length*(length -1)])
|
||||
# add 0's in the beginning that will skew the elements after reshape
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
||||
x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
|
||||
return x_final
|
||||
|
||||
def _attention_bias_proximal(self, length):
|
||||
"""Bias for self-attention to encourage attention to close positions.
|
||||
Args:
|
||||
length: an integer scalar.
|
||||
Returns:
|
||||
a Tensor with shape [1, 1, length, length]
|
||||
"""
|
||||
r = torch.arange(length, dtype=torch.float32)
|
||||
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
||||
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.activation = activation
|
||||
self.causal = causal
|
||||
|
||||
if causal:
|
||||
self.padding = self._causal_padding
|
||||
else:
|
||||
self.padding = self._same_padding
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
||||
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.conv_1(self.padding(x * x_mask))
|
||||
if self.activation == "gelu":
|
||||
x = x * torch.sigmoid(1.702 * x)
|
||||
else:
|
||||
x = torch.relu(x)
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(self.padding(x * x_mask))
|
||||
return x * x_mask
|
||||
|
||||
def _causal_padding(self, x):
|
||||
if self.kernel_size == 1:
|
||||
return x
|
||||
pad_l = self.kernel_size - 1
|
||||
pad_r = 0
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, commons.convert_pad_shape(padding))
|
||||
return x
|
||||
|
||||
def _same_padding(self, x):
|
||||
if self.kernel_size == 1:
|
||||
return x
|
||||
pad_l = (self.kernel_size - 1) // 2
|
||||
pad_r = self.kernel_size // 2
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, commons.convert_pad_shape(padding))
|
||||
return x
|
|
@ -0,0 +1,172 @@
|
|||
import math
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import torch.jit
|
||||
|
||||
|
||||
def script_method(fn, _rcb=None):
|
||||
return fn
|
||||
|
||||
|
||||
def script(obj, optimize=True, _frames_up=0, _rcb=None):
|
||||
return obj
|
||||
|
||||
|
||||
torch.jit.script_method = script_method
|
||||
torch.jit.script = script
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size*dilation - dilation)/2)
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
return result
|
||||
|
||||
|
||||
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
||||
"""KL(P||Q)"""
|
||||
kl = (logs_q - logs_p) - 0.5
|
||||
kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
|
||||
return kl
|
||||
|
||||
|
||||
def rand_gumbel(shape):
|
||||
"""Sample from the Gumbel distribution, protect from overflows."""
|
||||
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
||||
return -torch.log(-torch.log(uniform_samples))
|
||||
|
||||
|
||||
def rand_gumbel_like(x):
|
||||
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
||||
return g
|
||||
|
||||
|
||||
def slice_segments(x, ids_str, segment_size=4):
|
||||
ret = torch.zeros_like(x[:, :, :segment_size])
|
||||
for i in range(x.size(0)):
|
||||
idx_str = ids_str[i]
|
||||
idx_end = idx_str + segment_size
|
||||
ret[i] = x[i, :, idx_str:idx_end]
|
||||
return ret
|
||||
|
||||
|
||||
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
||||
b, d, t = x.size()
|
||||
if x_lengths is None:
|
||||
x_lengths = t
|
||||
ids_str_max = x_lengths - segment_size + 1
|
||||
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
||||
ret = slice_segments(x, ids_str, segment_size)
|
||||
return ret, ids_str
|
||||
|
||||
|
||||
def get_timing_signal_1d(
|
||||
length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
||||
position = torch.arange(length, dtype=torch.float)
|
||||
num_timescales = channels // 2
|
||||
log_timescale_increment = (
|
||||
math.log(float(max_timescale) / float(min_timescale)) /
|
||||
(num_timescales - 1))
|
||||
inv_timescales = min_timescale * torch.exp(
|
||||
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
|
||||
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
||||
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
||||
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
||||
signal = signal.view(1, channels, length)
|
||||
return signal
|
||||
|
||||
|
||||
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return x + signal.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
||||
|
||||
|
||||
def subsequent_mask(length):
|
||||
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
||||
return mask
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
in_act = input_a + input_b
|
||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def shift_1d(x):
|
||||
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
||||
return x
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
"""
|
||||
duration: [b, 1, t_x]
|
||||
mask: [b, 1, t_y, t_x]
|
||||
"""
|
||||
device = duration.device
|
||||
|
||||
b, _, t_y, t_x = mask.shape
|
||||
cum_duration = torch.cumsum(duration, -1)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = path.unsqueeze(1).transpose(2,3) * mask
|
||||
return path
|
||||
|
||||
|
||||
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
||||
norm_type = float(norm_type)
|
||||
if clip_value is not None:
|
||||
clip_value = float(clip_value)
|
||||
|
||||
total_norm = 0
|
||||
for p in parameters:
|
||||
param_norm = p.grad.data.norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
if clip_value is not None:
|
||||
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
||||
total_norm = total_norm ** (1. / norm_type)
|
||||
return total_norm
|
|
@ -0,0 +1,535 @@
|
|||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
# import commons
|
||||
# import modules
|
||||
# import attentions
|
||||
# import monotonic_align
|
||||
from utils.vits import commons, modules, attentions, monotonic_align
|
||||
from utils.vits.commons import init_weights, get_padding
|
||||
|
||||
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
# from commons import init_weights, get_padding
|
||||
|
||||
|
||||
class StochasticDurationPredictor(nn.Module):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
|
||||
super().__init__()
|
||||
filter_channels = in_channels # it needs to be removed from future version.
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.n_flows = n_flows
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.log_flow = modules.Log()
|
||||
self.flows = nn.ModuleList()
|
||||
self.flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(n_flows):
|
||||
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
||||
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
self.post_flows = nn.ModuleList()
|
||||
self.post_flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(4):
|
||||
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.post_flows.append(modules.Flip())
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
||||
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
|
||||
x = torch.detach(x)
|
||||
x = self.pre(x)
|
||||
if g is not None:
|
||||
g = torch.detach(g)
|
||||
x = x + self.cond(g)
|
||||
x = self.convs(x, x_mask)
|
||||
x = self.proj(x) * x_mask
|
||||
|
||||
if not reverse:
|
||||
flows = self.flows
|
||||
assert w is not None
|
||||
|
||||
logdet_tot_q = 0
|
||||
h_w = self.post_pre(w)
|
||||
h_w = self.post_convs(h_w, x_mask)
|
||||
h_w = self.post_proj(h_w) * x_mask
|
||||
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||
z_q = e_q
|
||||
for flow in self.post_flows:
|
||||
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
||||
logdet_tot_q += logdet_q
|
||||
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
||||
u = torch.sigmoid(z_u) * x_mask
|
||||
z0 = (w - u) * x_mask
|
||||
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
|
||||
logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
|
||||
|
||||
logdet_tot = 0
|
||||
z0, logdet = self.log_flow(z0, x_mask)
|
||||
logdet_tot += logdet
|
||||
z = torch.cat([z0, z1], 1)
|
||||
for flow in flows:
|
||||
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
||||
logdet_tot = logdet_tot + logdet
|
||||
nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
|
||||
return nll + logq # [b]
|
||||
else:
|
||||
flows = list(reversed(self.flows))
|
||||
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
||||
for flow in flows:
|
||||
z = flow(z, x_mask, g=x, reverse=reverse)
|
||||
z0, z1 = torch.split(z, [1, 1], 1)
|
||||
logw = z0
|
||||
return logw
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
|
||||
self.norm_1 = modules.LayerNorm(filter_channels)
|
||||
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
|
||||
self.norm_2 = modules.LayerNorm(filter_channels)
|
||||
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
x = torch.detach(x)
|
||||
if g is not None:
|
||||
g = torch.detach(g)
|
||||
x = x + self.cond(g)
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_1(x)
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_2(x)
|
||||
x = self.drop(x)
|
||||
x = self.proj(x * x_mask)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
n_vocab,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout):
|
||||
super().__init__()
|
||||
self.n_vocab = n_vocab
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
||||
|
||||
self.encoder = attentions.Encoder(
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout)
|
||||
self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths):
|
||||
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
stats = self.proj(x) * x_mask
|
||||
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
return x, m, logs, x_mask
|
||||
|
||||
|
||||
class ResidualCouplingBlock(nn.Module):
|
||||
def __init__(self,
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
n_flows=4,
|
||||
gin_channels=0):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.n_flows = n_flows
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.flows = nn.ModuleList()
|
||||
for i in range(n_flows):
|
||||
self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
if not reverse:
|
||||
for flow in self.flows:
|
||||
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
||||
else:
|
||||
for flow in reversed(self.flows):
|
||||
x = flow(x, x_mask, g=g, reverse=reverse)
|
||||
return x
|
||||
|
||||
|
||||
class PosteriorEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=0):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
||||
return z, m, logs, x_mask
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(weight_norm(
|
||||
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
||||
k, u, padding=(k-u)//2)))
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel//(2**(i+1))
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||
self.ups.apply(init_weights)
|
||||
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
|
||||
def forward(self, x, g=None):
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i*self.num_kernels+j](x)
|
||||
else:
|
||||
xs += self.resblocks[i*self.num_kernels+j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
self.use_spectral_norm = use_spectral_norm
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
|
||||
])
|
||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(DiscriminatorS, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
])
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
periods = [2,3,5,7,11]
|
||||
|
||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
||||
self.discriminators = nn.ModuleList(discs)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_rs.append(fmap_r)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
|
||||
class SynthesizerTrn(nn.Module):
|
||||
"""
|
||||
Synthesizer for Training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_vocab,
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
n_speakers=0,
|
||||
gin_channels=0,
|
||||
use_sdp=True,
|
||||
**kwargs):
|
||||
|
||||
super().__init__()
|
||||
self.n_vocab = n_vocab
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.resblock = resblock
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
self.upsample_rates = upsample_rates
|
||||
self.upsample_initial_channel = upsample_initial_channel
|
||||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||||
self.segment_size = segment_size
|
||||
self.n_speakers = n_speakers
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.use_sdp = use_sdp
|
||||
|
||||
self.enc_p = TextEncoder(n_vocab,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout)
|
||||
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
||||
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
||||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
|
||||
if use_sdp:
|
||||
self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
|
||||
else:
|
||||
self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
|
||||
|
||||
if n_speakers > 1:
|
||||
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
||||
|
||||
def forward(self, x, x_lengths, y, y_lengths, sid=None):
|
||||
|
||||
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
|
||||
if self.n_speakers > 0:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
else:
|
||||
g = None
|
||||
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
||||
z_p = self.flow(z, y_mask, g=g)
|
||||
|
||||
with torch.no_grad():
|
||||
# negative cross-entropy
|
||||
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
|
||||
neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s]
|
||||
neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
||||
neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
||||
neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
|
||||
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
|
||||
|
||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||
attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
|
||||
w = attn.sum(2)
|
||||
if self.use_sdp:
|
||||
l_length = self.dp(x, x_mask, w, g=g)
|
||||
l_length = l_length / torch.sum(x_mask)
|
||||
else:
|
||||
logw_ = torch.log(w + 1e-6) * x_mask
|
||||
logw = self.dp(x, x_mask, g=g)
|
||||
l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging
|
||||
|
||||
# expand prior
|
||||
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
|
||||
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
|
||||
o = self.dec(z_slice, g=g)
|
||||
return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||
|
||||
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
|
||||
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
|
||||
if self.n_speakers > 0:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
else:
|
||||
g = None
|
||||
|
||||
if self.use_sdp:
|
||||
logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
|
||||
else:
|
||||
logw = self.dp(x, x_mask, g=g)
|
||||
w = torch.exp(logw) * x_mask * length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||
attn = commons.generate_path(w_ceil, attn_mask)
|
||||
|
||||
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||
o = self.dec((z * y_mask)[:,:,:max_len], g=g)
|
||||
return o, attn, y_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
|
||||
assert self.n_speakers > 0, "n_speakers have to be larger than 0."
|
||||
g_src = self.emb_g(sid_src).unsqueeze(-1)
|
||||
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
|
||||
z_p = self.flow(z, y_mask, g=g_src)
|
||||
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
||||
o_hat = self.dec(z_hat * y_mask, g=g_tgt)
|
||||
return o_hat, y_mask, (z, z_p, z_hat)
|
||||
|
|
@ -0,0 +1,390 @@
|
|||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
|
||||
# import commons
|
||||
# from commons import init_weights, get_padding
|
||||
# from transforms import piecewise_rational_quadratic_transform
|
||||
from utils.vits import commons
|
||||
from utils.vits.commons import init_weights, get_padding
|
||||
from utils.vits.transforms import piecewise_rational_quadratic_transform
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = nn.Parameter(torch.ones(channels))
|
||||
self.beta = nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, -1)
|
||||
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||
return x.transpose(1, -1)
|
||||
|
||||
|
||||
class ConvReluNorm(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
assert n_layers > 1, "Number of layers should be larger than 0."
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.relu_drop = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p_dropout))
|
||||
for _ in range(n_layers-1):
|
||||
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x_org = x
|
||||
for i in range(self.n_layers):
|
||||
x = self.conv_layers[i](x * x_mask)
|
||||
x = self.norm_layers[i](x)
|
||||
x = self.relu_drop(x)
|
||||
x = x_org + self.proj(x)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class DDSConv(nn.Module):
|
||||
"""
|
||||
Dialted and Depth-Separable Convolution
|
||||
"""
|
||||
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.convs_sep = nn.ModuleList()
|
||||
self.convs_1x1 = nn.ModuleList()
|
||||
self.norms_1 = nn.ModuleList()
|
||||
self.norms_2 = nn.ModuleList()
|
||||
for i in range(n_layers):
|
||||
dilation = kernel_size ** i
|
||||
padding = (kernel_size * dilation - dilation) // 2
|
||||
self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
|
||||
groups=channels, dilation=dilation, padding=padding
|
||||
))
|
||||
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
||||
self.norms_1.append(LayerNorm(channels))
|
||||
self.norms_2.append(LayerNorm(channels))
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
if g is not None:
|
||||
x = x + g
|
||||
for i in range(self.n_layers):
|
||||
y = self.convs_sep[i](x * x_mask)
|
||||
y = self.norms_1[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.convs_1x1[i](y)
|
||||
y = self.norms_2[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.drop(y)
|
||||
x = x + y
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class WN(torch.nn.Module):
|
||||
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
|
||||
super(WN, self).__init__()
|
||||
assert(kernel_size % 2 == 1)
|
||||
self.hidden_channels =hidden_channels
|
||||
self.kernel_size = kernel_size,
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.gin_channels = gin_channels
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.in_layers = torch.nn.ModuleList()
|
||||
self.res_skip_layers = torch.nn.ModuleList()
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
if gin_channels != 0:
|
||||
cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
|
||||
|
||||
for i in range(n_layers):
|
||||
dilation = dilation_rate ** i
|
||||
padding = int((kernel_size * dilation - dilation) / 2)
|
||||
in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
|
||||
dilation=dilation, padding=padding)
|
||||
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
# last one is not necessary
|
||||
if i < n_layers - 1:
|
||||
res_skip_channels = 2 * hidden_channels
|
||||
else:
|
||||
res_skip_channels = hidden_channels
|
||||
|
||||
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
||||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
|
||||
def forward(self, x, x_mask, g=None, **kwargs):
|
||||
output = torch.zeros_like(x)
|
||||
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
||||
|
||||
if g is not None:
|
||||
g = self.cond_layer(g)
|
||||
|
||||
for i in range(self.n_layers):
|
||||
x_in = self.in_layers[i](x)
|
||||
if g is not None:
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
|
||||
else:
|
||||
g_l = torch.zeros_like(x_in)
|
||||
|
||||
acts = commons.fused_add_tanh_sigmoid_multiply(
|
||||
x_in,
|
||||
g_l,
|
||||
n_channels_tensor)
|
||||
acts = self.drop(acts)
|
||||
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
if i < self.n_layers - 1:
|
||||
res_acts = res_skip_acts[:,:self.hidden_channels,:]
|
||||
x = (x + res_acts) * x_mask
|
||||
output = output + res_skip_acts[:,self.hidden_channels:,:]
|
||||
else:
|
||||
output = output + res_skip_acts
|
||||
return output * x_mask
|
||||
|
||||
def remove_weight_norm(self):
|
||||
if self.gin_channels != 0:
|
||||
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
||||
for l in self.in_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
for l in self.res_skip_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super(ResBlock1, self).__init__()
|
||||
self.convs1 = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2])))
|
||||
])
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1)))
|
||||
])
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
def forward(self, x, x_mask=None):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
if x_mask is not None:
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super(ResBlock2, self).__init__()
|
||||
self.convs = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1])))
|
||||
])
|
||||
self.convs.apply(init_weights)
|
||||
|
||||
def forward(self, x, x_mask=None):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
if x_mask is not None:
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class Log(nn.Module):
|
||||
def forward(self, x, x_mask, reverse=False, **kwargs):
|
||||
if not reverse:
|
||||
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
||||
logdet = torch.sum(-y, [1, 2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = torch.exp(x) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class Flip(nn.Module):
|
||||
def forward(self, x, *args, reverse=False, **kwargs):
|
||||
x = torch.flip(x, [1])
|
||||
if not reverse:
|
||||
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
||||
return x, logdet
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class ElementwiseAffine(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.m = nn.Parameter(torch.zeros(channels,1))
|
||||
self.logs = nn.Parameter(torch.zeros(channels,1))
|
||||
|
||||
def forward(self, x, x_mask, reverse=False, **kwargs):
|
||||
if not reverse:
|
||||
y = self.m + torch.exp(self.logs) * x
|
||||
y = y * x_mask
|
||||
logdet = torch.sum(self.logs * x_mask, [1,2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCouplingLayer(nn.Module):
|
||||
def __init__(self,
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
p_dropout=0,
|
||||
gin_channels=0,
|
||||
mean_only=False):
|
||||
assert channels % 2 == 0, "channels should be divisible by 2"
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.half_channels = channels // 2
|
||||
self.mean_only = mean_only
|
||||
|
||||
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
||||
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
|
||||
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
||||
self.post.weight.data.zero_()
|
||||
self.post.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
x0, x1 = torch.split(x, [self.half_channels]*2, 1)
|
||||
h = self.pre(x0) * x_mask
|
||||
h = self.enc(h, x_mask, g=g)
|
||||
stats = self.post(h) * x_mask
|
||||
if not self.mean_only:
|
||||
m, logs = torch.split(stats, [self.half_channels]*2, 1)
|
||||
else:
|
||||
m = stats
|
||||
logs = torch.zeros_like(m)
|
||||
|
||||
if not reverse:
|
||||
x1 = m + x1 * torch.exp(logs) * x_mask
|
||||
x = torch.cat([x0, x1], 1)
|
||||
logdet = torch.sum(logs, [1,2])
|
||||
return x, logdet
|
||||
else:
|
||||
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
||||
x = torch.cat([x0, x1], 1)
|
||||
return x
|
||||
|
||||
|
||||
class ConvFlow(nn.Module):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.num_bins = num_bins
|
||||
self.tail_bound = tail_bound
|
||||
self.half_channels = in_channels // 2
|
||||
|
||||
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
|
||||
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
|
||||
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
x0, x1 = torch.split(x, [self.half_channels]*2, 1)
|
||||
h = self.pre(x0)
|
||||
h = self.convs(h, x_mask, g=g)
|
||||
h = self.proj(h) * x_mask
|
||||
|
||||
b, c, t = x0.shape
|
||||
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
||||
|
||||
unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
|
||||
unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels)
|
||||
unnormalized_derivatives = h[..., 2 * self.num_bins:]
|
||||
|
||||
x1, logabsdet = piecewise_rational_quadratic_transform(x1,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=reverse,
|
||||
tails='linear',
|
||||
tail_bound=self.tail_bound
|
||||
)
|
||||
|
||||
x = torch.cat([x0, x1], 1) * x_mask
|
||||
logdet = torch.sum(logabsdet * x_mask, [1,2])
|
||||
if not reverse:
|
||||
return x, logdet
|
||||
else:
|
||||
return x
|
|
@ -0,0 +1,20 @@
|
|||
from numpy import zeros, int32, float32
|
||||
from torch import from_numpy
|
||||
|
||||
from .core import maximum_path_jit
|
||||
|
||||
|
||||
def maximum_path(neg_cent, mask):
|
||||
""" numba optimized version.
|
||||
neg_cent: [b, t_t, t_s]
|
||||
mask: [b, t_t, t_s]
|
||||
"""
|
||||
device = neg_cent.device
|
||||
dtype = neg_cent.dtype
|
||||
neg_cent = neg_cent.data.cpu().numpy().astype(float32)
|
||||
path = zeros(neg_cent.shape, dtype=int32)
|
||||
|
||||
t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
|
||||
t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
|
||||
maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
|
||||
return from_numpy(path).to(device=device, dtype=dtype)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,36 @@
|
|||
import numba
|
||||
|
||||
|
||||
@numba.jit(numba.void(numba.int32[:, :, ::1], numba.float32[:, :, ::1], numba.int32[::1], numba.int32[::1]),
|
||||
nopython=True, nogil=True)
|
||||
def maximum_path_jit(paths, values, t_ys, t_xs):
|
||||
b = paths.shape[0]
|
||||
max_neg_val = -1e9
|
||||
for i in range(int(b)):
|
||||
path = paths[i]
|
||||
value = values[i]
|
||||
t_y = t_ys[i]
|
||||
t_x = t_xs[i]
|
||||
|
||||
v_prev = v_cur = 0.0
|
||||
index = t_x - 1
|
||||
|
||||
for y in range(t_y):
|
||||
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
||||
if x == y:
|
||||
v_cur = max_neg_val
|
||||
else:
|
||||
v_cur = value[y - 1, x]
|
||||
if x == 0:
|
||||
if y == 0:
|
||||
v_prev = 0.
|
||||
else:
|
||||
v_prev = max_neg_val
|
||||
else:
|
||||
v_prev = value[y - 1, x - 1]
|
||||
value[y, x] += max(v_prev, v_cur)
|
||||
|
||||
for y in range(t_y - 1, -1, -1):
|
||||
path[y, index] = 1
|
||||
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
|
||||
index = index - 1
|
|
@ -0,0 +1,19 @@
|
|||
Copyright (c) 2017 Keith Ito
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
|
@ -0,0 +1,57 @@
|
|||
""" from https://github.com/keithito/tacotron """
|
||||
from . import cleaners
|
||||
from .symbols import symbols
|
||||
|
||||
|
||||
# Mappings from symbol to numeric ID and vice versa:
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
||||
|
||||
|
||||
def text_to_sequence(text, symbols, cleaner_names):
|
||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
cleaner_names: names of the cleaner functions to run the text through
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
sequence = []
|
||||
|
||||
clean_text = _clean_text(text, cleaner_names)
|
||||
for symbol in clean_text:
|
||||
if symbol not in _symbol_to_id.keys():
|
||||
continue
|
||||
symbol_id = _symbol_to_id[symbol]
|
||||
sequence += [symbol_id]
|
||||
return sequence, clean_text
|
||||
|
||||
|
||||
def cleaned_text_to_sequence(cleaned_text):
|
||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
|
||||
return sequence
|
||||
|
||||
|
||||
def sequence_to_text(sequence):
|
||||
'''Converts a sequence of IDs back to a string'''
|
||||
result = ''
|
||||
for symbol_id in sequence:
|
||||
s = _id_to_symbol[symbol_id]
|
||||
result += s
|
||||
return result
|
||||
|
||||
|
||||
def _clean_text(text, cleaner_names):
|
||||
for name in cleaner_names:
|
||||
cleaner = getattr(cleaners, name)
|
||||
if not cleaner:
|
||||
raise Exception('Unknown cleaner: %s' % name)
|
||||
text = cleaner(text)
|
||||
return text
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,475 @@
|
|||
""" from https://github.com/keithito/tacotron """
|
||||
|
||||
'''
|
||||
Cleaners are transformations that run over the input text at both training and eval time.
|
||||
|
||||
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
||||
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
||||
1. "english_cleaners" for English text
|
||||
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
||||
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
||||
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
||||
the symbols in symbols.py to match your data).
|
||||
'''
|
||||
|
||||
import re
|
||||
from unidecode import unidecode
|
||||
# import pyopenjtalk
|
||||
from jamo import h2j, j2hcj
|
||||
from pypinyin import lazy_pinyin, BOPOMOFO
|
||||
import jieba, cn2an
|
||||
|
||||
|
||||
# This is a list of Korean classifiers preceded by pure Korean numerals.
|
||||
_korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
_whitespace_re = re.compile(r'\s+')
|
||||
|
||||
# Regular expression matching Japanese without punctuation marks:
|
||||
_japanese_characters = re.compile(r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
|
||||
|
||||
# Regular expression matching non-Japanese characters or punctuation marks:
|
||||
_japanese_marks = re.compile(r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('mrs', 'misess'),
|
||||
('mr', 'mister'),
|
||||
('dr', 'doctor'),
|
||||
('st', 'saint'),
|
||||
('co', 'company'),
|
||||
('jr', 'junior'),
|
||||
('maj', 'major'),
|
||||
('gen', 'general'),
|
||||
('drs', 'doctors'),
|
||||
('rev', 'reverend'),
|
||||
('lt', 'lieutenant'),
|
||||
('hon', 'honorable'),
|
||||
('sgt', 'sergeant'),
|
||||
('capt', 'captain'),
|
||||
('esq', 'esquire'),
|
||||
('ltd', 'limited'),
|
||||
('col', 'colonel'),
|
||||
('ft', 'fort'),
|
||||
]]
|
||||
|
||||
# List of (hangul, hangul divided) pairs:
|
||||
_hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [
|
||||
('ㄳ', 'ㄱㅅ'),
|
||||
('ㄵ', 'ㄴㅈ'),
|
||||
('ㄶ', 'ㄴㅎ'),
|
||||
('ㄺ', 'ㄹㄱ'),
|
||||
('ㄻ', 'ㄹㅁ'),
|
||||
('ㄼ', 'ㄹㅂ'),
|
||||
('ㄽ', 'ㄹㅅ'),
|
||||
('ㄾ', 'ㄹㅌ'),
|
||||
('ㄿ', 'ㄹㅍ'),
|
||||
('ㅀ', 'ㄹㅎ'),
|
||||
('ㅄ', 'ㅂㅅ'),
|
||||
('ㅘ', 'ㅗㅏ'),
|
||||
('ㅙ', 'ㅗㅐ'),
|
||||
('ㅚ', 'ㅗㅣ'),
|
||||
('ㅝ', 'ㅜㅓ'),
|
||||
('ㅞ', 'ㅜㅔ'),
|
||||
('ㅟ', 'ㅜㅣ'),
|
||||
('ㅢ', 'ㅡㅣ'),
|
||||
('ㅑ', 'ㅣㅏ'),
|
||||
('ㅒ', 'ㅣㅐ'),
|
||||
('ㅕ', 'ㅣㅓ'),
|
||||
('ㅖ', 'ㅣㅔ'),
|
||||
('ㅛ', 'ㅣㅗ'),
|
||||
('ㅠ', 'ㅣㅜ')
|
||||
]]
|
||||
|
||||
# List of (Latin alphabet, hangul) pairs:
|
||||
_latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('a', '에이'),
|
||||
('b', '비'),
|
||||
('c', '시'),
|
||||
('d', '디'),
|
||||
('e', '이'),
|
||||
('f', '에프'),
|
||||
('g', '지'),
|
||||
('h', '에이치'),
|
||||
('i', '아이'),
|
||||
('j', '제이'),
|
||||
('k', '케이'),
|
||||
('l', '엘'),
|
||||
('m', '엠'),
|
||||
('n', '엔'),
|
||||
('o', '오'),
|
||||
('p', '피'),
|
||||
('q', '큐'),
|
||||
('r', '아르'),
|
||||
('s', '에스'),
|
||||
('t', '티'),
|
||||
('u', '유'),
|
||||
('v', '브이'),
|
||||
('w', '더블유'),
|
||||
('x', '엑스'),
|
||||
('y', '와이'),
|
||||
('z', '제트')
|
||||
]]
|
||||
|
||||
# List of (Latin alphabet, bopomofo) pairs:
|
||||
_latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('a', 'ㄟˉ'),
|
||||
('b', 'ㄅㄧˋ'),
|
||||
('c', 'ㄙㄧˉ'),
|
||||
('d', 'ㄉㄧˋ'),
|
||||
('e', 'ㄧˋ'),
|
||||
('f', 'ㄝˊㄈㄨˋ'),
|
||||
('g', 'ㄐㄧˋ'),
|
||||
('h', 'ㄝˇㄑㄩˋ'),
|
||||
('i', 'ㄞˋ'),
|
||||
('j', 'ㄐㄟˋ'),
|
||||
('k', 'ㄎㄟˋ'),
|
||||
('l', 'ㄝˊㄛˋ'),
|
||||
('m', 'ㄝˊㄇㄨˋ'),
|
||||
('n', 'ㄣˉ'),
|
||||
('o', 'ㄡˉ'),
|
||||
('p', 'ㄆㄧˉ'),
|
||||
('q', 'ㄎㄧㄡˉ'),
|
||||
('r', 'ㄚˋ'),
|
||||
('s', 'ㄝˊㄙˋ'),
|
||||
('t', 'ㄊㄧˋ'),
|
||||
('u', 'ㄧㄡˉ'),
|
||||
('v', 'ㄨㄧˉ'),
|
||||
('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
|
||||
('x', 'ㄝˉㄎㄨˋㄙˋ'),
|
||||
('y', 'ㄨㄞˋ'),
|
||||
('z', 'ㄗㄟˋ')
|
||||
]]
|
||||
|
||||
|
||||
# List of (bopomofo, romaji) pairs:
|
||||
_bopomofo_to_romaji = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('ㄅㄛ', 'p⁼wo'),
|
||||
('ㄆㄛ', 'pʰwo'),
|
||||
('ㄇㄛ', 'mwo'),
|
||||
('ㄈㄛ', 'fwo'),
|
||||
('ㄅ', 'p⁼'),
|
||||
('ㄆ', 'pʰ'),
|
||||
('ㄇ', 'm'),
|
||||
('ㄈ', 'f'),
|
||||
('ㄉ', 't⁼'),
|
||||
('ㄊ', 'tʰ'),
|
||||
('ㄋ', 'n'),
|
||||
('ㄌ', 'l'),
|
||||
('ㄍ', 'k⁼'),
|
||||
('ㄎ', 'kʰ'),
|
||||
('ㄏ', 'h'),
|
||||
('ㄐ', 'ʧ⁼'),
|
||||
('ㄑ', 'ʧʰ'),
|
||||
('ㄒ', 'ʃ'),
|
||||
('ㄓ', 'ʦ`⁼'),
|
||||
('ㄔ', 'ʦ`ʰ'),
|
||||
('ㄕ', 's`'),
|
||||
('ㄖ', 'ɹ`'),
|
||||
('ㄗ', 'ʦ⁼'),
|
||||
('ㄘ', 'ʦʰ'),
|
||||
('ㄙ', 's'),
|
||||
('ㄚ', 'a'),
|
||||
('ㄛ', 'o'),
|
||||
('ㄜ', 'ə'),
|
||||
('ㄝ', 'e'),
|
||||
('ㄞ', 'ai'),
|
||||
('ㄟ', 'ei'),
|
||||
('ㄠ', 'au'),
|
||||
('ㄡ', 'ou'),
|
||||
('ㄧㄢ', 'yeNN'),
|
||||
('ㄢ', 'aNN'),
|
||||
('ㄧㄣ', 'iNN'),
|
||||
('ㄣ', 'əNN'),
|
||||
('ㄤ', 'aNg'),
|
||||
('ㄧㄥ', 'iNg'),
|
||||
('ㄨㄥ', 'uNg'),
|
||||
('ㄩㄥ', 'yuNg'),
|
||||
('ㄥ', 'əNg'),
|
||||
('ㄦ', 'əɻ'),
|
||||
('ㄧ', 'i'),
|
||||
('ㄨ', 'u'),
|
||||
('ㄩ', 'ɥ'),
|
||||
('ˉ', '→'),
|
||||
('ˊ', '↑'),
|
||||
('ˇ', '↓↑'),
|
||||
('ˋ', '↓'),
|
||||
('˙', ''),
|
||||
(',', ','),
|
||||
('。', '.'),
|
||||
('!', '!'),
|
||||
('?', '?'),
|
||||
('—', '-')
|
||||
]]
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, ' ', text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def japanese_to_romaji_with_accent(text):
|
||||
'''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
|
||||
sentences = re.split(_japanese_marks, text)
|
||||
marks = re.findall(_japanese_marks, text)
|
||||
text = ''
|
||||
for i, sentence in enumerate(sentences):
|
||||
if re.match(_japanese_characters, sentence):
|
||||
if text!='':
|
||||
text+=' '
|
||||
labels = pyopenjtalk.extract_fullcontext(sentence)
|
||||
for n, label in enumerate(labels):
|
||||
phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
|
||||
if phoneme not in ['sil','pau']:
|
||||
text += phoneme.replace('ch','ʧ').replace('sh','ʃ').replace('cl','Q')
|
||||
else:
|
||||
continue
|
||||
n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
|
||||
a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
|
||||
a2 = int(re.search(r"\+(\d+)\+", label).group(1))
|
||||
a3 = int(re.search(r"\+(\d+)/", label).group(1))
|
||||
if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil','pau']:
|
||||
a2_next=-1
|
||||
else:
|
||||
a2_next = int(re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
|
||||
# Accent phrase boundary
|
||||
if a3 == 1 and a2_next == 1:
|
||||
text += ' '
|
||||
# Falling
|
||||
elif a1 == 0 and a2_next == a2 + 1 and a2 != n_moras:
|
||||
text += '↓'
|
||||
# Rising
|
||||
elif a2 == 1 and a2_next == 2:
|
||||
text += '↑'
|
||||
if i<len(marks):
|
||||
text += unidecode(marks[i]).replace(' ','')
|
||||
return text
|
||||
|
||||
|
||||
def latin_to_hangul(text):
|
||||
for regex, replacement in _latin_to_hangul:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def divide_hangul(text):
|
||||
for regex, replacement in _hangul_divided:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def hangul_number(num, sino=True):
|
||||
'''Reference https://github.com/Kyubyong/g2pK'''
|
||||
num = re.sub(',', '', num)
|
||||
|
||||
if num == '0':
|
||||
return '영'
|
||||
if not sino and num == '20':
|
||||
return '스무'
|
||||
|
||||
digits = '123456789'
|
||||
names = '일이삼사오육칠팔구'
|
||||
digit2name = {d: n for d, n in zip(digits, names)}
|
||||
|
||||
modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉'
|
||||
decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔'
|
||||
digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())}
|
||||
digit2dec = {d: dec for d, dec in zip(digits, decimals.split())}
|
||||
|
||||
spelledout = []
|
||||
for i, digit in enumerate(num):
|
||||
i = len(num) - i - 1
|
||||
if sino:
|
||||
if i == 0:
|
||||
name = digit2name.get(digit, '')
|
||||
elif i == 1:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
name = name.replace('일십', '십')
|
||||
else:
|
||||
if i == 0:
|
||||
name = digit2mod.get(digit, '')
|
||||
elif i == 1:
|
||||
name = digit2dec.get(digit, '')
|
||||
if digit == '0':
|
||||
if i % 4 == 0:
|
||||
last_three = spelledout[-min(3, len(spelledout)):]
|
||||
if ''.join(last_three) == '':
|
||||
spelledout.append('')
|
||||
continue
|
||||
else:
|
||||
spelledout.append('')
|
||||
continue
|
||||
if i == 2:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
name = name.replace('일백', '백')
|
||||
elif i == 3:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
name = name.replace('일천', '천')
|
||||
elif i == 4:
|
||||
name = digit2name.get(digit, '') + '만'
|
||||
name = name.replace('일만', '만')
|
||||
elif i == 5:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
name = name.replace('일십', '십')
|
||||
elif i == 6:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
name = name.replace('일백', '백')
|
||||
elif i == 7:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
name = name.replace('일천', '천')
|
||||
elif i == 8:
|
||||
name = digit2name.get(digit, '') + '억'
|
||||
elif i == 9:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
elif i == 10:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
elif i == 11:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
elif i == 12:
|
||||
name = digit2name.get(digit, '') + '조'
|
||||
elif i == 13:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
elif i == 14:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
elif i == 15:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
spelledout.append(name)
|
||||
return ''.join(elem for elem in spelledout)
|
||||
|
||||
|
||||
def number_to_hangul(text):
|
||||
'''Reference https://github.com/Kyubyong/g2pK'''
|
||||
tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text))
|
||||
for token in tokens:
|
||||
num, classifier = token
|
||||
if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers:
|
||||
spelledout = hangul_number(num, sino=False)
|
||||
else:
|
||||
spelledout = hangul_number(num, sino=True)
|
||||
text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}')
|
||||
# digit by digit for remaining digits
|
||||
digits = '0123456789'
|
||||
names = '영일이삼사오육칠팔구'
|
||||
for d, n in zip(digits, names):
|
||||
text = text.replace(d, n)
|
||||
return text
|
||||
|
||||
|
||||
def number_to_chinese(text):
|
||||
numbers = re.findall(r'\d+(?:\.?\d+)?', text)
|
||||
for number in numbers:
|
||||
text = text.replace(number, cn2an.an2cn(number),1)
|
||||
return text
|
||||
|
||||
|
||||
def chinese_to_bopomofo(text):
|
||||
text=text.replace('、',',').replace(';',',').replace(':',',')
|
||||
words=jieba.lcut(text,cut_all=False)
|
||||
text=''
|
||||
for word in words:
|
||||
bopomofos=lazy_pinyin(word,BOPOMOFO)
|
||||
if not re.search('[\u4e00-\u9fff]',word):
|
||||
text+=word
|
||||
continue
|
||||
for i in range(len(bopomofos)):
|
||||
if re.match('[\u3105-\u3129]',bopomofos[i][-1]):
|
||||
bopomofos[i]+='ˉ'
|
||||
if text!='':
|
||||
text+=' '
|
||||
text+=''.join(bopomofos)
|
||||
return text
|
||||
|
||||
|
||||
def latin_to_bopomofo(text):
|
||||
for regex, replacement in _latin_to_bopomofo:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def bopomofo_to_romaji(text):
|
||||
for regex, replacement in _bopomofo_to_romaji:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def transliteration_cleaners(text):
|
||||
'''Pipeline for non-English text that transliterates to ASCII.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def japanese_cleaners(text):
|
||||
text=japanese_to_romaji_with_accent(text)
|
||||
if re.match('[A-Za-z]',text[-1]):
|
||||
text += '.'
|
||||
return text
|
||||
|
||||
|
||||
def japanese_cleaners2(text):
|
||||
return japanese_cleaners(text).replace('ts','ʦ').replace('...','…')
|
||||
|
||||
|
||||
def korean_cleaners(text):
|
||||
'''Pipeline for Korean text'''
|
||||
text = latin_to_hangul(text)
|
||||
text = number_to_hangul(text)
|
||||
text = j2hcj(h2j(text))
|
||||
text = divide_hangul(text)
|
||||
if re.match('[\u3131-\u3163]',text[-1]):
|
||||
text += '.'
|
||||
return text
|
||||
|
||||
|
||||
def chinese_cleaners(text):
|
||||
'''Pipeline for Chinese text'''
|
||||
text=number_to_chinese(text)
|
||||
text=chinese_to_bopomofo(text)
|
||||
text=latin_to_bopomofo(text)
|
||||
if re.match('[ˉˊˇˋ˙]',text[-1]):
|
||||
text += '。'
|
||||
return text
|
||||
|
||||
|
||||
def zh_ja_mixture_cleaners(text):
|
||||
chinese_texts=re.findall(r'\[ZH\].*?\[ZH\]',text)
|
||||
japanese_texts=re.findall(r'\[JA\].*?\[JA\]',text)
|
||||
for chinese_text in chinese_texts:
|
||||
cleaned_text=number_to_chinese(chinese_text[4:-4])
|
||||
cleaned_text=chinese_to_bopomofo(cleaned_text)
|
||||
cleaned_text=latin_to_bopomofo(cleaned_text)
|
||||
cleaned_text=bopomofo_to_romaji(cleaned_text)
|
||||
cleaned_text=re.sub('i[aoe]',lambda x:'y'+x.group(0)[1:],cleaned_text)
|
||||
cleaned_text=re.sub('u[aoəe]',lambda x:'w'+x.group(0)[1:],cleaned_text)
|
||||
cleaned_text=re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑]+)',lambda x:x.group(1)+'ɹ`'+x.group(2),cleaned_text).replace('ɻ','ɹ`')
|
||||
cleaned_text=re.sub('([ʦs][⁼ʰ]?)([→↓↑]+)',lambda x:x.group(1)+'ɹ'+x.group(2),cleaned_text)
|
||||
text = text.replace(chinese_text,cleaned_text+' ',1)
|
||||
for japanese_text in japanese_texts:
|
||||
cleaned_text=japanese_to_romaji_with_accent(japanese_text[4:-4]).replace('ts','ʦ').replace('u','ɯ').replace('...','…')
|
||||
text = text.replace(japanese_text,cleaned_text+' ',1)
|
||||
text=text[:-1]
|
||||
if re.match('[A-Za-zɯɹəɥ→↓↑]',text[-1]):
|
||||
text += '.'
|
||||
return text
|
|
@ -0,0 +1,39 @@
|
|||
'''
|
||||
Defines the set of symbols used in text input to the model.
|
||||
'''
|
||||
|
||||
'''# japanese_cleaners
|
||||
_pad = '_'
|
||||
_punctuation = ',.!?-'
|
||||
_letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
|
||||
'''
|
||||
|
||||
'''# japanese_cleaners2
|
||||
_pad = '_'
|
||||
_punctuation = ',.!?-~…'
|
||||
_letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
|
||||
'''
|
||||
|
||||
'''# korean_cleaners
|
||||
_pad = '_'
|
||||
_punctuation = ',.!?…~'
|
||||
_letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
|
||||
'''
|
||||
|
||||
'''# chinese_cleaners
|
||||
_pad = '_'
|
||||
_punctuation = ',。!?—…'
|
||||
_letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
|
||||
'''
|
||||
|
||||
# zh_ja_mixture_cleaners
|
||||
_pad = '_'
|
||||
_punctuation = ',.!?-~…'
|
||||
_letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
|
||||
|
||||
|
||||
# Export all symbols:
|
||||
symbols = [_pad] + list(_punctuation) + list(_letters)
|
||||
|
||||
# Special symbol ids
|
||||
SPACE_ID = symbols.index(" ")
|
|
@ -0,0 +1,193 @@
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
||||
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
||||
DEFAULT_MIN_DERIVATIVE = 1e-3
|
||||
|
||||
|
||||
def piecewise_rational_quadratic_transform(inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails=None,
|
||||
tail_bound=1.,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
||||
|
||||
if tails is None:
|
||||
spline_fn = rational_quadratic_spline
|
||||
spline_kwargs = {}
|
||||
else:
|
||||
spline_fn = unconstrained_rational_quadratic_spline
|
||||
spline_kwargs = {
|
||||
'tails': tails,
|
||||
'tail_bound': tail_bound
|
||||
}
|
||||
|
||||
outputs, logabsdet = spline_fn(
|
||||
inputs=inputs,
|
||||
unnormalized_widths=unnormalized_widths,
|
||||
unnormalized_heights=unnormalized_heights,
|
||||
unnormalized_derivatives=unnormalized_derivatives,
|
||||
inverse=inverse,
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative,
|
||||
**spline_kwargs
|
||||
)
|
||||
return outputs, logabsdet
|
||||
|
||||
|
||||
def searchsorted(bin_locations, inputs, eps=1e-6):
|
||||
bin_locations[..., -1] += eps
|
||||
return torch.sum(
|
||||
inputs[..., None] >= bin_locations,
|
||||
dim=-1
|
||||
) - 1
|
||||
|
||||
|
||||
def unconstrained_rational_quadratic_spline(inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails='linear',
|
||||
tail_bound=1.,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
||||
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
||||
outside_interval_mask = ~inside_interval_mask
|
||||
|
||||
outputs = torch.zeros_like(inputs)
|
||||
logabsdet = torch.zeros_like(inputs)
|
||||
|
||||
if tails == 'linear':
|
||||
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
||||
constant = np.log(np.exp(1 - min_derivative) - 1)
|
||||
unnormalized_derivatives[..., 0] = constant
|
||||
unnormalized_derivatives[..., -1] = constant
|
||||
|
||||
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
||||
logabsdet[outside_interval_mask] = 0
|
||||
else:
|
||||
raise RuntimeError('{} tails are not implemented.'.format(tails))
|
||||
|
||||
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
|
||||
inputs=inputs[inside_interval_mask],
|
||||
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
||||
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
||||
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
||||
inverse=inverse,
|
||||
left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative
|
||||
)
|
||||
|
||||
return outputs, logabsdet
|
||||
|
||||
def rational_quadratic_spline(inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
left=0., right=1., bottom=0., top=1.,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
||||
if torch.min(inputs) < left or torch.max(inputs) > right:
|
||||
raise ValueError('Input to a transform is not within its domain')
|
||||
|
||||
num_bins = unnormalized_widths.shape[-1]
|
||||
|
||||
if min_bin_width * num_bins > 1.0:
|
||||
raise ValueError('Minimal bin width too large for the number of bins')
|
||||
if min_bin_height * num_bins > 1.0:
|
||||
raise ValueError('Minimal bin height too large for the number of bins')
|
||||
|
||||
widths = F.softmax(unnormalized_widths, dim=-1)
|
||||
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
||||
cumwidths = torch.cumsum(widths, dim=-1)
|
||||
cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
|
||||
cumwidths = (right - left) * cumwidths + left
|
||||
cumwidths[..., 0] = left
|
||||
cumwidths[..., -1] = right
|
||||
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
||||
|
||||
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
||||
|
||||
heights = F.softmax(unnormalized_heights, dim=-1)
|
||||
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
||||
cumheights = torch.cumsum(heights, dim=-1)
|
||||
cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
|
||||
cumheights = (top - bottom) * cumheights + bottom
|
||||
cumheights[..., 0] = bottom
|
||||
cumheights[..., -1] = top
|
||||
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
||||
|
||||
if inverse:
|
||||
bin_idx = searchsorted(cumheights, inputs)[..., None]
|
||||
else:
|
||||
bin_idx = searchsorted(cumwidths, inputs)[..., None]
|
||||
|
||||
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
||||
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
||||
delta = heights / widths
|
||||
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
||||
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
if inverse:
|
||||
a = (((inputs - input_cumheights) * (input_derivatives
|
||||
+ input_derivatives_plus_one
|
||||
- 2 * input_delta)
|
||||
+ input_heights * (input_delta - input_derivatives)))
|
||||
b = (input_heights * input_derivatives
|
||||
- (inputs - input_cumheights) * (input_derivatives
|
||||
+ input_derivatives_plus_one
|
||||
- 2 * input_delta))
|
||||
c = - input_delta * (inputs - input_cumheights)
|
||||
|
||||
discriminant = b.pow(2) - 4 * a * c
|
||||
assert (discriminant >= 0).all()
|
||||
|
||||
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
||||
outputs = root * input_bin_widths + input_cumwidths
|
||||
|
||||
theta_one_minus_theta = root * (1 - root)
|
||||
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta)
|
||||
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
|
||||
+ 2 * input_delta * theta_one_minus_theta
|
||||
+ input_derivatives * (1 - root).pow(2))
|
||||
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||
|
||||
return outputs, -logabsdet
|
||||
else:
|
||||
theta = (inputs - input_cumwidths) / input_bin_widths
|
||||
theta_one_minus_theta = theta * (1 - theta)
|
||||
|
||||
numerator = input_heights * (input_delta * theta.pow(2)
|
||||
+ input_derivatives * theta_one_minus_theta)
|
||||
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta)
|
||||
outputs = input_cumheights + numerator / denominator
|
||||
|
||||
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
|
||||
+ 2 * input_delta * theta_one_minus_theta
|
||||
+ input_derivatives * (1 - theta).pow(2))
|
||||
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||
|
||||
return outputs, logabsdet
|
|
@ -0,0 +1,225 @@
|
|||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
import subprocess
|
||||
import numpy as np
|
||||
import librosa
|
||||
import torch
|
||||
|
||||
MATPLOTLIB_FLAG = False
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||
logger = logging
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_path, model, optimizer=None):
|
||||
assert os.path.isfile(checkpoint_path)
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
||||
iteration = checkpoint_dict['iteration']
|
||||
learning_rate = checkpoint_dict['learning_rate']
|
||||
if optimizer is not None:
|
||||
optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
||||
saved_state_dict = checkpoint_dict['model']
|
||||
if hasattr(model, 'module'):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict= {}
|
||||
for k, v in state_dict.items():
|
||||
try:
|
||||
new_state_dict[k] = saved_state_dict[k]
|
||||
except:
|
||||
logger.info("%s is not in the checkpoint" % k)
|
||||
new_state_dict[k] = v
|
||||
if hasattr(model, 'module'):
|
||||
model.module.load_state_dict(new_state_dict)
|
||||
else:
|
||||
model.load_state_dict(new_state_dict)
|
||||
logger.info("Loaded checkpoint '{}' (iteration {})" .format(
|
||||
checkpoint_path, iteration))
|
||||
return model, optimizer, learning_rate, iteration
|
||||
|
||||
|
||||
def plot_spectrogram_to_numpy(spectrogram):
|
||||
global MATPLOTLIB_FLAG
|
||||
if not MATPLOTLIB_FLAG:
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
MATPLOTLIB_FLAG = True
|
||||
mpl_logger = logging.getLogger('matplotlib')
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10,2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
||||
interpolation='none')
|
||||
plt.colorbar(im, ax=ax)
|
||||
plt.xlabel("Frames")
|
||||
plt.ylabel("Channels")
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
|
||||
def plot_alignment_to_numpy(alignment, info=None):
|
||||
global MATPLOTLIB_FLAG
|
||||
if not MATPLOTLIB_FLAG:
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
MATPLOTLIB_FLAG = True
|
||||
mpl_logger = logging.getLogger('matplotlib')
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
|
||||
interpolation='none')
|
||||
fig.colorbar(im, ax=ax)
|
||||
xlabel = 'Decoder timestep'
|
||||
if info is not None:
|
||||
xlabel += '\n\n' + info
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel('Encoder timestep')
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
|
||||
def load_audio_to_torch(full_path, target_sampling_rate):
|
||||
audio, sampling_rate = librosa.load(full_path, sr=target_sampling_rate, mono=True)
|
||||
return torch.FloatTensor(audio.astype(np.float32))
|
||||
|
||||
|
||||
def load_filepaths_and_text(filename, split="|"):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
filepaths_and_text = [line.strip().split(split) for line in f]
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
def get_hparams(init=True):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
|
||||
help='JSON file for configuration')
|
||||
parser.add_argument('-m', '--model', type=str, required=True,
|
||||
help='Model name')
|
||||
|
||||
args = parser.parse_args()
|
||||
model_dir = os.path.join("./logs", args.model)
|
||||
|
||||
if not os.path.exists(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
|
||||
config_path = args.config
|
||||
config_save_path = os.path.join(model_dir, "config.json")
|
||||
if init:
|
||||
with open(config_path, "r") as f:
|
||||
data = f.read()
|
||||
with open(config_save_path, "w") as f:
|
||||
f.write(data)
|
||||
else:
|
||||
with open(config_save_path, "r") as f:
|
||||
data = f.read()
|
||||
config = json.loads(data)
|
||||
|
||||
hparams = HParams(**config)
|
||||
hparams.model_dir = model_dir
|
||||
return hparams
|
||||
|
||||
|
||||
def get_hparams_from_dir(model_dir):
|
||||
config_save_path = os.path.join(model_dir, "config.json")
|
||||
with open(config_save_path, "r") as f:
|
||||
data = f.read()
|
||||
config = json.loads(data)
|
||||
|
||||
hparams =HParams(**config)
|
||||
hparams.model_dir = model_dir
|
||||
return hparams
|
||||
|
||||
|
||||
def get_hparams_from_file(config_path):
|
||||
with open(config_path, "r") as f:
|
||||
data = f.read()
|
||||
config = json.loads(data)
|
||||
|
||||
hparams =HParams(**config)
|
||||
return hparams
|
||||
|
||||
|
||||
def check_git_hash(model_dir):
|
||||
source_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
if not os.path.exists(os.path.join(source_dir, ".git")):
|
||||
logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
|
||||
source_dir
|
||||
))
|
||||
return
|
||||
|
||||
cur_hash = subprocess.getoutput("git rev-parse HEAD")
|
||||
|
||||
path = os.path.join(model_dir, "githash")
|
||||
if os.path.exists(path):
|
||||
saved_hash = open(path).read()
|
||||
if saved_hash != cur_hash:
|
||||
logger.warn("git hash values are different. {}(saved) != {}(current)".format(
|
||||
saved_hash[:8], cur_hash[:8]))
|
||||
else:
|
||||
open(path, "w").write(cur_hash)
|
||||
|
||||
|
||||
def get_logger(model_dir, filename="train.log"):
|
||||
global logger
|
||||
logger = logging.getLogger(os.path.basename(model_dir))
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
|
||||
if not os.path.exists(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
h = logging.FileHandler(os.path.join(model_dir, filename))
|
||||
h.setLevel(logging.DEBUG)
|
||||
h.setFormatter(formatter)
|
||||
logger.addHandler(h)
|
||||
return logger
|
||||
|
||||
|
||||
class HParams():
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if type(v) == dict:
|
||||
v = HParams(**v)
|
||||
self[k] = v
|
||||
|
||||
def keys(self):
|
||||
return self.__dict__.keys()
|
||||
|
||||
def items(self):
|
||||
return self.__dict__.items()
|
||||
|
||||
def values(self):
|
||||
return self.__dict__.values()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__dict__)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
return setattr(self, key, value)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.__dict__
|
||||
|
||||
def __repr__(self):
|
||||
return self.__dict__.__repr__()
|
|
@ -0,0 +1,114 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import LongTensor
|
||||
from typing import Optional
|
||||
import soundfile as sf
|
||||
# vits
|
||||
from .vits import utils, commons
|
||||
from .vits.models import SynthesizerTrn
|
||||
from .vits.text import text_to_sequence
|
||||
|
||||
def tts_model_init(model_path='./vits_model', device='cuda'):
|
||||
hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))
|
||||
# hps_ms = utils.get_hparams_from_file('vits_model/config.json')
|
||||
net_g_ms = SynthesizerTrn(
|
||||
len(hps_ms.symbols),
|
||||
hps_ms.data.filter_length // 2 + 1,
|
||||
hps_ms.train.segment_size // hps_ms.data.hop_length,
|
||||
n_speakers=hps_ms.data.n_speakers,
|
||||
**hps_ms.model)
|
||||
net_g_ms = net_g_ms.eval().to(device)
|
||||
speakers = hps_ms.speakers
|
||||
utils.load_checkpoint(os.path.join(model_path, 'G_953000.pth'), net_g_ms, None)
|
||||
# utils.load_checkpoint('vits_model/G_953000.pth', net_g_ms, None)
|
||||
return hps_ms, net_g_ms, speakers
|
||||
|
||||
class TextToSpeech:
|
||||
def __init__(self,
|
||||
model_path="./utils/vits_model",
|
||||
device='cuda',
|
||||
RATE=22050,
|
||||
debug=False,
|
||||
):
|
||||
self.debug = debug
|
||||
self.RATE = RATE
|
||||
self.device = torch.device(device)
|
||||
self.limitation = os.getenv("SYSTEM") == "spaces" # 在huggingface spaces中限制文本和音频长度
|
||||
self.hps_ms, self.net_g_ms, self.speakers = self._tts_model_init(model_path)
|
||||
|
||||
def _tts_model_init(self, model_path):
|
||||
hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))
|
||||
net_g_ms = SynthesizerTrn(
|
||||
len(hps_ms.symbols),
|
||||
hps_ms.data.filter_length // 2 + 1,
|
||||
hps_ms.train.segment_size // hps_ms.data.hop_length,
|
||||
n_speakers=hps_ms.data.n_speakers,
|
||||
**hps_ms.model)
|
||||
net_g_ms = net_g_ms.eval().to(self.device)
|
||||
speakers = hps_ms.speakers
|
||||
utils.load_checkpoint(os.path.join(model_path, 'G_953000.pth'), net_g_ms, None)
|
||||
if self.debug:
|
||||
print("Model loaded.")
|
||||
return hps_ms, net_g_ms, speakers
|
||||
|
||||
def _get_text(self, text):
|
||||
text_norm, clean_text = text_to_sequence(text, self.hps_ms.symbols, self.hps_ms.data.text_cleaners)
|
||||
if self.hps_ms.data.add_blank:
|
||||
text_norm = commons.intersperse(text_norm, 0)
|
||||
text_norm = LongTensor(text_norm)
|
||||
return text_norm, clean_text
|
||||
|
||||
def _preprocess_text(self, text, language):
|
||||
if language == 0:
|
||||
return f"[ZH]{text}[ZH]"
|
||||
elif language == 1:
|
||||
return f"[JA]{text}[JA]"
|
||||
return text
|
||||
|
||||
def _generate_audio(self, text, speaker_id, noise_scale, noise_scale_w, length_scale):
|
||||
import time
|
||||
start_time = time.time()
|
||||
stn_tst, clean_text = self._get_text(text)
|
||||
with torch.no_grad():
|
||||
x_tst = stn_tst.unsqueeze(0).to(self.device)
|
||||
x_tst_lengths = LongTensor([stn_tst.size(0)]).to(self.device)
|
||||
speaker_id = LongTensor([speaker_id]).to(self.device)
|
||||
audio = self.net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
|
||||
length_scale=length_scale)[0][0, 0].data.cpu().float().numpy()
|
||||
if self.debug:
|
||||
print(f"Synthesis time: {time.time() - start_time} s")
|
||||
return audio
|
||||
|
||||
def synthesize(self, text, tts_info,target_se: Optional[np.ndarray]=None, save_audio=False, return_bytes=True):
|
||||
if not len(text):
|
||||
return b''
|
||||
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
|
||||
if len(text) > 100 and self.limitation:
|
||||
return f"输入文字过长!{len(text)}>100", None
|
||||
text = self._preprocess_text(text, tts_info['language'])
|
||||
audio = self._generate_audio(text, tts_info['speaker_id'], tts_info['noise_scale'], tts_info['noise_scale_w'], tts_info['length_scale'])
|
||||
if return_bytes:
|
||||
audio = self.convert_numpy_to_bytes(audio)
|
||||
return audio
|
||||
|
||||
def convert_numpy_to_bytes(self, audio_data):
|
||||
if isinstance(audio_data, np.ndarray):
|
||||
if audio_data.dtype == np.dtype('float32'):
|
||||
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
|
||||
audio_data = audio_data.tobytes()
|
||||
return audio_data
|
||||
else:
|
||||
raise TypeError("audio_data must be a numpy array")
|
||||
|
||||
def save_audio(self, audio, sample_rate, file_name='output_file.wav'):
|
||||
sf.write(file_name, audio, samplerate=sample_rate)
|
||||
print(f"VITS Audio saved to {file_name}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
import websockets
|
||||
import datetime
|
||||
import hashlib
|
||||
import base64
|
||||
import hmac
|
||||
import json
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from config import Config
|
||||
|
||||
def generate_xf_asr_url():
|
||||
#设置讯飞流式听写API相关参数
|
||||
APIKey = Config.XF_ASR.API_KEY
|
||||
APISecret = Config.XF_ASR.API_SECRET
|
||||
|
||||
#鉴权并创建websocket_url
|
||||
url = 'wss://ws-api.xfyun.cn/v2/iat'
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v2/iat " + "HTTP/1.1"
|
||||
signature_sha = hmac.new(APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||
APIKey, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": "ws-api.xfyun.cn"
|
||||
}
|
||||
url = url + '?' + urlencode(v)
|
||||
return url
|
||||
|
||||
|
||||
def make_first_frame(buf):
|
||||
first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vad_eos":10000},
|
||||
"data":{"status":0,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
|
||||
return json.dumps(first_frame)
|
||||
|
||||
def make_continue_frame(buf):
|
||||
continue_frame = {"data":{"status":1,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
|
||||
return json.dumps(continue_frame)
|
||||
|
||||
def make_last_frame(buf):
|
||||
last_frame = {"data":{"status":2,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
|
||||
return json.dumps(last_frame)
|
||||
|
||||
def parse_xfasr_recv(message):
|
||||
code = message['code']
|
||||
if code!=0:
|
||||
raise Exception("讯飞ASR错误码:"+str(code))
|
||||
else:
|
||||
data = message['data']['result']['ws']
|
||||
result = ""
|
||||
for i in data:
|
||||
for w in i['cw']:
|
||||
result += w['w']
|
||||
return result
|
||||
|
||||
async def xf_asr_websocket_factory():
|
||||
url = generate_xf_asr_url()
|
||||
return await websockets.connect(url)
|
Loading…
Reference in New Issue