Problem with median_survival_time_ in custom fitter #1213
-
Hi, I am very impressed how easy it was to define custom fitters in lifelines! I am particularly interested in extreme value distributions, specifically EV-I min, EV-I max, and the generalized extreme value distribution. Unfortunately, I encountered a couple of issues with the GEV custom fitter I defined. Convergence was spotty, but I found that judicious choice of Thank you, import autograd.numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from lifelines import *
from lifelines.fitters import ParametricUnivariateFitter T1 = np.array([1.1667, 1.1667, 1.1667, 1.1667, 1.1667, 1.1667, 1.1667, 1.1833,
1.1833, 1.1833, 1.1833, 1.1833, 1.2 , 1.2 , 1.2 , 1.2 ,
1.2 , 1.2 , 1.2 , 1.2 , 1.2 , 1.2167, 1.2167, 1.2167,
1.2167, 1.2167, 1.2167, 1.2167, 1.2167, 1.2167, 1.2167, 1.2167,
1.2167, 1.2167, 1.2167, 1.2167, 1.2167, 1.2333, 1.2333, 1.2333,
1.2333, 1.2333, 1.2333, 1.2333, 1.2333, 1.2333, 1.2333, 1.2333,
1.2333, 1.2333, 1.2333, 1.2333, 1.2333, 1.2333, 1.2333, 1.2333,
1.2333, 1.2333, 1.25 , 1.25 , 1.25 , 1.25 , 1.25 , 1.25 ,
1.25 , 1.25 , 1.25 , 1.25 , 1.2667, 1.2667, 1.2667, 1.2667,
1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667,
1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667,
1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667, 1.2667,
1.2667, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833,
1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833,
1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833,
1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833, 1.2833,
1.3 , 1.3 , 1.3 , 1.3 , 1.3 , 1.3 , 1.3 , 1.3 ,
1.3 , 1.3 , 1.3 , 1.3 , 1.3 , 1.3 , 1.3 , 1.3 ,
1.3 , 1.3 , 1.3 , 1.3 , 1.3 , 1.3 , 1.3 , 1.3 ,
1.3 , 1.3 , 1.3 , 1.3 , 1.3 , 1.3 , 1.3 , 1.3 ,
1.3 , 1.3 , 1.3 , 1.3 , 1.3 , 1.3 , 1.3 , 1.3 ,
1.3 , 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167,
1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167,
1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167,
1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167,
1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3167, 1.3333, 1.3333,
1.3333, 1.3333, 1.3333, 1.3333, 1.3333, 1.3333, 1.3333, 1.3333,
1.3333, 1.3333, 1.3333, 1.3333, 1.3333, 1.3333, 1.3333, 1.3333,
1.3333, 1.3333, 1.3333, 1.35 , 1.35 , 1.35 , 1.35 , 1.35 ,
1.35 , 1.35 , 1.35 , 1.35 , 1.35 , 1.35 , 1.35 , 1.35 ,
1.35 , 1.35 , 1.35 , 1.35 , 1.35 , 1.35 , 1.35 , 1.35 ,
1.35 , 1.35 , 1.35 , 1.35 , 1.35 , 1.35 , 1.35 , 1.35 ,
1.35 , 1.35 , 1.35 , 1.35 , 1.35 , 1.35 , 1.35 , 1.35 ,
1.35 , 1.3667, 1.3667, 1.3667, 1.3667, 1.3667, 1.3667, 1.3667,
1.3667, 1.3667, 1.3667, 1.3667, 1.3667, 1.3667, 1.3833, 1.3833,
1.3833, 1.3833, 1.3833, 1.3833, 1.4 , 1.4167, 1.4167, 1.4333,
1.4833, 1.4833]) T2 = np.array([1.6 , 1.6 , 1.6167, 1.6333, 1.6333, 1.6333, 1.6333, 1.6333,
1.6333, 1.6333, 1.65 , 1.65 , 1.65 , 1.65 , 1.65 , 1.65 ,
1.65 , 1.65 , 1.65 , 1.65 , 1.65 , 1.65 , 1.65 , 1.65 ,
1.65 , 1.6667, 1.6667, 1.6667, 1.6667, 1.6667, 1.6667, 1.6667,
1.6667, 1.6667, 1.6667, 1.6667, 1.6667, 1.6667, 1.6667, 1.6667,
1.6833, 1.6833, 1.6833, 1.6833, 1.6833, 1.6833, 1.6833, 1.6833,
1.6833, 1.6833, 1.6833, 1.6833, 1.6833, 1.6833, 1.6833, 1.6833,
1.7 , 1.7 , 1.7 , 1.7 , 1.7 , 1.7 , 1.7 , 1.7 ,
1.7 , 1.7 , 1.7 , 1.7 , 1.7 , 1.7 , 1.7 , 1.7 ,
1.7 , 1.7 , 1.7167, 1.7167, 1.7167, 1.7167, 1.7167, 1.7167,
1.7167, 1.7167, 1.7167, 1.7167, 1.7167, 1.7167, 1.7167, 1.7167,
1.7167, 1.7333, 1.7333, 1.7333, 1.7333, 1.7333, 1.75 , 1.75 ,
1.75 , 1.75 , 1.75 , 1.75 , 1.75 , 1.75 , 1.75 , 1.75 ,
1.75 , 1.75 , 1.75 , 1.75 , 1.75 , 1.75 , 1.7667, 1.7667,
1.7667, 1.7667, 1.7667, 1.7667, 1.7667, 1.7667, 1.7667, 1.7667,
1.7667, 1.7667, 1.7667, 1.7667, 1.7833, 1.7833, 1.7833, 1.7833,
1.7833, 1.7833, 1.7833, 1.7833, 1.7833, 1.7833, 1.8 , 1.8 ,
1.8 , 1.8 , 1.8 , 1.8 , 1.8 , 1.8 , 1.8 , 1.8 ,
1.8 , 1.8 , 1.8167, 1.8167, 1.8167, 1.8167, 1.8167, 1.8167,
1.8167, 1.8167, 1.8167, 1.8167, 1.8167, 1.8167, 1.8167, 1.8167,
1.8167, 1.8167, 1.8333, 1.8333, 1.8333, 1.8333, 1.8333, 1.8333,
1.8333, 1.85 , 1.85 , 1.85 , 1.85 , 1.85 , 1.85 , 1.8667,
1.8667, 1.8667, 1.8833, 1.8833, 1.8833, 1.8833, 1.9 , 1.9167,
1.9167, 1.9333, 1.95 , 1.9833, 2.0333, 2.0667, 2.1 ]) class GEVDistFitter(ParametricUnivariateFitter):
_fitted_parameter_names = ["J_", "mu_", "k_"]
_bounds = [(0, None), (None, None), (None, None)]
def _cumulative_hazard(self, params, times):
J_, mu_, k_ = params
z = J_*(times - mu_)
return -np.log1p(-np.exp(-np.power(1+k_*z,-1/k_))) class EV1maxDistFitter(ParametricUnivariateFitter):
_fitted_parameter_names = ["J_", "mu_"]
#_bounds = [(0, None), (0, None), (0, T.min()-0.001)]
def _cumulative_hazard(self, params, times):
J_, mu_ = params
z = J_*(times - mu_)
return -np.log1p(-np.exp(-np.exp(-z))) class EV1minDistFitter(ParametricUnivariateFitter):
_fitted_parameter_names = ["J_", "mu_"]
#_bounds = [(0, None), (0, None), (0, T.min()-0.001)]
def _cumulative_hazard(self, params, times):
J_, mu_ = params
z = J_*(times - mu_)
return np.exp(z) kmf1 = KaplanMeierFitter().fit(T1,label='KM')
gev1 = GEVDistFitter().fit(T1,initial_point=np.array([10,1.3,-0.1]),label='GEV')
ev1max1 = EV1maxDistFitter().fit(T1,label='EV1_max')
ev1min1 = EV1minDistFitter().fit(T1,label='EV1_min')
kmf2 = KaplanMeierFitter().fit(T2,label='KM')
gev2 = GEVDistFitter().fit(T2,initial_point=np.array([10,1.3,-0.1]),label='GEV')
ev1max2 = EV1maxDistFitter().fit(T2,label='EV1_max')
ev1min2 = EV1minDistFitter().fit(T2,label='EV1_min') fig, axes = plt.subplots(3, 2,sharey=True, figsize = (8,10))
for ii in range(3):
kmf1.plot_survival_function(ax=axes[ii,0])
axes[ii,0].set_xlim(0,3)
kmf2.plot_survival_function(ax=axes[ii,1])
axes[ii,1].set_xlim(0,3)
gev1.plot_survival_function(ax=axes[0,0])
ev1max1.plot_survival_function(ax=axes[1,0])
ev1min1.plot_survival_function(ax=axes[2,0])
gev2.plot_survival_function(ax=axes[0,1])
ev1max2.plot_survival_function(ax=axes[1,1])
ev1min2.plot_survival_function(ax=axes[2,1])
tau1 = np.array(
[[gev1.median_survival_time_,
ev1max1.median_survival_time_,
ev1min1.median_survival_time_],
[(np.power(np.log(2),-gev1.k_)-1)/(gev1.J_*gev1.k_)+gev1.mu_,
ev1max1.mu_-1/ev1max1.J_*np.log(np.log(2)),
ev1min1.mu_+1/ev1min1.J_*np.log(np.log(2))]])
tau2 = np.array(
[[gev2.median_survival_time_,
ev1max2.median_survival_time_,
ev1min2.median_survival_time_],
[(np.power(np.log(2),-gev2.k_)-1)/(gev2.J_*gev2.k_)+gev2.mu_,
ev1max2.mu_-1/ev1max2.J_*np.log(np.log(2)),
ev1min2.mu_+1/ev1min2.J_*np.log(np.log(2))]]) pd.DataFrame(tau1,columns=['GEV1','EV1max1','EV1min1'],index=['lifelines','analytical'])
<style scoped>
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
pd.DataFrame(tau2,columns=['GEV2','EV1max2','EV1min2'],index=['lifelines','analytical'])
<style scoped>
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Hi @DerkJoester, thanks for the detailed issue. I'll have to dig in further as to what might be happening, but for your case: if you know the analytical formula, you can add it to the class, see example here. Note that |
Beta Was this translation helpful? Give feedback.
Hi @DerkJoester, thanks for the detailed issue. I'll have to dig in further as to what might be happening, but for your case: if you know the analytical formula, you can add it to the class, see example here. Note that
percentile
is called bymedian_survival_time_
, and if the former is not available, it's numerically computed - which is where I think the problem you are seeing lies.