diff --git a/lifelines/fitters/cox_time_varying_fitter.py b/lifelines/fitters/cox_time_varying_fitter.py index 29cfda683..5de45b07c 100644 --- a/lifelines/fitters/cox_time_varying_fitter.py +++ b/lifelines/fitters/cox_time_varying_fitter.py @@ -80,8 +80,12 @@ class CoxTimeVaryingFitter(SemiParametricRegressionFitter, ProportionalHazardMix the strata provided standard_errors_: Series the standard errors of the estimates + baseline_hazard_: DataFrame + baseline hazard with confidence bounds baseline_cumulative_hazard_: DataFrame + baseline cumulative hazard with confidence bounds baseline_survival_: DataFrame + baseline survival with confidence bounds """ _KNOWN_MODEL = True @@ -225,7 +229,8 @@ def fit( normalize(X, self._norm_mean, self._norm_std), events, start, stop, weights ) self.confidence_intervals_ = self._compute_confidence_intervals() - self.baseline_cumulative_hazard_ = self._compute_cumulative_baseline_hazard(df, events, start, stop, weights) + self.baseline_hazard_ = self._compute_baseline_hazard(df, events, start, stop, weights) + self.baseline_cumulative_hazard_ = self._compute_baseline_cumulative_hazard() self.baseline_survival_ = self._compute_baseline_survival() self.event_observed = events self.start_stop_and_events = pd.DataFrame({"event": events, "start": start, "stop": stop}) @@ -788,14 +793,16 @@ def plot(self, columns=None, ax=None, **errorbar_kwargs): return ax - def _compute_cumulative_baseline_hazard(self, tv_data, events, start, stop, weights): # pylint: disable=too-many-locals + def _compute_baseline_hazard(self, tv_data, events, start, stop, weights): # pylint: disable=too-many-locals with warnings.catch_warnings(): warnings.simplefilter("ignore") hazards = self.predict_partial_hazard(tv_data).values unique_death_times = np.unique(stop[events.values]) - baseline_hazard_ = pd.DataFrame(np.zeros_like(unique_death_times), index=unique_death_times, columns=["baseline hazard"]) + baseline_hazard_ = pd.DataFrame(np.zeros_like(unique_death_times), index=unique_death_times, columns=["baseline_hazard"]) + variance_baseline_hazard_ = pd.DataFrame(np.zeros_like(unique_death_times), index=unique_death_times, columns= + ["variance_baseline_hazard"]) for t in unique_death_times: ix = (start.values < t) & (t <= stop.values) @@ -807,15 +814,34 @@ def _compute_cumulative_baseline_hazard(self, tv_data, events, start, stop, weig deaths = events_at_t & (stops_at_t == t) - death_counts = (weights_at_t.squeeze() * deaths).sum() # should always be atleast 1. + death_counts = (weights_at_t.squeeze() * deaths).sum() # should always be at least 1. + variance_baseline_hazard_.loc[t] = death_counts / hazards_at_t.sum()**2 baseline_hazard_.loc[t] = death_counts / hazards_at_t.sum() + # klein and moeschberger, 2013, p.283 + z = inv_normal_cdf(1 - self.alpha / 2) + ci_labels = ["%s_upper_%g" % ("baseline_hazard", 1 - self.alpha), "%s_lower_%g" % ("baseline_hazard", self.alpha)] + upper = baseline_hazard_.values * np.exp(+z * np.sqrt(variance_baseline_hazard_.values) / baseline_hazard_.values) + lower = baseline_hazard_.values * np.exp(-z * np.sqrt(variance_baseline_hazard_.values) / baseline_hazard_.values) + baseline_hazard_[ci_labels[0]] = upper + baseline_hazard_[ci_labels[1]] = lower + + baseline_hazard_.loc[0] = [0,0,0] + baseline_hazard_ = baseline_hazard_.sort_index() + + return baseline_hazard_ - return baseline_hazard_.cumsum() + def _compute_baseline_cumulative_hazard(self): + baseline_cumulative_hazard = self.baseline_hazard_.cumsum() + baseline_cumulative_hazard.columns = baseline_cumulative_hazard.columns[:].str.replace("baseline_hazard", + "baseline_cumulative_hazard") + + return baseline_cumulative_hazard def _compute_baseline_survival(self): - survival_df = np.exp(-self.baseline_cumulative_hazard_) - survival_df.columns = ["baseline survival"] - return survival_df + survival = np.exp(-self.baseline_cumulative_hazard_.baseline_cumulative_hazard) + survival.name = "survival" + + return survival def __repr__(self): classname = self._class_name