1
0
Fork 0

仓库初始化

This commit is contained in:
Killua777 2024-05-01 17:18:30 +08:00
commit 83cbe007ba
60 changed files with 5190 additions and 0 deletions

9
.gitignore vendored Normal file
View File

@ -0,0 +1,9 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
/app.log
app.log
/utils/tts/vits_model/
vits_model

0
Dockerfile Normal file
View File

58
README.md Normal file
View File

@ -0,0 +1,58 @@
# TakwayAI后端
### 项目架构
AI后端使用fastapi框架采用分层架构
```
TakwayAI/
├── app/
│ ├── __init__.py # 应用初始化和配置
│ ├── main.py # 应用启动入口
│ ├── models/ # 数据模型定义
│ │ ├── __init__.py
│ │ ├── models.py # 数据库定义
│ ├── schemas/ # 请求和响应模型
│ │ ├── __init__.py
│ │ ├── user.py # 用户相关schema
│ │ └── ... # 其他schema
│ ├── controllers/ # 业务逻辑控制器
│ │ ├── __init__.py
│ │ ├── user.py # 用户相关控制器
│ │ └── ... # 其他控制器
│ ├── routes/ # 路由和视图函数
│ │ ├── __init__.py
│ │ ├── user.py # 用户相关路由
│ │ └── ... # 其他路由
│ ├── dependencies/ # 依赖注入相关
│ │ ├── __init__.py
│ │ ├── database.py # 数据库依赖
│ │ └── ... # 其他依赖
│ └── exceptions/ # 自定义异常处理
│ ├── __init__.py
│ └── ... # 自定义异常类
├── tests/ # 测试代码
│ ├── __init__.py
| └── ...
├── config/ # 配置文件
│ ├── __init__.py
│ ├── production.py # 生产环境配置
│ ├── development.py # 开发环境配置
│ └── ... # 其他环境配置
├── utils/ # 工具函数和辅助类
│ ├── __init__.py
│ ├── stt # 语音转文本工具函数
│ ├── tts # 语音合成工具函数
│ └── ... # 其他工具函数
|
├── main.py #启动脚本
├── app.log # 日志文件
├── Dockerfile # 用于构建Docker镜像的Dockerfile
├── requirements.txt # 项目依赖列表
└── README.md # 项目说明文件
```

53
app/__init__.py Normal file
View File

@ -0,0 +1,53 @@
from fastapi import FastAPI, Depends
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import create_engine
from .models import Base
from .routes.character import router as character_router
from .routes.user import router as user_router
from .routes.session import router as session_router
from .routes.chat import router as chat_router
from .dependencies.logger import get_logger
from config import get_config
#----------------------获取日志------------------------
logger = get_logger()
logger.info("日志初始化完成")
#-----------------------------------------------------
#----------------------获取配置------------------------
Config = get_config()
logger.info("配置获取完成")
#-----------------------------------------------------
#--------------------初始化数据库-----------------------
engine = create_engine(Config.SQLALCHEMY_DATABASE_URI)
Base.metadata.create_all(bind=engine)
logger.info("数据库初始化完成")
#------------------------------------------------------
#--------------------创建FastAPI实例--------------------
app = FastAPI()
logger.info("FastAPI实例创建完成")
#------------------------------------------------------
#---------------------初始化路由------------------------
app.include_router(character_router)
app.include_router(user_router)
app.include_router(session_router)
app.include_router(chat_router)
logger.info("路由初始化完成")
#-------------------------------------------------------
#-------------------设置跨域中间件-----------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有源,也可以指定特定源
allow_credentials=True,
allow_methods=["*"], # 允许所有方法
allow_headers=["*"], # 允许所有头
)
#-------------------------------------------------------

View File

@ -0,0 +1 @@
from . import *

View File

@ -0,0 +1,89 @@
from ..schemas.character import *
from ..dependencies.logger import get_logger
from sqlalchemy.orm import Session
from ..models import Character
from datetime import datetime
from fastapi import HTTPException, status
#依赖注入获取logger
logger = get_logger()
#创建新角色
async def create_character_handler(character: CharacterCreateRequest, db: Session):
new_character = Character(voice_id=character.voice_id, avatar_id=character.avatar_id, background_ids=character.background_ids, name=character.name,
wakeup_words=character.wakeup_words, world_scenario = character.world_scenario, description=character.description, emojis=character.emojis, dialogues=character.dialogues)
try:
db.add(new_character)
db.commit()
db.refresh(new_character)
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
character_create_data = CharacterCreateData(character_id=new_character.id, createdAt=datetime.now().isoformat())
return CharacterCreateResponse(status="success",message="创建角色成功",data=character_create_data)
#更新角色
async def update_character_handler(character_id: int, character: CharacterUpdateRequest, db: Session):
existing_character = db.query(Character).filter(Character.id == character_id).first()
if not existing_character:
return HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="角色不存在")
existing_character.voice_id = character.voice_id
existing_character.avatar_id = character.avatar_id
existing_character.background_ids = character.background_ids
existing_character.name = character.name
existing_character.wakeup_words = character.wakeup_words
existing_character.world_scenario = character.world_scenario
existing_character.description = character.description
existing_character.emojis = character.emojis
existing_character.dialogues = character.dialogues
try:
db.add(existing_character)
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
character_update_data = CharacterUpdateData(updatedAt=datetime.now().isoformat())
return CharacterUpdateResponse(status="success",message="更新角色成功",data=character_update_data)
#查询角色
async def get_character_handler(character_id: int, db: Session):
try:
existing_character = db.query(Character).filter(Character.id == character_id).first()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if not existing_character:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="角色不存在")
character_query_data = CharacterQueryData(voice_id=existing_character.voice_id,avatar_id=existing_character.avatar_id,background_ids=existing_character.background_ids,name=existing_character.name,
wakeup_words=existing_character.wakeup_words,world_scenario=existing_character.world_scenario,description=existing_character.description,emojis=existing_character.emojis,dialogues=existing_character.dialogues)
return CharacterQueryResponse(status="success",message="查询角色成功",data=character_query_data)
#删除角色
async def delete_character_handler(character_id: int, db: Session):
try:
existing_character = db.query(Character).filter(Character.id == character_id).first()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if not existing_character:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="角色不存在")
try:
db.delete(existing_character)
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
character_delete_data = CharacterDeleteData(id=character_id,deletedAt=datetime.now().isoformat())
return CharacterDeleteResponse(status="success",message="删除角色成功",data=character_delete_data)

613
app/controllers/chat.py Normal file
View File

@ -0,0 +1,613 @@
from ..schemas.chat import *
from ..dependencies.logger import get_logger
from .controller_enum import *
from ..models import UserCharacter, Session, Character, User
from utils.audio_utils import VAD
from fastapi import WebSocket, HTTPException, status
from datetime import datetime
from utils.xf_asr_utils import generate_xf_asr_url
from config import get_config
import uuid
import json
import requests
import asyncio
# 依赖注入获取logger
logger = get_logger()
# --------------------初始化本地ASR-----------------------
from utils.stt.funasr_utils import FunAutoSpeechRecognizer
asr = FunAutoSpeechRecognizer()
logger.info("本地ASR初始化成功")
# -------------------------------------------------------
# --------------------初始化本地VITS----------------------
from utils.tts.vits_utils import TextToSpeech
tts = TextToSpeech(device='cpu')
logger.info("本地TTS初始化成功")
# -------------------------------------------------------
# 依赖注入获取Config
Config = get_config()
# ----------------------工具函数-------------------------
#获取session内容
def get_session_content(session_id,redis,db):
session_content_str = ""
if redis.exists(session_id):
session_content_str = redis.get(session_id)
else:
session_db = db.query(Session).filter(Session.id == session_id).first()
if not session_db:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
session_content_str = session_db.content
return json.loads(session_content_str)
#解析大模型流式返回内容
def parseChunkDelta(chunk):
decoded_data = chunk.decode('utf-8')
parsed_data = json.loads(decoded_data[6:])
if 'delta' in parsed_data['choices'][0]:
delta_content = parsed_data['choices'][0]['delta']
return delta_content['content']
else:
return ""
#断句函数
def split_string_with_punctuation(current_sentence,text,is_first):
result = []
for char in text:
current_sentence += char
if is_first and char in ',.?!,。?!':
result.append(current_sentence)
current_sentence = ''
is_first = False
elif char in '。?!':
result.append(current_sentence)
current_sentence = ''
return result, current_sentence, is_first
#--------------------------------------------------------
# 创建新聊天
async def create_chat_handler(chat: ChatCreateRequest, db, redis):
# 创建新的UserCharacter记录
new_chat = UserCharacter(user_id=chat.user_id, character_id=chat.character_id)
try:
db.add(new_chat)
db.commit()
db.refresh(new_chat)
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
# 查询所要创建聊天的角色信息并创建SystemPrompt
db_character = db.query(Character).filter(Character.id == chat.character_id).first()
db_user = db.query(User).filter(User.id == chat.user_id).first()
system_prompt = f"""我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\n{"角色名称: " + db_character.name}\n{"角色背景: " + db_character.description}\n{"角色所处环境: " + db_character.world_scenario}\n
{"角色的常用问候语: " + db_character.wakeup_words}\n你需要用简单通俗易懂的口语化方式进行对话在没有经过允许的情况下你需要保持上述角色不得擅自跳出角色设定\n 与你聊天的对象信息如下{db_user.persona}"""
# 创建新的Session记录
session_id = str(uuid.uuid4())
user_id = chat.user_id
messages = json.dumps([{"role": "system", "content": system_prompt}], ensure_ascii=False)
tts_info = {
"language": 0,
"speaker_id":db_character.voice_id,
"noise_scale": 0.1,
"noise_scale_w":0.668,
"length_scale": 1.2
}
llm_info = {
"model": "abab5.5-chat",
"temperature": 1,
"top_p": 0.9,
}
# 将tts和llm信息转化为json字符串
tts_info_str = json.dumps(tts_info, ensure_ascii=False)
llm_info_str = json.dumps(llm_info, ensure_ascii=False)
user_info_str = db_user.persona
token = 0
content = {"user_id": user_id, "messages": messages, "user_info": user_info_str, "tts_info": tts_info_str,
"llm_info": llm_info_str, "token": token}
new_session = Session(id=session_id, user_character_id=new_chat.id, content=json.dumps(content, ensure_ascii=False),
last_activity=datetime.now(), is_permanent=False)
# 将Session记录存入
db.add(new_session)
db.commit()
db.refresh(new_session)
redis.set(session_id, json.dumps(content, ensure_ascii=False))
chat_create_data = ChatCreateData(user_character_id=new_chat.id, session_id=session_id, createdAt=datetime.now().isoformat())
return ChatCreateResponse(status="success", message="创建聊天成功", data=chat_create_data)
#删除聊天
async def delete_chat_handler(user_character_id, db, redis):
# 查询该聊天记录
user_character_record = db.query(UserCharacter).filter(UserCharacter.id == user_character_id).first()
if not user_character_record:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="UserCharacter not found")
session_record = db.query(Session).filter(Session.user_character_id == user_character_id).first()
try:
redis.delete(session_record.id)
except Exception as e:
logger.error(f"删除Redis中Session记录时发生错误: {str(e)}")
try:
db.delete(session_record)
db.delete(user_character_record)
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
chat_delete_data = ChatDeleteData(deletedAt=datetime.now().isoformat())
return ChatDeleteResponse(status="success", message="删除聊天成功", data=chat_delete_data)
# 非流式聊天
async def non_streaming_chat_handler(chat: ChatNonStreamRequest, db, redis):
pass
#---------------------------------------单次流式聊天接口---------------------------------------------
#处理用户输入
async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,user_input_finish_event):
logger.debug("用户输入处理函数启动")
is_future_done = False
try:
while not user_input_finish_event.is_set():
sct_data_json = json.loads(await ws.receive_text())
if not is_future_done:
future_session_id.set_result(sct_data_json['meta_info']['session_id'])
if sct_data_json['meta_info']['voice_synthesize']:
future_response_type.set_result(RESPONSE_AUDIO)
else:
future_response_type.set_result(RESPONSE_TEXT)
is_future_done = True
if sct_data_json['text']:
await llm_input_q.put(sct_data_json['text'])
if not user_input_finish_event.is_set():
user_input_finish_event.set()
break
if sct_data_json['meta_info']['is_end']:
await user_input_q.put(sct_data_json['audio'])
if not user_input_finish_event.is_set():
user_input_finish_event.set()
break
await user_input_q.put(sct_data_json['audio'])
except KeyError as ke:
if sct_data_json['state'] == 1 and sct_data_json['method'] == 'heartbeat':
logger.debug("收到心跳包")
#语音识别
async def sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event):
logger.debug("语音识别函数启动")
current_message = ""
while not (user_input_finish_event.is_set() and user_input_q.empty()):
audio_data = await user_input_q.get()
asr_result = asr.streaming_recognize(audio_data)
current_message += ''.join(asr_result['text'])
asr_result = asr.streaming_recognize(b'',is_end=True)
current_message += ''.join(asr_result['text'])
await llm_input_q.put(current_message)
logger.debug(f"接收到用户消息: {current_message}")
#大模型调用
async def sct_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,llm_response_finish_event):
logger.debug("llm调用函数启动")
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"])
current_message = await llm_input_q.get()
messages.append({'role': 'user', "content": current_message})
payload = json.dumps({
"model": llm_info["model"],
"stream": True,
"messages": messages,
"max_tokens": 10000,
"temperature": llm_info["temperature"],
"top_p": llm_info["top_p"]
})
headers = {
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json'
}
response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True)
if response.status_code == 200:
for chunk in response.iter_lines():
if chunk:
chunk_data = parseChunkDelta(chunk)
await llm_response_q.put(chunk_data)
llm_response_finish_event.set()
#大模型返回断句
async def sct_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_response_finish_event):
logger.debug("llm返回处理函数启动")
llm_response = ""
current_sentence = ""
is_first = True
while not (llm_response_finish_event.is_set() and llm_response_q.empty()):
llm_chunk = await llm_response_q.get()
llm_response += llm_chunk
sentences, current_sentence, is_first = split_string_with_punctuation(current_sentence, llm_chunk, is_first)
for sentence in sentences:
await split_result_q.put(sentence)
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"])
messages.append({'role': 'assistant', "content": llm_response})
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
logger.debug(f"llm返回结果: {llm_response}")
#文本返回及语音合成
async def sct_response_handler(ws,tts_info,response_type,split_result_q,llm_response_finish_event,chat_finish_event):
logger.debug("返回处理函数启动")
while not (llm_response_finish_event.is_set() and split_result_q.empty()):
sentence = await split_result_q.get()
if response_type == RESPONSE_TEXT:
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
elif response_type == RESPONSE_AUDIO:
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
response_message = {"type": "text", "code":200, "msg": sentence}
await ws.send_bytes(audio)
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
logger.debug(f"websocket返回: {sentence}")
chat_finish_event.set()
async def streaming_chat_temporary_handler(ws: WebSocket, db, redis):
logger.debug("streaming chat temporary websocket 连接建立")
user_input_q = asyncio.Queue() # 用于存储用户输入
llm_input_q = asyncio.Queue() # 用于存储llm输入
llm_response_q = asyncio.Queue() # 用于存储llm输出
split_result_q = asyncio.Queue() # 用于存储tts输出
user_input_finish_event = asyncio.Event()
llm_response_finish_event = asyncio.Event()
chat_finish_event = asyncio.Event()
future_session_id = asyncio.Future()
future_response_type = asyncio.Future()
asyncio.create_task(sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,user_input_finish_event))
asyncio.create_task(sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event))
session_id = await future_session_id #获取session_id
response_type = await future_response_type #获取返回类型
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
asyncio.create_task(sct_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,llm_response_finish_event))
asyncio.create_task(sct_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_response_finish_event))
asyncio.create_task(sct_response_handler(ws,tts_info,response_type,split_result_q,llm_response_finish_event,chat_finish_event))
while not chat_finish_event.is_set():
await asyncio.sleep(1)
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
await ws.close()
logger.debug("streaming chat temporary websocket 连接断开")
#---------------------------------------------------------------------------------------------------
# 持续流式聊天
async def streaming_chat_lasting_handler(ws, db, redis):
print("Websocket连接成功")
while True:
try:
print("等待接受")
data = await asyncio.wait_for(ws.receive_text(), timeout=60)
data_json = json.loads(data)
if data_json["is_close"]:
close_message = {"type": "close", "code": 200, "msg": ""}
await ws.send_text(json.dumps(close_message, ensure_ascii=False))
print("连接关闭")
await asyncio.sleep(0.5)
await ws.close()
return;
except asyncio.TimeoutError:
print("连接超时")
await ws.close()
return;
current_message = "" # 用于存储用户消息
response_type = RESPONSE_TEXT # 用于获取返回类型
session_id = ""
if Config.STRAM_CHAT.ASR == "LOCAL":
try:
while True:
if data_json["text"]: # 若文字输入不为空,则表示该输入为文字输入
if data_json["meta_info"]["voice_synthesize"]:
response_type = RESPONSE_AUDIO # 查看voice_synthesize判断返回类型
session_id = data_json["meta_info"]["session_id"]
current_message = data_json['text']
break
if not data_json['meta_info']['is_end']: # 还在发
asr_result = asr.streaming_recognize(data_json["audio"])
current_message += ''.join(asr_result['text'])
else: # 发完了
asr_result = asr.streaming_recognize(data_json["audio"], is_end=True)
session_id = data_json["meta_info"]["session_id"]
current_message += ''.join(asr_result['text'])
if data_json["meta_info"]["voice_synthesize"]:
response_type = RESPONSE_AUDIO # 查看voice_synthesize判断返回类型
break
data_json = json.loads(await ws.receive_text())
except Exception as e:
error_info = f"接收用户消息错误: {str(e)}"
error_message = {"type": "error", "code": "500", "msg": error_info}
logger.error(error_info)
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
await ws.close()
return
elif Config.STRAM_CHAT.ASR == "REMOTE":
error_info = f"远程ASR服务暂未开通"
error_message = {"type": "error", "code": "500", "msg": error_info}
logger.error(error_info)
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
await ws.close()
return
print(f"接收到用户消息: {current_message}")
# 查询Session记录
session_content_str = ""
if redis.exists(session_id):
session_content_str = redis.get(session_id)
else:
session_db = db.query(Session).filter(Session.id == session_id).first()
if not session_db:
error_info = f"未找到session记录: {str(e)}"
error_message = {"type": "error", "code": 500, "msg": error_info}
logger.error(error_info)
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
await ws.close()
return
session_content_str = session_db.content
session_content = json.loads(session_content_str)
llm_info = json.loads(session_content["llm_info"])
tts_info = json.loads(session_content["tts_info"])
user_info = json.loads(session_content["user_info"])
messages = json.loads(session_content["messages"])
messages.append({'role': 'user', "content": current_message})
token_count = session_content["token"]
try:
payload = json.dumps({
"model": llm_info["model"],
"stream": True,
"messages": messages,
"tool_choice": "auto",
"max_tokens": 10000,
"temperature": llm_info["temperature"],
"top_p": llm_info["top_p"]
})
headers = {
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json'
}
response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True)
except Exception as e:
error_info = f"发送信息给大模型时发生错误: {str(e)}"
error_message = {"type": "error", "code": 500, "msg": error_info}
logger.error(error_info)
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
await ws.close()
return
def split_string_with_punctuation(text):
punctuations = "!?。"
result = []
current_sentence = ""
for char in text:
current_sentence += char
if char in punctuations:
result.append(current_sentence)
current_sentence = ""
# 判断最后一个字符是否为标点符号
if current_sentence and current_sentence[-1] not in punctuations:
# 如果最后一段不以标点符号结尾,则加入拆分数组
result.append(current_sentence)
return result
llm_response = ""
response_buf = ""
try:
if Config.STRAM_CHAT.TTS == "LOCAL":
if response.status_code == 200:
for chunk in response.iter_lines():
if chunk:
if response_type == RESPONSE_AUDIO:
chunk_data = parseChunkDelta(chunk)
llm_response += chunk_data
response_buf += chunk_data
split_buf = split_string_with_punctuation(response_buf)
response_buf = ""
if len(split_buf) != 0:
for sentence in split_buf:
sr, audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
text_response = {"type": "text", "code": 200, "msg": sentence}
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) # 返回文本数据
await ws.send_bytes(audio) # 返回音频二进制流数据
if response_type == RESPONSE_TEXT:
chunk_data = parseChunkDelta(chunk)
llm_response += chunk_data
text_response = {"type": "text", "code": 200, "msg": chunk_data}
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) # 返回文本数据
elif Config.STRAM_CHAT.TTS == "REMOTE":
error_info = f"暂不支持远程音频合成"
error_message = {"type": "error", "code": 500, "msg": error_info}
logger.error(error_info)
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
await ws.close()
return
end_response = {"type": "end", "code": 200, "msg": ""}
await ws.send_text(json.dumps(end_response, ensure_ascii=False)) # 单次返回结束
print(f"llm消息: {llm_response}")
except Exception as e:
error_info = f"音频合成与向前端返回时错误: {str(e)}"
error_message = {"type": "error", "code": 500, "msg": error_info}
logger.error(error_info)
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
await ws.close()
return
try:
messages.append({'role': 'assistant', "content": llm_response})
token_count += len(llm_response)
session_content["messages"] = json.dumps(messages, ensure_ascii=False)
session_content["token"] = token_count
redis.set(session_id, json.dumps(session_content, ensure_ascii=False))
except Exception as e:
error_info = f"更新session时错误: {str(e)}"
error_message = {"type": "error", "code": 500, "msg": error_info}
logger.error(error_info)
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
await ws.close()
return
print("处理完毕")
#--------------------------------语音通话接口--------------------------------------
#音频数据生产函数
async def voice_call_audio_producer(ws,audio_queue,future,shutdown_event):
logger.debug("音频数据生产函数启动")
is_future_done = False
while not shutdown_event.is_set():
voice_call_data_json = json.loads(await ws.receive_text())
if not is_future_done: #在第一次循环中读取session_id
future.set_result(voice_call_data_json['meta_info']['session_id'])
is_future_done = True
if voice_call_data_json["is_close"]:
shutdown_event.set()
break
else:
await audio_queue.put(voice_call_data_json["audio"]) #将音频数据存入audio_q
#音频数据消费函数
async def voice_call_audio_consumer(audio_q,asr_result_q,shutdown_event):
logger.debug("音频数据消费者函数启动")
vad = VAD()
current_message = ""
vad_count = 0
while not (shutdown_event.is_set() and audio_q.empty()):
audio_data = await audio_q.get()
if vad.is_speech(audio_data):
if vad_count > 0:
vad_count -= 1
asr_result = asr.streaming_recognize(audio_data)
current_message += ''.join(asr_result['text'])
else:
vad_count += 1
if vad_count >= 25: #连续25帧没有语音则认为说完了
asr_result = asr.streaming_recognize(audio_data, is_end=True)
if current_message:
logger.debug(f"检测到静默,用户输入为:{current_message}")
await asr_result_q.put(current_message)
current_message = ""
vad_count = 0
#asr结果消费以及llm返回生产函数
async def voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,shutdown_event):
logger.debug("asr结果消费以及llm返回生产函数启动")
while not (shutdown_event.is_set() and asr_result_q.empty()):
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"])
current_message = await asr_result_q.get()
messages.append({'role': 'user', "content": current_message})
payload = json.dumps({
"model": llm_info["model"],
"stream": True,
"messages": messages,
"max_tokens":10000,
"temperature": llm_info["temperature"],
"top_p": llm_info["top_p"]
})
headers = {
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
'Content-Type': 'application/json'
}
response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True)
if response.status_code == 200:
for chunk in response.iter_lines():
if chunk:
chunk_data =parseChunkDelta(chunk)
llm_frame = {'message':chunk_data,'is_end':False}
await llm_response_q.put(llm_frame)
llm_frame = {'message':"",'is_end':True}
await llm_response_q.put(llm_frame)
#llm结果返回函数
async def voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,shutdown_event):
logger.debug("llm结果返回函数启动")
llm_response = ""
current_sentence = ""
is_first = True
while not (shutdown_event.is_set() and llm_response_q.empty()):
llm_frame = await llm_response_q.get()
llm_response += llm_frame['message']
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,llm_frame['message'],is_first)
for sentence in sentences:
await split_result_q.put(sentence)
if llm_frame['is_end']:
is_first = True
session_content = get_session_content(session_id,redis,db)
messages = json.loads(session_content["messages"])
messages.append({'role': 'assistant', "content": llm_response})
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
logger.debug(f"llm返回结果: {llm_response}")
llm_response = ""
current_sentence = ""
#语音合成及返回函数
async def voice_call_tts_handler(ws,tts_info,split_result_q,shutdown_event):
logger.debug("语音合成及返回函数启动")
while not (shutdown_event.is_set() and split_result_q.empty()):
sentence = await split_result_q.get()
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
text_response = {"type": "text", "code": 200, "msg": sentence}
await ws.send_bytes(audio) #返回音频二进制流数据
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
logger.debug(f"websocket返回:{sentence}")
asyncio.sleep(0.5)
await ws.close()
async def voice_call_handler(ws, db, redis):
logger.debug("voice_call websocket 连接建立")
audio_q = asyncio.Queue()
asr_result_q = asyncio.Queue()
llm_response_q = asyncio.Queue()
split_result_q = asyncio.Queue()
shutdown_event = asyncio.Event()
future = asyncio.Future()
asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,shutdown_event)) #创建音频数据生产者
asyncio.create_task(voice_call_audio_consumer(audio_q,asr_result_q,shutdown_event)) #创建音频数据消费者
#获取session内容
session_id = await future #获取session_id
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
asyncio.create_task(voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,shutdown_event)) #创建llm处理者
asyncio.create_task(voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,shutdown_event)) #创建llm断句结果
asyncio.create_task(voice_call_tts_handler(ws,tts_info,split_result_q,shutdown_event)) #返回tts音频结果
while not shutdown_event.is_set():
await asyncio.sleep(5)
await ws.close()
logger.debug("voice_call websocket 连接断开")
#------------------------------------------------------------------------------------------

View File

@ -0,0 +1,13 @@
#聊天类型:文本、语音、不确定
CHAT_TEXT = 0
CHAT_AUDIO = 1
CHAT_UNCERTAIN = -1
#流式传输帧类型:第一帧、中间帧、最后帧
FIRST_FRAME = 1
CONTINUE_FRAME =2
LAST_FRAME =3
#响应类型:文本、语音
RESPONSE_TEXT = 0
RESPONSE_AUDIO = 1

View File

@ -0,0 +1,79 @@
from ..schemas.session import *
from ..dependencies.logger import get_logger
from fastapi import HTTPException, status
from ..models import Session
from ..models import UserCharacter
from datetime import datetime
import json
#依赖注入获取logger
logger = get_logger()
#获取SessionID
async def get_session_id_handler(user_id: int, character_id:int, db):
try:
user_character_record = db.query(UserCharacter).filter(UserCharacter.user_id == user_id, UserCharacter.character_id == character_id).first()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if not user_character_record:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User Character not found")
try:
session_id = db.query(Session).filter(Session.user_character_id==user_character_record.id).first().id
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
session_id_query_data = SessionIdQueryData(session_id=session_id)
return SessionIdQueryResponse(status="success",message="Session ID 获取成功",data=session_id_query_data)
#查询Session信息
async def get_session_handler(session_id: int, db, redis):
session_str = ""
if redis.exists(session_id):
session_str = redis.get(session_id)
else:
try:
session_str = db.query(Session).filter(Session.id == session_id).first().content
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
redis.set(session_id, session_str)
if not session_str:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
session = json.loads(session_str)
session_query_data = SessionQueryData(user_id=session["user_id"], messages=session["messages"],user_info=session["user_info"],tts_info=session["tts_info"],llm_info=session["llm_info"],token=session["token"])
return SessionQueryResponse(status="success",message="Session 查询成功",data=session_query_data)
#更新Sessino信息
async def update_session_handler(session_id, session_data:SessionUpdateRequest, db, redis):
existing_session = ""
if redis.exists(session_id):
existing_session = redis.get(session_id)
else:
existing_session = db.query(Session).filter(Session.id == session_id).first().content
if not existing_session:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
#更新Session字段
session = json.loads(existing_session)
session["user_id"] = session_data.user_id
session["messages"] = session_data.messages
session["user_info"] = session_data.user_info
session["tts_info"] = session_data.tts_info
session["llm_info"] = session_data.llm_info
session["token"] = session_data.token
#存储Session
session_str = json.dumps(session,ensure_ascii=False)
redis.set(session_id, session_str)
try:
db.query(Session).filter(Session.id == session_id).update({"content": session_str})
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
session_update_data = SessionUpdateData(updatedAt=datetime.now().isoformat())
return SessionUpdateResponse(status="success",message="Session 更新成功",data=session_update_data)

159
app/controllers/user.py Normal file
View File

@ -0,0 +1,159 @@
from ..schemas.user import *
from ..dependencies.logger import get_logger
from ..models import User, Hardware
from datetime import datetime
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
#依赖注入获取logger
logger = get_logger()
#创建用户
async def create_user_handler(user:UserCrateRequest, db: Session):
new_user = User(created_at=datetime.now(), open_id=user.open_id, username=user.username, avatar_id=user.avatar_id, tags=user.tags, persona=user.persona)
try:
db.add(new_user)
db.commit()
db.refresh(new_user)
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
user_create_data = UserCrateData(user_id=new_user.id, createdAt=new_user.created_at.isoformat())
return UserCrateResponse(status="success", message="创建用户成功", data=user_create_data)
#更新用户信息
async def update_user_handler(user_id:int, user:UserUpdateRequest, db: Session):
existing_user = db.query(User).filter(User.id == user_id).first()
if existing_user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
existing_user.open_id = user.open_id
existing_user.username = user.username
existing_user.avatar_id = user.avatar_id
existing_user.tags = user.tags
existing_user.persona = user.persona
try:
db.add(existing_user)
db.commit()
db.refresh(existing_user)
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
user_update_data = UserUpdateData(updatedAt=datetime.now().isoformat())
return UserUpdateResponse(status="success", message="更新用户信息成功", data=user_update_data)
#查询用户信息
async def get_user_handler(user_id:int, db: Session):
try:
existing_user = db.query(User).filter(User.id == user_id).first()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if existing_user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
user_query_data = UserQueryData(open_id=existing_user.open_id, username=existing_user.username, avatar_id=existing_user.avatar_id, tags=existing_user.tags, persona=existing_user.persona)
return UserQueryResponse(status="success", message="查询用户信息成功", data=user_query_data)
#删除用户
async def delete_user_handler(user_id:int, db: Session):
try:
existing_user = db.query(User).filter(User.id == user_id).first()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if existing_user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
try:
db.delete(existing_user)
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
user_delete_data = UserDeleteData(deletedAt=datetime.now().isoformat())
return UserDeleteResponse(status="success", message="删除用户成功", data=user_delete_data)
#绑定硬件
async def bind_hardware_handler(hardware, db: Session):
new_hardware = Hardware(mac=hardware.mac, user_id=hardware.user_id, firmware=hardware.firmware, model=hardware.model)
try:
db.add(new_hardware)
db.commit()
db.refresh(new_hardware)
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
hardware_bind_data = HardwareBindData(hardware_id=new_hardware.id, bindedAt=datetime.now().isoformat())
return HardwareBindResponse(status="success", message="绑定硬件成功", data=hardware_bind_data)
#解绑硬件
async def unbind_hardware_handler(hardware_id:int, db: Session):
try:
existing_hardware = db.query(Hardware).filter(Hardware.id == hardware_id).first()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if existing_hardware is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
try:
db.delete(existing_hardware)
db.commit()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
hardware_unbind_data = HardwareUnbindData(unbindedAt=datetime.now().isoformat())
return HardwareUnbindResponse(status="success", message="解绑硬件成功", data=hardware_unbind_data)
#硬件换绑
async def change_bind_hardware_handler(hardware_id, user, db):
try:
existing_hardware = db.query(Hardware).filter(Hardware.id == hardware_id).first()
if existing_hardware is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
existing_hardware.user_id = user.user_id
db.add(existing_hardware)
db.commit()
db.refresh(existing_hardware)
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
hardware_change_bind_data = HardwareChangeBindData(bindChangedAt=datetime.now().isoformat())
return HardwareChangeBindResponse(status="success", message="硬件换绑成功", data=hardware_change_bind_data)
#硬件信息更新
async def update_hardware_handler(hardware_id, hardware, db):
try:
existing_hardware = db.query(Hardware).filter(Hardware.id == hardware_id).first()
if existing_hardware is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
existing_hardware.mac = hardware.mac
existing_hardware.firmware = hardware.firmware
existing_hardware.model = hardware.model
db.add(existing_hardware)
db.commit()
db.refresh(existing_hardware)
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
hardware_update_data = HardwareUpdateData(updatedAt=datetime.now().isoformat())
return HardwareUpdateResponse(status="success", message="硬件信息更新成功", data=hardware_update_data)
#查询硬件
async def get_hardware_handler(hardware_id, db):
try:
existing_hardware = db.query(Hardware).filter(Hardware.id == hardware_id).first()
except Exception as e:
db.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
if existing_hardware is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
hardware_query_data = HardwareQueryData(mac=existing_hardware.mac, user_id=existing_hardware.user_id, firmware=existing_hardware.firmware, model=existing_hardware.model)
return HardwareQueryResponse(status="success", message="查询硬件信息成功", data=hardware_query_data)

View File

@ -0,0 +1 @@
from . import *

View File

@ -0,0 +1,15 @@
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine
from config import get_config
Config = get_config()
LocalSession = sessionmaker(autocommit=False, autoflush=False, bind=create_engine(Config.SQLALCHEMY_DATABASE_URI))
#返回一个数据库连接
def get_db():
db = LocalSession()
try:
yield db
finally:
db.close()

View File

@ -0,0 +1,27 @@
import logging
from config import get_config
#获取配置信息
Config = get_config()
#日志类
class Logger:
def __init__(self):
self.logger = logging.getLogger(__name__)
self.logger.setLevel(Config.LOG_LEVEL)
self.logger.propagate = False
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
if not self.logger.handlers: # 检查是否已经有处理器
# 输出到控制台
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
# 输出到文件
file_handler = logging.FileHandler('app.log')
file_handler.setFormatter(formatter)
self.logger.addHandler(file_handler)
def get_logger():
return Logger().logger

View File

@ -0,0 +1,8 @@
import redis
from config import get_config
#获取配置信息
Config = get_config()
def get_redis():
return redis.Redis.from_url(Config.REDIS_URL)

View File

6
app/main.py Normal file
View File

@ -0,0 +1,6 @@
from app import app
import uvicorn
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7878)

1
app/models/__init__.py Normal file
View File

@ -0,0 +1 @@
from .models import *

82
app/models/models.py Normal file
View File

@ -0,0 +1,82 @@
from sqlalchemy import Column, Integer, String, JSON, Text, ForeignKey, DateTime, Boolean, CHAR
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
#角色表定义
class Character(Base):
__tablename__ = 'character'
id = Column(Integer, primary_key=True, autoincrement=True)
voice_id = Column(Integer, nullable=False)
avatar_id = Column(String(36), nullable=False)
background_ids = Column(String(255), nullable=False)
name = Column(String(36), nullable=False)
wakeup_words = Column(String(255), nullable=False)
world_scenario = Column(Text, nullable=False)
description = Column(Text, nullable=False)
emojis = Column(JSON, nullable=False)
dialogues = Column(Text, nullable=False)
def __repr__(self):
return f"<Character(id={self.id}, name={self.name}, avatar_id={self.avatar_id})>"
#用户表定义
class User(Base):
__tablename__ = 'user'
id = Column(Integer, primary_key=True, autoincrement=True)
created_at = Column(DateTime, nullable=True)
updated_at = Column(DateTime, nullable=True)
deleted_at = Column(DateTime, nullable=True)
open_id = Column(String(255), nullable=True)
username = Column(String(64), nullable=True)
avatar_id = Column(String(36), nullable=True)
tags = Column(JSON)
persona = Column(JSON)
def __repr__(self):
return f"<User(id={self.id}, tags={self.tags})>"
#硬件表定义
class Hardware(Base):
__tablename__ = 'hardware'
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('user.id'))
mac = Column(String(17))
firmware = Column(String(16))
model = Column(String(36))
def __repr__(self):
return f"<Hardware( mac={self.mac})>"
#用户角色表定义
class UserCharacter(Base):
__tablename__ = 'user_character'
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('user.id'))
character_id = Column(Integer, ForeignKey('character.id'))
persona = Column(JSON)
def __repr__(self):
return f"<UserCharacter(id={self.id}, user_id={self.user_id}, character_id={self.character_id})>"
#Session表定义
class Session(Base):
__tablename__ = 'session'
id = Column(CHAR(36), primary_key=True)
user_character_id = Column(Integer, ForeignKey('user_character.id'))
content = Column(Text)
last_activity = Column(DateTime())
is_permanent = Column(Boolean)
def __repr__(self):
return f"<Session(id={self.id}, user_character_id={self.user_character_id})>"

1
app/routes/__init__.py Normal file
View File

@ -0,0 +1 @@
from . import *

35
app/routes/character.py Normal file
View File

@ -0,0 +1,35 @@
from fastapi import APIRouter, HTTPException, status
from ..controllers.character import *
from fastapi import Depends
from sqlalchemy.orm import Session
from ..dependencies.database import get_db
router = APIRouter()
#角色创建接口
@router.post("/characters",response_model=CharacterCreateResponse)
async def create_character(character:CharacterCreateRequest, db: Session = Depends(get_db)):
response = await create_character_handler(character, db)
return response
#用户更新接口
@router.put("/characters/{character_id}",response_model=CharacterUpdateResponse)
async def update_character(character_id:int,character:CharacterUpdateRequest, db: Session = Depends(get_db)):
response = await update_character_handler(character_id, character, db)
return response
#角色删除接口
@router.delete("/characters/{character_id}", response_model=CharacterDeleteResponse)
async def delete_character(character_id: int, db: Session = Depends(get_db)):
response = await delete_character_handler(character_id, db)
return response
#角色查询接口
@router.get("/characters/{character_id}",response_model=CharacterQueryResponse)
async def get_character(character_id: int, db: Session = Depends(get_db)):
response = await get_character_handler(character_id, db)
return response

48
app/routes/chat.py Normal file
View File

@ -0,0 +1,48 @@
from fastapi import APIRouter, HTTPException, status, WebSocket
from ..controllers.chat import *
from fastapi import Depends
from ..dependencies.database import get_db
from ..dependencies.redis import get_redis
router = APIRouter()
#创建新聊天接口
@router.post("/chats", response_model=ChatCreateResponse)
async def create_chat(chat: ChatCreateRequest, db=Depends(get_db), redis=Depends(get_redis)):
response = await create_chat_handler(chat, db, redis)
return response
#删除聊天接口
@router.delete("/chats/{user_character_id}", response_model=ChatDeleteResponse)
async def delete_chat(user_character_id: int, db=Depends(get_db), redis=Depends(get_redis)):
response = await delete_chat_handler(user_character_id, db, redis)
return response
#非流式聊天
@router.post("/chats/non-streaming", response_model=ChatNonStreamResponse)
async def non_streaming_chat(chat: ChatNonStreamRequest,db=Depends(get_db), redis=Depends(get_redis)):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,detail="this api is not available")
response = await non_streaming_chat_handler(chat, db, redis)
return response
#流式聊天_单次
@router.websocket("/chat/streaming/temporary")
async def streaming_chat(ws: WebSocket,db=Depends(get_db), redis=Depends(get_redis)):
await ws.accept()
await streaming_chat_temporary_handler(ws,db,redis)
#流式聊天_持续
@router.websocket("/chat/streaming/lasting")
async def streaming_chat(ws: WebSocket,db=Depends(get_db), redis=Depends(get_redis)):
await ws.accept()
await streaming_chat_lasting_handler(ws,db,redis)
#语音通话
@router.websocket("/chat/voice_call")
async def voice_chat(ws: WebSocket,db=Depends(get_db), redis=Depends(get_redis)):
await ws.accept()
await voice_call_handler(ws,db,redis)

29
app/routes/session.py Normal file
View File

@ -0,0 +1,29 @@
from fastapi import APIRouter, Query, HTTPException, status
from ..controllers.session import *
from fastapi import Depends
from ..dependencies.database import get_db
from ..dependencies.redis import get_redis
router = APIRouter()
#session_id查询接口
@router.get("/sessions", response_model=SessionIdQueryResponse)
async def get_session_id(user_id: int=Query(..., description="用户id"), character_id: int=Query(..., description="角色id"),db=Depends(get_db)):
response = await get_session_id_handler(user_id, character_id, db)
return response
#session查询接口
@router.get("/sessions/{session_id}", response_model=SessionQueryResponse)
async def get_session(session_id: str, db=Depends(get_db), redis=Depends(get_redis)):
response = await get_session_handler(session_id, db, redis)
return response
#session更新接口
@router.put("/sessions/{session_id}", response_model=SessionUpdateResponse)
async def update_session(session_id: str, session_data: SessionUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
response = await update_session_handler(session_id, session_data, db, redis)
return response

71
app/routes/user.py Normal file
View File

@ -0,0 +1,71 @@
from fastapi import APIRouter, HTTPException, status
from ..controllers.user import *
from fastapi import Depends
from sqlalchemy.orm import Session
from ..dependencies.database import get_db
router = APIRouter()
#角色创建接口
@router.post('/users', response_model=UserCrateResponse)
async def create_user(user: UserCrateRequest,db: Session = Depends(get_db)):
response = await create_user_handler(user,db)
return response
#角色更新接口
@router.put('/users/{user_id}', response_model=UserUpdateResponse)
async def update_user(user_id: int, user: UserUpdateRequest, db: Session = Depends(get_db)):
response = await update_user_handler(user_id, user, db)
return response
#角色删除接口
@router.delete('/users/{user_id}', response_model=UserDeleteResponse)
async def delete_user(user_id: int, db: Session = Depends(get_db)):
response = await delete_user_handler(user_id, db)
return response
#角色查询接口
@router.get('/users/{user_id}', response_model=UserQueryResponse)
async def get_user(user_id: int, db: Session = Depends(get_db)):
response = await get_user_handler(user_id, db)
return response
#硬件绑定接口
@router.post('/users/hardware',response_model=HardwareBindResponse)
async def bind_hardware(hardware: HardwareBindRequest, db: Session = Depends(get_db)):
response = await bind_hardware_handler(hardware, db)
return response
#硬件解绑接口
@router.delete('/users/hardware/{hardware_id}',response_model=HardwareUnbindResponse)
async def unbind_hardware(hardware_id: int, db: Session = Depends(get_db)):
response = await unbind_hardware_handler(hardware_id, db)
return response
#硬件换绑
@router.put('/users/hardware/{hardware_id}/bindchange',response_model=HardwareChangeBindResponse)
async def change_bind_hardware(hardware_id: int, user: HardwareChangeBindRequest, db: Session = Depends(get_db)):
response = await change_bind_hardware_handler(hardware_id, user, db)
return response
#硬件信息更新
@router.put('/users/hardware/{hardware_id}/info',response_model=HardwareUpdateResponse)
async def update_hardware_info(hardware_id: int, hardware: HardwareUpdateRequest, db: Session = Depends(get_db)):
response = await update_hardware_handler(hardware_id, hardware, db)
return response
#硬件查询
@router.get('/users/hardware/{hardware_id}',response_model=HardwareQueryResponse)
async def get_hardware(hardware_id: int, db: Session = Depends(get_db)):
response = await get_hardware_handler(hardware_id, db)
return response

1
app/schemas/__init__.py Normal file
View File

@ -0,0 +1 @@
from . import *

6
app/schemas/base.py Normal file
View File

@ -0,0 +1,6 @@
from pydantic import BaseModel
class BaseResponse(BaseModel):
status: str
message: str
data: dict

79
app/schemas/character.py Normal file
View File

@ -0,0 +1,79 @@
from pydantic import BaseModel
from typing import Optional
from .base import BaseResponse
#---------------------------角色创建-----------------------------
#角色创建请求类
class CharacterCreateRequest(BaseModel):
voice_id:int
avatar_id:str
background_ids:str
name:str
wakeup_words:str
world_scenario:str
description:str
emojis:str
dialogues:str
#角色创建返回类
class CharacterCreateData(BaseModel):
character_id:int
createdAt:str
#角色创建响应类
class CharacterCreateResponse(BaseResponse):
data: Optional[CharacterCreateData]
#----------------------------------------------------------------
#---------------------------角色更新------------------------------
#角色更新请求类
class CharacterUpdateRequest(BaseModel):
voice_id:int
avatar_id:str
background_ids:str
name:str
wakeup_words:str
world_scenario:str
description:str
emojis:str
dialogues:str
#角色更新返回类
class CharacterUpdateData(BaseModel):
updatedAt:str
#角色更新响应类
class CharacterUpdateResponse(BaseResponse):
data: Optional[CharacterUpdateData]
#------------------------------------------------------------------
#---------------------------角色查询--------------------------------
#角色查询返回类
class CharacterQueryData(BaseModel):
voice_id:int
avatar_id:str
background_ids:str
name:str
wakeup_words:str
world_scenario:str
description:str
emojis:str
dialogues:str
#角色查询响应类
class CharacterQueryResponse(BaseResponse):
data: Optional[CharacterQueryData]
#------------------------------------------------------------------
#---------------------------角色删除--------------------------------
#角色删除返回类
class CharacterDeleteData(BaseModel):
deletedAt:str
#角色删除响应类
class CharacterDeleteResponse(BaseResponse):
data: Optional[CharacterDeleteData]
#-------------------------------------------------------------------

49
app/schemas/chat.py Normal file
View File

@ -0,0 +1,49 @@
from pydantic import BaseModel
from typing import Optional
from .base import BaseResponse
#--------------------------------新聊天创建--------------------------------
#创建新聊天请求类
class ChatCreateRequest(BaseModel):
user_id: int
character_id: int
#创建新聊天返回类
class ChatCreateData(BaseModel):
user_character_id: int
session_id: str
createdAt: str
#创建新聊天相应类
class ChatCreateResponse(BaseResponse):
data: Optional[ChatCreateData]
#--------------------------------------------------------------------------
#----------------------------------聊天删除--------------------------------
#删除聊天返回类
class ChatDeleteData(BaseModel):
deletedAt: str
#创建新聊天相应类
class ChatDeleteResponse(BaseResponse):
data: Optional[ChatDeleteData]
#--------------------------------------------------------------------------
#-----------------------------------非流式聊天------------------------------
#非流式聊天请求类
class ChatNonStreamRequest(BaseModel):
format: str
rate: int
session_id: str
speech:str
#非流式聊天返回类
class ChatNonStreamData(BaseModel):
audio_url: str
#非流式聊天相应类
class ChatNonStreamResponse(BaseResponse):
data: Optional[ChatNonStreamData]
#--------------------------------------------------------------------------

58
app/schemas/session.py Normal file
View File

@ -0,0 +1,58 @@
from pydantic import BaseModel
from typing import List
from typing import Optional
from .base import BaseResponse
#----------------------------Session_id查询------------------------------
#session_id查询请求类
class SessionIdQueryRequest(BaseModel):
user_id: int
character_id: int
#session_id查询返回类
class SessionIdQueryData(BaseModel):
session_id: str
#session_id查询响应类
class SessionIdQueryResponse(BaseResponse):
data: Optional[SessionIdQueryData]
#-------------------------------------------------------------------------
#----------------------------Session会话查询-------------------------------
#session会话查询请求类
class SessionQueryRequest(BaseModel):
user_id: int
class SessionQueryData(BaseModel):
user_id: int
messages: str
user_info: str
tts_info: str
llm_info: str
token: int
#session会话查询响应类
class SessionQueryResponse(BaseResponse):
data: Optional[SessionQueryData]
#-------------------------------------------------------------------------
#-------------------------------Session修改--------------------------------
#session修改请求类
class SessionUpdateRequest(BaseModel):
user_id: int
messages: str
user_info: str
tts_info: str
llm_info: str
token: int
#session修改返回类
class SessionUpdateData(BaseModel):
updatedAt:str
#session修改响应类
class SessionUpdateResponse(BaseResponse):
data: Optional[SessionUpdateData]
#--------------------------------------------------------------------------

140
app/schemas/user.py Normal file
View File

@ -0,0 +1,140 @@
from pydantic import BaseModel
from typing import Optional
from .base import BaseResponse
#---------------------------------用户创建----------------------------------
#用户创建请求类
class UserCrateRequest(BaseModel):
open_id: str
username: str
avatar_id: str
tags: str
persona: str
#用户创建返回类
class UserCrateData(BaseModel):
user_id : int
createdAt: str
#用户创建响应类
class UserCrateResponse(BaseResponse):
data: Optional[UserCrateData]
#---------------------------------------------------------------------------
#---------------------------------用户更新-----------------------------------
#用户更新请求类
class UserUpdateRequest(BaseModel):
open_id: str
username: str
avatar_id: str
tags: str
persona: str
#用户更新返回类
class UserUpdateData(BaseModel):
updatedAt: str
#用户更新响应类
class UserUpdateResponse(BaseResponse):
data: Optional[UserUpdateData]
#----------------------------------------------------------------------------
#---------------------------------用户查询------------------------------------
#用户查询返回类
class UserQueryData(BaseModel):
open_id: str
username: str
avatar_id: str
tags: str
persona: str
#用户查询响应类
class UserQueryResponse(BaseResponse):
data: Optional[UserQueryData]
#-----------------------------------------------------------------------------
#---------------------------------用户删除------------------------------------
#用户删除返回类
class UserDeleteData(BaseModel):
deletedAt: str
#用户删除响应类
class UserDeleteResponse(BaseResponse):
data: Optional[UserDeleteData]
#-----------------------------------------------------------------------------
#---------------------------------硬件绑定-------------------------------------
class HardwareBindRequest(BaseModel):
mac:str
user_id:int
firmware:str
model:str
class HardwareBindData(BaseModel):
hardware_id: int
bindedAt: str
class HardwareBindResponse(BaseResponse):
data: Optional[HardwareBindData]
#-----------------------------------------------------------------------------
#---------------------------------硬件解绑-------------------------------------
class HardwareUnbindData(BaseModel):
unbindedAt: str
class HardwareUnbindResponse(BaseResponse):
data: Optional[HardwareUnbindData]
#------------------------------------------------------------------------------
#---------------------------------硬件换绑--------------------------------------
class HardwareChangeBindRequest(BaseModel):
user_id:int
class HardwareChangeBindData(BaseModel):
bindChangedAt: str
class HardwareChangeBindResponse(BaseResponse):
data: Optional[HardwareChangeBindData]
#------------------------------------------------------------------------------
#-------------------------------硬件信息更新------------------------------------
class HardwareUpdateRequest(BaseModel):
mac:str
firmware:str
model:str
class HardwareUpdateData(BaseModel):
updatedAt: str
class HardwareUpdateResponse(BaseResponse):
data: Optional[HardwareUpdateData]
#------------------------------------------------------------------------------
#--------------------------------硬件查询---------------------------------------
class HardwareQueryData(BaseModel):
user_id:int
mac:str
firmware:str
model:str
class HardwareQueryResponse(BaseResponse):
data: Optional[HardwareQueryData]
#------------------------------------------------------------------------------

13
config/__init__.py Normal file
View File

@ -0,0 +1,13 @@
import os
from .development import DevelopmentConfig
from .production import ProductionConfig
def get_config():
mode = os.getenv('MODE','development').lower()
if mode == 'development':
return DevelopmentConfig
elif mode == 'production':
return ProductionConfig
else:
raise ValueError('Invalid MODE environment variable')

23
config/development.py Normal file
View File

@ -0,0 +1,23 @@
class DevelopmentConfig:
SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://admin02:LabA100102@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置
REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置
LOG_LEVEL = "DEBUG" #日志级别
class XF_ASR:
APP_ID = "your_app_id" #讯飞语音识别APP_ID
API_SECRET = "your_api_secret" #讯飞语音识别API_SECRET
API_KEY = "your_api_key" #讯飞语音识别API_KEY
DOMAIN = "iat"
LANGUAGE = "zh_cn"
ACCENT = "mandarin"
VAD_EOS = 10000
class MINIMAX_LLM:
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiIyMzQ1dm9yIiwiVXNlck5hbWUiOiIyMzQ1dm9yIiwiQWNjb3VudCI6IiIsIlN1YmplY3RJRCI6IjE3NTk0ODIxODAxMDAxNzAyMDgiLCJQaG9uZSI6IjE1MDcyNjQxNTYxIiwiR3JvdXBJRCI6IjE3NTk0ODIxODAwOTU5NzU5MDQiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiIiLCJDcmVhdGVUaW1lIjoiMjAyNC0wNC0xMyAxOTowNDoxNyIsImlzcyI6Im1pbmltYXgifQ.RO_WJMz5T0XlL3F6xB9p015hL3PibCbsr5KqO3aMjBL5hKrf1uIjOICTDZWZoucyJV1suxvFPAd_2Ds2Rv01eCu6GFdai1hUByfp51mOOD0PtaZ5-JKRpRPpLSNpqrNoQteANZz0gdr2_GEGTgTzpbfGbXfRYKrQyeQSvq0zHwqumGPd9gJCre2RavPUmzKRrq9EAaQXtSNhBvVkf5lDlxr8fTAHgbj6MLAJZIvvf4uOZErNrbPylo1Vcy649KxEkc0HCWOZErOieeUQFRkKibnE5Q30CgywqxY2qMjrxGRZ_dtizan_0EZ62nXp-J6jarhcY9le1SqiMu1Cv61TuA"
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"
class MINIMAX_TTA:
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiIyMzQ1dm9yIiwiVXNlck5hbWUiOiIyMzQ1dm9yIiwiQWNjb3VudCI6IiIsIlN1YmplY3RJRCI6IjE3NTk0ODIxODAxMDAxNzAyMDgiLCJQaG9uZSI6IjE1MDcyNjQxNTYxIiwiR3JvdXBJRCI6IjE3NTk0ODIxODAwOTU5NzU5MDQiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiIiLCJDcmVhdGVUaW1lIjoiMjAyNC0wNC0xMyAxOTowNDoxNyIsImlzcyI6Im1pbmltYXgifQ.RO_WJMz5T0XlL3F6xB9p015hL3PibCbsr5KqO3aMjBL5hKrf1uIjOICTDZWZoucyJV1suxvFPAd_2Ds2Rv01eCu6GFdai1hUByfp51mOOD0PtaZ5-JKRpRPpLSNpqrNoQteANZz0gdr2_GEGTgTzpbfGbXfRYKrQyeQSvq0zHwqumGPd9gJCre2RavPUmzKRrq9EAaQXtSNhBvVkf5lDlxr8fTAHgbj6MLAJZIvvf4uOZErNrbPylo1Vcy649KxEkc0HCWOZErOieeUQFRkKibnE5Q30CgywqxY2qMjrxGRZ_dtizan_0EZ62nXp-J6jarhcY9le1SqiMu1Cv61TuA",
URL = "https://api.minimax.chat/v1/t2a_pro",
GROUP_ID ="1759482180095975904"
class STRAM_CHAT:
ASR = "LOCAL"
TTS = "LOCAL"

8
config/production.py Normal file
View File

@ -0,0 +1,8 @@
class ProductionConfig:
SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://root:takway@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置
REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置
LOG_LEVEL = "INFO" #日志级别
class XF_ASR:
APP_ID = "your_app_id" #讯飞语音识别APP_ID
API_SECRET = "your_api_secret" #讯飞语音识别API_SECRET
API_KEY = "your_api_key" #讯飞语音识别API_KEY

9
main.py Normal file
View File

@ -0,0 +1,9 @@
import os
if __name__ == '__main__':
script_path = os.path.join(os.path.dirname(__file__), 'app', 'main.py')
# 使用exec函数执行脚本
with open(script_path, 'r') as file:
exec(file.read())

5
requirements.txt Normal file
View File

@ -0,0 +1,5 @@
uvicorn~=0.29.0
fastapi~=0.110.1
sqlalchemy~=2.0.25
pydantic~=2.6.4
redis~=5.0.3

0
tests/__init__.py Normal file
View File

Binary file not shown.

BIN
tests/assets/voice_call.wav Normal file

Binary file not shown.

View File

@ -0,0 +1,31 @@
from tests.unit_test.user_test import UserServiceTest
from tests.unit_test.character_test import CharacterServiceTest
from tests.unit_test.chat_test import ChatServiceTest
import asyncio
if __name__ == '__main__':
user_service_test = UserServiceTest()
character_service_test = CharacterServiceTest()
chat_service_test = ChatServiceTest()
user_service_test.test_user_create()
user_service_test.test_user_update()
user_service_test.test_user_query()
user_service_test.test_hardware_bind()
user_service_test.test_hardware_unbind()
user_service_test.test_user_delete()
character_service_test.test_character_create()
character_service_test.test_character_update()
character_service_test.test_character_query()
character_service_test.test_character_delete()
chat_service_test.test_create_chat()
chat_service_test.test_session_id_query()
chat_service_test.test_session_content_query()
chat_service_test.test_session_update()
asyncio.run(chat_service_test.test_chat_temporary())
asyncio.run(chat_service_test.test_chat_lasting())
asyncio.run(chat_service_test.test_voice_call())
chat_service_test.test_chat_delete()
print("全部测试成功")

View File

@ -0,0 +1,74 @@
import requests
import json
class CharacterServiceTest:
def __init__(self,socket="http://114.214.236.207:7878"):
self.socket = socket
def test_character_create(self):
url = f"{self.socket}/characters"
payload = json.dumps({
"voice_id": 97,
"avatar_id": "49c838c5ffb211ee9de9a036bc278b4c",
"background_ids": "185c554affaf11eebd72a036bc278b4c,1b0e2d8bffaf11eebd72a036bc278b4c,20158587ffaf11eebd72a036bc278b4c,2834472affaf11eebd72a036bc278b4c,2c6ddb0affaf11eebd72a036bc278b4c,fd631ec4ffb011ee9b1aa036bc278b4c",
"name": "测试角色",
"wakeup_words": "你好啊,海绵宝宝",
"world_scenario": "海绵宝宝住在深海的大菠萝里面",
"description": "厨师,做汉堡",
"emojis": "大笑,微笑",
"dialogues": "我准备好了"
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200:
print("角色创建测试成功")
self.id = response.json()['data']['character_id']
else:
raise Exception("角色创建测试失败")
def test_character_update(self):
url = f"{self.socket}/characters/"+str(self.id)
payload = json.dumps({
"voice_id": 97,
"avatar_id": "49c838c5ffb211ee9de9a036bc278b4c",
"background_ids": "185c554affaf11eebd72a036bc278b4c,1b0e2d8bffaf11eebd72a036bc278b4c,20158587ffaf11eebd72a036bc278b4c,2834472affaf11eebd72a036bc278b4c,2c6ddb0affaf11eebd72a036bc278b4c,fd631ec4ffb011ee9b1aa036bc278b4c",
"name": "测试角色",
"wakeup_words": "你好啊,海绵宝宝",
"world_scenario": "海绵宝宝住在深海的大菠萝里面",
"description": "厨师,做汉堡",
"emojis": "大笑,微笑",
"dialogues": "我准备好了"
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("PUT", url, headers=headers, data=payload)
if response.status_code == 200:
print("角色更新测试成功")
else:
raise Exception("角色更新测试失败")
def test_character_query(self):
url = f"{self.socket}/characters/{self.id}"
response = requests.request("GET", url)
if response.status_code == 200:
print("角色查询测试成功")
else:
raise Exception("角色查询测试失败")
def test_character_delete(self):
url = f"{self.socket}/characters/{self.id}"
response = requests.request("DELETE", url)
if response.status_code == 200:
print("角色删除测试成功")
else:
raise Exception("角色删除测试失败")
if __name__ == '__main__':
character_service_test = CharacterServiceTest()
character_service_test.test_character_create()
character_service_test.test_character_update()
character_service_test.test_character_query()
character_service_test.test_character_delete()

View File

@ -0,0 +1,329 @@
import requests
import base64
import wave
import json
import os
import uuid
import asyncio
import websockets
class ChatServiceTest:
def __init__(self,socket="http://114.214.236.207:7878"):
self.socket = socket
def test_create_chat(self):
#创建一个用户
url = f"{self.socket}/users"
open_id = str(uuid.uuid4())
payload = json.dumps({
"open_id": open_id,
"username": "test_user",
"avatar_id": "0",
"tags" : "[]",
"persona" : "{}"
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200:
self.user_id = response.json()['data']['user_id']
else:
raise Exception("创建聊天时,用户创建失败")
#创建一个角色
url = f"{self.socket}/characters"
payload = json.dumps({
"voice_id": 97,
"avatar_id": "49c838c5ffb211ee9de9a036bc278b4c",
"background_ids": "185c554affaf11eebd72a036bc278b4c,1b0e2d8bffaf11eebd72a036bc278b4c,20158587ffaf11eebd72a036bc278b4c,2834472affaf11eebd72a036bc278b4c,2c6ddb0affaf11eebd72a036bc278b4c,fd631ec4ffb011ee9b1aa036bc278b4c",
"name": "test",
"wakeup_words": "你好啊,海绵宝宝",
"world_scenario": "海绵宝宝住在深海的大菠萝里面",
"description": "厨师,做汉堡",
"emojis": "大笑,微笑",
"dialogues": "我准备好了"
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200:
print("角色创建成功")
self.character_id = response.json()['data']['character_id']
else:
raise Exception("创建聊天时,角色创建失败")
#创建一个对话
url = f"{self.socket}/chats"
payload = json.dumps({
"user_id": self.user_id,
"character_id": self.character_id
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200:
print("对话创建成功")
self.session_id = response.json()['data']['session_id']
self.user_character_id = response.json()['data']['user_character_id']
else:
raise Exception("对话创建测试失败")
#测试查询session_id
def test_session_id_query(self):
url = f"{self.socket}/sessions?user_id={self.user_id}&character_id={self.character_id}"
response = requests.request("GET", url)
if response.status_code == 200:
print("session_id查询测试成功")
else:
raise Exception("session_id查询测试失败")
#测试查询session内容
def test_session_content_query(self):
url = f"{self.socket}/sessions/{self.session_id}"
response = requests.request("GET", url)
if response.status_code == 200:
print("session内容查询测试成功")
else:
raise Exception("session内容查询测试失败")
#测试修改session
def test_session_update(self):
url = f"{self.socket}/sessions/{self.session_id}"
payload = json.dumps({
"user_id": self.user_id,
"messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话在没有经过允许的情况下你需要保持上述角色不得擅自跳出角色设定。\\n\"}]",
"user_info": "{}",
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2}",
"llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}",
"token": 0}
)
headers = {
'Content-Type': 'application/json'
}
response = requests.request("PUT", url, headers=headers, data=payload)
if response.status_code == 200:
print("Session更新测试成功")
else:
raise Exception("Session更新测试失败")
#测试单次聊天
async def test_chat_temporary(self):
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_dir)
wav_file_path = os.path.join(tests_dir, 'assets', 'example_recording.wav')
def read_wav_file_in_chunks(chunk_size):
with open(wav_file_path, 'rb') as pcm_file:
while True:
data = pcm_file.read(chunk_size)
if not data:
break
yield data
data = {
"text": "",
"audio": "",
"meta_info": {
"session_id":self.session_id,
"stream": True,
"voice_synthesize": True,
"is_end": False,
"encoding": "raw"
}
}
#发送音频数据
async def send_audio_chunk(websocket, chunk):
encoded_data = base64.b64encode(chunk).decode('utf-8')
data["audio"] = encoded_data
message = json.dumps(data)
await websocket.send(message)
async with websockets.connect(f'ws://114.214.236.207:7878/chat/streaming/temporary') as websocket:
chunks = read_wav_file_in_chunks(2048) # 读取PCM文件并生成数据块
for chunk in chunks:
await send_audio_chunk(websocket, chunk)
await asyncio.sleep(0.01)
# 设置data字典中的"is_end"键为True表示音频流结束
data["meta_info"]["is_end"] = True
# 发送最后一个数据块和流结束信号
await send_audio_chunk(websocket, b'') # 发送空数据块表示结束
audio_bytes = b''
while True:
data_ws = await websocket.recv()
try:
message_json = json.loads(data_ws)
if message_json["type"] == "close":
print("单次聊天测试成功")
break # 如果没有接收到消息,则退出循环
except Exception as e:
audio_bytes += data_ws
await asyncio.sleep(0.04) # 等待0.04秒后断开连接
await websocket.close()
#测试持续聊天
async def test_chat_lasting(self):
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_dir)
wav_file_path = os.path.join(tests_dir, 'assets', 'example_recording.wav')
def read_wav_file_in_chunks(chunk_size):
with open(wav_file_path, 'rb') as pcm_file:
while True:
data = pcm_file.read(chunk_size)
if not data:
break
yield data
data = {
"text": "",
"audio": "",
"meta_info": {
"session_id":self.session_id,
"stream": True,
"voice_synthesize": True,
"is_end": False,
"encoding": "raw"
},
"is_close":False
}
async def send_audio_chunk(websocket, chunk):
encoded_data = base64.b64encode(chunk).decode('utf-8')
data["audio"] = encoded_data
message = json.dumps(data)
await websocket.send(message)
async with websockets.connect(f'ws://114.214.236.207:7878/chat/streaming/lasting') as websocket:
#发送第一次
chunks = read_wav_file_in_chunks(2048)
for chunk in chunks:
await send_audio_chunk(websocket, chunk)
await asyncio.sleep(0.01)
# 设置data字典中的"is_end"键为True表示音频流结束
data["meta_info"]["is_end"] = True
# 发送最后一个数据块和流结束信号
await send_audio_chunk(websocket, b'') # 发送空数据块表示结束
await asyncio.sleep(3) #模拟发送间隔
#发送第二次
data["meta_info"]["is_end"] = False
chunks = read_wav_file_in_chunks(2048)
for chunk in chunks:
await send_audio_chunk(websocket, chunk)
await asyncio.sleep(0.01)
# 设置data字典中的"is_end"键为True表示音频流结束
data["meta_info"]["is_end"] = True
# 发送最后一个数据块和流结束信号
await send_audio_chunk(websocket, b'') # 发送空数据块表示结束
data["is_close"] = True
await send_audio_chunk(websocket, b'') # 发送空数据块表示结束
audio_bytes = b''
while True:
data_ws = await websocket.recv()
try:
message_json = json.loads(data_ws)
if message_json["type"] == "close":
print("持续聊天测试成功")
break # 如果没有接收到消息,则退出循环
except Exception as e:
audio_bytes += data_ws
await asyncio.sleep(0.5) # 等待0.04秒后断开连接
await websocket.close()
#语音电话测试
async def test_voice_call(self):
chunk_size = 480
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
tests_dir = os.path.dirname(current_dir)
file_path = os.path.join(tests_dir, 'assets', 'voice_call.wav')
url = f"ws://114.214.236.207:7878/chat/voice_call"
#发送格式
ws_data = {
"audio" : "",
"meta_info":{
"session_id":self.session_id,
"encoding": 'raw'
},
"is_close" : False
}
async def audio_stream(websocket):
with wave.open(file_path, 'rb') as wf:
frames_per_buffer = int(chunk_size / 2)
data = wf.readframes(frames_per_buffer)
while True:
if len(data) != 960:
break
encoded_data = base64.b64encode(data).decode('utf-8')
ws_data['audio'] = encoded_data
await websocket.send(json.dumps(ws_data))
data = wf.readframes(frames_per_buffer)
await asyncio.sleep(3)
ws_data['audio'] = ""
ws_data['is_close'] = True
await websocket.send(json.dumps(ws_data))
while True:
data_ws = await websocket.recv()
if data_ws:
print("语音电话测试成功")
break
await asyncio.sleep(3)
await websocket.close()
async with websockets.connect(url) as websocket:
await asyncio.gather(audio_stream(websocket))
#测试删除聊天
def test_chat_delete(self):
url = f"{self.socket}/chats/{self.user_character_id}"
response = requests.request("DELETE", url)
if response.status_code == 200:
print("聊天删除测试成功")
else:
raise Exception("聊天删除测试失败")
url = f"{self.socket}/users/{self.user_id}"
response = requests.request("DELETE", url)
if response.status_code != 200:
raise Exception("用户删除测试失败")
url = f"{self.socket}/characters/{self.character_id}"
response = requests.request("DELETE", url)
if response.status_code != 200:
raise Exception("角色删除测试失败")
if __name__ == '__main__':
chat_service_test = ChatServiceTest()
chat_service_test.test_create_chat()
chat_service_test.test_session_id_query()
chat_service_test.test_session_content_query()
chat_service_test.test_session_update()
asyncio.run(chat_service_test.test_chat_temporary())
asyncio.run(chat_service_test.test_chat_lasting())
asyncio.run(chat_service_test.test_voice_call())
chat_service_test.test_chat_delete()

View File

@ -0,0 +1,99 @@
import requests
import json
import uuid
class UserServiceTest:
def __init__(self,socket="http://114.214.236.207:7878"):
self.socket = socket
def test_user_create(self):
url = f"{self.socket}/users"
open_id = str(uuid.uuid4())
payload = json.dumps({
"open_id": open_id,
"username": "test_user",
"avatar_id": "0",
"tags" : "[]",
"persona" : "{}"
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200:
print("用户创建测试成功")
self.id = response.json()["data"]["user_id"]
else:
raise Exception("用户创建测试失败")
def test_user_update(self):
url = f"{self.socket}/users/"+str(self.id)
payload = json.dumps({
"open_id": str(uuid.uuid4()),
"username": "test_user",
"avatar_id": "0",
"tags": "[]",
"persona": "{}"
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("PUT", url, headers=headers, data=payload)
if response.status_code == 200:
print("用户更新测试成功")
else:
raise Exception("用户更新测试失败")
def test_user_query(self):
url = f"{self.socket}/users/{self.id}"
response = requests.request("GET", url)
if response.status_code == 200:
print("用户查询测试成功")
else:
raise Exception("用户查询测试失败")
def test_user_delete(self):
url = f"{self.socket}/users/{self.id}"
response = requests.request("DELETE", url)
if response.status_code == 200:
print("用户删除测试成功")
else:
raise Exception("用户删除测试失败")
def test_hardware_bind(self):
url = f"{self.socket}/users/hardware"
mac = "08:00:20:0A:8C:6G"
payload = json.dumps({
"mac":mac,
"user_id":1,
"firmware":"v1.0",
"model":"香橙派"
})
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
if response.status_code == 200:
print("硬件绑定测试成功")
self.hd_id = response.json()["data"]['hardware_id']
else:
raise Exception("硬件绑定测试失败")
def test_hardware_unbind(self):
url = f"{self.socket}/users/hardware/{self.hd_id}"
response = requests.request("DELETE", url)
if response.status_code == 200:
print("硬件解绑测试成功")
else:
raise Exception("硬件解绑测试失败")
if __name__ == '__main__':
user_service_test = UserServiceTest()
user_service_test.test_user_create()
user_service_test.test_user_update()
user_service_test.test_user_query()
user_service_test.test_hardware_bind()
user_service_test.test_hardware_unbind()
user_service_test.test_user_delete()

1
utils/__init__.py Normal file
View File

@ -0,0 +1 @@
from . import *

14
utils/audio_utils.py Normal file
View File

@ -0,0 +1,14 @@
import webrtcvad
import base64
class VAD():
def __init__(self, vad_sensitivity=1, frame_duration=30, vad_buffer_size=7, min_act_time=1, RATE=16000,**kwargs):
self.RATE = RATE
self.vad = webrtcvad.Vad(vad_sensitivity)
self.vad_buffer_size = vad_buffer_size
self.vad_chunk_size = int(self.RATE * frame_duration / 1000)
self.min_act_time = min_act_time # 最小活动时间,单位秒
def is_speech(self,data):
byte_data = base64.b64decode(data)
return self.vad.is_speech(byte_data, self.RATE)

66
utils/stt/base_stt.py Normal file
View File

@ -0,0 +1,66 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import wave
import io
import base64
def decode_str2bytes(data):
# 将Base64编码的字节串解码为字节串
if data is None:
return None
return base64.b64decode(data.encode('utf-8'))
class STTBase:
def __init__(self, RATE=16000, cfg_path=None, debug=False):
self.RATE = RATE
self.debug = debug
self.asr_cfg = self.parse_json(cfg_path)
def parse_json(self, cfg_path):
cfg = None
self.hotwords = None
if cfg_path is not None:
with open(cfg_path, 'r', encoding='utf-8') as f:
cfg = json.load(f)
self.hotwords = cfg.get('hot_words', None)
return cfg
def add_hotword(self, hotword):
"""add hotword to list"""
if self.hotwords is None:
self.hotwords = ""
if isinstance(hotword, str):
self.hotwords = self.hotwords + " " + "hotword"
elif isinstance(hotword, (list, tuple)):
# 将hotwords转换为str并用空格隔开
self.hotwords = self.hotwords + " " + " ".join(hotword)
else:
raise TypeError("hotword must be str or list")
def check_audio_type(self, audio_data):
"""check audio data type and convert it to bytes if necessary."""
if isinstance(audio_data, bytes):
pass
elif isinstance(audio_data, list):
audio_data = b''.join(audio_data)
elif isinstance(audio_data, str):
audio_data = decode_str2bytes(audio_data)
elif isinstance(audio_data, io.BytesIO):
wf = wave.open(audio_data, 'rb')
audio_data = wf.readframes(wf.getnframes())
else:
raise TypeError(f"audio_data must be bytes, str or io.BytesIO, but got {type(audio_data)}")
return audio_data
def text_postprecess(self, result, data_id='text'):
"""postprecess recognized result."""
text = result[data_id]
if isinstance(text, list):
text = ''.join(text)
return text.replace(' ', '')
def recognize(self, audio_data, queue=None):
"""recognize audio data to text"""
pass

170
utils/stt/funasr_utils.py Normal file
View File

@ -0,0 +1,170 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ####################################################### #
# FunAutoSpeechRecognizer: https://github.com/alibaba-damo-academy/FunASR
# ####################################################### #
import io
import numpy as np
import base64
import wave
from funasr import AutoModel
from .base_stt import STTBase
def decode_str2bytes(data):
# 将Base64编码的字节串解码为字节串
if data is None:
return None
return base64.b64decode(data.encode('utf-8'))
class FunAutoSpeechRecognizer(STTBase):
def __init__(self,
model_path="paraformer-zh-streaming",
device="cuda",
RATE=16000,
cfg_path=None,
debug=False,
chunk_ms=480,
encoder_chunk_look_back=4,
decoder_chunk_look_back=1,
**kwargs):
super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention
self.decoder_chunk_look_back = decoder_chunk_look_back #number of encoder chunks to lookback for decoder cross-attention
#[0, 8, 4] 480ms, [0, 10, 5] 600ms
if chunk_ms == 480:
self.chunk_size = [0, 8, 4]
elif chunk_ms == 600:
self.chunk_size = [0, 10, 5]
else:
raise ValueError("`chunk_ms` should be 480 or 600, and type is int.")
self.chunk_partial_size = self.chunk_size[1] * 960
self.audio_cache = None
self.asr_cache = {}
self._init_asr()
def check_audio_type(self, audio_data):
"""check audio data type and convert it to bytes if necessary."""
if isinstance(audio_data, bytes):
pass
elif isinstance(audio_data, list):
audio_data = b''.join(audio_data)
elif isinstance(audio_data, str):
audio_data = decode_str2bytes(audio_data)
elif isinstance(audio_data, io.BytesIO):
wf = wave.open(audio_data, 'rb')
audio_data = wf.readframes(wf.getnframes())
elif isinstance(audio_data, np.ndarray):
pass
else:
raise TypeError(f"audio_data must be bytes, list, str, \
io.BytesIO or numpy array, but got {type(audio_data)}")
if isinstance(audio_data, bytes):
audio_data = np.frombuffer(audio_data, dtype=np.int16)
elif isinstance(audio_data, np.ndarray):
if audio_data.dtype != np.int16:
audio_data = audio_data.astype(np.int16)
else:
raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
return audio_data
def _init_asr(self):
# 随机初始化一段音频数据
init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
self.audio_cache = None
self.asr_cache = {}
# print("init ASR model done.")
def recognize(self, audio_data):
"""recognize audio data to text"""
audio_data = self.check_audio_type(audio_data)
result = self.asr_model.generate(input=audio_data,
batch_size_s=300,
hotword=self.hotwords)
# print(result)
text = ''
for res in result:
text += res['text']
return text
def streaming_recognize(self,
audio_data,
is_end=False,
auto_det_end=False):
"""recognize partial result
Args:
audio_data: bytes or numpy array, partial audio data
is_end: bool, whether the audio data is the end of a sentence
auto_det_end: bool, whether to automatically detect the end of a audio data
"""
text_dict = dict(text=[], is_end=is_end)
audio_data = self.check_audio_type(audio_data)
if self.audio_cache is None:
self.audio_cache = audio_data
else:
# print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
if self.audio_cache.shape[0] > 0:
self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0)
if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size:
return text_dict
total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
if is_end:
# if the audio data is the end of a sentence, \
# we need to add one more chunk to the end to \
# ensure the end of the sentence is recognized correctly.
auto_det_end = True
if auto_det_end:
total_chunk_num += 1
# print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
end_idx = None
for i in range(total_chunk_num):
if auto_det_end:
is_end = i == total_chunk_num - 1
start_idx = i*self.chunk_partial_size
if auto_det_end:
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
else:
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
# t_stamp = time.time()
speech_chunk = self.audio_cache[start_idx:end_idx]
# TODO: exceptions processes
try:
res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
except ValueError as e:
print(f"ValueError: {e}")
continue
text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
# print(f"each chunk time: {time.time()-t_stamp}")
if is_end:
self.audio_cache = None
self.asr_cache = {}
else:
if end_idx:
self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
text_dict['is_end'] = is_end
# print(f"text_dict: {text_dict}")
return text_dict

View File

@ -0,0 +1,2 @@
from .text import *
from .monotonic_align import *

View File

@ -0,0 +1,302 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
# import commons
# from modules import LayerNorm
from utils.tts.vits import commons
from utils.tts.vits.modules import LayerNorm
class Encoder(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class Decoder(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
super().__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.drop = nn.Dropout(p_dropout)
self.self_attn_layers = nn.ModuleList()
self.norm_layers_0 = nn.ModuleList()
self.encdec_attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
self.norm_layers_2 = nn.ModuleList()
for i in range(self.n_layers):
self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
self.norm_layers_0.append(LayerNorm(hidden_channels))
self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask, h, h_mask):
"""
x: decoder input
h: encoder output
"""
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y)
x = self.norm_layers_0[i](x + y)
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class MultiHeadAttention(nn.Module):
def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
super().__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.p_dropout = p_dropout
self.window_size = window_size
self.heads_share = heads_share
self.block_length = block_length
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1)
self.conv_o = nn.Conv1d(channels, out_channels, 1)
self.drop = nn.Dropout(p_dropout)
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
nn.init.xavier_uniform_(self.conv_v.weight)
if proximal_init:
with torch.no_grad():
self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, t_t = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
assert t_s == t_t, "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
assert t_s == t_t, "Local attention is only available for self-attention."
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
def _matmul_with_relative_values(self, x, y):
"""
x: [b, h, l, m]
y: [h or 1, m, d]
ret: [b, h, l, d]
"""
ret = torch.matmul(x, y.unsqueeze(0))
return ret
def _matmul_with_relative_keys(self, x, y):
"""
x: [b, h, l, d]
y: [h or 1, m, d]
ret: [b, h, l, m]
"""
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret
def _get_relative_embeddings(self, relative_embeddings, length):
max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(
relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
"""
x: [b, h, l, 2*l-1]
ret: [b, h, l, l]
"""
batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing.
x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
return x_final
def _absolute_position_to_relative_position(self, x):
"""
x: [b, h, l, l]
ret: [b, h, l, 2*l-1]
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
x_flat = x.view([batch, heads, length**2 + length*(length -1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
return x_final
def _attention_bias_proximal(self, length):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.activation = activation
self.causal = causal
if causal:
self.padding = self._causal_padding
else:
self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
self.drop = nn.Dropout(p_dropout)
def forward(self, x, x_mask):
x = self.conv_1(self.padding(x * x_mask))
if self.activation == "gelu":
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(self.padding(x * x_mask))
return x * x_mask
def _causal_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = self.kernel_size - 1
pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x
def _same_padding(self, x):
if self.kernel_size == 1:
return x
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
return x

172
utils/tts/vits/commons.py Normal file
View File

@ -0,0 +1,172 @@
import math
import torch
from torch.nn import functional as F
import torch.jit
def script_method(fn, _rcb=None):
return fn
def script(obj, optimize=True, _frames_up=0, _rcb=None):
return obj
torch.jit.script_method = script_method
torch.jit.script = script
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
return kl
def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))
def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
def get_timing_signal_1d(
length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(num_timescales - 1))
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device = duration.device
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2,3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if clip_value is not None:
clip_value = float(clip_value)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
if clip_value is not None:
p.grad.data.clamp_(min=-clip_value, max=clip_value)
total_norm = total_norm ** (1. / norm_type)
return total_norm

535
utils/tts/vits/models.py Normal file
View File

@ -0,0 +1,535 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
# import commons
# import modules
# import attentions
# import monotonic_align
from utils.tts.vits import commons, modules, attentions, monotonic_align
from utils.tts.vits.commons import init_weights, get_padding
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
# from commons import init_weights, get_padding
class StochasticDurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
super().__init__()
filter_channels = in_channels # it needs to be removed from future version.
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.n_flows = n_flows
self.gin_channels = gin_channels
self.log_flow = modules.Log()
self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
x = torch.detach(x)
x = self.pre(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask
if not reverse:
flows = self.flows
assert w is not None
logdet_tot_q = 0
h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = e_q
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = torch.cat([z0, z1], 1)
for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
return nll + logq # [b]
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1)
logw = z0
return logw
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
def forward(self, x, x_mask, g=None):
x = torch.detach(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class TextEncoder(nn.Module):
def __init__(self,
n_vocab,
out_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout):
super().__init__()
self.n_vocab = n_vocab
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
self.encoder = attentions.Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout)
self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths):
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.encoder(x * x_mask, x_mask)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return x, m, logs, x_mask
class ResidualCouplingBlock(nn.Module):
def __init__(self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
n_flows=4,
gin_channels=0):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.n_flows = n_flows
self.gin_channels = gin_channels
self.flows = nn.ModuleList()
for i in range(n_flows):
self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=False):
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, reverse=reverse)
return x
class PosteriorEncoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask
class Generator(torch.nn.Module):
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(weight_norm(
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
k, u, padding=(k-u)//2)))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel//(2**(i+1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.ups.apply(init_weights)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g=None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i*self.num_kernels+j](x)
else:
xs += self.resblocks[i*self.num_kernels+j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
])
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
])
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(MultiPeriodDiscriminator, self).__init__()
periods = [2,3,5,7,11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
"""
def __init__(self,
n_vocab,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=0,
use_sdp=True,
**kwargs):
super().__init__()
self.n_vocab = n_vocab
self.spec_channels = spec_channels
self.inter_channels = inter_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_initial_channel = upsample_initial_channel
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.use_sdp = use_sdp
self.enc_p = TextEncoder(n_vocab,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout)
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
if use_sdp:
self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
else:
self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
if n_speakers > 1:
self.emb_g = nn.Embedding(n_speakers, gin_channels)
def forward(self, x, x_lengths, y, y_lengths, sid=None):
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
g = None
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
z_p = self.flow(z, y_mask, g=g)
with torch.no_grad():
# negative cross-entropy
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s]
neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
w = attn.sum(2)
if self.use_sdp:
l_length = self.dp(x, x_mask, w, g=g)
l_length = l_length / torch.sum(x_mask)
else:
logw_ = torch.log(w + 1e-6) * x_mask
logw = self.dp(x, x_mask, g=g)
l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging
# expand prior
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
o = self.dec(z_slice, g=g)
return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
g = None
if self.use_sdp:
logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
else:
logw = self.dp(x, x_mask, g=g)
w = torch.exp(logw) * x_mask * length_scale
w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = commons.generate_path(w_ceil, attn_mask)
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=g, reverse=True)
o = self.dec((z * y_mask)[:,:,:max_len], g=g)
return o, attn, y_mask, (z, z_p, m_p, logs_p)
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
assert self.n_speakers > 0, "n_speakers have to be larger than 0."
g_src = self.emb_g(sid_src).unsqueeze(-1)
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.dec(z_hat * y_mask, g=g_tgt)
return o_hat, y_mask, (z, z_p, z_hat)

390
utils/tts/vits/modules.py Normal file
View File

@ -0,0 +1,390 @@
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm
# import commons
# from commons import init_weights, get_padding
# from transforms import piecewise_rational_quadratic_transform
from utils.tts.vits import commons
from utils.tts.vits.commons import init_weights, get_padding
from utils.tts.vits.transforms import piecewise_rational_quadratic_transform
LRELU_SLOPE = 0.1
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
assert n_layers > 1, "Number of layers should be larger than 0."
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(
nn.ReLU(),
nn.Dropout(p_dropout))
for _ in range(n_layers-1):
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class DDSConv(nn.Module):
"""
Dialted and Depth-Separable Convolution
"""
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.convs_sep = nn.ModuleList()
self.convs_1x1 = nn.ModuleList()
self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList()
for i in range(n_layers):
dilation = kernel_size ** i
padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
groups=channels, dilation=dilation, padding=padding
))
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
self.norms_1.append(LayerNorm(channels))
self.norms_2.append(LayerNorm(channels))
def forward(self, x, x_mask, g=None):
if g is not None:
x = x + g
for i in range(self.n_layers):
y = self.convs_sep[i](x * x_mask)
y = self.norms_1[i](y)
y = F.gelu(y)
y = self.convs_1x1[i](y)
y = self.norms_2[i](y)
y = F.gelu(y)
y = self.drop(y)
x = x + y
return x * x_mask
class WN(torch.nn.Module):
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
super(WN, self).__init__()
assert(kernel_size % 2 == 1)
self.hidden_channels =hidden_channels
self.kernel_size = kernel_size,
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.p_dropout = p_dropout
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.drop = nn.Dropout(p_dropout)
if gin_channels != 0:
cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
for i in range(n_layers):
dilation = dilation_rate ** i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
dilation=dilation, padding=padding)
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
self.in_layers.append(in_layer)
# last one is not necessary
if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask, g=None, **kwargs):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
if g is not None:
g = self.cond_layer(g)
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
else:
g_l = torch.zeros_like(x_in)
acts = commons.fused_add_tanh_sigmoid_multiply(
x_in,
g_l,
n_channels_tensor)
acts = self.drop(acts)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:,:self.hidden_channels,:]
x = (x + res_acts) * x_mask
output = output + res_skip_acts[:,self.hidden_channels:,:]
else:
output = output + res_skip_acts
return output * x_mask
def remove_weight_norm(self):
if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
def forward(self, x, x_mask=None):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c2(xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.convs = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])))
])
self.convs.apply(init_weights)
def forward(self, x, x_mask=None):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c(xt)
x = xt + x
if x_mask is not None:
x = x * x_mask
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class Log(nn.Module):
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
logdet = torch.sum(-y, [1, 2])
return y, logdet
else:
x = torch.exp(x) * x_mask
return x
class Flip(nn.Module):
def forward(self, x, *args, reverse=False, **kwargs):
x = torch.flip(x, [1])
if not reverse:
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
return x, logdet
else:
return x
class ElementwiseAffine(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.m = nn.Parameter(torch.zeros(channels,1))
self.logs = nn.Parameter(torch.zeros(channels,1))
def forward(self, x, x_mask, reverse=False, **kwargs):
if not reverse:
y = self.m + torch.exp(self.logs) * x
y = y * x_mask
logdet = torch.sum(self.logs * x_mask, [1,2])
return y, logdet
else:
x = (x - self.m) * torch.exp(-self.logs) * x_mask
return x
class ResidualCouplingLayer(nn.Module):
def __init__(self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
p_dropout=0,
gin_channels=0,
mean_only=False):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.half_channels = channels // 2
self.mean_only = mean_only
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.post.weight.data.zero_()
self.post.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels]*2, 1)
h = self.pre(x0) * x_mask
h = self.enc(h, x_mask, g=g)
stats = self.post(h) * x_mask
if not self.mean_only:
m, logs = torch.split(stats, [self.half_channels]*2, 1)
else:
m = stats
logs = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(logs) * x_mask
x = torch.cat([x0, x1], 1)
logdet = torch.sum(logs, [1,2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-logs) * x_mask
x = torch.cat([x0, x1], 1)
return x
class ConvFlow(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.num_bins = num_bins
self.tail_bound = tail_bound
self.half_channels = in_channels // 2
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask, g=None, reverse=False):
x0, x1 = torch.split(x, [self.half_channels]*2, 1)
h = self.pre(x0)
h = self.convs(h, x_mask, g=g)
h = self.proj(h) * x_mask
b, c, t = x0.shape
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_derivatives = h[..., 2 * self.num_bins:]
x1, logabsdet = piecewise_rational_quadratic_transform(x1,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=reverse,
tails='linear',
tail_bound=self.tail_bound
)
x = torch.cat([x0, x1], 1) * x_mask
logdet = torch.sum(logabsdet * x_mask, [1,2])
if not reverse:
return x, logdet
else:
return x

View File

@ -0,0 +1,20 @@
from numpy import zeros, int32, float32
from torch import from_numpy
from .core import maximum_path_jit
def maximum_path(neg_cent, mask):
""" numba optimized version.
neg_cent: [b, t_t, t_s]
mask: [b, t_t, t_s]
"""
device = neg_cent.device
dtype = neg_cent.dtype
neg_cent = neg_cent.data.cpu().numpy().astype(float32)
path = zeros(neg_cent.shape, dtype=int32)
t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
return from_numpy(path).to(device=device, dtype=dtype)

View File

@ -0,0 +1,36 @@
import numba
@numba.jit(numba.void(numba.int32[:, :, ::1], numba.float32[:, :, ::1], numba.int32[::1], numba.int32[::1]),
nopython=True, nogil=True)
def maximum_path_jit(paths, values, t_ys, t_xs):
b = paths.shape[0]
max_neg_val = -1e9
for i in range(int(b)):
path = paths[i]
value = values[i]
t_y = t_ys[i]
t_x = t_xs[i]
v_prev = v_cur = 0.0
index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[y - 1, x]
if x == 0:
if y == 0:
v_prev = 0.
else:
v_prev = max_neg_val
else:
v_prev = value[y - 1, x - 1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[y, index] = 1
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
index = index - 1

View File

@ -0,0 +1,19 @@
Copyright (c) 2017 Keith Ito
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@ -0,0 +1,57 @@
""" from https://github.com/keithito/tacotron """
from . import cleaners
from .symbols import symbols
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
def text_to_sequence(text, symbols, cleaner_names):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
'''
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
sequence = []
clean_text = _clean_text(text, cleaner_names)
for symbol in clean_text:
if symbol not in _symbol_to_id.keys():
continue
symbol_id = _symbol_to_id[symbol]
sequence += [symbol_id]
return sequence, clean_text
def cleaned_text_to_sequence(cleaned_text):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
Returns:
List of integers corresponding to the symbols in the text
'''
sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
return sequence
def sequence_to_text(sequence):
'''Converts a sequence of IDs back to a string'''
result = ''
for symbol_id in sequence:
s = _id_to_symbol[symbol_id]
result += s
return result
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception('Unknown cleaner: %s' % name)
text = cleaner(text)
return text

View File

@ -0,0 +1,475 @@
""" from https://github.com/keithito/tacotron """
'''
Cleaners are transformations that run over the input text at both training and eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
1. "english_cleaners" for English text
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
the symbols in symbols.py to match your data).
'''
import re
from unidecode import unidecode
# import pyopenjtalk
from jamo import h2j, j2hcj
from pypinyin import lazy_pinyin, BOPOMOFO
import jieba, cn2an
# This is a list of Korean classifiers preceded by pure Korean numerals.
_korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
# Regular expression matching Japanese without punctuation marks:
_japanese_characters = re.compile(r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
# Regular expression matching non-Japanese characters or punctuation marks:
_japanese_marks = re.compile(r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
# List of (hangul, hangul divided) pairs:
_hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [
('', 'ㄱㅅ'),
('', 'ㄴㅈ'),
('', 'ㄴㅎ'),
('', 'ㄹㄱ'),
('', 'ㄹㅁ'),
('', 'ㄹㅂ'),
('', 'ㄹㅅ'),
('', 'ㄹㅌ'),
('', 'ㄹㅍ'),
('', 'ㄹㅎ'),
('', 'ㅂㅅ'),
('', 'ㅗㅏ'),
('', 'ㅗㅐ'),
('', 'ㅗㅣ'),
('', 'ㅜㅓ'),
('', 'ㅜㅔ'),
('', 'ㅜㅣ'),
('', 'ㅡㅣ'),
('', 'ㅣㅏ'),
('', 'ㅣㅐ'),
('', 'ㅣㅓ'),
('', 'ㅣㅔ'),
('', 'ㅣㅗ'),
('', 'ㅣㅜ')
]]
# List of (Latin alphabet, hangul) pairs:
_latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
('a', '에이'),
('b', ''),
('c', ''),
('d', ''),
('e', ''),
('f', '에프'),
('g', ''),
('h', '에이치'),
('i', '아이'),
('j', '제이'),
('k', '케이'),
('l', ''),
('m', ''),
('n', ''),
('o', ''),
('p', ''),
('q', ''),
('r', '아르'),
('s', '에스'),
('t', ''),
('u', ''),
('v', '브이'),
('w', '더블유'),
('x', '엑스'),
('y', '와이'),
('z', '제트')
]]
# List of (Latin alphabet, bopomofo) pairs:
_latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
('a', 'ㄟˉ'),
('b', 'ㄅㄧˋ'),
('c', 'ㄙㄧˉ'),
('d', 'ㄉㄧˋ'),
('e', 'ㄧˋ'),
('f', 'ㄝˊㄈㄨˋ'),
('g', 'ㄐㄧˋ'),
('h', 'ㄝˇㄑㄩˋ'),
('i', 'ㄞˋ'),
('j', 'ㄐㄟˋ'),
('k', 'ㄎㄟˋ'),
('l', 'ㄝˊㄛˋ'),
('m', 'ㄝˊㄇㄨˋ'),
('n', 'ㄣˉ'),
('o', 'ㄡˉ'),
('p', 'ㄆㄧˉ'),
('q', 'ㄎㄧㄡˉ'),
('r', 'ㄚˋ'),
('s', 'ㄝˊㄙˋ'),
('t', 'ㄊㄧˋ'),
('u', 'ㄧㄡˉ'),
('v', 'ㄨㄧˉ'),
('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
('x', 'ㄝˉㄎㄨˋㄙˋ'),
('y', 'ㄨㄞˋ'),
('z', 'ㄗㄟˋ')
]]
# List of (bopomofo, romaji) pairs:
_bopomofo_to_romaji = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
('ㄅㄛ', 'p⁼wo'),
('ㄆㄛ', 'pʰwo'),
('ㄇㄛ', 'mwo'),
('ㄈㄛ', 'fwo'),
('', 'p⁼'),
('', ''),
('', 'm'),
('', 'f'),
('', 't⁼'),
('', ''),
('', 'n'),
('', 'l'),
('', 'k⁼'),
('', ''),
('', 'h'),
('', 'ʧ⁼'),
('', 'ʧʰ'),
('', 'ʃ'),
('', 'ʦ`⁼'),
('', 'ʦ`ʰ'),
('', 's`'),
('', 'ɹ`'),
('', 'ʦ⁼'),
('', 'ʦʰ'),
('', 's'),
('', 'a'),
('', 'o'),
('', 'ə'),
('', 'e'),
('', 'ai'),
('', 'ei'),
('', 'au'),
('', 'ou'),
('ㄧㄢ', 'yeNN'),
('', 'aNN'),
('ㄧㄣ', 'iNN'),
('', 'əNN'),
('', 'aNg'),
('ㄧㄥ', 'iNg'),
('ㄨㄥ', 'uNg'),
('ㄩㄥ', 'yuNg'),
('', 'əNg'),
('', 'əɻ'),
('', 'i'),
('', 'u'),
('', 'ɥ'),
('ˉ', ''),
('ˊ', ''),
('ˇ', '↓↑'),
('ˋ', ''),
('˙', ''),
('', ','),
('', '.'),
('', '!'),
('', '?'),
('', '-')
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def japanese_to_romaji_with_accent(text):
'''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
sentences = re.split(_japanese_marks, text)
marks = re.findall(_japanese_marks, text)
text = ''
for i, sentence in enumerate(sentences):
if re.match(_japanese_characters, sentence):
if text!='':
text+=' '
labels = pyopenjtalk.extract_fullcontext(sentence)
for n, label in enumerate(labels):
phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
if phoneme not in ['sil','pau']:
text += phoneme.replace('ch','ʧ').replace('sh','ʃ').replace('cl','Q')
else:
continue
n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
a2 = int(re.search(r"\+(\d+)\+", label).group(1))
a3 = int(re.search(r"\+(\d+)/", label).group(1))
if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil','pau']:
a2_next=-1
else:
a2_next = int(re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
# Accent phrase boundary
if a3 == 1 and a2_next == 1:
text += ' '
# Falling
elif a1 == 0 and a2_next == a2 + 1 and a2 != n_moras:
text += ''
# Rising
elif a2 == 1 and a2_next == 2:
text += ''
if i<len(marks):
text += unidecode(marks[i]).replace(' ','')
return text
def latin_to_hangul(text):
for regex, replacement in _latin_to_hangul:
text = re.sub(regex, replacement, text)
return text
def divide_hangul(text):
for regex, replacement in _hangul_divided:
text = re.sub(regex, replacement, text)
return text
def hangul_number(num, sino=True):
'''Reference https://github.com/Kyubyong/g2pK'''
num = re.sub(',', '', num)
if num == '0':
return ''
if not sino and num == '20':
return '스무'
digits = '123456789'
names = '일이삼사오육칠팔구'
digit2name = {d: n for d, n in zip(digits, names)}
modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉'
decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔'
digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())}
digit2dec = {d: dec for d, dec in zip(digits, decimals.split())}
spelledout = []
for i, digit in enumerate(num):
i = len(num) - i - 1
if sino:
if i == 0:
name = digit2name.get(digit, '')
elif i == 1:
name = digit2name.get(digit, '') + ''
name = name.replace('일십', '')
else:
if i == 0:
name = digit2mod.get(digit, '')
elif i == 1:
name = digit2dec.get(digit, '')
if digit == '0':
if i % 4 == 0:
last_three = spelledout[-min(3, len(spelledout)):]
if ''.join(last_three) == '':
spelledout.append('')
continue
else:
spelledout.append('')
continue
if i == 2:
name = digit2name.get(digit, '') + ''
name = name.replace('일백', '')
elif i == 3:
name = digit2name.get(digit, '') + ''
name = name.replace('일천', '')
elif i == 4:
name = digit2name.get(digit, '') + ''
name = name.replace('일만', '')
elif i == 5:
name = digit2name.get(digit, '') + ''
name = name.replace('일십', '')
elif i == 6:
name = digit2name.get(digit, '') + ''
name = name.replace('일백', '')
elif i == 7:
name = digit2name.get(digit, '') + ''
name = name.replace('일천', '')
elif i == 8:
name = digit2name.get(digit, '') + ''
elif i == 9:
name = digit2name.get(digit, '') + ''
elif i == 10:
name = digit2name.get(digit, '') + ''
elif i == 11:
name = digit2name.get(digit, '') + ''
elif i == 12:
name = digit2name.get(digit, '') + ''
elif i == 13:
name = digit2name.get(digit, '') + ''
elif i == 14:
name = digit2name.get(digit, '') + ''
elif i == 15:
name = digit2name.get(digit, '') + ''
spelledout.append(name)
return ''.join(elem for elem in spelledout)
def number_to_hangul(text):
'''Reference https://github.com/Kyubyong/g2pK'''
tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text))
for token in tokens:
num, classifier = token
if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers:
spelledout = hangul_number(num, sino=False)
else:
spelledout = hangul_number(num, sino=True)
text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}')
# digit by digit for remaining digits
digits = '0123456789'
names = '영일이삼사오육칠팔구'
for d, n in zip(digits, names):
text = text.replace(d, n)
return text
def number_to_chinese(text):
numbers = re.findall(r'\d+(?:\.?\d+)?', text)
for number in numbers:
text = text.replace(number, cn2an.an2cn(number),1)
return text
def chinese_to_bopomofo(text):
text=text.replace('','').replace('','').replace('','')
words=jieba.lcut(text,cut_all=False)
text=''
for word in words:
bopomofos=lazy_pinyin(word,BOPOMOFO)
if not re.search('[\u4e00-\u9fff]',word):
text+=word
continue
for i in range(len(bopomofos)):
if re.match('[\u3105-\u3129]',bopomofos[i][-1]):
bopomofos[i]+='ˉ'
if text!='':
text+=' '
text+=''.join(bopomofos)
return text
def latin_to_bopomofo(text):
for regex, replacement in _latin_to_bopomofo:
text = re.sub(regex, replacement, text)
return text
def bopomofo_to_romaji(text):
for regex, replacement in _bopomofo_to_romaji:
text = re.sub(regex, replacement, text)
return text
def basic_cleaners(text):
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
'''Pipeline for non-English text that transliterates to ASCII.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def japanese_cleaners(text):
text=japanese_to_romaji_with_accent(text)
if re.match('[A-Za-z]',text[-1]):
text += '.'
return text
def japanese_cleaners2(text):
return japanese_cleaners(text).replace('ts','ʦ').replace('...','')
def korean_cleaners(text):
'''Pipeline for Korean text'''
text = latin_to_hangul(text)
text = number_to_hangul(text)
text = j2hcj(h2j(text))
text = divide_hangul(text)
if re.match('[\u3131-\u3163]',text[-1]):
text += '.'
return text
def chinese_cleaners(text):
'''Pipeline for Chinese text'''
text=number_to_chinese(text)
text=chinese_to_bopomofo(text)
text=latin_to_bopomofo(text)
if re.match('[ˉˊˇˋ˙]',text[-1]):
text += ''
return text
def zh_ja_mixture_cleaners(text):
chinese_texts=re.findall(r'\[ZH\].*?\[ZH\]',text)
japanese_texts=re.findall(r'\[JA\].*?\[JA\]',text)
for chinese_text in chinese_texts:
cleaned_text=number_to_chinese(chinese_text[4:-4])
cleaned_text=chinese_to_bopomofo(cleaned_text)
cleaned_text=latin_to_bopomofo(cleaned_text)
cleaned_text=bopomofo_to_romaji(cleaned_text)
cleaned_text=re.sub('i[aoe]',lambda x:'y'+x.group(0)[1:],cleaned_text)
cleaned_text=re.sub('u[aoəe]',lambda x:'w'+x.group(0)[1:],cleaned_text)
cleaned_text=re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑]+)',lambda x:x.group(1)+'ɹ`'+x.group(2),cleaned_text).replace('ɻ','ɹ`')
cleaned_text=re.sub('([ʦs][⁼ʰ]?)([→↓↑]+)',lambda x:x.group(1)+'ɹ'+x.group(2),cleaned_text)
text = text.replace(chinese_text,cleaned_text+' ',1)
for japanese_text in japanese_texts:
cleaned_text=japanese_to_romaji_with_accent(japanese_text[4:-4]).replace('ts','ʦ').replace('u','ɯ').replace('...','')
text = text.replace(japanese_text,cleaned_text+' ',1)
text=text[:-1]
if re.match('[A-Za-zɯɹəɥ→↓↑]',text[-1]):
text += '.'
return text

View File

@ -0,0 +1,39 @@
'''
Defines the set of symbols used in text input to the model.
'''
'''# japanese_cleaners
_pad = '_'
_punctuation = ',.!?-'
_letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
'''
'''# japanese_cleaners2
_pad = '_'
_punctuation = ',.!?-~…'
_letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
'''
'''# korean_cleaners
_pad = '_'
_punctuation = ',.!?…~'
_letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
'''
'''# chinese_cleaners
_pad = '_'
_punctuation = ',。!?—…'
_letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
'''
# zh_ja_mixture_cleaners
_pad = '_'
_punctuation = ',.!?-~…'
_letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
# Export all symbols:
symbols = [_pad] + list(_punctuation) + list(_letters)
# Special symbol ids
SPACE_ID = symbols.index(" ")

View File

@ -0,0 +1,193 @@
import torch
from torch.nn import functional as F
import numpy as np
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {
'tails': tails,
'tail_bound': tail_bound
}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs
)
return outputs, logabsdet
def searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return torch.sum(
inputs[..., None] >= bin_locations,
dim=-1
) - 1
def unconstrained_rational_quadratic_spline(inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails='linear',
tail_bound=1.,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = torch.zeros_like(inputs)
logabsdet = torch.zeros_like(inputs)
if tails == 'linear':
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError('{} tails are not implemented.'.format(tails))
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative
)
return outputs, logabsdet
def rational_quadratic_spline(inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0., right=1., bottom=0., top=1.,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE):
if torch.min(inputs) < left or torch.max(inputs) > right:
raise ValueError('Input to a transform is not within its domain')
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError('Minimal bin width too large for the number of bins')
if min_bin_height * num_bins > 1.0:
raise ValueError('Minimal bin height too large for the number of bins')
widths = F.softmax(unnormalized_widths, dim=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, dim=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
delta = heights / widths
input_delta = delta.gather(-1, bin_idx)[..., 0]
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
input_heights = heights.gather(-1, bin_idx)[..., 0]
if inverse:
a = (((inputs - input_cumheights) * (input_derivatives
+ input_derivatives_plus_one
- 2 * input_delta)
+ input_heights * (input_delta - input_derivatives)))
b = (input_heights * input_derivatives
- (inputs - input_cumheights) * (input_derivatives
+ input_derivatives_plus_one
- 2 * input_delta))
c = - input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - torch.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta)
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - root).pow(2))
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (input_delta * theta.pow(2)
+ input_derivatives * theta_one_minus_theta)
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
+ 2 * input_delta * theta_one_minus_theta
+ input_derivatives * (1 - theta).pow(2))
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
return outputs, logabsdet

225
utils/tts/vits/utils.py Normal file
View File

@ -0,0 +1,225 @@
import os
import sys
import argparse
import logging
import json
import subprocess
import numpy as np
import librosa
import torch
MATPLOTLIB_FLAG = False
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
def load_checkpoint(checkpoint_path, model, optimizer=None):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
iteration = checkpoint_dict['iteration']
learning_rate = checkpoint_dict['learning_rate']
if optimizer is not None:
optimizer.load_state_dict(checkpoint_dict['optimizer'])
saved_state_dict = checkpoint_dict['model']
if hasattr(model, 'module'):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
new_state_dict= {}
for k, v in state_dict.items():
try:
new_state_dict[k] = saved_state_dict[k]
except:
logger.info("%s is not in the checkpoint" % k)
new_state_dict[k] = v
if hasattr(model, 'module'):
model.module.load_state_dict(new_state_dict)
else:
model.load_state_dict(new_state_dict)
logger.info("Loaded checkpoint '{}' (iteration {})" .format(
checkpoint_path, iteration))
return model, optimizer, learning_rate, iteration
def plot_spectrogram_to_numpy(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10,2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
interpolation='none')
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_alignment_to_numpy(alignment, info=None):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
if info is not None:
xlabel += '\n\n' + info
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def load_audio_to_torch(full_path, target_sampling_rate):
audio, sampling_rate = librosa.load(full_path, sr=target_sampling_rate, mono=True)
return torch.FloatTensor(audio.astype(np.float32))
def load_filepaths_and_text(filename, split="|"):
with open(filename, encoding='utf-8') as f:
filepaths_and_text = [line.strip().split(split) for line in f]
return filepaths_and_text
def get_hparams(init=True):
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
help='JSON file for configuration')
parser.add_argument('-m', '--model', type=str, required=True,
help='Model name')
args = parser.parse_args()
model_dir = os.path.join("./logs", args.model)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
config_path = args.config
config_save_path = os.path.join(model_dir, "config.json")
if init:
with open(config_path, "r") as f:
data = f.read()
with open(config_save_path, "w") as f:
f.write(data)
else:
with open(config_save_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams = HParams(**config)
hparams.model_dir = model_dir
return hparams
def get_hparams_from_dir(model_dir):
config_save_path = os.path.join(model_dir, "config.json")
with open(config_save_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams =HParams(**config)
hparams.model_dir = model_dir
return hparams
def get_hparams_from_file(config_path):
with open(config_path, "r") as f:
data = f.read()
config = json.loads(data)
hparams =HParams(**config)
return hparams
def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir
))
return
cur_hash = subprocess.getoutput("git rev-parse HEAD")
path = os.path.join(model_dir, "githash")
if os.path.exists(path):
saved_hash = open(path).read()
if saved_hash != cur_hash:
logger.warn("git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], cur_hash[:8]))
else:
open(path, "w").write(cur_hash)
def get_logger(model_dir, filename="train.log"):
global logger
logger = logging.getLogger(os.path.basename(model_dir))
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
if not os.path.exists(model_dir):
os.makedirs(model_dir)
h = logging.FileHandler(os.path.join(model_dir, filename))
h.setLevel(logging.DEBUG)
h.setFormatter(formatter)
logger.addHandler(h)
return logger
class HParams():
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
v = HParams(**v)
self[k] = v
def keys(self):
return self.__dict__.keys()
def items(self):
return self.__dict__.items()
def values(self):
return self.__dict__.values()
def __len__(self):
return len(self.__dict__)
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
return setattr(self, key, value)
def __contains__(self, key):
return key in self.__dict__
def __repr__(self):
return self.__dict__.__repr__()

115
utils/tts/vits_utils.py Normal file
View File

@ -0,0 +1,115 @@
import os
import numpy as np
import torch
from torch import LongTensor
import soundfile as sf
# vits
from .vits import utils, commons
from .vits.models import SynthesizerTrn
from .vits.text import text_to_sequence
def tts_model_init(model_path='./vits_model', device='cuda'):
hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))
# hps_ms = utils.get_hparams_from_file('vits_model/config.json')
net_g_ms = SynthesizerTrn(
len(hps_ms.symbols),
hps_ms.data.filter_length // 2 + 1,
hps_ms.train.segment_size // hps_ms.data.hop_length,
n_speakers=hps_ms.data.n_speakers,
**hps_ms.model)
net_g_ms = net_g_ms.eval().to(device)
speakers = hps_ms.speakers
utils.load_checkpoint(os.path.join(model_path, 'G_953000.pth'), net_g_ms, None)
# utils.load_checkpoint('vits_model/G_953000.pth', net_g_ms, None)
return hps_ms, net_g_ms, speakers
class TextToSpeech:
def __init__(self,
model_path="./utils/tts/vits_model",
device='cuda',
RATE=22050,
debug=False,
):
self.debug = debug
self.RATE = RATE
self.device = torch.device(device)
self.limitation = os.getenv("SYSTEM") == "spaces" # 在huggingface spaces中限制文本和音频长度
self.hps_ms, self.net_g_ms, self.speakers = self._tts_model_init(model_path)
def _tts_model_init(self, model_path):
hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))
net_g_ms = SynthesizerTrn(
len(hps_ms.symbols),
hps_ms.data.filter_length // 2 + 1,
hps_ms.train.segment_size // hps_ms.data.hop_length,
n_speakers=hps_ms.data.n_speakers,
**hps_ms.model)
net_g_ms = net_g_ms.eval().to(self.device)
speakers = hps_ms.speakers
utils.load_checkpoint(os.path.join(model_path, 'G_953000.pth'), net_g_ms, None)
if self.debug:
print("Model loaded.")
return hps_ms, net_g_ms, speakers
def _get_text(self, text):
text_norm, clean_text = text_to_sequence(text, self.hps_ms.symbols, self.hps_ms.data.text_cleaners)
if self.hps_ms.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = LongTensor(text_norm)
return text_norm, clean_text
def _preprocess_text(self, text, language):
if language == 0:
return f"[ZH]{text}[ZH]"
elif language == 1:
return f"[JA]{text}[JA]"
return text
def _generate_audio(self, text, speaker_id, noise_scale, noise_scale_w, length_scale):
import time
start_time = time.time()
stn_tst, clean_text = self._get_text(text)
with torch.no_grad():
x_tst = stn_tst.unsqueeze(0).to(self.device)
x_tst_lengths = LongTensor([stn_tst.size(0)]).to(self.device)
speaker_id = LongTensor([speaker_id]).to(self.device)
audio = self.net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
length_scale=length_scale)[0][0, 0].data.cpu().float().numpy()
if self.debug:
print(f"Synthesis time: {time.time() - start_time} s")
return audio
def synthesize(self, text, language, speaker_id, noise_scale, noise_scale_w, length_scale, save_audio=False, return_bytes=False):
if not len(text):
return "输入文本不能为空!", None
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
if len(text) > 100 and self.limitation:
return f"输入文字过长!{len(text)}>100", None
text = self._preprocess_text(text, language)
audio = self._generate_audio(text, speaker_id, noise_scale, noise_scale_w, length_scale)
if self.debug or save_audio:
self.save_audio(audio, self.RATE, 'output_file.wav')
if return_bytes:
audio = self.convert_numpy_to_bytes(audio)
return self.RATE, audio
def convert_numpy_to_bytes(self, audio_data):
if isinstance(audio_data, np.ndarray):
if audio_data.dtype == np.dtype('float32'):
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
audio_data = audio_data.tobytes()
return audio_data
else:
raise TypeError("audio_data must be a numpy array")
def save_audio(self, audio, sample_rate, file_name='output_file.wav'):
sf.write(file_name, audio, samplerate=sample_rate)
print(f"VITS Audio saved to {file_name}")

37
utils/xf_asr_utils.py Normal file
View File

@ -0,0 +1,37 @@
import datetime
import hashlib
import base64
import hmac
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
from datetime import datetime
from time import mktime
from config import get_config
Config = get_config()
def generate_xf_asr_url():
#设置讯飞流式听写API相关参数
APIKey = Config.XF_ASR.API_KEY
APISecret = Config.XF_ASR.API_SECRET
#鉴权并创建websocket_url
url = 'wss://ws-api.xfyun.cn/v2/iat'
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v2/iat " + "HTTP/1.1"
signature_sha = hmac.new(APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
APIKey, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
v = {
"authorization": authorization,
"date": date,
"host": "ws-api.xfyun.cn"
}
url = url + '?' + urlencode(v)
return url