diff --git a/doc/api.rst b/doc/api.rst index 25af1040..06892f5f 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -106,6 +106,8 @@ Metrics concordance_index_censored concordance_index_ipcw cumulative_dynamic_auc + brier_score + integrated_brier_score Pre-Processing diff --git a/sksurv/metrics.py b/sksurv/metrics.py index 78b3d716..11ccdbb0 100644 --- a/sksurv/metrics.py +++ b/sksurv/metrics.py @@ -21,6 +21,8 @@ 'concordance_index_censored', 'concordance_index_ipcw', 'cumulative_dynamic_auc', + 'brier_score', + 'integrated_brier_score', ] @@ -118,6 +120,29 @@ def _estimate_concordance_index(event_indicator, event_time, estimate, weights, return cindex, concordant, discordant, tied_risk, tied_time +def _interp_pred_surv(y_pred, times, fu_time): + """Interpolated survival probability at time fu_time + + Parameters + ---------- + y_pred : array + Rectangular array, each individual's conditional probability of surviving each time interval + times : array + times for which survival probability is calculated. + fu_time: array + Follow-up time point at which predictions are needed + + Returns + ------- + pred_surv_prob : array + predicted survival probability for each individual at specified follow-up time + """ + pred_surv = [] + for i in range(y_pred.shape[0]): + pred_surv.append(numpy.interp(fu_time, times, y_pred[i, :])) + return numpy.array(pred_surv) + + def concordance_index_censored(event_indicator, event_time, estimate, tied_tol=1e-8): """Concordance index for right-censored data @@ -447,3 +472,222 @@ def cumulative_dynamic_auc(survival_train, survival_test, estimate, times, tied_ mean_auc = integral / (1.0 - s_times[-1]) return scores, mean_auc + + +def brier_score(survival_train, survival_test, estimate, times, + t_max=None, + use_mean_point=False, + internal_validation=True, + **kwargs): + """ + Modification of the implementation in PySurvival by Stephane Fotso et al. + TODO: NEED TO SHIP WITH AN APACHE LICENSE + Computing the Brier score at all times t such that t <= t_max; + it represents the average squared distances between + the observed survival status and the predicted + survival probability. + In the case of right censoring, it is necessary to adjust + the score by weighting the squared distances to + avoid bias. It can be achieved by using + the inverse probability of censoring weights method (IPCW), + (proposed by Graf et al. 1999; Gerds and Schumacher 2006) + by using the estimator of the conditional survival function + of the censoring times calculated using the Kaplan-Meier method, + such that:: + + BS(t) = 1/N*( W_1(t)*(Y_1(t) - S_1(t))^2 + ... + W_N(t)*(Y_N(t) - S_N(t))^2) + + In terms of benchmarks, a useful model will have a Brier score below + 0.25. Indeed, it is easy to see that if for all i in [1,N], + if `S(t, xi) = 0.5`, then `BS(t) = 0.25`. + + Parameters + ---------- + survival_train : structured array, shape = (n_train_samples,) + Survival times for training data to estimate if training + and testing data are drawn from same sample. + Set internal_validation to True in this case. + Otherwise, use surival_test again as input. + A structured array containing the binary event indicator + as first field, and time of event or time of censoring as + second field. + + survival_test : structured array, shape = (n_samples,) + Survival times of test data. + A structured array containing the binary event indicator + as first field, and time of event or time of censoring as + second field. + + estimate : array-like, shape = (n_samples,n_times) + Estimated risk of experiencing an event for test data at `times`. + + times : array-like, shape = (n_times,) + The time points for which the predicted Survival function + is calculated and interpolation for a specific follow-up-time + will be calculated from. Values must be + within the range of follow-up times of the test data + `survival_test`. + + t_max : float + Maximal time for estimating the prediction error curves. + If missing the largest value of the response variable is used. + + use_mean_point : bool + not necessary at the moment. + Predicted survival will be calculated at the mean of a time bucket (between 2 breaks) + + Returns + ------- + times : array, shape = (n_times*) + represents the time axis (length `n_times* = n_times[times <= t_max]` at which the brier scores were + + brier_scores : array , shape = (n_times*) + values of the brier scores + + Examples + -------- + """ + # check inputs + times = check_array(numpy.atleast_1d(times), ensure_2d=False, dtype=test_time.dtype) + times = numpy.unique(times) + + # if times.max() >= test_time.max() or times.min() < test_time.min(): + # raise ValueError( + # 'all times must be within follow-up time of test data: [{}; {}['.format( + # test_time.min(), test_time.max())) + # + + # Checking the format of the data + E, T = check_y_survival(survival_test) + + # computing the Survival function at times + Survival = estimate + + # Ordering Survival, T and E in descending order according to T + order = numpy.argsort(-T) + Survival = Survival[order, :] + T = T[order] + E = E[order] + survival_test = survival_test[order] + + # fit IPCW estimator for estimation of IPCW at time t* + cens = CensoringDistributionEstimator() + if internal_validation: + cens.fit(survival_train) + else: + cens.fit(survival_test) + + # calculate inverse probability of censoring weights at observation T[i] from survival_train + struct_event_times = numpy.zeros((T.shape[0],), dtype=[('event', 'bool'), ('time', 'int64')]) + struct_event_times['time'][:] = T + struct_event_times['event'][:] = E + ipcw = cens.predict_ipcw(struct_event_times) + + # setting time to last time observed, if not t_max set + if t_max is None or t_max <= 0.: + t_max = max(T) + + # Calculating the brier scores at each t <= t_max + brierlist = [] + for t in times[times <= t_max]: + # init bs + bs = numpy.zeros((T.shape[0])) + if use_mean_point: # in case of time buckets (breaks), use mean probability in the bucket + Survival = (numpy.add(Survival, numpy.roll(Survival, 1, axis=-1))) / 2. + + is_case = (T <= t) & E + is_control = (T > t) + + # get survival function S(t) by interpolating the Survival function + S = _interp_pred_surv(Survival, times, t) + S2 = numpy.multiply(S, S) + omS2 = numpy.multiply(1 - S, 1 - S) + + # calculate inverse probability of censoring weight at current timepoint t. + struct_arr = numpy.zeros((T.shape[0],), dtype=[('event', 'bool'), ('time', 'int64')]) + struct_arr['time'][:] = t + struct_arr['event'][:] = numpy.ones((E.shape[0],)) + ipcw_t = cens.predict_ipcw(struct_arr) + + bs[is_case] = numpy.multiply(S2[is_case], ipcw[is_case]) # multiplicative IPCW at T[i] + bs[is_control] = numpy.multiply(omS2[is_control], ipcw_t[is_control]) # multiplicative IPCW at current t + brierlist.append(numpy.mean(bs)) + + return times[times <= t_max], numpy.array(brierlist) + + +def integrated_brier_score(survival_train, survival_test, estimate, times, + t_max=None, + use_mean_point=False, + internal_validation=True, + **kwargs): + """The Integrated Brier Score (IBS) provides an overall calculation of + the model performance at all available times `t<=t_max`. + If `t_max` is `None` overall model performance will be integrated over + all available times. + + Parameters + ---------- + survival_train : structured array, shape = (n_train_samples,) + Survival times for training data to estimate if training + and testing data are drawn from same sample. + Set internal_validation to True in this case. + Otherwise, use surival_test again as input. + A structured array containing the binary event indicator + as first field, and time of event or time of censoring as + second field. + + survival_test : structured array, shape = (n_samples,) + Survival times of test data. + A structured array containing the binary event indicator + as first field, and time of event or time of censoring as + second field. + + estimate : array-like, shape = (n_samples,n_times) + Estimated risk of experiencing an event for test data at `times`. + + times : array-like, shape = (n_times,) + The time points for which the predicted Survival function + is calculated and interpolation for a specific follow-up-time + will be calculated from. Values must be + within the range of follow-up times of the test data + `survival_test`. + + t_max : float + Maximal time for estimating the prediction error curves. + If missing the largest value of the response variable is used. + + use_mean_point : bool + not necessary at the moment. + Predicted survival will be calculated at the mean of a time bucket (between 2 breaks) + + Returns + ------- + times : array, shape = (n_times*) + represents the time axis (length `n_times* = n_times[times <= t_max]` at which the brier scores were + computed + + brier_scores : array , shape = (n_times*) + values of the brier scores + + Examples + -------- + + """ + # Computing the brier scores + times, brier_scores = brier_score(survival_train, survival_test, estimate, times, + t_max=t_max, + use_mean_point=False, + internal_validation=True, + ) + + # Getting the proper value of t_max + if t_max is None: + t_max = max(times) + else: + t_max = min(t_max, max(times)) + + # Computing the IBS + ibs_value = numpy.trapz(brier_scores, times) / t_max + + return ibs_value