192 lines
5.5 KiB
Python
192 lines
5.5 KiB
Python
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import multiprocessing
|
|
from multiprocessing import Pool
|
|
|
|
import argparse
|
|
import os
|
|
import tritonclient.grpc as grpcclient
|
|
from utils import cal_cer
|
|
from speech_client import *
|
|
import numpy as np
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"-v",
|
|
"--verbose",
|
|
action="store_true",
|
|
required=False,
|
|
default=False,
|
|
help="Enable verbose output",
|
|
)
|
|
parser.add_argument(
|
|
"-u",
|
|
"--url",
|
|
type=str,
|
|
required=False,
|
|
default="localhost:10086",
|
|
help="Inference server URL. Default is " "localhost:8001.",
|
|
)
|
|
parser.add_argument(
|
|
"--model_name",
|
|
required=False,
|
|
default="attention_rescoring",
|
|
choices=["attention_rescoring", "streaming_wenet", "infer_pipeline"],
|
|
help="the model to send request to",
|
|
)
|
|
parser.add_argument(
|
|
"--wavscp",
|
|
type=str,
|
|
required=False,
|
|
default=None,
|
|
help="audio_id \t wav_path",
|
|
)
|
|
parser.add_argument(
|
|
"--trans",
|
|
type=str,
|
|
required=False,
|
|
default=None,
|
|
help="audio_id \t text",
|
|
)
|
|
parser.add_argument(
|
|
"--data_dir",
|
|
type=str,
|
|
required=False,
|
|
default=None,
|
|
help="path prefix for wav_path in wavscp/audio_file",
|
|
)
|
|
parser.add_argument(
|
|
"--audio_file",
|
|
type=str,
|
|
required=False,
|
|
default=None,
|
|
help="single wav file path",
|
|
)
|
|
# below arguments are for streaming
|
|
# Please check onnx_config.yaml and train.yaml
|
|
parser.add_argument("--streaming", action="store_true", required=False)
|
|
parser.add_argument(
|
|
"--sample_rate",
|
|
type=int,
|
|
required=False,
|
|
default=16000,
|
|
help="sample rate used in training",
|
|
)
|
|
parser.add_argument(
|
|
"--frame_length_ms",
|
|
type=int,
|
|
required=False,
|
|
default=25,
|
|
help="frame length",
|
|
)
|
|
parser.add_argument(
|
|
"--frame_shift_ms",
|
|
type=int,
|
|
required=False,
|
|
default=10,
|
|
help="frame shift length",
|
|
)
|
|
parser.add_argument(
|
|
"--chunk_size",
|
|
type=int,
|
|
required=False,
|
|
default=16,
|
|
help="chunk size default is 16",
|
|
)
|
|
parser.add_argument(
|
|
"--context",
|
|
type=int,
|
|
required=False,
|
|
default=7,
|
|
help="subsampling context",
|
|
)
|
|
parser.add_argument(
|
|
"--subsampling",
|
|
type=int,
|
|
required=False,
|
|
default=4,
|
|
help="subsampling rate",
|
|
)
|
|
|
|
FLAGS = parser.parse_args()
|
|
print(FLAGS)
|
|
|
|
# load data
|
|
filenames = []
|
|
transcripts = []
|
|
if FLAGS.audio_file is not None:
|
|
path = FLAGS.audio_file
|
|
if FLAGS.data_dir:
|
|
path = os.path.join(FLAGS.data_dir, path)
|
|
if os.path.exists(path):
|
|
filenames = [path]
|
|
elif FLAGS.wavscp is not None:
|
|
audio_data = {}
|
|
with open(FLAGS.wavscp, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
aid, path = line.strip().split("\t")
|
|
if FLAGS.data_dir:
|
|
path = os.path.join(FLAGS.data_dir, path)
|
|
audio_data[aid] = {"path": path}
|
|
with open(FLAGS.trans, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
aid, text = line.strip().split("\t")
|
|
audio_data[aid]["text"] = text
|
|
for key, value in audio_data.items():
|
|
filenames.append(value["path"])
|
|
transcripts.append(value["text"])
|
|
|
|
num_workers = multiprocessing.cpu_count() // 2
|
|
|
|
if FLAGS.streaming:
|
|
speech_client_cls = StreamingSpeechClient
|
|
else:
|
|
speech_client_cls = OfflineSpeechClient
|
|
|
|
def single_job(client_files):
|
|
with grpcclient.InferenceServerClient(
|
|
url=FLAGS.url, verbose=FLAGS.verbose
|
|
) as triton_client:
|
|
protocol_client = grpcclient
|
|
speech_client = speech_client_cls(
|
|
triton_client, FLAGS.model_name, protocol_client, FLAGS
|
|
)
|
|
idx, audio_files = client_files
|
|
predictions = []
|
|
for li in audio_files:
|
|
result = speech_client.recognize(li, idx)
|
|
print("Recognized {}:{}".format(li, result[0]))
|
|
predictions += result
|
|
return predictions
|
|
|
|
# start to do inference
|
|
# Group requests in batches
|
|
predictions = []
|
|
tasks = []
|
|
splits = np.array_split(filenames, num_workers)
|
|
|
|
for idx, per_split in enumerate(splits):
|
|
cur_files = per_split.tolist()
|
|
tasks.append((idx, cur_files))
|
|
|
|
with Pool(processes=num_workers) as pool:
|
|
predictions = pool.map(single_job, tasks)
|
|
|
|
predictions = [item for sublist in predictions for item in sublist]
|
|
if transcripts:
|
|
cer = cal_cer(predictions, transcripts)
|
|
print("CER is: {}".format(cer))
|