kaldi中ivector的提取程序在ivector-extract
和ivector-extract-online2
中,分别提取离线和在线的ivector。考虑到后续需要分析online的解码逻辑,所以在第二篇笔记中会仔细介绍在线情况下统计信息的累计和ivector估计方法。本篇主要介绍离线方法以及在线方法的整体框架。
离线方法
离线方法是指,在事先获取了句子的特征和高斯后验的情况下,估计出一个ivector向量。
对于句子 $\mathbf{U}_{T \times F} = \{\mathbf{x}_1, \mathbf{x}_2, \cdots, \mathbf{x}_T\}^T$和对应的高斯后验 $\mathbf{P}_{T \times C}= \{\mathbf{p}_1, \mathbf{p}_2, \cdots, \mathbf{p}_T\}^T$,首先得到提取ivector所需要的零阶和一阶统计量:
存储在$\boldsymbol{\gamma}_{C \times 1}$和$\mathbf{F}_{C \times F}$中。
之后利用加载好的ivector提取器,调用GetIvectorDistribution
函数并传入上述统计量将得到的均值作为ivector。ivector提取器IvectorExtractor
中维护ivector提取过程中说话人无关的一些量,比如$\mathbf{T}$矩阵和UBM方差等等,由于kaldi代码中变量命名和我们常见的公式中不一致,这里以论文中常见的表示为标准。比较重要的量如下:1
2
3
4std::vector<Matrix<double> > M_;
std::vector<SpMatrix<double> > Sigma_inv_;
Matrix<double> U_; //上三角作为一行
std::vector<Matrix<double> > Sigma_inv_M_;
其中M_
表示$\mathbf{T}_{C \times F \times R}$,表示为:
Sigma_inv_
表示$\mathbf{\Sigma}^{-1}_{C \times F \times F}$,表示为:
U_
和Sigma_inv_M_
由上面两个变量计算得到,分别定义为$\mathbf{U}_{C \times R \times R}, \mathbf{B}_{C \times F \times R}$:
令$\mathbf{w} = \mathbf{Q}^{-1}\mathbf{L}$,那么$\mathbf{Q}, \mathbf{L}$的计算过程如下:
观察到$\mathbf{T}_i^T \mathbf{\Sigma}^{-1}_i\mathbf{T}_i$为对称矩阵,所以为了存储高效,保留上三角即可,因而$\mathbf{U}_{C \times R \times R}$可以退化为$\mathbf{U}_{C \times (R + 1) / 2}$,写成:
那么$\mathbf{Q}_{R \times R}$可由$\mathbf{U}^T_{(R + 1) / 2 \times C} \boldsymbol{\gamma}_{C \times 1}$计算得到。
上述过程对应的代码在下面的函数中实现,参数mean
即为传入的ivector向量。1
2
3
4void IvectorExtractor::GetIvectorDistribution (
const IvectorExtractorUtteranceStats &utt_stats,
VectorBase<double> *mean,
SpMatrix<double> *var) const
在线方法
在线方法需要提供一个配置文件,这个文件使用脚本prepare_online_decoding.sh
生成,一个典型的配置文件如下:1
2
3
4
5
6
7
8
9
10
11--splice-config=$PREFIX/exp/nnet2_online/nnet2_ms_online/conf/splice.conf
--cmvn-config=$PREFIX/exp/nnet2_online/nnet2_ms_online/conf/online_cmvn.conf
--lda-matrix=$PREFIX/exp/nnet2_online/nnet2_ms_online/ivector_extractor/final.mat
--global-cmvn-stats=$PREFIX/exp/nnet2_online/nnet2_ms_online/ivector_extractor/global_cmvn.stats
--diag-ubm=$PREFIX/exp/nnet2_online/nnet2_ms_online/ivector_extractor/final.dubm
--ivector-extractor=$PREFIX/exp/nnet2_online/nnet2_ms_online/ivector_extractor/final.ie
--num-gselect=5
--min-post=0.025
--posterior-scale=0.1
--max-remembered-frames=1000
--max-count=100
其中前六个比较熟悉,分别是拼帧配置,在线cmvn配置,LDA变换矩阵,全局cmvn统计量,UBM和训练好的ivector提取器。还有两个配置参数比较重要,分别是--use_most_recent_ivector
和--ivector_period
。前者默认为true
,表示每次使用最估计的ivector,否则计算出的ivector需要缓存下来,以便获取到设定时间估计出的ivector,后者设置每多少帧估计一个ivector。
这些配置文件用来初始化OnlineIvectorExtractionConfig
,具体的对象载入在OnlineIvectorExtractionInfo
中完成,前者作为后者初始化的参数。核心代码流程如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
std::string spk = spk2utt_reader.Key();
const std::vector<std::string> &uttlist = spk2utt_reader.Value();
// adaptation_state是针对一个同一个说话人的,结构体成员:
// OnlineCmvnState cmvn_state;
// OnlineIvectorEstimationStats ivector_stats;
OnlineIvectorExtractorAdaptationState adaptation_state(ivector_info);
// 对于同一个说话人的所有句子
for (size_t i = 0; i < uttlist.size(); i++) {
std::string utt = uttlist[i];
// 得到该句子的特征
const Matrix<BaseFloat> &feats = feature_reader.Value(utt);
OnlineMatrixFeature matrix_feature(feats);
// 根据ivector_info初始化一系列具体的online特征
// 比如OnlineSpliceFrames,OnlineTransform,OnlineCmvn等等
OnlineIvectorFeature ivector_feature(ivector_info, &matrix_feature);
// ivector_stats_和cmvn_初始化
ivector_feature.SetAdaptationState(adaptation_state);
// repeat 默认false,ivector_period表示每多少帧取一个ivector,默认为10
int32 T = feats.NumRows(),
n = (repeat ? 1 : ivector_config.ivector_period),
num_ivectors = (T + n - 1) / n;
// num_ivectors 决定一句话提取多少个ivector
Matrix<BaseFloat> ivectors(num_ivectors, ivector_feature.Dim());
for (int32 i = 0; i < num_ivectors; i++) {
int32 t = i * n;
// 对应第i个ivector
SubVector<BaseFloat> ivector(ivectors, i);
// 核心过程,调用函数UpdateStatsUntilFrame
ivector_feature.GetFrame(t, &ivector);
}
// Update diagnostics.
// ...
// 更新adaptation_state
ivector_feature.GetAdaptationState(&adaptation_state);
// 完成提取
ivector_writer.Write(utt, ivectors);
num_done++;
}
}
上述过程的核心:GetFrame
方法的逻辑会在kaldi中ivector的提取【二】中详细分析。