Compare commits
5 Commits
Author | SHA1 | Date |
---|---|---|
|
d0b4bd4b3c | |
|
2b870c2e7d | |
|
05ccd1c8c0 | |
|
42767b065f | |
|
a776258f8b |
|
@ -7,11 +7,3 @@ __pycache__/
|
|||
app.log
|
||||
/utils/tts/vits_model/
|
||||
vits_model
|
||||
|
||||
nohup.out
|
||||
|
||||
/app/takway-ai.top.key
|
||||
/app/takway-ai.top.pem
|
||||
|
||||
tests/assets/BarbieDollsVoice.mp3
|
||||
utils/tts/openvoice_model/checkpoint.pth
|
||||
|
|
122
README.md
122
README.md
|
@ -15,15 +15,15 @@ TakwayAI/
|
|||
│ │ ├── models.py # 数据库定义
|
||||
│ ├── schemas/ # 请求和响应模型
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── user_schemas.py # 用户相关schema
|
||||
│ │ ├── user.py # 用户相关schema
|
||||
│ │ └── ... # 其他schema
|
||||
│ ├── controllers/ # 业务逻辑控制器
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── user_controllers.py # 用户相关控制器
|
||||
│ │ ├── user.py # 用户相关控制器
|
||||
│ │ └── ... # 其他控制器
|
||||
│ ├── routes/ # 路由和视图函数
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── user_routes.py # 用户相关路由
|
||||
│ │ ├── user.py # 用户相关路由
|
||||
│ │ └── ... # 其他路由
|
||||
│ ├── dependencies/ # 依赖注入相关
|
||||
│ │ ├── __init__.py
|
||||
|
@ -64,129 +64,21 @@ TakwayAI/
|
|||
git clone http://43.132.157.186:3000/killua/TakwayPlatform.git
|
||||
```
|
||||
|
||||
#### (2) 创建虚拟环境
|
||||
|
||||
创建虚拟环境
|
||||
#### (2) 安装依赖
|
||||
|
||||
``` shell
|
||||
conda create -n takway python=3.9
|
||||
conda activate takway
|
||||
```
|
||||
|
||||
#### (3) 安装依赖
|
||||
|
||||
如果你的本地环境可以科学上网,则直接运行下面两行指令
|
||||
|
||||
``` shell
|
||||
pip install git+https://github.com/myshell-ai/MeloTTS.git
|
||||
python -m unidic download
|
||||
```
|
||||
|
||||
如果不能科学上网
|
||||
|
||||
则先运行
|
||||
|
||||
``` shell
|
||||
pip install git+https://github.com/myshell-ai/MeloTTS.git
|
||||
```
|
||||
|
||||
1. unidic安装
|
||||
|
||||
然后手动下载[unidic.zip](https://cotonoha-dic.s3-ap-northeast-1.amazonaws.com/unidic-3.1.0.zip),并手动改名为unidic.zip
|
||||
|
||||
这边以miniconda举例,如果用的是conda应该也是一样的
|
||||
|
||||
将unidic.zip拷贝入~/miniconda3/envs/takway/lib/python3.9/site-packages/unidic
|
||||
|
||||
cd进入~/miniconda3/envs/takway/lib/python3.9/site-packages/unidic
|
||||
|
||||
vim download.py
|
||||
|
||||
将函数download_version()中除了最后一行全部注释掉,并且把最后一行的download_and_clean()的两个参数任意修改,比如"hello","world"
|
||||
|
||||
再将download_and_clean()函数定义位置,注释掉该函数中的download_process()行
|
||||
|
||||
运行`python -m unidic download`
|
||||
|
||||
2. huggingface配置
|
||||
|
||||
运行命令
|
||||
|
||||
```shell
|
||||
pip install -U huggingface_hub
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
```
|
||||
|
||||
最好把`export HF_ENDPOINT=https://hf-mirror.com`写入~/.bashrc,不然每次重启控制台终端就会失效
|
||||
|
||||
3. nltk_data下载
|
||||
|
||||
在/miniconda3/envs/takway/下创建nltk_data文件夹
|
||||
|
||||
在nltk_data文件夹下创建corpora和taggers文件夹
|
||||
|
||||
手动下载[cmudict.zip](https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/cmudict.zip)和[averaged_perceptron_tragger.zip](https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/taggers/averaged_perceptron_tagger.zip)
|
||||
|
||||
将cmudict.zip放入corpora文件夹下
|
||||
|
||||
将averaged_perceptron_tragger.zip放入taggers文件夹下
|
||||
|
||||
4. 下载其他依赖
|
||||
|
||||
``` shell
|
||||
cd TakwayPlatform/
|
||||
cd TakwayAI/
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
5. debug
|
||||
若出现AttributeError: module 'botocore.exceptions' has no attribute 'HTTPClientError'异常
|
||||
则执行下述命令
|
||||
|
||||
``` shell
|
||||
pip uninstall botocore
|
||||
pip install botocore==1.34.88
|
||||
```
|
||||
|
||||
#### (4) 安装FunASR
|
||||
|
||||
本项目使用的FunASRE在github上的FunASR的基础上做了一些修改
|
||||
|
||||
``` shell
|
||||
git clone http://43.132.157.186:3000/gaohz/FunASR.git
|
||||
cd FunASR/
|
||||
pip install -v -e .
|
||||
```
|
||||
|
||||
#### (5) 修改配置
|
||||
#### (3) 修改配置
|
||||
|
||||
1. 安装mysql,在mysql中创建名为takway的数据库
|
||||
2. 安装redis,将密码设置为takway
|
||||
3. 打开config中的development.py文件修改mysql和redis连接字符串
|
||||
|
||||
#### (6) 导入vits模型
|
||||
#### (4) 导入vits模型
|
||||
|
||||
在utils/tts/目录下,创建vits_model文件夹
|
||||
|
||||
从[链接](https://huggingface.co/spaces/zomehwh/vits-uma-genshin-honkai/tree/main/model)下载 vits_model并放入该文件夹下,只需下载config.json和G_953000.pth即可
|
||||
|
||||
#### (7) 导入openvoice模型
|
||||
|
||||
在utils/tts/目录下,创建openvoice_model文件夹
|
||||
|
||||
从[链接](https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/checkpoints_v2_0417.zip)下载model文件,并放入该文件夹下
|
||||
|
||||
#### (8) 添加配置环境变量
|
||||
|
||||
``` shell
|
||||
vim ~/.bashrc
|
||||
在最后添加
|
||||
export MODE=development
|
||||
```
|
||||
|
||||
#### (9) 启动程序
|
||||
|
||||
``` shell
|
||||
cd TakwayPlatform
|
||||
python main.py
|
||||
```
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ logger.info("数据库初始化完成")
|
|||
|
||||
#--------------------设置定时任务-----------------------
|
||||
scheduler = AsyncIOScheduler()
|
||||
# scheduler.add_job(updating_redis_cache, CronTrigger.from_crontab("0 4 * * *"))
|
||||
scheduler.add_job(updating_redis_cache, CronTrigger.from_crontab("0 4 * * *"))
|
||||
@asynccontextmanager
|
||||
async def lifespan(app:FastAPI):
|
||||
scheduler.start() #启动定时任务
|
||||
|
|
|
@ -1,38 +1,35 @@
|
|||
from ..schemas.chat_schema import *
|
||||
from ..dependencies.logger import get_logger
|
||||
from ..dependencies.summarizer import get_summarizer
|
||||
from ..dependencies.asr import get_asr
|
||||
from ..dependencies.tts import get_tts
|
||||
from .controller_enum import *
|
||||
from ..models import UserCharacter, Session, Character, User, Audio
|
||||
from ..models import UserCharacter, Session, Character, User
|
||||
from utils.audio_utils import VAD
|
||||
from fastapi import WebSocket, HTTPException, status
|
||||
from datetime import datetime
|
||||
from utils.xf_asr_utils import xf_asr_websocket_factory, make_first_frame, make_continue_frame, make_last_frame, parse_xfasr_recv
|
||||
from utils.xf_asr_utils import generate_xf_asr_url
|
||||
from config import get_config
|
||||
import numpy as np
|
||||
import websockets
|
||||
import struct
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import io
|
||||
import requests
|
||||
|
||||
# 依赖注入获取logger
|
||||
logger = get_logger()
|
||||
|
||||
# 依赖注入获取context总结服务
|
||||
summarizer = get_summarizer()
|
||||
# --------------------初始化本地ASR-----------------------
|
||||
from utils.stt.funasr_utils import FunAutoSpeechRecognizer
|
||||
|
||||
# -----------------------获取ASR-------------------------
|
||||
asr = get_asr()
|
||||
asr = FunAutoSpeechRecognizer()
|
||||
logger.info("本地ASR初始化成功")
|
||||
# -------------------------------------------------------
|
||||
|
||||
# -------------------------TTS--------------------------
|
||||
tts = get_tts()
|
||||
# --------------------初始化本地VITS----------------------
|
||||
from utils.tts.vits_utils import TextToSpeech
|
||||
|
||||
tts = TextToSpeech(device='cpu')
|
||||
logger.info("本地TTS初始化成功")
|
||||
# -------------------------------------------------------
|
||||
|
||||
|
||||
# 依赖注入获取Config
|
||||
Config = get_config()
|
||||
|
||||
|
@ -53,20 +50,16 @@ def get_session_content(session_id,redis,db):
|
|||
def parseChunkDelta(chunk):
|
||||
try:
|
||||
if chunk == b"":
|
||||
return 1,""
|
||||
return ""
|
||||
decoded_data = chunk.decode('utf-8')
|
||||
parsed_data = json.loads(decoded_data[6:])
|
||||
if 'delta' in parsed_data['choices'][0]:
|
||||
delta_content = parsed_data['choices'][0]['delta']
|
||||
return -1, delta_content['content']
|
||||
return delta_content['content']
|
||||
else:
|
||||
return parsed_data['usage']['total_tokens'] , ""
|
||||
return "end"
|
||||
except KeyError:
|
||||
logger.error(f"error chunk: {decoded_data}")
|
||||
return 1,""
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"error chunk: {decoded_data}")
|
||||
return 1,""
|
||||
logger.error(f"error chunk: {chunk}")
|
||||
|
||||
#断句函数
|
||||
def split_string_with_punctuation(current_sentence,text,is_first,is_end):
|
||||
|
@ -104,19 +97,6 @@ def update_session_activity(session_id,db):
|
|||
except Exception as e:
|
||||
db.roolback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
#获取target_se
|
||||
def get_emb(session_id,db):
|
||||
try:
|
||||
session_record = db.query(Session).filter(Session.id == session_id).first()
|
||||
user_character_record = db.query(UserCharacter).filter(UserCharacter.id == session_record.user_character_id).first()
|
||||
user_record = db.query(User).filter(User.id == user_character_record.user_id).first()
|
||||
audio_record = db.query(Audio).filter(Audio.id == user_record.selected_audio_id).first()
|
||||
emb_npy = np.load(io.BytesIO(audio_record.emb_data))
|
||||
return emb_npy
|
||||
except Exception as e:
|
||||
logger.debug("未找到音频:"+str(e))
|
||||
return np.array([])
|
||||
#--------------------------------------------------------
|
||||
|
||||
# 创建新聊天
|
||||
|
@ -126,6 +106,7 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
|
|||
try:
|
||||
db.add(new_chat)
|
||||
db.commit()
|
||||
db.refresh(new_chat)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
@ -151,23 +132,18 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
|
|||
"speaker_id":db_character.voice_id,
|
||||
"noise_scale": 0.1,
|
||||
"noise_scale_w":0.668,
|
||||
"length_scale": 1.2,
|
||||
"speed":1
|
||||
"length_scale": 1.2
|
||||
}
|
||||
llm_info = {
|
||||
"model": "abab5.5-chat",
|
||||
"temperature": 1,
|
||||
"top_p": 0.9,
|
||||
}
|
||||
user_info = {
|
||||
"character":"",
|
||||
"events":[]
|
||||
}
|
||||
|
||||
# 将tts和llm信息转化为json字符串
|
||||
tts_info_str = json.dumps(tts_info, ensure_ascii=False)
|
||||
llm_info_str = json.dumps(llm_info, ensure_ascii=False)
|
||||
user_info_str = json.dumps(user_info, ensure_ascii=False)
|
||||
user_info_str = db_user.persona
|
||||
|
||||
token = 0
|
||||
content = {"user_id": user_id, "messages": messages, "user_info": user_info_str, "tts_info": tts_info_str,
|
||||
|
@ -178,6 +154,7 @@ async def create_chat_handler(chat: ChatCreateRequest, db, redis):
|
|||
# 将Session记录存入
|
||||
db.add(new_session)
|
||||
db.commit()
|
||||
db.refresh(new_session)
|
||||
redis.set(session_id, json.dumps(content, ensure_ascii=False))
|
||||
|
||||
chat_create_data = ChatCreateData(user_character_id=new_chat.id, session_id=session_id, createdAt=datetime.now().isoformat())
|
||||
|
@ -246,11 +223,9 @@ async def sct_user_input_handler(ws,user_input_q,llm_input_q,future_session_id,f
|
|||
logger.error(f"用户输入处理函数发生错误: {str(e)}")
|
||||
|
||||
#语音识别
|
||||
async def sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event):
|
||||
async def sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event):
|
||||
logger.debug("语音识别函数启动")
|
||||
if Config.STRAM_CHAT.ASR == "LOCAL":
|
||||
is_signup = False
|
||||
audio = ""
|
||||
try:
|
||||
current_message = ""
|
||||
while not (user_input_finish_event.is_set() and user_input_q.empty()):
|
||||
|
@ -258,63 +233,16 @@ async def sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_fini
|
|||
asr.session_signup(session_id)
|
||||
is_signup = True
|
||||
audio_data = await user_input_q.get()
|
||||
audio += audio_data
|
||||
asr_result = asr.streaming_recognize(session_id,audio_data)
|
||||
current_message += ''.join(asr_result['text'])
|
||||
asr_result = asr.streaming_recognize(session_id,b'',is_end=True)
|
||||
current_message += ''.join(asr_result['text'])
|
||||
slice_arr = ["嗯",""]
|
||||
if current_message in slice_arr:
|
||||
await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False))
|
||||
return
|
||||
current_message = asr.punctuation_correction(current_message)
|
||||
# emotion_dict = asr.emtion_recognition(audio) #情感辨识
|
||||
# if not isinstance(emotion_dict, str):
|
||||
# max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
||||
# current_message = f"{current_message},当前说话人的情绪:{emotion_dict['labels'][max_index]}"
|
||||
await llm_input_q.put(current_message)
|
||||
asr.session_signout(session_id)
|
||||
except Exception as e:
|
||||
asr.session_signout(session_id)
|
||||
logger.error(f"语音识别函数发生错误: {str(e)}")
|
||||
logger.debug(f"接收到用户消息: {current_message}")
|
||||
elif Config.STRAM_CHAT.ASR == "XF":
|
||||
status = FIRST_FRAME
|
||||
xf_websocket = await xf_asr_websocket_factory() #获取一个讯飞语音识别接口websocket连接
|
||||
segment_duration_threshold = 25 #设置一个连接时长上限,讯飞语音接口超过30秒会自动断开连接,所以该值设置成25秒
|
||||
segment_start_time = asyncio.get_event_loop().time()
|
||||
current_message = ""
|
||||
while not (user_input_finish_event.is_set() and user_input_q.empty()):
|
||||
try:
|
||||
audio_data = await user_input_q.get()
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
if current_time - segment_start_time > segment_duration_threshold:
|
||||
await xf_websocket.send(make_last_frame())
|
||||
current_message += parse_xfasr_recv(await xf_websocket.recv())
|
||||
await xf_websocket.close()
|
||||
xf_websocket = await xf_asr_websocket_factory() #重建一个websocket连接
|
||||
status = FIRST_FRAME
|
||||
segment_start_time = current_time
|
||||
if status == FIRST_FRAME:
|
||||
await xf_websocket.send(make_first_frame(audio_data))
|
||||
status = CONTINUE_FRAME
|
||||
elif status == CONTINUE_FRAME:
|
||||
await xf_websocket.send(make_continue_frame(audio_data))
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
logger.debug("讯飞语音识别接口连接断开,重新创建连接")
|
||||
xf_websocket = await xf_asr_websocket_factory() #重建一个websocket连接
|
||||
status = FIRST_FRAME
|
||||
segment_start_time = asyncio.get_event_loop().time()
|
||||
await xf_websocket.send(make_last_frame())
|
||||
current_message += parse_xfasr_recv(await xf_websocket.recv())
|
||||
await xf_websocket.close()
|
||||
if current_message in ["嗯", ""]:
|
||||
await ws.send_text(json.dumps({"type": "close", "code": 201, "msg": ""}, ensure_ascii=False))
|
||||
return
|
||||
await llm_input_q.put(current_message)
|
||||
logger.debug(f"接收到用户消息: {current_message}")
|
||||
|
||||
|
||||
|
||||
#大模型调用
|
||||
async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis,llm_input_q,chat_finished_event):
|
||||
|
@ -325,18 +253,13 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
is_first = True
|
||||
is_end = False
|
||||
session_content = get_session_content(session_id,redis,db)
|
||||
user_info = json.loads(session_content["user_info"])
|
||||
messages = json.loads(session_content["messages"])
|
||||
current_message = await llm_input_q.get()
|
||||
if current_message == "":
|
||||
return
|
||||
messages.append({'role': 'user', "content": current_message})
|
||||
messages_send = messages #创造一个message副本,在其中最后一条数据前面添加用户信息
|
||||
# messages_send[-1]['content'] = f"用户性格:{user_info['character']}\n事件摘要:{user_info['events']}" + messages_send[-1]['content']
|
||||
payload = json.dumps({
|
||||
"model": llm_info["model"],
|
||||
"stream": True,
|
||||
"messages": messages_send,
|
||||
"messages": messages,
|
||||
"max_tokens": 10000,
|
||||
"temperature": llm_info["temperature"],
|
||||
"top_p": llm_info["top_p"]
|
||||
|
@ -345,15 +268,13 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
target_se = get_emb(session_id,db)
|
||||
response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload,stream=True) #调用大模型
|
||||
except Exception as e:
|
||||
logger.error(f"编辑http请求时发生错误: {str(e)}")
|
||||
logger.error(f"llm调用发生错误: {str(e)}")
|
||||
try:
|
||||
async with aiohttp.ClientSession() as client:
|
||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
||||
async for chunk in response.content.iter_any():
|
||||
token_count, chunk_data = parseChunkDelta(chunk)
|
||||
is_end = token_count >0
|
||||
for chunk in response.iter_lines():
|
||||
chunk_data = parseChunkDelta(chunk)
|
||||
is_end = chunk_data == "end"
|
||||
if not is_end:
|
||||
llm_response += chunk_data
|
||||
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end) #断句
|
||||
|
@ -362,12 +283,10 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
|
||||
elif response_type == RESPONSE_AUDIO:
|
||||
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
|
||||
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
|
||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||
response_bytes = json.dumps(response_message, ensure_ascii=False).encode('utf-8')
|
||||
header = struct.pack('!II',len(response_bytes),len(audio))
|
||||
message_bytes = header + response_bytes + audio
|
||||
await ws.send_bytes(message_bytes)
|
||||
await ws.send_bytes(audio) #返回音频数据
|
||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False)) #返回文本信息
|
||||
logger.debug(f"websocket返回: {sentence}")
|
||||
if is_end:
|
||||
logger.debug(f"llm返回结果: {llm_response}")
|
||||
|
@ -378,15 +297,6 @@ async def sct_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||
is_first = True
|
||||
llm_response = ""
|
||||
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数,则进行文本摘要
|
||||
system_prompt = messages[0]['content']
|
||||
summary = await summarizer.summarize(messages)
|
||||
events = user_info['events']
|
||||
events.append(summary['event'])
|
||||
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
|
||||
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
|
||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
|
||||
logger.debug(f"总结后session_content: {session_content}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理llm返回结果发生错误: {str(e)}")
|
||||
chat_finished_event.set()
|
||||
|
@ -405,7 +315,7 @@ async def streaming_chat_temporary_handler(ws: WebSocket, db, redis):
|
|||
session_id = await future_session_id #获取session_id
|
||||
update_session_activity(session_id,db)
|
||||
response_type = await future_response_type #获取返回类型
|
||||
asyncio.create_task(sct_asr_handler(ws,session_id,user_input_q,llm_input_q,user_input_finish_event))
|
||||
asyncio.create_task(sct_asr_handler(session_id,user_input_q,llm_input_q,user_input_finish_event))
|
||||
tts_info = json.loads(get_session_content(session_id,redis,db)["tts_info"])
|
||||
llm_info = json.loads(get_session_content(session_id,redis,db)["llm_info"])
|
||||
|
||||
|
@ -462,7 +372,6 @@ async def scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_eve
|
|||
logger.debug("语音识别函数启动")
|
||||
is_signup = False
|
||||
current_message = ""
|
||||
audio = ""
|
||||
while not (input_finished_event.is_set() and user_input_q.empty()):
|
||||
try:
|
||||
aduio_frame = await asyncio.wait_for(user_input_q.get(),timeout=3)
|
||||
|
@ -472,23 +381,15 @@ async def scl_asr_handler(session_id,user_input_q,llm_input_q,input_finished_eve
|
|||
if aduio_frame['is_end']:
|
||||
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'], is_end=True)
|
||||
current_message += ''.join(asr_result['text'])
|
||||
current_message = asr.punctuation_correction(current_message)
|
||||
audio += aduio_frame['audio']
|
||||
emotion_dict =asr.emtion_recognition(audio) #情感辨识
|
||||
if not isinstance(emotion_dict, str):
|
||||
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
||||
current_message = f"{current_message}当前说话人的情绪:{emotion_dict['labels'][max_index]}"
|
||||
await llm_input_q.put(current_message)
|
||||
logger.debug(f"接收到用户消息: {current_message}")
|
||||
current_message = ""
|
||||
audio = ""
|
||||
else:
|
||||
asr_result = asr.streaming_recognize(session_id,aduio_frame['audio'])
|
||||
audio += aduio_frame['audio']
|
||||
current_message += ''.join(asr_result['text'])
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception as e:
|
||||
asr.session_signout(session_id)
|
||||
logger.error(f"语音识别函数发生错误: {str(e)}")
|
||||
break
|
||||
asr.session_signout(session_id)
|
||||
|
@ -505,7 +406,6 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
try:
|
||||
session_content = get_session_content(session_id,redis,db)
|
||||
messages = json.loads(session_content["messages"])
|
||||
user_info = json.loads(session_content["user_info"])
|
||||
current_message = await asyncio.wait_for(llm_input_q.get(),timeout=3)
|
||||
messages.append({'role': 'user', "content": current_message})
|
||||
payload = json.dumps({
|
||||
|
@ -520,12 +420,10 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
target_se = get_emb(session_id,db)
|
||||
async with aiohttp.ClientSession() as client:
|
||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
||||
async for chunk in response.content.iter_any():
|
||||
token_count, chunk_data = parseChunkDelta(chunk)
|
||||
is_end = token_count >0
|
||||
response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload,stream=True)
|
||||
for chunk in response.iter_lines():
|
||||
chunk_data = parseChunkDelta(chunk)
|
||||
is_end = chunk_data == "end"
|
||||
if not is_end:
|
||||
llm_response += chunk_data
|
||||
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
|
||||
|
@ -535,7 +433,7 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
||||
elif response_type == RESPONSE_AUDIO:
|
||||
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
|
||||
sr,audio = tts.synthesize(sentence, tts_info["speaker_id"], tts_info["language"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"],return_bytes=True)
|
||||
response_message = {"type": "text", "code":200, "msg": sentence}
|
||||
await ws.send_bytes(audio)
|
||||
await ws.send_text(json.dumps(response_message, ensure_ascii=False))
|
||||
|
@ -544,25 +442,17 @@ async def scl_llm_handler(ws,session_id,response_type,llm_info,tts_info,db,redis
|
|||
logger.debug(f"llm返回结果: {llm_response}")
|
||||
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||
is_end = False
|
||||
|
||||
messages.append({'role': 'assistant', "content": llm_response})
|
||||
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||
is_first = True
|
||||
llm_response = ""
|
||||
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于70%的最大token数,则进行文本摘要
|
||||
system_prompt = messages[0]['content']
|
||||
summary = await summarizer.summarize(messages)
|
||||
events = user_info['events']
|
||||
events.append(summary['event'])
|
||||
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
|
||||
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
|
||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
|
||||
logger.debug(f"总结后session_content: {session_content}")
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
# except Exception as e:
|
||||
# logger.error(f"处理llm返回结果发生错误: {str(e)}")
|
||||
# break
|
||||
except Exception as e:
|
||||
logger.error(f"处理llm返回结果发生错误: {str(e)}")
|
||||
break
|
||||
chat_finished_event.set()
|
||||
|
||||
async def streaming_chat_lasting_handler(ws,db,redis):
|
||||
|
@ -633,7 +523,6 @@ async def voice_call_audio_consumer(ws,session_id,audio_q,asr_result_q,input_fin
|
|||
current_message = ""
|
||||
vad_count = 0
|
||||
is_signup = False
|
||||
audio = ""
|
||||
while not (input_finished_event.is_set() and audio_q.empty()):
|
||||
try:
|
||||
if not is_signup:
|
||||
|
@ -644,22 +533,14 @@ async def voice_call_audio_consumer(ws,session_id,audio_q,asr_result_q,input_fin
|
|||
if vad_count > 0:
|
||||
vad_count -= 1
|
||||
asr_result = asr.streaming_recognize(session_id, audio_data)
|
||||
audio += audio_data
|
||||
current_message += ''.join(asr_result['text'])
|
||||
else:
|
||||
vad_count += 1
|
||||
if vad_count >= 25: #连续25帧没有语音,则认为说完了
|
||||
asr_result = asr.streaming_recognize(session_id, audio_data, is_end=True)
|
||||
if current_message:
|
||||
current_message = asr.punctuation_correction(current_message)
|
||||
audio += audio_data
|
||||
emotion_dict =asr.emtion_recognition(audio) #情感辨识
|
||||
if not isinstance(emotion_dict, str):
|
||||
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
||||
current_message = f"{current_message}当前说话人的情绪:{emotion_dict['labels'][max_index]}"
|
||||
logger.debug(f"检测到静默,用户输入为:{current_message}")
|
||||
await asr_result_q.put(current_message)
|
||||
audio = ""
|
||||
text_response = {"type": "user_text", "code": 200, "msg": current_message}
|
||||
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
||||
current_message = ""
|
||||
|
@ -684,7 +565,6 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
|||
try:
|
||||
session_content = get_session_content(session_id,redis,db)
|
||||
messages = json.loads(session_content["messages"])
|
||||
user_info = json.loads(session_content["user_info"])
|
||||
current_message = await asyncio.wait_for(asr_result_q.get(),timeout=3)
|
||||
messages.append({'role': 'user', "content": current_message})
|
||||
payload = json.dumps({
|
||||
|
@ -695,43 +575,34 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
|||
"temperature": llm_info["temperature"],
|
||||
"top_p": llm_info["top_p"]
|
||||
})
|
||||
|
||||
headers = {
|
||||
'Authorization': f"Bearer {Config.MINIMAX_LLM.API_KEY}",
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
target_se = get_emb(session_id,db)
|
||||
async with aiohttp.ClientSession() as client:
|
||||
async with client.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload) as response: #发送大模型请求
|
||||
async for chunk in response.content.iter_any():
|
||||
token_count, chunk_data = parseChunkDelta(chunk)
|
||||
is_end = token_count >0
|
||||
response = requests.post(Config.MINIMAX_LLM.URL, headers=headers, data=payload,stream=True)
|
||||
for chunk in response.iter_lines():
|
||||
chunk_data = parseChunkDelta(chunk)
|
||||
is_end = chunk_data == "end"
|
||||
if not is_end:
|
||||
llm_response += chunk_data
|
||||
sentences,current_sentence,is_first = split_string_with_punctuation(current_sentence,chunk_data,is_first,is_end)
|
||||
for sentence in sentences:
|
||||
audio,sr = tts.synthesize(text=sentence,tts_info=tts_info,target_se=target_se)
|
||||
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
|
||||
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
|
||||
await ws.send_bytes(audio) #返回音频二进制流数据
|
||||
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
||||
logger.debug(f"websocket返回: {sentence}")
|
||||
logger.debug(f"llm返回结果: {sentence}")
|
||||
if is_end:
|
||||
logger.debug(f"llm返回结果: {llm_response}")
|
||||
await ws.send_text(json.dumps({"type": "end", "code": 200, "msg": ""}, ensure_ascii=False))
|
||||
is_end = False
|
||||
|
||||
messages.append({'role': 'assistant', "content": llm_response})
|
||||
session_content["messages"] = json.dumps(messages,ensure_ascii=False) #更新对话
|
||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False)) #更新session
|
||||
is_first = True
|
||||
llm_response = ""
|
||||
if token_count > summarizer.max_token * 0.7: #如果llm返回的token数大于60%的最大token数,则进行文本摘要
|
||||
system_prompt = messages[0]['content']
|
||||
summary = await summarizer.summarize(messages)
|
||||
events = user_info['events']
|
||||
events.append(summary['event'])
|
||||
session_content['messages'] = json.dumps([{'role':'system','content':system_prompt}],ensure_ascii=False)
|
||||
session_content['user_info'] = json.dumps({'character': summary['character'], 'events': events}, ensure_ascii=False)
|
||||
redis.set(session_id,json.dumps(session_content,ensure_ascii=False))
|
||||
logger.debug(f"总结后session_content: {session_content}")
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception as e:
|
||||
|
@ -739,6 +610,23 @@ async def voice_call_llm_handler(ws,session_id,llm_info,tts_info,db,redis,asr_re
|
|||
break
|
||||
voice_call_end_event.set()
|
||||
|
||||
|
||||
#语音合成及返回函数
|
||||
async def voice_call_tts_handler(ws,tts_info,split_result_q,split_finished_event,voice_call_end_event):
|
||||
logger.debug("语音合成及返回函数启动")
|
||||
while not (split_finished_event.is_set() and split_result_q.empty()):
|
||||
try:
|
||||
sentence = await asyncio.wait_for(split_result_q.get(),timeout=3)
|
||||
sr,audio = tts.synthesize(sentence, tts_info["language"], tts_info["speaker_id"], tts_info["noise_scale"], tts_info["noise_scale_w"], tts_info["length_scale"], return_bytes=True)
|
||||
text_response = {"type": "llm_text", "code": 200, "msg": sentence}
|
||||
await ws.send_bytes(audio) #返回音频二进制流数据
|
||||
await ws.send_text(json.dumps(text_response, ensure_ascii=False)) #返回文本数据
|
||||
logger.debug(f"websocket返回:{sentence}")
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
voice_call_end_event.set()
|
||||
|
||||
|
||||
async def voice_call_handler(ws, db, redis):
|
||||
logger.debug("voice_call websocket 连接建立")
|
||||
audio_q = asyncio.Queue() #音频队列
|
||||
|
|
|
@ -77,32 +77,3 @@ async def update_session_handler(session_id, session_data:SessionUpdateRequest,
|
|||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
session_update_data = SessionUpdateData(updatedAt=datetime.now().isoformat())
|
||||
return SessionUpdateResponse(status="success",message="Session 更新成功",data=session_update_data)
|
||||
|
||||
#更新Session中的Speaker Id信息
|
||||
async def update_session_speaker_id_handler(session_id, session_data, db, redis):
|
||||
existing_session = ""
|
||||
if redis.exists(session_id):
|
||||
existing_session = redis.get(session_id)
|
||||
else:
|
||||
existing_session = db.query(Session).filter(Session.id == session_id).first().content
|
||||
if not existing_session:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session not found")
|
||||
|
||||
#更新Session字段
|
||||
session = json.loads(existing_session)
|
||||
session_llm_info = json.loads(session["llm_info"])
|
||||
session_llm_info["speaker_id"] = session_data.speaker_id
|
||||
session["llm_info"] = json.dumps(session_llm_info,ensure_ascii=False)
|
||||
|
||||
#存储Session
|
||||
|
||||
session_str = json.dumps(session,ensure_ascii=False)
|
||||
redis.set(session_id, session_str)
|
||||
try:
|
||||
db.query(Session).filter(Session.id == session_id).update({"content": session_str})
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
session_update_data = SessionSpeakerUpdateData(updatedAt=datetime.now().isoformat())
|
||||
return SessionSpeakerUpdateResponse(status="success",message="Session SpeakID更新成功",data=session_update_data)
|
|
@ -1,20 +1,14 @@
|
|||
from ..schemas.user_schema import *
|
||||
from ..dependencies.logger import get_logger
|
||||
from ..dependencies.tts import get_tts
|
||||
from ..models import User, Hardware, Audio
|
||||
from ..models import User, Hardware
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import HTTPException, status
|
||||
from pydub import AudioSegment
|
||||
import numpy as np
|
||||
import io
|
||||
|
||||
|
||||
#依赖注入获取logger
|
||||
logger = get_logger()
|
||||
|
||||
#依赖注入获取tts
|
||||
tts = get_tts("OPENVOICE")
|
||||
|
||||
#创建用户
|
||||
async def create_user_handler(user:UserCrateRequest, db: Session):
|
||||
|
@ -42,6 +36,7 @@ async def update_user_handler(user_id:int, user:UserUpdateRequest, db: Session):
|
|||
existing_user.persona = user.persona
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(existing_user)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
@ -122,6 +117,7 @@ async def change_bind_hardware_handler(hardware_id, user, db):
|
|||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
|
||||
existing_hardware.user_id = user.user_id
|
||||
db.commit()
|
||||
db.refresh(existing_hardware)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
@ -139,6 +135,7 @@ async def update_hardware_handler(hardware_id, hardware, db):
|
|||
existing_hardware.firmware = hardware.firmware
|
||||
existing_hardware.model = hardware.model
|
||||
db.commit()
|
||||
db.refresh(existing_hardware)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
@ -157,96 +154,3 @@ async def get_hardware_handler(hardware_id, db):
|
|||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="硬件不存在")
|
||||
hardware_query_data = HardwareQueryData(mac=existing_hardware.mac, user_id=existing_hardware.user_id, firmware=existing_hardware.firmware, model=existing_hardware.model)
|
||||
return HardwareQueryResponse(status="success", message="查询硬件信息成功", data=hardware_query_data)
|
||||
|
||||
|
||||
#用户上传音频
|
||||
async def upload_audio_handler(user_id, audio, db):
|
||||
try:
|
||||
audio_data = await audio.read()
|
||||
emb_data = tts.audio2emb(np.frombuffer(AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data, dtype=np.int32),rate=44100,vad=True)
|
||||
out = io.BytesIO()
|
||||
np.save(out, emb_data)
|
||||
out.seek(0)
|
||||
emb_binary = out.read()
|
||||
new_audio = Audio(user_id=user_id, audio_data=audio_data,emb_data=emb_binary) #创建音频
|
||||
db.add(new_audio)
|
||||
db.commit()
|
||||
db.refresh(new_audio)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
audio_upload_data = AudioUploadData(audio_id=new_audio.id, uploadedAt=datetime.now().isoformat())
|
||||
return AudioUploadResponse(status="success", message="用户上传音频成功", data=audio_upload_data)
|
||||
|
||||
|
||||
#用户更新音频
|
||||
async def update_audio_handler(audio_id, audio_file, db):
|
||||
try:
|
||||
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
|
||||
if existing_audio is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
|
||||
audio_data = await audio_file.read()
|
||||
raw_data = AudioSegment.from_file(io.BytesIO(audio_data), format="mp3").raw_data
|
||||
emb_data = tts.audio2emb(np.frombuffer(raw_data, dtype=np.int32),rate=44100,vad=True).tobytes()
|
||||
existing_audio.audio_data = audio_data
|
||||
existing_audio.emb_data = emb_data
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
audio_update_data = AudioUpdateData(updatedAt=datetime.now().isoformat())
|
||||
return AudioUpdateResponse(status="success", message="用户更新音频成功", data=audio_update_data)
|
||||
|
||||
|
||||
#用户查询音频
|
||||
async def download_audio_handler(audio_id, db):
|
||||
try:
|
||||
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
if existing_audio is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
|
||||
audio_data = existing_audio.audio_data
|
||||
return audio_data
|
||||
|
||||
|
||||
#用户删除音频
|
||||
async def delete_audio_handler(audio_id, db):
|
||||
try:
|
||||
existing_audio = db.query(Audio).filter(Audio.id == audio_id).first()
|
||||
existing_user = db.query(User).filter(User.selected_audio_id == audio_id).first()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
if existing_audio is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="音频不存在")
|
||||
try:
|
||||
if existing_user.selected_audio_id == audio_id:
|
||||
existing_user.selected_audio_id = None
|
||||
db.delete(existing_audio)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
audio_delete_data = AudioDeleteData(deletedAt=datetime.now().isoformat())
|
||||
return AudioDeleteResponse(status="success", message="用户删除音频成功", data=audio_delete_data)
|
||||
|
||||
|
||||
#用户绑定音频
|
||||
async def bind_audio_handler(bind_req, db):
|
||||
try:
|
||||
existing_user = db.query(User).filter(User.id == bind_req.user_id).first()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
if existing_user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
|
||||
try:
|
||||
existing_user.selected_audio_id = bind_req.audio_id
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
audio_bind_data = AudioBindData(bindedAt=datetime.now().isoformat())
|
||||
return AudioBindResponse(status="success", message="用户绑定音频成功", data=audio_bind_data)
|
|
@ -1,11 +0,0 @@
|
|||
from utils.stt.modified_funasr import ModifiedRecognizer
|
||||
from app.dependencies.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
#初始化全局asr对象
|
||||
asr = ModifiedRecognizer()
|
||||
logger.info("ASR初始化成功")
|
||||
|
||||
def get_asr():
|
||||
return asr
|
|
@ -1,61 +0,0 @@
|
|||
import aiohttp
|
||||
import json
|
||||
from datetime import datetime
|
||||
from config import get_config
|
||||
|
||||
|
||||
# 依赖注入获取Config
|
||||
Config = get_config()
|
||||
|
||||
class Summarizer:
|
||||
def __init__(self):
|
||||
self.system_prompt = """你是一台对话总结机器,你的职责是整理用户与玩具之间的对话,最终提炼出对话中发生的事件,以及用户性格\n\n你的输出必须为一个json,里面有两个字段,一个是event,一个是character,将你总结出的事件写入event,将你总结出的用户性格写入character\nevent和character均为字符串\n返回示例:{"event":"在幼儿园看葫芦娃,老鹰抓小鸡","character":"活泼可爱"}"""
|
||||
self.model = "abab5.5-chat"
|
||||
self.max_token = 10000
|
||||
self.temperature = 0.9
|
||||
self.top_p = 1
|
||||
|
||||
async def summarize(self,messages):
|
||||
context = ""
|
||||
for message in messages:
|
||||
if message['role'] == 'user':
|
||||
context += "用户:"+ message['content'] + '\n'
|
||||
elif message['role'] == 'assistant':
|
||||
context += '玩具:'+ message['content'] + '\n'
|
||||
payload = json.dumps({
|
||||
"model":self.model,
|
||||
"messages":[
|
||||
{
|
||||
"role":"system",
|
||||
"content":self.system_prompt
|
||||
},
|
||||
{
|
||||
"role":"user",
|
||||
"content":context
|
||||
}
|
||||
],
|
||||
"max_tokens":self.max_token,
|
||||
"top_p":self.top_p
|
||||
})
|
||||
headers = {
|
||||
'Authorization': f'Bearer {Config.MINIMAX_LLM.API_KEY}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(Config.MINIMAX_LLM.URL, data=payload, headers=headers) as response:
|
||||
content = json.loads(json.loads(await response.text())['choices'][0]['message']['content'])
|
||||
try:
|
||||
summary = {
|
||||
'event':datetime.now().strftime("%Y-%m-%d")+":"+content['event'],
|
||||
'character':content['character']
|
||||
}
|
||||
except TypeError:
|
||||
summary = {
|
||||
'event':datetime.now().strftime("%Y-%m-%d")+":"+content['event'][0],
|
||||
'character':""
|
||||
}
|
||||
return summary
|
||||
|
||||
def get_summarizer():
|
||||
summarizer = Summarizer()
|
||||
return summarizer
|
|
@ -1,23 +0,0 @@
|
|||
|
||||
from app.dependencies.logger import get_logger
|
||||
from config import get_config
|
||||
|
||||
logger = get_logger()
|
||||
Config = get_config()
|
||||
|
||||
from utils.tts.openvoice_utils import TextToSpeech
|
||||
openvoice_tts = TextToSpeech(use_tone_convert=True,device='cuda')
|
||||
logger.info("TTS_OPENVOICE 初始化成功")
|
||||
|
||||
from utils.tts.vits_utils import TextToSpeech
|
||||
vits_tts = TextToSpeech()
|
||||
logger.info("TTS_VITS 初始化成功")
|
||||
|
||||
|
||||
|
||||
#初始化全局tts对象
|
||||
def get_tts(tts_type=Config.TTS_UTILS):
|
||||
if tts_type == "OPENVOICE":
|
||||
return openvoice_tts
|
||||
elif tts_type == "VITS":
|
||||
return vits_tts
|
|
@ -0,0 +1,26 @@
|
|||
from app import app, Config
|
||||
import uvicorn
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host=Config.UVICORN.HOST, port=Config.UVICORN.PORT)
|
||||
#uvicorn.run("app.main:app", host=Config.UVICORN.HOST, port=Config.UVICORN.PORT, workers=Config.UVICORN.WORKERS)
|
||||
# _ooOoo_ #
|
||||
# o8888888o #
|
||||
# 88" . "88 #
|
||||
# (| -_- |) #
|
||||
# O\ = /O #
|
||||
# ____/`---'\____ #
|
||||
# . ' \\| |// `. #
|
||||
# / \\||| : |||// \ #
|
||||
# / _||||| -:- |||||- \ #
|
||||
# | | \\\ - /// | | #
|
||||
# \ .-\__ `-` ___/-. / #
|
||||
# ___`. .' /--.--\ `. . __ #
|
||||
# ."" '< `.___\_<|>_/___.' >'"". #
|
||||
# | | : `- \`.;`\ _ /`;.`/ - ` : | | #
|
||||
# \ \ `-. \_ __\ /__ _/ .-` / / #
|
||||
# ======`-.____`-.___\_____/___.-`____.-'====== #
|
||||
# `=---=' #
|
||||
# ............................................. #
|
||||
# 佛祖保佑 永无BUG #
|
|
@ -1,4 +1,4 @@
|
|||
from sqlalchemy import Column, Integer, String, JSON, Text, ForeignKey, DateTime, Boolean, CHAR, LargeBinary
|
||||
from sqlalchemy import Column, Integer, String, JSON, Text, ForeignKey, DateTime, Boolean, CHAR
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
Base = declarative_base()
|
||||
|
@ -36,7 +36,6 @@ class User(Base):
|
|||
avatar_id = Column(String(36), nullable=True)
|
||||
tags = Column(JSON)
|
||||
persona = Column(JSON)
|
||||
selected_audio_id = Column(Integer, nullable=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User(id={self.id}, tags={self.tags})>"
|
||||
|
@ -81,12 +80,3 @@ class Session(Base):
|
|||
|
||||
def __repr__(self):
|
||||
return f"<Session(id={self.id}, user_character_id={self.user_character_id})>"
|
||||
|
||||
#音频表定义
|
||||
class Audio(Base):
|
||||
__tablename__ = 'audio'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(Integer, ForeignKey('user.id'))
|
||||
audio_data = Column(LargeBinary(16777215))
|
||||
emb_data = Column(LargeBinary(65535))
|
||||
|
|
@ -27,10 +27,3 @@ async def get_session(session_id: str, db=Depends(get_db), redis=Depends(get_red
|
|||
async def update_session(session_id: str, session_data: SessionUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
|
||||
response = await update_session_handler(session_id, session_data, db, redis)
|
||||
return response
|
||||
|
||||
|
||||
#session声音信息更新接口
|
||||
@router.put("/sessions/tts_info/speaker_id/{session_id}", response_model=SessionSpeakerUpdateResponse)
|
||||
async def update_session_speaker_id(session_id: str, session_data: SessionSpeakerUpdateRequest, db=Depends(get_db), redis=Depends(get_redis)):
|
||||
response = await update_session_speaker_id_handler(session_id, session_data, db, redis)
|
||||
return response
|
|
@ -1,4 +1,4 @@
|
|||
from fastapi import APIRouter, UploadFile, File, Response
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from ..controllers.user_controller import *
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
@ -69,38 +69,3 @@ async def update_hardware_info(hardware_id: int, hardware: HardwareUpdateRequest
|
|||
async def get_hardware(hardware_id: int, db: Session = Depends(get_db)):
|
||||
response = await get_hardware_handler(hardware_id, db)
|
||||
return response
|
||||
|
||||
|
||||
#用户音频上传
|
||||
@router.post('/users/audio',response_model=AudioUploadResponse)
|
||||
async def upload_audio(user_id:int, audio_file:UploadFile=File(...), db: Session = Depends(get_db)):
|
||||
response = await upload_audio_handler(user_id, audio_file, db)
|
||||
return response
|
||||
|
||||
|
||||
#用户音频修改
|
||||
@router.put('/users/audio/{audio_id}',response_model=AudioUpdateResponse)
|
||||
async def update_audio(audio_id:int, audio_file:UploadFile=File(...), db: Session = Depends(get_db)):
|
||||
response = await update_audio_handler(audio_id, audio_file, db)
|
||||
return response
|
||||
|
||||
|
||||
#用户音频下载
|
||||
@router.get('/users/audio/{audio_id}')
|
||||
async def download_audio(audio_id:int, db: Session = Depends(get_db)):
|
||||
audio_data = await download_audio_handler(audio_id, db)
|
||||
return Response(content=audio_data,media_type='application/octet-stream',headers={"Content-Disposition": "attachment"})
|
||||
|
||||
|
||||
#用户音频删除
|
||||
@router.delete('/users/audio/{audio_id}',response_model=AudioDeleteResponse)
|
||||
async def delete_audio(audio_id:int, db: Session = Depends(get_db)):
|
||||
response = await delete_audio_handler(audio_id, db)
|
||||
return response
|
||||
|
||||
|
||||
#用户绑定音频
|
||||
@router.post('/users/audio/bind',response_model=AudioBindResponse)
|
||||
async def bind_audio(bind_req:AudioBindRequest, db: Session = Depends(get_db)):
|
||||
response = await bind_audio_handler(bind_req, db)
|
||||
return response
|
|
@ -56,15 +56,3 @@ class SessionUpdateData(BaseModel):
|
|||
class SessionUpdateResponse(BaseResponse):
|
||||
data: Optional[SessionUpdateData]
|
||||
#--------------------------------------------------------------------------
|
||||
|
||||
|
||||
#------------------------------Session Speaker Id修改----------------------
|
||||
class SessionSpeakerUpdateRequest(BaseModel):
|
||||
speaker_id: int
|
||||
|
||||
class SessionSpeakerUpdateData(BaseModel):
|
||||
updatedAt:str
|
||||
|
||||
class SessionSpeakerUpdateResponse(BaseResponse):
|
||||
data: Optional[SessionSpeakerUpdateData]
|
||||
#--------------------------------------------------------------------------
|
|
@ -3,6 +3,7 @@ from typing import Optional
|
|||
from .base_schema import BaseResponse
|
||||
|
||||
|
||||
|
||||
#---------------------------------用户创建----------------------------------
|
||||
#用户创建请求类
|
||||
class UserCrateRequest(BaseModel):
|
||||
|
@ -137,46 +138,3 @@ class HardwareQueryData(BaseModel):
|
|||
class HardwareQueryResponse(BaseResponse):
|
||||
data: Optional[HardwareQueryData]
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
#-------------------------------用户音频上传-------------------------------------
|
||||
class AudioUploadData(BaseModel):
|
||||
audio_id: int
|
||||
uploadedAt: str
|
||||
|
||||
class AudioUploadResponse(BaseResponse):
|
||||
data: Optional[AudioUploadData]
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
#-------------------------------用户音频修改-------------------------------------
|
||||
class AudioUpdateData(BaseModel):
|
||||
updatedAt: str
|
||||
|
||||
class AudioUpdateResponse(BaseResponse):
|
||||
data: Optional[AudioUpdateData]
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
#-------------------------------用户音频删除-------------------------------------
|
||||
class AudioDeleteData(BaseModel):
|
||||
deletedAt: str
|
||||
|
||||
class AudioDeleteResponse(BaseResponse):
|
||||
data: Optional[AudioDeleteData]
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
#-------------------------------用户音频绑定-------------------------------------
|
||||
class AudioBindRequest(BaseModel):
|
||||
audio_id: int
|
||||
user_id: int
|
||||
|
||||
class AudioBindData(BaseModel):
|
||||
bindedAt: str
|
||||
|
||||
class AudioBindResponse(BaseResponse):
|
||||
data: Optional[AudioBindData]
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
|
|
@ -1,28 +1,27 @@
|
|||
class DevelopmentConfig:
|
||||
SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://takway:takway123456@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置
|
||||
SQLALCHEMY_DATABASE_URI = f"mysql+pymysql://admin02:LabA100102@127.0.0.1/takway?charset=utf8mb4" #mysql数据库连接配置
|
||||
REDIS_URL = "redis://:takway@127.0.0.1:6379/0" #redis数据库连接配置
|
||||
LOG_LEVEL = "DEBUG" #日志级别
|
||||
TTS_UTILS = "VITS" #TTS引擎配置,可选OPENVOICE或者VITS
|
||||
class UVICORN:
|
||||
HOST = "0.0.0.0" #uvicorn放行ip,0.0.0.0代表所有ip
|
||||
PORT = 8001 #uvicorn运行端口
|
||||
PORT = 7878 #uvicorn运行端口
|
||||
WORKERS = 12 #uvicorn进程数(通常与cpu核数相同)
|
||||
class XF_ASR:
|
||||
APP_ID = "f1c121c1" #讯飞语音识别APP_ID
|
||||
API_SECRET = "NjQwODA5MTA4OTc3YjIyODM2NmVlYWQ0" #讯飞语音识别API_SECRET
|
||||
API_KEY = "36b316c7977fa534ae1e3bf52157bb92" #讯飞语音识别API_KEY
|
||||
APP_ID = "your_app_id" #讯飞语音识别APP_ID
|
||||
API_SECRET = "your_api_secret" #讯飞语音识别API_SECRET
|
||||
API_KEY = "your_api_key" #讯飞语音识别API_KEY
|
||||
DOMAIN = "iat"
|
||||
LANGUAGE = "zh_cn"
|
||||
ACCENT = "mandarin"
|
||||
VAD_EOS = 10000
|
||||
class MINIMAX_LLM:
|
||||
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLph5EiLCJVc2VyTmFtZSI6IumHkSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzY4NTM2NDM3MzE1MDgwODg2IiwiUGhvbmUiOiIxMzEzNjE0NzUyNyIsIkdyb3VwSUQiOiIxNzY4NTM2NDM3MzA2NjkyMjc4IiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjQtMDUtMTggMTY6MTQ6MDMiLCJpc3MiOiJtaW5pbWF4In0.LypYOkJXwKV6GzDM1dcNn4L0m19o8Q_Lvmn6SkMMb9WAfDJYxEnTc5odm-L4WAWfbur_gY0cQzgoHnI14t4XSaAvqfmcdCrKYpJbKoBmMse_RogJs7KOBt658je3wES4pBUKQll6NbogQB1f93lnA9IYv4aEVldfqglbCikd54XO8E9Ptn4gX9Mp8fUn3lCpZ6_OSlmgZsQySrmt1sDHHzi3DlkdXlFSI38TQSZIa5RhFpI8WSBLIbaKl84OhaDzo7v99k9DUCzb5JGh0eZOnUT0YswbKCPeV8rZ1XUiOVQrna1uiDLvqv54aIt3vsu-LypYmnHxtZ_z4u2gt87pZg"
|
||||
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiIyMzQ1dm9yIiwiVXNlck5hbWUiOiIyMzQ1dm9yIiwiQWNjb3VudCI6IiIsIlN1YmplY3RJRCI6IjE3NTk0ODIxODAxMDAxNzAyMDgiLCJQaG9uZSI6IjE1MDcyNjQxNTYxIiwiR3JvdXBJRCI6IjE3NTk0ODIxODAwOTU5NzU5MDQiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiIiLCJDcmVhdGVUaW1lIjoiMjAyNC0wNC0xMyAxOTowNDoxNyIsImlzcyI6Im1pbmltYXgifQ.RO_WJMz5T0XlL3F6xB9p015hL3PibCbsr5KqO3aMjBL5hKrf1uIjOICTDZWZoucyJV1suxvFPAd_2Ds2Rv01eCu6GFdai1hUByfp51mOOD0PtaZ5-JKRpRPpLSNpqrNoQteANZz0gdr2_GEGTgTzpbfGbXfRYKrQyeQSvq0zHwqumGPd9gJCre2RavPUmzKRrq9EAaQXtSNhBvVkf5lDlxr8fTAHgbj6MLAJZIvvf4uOZErNrbPylo1Vcy649KxEkc0HCWOZErOieeUQFRkKibnE5Q30CgywqxY2qMjrxGRZ_dtizan_0EZ62nXp-J6jarhcY9le1SqiMu1Cv61TuA"
|
||||
URL = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
||||
class MINIMAX_TTA:
|
||||
API_KEY = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiIyMzQ1dm9yIiwiVXNlck5hbWUiOiIyMzQ1dm9yIiwiQWNjb3VudCI6IiIsIlN1YmplY3RJRCI6IjE3NTk0ODIxODAxMDAxNzAyMDgiLCJQaG9uZSI6IjE1MDcyNjQxNTYxIiwiR3JvdXBJRCI6IjE3NTk0ODIxODAwOTU5NzU5MDQiLCJQYWdlTmFtZSI6IiIsIk1haWwiOiIiLCJDcmVhdGVUaW1lIjoiMjAyNC0wNC0xMyAxOTowNDoxNyIsImlzcyI6Im1pbmltYXgifQ.RO_WJMz5T0XlL3F6xB9p015hL3PibCbsr5KqO3aMjBL5hKrf1uIjOICTDZWZoucyJV1suxvFPAd_2Ds2Rv01eCu6GFdai1hUByfp51mOOD0PtaZ5-JKRpRPpLSNpqrNoQteANZz0gdr2_GEGTgTzpbfGbXfRYKrQyeQSvq0zHwqumGPd9gJCre2RavPUmzKRrq9EAaQXtSNhBvVkf5lDlxr8fTAHgbj6MLAJZIvvf4uOZErNrbPylo1Vcy649KxEkc0HCWOZErOieeUQFRkKibnE5Q30CgywqxY2qMjrxGRZ_dtizan_0EZ62nXp-J6jarhcY9le1SqiMu1Cv61TuA",
|
||||
URL = "https://api.minimax.chat/v1/t2a_pro",
|
||||
GROUP_ID ="1759482180095975904"
|
||||
class STRAM_CHAT:
|
||||
ASR = "XF" # 语音识别引擎,可选XF或者LOCAL
|
||||
ASR = "LOCAL"
|
||||
TTS = "LOCAL"
|
||||
|
Binary file not shown.
|
@ -0,0 +1,284 @@
|
|||
import os
|
||||
import io
|
||||
import numpy as np
|
||||
import pyaudio
|
||||
import wave
|
||||
import base64
|
||||
"""
|
||||
audio utils for modified_funasr_demo.py
|
||||
"""
|
||||
|
||||
def decode_str2bytes(data):
|
||||
# 将Base64编码的字节串解码为字节串
|
||||
if data is None:
|
||||
return None
|
||||
return base64.b64decode(data.encode('utf-8'))
|
||||
|
||||
class BaseAudio:
|
||||
def __init__(self,
|
||||
filename=None,
|
||||
input=False,
|
||||
output=False,
|
||||
CHUNK=1024,
|
||||
FORMAT=pyaudio.paInt16,
|
||||
CHANNELS=1,
|
||||
RATE=16000,
|
||||
input_device_index=None,
|
||||
output_device_index=None,
|
||||
**kwargs):
|
||||
self.CHUNK = CHUNK
|
||||
self.FORMAT = FORMAT
|
||||
self.CHANNELS = CHANNELS
|
||||
self.RATE = RATE
|
||||
self.filename = filename
|
||||
assert input!= output, "input and output cannot be the same, \
|
||||
but got input={} and output={}.".format(input, output)
|
||||
print("------------------------------------------")
|
||||
print(f"{'Input' if input else 'Output'} Audio Initialization: ")
|
||||
print(f"CHUNK: {self.CHUNK} \nFORMAT: {self.FORMAT} \nCHANNELS: {self.CHANNELS} \nRATE: {self.RATE} \ninput_device_index: {input_device_index} \noutput_device_index: {output_device_index}")
|
||||
print("------------------------------------------")
|
||||
self.p = pyaudio.PyAudio()
|
||||
self.stream = self.p.open(format=FORMAT,
|
||||
channels=CHANNELS,
|
||||
rate=RATE,
|
||||
input=input,
|
||||
output=output,
|
||||
input_device_index=input_device_index,
|
||||
output_device_index=output_device_index,
|
||||
**kwargs)
|
||||
|
||||
def load_audio_file(self, wav_file):
|
||||
with wave.open(wav_file, 'rb') as wf:
|
||||
params = wf.getparams()
|
||||
frames = wf.readframes(params.nframes)
|
||||
print("Audio file loaded.")
|
||||
# Audio Parameters
|
||||
# print("Channels:", params.nchannels)
|
||||
# print("Sample width:", params.sampwidth)
|
||||
# print("Frame rate:", params.framerate)
|
||||
# print("Number of frames:", params.nframes)
|
||||
# print("Compression type:", params.comptype)
|
||||
return frames
|
||||
|
||||
def check_audio_type(self, audio_data, return_type=None):
|
||||
assert return_type in ['bytes', 'io', None], \
|
||||
"return_type should be 'bytes', 'io' or None."
|
||||
if isinstance(audio_data, str):
|
||||
if len(audio_data) > 50:
|
||||
audio_data = decode_str2bytes(audio_data)
|
||||
else:
|
||||
assert os.path.isfile(audio_data), \
|
||||
"audio_data should be a file path or a bytes object."
|
||||
wf = wave.open(audio_data, 'rb')
|
||||
audio_data = wf.readframes(wf.getnframes())
|
||||
elif isinstance(audio_data, np.ndarray):
|
||||
if audio_data.dtype == np.dtype('float32'):
|
||||
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
|
||||
audio_data = audio_data.tobytes()
|
||||
elif isinstance(audio_data, bytes):
|
||||
pass
|
||||
else:
|
||||
raise TypeError(f"audio_data must be bytes, numpy.ndarray or str, \
|
||||
but got {type(audio_data)}")
|
||||
|
||||
if return_type == None:
|
||||
return audio_data
|
||||
return self.write_wave(None, [audio_data], return_type)
|
||||
|
||||
def write_wave(self, filename, frames, return_type='io'):
|
||||
"""Write audio data to a file."""
|
||||
if isinstance(frames, bytes):
|
||||
frames = [frames]
|
||||
if not isinstance(frames, list):
|
||||
raise TypeError("frames should be \
|
||||
a list of bytes or a bytes object, \
|
||||
but got {}.".format(type(frames)))
|
||||
|
||||
if return_type == 'io':
|
||||
if filename is None:
|
||||
filename = io.BytesIO()
|
||||
if self.filename:
|
||||
filename = self.filename
|
||||
return self.write_wave_io(filename, frames)
|
||||
elif return_type == 'bytes':
|
||||
return self.write_wave_bytes(frames)
|
||||
|
||||
|
||||
def write_wave_io(self, filename, frames):
|
||||
"""
|
||||
Write audio data to a file-like object.
|
||||
|
||||
Args:
|
||||
filename: [string or file-like object], file path or file-like object to write
|
||||
frames: list of bytes, audio data to write
|
||||
"""
|
||||
wf = wave.open(filename, 'wb')
|
||||
|
||||
# 设置WAV文件的参数
|
||||
wf.setnchannels(self.CHANNELS)
|
||||
wf.setsampwidth(self.p.get_sample_size(self.FORMAT))
|
||||
wf.setframerate(self.RATE)
|
||||
wf.writeframes(b''.join(frames))
|
||||
wf.close()
|
||||
if isinstance(filename, io.BytesIO):
|
||||
filename.seek(0) # reset file pointer to beginning
|
||||
return filename
|
||||
|
||||
def write_wave_bytes(self, frames):
|
||||
"""Write audio data to a bytes object."""
|
||||
return b''.join(frames)
|
||||
class BaseAudio:
|
||||
def __init__(self,
|
||||
filename=None,
|
||||
input=False,
|
||||
output=False,
|
||||
CHUNK=1024,
|
||||
FORMAT=pyaudio.paInt16,
|
||||
CHANNELS=1,
|
||||
RATE=16000,
|
||||
input_device_index=None,
|
||||
output_device_index=None,
|
||||
**kwargs):
|
||||
self.CHUNK = CHUNK
|
||||
self.FORMAT = FORMAT
|
||||
self.CHANNELS = CHANNELS
|
||||
self.RATE = RATE
|
||||
self.filename = filename
|
||||
assert input!= output, "input and output cannot be the same, \
|
||||
but got input={} and output={}.".format(input, output)
|
||||
print("------------------------------------------")
|
||||
print(f"{'Input' if input else 'Output'} Audio Initialization: ")
|
||||
print(f"CHUNK: {self.CHUNK} \nFORMAT: {self.FORMAT} \nCHANNELS: {self.CHANNELS} \nRATE: {self.RATE} \ninput_device_index: {input_device_index} \noutput_device_index: {output_device_index}")
|
||||
print("------------------------------------------")
|
||||
self.p = pyaudio.PyAudio()
|
||||
self.stream = self.p.open(format=FORMAT,
|
||||
channels=CHANNELS,
|
||||
rate=RATE,
|
||||
input=input,
|
||||
output=output,
|
||||
input_device_index=input_device_index,
|
||||
output_device_index=output_device_index,
|
||||
**kwargs)
|
||||
|
||||
def load_audio_file(self, wav_file):
|
||||
with wave.open(wav_file, 'rb') as wf:
|
||||
params = wf.getparams()
|
||||
frames = wf.readframes(params.nframes)
|
||||
print("Audio file loaded.")
|
||||
# Audio Parameters
|
||||
# print("Channels:", params.nchannels)
|
||||
# print("Sample width:", params.sampwidth)
|
||||
# print("Frame rate:", params.framerate)
|
||||
# print("Number of frames:", params.nframes)
|
||||
# print("Compression type:", params.comptype)
|
||||
return frames
|
||||
|
||||
def check_audio_type(self, audio_data, return_type=None):
|
||||
assert return_type in ['bytes', 'io', None], \
|
||||
"return_type should be 'bytes', 'io' or None."
|
||||
if isinstance(audio_data, str):
|
||||
if len(audio_data) > 50:
|
||||
audio_data = decode_str2bytes(audio_data)
|
||||
else:
|
||||
assert os.path.isfile(audio_data), \
|
||||
"audio_data should be a file path or a bytes object."
|
||||
wf = wave.open(audio_data, 'rb')
|
||||
audio_data = wf.readframes(wf.getnframes())
|
||||
elif isinstance(audio_data, np.ndarray):
|
||||
if audio_data.dtype == np.dtype('float32'):
|
||||
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
|
||||
audio_data = audio_data.tobytes()
|
||||
elif isinstance(audio_data, bytes):
|
||||
pass
|
||||
else:
|
||||
raise TypeError(f"audio_data must be bytes, numpy.ndarray or str, \
|
||||
but got {type(audio_data)}")
|
||||
|
||||
if return_type == None:
|
||||
return audio_data
|
||||
return self.write_wave(None, [audio_data], return_type)
|
||||
|
||||
def write_wave(self, filename, frames, return_type='io'):
|
||||
"""Write audio data to a file."""
|
||||
if isinstance(frames, bytes):
|
||||
frames = [frames]
|
||||
if not isinstance(frames, list):
|
||||
raise TypeError("frames should be \
|
||||
a list of bytes or a bytes object, \
|
||||
but got {}.".format(type(frames)))
|
||||
|
||||
if return_type == 'io':
|
||||
if filename is None:
|
||||
filename = io.BytesIO()
|
||||
if self.filename:
|
||||
filename = self.filename
|
||||
return self.write_wave_io(filename, frames)
|
||||
elif return_type == 'bytes':
|
||||
return self.write_wave_bytes(frames)
|
||||
|
||||
|
||||
def write_wave_io(self, filename, frames):
|
||||
"""
|
||||
Write audio data to a file-like object.
|
||||
|
||||
Args:
|
||||
filename: [string or file-like object], file path or file-like object to write
|
||||
frames: list of bytes, audio data to write
|
||||
"""
|
||||
wf = wave.open(filename, 'wb')
|
||||
|
||||
# 设置WAV文件的参数
|
||||
wf.setnchannels(self.CHANNELS)
|
||||
wf.setsampwidth(self.p.get_sample_size(self.FORMAT))
|
||||
wf.setframerate(self.RATE)
|
||||
wf.writeframes(b''.join(frames))
|
||||
wf.close()
|
||||
if isinstance(filename, io.BytesIO):
|
||||
filename.seek(0) # reset file pointer to beginning
|
||||
return filename
|
||||
|
||||
def write_wave_bytes(self, frames):
|
||||
"""Write audio data to a bytes object."""
|
||||
return b''.join(frames)
|
||||
|
||||
|
||||
class BaseRecorder(BaseAudio):
|
||||
def __init__(self,
|
||||
input=True,
|
||||
base_chunk_size=None,
|
||||
RATE=16000,
|
||||
**kwargs):
|
||||
super().__init__(input=input, RATE=RATE, **kwargs)
|
||||
self.base_chunk_size = base_chunk_size
|
||||
if base_chunk_size is None:
|
||||
self.base_chunk_size = self.CHUNK
|
||||
|
||||
def record(self,
|
||||
filename,
|
||||
duration=5,
|
||||
return_type='io',
|
||||
logger=None):
|
||||
if logger is not None:
|
||||
logger.info("Recording started.")
|
||||
else:
|
||||
print("Recording started.")
|
||||
frames = []
|
||||
for i in range(0, int(self.RATE / self.CHUNK * duration)):
|
||||
data = self.stream.read(self.CHUNK, exception_on_overflow=False)
|
||||
frames.append(data)
|
||||
if logger is not None:
|
||||
logger.info("Recording stopped.")
|
||||
else:
|
||||
print("Recording stopped.")
|
||||
return self.write_wave(filename, frames, return_type)
|
||||
|
||||
def record_chunk_voice(self,
|
||||
return_type='bytes',
|
||||
CHUNK=None,
|
||||
exception_on_overflow=True,
|
||||
queue=None):
|
||||
data = self.stream.read(self.CHUNK if CHUNK is None else CHUNK,
|
||||
exception_on_overflow=exception_on_overflow)
|
||||
if return_type is not None:
|
||||
return self.write_wave(None, [data], return_type)
|
||||
return data
|
|
@ -0,0 +1,39 @@
|
|||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from audio_utils import BaseRecorder
|
||||
from utils.stt.modified_funasr import ModifiedRecognizer
|
||||
|
||||
|
||||
|
||||
def asr_file_stream(file_path=r'.\assets\example_recording.wav'):
|
||||
# 读入音频文件
|
||||
rec = BaseRecorder()
|
||||
data = rec.load_audio_file(file_path)
|
||||
|
||||
# 创建模型
|
||||
asr = ModifiedRecognizer(use_punct=True, use_emotion=True, use_speaker_ver=True)
|
||||
asr.session_signup("test")
|
||||
|
||||
# 记录目标说话人
|
||||
asr.initialize_speaker(r".\assets\example_recording.wav")
|
||||
|
||||
# 语音识别
|
||||
print("===============================================")
|
||||
text_dict = asr.streaming_recognize("test", data, auto_det_end=True)
|
||||
print(f"text_dict: {text_dict}")
|
||||
|
||||
if not isinstance(text_dict, str):
|
||||
print("".join(text_dict['text']))
|
||||
|
||||
# 情感识别
|
||||
print("===============================================")
|
||||
emotion_dict = asr.recognize_emotion(data)
|
||||
print(f"emotion_dict: {emotion_dict}")
|
||||
if not isinstance(emotion_dict, str):
|
||||
max_index = emotion_dict['scores'].index(max(emotion_dict['scores']))
|
||||
print("emotion: " +emotion_dict['labels'][max_index])
|
||||
|
||||
|
||||
asr_file_stream()
|
|
@ -0,0 +1 @@
|
|||
存储目标说话人的语音特征,如要修改路径,请修改 utils/stt/speaker_ver_utils中的DEFALUT_SAVE_PATH
|
Binary file not shown.
|
@ -0,0 +1,214 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Importing the dtw module. When using in academic works please cite:\n",
|
||||
" T. Giorgino. Computing and Visualizing Dynamic Time Warping Alignments in R: The dtw Package.\n",
|
||||
" J. Stat. Soft., doi:10.18637/jss.v031.i07.\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"import os\n",
|
||||
"sys.path.append(\"../\")\n",
|
||||
"from utils.tts.openvoice_utils import TextToSpeech\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\utils\\weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\n",
|
||||
" warnings.warn(\"torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\")\n",
|
||||
"Building prefix dict from the default dictionary ...\n",
|
||||
"Loading model from cache C:\\Users\\bing\\AppData\\Local\\Temp\\jieba.cache\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"load base tts model successfully!\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Loading model cost 0.304 seconds.\n",
|
||||
"Prefix dict has been built successfully.\n",
|
||||
"Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n",
|
||||
"- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
||||
"- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
||||
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:797: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
|
||||
" return F.conv_transpose1d(\n",
|
||||
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:306: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
|
||||
" return F.conv1d(input, weight, bias, self.stride,\n",
|
||||
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\utils\\weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\n",
|
||||
" warnings.warn(\"torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"generate base speech!\n",
|
||||
"**********************,tts sr 44100\n",
|
||||
"audio segment length is [torch.Size([81565])]\n",
|
||||
"True\n",
|
||||
"Loaded checkpoint 'D:\\python\\OpenVoice\\checkpoints_v2\\converter/checkpoint.pth'\n",
|
||||
"missing/unexpected keys: [] []\n",
|
||||
"load tone color converter successfully!\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = TextToSpeech(use_tone_convert=True, device=\"cuda\", debug=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 测试用,将mp3转为int32类型的numpy,对齐输入端\n",
|
||||
"from pydub import AudioSegment\n",
|
||||
"import numpy as np\n",
|
||||
"source_audio=r\"D:\\python\\OpenVoice\\resources\\demo_speaker0.mp3\"\n",
|
||||
"audio = AudioSegment.from_file(source_audio, format=\"mp3\")\n",
|
||||
"raw_data = audio.raw_data\n",
|
||||
"audio_array = np.frombuffer(raw_data, dtype=np.int32)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"OpenVoice version: v2\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\functional.py:665: UserWarning: stft with return_complex=False is deprecated. In a future pytorch release, stft will return complex tensors for all inputs, and return_complex=False will raise an error.\n",
|
||||
"Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\SpectralOps.cpp:878.)\n",
|
||||
" return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]\n",
|
||||
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
|
||||
" return F.conv2d(input, weight, bias, self.stride,\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 获取并设置目标说话人的speaker embedding\n",
|
||||
"# audio_array :输入的音频信号,类型为 np.ndarray\n",
|
||||
"# 获取speaker embedding\n",
|
||||
"target_se = model.audio2emb(audio_array, rate=44100, vad=True)\n",
|
||||
"# 将模型的默认目标说话人embedding设置为 target_se\n",
|
||||
"model.initialize_target_se(target_se)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:797: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
|
||||
" return F.conv_transpose1d(\n",
|
||||
"c:\\Users\\bing\\.conda\\envs\\openVoice\\lib\\site-packages\\torch\\nn\\modules\\conv.py:306: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
|
||||
" return F.conv1d(input, weight, bias, self.stride,\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"generate base speech!\n",
|
||||
"**********************,tts sr 44100\n",
|
||||
"audio segment length is [torch.Size([216378])]\n",
|
||||
"Audio saved to D:\\python\\OpenVoice\\outputs_v2\\demo_tts.wav\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 测试base_tts,不含音色转换\n",
|
||||
"text = \"你好呀,我不知道该怎么告诉你这件事,但是我真的很需要你。\"\n",
|
||||
"audio, sr = model._base_tts(text, speed=1)\n",
|
||||
"audio = model.tensor2numpy(audio)\n",
|
||||
"model.save_audio(audio, sr, r\"D:\\python\\OpenVoice\\outputs_v2\\demo_tts.wav\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"generate base speech!\n",
|
||||
"**********************,tts sr 44100\n",
|
||||
"audio segment length is [torch.Size([216378])]\n",
|
||||
"torch.float32\n",
|
||||
"**********************************, convert sr 22050\n",
|
||||
"tone color has been converted!\n",
|
||||
"Audio saved to D:\\python\\OpenVoice\\outputs_v2\\demo.wav\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 测试整体pipeline,包含音色转换\n",
|
||||
"text = \"你好呀,我不知道该怎么告诉你这件事,但是我真的很需要你。\"\n",
|
||||
"audio_bytes, sr = model.tts(text, speed=1)\n",
|
||||
"audio = np.frombuffer(audio_bytes, dtype=np.int16).flatten()\n",
|
||||
"model.save_audio(audio, sr, r\"D:\\python\\OpenVoice\\outputs_v2\\demo.wav\" )"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "openVoice",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.19"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
31
main.py
31
main.py
|
@ -1,24 +1,9 @@
|
|||
from app import app, Config
|
||||
import uvicorn
|
||||
import os
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host=Config.UVICORN.HOST, port=Config.UVICORN.PORT)
|
||||
# _ooOoo_ #
|
||||
# o8888888o #
|
||||
# 88" . "88 #
|
||||
# (| -_- |) #
|
||||
# O\ = /O #
|
||||
# ____/`---'\____ #
|
||||
# . ' \\| |// `. #
|
||||
# / \\||| : |||// \ #
|
||||
# / _||||| -:- |||||- \ #
|
||||
# | | \\\ - /// | | #
|
||||
# \ .-\__ `-` ___/-. / #
|
||||
# ___`. .' /--.--\ `. . __ #
|
||||
# ."" '< `.___\_<|>_/___.' >'"". #
|
||||
# | | : `- \`.;`\ _ /`;.`/ - ` : | | #
|
||||
# \ \ `-. \_ __\ /__ _/ .-` / / #
|
||||
# ======`-.____`-.___\_____/___.-`____.-'====== #
|
||||
# `=---=' #
|
||||
# ............................................. #
|
||||
# 佛祖保佑 永无BUG #
|
||||
if __name__ == '__main__':
|
||||
|
||||
script_path = os.path.join(os.path.dirname(__file__), 'app', 'main.py')
|
||||
|
||||
# 使用exec函数执行脚本
|
||||
with open(script_path, 'r') as file:
|
||||
exec(file.read())
|
||||
|
|
|
@ -7,6 +7,7 @@ redis
|
|||
requests
|
||||
websockets
|
||||
numpy
|
||||
funasr
|
||||
jieba
|
||||
cn2an
|
||||
unidecode
|
||||
|
@ -18,8 +19,3 @@ numba
|
|||
soundfile
|
||||
webrtcvad
|
||||
apscheduler
|
||||
aiohttp
|
||||
faster_whisper
|
||||
whisper_timestamped
|
||||
modelscope
|
||||
wavmark
|
Binary file not shown.
Binary file not shown.
|
@ -1,9 +1,31 @@
|
|||
from tests.unit_test.user_test import user_test
|
||||
from tests.unit_test.character_test import character_test
|
||||
from tests.unit_test.chat_test import chat_test
|
||||
from tests.unit_test.user_test import UserServiceTest
|
||||
from tests.unit_test.character_test import CharacterServiceTest
|
||||
from tests.unit_test.chat_test import ChatServiceTest
|
||||
import asyncio
|
||||
|
||||
if __name__ == '__main__':
|
||||
user_test()
|
||||
character_test()
|
||||
chat_test()
|
||||
user_service_test = UserServiceTest()
|
||||
character_service_test = CharacterServiceTest()
|
||||
chat_service_test = ChatServiceTest()
|
||||
|
||||
user_service_test.test_user_create()
|
||||
user_service_test.test_user_update()
|
||||
user_service_test.test_user_query()
|
||||
user_service_test.test_hardware_bind()
|
||||
user_service_test.test_hardware_unbind()
|
||||
user_service_test.test_user_delete()
|
||||
|
||||
character_service_test.test_character_create()
|
||||
character_service_test.test_character_update()
|
||||
character_service_test.test_character_query()
|
||||
character_service_test.test_character_delete()
|
||||
|
||||
chat_service_test.test_create_chat()
|
||||
chat_service_test.test_session_id_query()
|
||||
chat_service_test.test_session_content_query()
|
||||
chat_service_test.test_session_update()
|
||||
asyncio.run(chat_service_test.test_chat_temporary())
|
||||
asyncio.run(chat_service_test.test_chat_lasting())
|
||||
asyncio.run(chat_service_test.test_voice_call())
|
||||
chat_service_test.test_chat_delete()
|
||||
print("全部测试成功")
|
|
@ -2,7 +2,7 @@ import requests
|
|||
import json
|
||||
|
||||
class CharacterServiceTest:
|
||||
def __init__(self,socket="http://127.0.0.1:8001"):
|
||||
def __init__(self,socket="http://114.214.236.207:7878"):
|
||||
self.socket = socket
|
||||
|
||||
def test_character_create(self):
|
||||
|
@ -66,14 +66,9 @@ class CharacterServiceTest:
|
|||
else:
|
||||
raise Exception("角色删除测试失败")
|
||||
|
||||
|
||||
def character_test():
|
||||
if __name__ == '__main__':
|
||||
character_service_test = CharacterServiceTest()
|
||||
character_service_test.test_character_create()
|
||||
character_service_test.test_character_update()
|
||||
character_service_test.test_character_query()
|
||||
character_service_test.test_character_delete()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
character_test()
|
|
@ -10,7 +10,7 @@ import websockets
|
|||
|
||||
|
||||
class ChatServiceTest:
|
||||
def __init__(self,socket="http://127.0.0.1:7878"):
|
||||
def __init__(self,socket="http://114.214.236.207:7878"):
|
||||
self.socket = socket
|
||||
|
||||
|
||||
|
@ -30,7 +30,6 @@ class ChatServiceTest:
|
|||
}
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("用户创建成功")
|
||||
self.user_id = response.json()['data']['user_id']
|
||||
else:
|
||||
raise Exception("创建聊天时,用户创建失败")
|
||||
|
@ -58,37 +57,6 @@ class ChatServiceTest:
|
|||
else:
|
||||
raise Exception("创建聊天时,角色创建失败")
|
||||
|
||||
#上传音频用于音频克隆
|
||||
url = f"{self.socket}/users/audio?user_id={self.user_id}"
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
current_file_path = os.path.dirname(current_file_path)
|
||||
tests_dir = os.path.dirname(current_file_path)
|
||||
mp3_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
|
||||
with open(mp3_file_path, 'rb') as audio_file:
|
||||
files = {'audio_file':(mp3_file_path, audio_file, 'audio/mpeg')}
|
||||
response = requests.post(url,files=files)
|
||||
if response.status_code == 200:
|
||||
self.audio_id = response.json()['data']['audio_id']
|
||||
print("音频上传成功")
|
||||
else:
|
||||
raise Exception("音频上传失败")
|
||||
|
||||
#绑定音频
|
||||
url = f"{self.socket}/users/audio/bind"
|
||||
payload = json.dumps({
|
||||
"user_id":self.user_id,
|
||||
"audio_id":self.audio_id
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("音频绑定测试成功")
|
||||
else:
|
||||
raise Exception("音频绑定测试失败")
|
||||
|
||||
|
||||
#创建一个对话
|
||||
url = f"{self.socket}/chats"
|
||||
payload = json.dumps({
|
||||
|
@ -98,7 +66,6 @@ class ChatServiceTest:
|
|||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("对话创建成功")
|
||||
|
@ -134,8 +101,8 @@ class ChatServiceTest:
|
|||
payload = json.dumps({
|
||||
"user_id": self.user_id,
|
||||
"messages": "[{\"role\": \"system\", \"content\": \"我们正在角色扮演对话游戏中,你需要始终保持角色扮演并待在角色设定的情景中,你扮演的角色信息如下:\\n角色名称: 海绵宝宝。\\n角色背景: 厨师,做汉堡\\n角色所处环境: 海绵宝宝住在深海的大菠萝里面\\n角色的常用问候语: 你好啊,海绵宝宝。\\n\\n你需要用简单、通俗易懂的口语化方式进行对话,在没有经过允许的情况下,你需要保持上述角色,不得擅自跳出角色设定。\\n\"}]",
|
||||
"user_info": "{\"character\": \"\", \"events\": [] }",
|
||||
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2, \"speed\": 1.0}",
|
||||
"user_info": "{}",
|
||||
"tts_info": "{\"language\": 0, \"speaker_id\": 97, \"noise_scale\": 0.1, \"noise_scale_w\": 0.668, \"length_scale\": 1.2}",
|
||||
"llm_info": "{\"model\": \"abab5.5-chat\", \"temperature\": 1, \"top_p\": 0.9}",
|
||||
"token": 0}
|
||||
)
|
||||
|
@ -148,19 +115,6 @@ class ChatServiceTest:
|
|||
else:
|
||||
raise Exception("Session更新测试失败")
|
||||
|
||||
def test_session_speakerid_update(self):
|
||||
url = f"{self.socket}/sessions/tts_info/speaker_id/{self.session_id}"
|
||||
payload = json.dumps({
|
||||
"speaker_id" :37
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("PUT", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("Session SpeakerId更新测试成功")
|
||||
else:
|
||||
raise Exception("Session SpeakerId更新测试失败")
|
||||
|
||||
#测试单次聊天
|
||||
async def test_chat_temporary(self):
|
||||
|
@ -195,7 +149,7 @@ class ChatServiceTest:
|
|||
await websocket.send(message)
|
||||
|
||||
|
||||
async with websockets.connect(f'ws://127.0.0.1:7878/chat/streaming/temporary') as websocket:
|
||||
async with websockets.connect(f'ws://114.214.236.207:7878/chat/streaming/temporary') as websocket:
|
||||
chunks = read_wav_file_in_chunks(2048) # 读取PCM文件并生成数据块
|
||||
for chunk in chunks:
|
||||
await send_audio_chunk(websocket, chunk)
|
||||
|
@ -251,7 +205,7 @@ class ChatServiceTest:
|
|||
message = json.dumps(data)
|
||||
await websocket.send(message)
|
||||
|
||||
async with websockets.connect(f'ws://127.0.0.1:7878/chat/streaming/lasting') as websocket:
|
||||
async with websockets.connect(f'ws://114.214.236.207:7878/chat/streaming/lasting') as websocket:
|
||||
#发送第一次
|
||||
chunks = read_wav_file_in_chunks(2048)
|
||||
for chunk in chunks:
|
||||
|
@ -301,7 +255,7 @@ class ChatServiceTest:
|
|||
current_dir = os.path.dirname(current_file_path)
|
||||
tests_dir = os.path.dirname(current_dir)
|
||||
file_path = os.path.join(tests_dir, 'assets', 'voice_call.wav')
|
||||
url = f"ws://127.0.0.1:7878/chat/voice_call"
|
||||
url = f"ws://114.214.236.207:7878/chat/voice_call"
|
||||
#发送格式
|
||||
ws_data = {
|
||||
"audio" : "",
|
||||
|
@ -339,6 +293,7 @@ class ChatServiceTest:
|
|||
await asyncio.gather(audio_stream(websocket))
|
||||
|
||||
|
||||
|
||||
#测试删除聊天
|
||||
def test_chat_delete(self):
|
||||
url = f"{self.socket}/chats/{self.user_character_id}"
|
||||
|
@ -348,11 +303,6 @@ class ChatServiceTest:
|
|||
else:
|
||||
raise Exception("聊天删除测试失败")
|
||||
|
||||
url = f"{self.socket}/users/audio/{self.audio_id}"
|
||||
response = requests.request("DELETE", url)
|
||||
if response.status_code != 200:
|
||||
raise Exception("音频删除测试失败")
|
||||
|
||||
url = f"{self.socket}/users/{self.user_id}"
|
||||
response = requests.request("DELETE", url)
|
||||
if response.status_code != 200:
|
||||
|
@ -363,18 +313,17 @@ class ChatServiceTest:
|
|||
if response.status_code != 200:
|
||||
raise Exception("角色删除测试失败")
|
||||
|
||||
def chat_test():
|
||||
|
||||
if __name__ == '__main__':
|
||||
chat_service_test = ChatServiceTest()
|
||||
chat_service_test.test_create_chat()
|
||||
chat_service_test.test_session_id_query()
|
||||
chat_service_test.test_session_content_query()
|
||||
chat_service_test.test_session_update()
|
||||
chat_service_test.test_session_speakerid_update()
|
||||
asyncio.run(chat_service_test.test_chat_temporary())
|
||||
asyncio.run(chat_service_test.test_chat_lasting())
|
||||
asyncio.run(chat_service_test.test_voice_call())
|
||||
chat_service_test.test_chat_delete()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
chat_test()
|
||||
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
from chat_test import chat_test
|
||||
import multiprocessing
|
||||
|
||||
if __name__ == '__main__':
|
||||
processes = []
|
||||
for _ in range(2):
|
||||
p = multiprocessing.Process(target=chat_test)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
|
@ -1,11 +1,10 @@
|
|||
import requests
|
||||
import json
|
||||
import uuid
|
||||
import os
|
||||
|
||||
|
||||
class UserServiceTest:
|
||||
def __init__(self,socket="http://127.0.0.1:7878"):
|
||||
def __init__(self,socket="http://114.214.236.207:7878"):
|
||||
self.socket = socket
|
||||
|
||||
def test_user_create(self):
|
||||
|
@ -67,7 +66,7 @@ class UserServiceTest:
|
|||
mac = "08:00:20:0A:8C:6G"
|
||||
payload = json.dumps({
|
||||
"mac":mac,
|
||||
"user_id":self.id,
|
||||
"user_id":1,
|
||||
"firmware":"v1.0",
|
||||
"model":"香橙派"
|
||||
})
|
||||
|
@ -89,122 +88,12 @@ class UserServiceTest:
|
|||
else:
|
||||
raise Exception("硬件解绑测试失败")
|
||||
|
||||
def test_hardware_bind_change(self):
|
||||
url = f"{self.socket}/users/hardware/{self.hd_id}/bindchange"
|
||||
payload = json.dumps({
|
||||
"user_id" : self.id
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("PUT", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("硬件换绑测试成功")
|
||||
else:
|
||||
raise Exception("硬件换绑测试失败")
|
||||
|
||||
def test_hardware_update(self):
|
||||
url = f"{self.socket}/users/hardware/{self.hd_id}/info"
|
||||
payload = json.dumps({
|
||||
"mac":"08:00:20:0A:8C:6G",
|
||||
"firmware":"v1.0",
|
||||
"model":"香橙派"
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("PUT", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("硬件信息更新测试成功")
|
||||
else:
|
||||
raise Exception("硬件信息更新测试失败")
|
||||
|
||||
def test_hardware_query(self):
|
||||
url = f"{self.socket}/users/hardware/{self.hd_id}"
|
||||
response = requests.request("GET", url)
|
||||
if response.status_code == 200:
|
||||
print("硬件查询测试成功")
|
||||
else:
|
||||
raise Exception("硬件查询测试失败")
|
||||
|
||||
def test_upload_audio(self):
|
||||
url = f"{self.socket}/users/audio?user_id={self.id}"
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
current_dir = os.path.dirname(current_file_path)
|
||||
tests_dir = os.path.dirname(current_dir)
|
||||
wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
|
||||
with open(wav_file_path, 'rb') as audio_file:
|
||||
files = {'audio_file':(wav_file_path,audio_file,'audio/mpeg')}
|
||||
response = requests.post(url, files=files)
|
||||
if response.status_code == 200:
|
||||
self.audio_id = response.json()["data"]['audio_id']
|
||||
print("音频上传测试成功")
|
||||
else:
|
||||
raise Exception("音频上传测试失败")
|
||||
|
||||
def test_update_audio(self):
|
||||
url = f"{self.socket}/users/audio/{self.audio_id}"
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
current_dir = os.path.dirname(current_file_path)
|
||||
tests_dir = os.path.dirname(current_dir)
|
||||
wav_file_path = os.path.join(tests_dir, 'assets', 'demo_speaker0.mp3')
|
||||
with open(wav_file_path, 'rb') as audio_file:
|
||||
files = {'audio_file':(wav_file_path,audio_file,'audio/wav')}
|
||||
response = requests.put(url, files=files)
|
||||
if response.status_code == 200:
|
||||
print("音频上传测试成功")
|
||||
else:
|
||||
raise Exception("音频上传测试失败")
|
||||
|
||||
def test_bind_audio(self):
|
||||
url = f"{self.socket}/users/audio/bind"
|
||||
payload = json.dumps({
|
||||
"user_id":self.id,
|
||||
"audio_id":self.audio_id
|
||||
})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
print("音频绑定测试成功")
|
||||
else:
|
||||
raise Exception("音频绑定测试失败")
|
||||
|
||||
def test_audio_download(self):
|
||||
url = f"{self.socket}/users/audio/{self.audio_id}"
|
||||
response = requests.request("GET", url)
|
||||
if response.status_code == 200:
|
||||
print("音频下载测试成功")
|
||||
else:
|
||||
raise Exception("音频下载测试失败")
|
||||
|
||||
def test_audio_delete(self):
|
||||
url = f"{self.socket}/users/audio/{self.audio_id}"
|
||||
response = requests.request("DELETE", url)
|
||||
if response.status_code == 200:
|
||||
print("音频删除测试成功")
|
||||
else:
|
||||
raise Exception("音频删除测试失败")
|
||||
|
||||
|
||||
def user_test():
|
||||
if __name__ == '__main__':
|
||||
user_service_test = UserServiceTest()
|
||||
user_service_test.test_user_create()
|
||||
user_service_test.test_user_update()
|
||||
user_service_test.test_user_query()
|
||||
user_service_test.test_hardware_bind()
|
||||
user_service_test.test_hardware_bind_change()
|
||||
user_service_test.test_hardware_update()
|
||||
user_service_test.test_hardware_query()
|
||||
user_service_test.test_hardware_unbind()
|
||||
user_service_test.test_upload_audio()
|
||||
user_service_test.test_update_audio()
|
||||
user_service_test.test_bind_audio()
|
||||
user_service_test.test_audio_download()
|
||||
user_service_test.test_audio_delete()
|
||||
user_service_test.test_user_delete()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
user_test()
|
|
@ -0,0 +1 @@
|
|||
./ses 保存 source se 的 embedding 路径,格式为 *.pth
|
|
@ -29,7 +29,6 @@ class FunAutoSpeechRecognizer(STTBase):
|
|||
**kwargs):
|
||||
super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
|
||||
|
||||
|
||||
self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
|
||||
|
||||
self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention
|
||||
|
@ -80,9 +79,9 @@ class FunAutoSpeechRecognizer(STTBase):
|
|||
def _init_asr(self):
|
||||
# 随机初始化一段音频数据
|
||||
init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
|
||||
self.session_signup("init")
|
||||
self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id="init")
|
||||
self.session_signout("init")
|
||||
self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
||||
self.audio_cache = {}
|
||||
self.asr_cache = {}
|
||||
# print("init ASR model done.")
|
||||
|
||||
# when chat trying to use asr , sign up
|
||||
|
@ -109,79 +108,6 @@ class FunAutoSpeechRecognizer(STTBase):
|
|||
"""
|
||||
text_dict = dict(text=[], is_end=is_end)
|
||||
|
||||
audio_cache = self.audio_cache[session_id]
|
||||
|
||||
audio_data = self.check_audio_type(audio_data)
|
||||
if audio_cache is None:
|
||||
audio_cache = audio_data
|
||||
else:
|
||||
if audio_cache.shape[0] > 0:
|
||||
audio_cache = np.concatenate([audio_cache, audio_data], axis=0)
|
||||
|
||||
if not is_end and audio_cache.shape[0] < self.chunk_partial_size:
|
||||
self.audio_cache[session_id] = audio_cache
|
||||
return text_dict
|
||||
|
||||
total_chunk_num = int((len(audio_cache)-1)/self.chunk_partial_size)
|
||||
|
||||
if is_end:
|
||||
# if the audio data is the end of a sentence, \
|
||||
# we need to add one more chunk to the end to \
|
||||
# ensure the end of the sentence is recognized correctly.
|
||||
auto_det_end = True
|
||||
|
||||
if auto_det_end:
|
||||
total_chunk_num += 1
|
||||
|
||||
end_idx = None
|
||||
for i in range(total_chunk_num):
|
||||
if auto_det_end:
|
||||
is_end = i == total_chunk_num - 1
|
||||
start_idx = i*self.chunk_partial_size
|
||||
if auto_det_end:
|
||||
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
|
||||
else:
|
||||
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
|
||||
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
|
||||
# t_stamp = time.time()
|
||||
|
||||
speech_chunk = audio_cache[start_idx:end_idx]
|
||||
|
||||
# TODO: exceptions processes
|
||||
# print("i:", i)
|
||||
try:
|
||||
res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back, session_id=session_id)
|
||||
except ValueError as e:
|
||||
print(f"ValueError: {e}")
|
||||
continue
|
||||
text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
|
||||
# print(f"each chunk time: {time.time()-t_stamp}")
|
||||
|
||||
if is_end:
|
||||
audio_cache = None
|
||||
else:
|
||||
if end_idx:
|
||||
audio_cache = audio_cache[end_idx:] # cut the processed part from audio_cache
|
||||
text_dict['is_end'] = is_end
|
||||
|
||||
|
||||
self.audio_cache[session_id] = audio_cache
|
||||
return text_dict
|
||||
|
||||
def streaming_recognize_origin(self,
|
||||
session_id,
|
||||
audio_data,
|
||||
is_end=False,
|
||||
auto_det_end=False):
|
||||
"""recognize partial result
|
||||
|
||||
Args:
|
||||
audio_data: bytes or numpy array, partial audio data
|
||||
is_end: bool, whether the audio data is the end of a sentence
|
||||
auto_det_end: bool, whether to automatically detect the end of a audio data
|
||||
"""
|
||||
text_dict = dict(text=[], is_end=is_end)
|
||||
|
||||
audio_cache = self.audio_cache[session_id]
|
||||
asr_cache = self.asr_cache[session_id]
|
||||
|
||||
|
@ -242,3 +168,175 @@ class FunAutoSpeechRecognizer(STTBase):
|
|||
self.audio_cache[session_id] = audio_cache
|
||||
self.asr_cache[session_id] = asr_cache
|
||||
return text_dict
|
||||
|
||||
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# ####################################################### #
|
||||
# FunAutoSpeechRecognizer: https://github.com/alibaba-damo-academy/FunASR
|
||||
# ####################################################### #
|
||||
# import io
|
||||
# import numpy as np
|
||||
# import base64
|
||||
# import wave
|
||||
# from funasr import AutoModel
|
||||
# from .base_stt import STTBase
|
||||
|
||||
# def decode_str2bytes(data):
|
||||
# # 将Base64编码的字节串解码为字节串
|
||||
# if data is None:
|
||||
# return None
|
||||
# return base64.b64decode(data.encode('utf-8'))
|
||||
|
||||
# class FunAutoSpeechRecognizer(STTBase):
|
||||
# def __init__(self,
|
||||
# model_path="paraformer-zh-streaming",
|
||||
# device="cuda",
|
||||
# RATE=16000,
|
||||
# cfg_path=None,
|
||||
# debug=False,
|
||||
# chunk_ms=480,
|
||||
# encoder_chunk_look_back=4,
|
||||
# decoder_chunk_look_back=1,
|
||||
# **kwargs):
|
||||
# super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
|
||||
|
||||
# self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
|
||||
|
||||
# self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention
|
||||
# self.decoder_chunk_look_back = decoder_chunk_look_back #number of encoder chunks to lookback for decoder cross-attention
|
||||
|
||||
# #[0, 8, 4] 480ms, [0, 10, 5] 600ms
|
||||
# if chunk_ms == 480:
|
||||
# self.chunk_size = [0, 8, 4]
|
||||
# elif chunk_ms == 600:
|
||||
# self.chunk_size = [0, 10, 5]
|
||||
# else:
|
||||
# raise ValueError("`chunk_ms` should be 480 or 600, and type is int.")
|
||||
# self.chunk_partial_size = self.chunk_size[1] * 960
|
||||
# self.audio_cache = None
|
||||
# self.asr_cache = {}
|
||||
|
||||
|
||||
|
||||
# self._init_asr()
|
||||
|
||||
# def check_audio_type(self, audio_data):
|
||||
# """check audio data type and convert it to bytes if necessary."""
|
||||
# if isinstance(audio_data, bytes):
|
||||
# pass
|
||||
# elif isinstance(audio_data, list):
|
||||
# audio_data = b''.join(audio_data)
|
||||
# elif isinstance(audio_data, str):
|
||||
# audio_data = decode_str2bytes(audio_data)
|
||||
# elif isinstance(audio_data, io.BytesIO):
|
||||
# wf = wave.open(audio_data, 'rb')
|
||||
# audio_data = wf.readframes(wf.getnframes())
|
||||
# elif isinstance(audio_data, np.ndarray):
|
||||
# pass
|
||||
# else:
|
||||
# raise TypeError(f"audio_data must be bytes, list, str, \
|
||||
# io.BytesIO or numpy array, but got {type(audio_data)}")
|
||||
|
||||
# if isinstance(audio_data, bytes):
|
||||
# audio_data = np.frombuffer(audio_data, dtype=np.int16)
|
||||
# elif isinstance(audio_data, np.ndarray):
|
||||
# if audio_data.dtype != np.int16:
|
||||
# audio_data = audio_data.astype(np.int16)
|
||||
# else:
|
||||
# raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
|
||||
# return audio_data
|
||||
|
||||
# def _init_asr(self):
|
||||
# # 随机初始化一段音频数据
|
||||
# init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
|
||||
# self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
||||
# self.audio_cache = None
|
||||
# self.asr_cache = {}
|
||||
# # print("init ASR model done.")
|
||||
|
||||
# def recognize(self, audio_data):
|
||||
# """recognize audio data to text"""
|
||||
# audio_data = self.check_audio_type(audio_data)
|
||||
# result = self.asr_model.generate(input=audio_data,
|
||||
# batch_size_s=300,
|
||||
# hotword=self.hotwords)
|
||||
|
||||
# # print(result)
|
||||
# text = ''
|
||||
# for res in result:
|
||||
# text += res['text']
|
||||
# return text
|
||||
|
||||
# def streaming_recognize(self,
|
||||
# audio_data,
|
||||
# is_end=False,
|
||||
# auto_det_end=False):
|
||||
# """recognize partial result
|
||||
|
||||
# Args:
|
||||
# audio_data: bytes or numpy array, partial audio data
|
||||
# is_end: bool, whether the audio data is the end of a sentence
|
||||
# auto_det_end: bool, whether to automatically detect the end of a audio data
|
||||
# """
|
||||
# text_dict = dict(text=[], is_end=is_end)
|
||||
|
||||
# audio_data = self.check_audio_type(audio_data)
|
||||
# if self.audio_cache is None:
|
||||
# self.audio_cache = audio_data
|
||||
# else:
|
||||
# # print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
|
||||
# if self.audio_cache.shape[0] > 0:
|
||||
# self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0)
|
||||
|
||||
# if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size:
|
||||
# return text_dict
|
||||
|
||||
# total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
|
||||
|
||||
# if is_end:
|
||||
# # if the audio data is the end of a sentence, \
|
||||
# # we need to add one more chunk to the end to \
|
||||
# # ensure the end of the sentence is recognized correctly.
|
||||
# auto_det_end = True
|
||||
|
||||
# if auto_det_end:
|
||||
# total_chunk_num += 1
|
||||
|
||||
# # print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
|
||||
# end_idx = None
|
||||
# for i in range(total_chunk_num):
|
||||
# if auto_det_end:
|
||||
# is_end = i == total_chunk_num - 1
|
||||
# start_idx = i*self.chunk_partial_size
|
||||
# if auto_det_end:
|
||||
# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
|
||||
# else:
|
||||
# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
|
||||
# # print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
|
||||
# # t_stamp = time.time()
|
||||
|
||||
# speech_chunk = self.audio_cache[start_idx:end_idx]
|
||||
|
||||
# # TODO: exceptions processes
|
||||
# try:
|
||||
# res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
||||
# except ValueError as e:
|
||||
# print(f"ValueError: {e}")
|
||||
# continue
|
||||
# text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
|
||||
# # print(f"each chunk time: {time.time()-t_stamp}")
|
||||
|
||||
# if is_end:
|
||||
# self.audio_cache = None
|
||||
# self.asr_cache = {}
|
||||
# else:
|
||||
# if end_idx:
|
||||
# self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
|
||||
# text_dict['is_end'] = is_end
|
||||
|
||||
# # print(f"text_dict: {text_dict}")
|
||||
# return text_dict
|
||||
|
||||
|
||||
|
|
@ -1,29 +1,209 @@
|
|||
from .funasr_utils import FunAutoSpeechRecognizer
|
||||
from .punctuation_utils import FUNASR, Punctuation
|
||||
from .punctuation_utils import CTTRANSFORMER, Punctuation
|
||||
from .emotion_utils import FUNASRFINETUNE, Emotion
|
||||
from .speaker_ver_utils import ERES2NETV2, DEFALUT_SAVE_PATH, speaker_verfication
|
||||
import os
|
||||
|
||||
class ModifiedRecognizer():
|
||||
def __init__(self):
|
||||
#增加语音识别模型
|
||||
self.asr_model = FunAutoSpeechRecognizer()
|
||||
import numpy as np
|
||||
class ModifiedRecognizer(FunAutoSpeechRecognizer):
|
||||
def __init__(self,
|
||||
use_punct=True,
|
||||
use_emotion=False,
|
||||
use_speaker_ver=True):
|
||||
|
||||
# 创建基础的 funasr模型,用于语音识别,识别出不带标点的句子
|
||||
super().__init__(
|
||||
model_path="paraformer-zh-streaming",
|
||||
device="cuda",
|
||||
RATE=16000,
|
||||
cfg_path=None,
|
||||
debug=False,
|
||||
chunk_ms=480,
|
||||
encoder_chunk_look_back=4,
|
||||
decoder_chunk_look_back=1)
|
||||
|
||||
# 记录是否具备附加功能
|
||||
self.use_punct = use_punct
|
||||
self.use_emotion = use_emotion
|
||||
self.use_speaker_ver = use_speaker_ver
|
||||
|
||||
# 增加标点模型
|
||||
self.puctuation_model = Punctuation(**FUNASR)
|
||||
if use_punct:
|
||||
self.puctuation_model = Punctuation(**CTTRANSFORMER)
|
||||
|
||||
# 情绪识别模型
|
||||
if use_emotion:
|
||||
self.emotion_model = Emotion(**FUNASRFINETUNE)
|
||||
|
||||
def session_signup(self, session_id):
|
||||
self.asr_model.session_signup(session_id)
|
||||
# 说话人识别模型
|
||||
if use_speaker_ver:
|
||||
self.speaker_ver_model = speaker_verfication(**ERES2NETV2)
|
||||
|
||||
def session_signout(self, session_id):
|
||||
self.asr_model.session_signout(session_id)
|
||||
def initialize_speaker(self, speaker_1_wav):
|
||||
"""
|
||||
用于说话人识别,将输入的音频(speaker_1_wav)设立为目标说话人,并将其特征保存本地
|
||||
"""
|
||||
if not self.use_speaker_ver:
|
||||
raise NotImplementedError("no access")
|
||||
if speaker_1_wav.endswith(".npy"):
|
||||
self.save_speaker_path = speaker_1_wav
|
||||
elif speaker_1_wav.endswith('.wav'):
|
||||
self.save_speaker_path = os.path.join(DEFALUT_SAVE_PATH,
|
||||
os.path.basename(speaker_1_wav).replace(".wav", ".npy"))
|
||||
# self.save_speaker_path = DEFALUT_SAVE_PATH
|
||||
self.speaker_ver_model.wav2embeddings(speaker_1_wav, self.save_speaker_path)
|
||||
else:
|
||||
raise TypeError("only support [.npy] or [.wav].")
|
||||
|
||||
def streaming_recognize(self, session_id, audio_data,is_end=False):
|
||||
return self.asr_model.streaming_recognize(session_id, audio_data,is_end=is_end)
|
||||
|
||||
def punctuation_correction(self, sentence):
|
||||
return self.puctuation_model.process(sentence)
|
||||
def speaker_ver(self, speaker_2_wav):
|
||||
"""
|
||||
用于说话人识别,判断输入音频是否为目标说话人,
|
||||
是返回True,不是返回False
|
||||
"""
|
||||
if not self.use_speaker_ver:
|
||||
raise NotImplementedError("no access")
|
||||
if not hasattr(self, "save_speaker_path"):
|
||||
raise NotImplementedError("please initialize speaker first")
|
||||
|
||||
def emtion_recognition(self, audio):
|
||||
return self.emotion_model.process(audio)
|
||||
# self.speaker_ver_model.verfication 返回值为字符串 'yes' / 'no'
|
||||
return self.speaker_ver_model.verfication(base_emb=self.save_speaker_path,
|
||||
speaker_2_wav=speaker_2_wav) == 'yes'
|
||||
|
||||
|
||||
def recognize(self, audio_data):
|
||||
"""
|
||||
非流式语音识别,返回识别出的文本,返回值类型 str
|
||||
"""
|
||||
audio_data = self.check_audio_type(audio_data)
|
||||
|
||||
# 说话人识别
|
||||
if self.use_speaker_ver:
|
||||
if self.speaker_ver_model.verfication(self.save_speaker_path,
|
||||
speaker_2_wav=audio_data) == 'no':
|
||||
return "Other People"
|
||||
|
||||
# 语音识别
|
||||
result = self.asr_model.generate(input=audio_data,
|
||||
batch_size_s=300,
|
||||
hotword=self.hotwords)
|
||||
text = ''
|
||||
for res in result:
|
||||
text += res['text']
|
||||
|
||||
# 添加标点
|
||||
if self.use_punct:
|
||||
text = self.puctuation_model.process(text+'#', append_period=False).replace('#', '')
|
||||
|
||||
return text
|
||||
|
||||
def recognize_emotion(self, audio_data):
|
||||
"""
|
||||
情感识别,返回值为:
|
||||
1. 如果说话人非目标说话人,返回字符串 "Other People"
|
||||
2. 如果说话人为目标说话人,返回字典{"Labels": List[str], "scores": List[int]}
|
||||
"""
|
||||
audio_data = self.check_audio_type(audio_data)
|
||||
|
||||
if self.use_speaker_ver:
|
||||
if self.speaker_ver_model.verfication(self.save_speaker_path,
|
||||
speaker_2_wav=audio_data) == 'no':
|
||||
return "Other People"
|
||||
|
||||
if self.use_emotion:
|
||||
return self.emotion_model.process(audio_data)
|
||||
else:
|
||||
raise NotImplementedError("no access")
|
||||
|
||||
def streaming_recognize(self, session_id, audio_data, is_end=False, auto_det_end=False):
|
||||
"""recognize partial result
|
||||
|
||||
Args:
|
||||
audio_data: bytes or numpy array, partial audio data
|
||||
is_end: bool, whether the audio data is the end of a sentence
|
||||
auto_det_end: bool, whether to automatically detect the end of a audio data
|
||||
|
||||
流式语音识别,返回值为:
|
||||
1. 如果说话人非目标说话人,返回字符串 "Other People"
|
||||
2. 如果说话人为目标说话人,返回字典{"test": List[str], "is_end": boolean}
|
||||
"""
|
||||
audio_cache = self.audio_cache[session_id]
|
||||
asr_cache = self.asr_cache[session_id]
|
||||
text_dict = dict(text=[], is_end=is_end)
|
||||
audio_data = self.check_audio_type(audio_data)
|
||||
|
||||
# 说话人识别
|
||||
if self.use_speaker_ver:
|
||||
if self.speaker_ver_model.verfication(self.save_speaker_path,
|
||||
speaker_2_wav=audio_data) == 'no':
|
||||
return "Other People"
|
||||
|
||||
# 语音识别
|
||||
if audio_cache is None:
|
||||
audio_cache = audio_data
|
||||
else:
|
||||
# print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
|
||||
if audio_cache.shape[0] > 0:
|
||||
audio_cache = np.concatenate([audio_cache, audio_data], axis=0)
|
||||
|
||||
if not is_end and audio_cache.shape[0] < self.chunk_partial_size:
|
||||
self.audio_cache[session_id] = audio_cache
|
||||
return text_dict
|
||||
|
||||
total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
|
||||
|
||||
if is_end:
|
||||
# if the audio data is the end of a sentence, \
|
||||
# we need to add one more chunk to the end to \
|
||||
# ensure the end of the sentence is recognized correctly.
|
||||
auto_det_end = True
|
||||
|
||||
if auto_det_end:
|
||||
total_chunk_num += 1
|
||||
|
||||
# print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
|
||||
end_idx = None
|
||||
for i in range(total_chunk_num):
|
||||
if auto_det_end:
|
||||
is_end = i == total_chunk_num - 1
|
||||
start_idx = i*self.chunk_partial_size
|
||||
if auto_det_end:
|
||||
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
|
||||
else:
|
||||
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
|
||||
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
|
||||
# t_stamp = time.time()
|
||||
|
||||
speech_chunk = audio_cache[start_idx:end_idx]
|
||||
|
||||
# TODO: exceptions processes
|
||||
try:
|
||||
res = self.asr_model.generate(input=speech_chunk, cache=asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
|
||||
except ValueError as e:
|
||||
print(f"ValueError: {e}")
|
||||
continue
|
||||
|
||||
# 增添标点
|
||||
if self.use_punct:
|
||||
text_dict['text'].append(self.puctuation_model.process(self.text_postprecess(res[0], data_id='text'), cache=text_dict))
|
||||
else:
|
||||
text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
|
||||
|
||||
|
||||
# print(f"each chunk time: {time.time()-t_stamp}")
|
||||
|
||||
if is_end:
|
||||
audio_cache = None
|
||||
asr_cache = {}
|
||||
else:
|
||||
if end_idx:
|
||||
audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
|
||||
text_dict['is_end'] = is_end
|
||||
|
||||
if self.use_punct and is_end:
|
||||
text_dict['text'].append(self.puctuation_model.process('#', cache=text_dict).replace('#', ''))
|
||||
|
||||
self.audio_cache[session_id] = audio_cache
|
||||
self.asr_cache[session_id] = asr_cache
|
||||
# print(f"text_dict: {text_dict}")
|
||||
return text_dict
|
|
@ -1,7 +1,7 @@
|
|||
from modelscope.pipelines import pipeline
|
||||
import numpy as np
|
||||
import os
|
||||
import pdb
|
||||
|
||||
ERES2NETV2 = {
|
||||
"task": 'speaker-verification',
|
||||
"model_name": 'damo/speech_eres2netv2_sv_zh-cn_16k-common',
|
||||
|
@ -10,7 +10,7 @@ ERES2NETV2 = {
|
|||
}
|
||||
|
||||
# 保存 embedding 的路径
|
||||
DEFALUT_SAVE_PATH = r".\takway\savePath"
|
||||
DEFALUT_SAVE_PATH = os.path.join(os.path.dirname(os.path.dirname(__name__)), "speaker_embedding")
|
||||
|
||||
class speaker_verfication:
|
||||
def __init__(self,
|
||||
|
@ -26,9 +26,11 @@ class speaker_verfication:
|
|||
device=device)
|
||||
self.save_embeddings = save_embeddings
|
||||
|
||||
def wav2embeddings(self, speaker_1_wav):
|
||||
def wav2embeddings(self, speaker_1_wav, save_path=None):
|
||||
result = self.pipeline([speaker_1_wav], output_emb=True)
|
||||
speaker_1_emb = result['embs'][0]
|
||||
if save_path is not None:
|
||||
np.save(save_path, speaker_1_emb)
|
||||
return speaker_1_emb
|
||||
|
||||
def _verifaction(self, speaker_1_wav, speaker_2_wav, threshold, save_path):
|
||||
|
@ -53,10 +55,19 @@ class speaker_verfication:
|
|||
return "no"
|
||||
|
||||
def verfication(self,
|
||||
base_emb,
|
||||
speaker_emb,
|
||||
threshold=0.333, ):
|
||||
return np.dot(base_emb, speaker_emb) / (np.linalg.norm(base_emb) * np.linalg.norm(speaker_emb)) > threshold
|
||||
base_emb=None,
|
||||
speaker_1_wav=None,
|
||||
speaker_2_wav=None,
|
||||
threshold=0.333,
|
||||
save_path=None):
|
||||
if base_emb is not None and speaker_1_wav is not None:
|
||||
raise ValueError("Only need one of them, base_emb or speaker_1_wav")
|
||||
if base_emb is not None and speaker_2_wav is not None:
|
||||
return self._verifaction_from_embedding(base_emb, speaker_2_wav, threshold)
|
||||
elif speaker_1_wav is not None and speaker_2_wav is not None:
|
||||
return self._verifaction(speaker_1_wav, speaker_2_wav, threshold, save_path)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if __name__ == '__main__':
|
||||
verifier = speaker_verfication(**ERES2NETV2)
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
1. 安装 melo
|
||||
参考 https://github.com/myshell-ai/OpenVoice/blob/main/docs/USAGE.md#openvoice-v2
|
||||
2. 修改 openvoice_utils.py 中的路径 (包括 SOURCE_SE_DIR / CACHE_PATH / OPENVOICE_TONE_COLOR_CONVERTER.converter_path )
|
||||
其中:
|
||||
SOURCE_SE_DIR 改为 utils/assets/ses
|
||||
OPENVOICE_TONE_COLOR_CONVERTER.converter_path 修改为下载到的模型参数路径,下载链接为 https://myshell-public-repo-hosting.s3.amazonaws.com/openvoice/checkpoints_v2_0417.zip
|
||||
3. 参考 examples/tts_demo.ipynb 进行代码迁移
|
|
@ -140,7 +140,6 @@ def get_se(audio_path, vc_model, target_dir='processed', vad=True):
|
|||
# if os.path.isdir(audio_path):
|
||||
# wavs_folder = audio_path
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
if vad:
|
||||
wavs_folder = split_audio_vad(audio_path, target_dir=target_dir, audio_name=audio_name)
|
||||
else:
|
||||
|
|
|
@ -1,57 +0,0 @@
|
|||
{
|
||||
"_version_": "v2",
|
||||
"data": {
|
||||
"sampling_rate": 22050,
|
||||
"filter_length": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_speakers": 0
|
||||
},
|
||||
"model": {
|
||||
"zero_g": true,
|
||||
"inter_channels": 192,
|
||||
"hidden_channels": 192,
|
||||
"filter_channels": 768,
|
||||
"n_heads": 2,
|
||||
"n_layers": 6,
|
||||
"kernel_size": 3,
|
||||
"p_dropout": 0.1,
|
||||
"resblock": "1",
|
||||
"resblock_kernel_sizes": [
|
||||
3,
|
||||
7,
|
||||
11
|
||||
],
|
||||
"resblock_dilation_sizes": [
|
||||
[
|
||||
1,
|
||||
3,
|
||||
5
|
||||
],
|
||||
[
|
||||
1,
|
||||
3,
|
||||
5
|
||||
],
|
||||
[
|
||||
1,
|
||||
3,
|
||||
5
|
||||
]
|
||||
],
|
||||
"upsample_rates": [
|
||||
8,
|
||||
8,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"upsample_initial_channel": 512,
|
||||
"upsample_kernel_sizes": [
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
4
|
||||
],
|
||||
"gin_channels": 256
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import re
|
||||
from glob import glob
|
||||
import hashlib
|
||||
from tqdm.auto import tqdm
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
|
@ -15,15 +16,11 @@ from .openvoice.api import ToneColorConverter
|
|||
from .openvoice.mel_processing import spectrogram_torch
|
||||
# torchaudio
|
||||
import torchaudio.functional as F
|
||||
|
||||
# 存储 BASE SPEAKER 的 embedding(source_se) 的路径
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
utils_dir = os.path.dirname(os.path.dirname(current_file_path))
|
||||
|
||||
SOURCE_SE_DIR = os.path.join(utils_dir,'assets','ses')
|
||||
SOURCE_SE_DIR = r"D:\python\OpenVoice\checkpoints_v2\base_speakers\ses"
|
||||
|
||||
# 存储缓存文件的路径
|
||||
CACHE_PATH = r"/tmp/openvoice_cache"
|
||||
CACHE_PATH = r"D:\python\OpenVoice\processed"
|
||||
|
||||
OPENVOICE_BASE_TTS={
|
||||
"model_type": "open_voice_base_tts",
|
||||
|
@ -31,11 +28,10 @@ OPENVOICE_BASE_TTS={
|
|||
"language": "ZH",
|
||||
}
|
||||
|
||||
converter_path = os.path.join(os.path.dirname(current_file_path),'openvoice_model')
|
||||
OPENVOICE_TONE_COLOR_CONVERTER={
|
||||
"model_type": "open_voice_converter",
|
||||
# 模型参数路径
|
||||
"converter_path": converter_path,
|
||||
"converter_path": r"D:\python\OpenVoice\checkpoints_v2\converter",
|
||||
}
|
||||
|
||||
class TextToSpeech:
|
||||
|
@ -122,7 +118,6 @@ class TextToSpeech:
|
|||
elif isinstance(se, torch.Tensor):
|
||||
self.target_se = se.float().to(self.device)
|
||||
|
||||
#语音转numpy
|
||||
def audio2numpy(self, audio_data: Union[bytes, np.ndarray]):
|
||||
"""
|
||||
将字节流的audio转为numpy类型,也可以传入numpy类型
|
||||
|
@ -148,14 +143,30 @@ class TextToSpeech:
|
|||
return: np.ndarray
|
||||
"""
|
||||
audio_data = self.audio2numpy(audio_data)
|
||||
if not os.path.exists(CACHE_PATH):
|
||||
os.makedirs(CACHE_PATH)
|
||||
|
||||
from scipy.io import wavfile
|
||||
audio_path = os.path.join(CACHE_PATH, "tmp.wav")
|
||||
wavfile.write(audio_path, rate=rate, data=audio_data)
|
||||
|
||||
se, _ = se_extractor.get_se(audio_path, self.tone_color_converter, target_dir=CACHE_PATH, vad=False)
|
||||
# device = self.tone_color_converter.device
|
||||
# version = self.tone_color_converter.version
|
||||
# if self.debug:
|
||||
# print("OpenVoice version:", version)
|
||||
|
||||
# audio_name = f"tmp_{version}_{hashlib.sha256(audio_data.tobytes()).hexdigest()[:16].replace('/','_^')}"
|
||||
|
||||
|
||||
# if vad:
|
||||
# wavs_folder = se_extractor.split_audio_vad(audio_path, target_dir=CACHE_PATH, audio_name=audio_name)
|
||||
# else:
|
||||
# wavs_folder = se_extractor.split_audio_whisper(audio_data, target_dir=CACHE_PATH, audio_name=audio_name)
|
||||
|
||||
# audio_segs = glob(f'{wavs_folder}/*.wav')
|
||||
# if len(audio_segs) == 0:
|
||||
# raise NotImplementedError('No audio segments found!')
|
||||
# # se, _ = se_extractor.get_se(audio_data, self.tone_color_converter, CACHE_PATH, vad=False)
|
||||
# se = self.tone_color_converter.extract_se(audio_segs)
|
||||
return se.cpu().detach().numpy()
|
||||
|
||||
def tensor2numpy(self, audio_data: torch.Tensor):
|
||||
|
@ -164,14 +175,11 @@ class TextToSpeech:
|
|||
"""
|
||||
return audio_data.cpu().detach().float().numpy()
|
||||
|
||||
def numpy2bytes(self, audio_data):
|
||||
if isinstance(audio_data, np.ndarray):
|
||||
if audio_data.dtype == np.dtype('float32'):
|
||||
audio_data = np.int16(audio_data * np.iinfo(np.int16).max)
|
||||
audio_data = audio_data.tobytes()
|
||||
return audio_data
|
||||
else:
|
||||
raise TypeError("audio_data must be a numpy array")
|
||||
def numpy2bytes(self, audio_data: np.ndarray):
|
||||
"""
|
||||
numpy类型转bytes
|
||||
"""
|
||||
return (audio_data*32768.0).astype(np.int32).tobytes()
|
||||
|
||||
def _base_tts(self,
|
||||
text: str,
|
||||
|
@ -263,11 +271,6 @@ class TextToSpeech:
|
|||
audio: tensor
|
||||
sr: 生成音频的采样速率
|
||||
"""
|
||||
if source_se is not None:
|
||||
source_se = torch.tensor(source_se.astype(np.float32)).to(self.device)
|
||||
if target_se is not None:
|
||||
target_se = torch.tensor(target_se.astype(np.float32)).to(self.device)
|
||||
|
||||
if source_se is None:
|
||||
source_se = self.source_se
|
||||
if target_se is None:
|
||||
|
@ -293,13 +296,16 @@ class TextToSpeech:
|
|||
print("tone color has been converted!")
|
||||
return audio, sr
|
||||
|
||||
def synthesize(self,
|
||||
def tts(self,
|
||||
text: str,
|
||||
tts_info,
|
||||
sdp_ratio=0.2,
|
||||
noise_scale=0.6,
|
||||
noise_scale_w=0.8,
|
||||
speed=1.0,
|
||||
quite=True,
|
||||
|
||||
source_se: Optional[np.ndarray]=None,
|
||||
target_se: Optional[np.ndarray]=None,
|
||||
sdp_ratio=0.2,
|
||||
quite=True,
|
||||
tau :float=0.3,
|
||||
message :str="default"):
|
||||
"""
|
||||
|
@ -316,14 +322,15 @@ class TextToSpeech:
|
|||
"""
|
||||
audio, sr = self._base_tts(text,
|
||||
sdp_ratio=sdp_ratio,
|
||||
noise_scale=tts_info['noise_scale'],
|
||||
noise_scale_w=tts_info['noise_scale_w'],
|
||||
speed=tts_info['speed'],
|
||||
noise_scale=noise_scale,
|
||||
noise_scale_w=noise_scale_w,
|
||||
speed=speed,
|
||||
quite=quite)
|
||||
if self.use_tone_convert and target_se.size>0:
|
||||
if self.use_tone_convert:
|
||||
tts_sr = self.base_tts_model.hps.data.sampling_rate
|
||||
converter_sr = self.tone_color_converter.hps.data.sampling_rate
|
||||
audio = F.resample(audio, tts_sr, converter_sr)
|
||||
print(audio.dtype)
|
||||
audio, sr = self._convert_tone(audio,
|
||||
source_se=source_se,
|
||||
target_se=target_se,
|
||||
|
@ -343,4 +350,3 @@ class TextToSpeech:
|
|||
"""
|
||||
sf.write(save_path, audio, sample_rate)
|
||||
print(f"Audio saved to {save_path}")
|
||||
|
|
@ -2,7 +2,6 @@ import os
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch import LongTensor
|
||||
from typing import Optional
|
||||
import soundfile as sf
|
||||
# vits
|
||||
from .vits import utils, commons
|
||||
|
@ -80,19 +79,19 @@ class TextToSpeech:
|
|||
print(f"Synthesis time: {time.time() - start_time} s")
|
||||
return audio
|
||||
|
||||
def synthesize(self, text, tts_info,target_se: Optional[np.ndarray]=None, save_audio=False, return_bytes=True):
|
||||
def synthesize(self, text, language, speaker_id, noise_scale, noise_scale_w, length_scale, save_audio=False, return_bytes=False):
|
||||
if not len(text):
|
||||
return "输入文本不能为空!", None
|
||||
text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
|
||||
if len(text) > 100 and self.limitation:
|
||||
return f"输入文字过长!{len(text)}>100", None
|
||||
text = self._preprocess_text(text, tts_info['language'])
|
||||
audio = self._generate_audio(text, tts_info['speaker_id'], tts_info['noise_scale'], tts_info['noise_scale_w'], tts_info['length_scale'])
|
||||
text = self._preprocess_text(text, language)
|
||||
audio = self._generate_audio(text, speaker_id, noise_scale, noise_scale_w, length_scale)
|
||||
if self.debug or save_audio:
|
||||
self.save_audio(audio, self.RATE, 'output_file.wav')
|
||||
if return_bytes:
|
||||
audio = self.convert_numpy_to_bytes(audio)
|
||||
return audio, self.RATE
|
||||
return self.RATE, audio
|
||||
|
||||
def convert_numpy_to_bytes(self, audio_data):
|
||||
if isinstance(audio_data, np.ndarray):
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
import websockets
|
||||
import datetime
|
||||
import hashlib
|
||||
import base64
|
||||
import hmac
|
||||
import json
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
from datetime import datetime
|
||||
|
@ -37,33 +35,3 @@ def generate_xf_asr_url():
|
|||
}
|
||||
url = url + '?' + urlencode(v)
|
||||
return url
|
||||
|
||||
|
||||
def make_first_frame(buf):
|
||||
first_frame = {"common" : {"app_id":Config.XF_ASR.APP_ID},"business" : {"domain":"iat","language":"zh_cn","accent":"mandarin","vad_eos":10000},
|
||||
"data":{"status":0,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
|
||||
return json.dumps(first_frame)
|
||||
|
||||
def make_continue_frame(buf):
|
||||
continue_frame = {"data":{"status":1,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
|
||||
return json.dumps(continue_frame)
|
||||
|
||||
def make_last_frame(buf):
|
||||
last_frame = {"data":{"status":2,"format":"audio/L16;rate=16000","audio":buf,"encoding":"raw"}}
|
||||
return json.dumps(last_frame)
|
||||
|
||||
def parse_xfasr_recv(message):
|
||||
code = message['code']
|
||||
if code!=0:
|
||||
raise Exception("讯飞ASR错误码:"+str(code))
|
||||
else:
|
||||
data = message['data']['result']['ws']
|
||||
result = ""
|
||||
for i in data:
|
||||
for w in i['cw']:
|
||||
result += w['w']
|
||||
return result
|
||||
|
||||
async def xf_asr_websocket_factory():
|
||||
url = generate_xf_asr_url()
|
||||
return await websockets.connect(url)
|
Loading…
Reference in New Issue