FunASR/runtime/triton_gpu/client/client.py

192 lines
5.5 KiB
Python

# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
from multiprocessing import Pool
import argparse
import os
import tritonclient.grpc as grpcclient
from utils import cal_cer
from speech_client import *
import numpy as np
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-v",
"--verbose",
action="store_true",
required=False,
default=False,
help="Enable verbose output",
)
parser.add_argument(
"-u",
"--url",
type=str,
required=False,
default="localhost:10086",
help="Inference server URL. Default is " "localhost:8001.",
)
parser.add_argument(
"--model_name",
required=False,
default="attention_rescoring",
choices=["attention_rescoring", "streaming_wenet", "infer_pipeline"],
help="the model to send request to",
)
parser.add_argument(
"--wavscp",
type=str,
required=False,
default=None,
help="audio_id \t wav_path",
)
parser.add_argument(
"--trans",
type=str,
required=False,
default=None,
help="audio_id \t text",
)
parser.add_argument(
"--data_dir",
type=str,
required=False,
default=None,
help="path prefix for wav_path in wavscp/audio_file",
)
parser.add_argument(
"--audio_file",
type=str,
required=False,
default=None,
help="single wav file path",
)
# below arguments are for streaming
# Please check onnx_config.yaml and train.yaml
parser.add_argument("--streaming", action="store_true", required=False)
parser.add_argument(
"--sample_rate",
type=int,
required=False,
default=16000,
help="sample rate used in training",
)
parser.add_argument(
"--frame_length_ms",
type=int,
required=False,
default=25,
help="frame length",
)
parser.add_argument(
"--frame_shift_ms",
type=int,
required=False,
default=10,
help="frame shift length",
)
parser.add_argument(
"--chunk_size",
type=int,
required=False,
default=16,
help="chunk size default is 16",
)
parser.add_argument(
"--context",
type=int,
required=False,
default=7,
help="subsampling context",
)
parser.add_argument(
"--subsampling",
type=int,
required=False,
default=4,
help="subsampling rate",
)
FLAGS = parser.parse_args()
print(FLAGS)
# load data
filenames = []
transcripts = []
if FLAGS.audio_file is not None:
path = FLAGS.audio_file
if FLAGS.data_dir:
path = os.path.join(FLAGS.data_dir, path)
if os.path.exists(path):
filenames = [path]
elif FLAGS.wavscp is not None:
audio_data = {}
with open(FLAGS.wavscp, "r", encoding="utf-8") as f:
for line in f:
aid, path = line.strip().split("\t")
if FLAGS.data_dir:
path = os.path.join(FLAGS.data_dir, path)
audio_data[aid] = {"path": path}
with open(FLAGS.trans, "r", encoding="utf-8") as f:
for line in f:
aid, text = line.strip().split("\t")
audio_data[aid]["text"] = text
for key, value in audio_data.items():
filenames.append(value["path"])
transcripts.append(value["text"])
num_workers = multiprocessing.cpu_count() // 2
if FLAGS.streaming:
speech_client_cls = StreamingSpeechClient
else:
speech_client_cls = OfflineSpeechClient
def single_job(client_files):
with grpcclient.InferenceServerClient(
url=FLAGS.url, verbose=FLAGS.verbose
) as triton_client:
protocol_client = grpcclient
speech_client = speech_client_cls(
triton_client, FLAGS.model_name, protocol_client, FLAGS
)
idx, audio_files = client_files
predictions = []
for li in audio_files:
result = speech_client.recognize(li, idx)
print("Recognized {}:{}".format(li, result[0]))
predictions += result
return predictions
# start to do inference
# Group requests in batches
predictions = []
tasks = []
splits = np.array_split(filenames, num_workers)
for idx, per_split in enumerate(splits):
cur_files = per_split.tolist()
tasks.append((idx, cur_files))
with Pool(processes=num_workers) as pool:
predictions = pool.map(single_job, tasks)
predictions = [item for sublist in predictions for item in sublist]
if transcripts:
cer = cal_cer(predictions, transcripts)
print("CER is: {}".format(cer))