FunASR/runtime/grpc/paraformer-server.cc

266 lines
9.1 KiB
C++
Raw Normal View History

2024-05-18 15:50:56 +08:00
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023 by burkliu(刘柏基) liubaiji@xverse.cn */
#include "paraformer-server.h"
GrpcEngine::GrpcEngine(
grpc::ServerReaderWriter<Response, Request>* stream,
std::shared_ptr<FUNASR_HANDLE> asr_handler)
: stream_(std::move(stream)),
asr_handler_(std::move(asr_handler)) {
request_ = std::make_shared<Request>();
}
void GrpcEngine::DecodeThreadFunc() {
FUNASR_HANDLE tpass_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size_);
int step = (sampling_rate_ * step_duration_ms_ / 1000) * 2; // int16 = 2bytes;
std::vector<std::vector<std::string>> punc_cache(2);
bool is_final = false;
std::string online_result = "";
std::string tpass_result = "";
LOG(INFO) << "Decoder init, start decoding loop with mode";
while (true) {
if (audio_buffer_.length() > step || is_end_) {
if (audio_buffer_.length() <= step && is_end_) {
is_final = true;
step = audio_buffer_.length();
}
FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_,
tpass_online_handler,
audio_buffer_.c_str(),
step,
punc_cache,
is_final,
sampling_rate_,
encoding_,
mode_);
p_mutex_->lock();
audio_buffer_ = audio_buffer_.substr(step);
p_mutex_->unlock();
if (result) {
std::string online_message = FunASRGetResult(result, 0);
online_result += online_message;
if(online_message != ""){
Response response;
response.set_mode(DecodeMode::online);
response.set_text(online_message);
response.set_is_final(is_final);
stream_->Write(response);
LOG(INFO) << "send online results: " << online_message;
}
std::string tpass_message = FunASRGetTpassResult(result, 0);
tpass_result += tpass_message;
if(tpass_message != ""){
Response response;
response.set_mode(DecodeMode::two_pass);
response.set_text(tpass_message);
response.set_is_final(is_final);
stream_->Write(response);
LOG(INFO) << "send offline results: " << tpass_message;
}
FunASRFreeResult(result);
}
if (is_final) {
FunTpassOnlineUninit(tpass_online_handler);
break;
}
}
sleep(0.001);
}
}
void GrpcEngine::OnSpeechStart() {
if (request_->chunk_size_size() == 3) {
for (int i = 0; i < 3; i++) {
chunk_size_[i] = int(request_->chunk_size(i));
}
}
std::string chunk_size_str;
for (int i = 0; i < 3; i++) {
chunk_size_str = " " + chunk_size_[i];
}
LOG(INFO) << "chunk_size is" << chunk_size_str;
if (request_->sampling_rate() != 0) {
sampling_rate_ = request_->sampling_rate();
}
LOG(INFO) << "sampling_rate is " << sampling_rate_;
switch(request_->wav_format()) {
case WavFormat::pcm: encoding_ = "pcm";
}
LOG(INFO) << "encoding is " << encoding_;
std::string mode_str;
switch(request_->mode()) {
case DecodeMode::offline:
mode_ = ASR_OFFLINE;
mode_str = "offline";
break;
case DecodeMode::online:
mode_ = ASR_ONLINE;
mode_str = "online";
break;
case DecodeMode::two_pass:
mode_ = ASR_TWO_PASS;
mode_str = "two_pass";
break;
}
LOG(INFO) << "decode mode is " << mode_str;
decode_thread_ = std::make_shared<std::thread>(&GrpcEngine::DecodeThreadFunc, this);
is_start_ = true;
}
void GrpcEngine::OnSpeechData() {
p_mutex_->lock();
audio_buffer_ += request_->audio_data();
p_mutex_->unlock();
}
void GrpcEngine::OnSpeechEnd() {
is_end_ = true;
LOG(INFO) << "Read all pcm data, wait for decoding thread";
if (decode_thread_ != nullptr) {
decode_thread_->join();
}
}
void GrpcEngine::operator()() {
try {
LOG(INFO) << "start engine main loop";
while (stream_->Read(request_.get())) {
LOG(INFO) << "receive data";
if (!is_start_) {
OnSpeechStart();
}
OnSpeechData();
if (request_->is_final()) {
break;
}
}
OnSpeechEnd();
LOG(INFO) << "Connect finish";
} catch (std::exception const& e) {
LOG(ERROR) << e.what();
}
}
GrpcService::GrpcService(std::map<std::string, std::string>& config, int onnx_thread)
: config_(config) {
asr_handler_ = std::make_shared<FUNASR_HANDLE>(std::move(FunTpassInit(config_, onnx_thread)));
LOG(INFO) << "GrpcService model loaded";
std::vector<int> chunk_size = {5, 10, 5};
FUNASR_HANDLE tmp_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size);
int sampling_rate = 16000;
int buffer_len = sampling_rate * 1;
std::string tmp_data(buffer_len, '0');
std::vector<std::vector<std::string>> punc_cache(2);
bool is_final = true;
std::string encoding = "pcm";
FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_,
tmp_online_handler,
tmp_data.c_str(),
buffer_len,
punc_cache,
is_final,
buffer_len,
encoding,
ASR_TWO_PASS);
if (result) {
FunASRFreeResult(result);
}
FunTpassOnlineUninit(tmp_online_handler);
LOG(INFO) << "GrpcService model warmup";
}
grpc::Status GrpcService::Recognize(
grpc::ServerContext* context,
grpc::ServerReaderWriter<Response, Request>* stream) {
LOG(INFO) << "Get Recognize request";
GrpcEngine engine(
stream,
asr_handler_
);
std::thread t(std::move(engine));
t.join();
return grpc::Status::OK;
}
void GetValue(TCLAP::ValueArg<std::string>& value_arg, std::string key, std::map<std::string, std::string>& config) {
if (value_arg.isSet()) {
config.insert({key, value_arg.getValue()});
LOG(INFO) << key << " : " << value_arg.getValue();
}
}
int main(int argc, char* argv[]) {
FLAGS_logtostderr = true;
google::InitGoogleLogging(argv[0]);
TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
TCLAP::ValueArg<std::int32_t> onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string");
cmd.add(model_dir);
cmd.add(online_model_dir);
cmd.add(quantize);
cmd.add(vad_dir);
cmd.add(vad_quant);
cmd.add(punc_dir);
cmd.add(punc_quant);
cmd.add(onnx_thread);
cmd.add(port_id);
cmd.parse(argc, argv);
std::map<std::string, std::string> config;
GetValue(model_dir, MODEL_DIR, config);
GetValue(online_model_dir, ONLINE_MODEL_DIR, config);
GetValue(quantize, QUANTIZE, config);
GetValue(vad_dir, VAD_DIR, config);
GetValue(vad_quant, VAD_QUANT, config);
GetValue(punc_dir, PUNC_DIR, config);
GetValue(punc_quant, PUNC_QUANT, config);
GetValue(port_id, PORT_ID, config);
std::string port;
try {
port = config.at(PORT_ID);
} catch(std::exception const &e) {
LOG(INFO) << ("Error when read port.");
exit(0);
}
std::string server_address;
server_address = "0.0.0.0:" + port;
GrpcService service(config, onnx_thread);
grpc::ServerBuilder builder;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
LOG(INFO) << "Server listening on " << server_address;
server->Wait();
return 0;
}