TakwayPlatform/utils/stt/funasr_utils.py

342 lines
14 KiB
Python
Raw Normal View History

2024-05-01 17:18:30 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ####################################################### #
# FunAutoSpeechRecognizer: https://github.com/alibaba-damo-academy/FunASR
# ####################################################### #
import io
import numpy as np
import base64
import wave
from funasr import AutoModel
from .base_stt import STTBase
def decode_str2bytes(data):
# 将Base64编码的字节串解码为字节串
if data is None:
return None
return base64.b64decode(data.encode('utf-8'))
class FunAutoSpeechRecognizer(STTBase):
def __init__(self,
model_path="paraformer-zh-streaming",
device="cuda",
RATE=16000,
cfg_path=None,
debug=False,
chunk_ms=480,
encoder_chunk_look_back=4,
decoder_chunk_look_back=1,
**kwargs):
super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention
self.decoder_chunk_look_back = decoder_chunk_look_back #number of encoder chunks to lookback for decoder cross-attention
#[0, 8, 4] 480ms, [0, 10, 5] 600ms
if chunk_ms == 480:
self.chunk_size = [0, 8, 4]
elif chunk_ms == 600:
self.chunk_size = [0, 10, 5]
else:
raise ValueError("`chunk_ms` should be 480 or 600, and type is int.")
self.chunk_partial_size = self.chunk_size[1] * 960
self.audio_cache = {}
2024-05-01 17:18:30 +08:00
self.asr_cache = {}
# self.audio_cache = None
# self.asr_cache = {}
2024-05-01 17:18:30 +08:00
self._init_asr()
def check_audio_type(self, audio_data):
"""check audio data type and convert it to bytes if necessary."""
if isinstance(audio_data, bytes):
pass
elif isinstance(audio_data, list):
audio_data = b''.join(audio_data)
elif isinstance(audio_data, str):
audio_data = decode_str2bytes(audio_data)
elif isinstance(audio_data, io.BytesIO):
wf = wave.open(audio_data, 'rb')
audio_data = wf.readframes(wf.getnframes())
elif isinstance(audio_data, np.ndarray):
pass
else:
raise TypeError(f"audio_data must be bytes, list, str, \
io.BytesIO or numpy array, but got {type(audio_data)}")
if isinstance(audio_data, bytes):
audio_data = np.frombuffer(audio_data, dtype=np.int16)
elif isinstance(audio_data, np.ndarray):
if audio_data.dtype != np.int16:
audio_data = audio_data.astype(np.int16)
else:
raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
return audio_data
def _init_asr(self):
# 随机初始化一段音频数据
init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
self.audio_cache = {}
2024-05-01 17:18:30 +08:00
self.asr_cache = {}
# print("init ASR model done.")
# when chat trying to use asr , sign up
def session_signup(self,session_id):
self.audio_cache[session_id] = None
self.asr_cache[session_id] = {}
# when chat finish using asr , sign out
def session_signout(self,session_id):
del self.audio_cache[session_id]
del self.asr_cache[session_id]
def streaming_recognize(self,
session_id,
2024-05-01 17:18:30 +08:00
audio_data,
is_end=False,
auto_det_end=False):
"""recognize partial result
Args:
audio_data: bytes or numpy array, partial audio data
is_end: bool, whether the audio data is the end of a sentence
auto_det_end: bool, whether to automatically detect the end of a audio data
"""
text_dict = dict(text=[], is_end=is_end)
audio_cache = self.audio_cache[session_id]
asr_cache = self.asr_cache[session_id]
2024-05-01 17:18:30 +08:00
audio_data = self.check_audio_type(audio_data)
if audio_cache is None:
audio_cache = audio_data
2024-05-01 17:18:30 +08:00
else:
if audio_cache.shape[0] > 0:
audio_cache = np.concatenate([audio_cache, audio_data], axis=0)
2024-05-01 17:18:30 +08:00
if not is_end and audio_cache.shape[0] < self.chunk_partial_size:
self.audio_cache[session_id] = audio_cache
2024-05-01 17:18:30 +08:00
return text_dict
total_chunk_num = int((len(audio_cache)-1)/self.chunk_partial_size)
2024-05-01 17:18:30 +08:00
if is_end:
# if the audio data is the end of a sentence, \
# we need to add one more chunk to the end to \
# ensure the end of the sentence is recognized correctly.
auto_det_end = True
if auto_det_end:
total_chunk_num += 1
end_idx = None
for i in range(total_chunk_num):
if auto_det_end:
is_end = i == total_chunk_num - 1
start_idx = i*self.chunk_partial_size
if auto_det_end:
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
else:
end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
# print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
# t_stamp = time.time()
speech_chunk = audio_cache[start_idx:end_idx]
2024-05-01 17:18:30 +08:00
# TODO: exceptions processes
try:
res = self.asr_model.generate(input=speech_chunk, cache=asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
2024-05-01 17:18:30 +08:00
except ValueError as e:
print(f"ValueError: {e}")
continue
text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
# print(f"each chunk time: {time.time()-t_stamp}")
if is_end:
audio_cache = None
asr_cache = {}
2024-05-01 17:18:30 +08:00
else:
if end_idx:
audio_cache = audio_cache[end_idx:] # cut the processed part from audio_cache
2024-05-01 17:18:30 +08:00
text_dict['is_end'] = is_end
self.audio_cache[session_id] = audio_cache
self.asr_cache[session_id] = asr_cache
2024-05-01 17:18:30 +08:00
return text_dict
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ####################################################### #
# FunAutoSpeechRecognizer: https://github.com/alibaba-damo-academy/FunASR
# ####################################################### #
# import io
# import numpy as np
# import base64
# import wave
# from funasr import AutoModel
# from .base_stt import STTBase
# def decode_str2bytes(data):
# # 将Base64编码的字节串解码为字节串
# if data is None:
# return None
# return base64.b64decode(data.encode('utf-8'))
# class FunAutoSpeechRecognizer(STTBase):
# def __init__(self,
# model_path="paraformer-zh-streaming",
# device="cuda",
# RATE=16000,
# cfg_path=None,
# debug=False,
# chunk_ms=480,
# encoder_chunk_look_back=4,
# decoder_chunk_look_back=1,
# **kwargs):
# super().__init__(RATE=RATE, cfg_path=cfg_path, debug=debug)
# self.asr_model = AutoModel(model=model_path, device=device, **kwargs)
# self.encoder_chunk_look_back = encoder_chunk_look_back #number of chunks to lookback for encoder self-attention
# self.decoder_chunk_look_back = decoder_chunk_look_back #number of encoder chunks to lookback for decoder cross-attention
# #[0, 8, 4] 480ms, [0, 10, 5] 600ms
# if chunk_ms == 480:
# self.chunk_size = [0, 8, 4]
# elif chunk_ms == 600:
# self.chunk_size = [0, 10, 5]
# else:
# raise ValueError("`chunk_ms` should be 480 or 600, and type is int.")
# self.chunk_partial_size = self.chunk_size[1] * 960
# self.audio_cache = None
# self.asr_cache = {}
# self._init_asr()
# def check_audio_type(self, audio_data):
# """check audio data type and convert it to bytes if necessary."""
# if isinstance(audio_data, bytes):
# pass
# elif isinstance(audio_data, list):
# audio_data = b''.join(audio_data)
# elif isinstance(audio_data, str):
# audio_data = decode_str2bytes(audio_data)
# elif isinstance(audio_data, io.BytesIO):
# wf = wave.open(audio_data, 'rb')
# audio_data = wf.readframes(wf.getnframes())
# elif isinstance(audio_data, np.ndarray):
# pass
# else:
# raise TypeError(f"audio_data must be bytes, list, str, \
# io.BytesIO or numpy array, but got {type(audio_data)}")
# if isinstance(audio_data, bytes):
# audio_data = np.frombuffer(audio_data, dtype=np.int16)
# elif isinstance(audio_data, np.ndarray):
# if audio_data.dtype != np.int16:
# audio_data = audio_data.astype(np.int16)
# else:
# raise TypeError(f"audio_data must be bytes or numpy array, but got {type(audio_data)}")
# return audio_data
# def _init_asr(self):
# # 随机初始化一段音频数据
# init_audio_data = np.random.randint(-32768, 32767, size=self.chunk_partial_size, dtype=np.int16)
# self.asr_model.generate(input=init_audio_data, cache=self.asr_cache, is_final=False, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
# self.audio_cache = None
# self.asr_cache = {}
# # print("init ASR model done.")
# def recognize(self, audio_data):
# """recognize audio data to text"""
# audio_data = self.check_audio_type(audio_data)
# result = self.asr_model.generate(input=audio_data,
# batch_size_s=300,
# hotword=self.hotwords)
# # print(result)
# text = ''
# for res in result:
# text += res['text']
# return text
# def streaming_recognize(self,
# audio_data,
# is_end=False,
# auto_det_end=False):
# """recognize partial result
# Args:
# audio_data: bytes or numpy array, partial audio data
# is_end: bool, whether the audio data is the end of a sentence
# auto_det_end: bool, whether to automatically detect the end of a audio data
# """
# text_dict = dict(text=[], is_end=is_end)
# audio_data = self.check_audio_type(audio_data)
# if self.audio_cache is None:
# self.audio_cache = audio_data
# else:
# # print(f"audio_data: {audio_data.shape}, audio_cache: {self.audio_cache.shape}")
# if self.audio_cache.shape[0] > 0:
# self.audio_cache = np.concatenate([self.audio_cache, audio_data], axis=0)
# if not is_end and self.audio_cache.shape[0] < self.chunk_partial_size:
# return text_dict
# total_chunk_num = int((len(self.audio_cache)-1)/self.chunk_partial_size)
# if is_end:
# # if the audio data is the end of a sentence, \
# # we need to add one more chunk to the end to \
# # ensure the end of the sentence is recognized correctly.
# auto_det_end = True
# if auto_det_end:
# total_chunk_num += 1
# # print(f"chunk_size: {self.chunk_size}, chunk_stride: {self.chunk_partial_size}, total_chunk_num: {total_chunk_num}, len: {len(self.audio_cache)}")
# end_idx = None
# for i in range(total_chunk_num):
# if auto_det_end:
# is_end = i == total_chunk_num - 1
# start_idx = i*self.chunk_partial_size
# if auto_det_end:
# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num-1 else -1
# else:
# end_idx = (i+1)*self.chunk_partial_size if i < total_chunk_num else -1
# # print(f"cut part: {start_idx}:{end_idx}, is_end: {is_end}, i: {i}, total_chunk_num: {total_chunk_num}")
# # t_stamp = time.time()
# speech_chunk = self.audio_cache[start_idx:end_idx]
# # TODO: exceptions processes
# try:
# res = self.asr_model.generate(input=speech_chunk, cache=self.asr_cache, is_final=is_end, chunk_size=self.chunk_size, encoder_chunk_look_back=self.encoder_chunk_look_back, decoder_chunk_look_back=self.decoder_chunk_look_back)
# except ValueError as e:
# print(f"ValueError: {e}")
# continue
# text_dict['text'].append(self.text_postprecess(res[0], data_id='text'))
# # print(f"each chunk time: {time.time()-t_stamp}")
# if is_end:
# self.audio_cache = None
# self.asr_cache = {}
# else:
# if end_idx:
# self.audio_cache = self.audio_cache[end_idx:] # cut the processed part from audio_cache
# text_dict['is_end'] = is_end
# # print(f"text_dict: {text_dict}")
# return text_dict
2024-05-01 17:18:30 +08:00