TakwayPlatform/utils/stt/base_stt.py

67 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import wave
import io
import base64
def decode_str2bytes(data):
# 将Base64编码的字节串解码为字节串
if data is None:
return None
return base64.b64decode(data.encode('utf-8'))
class STTBase:
def __init__(self, RATE=16000, cfg_path=None, debug=False):
self.RATE = RATE
self.debug = debug
self.asr_cfg = self.parse_json(cfg_path)
def parse_json(self, cfg_path):
cfg = None
self.hotwords = None
if cfg_path is not None:
with open(cfg_path, 'r', encoding='utf-8') as f:
cfg = json.load(f)
self.hotwords = cfg.get('hot_words', None)
return cfg
def add_hotword(self, hotword):
"""add hotword to list"""
if self.hotwords is None:
self.hotwords = ""
if isinstance(hotword, str):
self.hotwords = self.hotwords + " " + "hotword"
elif isinstance(hotword, (list, tuple)):
# 将hotwords转换为str并用空格隔开
self.hotwords = self.hotwords + " " + " ".join(hotword)
else:
raise TypeError("hotword must be str or list")
def check_audio_type(self, audio_data):
"""check audio data type and convert it to bytes if necessary."""
if isinstance(audio_data, bytes):
pass
elif isinstance(audio_data, list):
audio_data = b''.join(audio_data)
elif isinstance(audio_data, str):
audio_data = decode_str2bytes(audio_data)
elif isinstance(audio_data, io.BytesIO):
wf = wave.open(audio_data, 'rb')
audio_data = wf.readframes(wf.getnframes())
else:
raise TypeError(f"audio_data must be bytes, str or io.BytesIO, but got {type(audio_data)}")
return audio_data
def text_postprecess(self, result, data_id='text'):
"""postprecess recognized result."""
text = result[data_id]
if isinstance(text, list):
text = ''.join(text)
return text.replace(' ', '')
def recognize(self, audio_data, queue=None):
"""recognize audio data to text"""
pass