上篇《xgboost: A Scalable Tree Boosting System论文及源码导读》介绍了xgboost的框架和代码结构。本篇将继续讨论代码细节,可能比较枯燥,坚持一下哈。
接下来按这个顺序整理笔记,介绍xgboost的核心代码:

  1. 树的结构如何?
  2. 树的操作有哪些?
    主要包括:怎么生成一棵树? 何时,且如何进行剪枝?什么时候进行数据采样,怎么采样? 分布式并行怎么完成?

上面这2个问题是寻求实现具体算法的思考过程,笔者也在一边读一边记录,有理解错谬的地方请指出。

树的结构如何?

从代码的UML图可以看到RegTree派生自TreeModel类,TreeModel类中定义了子类Node,结构如下:
树的定义
从代码角度介绍两个重要的数据结构TreeModel

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
//TreeModel主要成员有:
std::vector<Node> nodes;// 节点
std::vector<int> deleted_nodes;// 记录已删除节点
std::vector<TNodeStat> stats; //节点的统计量
std::vector<bst_float> leaf_vector; //存储节点的增量信息
//介绍上图TreeModel类的两个模板参数template<typename TSplitCond, typename TNodeStat>
typedef TSplitCond SplitCond; //对于回归树的话,分裂值类型为float。
//节点的统计量特征
typedef TNodeStat Nodestat;
//Nodestat根据不同的树类型进行定义,如下RTreeNodeStat结构体。
struct RTreeNodeStat{
loss_chg; //当前节点的loss change
sum_hess, //hessian值的和
base_weight, //当前节点的权值
leaf_child_cnt, //当前孩子节点的数目
}

而Node具体代码为:

1
2
3
4
5
6
7
8
9
10
11
class Node{
int parent_; //父类
int cleft_, cright_; //左右叶子
unsigned sindex_; //选作分割节点的特征,最高位代表了方向左右,1表示左边。
Info info_; //union结构,存储叶子的值,或者分裂条件。结构如下:
//union Info{
//float leaf_value;
//TSplitCond split_cond;
//};
//此外包含许多相关的节点操作,例如set_split(), set_leaf(), set_parent()
}

对于基本数据结构稍事了解之后,我们开始进入主题,也是我着手写这篇文章的目的:

怎么生成一棵树?

在前一篇博客里面比较详细涉及到xgboost的一些算法,这些不再重复叙述,这里参考了杨军[2]的介绍,按单机版结合前面介绍的代码结构图,先进行一遍代码走读:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
Cli_main.cc:
main()
-> CLIRunTask()
-> CLITrain()
-> DMatrix::Load()
-> learner = Learner::Create()
-> learner->Configure()
-> learner->InitModel()
-> for (i = 0; i < param.num_round; ++i)//train
-> learner->UpdateOneIter()
-> learner->Save()
learner.cc:
Create()
-> new LearnerImpl() //派生自learner
Configure() //配置obj, updater和模型的相关参数
InitModel()
-> LazyInitModel()
-> obj_ = ObjFunction::Create()
// objective.cc
// Create():
// SoftmaxMultiClassObj(multiclass_obj.cc)
// LambdaRankObj(rank_obj.cc)
// RegLossObj(regression_obj.cc)
// PoissonRegression(regression_obj.cc)
-> gbm_ = GradientBooster::Create()
// gbm.cc
// Create()
// GBTree(gbtree.cc)
// GBLinear(gblinear.cc)
-> obj_->Configure()
-> gbm_->Configure()
UpdateOneIter()
-> PredictRaw() //预测样本标签
-> obj_->GetGradient() //计算样本的一阶导,二阶导
-> gbm_->DoBoost() //进行boost(tree model/linear model)
/*当gbm为gbtree*/
gbtree.cc:
Configure() //根据配置初始化树相关操作
-> for (up in updaters)
-> up->Init()
DoBoost() //每一级生成一颗RegTree
-> BoostNewTrees()
-> new_tree = new RegTree()
-> for (up in updaters)
-> up->Update(new_tree)
/*updaters的定义*/
tree_updater.cc:
Create()
-> ColMaker/DistColMaker(updater_colmaker.cc)
SketchMaker(updater_skmaker.cc)
TreeRefresher(updater_refresh.cc)
TreePruner(updater_prune.cc)
HistMaker / CQHistMaker /
GlobalProposalHistMaker/
QuantileHistMaker(updater_histmaker.cc)
TreeSyncher(updater_sync.cc)

从代码走读中可以看到,gbtree最核心处定义了ColMaker, DistColMaker, LocalHistMaker, GlobalHistMaker, HistMaker, TreePruner, TreeRefesher, SketchMaker, TreeSyncher 等多种树操作,在论文里有介绍的仅有部分。
下面逐一进行分解,看看这不同的操作实质上代表了什么动作,如何实现。
xgboost在树的操作定义使用装饰(decorator)模式[1],基类TreeUpdater为抽象构件,定义了基本接口:
TreeUpdater
初始化init(),树的更新操作update(), 取得更新操作后叶子位置GetLeafPosition(), 产生具体构建updater的creat()函数。

ColMaker

从Updater类中派生ColMaker为具体构建类(为论文实现的单机多线程版本),主要用于实现单棵树生成,类的结构如下图。
Comaker
ColMaker类中最重要为树的Builder 结构体,在Builder中实现ColMaker主要算法,主要实现在Builder->Update()函数中。

下面具体介绍下Builder->Update()函数的核心代码(updater_colmaker.cc):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
//核心代码:
//初始化数据和节点的映射关系,根据配置进行样本的降采样(伯努利采样)
//qexpand_用于存储每次探索出候选树节点,初始化为root节点。
this->InitData(gpair, *p_fmat, *p_tree);
//计算qexpand_队列中所有候选节点的损失函数和权重
this->InitNewNode(qexpand_, gpair, *p_fmat, *p_tree);
//根据树的最大深度进行生长
for (int depth = 0; depth < param.max_depth; ++depth) {
//查找最佳分裂点
this->FindSplit(depth, qexpand_, gpair, p_fmat, p_tree);
this->UpdateSolution()//找到分裂点,有两种并行方式
//特征间并行方式
//每个线程处理一维特征,遍历数据累计统计量(grad/hess)得到最佳分裂点split_point
this->EnumerateSplit()
//特征内并行方式
//在每个线程里汇总各个线程内分配到的数据样本的统计量(grad/hess);
//每个线程输出对应样本整体的统计量,得到一个线程级别统计量数组
//在这个组内进行枚举选出最佳分裂点,进一步定位到对应线程的最优分割点
this->ParallelFindSplit()
this->SyncBestSolution(qexpand);//找到最优解,设为当前分裂节点
//根据分裂结果,将数据重新映射到子节点。
this->ResetPosition(qexpand_, p_fmat, *p_tree);
//将待扩展分割的叶子结点用于替换qexpand_,作为下一轮split的候选节点
this->UpdateQueueExpand(*p_tree, &qexpand_);
//重新初始化,计算qexpand_队列中所有候选节点的损失函数和权重
this->InitNewNode(qexpand_, gpair, *p_fmat, *p_tree);
//若无需继续分裂,则停止
if (qexpand_.size() == 0) break;
}

怎样剪枝

TreePruner (prune a tree given the statistics)

剪枝操作由TreePruner进行,对于一棵树,通过比对节点所带来的loss收益决定是否剪枝。
核心代码(updater_prune.cc)如下:

1
2
3
4
5
6
7
8
TreePruner->updater()
for (each tree)
//对每一棵树进行剪枝
->DoPrune();
->for (each node)
//对当前节点进行剪枝
//如果loss_chg<threlshold,剪枝。并递归判断新节点是否需要剪枝。
TryPruneLeaf();

如何进行分布式?

HistMaker (use histogram counting to construct a tree)

Histgram使用直方图法近似加速建树过程,HistMaker和Colmaker一样,用于建树,不同是HistMaker并不直接由基类派生,而是HistMaker->BaseMaker->Updater,关系如图:HistMaker
HistMaker类定义了基于直方图法的方法接口,具体的实现在CQHistMaker/QuantileHistMaker,QuantileHistMaker中调用了SketchMaker进行采样.
BaseMaker(updater_basemaker-inl.h)在Updater基类上增加了一些公共操作,后面介绍的SketchMaker也是派生于BaseMaker。按row based进行分布式建树,每台机器各自找到各自候选分割点,每一部分算出自己的统计量,用allreduce合并起来后再根据全局统计量计算最终的分割点,最终层次遍历的构建树[1]
由核心代码(Updater_histmaker.cc)走读如下:

1
2
3
4
5
6
7
8
9
10
11
/初始化样本和树节点的映射关系,根据配置进行样本的降采样(伯努利采样)
//队列qexpand_用于存储每次探索出候选树节点,初始化为root节点。
->InitData()
->UpdateNode2WorkIndex() //node2workindex大小为树的节点个数。初始化为qexpand中待扩展节点。
->InitWorkSet() //初始化fwork_set,大小为每颗树的节点,表示用于建树的特征。
->for (depth < param.max_depth)
->ResetPosAndPropose() //重置并设置候选分裂特征
->CreateHist() //生成近似直方图
->FindSplit() //通过直方图统计量找到合适的分裂节点
->ResetPositionAfterSplit() //对新加入的节点重新分配样本映射关系
->UpdateQueueExpand() //更新待探索的候选节点列表qexpand_

xgboost的论文中未有提及直方图法,是树模型常用的近似方法,对统计量进行聚合统计,存储为一个个桶,然后在这些桶之间寻找最佳分裂点,为了进一步了解这种方式的建树过程,这里需要重点探索下CreatHist() 以及FindSplit(),两处函数具体的实现在派生类CQHistMaker->HistMaker,CQHistMaker实现了具体的分布式计算,在InitWorkSet()中进行机器节点的数据分片和信息同步。其类UML图如下
CQHistMaker
CreatHist()作用为将统计量存储于待分割节点的直方图分桶中,对CreatHist()函数进行代码走读,其分布式计算是基于RABIT实现的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
fset: feature_index
//遍历所有特征
while (p_fmat->ColIterator(fset)->Next()) {
const ColBatch &batch = iter->Value();//取得每个特征对应数据
->UpdateHistCol() //多线程更新特征直方图
//对_qexpand每一个待分裂节点对应,特征为nid,初始化直方图
//统计直方图,简介无cache优化的过程如下
for (bst_uint j = 0; j < ColBatch.length; ++j)
->hbuilder[nid].Add(fvalue, gpair)//特征值fvalue用于确定分桶位置,gpair为统计量。
//对_qexpand中节点计算其对应统计量的和
->GetNodeStats()
}
//rabit::allreduce。借助mpi的allreduce进行理解,思考2
//从各个计算结点汇总统计量
->histred.Allreduce()

FindSplit()通过处理每个桶的特征量进行分裂点查找,代码走读如下:

1
2
3
4
5
6
7
//对于每一个待分裂节点
for (wid = 0; wid < qexpand.size(); ++wid)
//对于每个特征
for (size_t i = 0; i < fset.size(); ++i)
//枚举查找最佳分裂点,直方图上从左到右,从右到左两个方向分别查找
//根据统计量找到最好的桶,将桶的分界点作为特征的split_value
->EnumerateSplit()

近似算法怎么实现?

SketchMaker(use approximation sketch to construct a tree)

介绍SketchMaker之前,需要对此前论文导读中的加权分位数略图做少许补充。在传统的GK 算法[4]和其他扩展[5]之上,作者提出加权分位数略图算法,并引入支持merge和prune操作的数据结构进行实现。先介绍quantile summary,定义为以相对错误率$\varepsilon $返回分位数查询结果。

  1. merge操作:若两个summary的的相对正确率分别为$\varepsilon_1 $和$\varepsilon_2 $,将二者merge得到的新summary的相对错误率为$max \{\varepsilon_1, \varepsilon_2 \}$
  2. prune操作:将summary中元素个数缩减到$b+1$那么相对错误率率将从$\varepsilon$增加到$\varepsilon+{1 \over b}$。

二者如何实现,详细介绍见引文[3]附录部分《WEIGHTED QUANTILE SKETCH》。
xgboost通过不同的summary操作(WQSummary/WXQSummary/GKSummary)定义QuantileSketchTemplate类型,从而派生出不同的具体的sketch类(WQuantileSketch/WXQuantileSketch/GKQuantileSketch)。其中WXQSummary派生自WQSummary,在prune操作上效率要高些。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
//初始化数据和节点的映射关系,根据配置进行样本的降采样(伯努利采样)
//qexpand_用于存储每次探索出候选树节点,初始化为root节点。
->InitData()
for (depth = 0; depth < param.max_depth; ++depth) {
->GetNodeStats() //对_qexpand中节点计算其对应统计量的和
->BuildSketch() //初始化WXQuantileSketch,用于分位图策略(采样策略,需后期进一步了解)
->UpdateSketchCol() //更新每个特征的WXQuantileSketch统计量
//(pos_grad/neg_grad/sum_hess)
->SetPrune() //根据sketchs_size进行剪枝
//sketchs_size=sketch_ratio / sketch_eps(参数可配置)
->SyncNodeStats() //同步所有节点的信息SKStats
->FindSplit() //查找最佳分裂点
//枚举sketchs找到最佳特征的分裂点
->EnumerateSplit()
->ResetPositionCol() ///对新加入的节点重新分配样本映射关系
->UpdateQueueExpand() //更新待探索的候选节点列表_qexpand

自此从代码角度交代了前文《xgboost: A Scalable Tree Boosting System论文及源码》的作者的具体实现思路,希望能够进一步加深对boosting方法的理解。目前最新的进展为XGBoost4J-Spark发布进一步融入Spark的应用场景中(2016/10/26)见附录[3],后续的发展会继续跟进。
PS. TreeRefresher 代码中未见引用这里不做进一步介绍。

[1]设计模式——装饰模式(Decorator), @shu_lin
[2]机器学习算法中GBDT和XGBOOST的区别有哪些?, @杨军
[3]Chen, Tianqi, and C. Guestrin. “XGBoost: A Scalable Tree Boosting System.” (2016).
[4]L. Breiman. “Random forests. Maching Learning”, 2001.
[5]Q. Zhang and W. Wang. A fast algorithm for approximate quantiles in high speed data streams. In Proceedings of the 19th International Conference on Scientic and Statistical Database Management, 2007.

附录:
[1]xgoost调参技巧:
使用TrainValidationSplit和RegressionEvaluator自动进行高度max_depth和树权重$\eta$调参。
[2]MPI Allreduce操作介绍,可结合reduce理解,reduce产生的作用示意如下图:
Reduce
计算中需要归集所有结果,再分发到各个process中,Allreduce起这个作用,可以结合最小均方差的例子进行理解(点击链接),具体在流程中的作用如下图示意:
Allreduce
[3]XGboost与Spark的完全集成 @Nan Zhu

彩蛋:
[1]《PRML中文版》百度网盘地址, @马春鹏

写这篇博客有个痛点,代码的过度抽象和大量模板类的使用对阅读带来比较大的伤害,也是这篇文章历时许久的部分原因。如果您有好的代码阅读工具,不吝一并推荐下。
介绍下我的两款辅助软件:
Source Insight: 主司代码阅读
visual paradigm for UML 主司UML类图解析
另外,推荐百度学术的引文功能,写博客很方便

ShawnXiao@baidu