1
0
Fork 0
This commit is contained in:
Killua777 2024-06-09 22:54:13 +08:00
commit 4c255ebd59
51 changed files with 3424 additions and 0 deletions

1
app/__init__.py Normal file
View File

@ -0,0 +1 @@
from . import *

41
app/abstract.py Normal file
View File

@ -0,0 +1,41 @@
from abc import ABC, abstractmethod
#------ 抽象 ASR, LLM, TTS 类 ------ #
class ASR(ABC):
@abstractmethod
async def stream_recognize(self, chunk):
pass
class LLM(ABC):
@abstractmethod
def chat(self, assistant, prompt, db):
pass
class TTS(ABC):
@abstractmethod
def synthetize(self, assistant, text):
pass
# --------------------------------- #
# ----------- 抽象服务类 ----------- #
class UserAudioService(ABC):
@abstractmethod
def user_audio_process(self, audio:str, **kwargs) -> str:
pass
class PromptService(ABC):
@abstractmethod
def prompt_process(self, prompt:str, **kwargs) -> str:
pass
class LLMMsgService(ABC):
@abstractmethod
def llm_msg_process(self, llm_chunk:list, **kwargs) -> list:
pass
class TTSAudioService(ABC):
@abstractmethod
def tts_audio_process(self, audio:bytes, **kwargs) -> bytes:
pass
# --------------------------------- #

285
app/concrete.py Normal file
View File

@ -0,0 +1,285 @@
from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
from .model import Assistant
from .abstract import *
from .public import *
from utils.vits_utils import TextToSpeech
from config import Config
import aiohttp
import asyncio
import struct
import base64
import json
# ----------- 初始化vits ----------- #
vits = TextToSpeech()
# ---------------------------------- #
#------ 具体 ASR, LLM, TTS 类 ------ #
FIRST_FRAME = 1
CONTINUE_FRAME =2
LAST_FRAME =3
class XF_ASR(ASR):
def __init__(self):
self.websocket = None
self.current_message = ""
self.audio = ""
self.status = FIRST_FRAME
self.segment_duration_threshold = 25 #超时时间为25秒
self.segment_start_time = None
async def stream_recognize(self, chunk):
if self.websocket is None: #如果websocket未建立则建立一个新的连接
self.websocket = await xf_asr_websocket_factory()
if self.segment_start_time is None: #如果是第一段,则记录开始时间
self.segment_start_time = asyncio.get_event_loop().time()
if chunk['meta_info']['is_end']:
self.status = LAST_FRAME
audio_data = chunk['audio']
self.audio += audio_data
if self.status == FIRST_FRAME: #发送第一帧
await self.websocket.send(make_first_frame(audio_data))
self.status = CONTINUE_FRAME
elif self.status == CONTINUE_FRAME: #发送中间帧
await self.websocket.send(make_continue_frame(audio_data))
elif self.status == LAST_FRAME: #发送最后一帧
await self.websocket.send(make_last_frame(audio_data))
self.current_message += parse_xfasr_recv(json.loads(await self.websocket.recv()))
if self.current_message == "":
raise AsrResultNoneError()
await self.websocket.close()
print("语音识别结束,用户消息:", self.current_message)
return [{"text":self.current_message, "audio":self.audio}]
current_time = asyncio.get_event_loop().time()
if current_time - self.segment_start_time > self.segment_duration_threshold: #超时,发送最后一帧并重置状态
await self.websocket.send(make_last_frame())
self.current_message += parse_xfasr_recv(await self.websocket.recv())
await self.websocket.close()
self.websocket = await xf_asr_websocket_factory()
self.status = FIRST_FRAME
self.segment_start_time = current_time
return []
class MINIMAX_LLM(LLM):
def __init__(self):
self.token = 0
async def chat(self, assistant, prompt, db):
llm_info = json.loads(assistant.llm_info)
messages = json.loads(assistant.messages)
messages.append({"role":"user","content":prompt})
assistant.messages = json.dumps(messages)
payload = json.dumps({
"model": llm_info['model'],
"stream": True,
"messages": messages,
"max_tokens": llm_info['max_tokens'],
"temperature": llm_info['temperature'],
"top_p": llm_info['top_p'],
})
headers = {
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json'
}
async with aiohttp.ClientSession() as client:
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response:
async for chunk in response.content.iter_any():
try:
chunk_msg = self.__parseChunk(chunk)
msg_frame = {"is_end":False,"code":200,"msg":chunk_msg}
yield msg_frame
except LLMResponseEnd:
msg_frame = {"is_end":True,"msg":""}
if self.token > llm_info['max_tokens'] * 0.8:
msg_frame['msg'] = 'max_token reached'
as_query = db.query(Assistant).filter(Assistant.id == assistant.id).first()
as_query.messages = json.dumps([{'role':'system','content':assistant.system_prompt},{'role':'user','content':prompt}])
db.commit()
assistant.messages = as_query.messages
yield msg_frame
def __parseChunk(self, llm_chunk):
result = ""
data=json.loads(llm_chunk.decode('utf-8')[6:])
if data["object"] == "chat.completion": #如果是结束帧
self.token = data['usage']['total_tokens']
raise LLMResponseEnd()
elif data['object'] == 'chat.completion.chunk':
for choice in data['choices']:
result += choice['delta']['content']
else:
raise UnkownLLMFrame()
return result
class VITS_TTS(TTS):
def __init__(self):
pass
def synthetize(self, assistant, text):
tts_info = json.loads(assistant.tts_info)
return vits.synthesize(text, tts_info)
# --------------------------------- #
# ------ ASR, LLM, TTS 工厂类 ------ #
class ASRFactory:
def create_asr(self,asr_type:str) -> ASR:
if asr_type == 'XF':
return XF_ASR()
class LLMFactory:
def create_llm(self,llm_type:str) -> LLM:
if llm_type == 'MINIMAX':
return MINIMAX_LLM()
class TTSFactory:
def create_tts(self,tts_type:str) -> TTS:
if tts_type == 'VITS':
return VITS_TTS()
# --------------------------------- #
# ----------- 具体服务类 ----------- #
# 从单说话人的asr_results中取出第一个结果作为prompt
class BasicPromptService(PromptService):
def prompt_process(self, prompt:str, **kwargs):
if 'asr_result' in kwargs:
return kwargs['asr_result'][0]['text'], kwargs
else:
raise NoAsrResultsError()
class UserAudioRecordService(UserAudioService):
def user_audio_process(self, audio:str, **kwargs):
audio_decoded = base64.b64decode(audio)
kwargs['recorder'].user_audio += audio_decoded
return audio,kwargs
class TTSAudioRecordService(TTSAudioService):
def tts_audio_process(self, audio:bytes, **kwargs):
kwargs['recorder'].tts_audio += audio
return audio,kwargs
# --------------------------------- #
# ------------ 服务链类 ----------- #
class UserAudioServiceChain:
def __init__(self):
self.services = []
def add_service(self, service:UserAudioService):
self.services.append(service)
def user_audio_process(self, audio, **kwargs):
for service in self.services:
audio, kwargs = service.user_audio_process(audio, **kwargs)
return audio
class PromptServiceChain:
def __init__(self):
self.services = []
def add_service(self, service:PromptService):
self.services.append(service)
def prompt_process(self, asr_results):
kwargs = {"asr_result":asr_results}
prompt = ""
for service in self.services:
prompt, kwargs = service.prompt_process(prompt, **kwargs)
return prompt
class LLMMsgServiceChain:
def __init__(self, ):
self.services = []
self.spliter = SentenceSegmentation()
def add_service(self, service:LLMMsgService):
self.services.append(service)
def llm_msg_process(self, llm_msg):
kwargs = {}
llm_chunks = self.spliter.split(llm_msg)
for service in self.services:
llm_chunks , kwargs = service.llm_msg_process(llm_chunks, **kwargs)
return llm_chunks
class TTSAudioServiceChain:
def __init__(self):
self.services = []
def add_service(self, service:TTSAudioService):
self.services.append(service)
def tts_audio_process(self, audio, **kwargs):
for service in self.services:
audio, kwargs = service.tts_audio_process(audio, **kwargs)
return audio
# --------------------------------- #
class Agent():
def __init__(self, asr_type:str, llm_type:str, tts_type:str):
asrFactory = ASRFactory()
llmFactory = LLMFactory()
ttsFactory = TTSFactory()
self.recorder = None
self.user_audio_service_chain = UserAudioServiceChain() #创建用户音频处理服务链
self.user_audio_service_chain.add_service(UserAudioRecordService()) #添加用户音频记录服务
self.asr = asrFactory.create_asr(asr_type)
self.prompt_service_chain = PromptServiceChain()
self.prompt_service_chain.add_service(BasicPromptService())
self.llm = llmFactory.create_llm(llm_type)
self.llm_chunk_service_chain = LLMMsgServiceChain()
self.tts = ttsFactory.create_tts(tts_type)
self.tts_audio_service_chain = TTSAudioServiceChain()
self.tts_audio_service_chain.add_service(TTSAudioRecordService())
def init_recorder(self,user_id):
self.recorder = Recorder(user_id)
# 对用户输入的音频进行预处理
def user_audio_process(self, audio, db):
return self.user_audio_service_chain.user_audio_process(audio, recorder=self.recorder)
# 进行流式语音识别
async def stream_recognize(self, chunk, db):
return await self.asr.stream_recognize(chunk)
# 进行Prompt加工
def prompt_process(self, asr_results, db):
return self.prompt_service_chain.prompt_process(asr_results)
# 进行大模型调用
async def chat(self, assistant ,prompt, db):
return self.llm.chat(assistant, prompt, db)
# 对大模型的返回进行处理
def llm_msg_process(self, llm_chunk, db):
return self.llm_chunk_service_chain.llm_msg_process(llm_chunk)
# 进行TTS合成
def synthetize(self, assistant, text, db):
return self.tts.synthetize(assistant, text)
# 对合成后的音频进行处理
def tts_audio_process(self, audio, db):
return self.tts_audio_service_chain.tts_audio_process(audio, recorder=self.recorder)
# 编码
def encode(self, text, audio):
text_resp = {"type":"text","code":200,"msg":text}
text_resp_bytes = json.dumps(text_resp,ensure_ascii=False).encode('utf-8')
header = struct.pack('!II',len(text_resp_bytes),len(audio))
final_resp = header + text_resp_bytes + audio
return final_resp
# 保存音频
def save(self):
self.recorder.write()

39
app/model.py Normal file
View File

@ -0,0 +1,39 @@
from sqlalchemy import create_engine, Column, Integer, String, CHAR
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from config import Config
engine = create_engine(Config.SQLITE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
MESSAGE_LENGTH_LIMIT = 2**31-1
class User(Base):
__tablename__ = "user"
id = Column(CHAR(36), primary_key=True)
name = Column(String(32))
email = Column(String(64))
password = Column(String(128))
class Assistant(Base):
__tablename__ = "assistant"
id = Column(CHAR(36), primary_key=True)
user_id = Column(CHAR(36))
name = Column(String(32))
system_prompt = Column(String(512))
messages = Column(String(MESSAGE_LENGTH_LIMIT))
user_info = Column(String(256))
llm_info = Column(String(256))
tts_info = Column(String(256))
token = Column(Integer)
Base.metadata.create_all(bind=engine)
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

79
app/public.py Normal file
View File

@ -0,0 +1,79 @@
from datetime import datetime
import wave
import json
# -------------- 公共类 ------------ #
class AsrResultNoneError(Exception):
def __init__(self, message="Asr Result is None!"):
super().__init__(message)
self.message = message
class NoAsrResultsError(Exception):
def __init__(self, message="No Asr Results!"):
super().__init__(message)
self.message = message
class LLMResponseEnd(Exception):
def __init__(self, message="LLM Response End!"):
super().__init__(message)
self.message = message
class UnkownLLMFrame(Exception):
def __init__(self, message="Unkown LLM Frame!"):
super().__init__(message)
self.message = message
class TokenOutofRangeError(Exception):
def __init__(self, message="Token Out of Range!"):
super().__init__(message)
self.message = message
class SentenceSegmentation():
def __init__(self,):
self.is_first_sentence = True
self.cache = ""
def __sentenceSegmentation(self, llm_frame):
results = []
if llm_frame['is_end']:
if self.cache:
results.append(self.cache)
self.cache = ""
return results
for char in llm_frame['msg']:
self.cache += char
if self.is_first_sentence and char in ',.?!,。?!':
results.append(self.cache)
self.cache = ""
self.is_first_sentence = False
elif char in '。?!':
results.append(self.cache)
self.cache = ""
return results
def split(self, llm_chunk):
return self.__sentenceSegmentation(llm_chunk)
class Recorder:
def __init__(self, user_id):
self.input_wav_path = 'storage/wav/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'i.wav'
self.output_wav_path = 'storage/wav/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'o.wav'
self.out_put_text_path = 'storage/record/'+ datetime.now().strftime('%Y%m%d%H%M%S') + 'U' + user_id + 'o.txt'
self.input_sr = 16000
self.output_sr = 22050
self.user_audio = b''
self.tts_audio = b''
self.input_text = ""
self.output_text = ""
def write(self):
record = {"input_wav":self.input_wav_path,"input_text":self.input_text,"input_sr":self.input_sr,"output_wav":self.output_wav_path,"output_text":self.output_text,"output_sr":self.output_sr}
with wave.open(self.input_wav_path, 'wb') as wav_file:
wav_file.setparams((1, 2, self.input_sr, 0, 'NONE', 'not compressed'))
wav_file.writeframes(self.user_audio)
with wave.open(self.output_wav_path, 'wb') as wav_file:
wav_file.setparams((1, 2, self.output_sr, 0, 'NONE', 'not compressed'))
wav_file.writeframes(self.tts_audio)
with open(self.out_put_text_path, 'w', encoding='utf-8') as file:
file.write(json.dumps(record, ensure_ascii=False))
# ---------------------------------- #

36
app/schemas.py Normal file
View File

@ -0,0 +1,36 @@
from pydantic import BaseModel
class create_assistant_request(BaseModel):
user_id: str
name: str
system_prompt : str
user_info : str
llm_info : str
tts_info : str
class update_assistant_request(BaseModel):
name: str
system_prompt : str
messages: str
user_info : str
llm_info : str
tts_info : str
class create_user_request(BaseModel):
name: str
email: str
password: str
class update_user_request(BaseModel):
name: str
email: str
password: str
class update_assistant_system_prompt_request(BaseModel):
system_prompt:str
class update_assistant_deatil_params_request(BaseModel):
model :str
temperature :float
speaker_id:int
length_scale:float

20
config.py Normal file
View File

@ -0,0 +1,20 @@
class Config:
SQLITE_URL = 'sqlite:///takway.db'
ASR = "XF" #在此处选择语音识别引擎
LLM = "MINIMAX" #在此处选择大模型
TTS = "VITS" #在此处选择语音合成引擎
class UVICORN:
HOST = '0.0.0.0'
PORT = 7878
class XF_ASR:
APP_ID = "f1c121c1" #讯飞语音识别APP_ID
API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET
API_KEY = "36b316c7977fa534ae1e3bf52157bb92" #讯飞语音识别API_KEY
DOMAIN = "iat"
LANGUAGE = "zh_cn"
ACCENT = "mandarin"
VAD_EOS = 10000
class MINIMAX_LLM:
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLph5EiLCJVc2VyTmFtZSI6IumHkSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzY4NTM2NDM3MzE1MDgwODg2IiwiUGhvbmUiOiIxMzEzNjE0NzUyNyIsIkdyb3VwSUQiOiIxNzY4NTM2NDM3MzA2NjkyMjc4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjQtMDUtMTggMTY6MTQ6MDMiLCJpc3MiOiJtaW5pbWF4In0.LypYOkJXwKV6GzDM1dcNn4L0m19o8Q_Lvmn6SkMMb9WAfDJYxEnTc5odm-L4WAWfbur_gY0cQzgoHnI14t4XSaAvqfmcdCrKYpJbKoBmMse_RogJs7KOBt658je3wES4pBUKQll6NbogQB1f93lnA9IYv4aEVldfqglbCikd54XO8E9Ptn4gX9Mp8fUn3lCpZ6_OSlmgZsQySrmt1sDHHzi3DlkdXlFSI38TQSZIa5RhFpI8WSBLIbaKl84OhaDzo7v99k9DUCzb5JGh0eZOnUT0YswbKCPeV8rZ1XUiOVQrna1uiDLvqv54aIt3vsu-LypYmnHxtZ_z4u2gt87pZg"
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"

208
main.py Normal file
View File

@ -0,0 +1,208 @@
from fastapi import FastAPI, Depends, WebSocket, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from config import Config
from app.concrete import Agent, AsrResultNoneError
from app.model import Assistant, User, get_db
from app.schemas import *
import uvicorn
import uuid
import json
# 公共函数 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------
def update_messages(messages, llm_text):
messages = json.loads(messages)
messages.append({"role":"assistant","content":llm_text})
return json.dumps(messages,ensure_ascii=False)
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 创建FastAPI实例
app = FastAPI()
# 增删查改 assiatant----------------------------------------------------------------------------------------------------------------------------------------------------------------
# 创建一个assistant
@app.post("/api/assistants")
async def create_assistant(request: create_assistant_request,db=Depends(get_db)):
id = str(uuid.uuid4())
messages = json.dumps([{"role":"system","content":request.system_prompt}],ensure_ascii=False)
assistant = Assistant(id=id,user_id=request.user_id, name=request.name, system_prompt=request.system_prompt, messages=messages,
user_info=request.user_info, llm_info=request.llm_info, tts_info=request.tts_info, token=0)
db.add(assistant)
db.commit()
return {"code":200,"msg":"success","data":{"id":id}}
# 删除一个assistant
@app.delete("/api/assistants/{id}")
async def delete_assistant(id: str,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
db.delete(assistant)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 更新一个assistant
@app.put("/api/assistants/{id}")
async def update_assistant(id: str,request: update_assistant_request,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
assistant.name = request.name
assistant.system_prompt = request.system_prompt
assistant.messages = request.messages
assistant.user_info = request.user_info
assistant.llm_info = request.llm_info
assistant.tts_info = request.tts_info
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 获取一个assistant
@app.get("/api/assistants/{id}")
async def get_assistant(id: str,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
return {"code":200,"msg":"success","data":assistant}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 获取所有的assistant名称和id
@app.get("/api/assistants")
async def get_all_assistants_name_id(db=Depends(get_db)):
assistants = db.query(Assistant.id, Assistant.name).all()
return {"code":200,"msg":"success","data":[{"id": assistant.id, "name": assistant.name} for assistant in assistants]}
# 重置一个assistant的消息
@app.post("/api/assistants/{id}/reset_msg")
async def reset_assistant_msg(id: str,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
assistant.messages = json.dumps([{"role":"system","content":assistant.system_prompt}],ensure_ascii=False)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 修改一个assistant的system_prompt
@app.put("/api/assistants/{id}/system_prompt")
async def update_assistant_system_prompt(id: str,request: update_assistant_system_prompt_request,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
assistant.system_prompt = request.system_prompt
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# 更新具体参数
@app.put("/api/assistants/{id}/deatil_params")
async def update_assistant_deatil_params(id: str,request: update_assistant_deatil_params_request,db=Depends(get_db)):
assistant = db.query(Assistant).filter(Assistant.id == id).first()
if assistant:
llm_info = json.loads(assistant.llm_info)
tts_info = json.loads(assistant.tts_info)
llm_info['model'] = request.model
llm_info['temperature'] = request.temperature
tts_info['speaker_id'] = request.speaker_id
tts_info['length_scale'] = request.length_scale
assistant.llm_info = json.dumps(llm_info, ensure_ascii=False)
assistant.tts_info = json.dumps(tts_info, ensure_ascii=False)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
return {"code":404,'msg':"assistant not found","data":{}}
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 用户增删改查接口 ----------------------------------------------------------------------------------------------------------------------------------------------------------------
# 添加用户
@app.post("/api/users")
async def create_user(request: create_user_request,db=Depends(get_db)):
id = str(uuid.uuid4())
user = User(id=id, name=request.name, email=request.email, password=request.password)
db.add(user)
db.commit()
return {"code":200,"msg":"success","data":{"id":id}}
# 删除用户
@app.delete("/api/users/{id}")
async def delete_user(id: str,db=Depends(get_db)):
user = db.query(User).filter(User.id == id).first()
if user:
db.delete(user)
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
raise HTTPException(status_code=404, detail="user not found")
# 获取用户
@app.get("/api/users/{id}")
async def get_user(id: str,db=Depends(get_db)):
user = db.query(User).filter(User.id == id).first()
if user:
return {"code":200,"msg":"success","data":user}
else:
raise HTTPException(status_code=404, detail="user not found")
# 更新用户
@app.put("/api/users/{id}")
async def update_user(id: str,request: update_user_request,db=Depends(get_db)):
user = db.query(User).filter(User.id == id).first()
if user:
user.name = request.name
user.email = request.email
user.password = request.password
db.commit()
return {"code":200,"msg":"success","data":{}}
else:
raise HTTPException(status_code=404, detail="user not found")
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
# 流式聊天websokct接口 ------------------------------------------------------------------------------------------------------------------------------------------------------------
@app.websocket("/api/chat/streaming/temporary")
async def streaming_chat(ws: WebSocket,db=Depends(get_db)):
await ws.accept()
agent = Agent(asr_type=Config.ASR, llm_type=Config.LLM, tts_type=Config.TTS)
assistant = None
asr_results = []
llm_text = ""
try:
while len(asr_results)==0:
chunk = json.loads(await ws.receive_text())
if assistant is None:
assistant = db.query(Assistant).filter(Assistant.id == chunk['meta_info']['session_id']).first()
agent.init_recorder(assistant.user_id)
chunk["audio"] = agent.user_audio_process(chunk["audio"], db)
asr_results = await agent.stream_recognize(chunk, db)
except AsrResultNoneError:
await ws.send_text(json.dumps({"type":"close","code":201,"msg":""}, ensure_ascii=False))
return
prompt = agent.prompt_process(asr_results, db)
agent.recorder.input_text = prompt
llm_frames = await agent.chat(assistant, prompt, db)
async for llm_frame in llm_frames:
resp_msgs = agent.llm_msg_process(llm_frame, db)
for resp_msg in resp_msgs:
llm_text += resp_msg
tts_audio = agent.synthetize(assistant, resp_msg, db)
agent.tts_audio_process(tts_audio, db)
await ws.send_bytes(agent.encode(resp_msg, tts_audio))
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
assistant.messages = update_messages(assistant.messages, llm_text)
db.commit()
agent.recorder.output_text = llm_text
agent.save()
await ws.close()
# --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有源,也可以指定特定源
allow_credentials=True,
allow_methods=["*"], # 允许所有方法
allow_headers=["*"], # 允许所有头
)
# 启动服务
uvicorn.run(app, host=Config.UVICORN.HOST, port=Config.UVICORN.PORT)

BIN
takway.db Normal file

Binary file not shown.

BIN
test/example_recording.wav Normal file

Binary file not shown.

68
test/test.py Normal file
View File

@ -0,0 +1,68 @@
import json
import base64
from datetime import datetime
import io
from websocket import create_connection
data = {
"text": "",
"audio": "",
"meta_info": {
"session_id":"a36c9bb4-e813-4f0e-9c75-18e049c60f48",
"stream": True,
"voice_synthesize": True,
"is_end": False,
"encoding": "raw"
}
}
def read_pcm_file_in_chunks(chunk_size):
with open('example_recording.wav', 'rb') as pcm_file:
while True:
data = pcm_file.read(chunk_size)
if not data:
break
yield data
def send_audio_chunk(websocket, chunk):
# 将PCM数据进行Base64编码
encoded_data = base64.b64encode(chunk).decode('utf-8')
# 更新data字典中的"audio"键的值为Base64编码后的音频数据
data["audio"] = encoded_data
# 将JSON数据对象转换为JSON字符串
message = json.dumps(data)
# 发送JSON字符串到WebSocket接口
websocket.send(message)
def send_json():
websocket = create_connection('ws://114.214.236.207:7878/api/chat/streaming/temporary')
chunks = read_pcm_file_in_chunks(1024) # 读取PCM文件并生成数据块
for chunk in chunks:
send_audio_chunk(websocket, chunk)
# print("发送数据块:", len(chunk))
import time; time.sleep(0.01)
# threading.Event().wait(0.01) # 等待0.01秒
# 设置data字典中的"is_end"键为True表示音频流结束
data["meta_info"]["is_end"] = True
# 发送最后一个数据块和流结束信号
send_audio_chunk(websocket, b'') # 发送空数据块表示结束
# 等待并打印接收到的数据
print("等待接收:", datetime.now())
audio_bytes = b''
while True:
data_ws = websocket.recv()
try:
message_json = json.loads(data_ws)
print(message_json) # 打印接收到的消息
if message_json["type"] == "close":
break # 如果没有接收到消息,则退出循环
except Exception as e:
audio_bytes += data_ws
print(e)
print("接收完毕:", datetime.now())
websocket.close()
# 启动事件循环
send_json()

1
utils/__init__.py Normal file
View File

@ -0,0 +1 @@
from . import *

2
utils/vits/__init__.py Normal file
View File

@ -0,0 +1,2 @@
from .text import *
from .monotonic_align import *

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

302
utils/vits/attentions.py Normal file
View File

@ -0,0 +1,302 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
# import commons
# from modules import LayerNorm
from utils.vits import commons
from utils.vits.modules import LayerNorm
class Encoder(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class Decoder(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.drop = nn.Dropout(p_dropout)
self.self_attn_layers = nn.ModuleList()
self.norm_layers_0 = nn.ModuleList()
self.encdec_attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
self.norm_layers_0.append(LayerNorm(hidden_channels))
self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, h, h_mask):
"""
x: decoder input
h: encoder output
"""
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y)
x = self.norm_layers_0[i](x + y)
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class MultiHeadAttention(nn.Module):
def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.p_dropout = p_dropout
self.window_size = window_size
self.heads_share = heads_share
self.block_length = block_length
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1)
self.conv_o = nn.Conv1d(channels, out_channels, 1)
self.drop = nn.Dropout(p_dropout)
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
nn.init.xavier_uniform_(self.conv_v.weight)
if proximal_init:
with torch.no_grad():
self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, t_t = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
assert t_s == t_t, "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
assert t_s == t_t, "Local attention is only available for self-attention."
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
def _matmul_with_relative_values(self, x, y):
"""
x: [b, h, l, m]
y: [h or 1, m, d]
ret: [b, h, l, d]
"""
ret = torch.matmul(x, y.unsqueeze(0))
return ret
def _matmul_with_relative_keys(self, x, y):
"""
x: [b, h, l, d]
y: [h or 1, m, d]
ret: [b, h, l, m]
"""
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret
def _get_relative_embeddings(self, relative_embeddings, length):
max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(
relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
"""
x: [b, h, l, 2*l-1]
ret: [b, h, l, l]
"""
batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing.
x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
return x_final
def _absolute_position_to_relative_position(self, x):
"""
x: [b, h, l, l]
ret: [b, h, l, 2*l-1]
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
x_flat = x.view([batch, heads, length**2 + length*(length -1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
return x_final
def _attention_bias_proximal(self, length):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.activation = activation
self.causal = causal
if causal:
self.padding = self._causal_padding
else:
self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
self.drop = nn.Dropout(p_dropout)
def forward(self, x, x_mask):
x = self.conv_1(self.padding(x * x_mask))
if self.activation == "gelu":
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(self.padding(x * x_mask))
return x * x_mask
def _causal_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = self.kernel_size - 1
pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
def _same_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x

172
utils/vits/commons.py Normal file
View File

@ -0,0 +1,172 @@
import math
import torch
from torch.nn import functional as F
import torch.jit
def script_method(fn, _rcb=None):
return fn
def script(obj, optimize=True, _frames_up=0, _rcb=None):
return obj
torch.jit.script_method = script_method
torch.jit.script = script
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
return kl
def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def get_timing_signal_1d(
length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(num_timescales - 1))
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device = duration.device
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2,3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1. / norm_type)
return total_norm

535
utils/vits/models.py Normal file
View File

@ -0,0 +1,535 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
# import commons
# import modules
# import attentions
# import monotonic_align
from utils.vits import commons, modules, attentions, monotonic_align
from utils.vits.commons import init_weights, get_padding
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
# from commons import init_weights, get_padding
class StochasticDurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
super().__init__()
filter_channels = in_channels # it needs to be removed from future version.
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.n_flows = n_flows
self.gin_channels = gin_channels
self.log_flow = modules.Log()
self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
x = torch.detach(x)
x = self.pre(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask
if not reverse:
flows = self.flows
assert w is not None
logdet_tot_q = 0
h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = e_q
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = torch.cat([z0, z1], 1)
for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
return nll + logq # [b]
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1)
logw = z0
return logw
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
def forward(self, x, x_mask, g=None):
x = torch.detach(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class TextEncoder(nn.Module):
def __init__(self,
n_vocab,
out_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout):
super().__init__()
self.n_vocab = n_vocab
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
self.encoder = attentions.Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout)
self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths):
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return x, m, logs, x_mask
class ResidualCouplingBlock(nn.Module):
def __init__(self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
n_flows=4,
gin_channels=0):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.n_flows = n_flows
self.gin_channels = gin_channels
self.flows = nn.ModuleList()
for i in range(n_flows):
self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=False):
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, reverse=reverse)
return x
class PosteriorEncoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask
class Generator(torch.nn.Module):
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(weight_norm(
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
k, u, padding=(k-u)//2)))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel//(2**(i+1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.ups.apply(init_weights)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g=None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i*self.num_kernels+j](x)
else:
xs += self.resblocks[i*self.num_kernels+j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
])
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
])
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(MultiPeriodDiscriminator, self).__init__()
periods = [2,3,5,7,11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
"""
def __init__(self,
n_vocab,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=0,
use_sdp=True,
**kwargs):
super().__init__()
self.n_vocab = n_vocab
self.spec_channels = spec_channels
self.inter_channels = inter_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_initial_channel = upsample_initial_channel
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.use_sdp = use_sdp
self.enc_p = TextEncoder(n_vocab,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout)
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
if use_sdp:
self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
else:
self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
if n_speakers > 1:
self.emb_g = nn.Embedding(n_speakers, gin_channels)
def forward(self, x, x_lengths, y, y_lengths, sid=None):
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
g = None
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g)
with torch.no_grad():
# negative cross-entropy
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s]
neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
w = attn.sum(2)
if self.use_sdp:
l_length = self.dp(x, x_mask, w, g=g)
l_length = l_length / torch.sum(x_mask)
else:
logw_ = torch.log(w + 1e-6) * x_mask
logw = self.dp(x, x_mask, g=g)
l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging
# expand prior
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
o = self.dec(z_slice, g=g)
return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
g = None
if self.use_sdp:
logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
else:
logw = self.dp(x, x_mask, g=g)
w = torch.exp(logw) * x_mask * length_scale
w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = commons.generate_path(w_ceil, attn_mask)
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=g, reverse=True)
o = self.dec((z * y_mask)[:,:,:max_len], g=g)
return o, attn, y_mask, (z, z_p, m_p, logs_p)
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
assert self.n_speakers > 0, "n_speakers have to be larger than 0."
g_src = self.emb_g(sid_src).unsqueeze(-1)
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.dec(z_hat * y_mask, g=g_tgt)
return o_hat, y_mask, (z, z_p, z_hat)

390
utils/vits/modules.py Normal file
View File

@ -0,0 +1,390 @@
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm
# import commons
# from commons import init_weights, get_padding
# from transforms import piecewise_rational_quadratic_transform
from utils.vits import commons
from utils.vits.commons import init_weights, get_padding
from utils.vits.transforms import piecewise_rational_quadratic_transform
LRELU_SLOPE = 0.1
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
assert n_layers > 1, "Number of layers should be larger than 0."
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(
nn.ReLU(),
nn.Dropout(p_dropout))
for _ in range(n_layers-1):
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size ** i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
groups=channels, dilation=dilation, padding=padding
))
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class WN(torch.nn.Module):
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
super(WN, self).__init__()
assert(kernel_size % 2 == 1)
self.hidden_channels =hidden_channels
self.kernel_size = kernel_size,
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.p_dropout = p_dropout
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.drop = nn.Dropout(p_dropout)
if gin_channels != 0:
cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
for i in range(n_layers):
dilation = dilation_rate ** i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
dilation=dilation, padding=padding)
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
self.in_layers.append(in_layer)
# last one is not necessary
if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask, g=None, **kwargs):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
if g is not None:
g = self.cond_layer(g)
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
else:
g_l = torch.zeros_like(x_in)
acts = commons.fused_add_tanh_sigmoid_multiply(
x_in,
g_l,
n_channels_tensor)
acts = self.drop(acts)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:,:self.hidden_channels,:]
x = (x + res_acts) * x_mask
output = output + res_skip_acts[:,self.hidden_channels:,:]
else:
output = output + res_skip_acts
return output * x_mask
def remove_weight_norm(self):
if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
def forward(self, x, x_mask=None):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c2(xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.convs = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])))
])
self.convs.apply(init_weights)
def forward(self, x, x_mask=None):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c(xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class Log(nn.Module):
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
class Flip(nn.Module):
def forward(self, x, *args, reverse=False, **kwargs):
x = torch.flip(x, [1])
if not reverse:
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
return x, logdet
else:
return x
class ElementwiseAffine(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels,1))
self.logs = nn.Parameter(torch.zeros(channels,1))
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1,2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class ResidualCouplingLayer(nn.Module):
def __init__(self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=0,
gin_channels=0,
mean_only=False):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.half_channels = channels // 2
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels]*2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels]*2, 1)
else:
m = stats
logs = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1,2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x
class ConvFlow(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.num_bins = num_bins
self.tail_bound = tail_bound
self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels]*2, 1)
h = self.pre(x0)
h = self.convs(h, x_mask, g=g)
h = self.proj(h) * x_mask
b, c, t = x0.shape
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_derivatives = h[..., 2 * self.num_bins:]
x1, logabsdet = piecewise_rational_quadratic_transform(x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails='linear',
tail_bound=self.tail_bound
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1,2])
if not reverse:
return x, logdet
else:
return x

View File

@ -0,0 +1,20 @@
from numpy import zeros, int32, float32
from torch import from_numpy
from .core import maximum_path_jit
def maximum_path(neg_cent, mask):
""" numba optimized version.
neg_cent: [b, t_t, t_s]
mask: [b, t_t, t_s]
"""
device = neg_cent.device
dtype = neg_cent.dtype
neg_cent = neg_cent.data.cpu().numpy().astype(float32)
path = zeros(neg_cent.shape, dtype=int32)
t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
return from_numpy(path).to(device=device, dtype=dtype)

View File

@ -0,0 +1,36 @@
import numba
@numba.jit(numba.void(numba.int32[:, :, ::1], numba.float32[:, :, ::1], numba.int32[::1], numba.int32[::1]),
nopython=True, nogil=True)
def maximum_path_jit(paths, values, t_ys, t_xs):
b = paths.shape[0]
max_neg_val = -1e9
for i in range(int(b)):
path = paths[i]
value = values[i]
t_y = t_ys[i]
t_x = t_xs[i]
v_prev = v_cur = 0.0
index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[y - 1, x]
if x == 0:
if y == 0:
v_prev = 0.
else:
v_prev = max_neg_val
else:
v_prev = value[y - 1, x - 1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[y, index] = 1
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
index = index - 1

19
utils/vits/text/LICENSE Normal file
View File

@ -0,0 +1,19 @@
Copyright (c) 2017 Keith Ito
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@ -0,0 +1,57 @@
""" from https://github.com/keithito/tacotron """
from . import cleaners
from .symbols import symbols
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
def text_to_sequence(text, symbols, cleaner_names):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
'''
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
sequence = []
clean_text = _clean_text(text, cleaner_names)
for symbol in clean_text:
if symbol not in _symbol_to_id.keys():
continue
symbol_id = _symbol_to_id[symbol]
sequence += [symbol_id]
return sequence, clean_text
def cleaned_text_to_sequence(cleaned_text):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
Returns:
List of integers corresponding to the symbols in the text
'''
sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
return sequence
def sequence_to_text(sequence):
'''Converts a sequence of IDs back to a string'''
result = ''
for symbol_id in sequence:
s = _id_to_symbol[symbol_id]
result += s
return result
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception('Unknown cleaner: %s' % name)
text = cleaner(text)
return text

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

475
utils/vits/text/cleaners.py Normal file
View File

@ -0,0 +1,475 @@
""" from https://github.com/keithito/tacotron """
'''
Cleaners are transformations that run over the input text at both training and eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
1. "english_cleaners" for English text
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
the symbols in symbols.py to match your data).
'''
import re
from unidecode import unidecode
# import pyopenjtalk
from jamo import h2j, j2hcj
from pypinyin import lazy_pinyin, BOPOMOFO
import jieba, cn2an
# This is a list of Korean classifiers preceded by pure Korean numerals.
_korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
# Regular expression matching Japanese without punctuation marks:
_japanese_characters = re.compile(r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
# Regular expression matching non-Japanese characters or punctuation marks:
_japanese_marks = re.compile(r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
# List of (hangul, hangul divided) pairs:
_hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [
('', 'ㄱㅅ'),
('', 'ㄴㅈ'),
('', 'ㄴㅎ'),
('', 'ㄹㄱ'),
('', 'ㄹㅁ'),
('', 'ㄹㅂ'),
('', 'ㄹㅅ'),
('', 'ㄹㅌ'),
('', 'ㄹㅍ'),
('', 'ㄹㅎ'),
('', 'ㅂㅅ'),
('', 'ㅗㅏ'),
('', 'ㅗㅐ'),
('', 'ㅗㅣ'),
('', 'ㅜㅓ'),
('', 'ㅜㅔ'),
('', 'ㅜㅣ'),
('', 'ㅡㅣ'),
('', 'ㅣㅏ'),
('', 'ㅣㅐ'),
('', 'ㅣㅓ'),
('', 'ㅣㅔ'),
('', 'ㅣㅗ'),
('', 'ㅣㅜ')
]]
# List of (Latin alphabet, hangul) pairs:
_latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
('a', '에이'),
('b', ''),
('c', ''),
('d', ''),
('e', ''),
('f', '에프'),
('g', ''),
('h', '에이치'),
('i', '아이'),
('j', '제이'),
('k', '케이'),
('l', ''),
('m', ''),
('n', ''),
('o', ''),
('p', ''),
('q', ''),
('r', '아르'),
('s', '에스'),
('t', ''),
('u', ''),
('v', '브이'),
('w', '더블유'),
('x', '엑스'),
('y', '와이'),
('z', '제트')
]]
# List of (Latin alphabet, bopomofo) pairs:
_latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
('a', 'ㄟˉ'),
('b', 'ㄅㄧˋ'),
('c', 'ㄙㄧˉ'),
('d', 'ㄉㄧˋ'),
('e', 'ㄧˋ'),
('f', 'ㄝˊㄈㄨˋ'),
('g', 'ㄐㄧˋ'),
('h', 'ㄝˇㄑㄩˋ'),
('i', 'ㄞˋ'),
('j', 'ㄐㄟˋ'),
('k', 'ㄎㄟˋ'),
('l', 'ㄝˊㄛˋ'),
('m', 'ㄝˊㄇㄨˋ'),
('n', 'ㄣˉ'),
('o', 'ㄡˉ'),
('p', 'ㄆㄧˉ'),
('q', 'ㄎㄧㄡˉ'),
('r', 'ㄚˋ'),
('s', 'ㄝˊㄙˋ'),
('t', 'ㄊㄧˋ'),
('u', 'ㄧㄡˉ'),
('v', 'ㄨㄧˉ'),
('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
('x', 'ㄝˉㄎㄨˋㄙˋ'),
('y', 'ㄨㄞˋ'),
('z', 'ㄗㄟˋ')
]]
# List of (bopomofo, romaji) pairs:
_bopomofo_to_romaji = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
('ㄅㄛ', 'p⁼wo'),
('ㄆㄛ', 'pʰwo'),
('ㄇㄛ', 'mwo'),
('ㄈㄛ', 'fwo'),
('', 'p⁼'),
('', ''),
('', 'm'),
('', 'f'),
('', 't⁼'),
('', ''),
('', 'n'),
('', 'l'),
('', 'k⁼'),
('', ''),
('', 'h'),
('', 'ʧ⁼'),
('', 'ʧʰ'),
('', 'ʃ'),
('', 'ʦ`⁼'),
('', 'ʦ`ʰ'),
('', 's`'),
('', 'ɹ`'),
('', 'ʦ⁼'),
('', 'ʦʰ'),
('', 's'),
('', 'a'),
('', 'o'),
('', 'ə'),
('', 'e'),
('', 'ai'),
('', 'ei'),
('', 'au'),
('', 'ou'),
('ㄧㄢ', 'yeNN'),
('', 'aNN'),
('ㄧㄣ', 'iNN'),
('', 'əNN'),
('', 'aNg'),
('ㄧㄥ', 'iNg'),
('ㄨㄥ', 'uNg'),
('ㄩㄥ', 'yuNg'),
('', 'əNg'),
('', 'əɻ'),
('', 'i'),
('', 'u'),
('', 'ɥ'),
('ˉ', ''),
('ˊ', ''),
('ˇ', '↓↑'),
('ˋ', ''),
('˙', ''),
('', ','),
('', '.'),
('', '!'),
('', '?'),
('', '-')
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def japanese_to_romaji_with_accent(text):
'''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
sentences = re.split(_japanese_marks, text)
marks = re.findall(_japanese_marks, text)
text = ''
for i, sentence in enumerate(sentences):
if re.match(_japanese_characters, sentence):
if text!='':
text+=' '
labels = pyopenjtalk.extract_fullcontext(sentence)
for n, label in enumerate(labels):
phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
if phoneme not in ['sil','pau']:
text += phoneme.replace('ch','ʧ').replace('sh','ʃ').replace('cl','Q')
else:
continue
n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
a2 = int(re.search(r"\+(\d+)\+", label).group(1))
a3 = int(re.search(r"\+(\d+)/", label).group(1))
if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil','pau']:
a2_next=-1
else:
a2_next = int(re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
# Accent phrase boundary
if a3 == 1 and a2_next == 1:
text += ' '
# Falling
elif a1 == 0 and a2_next == a2 + 1 and a2 != n_moras:
text += ''
# Rising
elif a2 == 1 and a2_next == 2:
text += ''
if i<len(marks):
text += unidecode(marks[i]).replace(' ','')
return text
def latin_to_hangul(text):
for regex, replacement in _latin_to_hangul:
text = re.sub(regex, replacement, text)
return text
def divide_hangul(text):
for regex, replacement in _hangul_divided:
text = re.sub(regex, replacement, text)
return text
def hangul_number(num, sino=True):
'''Reference https://github.com/Kyubyong/g2pK'''
num = re.sub(',', '', num)
if num == '0':
return ''
if not sino and num == '20':
return '스무'
digits = '123456789'
names = '일이삼사오육칠팔구'
digit2name = {d: n for d, n in zip(digits, names)}
modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉'
decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔'
digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())}
digit2dec = {d: dec for d, dec in zip(digits, decimals.split())}
spelledout = []
for i, digit in enumerate(num):
i = len(num) - i - 1
if sino:
if i == 0:
name = digit2name.get(digit, '')
elif i == 1:
name = digit2name.get(digit, '') + ''
name = name.replace('일십', '')
else:
if i == 0:
name = digit2mod.get(digit, '')
elif i == 1:
name = digit2dec.get(digit, '')
if digit == '0':
if i % 4 == 0:
last_three = spelledout[-min(3, len(spelledout)):]
if ''.join(last_three) == '':
spelledout.append('')
continue
else:
spelledout.append('')
continue
if i == 2:
name = digit2name.get(digit, '') + ''
name = name.replace('일백', '')
elif i == 3:
name = digit2name.get(digit, '') + ''
name = name.replace('일천', '')
elif i == 4:
name = digit2name.get(digit, '') + ''
name = name.replace('일만', '')
elif i == 5:
name = digit2name.get(digit, '') + ''
name = name.replace('일십', '')
elif i == 6:
name = digit2name.get(digit, '') + ''
name = name.replace('일백', '')
elif i == 7:
name = digit2name.get(digit, '') + ''
name = name.replace('일천', '')
elif i == 8:
name = digit2name.get(digit, '') + ''
elif i == 9:
name = digit2name.get(digit, '') + ''
elif i == 10:
name = digit2name.get(digit, '') + ''
elif i == 11:
name = digit2name.get(digit, '') + ''
elif i == 12:
name = digit2name.get(digit, '') + ''
elif i == 13:
name = digit2name.get(digit, '') + ''
elif i == 14:
name = digit2name.get(digit, '') + ''
elif i == 15:
name = digit2name.get(digit, '') + ''
spelledout.append(name)
return ''.join(elem for elem in spelledout)
def number_to_hangul(text):
'''Reference https://github.com/Kyubyong/g2pK'''
tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text))
for token in tokens:
num, classifier = token
if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers:
spelledout = hangul_number(num, sino=False)
else:
spelledout = hangul_number(num, sino=True)
text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}')
# digit by digit for remaining digits
digits = '0123456789'
names = '영일이삼사오육칠팔구'
for d, n in zip(digits, names):
text = text.replace(d, n)
return text
def number_to_chinese(text):
numbers = re.findall(r'\d+(?:\.?\d+)?', text)
for number in numbers:
text = text.replace(number, cn2an.an2cn(number),1)
return text
def chinese_to_bopomofo(text):
text=text.replace('','').replace('','').replace('','')
words=jieba.lcut(text,cut_all=False)
text=''
for word in words:
bopomofos=lazy_pinyin(word,BOPOMOFO)
if not re.search('[\u4e00-\u9fff]',word):
text+=word
continue
for i in range(len(bopomofos)):
if re.match('[\u3105-\u3129]',bopomofos[i][-1]):
bopomofos[i]+='ˉ'
if text!='':
text+=' '
text+=''.join(bopomofos)
return text
def latin_to_bopomofo(text):
for regex, replacement in _latin_to_bopomofo:
text = re.sub(regex, replacement, text)
return text
def bopomofo_to_romaji(text):
for regex, replacement in _bopomofo_to_romaji:
text = re.sub(regex, replacement, text)
return text
def basic_cleaners(text):
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
'''Pipeline for non-English text that transliterates to ASCII.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def japanese_cleaners(text):
text=japanese_to_romaji_with_accent(text)
if re.match('[A-Za-z]',text[-1]):
text += '.'
return text
def japanese_cleaners2(text):
return japanese_cleaners(text).replace('ts','ʦ').replace('...','')
def korean_cleaners(text):
'''Pipeline for Korean text'''
text = latin_to_hangul(text)
text = number_to_hangul(text)
text = j2hcj(h2j(text))
text = divide_hangul(text)
if re.match('[\u3131-\u3163]',text[-1]):
text += '.'
return text
def chinese_cleaners(text):
'''Pipeline for Chinese text'''
text=number_to_chinese(text)
text=chinese_to_bopomofo(text)
text=latin_to_bopomofo(text)
if re.match('[ˉˊˇˋ˙]',text[-1]):
text += ''
return text
def zh_ja_mixture_cleaners(text):
chinese_texts=re.findall(r'\[ZH\].*?\[ZH\]',text)
japanese_texts=re.findall(r'\[JA\].*?\[JA\]',text)
for chinese_text in chinese_texts:
cleaned_text=number_to_chinese(chinese_text[4:-4])
cleaned_text=chinese_to_bopomofo(cleaned_text)
cleaned_text=latin_to_bopomofo(cleaned_text)
cleaned_text=bopomofo_to_romaji(cleaned_text)
cleaned_text=re.sub('i[aoe]',lambda x:'y'+x.group(0)[1:],cleaned_text)
cleaned_text=re.sub('u[aoəe]',lambda x:'w'+x.group(0)[1:],cleaned_text)
cleaned_text=re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑]+)',lambda x:x.group(1)+'ɹ`'+x.group(2),cleaned_text).replace('ɻ','ɹ`')
cleaned_text=re.sub('([ʦs][⁼ʰ]?)([→↓↑]+)',lambda x:x.group(1)+'ɹ'+x.group(2),cleaned_text)
text = text.replace(chinese_text,cleaned_text+' ',1)
for japanese_text in japanese_texts:
cleaned_text=japanese_to_romaji_with_accent(japanese_text[4:-4]).replace('ts','ʦ').replace('u','ɯ').replace('...','')
text = text.replace(japanese_text,cleaned_text+' ',1)
text=text[:-1]
if re.match('[A-Za-zɯɹəɥ→↓↑]',text[-1]):
text += '.'
return text

View File

@ -0,0 +1,39 @@
'''
Defines the set of symbols used in text input to the model.
'''
'''# japanese_cleaners
_pad = '_'
_punctuation = ',.!?-'
_letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
'''
'''# japanese_cleaners2
_pad = '_'
_punctuation = ',.!?-~…'
_letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
'''
'''# korean_cleaners
_pad = '_'
_punctuation = ',.!?…~'
_letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
'''
'''# chinese_cleaners
_pad = '_'
_punctuation = ',。!?—…'
_letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
'''
# zh_ja_mixture_cleaners
_pad = '_'
_punctuation = ',.!?-~…'
_letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
# Export all symbols:
symbols = [_pad] + list(_punctuation) + list(_letters)
# Special symbol ids
SPACE_ID = symbols.index(" ")

193
utils/vits/transforms.py Normal file
View File

@ -0,0 +1,193 @@
import torch
from torch.nn import functional as F
import numpy as np
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {
'tails': tails,
'tail_bound': tail_bound
}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs
)
return outputs, logabsdet
def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return torch.sum(
inputs[..., None] >= bin_locations,
dim=-1
) - 1
def unconstrained_rational_quadratic_spline(inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails='linear',
tail_bound=1.,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs)
if tails == 'linear':
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError('{} tails are not implemented.'.format(tails))
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative
)
return outputs, logabsdet
def rational_quadratic_spline(inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0., right=1., bottom=0., top=1.,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE):
if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError('Input to a transform is not within its domain')
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError('Minimal bin width too large for the number of bins')
if min_bin_height * num_bins > 1.0:
raise ValueError('Minimal bin height too large for the number of bins')
widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
delta = heights / widths
input_delta = delta.gather(-1, bin_idx)[..., 0]
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse:
a = (((inputs - input_cumheights) * (input_derivatives
+ input_derivatives_plus_one
- 2 * input_delta)
+ input_heights * (input_delta - input_derivatives)))
b = (input_heights * input_derivatives
- (inputs - input_cumheights) * (input_derivatives
+ input_derivatives_plus_one
- 2 * input_delta))
c = - input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - torch.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta)
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2))
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (input_delta * theta.pow(2)
+ input_derivatives * theta_one_minus_theta)
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2))
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet

225
utils/vits/utils.py Normal file
View File

@ -0,0 +1,225 @@
import os
import sys
import argparse
import logging
import json
import subprocess
import numpy as np
import librosa
import torch
MATPLOTLIB_FLAG = False
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
def load_checkpoint(checkpoint_path, model, optimizer=None):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
iteration = checkpoint_dict['iteration']
learning_rate = checkpoint_dict['learning_rate']
if optimizer is not None:
optimizer.load_state_dict(checkpoint_dict['optimizer'])
saved_state_dict = checkpoint_dict['model']
if hasattr(model, 'module'):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict= {}
for k, v in state_dict.items():
try:
new_state_dict[k] = saved_state_dict[k]
except:
logger.info("%s is not in the checkpoint" % k)
new_state_dict[k] = v
if hasattr(model, 'module'):
model.module.load_state_dict(new_state_dict)
else:
model.load_state_dict(new_state_dict)
logger.info("Loaded checkpoint '{}' (iteration {})" .format(
checkpoint_path, iteration))
return model, optimizer, learning_rate, iteration
def plot_spectrogram_to_numpy(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10,2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
interpolation='none')
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_alignment_to_numpy(alignment, info=None):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
if info is not None:
xlabel += '\n\n' + info
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def load_audio_to_torch(full_path, target_sampling_rate):
audio, sampling_rate = librosa.load(full_path, sr=target_sampling_rate, mono=True)
return torch.FloatTensor(audio.astype(np.float32))
def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding='utf-8') as f:
filepaths_and_text = [line.strip().split(split) for line in f]
return filepaths_and_text
def get_hparams(init=True):
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
help='JSON file for configuration')
parser.add_argument('-m', '--model', type=str, required=True,
help='Model name')
args = parser.parse_args()
model_dir = os.path.join("./logs", args.model)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
config_path = args.config
config_save_path = os.path.join(model_dir, "config.json")
if init:
with open(config_path, "r") as f:
data = f.read()
with open(config_save_path, "w") as f:
f.write(data)
else:
with open(config_save_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
hparams.model_dir = model_dir
return hparams
def get_hparams_from_dir(model_dir):
config_save_path = os.path.join(model_dir, "config.json")
with open(config_save_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams =HParams(**config)
hparams.model_dir = model_dir
return hparams
def get_hparams_from_file(config_path):
with open(config_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams =HParams(**config)
return hparams
def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir
))
return
cur_hash = subprocess.getoutput("git rev-parse HEAD")
path = os.path.join(model_dir, "githash")
if os.path.exists(path):
saved_hash = open(path).read()
if saved_hash != cur_hash:
logger.warn("git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]))
else:
open(path, "w").write(cur_hash)
def get_logger(model_dir, filename="train.log"):
global logger
logger = logging.getLogger(os.path.basename(model_dir))
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
if not os.path.exists(model_dir):
os.makedirs(model_dir)
h = logging.FileHandler(os.path.join(model_dir, filename))
h.setLevel(logging.DEBUG)
h.setFormatter(formatter)
logger.addHandler(h)
return logger
class HParams():
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = HParams(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()

114
utils/vits_utils.py Normal file
View File

@ -0,0 +1,114 @@
import os
import numpy as np
import torch
from torch import LongTensor
from typing import Optional
import soundfile as sf
# vits
from .vits import utils, commons
from .vits.models import SynthesizerTrn
from .vits.text import text_to_sequence
def tts_model_init(model_path='./vits_model', device='cuda'):
hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))
# hps_ms = utils.get_hparams_from_file('vits_model/config.json')
net_g_ms = SynthesizerTrn(
len(hps_ms.symbols),
hps_ms.data.filter_length // 2 + 1,
hps_ms.train.segment_size // hps_ms.data.hop_length,
n_speakers=hps_ms.data.n_speakers,
**hps_ms.model)
net_g_ms = net_g_ms.eval().to(device)
speakers = hps_ms.speakers
utils.load_checkpoint(os.path.join(model_path, 'G_953000.pth'), net_g_ms, None)
# utils.load_checkpoint('vits_model/G_953000.pth', net_g_ms, None)
return hps_ms, net_g_ms, speakers
class TextToSpeech:
def __init__(self,
model_path="./utils/vits_model",
device='cuda',
RATE=22050,
debug=False,
):
self.debug = debug
self.RATE = RATE
self.device = torch.device(device)
self.limitation = os.getenv("SYSTEM") == "spaces" # 在huggingface spaces中限制文本和音频长度
self.hps_ms, self.net_g_ms, self.speakers = self._tts_model_init(model_path)
def _tts_model_init(self, model_path):
hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))
net_g_ms = SynthesizerTrn(
len(hps_ms.symbols),
hps_ms.data.filter_length // 2 + 1,
hps_ms.train.segment_size // hps_ms.data.hop_length,
n_speakers=hps_ms.data.n_speakers,
**hps_ms.model)
net_g_ms = net_g_ms.eval().to(self.device)
speakers = hps_ms.speakers
utils.load_checkpoint(os.path.join(model_path, 'G_953000.pth'), net_g_ms, None)
if self.debug:
print("Model loaded.")
return hps_ms, net_g_ms, speakers
def _get_text(self, text):
text_norm, clean_text = text_to_sequence(text, self.hps_ms.symbols, self.hps_ms.data.text_cleaners)
if self.hps_ms.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = LongTensor(text_norm)
return text_norm, clean_text
def _preprocess_text(self, text, language):
if language == 0:
return f"[ZH]{text}[ZH]"
elif language == 1:
return f"[JA]{text}[JA]"
return text
def _generate_audio(self, text, speaker_id, noise_scale, noise_scale_w, length_scale):
import time
start_time = time.time()
stn_tst, clean_text = self._get_text(text)
with torch.no_grad():
x_tst = stn_tst.unsqueeze(0).to(self.device)
x_tst_lengths = LongTensor([stn_tst.size(0)]).to(self.device)
speaker_id = LongTensor([speaker_id]).to(self.device)
audio = self.net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
length_scale=length_scale)[0][0, 0].data.cpu().float().numpy()
if self.debug:
print(f"Synthesis time: {time.time() - start_time} s")
return audio
def synthesize(self, text, tts_info,target_se: Optional[np.ndarray]=None, save_audio=False, return_bytes=True):
if not len(text):
return b''
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
if len(text) > 100 and self.limitation:
return f"输入文字过长!{len(text)}>100", None
text = self._preprocess_text(text, tts_info['language'])
audio = self._generate_audio(text, tts_info['speaker_id'], tts_info['noise_scale'], tts_info['noise_scale_w'], tts_info['length_scale'])
if return_bytes:
audio = self.convert_numpy_to_bytes(audio)
return audio
def convert_numpy_to_bytes(self, audio_data):
if isinstance(audio_data, np.ndarray):
if audio_data.dtype == np.dtype('float32'):
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
audio_data = audio_data.tobytes()
return audio_data
else:
raise TypeError("audio_data must be a numpy array")
def save_audio(self, audio, sample_rate, file_name='output_file.wav'):
sf.write(file_name, audio, samplerate=sample_rate)
print(f"VITS Audio saved to {file_name}")

67
utils/xf_asr_utils.py Normal file
View File

@ -0,0 +1,67 @@
import websockets
import datetime
import hashlib
import base64
import hmac
import json
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
from datetime import datetime
from time import mktime
from config import Config
def generate_xf_asr_url():
#设置讯飞流式听写API相关参数
APIKey = Config.XF_ASR.API_KEY
APISecret = Config.XF_ASR.API_SECRET
#鉴权并创建websocket_url
url = 'wss://ws-api.xfyun.cn/v2/iat'
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v2/iat " + "HTTP/1.1"
signature_sha = hmac.new(APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
APIKey, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
v = {
"authorization": authorization,
"date": date,
"host": "ws-api.xfyun.cn"
}
url = url + '?' + urlencode(v)
return url
def make_first_frame(buf):
first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vad_eos":10000},
"data":{"status":0,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
return json.dumps(first_frame)
def make_continue_frame(buf):
continue_frame = {"data":{"status":1,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
return json.dumps(continue_frame)
def make_last_frame(buf):
last_frame = {"data":{"status":2,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
return json.dumps(last_frame)
def parse_xfasr_recv(message):
code = message['code']
if code!=0:
raise Exception("讯飞ASR错误码"+str(code))
else:
data = message['data']['result']['ws']
result = ""
for i in data:
for w in i['cw']:
result += w['w']
return result
async def xf_asr_websocket_factory():
url = generate_xf_asr_url()
return await websockets.connect(url)