pun_emo_speaker_utils/takway/stt/base_stt.py

66 lines
2.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import wave
import io
import os
import logging
from ..common_utils import decode_str2bytes
class STTBase:
def __init__(self, RATE=16000, cfg_path=None, debug=False):
self.RATE = RATE
self.debug = debug
self.asr_cfg = self.parse_json(cfg_path)
def parse_json(self, cfg_path):
cfg = None
self.hotwords = None
if cfg_path is not None:
with open(cfg_path, 'r', encoding='utf-8') as f:
cfg = json.load(f)
self.hotwords = cfg.get('hot_words', None)
logging.info(f"load STT config file: {cfg_path}")
logging.info(f"Hot words: {self.hotwords}")
else:
logging.warning("No STT config file provided, using default config.")
return cfg
def add_hotword(self, hotword):
"""add hotword to list"""
if self.hotwords is None:
self.hotwords = ""
if isinstance(hotword, str):
self.hotwords = self.hotwords + " " + "hotword"
elif isinstance(hotword, (list, tuple)):
# 将hotwords转换为str并用空格隔开
self.hotwords = self.hotwords + " " + " ".join(hotword)
else:
raise TypeError("hotword must be str or list")
def check_audio_type(self, audio_data):
"""check audio data type and convert it to bytes if necessary."""
if isinstance(audio_data, bytes):
pass
elif isinstance(audio_data, list):
audio_data = b''.join(audio_data)
elif isinstance(audio_data, str):
audio_data = decode_str2bytes(audio_data)
elif isinstance(audio_data, io.BytesIO):
wf = wave.open(audio_data, 'rb')
audio_data = wf.readframes(wf.getnframes())
else:
raise TypeError(f"audio_data must be bytes, str or io.BytesIO, but got {type(audio_data)}")
return audio_data
def text_postprecess(self, result, data_id='text'):
"""postprecess recognized result."""
text = result[data_id]
if isinstance(text, list):
text = ''.join(text)
return text.replace(' ', '')
def recognize(self, audio_data, queue=None):
"""recognize audio data to text"""
pass