解码是指在构建好的解码网络(图)中,根据输入的音频,生成最优序列的过程。在kaidi中,解码之前,解码网络(实际上是一个巨大的WFST)是已经构建好了的(一般称之为静态解码器),也就是熟知的HCLG。在这张图中,状态节点并无太大意义(毕竟不是自动机),信息存储在状态和状态的转移边之间。输入为tid(tid为0表示没有输入/输入为空$\epsilon$,因而可以连续跳转),输出为词,权值为语言模型的权值。解码过程中,声学模型的后验概率需要实时计算得出(特征+声学模型)。理解解码过程需要弄清以下几个方面:
- 解码过程中使用到的数据结构及其拓扑关系
- 解码过程中的剪枝策略(个人觉得这部分比较抽象)
- 解码过程中的特征供给
下面是以LatticeFasterOnlineDecoder
为分析对象的一些记录。
数据结构
解码过程是一帧一帧进行的,在约束图(解码图)中,$t$时刻可以到达的状态必然是由$t - 1$时刻的状态出发的,也就是说,信息流从$t - 1$时刻向$t$时刻传递。如果不考虑效率问题,将每个时刻可能到达的状态按照时间展开,那么就可以得到一颗巨大的树。每一个从根节点到叶子节点的路径都可以理解为一条可能路径,只不过有的出现概率高,有的低而已。
用token表示上述过程中描述的信息,那么每一时刻,每一状态维护一个token,在$t \in [0, T]$时刻,新的一轮token由$t - 1$时刻的token和解码图的路径约束共同形成。kaldi中token用结构体Token
来描述,同一时刻的token通过next
指针相连,从该时刻可以传递到的状态用ForwardLink
描述,前一时刻的token由HashList<StateId, Token*>
描述(这是一个kaldi自己实现的模板,其中的Hash元素本质上是使用链表连接的,但是也可以像Hash表一样随机获取),而整个解码过程中每一时刻产生的token(相当于上面说的树)用std::vector<TokenList>
描述,下标为当前帧数,TokenList
相当于一个Token
的链表头,通过它可以依次获取到相应时刻的所有活跃token,具体的成员及其含义如下:
Token
1
2
3
4
5
6
7
8struct Token {
BaseFloat tot_cost; // 到该状态的累计cost
BaseFloat extra_cost;
ForwardLink *links; // 也是一个链表,因为由该token可以到达下一时刻的token不止一个
Token *next; // 指向同一时刻的下一个token
Token *backpointer; // 指向上一时刻的最佳token,相当于一个回溯指针,
// 到达该状态的token可能会有很多,但只取最优的一个
};ForwardLink
1
2
3
4
5
6
7
8struct ForwardLink {
Token *next_tok; // 这条链接指向的token
Label ilabel; // 这下面的四个量取自解码图中的跳转/弧/边,因为每一个状态
Label olabel; // 维护一个token,那么token到token之间的连接信息和状态到状态之间的信息
BaseFloat graph_cost; // 应该保持一致,所以会有输入(tid),输出,权值(就是graph_cost)
BaseFloat acoustic_cost; // acoustic_cost就是tid对应的pdf_id的在声学模型中的后验
ForwardLink *next; // 链表结构,指向下一个
};TokenList
1
2
3
4
5struct TokenList {
Token *toks; // 同一时刻的token链表头
bool must_prune_forward_links; // 这两个是Lattice剪枝标记
bool must_prune_tokens;
};
这几部分的关系可以用下图描述:
另一个重要的数据结构是kaldi中自己实现的HashList
,一个类似于跳表的数据结构,里面维护的元素本质上以链表的形式相连,同时通过一个Hash表建立索引。元素和哈希通过Elem
和HashBucket
实现:1
2
3
4
5
6
7
8
9
10struct Elem {
I key; // State
T val; // Token
Elem *tail; // 链表,指向下一个元素
};
struct HashBucket {
size_t prev_bucket; // 指向前一个桶的下标,类似静态链表的索引方法
Elem *last_elem; // 指向挂在该桶上的最后一个元素,找到他就可以索引该桶上所有元素了
}
整个哈希结结构存在容器std::vector<HashBucket> buckets_
中,通过SetSize()
可以分配buckets_
的大小。需要注意的是,前一个HashBucket
的last_elem.tail
指向当前HashBucket
的第一个元素。其他重要的变量如下:1
2
3
4Elem *list_head_; // Elem链表会不断释放,分配,该变量记录链表头部
Elem *freed_head_; // 记录空闲链表的头部,分配新的Elem就是从该头部取出一个空闲Elem
size_t bucket_list_tail_; // 当前活跃的最后一个桶的下标
size_t hash_size_; // 当前活跃的桶的个数
该数据结构的操作逻辑如下:
获取当前活跃的所有元素
Elem
只需要返回list_head
即可,通过链表的遍历即可访问到所有活跃元素(即上一时刻的State和对应的Token)。如何遍历Hash
通过bucket_list_tail_
获取到最后一个活跃桶下标,通过HashBucket.prev_bucket
访问到前一个,直到访问到-1标记结束位为止。如何清空Hash
清空Hash只需要将相应的标记置为“空”即可,不需要实际释放元素内存。对于Hash而言,遍历一遍,将HashBucket.last_elem
置为NULL
,bucket_list_tail_
置为-1,对于Elem
而言,将链表头部list_head_
置为NULL
即可。如何删除元素
将需要删除的元素从list_head_
中插入到freed_head_
中,采用头插法(插入freed_head_
头部)如何查找元素
首先通过哈希函数对Key
哈希到下标,定位到具体的bucket,之后,根据之前提到的关系:“前一个HashBucket
的last_elem.tail
指向当前HashBucket
的第一个元素*first_elem
”,就可以在链表头尾之间遍历查询了。如何增加元素
如果freed_head_
不为空,那么意味着存在可用的空闲元素,将free_head_
指向的元素返回,并后移一位即可,如何已经耗尽,那么新分配一批(默认是1024个)Elem
赋给free_head_
,重复之前的过程即可。如何插入元素
首先拿出一个空闲的元素,赋值为相应的Key/Value
,使用Key
哈希到bucket的下标:
- 如果该桶不为空,将该元素插入到当前桶的尾部即可
- 若该桶为空,需要将
bucket_list_tail_
指向该桶,还要修正该桶的prev_bucket
(指向修改前的bucket_list_tail_
)以及last_elem
(指向该元素)。
如果调用toks = HashList.Clear()
,那么拿到的toks
实际上是被置为NULL
之前的HashList.list_head_
的值,也就是清空前哈希的链表结构,这时候由于HashList
已经重置过了,所以toks
成为了前一状态产生的哈希/Elem
链表的唯一引用。之后再对HashList
操作都不会干扰到toks
。因此,可以用toks
来替代prev_toks
,清空之后的HashList
代替cur_toks
。由于我们只需要遍历toks
,用一下token中的信息,使用完之后,通过Delete
函数,就可以将toks
中的元素插入到HashList.free_head_
中,实现内存的回收。
解码逻辑
解码的逻辑非常简单,在解码正式开始之前调用InitDecoding()
做一些初始化的工作,之后就是不断接受新的音频数据,调用AdvanceDecoding()
(每一次DecodableInterface
中的新的可用特征一次解码完毕),直至音频输入终止,调用FinalizeDecoding()
结束解码。
在AdvanceDecoding()
中,对于每一帧调用ProcessEmitting()
和ProcessNonemitting()
,后者处理ilabel == 0
的token传递(考虑到解码图中每一个状态均有自环,所以可以肯定的是,自环的输入label必然不能为空,否则就会陷入死循环)。每隔几帧剪一次枝。主要逻辑代码如下:1
2
3
4
5
6
7
8
9
10
11// 解码到没有为止
// NumFramesDecoded(): active_toks_.size() - 1
while (NumFramesDecoded() < target_frames_decoded) {
// 剪枝:默认25
if (NumFramesDecoded() % config_.prune_interval == 0) {
PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
}
// 这里active_toks_扩充一个空间
BaseFloat cost_cutoff = ProcessEmitting(decodable);
ProcessNonemitting(cost_cutoff);
}
下面对每个函数块逐一分析:
ProcessEmitting
这部分代码主要实现两个功能,一个是token信息沿着输入不为空的弧上的传递,另一个就是若干剪枝阈值的评估。这部分的剪枝有两层,第一层决前一时刻的token能否继续存在,也就是说,如果前一时刻某一个token的tot_cost
相比最好的token的tok_cost
差的太多,那么它就没有必要继续传递下去了(这条路径概率过低)。第二层决定前一时刻的token是否能将信息沿着解码图中的路径约束传递到当前时刻,这一步需要预估一个当前时刻的最优权值,如果token沿着某条弧转移过来过低于我们这个预估的值,那么也不允许它传递(因为就算传递了,也会在下一时刻的第一层剪枝中被剪掉)。
解释一下所谓的beam,beam的含义是,允许过滤项和最优项之间的差距在beam之内。解码过程中的cost越低越优,因此,当得出所谓的best_cost
之后,往往将best_cost + beam
最为剪枝的阈值。
代码中将上述第一层剪枝的阈值命名为cur_cutoff
,第二层剪枝阈值命名为next_cutoff
。第一层剪枝阈值的计算在函数GetCutoff()
中实现,这个函数不只计算出前一时刻的剪枝阈值,同时得出前一时刻token链表中token的个数,最佳token,以及估计next_cutoff
需要的adaptive_beam
。
考虑到解码中允许设置--max-active(default = MAX_INT)
和--min-active(default = 200)
,不同的设置在GetCutoff()
中的计算逻辑不同,因为beam
本身就是一个容差估计(只是这种估计并没有什么参照标准),但是现在有新的约束条件(token数目)约束了,因而需要结合在这种约束下的估计值选一个。分如下两种情况:
假设没有约束,即某一时刻的token数目没有限制,那么使用
beam
作为容差(默认为16),即:1
2cur_cutoff = best_weight + config_.beam; // default config_.beam = 16.0
adaptive_beam = config_.beam;存在$[N_{min}, N_{max}]$约束:
那么令$C_{min}$(max_active_cutoff
)和$C_{max}$(min_active_cutoff
)分别为token链中第max-active和min-active小的cost。显然,前者小于后者(约束更紧)。用$W_{best}$表示best_weight
,$b$表示config_.beam
:
那么cur_cutoff
取$\min_{2nd}(W_{best} + b, C_{min}, C_{max})$。如果结果是$W_{best} + b$,那么和情况1结果保持一致,否则adaptive_beam
计算如下:1
2// config_.beam_delta = 0.5
adaptive_beam = cur_cutoff - best_weight + config_.beam_delta;
因此,adaptive_beam
的作用是作为next_cutoff
的beam
替代,在--min-active/--max-active
形成约束力的时候。
对于next_cutoff
,使用上一时刻最优token前向传递产生的cost(加上adaptive_beam
)作为初始值。之后在处理token第二层剪枝过程中,如果传递形成的新的cost高于next_cutoff
,则不予传递,否则产生新的token。若tot_cost + adaptive_beam < next_cutoff
,则更新next_cutoff
。这部分代码如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 for (Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) {
StateId state = e->key;
Token *tok = e->val;
if (tok->tot_cost <= cur_cutoff) {
for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
!aiter.Done();
aiter.Next()) {
// 遍历状态的连接弧
const Arc &arc = aiter.Value();
if (arc.ilabel != 0) {
BaseFloat ac_cost = cost_offset -
decodable->LogLikelihood(frame, arc.ilabel),
graph_cost = arc.weight.Value(),
cur_cost = tok->tot_cost,
tot_cost = cur_cost + ac_cost + graph_cost;
// 剪枝
if (tot_cost > next_cutoff) continue;
// 继续更新,这是一个不断估计的过程
else if (tot_cost + adaptive_beam < next_cutoff)
next_cutoff = tot_cost + adaptive_beam;
// 一开始调用时tok_已经是空的了。这一步产生的是新的tok
Token *next_tok = FindOrAddToken(arc.nextstate,
frame + 1, tot_cost, tok, NULL);
// 把当前时刻的新token加入上一时刻的前向链表中
tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel,
graph_cost, ac_cost, tok->links);
}
}
}
// 下一个element
e_tail = e->tail;
// 从链表中拿掉/回收
toks_.Delete(e);
}
函数FindOrAddToken()
的作用是创建新的token(当前时刻)或者更新已有token的参数(tot_cost
或者回溯指针等等)。因为同一个状态可能会有不同的token带着不同的tot_cost
到达,但是只保留最优的一个。之前也说过,新建的token以链表形式相连,头部挂在std::vector<TokenList> active_toks_
中:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20FindOrAddToken(StateId state, int32 frame_plus_one, BaseFloat tot_cost,
Token *backpointer, bool *changed) {
// ...
Token *&toks = active_toks_[frame_plus_one].toks; // 注意引用,实际修改了
Elem *e_found = toks_.Find(state);
if (e_found == NULL) {
const BaseFloat extra_cost = 0.0;
// NULL表示暂时没有forwardlinks,t时刻只能形成t-1时刻的前向链表
// 头插法:这里new_tok.next = toks
// 这里的new_tok暂时不释放的
Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks, backpointer);
toks = new_tok;
num_toks_++; // 整个解码过程中的token数目
// 插入hash中维护,维护的是Elem,删除它不会释放token
toks_.Insert(state, new_tok);
if (changed) *changed = true;
return new_tok;
}
// ...
}active_toks_
也是一个十分重要的变量,它的大小是逐渐增加的,每一次调用ProcessEmitting()
都会扩充一次。整个解码器用它的大小来追溯已经处理了的总帧数:1
inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; }
假设现在处理的是第frame帧(frame从0开始计数),那么要获得该帧产生的token链表,active_toks_
的index应该是frame + 1。
ProcessNonemitting
从函数名可以看出,这一步处理的当前帧下输入为$\epsilon$的跳转/弧,也就是说没有声学的观测概率。这部分代码如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48ProcessNonemitting(BaseFloat cutoff) {
KALDI_ASSERT(!active_toks_.empty());
// active_toks_在ProcessEmitting()中已经扩容了/+1,所以要访问
// 当前帧需要-2
int32 frame = static_cast<int32>(active_toks_.size()) - 2;
KALDI_ASSERT(queue_.empty());
// push当前时刻到达的状态
// toks_.GetList()获取的是在ProcessEmitting()中新产生的token
// 也就是当前时刻的token
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail)
queue_.push_back(e->key);
// queue_中存储的是由当前时刻token对应的状态出发,可以通过空跳转
// 到达的所有状态
while (!queue_.empty()) {
StateId state = queue_.back();
queue_.pop_back();
Token *tok = toks_.Find(state)->val;
BaseFloat cur_cost = tok->tot_cost;
if (cur_cost > cutoff)
continue;
// 一般而言,当前时刻的token没有前向链表
tok->DeleteForwardLinks();
tok->links = NULL;
for (fst::ArcIterator<fst::Fst<Arc> > aiter(fst_, state);
!aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
// ilabel == 0表示non-emitting
if (arc.ilabel == 0) { // propagate nonemitting only...
BaseFloat graph_cost = arc.weight.Value(),
tot_cost = cur_cost + graph_cost; // acoustic_cost = 0
if (tot_cost < cutoff) {
bool changed;
Token *new_tok = FindOrAddToken(arc.nextstate,
frame + 1, tot_cost,
tok, &changed);
// ilabel == 0则acoustic_cost = 0
// tok指的是当前时刻的token
tok->links = new ForwardLink(new_tok, 0, arc.olabel,
graph_cost, 0, tok->links);
// 形成新的token或者更新了权值,说明新状态可到达
if (changed) queue_.push_back(arc.nextstate);
}
}
}
}
}
总体而言,解码过程就是不断的对一帧处理可观测跳转和空跳转,在拓展token的过程中执行剪枝策略,并将每一时刻的token链保存在active_tok_
中,用于最终形成Lattice。
PruneActiveTokens
除了在token传递过程中执行剪枝,kaldi还会每隔几帧执行一次PruneActiveTokens
操作。该操作虽然也会删掉一些不必要的token,但是阈值的计算是在active_toks_
中进行的(存在时间跨度)。首先执行PruneForwardLinks()
,剪去一些token的前向指针,之后会调用PruneTokensForFrame()
,将前向链接为空的token删去。
在PruneForwardLinks()
中,对于当前时刻的每一个token,程序会遍历其ForwardLinks,算出每一条link和最优路径的cost差/距离link_extra_cost
,如果差距过大(大于lattice_beam
)就剪掉该link。token->extra_cost
置为所有前向链接的link_extra_cost
中最小的一个(如果前向链接被删完了,token->extra_cost
会被置为无穷)。在下一步的PruneTokensForFrame()
就是通过token->extra_cost
来断定该token是否有前向链接的。整个搜索过程是一个时间轴上的回溯过程,即从当前时刻向初始时刻0开始。这里需要注意一个顺序问题,比如在$t - 1$时刻执行PruneForwardLinks()
,需要用到$t$时刻的token信息,因此这里的执行顺序应该是先执行$t$时刻的Links剪枝,再执行$t + 1$时刻的Token剪枝。PruneForwardLinks()
核心操作代码如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed,
bool *links_pruned, BaseFloat delta) {
// ...
bool changed = true; // 新的tok_extra_cost - 旧的tok_extra_cost > 1.0 ?
while (changed) {
changed = false;
// 当前时刻的每一个token
for (Token *tok = active_toks_[frame_plus_one].toks;
tok != NULL; tok = tok->next) {
ForwardLink *link, *prev_link = NULL;
BaseFloat tok_extra_cost = std::numeric_limits<BaseFloat>::infinity();
// 对于每一个token的每一个前向链接
// tok_extra_cost 取最小的link_extra_cost
for (link = tok->links; link != NULL; ) {
Token *next_tok = link->next_tok;
// extra_cost 初始化的时候为0
// 和最优路径的差距
// (tok->tot_cost + link->acoustic_cost + link->graph_cost)
// - next_tok->tot_cost >= 0
BaseFloat link_extra_cost = next_tok->extra_cost +
((tok->tot_cost + link->acoustic_cost + link->graph_cost)
- next_tok->tot_cost);
// 超过了阈值,删掉link
if (link_extra_cost > config_.lattice_beam) {
ForwardLink *next_link = link->next;
if (prev_link != NULL) prev_link->next = next_link;
else tok->links = next_link;
delete link;
link = next_link;
*links_pruned = true; // 表示有link删除,可能产生没有link的token了
} else { // 更新tok_extra_cost,保留link
if (link_extra_cost < 0.0) { // 正常不会这样
if (link_extra_cost < -0.01)
KALDI_WARN << "Negative extra_cost: " << link_extra_cost;
link_extra_cost = 0.0;
} // keep最小的
if (link_extra_cost < tok_extra_cost)
tok_extra_cost = link_extra_cost;
// 指向没有被删除的最后一个
prev_link = link;
link = link->next; // 下一个link
}
} // ForwardLink循环结束
// delta默认1.0
// 新的tok_extra_cost - 旧的tok_extra_cost > 1.0
if (fabs(tok_extra_cost - tok->extra_cost) > delta)
changed = true;
tok->extra_cost = tok_extra_cost;
// 要么+infinity 要么 <= lattice_beam_
// +infinity就会被剪掉
}
// 曾经有一次是true,它就是true
// extra_costs_changed表示tok->extra_cost更新了
// 因此前一时刻的tokens的extra_cost也需要重新计算了
if (changed) *extra_costs_changed = true;
}
}PruneForwardLinks()
会在两中情况下被执行:
- 在
active_toks_
中有新的TokenList
被拓展,因为初始化的时候must_prune_forward_links
和must_prune_tokens
为true
。 - 下一时刻的
TokenList
中,有token
的extra_cost
变化了,所以之前时刻都需要重新计算。PruneActiveTokens()
的核心代码如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25PruneActiveTokens(BaseFloat delta) {
// ...
for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) {
if (active_toks_[f].must_prune_forward_links) {
bool extra_costs_changed = false, links_pruned = false;
PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta);
// extra_costs_changed表示f时刻的token->extra_cost发生了变化
// 因此之前时刻的token->extra_cost的需要相应的重新计算
if (extra_costs_changed && f > 0)
active_toks_[f-1].must_prune_forward_links = true;
// links_pruned 表示是否有link被剪掉了
if (links_pruned)
active_toks_[f].must_prune_tokens = true;
// 剪完了
active_toks_[f].must_prune_forward_links = false;
}
// f + 1 != cur_frame_plus_one - 1 还没有ForwordLink
// f + 1时刻
if (f+1 < cur_frame_plus_one &&
active_toks_[f+1].must_prune_tokens) {
PruneTokensForFrame(f+1);
active_toks_[f+1].must_prune_tokens = false;
}
}
}
特征供给
这部分会和ivector的在线提取放在一起,说明kaldi对Online-Feature的包装。