Kaldi中的决策树

三音素状态绑定这部分通过以下四个步骤完成,分别对应acc-tree-statscluster-phonescompile-questionbuild-tree

  1. 统计量累计
  2. 音素聚类
  3. 问题集生成
  4. 决策树构建

统计量累计

累计统计量在acc-tree-stats.cc中的AccumulateTreeStats函数中完成,接受一个transition model,特征序列和对应的对齐信息,输出统计量表示。

统计量的结构表示:

1
2
3
4
5
6
// 如果EventKeyType表示position信息,那么EventValueType表示该位置上的phone_id
// 如果EventKeyType为kPdfClass,那么EventValueType为对应的pdf_class
// pair<EventKeyType, EventValueType>组合为一个vector表示成一个EventType
// 一般单因素,1+1,三因素,3+1
typedef std::vector<std::pair<EventKeyType, EventValueType> > EventType;
std::map<EventType, GaussClusterable*> tree_stats;

统计量的具体信息在GaussClusterable里面,里面维护的是特征计数,特征向量和其平方和:
1
2
3
4
5
// n
double count_;
// stats_(0) => X1 + X2 + ... + Xn
// stats_(1) => X1^2 + X2^2 + ... + Xn^2
Matrix<double> stats_;

在后续聚类操作的时候,用到了一个Objf函数计算似然值,函数逻辑是:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
BaseFloat GaussClusterable::Objf() const {
size_t dim = stats_.NumCols();
Vector<double> vars(dim);
double objf_per_frame = 0.0;
for (size_t d = 0; d < dim; d++) {
// default var_floor_ = 0.01
double mean(stats_(0, d) / count_), var = stats_(1, d) / count_ - mean * mean,
floored_var = std::max(var, var_floor_);
vars(d) = floored_var;
objf_per_frame += -0.5 * var / floored_var;
}
objf_per_frame += -0.5 * (vars.SumLog() + M_LOG_2PI * dim);
return objf_per_frame * count_;
}

和推出来的似然函数保持一致:

对$\gamma_s^t$的理解:

  1. 为什么似然函数用上式表示
  2. $t$时刻的观测是否来自状态$s$,那么$\gamma^t$最多是一个one-hot的向量
  3. $t$时刻的观测是否来自状态$s$的概率

该过程通过如下几步进行。

  • 把对齐信息按照音素划分,一行tid对应一个音素,存在std::vector<std::vector<int32> >里面。
    注意一下,原始的对齐信息默认应该是进行过重新排序的,就是将自环放在出环之后。

比如对于Transition Model如下(部分):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
Transition-state 1: phone = SIL hmm-state = 0 pdf = 0
Transition-id = 1 p = 0.952097 [self-loop]
Transition-id = 2 p = 0.01 [0 -> 1]
Transition-id = 3 p = 0.01 [0 -> 2]
Transition-id = 4 p = 0.0279074 [0 -> 3]
Transition-state 2: phone = SIL hmm-state = 1 pdf = 1
Transition-id = 5 p = 0.921613 [self-loop]
Transition-id = 6 p = 0.0531309 [1 -> 2]
Transition-id = 7 p = 0.01 [1 -> 3]
Transition-id = 8 p = 0.0152566 [1 -> 4]
Transition-state 3: phone = SIL hmm-state = 2 pdf = 2
Transition-id = 9 p = 0.0146723 [2 -> 1]
Transition-id = 10 p = 0.96533 [self-loop]
Transition-id = 11 p = 0.01 [2 -> 3]
Transition-id = 12 p = 0.01 [2 -> 4]
Transition-state 4: phone = SIL hmm-state = 3 pdf = 3
Transition-id = 13 p = 0.01 [3 -> 1]
Transition-id = 14 p = 0.01 [3 -> 2]
Transition-id = 15 p = 0.928052 [self-loop]
Transition-id = 16 p = 0.0519551 [3 -> 4]
Transition-state 5: phone = SIL hmm-state = 4 pdf = 4
Transition-id = 17 p = 0.957834 [self-loop]
Transition-id = 18 p = 0.0421665 [4 -> 5]
...
Transition-state 9: phone = ONE hmm-state = 0 pdf = 8
Transition-id = 25 p = 0.865902 [self-loop]
Transition-id = 26 p = 0.134098 [0 -> 1]
Transition-state 10: phone = ONE hmm-state = 1 pdf = 9
Transition-id = 27 p = 0.921862 [self-loop]
Transition-id = 28 p = 0.078138 [1 -> 2]
Transition-state 11: phone = ONE hmm-state = 2 pdf = 10
Transition-id = 29 p = 0.936872 [self-loop]
Transition-id = 30 p = 0.0631278 [2 -> 3]
...
Transition-state 24: phone = SIX hmm-state = 0 pdf = 23
Transition-id = 55 p = 0.90631 [self-loop]
Transition-id = 56 p = 0.0936895 [0 -> 1]
Transition-state 25: phone = SIX hmm-state = 1 pdf = 24
Transition-id = 57 p = 0.783409 [self-loop]
Transition-id = 58 p = 0.216591 [1 -> 2]
Transition-state 26: phone = SIX hmm-state = 2 pdf = 25
Transition-id = 59 p = 0.931359 [self-loop]
Transition-id = 60 p = 0.0686414 [2 -> 3]
Transition-state 27: phone = SEVEN hmm-state = 0 pdf = 26
Transition-id = 61 p = 0.916657 [self-loop]
Transition-id = 62 p = 0.0833432 [0 -> 1]
Transition-state 28: phone = SEVEN hmm-state = 1 pdf = 27
Transition-id = 63 p = 0.886809 [self-loop]
Transition-id = 64 p = 0.113191 [1 -> 2]
Transition-state 29: phone = SEVEN hmm-state = 2 pdf = 28
Transition-id = 65 p = 0.898122 [self-loop]
Transition-id = 66 p = 0.101878 [2 -> 3]
...
Transition-state 33: phone = NINE hmm-state = 0 pdf = 32
Transition-id = 73 p = 0.842715 [self-loop]
Transition-id = 74 p = 0.157285 [0 -> 1]
Transition-state 34: phone = NINE hmm-state = 1 pdf = 33
Transition-id = 75 p = 0.78074 [self-loop]
Transition-id = 76 p = 0.21926 [1 -> 2]
Transition-state 35: phone = NINE hmm-state = 2 pdf = 34
Transition-id = 77 p = 0.902379 [self-loop]
Transition-id = 78 p = 0.0976208 [2 -> 3]

那么tid序列可以为
1
74 73 73 73 73 76 75 75 75 75 75 75 78 77 77 77 77 77 77 77 77 77 77 77 77 77 77 62 61 61 61 61 61 61 61 61 61 61 61 61 61 61 61 61 64 63 63 63 63 63 63 63 63 63 63 63 63 63 63 66 65 65 26 28 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 30 4 1 1 1 1 1 16 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 18 56 58 57 57 60 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 26 25 25 25 28 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 30 29 29

按照音素划分之后,得到:

1
2
3
4
5
6
11 => [74 73 73 73 73 76 75 75 75 75 75 75 78 77 77 77 77 77 77 77 77 77 77 77 77 77 77 ]
9 => [62 61 61 61 61 61 61 61 61 61 61 61 61 61 61 61 61 64 63 63 63 63 63 63 63 63 63 63 63 63 63 63 66 65 65 ]
3 => [26 28 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 30 ]
1 => [4 1 1 1 1 1 16 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 18 ]
8 => [56 58 57 57 60 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 59 ]
3 => [26 25 25 25 28 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 30 29 29 ]

  • 对于每一条对齐信息,获取三音素和pdf_classEventType事件类型,并累加其对应的GaussClusterable统计量。比如说对于以上对齐信息,可以得到音素序列11-9-3-1-8-3三音素的划分和EventType表示如下:
    1
    2
    3
    4
    5
    6
    0-11-9  => (0:0  1:11 2:9)
    11-9-3 => (0:11 1:9 2:3)
    9-3-1 => (0:9 1:3 2:1)
    3-1-8 => (0:3 1:1 2:8)
    1-8-3 => (0:1 1:8 2:3)
    8-3-0 => (0:8 1:3 2:0)
    对于每一种三音素划分,添加pdf_class即状态信息,并累计统计量,比如对于三音素0-11-9对应的统计量累计过程如下:
    1
    2
    3
    tree_stats(0:0 1:11 2:9 kPdfClass:0) += AddStats([74 73 73 73 73])
    tree_stats(0:0 1:11 2:9 kPdfClass:1) += AddStats([76 75 75 75 75 75 75])
    tree_stats(0:0 1:11 2:9 kPdfClass:2) += AddStats([78 77 77 77 77 77 77 77 77 77 77 77 77 77 77])
    其他三音素类推。

该过程执行完毕之后,将tree_stats信息写到磁盘中,实际就是EventTypeGaussClusterable信息。

音素聚类

音素聚类在cluster-phones.cc中的AutomaticallyObtainQuestions中完成,接受统计信息和set.int,输出音素聚类情况,也就是所谓的问题集。

聚类过程

音素聚类主要进行如下几步操作

  • 状态过滤
    默认只保留中间状态(pdf_class = 1)的统计量,这部分代码在FilterStatsByKey中,操作完毕之后,stats转为如下形式(vector存储):
    1
    2
    3
    4
    5
    (0:0  1:11 2:9 kPdfClass:1) => C1
    (0:11 1:9 2:3 kPdfClass:1) => C2
    (0:9 1:3 2:1 kPdfClass:1) => C3
    ...
    (0:8 1:3 2:0 kPdfClass:1) => C6
  • 音素划分
    按中间位置(P == 1)音素对统计量进行划分,累加,划分之后的统计量以phone_id为索引,这部分代码在SplitStatsByKeySumStatsVec中实现,完成之后,统计量转换为如下形式:
    1
    2
    3
    4
    5
    6
    phone_id    GaussClusterable
    1 C4
    3 C3 + C6
    8 C5
    9 C2
    11 C1
  • 按音素集合累加
    由于在set.int中,可能是多个音素共享一个HMM的,但是他们的phone_id是不同的,所以,需要把这些共享HMM的音素对应的统计量做一个合并,存到std::vector<Clusterable*>之中。完成之后,vector的长度和set.int文件的行数相同。

  • 决策树聚类
    这部分在函数TreeCluster中完成,输入按音素集合累加之后的std::vector<Clusterable*>。聚类过程使用决策树+KMeans,KMeans算法主要目的是生成获取似然提升的划分方案。内在逻辑在“TreeClusterer的聚类逻辑”中介绍,最终目的是要获取如下信息,以生成问题集(即聚类结果):

    1
    2
    3
    4
    5
    6
    std::vector<int32> assignments;  
    // assignment of phones to clusters. dim == summed_stats.size().
    std::vector<int32> clust_assignments;
    // Parent of each cluster. Dim == #clusters.
    int32 num_leaves;
    // number of leaf-level clusters. == leaf_node_.size()
  • 获取问题集合
    这部分在函数ObtainSetsOfPhones中实现。函数接受上面决策树聚类得到的三个统计量以及phone_sets(即从set.int读出来的std::vector<std::vector<int32> >),输出一个std::vector<std::vector<int32> >类型的问题集/音素聚类结果。该部分依次完成如下操作:

    1. 根据assignments,对phone_sets进行重新组合,每一类(cluster)对应一组phone_id,因为决策树可能把不同行的音素集划分为同一类了。
    2. 根据clust_assignments,组合出所有非叶子节点对应的phone_id,实际上就是其子节点phone_id之和,也就是要得到决策树上每一个节点音素id集合
    3. 补上原始phone_sets内的向量(暂时不理解),去空

最后将该聚类结果输入到文本question.int中,注意,实际看到的question.int可能还加上了一部分extra_question.int,比如在hkust中就是。展示question.int如下(不是完整的,加工自hkust,把set.int中每一行用对应的行号表示)

TreeClusterer的聚类逻辑

聚类过程通过TreeClusterer完成,下面主要分析TreeClusterer的聚类逻辑。

根据决策树理论,每一次节点分裂都是需要找到最大似然提升的方案,但是由于事先的这些音素统计量并没有标记信息,即无法根据标记信息制定划分方案,所以,一般采用无监督的方法进行二分类,取其最大的似然提升方案,kaldi中用到了KMeans聚类算法给出一次划分方案。

  • 决策树初始化
    初始化给决策树建立根节点,并执行一次最优划分FindBestSplit。决策树由一系列Node构成,其中维护了其自身和叶子节点的统计信息。节点信息表示如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    struct Node {
    bool is_leaf;
    int32 index; // index into leaf_nodes or nonleaf_nodes as applicable.
    Node *parent;
    Clusterable *node_total; // sum of all data with this node.
    struct {
    std::vector<Clusterable*> points;
    std::vector<int32> point_indices; // index of points
    BaseFloat best_split;
    std::vector<Clusterable*> clusters; // [branch_factor]... if we do split.
    std::vector<int32> assignments; // assignments of points to clusters.
    } leaf;
    std::vector<Node*> children; // vector of size branch_factor. if non-leaf.
    // pointers not owned here but in vectors leaf_nodes_, nonleaf_nodes_.
    };

    FindBestSplit总是在节点的points数据上执行最优划分,划分结果存储在clustersassignments中。assignments表示point_id属于哪一个cluster。划分成功的话,将获得的最大似然提升和节点信息放在一个优先队列中。

  • 最优划分逻辑
    一次KMeans分类操作定义在ClusterKMeansOnce中,迭代cfg.num_iters次返回结果(了解KMeans原理就知道为什么这么做了)。
    一次最优节点分裂FindBestSplit执行一次ClusterKMeansClusterKMeans具体操作是执行cfg.num_triesClusterKMeansOnce,找到最大获得最大似然提升的方案。

  • 聚类逻辑
    聚类采取BFS逻辑,不断的取出获取最大似然提升的节点,扩展子节点并在其子节点进行最优划分(这部分操作定义在DoSplit之中,主要是将父节点的leaf信息转移到新建子节点之中,并对子节点执行FindBestSplit),直至队列为空或者叶子节点达到上界,注意,这里的上界一般设为phone_sets的行数,假设set.int中有48组音素,那么聚类结果一定不大于48。

    在函数DoSplit中,需要完成以下操作:

    1. 根据父节点的leaf.assignments划分结果,初始化子节点的leaf.pointsleaf.point_indices
    2. 根据父节点的leaf.clusters,初始化子节点的node_total
    3. 给子节点的index编号,左节点继承父节点编号,在leaf_nodes_中替换父节点,右节点赋值为leaf_nodes_.size(),加入leaf_nodes_
    4. 将父节点标记为非叶子节点,index赋值为非叶子节点的编号nonleaf_nodes_.size()并加入nonleaf_nodes_
    5. 清空父节点的leaf信息。
  • 生成聚类信息
    num_leaves_out就是叶子节点的个数,其次主要就是要获取assignmentsclust_assignments两类信息。
    assignments是音素集合id到决策树中叶子节点id的映射关系,只需要遍历所有叶子节点,每一次将对应叶子节点的point_indices集合中元素下标处赋值为叶子节点id即可。
    clust_assignments用来维护节点之间的父子关系,长度为叶子节点和非叶子节点个数之和。由于节点index都是从0开始的,所以会有冲突。kaldi中把非叶子节点通过clust_assignments.size() - 1 - nonleaf_index映射到区间[leaf_nodes_.size(), clust_assignments.size()]

问题编译

该部分接受topo文件和question.int输出编译好的问题集,实际就是Questions这个类结构。定义在compile-question.cc中。Questions维护了EventKeyType和其对应的QuestionsForKey集合,其中QuestionsForKey表示对于特定EventKeyType的查询问题集。事实上,输出的问题集包含以下信息:

1
2
3
4
5
EventKeyType          QuestionsForKey
0 PhoneQuestions(question.int)
1 PhoneQuestions(question.int)
2 PhoneQuestions(question.int)
kPdfClass PdfClassQuestion

其中pdfClassQuestion根据每个HMM建模的状态数而不同:
1
2
3
MaxNumPdfclasses        PdfClassQuestion
3 [[0] [0 1]]
5 [[0] [0 1] [0 1 2] [0 1 2 3]]

决策树构建

未完待续……