Point Cloud Library (PCL) 1.14.0
Loading...
Searching...
No Matches
fern_trainer.hpp
1/*
2 * Software License Agreement (BSD License)
3 *
4 * Point Cloud Library (PCL) - www.pointclouds.org
5 * Copyright (c) 2010-2011, Willow Garage, Inc.
6 *
7 * All rights reserved.
8 *
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
11 * are met:
12 *
13 * * Redistributions of source code must retain the above copyright
14 * notice, this list of conditions and the following disclaimer.
15 * * Redistributions in binary form must reproduce the above
16 * copyright notice, this list of conditions and the following
17 * disclaimer in the documentation and/or other materials provided
18 * with the distribution.
19 * * Neither the name of Willow Garage, Inc. nor the names of its
20 * contributors may be used to endorse or promote products derived
21 * from this software without specific prior written permission.
22 *
23 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34 * POSSIBILITY OF SUCH DAMAGE.
35 *
36 */
37
38#pragma once
39
40namespace pcl {
41
42template <class FeatureType,
43 class DataSet,
44 class LabelType,
45 class ExampleIndex,
46 class NodeType>
48: fern_depth_(10)
49, num_of_features_(1000)
50, num_of_thresholds_(10)
51, feature_handler_(nullptr)
52, stats_estimator_(nullptr)
53, data_set_()
54, label_data_()
55, examples_()
56{}
57
58template <class FeatureType,
59 class DataSet,
60 class LabelType,
61 class ExampleIndex,
62 class NodeType>
63void
66{
67 const std::size_t num_of_branches = stats_estimator_->getNumOfBranches();
68 const std::size_t num_of_examples = examples_.size();
69
70 // create random features
71 std::vector<FeatureType> features;
72 feature_handler_->createRandomFeatures(num_of_features_, features);
73
74 // setup fern
75 fern.initialize(fern_depth_);
76
77 // evaluate all features
78 std::vector<std::vector<float>> feature_results(num_of_features_);
79 std::vector<std::vector<unsigned char>> flags(num_of_features_);
80
81 for (std::size_t feature_index = 0; feature_index < num_of_features_;
82 ++feature_index) {
83 feature_results[feature_index].reserve(num_of_examples);
84 flags[feature_index].reserve(num_of_examples);
85
86 feature_handler_->evaluateFeature(features[feature_index],
87 data_set_,
88 examples_,
89 feature_results[feature_index],
90 flags[feature_index]);
91 }
92
93 // iteratively select features and thresholds
94 std::vector<std::vector<std::vector<float>>> branch_feature_results(
95 num_of_features_); // [feature_index][branch_index][result_index]
96 std::vector<std::vector<std::vector<unsigned char>>> branch_flags(
97 num_of_features_); // [feature_index][branch_index][flag_index]
98 std::vector<std::vector<std::vector<ExampleIndex>>> branch_examples(
99 num_of_features_); // [feature_index][branch_index][result_index]
100 std::vector<std::vector<std::vector<LabelType>>> branch_label_data(
101 num_of_features_); // [feature_index][branch_index][flag_index]
102
103 // - initialize branch feature results and flags
104 for (std::size_t feature_index = 0; feature_index < num_of_features_;
105 ++feature_index) {
106 branch_feature_results[feature_index].resize(1);
107 branch_flags[feature_index].resize(1);
108 branch_examples[feature_index].resize(1);
109 branch_label_data[feature_index].resize(1);
110
111 branch_feature_results[feature_index][0] = feature_results[feature_index];
112 branch_flags[feature_index][0] = flags[feature_index];
113 branch_examples[feature_index][0] = examples_;
114 branch_label_data[feature_index][0] = label_data_;
115 }
116
117 for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
118 // get thresholds
119 std::vector<std::vector<float>> thresholds(num_of_features_);
120
121 for (std::size_t feature_index = 0; feature_index < num_of_features_;
122 ++feature_index) {
123 thresholds.reserve(num_of_thresholds_);
124 createThresholdsUniform(num_of_thresholds_,
125 feature_results[feature_index],
126 thresholds[feature_index]);
127 }
128
129 // compute information gain
130 int best_feature_index = -1;
131 float best_feature_threshold = 0.0f;
132 float best_feature_information_gain = 0.0f;
133
134 for (std::size_t feature_index = 0; feature_index < num_of_features_;
135 ++feature_index) {
136 for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds_;
137 ++threshold_index) {
138 float information_gain = 0.0f;
139 for (std::size_t branch_index = 0;
140 branch_index < branch_feature_results[feature_index].size();
141 ++branch_index) {
142 const float branch_information_gain =
143 stats_estimator_->computeInformationGain(
144 data_set_,
145 branch_examples[feature_index][branch_index],
146 branch_label_data[feature_index][branch_index],
147 branch_feature_results[feature_index][branch_index],
148 branch_flags[feature_index][branch_index],
149 thresholds[feature_index][threshold_index]);
150
151 information_gain +=
152 branch_information_gain *
153 branch_feature_results[feature_index][branch_index].size();
154 }
155
156 if (information_gain > best_feature_information_gain) {
157 best_feature_information_gain = information_gain;
158 best_feature_index = static_cast<int>(feature_index);
159 best_feature_threshold = thresholds[feature_index][threshold_index];
160 }
161 }
162 }
163
164 // add feature to the feature list of the fern
165 fern.accessFeature(depth_index) = features[best_feature_index];
166 fern.accessThreshold(depth_index) = best_feature_threshold;
167
168 // update branch feature results and flags
169 for (std::size_t feature_index = 0; feature_index < num_of_features_;
170 ++feature_index) {
171 std::vector<std::vector<float>>& cur_branch_feature_results =
172 branch_feature_results[feature_index];
173 std::vector<std::vector<unsigned char>>& cur_branch_flags =
174 branch_flags[feature_index];
175 std::vector<std::vector<ExampleIndex>>& cur_branch_examples =
176 branch_examples[feature_index];
177 std::vector<std::vector<LabelType>>& cur_branch_label_data =
178 branch_label_data[feature_index];
179
180 const std::size_t total_num_of_new_branches =
181 num_of_branches * cur_branch_feature_results.size();
182
183 std::vector<std::vector<float>> new_branch_feature_results(
184 total_num_of_new_branches); // [branch_index][example_index]
185 std::vector<std::vector<unsigned char>> new_branch_flags(
186 total_num_of_new_branches); // [branch_index][example_index]
187 std::vector<std::vector<ExampleIndex>> new_branch_examples(
188 total_num_of_new_branches); // [branch_index][example_index]
189 std::vector<std::vector<LabelType>> new_branch_label_data(
190 total_num_of_new_branches); // [branch_index][example_index]
191
192 for (std::size_t branch_index = 0;
193 branch_index < cur_branch_feature_results.size();
194 ++branch_index) {
195 const std::size_t num_of_examples_in_this_branch =
196 cur_branch_feature_results[branch_index].size();
197
198 std::vector<unsigned char> branch_indices;
199 branch_indices.reserve(num_of_examples_in_this_branch);
200
201 stats_estimator_->computeBranchIndices(cur_branch_feature_results[branch_index],
202 cur_branch_flags[branch_index],
203 best_feature_threshold,
204 branch_indices);
205
206 // split results into different branches
207 const std::size_t base_branch_index = branch_index * num_of_branches;
208 for (std::size_t example_index = 0;
209 example_index < num_of_examples_in_this_branch;
210 ++example_index) {
211 const std::size_t combined_branch_index =
212 base_branch_index + branch_indices[example_index];
213
214 new_branch_feature_results[combined_branch_index].push_back(
215 cur_branch_feature_results[branch_index][example_index]);
216 new_branch_flags[combined_branch_index].push_back(
217 cur_branch_flags[branch_index][example_index]);
218 new_branch_examples[combined_branch_index].push_back(
219 cur_branch_examples[branch_index][example_index]);
220 new_branch_label_data[combined_branch_index].push_back(
221 cur_branch_label_data[branch_index][example_index]);
222 }
223 }
224
225 branch_feature_results[feature_index] = new_branch_feature_results;
226 branch_flags[feature_index] = new_branch_flags;
227 branch_examples[feature_index] = new_branch_examples;
228 branch_label_data[feature_index] = new_branch_label_data;
229 }
230 }
231
232 // set node statistics
233 // - re-evaluate selected features
234 std::vector<std::vector<float>> final_feature_results(
235 fern_depth_); // [feature_index][example_index]
236 std::vector<std::vector<unsigned char>> final_flags(
237 fern_depth_); // [feature_index][example_index]
238 std::vector<std::vector<unsigned char>> final_branch_indices(
239 fern_depth_); // [feature_index][example_index]
240 for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
241 final_feature_results[depth_index].reserve(num_of_examples);
242 final_flags[depth_index].reserve(num_of_examples);
243 final_branch_indices[depth_index].reserve(num_of_examples);
244
245 feature_handler_->evaluateFeature(fern.accessFeature(depth_index),
246 data_set_,
247 examples_,
248 final_feature_results[depth_index],
249 final_flags[depth_index]);
250
251 stats_estimator_->computeBranchIndices(final_feature_results[depth_index],
252 final_flags[depth_index],
253 fern.accessThreshold(depth_index),
254 final_branch_indices[depth_index]);
255 }
256
257 // - distribute examples to nodes
258 std::vector<std::vector<LabelType>> node_labels(
259 0x1 << fern_depth_); // [node_index][example_index]
260 std::vector<std::vector<ExampleIndex>> node_examples(
261 0x1 << fern_depth_); // [node_index][example_index]
262
263 for (std::size_t example_index = 0; example_index < num_of_examples;
264 ++example_index) {
265 std::size_t node_index = 0;
266 for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
267 node_index *= num_of_branches;
268 node_index += final_branch_indices[depth_index][example_index];
269 }
270
271 node_labels[node_index].push_back(label_data_[example_index]);
272 node_examples[node_index].push_back(examples_[example_index]);
273 }
274
275 // - compute and set statistics for every node
276 const std::size_t num_of_nodes = 0x1 << fern_depth_;
277 for (std::size_t node_index = 0; node_index < num_of_nodes; ++node_index) {
278 stats_estimator_->computeAndSetNodeStats(data_set_,
279 node_examples[node_index],
280 node_labels[node_index],
281 fern[node_index]);
282 }
283}
284
285template <class FeatureType,
286 class DataSet,
287 class LabelType,
288 class ExampleIndex,
289 class NodeType>
290void
292 createThresholdsUniform(const std::size_t num_of_thresholds,
293 std::vector<float>& values,
294 std::vector<float>& thresholds)
295{
296 // estimate range of values
297 float min_value = ::std::numeric_limits<float>::max();
298 float max_value = -::std::numeric_limits<float>::max();
299
300 const std::size_t num_of_values = values.size();
301 for (int value_index = 0; value_index < num_of_values; ++value_index) {
302 const float value = values[value_index];
303
304 if (value < min_value)
305 min_value = value;
306 if (value > max_value)
307 max_value = value;
308 }
309
310 const float range = max_value - min_value;
311 const float step = range / (num_of_thresholds + 2);
312
313 // compute thresholds
314 thresholds.resize(num_of_thresholds);
315
316 for (int threshold_index = 0; threshold_index < num_of_thresholds;
317 ++threshold_index) {
318 thresholds[threshold_index] = min_value + step * (threshold_index + 1);
319 }
320}
321
322} // namespace pcl
Class representing a Fern.
Definition fern.h:49
float & accessThreshold(const std::size_t threshold_index)
Access operator for thresholds.
Definition fern.h:177
FeatureType & accessFeature(const std::size_t feature_index)
Access operator for features.
Definition fern.h:157
void initialize(const std::size_t num_of_decisions)
Initializes the fern.
Definition fern.h:59
static void createThresholdsUniform(const std::size_t num_of_thresholds, std::vector< float > &values, std::vector< float > &thresholds)
Creates uniformly distributed thresholds over the range of the supplied values.
void train(Fern< FeatureType, NodeType > &fern)
Trains a decision tree using the set training data and settings.
FernTrainer()
Constructor.