1143 lines
47 KiB
Python
1143 lines
47 KiB
Python
"""RNN decoder module."""
|
|
|
|
import logging
|
|
import math
|
|
import random
|
|
from argparse import Namespace
|
|
|
|
import numpy as np
|
|
import six
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from funasr.models.transformer.utils.scorers.ctc_prefix_score import CTCPrefixScore
|
|
from funasr.models.transformer.utils.scorers.ctc_prefix_score import CTCPrefixScoreTH
|
|
from funasr.models.transformer.utils.scorers.scorer_interface import ScorerInterface
|
|
from funasr.metrics import end_detect
|
|
from funasr.models.transformer.utils.nets_utils import mask_by_length
|
|
from funasr.models.transformer.utils.nets_utils import pad_list
|
|
from funasr.metrics.compute_acc import th_accuracy
|
|
from funasr.models.transformer.utils.nets_utils import to_device
|
|
from funasr.models.language_model.rnn.attentions import att_to_numpy
|
|
|
|
MAX_DECODER_OUTPUT = 5
|
|
CTC_SCORING_RATIO = 1.5
|
|
|
|
|
|
class Decoder(torch.nn.Module, ScorerInterface):
|
|
"""Decoder module
|
|
|
|
:param int eprojs: encoder projection units
|
|
:param int odim: dimension of outputs
|
|
:param str dtype: gru or lstm
|
|
:param int dlayers: decoder layers
|
|
:param int dunits: decoder units
|
|
:param int sos: start of sequence symbol id
|
|
:param int eos: end of sequence symbol id
|
|
:param torch.nn.Module att: attention module
|
|
:param int verbose: verbose level
|
|
:param list char_list: list of character strings
|
|
:param ndarray labeldist: distribution of label smoothing
|
|
:param float lsm_weight: label smoothing weight
|
|
:param float sampling_probability: scheduled sampling probability
|
|
:param float dropout: dropout rate
|
|
:param float context_residual: if True, use context vector for token generation
|
|
:param float replace_sos: use for multilingual (speech/text) translation
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
eprojs,
|
|
odim,
|
|
dtype,
|
|
dlayers,
|
|
dunits,
|
|
sos,
|
|
eos,
|
|
att,
|
|
verbose=0,
|
|
char_list=None,
|
|
labeldist=None,
|
|
lsm_weight=0.0,
|
|
sampling_probability=0.0,
|
|
dropout=0.0,
|
|
context_residual=False,
|
|
replace_sos=False,
|
|
num_encs=1,
|
|
):
|
|
|
|
torch.nn.Module.__init__(self)
|
|
self.dtype = dtype
|
|
self.dunits = dunits
|
|
self.dlayers = dlayers
|
|
self.context_residual = context_residual
|
|
self.embed = torch.nn.Embedding(odim, dunits)
|
|
self.dropout_emb = torch.nn.Dropout(p=dropout)
|
|
|
|
self.decoder = torch.nn.ModuleList()
|
|
self.dropout_dec = torch.nn.ModuleList()
|
|
self.decoder += [
|
|
(
|
|
torch.nn.LSTMCell(dunits + eprojs, dunits)
|
|
if self.dtype == "lstm"
|
|
else torch.nn.GRUCell(dunits + eprojs, dunits)
|
|
)
|
|
]
|
|
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
|
|
for _ in six.moves.range(1, self.dlayers):
|
|
self.decoder += [
|
|
(
|
|
torch.nn.LSTMCell(dunits, dunits)
|
|
if self.dtype == "lstm"
|
|
else torch.nn.GRUCell(dunits, dunits)
|
|
)
|
|
]
|
|
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
|
|
# NOTE: dropout is applied only for the vertical connections
|
|
# see https://arxiv.org/pdf/1409.2329.pdf
|
|
self.ignore_id = -1
|
|
|
|
if context_residual:
|
|
self.output = torch.nn.Linear(dunits + eprojs, odim)
|
|
else:
|
|
self.output = torch.nn.Linear(dunits, odim)
|
|
|
|
self.loss = None
|
|
self.att = att
|
|
self.dunits = dunits
|
|
self.sos = sos
|
|
self.eos = eos
|
|
self.odim = odim
|
|
self.verbose = verbose
|
|
self.char_list = char_list
|
|
# for label smoothing
|
|
self.labeldist = labeldist
|
|
self.vlabeldist = None
|
|
self.lsm_weight = lsm_weight
|
|
self.sampling_probability = sampling_probability
|
|
self.dropout = dropout
|
|
self.num_encs = num_encs
|
|
|
|
# for multilingual E2E-ST
|
|
self.replace_sos = replace_sos
|
|
|
|
self.logzero = -10000000000.0
|
|
|
|
def zero_state(self, hs_pad):
|
|
return hs_pad.new_zeros(hs_pad.size(0), self.dunits)
|
|
|
|
def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
|
|
if self.dtype == "lstm":
|
|
z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
|
|
for i in six.moves.range(1, self.dlayers):
|
|
z_list[i], c_list[i] = self.decoder[i](
|
|
self.dropout_dec[i - 1](z_list[i - 1]), (z_prev[i], c_prev[i])
|
|
)
|
|
else:
|
|
z_list[0] = self.decoder[0](ey, z_prev[0])
|
|
for i in six.moves.range(1, self.dlayers):
|
|
z_list[i] = self.decoder[i](self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i])
|
|
return z_list, c_list
|
|
|
|
def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None):
|
|
"""Decoder forward
|
|
|
|
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
|
|
[in multi-encoder case,
|
|
list of torch.Tensor,
|
|
[(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]
|
|
:param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
|
|
[in multi-encoder case, list of torch.Tensor,
|
|
[(B), (B), ..., ]
|
|
:param torch.Tensor ys_pad: batch of padded character id sequence tensor
|
|
(B, Lmax)
|
|
:param int strm_idx: stream index indicates the index of decoding stream.
|
|
:param torch.Tensor lang_ids: batch of target language id tensor (B, 1)
|
|
:return: attention loss value
|
|
:rtype: torch.Tensor
|
|
:return: accuracy
|
|
:rtype: float
|
|
"""
|
|
# to support mutiple encoder asr mode, in single encoder mode,
|
|
# convert torch.Tensor to List of torch.Tensor
|
|
if self.num_encs == 1:
|
|
hs_pad = [hs_pad]
|
|
hlens = [hlens]
|
|
|
|
# TODO(kan-bayashi): need to make more smart way
|
|
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
|
|
# attention index for the attention module
|
|
# in SPA (speaker parallel attention),
|
|
# att_idx is used to select attention module. In other cases, it is 0.
|
|
att_idx = min(strm_idx, len(self.att) - 1)
|
|
|
|
# hlens should be list of list of integer
|
|
hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
|
|
|
|
self.loss = None
|
|
# prepare input and output word sequences with sos/eos IDs
|
|
eos = ys[0].new([self.eos])
|
|
sos = ys[0].new([self.sos])
|
|
if self.replace_sos:
|
|
ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)]
|
|
else:
|
|
ys_in = [torch.cat([sos, y], dim=0) for y in ys]
|
|
ys_out = [torch.cat([y, eos], dim=0) for y in ys]
|
|
|
|
# padding for ys with -1
|
|
# pys: utt x olen
|
|
ys_in_pad = pad_list(ys_in, self.eos)
|
|
ys_out_pad = pad_list(ys_out, self.ignore_id)
|
|
|
|
# get dim, length info
|
|
batch = ys_out_pad.size(0)
|
|
olength = ys_out_pad.size(1)
|
|
for idx in range(self.num_encs):
|
|
logging.info(
|
|
self.__class__.__name__
|
|
+ "Number of Encoder:{}; enc{}: input lengths: {}.".format(
|
|
self.num_encs, idx + 1, hlens[idx]
|
|
)
|
|
)
|
|
logging.info(
|
|
self.__class__.__name__ + " output lengths: " + str([y.size(0) for y in ys_out])
|
|
)
|
|
|
|
# initialization
|
|
c_list = [self.zero_state(hs_pad[0])]
|
|
z_list = [self.zero_state(hs_pad[0])]
|
|
for _ in six.moves.range(1, self.dlayers):
|
|
c_list.append(self.zero_state(hs_pad[0]))
|
|
z_list.append(self.zero_state(hs_pad[0]))
|
|
z_all = []
|
|
if self.num_encs == 1:
|
|
att_w = None
|
|
self.att[att_idx].reset() # reset pre-computation of h
|
|
else:
|
|
att_w_list = [None] * (self.num_encs + 1) # atts + han
|
|
att_c_list = [None] * (self.num_encs) # atts
|
|
for idx in range(self.num_encs + 1):
|
|
self.att[idx].reset() # reset pre-computation of h in atts and han
|
|
|
|
# pre-computation of embedding
|
|
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
|
|
|
|
# loop for an output sequence
|
|
for i in six.moves.range(olength):
|
|
if self.num_encs == 1:
|
|
att_c, att_w = self.att[att_idx](
|
|
hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w
|
|
)
|
|
else:
|
|
for idx in range(self.num_encs):
|
|
att_c_list[idx], att_w_list[idx] = self.att[idx](
|
|
hs_pad[idx],
|
|
hlens[idx],
|
|
self.dropout_dec[0](z_list[0]),
|
|
att_w_list[idx],
|
|
)
|
|
hs_pad_han = torch.stack(att_c_list, dim=1)
|
|
hlens_han = [self.num_encs] * len(ys_in)
|
|
att_c, att_w_list[self.num_encs] = self.att[self.num_encs](
|
|
hs_pad_han,
|
|
hlens_han,
|
|
self.dropout_dec[0](z_list[0]),
|
|
att_w_list[self.num_encs],
|
|
)
|
|
if i > 0 and random.random() < self.sampling_probability:
|
|
logging.info(" scheduled sampling ")
|
|
z_out = self.output(z_all[-1])
|
|
z_out = np.argmax(z_out.detach().cpu(), axis=1)
|
|
z_out = self.dropout_emb(self.embed(to_device(hs_pad[0], z_out)))
|
|
ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
|
|
else:
|
|
ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim)
|
|
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
|
|
if self.context_residual:
|
|
z_all.append(
|
|
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
|
|
) # utt x (zdim + hdim)
|
|
else:
|
|
z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
|
|
|
|
z_all = torch.stack(z_all, dim=1).view(batch * olength, -1)
|
|
# compute loss
|
|
y_all = self.output(z_all)
|
|
self.loss = F.cross_entropy(
|
|
y_all,
|
|
ys_out_pad.view(-1),
|
|
ignore_index=self.ignore_id,
|
|
reduction="mean",
|
|
)
|
|
# compute perplexity
|
|
ppl = math.exp(self.loss.item())
|
|
# -1: eos, which is removed in the loss computation
|
|
self.loss *= np.mean([len(x) for x in ys_in]) - 1
|
|
acc = th_accuracy(y_all, ys_out_pad, ignore_label=self.ignore_id)
|
|
logging.info("att loss:" + "".join(str(self.loss.item()).split("\n")))
|
|
|
|
# show predicted character sequence for debug
|
|
if self.verbose > 0 and self.char_list is not None:
|
|
ys_hat = y_all.view(batch, olength, -1)
|
|
ys_true = ys_out_pad
|
|
for (i, y_hat), y_true in zip(
|
|
enumerate(ys_hat.detach().cpu().numpy()), ys_true.detach().cpu().numpy()
|
|
):
|
|
if i == MAX_DECODER_OUTPUT:
|
|
break
|
|
idx_hat = np.argmax(y_hat[y_true != self.ignore_id], axis=1)
|
|
idx_true = y_true[y_true != self.ignore_id]
|
|
seq_hat = [self.char_list[int(idx)] for idx in idx_hat]
|
|
seq_true = [self.char_list[int(idx)] for idx in idx_true]
|
|
seq_hat = "".join(seq_hat)
|
|
seq_true = "".join(seq_true)
|
|
logging.info("groundtruth[%d]: " % i + seq_true)
|
|
logging.info("prediction [%d]: " % i + seq_hat)
|
|
|
|
if self.labeldist is not None:
|
|
if self.vlabeldist is None:
|
|
self.vlabeldist = to_device(hs_pad[0], torch.from_numpy(self.labeldist))
|
|
loss_reg = -torch.sum(
|
|
(F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1), dim=0
|
|
) / len(ys_in)
|
|
self.loss = (1.0 - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg
|
|
|
|
return self.loss, acc, ppl
|
|
|
|
def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0):
|
|
"""beam search implementation
|
|
|
|
:param torch.Tensor h: encoder hidden state (T, eprojs)
|
|
[in multi-encoder case, list of torch.Tensor,
|
|
[(T1, eprojs), (T2, eprojs), ...] ]
|
|
:param torch.Tensor lpz: ctc log softmax output (T, odim)
|
|
[in multi-encoder case, list of torch.Tensor,
|
|
[(T1, odim), (T2, odim), ...] ]
|
|
:param Namespace recog_args: argument Namespace containing options
|
|
:param char_list: list of character strings
|
|
:param torch.nn.Module rnnlm: language module
|
|
:param int strm_idx:
|
|
stream index for speaker parallel attention in multi-speaker case
|
|
:return: N-best decoding results
|
|
:rtype: list of dicts
|
|
"""
|
|
# to support mutiple encoder asr mode, in single encoder mode,
|
|
# convert torch.Tensor to List of torch.Tensor
|
|
if self.num_encs == 1:
|
|
h = [h]
|
|
lpz = [lpz]
|
|
if self.num_encs > 1 and lpz is None:
|
|
lpz = [lpz] * self.num_encs
|
|
|
|
for idx in range(self.num_encs):
|
|
logging.info(
|
|
"Number of Encoder:{}; enc{}: input lengths: {}.".format(
|
|
self.num_encs, idx + 1, h[0].size(0)
|
|
)
|
|
)
|
|
att_idx = min(strm_idx, len(self.att) - 1)
|
|
# initialization
|
|
c_list = [self.zero_state(h[0].unsqueeze(0))]
|
|
z_list = [self.zero_state(h[0].unsqueeze(0))]
|
|
for _ in six.moves.range(1, self.dlayers):
|
|
c_list.append(self.zero_state(h[0].unsqueeze(0)))
|
|
z_list.append(self.zero_state(h[0].unsqueeze(0)))
|
|
if self.num_encs == 1:
|
|
a = None
|
|
self.att[att_idx].reset() # reset pre-computation of h
|
|
else:
|
|
a = [None] * (self.num_encs + 1) # atts + han
|
|
att_w_list = [None] * (self.num_encs + 1) # atts + han
|
|
att_c_list = [None] * (self.num_encs) # atts
|
|
for idx in range(self.num_encs + 1):
|
|
self.att[idx].reset() # reset pre-computation of h in atts and han
|
|
|
|
# search parms
|
|
beam = recog_args.beam_size
|
|
penalty = recog_args.penalty
|
|
ctc_weight = getattr(recog_args, "ctc_weight", False) # for NMT
|
|
|
|
if lpz[0] is not None and self.num_encs > 1:
|
|
# weights-ctc,
|
|
# e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
|
|
weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(
|
|
recog_args.weights_ctc_dec
|
|
) # normalize
|
|
logging.info("ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec]))
|
|
else:
|
|
weights_ctc_dec = [1.0]
|
|
|
|
# preprate sos
|
|
if self.replace_sos and recog_args.tgt_lang:
|
|
y = char_list.index(recog_args.tgt_lang)
|
|
else:
|
|
y = self.sos
|
|
logging.info("<sos> index: " + str(y))
|
|
logging.info("<sos> mark: " + char_list[y])
|
|
vy = h[0].new_zeros(1).long()
|
|
|
|
maxlen = np.amin([h[idx].size(0) for idx in range(self.num_encs)])
|
|
if recog_args.maxlenratio != 0:
|
|
# maxlen >= 1
|
|
maxlen = max(1, int(recog_args.maxlenratio * maxlen))
|
|
minlen = int(recog_args.minlenratio * maxlen)
|
|
logging.info("max output length: " + str(maxlen))
|
|
logging.info("min output length: " + str(minlen))
|
|
|
|
# initialize hypothesis
|
|
if rnnlm:
|
|
hyp = {
|
|
"score": 0.0,
|
|
"yseq": [y],
|
|
"c_prev": c_list,
|
|
"z_prev": z_list,
|
|
"a_prev": a,
|
|
"rnnlm_prev": None,
|
|
}
|
|
else:
|
|
hyp = {
|
|
"score": 0.0,
|
|
"yseq": [y],
|
|
"c_prev": c_list,
|
|
"z_prev": z_list,
|
|
"a_prev": a,
|
|
}
|
|
if lpz[0] is not None:
|
|
ctc_prefix_score = [
|
|
CTCPrefixScore(lpz[idx].detach().numpy(), 0, self.eos, np)
|
|
for idx in range(self.num_encs)
|
|
]
|
|
hyp["ctc_state_prev"] = [
|
|
ctc_prefix_score[idx].initial_state() for idx in range(self.num_encs)
|
|
]
|
|
hyp["ctc_score_prev"] = [0.0] * self.num_encs
|
|
if ctc_weight != 1.0:
|
|
# pre-pruning based on attention scores
|
|
ctc_beam = min(lpz[0].shape[-1], int(beam * CTC_SCORING_RATIO))
|
|
else:
|
|
ctc_beam = lpz[0].shape[-1]
|
|
hyps = [hyp]
|
|
ended_hyps = []
|
|
|
|
for i in six.moves.range(maxlen):
|
|
logging.debug("position " + str(i))
|
|
|
|
hyps_best_kept = []
|
|
for hyp in hyps:
|
|
vy[0] = hyp["yseq"][i]
|
|
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
|
|
if self.num_encs == 1:
|
|
att_c, att_w = self.att[att_idx](
|
|
h[0].unsqueeze(0),
|
|
[h[0].size(0)],
|
|
self.dropout_dec[0](hyp["z_prev"][0]),
|
|
hyp["a_prev"],
|
|
)
|
|
else:
|
|
for idx in range(self.num_encs):
|
|
att_c_list[idx], att_w_list[idx] = self.att[idx](
|
|
h[idx].unsqueeze(0),
|
|
[h[idx].size(0)],
|
|
self.dropout_dec[0](hyp["z_prev"][0]),
|
|
hyp["a_prev"][idx],
|
|
)
|
|
h_han = torch.stack(att_c_list, dim=1)
|
|
att_c, att_w_list[self.num_encs] = self.att[self.num_encs](
|
|
h_han,
|
|
[self.num_encs],
|
|
self.dropout_dec[0](hyp["z_prev"][0]),
|
|
hyp["a_prev"][self.num_encs],
|
|
)
|
|
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
|
|
z_list, c_list = self.rnn_forward(ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"])
|
|
|
|
# get nbest local scores and their ids
|
|
if self.context_residual:
|
|
logits = self.output(
|
|
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
|
|
)
|
|
else:
|
|
logits = self.output(self.dropout_dec[-1](z_list[-1]))
|
|
local_att_scores = F.log_softmax(logits, dim=1)
|
|
if rnnlm:
|
|
rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy)
|
|
local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores
|
|
else:
|
|
local_scores = local_att_scores
|
|
|
|
if lpz[0] is not None:
|
|
local_best_scores, local_best_ids = torch.topk(
|
|
local_att_scores, ctc_beam, dim=1
|
|
)
|
|
ctc_scores, ctc_states = (
|
|
[None] * self.num_encs,
|
|
[None] * self.num_encs,
|
|
)
|
|
for idx in range(self.num_encs):
|
|
ctc_scores[idx], ctc_states[idx] = ctc_prefix_score[idx](
|
|
hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"][idx]
|
|
)
|
|
local_scores = (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]]
|
|
if self.num_encs == 1:
|
|
local_scores += ctc_weight * torch.from_numpy(
|
|
ctc_scores[0] - hyp["ctc_score_prev"][0]
|
|
)
|
|
else:
|
|
for idx in range(self.num_encs):
|
|
local_scores += (
|
|
ctc_weight
|
|
* weights_ctc_dec[idx]
|
|
* torch.from_numpy(ctc_scores[idx] - hyp["ctc_score_prev"][idx])
|
|
)
|
|
if rnnlm:
|
|
local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
|
|
local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1)
|
|
local_best_ids = local_best_ids[:, joint_best_ids[0]]
|
|
else:
|
|
local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1)
|
|
|
|
for j in six.moves.range(beam):
|
|
new_hyp = {}
|
|
# [:] is needed!
|
|
new_hyp["z_prev"] = z_list[:]
|
|
new_hyp["c_prev"] = c_list[:]
|
|
if self.num_encs == 1:
|
|
new_hyp["a_prev"] = att_w[:]
|
|
else:
|
|
new_hyp["a_prev"] = [att_w_list[idx][:] for idx in range(self.num_encs + 1)]
|
|
new_hyp["score"] = hyp["score"] + local_best_scores[0, j]
|
|
new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"]))
|
|
new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"]
|
|
new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j])
|
|
if rnnlm:
|
|
new_hyp["rnnlm_prev"] = rnnlm_state
|
|
if lpz[0] is not None:
|
|
new_hyp["ctc_state_prev"] = [
|
|
ctc_states[idx][joint_best_ids[0, j]] for idx in range(self.num_encs)
|
|
]
|
|
new_hyp["ctc_score_prev"] = [
|
|
ctc_scores[idx][joint_best_ids[0, j]] for idx in range(self.num_encs)
|
|
]
|
|
# will be (2 x beam) hyps at most
|
|
hyps_best_kept.append(new_hyp)
|
|
|
|
hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[
|
|
:beam
|
|
]
|
|
|
|
# sort and get nbest
|
|
hyps = hyps_best_kept
|
|
logging.debug("number of pruned hypotheses: " + str(len(hyps)))
|
|
logging.debug("best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]))
|
|
|
|
# add eos in the final loop to avoid that there are no ended hyps
|
|
if i == maxlen - 1:
|
|
logging.info("adding <eos> in the last position in the loop")
|
|
for hyp in hyps:
|
|
hyp["yseq"].append(self.eos)
|
|
|
|
# add ended hypotheses to a final list,
|
|
# and removed them from current hypotheses
|
|
# (this will be a problem, number of hyps < beam)
|
|
remained_hyps = []
|
|
for hyp in hyps:
|
|
if hyp["yseq"][-1] == self.eos:
|
|
# only store the sequence that has more than minlen outputs
|
|
# also add penalty
|
|
if len(hyp["yseq"]) > minlen:
|
|
hyp["score"] += (i + 1) * penalty
|
|
if rnnlm: # Word LM needs to add final <eos> score
|
|
hyp["score"] += recog_args.lm_weight * rnnlm.final(hyp["rnnlm_prev"])
|
|
ended_hyps.append(hyp)
|
|
else:
|
|
remained_hyps.append(hyp)
|
|
|
|
# end detection
|
|
if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
|
|
logging.info("end detected at %d", i)
|
|
break
|
|
|
|
hyps = remained_hyps
|
|
if len(hyps) > 0:
|
|
logging.debug("remaining hypotheses: " + str(len(hyps)))
|
|
else:
|
|
logging.info("no hypothesis. Finish decoding.")
|
|
break
|
|
|
|
for hyp in hyps:
|
|
logging.debug("hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]))
|
|
|
|
logging.debug("number of ended hypotheses: " + str(len(ended_hyps)))
|
|
|
|
nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[
|
|
: min(len(ended_hyps), recog_args.nbest)
|
|
]
|
|
|
|
# check number of hypotheses
|
|
if len(nbest_hyps) == 0:
|
|
logging.warning(
|
|
"there is no N-best results, " "perform recognition again with smaller minlenratio."
|
|
)
|
|
# should copy because Namespace will be overwritten globally
|
|
recog_args = Namespace(**vars(recog_args))
|
|
recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1)
|
|
if self.num_encs == 1:
|
|
return self.recognize_beam(h[0], lpz[0], recog_args, char_list, rnnlm)
|
|
else:
|
|
return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm)
|
|
|
|
logging.info("total log probability: " + str(nbest_hyps[0]["score"]))
|
|
logging.info(
|
|
"normalized log probability: "
|
|
+ str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"]))
|
|
)
|
|
|
|
# remove sos
|
|
return nbest_hyps
|
|
|
|
def recognize_beam_batch(
|
|
self,
|
|
h,
|
|
hlens,
|
|
lpz,
|
|
recog_args,
|
|
char_list,
|
|
rnnlm=None,
|
|
normalize_score=True,
|
|
strm_idx=0,
|
|
lang_ids=None,
|
|
):
|
|
# to support mutiple encoder asr mode, in single encoder mode,
|
|
# convert torch.Tensor to List of torch.Tensor
|
|
if self.num_encs == 1:
|
|
h = [h]
|
|
hlens = [hlens]
|
|
lpz = [lpz]
|
|
if self.num_encs > 1 and lpz is None:
|
|
lpz = [lpz] * self.num_encs
|
|
|
|
att_idx = min(strm_idx, len(self.att) - 1)
|
|
for idx in range(self.num_encs):
|
|
logging.info(
|
|
"Number of Encoder:{}; enc{}: input lengths: {}.".format(
|
|
self.num_encs, idx + 1, h[idx].size(1)
|
|
)
|
|
)
|
|
h[idx] = mask_by_length(h[idx], hlens[idx], 0.0)
|
|
|
|
# search params
|
|
batch = len(hlens[0])
|
|
beam = recog_args.beam_size
|
|
penalty = recog_args.penalty
|
|
ctc_weight = getattr(recog_args, "ctc_weight", 0) # for NMT
|
|
att_weight = 1.0 - ctc_weight
|
|
ctc_margin = getattr(
|
|
recog_args, "ctc_window_margin", 0
|
|
) # use getattr to keep compatibility
|
|
# weights-ctc,
|
|
# e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
|
|
if lpz[0] is not None and self.num_encs > 1:
|
|
weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(
|
|
recog_args.weights_ctc_dec
|
|
) # normalize
|
|
logging.info("ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec]))
|
|
else:
|
|
weights_ctc_dec = [1.0]
|
|
|
|
n_bb = batch * beam
|
|
pad_b = to_device(h[0], torch.arange(batch) * beam).view(-1, 1)
|
|
|
|
max_hlen = np.amin([max(hlens[idx]) for idx in range(self.num_encs)])
|
|
if recog_args.maxlenratio == 0:
|
|
maxlen = max_hlen
|
|
else:
|
|
maxlen = max(1, int(recog_args.maxlenratio * max_hlen))
|
|
minlen = int(recog_args.minlenratio * max_hlen)
|
|
logging.info("max output length: " + str(maxlen))
|
|
logging.info("min output length: " + str(minlen))
|
|
|
|
# initialization
|
|
c_prev = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
|
|
z_prev = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
|
|
c_list = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
|
|
z_list = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
|
|
vscores = to_device(h[0], torch.zeros(batch, beam))
|
|
|
|
rnnlm_state = None
|
|
if self.num_encs == 1:
|
|
a_prev = [None]
|
|
att_w_list, ctc_scorer, ctc_state = [None], [None], [None]
|
|
self.att[att_idx].reset() # reset pre-computation of h
|
|
else:
|
|
a_prev = [None] * (self.num_encs + 1) # atts + han
|
|
att_w_list = [None] * (self.num_encs + 1) # atts + han
|
|
att_c_list = [None] * (self.num_encs) # atts
|
|
ctc_scorer, ctc_state = [None] * (self.num_encs), [None] * (self.num_encs)
|
|
for idx in range(self.num_encs + 1):
|
|
self.att[idx].reset() # reset pre-computation of h in atts and han
|
|
|
|
if self.replace_sos and recog_args.tgt_lang:
|
|
logging.info("<sos> index: " + str(char_list.index(recog_args.tgt_lang)))
|
|
logging.info("<sos> mark: " + recog_args.tgt_lang)
|
|
yseq = [[char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb)]
|
|
elif lang_ids is not None:
|
|
# NOTE: used for evaluation during training
|
|
yseq = [[lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb)]
|
|
else:
|
|
logging.info("<sos> index: " + str(self.sos))
|
|
logging.info("<sos> mark: " + char_list[self.sos])
|
|
yseq = [[self.sos] for _ in six.moves.range(n_bb)]
|
|
|
|
accum_odim_ids = [self.sos for _ in six.moves.range(n_bb)]
|
|
stop_search = [False for _ in six.moves.range(batch)]
|
|
nbest_hyps = [[] for _ in six.moves.range(batch)]
|
|
ended_hyps = [[] for _ in range(batch)]
|
|
|
|
exp_hlens = [
|
|
hlens[idx].repeat(beam).view(beam, batch).transpose(0, 1).contiguous()
|
|
for idx in range(self.num_encs)
|
|
]
|
|
exp_hlens = [exp_hlens[idx].view(-1).tolist() for idx in range(self.num_encs)]
|
|
exp_h = [
|
|
h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous() for idx in range(self.num_encs)
|
|
]
|
|
exp_h = [
|
|
exp_h[idx].view(n_bb, h[idx].size()[1], h[idx].size()[2])
|
|
for idx in range(self.num_encs)
|
|
]
|
|
|
|
if lpz[0] is not None:
|
|
scoring_num = min(
|
|
int(beam * CTC_SCORING_RATIO) if att_weight > 0.0 and not lpz[0].is_cuda else 0,
|
|
lpz[0].size(-1),
|
|
)
|
|
ctc_scorer = [
|
|
CTCPrefixScoreTH(
|
|
lpz[idx],
|
|
hlens[idx],
|
|
0,
|
|
self.eos,
|
|
margin=ctc_margin,
|
|
)
|
|
for idx in range(self.num_encs)
|
|
]
|
|
|
|
for i in six.moves.range(maxlen):
|
|
logging.debug("position " + str(i))
|
|
|
|
vy = to_device(h[0], torch.LongTensor(self._get_last_yseq(yseq)))
|
|
ey = self.dropout_emb(self.embed(vy))
|
|
if self.num_encs == 1:
|
|
att_c, att_w = self.att[att_idx](
|
|
exp_h[0], exp_hlens[0], self.dropout_dec[0](z_prev[0]), a_prev[0]
|
|
)
|
|
att_w_list = [att_w]
|
|
else:
|
|
for idx in range(self.num_encs):
|
|
att_c_list[idx], att_w_list[idx] = self.att[idx](
|
|
exp_h[idx],
|
|
exp_hlens[idx],
|
|
self.dropout_dec[0](z_prev[0]),
|
|
a_prev[idx],
|
|
)
|
|
exp_h_han = torch.stack(att_c_list, dim=1)
|
|
att_c, att_w_list[self.num_encs] = self.att[self.num_encs](
|
|
exp_h_han,
|
|
[self.num_encs] * n_bb,
|
|
self.dropout_dec[0](z_prev[0]),
|
|
a_prev[self.num_encs],
|
|
)
|
|
ey = torch.cat((ey, att_c), dim=1)
|
|
|
|
# attention decoder
|
|
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev)
|
|
if self.context_residual:
|
|
logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1))
|
|
else:
|
|
logits = self.output(self.dropout_dec[-1](z_list[-1]))
|
|
local_scores = att_weight * F.log_softmax(logits, dim=1)
|
|
|
|
# rnnlm
|
|
if rnnlm:
|
|
rnnlm_state, local_lm_scores = rnnlm.buff_predict(rnnlm_state, vy, n_bb)
|
|
local_scores = local_scores + recog_args.lm_weight * local_lm_scores
|
|
|
|
# ctc
|
|
if ctc_scorer[0]:
|
|
local_scores[:, 0] = self.logzero # avoid choosing blank
|
|
part_ids = (
|
|
torch.topk(local_scores, scoring_num, dim=-1)[1] if scoring_num > 0 else None
|
|
)
|
|
for idx in range(self.num_encs):
|
|
att_w = att_w_list[idx]
|
|
att_w_ = att_w if isinstance(att_w, torch.Tensor) else att_w[0]
|
|
local_ctc_scores, ctc_state[idx] = ctc_scorer[idx](
|
|
yseq, ctc_state[idx], part_ids, att_w_
|
|
)
|
|
local_scores = (
|
|
local_scores + ctc_weight * weights_ctc_dec[idx] * local_ctc_scores
|
|
)
|
|
|
|
local_scores = local_scores.view(batch, beam, self.odim)
|
|
if i == 0:
|
|
local_scores[:, 1:, :] = self.logzero
|
|
|
|
# accumulate scores
|
|
eos_vscores = local_scores[:, :, self.eos] + vscores
|
|
vscores = vscores.view(batch, beam, 1).repeat(1, 1, self.odim)
|
|
vscores[:, :, self.eos] = self.logzero
|
|
vscores = (vscores + local_scores).view(batch, -1)
|
|
|
|
# global pruning
|
|
accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1)
|
|
accum_odim_ids = torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist()
|
|
accum_padded_beam_ids = (
|
|
(accum_best_ids // self.odim + pad_b).view(-1).data.cpu().tolist()
|
|
)
|
|
|
|
y_prev = yseq[:][:]
|
|
yseq = self._index_select_list(yseq, accum_padded_beam_ids)
|
|
yseq = self._append_ids(yseq, accum_odim_ids)
|
|
vscores = accum_best_scores
|
|
vidx = to_device(h[0], torch.LongTensor(accum_padded_beam_ids))
|
|
|
|
a_prev = []
|
|
num_atts = self.num_encs if self.num_encs == 1 else self.num_encs + 1
|
|
for idx in range(num_atts):
|
|
if isinstance(att_w_list[idx], torch.Tensor):
|
|
_a_prev = torch.index_select(
|
|
att_w_list[idx].view(n_bb, *att_w_list[idx].shape[1:]), 0, vidx
|
|
)
|
|
elif isinstance(att_w_list[idx], list):
|
|
# handle the case of multi-head attention
|
|
_a_prev = [
|
|
torch.index_select(att_w_one.view(n_bb, -1), 0, vidx)
|
|
for att_w_one in att_w_list[idx]
|
|
]
|
|
else:
|
|
# handle the case of location_recurrent when return is a tuple
|
|
_a_prev_ = torch.index_select(att_w_list[idx][0].view(n_bb, -1), 0, vidx)
|
|
_h_prev_ = torch.index_select(att_w_list[idx][1][0].view(n_bb, -1), 0, vidx)
|
|
_c_prev_ = torch.index_select(att_w_list[idx][1][1].view(n_bb, -1), 0, vidx)
|
|
_a_prev = (_a_prev_, (_h_prev_, _c_prev_))
|
|
a_prev.append(_a_prev)
|
|
z_prev = [
|
|
torch.index_select(z_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers)
|
|
]
|
|
c_prev = [
|
|
torch.index_select(c_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers)
|
|
]
|
|
|
|
# pick ended hyps
|
|
if i >= minlen:
|
|
k = 0
|
|
penalty_i = (i + 1) * penalty
|
|
thr = accum_best_scores[:, -1]
|
|
for samp_i in six.moves.range(batch):
|
|
if stop_search[samp_i]:
|
|
k = k + beam
|
|
continue
|
|
for beam_j in six.moves.range(beam):
|
|
_vscore = None
|
|
if eos_vscores[samp_i, beam_j] > thr[samp_i]:
|
|
yk = y_prev[k][:]
|
|
if len(yk) <= min(hlens[idx][samp_i] for idx in range(self.num_encs)):
|
|
_vscore = eos_vscores[samp_i][beam_j] + penalty_i
|
|
elif i == maxlen - 1:
|
|
yk = yseq[k][:]
|
|
_vscore = vscores[samp_i][beam_j] + penalty_i
|
|
if _vscore:
|
|
yk.append(self.eos)
|
|
if rnnlm:
|
|
_vscore += recog_args.lm_weight * rnnlm.final(rnnlm_state, index=k)
|
|
_score = _vscore.data.cpu().numpy()
|
|
ended_hyps[samp_i].append(
|
|
{"yseq": yk, "vscore": _vscore, "score": _score}
|
|
)
|
|
k = k + 1
|
|
|
|
# end detection
|
|
stop_search = [
|
|
stop_search[samp_i] or end_detect(ended_hyps[samp_i], i)
|
|
for samp_i in six.moves.range(batch)
|
|
]
|
|
stop_search_summary = list(set(stop_search))
|
|
if len(stop_search_summary) == 1 and stop_search_summary[0]:
|
|
break
|
|
|
|
if rnnlm:
|
|
rnnlm_state = self._index_select_lm_state(rnnlm_state, 0, vidx)
|
|
if ctc_scorer[0]:
|
|
for idx in range(self.num_encs):
|
|
ctc_state[idx] = ctc_scorer[idx].index_select_state(
|
|
ctc_state[idx], accum_best_ids
|
|
)
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
dummy_hyps = [{"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])}]
|
|
ended_hyps = [
|
|
ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps
|
|
for samp_i in six.moves.range(batch)
|
|
]
|
|
if normalize_score:
|
|
for samp_i in six.moves.range(batch):
|
|
for x in ended_hyps[samp_i]:
|
|
x["score"] /= len(x["yseq"])
|
|
|
|
nbest_hyps = [
|
|
sorted(ended_hyps[samp_i], key=lambda x: x["score"], reverse=True)[
|
|
: min(len(ended_hyps[samp_i]), recog_args.nbest)
|
|
]
|
|
for samp_i in six.moves.range(batch)
|
|
]
|
|
|
|
return nbest_hyps
|
|
|
|
def calculate_all_attentions(self, hs_pad, hlen, ys_pad, strm_idx=0, lang_ids=None):
|
|
"""Calculate all of attentions
|
|
|
|
:param torch.Tensor hs_pad: batch of padded hidden state sequences
|
|
(B, Tmax, D)
|
|
in multi-encoder case, list of torch.Tensor,
|
|
[(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]
|
|
:param torch.Tensor hlen: batch of lengths of hidden state sequences (B)
|
|
[in multi-encoder case, list of torch.Tensor,
|
|
[(B), (B), ..., ]
|
|
:param torch.Tensor ys_pad:
|
|
batch of padded character id sequence tensor (B, Lmax)
|
|
:param int strm_idx:
|
|
stream index for parallel speaker attention in multi-speaker case
|
|
:param torch.Tensor lang_ids: batch of target language id tensor (B, 1)
|
|
:return: attention weights with the following shape,
|
|
1) multi-head case => attention weights (B, H, Lmax, Tmax),
|
|
2) multi-encoder case =>
|
|
[(B, Lmax, Tmax1), (B, Lmax, Tmax2), ..., (B, Lmax, NumEncs)]
|
|
3) other case => attention weights (B, Lmax, Tmax).
|
|
:rtype: float ndarray
|
|
"""
|
|
# to support mutiple encoder asr mode, in single encoder mode,
|
|
# convert torch.Tensor to List of torch.Tensor
|
|
if self.num_encs == 1:
|
|
hs_pad = [hs_pad]
|
|
hlen = [hlen]
|
|
|
|
# TODO(kan-bayashi): need to make more smart way
|
|
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
|
|
att_idx = min(strm_idx, len(self.att) - 1)
|
|
|
|
# hlen should be list of list of integer
|
|
hlen = [list(map(int, hlen[idx])) for idx in range(self.num_encs)]
|
|
|
|
self.loss = None
|
|
# prepare input and output word sequences with sos/eos IDs
|
|
eos = ys[0].new([self.eos])
|
|
sos = ys[0].new([self.sos])
|
|
if self.replace_sos:
|
|
ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)]
|
|
else:
|
|
ys_in = [torch.cat([sos, y], dim=0) for y in ys]
|
|
ys_out = [torch.cat([y, eos], dim=0) for y in ys]
|
|
|
|
# padding for ys with -1
|
|
# pys: utt x olen
|
|
ys_in_pad = pad_list(ys_in, self.eos)
|
|
ys_out_pad = pad_list(ys_out, self.ignore_id)
|
|
|
|
# get length info
|
|
olength = ys_out_pad.size(1)
|
|
|
|
# initialization
|
|
c_list = [self.zero_state(hs_pad[0])]
|
|
z_list = [self.zero_state(hs_pad[0])]
|
|
for _ in six.moves.range(1, self.dlayers):
|
|
c_list.append(self.zero_state(hs_pad[0]))
|
|
z_list.append(self.zero_state(hs_pad[0]))
|
|
att_ws = []
|
|
if self.num_encs == 1:
|
|
att_w = None
|
|
self.att[att_idx].reset() # reset pre-computation of h
|
|
else:
|
|
att_w_list = [None] * (self.num_encs + 1) # atts + han
|
|
att_c_list = [None] * (self.num_encs) # atts
|
|
for idx in range(self.num_encs + 1):
|
|
self.att[idx].reset() # reset pre-computation of h in atts and han
|
|
|
|
# pre-computation of embedding
|
|
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
|
|
|
|
# loop for an output sequence
|
|
for i in six.moves.range(olength):
|
|
if self.num_encs == 1:
|
|
att_c, att_w = self.att[att_idx](
|
|
hs_pad[0], hlen[0], self.dropout_dec[0](z_list[0]), att_w
|
|
)
|
|
att_ws.append(att_w)
|
|
else:
|
|
for idx in range(self.num_encs):
|
|
att_c_list[idx], att_w_list[idx] = self.att[idx](
|
|
hs_pad[idx],
|
|
hlen[idx],
|
|
self.dropout_dec[0](z_list[0]),
|
|
att_w_list[idx],
|
|
)
|
|
hs_pad_han = torch.stack(att_c_list, dim=1)
|
|
hlen_han = [self.num_encs] * len(ys_in)
|
|
att_c, att_w_list[self.num_encs] = self.att[self.num_encs](
|
|
hs_pad_han,
|
|
hlen_han,
|
|
self.dropout_dec[0](z_list[0]),
|
|
att_w_list[self.num_encs],
|
|
)
|
|
att_ws.append(att_w_list.copy())
|
|
ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim)
|
|
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
|
|
|
|
if self.num_encs == 1:
|
|
# convert to numpy array with the shape (B, Lmax, Tmax)
|
|
att_ws = att_to_numpy(att_ws, self.att[att_idx])
|
|
else:
|
|
_att_ws = []
|
|
for idx, ws in enumerate(zip(*att_ws)):
|
|
ws = att_to_numpy(ws, self.att[idx])
|
|
_att_ws.append(ws)
|
|
att_ws = _att_ws
|
|
return att_ws
|
|
|
|
@staticmethod
|
|
def _get_last_yseq(exp_yseq):
|
|
last = []
|
|
for y_seq in exp_yseq:
|
|
last.append(y_seq[-1])
|
|
return last
|
|
|
|
@staticmethod
|
|
def _append_ids(yseq, ids):
|
|
if isinstance(ids, list):
|
|
for i, j in enumerate(ids):
|
|
yseq[i].append(j)
|
|
else:
|
|
for i in range(len(yseq)):
|
|
yseq[i].append(ids)
|
|
return yseq
|
|
|
|
@staticmethod
|
|
def _index_select_list(yseq, lst):
|
|
new_yseq = []
|
|
for i in lst:
|
|
new_yseq.append(yseq[i][:])
|
|
return new_yseq
|
|
|
|
@staticmethod
|
|
def _index_select_lm_state(rnnlm_state, dim, vidx):
|
|
if isinstance(rnnlm_state, dict):
|
|
new_state = {}
|
|
for k, v in rnnlm_state.items():
|
|
new_state[k] = [torch.index_select(vi, dim, vidx) for vi in v]
|
|
elif isinstance(rnnlm_state, list):
|
|
new_state = []
|
|
for i in vidx:
|
|
new_state.append(rnnlm_state[int(i)][:])
|
|
return new_state
|
|
|
|
# scorer interface methods
|
|
def init_state(self, x):
|
|
# to support mutiple encoder asr mode, in single encoder mode,
|
|
# convert torch.Tensor to List of torch.Tensor
|
|
if self.num_encs == 1:
|
|
x = [x]
|
|
|
|
c_list = [self.zero_state(x[0].unsqueeze(0))]
|
|
z_list = [self.zero_state(x[0].unsqueeze(0))]
|
|
for _ in six.moves.range(1, self.dlayers):
|
|
c_list.append(self.zero_state(x[0].unsqueeze(0)))
|
|
z_list.append(self.zero_state(x[0].unsqueeze(0)))
|
|
# TODO(karita): support strm_index for `asr_mix`
|
|
strm_index = 0
|
|
att_idx = min(strm_index, len(self.att) - 1)
|
|
if self.num_encs == 1:
|
|
a = None
|
|
self.att[att_idx].reset() # reset pre-computation of h
|
|
else:
|
|
a = [None] * (self.num_encs + 1) # atts + han
|
|
for idx in range(self.num_encs + 1):
|
|
self.att[idx].reset() # reset pre-computation of h in atts and han
|
|
return dict(
|
|
c_prev=c_list[:],
|
|
z_prev=z_list[:],
|
|
a_prev=a,
|
|
workspace=(att_idx, z_list, c_list),
|
|
)
|
|
|
|
def score(self, yseq, state, x):
|
|
# to support mutiple encoder asr mode, in single encoder mode,
|
|
# convert torch.Tensor to List of torch.Tensor
|
|
if self.num_encs == 1:
|
|
x = [x]
|
|
|
|
att_idx, z_list, c_list = state["workspace"]
|
|
vy = yseq[-1].unsqueeze(0)
|
|
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
|
|
if self.num_encs == 1:
|
|
att_c, att_w = self.att[att_idx](
|
|
x[0].unsqueeze(0),
|
|
[x[0].size(0)],
|
|
self.dropout_dec[0](state["z_prev"][0]),
|
|
state["a_prev"],
|
|
)
|
|
else:
|
|
att_w = [None] * (self.num_encs + 1) # atts + han
|
|
att_c_list = [None] * (self.num_encs) # atts
|
|
for idx in range(self.num_encs):
|
|
att_c_list[idx], att_w[idx] = self.att[idx](
|
|
x[idx].unsqueeze(0),
|
|
[x[idx].size(0)],
|
|
self.dropout_dec[0](state["z_prev"][0]),
|
|
state["a_prev"][idx],
|
|
)
|
|
h_han = torch.stack(att_c_list, dim=1)
|
|
att_c, att_w[self.num_encs] = self.att[self.num_encs](
|
|
h_han,
|
|
[self.num_encs],
|
|
self.dropout_dec[0](state["z_prev"][0]),
|
|
state["a_prev"][self.num_encs],
|
|
)
|
|
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
|
|
z_list, c_list = self.rnn_forward(ey, z_list, c_list, state["z_prev"], state["c_prev"])
|
|
if self.context_residual:
|
|
logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1))
|
|
else:
|
|
logits = self.output(self.dropout_dec[-1](z_list[-1]))
|
|
logp = F.log_softmax(logits, dim=1).squeeze(0)
|
|
return (
|
|
logp,
|
|
dict(
|
|
c_prev=c_list[:],
|
|
z_prev=z_list[:],
|
|
a_prev=att_w,
|
|
workspace=(att_idx, z_list, c_list),
|
|
),
|
|
)
|
|
|
|
|
|
def decoder_for(args, odim, sos, eos, att, labeldist):
|
|
return Decoder(
|
|
args.eprojs,
|
|
odim,
|
|
args.dtype,
|
|
args.dlayers,
|
|
args.dunits,
|
|
sos,
|
|
eos,
|
|
att,
|
|
args.verbose,
|
|
args.char_list,
|
|
labeldist,
|
|
args.lsm_weight,
|
|
args.sampling_probability,
|
|
args.dropout_rate_decoder,
|
|
getattr(args, "context_residual", False), # use getattr to keep compatibility
|
|
getattr(args, "replace_sos", False), # use getattr to keep compatibility
|
|
getattr(args, "num_encs", 1),
|
|
) # use getattr to keep compatibility
|