commit 83cbe007baa1e6d3b4067b568e29d8f606546db6 Author: Killua777 <1223086337@qq.com> Date: Wed May 1 17:18:30 2024 +0800 仓库初始化 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a1186f4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +/app.log +app.log +/utils/tts/vits_model/ +vits_model \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e69de29 diff --git a/README.md b/README.md new file mode 100644 index 0000000..7915954 --- /dev/null +++ b/README.md @@ -0,0 +1,58 @@ +# TakwayAI后端 + +### 项目架构 + +AI后端使用fastapi框架,采用分层架构 + +``` +TakwayAI/ +│ +├── app/ +│ ├── __init__.py # 应用初始化和配置 +│ ├── main.py # 应用启动入口 +│ ├── models/ # 数据模型定义 +│ │ ├── __init__.py +│ │ ├── models.py # 数据库定义 +│ ├── schemas/ # 请求和响应模型 +│ │ ├── __init__.py +│ │ ├── user.py # 用户相关schema +│ │ └── ... # 其他schema +│ ├── controllers/ # 业务逻辑控制器 +│ │ ├── __init__.py +│ │ ├── user.py # 用户相关控制器 +│ │ └── ... # 其他控制器 +│ ├── routes/ # 路由和视图函数 +│ │ ├── __init__.py +│ │ ├── user.py # 用户相关路由 +│ │ └── ... # 其他路由 +│ ├── dependencies/ # 依赖注入相关 +│ │ ├── __init__.py +│ │ ├── database.py # 数据库依赖 +│ │ └── ... # 其他依赖 +│ └── exceptions/ # 自定义异常处理 +│ ├── __init__.py +│ └── ... # 自定义异常类 +│ +├── tests/ # 测试代码 +│ ├── __init__.py +| └── ... +│ +├── config/ # 配置文件 +│ ├── __init__.py +│ ├── production.py # 生产环境配置 +│ ├── development.py # 开发环境配置 +│ └── ... # 其他环境配置 +│ +├── utils/ # 工具函数和辅助类 +│ ├── __init__.py +│ ├── stt # 语音转文本工具函数 +│ ├── tts # 语音合成工具函数 +│ └── ... # 其他工具函数 +| +├── main.py #启动脚本 +├── app.log # 日志文件 +├── Dockerfile # 用于构建Docker镜像的Dockerfile +├── requirements.txt # 项目依赖列表 +└── README.md # 项目说明文件 +``` + diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..c08db0a --- /dev/null +++ b/app/__init__.py @@ -0,0 +1,53 @@ +from fastapi import FastAPI, Depends +from fastapi.middleware.cors import CORSMiddleware +from sqlalchemy import create_engine +from .models import Base +from .routes.character import router as character_router +from .routes.user import router as user_router +from .routes.session import router as session_router +from .routes.chat import router as chat_router +from .dependencies.logger import get_logger +from config import get_config + + +#----------------------获取日志------------------------ +logger = get_logger() +logger.info("日志初始化完成") +#----------------------------------------------------- + +#----------------------获取配置------------------------ +Config = get_config() +logger.info("配置获取完成") +#----------------------------------------------------- + + +#--------------------初始化数据库----------------------- +engine = create_engine(Config.SQLALCHEMY_DATABASE_URI) +Base.metadata.create_all(bind=engine) +logger.info("数据库初始化完成") +#------------------------------------------------------ + + +#--------------------创建FastAPI实例-------------------- +app = FastAPI() +logger.info("FastAPI实例创建完成") +#------------------------------------------------------ + + +#---------------------初始化路由------------------------ +app.include_router(character_router) +app.include_router(user_router) +app.include_router(session_router) +app.include_router(chat_router) +logger.info("路由初始化完成") +#------------------------------------------------------- + +#-------------------设置跨域中间件----------------------- +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 允许所有源,也可以指定特定源 + allow_credentials=True, + allow_methods=["*"], # 允许所有方法 + allow_headers=["*"], # 允许所有头 +) +#------------------------------------------------------- \ No newline at end of file diff --git a/app/controllers/__init__.py b/app/controllers/__init__.py new file mode 100644 index 0000000..b6e690f --- /dev/null +++ b/app/controllers/__init__.py @@ -0,0 +1 @@ +from . import * diff --git a/app/controllers/character.py b/app/controllers/character.py new file mode 100644 index 0000000..17de36c --- /dev/null +++ b/app/controllers/character.py @@ -0,0 +1,89 @@ +from ..schemas.character import * +from ..dependencies.logger import get_logger +from sqlalchemy.orm import Session +from ..models import Character +from datetime import datetime +from fastapi import HTTPException, status + +#依赖注入获取logger +logger = get_logger() + + +#创建新角色 +async def create_character_handler(character: CharacterCreateRequest, db: Session): + new_character = Character(voice_id=character.voice_id, avatar_id=character.avatar_id, background_ids=character.background_ids, name=character.name, + wakeup_words=character.wakeup_words, world_scenario = character.world_scenario, description=character.description, emojis=character.emojis, dialogues=character.dialogues) + try: + db.add(new_character) + db.commit() + db.refresh(new_character) + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + + character_create_data = CharacterCreateData(character_id=new_character.id, createdAt=datetime.now().isoformat()) + return CharacterCreateResponse(status="success",message="创建角色成功",data=character_create_data) + + + +#更新角色 +async def update_character_handler(character_id: int, character: CharacterUpdateRequest, db: Session): + existing_character = db.query(Character).filter(Character.id == character_id).first() + if not existing_character: + return HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="角色不存在") + existing_character.voice_id = character.voice_id + existing_character.avatar_id = character.avatar_id + existing_character.background_ids = character.background_ids + existing_character.name = character.name + existing_character.wakeup_words = character.wakeup_words + existing_character.world_scenario = character.world_scenario + existing_character.description = character.description + existing_character.emojis = character.emojis + existing_character.dialogues = character.dialogues + + try: + db.add(existing_character) + db.commit() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + + character_update_data = CharacterUpdateData(updatedAt=datetime.now().isoformat()) + return CharacterUpdateResponse(status="success",message="更新角色成功",data=character_update_data) + + +#查询角色 +async def get_character_handler(character_id: int, db: Session): + try: + existing_character = db.query(Character).filter(Character.id == character_id).first() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + + if not existing_character: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="角色不存在") + character_query_data = CharacterQueryData(voice_id=existing_character.voice_id,avatar_id=existing_character.avatar_id,background_ids=existing_character.background_ids,name=existing_character.name, + wakeup_words=existing_character.wakeup_words,world_scenario=existing_character.world_scenario,description=existing_character.description,emojis=existing_character.emojis,dialogues=existing_character.dialogues) + return CharacterQueryResponse(status="success",message="查询角色成功",data=character_query_data) + + +#删除角色 +async def delete_character_handler(character_id: int, db: Session): + + try: + existing_character = db.query(Character).filter(Character.id == character_id).first() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + + if not existing_character: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="角色不存在") + try: + db.delete(existing_character) + db.commit() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + + character_delete_data = CharacterDeleteData(id=character_id,deletedAt=datetime.now().isoformat()) + return CharacterDeleteResponse(status="success",message="删除角色成功",data=character_delete_data) \ No newline at end of file diff --git a/app/controllers/chat.py b/app/controllers/chat.py new file mode 100644 index 0000000..4aa2602 --- /dev/null +++ b/app/controllers/chat.py @@ -0,0 +1,613 @@ +from ..schemas.chat import * +from ..dependencies.logger import get_logger +from .controller_enum import * +from ..models import UserCharacter, Session, Character, User +from utils.audio_utils import VAD +from fastapi import WebSocket, HTTPException, status +from datetime import datetime +from utils.xf_asr_utils import generate_xf_asr_url +from config import get_config +import uuid +import json +import requests +import asyncio + +# 依赖注入获取logger +logger = get_logger() + +# --------------------初始化本地ASR----------------------- +from utils.stt.funasr_utils import FunAutoSpeechRecognizer + +asr = FunAutoSpeechRecognizer() +logger.info("本地ASR初始化成功") +# ------------------------------------------------------- + +# --------------------初始化本地VITS---------------------- +from utils.tts.vits_utils import TextToSpeech + +tts = TextToSpeech(device='cpu') +logger.info("本地TTS初始化成功") +# ------------------------------------------------------- + + +# 依赖注入获取Config +Config = get_config() + +# ----------------------工具函数------------------------- +#获取session内容 +def get_session_content(session_id,redis,db): + session_content_str = "" + if redis.exists(session_id): + session_content_str = redis.get(session_id) + else: + session_db = db.query(Session).filter(Session.id == session_id).first() + if not session_db: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found") + session_content_str = session_db.content + return json.loads(session_content_str) + +#解析大模型流式返回内容 +def parseChunkDelta(chunk): + decoded_data = chunk.decode('utf-8') + parsed_data = json.loads(decoded_data[6:]) + if 'delta' in parsed_data['choices'][0]: + delta_content = parsed_data['choices'][0]['delta'] + return delta_content['content'] + else: + return "" + +#断句函数 +def split_string_with_punctuation(current_sentence,text,is_first): + result = [] + for char in text: + current_sentence += char + if is_first and char in ',.?!,。?!': + result.append(current_sentence) + current_sentence = '' + is_first = False + elif char in '。?!': + result.append(current_sentence) + current_sentence = '' + return result, current_sentence, is_first +#-------------------------------------------------------- + +# 创建新聊天 +async def create_chat_handler(chat: ChatCreateRequest, db, redis): + # 创建新的UserCharacter记录 + new_chat = UserCharacter(user_id=chat.user_id, character_id=chat.character_id) + + try: + db.add(new_chat) + db.commit() + db.refresh(new_chat) + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + + # 查询所要创建聊天的角色信息,并创建SystemPrompt + db_character = db.query(Character).filter(Character.id == chat.character_id).first() + db_user = db.query(User).filter(User.id == chat.user_id).first() + system_prompt = f"""我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\n{"角色名称: " + db_character.name}。\n{"角色背景: " + db_character.description}\n{"角色所处环境: " + db_character.world_scenario}\n +{"角色的常用问候语: " + db_character.wakeup_words}。\n你需要用简单、通俗易懂的口语化方式进行对话,在没有经过允许的情况下,你需要保持上述角色,不得擅自跳出角色设定。\n 与你聊天的对象信息如下:{db_user.persona}""" + + + # 创建新的Session记录 + session_id = str(uuid.uuid4()) + user_id = chat.user_id + messages = json.dumps([{"role": "system", "content": system_prompt}], ensure_ascii=False) + + tts_info = { + "language": 0, + "speaker_id":db_character.voice_id, + "noise_scale": 0.1, + "noise_scale_w":0.668, + "length_scale": 1.2 + } + llm_info = { + "model": "abab5.5-chat", + "temperature": 1, + "top_p": 0.9, + } + + # 将tts和llm信息转化为json字符串 + tts_info_str = json.dumps(tts_info, ensure_ascii=False) + llm_info_str = json.dumps(llm_info, ensure_ascii=False) + user_info_str = db_user.persona + + token = 0 + content = {"user_id": user_id, "messages": messages, "user_info": user_info_str, "tts_info": tts_info_str, + "llm_info": llm_info_str, "token": token} + new_session = Session(id=session_id, user_character_id=new_chat.id, content=json.dumps(content, ensure_ascii=False), + last_activity=datetime.now(), is_permanent=False) + + # 将Session记录存入 + db.add(new_session) + db.commit() + db.refresh(new_session) + redis.set(session_id, json.dumps(content, ensure_ascii=False)) + + chat_create_data = ChatCreateData(user_character_id=new_chat.id, session_id=session_id, createdAt=datetime.now().isoformat()) + return ChatCreateResponse(status="success", message="创建聊天成功", data=chat_create_data) + + +#删除聊天 +async def delete_chat_handler(user_character_id, db, redis): + # 查询该聊天记录 + user_character_record = db.query(UserCharacter).filter(UserCharacter.id == user_character_id).first() + if not user_character_record: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="UserCharacter not found") + session_record = db.query(Session).filter(Session.user_character_id == user_character_id).first() + try: + redis.delete(session_record.id) + except Exception as e: + logger.error(f"删除Redis中Session记录时发生错误: {str(e)}") + try: + db.delete(session_record) + db.delete(user_character_record) + db.commit() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + chat_delete_data = ChatDeleteData(deletedAt=datetime.now().isoformat()) + return ChatDeleteResponse(status="success", message="删除聊天成功", data=chat_delete_data) + + +# 非流式聊天 +async def non_streaming_chat_handler(chat: ChatNonStreamRequest, db, redis): + pass + + + + +#---------------------------------------单次流式聊天接口--------------------------------------------- +#处理用户输入 +async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,user_input_finish_event): + logger.debug("用户输入处理函数启动") + is_future_done = False + try: + while not user_input_finish_event.is_set(): + sct_data_json = json.loads(await ws.receive_text()) + if not is_future_done: + future_session_id.set_result(sct_data_json['meta_info']['session_id']) + if sct_data_json['meta_info']['voice_synthesize']: + future_response_type.set_result(RESPONSE_AUDIO) + else: + future_response_type.set_result(RESPONSE_TEXT) + is_future_done = True + if sct_data_json['text']: + await llm_input_q.put(sct_data_json['text']) + if not user_input_finish_event.is_set(): + user_input_finish_event.set() + break + if sct_data_json['meta_info']['is_end']: + await user_input_q.put(sct_data_json['audio']) + if not user_input_finish_event.is_set(): + user_input_finish_event.set() + break + await user_input_q.put(sct_data_json['audio']) + except KeyError as ke: + if sct_data_json['state'] == 1 and sct_data_json['method'] == 'heartbeat': + logger.debug("收到心跳包") + +#语音识别 +async def sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event): + logger.debug("语音识别函数启动") + current_message = "" + while not (user_input_finish_event.is_set() and user_input_q.empty()): + audio_data = await user_input_q.get() + asr_result = asr.streaming_recognize(audio_data) + current_message += ''.join(asr_result['text']) + asr_result = asr.streaming_recognize(b'',is_end=True) + current_message += ''.join(asr_result['text']) + await llm_input_q.put(current_message) + logger.debug(f"接收到用户消息: {current_message}") + +#大模型调用 +async def sct_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,llm_response_finish_event): + logger.debug("llm调用函数启动") + session_content = get_session_content(session_id,redis,db) + messages = json.loads(session_content["messages"]) + current_message = await llm_input_q.get() + messages.append({'role': 'user', "content": current_message}) + payload = json.dumps({ + "model": llm_info["model"], + "stream": True, + "messages": messages, + "max_tokens": 10000, + "temperature": llm_info["temperature"], + "top_p": llm_info["top_p"] + }) + headers = { + 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", + 'Content-Type': 'application/json' + } + response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) + if response.status_code == 200: + for chunk in response.iter_lines(): + if chunk: + chunk_data = parseChunkDelta(chunk) + await llm_response_q.put(chunk_data) + llm_response_finish_event.set() + +#大模型返回断句 +async def sct_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_response_finish_event): + logger.debug("llm返回处理函数启动") + llm_response = "" + current_sentence = "" + is_first = True + while not (llm_response_finish_event.is_set() and llm_response_q.empty()): + llm_chunk = await llm_response_q.get() + llm_response += llm_chunk + sentences, current_sentence, is_first = split_string_with_punctuation(current_sentence, llm_chunk, is_first) + for sentence in sentences: + await split_result_q.put(sentence) + session_content = get_session_content(session_id,redis,db) + messages = json.loads(session_content["messages"]) + messages.append({'role': 'assistant', "content": llm_response}) + session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 + redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session + logger.debug(f"llm返回结果: {llm_response}") + +#文本返回及语音合成 +async def sct_response_handler(ws,tts_info,response_type,split_result_q,llm_response_finish_event,chat_finish_event): + logger.debug("返回处理函数启动") + while not (llm_response_finish_event.is_set() and split_result_q.empty()): + sentence = await split_result_q.get() + if response_type == RESPONSE_TEXT: + response_message = {"type": "text", "code":200, "msg": sentence} + await ws.send_text(json.dumps(response_message, ensure_ascii=False)) + elif response_type == RESPONSE_AUDIO: + sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True) + response_message = {"type": "text", "code":200, "msg": sentence} + await ws.send_bytes(audio) + await ws.send_text(json.dumps(response_message, ensure_ascii=False)) + logger.debug(f"websocket返回: {sentence}") + chat_finish_event.set() + +async def streaming_chat_temporary_handler(ws: WebSocket, db, redis): + logger.debug("streaming chat temporary websocket 连接建立") + user_input_q = asyncio.Queue() # 用于存储用户输入 + llm_input_q = asyncio.Queue() # 用于存储llm输入 + llm_response_q = asyncio.Queue() # 用于存储llm输出 + split_result_q = asyncio.Queue() # 用于存储tts输出 + + user_input_finish_event = asyncio.Event() + llm_response_finish_event = asyncio.Event() + chat_finish_event = asyncio.Event() + future_session_id = asyncio.Future() + future_response_type = asyncio.Future() + asyncio.create_task(sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,future_response_type,user_input_finish_event)) + asyncio.create_task(sct_asr_handler(user_input_q,llm_input_q,user_input_finish_event)) + + + session_id = await future_session_id #获取session_id + response_type = await future_response_type #获取返回类型 + tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) + llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) + + asyncio.create_task(sct_llm_handler(session_id,llm_info,db,redis,llm_input_q,llm_response_q,llm_response_finish_event)) + asyncio.create_task(sct_llm_response_handler(session_id,redis,db,llm_response_q,split_result_q,llm_response_finish_event)) + asyncio.create_task(sct_response_handler(ws,tts_info,response_type,split_result_q,llm_response_finish_event,chat_finish_event)) + + while not chat_finish_event.is_set(): + await asyncio.sleep(1) + await ws.send_text(json.dumps({"type": "close", "code": 200, "msg": ""}, ensure_ascii=False)) + await ws.close() + logger.debug("streaming chat temporary websocket 连接断开") +#--------------------------------------------------------------------------------------------------- + + +# 持续流式聊天 +async def streaming_chat_lasting_handler(ws, db, redis): + print("Websocket连接成功") + while True: + try: + print("等待接受") + data = await asyncio.wait_for(ws.receive_text(), timeout=60) + data_json = json.loads(data) + if data_json["is_close"]: + close_message = {"type": "close", "code": 200, "msg": ""} + await ws.send_text(json.dumps(close_message, ensure_ascii=False)) + print("连接关闭") + await asyncio.sleep(0.5) + await ws.close() + return; + except asyncio.TimeoutError: + print("连接超时") + await ws.close() + return; + current_message = "" # 用于存储用户消息 + response_type = RESPONSE_TEXT # 用于获取返回类型 + session_id = "" + if Config.STRAM_CHAT.ASR == "LOCAL": + try: + while True: + if data_json["text"]: # 若文字输入不为空,则表示该输入为文字输入 + if data_json["meta_info"]["voice_synthesize"]: + response_type = RESPONSE_AUDIO # 查看voice_synthesize判断返回类型 + session_id = data_json["meta_info"]["session_id"] + current_message = data_json['text'] + break + + if not data_json['meta_info']['is_end']: # 还在发 + asr_result = asr.streaming_recognize(data_json["audio"]) + current_message += ''.join(asr_result['text']) + else: # 发完了 + asr_result = asr.streaming_recognize(data_json["audio"], is_end=True) + session_id = data_json["meta_info"]["session_id"] + current_message += ''.join(asr_result['text']) + if data_json["meta_info"]["voice_synthesize"]: + response_type = RESPONSE_AUDIO # 查看voice_synthesize判断返回类型 + break + data_json = json.loads(await ws.receive_text()) + + except Exception as e: + error_info = f"接收用户消息错误: {str(e)}" + error_message = {"type": "error", "code": "500", "msg": error_info} + logger.error(error_info) + await ws.send_text(json.dumps(error_message, ensure_ascii=False)) + await ws.close() + return + elif Config.STRAM_CHAT.ASR == "REMOTE": + error_info = f"远程ASR服务暂未开通" + error_message = {"type": "error", "code": "500", "msg": error_info} + logger.error(error_info) + await ws.send_text(json.dumps(error_message, ensure_ascii=False)) + await ws.close() + return + + print(f"接收到用户消息: {current_message}") + # 查询Session记录 + session_content_str = "" + if redis.exists(session_id): + session_content_str = redis.get(session_id) + else: + session_db = db.query(Session).filter(Session.id == session_id).first() + if not session_db: + error_info = f"未找到session记录: {str(e)}" + error_message = {"type": "error", "code": 500, "msg": error_info} + logger.error(error_info) + await ws.send_text(json.dumps(error_message, ensure_ascii=False)) + await ws.close() + return + session_content_str = session_db.content + + session_content = json.loads(session_content_str) + llm_info = json.loads(session_content["llm_info"]) + tts_info = json.loads(session_content["tts_info"]) + user_info = json.loads(session_content["user_info"]) + messages = json.loads(session_content["messages"]) + messages.append({'role': 'user', "content": current_message}) + token_count = session_content["token"] + + try: + payload = json.dumps({ + "model": llm_info["model"], + "stream": True, + "messages": messages, + "tool_choice": "auto", + "max_tokens": 10000, + "temperature": llm_info["temperature"], + "top_p": llm_info["top_p"] + }) + headers = { + 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", + 'Content-Type': 'application/json' + } + response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) + except Exception as e: + error_info = f"发送信息给大模型时发生错误: {str(e)}" + error_message = {"type": "error", "code": 500, "msg": error_info} + logger.error(error_info) + await ws.send_text(json.dumps(error_message, ensure_ascii=False)) + await ws.close() + return + + def split_string_with_punctuation(text): + punctuations = "!?。" + result = [] + current_sentence = "" + for char in text: + current_sentence += char + if char in punctuations: + result.append(current_sentence) + current_sentence = "" + # 判断最后一个字符是否为标点符号 + if current_sentence and current_sentence[-1] not in punctuations: + # 如果最后一段不以标点符号结尾,则加入拆分数组 + result.append(current_sentence) + return result + + llm_response = "" + response_buf = "" + + try: + if Config.STRAM_CHAT.TTS == "LOCAL": + if response.status_code == 200: + for chunk in response.iter_lines(): + if chunk: + if response_type == RESPONSE_AUDIO: + chunk_data = parseChunkDelta(chunk) + llm_response += chunk_data + response_buf += chunk_data + split_buf = split_string_with_punctuation(response_buf) + response_buf = "" + if len(split_buf) != 0: + for sentence in split_buf: + sr, audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True) + text_response = {"type": "text", "code": 200, "msg": sentence} + await ws.send_text(json.dumps(text_response, ensure_ascii=False)) # 返回文本数据 + await ws.send_bytes(audio) # 返回音频二进制流数据 + if response_type == RESPONSE_TEXT: + chunk_data = parseChunkDelta(chunk) + llm_response += chunk_data + text_response = {"type": "text", "code": 200, "msg": chunk_data} + await ws.send_text(json.dumps(text_response, ensure_ascii=False)) # 返回文本数据 + + elif Config.STRAM_CHAT.TTS == "REMOTE": + error_info = f"暂不支持远程音频合成" + error_message = {"type": "error", "code": 500, "msg": error_info} + logger.error(error_info) + await ws.send_text(json.dumps(error_message, ensure_ascii=False)) + await ws.close() + return + end_response = {"type": "end", "code": 200, "msg": ""} + await ws.send_text(json.dumps(end_response, ensure_ascii=False)) # 单次返回结束 + print(f"llm消息: {llm_response}") + except Exception as e: + error_info = f"音频合成与向前端返回时错误: {str(e)}" + error_message = {"type": "error", "code": 500, "msg": error_info} + logger.error(error_info) + await ws.send_text(json.dumps(error_message, ensure_ascii=False)) + await ws.close() + return + + try: + messages.append({'role': 'assistant', "content": llm_response}) + token_count += len(llm_response) + session_content["messages"] = json.dumps(messages, ensure_ascii=False) + session_content["token"] = token_count + redis.set(session_id, json.dumps(session_content, ensure_ascii=False)) + except Exception as e: + error_info = f"更新session时错误: {str(e)}" + error_message = {"type": "error", "code": 500, "msg": error_info} + logger.error(error_info) + await ws.send_text(json.dumps(error_message, ensure_ascii=False)) + await ws.close() + return + print("处理完毕") + + +#--------------------------------语音通话接口-------------------------------------- +#音频数据生产函数 +async def voice_call_audio_producer(ws,audio_queue,future,shutdown_event): + logger.debug("音频数据生产函数启动") + is_future_done = False + while not shutdown_event.is_set(): + voice_call_data_json = json.loads(await ws.receive_text()) + if not is_future_done: #在第一次循环中读取session_id + future.set_result(voice_call_data_json['meta_info']['session_id']) + is_future_done = True + if voice_call_data_json["is_close"]: + shutdown_event.set() + break + else: + await audio_queue.put(voice_call_data_json["audio"]) #将音频数据存入audio_q + +#音频数据消费函数 +async def voice_call_audio_consumer(audio_q,asr_result_q,shutdown_event): + logger.debug("音频数据消费者函数启动") + vad = VAD() + current_message = "" + vad_count = 0 + while not (shutdown_event.is_set() and audio_q.empty()): + audio_data = await audio_q.get() + if vad.is_speech(audio_data): + if vad_count > 0: + vad_count -= 1 + asr_result = asr.streaming_recognize(audio_data) + current_message += ''.join(asr_result['text']) + else: + vad_count += 1 + if vad_count >= 25: #连续25帧没有语音,则认为说完了 + asr_result = asr.streaming_recognize(audio_data, is_end=True) + if current_message: + logger.debug(f"检测到静默,用户输入为:{current_message}") + await asr_result_q.put(current_message) + current_message = "" + vad_count = 0 + +#asr结果消费以及llm返回生产函数 +async def voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,shutdown_event): + logger.debug("asr结果消费以及llm返回生产函数启动") + while not (shutdown_event.is_set() and asr_result_q.empty()): + session_content = get_session_content(session_id,redis,db) + messages = json.loads(session_content["messages"]) + current_message = await asr_result_q.get() + messages.append({'role': 'user', "content": current_message}) + payload = json.dumps({ + "model": llm_info["model"], + "stream": True, + "messages": messages, + "max_tokens":10000, + "temperature": llm_info["temperature"], + "top_p": llm_info["top_p"] + }) + + headers = { + 'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}", + 'Content-Type': 'application/json' + } + response = requests.request("POST", Config.MINIMAX_LLM.URL, headers=headers, data=payload, stream=True) + if response.status_code == 200: + for chunk in response.iter_lines(): + if chunk: + chunk_data =parseChunkDelta(chunk) + llm_frame = {'message':chunk_data,'is_end':False} + await llm_response_q.put(llm_frame) + llm_frame = {'message':"",'is_end':True} + await llm_response_q.put(llm_frame) + +#llm结果返回函数 +async def voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,shutdown_event): + logger.debug("llm结果返回函数启动") + llm_response = "" + current_sentence = "" + is_first = True + while not (shutdown_event.is_set() and llm_response_q.empty()): + llm_frame = await llm_response_q.get() + llm_response += llm_frame['message'] + sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,llm_frame['message'],is_first) + for sentence in sentences: + await split_result_q.put(sentence) + if llm_frame['is_end']: + is_first = True + session_content = get_session_content(session_id,redis,db) + messages = json.loads(session_content["messages"]) + messages.append({'role': 'assistant', "content": llm_response}) + session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话 + redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session + logger.debug(f"llm返回结果: {llm_response}") + llm_response = "" + current_sentence = "" + +#语音合成及返回函数 +async def voice_call_tts_handler(ws,tts_info,split_result_q,shutdown_event): + logger.debug("语音合成及返回函数启动") + while not (shutdown_event.is_set() and split_result_q.empty()): + sentence = await split_result_q.get() + sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True) + text_response = {"type": "text", "code": 200, "msg": sentence} + await ws.send_bytes(audio) #返回音频二进制流数据 + await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据 + logger.debug(f"websocket返回:{sentence}") + asyncio.sleep(0.5) + await ws.close() + + +async def voice_call_handler(ws, db, redis): + logger.debug("voice_call websocket 连接建立") + audio_q = asyncio.Queue() + asr_result_q = asyncio.Queue() + llm_response_q = asyncio.Queue() + split_result_q = asyncio.Queue() + + shutdown_event = asyncio.Event() + future = asyncio.Future() + asyncio.create_task(voice_call_audio_producer(ws,audio_q,future,shutdown_event)) #创建音频数据生产者 + asyncio.create_task(voice_call_audio_consumer(audio_q,asr_result_q,shutdown_event)) #创建音频数据消费者 + + #获取session内容 + session_id = await future #获取session_id + tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"]) + llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"]) + + asyncio.create_task(voice_call_llm_handler(session_id,llm_info,redis,db,asr_result_q,llm_response_q,shutdown_event)) #创建llm处理者 + asyncio.create_task(voice_call_llm_response_consumer(session_id,redis,db,llm_response_q,split_result_q,shutdown_event)) #创建llm断句结果 + asyncio.create_task(voice_call_tts_handler(ws,tts_info,split_result_q,shutdown_event)) #返回tts音频结果 + + while not shutdown_event.is_set(): + await asyncio.sleep(5) + await ws.close() + logger.debug("voice_call websocket 连接断开") +#------------------------------------------------------------------------------------------ \ No newline at end of file diff --git a/app/controllers/controller_enum.py b/app/controllers/controller_enum.py new file mode 100644 index 0000000..2046bc4 --- /dev/null +++ b/app/controllers/controller_enum.py @@ -0,0 +1,13 @@ +#聊天类型:文本、语音、不确定 +CHAT_TEXT = 0 +CHAT_AUDIO = 1 +CHAT_UNCERTAIN = -1 + +#流式传输帧类型:第一帧、中间帧、最后帧 +FIRST_FRAME = 1 +CONTINUE_FRAME =2 +LAST_FRAME =3 + +#响应类型:文本、语音 +RESPONSE_TEXT = 0 +RESPONSE_AUDIO = 1 \ No newline at end of file diff --git a/app/controllers/session.py b/app/controllers/session.py new file mode 100644 index 0000000..fd402eb --- /dev/null +++ b/app/controllers/session.py @@ -0,0 +1,79 @@ +from ..schemas.session import * +from ..dependencies.logger import get_logger +from fastapi import HTTPException, status +from ..models import Session +from ..models import UserCharacter +from datetime import datetime +import json + +#依赖注入获取logger +logger = get_logger() + + +#获取SessionID +async def get_session_id_handler(user_id: int, character_id:int, db): + try: + user_character_record = db.query(UserCharacter).filter(UserCharacter.user_id == user_id, UserCharacter.character_id == character_id).first() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + + if not user_character_record: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User Character not found") + try: + session_id = db.query(Session).filter(Session.user_character_id==user_character_record.id).first().id + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + + session_id_query_data = SessionIdQueryData(session_id=session_id) + return SessionIdQueryResponse(status="success",message="Session ID 获取成功",data=session_id_query_data) + +#查询Session信息 +async def get_session_handler(session_id: int, db, redis): + session_str = "" + if redis.exists(session_id): + session_str = redis.get(session_id) + else: + try: + session_str = db.query(Session).filter(Session.id == session_id).first().content + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + redis.set(session_id, session_str) + if not session_str: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found") + session = json.loads(session_str) + session_query_data = SessionQueryData(user_id=session["user_id"], messages=session["messages"],user_info=session["user_info"],tts_info=session["tts_info"],llm_info=session["llm_info"],token=session["token"]) + return SessionQueryResponse(status="success",message="Session 查询成功",data=session_query_data) + +#更新Sessino信息 +async def update_session_handler(session_id, session_data:SessionUpdateRequest, db, redis): + existing_session = "" + if redis.exists(session_id): + existing_session = redis.get(session_id) + else: + existing_session = db.query(Session).filter(Session.id == session_id).first().content + if not existing_session: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found") + + #更新Session字段 + session = json.loads(existing_session) + session["user_id"] = session_data.user_id + session["messages"] = session_data.messages + session["user_info"] = session_data.user_info + session["tts_info"] = session_data.tts_info + session["llm_info"] = session_data.llm_info + session["token"] = session_data.token + + #存储Session + session_str = json.dumps(session,ensure_ascii=False) + redis.set(session_id, session_str) + try: + db.query(Session).filter(Session.id == session_id).update({"content": session_str}) + db.commit() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + session_update_data = SessionUpdateData(updatedAt=datetime.now().isoformat()) + return SessionUpdateResponse(status="success",message="Session 更新成功",data=session_update_data) diff --git a/app/controllers/user.py b/app/controllers/user.py new file mode 100644 index 0000000..155cd27 --- /dev/null +++ b/app/controllers/user.py @@ -0,0 +1,159 @@ +from ..schemas.user import * +from ..dependencies.logger import get_logger +from ..models import User, Hardware +from datetime import datetime +from sqlalchemy.orm import Session +from fastapi import HTTPException, status + + +#依赖注入获取logger +logger = get_logger() + + +#创建用户 +async def create_user_handler(user:UserCrateRequest, db: Session): + new_user = User(created_at=datetime.now(), open_id=user.open_id, username=user.username, avatar_id=user.avatar_id, tags=user.tags, persona=user.persona) + try: + db.add(new_user) + db.commit() + db.refresh(new_user) + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + user_create_data = UserCrateData(user_id=new_user.id, createdAt=new_user.created_at.isoformat()) + return UserCrateResponse(status="success", message="创建用户成功", data=user_create_data) + + +#更新用户信息 +async def update_user_handler(user_id:int, user:UserUpdateRequest, db: Session): + existing_user = db.query(User).filter(User.id == user_id).first() + if existing_user is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在") + existing_user.open_id = user.open_id + existing_user.username = user.username + existing_user.avatar_id = user.avatar_id + existing_user.tags = user.tags + existing_user.persona = user.persona + try: + db.add(existing_user) + db.commit() + db.refresh(existing_user) + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + user_update_data = UserUpdateData(updatedAt=datetime.now().isoformat()) + return UserUpdateResponse(status="success", message="更新用户信息成功", data=user_update_data) + + +#查询用户信息 +async def get_user_handler(user_id:int, db: Session): + try: + existing_user = db.query(User).filter(User.id == user_id).first() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + if existing_user is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在") + user_query_data = UserQueryData(open_id=existing_user.open_id, username=existing_user.username, avatar_id=existing_user.avatar_id, tags=existing_user.tags, persona=existing_user.persona) + return UserQueryResponse(status="success", message="查询用户信息成功", data=user_query_data) + + +#删除用户 +async def delete_user_handler(user_id:int, db: Session): + try: + existing_user = db.query(User).filter(User.id == user_id).first() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + if existing_user is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在") + try: + db.delete(existing_user) + db.commit() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + user_delete_data = UserDeleteData(deletedAt=datetime.now().isoformat()) + return UserDeleteResponse(status="success", message="删除用户成功", data=user_delete_data) + + +#绑定硬件 +async def bind_hardware_handler(hardware, db: Session): + new_hardware = Hardware(mac=hardware.mac, user_id=hardware.user_id, firmware=hardware.firmware, model=hardware.model) + try: + db.add(new_hardware) + db.commit() + db.refresh(new_hardware) + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + hardware_bind_data = HardwareBindData(hardware_id=new_hardware.id, bindedAt=datetime.now().isoformat()) + return HardwareBindResponse(status="success", message="绑定硬件成功", data=hardware_bind_data) + + +#解绑硬件 +async def unbind_hardware_handler(hardware_id:int, db: Session): + try: + existing_hardware = db.query(Hardware).filter(Hardware.id == hardware_id).first() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + if existing_hardware is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在") + try: + db.delete(existing_hardware) + db.commit() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + hardware_unbind_data = HardwareUnbindData(unbindedAt=datetime.now().isoformat()) + return HardwareUnbindResponse(status="success", message="解绑硬件成功", data=hardware_unbind_data) + + +#硬件换绑 +async def change_bind_hardware_handler(hardware_id, user, db): + try: + existing_hardware = db.query(Hardware).filter(Hardware.id == hardware_id).first() + if existing_hardware is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在") + existing_hardware.user_id = user.user_id + db.add(existing_hardware) + db.commit() + db.refresh(existing_hardware) + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + hardware_change_bind_data = HardwareChangeBindData(bindChangedAt=datetime.now().isoformat()) + return HardwareChangeBindResponse(status="success", message="硬件换绑成功", data=hardware_change_bind_data) + + +#硬件信息更新 +async def update_hardware_handler(hardware_id, hardware, db): + try: + existing_hardware = db.query(Hardware).filter(Hardware.id == hardware_id).first() + if existing_hardware is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在") + existing_hardware.mac = hardware.mac + existing_hardware.firmware = hardware.firmware + existing_hardware.model = hardware.model + db.add(existing_hardware) + db.commit() + db.refresh(existing_hardware) + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + hardware_update_data = HardwareUpdateData(updatedAt=datetime.now().isoformat()) + return HardwareUpdateResponse(status="success", message="硬件信息更新成功", data=hardware_update_data) + + +#查询硬件 +async def get_hardware_handler(hardware_id, db): + try: + existing_hardware = db.query(Hardware).filter(Hardware.id == hardware_id).first() + except Exception as e: + db.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + if existing_hardware is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在") + hardware_query_data = HardwareQueryData(mac=existing_hardware.mac, user_id=existing_hardware.user_id, firmware=existing_hardware.firmware, model=existing_hardware.model) + return HardwareQueryResponse(status="success", message="查询硬件信息成功", data=hardware_query_data) \ No newline at end of file diff --git a/app/dependencies/__init__.py b/app/dependencies/__init__.py new file mode 100644 index 0000000..b974282 --- /dev/null +++ b/app/dependencies/__init__.py @@ -0,0 +1 @@ +from . import * \ No newline at end of file diff --git a/app/dependencies/database.py b/app/dependencies/database.py new file mode 100644 index 0000000..f857532 --- /dev/null +++ b/app/dependencies/database.py @@ -0,0 +1,15 @@ +from sqlalchemy.orm import sessionmaker +from sqlalchemy import create_engine +from config import get_config + +Config = get_config() + +LocalSession = sessionmaker(autocommit=False, autoflush=False, bind=create_engine(Config.SQLALCHEMY_DATABASE_URI)) + +#返回一个数据库连接 +def get_db(): + db = LocalSession() + try: + yield db + finally: + db.close() \ No newline at end of file diff --git a/app/dependencies/logger.py b/app/dependencies/logger.py new file mode 100644 index 0000000..44e0b96 --- /dev/null +++ b/app/dependencies/logger.py @@ -0,0 +1,27 @@ +import logging +from config import get_config + +#获取配置信息 +Config = get_config() + +#日志类 +class Logger: + def __init__(self): + self.logger = logging.getLogger(__name__) + self.logger.setLevel(Config.LOG_LEVEL) + self.logger.propagate = False + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + + if not self.logger.handlers: # 检查是否已经有处理器 + # 输出到控制台 + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + self.logger.addHandler(console_handler) + + # 输出到文件 + file_handler = logging.FileHandler('app.log') + file_handler.setFormatter(formatter) + self.logger.addHandler(file_handler) + +def get_logger(): + return Logger().logger \ No newline at end of file diff --git a/app/dependencies/redis.py b/app/dependencies/redis.py new file mode 100644 index 0000000..a1efc9e --- /dev/null +++ b/app/dependencies/redis.py @@ -0,0 +1,8 @@ +import redis +from config import get_config + +#获取配置信息 +Config = get_config() + +def get_redis(): + return redis.Redis.from_url(Config.REDIS_URL) diff --git a/app/exceptions/__init__.py b/app/exceptions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..cac1ead --- /dev/null +++ b/app/main.py @@ -0,0 +1,6 @@ +from app import app +import uvicorn + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=7878) diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 0000000..cf4f59d --- /dev/null +++ b/app/models/__init__.py @@ -0,0 +1 @@ +from .models import * \ No newline at end of file diff --git a/app/models/models.py b/app/models/models.py new file mode 100644 index 0000000..baa390a --- /dev/null +++ b/app/models/models.py @@ -0,0 +1,82 @@ +from sqlalchemy import Column, Integer, String, JSON, Text, ForeignKey, DateTime, Boolean, CHAR +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + + +#角色表定义 +class Character(Base): + __tablename__ = 'character' + + id = Column(Integer, primary_key=True, autoincrement=True) + voice_id = Column(Integer, nullable=False) + avatar_id = Column(String(36), nullable=False) + background_ids = Column(String(255), nullable=False) + name = Column(String(36), nullable=False) + wakeup_words = Column(String(255), nullable=False) + world_scenario = Column(Text, nullable=False) + description = Column(Text, nullable=False) + emojis = Column(JSON, nullable=False) + dialogues = Column(Text, nullable=False) + + def __repr__(self): + return f"" + + +#用户表定义 +class User(Base): + __tablename__ = 'user' + + id = Column(Integer, primary_key=True, autoincrement=True) + created_at = Column(DateTime, nullable=True) + updated_at = Column(DateTime, nullable=True) + deleted_at = Column(DateTime, nullable=True) + open_id = Column(String(255), nullable=True) + username = Column(String(64), nullable=True) + avatar_id = Column(String(36), nullable=True) + tags = Column(JSON) + persona = Column(JSON) + + def __repr__(self): + return f"" + + +#硬件表定义 +class Hardware(Base): + __tablename__ = 'hardware' + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, ForeignKey('user.id')) + mac = Column(String(17)) + firmware = Column(String(16)) + model = Column(String(36)) + + def __repr__(self): + return f"" + + +#用户角色表定义 +class UserCharacter(Base): + __tablename__ = 'user_character' + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, ForeignKey('user.id')) + character_id = Column(Integer, ForeignKey('character.id')) + persona = Column(JSON) + + def __repr__(self): + return f"" + + +#Session表定义 +class Session(Base): + __tablename__ = 'session' + + id = Column(CHAR(36), primary_key=True) + user_character_id = Column(Integer, ForeignKey('user_character.id')) + content = Column(Text) + last_activity = Column(DateTime()) + is_permanent = Column(Boolean) + + def __repr__(self): + return f"" diff --git a/app/routes/__init__.py b/app/routes/__init__.py new file mode 100644 index 0000000..b974282 --- /dev/null +++ b/app/routes/__init__.py @@ -0,0 +1 @@ +from . import * \ No newline at end of file diff --git a/app/routes/character.py b/app/routes/character.py new file mode 100644 index 0000000..3578e2c --- /dev/null +++ b/app/routes/character.py @@ -0,0 +1,35 @@ +from fastapi import APIRouter, HTTPException, status +from ..controllers.character import * +from fastapi import Depends +from sqlalchemy.orm import Session +from ..dependencies.database import get_db + +router = APIRouter() + + +#角色创建接口 +@router.post("/characters",response_model=CharacterCreateResponse) +async def create_character(character:CharacterCreateRequest, db: Session = Depends(get_db)): + response = await create_character_handler(character, db) + return response + + +#用户更新接口 +@router.put("/characters/{character_id}",response_model=CharacterUpdateResponse) +async def update_character(character_id:int,character:CharacterUpdateRequest, db: Session = Depends(get_db)): + response = await update_character_handler(character_id, character, db) + return response + + +#角色删除接口 +@router.delete("/characters/{character_id}", response_model=CharacterDeleteResponse) +async def delete_character(character_id: int, db: Session = Depends(get_db)): + response = await delete_character_handler(character_id, db) + return response + + +#角色查询接口 +@router.get("/characters/{character_id}",response_model=CharacterQueryResponse) +async def get_character(character_id: int, db: Session = Depends(get_db)): + response = await get_character_handler(character_id, db) + return response diff --git a/app/routes/chat.py b/app/routes/chat.py new file mode 100644 index 0000000..b367b9d --- /dev/null +++ b/app/routes/chat.py @@ -0,0 +1,48 @@ +from fastapi import APIRouter, HTTPException, status, WebSocket +from ..controllers.chat import * +from fastapi import Depends +from ..dependencies.database import get_db +from ..dependencies.redis import get_redis + +router = APIRouter() + + +#创建新聊天接口 +@router.post("/chats", response_model=ChatCreateResponse) +async def create_chat(chat: ChatCreateRequest, db=Depends(get_db), redis=Depends(get_redis)): + response = await create_chat_handler(chat, db, redis) + return response + +#删除聊天接口 +@router.delete("/chats/{user_character_id}", response_model=ChatDeleteResponse) +async def delete_chat(user_character_id: int, db=Depends(get_db), redis=Depends(get_redis)): + response = await delete_chat_handler(user_character_id, db, redis) + return response + +#非流式聊天 +@router.post("/chats/non-streaming", response_model=ChatNonStreamResponse) +async def non_streaming_chat(chat: ChatNonStreamRequest,db=Depends(get_db), redis=Depends(get_redis)): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,detail="this api is not available") + response = await non_streaming_chat_handler(chat, db, redis) + return response + + +#流式聊天_单次 +@router.websocket("/chat/streaming/temporary") +async def streaming_chat(ws: WebSocket,db=Depends(get_db), redis=Depends(get_redis)): + await ws.accept() + await streaming_chat_temporary_handler(ws,db,redis) + + +#流式聊天_持续 +@router.websocket("/chat/streaming/lasting") +async def streaming_chat(ws: WebSocket,db=Depends(get_db), redis=Depends(get_redis)): + await ws.accept() + await streaming_chat_lasting_handler(ws,db,redis) + + +#语音通话 +@router.websocket("/chat/voice_call") +async def voice_chat(ws: WebSocket,db=Depends(get_db), redis=Depends(get_redis)): + await ws.accept() + await voice_call_handler(ws,db,redis) \ No newline at end of file diff --git a/app/routes/session.py b/app/routes/session.py new file mode 100644 index 0000000..30e0ca6 --- /dev/null +++ b/app/routes/session.py @@ -0,0 +1,29 @@ +from fastapi import APIRouter, Query, HTTPException, status +from ..controllers.session import * +from fastapi import Depends +from ..dependencies.database import get_db +from ..dependencies.redis import get_redis + + +router = APIRouter() + + +#session_id查询接口 +@router.get("/sessions", response_model=SessionIdQueryResponse) +async def get_session_id(user_id: int=Query(..., description="用户id"), character_id: int=Query(..., description="角色id"),db=Depends(get_db)): + response = await get_session_id_handler(user_id, character_id, db) + return response + + +#session查询接口 +@router.get("/sessions/{session_id}", response_model=SessionQueryResponse) +async def get_session(session_id: str, db=Depends(get_db), redis=Depends(get_redis)): + response = await get_session_handler(session_id, db, redis) + return response + + +#session更新接口 +@router.put("/sessions/{session_id}", response_model=SessionUpdateResponse) +async def update_session(session_id: str, session_data: SessionUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)): + response = await update_session_handler(session_id, session_data, db, redis) + return response \ No newline at end of file diff --git a/app/routes/user.py b/app/routes/user.py new file mode 100644 index 0000000..75ce47e --- /dev/null +++ b/app/routes/user.py @@ -0,0 +1,71 @@ +from fastapi import APIRouter, HTTPException, status +from ..controllers.user import * +from fastapi import Depends +from sqlalchemy.orm import Session +from ..dependencies.database import get_db + + +router = APIRouter() + + +#角色创建接口 +@router.post('/users', response_model=UserCrateResponse) +async def create_user(user: UserCrateRequest,db: Session = Depends(get_db)): + response = await create_user_handler(user,db) + return response + + +#角色更新接口 +@router.put('/users/{user_id}', response_model=UserUpdateResponse) +async def update_user(user_id: int, user: UserUpdateRequest, db: Session = Depends(get_db)): + response = await update_user_handler(user_id, user, db) + return response + + +#角色删除接口 +@router.delete('/users/{user_id}', response_model=UserDeleteResponse) +async def delete_user(user_id: int, db: Session = Depends(get_db)): + response = await delete_user_handler(user_id, db) + return response + + +#角色查询接口 +@router.get('/users/{user_id}', response_model=UserQueryResponse) +async def get_user(user_id: int, db: Session = Depends(get_db)): + response = await get_user_handler(user_id, db) + return response + + +#硬件绑定接口 +@router.post('/users/hardware',response_model=HardwareBindResponse) +async def bind_hardware(hardware: HardwareBindRequest, db: Session = Depends(get_db)): + response = await bind_hardware_handler(hardware, db) + return response + + +#硬件解绑接口 +@router.delete('/users/hardware/{hardware_id}',response_model=HardwareUnbindResponse) +async def unbind_hardware(hardware_id: int, db: Session = Depends(get_db)): + response = await unbind_hardware_handler(hardware_id, db) + return response + + +#硬件换绑 +@router.put('/users/hardware/{hardware_id}/bindchange',response_model=HardwareChangeBindResponse) +async def change_bind_hardware(hardware_id: int, user: HardwareChangeBindRequest, db: Session = Depends(get_db)): + response = await change_bind_hardware_handler(hardware_id, user, db) + return response + + +#硬件信息更新 +@router.put('/users/hardware/{hardware_id}/info',response_model=HardwareUpdateResponse) +async def update_hardware_info(hardware_id: int, hardware: HardwareUpdateRequest, db: Session = Depends(get_db)): + response = await update_hardware_handler(hardware_id, hardware, db) + return response + + +#硬件查询 +@router.get('/users/hardware/{hardware_id}',response_model=HardwareQueryResponse) +async def get_hardware(hardware_id: int, db: Session = Depends(get_db)): + response = await get_hardware_handler(hardware_id, db) + return response \ No newline at end of file diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py new file mode 100644 index 0000000..b974282 --- /dev/null +++ b/app/schemas/__init__.py @@ -0,0 +1 @@ +from . import * \ No newline at end of file diff --git a/app/schemas/base.py b/app/schemas/base.py new file mode 100644 index 0000000..58ff5d5 --- /dev/null +++ b/app/schemas/base.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + +class BaseResponse(BaseModel): + status: str + message: str + data: dict \ No newline at end of file diff --git a/app/schemas/character.py b/app/schemas/character.py new file mode 100644 index 0000000..d67fc12 --- /dev/null +++ b/app/schemas/character.py @@ -0,0 +1,79 @@ +from pydantic import BaseModel +from typing import Optional +from .base import BaseResponse + +#---------------------------角色创建----------------------------- +#角色创建请求类 +class CharacterCreateRequest(BaseModel): + voice_id:int + avatar_id:str + background_ids:str + name:str + wakeup_words:str + world_scenario:str + description:str + emojis:str + dialogues:str + +#角色创建返回类 +class CharacterCreateData(BaseModel): + character_id:int + createdAt:str + +#角色创建响应类 +class CharacterCreateResponse(BaseResponse): + data: Optional[CharacterCreateData] +#---------------------------------------------------------------- + + +#---------------------------角色更新------------------------------ +#角色更新请求类 +class CharacterUpdateRequest(BaseModel): + voice_id:int + avatar_id:str + background_ids:str + name:str + wakeup_words:str + world_scenario:str + description:str + emojis:str + dialogues:str + +#角色更新返回类 +class CharacterUpdateData(BaseModel): + updatedAt:str + +#角色更新响应类 +class CharacterUpdateResponse(BaseResponse): + data: Optional[CharacterUpdateData] +#------------------------------------------------------------------ + + +#---------------------------角色查询-------------------------------- +#角色查询返回类 +class CharacterQueryData(BaseModel): + voice_id:int + avatar_id:str + background_ids:str + name:str + wakeup_words:str + world_scenario:str + description:str + emojis:str + dialogues:str + +#角色查询响应类 +class CharacterQueryResponse(BaseResponse): + data: Optional[CharacterQueryData] +#------------------------------------------------------------------ + + +#---------------------------角色删除-------------------------------- +#角色删除返回类 +class CharacterDeleteData(BaseModel): + deletedAt:str + +#角色删除响应类 +class CharacterDeleteResponse(BaseResponse): + data: Optional[CharacterDeleteData] +#------------------------------------------------------------------- diff --git a/app/schemas/chat.py b/app/schemas/chat.py new file mode 100644 index 0000000..d1fe253 --- /dev/null +++ b/app/schemas/chat.py @@ -0,0 +1,49 @@ +from pydantic import BaseModel +from typing import Optional +from .base import BaseResponse + +#--------------------------------新聊天创建-------------------------------- +#创建新聊天请求类 +class ChatCreateRequest(BaseModel): + user_id: int + character_id: int + +#创建新聊天返回类 +class ChatCreateData(BaseModel): + user_character_id: int + session_id: str + createdAt: str + +#创建新聊天相应类 +class ChatCreateResponse(BaseResponse): + data: Optional[ChatCreateData] +#-------------------------------------------------------------------------- + + +#----------------------------------聊天删除-------------------------------- +#删除聊天返回类 +class ChatDeleteData(BaseModel): + deletedAt: str + +#创建新聊天相应类 +class ChatDeleteResponse(BaseResponse): + data: Optional[ChatDeleteData] +#-------------------------------------------------------------------------- + + +#-----------------------------------非流式聊天------------------------------ +#非流式聊天请求类 +class ChatNonStreamRequest(BaseModel): + format: str + rate: int + session_id: str + speech:str + +#非流式聊天返回类 +class ChatNonStreamData(BaseModel): + audio_url: str + +#非流式聊天相应类 +class ChatNonStreamResponse(BaseResponse): + data: Optional[ChatNonStreamData] +#-------------------------------------------------------------------------- \ No newline at end of file diff --git a/app/schemas/session.py b/app/schemas/session.py new file mode 100644 index 0000000..e725a97 --- /dev/null +++ b/app/schemas/session.py @@ -0,0 +1,58 @@ +from pydantic import BaseModel +from typing import List +from typing import Optional +from .base import BaseResponse + +#----------------------------Session_id查询------------------------------ +#session_id查询请求类 +class SessionIdQueryRequest(BaseModel): + user_id: int + character_id: int + +#session_id查询返回类 +class SessionIdQueryData(BaseModel): + session_id: str + +#session_id查询响应类 +class SessionIdQueryResponse(BaseResponse): + data: Optional[SessionIdQueryData] +#------------------------------------------------------------------------- + + +#----------------------------Session会话查询------------------------------- +#session会话查询请求类 +class SessionQueryRequest(BaseModel): + user_id: int + +class SessionQueryData(BaseModel): + user_id: int + messages: str + user_info: str + tts_info: str + llm_info: str + token: int + +#session会话查询响应类 +class SessionQueryResponse(BaseResponse): + data: Optional[SessionQueryData] +#------------------------------------------------------------------------- + + +#-------------------------------Session修改-------------------------------- +#session修改请求类 +class SessionUpdateRequest(BaseModel): + user_id: int + messages: str + user_info: str + tts_info: str + llm_info: str + token: int + +#session修改返回类 +class SessionUpdateData(BaseModel): + updatedAt:str + +#session修改响应类 +class SessionUpdateResponse(BaseResponse): + data: Optional[SessionUpdateData] +#-------------------------------------------------------------------------- \ No newline at end of file diff --git a/app/schemas/user.py b/app/schemas/user.py new file mode 100644 index 0000000..8caa3db --- /dev/null +++ b/app/schemas/user.py @@ -0,0 +1,140 @@ +from pydantic import BaseModel +from typing import Optional +from .base import BaseResponse + + + +#---------------------------------用户创建---------------------------------- +#用户创建请求类 +class UserCrateRequest(BaseModel): + open_id: str + username: str + avatar_id: str + tags: str + persona: str + +#用户创建返回类 +class UserCrateData(BaseModel): + user_id : int + createdAt: str + +#用户创建响应类 +class UserCrateResponse(BaseResponse): + data: Optional[UserCrateData] +#--------------------------------------------------------------------------- + + + +#---------------------------------用户更新----------------------------------- +#用户更新请求类 +class UserUpdateRequest(BaseModel): + open_id: str + username: str + avatar_id: str + tags: str + persona: str + +#用户更新返回类 +class UserUpdateData(BaseModel): + updatedAt: str + +#用户更新响应类 +class UserUpdateResponse(BaseResponse): + data: Optional[UserUpdateData] +#---------------------------------------------------------------------------- + + + +#---------------------------------用户查询------------------------------------ +#用户查询返回类 +class UserQueryData(BaseModel): + open_id: str + username: str + avatar_id: str + tags: str + persona: str + +#用户查询响应类 +class UserQueryResponse(BaseResponse): + data: Optional[UserQueryData] +#----------------------------------------------------------------------------- + + + +#---------------------------------用户删除------------------------------------ +#用户删除返回类 +class UserDeleteData(BaseModel): + deletedAt: str + +#用户删除响应类 +class UserDeleteResponse(BaseResponse): + data: Optional[UserDeleteData] +#----------------------------------------------------------------------------- + + + +#---------------------------------硬件绑定------------------------------------- +class HardwareBindRequest(BaseModel): + mac:str + user_id:int + firmware:str + model:str + +class HardwareBindData(BaseModel): + hardware_id: int + bindedAt: str + +class HardwareBindResponse(BaseResponse): + data: Optional[HardwareBindData] +#----------------------------------------------------------------------------- + + + +#---------------------------------硬件解绑------------------------------------- +class HardwareUnbindData(BaseModel): + unbindedAt: str + +class HardwareUnbindResponse(BaseResponse): + data: Optional[HardwareUnbindData] +#------------------------------------------------------------------------------ + + + +#---------------------------------硬件换绑-------------------------------------- +class HardwareChangeBindRequest(BaseModel): + user_id:int + +class HardwareChangeBindData(BaseModel): + bindChangedAt: str + +class HardwareChangeBindResponse(BaseResponse): + data: Optional[HardwareChangeBindData] +#------------------------------------------------------------------------------ + + + +#-------------------------------硬件信息更新------------------------------------ +class HardwareUpdateRequest(BaseModel): + mac:str + firmware:str + model:str + +class HardwareUpdateData(BaseModel): + updatedAt: str + +class HardwareUpdateResponse(BaseResponse): + data: Optional[HardwareUpdateData] +#------------------------------------------------------------------------------ + + + +#--------------------------------硬件查询--------------------------------------- +class HardwareQueryData(BaseModel): + user_id:int + mac:str + firmware:str + model:str + +class HardwareQueryResponse(BaseResponse): + data: Optional[HardwareQueryData] +#------------------------------------------------------------------------------ \ No newline at end of file diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..d5f7bdc --- /dev/null +++ b/config/__init__.py @@ -0,0 +1,13 @@ +import os +from .development import DevelopmentConfig +from .production import ProductionConfig + +def get_config(): + mode = os.getenv('MODE','development').lower() + if mode == 'development': + return DevelopmentConfig + elif mode == 'production': + return ProductionConfig + else: + raise ValueError('Invalid MODE environment variable') + \ No newline at end of file diff --git a/config/development.py b/config/development.py new file mode 100644 index 0000000..2152925 --- /dev/null +++ b/config/development.py @@ -0,0 +1,23 @@ +class DevelopmentConfig: + SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://admin02:LabA100102@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置 + REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置 + LOG_LEVEL = "DEBUG" #日志级别 + class XF_ASR: + APP_ID = "your_app_id" #讯飞语音识别APP_ID + API_SECRET = "your_api_secret" #讯飞语音识别API_SECRET + API_KEY = "your_api_key" #讯飞语音识别API_KEY + DOMAIN = "iat" + LANGUAGE = "zh_cn" + ACCENT = "mandarin" + VAD_EOS = 10000 + class MINIMAX_LLM: + API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiIyMzQ1dm9yIiwiVXNlck5hbWUiOiIyMzQ1dm9yIiwiQWNjb3VudCI6IiIsIlN1YmplY3RJRCI6IjE3NTk0ODIxODAxMDAxNzAyMDgiLCJQaG9uZSI6IjE1MDcyNjQxNTYxIiwiR3JvdXBJRCI6IjE3NTk0ODIxODAwOTU5NzU5MDQiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiIiLCJDcmVhdGVUaW1lIjoiMjAyNC0wNC0xMyAxOTowNDoxNyIsImlzcyI6Im1pbmltYXgifQ.RO_WJMz5T0XlL3F6xB9p015hL3PibCbsr5KqO3aMjBL5hKrf1uIjOICTDZWZoucyJV1suxvFPAd_2Ds2Rv01eCu6GFdai1hUByfp51mOOD0PtaZ5-JKRpRPpLSNpqrNoQteANZz0gdr2_GEGTgTzpbfGbXfRYKrQyeQSvq0zHwqumGPd9gJCre2RavPUmzKRrq9EAaQXtSNhBvVkf5lDlxr8fTAHgbj6MLAJZIvvf4uOZErNrbPylo1Vcy649KxEkc0HCWOZErOieeUQFRkKibnE5Q30CgywqxY2qMjrxGRZ_dtizan_0EZ62nXp-J6jarhcY9le1SqiMu1Cv61TuA" + URL = "https://api.minimax.chat/v1/text/chatcompletion_v2" + class MINIMAX_TTA: + API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiIyMzQ1dm9yIiwiVXNlck5hbWUiOiIyMzQ1dm9yIiwiQWNjb3VudCI6IiIsIlN1YmplY3RJRCI6IjE3NTk0ODIxODAxMDAxNzAyMDgiLCJQaG9uZSI6IjE1MDcyNjQxNTYxIiwiR3JvdXBJRCI6IjE3NTk0ODIxODAwOTU5NzU5MDQiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiIiLCJDcmVhdGVUaW1lIjoiMjAyNC0wNC0xMyAxOTowNDoxNyIsImlzcyI6Im1pbmltYXgifQ.RO_WJMz5T0XlL3F6xB9p015hL3PibCbsr5KqO3aMjBL5hKrf1uIjOICTDZWZoucyJV1suxvFPAd_2Ds2Rv01eCu6GFdai1hUByfp51mOOD0PtaZ5-JKRpRPpLSNpqrNoQteANZz0gdr2_GEGTgTzpbfGbXfRYKrQyeQSvq0zHwqumGPd9gJCre2RavPUmzKRrq9EAaQXtSNhBvVkf5lDlxr8fTAHgbj6MLAJZIvvf4uOZErNrbPylo1Vcy649KxEkc0HCWOZErOieeUQFRkKibnE5Q30CgywqxY2qMjrxGRZ_dtizan_0EZ62nXp-J6jarhcY9le1SqiMu1Cv61TuA", + URL = "https://api.minimax.chat/v1/t2a_pro", + GROUP_ID ="1759482180095975904" + class STRAM_CHAT: + ASR = "LOCAL" + TTS = "LOCAL" + \ No newline at end of file diff --git a/config/production.py b/config/production.py new file mode 100644 index 0000000..7fb841c --- /dev/null +++ b/config/production.py @@ -0,0 +1,8 @@ +class ProductionConfig: + SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://root:takway@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置 + REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置 + LOG_LEVEL = "INFO" #日志级别 + class XF_ASR: + APP_ID = "your_app_id" #讯飞语音识别APP_ID + API_SECRET = "your_api_secret" #讯飞语音识别API_SECRET + API_KEY = "your_api_key" #讯飞语音识别API_KEY \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..93e804e --- /dev/null +++ b/main.py @@ -0,0 +1,9 @@ +import os + +if __name__ == '__main__': + + script_path = os.path.join(os.path.dirname(__file__), 'app', 'main.py') + + # 使用exec函数执行脚本 + with open(script_path, 'r') as file: + exec(file.read()) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..13d7b4f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +uvicorn~=0.29.0 +fastapi~=0.110.1 +sqlalchemy~=2.0.25 +pydantic~=2.6.4 +redis~=5.0.3 \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/assets/example_recording.wav b/tests/assets/example_recording.wav new file mode 100644 index 0000000..2f16668 Binary files /dev/null and b/tests/assets/example_recording.wav differ diff --git a/tests/assets/voice_call.wav b/tests/assets/voice_call.wav new file mode 100644 index 0000000..6315fc5 Binary files /dev/null and b/tests/assets/voice_call.wav differ diff --git a/tests/integration_test/backend_test.py b/tests/integration_test/backend_test.py new file mode 100644 index 0000000..a85f79e --- /dev/null +++ b/tests/integration_test/backend_test.py @@ -0,0 +1,31 @@ +from tests.unit_test.user_test import UserServiceTest +from tests.unit_test.character_test import CharacterServiceTest +from tests.unit_test.chat_test import ChatServiceTest +import asyncio + +if __name__ == '__main__': + user_service_test = UserServiceTest() + character_service_test = CharacterServiceTest() + chat_service_test = ChatServiceTest() + + user_service_test.test_user_create() + user_service_test.test_user_update() + user_service_test.test_user_query() + user_service_test.test_hardware_bind() + user_service_test.test_hardware_unbind() + user_service_test.test_user_delete() + + character_service_test.test_character_create() + character_service_test.test_character_update() + character_service_test.test_character_query() + character_service_test.test_character_delete() + + chat_service_test.test_create_chat() + chat_service_test.test_session_id_query() + chat_service_test.test_session_content_query() + chat_service_test.test_session_update() + asyncio.run(chat_service_test.test_chat_temporary()) + asyncio.run(chat_service_test.test_chat_lasting()) + asyncio.run(chat_service_test.test_voice_call()) + chat_service_test.test_chat_delete() + print("全部测试成功") \ No newline at end of file diff --git a/tests/unit_test/character_test.py b/tests/unit_test/character_test.py new file mode 100644 index 0000000..740897e --- /dev/null +++ b/tests/unit_test/character_test.py @@ -0,0 +1,74 @@ +import requests +import json + +class CharacterServiceTest: + def __init__(self,socket="http://114.214.236.207:7878"): + self.socket = socket + + def test_character_create(self): + url = f"{self.socket}/characters" + payload = json.dumps({ + "voice_id": 97, + "avatar_id": "49c838c5ffb211ee9de9a036bc278b4c", + "background_ids": "185c554affaf11eebd72a036bc278b4c,1b0e2d8bffaf11eebd72a036bc278b4c,20158587ffaf11eebd72a036bc278b4c,2834472affaf11eebd72a036bc278b4c,2c6ddb0affaf11eebd72a036bc278b4c,fd631ec4ffb011ee9b1aa036bc278b4c", + "name": "测试角色", + "wakeup_words": "你好啊,海绵宝宝", + "world_scenario": "海绵宝宝住在深海的大菠萝里面", + "description": "厨师,做汉堡", + "emojis": "大笑,微笑", + "dialogues": "我准备好了" + }) + headers = { + 'Content-Type': 'application/json' + } + response = requests.request("POST", url, headers=headers, data=payload) + if response.status_code == 200: + print("角色创建测试成功") + self.id = response.json()['data']['character_id'] + else: + raise Exception("角色创建测试失败") + + def test_character_update(self): + url = f"{self.socket}/characters/"+str(self.id) + payload = json.dumps({ + "voice_id": 97, + "avatar_id": "49c838c5ffb211ee9de9a036bc278b4c", + "background_ids": "185c554affaf11eebd72a036bc278b4c,1b0e2d8bffaf11eebd72a036bc278b4c,20158587ffaf11eebd72a036bc278b4c,2834472affaf11eebd72a036bc278b4c,2c6ddb0affaf11eebd72a036bc278b4c,fd631ec4ffb011ee9b1aa036bc278b4c", + "name": "测试角色", + "wakeup_words": "你好啊,海绵宝宝", + "world_scenario": "海绵宝宝住在深海的大菠萝里面", + "description": "厨师,做汉堡", + "emojis": "大笑,微笑", + "dialogues": "我准备好了" + }) + headers = { + 'Content-Type': 'application/json' + } + response = requests.request("PUT", url, headers=headers, data=payload) + if response.status_code == 200: + print("角色更新测试成功") + else: + raise Exception("角色更新测试失败") + + def test_character_query(self): + url = f"{self.socket}/characters/{self.id}" + response = requests.request("GET", url) + if response.status_code == 200: + print("角色查询测试成功") + else: + raise Exception("角色查询测试失败") + + def test_character_delete(self): + url = f"{self.socket}/characters/{self.id}" + response = requests.request("DELETE", url) + if response.status_code == 200: + print("角色删除测试成功") + else: + raise Exception("角色删除测试失败") + +if __name__ == '__main__': + character_service_test = CharacterServiceTest() + character_service_test.test_character_create() + character_service_test.test_character_update() + character_service_test.test_character_query() + character_service_test.test_character_delete() \ No newline at end of file diff --git a/tests/unit_test/chat_test.py b/tests/unit_test/chat_test.py new file mode 100644 index 0000000..3ffa262 --- /dev/null +++ b/tests/unit_test/chat_test.py @@ -0,0 +1,329 @@ +import requests +import base64 +import wave +import json +import os +import uuid +import asyncio +import websockets + + + +class ChatServiceTest: + def __init__(self,socket="http://114.214.236.207:7878"): + self.socket = socket + + + def test_create_chat(self): + #创建一个用户 + url = f"{self.socket}/users" + open_id = str(uuid.uuid4()) + payload = json.dumps({ + "open_id": open_id, + "username": "test_user", + "avatar_id": "0", + "tags" : "[]", + "persona" : "{}" + }) + headers = { + 'Content-Type': 'application/json' + } + response = requests.request("POST", url, headers=headers, data=payload) + if response.status_code == 200: + self.user_id = response.json()['data']['user_id'] + else: + raise Exception("创建聊天时,用户创建失败") + + #创建一个角色 + url = f"{self.socket}/characters" + payload = json.dumps({ + "voice_id": 97, + "avatar_id": "49c838c5ffb211ee9de9a036bc278b4c", + "background_ids": "185c554affaf11eebd72a036bc278b4c,1b0e2d8bffaf11eebd72a036bc278b4c,20158587ffaf11eebd72a036bc278b4c,2834472affaf11eebd72a036bc278b4c,2c6ddb0affaf11eebd72a036bc278b4c,fd631ec4ffb011ee9b1aa036bc278b4c", + "name": "test", + "wakeup_words": "你好啊,海绵宝宝", + "world_scenario": "海绵宝宝住在深海的大菠萝里面", + "description": "厨师,做汉堡", + "emojis": "大笑,微笑", + "dialogues": "我准备好了" + }) + headers = { + 'Content-Type': 'application/json' + } + response = requests.request("POST", url, headers=headers, data=payload) + if response.status_code == 200: + print("角色创建成功") + self.character_id = response.json()['data']['character_id'] + else: + raise Exception("创建聊天时,角色创建失败") + + #创建一个对话 + url = f"{self.socket}/chats" + payload = json.dumps({ + "user_id": self.user_id, + "character_id": self.character_id + }) + headers = { + 'Content-Type': 'application/json' + } + response = requests.request("POST", url, headers=headers, data=payload) + if response.status_code == 200: + print("对话创建成功") + self.session_id = response.json()['data']['session_id'] + self.user_character_id = response.json()['data']['user_character_id'] + else: + raise Exception("对话创建测试失败") + + + #测试查询session_id + def test_session_id_query(self): + url = f"{self.socket}/sessions?user_id={self.user_id}&character_id={self.character_id}" + response = requests.request("GET", url) + if response.status_code == 200: + print("session_id查询测试成功") + else: + raise Exception("session_id查询测试失败") + + + #测试查询session内容 + def test_session_content_query(self): + url = f"{self.socket}/sessions/{self.session_id}" + response = requests.request("GET", url) + if response.status_code == 200: + print("session内容查询测试成功") + else: + raise Exception("session内容查询测试失败") + + + #测试修改session + def test_session_update(self): + url = f"{self.socket}/sessions/{self.session_id}" + payload = json.dumps({ + "user_id": self.user_id, + "messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话,在没有经过允许的情况下,你需要保持上述角色,不得擅自跳出角色设定。\\n\"}]", + "user_info": "{}", + "tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2}", + "llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}", + "token": 0} + ) + headers = { + 'Content-Type': 'application/json' + } + response = requests.request("PUT", url, headers=headers, data=payload) + if response.status_code == 200: + print("Session更新测试成功") + else: + raise Exception("Session更新测试失败") + + + #测试单次聊天 + async def test_chat_temporary(self): + current_file_path = os.path.abspath(__file__) + current_dir = os.path.dirname(current_file_path) + tests_dir = os.path.dirname(current_dir) + wav_file_path = os.path.join(tests_dir, 'assets', 'example_recording.wav') + def read_wav_file_in_chunks(chunk_size): + with open(wav_file_path, 'rb') as pcm_file: + while True: + data = pcm_file.read(chunk_size) + if not data: + break + yield data + data = { + "text": "", + "audio": "", + "meta_info": { + "session_id":self.session_id, + "stream": True, + "voice_synthesize": True, + "is_end": False, + "encoding": "raw" + } + } + + #发送音频数据 + async def send_audio_chunk(websocket, chunk): + encoded_data = base64.b64encode(chunk).decode('utf-8') + data["audio"] = encoded_data + message = json.dumps(data) + await websocket.send(message) + + + async with websockets.connect(f'ws://114.214.236.207:7878/chat/streaming/temporary') as websocket: + chunks = read_wav_file_in_chunks(2048) # 读取PCM文件并生成数据块 + for chunk in chunks: + await send_audio_chunk(websocket, chunk) + await asyncio.sleep(0.01) + # 设置data字典中的"is_end"键为True,表示音频流结束 + data["meta_info"]["is_end"] = True + # 发送最后一个数据块和流结束信号 + await send_audio_chunk(websocket, b'') # 发送空数据块表示结束 + + audio_bytes = b'' + while True: + data_ws = await websocket.recv() + try: + message_json = json.loads(data_ws) + if message_json["type"] == "close": + print("单次聊天测试成功") + break # 如果没有接收到消息,则退出循环 + except Exception as e: + audio_bytes += data_ws + + await asyncio.sleep(0.04) # 等待0.04秒后断开连接 + await websocket.close() + + + #测试持续聊天 + async def test_chat_lasting(self): + current_file_path = os.path.abspath(__file__) + current_dir = os.path.dirname(current_file_path) + tests_dir = os.path.dirname(current_dir) + wav_file_path = os.path.join(tests_dir, 'assets', 'example_recording.wav') + def read_wav_file_in_chunks(chunk_size): + with open(wav_file_path, 'rb') as pcm_file: + while True: + data = pcm_file.read(chunk_size) + if not data: + break + yield data + data = { + "text": "", + "audio": "", + "meta_info": { + "session_id":self.session_id, + "stream": True, + "voice_synthesize": True, + "is_end": False, + "encoding": "raw" + }, + "is_close":False + } + async def send_audio_chunk(websocket, chunk): + encoded_data = base64.b64encode(chunk).decode('utf-8') + data["audio"] = encoded_data + message = json.dumps(data) + await websocket.send(message) + + async with websockets.connect(f'ws://114.214.236.207:7878/chat/streaming/lasting') as websocket: + #发送第一次 + chunks = read_wav_file_in_chunks(2048) + for chunk in chunks: + await send_audio_chunk(websocket, chunk) + await asyncio.sleep(0.01) + # 设置data字典中的"is_end"键为True,表示音频流结束 + data["meta_info"]["is_end"] = True + # 发送最后一个数据块和流结束信号 + await send_audio_chunk(websocket, b'') # 发送空数据块表示结束 + + await asyncio.sleep(3) #模拟发送间隔 + + #发送第二次 + data["meta_info"]["is_end"] = False + chunks = read_wav_file_in_chunks(2048) + for chunk in chunks: + await send_audio_chunk(websocket, chunk) + await asyncio.sleep(0.01) + # 设置data字典中的"is_end"键为True,表示音频流结束 + data["meta_info"]["is_end"] = True + # 发送最后一个数据块和流结束信号 + await send_audio_chunk(websocket, b'') # 发送空数据块表示结束 + + data["is_close"] = True + await send_audio_chunk(websocket, b'') # 发送空数据块表示结束 + + + audio_bytes = b'' + while True: + data_ws = await websocket.recv() + try: + message_json = json.loads(data_ws) + if message_json["type"] == "close": + print("持续聊天测试成功") + break # 如果没有接收到消息,则退出循环 + except Exception as e: + audio_bytes += data_ws + + await asyncio.sleep(0.5) # 等待0.04秒后断开连接 + await websocket.close() + + + #语音电话测试 + async def test_voice_call(self): + chunk_size = 480 + current_file_path = os.path.abspath(__file__) + current_dir = os.path.dirname(current_file_path) + tests_dir = os.path.dirname(current_dir) + file_path = os.path.join(tests_dir, 'assets', 'voice_call.wav') + url = f"ws://114.214.236.207:7878/chat/voice_call" + #发送格式 + ws_data = { + "audio" : "", + "meta_info":{ + "session_id":self.session_id, + "encoding": 'raw' + }, + "is_close" : False + } + + async def audio_stream(websocket): + with wave.open(file_path, 'rb') as wf: + frames_per_buffer = int(chunk_size / 2) + data = wf.readframes(frames_per_buffer) + while True: + if len(data) != 960: + break + encoded_data = base64.b64encode(data).decode('utf-8') + ws_data['audio'] = encoded_data + await websocket.send(json.dumps(ws_data)) + data = wf.readframes(frames_per_buffer) + await asyncio.sleep(3) + ws_data['audio'] = "" + ws_data['is_close'] = True + await websocket.send(json.dumps(ws_data)) + while True: + data_ws = await websocket.recv() + if data_ws: + print("语音电话测试成功") + break + await asyncio.sleep(3) + await websocket.close() + + async with websockets.connect(url) as websocket: + await asyncio.gather(audio_stream(websocket)) + + + + #测试删除聊天 + def test_chat_delete(self): + url = f"{self.socket}/chats/{self.user_character_id}" + response = requests.request("DELETE", url) + if response.status_code == 200: + print("聊天删除测试成功") + else: + raise Exception("聊天删除测试失败") + + url = f"{self.socket}/users/{self.user_id}" + response = requests.request("DELETE", url) + if response.status_code != 200: + raise Exception("用户删除测试失败") + + url = f"{self.socket}/characters/{self.character_id}" + response = requests.request("DELETE", url) + if response.status_code != 200: + raise Exception("角色删除测试失败") + + +if __name__ == '__main__': + chat_service_test = ChatServiceTest() + chat_service_test.test_create_chat() + chat_service_test.test_session_id_query() + chat_service_test.test_session_content_query() + chat_service_test.test_session_update() + asyncio.run(chat_service_test.test_chat_temporary()) + asyncio.run(chat_service_test.test_chat_lasting()) + asyncio.run(chat_service_test.test_voice_call()) + chat_service_test.test_chat_delete() + + + diff --git a/tests/unit_test/user_test.py b/tests/unit_test/user_test.py new file mode 100644 index 0000000..2884f92 --- /dev/null +++ b/tests/unit_test/user_test.py @@ -0,0 +1,99 @@ +import requests +import json +import uuid + + +class UserServiceTest: + def __init__(self,socket="http://114.214.236.207:7878"): + self.socket = socket + + def test_user_create(self): + url = f"{self.socket}/users" + open_id = str(uuid.uuid4()) + payload = json.dumps({ + "open_id": open_id, + "username": "test_user", + "avatar_id": "0", + "tags" : "[]", + "persona" : "{}" + }) + headers = { + 'Content-Type': 'application/json' + } + response = requests.request("POST", url, headers=headers, data=payload) + if response.status_code == 200: + print("用户创建测试成功") + self.id = response.json()["data"]["user_id"] + else: + raise Exception("用户创建测试失败") + + def test_user_update(self): + url = f"{self.socket}/users/"+str(self.id) + payload = json.dumps({ + "open_id": str(uuid.uuid4()), + "username": "test_user", + "avatar_id": "0", + "tags": "[]", + "persona": "{}" + }) + headers = { + 'Content-Type': 'application/json' + } + response = requests.request("PUT", url, headers=headers, data=payload) + if response.status_code == 200: + print("用户更新测试成功") + else: + raise Exception("用户更新测试失败") + + def test_user_query(self): + url = f"{self.socket}/users/{self.id}" + response = requests.request("GET", url) + if response.status_code == 200: + print("用户查询测试成功") + else: + raise Exception("用户查询测试失败") + + def test_user_delete(self): + url = f"{self.socket}/users/{self.id}" + response = requests.request("DELETE", url) + if response.status_code == 200: + print("用户删除测试成功") + else: + raise Exception("用户删除测试失败") + + def test_hardware_bind(self): + url = f"{self.socket}/users/hardware" + mac = "08:00:20:0A:8C:6G" + payload = json.dumps({ + "mac":mac, + "user_id":1, + "firmware":"v1.0", + "model":"香橙派" + }) + headers = { + 'Content-Type': 'application/json' + } + response = requests.request("POST", url, headers=headers, data=payload) + if response.status_code == 200: + print("硬件绑定测试成功") + self.hd_id = response.json()["data"]['hardware_id'] + else: + raise Exception("硬件绑定测试失败") + + def test_hardware_unbind(self): + url = f"{self.socket}/users/hardware/{self.hd_id}" + response = requests.request("DELETE", url) + if response.status_code == 200: + print("硬件解绑测试成功") + else: + raise Exception("硬件解绑测试失败") + +if __name__ == '__main__': + user_service_test = UserServiceTest() + user_service_test.test_user_create() + user_service_test.test_user_update() + user_service_test.test_user_query() + user_service_test.test_hardware_bind() + user_service_test.test_hardware_unbind() + user_service_test.test_user_delete() + \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..b974282 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1 @@ +from . import * \ No newline at end of file diff --git a/utils/audio_utils.py b/utils/audio_utils.py new file mode 100644 index 0000000..48d0c72 --- /dev/null +++ b/utils/audio_utils.py @@ -0,0 +1,14 @@ +import webrtcvad +import base64 + +class VAD(): + def __init__(self, vad_sensitivity=1, frame_duration=30, vad_buffer_size=7, min_act_time=1, RATE=16000,**kwargs): + self.RATE = RATE + self.vad = webrtcvad.Vad(vad_sensitivity) + self.vad_buffer_size = vad_buffer_size + self.vad_chunk_size = int(self.RATE * frame_duration / 1000) + self.min_act_time = min_act_time # 最小活动时间,单位秒 + + def is_speech(self,data): + byte_data = base64.b64decode(data) + return self.vad.is_speech(byte_data, self.RATE) \ No newline at end of file diff --git a/utils/stt/base_stt.py b/utils/stt/base_stt.py new file mode 100644 index 0000000..1dee2e9 --- /dev/null +++ b/utils/stt/base_stt.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import json +import wave +import io +import base64 + + +def decode_str2bytes(data): + # 将Base64编码的字节串解码为字节串 + if data is None: + return None + return base64.b64decode(data.encode('utf-8')) + +class STTBase: + def __init__(self, RATE=16000, cfg_path=None, debug=False): + self.RATE = RATE + self.debug = debug + self.asr_cfg = self.parse_json(cfg_path) + + def parse_json(self, cfg_path): + cfg = None + self.hotwords = None + if cfg_path is not None: + with open(cfg_path, 'r', encoding='utf-8') as f: + cfg = json.load(f) + self.hotwords = cfg.get('hot_words', None) + return cfg + + def add_hotword(self, hotword): + """add hotword to list""" + if self.hotwords is None: + self.hotwords = "" + if isinstance(hotword, str): + self.hotwords = self.hotwords + " " + "hotword" + elif isinstance(hotword, (list, tuple)): + # 将hotwords转换为str,并用空格隔开 + self.hotwords = self.hotwords + " " + " ".join(hotword) + else: + raise TypeError("hotword must be str or list") + + def check_audio_type(self, audio_data): + """check audio data type and convert it to bytes if necessary.""" + if isinstance(audio_data, bytes): + pass + elif isinstance(audio_data, list): + audio_data = b''.join(audio_data) + elif isinstance(audio_data, str): + audio_data = decode_str2bytes(audio_data) + elif isinstance(audio_data, io.BytesIO): + wf = wave.open(audio_data, 'rb') + audio_data = wf.readframes(wf.getnframes()) + else: + raise TypeError(f"audio_data must be bytes, str or io.BytesIO, but got {type(audio_data)}") + return audio_data + + def text_postprecess(self, result, data_id='text'): + """postprecess recognized result.""" + text = result[data_id] + if isinstance(text, list): + text = ''.join(text) + return text.replace(' ', '') + + def recognize(self, audio_data, queue=None): + """recognize audio data to text""" + pass diff --git a/utils/stt/funasr_utils.py b/utils/stt/funasr_utils.py new file mode 100644 index 0000000..84cc9e9 --- /dev/null +++ b/utils/stt/funasr_utils.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# ####################################################### # +# FunAutoSpeechRecognizer: https://github.com/alibaba-damo-academy/FunASR +# ####################################################### # +import io +import numpy as np +import base64 +import wave +from funasr import AutoModel +from .base_stt import STTBase + +def decode_str2bytes(data): + # 将Base64编码的字节串解码为字节串 + if data is None: + return None + return base64.b64decode(data.encode('utf-8')) + +class FunAutoSpeechRecognizer(STTBase): + def __init__(self, + model_path="paraformer-zh-streaming", + device="cuda", + RATE=16000, + cfg_path=None, + debug=False, + chunk_ms=480, + encoder_chunk_look_back=4, + decoder_chunk_look_back=1, + **kwargs): + super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug) + + self.asr_model = AutoModel(model=model_path, device=device, **kwargs) + + self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention + self.decoder_chunk_look_back = decoder_chunk_look_back #number of encoder chunks to lookback for decoder cross-attention + + #[0, 8, 4] 480ms, [0, 10, 5] 600ms + if chunk_ms == 480: + self.chunk_size = [0, 8, 4] + elif chunk_ms == 600: + self.chunk_size = [0, 10, 5] + else: + raise ValueError("`chunk_ms` should be 480 or 600, and type is int.") + self.chunk_partial_size = self.chunk_size[1] * 960 + self.audio_cache = None + self.asr_cache = {} + + + + self._init_asr() + + def check_audio_type(self, audio_data): + """check audio data type and convert it to bytes if necessary.""" + if isinstance(audio_data, bytes): + pass + elif isinstance(audio_data, list): + audio_data = b''.join(audio_data) + elif isinstance(audio_data, str): + audio_data = decode_str2bytes(audio_data) + elif isinstance(audio_data, io.BytesIO): + wf = wave.open(audio_data, 'rb') + audio_data = wf.readframes(wf.getnframes()) + elif isinstance(audio_data, np.ndarray): + pass + else: + raise TypeError(f"audio_data must be bytes, list, str, \ + io.BytesIO or numpy array, but got {type(audio_data)}") + + if isinstance(audio_data, bytes): + audio_data = np.frombuffer(audio_data, dtype=np.int16) + elif isinstance(audio_data, np.ndarray): + if audio_data.dtype != np.int16: + audio_data = audio_data.astype(np.int16) + else: + raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}") + return audio_data + + def _init_asr(self): + # 随机初始化一段音频数据 + init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16) + self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back) + self.audio_cache = None + self.asr_cache = {} + # print("init ASR model done.") + + def recognize(self, audio_data): + """recognize audio data to text""" + audio_data = self.check_audio_type(audio_data) + result = self.asr_model.generate(input=audio_data, + batch_size_s=300, + hotword=self.hotwords) + + # print(result) + text = '' + for res in result: + text += res['text'] + return text + + def streaming_recognize(self, + audio_data, + is_end=False, + auto_det_end=False): + """recognize partial result + + Args: + audio_data: bytes or numpy array, partial audio data + is_end: bool, whether the audio data is the end of a sentence + auto_det_end: bool, whether to automatically detect the end of a audio data + """ + text_dict = dict(text=[], is_end=is_end) + + audio_data = self.check_audio_type(audio_data) + if self.audio_cache is None: + self.audio_cache = audio_data + else: + # print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}") + if self.audio_cache.shape[0] > 0: + self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0) + + if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size: + return text_dict + + total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size) + + if is_end: + # if the audio data is the end of a sentence, \ + # we need to add one more chunk to the end to \ + # ensure the end of the sentence is recognized correctly. + auto_det_end = True + + if auto_det_end: + total_chunk_num += 1 + + # print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}") + end_idx = None + for i in range(total_chunk_num): + if auto_det_end: + is_end = i == total_chunk_num - 1 + start_idx = i*self.chunk_partial_size + if auto_det_end: + end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1 + else: + end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1 + # print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}") + # t_stamp = time.time() + + speech_chunk = self.audio_cache[start_idx:end_idx] + + # TODO: exceptions processes + try: + res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back) + except ValueError as e: + print(f"ValueError: {e}") + continue + text_dict['text'].append(self.text_postprecess(res[0], data_id='text')) + # print(f"each chunk time: {time.time()-t_stamp}") + + if is_end: + self.audio_cache = None + self.asr_cache = {} + else: + if end_idx: + self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache + text_dict['is_end'] = is_end + + # print(f"text_dict: {text_dict}") + return text_dict + + + \ No newline at end of file diff --git a/utils/tts/vits/__init__.py b/utils/tts/vits/__init__.py new file mode 100644 index 0000000..c96b491 --- /dev/null +++ b/utils/tts/vits/__init__.py @@ -0,0 +1,2 @@ +from .text import * +from .monotonic_align import * \ No newline at end of file diff --git a/utils/tts/vits/attentions.py b/utils/tts/vits/attentions.py new file mode 100644 index 0000000..13e1ed4 --- /dev/null +++ b/utils/tts/vits/attentions.py @@ -0,0 +1,302 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +# import commons +# from modules import LayerNorm +from utils.tts.vits import commons +from utils.tts.vits.modules import LayerNorm + + +class Encoder(nn.Module): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Decoder(nn.Module): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init)) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert t_s == t_t, "Local attention is only available for self-attention." + block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + max_relative_position = 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]])) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) + x_flat = x.view([batch, heads, length**2 + length*(length -1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x diff --git a/utils/tts/vits/commons.py b/utils/tts/vits/commons.py new file mode 100644 index 0000000..40fcc05 --- /dev/null +++ b/utils/tts/vits/commons.py @@ -0,0 +1,172 @@ +import math +import torch +from torch.nn import functional as F +import torch.jit + + +def script_method(fn, _rcb=None): + return fn + + +def script(obj, optimize=True, _frames_up=0, _rcb=None): + return obj + + +torch.jit.script_method = script_method +torch.jit.script = script + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d( + length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = ( + math.log(float(max_timescale) / float(min_timescale)) / + (num_timescales - 1)) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2,3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1. / norm_type) + return total_norm diff --git a/utils/tts/vits/models.py b/utils/tts/vits/models.py new file mode 100644 index 0000000..6e71054 --- /dev/null +++ b/utils/tts/vits/models.py @@ -0,0 +1,535 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +# import commons +# import modules +# import attentions +# import monotonic_align +from utils.tts.vits import commons, modules, attentions, monotonic_align +from utils.tts.vits.commons import init_weights, get_padding + +from torch.nn import Conv1d, ConvTranspose1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +# from commons import init_weights, get_padding + + +class StochasticDurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = modules.Log() + self.flows = nn.ModuleList() + self.flows.append(modules.ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.flows.append(modules.Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + self.post_flows = nn.ModuleList() + self.post_flows.append(modules.ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.post_flows.append(modules.Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) + logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2) + self.norm_2 = modules.LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class TextEncoder(nn.Module): + def __init__(self, + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.emb = nn.Embedding(n_vocab, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + self.encoder = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths): + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class PosteriorEncoder(nn.Module): + def __init__(self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Generator(torch.nn.Module): + def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append(weight_norm( + ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), + k, u, padding=(k-u)//2))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel//(2**(i+1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i*self.num_kernels+j](x) + else: + xs += self.resblocks[i*self.num_kernels+j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2,3,5,7,11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__(self, + n_vocab, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=0, + gin_channels=0, + use_sdp=True, + **kwargs): + + super().__init__() + self.n_vocab = n_vocab + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + + self.use_sdp = use_sdp + + self.enc_p = TextEncoder(n_vocab, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) + self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + + if use_sdp: + self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) + else: + self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) + + if n_speakers > 1: + self.emb_g = nn.Embedding(n_speakers, gin_channels) + + def forward(self, x, x_lengths, y, y_lengths, sid=None): + + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + if self.n_speakers > 0: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + else: + g = None + + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) + z_p = self.flow(z, y_mask, g=g) + + with torch.no_grad(): + # negative cross-entropy + s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] + neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s] + neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s] + neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 + + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() + + w = attn.sum(2) + if self.use_sdp: + l_length = self.dp(x, x_mask, w, g=g) + l_length = l_length / torch.sum(x_mask) + else: + logw_ = torch.log(w + 1e-6) * x_mask + logw = self.dp(x, x_mask, g=g) + l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging + + # expand prior + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) + + z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) + o = self.dec(z_slice, g=g) + return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None): + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + if self.n_speakers > 0: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + else: + g = None + + if self.use_sdp: + logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) + else: + logw = self.dp(x, x_mask, g=g) + w = torch.exp(logw) * x_mask * length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = commons.generate_path(w_ceil, attn_mask) + + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + o = self.dec((z * y_mask)[:,:,:max_len], g=g) + return o, attn, y_mask, (z, z_p, m_p, logs_p) + + def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): + assert self.n_speakers > 0, "n_speakers have to be larger than 0." + g_src = self.emb_g(sid_src).unsqueeze(-1) + g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.dec(z_hat * y_mask, g=g_tgt) + return o_hat, y_mask, (z, z_p, z_hat) + diff --git a/utils/tts/vits/modules.py b/utils/tts/vits/modules.py new file mode 100644 index 0000000..f210891 --- /dev/null +++ b/utils/tts/vits/modules.py @@ -0,0 +1,390 @@ +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm + +# import commons +# from commons import init_weights, get_padding +# from transforms import piecewise_rational_quadratic_transform +from utils.tts.vits import commons +from utils.tts.vits.commons import init_weights, get_padding +from utils.tts.vits.transforms import piecewise_rational_quadratic_transform + +LRELU_SLOPE = 0.1 + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential( + nn.ReLU(), + nn.Dropout(p_dropout)) + for _ in range(n_layers-1): + self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DDSConv(nn.Module): + """ + Dialted and Depth-Separable Convolution + """ + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size ** i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, + groups=channels, dilation=dilation, padding=padding + )) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class WN(torch.nn.Module): + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): + super(WN, self).__init__() + assert(kernel_size % 2 == 1) + self.hidden_channels =hidden_channels + self.kernel_size = kernel_size, + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + + for i in range(n_layers): + dilation = dilation_rate ** i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, + dilation=dilation, padding=padding) + in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] + else: + g_l = torch.zeros_like(x_in) + + acts = commons.fused_add_tanh_sigmoid_multiply( + x_in, + g_l, + n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:,:self.hidden_channels,:] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:,self.hidden_channels:,:] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Log(nn.Module): + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x + + +class ElementwiseAffine(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels,1)) + self.logs = nn.Parameter(torch.zeros(channels,1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1,2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class ResidualCouplingLayer(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels]*2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels]*2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1,2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + + +class ConvFlow(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.num_bins = num_bins + self.tail_bound = tail_bound + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) + self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels]*2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_derivatives = h[..., 2 * self.num_bins:] + + x1, logabsdet = piecewise_rational_quadratic_transform(x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails='linear', + tail_bound=self.tail_bound + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1,2]) + if not reverse: + return x, logdet + else: + return x diff --git a/utils/tts/vits/monotonic_align/__init__.py b/utils/tts/vits/monotonic_align/__init__.py new file mode 100644 index 0000000..e97eecc --- /dev/null +++ b/utils/tts/vits/monotonic_align/__init__.py @@ -0,0 +1,20 @@ +from numpy import zeros, int32, float32 +from torch import from_numpy + +from .core import maximum_path_jit + + +def maximum_path(neg_cent, mask): + """ numba optimized version. + neg_cent: [b, t_t, t_s] + mask: [b, t_t, t_s] + """ + device = neg_cent.device + dtype = neg_cent.dtype + neg_cent = neg_cent.data.cpu().numpy().astype(float32) + path = zeros(neg_cent.shape, dtype=int32) + + t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32) + t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32) + maximum_path_jit(path, neg_cent, t_t_max, t_s_max) + return from_numpy(path).to(device=device, dtype=dtype) diff --git a/utils/tts/vits/monotonic_align/core.py b/utils/tts/vits/monotonic_align/core.py new file mode 100644 index 0000000..1f94060 --- /dev/null +++ b/utils/tts/vits/monotonic_align/core.py @@ -0,0 +1,36 @@ +import numba + + +@numba.jit(numba.void(numba.int32[:, :, ::1], numba.float32[:, :, ::1], numba.int32[::1], numba.int32[::1]), + nopython=True, nogil=True) +def maximum_path_jit(paths, values, t_ys, t_xs): + b = paths.shape[0] + max_neg_val = -1e9 + for i in range(int(b)): + path = paths[i] + value = values[i] + t_y = t_ys[i] + t_x = t_xs[i] + + v_prev = v_cur = 0.0 + index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y - 1, x] + if x == 0: + if y == 0: + v_prev = 0. + else: + v_prev = max_neg_val + else: + v_prev = value[y - 1, x - 1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): + index = index - 1 diff --git a/utils/tts/vits/text/LICENSE b/utils/tts/vits/text/LICENSE new file mode 100644 index 0000000..4ad4ed1 --- /dev/null +++ b/utils/tts/vits/text/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2017 Keith Ito + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/utils/tts/vits/text/__init__.py b/utils/tts/vits/text/__init__.py new file mode 100644 index 0000000..edc98aa --- /dev/null +++ b/utils/tts/vits/text/__init__.py @@ -0,0 +1,57 @@ +""" from https://github.com/keithito/tacotron """ +from . import cleaners +from .symbols import symbols + + +# Mappings from symbol to numeric ID and vice versa: +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} + + +def text_to_sequence(text, symbols, cleaner_names): + '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + Returns: + List of integers corresponding to the symbols in the text + ''' + _symbol_to_id = {s: i for i, s in enumerate(symbols)} + sequence = [] + + clean_text = _clean_text(text, cleaner_names) + for symbol in clean_text: + if symbol not in _symbol_to_id.keys(): + continue + symbol_id = _symbol_to_id[symbol] + sequence += [symbol_id] + return sequence, clean_text + + +def cleaned_text_to_sequence(cleaned_text): + '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: + text: string to convert to a sequence + Returns: + List of integers corresponding to the symbols in the text + ''' + sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()] + return sequence + + +def sequence_to_text(sequence): + '''Converts a sequence of IDs back to a string''' + result = '' + for symbol_id in sequence: + s = _id_to_symbol[symbol_id] + result += s + return result + + +def _clean_text(text, cleaner_names): + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception('Unknown cleaner: %s' % name) + text = cleaner(text) + return text diff --git a/utils/tts/vits/text/cleaners.py b/utils/tts/vits/text/cleaners.py new file mode 100644 index 0000000..347ca36 --- /dev/null +++ b/utils/tts/vits/text/cleaners.py @@ -0,0 +1,475 @@ +""" from https://github.com/keithito/tacotron """ + +''' +Cleaners are transformations that run over the input text at both training and eval time. + +Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +hyperparameter. Some cleaners are English-specific. You'll typically want to use: + 1. "english_cleaners" for English text + 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using + the Unidecode library (https://pypi.python.org/pypi/Unidecode) + 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update + the symbols in symbols.py to match your data). +''' + +import re +from unidecode import unidecode +# import pyopenjtalk +from jamo import h2j, j2hcj +from pypinyin import lazy_pinyin, BOPOMOFO +import jieba, cn2an + + +# This is a list of Korean classifiers preceded by pure Korean numerals. +_korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통' + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r'\s+') + +# Regular expression matching Japanese without punctuation marks: +_japanese_characters = re.compile(r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') + +# Regular expression matching non-Japanese characters or punctuation marks: +_japanese_marks = re.compile(r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), +]] + +# List of (hangul, hangul divided) pairs: +_hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [ + ('ㄳ', 'ㄱㅅ'), + ('ㄵ', 'ㄴㅈ'), + ('ㄶ', 'ㄴㅎ'), + ('ㄺ', 'ㄹㄱ'), + ('ㄻ', 'ㄹㅁ'), + ('ㄼ', 'ㄹㅂ'), + ('ㄽ', 'ㄹㅅ'), + ('ㄾ', 'ㄹㅌ'), + ('ㄿ', 'ㄹㅍ'), + ('ㅀ', 'ㄹㅎ'), + ('ㅄ', 'ㅂㅅ'), + ('ㅘ', 'ㅗㅏ'), + ('ㅙ', 'ㅗㅐ'), + ('ㅚ', 'ㅗㅣ'), + ('ㅝ', 'ㅜㅓ'), + ('ㅞ', 'ㅜㅔ'), + ('ㅟ', 'ㅜㅣ'), + ('ㅢ', 'ㅡㅣ'), + ('ㅑ', 'ㅣㅏ'), + ('ㅒ', 'ㅣㅐ'), + ('ㅕ', 'ㅣㅓ'), + ('ㅖ', 'ㅣㅔ'), + ('ㅛ', 'ㅣㅗ'), + ('ㅠ', 'ㅣㅜ') +]] + +# List of (Latin alphabet, hangul) pairs: +_latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ + ('a', '에이'), + ('b', '비'), + ('c', '시'), + ('d', '디'), + ('e', '이'), + ('f', '에프'), + ('g', '지'), + ('h', '에이치'), + ('i', '아이'), + ('j', '제이'), + ('k', '케이'), + ('l', '엘'), + ('m', '엠'), + ('n', '엔'), + ('o', '오'), + ('p', '피'), + ('q', '큐'), + ('r', '아르'), + ('s', '에스'), + ('t', '티'), + ('u', '유'), + ('v', '브이'), + ('w', '더블유'), + ('x', '엑스'), + ('y', '와이'), + ('z', '제트') +]] + +# List of (Latin alphabet, bopomofo) pairs: +_latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ + ('a', 'ㄟˉ'), + ('b', 'ㄅㄧˋ'), + ('c', 'ㄙㄧˉ'), + ('d', 'ㄉㄧˋ'), + ('e', 'ㄧˋ'), + ('f', 'ㄝˊㄈㄨˋ'), + ('g', 'ㄐㄧˋ'), + ('h', 'ㄝˇㄑㄩˋ'), + ('i', 'ㄞˋ'), + ('j', 'ㄐㄟˋ'), + ('k', 'ㄎㄟˋ'), + ('l', 'ㄝˊㄛˋ'), + ('m', 'ㄝˊㄇㄨˋ'), + ('n', 'ㄣˉ'), + ('o', 'ㄡˉ'), + ('p', 'ㄆㄧˉ'), + ('q', 'ㄎㄧㄡˉ'), + ('r', 'ㄚˋ'), + ('s', 'ㄝˊㄙˋ'), + ('t', 'ㄊㄧˋ'), + ('u', 'ㄧㄡˉ'), + ('v', 'ㄨㄧˉ'), + ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'), + ('x', 'ㄝˉㄎㄨˋㄙˋ'), + ('y', 'ㄨㄞˋ'), + ('z', 'ㄗㄟˋ') +]] + + +# List of (bopomofo, romaji) pairs: +_bopomofo_to_romaji = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ + ('ㄅㄛ', 'p⁼wo'), + ('ㄆㄛ', 'pʰwo'), + ('ㄇㄛ', 'mwo'), + ('ㄈㄛ', 'fwo'), + ('ㄅ', 'p⁼'), + ('ㄆ', 'pʰ'), + ('ㄇ', 'm'), + ('ㄈ', 'f'), + ('ㄉ', 't⁼'), + ('ㄊ', 'tʰ'), + ('ㄋ', 'n'), + ('ㄌ', 'l'), + ('ㄍ', 'k⁼'), + ('ㄎ', 'kʰ'), + ('ㄏ', 'h'), + ('ㄐ', 'ʧ⁼'), + ('ㄑ', 'ʧʰ'), + ('ㄒ', 'ʃ'), + ('ㄓ', 'ʦ`⁼'), + ('ㄔ', 'ʦ`ʰ'), + ('ㄕ', 's`'), + ('ㄖ', 'ɹ`'), + ('ㄗ', 'ʦ⁼'), + ('ㄘ', 'ʦʰ'), + ('ㄙ', 's'), + ('ㄚ', 'a'), + ('ㄛ', 'o'), + ('ㄜ', 'ə'), + ('ㄝ', 'e'), + ('ㄞ', 'ai'), + ('ㄟ', 'ei'), + ('ㄠ', 'au'), + ('ㄡ', 'ou'), + ('ㄧㄢ', 'yeNN'), + ('ㄢ', 'aNN'), + ('ㄧㄣ', 'iNN'), + ('ㄣ', 'əNN'), + ('ㄤ', 'aNg'), + ('ㄧㄥ', 'iNg'), + ('ㄨㄥ', 'uNg'), + ('ㄩㄥ', 'yuNg'), + ('ㄥ', 'əNg'), + ('ㄦ', 'əɻ'), + ('ㄧ', 'i'), + ('ㄨ', 'u'), + ('ㄩ', 'ɥ'), + ('ˉ', '→'), + ('ˊ', '↑'), + ('ˇ', '↓↑'), + ('ˋ', '↓'), + ('˙', ''), + (',', ','), + ('。', '.'), + ('!', '!'), + ('?', '?'), + ('—', '-') +]] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, ' ', text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def japanese_to_romaji_with_accent(text): + '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' + sentences = re.split(_japanese_marks, text) + marks = re.findall(_japanese_marks, text) + text = '' + for i, sentence in enumerate(sentences): + if re.match(_japanese_characters, sentence): + if text!='': + text+=' ' + labels = pyopenjtalk.extract_fullcontext(sentence) + for n, label in enumerate(labels): + phoneme = re.search(r'\-([^\+]*)\+', label).group(1) + if phoneme not in ['sil','pau']: + text += phoneme.replace('ch','ʧ').replace('sh','ʃ').replace('cl','Q') + else: + continue + n_moras = int(re.search(r'/F:(\d+)_', label).group(1)) + a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1)) + a2 = int(re.search(r"\+(\d+)\+", label).group(1)) + a3 = int(re.search(r"\+(\d+)/", label).group(1)) + if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil','pau']: + a2_next=-1 + else: + a2_next = int(re.search(r"\+(\d+)\+", labels[n + 1]).group(1)) + # Accent phrase boundary + if a3 == 1 and a2_next == 1: + text += ' ' + # Falling + elif a1 == 0 and a2_next == a2 + 1 and a2 != n_moras: + text += '↓' + # Rising + elif a2 == 1 and a2_next == 2: + text += '↑' + if i= bin_locations, + dim=-1 + ) - 1 + + +def unconstrained_rational_quadratic_spline(inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails='linear', + tail_bound=1., + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == 'linear': + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError('{} tails are not implemented.'.format(tails)) + + outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative + ) + + return outputs, logabsdet + +def rational_quadratic_spline(inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0., right=1., bottom=0., top=1., + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError('Input to a transform is not within its domain') + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError('Minimal bin width too large for the number of bins') + if min_bin_height * num_bins > 1.0: + raise ValueError('Minimal bin height too large for the number of bins') + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (((inputs - input_cumheights) * (input_derivatives + + input_derivatives_plus_one + - 2 * input_delta) + + input_heights * (input_delta - input_derivatives))) + b = (input_heights * input_derivatives + - (inputs - input_cumheights) * (input_derivatives + + input_derivatives_plus_one + - 2 * input_delta)) + c = - input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta) + derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2)) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * (input_delta * theta.pow(2) + + input_derivatives * theta_one_minus_theta) + denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2)) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/utils/tts/vits/utils.py b/utils/tts/vits/utils.py new file mode 100644 index 0000000..ee4b01d --- /dev/null +++ b/utils/tts/vits/utils.py @@ -0,0 +1,225 @@ +import os +import sys +import argparse +import logging +import json +import subprocess +import numpy as np +import librosa +import torch + +MATPLOTLIB_FLAG = False + +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) +logger = logging + + +def load_checkpoint(checkpoint_path, model, optimizer=None): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') + iteration = checkpoint_dict['iteration'] + learning_rate = checkpoint_dict['learning_rate'] + if optimizer is not None: + optimizer.load_state_dict(checkpoint_dict['optimizer']) + saved_state_dict = checkpoint_dict['model'] + if hasattr(model, 'module'): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + new_state_dict= {} + for k, v in state_dict.items(): + try: + new_state_dict[k] = saved_state_dict[k] + except: + logger.info("%s is not in the checkpoint" % k) + new_state_dict[k] = v + if hasattr(model, 'module'): + model.module.load_state_dict(new_state_dict) + else: + model.load_state_dict(new_state_dict) + logger.info("Loaded checkpoint '{}' (iteration {})" .format( + checkpoint_path, iteration)) + return model, optimizer, learning_rate, iteration + + +def plot_spectrogram_to_numpy(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger('matplotlib') + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10,2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def plot_alignment_to_numpy(alignment, info=None): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger('matplotlib') + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(6, 4)) + im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', + interpolation='none') + fig.colorbar(im, ax=ax) + xlabel = 'Decoder timestep' + if info is not None: + xlabel += '\n\n' + info + plt.xlabel(xlabel) + plt.ylabel('Encoder timestep') + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def load_audio_to_torch(full_path, target_sampling_rate): + audio, sampling_rate = librosa.load(full_path, sr=target_sampling_rate, mono=True) + return torch.FloatTensor(audio.astype(np.float32)) + + +def load_filepaths_and_text(filename, split="|"): + with open(filename, encoding='utf-8') as f: + filepaths_and_text = [line.strip().split(split) for line in f] + return filepaths_and_text + + +def get_hparams(init=True): + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', type=str, default="./configs/base.json", + help='JSON file for configuration') + parser.add_argument('-m', '--model', type=str, required=True, + help='Model name') + + args = parser.parse_args() + model_dir = os.path.join("./logs", args.model) + + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + config_path = args.config + config_save_path = os.path.join(model_dir, "config.json") + if init: + with open(config_path, "r") as f: + data = f.read() + with open(config_save_path, "w") as f: + f.write(data) + else: + with open(config_save_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + hparams.model_dir = model_dir + return hparams + + +def get_hparams_from_dir(model_dir): + config_save_path = os.path.join(model_dir, "config.json") + with open(config_save_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams =HParams(**config) + hparams.model_dir = model_dir + return hparams + + +def get_hparams_from_file(config_path): + with open(config_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams =HParams(**config) + return hparams + + +def check_git_hash(model_dir): + source_dir = os.path.dirname(os.path.realpath(__file__)) + if not os.path.exists(os.path.join(source_dir, ".git")): + logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( + source_dir + )) + return + + cur_hash = subprocess.getoutput("git rev-parse HEAD") + + path = os.path.join(model_dir, "githash") + if os.path.exists(path): + saved_hash = open(path).read() + if saved_hash != cur_hash: + logger.warn("git hash values are different. {}(saved) != {}(current)".format( + saved_hash[:8], cur_hash[:8])) + else: + open(path, "w").write(cur_hash) + + +def get_logger(model_dir, filename="train.log"): + global logger + logger = logging.getLogger(os.path.basename(model_dir)) + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") + if not os.path.exists(model_dir): + os.makedirs(model_dir) + h = logging.FileHandler(os.path.join(model_dir, filename)) + h.setLevel(logging.DEBUG) + h.setFormatter(formatter) + logger.addHandler(h) + return logger + + +class HParams(): + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = HParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() diff --git a/utils/tts/vits_utils.py b/utils/tts/vits_utils.py new file mode 100644 index 0000000..ddcc55c --- /dev/null +++ b/utils/tts/vits_utils.py @@ -0,0 +1,115 @@ +import os +import numpy as np +import torch +from torch import LongTensor +import soundfile as sf +# vits +from .vits import utils, commons +from .vits.models import SynthesizerTrn +from .vits.text import text_to_sequence + +def tts_model_init(model_path='./vits_model', device='cuda'): + hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json')) + # hps_ms = utils.get_hparams_from_file('vits_model/config.json') + net_g_ms = SynthesizerTrn( + len(hps_ms.symbols), + hps_ms.data.filter_length // 2 + 1, + hps_ms.train.segment_size // hps_ms.data.hop_length, + n_speakers=hps_ms.data.n_speakers, + **hps_ms.model) + net_g_ms = net_g_ms.eval().to(device) + speakers = hps_ms.speakers + utils.load_checkpoint(os.path.join(model_path, 'G_953000.pth'), net_g_ms, None) + # utils.load_checkpoint('vits_model/G_953000.pth', net_g_ms, None) + return hps_ms, net_g_ms, speakers + +class TextToSpeech: + def __init__(self, + model_path="./utils/tts/vits_model", + device='cuda', + RATE=22050, + debug=False, + ): + self.debug = debug + self.RATE = RATE + self.device = torch.device(device) + self.limitation = os.getenv("SYSTEM") == "spaces" # 在huggingface spaces中限制文本和音频长度 + self.hps_ms, self.net_g_ms, self.speakers = self._tts_model_init(model_path) + + def _tts_model_init(self, model_path): + hps_ms = utils.get_hparams_from_file(os.path.join(model_path, 'config.json')) + net_g_ms = SynthesizerTrn( + len(hps_ms.symbols), + hps_ms.data.filter_length // 2 + 1, + hps_ms.train.segment_size // hps_ms.data.hop_length, + n_speakers=hps_ms.data.n_speakers, + **hps_ms.model) + net_g_ms = net_g_ms.eval().to(self.device) + speakers = hps_ms.speakers + utils.load_checkpoint(os.path.join(model_path, 'G_953000.pth'), net_g_ms, None) + if self.debug: + print("Model loaded.") + return hps_ms, net_g_ms, speakers + + def _get_text(self, text): + text_norm, clean_text = text_to_sequence(text, self.hps_ms.symbols, self.hps_ms.data.text_cleaners) + if self.hps_ms.data.add_blank: + text_norm = commons.intersperse(text_norm, 0) + text_norm = LongTensor(text_norm) + return text_norm, clean_text + + def _preprocess_text(self, text, language): + if language == 0: + return f"[ZH]{text}[ZH]" + elif language == 1: + return f"[JA]{text}[JA]" + return text + + def _generate_audio(self, text, speaker_id, noise_scale, noise_scale_w, length_scale): + import time + start_time = time.time() + stn_tst, clean_text = self._get_text(text) + with torch.no_grad(): + x_tst = stn_tst.unsqueeze(0).to(self.device) + x_tst_lengths = LongTensor([stn_tst.size(0)]).to(self.device) + speaker_id = LongTensor([speaker_id]).to(self.device) + audio = self.net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w, + length_scale=length_scale)[0][0, 0].data.cpu().float().numpy() + if self.debug: + print(f"Synthesis time: {time.time() - start_time} s") + return audio + + def synthesize(self, text, language, speaker_id, noise_scale, noise_scale_w, length_scale, save_audio=False, return_bytes=False): + if not len(text): + return "输入文本不能为空!", None + text = text.replace('\n', ' ').replace('\r', '').replace(" ", "") + if len(text) > 100 and self.limitation: + return f"输入文字过长!{len(text)}>100", None + text = self._preprocess_text(text, language) + audio = self._generate_audio(text, speaker_id, noise_scale, noise_scale_w, length_scale) + if self.debug or save_audio: + self.save_audio(audio, self.RATE, 'output_file.wav') + if return_bytes: + audio = self.convert_numpy_to_bytes(audio) + return self.RATE, audio + + def convert_numpy_to_bytes(self, audio_data): + if isinstance(audio_data, np.ndarray): + if audio_data.dtype == np.dtype('float32'): + audio_data = np.int16(audio_data * np.iinfo(np.int16).max) + audio_data = audio_data.tobytes() + return audio_data + else: + raise TypeError("audio_data must be a numpy array") + + def save_audio(self, audio, sample_rate, file_name='output_file.wav'): + sf.write(file_name, audio, samplerate=sample_rate) + print(f"VITS Audio saved to {file_name}") + + + + + + + + diff --git a/utils/xf_asr_utils.py b/utils/xf_asr_utils.py new file mode 100644 index 0000000..d141a8e --- /dev/null +++ b/utils/xf_asr_utils.py @@ -0,0 +1,37 @@ +import datetime +import hashlib +import base64 +import hmac +from urllib.parse import urlencode +from wsgiref.handlers import format_date_time +from datetime import datetime +from time import mktime +from config import get_config + +Config = get_config() + +def generate_xf_asr_url(): + #设置讯飞流式听写API相关参数 + APIKey = Config.XF_ASR.API_KEY + APISecret = Config.XF_ASR.API_SECRET + + #鉴权并创建websocket_url + url = 'wss://ws-api.xfyun.cn/v2/iat' + now = datetime.now() + date = format_date_time(mktime(now.timetuple())) + signature_origin = "host: " + "ws-api.xfyun.cn" + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + "/v2/iat " + "HTTP/1.1" + signature_sha = hmac.new(APISecret.encode('utf-8'), signature_origin.encode('utf-8'), + digestmod=hashlib.sha256).digest() + signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') + authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( + APIKey, "hmac-sha256", "host date request-line", signature_sha) + authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + v = { + "authorization": authorization, + "date": date, + "host": "ws-api.xfyun.cn" + } + url = url + '?' + urlencode(v) + return url