-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTreeSurvival.h
164 lines (129 loc) · 5.9 KB
/
TreeSurvival.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
/*-------------------------------------------------------------------------------
This file is part of Ranger.
Ranger is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Ranger is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with Ranger. If not, see <http://www.gnu.org/licenses/>.
Written by:
Marvin N. Wright
Institut für Medizinische Biometrie und Statistik
Universität zu Lübeck
Ratzeburger Allee 160
23562 Lübeck
Germany
http://www.imbs-luebeck.de
#-------------------------------------------------------------------------------*/
#ifndef TREESURVIVAL_H_
#define TREESURVIVAL_H_
#include <ctime>
#ifndef OLD_WIN_R_BUILD
#include <thread>
#include <chrono>
#include <mutex>
#include <condition_variable>
#endif
#include "globals.h"
#include "Tree.h"
#include "utility.h"
#include "findBestSplitValueLR.h"
class TreeSurvival: public Tree {
public:
TreeSurvival(std::vector<double>* unique_timepoints, size_t status_varID, std::vector<size_t>* response_timepointIDs);
// Create from loaded forest
TreeSurvival(std::vector<std::vector<size_t>>& child_nodeIDs, std::vector<size_t>& split_varIDs,
std::vector<double>& split_values, std::vector<std::vector<double>> chf, std::vector<double>* unique_timepoints,
std::vector<size_t>* response_timepointIDs);
virtual ~TreeSurvival();
void initInternal();
void appendToFileInternal(std::ofstream& file);
void computePermutationImportanceInternal(std::vector<std::vector<size_t>>* permutations);
const std::vector<std::vector<double> >& getChf() const {
return chf;
}
const std::vector<double>& getPrediction(size_t sampleID) const {
size_t terminal_nodeID = prediction_terminal_nodeIDs[sampleID];
return chf[terminal_nodeID];
}
size_t getPredictionTerminalNodeID(size_t sampleID) const {
return prediction_terminal_nodeIDs[sampleID];
}
//maybe should belong to protected
protected:
// Show progress every few seconds
#ifdef OLD_WIN_R_BUILD
void showSProgress(std::string operation, clock_t start_time, clock_t& lap_time);
#else
void showSProgress(std::string operation);
#endif
// Multithreading
uint num_threads;
std::vector<uint> thread_ranges1;
#ifndef OLD_WIN_R_BUILD
std::mutex mutex;
std::condition_variable condition_variable;
#endif
std::vector<findBestSplitValueLR*> findBestSplitValueLRs;
// Computation progress (finished trees)
size_t Sprogress;
#ifdef R_BUILD
size_t aborted_threads;
bool aborted;
#endif
//private:
void createEmptyNodeInternal();
void computeSurvival(size_t nodeID);
double computePredictionAccuracyInternal();
bool splitNodeInternal(size_t nodeID, std::vector<size_t>& possible_split_varIDs);
bool findBestSplit(size_t nodeID, std::vector<size_t>& possible_split_varIDs);
//new
void findBestSplitValueLRInthread(uint thread_idx, size_t nodeID, std::vector<size_t>* possible_split_varIDs, double* best_value, size_t* best_varID,double* best_decrease);
//void findBestSplitValueLRInthread(uint thread_idx, size_t nodeID);
bool findBestSplitMaxstat(size_t nodeID, std::vector<size_t>& possible_split_varIDs);
void findBestSplitValueLogRank(size_t nodeID, size_t varID, std::vector<double>& possible_split_values,
double& best_value, size_t& best_varID, double& best_logrank);
void findBestSplitValueLogRankUnordered(size_t nodeID, size_t varID, std::vector<double>& factor_levels,
double& best_value, size_t& best_varID, double& best_logrank);
void findBestSplitValueAUC(size_t nodeID, size_t varID, double& best_value, size_t& best_varID, double& best_auc);
void computeDeathCounts(size_t nodeID);
void computeChildDeathCounts(size_t nodeID, size_t varID, std::vector<double>& possible_split_values,
size_t* num_samples_right_child, size_t* num_samples_at_risk_right_child, size_t* num_deaths_right_child,
size_t num_splits);
void computeAucSplit(double time_k, double time_l, double status_k, double status_l, double value_k, double value_l,
size_t num_splits, std::vector<double>& possible_split_values, double* num_count, double* num_total);
void findBestSplitValueLogRank(size_t nodeID, size_t varID, double& best_value, size_t& best_varID,
double& best_logrank);
void findBestSplitValueLogRankUnordered(size_t nodeID, size_t varID, double& best_value, size_t& best_varID,
double& best_logrank);
bool findBestSplitExtraTrees(size_t nodeID, std::vector<size_t>& possible_split_varIDs);
void findBestSplitValueExtraTrees(size_t nodeID, size_t varID, double& best_value, size_t& best_varID,
double& best_logrank);
void findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t varID, double& best_value, size_t& best_varID,
double& best_logrank);
void addImpurityImportance(size_t nodeID, size_t varID, double decrease);
void cleanUpInternal() {
delete[] num_deaths;
delete[] num_samples_at_risk;
}
size_t status_varID;
// Unique time points for all individuals (not only this bootstrap), sorted
std::vector<double>* unique_timepoints;
size_t num_timepoints;
size_t num_split_varIDs;
std::vector<size_t>* response_timepointIDs;
// For all terminal nodes CHF for all unique timepoints. For other nodes empty vector.
std::vector<std::vector<double>> chf;
// Fields to save to while tree growing
//std::vector<size_t>* num_deaths;
//std::vector<size_t>* num_samples_at_risk;
size_t* num_deaths;
size_t* num_samples_at_risk;
private:
DISALLOW_COPY_AND_ASSIGN(TreeSurvival);
};
#endif /* TREESURVIVAL_H_ */