#!/usr/bin/env python3 # -*- coding: utf-8 -*- import json import wave import io import base64 def decode_str2bytes(data): # 将Base64编码的字节串解码为字节串 if data is None: return None return base64.b64decode(data.encode('utf-8')) class STTBase: def __init__(self, RATE=16000, cfg_path=None, debug=False): self.RATE = RATE self.debug = debug self.asr_cfg = self.parse_json(cfg_path) def parse_json(self, cfg_path): cfg = None self.hotwords = None if cfg_path is not None: with open(cfg_path, 'r', encoding='utf-8') as f: cfg = json.load(f) self.hotwords = cfg.get('hot_words', None) return cfg def add_hotword(self, hotword): """add hotword to list""" if self.hotwords is None: self.hotwords = "" if isinstance(hotword, str): self.hotwords = self.hotwords + " " + "hotword" elif isinstance(hotword, (list, tuple)): # 将hotwords转换为str,并用空格隔开 self.hotwords = self.hotwords + " " + " ".join(hotword) else: raise TypeError("hotword must be str or list") def check_audio_type(self, audio_data): """check audio data type and convert it to bytes if necessary.""" if isinstance(audio_data, bytes): pass elif isinstance(audio_data, list): audio_data = b''.join(audio_data) elif isinstance(audio_data, str): audio_data = decode_str2bytes(audio_data) elif isinstance(audio_data, io.BytesIO): wf = wave.open(audio_data, 'rb') audio_data = wf.readframes(wf.getnframes()) else: raise TypeError(f"audio_data must be bytes, str or io.BytesIO, but got {type(audio_data)}") return audio_data def text_postprecess(self, result, data_id='text'): """postprecess recognized result.""" text = result[data_id] if isinstance(text, list): text = ''.join(text) return text.replace(' ', '') def recognize(self, audio_data, queue=None): """recognize audio data to text""" pass