-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfindBestSplitValueLR.cpp
176 lines (145 loc) · 6.9 KB
/
findBestSplitValueLR.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
//
// findBestSplitValueLR.cpp
//
//
// Created by yichen zhou on 12/10/17.
//
//
#include <algorithm>
#include <cmath>
#include <iostream>
#include <iterator>
#include <numeric>
#include <vector>
#include "utility.h"
#include "TreeSurvival.h"
#include "Data.h"
#include "findBestSplitValueLR.h"
//to do construct
//in construct sampleIDs=0 is wrong
findBestSplitValueLR::findBestSplitValueLR(){};
//findBestSplitValueLR::findBestSplitValueLR(): data(0), sampleIDs(0), nodeID(0) , varID(0), unique_timepoints(0), status_varID(0),response_timepointIDs(0), min_node_size(0) {
//}
//init
//init
void findBestSplitValueLR::init(Data* data, std::vector<std::vector<size_t>>& sampleIDs, size_t nodeID ,size_t varID, std::vector<double>* unique_timepoints,
size_t status_varID,std::vector<size_t>* response_timepointIDs,size_t min_node_size,size_t* num_deaths, size_t* num_samples_at_risk)
{
std:: cout << "print some1 in findBestSplitValueLR "<< std::endl;
this->data=data;
this->sampleIDs=sampleIDs;
this->nodeID= nodeID;
this->varID=varID;
this->unique_timepoints=unique_timepoints;
this->status_varID=status_varID;
this->response_timepointIDs=response_timepointIDs;
this->min_node_size=min_node_size;
// num_timepoints = unique_timepoints->size();
this->num_deaths =num_deaths ;
this->num_samples_at_risk = num_samples_at_risk;
double value =data->get(nodeID, varID);
size_t i=0;
std:: cout << "inside findBestSplitValueLR value is " << value << "\n"<< std::endl;
std:: cout << "sampleIDs is " << sampleIDs[nodeID].size() << "\n"<< std::endl;
std:: cout << "unique_timepoints is " << unique_timepoints[i].size() << "\n"<< std::endl;
std:: cout << "response_timepointIDs is " << response_timepointIDs[i].size() << "\n"<< std::endl;
std:: cout << "status_varID is "<< status_varID << "\n"<< std::endl;
std:: cout << "num_deaths is " << num_deaths[i] << "\n"<< std::endl;
std:: cout << "num_samples_at_risk is " << num_samples_at_risk[i] << "\n"<< std::endl;
std:: cout << "min_node_size is "<< min_node_size << "\n"<< std::endl;
this->varID=varID;
std::cout << typeid(varID).name() << std::endl;
std:: cout << "inside findBestSplitValueLR varID is " << varID << "\n"<< std::endl;
}
//deconstruct todo
findBestSplitValueLR::~findBestSplitValueLR(){
// Delete sampleID vector to save memory
//sampleIDs.clear();
//sampleIDs.shrink_to_fit();
}
void findBestSplitValueLR::printSome()
{
std:: cout << "print some2 in findBestSplitValueLR "<< std::endl;
}
void findBestSplitValueLR::findBestSplitValueLogRank1(size_t nodeID, size_t varID, double& best_value,double& best_logrank) {
// Create possible split values
std::vector<double> possible_split_values;
data->getAllValues(possible_split_values, sampleIDs[nodeID], varID);
// Try next variable if all equal for this
if (possible_split_values.size() < 2) {
return;
}
// -1 because no split possible at largest value
size_t num_splits = possible_split_values.size() - 1;
// Initialize
size_t* num_deaths_right_child = new size_t[num_splits * num_timepoints]();
size_t* delta_samples_at_risk_right_child = new size_t[num_splits * num_timepoints]();
size_t* num_samples_right_child = new size_t[num_splits]();
computeChildDeathCounts1(nodeID, varID, possible_split_values, num_samples_right_child,
delta_samples_at_risk_right_child, num_deaths_right_child, num_splits);
// Compute logrank test for all splits and use best
for (size_t i = 0; i < num_splits; ++i) {
double numerator = 0;
double denominator_squared = 0;
// Stop if minimal node size reached
size_t num_samples_left_child = sampleIDs[nodeID].size() - num_samples_right_child[i];
if (num_samples_right_child[i] < min_node_size || num_samples_left_child < min_node_size) {
continue;
}
// Compute logrank test statistic for this split
size_t num_samples_at_risk_right_child = num_samples_right_child[i];
for (size_t t = 0; t < num_timepoints; ++t) {
if (num_samples_at_risk[t] < 2 || num_samples_at_risk_right_child < 1) {
break;
}
if (num_deaths[t] > 0) {
// Numerator and demoninator for log-rank test, notation from Ishwaran et al.
double di = (double) num_deaths[t];
double di1 = (double) num_deaths_right_child[i * num_timepoints + t];
double Yi = (double) num_samples_at_risk[t];
double Yi1 = (double) num_samples_at_risk_right_child;
numerator += di1 - Yi1 * (di / Yi);
denominator_squared += (Yi1 / Yi) * (1.0 - Yi1 / Yi) * ((Yi - di) / (Yi - 1)) * di;
}
// Reduce number of samples at risk for next timepoint
num_samples_at_risk_right_child -= delta_samples_at_risk_right_child[i * num_timepoints + t];
}
double logrank = -1;
if (denominator_squared != 0) {
logrank = fabs(numerator / sqrt(denominator_squared));
}
// if (logrank > best_logrank) {
best_value = (possible_split_values[i] + possible_split_values[i + 1]) / 2;
//best_varID = varID;
best_logrank = logrank;
// Use smaller value if average is numerically the same as the larger value
if (best_value == possible_split_values[i + 1]) {
best_value = possible_split_values[i];
}
// }
}
delete[] num_deaths_right_child;
delete[] delta_samples_at_risk_right_child;
delete[] num_samples_right_child;
}
void findBestSplitValueLR::computeChildDeathCounts1(size_t nodeID, size_t varID, std::vector<double>& possible_split_values,
size_t* num_samples_right_child, size_t* delta_samples_at_risk_right_child, size_t* num_deaths_right_child,
size_t num_splits) {
// Count deaths in right child per timepoint and possbile split
for (auto& sampleID : sampleIDs[nodeID]) {
double value = data->get(sampleID, varID);
size_t survival_timeID = (*response_timepointIDs)[sampleID];
// Count deaths until split_value reached
for (size_t i = 0; i < num_splits; ++i) {
if (value > possible_split_values[i]) {
++num_samples_right_child[i];
++delta_samples_at_risk_right_child[i * num_timepoints + survival_timeID];
if (data->get(sampleID, status_varID) == 1) {
++num_deaths_right_child[i * num_timepoints + survival_timeID];
}
} else {
break;
}
}
}
}