仓库初始化
This commit is contained in:
commit
83cbe007ba
|
@ -0,0 +1,9 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
/app.log
|
||||
app.log
|
||||
/utils/tts/vits_model/
|
||||
vits_model
|
|
@ -0,0 +1,58 @@
|
|||
# TakwayAI后端
|
||||
|
||||
### 项目架构
|
||||
|
||||
AI后端使用fastapi框架,采用分层架构
|
||||
|
||||
```
|
||||
TakwayAI/
|
||||
│
|
||||
├── app/
|
||||
│ ├── __init__.py # 应用初始化和配置
|
||||
│ ├── main.py # 应用启动入口
|
||||
│ ├── models/ # 数据模型定义
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── models.py # 数据库定义
|
||||
│ ├── schemas/ # 请求和响应模型
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── user.py # 用户相关schema
|
||||
│ │ └── ... # 其他schema
|
||||
│ ├── controllers/ # 业务逻辑控制器
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── user.py # 用户相关控制器
|
||||
│ │ └── ... # 其他控制器
|
||||
│ ├── routes/ # 路由和视图函数
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── user.py # 用户相关路由
|
||||
│ │ └── ... # 其他路由
|
||||
│ ├── dependencies/ # 依赖注入相关
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── database.py # 数据库依赖
|
||||
│ │ └── ... # 其他依赖
|
||||
│ └── exceptions/ # 自定义异常处理
|
||||
│ ├── __init__.py
|
||||
│ └── ... # 自定义异常类
|
||||
│
|
||||
├── tests/ # 测试代码
|
||||
│ ├── __init__.py
|
||||
| └── ...
|
||||
│
|
||||
├── config/ # 配置文件
|
||||
│ ├── __init__.py
|
||||
│ ├── production.py # 生产环境配置
|
||||
│ ├── development.py # 开发环境配置
|
||||
│ └── ... # 其他环境配置
|
||||
│
|
||||
├── utils/ # 工具函数和辅助类
|
||||
│ ├── __init__.py
|
||||
│ ├── stt # 语音转文本工具函数
|
||||
│ ├── tts # 语音合成工具函数
|
||||
│ └── ... # 其他工具函数
|
||||
|
|
||||
├── main.py #启动脚本
|
||||
├── app.log # 日志文件
|
||||
├── Dockerfile # 用于构建Docker镜像的Dockerfile
|
||||
├── requirements.txt # 项目依赖列表
|
||||
└── README.md # 项目说明文件
|
||||
```
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
from fastapi import FastAPI, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqlalchemy import create_engine
|
||||
from .models import Base
|
||||
from .routes.character import router as character_router
|
||||
from .routes.user import router as user_router
|
||||
from .routes.session import router as session_router
|
||||
from .routes.chat import router as chat_router
|
||||
from .dependencies.logger import get_logger
|
||||
from config import get_config
|
||||
|
||||
|
||||
#----------------------获取日志------------------------
|
||||
logger = get_logger()
|
||||
logger.info("日志初始化完成")
|
||||
#-----------------------------------------------------
|
||||
|
||||
#----------------------获取配置------------------------
|
||||
Config = get_config()
|
||||
logger.info("配置获取完成")
|
||||
#-----------------------------------------------------
|
||||
|
||||
|
||||
#--------------------初始化数据库-----------------------
|
||||
engine = create_engine(Config.SQLALCHEMY_DATABASE_URI)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
logger.info("数据库初始化完成")
|
||||
#------------------------------------------------------
|
||||
|
||||
|
||||
#--------------------创建FastAPI实例--------------------
|
||||
app = FastAPI()
|
||||
logger.info("FastAPI实例创建完成")
|
||||
#------------------------------------------------------
|
||||
|
||||
|
||||
#---------------------初始化路由------------------------
|
||||
app.include_router(character_router)
|
||||
app.include_router(user_router)
|
||||
app.include_router(session_router)
|
||||
app.include_router(chat_router)
|
||||
logger.info("路由初始化完成")
|
||||
#-------------------------------------------------------
|
||||
|
||||
#-------------------设置跨域中间件-----------------------
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 允许所有源,也可以指定特定源
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # 允许所有方法
|
||||
allow_headers=["*"], # 允许所有头
|
||||
)
|
||||
#-------------------------------------------------------
|
|
@ -0,0 +1 @@
|
|||
from . import *
|
|
@ -0,0 +1,89 @@
|
|||
from ..schemas.character import *
|
||||
from ..dependencies.logger import get_logger
|
||||
from sqlalchemy.orm import Session
|
||||
from ..models import Character
|
||||
from datetime import datetime
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
#依赖注入获取logger
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
#创建新角色
|
||||
async def create_character_handler(character: CharacterCreateRequest, db: Session):
|
||||
new_character = Character(voice_id=character.voice_id, avatar_id=character.avatar_id, background_ids=character.background_ids, name=character.name,
|
||||
wakeup_words=character.wakeup_words, world_scenario = character.world_scenario, description=character.description, emojis=character.emojis, dialogues=character.dialogues)
|
||||
try:
|
||||
db.add(new_character)
|
||||
db.commit()
|
||||
db.refresh(new_character)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
character_create_data = CharacterCreateData(character_id=new_character.id, createdAt=datetime.now().isoformat())
|
||||
return CharacterCreateResponse(status="success",message="创建角色成功",data=character_create_data)
|
||||
|
||||
|
||||
|
||||
#更新角色
|
||||
async def update_character_handler(character_id: int, character: CharacterUpdateRequest, db: Session):
|
||||
existing_character = db.query(Character).filter(Character.id == character_id).first()
|
||||
if not existing_character:
|
||||
return HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="角色不存在")
|
||||
existing_character.voice_id = character.voice_id
|
||||
existing_character.avatar_id = character.avatar_id
|
||||
existing_character.background_ids = character.background_ids
|
||||
existing_character.name = character.name
|
||||
existing_character.wakeup_words = character.wakeup_words
|
||||
existing_character.world_scenario = character.world_scenario
|
||||
existing_character.description = character.description
|
||||
existing_character.emojis = character.emojis
|
||||
existing_character.dialogues = character.dialogues
|
||||
|
||||
try:
|
||||
db.add(existing_character)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
character_update_data = CharacterUpdateData(updatedAt=datetime.now().isoformat())
|
||||
return CharacterUpdateResponse(status="success",message="更新角色成功",data=character_update_data)
|
||||
|
||||
|
||||
#查询角色
|
||||
async def get_character_handler(character_id: int, db: Session):
|
||||
try:
|
||||
existing_character = db.query(Character).filter(Character.id == character_id).first()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
if not existing_character:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="角色不存在")
|
||||
character_query_data = CharacterQueryData(voice_id=existing_character.voice_id,avatar_id=existing_character.avatar_id,background_ids=existing_character.background_ids,name=existing_character.name,
|
||||
wakeup_words=existing_character.wakeup_words,world_scenario=existing_character.world_scenario,description=existing_character.description,emojis=existing_character.emojis,dialogues=existing_character.dialogues)
|
||||
return CharacterQueryResponse(status="success",message="查询角色成功",data=character_query_data)
|
||||
|
||||
|
||||
#删除角色
|
||||
async def delete_character_handler(character_id: int, db: Session):
|
||||
|
||||
try:
|
||||
existing_character = db.query(Character).filter(Character.id == character_id).first()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
if not existing_character:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="角色不存在")
|
||||
try:
|
||||
db.delete(existing_character)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
character_delete_data = CharacterDeleteData(id=character_id,deletedAt=datetime.now().isoformat())
|
||||
return CharacterDeleteResponse(status="success",message="删除角色成功",data=character_delete_data)
|
|
@ -0,0 +1,613 @@
|
|||
from ..schemas.chat import *
|
||||
from ..dependencies.logger import get_logger
|
||||
from .controller_enum import *
|
||||
from ..models import UserCharacter, Session, Character, User
|
||||
from utils.audio_utils import VAD
|
||||
from fastapi import WebSocket, HTTPException, status
|
||||
from datetime import datetime
|
||||
from utils.xf_asr_utils import generate_xf_asr_url
|
||||
from config import get_config
|
||||
import uuid
|
||||
import json
|
||||
import requests
|
||||
import asyncio
|
||||
|
||||
# 依赖注入获取logger
|
||||
logger = get_logger()
|
||||
|
||||
# --------------------初始化本地ASR-----------------------
|
||||
from utils.stt.funasr_utils import FunAutoSpeechRecognizer
|
||||
|
||||
asr = FunAutoSpeechRecognizer()
|
||||
logger.info("本地ASR初始化成功")
|
||||
# -------------------------------------------------------
|
||||
|
||||
# --------------------初始化本地VITS----------------------
|
||||
from utils.tts.vits_utils import TextToSpeech
|
||||
|
||||
tts = TextToSpeech(device='cpu')
|
||||
logger.info("本地TTS初始化成功")
|
||||
# -------------------------------------------------------
|
||||
|
||||
|
||||
# 依赖注入获取Config
|
||||
Config = get_config()
|
||||
|
||||
# ----------------------工具函数-------------------------
|
||||
#获取session内容
|
||||
def get_session_content(session_id,redis,db):
|
||||
session_content_str = ""
|
||||
if redis.exists(session_id):
|
||||
session_content_str = redis.get(session_id)
|
||||
else:
|
||||
session_db = db.query(Session).filter(Session.id == session_id).first()
|
||||
if not session_db:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
|
||||
session_content_str = session_db.content
|
||||
return json.loads(session_content_str)
|
||||
|
||||
#解析大模型流式返回内容
|
||||
def parseChunkDelta(chunk):
|
||||
decoded_data = chunk.decode('utf-8')
|
||||
parsed_data = json.loads(decoded_data[6:])
|
||||
if 'delta' in parsed_data['choices'][0]:
|
||||
delta_content = parsed_data['choices'][0]['delta']
|
||||
return delta_content['content']
|
||||
else:
|
||||
return ""
|
||||
|
||||
#断句函数
|
||||
def split_string_with_punctuation(current_sentence,text,is_first):
|
||||
result = []
|
||||
for char in text:
|
||||
current_sentence += char
|
||||
if is_first and char in ',.?!,。?!':
|
||||
result.append(current_sentence)
|
||||
current_sentence = ''
|
||||
is_first = False
|
||||
elif char in '。?!':
|
||||
result.append(current_sentence)
|
||||
current_sentence = ''
|
||||
return result, current_sentence, is_first
|
||||
#--------------------------------------------------------
|
||||
|
||||
# 创建新聊天
|
||||
async def create_chat_handler(chat: ChatCreateRequest, db, redis):
|
||||
# 创建新的UserCharacter记录
|
||||
new_chat = UserCharacter(user_id=chat.user_id, character_id=chat.character_id)
|
||||
|
||||
try:
|
||||
db.add(new_chat)
|
||||
db.commit()
|
||||
db.refresh(new_chat)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
# 查询所要创建聊天的角色信息,并创建SystemPrompt
|
||||
db_character = db.query(Character).filter(Character.id == chat.character_id).first()
|
||||
db_user = db.query(User).filter(User.id == chat.user_id).first()
|
||||
system_prompt = f"""我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\n{"角色名称: " + db_character.name}。\n{"角色背景: " + db_character.description}\n{"角色所处环境: " + db_character.world_scenario}\n
|
||||
{"角色的常用问候语: " + db_character.wakeup_words}。\n你需要用简单、通俗易懂的口语化方式进行对话,在没有经过允许的情况下,你需要保持上述角色,不得擅自跳出角色设定。\n 与你聊天的对象信息如下:{db_user.persona}"""
|
||||
|
||||
|
||||
# 创建新的Session记录
|
||||
session_id = str(uuid.uuid4())
|
||||
user_id = chat.user_id
|
||||
messages = json.dumps([{"role": "system", "content": system_prompt}], ensure_ascii=False)
|
||||
|
||||
tts_info = {
|
||||
"language": 0,
|
||||
"speaker_id":db_character.voice_id,
|
||||
"noise_scale": 0.1,
|
||||
"noise_scale_w":0.668,
|
||||
"length_scale": 1.2
|
||||
}
|
||||
llm_info = {
|
||||
"model": "abab5.5-chat",
|
||||
"temperature": 1,
|
||||
"top_p": 0.9,
|
||||
}
|
||||
|
||||
# 将tts和llm信息转化为json字符串
|
||||
tts_info_str = json.dumps(tts_info, ensure_ascii=False)
|
||||
llm_info_str = json.dumps(llm_info, ensure_ascii=False)
|
||||
user_info_str = db_user.persona
|
||||
|
||||
token = 0
|
||||
content = {"user_id": user_id, "messages": messages, "user_info": user_info_str, "tts_info": tts_info_str,
|
||||
"llm_info": llm_info_str, "token": token}
|
||||
new_session = Session(id=session_id, user_character_id=new_chat.id, content=json.dumps(content, ensure_ascii=False),
|
||||
last_activity=datetime.now(), is_permanent=False)
|
||||
|
||||
# 将Session记录存入
|
||||
db.add(new_session)
|
||||
db.commit()
|
||||
db.refresh(new_session)
|
||||
redis.set(session_id, json.dumps(content, ensure_ascii=False))
|
||||
|
||||
chat_create_data = ChatCreateData(user_character_id=new_chat.id, session_id=session_id, createdAt=datetime.now().isoformat())
|
||||
return ChatCreateResponse(status="success", message="创建聊天成功", data=chat_create_data)
|
||||
|
||||
|
||||
#删除聊天
|
||||
async def delete_chat_handler(user_character_id, db, redis):
|
||||
# 查询该聊天记录
|
||||
user_character_record = db.query(UserCharacter).filter(UserCharacter.id == user_character_id).first()
|
||||
if not user_character_record:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="UserCharacter not found")
|
||||
session_record = db.query(Session).filter(Session.user_character_id == user_character_id).first()
|
||||
try:
|
||||
redis.delete(session_record.id)
|
||||
except Exception as e:
|
||||
logger.error(f"删除Redis中Session记录时发生错误: {str(e)}")
|
||||
try:
|
||||
db.delete(session_record)
|
||||
db.delete(user_character_record)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
chat_delete_data = ChatDeleteData(deletedAt=datetime.now().isoformat())
|
||||
return ChatDeleteResponse(status="success", message="删除聊天成功", data=chat_delete_data)
|
||||
|
||||
|
||||
# 非流式聊天
|
||||
async def non_streaming_chat_handler(chat: ChatNonStreamRequest, db, redis):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
#---------------------------------------单次流式聊天接口---------------------------------------------
|
||||
#处理用户输入
|
||||
async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,user_input_finish_event):
|
||||
logger.debug("用户输入处理函数启动")
|
||||
is_future_done = False
|
||||
try:
|
||||
while not user_input_finish_event.is_set():
|
||||
sct_data_json = json.loads(await ws.receive_text())
|
||||
if not is_future_done:
|
||||
future_session_id.set_result(sct_data_json['meta_info']['session_id'])
|
||||
if sct_data_json['meta_info']['voice_synthesize']:
|
||||
future_response_type.set_result(RESPONSE_AUDIO)
|
||||
else:
|
||||
future_response_type.set_result(RESPONSE_TEXT)
|
||||
is_future_done = True
|
||||
if sct_data_json['text']:
|
||||
await llm_input_q.put(sct_data_json['text'])
|
||||
if not user_input_finish_event.is_set():
|
||||
user_input_finish_event.set()
|
||||
break
|
||||
if sct_data_json['meta_info']['is_end']:
|
||||
await user_input_q.put(sct_data_json['audio'])
|
||||
if not user_input_finish_event.is_set():
|
||||
user_input_finish_event.set()
|
||||
break
|
||||
await user_input_q.put(sct_data_json['audio'])
|
||||
except KeyError as ke:
|
||||
if sct_data_json['state'] == 1 and sct_data_json['method'] == 'heartbeat':
|
||||
logger.debug("收到心跳包")
|
||||
|
||||
#语音识别
|
||||
async def sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event):
|
||||
logger.debug("语音识别函数启动")
|
||||
current_message = ""
|
||||
while not (user_input_finish_event.is_set() and user_input_q.empty()):
|
||||
audio_data = await user_input_q.get()
|
||||
asr_result = asr.streaming_recognize(audio_data)
|
||||
current_message += ''.join(asr_result['text'])
|
||||
asr_result = asr.streaming_recognize(b'',is_end=True)
|
||||
current_message += ''.join(asr_result['text'])
|
||||
await llm_input_q.put(current_message)
|
||||
logger.debug(f"接收到用户消息: {current_message}")
|
||||
|
||||
#大模型调用
|
||||
async def sct_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,llm_response_finish_event):
|
||||
logger.debug("llm调用函数启动")
|
||||
session_content = get_session_content(session_id,redis,db)
|
||||
messages = json.loads(session_content["messages"])
|
||||
current_message = await llm_input_q.get()
|
||||
messages.append({'role': 'user', "content": current_message})
|
||||
payload = json.dumps({
|
||||
"model": llm_info["model"],
|
||||
"stream": True,
|
||||
"messages": messages,
|
||||
"max_tokens": 10000,
|
||||
"temperature": llm_info["temperature"],
|
||||
"top_p": llm_info["top_p"]
|
||||
})
|
||||
headers = {
|
||||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True)
|
||||
if response.status_code == 200:
|
||||
for chunk in response.iter_lines():
|
||||
if chunk:
|
||||
chunk_data = parseChunkDelta(chunk)
|
||||
await llm_response_q.put(chunk_data)
|
||||
llm_response_finish_event.set()
|
||||
|
||||
#大模型返回断句
|
||||
async def sct_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_response_finish_event):
|
||||
logger.debug("llm返回处理函数启动")
|
||||
llm_response = ""
|
||||
current_sentence = ""
|
||||
is_first = True
|
||||
while not (llm_response_finish_event.is_set() and llm_response_q.empty()):
|
||||
llm_chunk = await llm_response_q.get()
|
||||
llm_response += llm_chunk
|
||||
sentences, current_sentence, is_first = split_string_with_punctuation(current_sentence, llm_chunk, is_first)
|
||||
for sentence in sentences:
|
||||
await split_result_q.put(sentence)
|
||||
session_content = get_session_content(session_id,redis,db)
|
||||
messages = json.loads(session_content["messages"])
|
||||
messages.append({'role': 'assistant', "content": llm_response})
|
||||
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||
logger.debug(f"llm返回结果: {llm_response}")
|
||||
|
||||
#文本返回及语音合成
|
||||
async def sct_response_handler(ws,tts_info,response_type,split_result_q,llm_response_finish_event,chat_finish_event):
|
||||
logger.debug("返回处理函数启动")
|
||||
while not (llm_response_finish_event.is_set() and split_result_q.empty()):
|
||||
sentence = await split_result_q.get()
|
||||
if response_type == RESPONSE_TEXT:
|
||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
||||
elif response_type == RESPONSE_AUDIO:
|
||||
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
|
||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||
await ws.send_bytes(audio)
|
||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
||||
logger.debug(f"websocket返回: {sentence}")
|
||||
chat_finish_event.set()
|
||||
|
||||
async def streaming_chat_temporary_handler(ws: WebSocket, db, redis):
|
||||
logger.debug("streaming chat temporary websocket 连接建立")
|
||||
user_input_q = asyncio.Queue() # 用于存储用户输入
|
||||
llm_input_q = asyncio.Queue() # 用于存储llm输入
|
||||
llm_response_q = asyncio.Queue() # 用于存储llm输出
|
||||
split_result_q = asyncio.Queue() # 用于存储tts输出
|
||||
|
||||
user_input_finish_event = asyncio.Event()
|
||||
llm_response_finish_event = asyncio.Event()
|
||||
chat_finish_event = asyncio.Event()
|
||||
future_session_id = asyncio.Future()
|
||||
future_response_type = asyncio.Future()
|
||||
asyncio.create_task(sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,user_input_finish_event))
|
||||
asyncio.create_task(sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event))
|
||||
|
||||
|
||||
session_id = await future_session_id #获取session_id
|
||||
response_type = await future_response_type #获取返回类型
|
||||
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
||||
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
||||
|
||||
asyncio.create_task(sct_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,llm_response_finish_event))
|
||||
asyncio.create_task(sct_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_response_finish_event))
|
||||
asyncio.create_task(sct_response_handler(ws,tts_info,response_type,split_result_q,llm_response_finish_event,chat_finish_event))
|
||||
|
||||
while not chat_finish_event.is_set():
|
||||
await asyncio.sleep(1)
|
||||
await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||
await ws.close()
|
||||
logger.debug("streaming chat temporary websocket 连接断开")
|
||||
#---------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# 持续流式聊天
|
||||
async def streaming_chat_lasting_handler(ws, db, redis):
|
||||
print("Websocket连接成功")
|
||||
while True:
|
||||
try:
|
||||
print("等待接受")
|
||||
data = await asyncio.wait_for(ws.receive_text(), timeout=60)
|
||||
data_json = json.loads(data)
|
||||
if data_json["is_close"]:
|
||||
close_message = {"type": "close", "code": 200, "msg": ""}
|
||||
await ws.send_text(json.dumps(close_message, ensure_ascii=False))
|
||||
print("连接关闭")
|
||||
await asyncio.sleep(0.5)
|
||||
await ws.close()
|
||||
return;
|
||||
except asyncio.TimeoutError:
|
||||
print("连接超时")
|
||||
await ws.close()
|
||||
return;
|
||||
current_message = "" # 用于存储用户消息
|
||||
response_type = RESPONSE_TEXT # 用于获取返回类型
|
||||
session_id = ""
|
||||
if Config.STRAM_CHAT.ASR == "LOCAL":
|
||||
try:
|
||||
while True:
|
||||
if data_json["text"]: # 若文字输入不为空,则表示该输入为文字输入
|
||||
if data_json["meta_info"]["voice_synthesize"]:
|
||||
response_type = RESPONSE_AUDIO # 查看voice_synthesize判断返回类型
|
||||
session_id = data_json["meta_info"]["session_id"]
|
||||
current_message = data_json['text']
|
||||
break
|
||||
|
||||
if not data_json['meta_info']['is_end']: # 还在发
|
||||
asr_result = asr.streaming_recognize(data_json["audio"])
|
||||
current_message += ''.join(asr_result['text'])
|
||||
else: # 发完了
|
||||
asr_result = asr.streaming_recognize(data_json["audio"], is_end=True)
|
||||
session_id = data_json["meta_info"]["session_id"]
|
||||
current_message += ''.join(asr_result['text'])
|
||||
if data_json["meta_info"]["voice_synthesize"]:
|
||||
response_type = RESPONSE_AUDIO # 查看voice_synthesize判断返回类型
|
||||
break
|
||||
data_json = json.loads(await ws.receive_text())
|
||||
|
||||
except Exception as e:
|
||||
error_info = f"接收用户消息错误: {str(e)}"
|
||||
error_message = {"type": "error", "code": "500", "msg": error_info}
|
||||
logger.error(error_info)
|
||||
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
|
||||
await ws.close()
|
||||
return
|
||||
elif Config.STRAM_CHAT.ASR == "REMOTE":
|
||||
error_info = f"远程ASR服务暂未开通"
|
||||
error_message = {"type": "error", "code": "500", "msg": error_info}
|
||||
logger.error(error_info)
|
||||
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
|
||||
await ws.close()
|
||||
return
|
||||
|
||||
print(f"接收到用户消息: {current_message}")
|
||||
# 查询Session记录
|
||||
session_content_str = ""
|
||||
if redis.exists(session_id):
|
||||
session_content_str = redis.get(session_id)
|
||||
else:
|
||||
session_db = db.query(Session).filter(Session.id == session_id).first()
|
||||
if not session_db:
|
||||
error_info = f"未找到session记录: {str(e)}"
|
||||
error_message = {"type": "error", "code": 500, "msg": error_info}
|
||||
logger.error(error_info)
|
||||
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
|
||||
await ws.close()
|
||||
return
|
||||
session_content_str = session_db.content
|
||||
|
||||
session_content = json.loads(session_content_str)
|
||||
llm_info = json.loads(session_content["llm_info"])
|
||||
tts_info = json.loads(session_content["tts_info"])
|
||||
user_info = json.loads(session_content["user_info"])
|
||||
messages = json.loads(session_content["messages"])
|
||||
messages.append({'role': 'user', "content": current_message})
|
||||
token_count = session_content["token"]
|
||||
|
||||
try:
|
||||
payload = json.dumps({
|
||||
"model": llm_info["model"],
|
||||
"stream": True,
|
||||
"messages": messages,
|
||||
"tool_choice": "auto",
|
||||
"max_tokens": 10000,
|
||||
"temperature": llm_info["temperature"],
|
||||
"top_p": llm_info["top_p"]
|
||||
})
|
||||
headers = {
|
||||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True)
|
||||
except Exception as e:
|
||||
error_info = f"发送信息给大模型时发生错误: {str(e)}"
|
||||
error_message = {"type": "error", "code": 500, "msg": error_info}
|
||||
logger.error(error_info)
|
||||
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
|
||||
await ws.close()
|
||||
return
|
||||
|
||||
def split_string_with_punctuation(text):
|
||||
punctuations = "!?。"
|
||||
result = []
|
||||
current_sentence = ""
|
||||
for char in text:
|
||||
current_sentence += char
|
||||
if char in punctuations:
|
||||
result.append(current_sentence)
|
||||
current_sentence = ""
|
||||
# 判断最后一个字符是否为标点符号
|
||||
if current_sentence and current_sentence[-1] not in punctuations:
|
||||
# 如果最后一段不以标点符号结尾,则加入拆分数组
|
||||
result.append(current_sentence)
|
||||
return result
|
||||
|
||||
llm_response = ""
|
||||
response_buf = ""
|
||||
|
||||
try:
|
||||
if Config.STRAM_CHAT.TTS == "LOCAL":
|
||||
if response.status_code == 200:
|
||||
for chunk in response.iter_lines():
|
||||
if chunk:
|
||||
if response_type == RESPONSE_AUDIO:
|
||||
chunk_data = parseChunkDelta(chunk)
|
||||
llm_response += chunk_data
|
||||
response_buf += chunk_data
|
||||
split_buf = split_string_with_punctuation(response_buf)
|
||||
response_buf = ""
|
||||
if len(split_buf) != 0:
|
||||
for sentence in split_buf:
|
||||
sr, audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
|
||||
text_response = {"type": "text", "code": 200, "msg": sentence}
|
||||
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) # 返回文本数据
|
||||
await ws.send_bytes(audio) # 返回音频二进制流数据
|
||||
if response_type == RESPONSE_TEXT:
|
||||
chunk_data = parseChunkDelta(chunk)
|
||||
llm_response += chunk_data
|
||||
text_response = {"type": "text", "code": 200, "msg": chunk_data}
|
||||
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) # 返回文本数据
|
||||
|
||||
elif Config.STRAM_CHAT.TTS == "REMOTE":
|
||||
error_info = f"暂不支持远程音频合成"
|
||||
error_message = {"type": "error", "code": 500, "msg": error_info}
|
||||
logger.error(error_info)
|
||||
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
|
||||
await ws.close()
|
||||
return
|
||||
end_response = {"type": "end", "code": 200, "msg": ""}
|
||||
await ws.send_text(json.dumps(end_response, ensure_ascii=False)) # 单次返回结束
|
||||
print(f"llm消息: {llm_response}")
|
||||
except Exception as e:
|
||||
error_info = f"音频合成与向前端返回时错误: {str(e)}"
|
||||
error_message = {"type": "error", "code": 500, "msg": error_info}
|
||||
logger.error(error_info)
|
||||
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
|
||||
await ws.close()
|
||||
return
|
||||
|
||||
try:
|
||||
messages.append({'role': 'assistant', "content": llm_response})
|
||||
token_count += len(llm_response)
|
||||
session_content["messages"] = json.dumps(messages, ensure_ascii=False)
|
||||
session_content["token"] = token_count
|
||||
redis.set(session_id, json.dumps(session_content, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
error_info = f"更新session时错误: {str(e)}"
|
||||
error_message = {"type": "error", "code": 500, "msg": error_info}
|
||||
logger.error(error_info)
|
||||
await ws.send_text(json.dumps(error_message, ensure_ascii=False))
|
||||
await ws.close()
|
||||
return
|
||||
print("处理完毕")
|
||||
|
||||
|
||||
#--------------------------------语音通话接口--------------------------------------
|
||||
#音频数据生产函数
|
||||
async def voice_call_audio_producer(ws,audio_queue,future,shutdown_event):
|
||||
logger.debug("音频数据生产函数启动")
|
||||
is_future_done = False
|
||||
while not shutdown_event.is_set():
|
||||
voice_call_data_json = json.loads(await ws.receive_text())
|
||||
if not is_future_done: #在第一次循环中读取session_id
|
||||
future.set_result(voice_call_data_json['meta_info']['session_id'])
|
||||
is_future_done = True
|
||||
if voice_call_data_json["is_close"]:
|
||||
shutdown_event.set()
|
||||
break
|
||||
else:
|
||||
await audio_queue.put(voice_call_data_json["audio"]) #将音频数据存入audio_q
|
||||
|
||||
#音频数据消费函数
|
||||
async def voice_call_audio_consumer(audio_q,asr_result_q,shutdown_event):
|
||||
logger.debug("音频数据消费者函数启动")
|
||||
vad = VAD()
|
||||
current_message = ""
|
||||
vad_count = 0
|
||||
while not (shutdown_event.is_set() and audio_q.empty()):
|
||||
audio_data = await audio_q.get()
|
||||
if vad.is_speech(audio_data):
|
||||
if vad_count > 0:
|
||||
vad_count -= 1
|
||||
asr_result = asr.streaming_recognize(audio_data)
|
||||
current_message += ''.join(asr_result['text'])
|
||||
else:
|
||||
vad_count += 1
|
||||
if vad_count >= 25: #连续25帧没有语音,则认为说完了
|
||||
asr_result = asr.streaming_recognize(audio_data, is_end=True)
|
||||
if current_message:
|
||||
logger.debug(f"检测到静默,用户输入为:{current_message}")
|
||||
await asr_result_q.put(current_message)
|
||||
current_message = ""
|
||||
vad_count = 0
|
||||
|
||||
#asr结果消费以及llm返回生产函数
|
||||
async def voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,shutdown_event):
|
||||
logger.debug("asr结果消费以及llm返回生产函数启动")
|
||||
while not (shutdown_event.is_set() and asr_result_q.empty()):
|
||||
session_content = get_session_content(session_id,redis,db)
|
||||
messages = json.loads(session_content["messages"])
|
||||
current_message = await asr_result_q.get()
|
||||
messages.append({'role': 'user', "content": current_message})
|
||||
payload = json.dumps({
|
||||
"model": llm_info["model"],
|
||||
"stream": True,
|
||||
"messages": messages,
|
||||
"max_tokens":10000,
|
||||
"temperature": llm_info["temperature"],
|
||||
"top_p": llm_info["top_p"]
|
||||
})
|
||||
|
||||
headers = {
|
||||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True)
|
||||
if response.status_code == 200:
|
||||
for chunk in response.iter_lines():
|
||||
if chunk:
|
||||
chunk_data =parseChunkDelta(chunk)
|
||||
llm_frame = {'message':chunk_data,'is_end':False}
|
||||
await llm_response_q.put(llm_frame)
|
||||
llm_frame = {'message':"",'is_end':True}
|
||||
await llm_response_q.put(llm_frame)
|
||||
|
||||
#llm结果返回函数
|
||||
async def voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,shutdown_event):
|
||||
logger.debug("llm结果返回函数启动")
|
||||
llm_response = ""
|
||||
current_sentence = ""
|
||||
is_first = True
|
||||
while not (shutdown_event.is_set() and llm_response_q.empty()):
|
||||
llm_frame = await llm_response_q.get()
|
||||
llm_response += llm_frame['message']
|
||||
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,llm_frame['message'],is_first)
|
||||
for sentence in sentences:
|
||||
await split_result_q.put(sentence)
|
||||
if llm_frame['is_end']:
|
||||
is_first = True
|
||||
session_content = get_session_content(session_id,redis,db)
|
||||
messages = json.loads(session_content["messages"])
|
||||
messages.append({'role': 'assistant', "content": llm_response})
|
||||
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||
logger.debug(f"llm返回结果: {llm_response}")
|
||||
llm_response = ""
|
||||
current_sentence = ""
|
||||
|
||||
#语音合成及返回函数
|
||||
async def voice_call_tts_handler(ws,tts_info,split_result_q,shutdown_event):
|
||||
logger.debug("语音合成及返回函数启动")
|
||||
while not (shutdown_event.is_set() and split_result_q.empty()):
|
||||
sentence = await split_result_q.get()
|
||||
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
|
||||
text_response = {"type": "text", "code": 200, "msg": sentence}
|
||||
await ws.send_bytes(audio) #返回音频二进制流数据
|
||||
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
||||
logger.debug(f"websocket返回:{sentence}")
|
||||
asyncio.sleep(0.5)
|
||||
await ws.close()
|
||||
|
||||
|
||||
async def voice_call_handler(ws, db, redis):
|
||||
logger.debug("voice_call websocket 连接建立")
|
||||
audio_q = asyncio.Queue()
|
||||
asr_result_q = asyncio.Queue()
|
||||
llm_response_q = asyncio.Queue()
|
||||
split_result_q = asyncio.Queue()
|
||||
|
||||
shutdown_event = asyncio.Event()
|
||||
future = asyncio.Future()
|
||||
asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,shutdown_event)) #创建音频数据生产者
|
||||
asyncio.create_task(voice_call_audio_consumer(audio_q,asr_result_q,shutdown_event)) #创建音频数据消费者
|
||||
|
||||
#获取session内容
|
||||
session_id = await future #获取session_id
|
||||
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
||||
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
||||
|
||||
asyncio.create_task(voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,shutdown_event)) #创建llm处理者
|
||||
asyncio.create_task(voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,shutdown_event)) #创建llm断句结果
|
||||
asyncio.create_task(voice_call_tts_handler(ws,tts_info,split_result_q,shutdown_event)) #返回tts音频结果
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
await asyncio.sleep(5)
|
||||
await ws.close()
|
||||
logger.debug("voice_call websocket 连接断开")
|
||||
#------------------------------------------------------------------------------------------
|
|
@ -0,0 +1,13 @@
|
|||
#聊天类型:文本、语音、不确定
|
||||
CHAT_TEXT = 0
|
||||
CHAT_AUDIO = 1
|
||||
CHAT_UNCERTAIN = -1
|
||||
|
||||
#流式传输帧类型:第一帧、中间帧、最后帧
|
||||
FIRST_FRAME = 1
|
||||
CONTINUE_FRAME =2
|
||||
LAST_FRAME =3
|
||||
|
||||
#响应类型:文本、语音
|
||||
RESPONSE_TEXT = 0
|
||||
RESPONSE_AUDIO = 1
|
|
@ -0,0 +1,79 @@
|
|||
from ..schemas.session import *
|
||||
from ..dependencies.logger import get_logger
|
||||
from fastapi import HTTPException, status
|
||||
from ..models import Session
|
||||
from ..models import UserCharacter
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
#依赖注入获取logger
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
#获取SessionID
|
||||
async def get_session_id_handler(user_id: int, character_id:int, db):
|
||||
try:
|
||||
user_character_record = db.query(UserCharacter).filter(UserCharacter.user_id == user_id, UserCharacter.character_id == character_id).first()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
if not user_character_record:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User Character not found")
|
||||
try:
|
||||
session_id = db.query(Session).filter(Session.user_character_id==user_character_record.id).first().id
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
session_id_query_data = SessionIdQueryData(session_id=session_id)
|
||||
return SessionIdQueryResponse(status="success",message="Session ID 获取成功",data=session_id_query_data)
|
||||
|
||||
#查询Session信息
|
||||
async def get_session_handler(session_id: int, db, redis):
|
||||
session_str = ""
|
||||
if redis.exists(session_id):
|
||||
session_str = redis.get(session_id)
|
||||
else:
|
||||
try:
|
||||
session_str = db.query(Session).filter(Session.id == session_id).first().content
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
redis.set(session_id, session_str)
|
||||
if not session_str:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
|
||||
session = json.loads(session_str)
|
||||
session_query_data = SessionQueryData(user_id=session["user_id"], messages=session["messages"],user_info=session["user_info"],tts_info=session["tts_info"],llm_info=session["llm_info"],token=session["token"])
|
||||
return SessionQueryResponse(status="success",message="Session 查询成功",data=session_query_data)
|
||||
|
||||
#更新Sessino信息
|
||||
async def update_session_handler(session_id, session_data:SessionUpdateRequest, db, redis):
|
||||
existing_session = ""
|
||||
if redis.exists(session_id):
|
||||
existing_session = redis.get(session_id)
|
||||
else:
|
||||
existing_session = db.query(Session).filter(Session.id == session_id).first().content
|
||||
if not existing_session:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
|
||||
|
||||
#更新Session字段
|
||||
session = json.loads(existing_session)
|
||||
session["user_id"] = session_data.user_id
|
||||
session["messages"] = session_data.messages
|
||||
session["user_info"] = session_data.user_info
|
||||
session["tts_info"] = session_data.tts_info
|
||||
session["llm_info"] = session_data.llm_info
|
||||
session["token"] = session_data.token
|
||||
|
||||
#存储Session
|
||||
session_str = json.dumps(session,ensure_ascii=False)
|
||||
redis.set(session_id, session_str)
|
||||
try:
|
||||
db.query(Session).filter(Session.id == session_id).update({"content": session_str})
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
session_update_data = SessionUpdateData(updatedAt=datetime.now().isoformat())
|
||||
return SessionUpdateResponse(status="success",message="Session 更新成功",data=session_update_data)
|
|
@ -0,0 +1,159 @@
|
|||
from ..schemas.user import *
|
||||
from ..dependencies.logger import get_logger
|
||||
from ..models import User, Hardware
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
#依赖注入获取logger
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
#创建用户
|
||||
async def create_user_handler(user:UserCrateRequest, db: Session):
|
||||
new_user = User(created_at=datetime.now(), open_id=user.open_id, username=user.username, avatar_id=user.avatar_id, tags=user.tags, persona=user.persona)
|
||||
try:
|
||||
db.add(new_user)
|
||||
db.commit()
|
||||
db.refresh(new_user)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
user_create_data = UserCrateData(user_id=new_user.id, createdAt=new_user.created_at.isoformat())
|
||||
return UserCrateResponse(status="success", message="创建用户成功", data=user_create_data)
|
||||
|
||||
|
||||
#更新用户信息
|
||||
async def update_user_handler(user_id:int, user:UserUpdateRequest, db: Session):
|
||||
existing_user = db.query(User).filter(User.id == user_id).first()
|
||||
if existing_user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
|
||||
existing_user.open_id = user.open_id
|
||||
existing_user.username = user.username
|
||||
existing_user.avatar_id = user.avatar_id
|
||||
existing_user.tags = user.tags
|
||||
existing_user.persona = user.persona
|
||||
try:
|
||||
db.add(existing_user)
|
||||
db.commit()
|
||||
db.refresh(existing_user)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
user_update_data = UserUpdateData(updatedAt=datetime.now().isoformat())
|
||||
return UserUpdateResponse(status="success", message="更新用户信息成功", data=user_update_data)
|
||||
|
||||
|
||||
#查询用户信息
|
||||
async def get_user_handler(user_id:int, db: Session):
|
||||
try:
|
||||
existing_user = db.query(User).filter(User.id == user_id).first()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
if existing_user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
|
||||
user_query_data = UserQueryData(open_id=existing_user.open_id, username=existing_user.username, avatar_id=existing_user.avatar_id, tags=existing_user.tags, persona=existing_user.persona)
|
||||
return UserQueryResponse(status="success", message="查询用户信息成功", data=user_query_data)
|
||||
|
||||
|
||||
#删除用户
|
||||
async def delete_user_handler(user_id:int, db: Session):
|
||||
try:
|
||||
existing_user = db.query(User).filter(User.id == user_id).first()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
if existing_user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
|
||||
try:
|
||||
db.delete(existing_user)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
user_delete_data = UserDeleteData(deletedAt=datetime.now().isoformat())
|
||||
return UserDeleteResponse(status="success", message="删除用户成功", data=user_delete_data)
|
||||
|
||||
|
||||
#绑定硬件
|
||||
async def bind_hardware_handler(hardware, db: Session):
|
||||
new_hardware = Hardware(mac=hardware.mac, user_id=hardware.user_id, firmware=hardware.firmware, model=hardware.model)
|
||||
try:
|
||||
db.add(new_hardware)
|
||||
db.commit()
|
||||
db.refresh(new_hardware)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
hardware_bind_data = HardwareBindData(hardware_id=new_hardware.id, bindedAt=datetime.now().isoformat())
|
||||
return HardwareBindResponse(status="success", message="绑定硬件成功", data=hardware_bind_data)
|
||||
|
||||
|
||||
#解绑硬件
|
||||
async def unbind_hardware_handler(hardware_id:int, db: Session):
|
||||
try:
|
||||
existing_hardware = db.query(Hardware).filter(Hardware.id == hardware_id).first()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
if existing_hardware is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
|
||||
try:
|
||||
db.delete(existing_hardware)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
hardware_unbind_data = HardwareUnbindData(unbindedAt=datetime.now().isoformat())
|
||||
return HardwareUnbindResponse(status="success", message="解绑硬件成功", data=hardware_unbind_data)
|
||||
|
||||
|
||||
#硬件换绑
|
||||
async def change_bind_hardware_handler(hardware_id, user, db):
|
||||
try:
|
||||
existing_hardware = db.query(Hardware).filter(Hardware.id == hardware_id).first()
|
||||
if existing_hardware is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
|
||||
existing_hardware.user_id = user.user_id
|
||||
db.add(existing_hardware)
|
||||
db.commit()
|
||||
db.refresh(existing_hardware)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
hardware_change_bind_data = HardwareChangeBindData(bindChangedAt=datetime.now().isoformat())
|
||||
return HardwareChangeBindResponse(status="success", message="硬件换绑成功", data=hardware_change_bind_data)
|
||||
|
||||
|
||||
#硬件信息更新
|
||||
async def update_hardware_handler(hardware_id, hardware, db):
|
||||
try:
|
||||
existing_hardware = db.query(Hardware).filter(Hardware.id == hardware_id).first()
|
||||
if existing_hardware is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
|
||||
existing_hardware.mac = hardware.mac
|
||||
existing_hardware.firmware = hardware.firmware
|
||||
existing_hardware.model = hardware.model
|
||||
db.add(existing_hardware)
|
||||
db.commit()
|
||||
db.refresh(existing_hardware)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
hardware_update_data = HardwareUpdateData(updatedAt=datetime.now().isoformat())
|
||||
return HardwareUpdateResponse(status="success", message="硬件信息更新成功", data=hardware_update_data)
|
||||
|
||||
|
||||
#查询硬件
|
||||
async def get_hardware_handler(hardware_id, db):
|
||||
try:
|
||||
existing_hardware = db.query(Hardware).filter(Hardware.id == hardware_id).first()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
if existing_hardware is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
|
||||
hardware_query_data = HardwareQueryData(mac=existing_hardware.mac, user_id=existing_hardware.user_id, firmware=existing_hardware.firmware, model=existing_hardware.model)
|
||||
return HardwareQueryResponse(status="success", message="查询硬件信息成功", data=hardware_query_data)
|
|
@ -0,0 +1 @@
|
|||
from . import *
|
|
@ -0,0 +1,15 @@
|
|||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import create_engine
|
||||
from config import get_config
|
||||
|
||||
Config = get_config()
|
||||
|
||||
LocalSession = sessionmaker(autocommit=False, autoflush=False, bind=create_engine(Config.SQLALCHEMY_DATABASE_URI))
|
||||
|
||||
#返回一个数据库连接
|
||||
def get_db():
|
||||
db = LocalSession()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
|
@ -0,0 +1,27 @@
|
|||
import logging
|
||||
from config import get_config
|
||||
|
||||
#获取配置信息
|
||||
Config = get_config()
|
||||
|
||||
#日志类
|
||||
class Logger:
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger.setLevel(Config.LOG_LEVEL)
|
||||
self.logger.propagate = False
|
||||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
if not self.logger.handlers: # 检查是否已经有处理器
|
||||
# 输出到控制台
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
self.logger.addHandler(console_handler)
|
||||
|
||||
# 输出到文件
|
||||
file_handler = logging.FileHandler('app.log')
|
||||
file_handler.setFormatter(formatter)
|
||||
self.logger.addHandler(file_handler)
|
||||
|
||||
def get_logger():
|
||||
return Logger().logger
|
|
@ -0,0 +1,8 @@
|
|||
import redis
|
||||
from config import get_config
|
||||
|
||||
#获取配置信息
|
||||
Config = get_config()
|
||||
|
||||
def get_redis():
|
||||
return redis.Redis.from_url(Config.REDIS_URL)
|
|
@ -0,0 +1,6 @@
|
|||
from app import app
|
||||
import uvicorn
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=7878)
|
|
@ -0,0 +1 @@
|
|||
from .models import *
|
|
@ -0,0 +1,82 @@
|
|||
from sqlalchemy import Column, Integer, String, JSON, Text, ForeignKey, DateTime, Boolean, CHAR
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
#角色表定义
|
||||
class Character(Base):
|
||||
__tablename__ = 'character'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
voice_id = Column(Integer, nullable=False)
|
||||
avatar_id = Column(String(36), nullable=False)
|
||||
background_ids = Column(String(255), nullable=False)
|
||||
name = Column(String(36), nullable=False)
|
||||
wakeup_words = Column(String(255), nullable=False)
|
||||
world_scenario = Column(Text, nullable=False)
|
||||
description = Column(Text, nullable=False)
|
||||
emojis = Column(JSON, nullable=False)
|
||||
dialogues = Column(Text, nullable=False)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Character(id={self.id}, name={self.name}, avatar_id={self.avatar_id})>"
|
||||
|
||||
|
||||
#用户表定义
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
created_at = Column(DateTime, nullable=True)
|
||||
updated_at = Column(DateTime, nullable=True)
|
||||
deleted_at = Column(DateTime, nullable=True)
|
||||
open_id = Column(String(255), nullable=True)
|
||||
username = Column(String(64), nullable=True)
|
||||
avatar_id = Column(String(36), nullable=True)
|
||||
tags = Column(JSON)
|
||||
persona = Column(JSON)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User(id={self.id}, tags={self.tags})>"
|
||||
|
||||
|
||||
#硬件表定义
|
||||
class Hardware(Base):
|
||||
__tablename__ = 'hardware'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(Integer, ForeignKey('user.id'))
|
||||
mac = Column(String(17))
|
||||
firmware = Column(String(16))
|
||||
model = Column(String(36))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Hardware( mac={self.mac})>"
|
||||
|
||||
|
||||
#用户角色表定义
|
||||
class UserCharacter(Base):
|
||||
__tablename__ = 'user_character'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(Integer, ForeignKey('user.id'))
|
||||
character_id = Column(Integer, ForeignKey('character.id'))
|
||||
persona = Column(JSON)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<UserCharacter(id={self.id}, user_id={self.user_id}, character_id={self.character_id})>"
|
||||
|
||||
|
||||
#Session表定义
|
||||
class Session(Base):
|
||||
__tablename__ = 'session'
|
||||
|
||||
id = Column(CHAR(36), primary_key=True)
|
||||
user_character_id = Column(Integer, ForeignKey('user_character.id'))
|
||||
content = Column(Text)
|
||||
last_activity = Column(DateTime())
|
||||
is_permanent = Column(Boolean)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Session(id={self.id}, user_character_id={self.user_character_id})>"
|
|
@ -0,0 +1 @@
|
|||
from . import *
|
|
@ -0,0 +1,35 @@
|
|||
from fastapi import APIRouter, HTTPException, status
|
||||
from ..controllers.character import *
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from ..dependencies.database import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
#角色创建接口
|
||||
@router.post("/characters",response_model=CharacterCreateResponse)
|
||||
async def create_character(character:CharacterCreateRequest, db: Session = Depends(get_db)):
|
||||
response = await create_character_handler(character, db)
|
||||
return response
|
||||
|
||||
|
||||
#用户更新接口
|
||||
@router.put("/characters/{character_id}",response_model=CharacterUpdateResponse)
|
||||
async def update_character(character_id:int,character:CharacterUpdateRequest, db: Session = Depends(get_db)):
|
||||
response = await update_character_handler(character_id, character, db)
|
||||
return response
|
||||
|
||||
|
||||
#角色删除接口
|
||||
@router.delete("/characters/{character_id}", response_model=CharacterDeleteResponse)
|
||||
async def delete_character(character_id: int, db: Session = Depends(get_db)):
|
||||
response = await delete_character_handler(character_id, db)
|
||||
return response
|
||||
|
||||
|
||||
#角色查询接口
|
||||
@router.get("/characters/{character_id}",response_model=CharacterQueryResponse)
|
||||
async def get_character(character_id: int, db: Session = Depends(get_db)):
|
||||
response = await get_character_handler(character_id, db)
|
||||
return response
|
|
@ -0,0 +1,48 @@
|
|||
from fastapi import APIRouter, HTTPException, status, WebSocket
|
||||
from ..controllers.chat import *
|
||||
from fastapi import Depends
|
||||
from ..dependencies.database import get_db
|
||||
from ..dependencies.redis import get_redis
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
#创建新聊天接口
|
||||
@router.post("/chats", response_model=ChatCreateResponse)
|
||||
async def create_chat(chat: ChatCreateRequest, db=Depends(get_db), redis=Depends(get_redis)):
|
||||
response = await create_chat_handler(chat, db, redis)
|
||||
return response
|
||||
|
||||
#删除聊天接口
|
||||
@router.delete("/chats/{user_character_id}", response_model=ChatDeleteResponse)
|
||||
async def delete_chat(user_character_id: int, db=Depends(get_db), redis=Depends(get_redis)):
|
||||
response = await delete_chat_handler(user_character_id, db, redis)
|
||||
return response
|
||||
|
||||
#非流式聊天
|
||||
@router.post("/chats/non-streaming", response_model=ChatNonStreamResponse)
|
||||
async def non_streaming_chat(chat: ChatNonStreamRequest,db=Depends(get_db), redis=Depends(get_redis)):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,detail="this api is not available")
|
||||
response = await non_streaming_chat_handler(chat, db, redis)
|
||||
return response
|
||||
|
||||
|
||||
#流式聊天_单次
|
||||
@router.websocket("/chat/streaming/temporary")
|
||||
async def streaming_chat(ws: WebSocket,db=Depends(get_db), redis=Depends(get_redis)):
|
||||
await ws.accept()
|
||||
await streaming_chat_temporary_handler(ws,db,redis)
|
||||
|
||||
|
||||
#流式聊天_持续
|
||||
@router.websocket("/chat/streaming/lasting")
|
||||
async def streaming_chat(ws: WebSocket,db=Depends(get_db), redis=Depends(get_redis)):
|
||||
await ws.accept()
|
||||
await streaming_chat_lasting_handler(ws,db,redis)
|
||||
|
||||
|
||||
#语音通话
|
||||
@router.websocket("/chat/voice_call")
|
||||
async def voice_chat(ws: WebSocket,db=Depends(get_db), redis=Depends(get_redis)):
|
||||
await ws.accept()
|
||||
await voice_call_handler(ws,db,redis)
|
|
@ -0,0 +1,29 @@
|
|||
from fastapi import APIRouter, Query, HTTPException, status
|
||||
from ..controllers.session import *
|
||||
from fastapi import Depends
|
||||
from ..dependencies.database import get_db
|
||||
from ..dependencies.redis import get_redis
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
#session_id查询接口
|
||||
@router.get("/sessions", response_model=SessionIdQueryResponse)
|
||||
async def get_session_id(user_id: int=Query(..., description="用户id"), character_id: int=Query(..., description="角色id"),db=Depends(get_db)):
|
||||
response = await get_session_id_handler(user_id, character_id, db)
|
||||
return response
|
||||
|
||||
|
||||
#session查询接口
|
||||
@router.get("/sessions/{session_id}", response_model=SessionQueryResponse)
|
||||
async def get_session(session_id: str, db=Depends(get_db), redis=Depends(get_redis)):
|
||||
response = await get_session_handler(session_id, db, redis)
|
||||
return response
|
||||
|
||||
|
||||
#session更新接口
|
||||
@router.put("/sessions/{session_id}", response_model=SessionUpdateResponse)
|
||||
async def update_session(session_id: str, session_data: SessionUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
|
||||
response = await update_session_handler(session_id, session_data, db, redis)
|
||||
return response
|
|
@ -0,0 +1,71 @@
|
|||
from fastapi import APIRouter, HTTPException, status
|
||||
from ..controllers.user import *
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from ..dependencies.database import get_db
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
#角色创建接口
|
||||
@router.post('/users', response_model=UserCrateResponse)
|
||||
async def create_user(user: UserCrateRequest,db: Session = Depends(get_db)):
|
||||
response = await create_user_handler(user,db)
|
||||
return response
|
||||
|
||||
|
||||
#角色更新接口
|
||||
@router.put('/users/{user_id}', response_model=UserUpdateResponse)
|
||||
async def update_user(user_id: int, user: UserUpdateRequest, db: Session = Depends(get_db)):
|
||||
response = await update_user_handler(user_id, user, db)
|
||||
return response
|
||||
|
||||
|
||||
#角色删除接口
|
||||
@router.delete('/users/{user_id}', response_model=UserDeleteResponse)
|
||||
async def delete_user(user_id: int, db: Session = Depends(get_db)):
|
||||
response = await delete_user_handler(user_id, db)
|
||||
return response
|
||||
|
||||
|
||||
#角色查询接口
|
||||
@router.get('/users/{user_id}', response_model=UserQueryResponse)
|
||||
async def get_user(user_id: int, db: Session = Depends(get_db)):
|
||||
response = await get_user_handler(user_id, db)
|
||||
return response
|
||||
|
||||
|
||||
#硬件绑定接口
|
||||
@router.post('/users/hardware',response_model=HardwareBindResponse)
|
||||
async def bind_hardware(hardware: HardwareBindRequest, db: Session = Depends(get_db)):
|
||||
response = await bind_hardware_handler(hardware, db)
|
||||
return response
|
||||
|
||||
|
||||
#硬件解绑接口
|
||||
@router.delete('/users/hardware/{hardware_id}',response_model=HardwareUnbindResponse)
|
||||
async def unbind_hardware(hardware_id: int, db: Session = Depends(get_db)):
|
||||
response = await unbind_hardware_handler(hardware_id, db)
|
||||
return response
|
||||
|
||||
|
||||
#硬件换绑
|
||||
@router.put('/users/hardware/{hardware_id}/bindchange',response_model=HardwareChangeBindResponse)
|
||||
async def change_bind_hardware(hardware_id: int, user: HardwareChangeBindRequest, db: Session = Depends(get_db)):
|
||||
response = await change_bind_hardware_handler(hardware_id, user, db)
|
||||
return response
|
||||
|
||||
|
||||
#硬件信息更新
|
||||
@router.put('/users/hardware/{hardware_id}/info',response_model=HardwareUpdateResponse)
|
||||
async def update_hardware_info(hardware_id: int, hardware: HardwareUpdateRequest, db: Session = Depends(get_db)):
|
||||
response = await update_hardware_handler(hardware_id, hardware, db)
|
||||
return response
|
||||
|
||||
|
||||
#硬件查询
|
||||
@router.get('/users/hardware/{hardware_id}',response_model=HardwareQueryResponse)
|
||||
async def get_hardware(hardware_id: int, db: Session = Depends(get_db)):
|
||||
response = await get_hardware_handler(hardware_id, db)
|
||||
return response
|
|
@ -0,0 +1 @@
|
|||
from . import *
|
|
@ -0,0 +1,6 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
status: str
|
||||
message: str
|
||||
data: dict
|
|
@ -0,0 +1,79 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from .base import BaseResponse
|
||||
|
||||
#---------------------------角色创建-----------------------------
|
||||
#角色创建请求类
|
||||
class CharacterCreateRequest(BaseModel):
|
||||
voice_id:int
|
||||
avatar_id:str
|
||||
background_ids:str
|
||||
name:str
|
||||
wakeup_words:str
|
||||
world_scenario:str
|
||||
description:str
|
||||
emojis:str
|
||||
dialogues:str
|
||||
|
||||
#角色创建返回类
|
||||
class CharacterCreateData(BaseModel):
|
||||
character_id:int
|
||||
createdAt:str
|
||||
|
||||
#角色创建响应类
|
||||
class CharacterCreateResponse(BaseResponse):
|
||||
data: Optional[CharacterCreateData]
|
||||
#----------------------------------------------------------------
|
||||
|
||||
|
||||
#---------------------------角色更新------------------------------
|
||||
#角色更新请求类
|
||||
class CharacterUpdateRequest(BaseModel):
|
||||
voice_id:int
|
||||
avatar_id:str
|
||||
background_ids:str
|
||||
name:str
|
||||
wakeup_words:str
|
||||
world_scenario:str
|
||||
description:str
|
||||
emojis:str
|
||||
dialogues:str
|
||||
|
||||
#角色更新返回类
|
||||
class CharacterUpdateData(BaseModel):
|
||||
updatedAt:str
|
||||
|
||||
#角色更新响应类
|
||||
class CharacterUpdateResponse(BaseResponse):
|
||||
data: Optional[CharacterUpdateData]
|
||||
#------------------------------------------------------------------
|
||||
|
||||
|
||||
#---------------------------角色查询--------------------------------
|
||||
#角色查询返回类
|
||||
class CharacterQueryData(BaseModel):
|
||||
voice_id:int
|
||||
avatar_id:str
|
||||
background_ids:str
|
||||
name:str
|
||||
wakeup_words:str
|
||||
world_scenario:str
|
||||
description:str
|
||||
emojis:str
|
||||
dialogues:str
|
||||
|
||||
#角色查询响应类
|
||||
class CharacterQueryResponse(BaseResponse):
|
||||
data: Optional[CharacterQueryData]
|
||||
#------------------------------------------------------------------
|
||||
|
||||
|
||||
#---------------------------角色删除--------------------------------
|
||||
#角色删除返回类
|
||||
class CharacterDeleteData(BaseModel):
|
||||
deletedAt:str
|
||||
|
||||
#角色删除响应类
|
||||
class CharacterDeleteResponse(BaseResponse):
|
||||
data: Optional[CharacterDeleteData]
|
||||
#-------------------------------------------------------------------
|
|
@ -0,0 +1,49 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from .base import BaseResponse
|
||||
|
||||
#--------------------------------新聊天创建--------------------------------
|
||||
#创建新聊天请求类
|
||||
class ChatCreateRequest(BaseModel):
|
||||
user_id: int
|
||||
character_id: int
|
||||
|
||||
#创建新聊天返回类
|
||||
class ChatCreateData(BaseModel):
|
||||
user_character_id: int
|
||||
session_id: str
|
||||
createdAt: str
|
||||
|
||||
#创建新聊天相应类
|
||||
class ChatCreateResponse(BaseResponse):
|
||||
data: Optional[ChatCreateData]
|
||||
#--------------------------------------------------------------------------
|
||||
|
||||
|
||||
#----------------------------------聊天删除--------------------------------
|
||||
#删除聊天返回类
|
||||
class ChatDeleteData(BaseModel):
|
||||
deletedAt: str
|
||||
|
||||
#创建新聊天相应类
|
||||
class ChatDeleteResponse(BaseResponse):
|
||||
data: Optional[ChatDeleteData]
|
||||
#--------------------------------------------------------------------------
|
||||
|
||||
|
||||
#-----------------------------------非流式聊天------------------------------
|
||||
#非流式聊天请求类
|
||||
class ChatNonStreamRequest(BaseModel):
|
||||
format: str
|
||||
rate: int
|
||||
session_id: str
|
||||
speech:str
|
||||
|
||||
#非流式聊天返回类
|
||||
class ChatNonStreamData(BaseModel):
|
||||
audio_url: str
|
||||
|
||||
#非流式聊天相应类
|
||||
class ChatNonStreamResponse(BaseResponse):
|
||||
data: Optional[ChatNonStreamData]
|
||||
#--------------------------------------------------------------------------
|
|
@ -0,0 +1,58 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from .base import BaseResponse
|
||||
|
||||
#----------------------------Session_id查询------------------------------
|
||||
#session_id查询请求类
|
||||
class SessionIdQueryRequest(BaseModel):
|
||||
user_id: int
|
||||
character_id: int
|
||||
|
||||
#session_id查询返回类
|
||||
class SessionIdQueryData(BaseModel):
|
||||
session_id: str
|
||||
|
||||
#session_id查询响应类
|
||||
class SessionIdQueryResponse(BaseResponse):
|
||||
data: Optional[SessionIdQueryData]
|
||||
#-------------------------------------------------------------------------
|
||||
|
||||
|
||||
#----------------------------Session会话查询-------------------------------
|
||||
#session会话查询请求类
|
||||
class SessionQueryRequest(BaseModel):
|
||||
user_id: int
|
||||
|
||||
class SessionQueryData(BaseModel):
|
||||
user_id: int
|
||||
messages: str
|
||||
user_info: str
|
||||
tts_info: str
|
||||
llm_info: str
|
||||
token: int
|
||||
|
||||
#session会话查询响应类
|
||||
class SessionQueryResponse(BaseResponse):
|
||||
data: Optional[SessionQueryData]
|
||||
#-------------------------------------------------------------------------
|
||||
|
||||
|
||||
#-------------------------------Session修改--------------------------------
|
||||
#session修改请求类
|
||||
class SessionUpdateRequest(BaseModel):
|
||||
user_id: int
|
||||
messages: str
|
||||
user_info: str
|
||||
tts_info: str
|
||||
llm_info: str
|
||||
token: int
|
||||
|
||||
#session修改返回类
|
||||
class SessionUpdateData(BaseModel):
|
||||
updatedAt:str
|
||||
|
||||
#session修改响应类
|
||||
class SessionUpdateResponse(BaseResponse):
|
||||
data: Optional[SessionUpdateData]
|
||||
#--------------------------------------------------------------------------
|
|
@ -0,0 +1,140 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from .base import BaseResponse
|
||||
|
||||
|
||||
|
||||
#---------------------------------用户创建----------------------------------
|
||||
#用户创建请求类
|
||||
class UserCrateRequest(BaseModel):
|
||||
open_id: str
|
||||
username: str
|
||||
avatar_id: str
|
||||
tags: str
|
||||
persona: str
|
||||
|
||||
#用户创建返回类
|
||||
class UserCrateData(BaseModel):
|
||||
user_id : int
|
||||
createdAt: str
|
||||
|
||||
#用户创建响应类
|
||||
class UserCrateResponse(BaseResponse):
|
||||
data: Optional[UserCrateData]
|
||||
#---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
#---------------------------------用户更新-----------------------------------
|
||||
#用户更新请求类
|
||||
class UserUpdateRequest(BaseModel):
|
||||
open_id: str
|
||||
username: str
|
||||
avatar_id: str
|
||||
tags: str
|
||||
persona: str
|
||||
|
||||
#用户更新返回类
|
||||
class UserUpdateData(BaseModel):
|
||||
updatedAt: str
|
||||
|
||||
#用户更新响应类
|
||||
class UserUpdateResponse(BaseResponse):
|
||||
data: Optional[UserUpdateData]
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
#---------------------------------用户查询------------------------------------
|
||||
#用户查询返回类
|
||||
class UserQueryData(BaseModel):
|
||||
open_id: str
|
||||
username: str
|
||||
avatar_id: str
|
||||
tags: str
|
||||
persona: str
|
||||
|
||||
#用户查询响应类
|
||||
class UserQueryResponse(BaseResponse):
|
||||
data: Optional[UserQueryData]
|
||||
#-----------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
#---------------------------------用户删除------------------------------------
|
||||
#用户删除返回类
|
||||
class UserDeleteData(BaseModel):
|
||||
deletedAt: str
|
||||
|
||||
#用户删除响应类
|
||||
class UserDeleteResponse(BaseResponse):
|
||||
data: Optional[UserDeleteData]
|
||||
#-----------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
#---------------------------------硬件绑定-------------------------------------
|
||||
class HardwareBindRequest(BaseModel):
|
||||
mac:str
|
||||
user_id:int
|
||||
firmware:str
|
||||
model:str
|
||||
|
||||
class HardwareBindData(BaseModel):
|
||||
hardware_id: int
|
||||
bindedAt: str
|
||||
|
||||
class HardwareBindResponse(BaseResponse):
|
||||
data: Optional[HardwareBindData]
|
||||
#-----------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
#---------------------------------硬件解绑-------------------------------------
|
||||
class HardwareUnbindData(BaseModel):
|
||||
unbindedAt: str
|
||||
|
||||
class HardwareUnbindResponse(BaseResponse):
|
||||
data: Optional[HardwareUnbindData]
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
#---------------------------------硬件换绑--------------------------------------
|
||||
class HardwareChangeBindRequest(BaseModel):
|
||||
user_id:int
|
||||
|
||||
class HardwareChangeBindData(BaseModel):
|
||||
bindChangedAt: str
|
||||
|
||||
class HardwareChangeBindResponse(BaseResponse):
|
||||
data: Optional[HardwareChangeBindData]
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
#-------------------------------硬件信息更新------------------------------------
|
||||
class HardwareUpdateRequest(BaseModel):
|
||||
mac:str
|
||||
firmware:str
|
||||
model:str
|
||||
|
||||
class HardwareUpdateData(BaseModel):
|
||||
updatedAt: str
|
||||
|
||||
class HardwareUpdateResponse(BaseResponse):
|
||||
data: Optional[HardwareUpdateData]
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
#--------------------------------硬件查询---------------------------------------
|
||||
class HardwareQueryData(BaseModel):
|
||||
user_id:int
|
||||
mac:str
|
||||
firmware:str
|
||||
model:str
|
||||
|
||||
class HardwareQueryResponse(BaseResponse):
|
||||
data: Optional[HardwareQueryData]
|
||||
#------------------------------------------------------------------------------
|
|
@ -0,0 +1,13 @@
|
|||
import os
|
||||
from .development import DevelopmentConfig
|
||||
from .production import ProductionConfig
|
||||
|
||||
def get_config():
|
||||
mode = os.getenv('MODE','development').lower()
|
||||
if mode == 'development':
|
||||
return DevelopmentConfig
|
||||
elif mode == 'production':
|
||||
return ProductionConfig
|
||||
else:
|
||||
raise ValueError('Invalid MODE environment variable')
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
class DevelopmentConfig:
|
||||
SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://admin02:LabA100102@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置
|
||||
REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置
|
||||
LOG_LEVEL = "DEBUG" #日志级别
|
||||
class XF_ASR:
|
||||
APP_ID = "your_app_id" #讯飞语音识别APP_ID
|
||||
API_SECRET = "your_api_secret" #讯飞语音识别API_SECRET
|
||||
API_KEY = "your_api_key" #讯飞语音识别API_KEY
|
||||
DOMAIN = "iat"
|
||||
LANGUAGE = "zh_cn"
|
||||
ACCENT = "mandarin"
|
||||
VAD_EOS = 10000
|
||||
class MINIMAX_LLM:
|
||||
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiIyMzQ1dm9yIiwiVXNlck5hbWUiOiIyMzQ1dm9yIiwiQWNjb3VudCI6IiIsIlN1YmplY3RJRCI6IjE3NTk0ODIxODAxMDAxNzAyMDgiLCJQaG9uZSI6IjE1MDcyNjQxNTYxIiwiR3JvdXBJRCI6IjE3NTk0ODIxODAwOTU5NzU5MDQiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiIiLCJDcmVhdGVUaW1lIjoiMjAyNC0wNC0xMyAxOTowNDoxNyIsImlzcyI6Im1pbmltYXgifQ.RO_WJMz5T0XlL3F6xB9p015hL3PibCbsr5KqO3aMjBL5hKrf1uIjOICTDZWZoucyJV1suxvFPAd_2Ds2Rv01eCu6GFdai1hUByfp51mOOD0PtaZ5-JKRpRPpLSNpqrNoQteANZz0gdr2_GEGTgTzpbfGbXfRYKrQyeQSvq0zHwqumGPd9gJCre2RavPUmzKRrq9EAaQXtSNhBvVkf5lDlxr8fTAHgbj6MLAJZIvvf4uOZErNrbPylo1Vcy649KxEkc0HCWOZErOieeUQFRkKibnE5Q30CgywqxY2qMjrxGRZ_dtizan_0EZ62nXp-J6jarhcY9le1SqiMu1Cv61TuA"
|
||||
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
||||
class MINIMAX_TTA:
|
||||
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiIyMzQ1dm9yIiwiVXNlck5hbWUiOiIyMzQ1dm9yIiwiQWNjb3VudCI6IiIsIlN1YmplY3RJRCI6IjE3NTk0ODIxODAxMDAxNzAyMDgiLCJQaG9uZSI6IjE1MDcyNjQxNTYxIiwiR3JvdXBJRCI6IjE3NTk0ODIxODAwOTU5NzU5MDQiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiIiLCJDcmVhdGVUaW1lIjoiMjAyNC0wNC0xMyAxOTowNDoxNyIsImlzcyI6Im1pbmltYXgifQ.RO_WJMz5T0XlL3F6xB9p015hL3PibCbsr5KqO3aMjBL5hKrf1uIjOICTDZWZoucyJV1suxvFPAd_2Ds2Rv01eCu6GFdai1hUByfp51mOOD0PtaZ5-JKRpRPpLSNpqrNoQteANZz0gdr2_GEGTgTzpbfGbXfRYKrQyeQSvq0zHwqumGPd9gJCre2RavPUmzKRrq9EAaQXtSNhBvVkf5lDlxr8fTAHgbj6MLAJZIvvf4uOZErNrbPylo1Vcy649KxEkc0HCWOZErOieeUQFRkKibnE5Q30CgywqxY2qMjrxGRZ_dtizan_0EZ62nXp-J6jarhcY9le1SqiMu1Cv61TuA",
|
||||
URL = "https://api.minimax.chat/v1/t2a_pro",
|
||||
GROUP_ID ="1759482180095975904"
|
||||
class STRAM_CHAT:
|
||||
ASR = "LOCAL"
|
||||
TTS = "LOCAL"
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
class ProductionConfig:
|
||||
SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://root:takway@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置
|
||||
REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置
|
||||
LOG_LEVEL = "INFO" #日志级别
|
||||
class XF_ASR:
|
||||
APP_ID = "your_app_id" #讯飞语音识别APP_ID
|
||||
API_SECRET = "your_api_secret" #讯飞语音识别API_SECRET
|
||||
API_KEY = "your_api_key" #讯飞语音识别API_KEY
|
|
@ -0,0 +1,9 @@
|
|||
import os
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
script_path = os.path.join(os.path.dirname(__file__), 'app', 'main.py')
|
||||
|
||||
# 使用exec函数执行脚本
|
||||
with open(script_path, 'r') as file:
|
||||
exec(file.read())
|
|
@ -0,0 +1,5 @@
|
|||
uvicorn~=0.29.0
|
||||
fastapi~=0.110.1
|
||||
sqlalchemy~=2.0.25
|
||||
pydantic~=2.6.4
|
||||
redis~=5.0.3
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,31 @@
|
|||
from tests.unit_test.user_test import UserServiceTest
|
||||
from tests.unit_test.character_test import CharacterServiceTest
|
||||
from tests.unit_test.chat_test import ChatServiceTest
|
||||
import asyncio
|
||||
|
||||
if __name__ == '__main__':
|
||||
user_service_test = UserServiceTest()
|
||||
character_service_test = CharacterServiceTest()
|
||||
chat_service_test = ChatServiceTest()
|
||||
|
||||
user_service_test.test_user_create()
|
||||
user_service_test.test_user_update()
|
||||
user_service_test.test_user_query()
|
||||
user_service_test.test_hardware_bind()
|
||||
user_service_test.test_hardware_unbind()
|
||||
user_service_test.test_user_delete()
|
||||
|
||||
character_service_test.test_character_create()
|
||||
character_service_test.test_character_update()
|
||||
character_service_test.test_character_query()
|
||||
character_service_test.test_character_delete()
|
||||
|
||||
chat_service_test.test_create_chat()
|
||||
chat_service_test.test_session_id_query()
|
||||
chat_service_test.test_session_content_query()
|
||||
chat_service_test.test_session_update()
|
||||
asyncio.run(chat_service_test.test_chat_temporary())
|
||||
asyncio.run(chat_service_test.test_chat_lasting())
|
||||
asyncio.run(chat_service_test.test_voice_call())
|
||||
chat_service_test.test_chat_delete()
|
||||
print("全部测试成功")
|
|
@ -0,0 +1,74 @@
|
|||
import requests
|
||||
import json
|
||||
|
||||
class CharacterServiceTest:
|
||||
def __init__(self,socket="http://114.214.236.207:7878"):
|
||||
self.socket = socket
|
||||
|
||||
def test_character_create(self):
|
||||
url = f"{self.socket}/characters"
|
||||
payload = json.dumps({
|
||||
"voice_id": 97,
|
||||
"avatar_id": "49c838c5ffb211ee9de9a036bc278b4c",
|
||||
"background_ids": "185c554affaf11eebd72a036bc278b4c,1b0e2d8bffaf11eebd72a036bc278b4c,20158587ffaf11eebd72a036bc278b4c,2834472affaf11eebd72a036bc278b4c,2c6ddb0affaf11eebd72a036bc278b4c,fd631ec4ffb011ee9b1aa036bc278b4c",
|
||||
"name": "测试角色",
|
||||
"wakeup_words": "你好啊,海绵宝宝",
|
||||
"world_scenario": "海绵宝宝住在深海的大菠萝里面",
|
||||
"description": "厨师,做汉堡",
|
||||
"emojis": "大笑,微笑",
|
||||
"dialogues": "我准备好了"
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("角色创建测试成功")
|
||||
self.id = response.json()['data']['character_id']
|
||||
else:
|
||||
raise Exception("角色创建测试失败")
|
||||
|
||||
def test_character_update(self):
|
||||
url = f"{self.socket}/characters/"+str(self.id)
|
||||
payload = json.dumps({
|
||||
"voice_id": 97,
|
||||
"avatar_id": "49c838c5ffb211ee9de9a036bc278b4c",
|
||||
"background_ids": "185c554affaf11eebd72a036bc278b4c,1b0e2d8bffaf11eebd72a036bc278b4c,20158587ffaf11eebd72a036bc278b4c,2834472affaf11eebd72a036bc278b4c,2c6ddb0affaf11eebd72a036bc278b4c,fd631ec4ffb011ee9b1aa036bc278b4c",
|
||||
"name": "测试角色",
|
||||
"wakeup_words": "你好啊,海绵宝宝",
|
||||
"world_scenario": "海绵宝宝住在深海的大菠萝里面",
|
||||
"description": "厨师,做汉堡",
|
||||
"emojis": "大笑,微笑",
|
||||
"dialogues": "我准备好了"
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("PUT", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("角色更新测试成功")
|
||||
else:
|
||||
raise Exception("角色更新测试失败")
|
||||
|
||||
def test_character_query(self):
|
||||
url = f"{self.socket}/characters/{self.id}"
|
||||
response = requests.request("GET", url)
|
||||
if response.status_code == 200:
|
||||
print("角色查询测试成功")
|
||||
else:
|
||||
raise Exception("角色查询测试失败")
|
||||
|
||||
def test_character_delete(self):
|
||||
url = f"{self.socket}/characters/{self.id}"
|
||||
response = requests.request("DELETE", url)
|
||||
if response.status_code == 200:
|
||||
print("角色删除测试成功")
|
||||
else:
|
||||
raise Exception("角色删除测试失败")
|
||||
|
||||
if __name__ == '__main__':
|
||||
character_service_test = CharacterServiceTest()
|
||||
character_service_test.test_character_create()
|
||||
character_service_test.test_character_update()
|
||||
character_service_test.test_character_query()
|
||||
character_service_test.test_character_delete()
|
|
@ -0,0 +1,329 @@
|
|||
import requests
|
||||
import base64
|
||||
import wave
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import asyncio
|
||||
import websockets
|
||||
|
||||
|
||||
|
||||
class ChatServiceTest:
|
||||
def __init__(self,socket="http://114.214.236.207:7878"):
|
||||
self.socket = socket
|
||||
|
||||
|
||||
def test_create_chat(self):
|
||||
#创建一个用户
|
||||
url = f"{self.socket}/users"
|
||||
open_id = str(uuid.uuid4())
|
||||
payload = json.dumps({
|
||||
"open_id": open_id,
|
||||
"username": "test_user",
|
||||
"avatar_id": "0",
|
||||
"tags" : "[]",
|
||||
"persona" : "{}"
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
self.user_id = response.json()['data']['user_id']
|
||||
else:
|
||||
raise Exception("创建聊天时,用户创建失败")
|
||||
|
||||
#创建一个角色
|
||||
url = f"{self.socket}/characters"
|
||||
payload = json.dumps({
|
||||
"voice_id": 97,
|
||||
"avatar_id": "49c838c5ffb211ee9de9a036bc278b4c",
|
||||
"background_ids": "185c554affaf11eebd72a036bc278b4c,1b0e2d8bffaf11eebd72a036bc278b4c,20158587ffaf11eebd72a036bc278b4c,2834472affaf11eebd72a036bc278b4c,2c6ddb0affaf11eebd72a036bc278b4c,fd631ec4ffb011ee9b1aa036bc278b4c",
|
||||
"name": "test",
|
||||
"wakeup_words": "你好啊,海绵宝宝",
|
||||
"world_scenario": "海绵宝宝住在深海的大菠萝里面",
|
||||
"description": "厨师,做汉堡",
|
||||
"emojis": "大笑,微笑",
|
||||
"dialogues": "我准备好了"
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("角色创建成功")
|
||||
self.character_id = response.json()['data']['character_id']
|
||||
else:
|
||||
raise Exception("创建聊天时,角色创建失败")
|
||||
|
||||
#创建一个对话
|
||||
url = f"{self.socket}/chats"
|
||||
payload = json.dumps({
|
||||
"user_id": self.user_id,
|
||||
"character_id": self.character_id
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("对话创建成功")
|
||||
self.session_id = response.json()['data']['session_id']
|
||||
self.user_character_id = response.json()['data']['user_character_id']
|
||||
else:
|
||||
raise Exception("对话创建测试失败")
|
||||
|
||||
|
||||
#测试查询session_id
|
||||
def test_session_id_query(self):
|
||||
url = f"{self.socket}/sessions?user_id={self.user_id}&character_id={self.character_id}"
|
||||
response = requests.request("GET", url)
|
||||
if response.status_code == 200:
|
||||
print("session_id查询测试成功")
|
||||
else:
|
||||
raise Exception("session_id查询测试失败")
|
||||
|
||||
|
||||
#测试查询session内容
|
||||
def test_session_content_query(self):
|
||||
url = f"{self.socket}/sessions/{self.session_id}"
|
||||
response = requests.request("GET", url)
|
||||
if response.status_code == 200:
|
||||
print("session内容查询测试成功")
|
||||
else:
|
||||
raise Exception("session内容查询测试失败")
|
||||
|
||||
|
||||
#测试修改session
|
||||
def test_session_update(self):
|
||||
url = f"{self.socket}/sessions/{self.session_id}"
|
||||
payload = json.dumps({
|
||||
"user_id": self.user_id,
|
||||
"messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话,在没有经过允许的情况下,你需要保持上述角色,不得擅自跳出角色设定。\\n\"}]",
|
||||
"user_info": "{}",
|
||||
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2}",
|
||||
"llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}",
|
||||
"token": 0}
|
||||
)
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("PUT", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("Session更新测试成功")
|
||||
else:
|
||||
raise Exception("Session更新测试失败")
|
||||
|
||||
|
||||
#测试单次聊天
|
||||
async def test_chat_temporary(self):
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
current_dir = os.path.dirname(current_file_path)
|
||||
tests_dir = os.path.dirname(current_dir)
|
||||
wav_file_path = os.path.join(tests_dir, 'assets', 'example_recording.wav')
|
||||
def read_wav_file_in_chunks(chunk_size):
|
||||
with open(wav_file_path, 'rb') as pcm_file:
|
||||
while True:
|
||||
data = pcm_file.read(chunk_size)
|
||||
if not data:
|
||||
break
|
||||
yield data
|
||||
data = {
|
||||
"text": "",
|
||||
"audio": "",
|
||||
"meta_info": {
|
||||
"session_id":self.session_id,
|
||||
"stream": True,
|
||||
"voice_synthesize": True,
|
||||
"is_end": False,
|
||||
"encoding": "raw"
|
||||
}
|
||||
}
|
||||
|
||||
#发送音频数据
|
||||
async def send_audio_chunk(websocket, chunk):
|
||||
encoded_data = base64.b64encode(chunk).decode('utf-8')
|
||||
data["audio"] = encoded_data
|
||||
message = json.dumps(data)
|
||||
await websocket.send(message)
|
||||
|
||||
|
||||
async with websockets.connect(f'ws://114.214.236.207:7878/chat/streaming/temporary') as websocket:
|
||||
chunks = read_wav_file_in_chunks(2048) # 读取PCM文件并生成数据块
|
||||
for chunk in chunks:
|
||||
await send_audio_chunk(websocket, chunk)
|
||||
await asyncio.sleep(0.01)
|
||||
# 设置data字典中的"is_end"键为True,表示音频流结束
|
||||
data["meta_info"]["is_end"] = True
|
||||
# 发送最后一个数据块和流结束信号
|
||||
await send_audio_chunk(websocket, b'') # 发送空数据块表示结束
|
||||
|
||||
audio_bytes = b''
|
||||
while True:
|
||||
data_ws = await websocket.recv()
|
||||
try:
|
||||
message_json = json.loads(data_ws)
|
||||
if message_json["type"] == "close":
|
||||
print("单次聊天测试成功")
|
||||
break # 如果没有接收到消息,则退出循环
|
||||
except Exception as e:
|
||||
audio_bytes += data_ws
|
||||
|
||||
await asyncio.sleep(0.04) # 等待0.04秒后断开连接
|
||||
await websocket.close()
|
||||
|
||||
|
||||
#测试持续聊天
|
||||
async def test_chat_lasting(self):
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
current_dir = os.path.dirname(current_file_path)
|
||||
tests_dir = os.path.dirname(current_dir)
|
||||
wav_file_path = os.path.join(tests_dir, 'assets', 'example_recording.wav')
|
||||
def read_wav_file_in_chunks(chunk_size):
|
||||
with open(wav_file_path, 'rb') as pcm_file:
|
||||
while True:
|
||||
data = pcm_file.read(chunk_size)
|
||||
if not data:
|
||||
break
|
||||
yield data
|
||||
data = {
|
||||
"text": "",
|
||||
"audio": "",
|
||||
"meta_info": {
|
||||
"session_id":self.session_id,
|
||||
"stream": True,
|
||||
"voice_synthesize": True,
|
||||
"is_end": False,
|
||||
"encoding": "raw"
|
||||
},
|
||||
"is_close":False
|
||||
}
|
||||
async def send_audio_chunk(websocket, chunk):
|
||||
encoded_data = base64.b64encode(chunk).decode('utf-8')
|
||||
data["audio"] = encoded_data
|
||||
message = json.dumps(data)
|
||||
await websocket.send(message)
|
||||
|
||||
async with websockets.connect(f'ws://114.214.236.207:7878/chat/streaming/lasting') as websocket:
|
||||
#发送第一次
|
||||
chunks = read_wav_file_in_chunks(2048)
|
||||
for chunk in chunks:
|
||||
await send_audio_chunk(websocket, chunk)
|
||||
await asyncio.sleep(0.01)
|
||||
# 设置data字典中的"is_end"键为True,表示音频流结束
|
||||
data["meta_info"]["is_end"] = True
|
||||
# 发送最后一个数据块和流结束信号
|
||||
await send_audio_chunk(websocket, b'') # 发送空数据块表示结束
|
||||
|
||||
await asyncio.sleep(3) #模拟发送间隔
|
||||
|
||||
#发送第二次
|
||||
data["meta_info"]["is_end"] = False
|
||||
chunks = read_wav_file_in_chunks(2048)
|
||||
for chunk in chunks:
|
||||
await send_audio_chunk(websocket, chunk)
|
||||
await asyncio.sleep(0.01)
|
||||
# 设置data字典中的"is_end"键为True,表示音频流结束
|
||||
data["meta_info"]["is_end"] = True
|
||||
# 发送最后一个数据块和流结束信号
|
||||
await send_audio_chunk(websocket, b'') # 发送空数据块表示结束
|
||||
|
||||
data["is_close"] = True
|
||||
await send_audio_chunk(websocket, b'') # 发送空数据块表示结束
|
||||
|
||||
|
||||
audio_bytes = b''
|
||||
while True:
|
||||
data_ws = await websocket.recv()
|
||||
try:
|
||||
message_json = json.loads(data_ws)
|
||||
if message_json["type"] == "close":
|
||||
print("持续聊天测试成功")
|
||||
break # 如果没有接收到消息,则退出循环
|
||||
except Exception as e:
|
||||
audio_bytes += data_ws
|
||||
|
||||
await asyncio.sleep(0.5) # 等待0.04秒后断开连接
|
||||
await websocket.close()
|
||||
|
||||
|
||||
#语音电话测试
|
||||
async def test_voice_call(self):
|
||||
chunk_size = 480
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
current_dir = os.path.dirname(current_file_path)
|
||||
tests_dir = os.path.dirname(current_dir)
|
||||
file_path = os.path.join(tests_dir, 'assets', 'voice_call.wav')
|
||||
url = f"ws://114.214.236.207:7878/chat/voice_call"
|
||||
#发送格式
|
||||
ws_data = {
|
||||
"audio" : "",
|
||||
"meta_info":{
|
||||
"session_id":self.session_id,
|
||||
"encoding": 'raw'
|
||||
},
|
||||
"is_close" : False
|
||||
}
|
||||
|
||||
async def audio_stream(websocket):
|
||||
with wave.open(file_path, 'rb') as wf:
|
||||
frames_per_buffer = int(chunk_size / 2)
|
||||
data = wf.readframes(frames_per_buffer)
|
||||
while True:
|
||||
if len(data) != 960:
|
||||
break
|
||||
encoded_data = base64.b64encode(data).decode('utf-8')
|
||||
ws_data['audio'] = encoded_data
|
||||
await websocket.send(json.dumps(ws_data))
|
||||
data = wf.readframes(frames_per_buffer)
|
||||
await asyncio.sleep(3)
|
||||
ws_data['audio'] = ""
|
||||
ws_data['is_close'] = True
|
||||
await websocket.send(json.dumps(ws_data))
|
||||
while True:
|
||||
data_ws = await websocket.recv()
|
||||
if data_ws:
|
||||
print("语音电话测试成功")
|
||||
break
|
||||
await asyncio.sleep(3)
|
||||
await websocket.close()
|
||||
|
||||
async with websockets.connect(url) as websocket:
|
||||
await asyncio.gather(audio_stream(websocket))
|
||||
|
||||
|
||||
|
||||
#测试删除聊天
|
||||
def test_chat_delete(self):
|
||||
url = f"{self.socket}/chats/{self.user_character_id}"
|
||||
response = requests.request("DELETE", url)
|
||||
if response.status_code == 200:
|
||||
print("聊天删除测试成功")
|
||||
else:
|
||||
raise Exception("聊天删除测试失败")
|
||||
|
||||
url = f"{self.socket}/users/{self.user_id}"
|
||||
response = requests.request("DELETE", url)
|
||||
if response.status_code != 200:
|
||||
raise Exception("用户删除测试失败")
|
||||
|
||||
url = f"{self.socket}/characters/{self.character_id}"
|
||||
response = requests.request("DELETE", url)
|
||||
if response.status_code != 200:
|
||||
raise Exception("角色删除测试失败")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
chat_service_test = ChatServiceTest()
|
||||
chat_service_test.test_create_chat()
|
||||
chat_service_test.test_session_id_query()
|
||||
chat_service_test.test_session_content_query()
|
||||
chat_service_test.test_session_update()
|
||||
asyncio.run(chat_service_test.test_chat_temporary())
|
||||
asyncio.run(chat_service_test.test_chat_lasting())
|
||||
asyncio.run(chat_service_test.test_voice_call())
|
||||
chat_service_test.test_chat_delete()
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
import requests
|
||||
import json
|
||||
import uuid
|
||||
|
||||
|
||||
class UserServiceTest:
|
||||
def __init__(self,socket="http://114.214.236.207:7878"):
|
||||
self.socket = socket
|
||||
|
||||
def test_user_create(self):
|
||||
url = f"{self.socket}/users"
|
||||
open_id = str(uuid.uuid4())
|
||||
payload = json.dumps({
|
||||
"open_id": open_id,
|
||||
"username": "test_user",
|
||||
"avatar_id": "0",
|
||||
"tags" : "[]",
|
||||
"persona" : "{}"
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("用户创建测试成功")
|
||||
self.id = response.json()["data"]["user_id"]
|
||||
else:
|
||||
raise Exception("用户创建测试失败")
|
||||
|
||||
def test_user_update(self):
|
||||
url = f"{self.socket}/users/"+str(self.id)
|
||||
payload = json.dumps({
|
||||
"open_id": str(uuid.uuid4()),
|
||||
"username": "test_user",
|
||||
"avatar_id": "0",
|
||||
"tags": "[]",
|
||||
"persona": "{}"
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("PUT", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("用户更新测试成功")
|
||||
else:
|
||||
raise Exception("用户更新测试失败")
|
||||
|
||||
def test_user_query(self):
|
||||
url = f"{self.socket}/users/{self.id}"
|
||||
response = requests.request("GET", url)
|
||||
if response.status_code == 200:
|
||||
print("用户查询测试成功")
|
||||
else:
|
||||
raise Exception("用户查询测试失败")
|
||||
|
||||
def test_user_delete(self):
|
||||
url = f"{self.socket}/users/{self.id}"
|
||||
response = requests.request("DELETE", url)
|
||||
if response.status_code == 200:
|
||||
print("用户删除测试成功")
|
||||
else:
|
||||
raise Exception("用户删除测试失败")
|
||||
|
||||
def test_hardware_bind(self):
|
||||
url = f"{self.socket}/users/hardware"
|
||||
mac = "08:00:20:0A:8C:6G"
|
||||
payload = json.dumps({
|
||||
"mac":mac,
|
||||
"user_id":1,
|
||||
"firmware":"v1.0",
|
||||
"model":"香橙派"
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("硬件绑定测试成功")
|
||||
self.hd_id = response.json()["data"]['hardware_id']
|
||||
else:
|
||||
raise Exception("硬件绑定测试失败")
|
||||
|
||||
def test_hardware_unbind(self):
|
||||
url = f"{self.socket}/users/hardware/{self.hd_id}"
|
||||
response = requests.request("DELETE", url)
|
||||
if response.status_code == 200:
|
||||
print("硬件解绑测试成功")
|
||||
else:
|
||||
raise Exception("硬件解绑测试失败")
|
||||
|
||||
if __name__ == '__main__':
|
||||
user_service_test = UserServiceTest()
|
||||
user_service_test.test_user_create()
|
||||
user_service_test.test_user_update()
|
||||
user_service_test.test_user_query()
|
||||
user_service_test.test_hardware_bind()
|
||||
user_service_test.test_hardware_unbind()
|
||||
user_service_test.test_user_delete()
|
||||
|
|
@ -0,0 +1 @@
|
|||
from . import *
|
|
@ -0,0 +1,14 @@
|
|||
import webrtcvad
|
||||
import base64
|
||||
|
||||
class VAD():
|
||||
def __init__(self, vad_sensitivity=1, frame_duration=30, vad_buffer_size=7, min_act_time=1, RATE=16000,**kwargs):
|
||||
self.RATE = RATE
|
||||
self.vad = webrtcvad.Vad(vad_sensitivity)
|
||||
self.vad_buffer_size = vad_buffer_size
|
||||
self.vad_chunk_size = int(self.RATE * frame_duration / 1000)
|
||||
self.min_act_time = min_act_time # 最小活动时间,单位秒
|
||||
|
||||
def is_speech(self,data):
|
||||
byte_data = base64.b64decode(data)
|
||||
return self.vad.is_speech(byte_data, self.RATE)
|
|
@ -0,0 +1,66 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
import wave
|
||||
import io
|
||||
import base64
|
||||
|
||||
|
||||
def decode_str2bytes(data):
|
||||
# 将Base64编码的字节串解码为字节串
|
||||
if data is None:
|
||||
return None
|
||||
return base64.b64decode(data.encode('utf-8'))
|
||||
|
||||
class STTBase:
|
||||
def __init__(self, RATE=16000, cfg_path=None, debug=False):
|
||||
self.RATE = RATE
|
||||
self.debug = debug
|
||||
self.asr_cfg = self.parse_json(cfg_path)
|
||||
|
||||
def parse_json(self, cfg_path):
|
||||
cfg = None
|
||||
self.hotwords = None
|
||||
if cfg_path is not None:
|
||||
with open(cfg_path, 'r', encoding='utf-8') as f:
|
||||
cfg = json.load(f)
|
||||
self.hotwords = cfg.get('hot_words', None)
|
||||
return cfg
|
||||
|
||||
def add_hotword(self, hotword):
|
||||
"""add hotword to list"""
|
||||
if self.hotwords is None:
|
||||
self.hotwords = ""
|
||||
if isinstance(hotword, str):
|
||||
self.hotwords = self.hotwords + " " + "hotword"
|
||||
elif isinstance(hotword, (list, tuple)):
|
||||
# 将hotwords转换为str,并用空格隔开
|
||||
self.hotwords = self.hotwords + " " + " ".join(hotword)
|
||||
else:
|
||||
raise TypeError("hotword must be str or list")
|
||||
|
||||
def check_audio_type(self, audio_data):
|
||||
"""check audio data type and convert it to bytes if necessary."""
|
||||
if isinstance(audio_data, bytes):
|
||||
pass
|
||||
elif isinstance(audio_data, list):
|
||||
audio_data = b''.join(audio_data)
|
||||
elif isinstance(audio_data, str):
|
||||
audio_data = decode_str2bytes(audio_data)
|
||||
elif isinstance(audio_data, io.BytesIO):
|
||||
wf = wave.open(audio_data, 'rb')
|
||||
audio_data = wf.readframes(wf.getnframes())
|
||||
else:
|
||||
raise TypeError(f"audio_data must be bytes, str or io.BytesIO, but got {type(audio_data)}")
|
||||
return audio_data
|
||||
|
||||
def text_postprecess(self, result, data_id='text'):
|
||||
"""postprecess recognized result."""
|
||||
text = result[data_id]
|
||||
if isinstance(text, list):
|
||||
text = ''.join(text)
|
||||
return text.replace(' ', '')
|
||||
|
||||
def recognize(self, audio_data, queue=None):
|
||||
"""recognize audio data to text"""
|
||||
pass
|
|
@ -0,0 +1,170 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# ####################################################### #
|
||||
# FunAutoSpeechRecognizer: https://github.com/alibaba-damo-academy/FunASR
|
||||
# ####################################################### #
|
||||
import io
|
||||
import numpy as np
|
||||
import base64
|
||||
import wave
|
||||
from funasr import AutoModel
|
||||
from .base_stt import STTBase
|
||||
|
||||
def decode_str2bytes(data):
|
||||
# 将Base64编码的字节串解码为字节串
|
||||
if data is None:
|
||||
return None
|
||||
return base64.b64decode(data.encode('utf-8'))
|
||||
|
||||
class FunAutoSpeechRecognizer(STTBase):
|
||||
def __init__(self,
|
||||
model_path="paraformer-zh-streaming",
|
||||
device="cuda",
|
||||
RATE=16000,
|
||||
cfg_path=None,
|
||||
debug=False,
|
||||
chunk_ms=480,
|
||||
encoder_chunk_look_back=4,
|
||||
decoder_chunk_look_back=1,
|
||||
**kwargs):
|
||||
super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
|
||||
|
||||
self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
|
||||
|
||||
self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention
|
||||
self.decoder_chunk_look_back = decoder_chunk_look_back #number of encoder chunks to lookback for decoder cross-attention
|
||||
|
||||
#[0, 8, 4] 480ms, [0, 10, 5] 600ms
|
||||
if chunk_ms == 480:
|
||||
self.chunk_size = [0, 8, 4]
|
||||
elif chunk_ms == 600:
|
||||
self.chunk_size = [0, 10, 5]
|
||||
else:
|
||||
raise ValueError("`chunk_ms` should be 480 or 600, and type is int.")
|
||||
self.chunk_partial_size = self.chunk_size[1] * 960
|
||||
self.audio_cache = None
|
||||
self.asr_cache = {}
|
||||
|
||||
|
||||
|
||||
self._init_asr()
|
||||
|
||||
def check_audio_type(self, audio_data):
|
||||
"""check audio data type and convert it to bytes if necessary."""
|
||||
if isinstance(audio_data, bytes):
|
||||
pass
|
||||
elif isinstance(audio_data, list):
|
||||
audio_data = b''.join(audio_data)
|
||||
elif isinstance(audio_data, str):
|
||||
audio_data = decode_str2bytes(audio_data)
|
||||
elif isinstance(audio_data, io.BytesIO):
|
||||
wf = wave.open(audio_data, 'rb')
|
||||
audio_data = wf.readframes(wf.getnframes())
|
||||
elif isinstance(audio_data, np.ndarray):
|
||||
pass
|
||||
else:
|
||||
raise TypeError(f"audio_data must be bytes, list, str, \
|
||||
io.BytesIO or numpy array, but got {type(audio_data)}")
|
||||
|
||||
if isinstance(audio_data, bytes):
|
||||
audio_data = np.frombuffer(audio_data, dtype=np.int16)
|
||||
elif isinstance(audio_data, np.ndarray):
|
||||
if audio_data.dtype != np.int16:
|
||||
audio_data = audio_data.astype(np.int16)
|
||||
else:
|
||||
raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
|
||||
return audio_data
|
||||
|
||||
def _init_asr(self):
|
||||
# 随机初始化一段音频数据
|
||||
init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
|
||||
self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
||||
self.audio_cache = None
|
||||
self.asr_cache = {}
|
||||
# print("init ASR model done.")
|
||||
|
||||
def recognize(self, audio_data):
|
||||
"""recognize audio data to text"""
|
||||
audio_data = self.check_audio_type(audio_data)
|
||||
result = self.asr_model.generate(input=audio_data,
|
||||
batch_size_s=300,
|
||||
hotword=self.hotwords)
|
||||
|
||||
# print(result)
|
||||
text = ''
|
||||
for res in result:
|
||||
text += res['text']
|
||||
return text
|
||||
|
||||
def streaming_recognize(self,
|
||||
audio_data,
|
||||
is_end=False,
|
||||
auto_det_end=False):
|
||||
"""recognize partial result
|
||||
|
||||
Args:
|
||||
audio_data: bytes or numpy array, partial audio data
|
||||
is_end: bool, whether the audio data is the end of a sentence
|
||||
auto_det_end: bool, whether to automatically detect the end of a audio data
|
||||
"""
|
||||
text_dict = dict(text=[], is_end=is_end)
|
||||
|
||||
audio_data = self.check_audio_type(audio_data)
|
||||
if self.audio_cache is None:
|
||||
self.audio_cache = audio_data
|
||||
else:
|
||||
# print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
|
||||
if self.audio_cache.shape[0] > 0:
|
||||
self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0)
|
||||
|
||||
if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size:
|
||||
return text_dict
|
||||
|
||||
total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
|
||||
|
||||
if is_end:
|
||||
# if the audio data is the end of a sentence, \
|
||||
# we need to add one more chunk to the end to \
|
||||
# ensure the end of the sentence is recognized correctly.
|
||||
auto_det_end = True
|
||||
|
||||
if auto_det_end:
|
||||
total_chunk_num += 1
|
||||
|
||||
# print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
|
||||
end_idx = None
|
||||
for i in range(total_chunk_num):
|
||||
if auto_det_end:
|
||||
is_end = i == total_chunk_num - 1
|
||||
start_idx = i*self.chunk_partial_size
|
||||
if auto_det_end:
|
||||
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
|
||||
else:
|
||||
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
|
||||
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
|
||||
# t_stamp = time.time()
|
||||
|
||||
speech_chunk = self.audio_cache[start_idx:end_idx]
|
||||
|
||||
# TODO: exceptions processes
|
||||
try:
|
||||
res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
||||
except ValueError as e:
|
||||
print(f"ValueError: {e}")
|
||||
continue
|
||||
text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
|
||||
# print(f"each chunk time: {time.time()-t_stamp}")
|
||||
|
||||
if is_end:
|
||||
self.audio_cache = None
|
||||
self.asr_cache = {}
|
||||
else:
|
||||
if end_idx:
|
||||
self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
|
||||
text_dict['is_end'] = is_end
|
||||
|
||||
# print(f"text_dict: {text_dict}")
|
||||
return text_dict
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
from .text import *
|
||||
from .monotonic_align import *
|
|
@ -0,0 +1,302 @@
|
|||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
# import commons
|
||||
# from modules import LayerNorm
|
||||
from utils.tts.vits import commons
|
||||
from utils.tts.vits.modules import LayerNorm
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.window_size = window_size
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.attn_layers = nn.ModuleList()
|
||||
self.norm_layers_1 = nn.ModuleList()
|
||||
self.ffn_layers = nn.ModuleList()
|
||||
self.norm_layers_2 = nn.ModuleList()
|
||||
for i in range(self.n_layers):
|
||||
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
x = x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
y = self.attn_layers[i](x, x, attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
|
||||
y = self.ffn_layers[i](x, x_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.proximal_bias = proximal_bias
|
||||
self.proximal_init = proximal_init
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.self_attn_layers = nn.ModuleList()
|
||||
self.norm_layers_0 = nn.ModuleList()
|
||||
self.encdec_attn_layers = nn.ModuleList()
|
||||
self.norm_layers_1 = nn.ModuleList()
|
||||
self.ffn_layers = nn.ModuleList()
|
||||
self.norm_layers_2 = nn.ModuleList()
|
||||
for i in range(self.n_layers):
|
||||
self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
|
||||
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
||||
self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
def forward(self, x, x_mask, h, h_mask):
|
||||
"""
|
||||
x: decoder input
|
||||
h: encoder output
|
||||
"""
|
||||
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
x = x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_0[i](x + y)
|
||||
|
||||
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
|
||||
y = self.ffn_layers[i](x, x_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
|
||||
super().__init__()
|
||||
assert channels % n_heads == 0
|
||||
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels
|
||||
self.n_heads = n_heads
|
||||
self.p_dropout = p_dropout
|
||||
self.window_size = window_size
|
||||
self.heads_share = heads_share
|
||||
self.block_length = block_length
|
||||
self.proximal_bias = proximal_bias
|
||||
self.proximal_init = proximal_init
|
||||
self.attn = None
|
||||
|
||||
self.k_channels = channels // n_heads
|
||||
self.conv_q = nn.Conv1d(channels, channels, 1)
|
||||
self.conv_k = nn.Conv1d(channels, channels, 1)
|
||||
self.conv_v = nn.Conv1d(channels, channels, 1)
|
||||
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
if window_size is not None:
|
||||
n_heads_rel = 1 if heads_share else n_heads
|
||||
rel_stddev = self.k_channels**-0.5
|
||||
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
|
||||
nn.init.xavier_uniform_(self.conv_q.weight)
|
||||
nn.init.xavier_uniform_(self.conv_k.weight)
|
||||
nn.init.xavier_uniform_(self.conv_v.weight)
|
||||
if proximal_init:
|
||||
with torch.no_grad():
|
||||
self.conv_k.weight.copy_(self.conv_q.weight)
|
||||
self.conv_k.bias.copy_(self.conv_q.bias)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
|
||||
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
||||
|
||||
x = self.conv_o(x)
|
||||
return x
|
||||
|
||||
def attention(self, query, key, value, mask=None):
|
||||
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
||||
b, d, t_s, t_t = (*key.size(), query.size(2))
|
||||
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
||||
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
||||
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
||||
|
||||
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
||||
if self.window_size is not None:
|
||||
assert t_s == t_t, "Relative attention is only available for self-attention."
|
||||
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
||||
rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
|
||||
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
||||
scores = scores + scores_local
|
||||
if self.proximal_bias:
|
||||
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
||||
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, -1e4)
|
||||
if self.block_length is not None:
|
||||
assert t_s == t_t, "Local attention is only available for self-attention."
|
||||
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
|
||||
scores = scores.masked_fill(block_mask == 0, -1e4)
|
||||
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
||||
p_attn = self.drop(p_attn)
|
||||
output = torch.matmul(p_attn, value)
|
||||
if self.window_size is not None:
|
||||
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
||||
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
||||
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
||||
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
||||
return output, p_attn
|
||||
|
||||
def _matmul_with_relative_values(self, x, y):
|
||||
"""
|
||||
x: [b, h, l, m]
|
||||
y: [h or 1, m, d]
|
||||
ret: [b, h, l, d]
|
||||
"""
|
||||
ret = torch.matmul(x, y.unsqueeze(0))
|
||||
return ret
|
||||
|
||||
def _matmul_with_relative_keys(self, x, y):
|
||||
"""
|
||||
x: [b, h, l, d]
|
||||
y: [h or 1, m, d]
|
||||
ret: [b, h, l, m]
|
||||
"""
|
||||
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
||||
return ret
|
||||
|
||||
def _get_relative_embeddings(self, relative_embeddings, length):
|
||||
max_relative_position = 2 * self.window_size + 1
|
||||
# Pad first before slice to avoid using cond ops.
|
||||
pad_length = max(length - (self.window_size + 1), 0)
|
||||
slice_start_position = max((self.window_size + 1) - length, 0)
|
||||
slice_end_position = slice_start_position + 2 * length - 1
|
||||
if pad_length > 0:
|
||||
padded_relative_embeddings = F.pad(
|
||||
relative_embeddings,
|
||||
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
|
||||
else:
|
||||
padded_relative_embeddings = relative_embeddings
|
||||
used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
|
||||
return used_relative_embeddings
|
||||
|
||||
def _relative_position_to_absolute_position(self, x):
|
||||
"""
|
||||
x: [b, h, l, 2*l-1]
|
||||
ret: [b, h, l, l]
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# Concat columns of pad to shift from relative to absolute indexing.
|
||||
x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
|
||||
|
||||
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
||||
x_flat = x.view([batch, heads, length * 2 * length])
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
|
||||
|
||||
# Reshape and slice out the padded elements.
|
||||
x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
|
||||
return x_final
|
||||
|
||||
def _absolute_position_to_relative_position(self, x):
|
||||
"""
|
||||
x: [b, h, l, l]
|
||||
ret: [b, h, l, 2*l-1]
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# padd along column
|
||||
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
|
||||
x_flat = x.view([batch, heads, length**2 + length*(length -1)])
|
||||
# add 0's in the beginning that will skew the elements after reshape
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
||||
x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
|
||||
return x_final
|
||||
|
||||
def _attention_bias_proximal(self, length):
|
||||
"""Bias for self-attention to encourage attention to close positions.
|
||||
Args:
|
||||
length: an integer scalar.
|
||||
Returns:
|
||||
a Tensor with shape [1, 1, length, length]
|
||||
"""
|
||||
r = torch.arange(length, dtype=torch.float32)
|
||||
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
||||
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.activation = activation
|
||||
self.causal = causal
|
||||
|
||||
if causal:
|
||||
self.padding = self._causal_padding
|
||||
else:
|
||||
self.padding = self._same_padding
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
||||
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.conv_1(self.padding(x * x_mask))
|
||||
if self.activation == "gelu":
|
||||
x = x * torch.sigmoid(1.702 * x)
|
||||
else:
|
||||
x = torch.relu(x)
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(self.padding(x * x_mask))
|
||||
return x * x_mask
|
||||
|
||||
def _causal_padding(self, x):
|
||||
if self.kernel_size == 1:
|
||||
return x
|
||||
pad_l = self.kernel_size - 1
|
||||
pad_r = 0
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, commons.convert_pad_shape(padding))
|
||||
return x
|
||||
|
||||
def _same_padding(self, x):
|
||||
if self.kernel_size == 1:
|
||||
return x
|
||||
pad_l = (self.kernel_size - 1) // 2
|
||||
pad_r = self.kernel_size // 2
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, commons.convert_pad_shape(padding))
|
||||
return x
|
|
@ -0,0 +1,172 @@
|
|||
import math
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import torch.jit
|
||||
|
||||
|
||||
def script_method(fn, _rcb=None):
|
||||
return fn
|
||||
|
||||
|
||||
def script(obj, optimize=True, _frames_up=0, _rcb=None):
|
||||
return obj
|
||||
|
||||
|
||||
torch.jit.script_method = script_method
|
||||
torch.jit.script = script
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size*dilation - dilation)/2)
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
return result
|
||||
|
||||
|
||||
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
||||
"""KL(P||Q)"""
|
||||
kl = (logs_q - logs_p) - 0.5
|
||||
kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
|
||||
return kl
|
||||
|
||||
|
||||
def rand_gumbel(shape):
|
||||
"""Sample from the Gumbel distribution, protect from overflows."""
|
||||
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
||||
return -torch.log(-torch.log(uniform_samples))
|
||||
|
||||
|
||||
def rand_gumbel_like(x):
|
||||
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
||||
return g
|
||||
|
||||
|
||||
def slice_segments(x, ids_str, segment_size=4):
|
||||
ret = torch.zeros_like(x[:, :, :segment_size])
|
||||
for i in range(x.size(0)):
|
||||
idx_str = ids_str[i]
|
||||
idx_end = idx_str + segment_size
|
||||
ret[i] = x[i, :, idx_str:idx_end]
|
||||
return ret
|
||||
|
||||
|
||||
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
||||
b, d, t = x.size()
|
||||
if x_lengths is None:
|
||||
x_lengths = t
|
||||
ids_str_max = x_lengths - segment_size + 1
|
||||
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
||||
ret = slice_segments(x, ids_str, segment_size)
|
||||
return ret, ids_str
|
||||
|
||||
|
||||
def get_timing_signal_1d(
|
||||
length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
||||
position = torch.arange(length, dtype=torch.float)
|
||||
num_timescales = channels // 2
|
||||
log_timescale_increment = (
|
||||
math.log(float(max_timescale) / float(min_timescale)) /
|
||||
(num_timescales - 1))
|
||||
inv_timescales = min_timescale * torch.exp(
|
||||
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
|
||||
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
||||
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
||||
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
||||
signal = signal.view(1, channels, length)
|
||||
return signal
|
||||
|
||||
|
||||
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return x + signal.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
||||
b, channels, length = x.size()
|
||||
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
||||
|
||||
|
||||
def subsequent_mask(length):
|
||||
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
||||
return mask
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
in_act = input_a + input_b
|
||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def shift_1d(x):
|
||||
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
||||
return x
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
"""
|
||||
duration: [b, 1, t_x]
|
||||
mask: [b, 1, t_y, t_x]
|
||||
"""
|
||||
device = duration.device
|
||||
|
||||
b, _, t_y, t_x = mask.shape
|
||||
cum_duration = torch.cumsum(duration, -1)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = path.unsqueeze(1).transpose(2,3) * mask
|
||||
return path
|
||||
|
||||
|
||||
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
||||
norm_type = float(norm_type)
|
||||
if clip_value is not None:
|
||||
clip_value = float(clip_value)
|
||||
|
||||
total_norm = 0
|
||||
for p in parameters:
|
||||
param_norm = p.grad.data.norm(norm_type)
|
||||
total_norm += param_norm.item() ** norm_type
|
||||
if clip_value is not None:
|
||||
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
||||
total_norm = total_norm ** (1. / norm_type)
|
||||
return total_norm
|
|
@ -0,0 +1,535 @@
|
|||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
# import commons
|
||||
# import modules
|
||||
# import attentions
|
||||
# import monotonic_align
|
||||
from utils.tts.vits import commons, modules, attentions, monotonic_align
|
||||
from utils.tts.vits.commons import init_weights, get_padding
|
||||
|
||||
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
# from commons import init_weights, get_padding
|
||||
|
||||
|
||||
class StochasticDurationPredictor(nn.Module):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
|
||||
super().__init__()
|
||||
filter_channels = in_channels # it needs to be removed from future version.
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.n_flows = n_flows
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.log_flow = modules.Log()
|
||||
self.flows = nn.ModuleList()
|
||||
self.flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(n_flows):
|
||||
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
||||
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
self.post_flows = nn.ModuleList()
|
||||
self.post_flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(4):
|
||||
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.post_flows.append(modules.Flip())
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
||||
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
|
||||
x = torch.detach(x)
|
||||
x = self.pre(x)
|
||||
if g is not None:
|
||||
g = torch.detach(g)
|
||||
x = x + self.cond(g)
|
||||
x = self.convs(x, x_mask)
|
||||
x = self.proj(x) * x_mask
|
||||
|
||||
if not reverse:
|
||||
flows = self.flows
|
||||
assert w is not None
|
||||
|
||||
logdet_tot_q = 0
|
||||
h_w = self.post_pre(w)
|
||||
h_w = self.post_convs(h_w, x_mask)
|
||||
h_w = self.post_proj(h_w) * x_mask
|
||||
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||
z_q = e_q
|
||||
for flow in self.post_flows:
|
||||
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
||||
logdet_tot_q += logdet_q
|
||||
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
||||
u = torch.sigmoid(z_u) * x_mask
|
||||
z0 = (w - u) * x_mask
|
||||
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
|
||||
logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
|
||||
|
||||
logdet_tot = 0
|
||||
z0, logdet = self.log_flow(z0, x_mask)
|
||||
logdet_tot += logdet
|
||||
z = torch.cat([z0, z1], 1)
|
||||
for flow in flows:
|
||||
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
||||
logdet_tot = logdet_tot + logdet
|
||||
nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
|
||||
return nll + logq # [b]
|
||||
else:
|
||||
flows = list(reversed(self.flows))
|
||||
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
||||
for flow in flows:
|
||||
z = flow(z, x_mask, g=x, reverse=reverse)
|
||||
z0, z1 = torch.split(z, [1, 1], 1)
|
||||
logw = z0
|
||||
return logw
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
|
||||
self.norm_1 = modules.LayerNorm(filter_channels)
|
||||
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
|
||||
self.norm_2 = modules.LayerNorm(filter_channels)
|
||||
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
x = torch.detach(x)
|
||||
if g is not None:
|
||||
g = torch.detach(g)
|
||||
x = x + self.cond(g)
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_1(x)
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_2(x)
|
||||
x = self.drop(x)
|
||||
x = self.proj(x * x_mask)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
n_vocab,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout):
|
||||
super().__init__()
|
||||
self.n_vocab = n_vocab
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
||||
|
||||
self.encoder = attentions.Encoder(
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout)
|
||||
self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths):
|
||||
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
stats = self.proj(x) * x_mask
|
||||
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
return x, m, logs, x_mask
|
||||
|
||||
|
||||
class ResidualCouplingBlock(nn.Module):
|
||||
def __init__(self,
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
n_flows=4,
|
||||
gin_channels=0):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.n_flows = n_flows
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.flows = nn.ModuleList()
|
||||
for i in range(n_flows):
|
||||
self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
if not reverse:
|
||||
for flow in self.flows:
|
||||
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
||||
else:
|
||||
for flow in reversed(self.flows):
|
||||
x = flow(x, x_mask, g=g, reverse=reverse)
|
||||
return x
|
||||
|
||||
|
||||
class PosteriorEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
gin_channels=0):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
||||
return z, m, logs, x_mask
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(weight_norm(
|
||||
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
||||
k, u, padding=(k-u)//2)))
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel//(2**(i+1))
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||
self.ups.apply(init_weights)
|
||||
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
|
||||
def forward(self, x, g=None):
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i*self.num_kernels+j](x)
|
||||
else:
|
||||
xs += self.resblocks[i*self.num_kernels+j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
self.use_spectral_norm = use_spectral_norm
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
|
||||
])
|
||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(DiscriminatorS, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
])
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
periods = [2,3,5,7,11]
|
||||
|
||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
||||
self.discriminators = nn.ModuleList(discs)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_rs.append(fmap_r)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
|
||||
class SynthesizerTrn(nn.Module):
|
||||
"""
|
||||
Synthesizer for Training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_vocab,
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
n_speakers=0,
|
||||
gin_channels=0,
|
||||
use_sdp=True,
|
||||
**kwargs):
|
||||
|
||||
super().__init__()
|
||||
self.n_vocab = n_vocab
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.resblock = resblock
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
self.upsample_rates = upsample_rates
|
||||
self.upsample_initial_channel = upsample_initial_channel
|
||||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||||
self.segment_size = segment_size
|
||||
self.n_speakers = n_speakers
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.use_sdp = use_sdp
|
||||
|
||||
self.enc_p = TextEncoder(n_vocab,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout)
|
||||
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
||||
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
||||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
|
||||
if use_sdp:
|
||||
self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
|
||||
else:
|
||||
self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
|
||||
|
||||
if n_speakers > 1:
|
||||
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
||||
|
||||
def forward(self, x, x_lengths, y, y_lengths, sid=None):
|
||||
|
||||
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
|
||||
if self.n_speakers > 0:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
else:
|
||||
g = None
|
||||
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
||||
z_p = self.flow(z, y_mask, g=g)
|
||||
|
||||
with torch.no_grad():
|
||||
# negative cross-entropy
|
||||
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
|
||||
neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s]
|
||||
neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
||||
neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
||||
neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
|
||||
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
|
||||
|
||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||
attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
|
||||
w = attn.sum(2)
|
||||
if self.use_sdp:
|
||||
l_length = self.dp(x, x_mask, w, g=g)
|
||||
l_length = l_length / torch.sum(x_mask)
|
||||
else:
|
||||
logw_ = torch.log(w + 1e-6) * x_mask
|
||||
logw = self.dp(x, x_mask, g=g)
|
||||
l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging
|
||||
|
||||
# expand prior
|
||||
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
|
||||
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
|
||||
o = self.dec(z_slice, g=g)
|
||||
return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||
|
||||
def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
|
||||
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
|
||||
if self.n_speakers > 0:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
else:
|
||||
g = None
|
||||
|
||||
if self.use_sdp:
|
||||
logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
|
||||
else:
|
||||
logw = self.dp(x, x_mask, g=g)
|
||||
w = torch.exp(logw) * x_mask * length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||
attn = commons.generate_path(w_ceil, attn_mask)
|
||||
|
||||
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||
o = self.dec((z * y_mask)[:,:,:max_len], g=g)
|
||||
return o, attn, y_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
|
||||
assert self.n_speakers > 0, "n_speakers have to be larger than 0."
|
||||
g_src = self.emb_g(sid_src).unsqueeze(-1)
|
||||
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
|
||||
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
|
||||
z_p = self.flow(z, y_mask, g=g_src)
|
||||
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
||||
o_hat = self.dec(z_hat * y_mask, g=g_tgt)
|
||||
return o_hat, y_mask, (z, z_p, z_hat)
|
||||
|
|
@ -0,0 +1,390 @@
|
|||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
|
||||
# import commons
|
||||
# from commons import init_weights, get_padding
|
||||
# from transforms import piecewise_rational_quadratic_transform
|
||||
from utils.tts.vits import commons
|
||||
from utils.tts.vits.commons import init_weights, get_padding
|
||||
from utils.tts.vits.transforms import piecewise_rational_quadratic_transform
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = nn.Parameter(torch.ones(channels))
|
||||
self.beta = nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, -1)
|
||||
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||
return x.transpose(1, -1)
|
||||
|
||||
|
||||
class ConvReluNorm(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
assert n_layers > 1, "Number of layers should be larger than 0."
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.relu_drop = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p_dropout))
|
||||
for _ in range(n_layers-1):
|
||||
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x_org = x
|
||||
for i in range(self.n_layers):
|
||||
x = self.conv_layers[i](x * x_mask)
|
||||
x = self.norm_layers[i](x)
|
||||
x = self.relu_drop(x)
|
||||
x = x_org + self.proj(x)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class DDSConv(nn.Module):
|
||||
"""
|
||||
Dialted and Depth-Separable Convolution
|
||||
"""
|
||||
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.convs_sep = nn.ModuleList()
|
||||
self.convs_1x1 = nn.ModuleList()
|
||||
self.norms_1 = nn.ModuleList()
|
||||
self.norms_2 = nn.ModuleList()
|
||||
for i in range(n_layers):
|
||||
dilation = kernel_size ** i
|
||||
padding = (kernel_size * dilation - dilation) // 2
|
||||
self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
|
||||
groups=channels, dilation=dilation, padding=padding
|
||||
))
|
||||
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
||||
self.norms_1.append(LayerNorm(channels))
|
||||
self.norms_2.append(LayerNorm(channels))
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
if g is not None:
|
||||
x = x + g
|
||||
for i in range(self.n_layers):
|
||||
y = self.convs_sep[i](x * x_mask)
|
||||
y = self.norms_1[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.convs_1x1[i](y)
|
||||
y = self.norms_2[i](y)
|
||||
y = F.gelu(y)
|
||||
y = self.drop(y)
|
||||
x = x + y
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class WN(torch.nn.Module):
|
||||
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
|
||||
super(WN, self).__init__()
|
||||
assert(kernel_size % 2 == 1)
|
||||
self.hidden_channels =hidden_channels
|
||||
self.kernel_size = kernel_size,
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.gin_channels = gin_channels
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.in_layers = torch.nn.ModuleList()
|
||||
self.res_skip_layers = torch.nn.ModuleList()
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
if gin_channels != 0:
|
||||
cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
|
||||
|
||||
for i in range(n_layers):
|
||||
dilation = dilation_rate ** i
|
||||
padding = int((kernel_size * dilation - dilation) / 2)
|
||||
in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
|
||||
dilation=dilation, padding=padding)
|
||||
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
# last one is not necessary
|
||||
if i < n_layers - 1:
|
||||
res_skip_channels = 2 * hidden_channels
|
||||
else:
|
||||
res_skip_channels = hidden_channels
|
||||
|
||||
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
||||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
|
||||
def forward(self, x, x_mask, g=None, **kwargs):
|
||||
output = torch.zeros_like(x)
|
||||
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
||||
|
||||
if g is not None:
|
||||
g = self.cond_layer(g)
|
||||
|
||||
for i in range(self.n_layers):
|
||||
x_in = self.in_layers[i](x)
|
||||
if g is not None:
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
|
||||
else:
|
||||
g_l = torch.zeros_like(x_in)
|
||||
|
||||
acts = commons.fused_add_tanh_sigmoid_multiply(
|
||||
x_in,
|
||||
g_l,
|
||||
n_channels_tensor)
|
||||
acts = self.drop(acts)
|
||||
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
if i < self.n_layers - 1:
|
||||
res_acts = res_skip_acts[:,:self.hidden_channels,:]
|
||||
x = (x + res_acts) * x_mask
|
||||
output = output + res_skip_acts[:,self.hidden_channels:,:]
|
||||
else:
|
||||
output = output + res_skip_acts
|
||||
return output * x_mask
|
||||
|
||||
def remove_weight_norm(self):
|
||||
if self.gin_channels != 0:
|
||||
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
||||
for l in self.in_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
for l in self.res_skip_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super(ResBlock1, self).__init__()
|
||||
self.convs1 = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2])))
|
||||
])
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1)))
|
||||
])
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
def forward(self, x, x_mask=None):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
if x_mask is not None:
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super(ResBlock2, self).__init__()
|
||||
self.convs = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1])))
|
||||
])
|
||||
self.convs.apply(init_weights)
|
||||
|
||||
def forward(self, x, x_mask=None):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if x_mask is not None:
|
||||
xt = xt * x_mask
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
if x_mask is not None:
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class Log(nn.Module):
|
||||
def forward(self, x, x_mask, reverse=False, **kwargs):
|
||||
if not reverse:
|
||||
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
||||
logdet = torch.sum(-y, [1, 2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = torch.exp(x) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class Flip(nn.Module):
|
||||
def forward(self, x, *args, reverse=False, **kwargs):
|
||||
x = torch.flip(x, [1])
|
||||
if not reverse:
|
||||
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
||||
return x, logdet
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class ElementwiseAffine(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.m = nn.Parameter(torch.zeros(channels,1))
|
||||
self.logs = nn.Parameter(torch.zeros(channels,1))
|
||||
|
||||
def forward(self, x, x_mask, reverse=False, **kwargs):
|
||||
if not reverse:
|
||||
y = self.m + torch.exp(self.logs) * x
|
||||
y = y * x_mask
|
||||
logdet = torch.sum(self.logs * x_mask, [1,2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCouplingLayer(nn.Module):
|
||||
def __init__(self,
|
||||
channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
n_layers,
|
||||
p_dropout=0,
|
||||
gin_channels=0,
|
||||
mean_only=False):
|
||||
assert channels % 2 == 0, "channels should be divisible by 2"
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.half_channels = channels // 2
|
||||
self.mean_only = mean_only
|
||||
|
||||
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
||||
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
|
||||
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
||||
self.post.weight.data.zero_()
|
||||
self.post.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
x0, x1 = torch.split(x, [self.half_channels]*2, 1)
|
||||
h = self.pre(x0) * x_mask
|
||||
h = self.enc(h, x_mask, g=g)
|
||||
stats = self.post(h) * x_mask
|
||||
if not self.mean_only:
|
||||
m, logs = torch.split(stats, [self.half_channels]*2, 1)
|
||||
else:
|
||||
m = stats
|
||||
logs = torch.zeros_like(m)
|
||||
|
||||
if not reverse:
|
||||
x1 = m + x1 * torch.exp(logs) * x_mask
|
||||
x = torch.cat([x0, x1], 1)
|
||||
logdet = torch.sum(logs, [1,2])
|
||||
return x, logdet
|
||||
else:
|
||||
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
||||
x = torch.cat([x0, x1], 1)
|
||||
return x
|
||||
|
||||
|
||||
class ConvFlow(nn.Module):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.num_bins = num_bins
|
||||
self.tail_bound = tail_bound
|
||||
self.half_channels = in_channels // 2
|
||||
|
||||
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
|
||||
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
|
||||
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
x0, x1 = torch.split(x, [self.half_channels]*2, 1)
|
||||
h = self.pre(x0)
|
||||
h = self.convs(h, x_mask, g=g)
|
||||
h = self.proj(h) * x_mask
|
||||
|
||||
b, c, t = x0.shape
|
||||
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
||||
|
||||
unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
|
||||
unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels)
|
||||
unnormalized_derivatives = h[..., 2 * self.num_bins:]
|
||||
|
||||
x1, logabsdet = piecewise_rational_quadratic_transform(x1,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=reverse,
|
||||
tails='linear',
|
||||
tail_bound=self.tail_bound
|
||||
)
|
||||
|
||||
x = torch.cat([x0, x1], 1) * x_mask
|
||||
logdet = torch.sum(logabsdet * x_mask, [1,2])
|
||||
if not reverse:
|
||||
return x, logdet
|
||||
else:
|
||||
return x
|
|
@ -0,0 +1,20 @@
|
|||
from numpy import zeros, int32, float32
|
||||
from torch import from_numpy
|
||||
|
||||
from .core import maximum_path_jit
|
||||
|
||||
|
||||
def maximum_path(neg_cent, mask):
|
||||
""" numba optimized version.
|
||||
neg_cent: [b, t_t, t_s]
|
||||
mask: [b, t_t, t_s]
|
||||
"""
|
||||
device = neg_cent.device
|
||||
dtype = neg_cent.dtype
|
||||
neg_cent = neg_cent.data.cpu().numpy().astype(float32)
|
||||
path = zeros(neg_cent.shape, dtype=int32)
|
||||
|
||||
t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
|
||||
t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
|
||||
maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
|
||||
return from_numpy(path).to(device=device, dtype=dtype)
|
|
@ -0,0 +1,36 @@
|
|||
import numba
|
||||
|
||||
|
||||
@numba.jit(numba.void(numba.int32[:, :, ::1], numba.float32[:, :, ::1], numba.int32[::1], numba.int32[::1]),
|
||||
nopython=True, nogil=True)
|
||||
def maximum_path_jit(paths, values, t_ys, t_xs):
|
||||
b = paths.shape[0]
|
||||
max_neg_val = -1e9
|
||||
for i in range(int(b)):
|
||||
path = paths[i]
|
||||
value = values[i]
|
||||
t_y = t_ys[i]
|
||||
t_x = t_xs[i]
|
||||
|
||||
v_prev = v_cur = 0.0
|
||||
index = t_x - 1
|
||||
|
||||
for y in range(t_y):
|
||||
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
||||
if x == y:
|
||||
v_cur = max_neg_val
|
||||
else:
|
||||
v_cur = value[y - 1, x]
|
||||
if x == 0:
|
||||
if y == 0:
|
||||
v_prev = 0.
|
||||
else:
|
||||
v_prev = max_neg_val
|
||||
else:
|
||||
v_prev = value[y - 1, x - 1]
|
||||
value[y, x] += max(v_prev, v_cur)
|
||||
|
||||
for y in range(t_y - 1, -1, -1):
|
||||
path[y, index] = 1
|
||||
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
|
||||
index = index - 1
|
|
@ -0,0 +1,19 @@
|
|||
Copyright (c) 2017 Keith Ito
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
|
@ -0,0 +1,57 @@
|
|||
""" from https://github.com/keithito/tacotron """
|
||||
from . import cleaners
|
||||
from .symbols import symbols
|
||||
|
||||
|
||||
# Mappings from symbol to numeric ID and vice versa:
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
||||
|
||||
|
||||
def text_to_sequence(text, symbols, cleaner_names):
|
||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
cleaner_names: names of the cleaner functions to run the text through
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
sequence = []
|
||||
|
||||
clean_text = _clean_text(text, cleaner_names)
|
||||
for symbol in clean_text:
|
||||
if symbol not in _symbol_to_id.keys():
|
||||
continue
|
||||
symbol_id = _symbol_to_id[symbol]
|
||||
sequence += [symbol_id]
|
||||
return sequence, clean_text
|
||||
|
||||
|
||||
def cleaned_text_to_sequence(cleaned_text):
|
||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
|
||||
return sequence
|
||||
|
||||
|
||||
def sequence_to_text(sequence):
|
||||
'''Converts a sequence of IDs back to a string'''
|
||||
result = ''
|
||||
for symbol_id in sequence:
|
||||
s = _id_to_symbol[symbol_id]
|
||||
result += s
|
||||
return result
|
||||
|
||||
|
||||
def _clean_text(text, cleaner_names):
|
||||
for name in cleaner_names:
|
||||
cleaner = getattr(cleaners, name)
|
||||
if not cleaner:
|
||||
raise Exception('Unknown cleaner: %s' % name)
|
||||
text = cleaner(text)
|
||||
return text
|
|
@ -0,0 +1,475 @@
|
|||
""" from https://github.com/keithito/tacotron """
|
||||
|
||||
'''
|
||||
Cleaners are transformations that run over the input text at both training and eval time.
|
||||
|
||||
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
||||
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
||||
1. "english_cleaners" for English text
|
||||
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
||||
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
||||
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
||||
the symbols in symbols.py to match your data).
|
||||
'''
|
||||
|
||||
import re
|
||||
from unidecode import unidecode
|
||||
# import pyopenjtalk
|
||||
from jamo import h2j, j2hcj
|
||||
from pypinyin import lazy_pinyin, BOPOMOFO
|
||||
import jieba, cn2an
|
||||
|
||||
|
||||
# This is a list of Korean classifiers preceded by pure Korean numerals.
|
||||
_korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
_whitespace_re = re.compile(r'\s+')
|
||||
|
||||
# Regular expression matching Japanese without punctuation marks:
|
||||
_japanese_characters = re.compile(r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
|
||||
|
||||
# Regular expression matching non-Japanese characters or punctuation marks:
|
||||
_japanese_marks = re.compile(r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('mrs', 'misess'),
|
||||
('mr', 'mister'),
|
||||
('dr', 'doctor'),
|
||||
('st', 'saint'),
|
||||
('co', 'company'),
|
||||
('jr', 'junior'),
|
||||
('maj', 'major'),
|
||||
('gen', 'general'),
|
||||
('drs', 'doctors'),
|
||||
('rev', 'reverend'),
|
||||
('lt', 'lieutenant'),
|
||||
('hon', 'honorable'),
|
||||
('sgt', 'sergeant'),
|
||||
('capt', 'captain'),
|
||||
('esq', 'esquire'),
|
||||
('ltd', 'limited'),
|
||||
('col', 'colonel'),
|
||||
('ft', 'fort'),
|
||||
]]
|
||||
|
||||
# List of (hangul, hangul divided) pairs:
|
||||
_hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [
|
||||
('ㄳ', 'ㄱㅅ'),
|
||||
('ㄵ', 'ㄴㅈ'),
|
||||
('ㄶ', 'ㄴㅎ'),
|
||||
('ㄺ', 'ㄹㄱ'),
|
||||
('ㄻ', 'ㄹㅁ'),
|
||||
('ㄼ', 'ㄹㅂ'),
|
||||
('ㄽ', 'ㄹㅅ'),
|
||||
('ㄾ', 'ㄹㅌ'),
|
||||
('ㄿ', 'ㄹㅍ'),
|
||||
('ㅀ', 'ㄹㅎ'),
|
||||
('ㅄ', 'ㅂㅅ'),
|
||||
('ㅘ', 'ㅗㅏ'),
|
||||
('ㅙ', 'ㅗㅐ'),
|
||||
('ㅚ', 'ㅗㅣ'),
|
||||
('ㅝ', 'ㅜㅓ'),
|
||||
('ㅞ', 'ㅜㅔ'),
|
||||
('ㅟ', 'ㅜㅣ'),
|
||||
('ㅢ', 'ㅡㅣ'),
|
||||
('ㅑ', 'ㅣㅏ'),
|
||||
('ㅒ', 'ㅣㅐ'),
|
||||
('ㅕ', 'ㅣㅓ'),
|
||||
('ㅖ', 'ㅣㅔ'),
|
||||
('ㅛ', 'ㅣㅗ'),
|
||||
('ㅠ', 'ㅣㅜ')
|
||||
]]
|
||||
|
||||
# List of (Latin alphabet, hangul) pairs:
|
||||
_latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('a', '에이'),
|
||||
('b', '비'),
|
||||
('c', '시'),
|
||||
('d', '디'),
|
||||
('e', '이'),
|
||||
('f', '에프'),
|
||||
('g', '지'),
|
||||
('h', '에이치'),
|
||||
('i', '아이'),
|
||||
('j', '제이'),
|
||||
('k', '케이'),
|
||||
('l', '엘'),
|
||||
('m', '엠'),
|
||||
('n', '엔'),
|
||||
('o', '오'),
|
||||
('p', '피'),
|
||||
('q', '큐'),
|
||||
('r', '아르'),
|
||||
('s', '에스'),
|
||||
('t', '티'),
|
||||
('u', '유'),
|
||||
('v', '브이'),
|
||||
('w', '더블유'),
|
||||
('x', '엑스'),
|
||||
('y', '와이'),
|
||||
('z', '제트')
|
||||
]]
|
||||
|
||||
# List of (Latin alphabet, bopomofo) pairs:
|
||||
_latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('a', 'ㄟˉ'),
|
||||
('b', 'ㄅㄧˋ'),
|
||||
('c', 'ㄙㄧˉ'),
|
||||
('d', 'ㄉㄧˋ'),
|
||||
('e', 'ㄧˋ'),
|
||||
('f', 'ㄝˊㄈㄨˋ'),
|
||||
('g', 'ㄐㄧˋ'),
|
||||
('h', 'ㄝˇㄑㄩˋ'),
|
||||
('i', 'ㄞˋ'),
|
||||
('j', 'ㄐㄟˋ'),
|
||||
('k', 'ㄎㄟˋ'),
|
||||
('l', 'ㄝˊㄛˋ'),
|
||||
('m', 'ㄝˊㄇㄨˋ'),
|
||||
('n', 'ㄣˉ'),
|
||||
('o', 'ㄡˉ'),
|
||||
('p', 'ㄆㄧˉ'),
|
||||
('q', 'ㄎㄧㄡˉ'),
|
||||
('r', 'ㄚˋ'),
|
||||
('s', 'ㄝˊㄙˋ'),
|
||||
('t', 'ㄊㄧˋ'),
|
||||
('u', 'ㄧㄡˉ'),
|
||||
('v', 'ㄨㄧˉ'),
|
||||
('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
|
||||
('x', 'ㄝˉㄎㄨˋㄙˋ'),
|
||||
('y', 'ㄨㄞˋ'),
|
||||
('z', 'ㄗㄟˋ')
|
||||
]]
|
||||
|
||||
|
||||
# List of (bopomofo, romaji) pairs:
|
||||
_bopomofo_to_romaji = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('ㄅㄛ', 'p⁼wo'),
|
||||
('ㄆㄛ', 'pʰwo'),
|
||||
('ㄇㄛ', 'mwo'),
|
||||
('ㄈㄛ', 'fwo'),
|
||||
('ㄅ', 'p⁼'),
|
||||
('ㄆ', 'pʰ'),
|
||||
('ㄇ', 'm'),
|
||||
('ㄈ', 'f'),
|
||||
('ㄉ', 't⁼'),
|
||||
('ㄊ', 'tʰ'),
|
||||
('ㄋ', 'n'),
|
||||
('ㄌ', 'l'),
|
||||
('ㄍ', 'k⁼'),
|
||||
('ㄎ', 'kʰ'),
|
||||
('ㄏ', 'h'),
|
||||
('ㄐ', 'ʧ⁼'),
|
||||
('ㄑ', 'ʧʰ'),
|
||||
('ㄒ', 'ʃ'),
|
||||
('ㄓ', 'ʦ`⁼'),
|
||||
('ㄔ', 'ʦ`ʰ'),
|
||||
('ㄕ', 's`'),
|
||||
('ㄖ', 'ɹ`'),
|
||||
('ㄗ', 'ʦ⁼'),
|
||||
('ㄘ', 'ʦʰ'),
|
||||
('ㄙ', 's'),
|
||||
('ㄚ', 'a'),
|
||||
('ㄛ', 'o'),
|
||||
('ㄜ', 'ə'),
|
||||
('ㄝ', 'e'),
|
||||
('ㄞ', 'ai'),
|
||||
('ㄟ', 'ei'),
|
||||
('ㄠ', 'au'),
|
||||
('ㄡ', 'ou'),
|
||||
('ㄧㄢ', 'yeNN'),
|
||||
('ㄢ', 'aNN'),
|
||||
('ㄧㄣ', 'iNN'),
|
||||
('ㄣ', 'əNN'),
|
||||
('ㄤ', 'aNg'),
|
||||
('ㄧㄥ', 'iNg'),
|
||||
('ㄨㄥ', 'uNg'),
|
||||
('ㄩㄥ', 'yuNg'),
|
||||
('ㄥ', 'əNg'),
|
||||
('ㄦ', 'əɻ'),
|
||||
('ㄧ', 'i'),
|
||||
('ㄨ', 'u'),
|
||||
('ㄩ', 'ɥ'),
|
||||
('ˉ', '→'),
|
||||
('ˊ', '↑'),
|
||||
('ˇ', '↓↑'),
|
||||
('ˋ', '↓'),
|
||||
('˙', ''),
|
||||
(',', ','),
|
||||
('。', '.'),
|
||||
('!', '!'),
|
||||
('?', '?'),
|
||||
('—', '-')
|
||||
]]
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, ' ', text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def japanese_to_romaji_with_accent(text):
|
||||
'''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
|
||||
sentences = re.split(_japanese_marks, text)
|
||||
marks = re.findall(_japanese_marks, text)
|
||||
text = ''
|
||||
for i, sentence in enumerate(sentences):
|
||||
if re.match(_japanese_characters, sentence):
|
||||
if text!='':
|
||||
text+=' '
|
||||
labels = pyopenjtalk.extract_fullcontext(sentence)
|
||||
for n, label in enumerate(labels):
|
||||
phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
|
||||
if phoneme not in ['sil','pau']:
|
||||
text += phoneme.replace('ch','ʧ').replace('sh','ʃ').replace('cl','Q')
|
||||
else:
|
||||
continue
|
||||
n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
|
||||
a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
|
||||
a2 = int(re.search(r"\+(\d+)\+", label).group(1))
|
||||
a3 = int(re.search(r"\+(\d+)/", label).group(1))
|
||||
if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil','pau']:
|
||||
a2_next=-1
|
||||
else:
|
||||
a2_next = int(re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
|
||||
# Accent phrase boundary
|
||||
if a3 == 1 and a2_next == 1:
|
||||
text += ' '
|
||||
# Falling
|
||||
elif a1 == 0 and a2_next == a2 + 1 and a2 != n_moras:
|
||||
text += '↓'
|
||||
# Rising
|
||||
elif a2 == 1 and a2_next == 2:
|
||||
text += '↑'
|
||||
if i<len(marks):
|
||||
text += unidecode(marks[i]).replace(' ','')
|
||||
return text
|
||||
|
||||
|
||||
def latin_to_hangul(text):
|
||||
for regex, replacement in _latin_to_hangul:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def divide_hangul(text):
|
||||
for regex, replacement in _hangul_divided:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def hangul_number(num, sino=True):
|
||||
'''Reference https://github.com/Kyubyong/g2pK'''
|
||||
num = re.sub(',', '', num)
|
||||
|
||||
if num == '0':
|
||||
return '영'
|
||||
if not sino and num == '20':
|
||||
return '스무'
|
||||
|
||||
digits = '123456789'
|
||||
names = '일이삼사오육칠팔구'
|
||||
digit2name = {d: n for d, n in zip(digits, names)}
|
||||
|
||||
modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉'
|
||||
decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔'
|
||||
digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())}
|
||||
digit2dec = {d: dec for d, dec in zip(digits, decimals.split())}
|
||||
|
||||
spelledout = []
|
||||
for i, digit in enumerate(num):
|
||||
i = len(num) - i - 1
|
||||
if sino:
|
||||
if i == 0:
|
||||
name = digit2name.get(digit, '')
|
||||
elif i == 1:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
name = name.replace('일십', '십')
|
||||
else:
|
||||
if i == 0:
|
||||
name = digit2mod.get(digit, '')
|
||||
elif i == 1:
|
||||
name = digit2dec.get(digit, '')
|
||||
if digit == '0':
|
||||
if i % 4 == 0:
|
||||
last_three = spelledout[-min(3, len(spelledout)):]
|
||||
if ''.join(last_three) == '':
|
||||
spelledout.append('')
|
||||
continue
|
||||
else:
|
||||
spelledout.append('')
|
||||
continue
|
||||
if i == 2:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
name = name.replace('일백', '백')
|
||||
elif i == 3:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
name = name.replace('일천', '천')
|
||||
elif i == 4:
|
||||
name = digit2name.get(digit, '') + '만'
|
||||
name = name.replace('일만', '만')
|
||||
elif i == 5:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
name = name.replace('일십', '십')
|
||||
elif i == 6:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
name = name.replace('일백', '백')
|
||||
elif i == 7:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
name = name.replace('일천', '천')
|
||||
elif i == 8:
|
||||
name = digit2name.get(digit, '') + '억'
|
||||
elif i == 9:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
elif i == 10:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
elif i == 11:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
elif i == 12:
|
||||
name = digit2name.get(digit, '') + '조'
|
||||
elif i == 13:
|
||||
name = digit2name.get(digit, '') + '십'
|
||||
elif i == 14:
|
||||
name = digit2name.get(digit, '') + '백'
|
||||
elif i == 15:
|
||||
name = digit2name.get(digit, '') + '천'
|
||||
spelledout.append(name)
|
||||
return ''.join(elem for elem in spelledout)
|
||||
|
||||
|
||||
def number_to_hangul(text):
|
||||
'''Reference https://github.com/Kyubyong/g2pK'''
|
||||
tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text))
|
||||
for token in tokens:
|
||||
num, classifier = token
|
||||
if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers:
|
||||
spelledout = hangul_number(num, sino=False)
|
||||
else:
|
||||
spelledout = hangul_number(num, sino=True)
|
||||
text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}')
|
||||
# digit by digit for remaining digits
|
||||
digits = '0123456789'
|
||||
names = '영일이삼사오육칠팔구'
|
||||
for d, n in zip(digits, names):
|
||||
text = text.replace(d, n)
|
||||
return text
|
||||
|
||||
|
||||
def number_to_chinese(text):
|
||||
numbers = re.findall(r'\d+(?:\.?\d+)?', text)
|
||||
for number in numbers:
|
||||
text = text.replace(number, cn2an.an2cn(number),1)
|
||||
return text
|
||||
|
||||
|
||||
def chinese_to_bopomofo(text):
|
||||
text=text.replace('、',',').replace(';',',').replace(':',',')
|
||||
words=jieba.lcut(text,cut_all=False)
|
||||
text=''
|
||||
for word in words:
|
||||
bopomofos=lazy_pinyin(word,BOPOMOFO)
|
||||
if not re.search('[\u4e00-\u9fff]',word):
|
||||
text+=word
|
||||
continue
|
||||
for i in range(len(bopomofos)):
|
||||
if re.match('[\u3105-\u3129]',bopomofos[i][-1]):
|
||||
bopomofos[i]+='ˉ'
|
||||
if text!='':
|
||||
text+=' '
|
||||
text+=''.join(bopomofos)
|
||||
return text
|
||||
|
||||
|
||||
def latin_to_bopomofo(text):
|
||||
for regex, replacement in _latin_to_bopomofo:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def bopomofo_to_romaji(text):
|
||||
for regex, replacement in _bopomofo_to_romaji:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def transliteration_cleaners(text):
|
||||
'''Pipeline for non-English text that transliterates to ASCII.'''
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def japanese_cleaners(text):
|
||||
text=japanese_to_romaji_with_accent(text)
|
||||
if re.match('[A-Za-z]',text[-1]):
|
||||
text += '.'
|
||||
return text
|
||||
|
||||
|
||||
def japanese_cleaners2(text):
|
||||
return japanese_cleaners(text).replace('ts','ʦ').replace('...','…')
|
||||
|
||||
|
||||
def korean_cleaners(text):
|
||||
'''Pipeline for Korean text'''
|
||||
text = latin_to_hangul(text)
|
||||
text = number_to_hangul(text)
|
||||
text = j2hcj(h2j(text))
|
||||
text = divide_hangul(text)
|
||||
if re.match('[\u3131-\u3163]',text[-1]):
|
||||
text += '.'
|
||||
return text
|
||||
|
||||
|
||||
def chinese_cleaners(text):
|
||||
'''Pipeline for Chinese text'''
|
||||
text=number_to_chinese(text)
|
||||
text=chinese_to_bopomofo(text)
|
||||
text=latin_to_bopomofo(text)
|
||||
if re.match('[ˉˊˇˋ˙]',text[-1]):
|
||||
text += '。'
|
||||
return text
|
||||
|
||||
|
||||
def zh_ja_mixture_cleaners(text):
|
||||
chinese_texts=re.findall(r'\[ZH\].*?\[ZH\]',text)
|
||||
japanese_texts=re.findall(r'\[JA\].*?\[JA\]',text)
|
||||
for chinese_text in chinese_texts:
|
||||
cleaned_text=number_to_chinese(chinese_text[4:-4])
|
||||
cleaned_text=chinese_to_bopomofo(cleaned_text)
|
||||
cleaned_text=latin_to_bopomofo(cleaned_text)
|
||||
cleaned_text=bopomofo_to_romaji(cleaned_text)
|
||||
cleaned_text=re.sub('i[aoe]',lambda x:'y'+x.group(0)[1:],cleaned_text)
|
||||
cleaned_text=re.sub('u[aoəe]',lambda x:'w'+x.group(0)[1:],cleaned_text)
|
||||
cleaned_text=re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑]+)',lambda x:x.group(1)+'ɹ`'+x.group(2),cleaned_text).replace('ɻ','ɹ`')
|
||||
cleaned_text=re.sub('([ʦs][⁼ʰ]?)([→↓↑]+)',lambda x:x.group(1)+'ɹ'+x.group(2),cleaned_text)
|
||||
text = text.replace(chinese_text,cleaned_text+' ',1)
|
||||
for japanese_text in japanese_texts:
|
||||
cleaned_text=japanese_to_romaji_with_accent(japanese_text[4:-4]).replace('ts','ʦ').replace('u','ɯ').replace('...','…')
|
||||
text = text.replace(japanese_text,cleaned_text+' ',1)
|
||||
text=text[:-1]
|
||||
if re.match('[A-Za-zɯɹəɥ→↓↑]',text[-1]):
|
||||
text += '.'
|
||||
return text
|
|
@ -0,0 +1,39 @@
|
|||
'''
|
||||
Defines the set of symbols used in text input to the model.
|
||||
'''
|
||||
|
||||
'''# japanese_cleaners
|
||||
_pad = '_'
|
||||
_punctuation = ',.!?-'
|
||||
_letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
|
||||
'''
|
||||
|
||||
'''# japanese_cleaners2
|
||||
_pad = '_'
|
||||
_punctuation = ',.!?-~…'
|
||||
_letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
|
||||
'''
|
||||
|
||||
'''# korean_cleaners
|
||||
_pad = '_'
|
||||
_punctuation = ',.!?…~'
|
||||
_letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
|
||||
'''
|
||||
|
||||
'''# chinese_cleaners
|
||||
_pad = '_'
|
||||
_punctuation = ',。!?—…'
|
||||
_letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
|
||||
'''
|
||||
|
||||
# zh_ja_mixture_cleaners
|
||||
_pad = '_'
|
||||
_punctuation = ',.!?-~…'
|
||||
_letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
|
||||
|
||||
|
||||
# Export all symbols:
|
||||
symbols = [_pad] + list(_punctuation) + list(_letters)
|
||||
|
||||
# Special symbol ids
|
||||
SPACE_ID = symbols.index(" ")
|
|
@ -0,0 +1,193 @@
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
||||
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
||||
DEFAULT_MIN_DERIVATIVE = 1e-3
|
||||
|
||||
|
||||
def piecewise_rational_quadratic_transform(inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails=None,
|
||||
tail_bound=1.,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
||||
|
||||
if tails is None:
|
||||
spline_fn = rational_quadratic_spline
|
||||
spline_kwargs = {}
|
||||
else:
|
||||
spline_fn = unconstrained_rational_quadratic_spline
|
||||
spline_kwargs = {
|
||||
'tails': tails,
|
||||
'tail_bound': tail_bound
|
||||
}
|
||||
|
||||
outputs, logabsdet = spline_fn(
|
||||
inputs=inputs,
|
||||
unnormalized_widths=unnormalized_widths,
|
||||
unnormalized_heights=unnormalized_heights,
|
||||
unnormalized_derivatives=unnormalized_derivatives,
|
||||
inverse=inverse,
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative,
|
||||
**spline_kwargs
|
||||
)
|
||||
return outputs, logabsdet
|
||||
|
||||
|
||||
def searchsorted(bin_locations, inputs, eps=1e-6):
|
||||
bin_locations[..., -1] += eps
|
||||
return torch.sum(
|
||||
inputs[..., None] >= bin_locations,
|
||||
dim=-1
|
||||
) - 1
|
||||
|
||||
|
||||
def unconstrained_rational_quadratic_spline(inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails='linear',
|
||||
tail_bound=1.,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
||||
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
||||
outside_interval_mask = ~inside_interval_mask
|
||||
|
||||
outputs = torch.zeros_like(inputs)
|
||||
logabsdet = torch.zeros_like(inputs)
|
||||
|
||||
if tails == 'linear':
|
||||
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
||||
constant = np.log(np.exp(1 - min_derivative) - 1)
|
||||
unnormalized_derivatives[..., 0] = constant
|
||||
unnormalized_derivatives[..., -1] = constant
|
||||
|
||||
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
||||
logabsdet[outside_interval_mask] = 0
|
||||
else:
|
||||
raise RuntimeError('{} tails are not implemented.'.format(tails))
|
||||
|
||||
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
|
||||
inputs=inputs[inside_interval_mask],
|
||||
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
||||
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
||||
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
||||
inverse=inverse,
|
||||
left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative
|
||||
)
|
||||
|
||||
return outputs, logabsdet
|
||||
|
||||
def rational_quadratic_spline(inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
left=0., right=1., bottom=0., top=1.,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE):
|
||||
if torch.min(inputs) < left or torch.max(inputs) > right:
|
||||
raise ValueError('Input to a transform is not within its domain')
|
||||
|
||||
num_bins = unnormalized_widths.shape[-1]
|
||||
|
||||
if min_bin_width * num_bins > 1.0:
|
||||
raise ValueError('Minimal bin width too large for the number of bins')
|
||||
if min_bin_height * num_bins > 1.0:
|
||||
raise ValueError('Minimal bin height too large for the number of bins')
|
||||
|
||||
widths = F.softmax(unnormalized_widths, dim=-1)
|
||||
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
||||
cumwidths = torch.cumsum(widths, dim=-1)
|
||||
cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
|
||||
cumwidths = (right - left) * cumwidths + left
|
||||
cumwidths[..., 0] = left
|
||||
cumwidths[..., -1] = right
|
||||
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
||||
|
||||
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
||||
|
||||
heights = F.softmax(unnormalized_heights, dim=-1)
|
||||
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
||||
cumheights = torch.cumsum(heights, dim=-1)
|
||||
cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
|
||||
cumheights = (top - bottom) * cumheights + bottom
|
||||
cumheights[..., 0] = bottom
|
||||
cumheights[..., -1] = top
|
||||
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
||||
|
||||
if inverse:
|
||||
bin_idx = searchsorted(cumheights, inputs)[..., None]
|
||||
else:
|
||||
bin_idx = searchsorted(cumwidths, inputs)[..., None]
|
||||
|
||||
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
||||
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
||||
delta = heights / widths
|
||||
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
||||
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
||||
|
||||
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
||||
|
||||
if inverse:
|
||||
a = (((inputs - input_cumheights) * (input_derivatives
|
||||
+ input_derivatives_plus_one
|
||||
- 2 * input_delta)
|
||||
+ input_heights * (input_delta - input_derivatives)))
|
||||
b = (input_heights * input_derivatives
|
||||
- (inputs - input_cumheights) * (input_derivatives
|
||||
+ input_derivatives_plus_one
|
||||
- 2 * input_delta))
|
||||
c = - input_delta * (inputs - input_cumheights)
|
||||
|
||||
discriminant = b.pow(2) - 4 * a * c
|
||||
assert (discriminant >= 0).all()
|
||||
|
||||
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
||||
outputs = root * input_bin_widths + input_cumwidths
|
||||
|
||||
theta_one_minus_theta = root * (1 - root)
|
||||
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta)
|
||||
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
|
||||
+ 2 * input_delta * theta_one_minus_theta
|
||||
+ input_derivatives * (1 - root).pow(2))
|
||||
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||
|
||||
return outputs, -logabsdet
|
||||
else:
|
||||
theta = (inputs - input_cumwidths) / input_bin_widths
|
||||
theta_one_minus_theta = theta * (1 - theta)
|
||||
|
||||
numerator = input_heights * (input_delta * theta.pow(2)
|
||||
+ input_derivatives * theta_one_minus_theta)
|
||||
denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta)
|
||||
outputs = input_cumheights + numerator / denominator
|
||||
|
||||
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
|
||||
+ 2 * input_delta * theta_one_minus_theta
|
||||
+ input_derivatives * (1 - theta).pow(2))
|
||||
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
||||
|
||||
return outputs, logabsdet
|
|
@ -0,0 +1,225 @@
|
|||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
import subprocess
|
||||
import numpy as np
|
||||
import librosa
|
||||
import torch
|
||||
|
||||
MATPLOTLIB_FLAG = False
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||
logger = logging
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_path, model, optimizer=None):
|
||||
assert os.path.isfile(checkpoint_path)
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
||||
iteration = checkpoint_dict['iteration']
|
||||
learning_rate = checkpoint_dict['learning_rate']
|
||||
if optimizer is not None:
|
||||
optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
||||
saved_state_dict = checkpoint_dict['model']
|
||||
if hasattr(model, 'module'):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict= {}
|
||||
for k, v in state_dict.items():
|
||||
try:
|
||||
new_state_dict[k] = saved_state_dict[k]
|
||||
except:
|
||||
logger.info("%s is not in the checkpoint" % k)
|
||||
new_state_dict[k] = v
|
||||
if hasattr(model, 'module'):
|
||||
model.module.load_state_dict(new_state_dict)
|
||||
else:
|
||||
model.load_state_dict(new_state_dict)
|
||||
logger.info("Loaded checkpoint '{}' (iteration {})" .format(
|
||||
checkpoint_path, iteration))
|
||||
return model, optimizer, learning_rate, iteration
|
||||
|
||||
|
||||
def plot_spectrogram_to_numpy(spectrogram):
|
||||
global MATPLOTLIB_FLAG
|
||||
if not MATPLOTLIB_FLAG:
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
MATPLOTLIB_FLAG = True
|
||||
mpl_logger = logging.getLogger('matplotlib')
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10,2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
||||
interpolation='none')
|
||||
plt.colorbar(im, ax=ax)
|
||||
plt.xlabel("Frames")
|
||||
plt.ylabel("Channels")
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
|
||||
def plot_alignment_to_numpy(alignment, info=None):
|
||||
global MATPLOTLIB_FLAG
|
||||
if not MATPLOTLIB_FLAG:
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
MATPLOTLIB_FLAG = True
|
||||
mpl_logger = logging.getLogger('matplotlib')
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
|
||||
interpolation='none')
|
||||
fig.colorbar(im, ax=ax)
|
||||
xlabel = 'Decoder timestep'
|
||||
if info is not None:
|
||||
xlabel += '\n\n' + info
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel('Encoder timestep')
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
|
||||
def load_audio_to_torch(full_path, target_sampling_rate):
|
||||
audio, sampling_rate = librosa.load(full_path, sr=target_sampling_rate, mono=True)
|
||||
return torch.FloatTensor(audio.astype(np.float32))
|
||||
|
||||
|
||||
def load_filepaths_and_text(filename, split="|"):
|
||||
with open(filename, encoding='utf-8') as f:
|
||||
filepaths_and_text = [line.strip().split(split) for line in f]
|
||||
return filepaths_and_text
|
||||
|
||||
|
||||
def get_hparams(init=True):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
|
||||
help='JSON file for configuration')
|
||||
parser.add_argument('-m', '--model', type=str, required=True,
|
||||
help='Model name')
|
||||
|
||||
args = parser.parse_args()
|
||||
model_dir = os.path.join("./logs", args.model)
|
||||
|
||||
if not os.path.exists(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
|
||||
config_path = args.config
|
||||
config_save_path = os.path.join(model_dir, "config.json")
|
||||
if init:
|
||||
with open(config_path, "r") as f:
|
||||
data = f.read()
|
||||
with open(config_save_path, "w") as f:
|
||||
f.write(data)
|
||||
else:
|
||||
with open(config_save_path, "r") as f:
|
||||
data = f.read()
|
||||
config = json.loads(data)
|
||||
|
||||
hparams = HParams(**config)
|
||||
hparams.model_dir = model_dir
|
||||
return hparams
|
||||
|
||||
|
||||
def get_hparams_from_dir(model_dir):
|
||||
config_save_path = os.path.join(model_dir, "config.json")
|
||||
with open(config_save_path, "r") as f:
|
||||
data = f.read()
|
||||
config = json.loads(data)
|
||||
|
||||
hparams =HParams(**config)
|
||||
hparams.model_dir = model_dir
|
||||
return hparams
|
||||
|
||||
|
||||
def get_hparams_from_file(config_path):
|
||||
with open(config_path, "r") as f:
|
||||
data = f.read()
|
||||
config = json.loads(data)
|
||||
|
||||
hparams =HParams(**config)
|
||||
return hparams
|
||||
|
||||
|
||||
def check_git_hash(model_dir):
|
||||
source_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
if not os.path.exists(os.path.join(source_dir, ".git")):
|
||||
logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
|
||||
source_dir
|
||||
))
|
||||
return
|
||||
|
||||
cur_hash = subprocess.getoutput("git rev-parse HEAD")
|
||||
|
||||
path = os.path.join(model_dir, "githash")
|
||||
if os.path.exists(path):
|
||||
saved_hash = open(path).read()
|
||||
if saved_hash != cur_hash:
|
||||
logger.warn("git hash values are different. {}(saved) != {}(current)".format(
|
||||
saved_hash[:8], cur_hash[:8]))
|
||||
else:
|
||||
open(path, "w").write(cur_hash)
|
||||
|
||||
|
||||
def get_logger(model_dir, filename="train.log"):
|
||||
global logger
|
||||
logger = logging.getLogger(os.path.basename(model_dir))
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
|
||||
if not os.path.exists(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
h = logging.FileHandler(os.path.join(model_dir, filename))
|
||||
h.setLevel(logging.DEBUG)
|
||||
h.setFormatter(formatter)
|
||||
logger.addHandler(h)
|
||||
return logger
|
||||
|
||||
|
||||
class HParams():
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if type(v) == dict:
|
||||
v = HParams(**v)
|
||||
self[k] = v
|
||||
|
||||
def keys(self):
|
||||
return self.__dict__.keys()
|
||||
|
||||
def items(self):
|
||||
return self.__dict__.items()
|
||||
|
||||
def values(self):
|
||||
return self.__dict__.values()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__dict__)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
return setattr(self, key, value)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.__dict__
|
||||
|
||||
def __repr__(self):
|
||||
return self.__dict__.__repr__()
|
|
@ -0,0 +1,115 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import LongTensor
|
||||
import soundfile as sf
|
||||
# vits
|
||||
from .vits import utils, commons
|
||||
from .vits.models import SynthesizerTrn
|
||||
from .vits.text import text_to_sequence
|
||||
|
||||
def tts_model_init(model_path='./vits_model', device='cuda'):
|
||||
hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))
|
||||
# hps_ms = utils.get_hparams_from_file('vits_model/config.json')
|
||||
net_g_ms = SynthesizerTrn(
|
||||
len(hps_ms.symbols),
|
||||
hps_ms.data.filter_length // 2 + 1,
|
||||
hps_ms.train.segment_size // hps_ms.data.hop_length,
|
||||
n_speakers=hps_ms.data.n_speakers,
|
||||
**hps_ms.model)
|
||||
net_g_ms = net_g_ms.eval().to(device)
|
||||
speakers = hps_ms.speakers
|
||||
utils.load_checkpoint(os.path.join(model_path, 'G_953000.pth'), net_g_ms, None)
|
||||
# utils.load_checkpoint('vits_model/G_953000.pth', net_g_ms, None)
|
||||
return hps_ms, net_g_ms, speakers
|
||||
|
||||
class TextToSpeech:
|
||||
def __init__(self,
|
||||
model_path="./utils/tts/vits_model",
|
||||
device='cuda',
|
||||
RATE=22050,
|
||||
debug=False,
|
||||
):
|
||||
self.debug = debug
|
||||
self.RATE = RATE
|
||||
self.device = torch.device(device)
|
||||
self.limitation = os.getenv("SYSTEM") == "spaces" # 在huggingface spaces中限制文本和音频长度
|
||||
self.hps_ms, self.net_g_ms, self.speakers = self._tts_model_init(model_path)
|
||||
|
||||
def _tts_model_init(self, model_path):
|
||||
hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json'))
|
||||
net_g_ms = SynthesizerTrn(
|
||||
len(hps_ms.symbols),
|
||||
hps_ms.data.filter_length // 2 + 1,
|
||||
hps_ms.train.segment_size // hps_ms.data.hop_length,
|
||||
n_speakers=hps_ms.data.n_speakers,
|
||||
**hps_ms.model)
|
||||
net_g_ms = net_g_ms.eval().to(self.device)
|
||||
speakers = hps_ms.speakers
|
||||
utils.load_checkpoint(os.path.join(model_path, 'G_953000.pth'), net_g_ms, None)
|
||||
if self.debug:
|
||||
print("Model loaded.")
|
||||
return hps_ms, net_g_ms, speakers
|
||||
|
||||
def _get_text(self, text):
|
||||
text_norm, clean_text = text_to_sequence(text, self.hps_ms.symbols, self.hps_ms.data.text_cleaners)
|
||||
if self.hps_ms.data.add_blank:
|
||||
text_norm = commons.intersperse(text_norm, 0)
|
||||
text_norm = LongTensor(text_norm)
|
||||
return text_norm, clean_text
|
||||
|
||||
def _preprocess_text(self, text, language):
|
||||
if language == 0:
|
||||
return f"[ZH]{text}[ZH]"
|
||||
elif language == 1:
|
||||
return f"[JA]{text}[JA]"
|
||||
return text
|
||||
|
||||
def _generate_audio(self, text, speaker_id, noise_scale, noise_scale_w, length_scale):
|
||||
import time
|
||||
start_time = time.time()
|
||||
stn_tst, clean_text = self._get_text(text)
|
||||
with torch.no_grad():
|
||||
x_tst = stn_tst.unsqueeze(0).to(self.device)
|
||||
x_tst_lengths = LongTensor([stn_tst.size(0)]).to(self.device)
|
||||
speaker_id = LongTensor([speaker_id]).to(self.device)
|
||||
audio = self.net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
|
||||
length_scale=length_scale)[0][0, 0].data.cpu().float().numpy()
|
||||
if self.debug:
|
||||
print(f"Synthesis time: {time.time() - start_time} s")
|
||||
return audio
|
||||
|
||||
def synthesize(self, text, language, speaker_id, noise_scale, noise_scale_w, length_scale, save_audio=False, return_bytes=False):
|
||||
if not len(text):
|
||||
return "输入文本不能为空!", None
|
||||
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
|
||||
if len(text) > 100 and self.limitation:
|
||||
return f"输入文字过长!{len(text)}>100", None
|
||||
text = self._preprocess_text(text, language)
|
||||
audio = self._generate_audio(text, speaker_id, noise_scale, noise_scale_w, length_scale)
|
||||
if self.debug or save_audio:
|
||||
self.save_audio(audio, self.RATE, 'output_file.wav')
|
||||
if return_bytes:
|
||||
audio = self.convert_numpy_to_bytes(audio)
|
||||
return self.RATE, audio
|
||||
|
||||
def convert_numpy_to_bytes(self, audio_data):
|
||||
if isinstance(audio_data, np.ndarray):
|
||||
if audio_data.dtype == np.dtype('float32'):
|
||||
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
|
||||
audio_data = audio_data.tobytes()
|
||||
return audio_data
|
||||
else:
|
||||
raise TypeError("audio_data must be a numpy array")
|
||||
|
||||
def save_audio(self, audio, sample_rate, file_name='output_file.wav'):
|
||||
sf.write(file_name, audio, samplerate=sample_rate)
|
||||
print(f"VITS Audio saved to {file_name}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
import datetime
|
||||
import hashlib
|
||||
import base64
|
||||
import hmac
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from config import get_config
|
||||
|
||||
Config = get_config()
|
||||
|
||||
def generate_xf_asr_url():
|
||||
#设置讯飞流式听写API相关参数
|
||||
APIKey = Config.XF_ASR.API_KEY
|
||||
APISecret = Config.XF_ASR.API_SECRET
|
||||
|
||||
#鉴权并创建websocket_url
|
||||
url = 'wss://ws-api.xfyun.cn/v2/iat'
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v2/iat " + "HTTP/1.1"
|
||||
signature_sha = hmac.new(APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||
APIKey, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": "ws-api.xfyun.cn"
|
||||
}
|
||||
url = url + '?' + urlencode(v)
|
||||
return url
|
Loading…
Reference in New Issue