GradientBoostingDecisionTree
gradient boosting decision tree
Install / Use
/learn @MegrezZhu/GradientBoostingDecisionTreeREADME
Gradient Boosting Decision Tree
1. 构建与使用
1.1 构建
- Windows: 使用 Visual Studio 2017 打开解决方案并生成即可。
- Linux: 根目录提供了
makefile文件,使用make编译即可,需要gcc >= 5.4.0
1.2 使用
-
用法:
boost <config_file> <train_file> <test_file> <predict_dest> -
接受 LibSVM 格式的训练数据输入,如下每行代表一个训练样本:
<label> <feature-index>:<feature-value> <feature-index>:<feature-value> <feature-index>:<feature-value> -
用于预测的数据输入和训练数据类似:
<id> <feature-index>:<feature-value> <feature-index>:<feature-value> <feature-index>:<feature-value> -
目前只支持二分类问题
-
<config_file>指定训练参数:eta = 1. # shrinkage rate gamma = 0. # minimum gain required to split a node maxDepth = 6 # max depth allowed minChildWeight = 1 # minimum allowed size for a node to be splitted rounds = 1 # REQUIRED. number of subtrees subsample = 1. # subsampling ratio for each tree colsampleByTree = 1. # tree-wise feature subsampling ratio maxThreads = 1; # max running threads features; # REQUIRED. number of features validateSize = .2 # if greater than 0, input data will be split into two sets and used for training and validation repectively
2. 算法原理
GBDT 的核心可以分成两部分,分别是 Gradient Boosting 和 Decision Tree:
- Decision Tree : GBDT 的基分类器,通过划分输入样本的特征使得落在相同特征的样本拥有大致相同的 label。由于在 GBDT 中需要对若干不同的 Decision Tree 的结果进行综合,因此一般采用的是 Regression Tree (回归树)而不是 Classification Tree (分类树)。
- Gradient Boosting: 迭代式的集成算法,每一棵决策树的学习目标 y 都是之前所有树的结论和的残差(即梯度方向),也即
。
3. 实现与优化历程
各个部分的实现均经过若干次“初版实现 - 性能 profiling - 优化得到下一版代码”的迭代。其中,性能 profiling 部分,使用的是 Visual Studio 2017 的“性能探查器”功能,在进行性能 profile 之前均使用 release 模式编译(打开/O2 /Oi优化选项)。
3.1 数据处理
选择的输入文件数据格式是 Libsvm 的格式,格式如下:
<label> <feature-index>:<feature-value> <feature-index>:<feature-value>
可以看到这种格式天然适合用来表示稀疏的数据集,但在实现过程中,为了简单起见以及 cache 性能,我通过将空值填充为 0 转化为密集矩阵形式存储。代价是内存占用会相对高许多。
3.1.1 初版
最初并没有做什么优化,采用的是如下的简单流程:
- 文件按行读取
- 对于每一行内容,先转成
std::stringstream,再从中解析出相应的数据。
核心代码如下:
ifstream in(path);
string line;
while (getline(in, line)) {
auto item = parseLibSVMLine(move(line), featureCount); // { label, vector }
x.push_back(move(item.first));
y.push_back(item.second);
}
/* in parseLibSVMLine */
stringstream ss(line);
ss >> label;
while (ss) {
char _;
ss >> index >> _ >> value;
values[index - 1] = value;
}
profile 结果:

可以看到,主要的耗时在于将一行字符串解析成我们需要的 label + vector 数据这一过程中,进一步分析:

因此得知主要问题在于字符串解析部分。此时怀疑是 std::stringstream 的实现为了线程安全、错误检查等功能牺牲了性能,因此考虑使用 cstdio 中的实现。
3.1.2 改进
将 parseLibSVMLine 的实现重写,使用cstdio 中的sscanf 代替了 std::stringstream:
int lastp = -1;
for (size_t p = 0; p < line.length(); p++) {
if (isspace(line[p]) || p == line.length() - 1) {
if (lastp == -1) {
sscanf(line.c_str(), "%zu", &label);
}
else {
sscanf(line.c_str() + lastp, "%zu:%lf", &index, &value);
values[index - 1] = value;
}
lastp = int(p + 1);
}
}
profile 结果:

可以看到,虽然 parse 部分仍然是计算的热点,但这部分的计算量显著下降(53823 -> 23181),读取完整个数据集的是时间减少了 50% 以上。
3.1.3 最终版
显然,在数据集中,每一行之间的解析任务都是相互独立的,因此可以在一次性读入整个文件并按行划分数据后,对数据的解析进行并行化:
string content;
getline(ifstream(path), content, '\0');
stringstream in(move(content));
vector<string> lines;
string line;
while (getline(in, line)) lines.push_back(move(line));
#pragma omp parallel for
for (int i = 0; i < lines.size(); i++) {
auto item = parseLibSVMLine(move(lines[i]), featureCount);
#pragma omp critical
{
x.push_back(move(item.first));
y.push_back(item.second);
}
}
根据 profile 结果,进行并行化后,性能提升了约 25%。CPU 峰值占用率从 15% 上升到了 70%。可以发现性能的提升并没有 CPU 占用率的提升高,原因根据推测有以下两点:
- 读取文件的IO时间,在测试时使用的是 672MB 的数据集,因此光是读取全部内容就占了 50% 以上的时间
- 多线程同步的代价
3.2 决策树生成
决策树生成的过程采用的是 depth-first 深度优先的方式,即不断向下划分子树直到遇到下面的终止条件之一:
- 达到限定的最大深度
- 划分收益小于阈值
- 该节点中的样本数小于阈值
大致代码如下:
auto p = new RegressionTree();
// calculate value for prediction
p->average = calculateAverageY();
if (x.size() > nodeThres) {
// try to split
auto ret = findSplitPoint(x, y, index);
if (ret.gain > 0 && maxDepth > 1) { // check splitablity
// split points
// ...
// ...
p->left = createNode(x, y, leftIndex, maxDepth - 1);
p->right = createNode(x, y, rightIndex, maxDepth - 1);
}
}
3.2.1 计算划分点
在哪个特征的哪个值上做划分是决策树生成过程中最核心(也是最耗时)的部分。
问题描述如下:
对于数据集 ,我们要找到特征
以及该特征上的划分点
,满足 MSE (mean-square-error 均方误差) 最小:
其中:
,即划分后的样本 label 均值。
,
为划分后的子数据集。
等价地,如果用 表示划分收益:
其中,
为划分前的 MSE:
,
。
寻找最佳划分点等价于寻找收益最高的划分方案:
3.2.1.1 基于排序的实现
分析:
显然, 与
都只与分割点左边(右边)的部分和有关,因此可以先排序、再从小到大枚举分割点计算出所有分割情况的收益,对于每个特征,时间复杂度均为
。
代码如下:
for (size_t featureIndex = 0; featureIndex < x.front().size(); featureIndex++) {
vector<pair<size_t, double>> v(index.size());
for (size_t i = 0; i < index.size(); i++) {
auto ind = index[i];
v[i].first = ind;
v[i].second = x[ind][featureIndex];
}
// sorting
tuple<size_t, double, double> tup;
sort(v.begin(), v.end(), [](const auto &l, const auto &r) {
return l.second < r.second;
});
// maintaining sums of y_i and y_i^2 in both left and right part
double wholeErr, leftErr, rightErr;
double wholeSum = 0, leftSum, rightSum;
double wholePowSum = 0, leftPowSum, rightPowSum;
for (const auto &t : v) {
wholeSum += y[t.first];
wholePowSum += pow(y[t.first], 2);
}
wholeErr = calculateError(index.size(), wholeSum, wholePowSum);
leftSum = leftPowSum = 0;
rightSum = wholeSum;
rightPowSum = wholePowSum;
for (size_t i = 0; i + 1 < index.size(); i++) {
auto label = y[v[i].first];
leftSum += label;
rightSum -= label;
leftPowSum += pow(label, 2);
rightPowSum -= pow(label, 2);
if (y[v[i].first] == y[v[i + 1].first]) continue; // same label with next, not splitable
if (v[i].second == v[i + 1].second) continue; // same value, not splitable
leftErr = calculateError(i + 1, leftSum, leftPowSum);
rightErr = calculateError(index.size() - i - 1, rightSum, rightPowSum);
// calculate error gain
double gain = wholeErr - ((i + 1) * leftErr / index.size() + (index.size() - i - 1) * rightErr / index.size());
if (gain > bestGain) {
bestGain = gain;
bestSplit = (v[i].second + v[i + 1].second) / 2;
bestFeature = featureIndex;
}
}
}
profile 结果:

可以看到, sorting 以及 sorting 之前的数据准备部分占了很大一部分时间。
3.2.1.2 基于采样分桶的实现
由于之前基于排序的实现耗时较大,因此考虑换一种方法。后来翻 LightGBM 的优化方案,在参考文献[^1]里看到一个叫做 Sampling the Splitting points (SS) 的方法,比起 LightGBM 的方案, SS 方法更加容易实现。
SS 方法描述如下:
对于 个乱序的数值,我们先从中随机
Related Skills
node-connect
349.0kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
109.4kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
349.0kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
349.0kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
