66 lines
2.3 KiB
Python
66 lines
2.3 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
import json
|
||
import wave
|
||
import io
|
||
import os
|
||
import logging
|
||
from ..common_utils import decode_str2bytes
|
||
|
||
class STTBase:
|
||
def __init__(self, RATE=16000, cfg_path=None, debug=False):
|
||
self.RATE = RATE
|
||
self.debug = debug
|
||
self.asr_cfg = self.parse_json(cfg_path)
|
||
|
||
def parse_json(self, cfg_path):
|
||
cfg = None
|
||
self.hotwords = None
|
||
if cfg_path is not None:
|
||
with open(cfg_path, 'r', encoding='utf-8') as f:
|
||
cfg = json.load(f)
|
||
self.hotwords = cfg.get('hot_words', None)
|
||
logging.info(f"load STT config file: {cfg_path}")
|
||
logging.info(f"Hot words: {self.hotwords}")
|
||
else:
|
||
logging.warning("No STT config file provided, using default config.")
|
||
return cfg
|
||
|
||
def add_hotword(self, hotword):
|
||
"""add hotword to list"""
|
||
if self.hotwords is None:
|
||
self.hotwords = ""
|
||
if isinstance(hotword, str):
|
||
self.hotwords = self.hotwords + " " + "hotword"
|
||
elif isinstance(hotword, (list, tuple)):
|
||
# 将hotwords转换为str,并用空格隔开
|
||
self.hotwords = self.hotwords + " " + " ".join(hotword)
|
||
else:
|
||
raise TypeError("hotword must be str or list")
|
||
|
||
def check_audio_type(self, audio_data):
|
||
"""check audio data type and convert it to bytes if necessary."""
|
||
if isinstance(audio_data, bytes):
|
||
pass
|
||
elif isinstance(audio_data, list):
|
||
audio_data = b''.join(audio_data)
|
||
elif isinstance(audio_data, str):
|
||
audio_data = decode_str2bytes(audio_data)
|
||
elif isinstance(audio_data, io.BytesIO):
|
||
wf = wave.open(audio_data, 'rb')
|
||
audio_data = wf.readframes(wf.getnframes())
|
||
else:
|
||
raise TypeError(f"audio_data must be bytes, str or io.BytesIO, but got {type(audio_data)}")
|
||
return audio_data
|
||
|
||
def text_postprecess(self, result, data_id='text'):
|
||
"""postprecess recognized result."""
|
||
text = result[data_id]
|
||
if isinstance(text, list):
|
||
text = ''.join(text)
|
||
return text.replace(' ', '')
|
||
|
||
def recognize(self, audio_data, queue=None):
|
||
"""recognize audio data to text"""
|
||
pass
|