Comparing Shap values between high and low thrombolysing hospitals#

Plain English summary#

It has become apparent that hospitals are making different clinical decisions, and that it is not just the patient group that is accounting for the difference in thrombolysis rates between hosptials. For example, hospitals presented with the same information can make opposing decisions on whether or not to give a patient thrombolysis. We would like to pick apart what factors different hospitals are using to make their different decisions. Here we create two groups of hospitals: those that have a high propensity to thrombolyse, and those that have a low propensity to thrombolyse. By comparing the SHAP values for a specific feature between these two hospital groups shows the contribution this feature is having on their different decision of whether to give thrombolysis.

We see that low and high thrombolysing hosptials have the same general pattern, that a patient is more likely to recieve thromboylsis if they have a mid-level stroke severity. But it can be seen that the more cautious hospitals are even less likely to thrombolyse mild and severe strokes, are more likely to thrombolyse mid-level strokes. The wider the net for including hospitals into the two groups (9 vs 30) the less pronounced the difference.

Model and data#

The necessary data for this analysis was created in previous notebooks (that trained the models): feature data, SHAP values.

These XGBoost models were trained on stratified k-fold cross-validation data. The 8 features in the model are:

  • Arrival-to-scan time: Time from arrival at hospital to scan (mins)

  • Infarction: Stroke type (1 = infarction, 0 = haemorrhage)

  • Stroke severity: Stroke severity (NIHSS) on arrival

  • Precise onset time: Onset time type (1 = precise, 0 = best estimate)

  • Prior disability level: Disability level (modified Rankin Scale) before stroke

  • Stroke team: Stroke team attended

  • Use of AF anticoagulents: Use of atrial fibrillation anticoagulant (1 = Yes, 0 = No)

  • Onset-to-arrival time: Time from onset of stroke to arrival at hospital (mins)

And one target feature:

  • Thrombolysis: Recieve thrombolysis (1 = Yes, 0 = No)

The 8 features included in the model (to predict whether a patient will recieve thrombolysis) were chosen sequentially as having the single best improvement in model performance (using the ROC AUC). The stroke team feature is included as a one-hot encoded feature.

The Python library SHAP was applied to the first k-fold model to obtain a SHAP value for each feature, for each instance. SHAP values are in the same units as the model output, so for XGBoost this is in log odds-ratio.

A single SHAP value per feature was obtained by taking the mean of the absolute values across all instances.

A note on Shap values#

Shap values are usually reported as log odds shifts in model predictions. For a description of the relationships between probability, odds, and Shap values (log odds shifts) see here.

Aims:#

  • Identify top 30 and bottom 30 hosptials judged by the 10k cohort dataset

  • Plot SHAP values for the feature stroke severity (NIHSS) for the top and bottom hospitals

  • Represent the data for the patients that did, and did not, receive thrombolysis

  • Repeat for top and bottom 9 hospitals

  • Repeat for feature prior disability (mRS)

Observations#

  • Low and high thrombolysing hosptials have the same general pattern: that a patient is more likely to recieve thromboylsis if they have a mid-level stroke severity.

  • The more cautious hospitals are even less likely to thrombolyse mild and severe strokes, are more likely to thrombolyse mid-level strokes.

  • The wider the net for including hospitals into the two groups (9 vs 30) the less pronounced the difference.

  • Dividing the patients into did/did not receive thrombolysis shows for those

Note on shap version 0.40:#

Installed using pip install shap

There is a bug in the waterfall plot where show=False (required to save plot) fails. To correct this find the _waterfall.py file in the shap library (e.g. in anaconda/envs/samuel2/lib/python3.8/site-packages/shap/plots, and search/replace all pl. to plt., and replace initial import of import matplotlib.pyplot as pl with import matplotlib.pyplot as plt.

Load packages#

# Turn warnings off to keep notebook tidy
import warnings
warnings.filterwarnings("ignore")

import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import shap

from xgboost import XGBClassifier

import json

Create output folders if needed#

path = './saved_models'
if not os.path.exists(path):
    os.makedirs(path)

Load data on predicted 10k cohort thrombolysis use at each hospital#

Use the hospitals thrombolysis rate on the same set of 10k patients to select the 30 hospitals with the highest thrombolysis rates.

thrombolysis_by_hosp = pd.read_csv(
    './output/10k_thrombolysis_rate_by_hosp_key_features.csv', index_col='stroke_team')
thrombolysis_by_hosp.sort_values(
    'Thrombolysis rate', ascending=False, inplace=True)
thrombolysis_by_hosp.head()
Thrombolysis rate
stroke_team
VKKDD9172T 0.4610
GKONI0110I 0.4356
CNBGF2713O 0.4207
HPWIF9956L 0.4191
MHMYL4920B 0.3981

Load SHAP data for first k-fold model#

Use explainer to access the feature names (explainer.data_feature_names) and shap_values_extended to access the data values and SHAP values

k = 0
filename = f'./output/shap_values_extended_xgb_key_features_{k}.p'
with open(filename, 'rb') as filehandler:
    shap_values_extended = pickle.load(filehandler)
    
filename = f'./output/shap_values_explainer_xgb_key_features_{k}.p'
with open(filename, 'rb') as filehandler:
    explainer = pickle.load(filehandler)

Load target feature data for first k-fold model#

data_loc = '../data/kfold_5fold/'

feature = 'S2Thrombolysis'

test = pd.read_csv(data_loc + 'test_0.csv')
target_feature = test[feature]

Prepare the datasets#

Create Dataframe containing the feature values and target value (with feature names as column titles)

data_values_df = pd.DataFrame(shap_values_extended.data)
data_values_df.columns = explainer.data_feature_names
data_values_df = data_values_df.join(target_feature)
data_values_df
Arrival-to-scan time Infarction Stroke severity Precise onset time Prior disability level Use of AF anticoagulents Onset-to-arrival time team_AGNOF1041H team_AKCGO9726K team_AOBTM3098N ... team_XPABC1435F team_XQAGA4299B team_XWUBX0795L team_YEXCH8391J team_YPKYH1768F team_YQMZV4284N team_ZBVSO0975W team_ZHCLE1578P team_ZRRCV7012C S2Thrombolysis
0 17.0 1.0 14.0 1.0 0.0 0.0 186.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1
1 25.0 1.0 6.0 1.0 0.0 0.0 71.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1
2 138.0 1.0 2.0 1.0 0.0 0.0 67.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0
3 21.0 0.0 11.0 1.0 0.0 0.0 86.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0
4 8.0 1.0 16.0 1.0 0.0 0.0 83.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
17754 8.0 1.0 4.0 1.0 0.0 0.0 105.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0
17755 35.0 1.0 25.0 0.0 3.0 0.0 208.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0
17756 80.0 1.0 2.0 0.0 1.0 0.0 236.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0
17757 16.0 1.0 10.0 1.0 0.0 0.0 34.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1
17758 23.0 1.0 10.0 1.0 0.0 0.0 125.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0

17759 rows × 140 columns

Create Dataframe containing the SHAP values (with feature names as column titles)

#explainer.data_feature_names
shap_values_df = pd.DataFrame(shap_values_extended.values)
shap_values_df.columns = explainer.data_feature_names
shap_values_df
Arrival-to-scan time Infarction Stroke severity Precise onset time Prior disability level Use of AF anticoagulents Onset-to-arrival time team_AGNOF1041H team_AKCGO9726K team_AOBTM3098N ... team_XKAWN3771U team_XPABC1435F team_XQAGA4299B team_XWUBX0795L team_YEXCH8391J team_YPKYH1768F team_YQMZV4284N team_ZBVSO0975W team_ZHCLE1578P team_ZRRCV7012C
0 1.665572 1.764335 1.519817 0.533487 0.516871 0.261582 -0.351841 0.006454 -0.000906 0.0 ... 0.0 0.0 0.012956 0.0 0.0 0.0 0.0 0.0 -0.005460 0.007542
1 1.042294 1.674223 0.856954 0.567241 0.615244 0.261544 -0.050369 -0.002464 -0.002205 0.0 ... 0.0 0.0 0.014900 0.0 0.0 0.0 0.0 0.0 -0.006770 0.008916
2 -0.826964 1.314690 -1.217614 0.609297 0.652299 0.212339 0.475219 -0.002121 -0.005571 0.0 ... 0.0 0.0 0.014900 0.0 0.0 0.0 0.0 0.0 0.000129 0.008393
3 0.516566 -7.852724 0.652316 0.334562 0.286924 0.083673 -0.004301 0.000162 0.000863 0.0 ... 0.0 0.0 0.017112 0.0 0.0 0.0 0.0 0.0 -0.006360 0.008215
4 1.597652 1.797345 1.430198 0.551891 0.520666 0.206812 0.149001 0.014654 0.000492 0.0 ... 0.0 0.0 0.017112 0.0 0.0 0.0 0.0 0.0 -0.003191 0.008034
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
17754 1.450873 1.646119 0.159537 0.893681 0.592578 0.227030 -0.014525 -0.000794 -0.001883 0.0 ... 0.0 0.0 0.011911 0.0 0.0 0.0 0.0 0.0 -0.003646 0.005996
17755 0.327505 1.435369 0.858219 -0.925355 -0.411963 0.139809 -0.969236 0.003032 -0.000350 0.0 ... 0.0 0.0 0.012956 0.0 0.0 0.0 0.0 0.0 -0.008049 0.013558
17756 -0.835297 1.215644 -0.853170 -0.755055 -0.054575 0.254649 -2.016223 0.004402 -0.008321 0.0 ... 0.0 0.0 0.015529 0.0 0.0 0.0 0.0 0.0 0.001583 0.007954
17757 1.188463 1.760321 1.185057 0.486597 0.581963 0.170620 0.500511 0.007738 -0.000324 0.0 ... 0.0 0.0 0.008683 0.0 0.0 0.0 0.0 0.0 -0.013498 0.007206
17758 1.234641 1.749588 0.820623 0.555529 0.406468 0.227709 -0.020579 0.002986 -0.000084 0.0 ... 0.0 0.0 0.012956 0.0 0.0 0.0 0.0 0.0 -0.006360 0.007542

17759 rows × 139 columns

Plot SHAP values for Stroke Severity for the top and bottom 30 hospitals#

Identify the top and bottom 30 hospitals with highest thrombolysis rates for the common set of 10k patients. Store their stroke team name.

top_30_hospitals = list(thrombolysis_by_hosp.head(30).index)
bottom_30_hospitals = list(thrombolysis_by_hosp.tail(30).index)

Define the function to plot the SHAP values for stroke severity for the two hospital groups.

def plot_shap_per_hospital_group(feature, top_hospital_list, 
                                 bottom_hospital_list, data_values_df, 
                                 shap_values_df,
                                 thrombolysis_given=-1):

    """
    If thrombolysis_given is passed (0 or 1) then use mask to split patients
    based on this. If not given, then -1 is flag to not split
    """
    
    # Calculate data for top n hospitals
    shap_df = pd.DataFrame()
    for hospital in top_hospital_list:
        column_name = f"team_{hospital}"
        mask1 = data_values_df[column_name]!=0
        if thrombolysis_given == -1:
            #-1 is flag to not split on this value, so create a mask all 1
            mask2 = mask1 + ~mask1
        else:
            # split based on thrombolysis_given
            mask2 = data_values_df['S2Thrombolysis'] == thrombolysis_given
        mask = mask1 * mask2
        s1 = pd.Series(shap_values_df[feature][mask], name="shap")
        s2 = pd.Series(data_values_df[feature][mask], name="data")
        df = pd.concat([s1, s2], axis=1)
        shap_df = shap_df.append(df)

    # calculate mean per category
    mean_by_category_top = shap_df.groupby('data').mean()

    # Calculate data for bottom n hospitals
    shap_df = pd.DataFrame()
    for hospital in bottom_hospital_list:
        column_name = f"team_{hospital}"
        mask1 = data_values_df[column_name]!=0
        if thrombolysis_given == -1:
            #-1 is flag to not split on this value, so create a mask all 1
            mask2 = mask1 + ~mask1
        else:
            # split based on thrombolysis_given
            mask2 = data_values_df['S2Thrombolysis'] == thrombolysis_given
        mask = mask1 * mask2
        s1 = pd.Series(shap_values_df[feature][mask], name="shap")
        s2 = pd.Series(data_values_df[feature][mask], name="data")
        df = pd.concat([s1, s2], axis=1)
        shap_df = shap_df.append(df)
            
    # calculate mean per category
    mean_by_category_bottom = shap_df.groupby('data').mean()
        
    # plot means
    fig = plt.figure(figsize=(6,6))
    ax1 = fig.add_subplot(111)
    n_hospitals = len(top_hospital_list)
    ax1.scatter(mean_by_category_top.index,
                mean_by_category_top['shap'],
                label=f"top {n_hospitals} hospitals")
    n_hospitals = len(bottom_hospital_list)
    ax1.scatter(mean_by_category_bottom.index,
                mean_by_category_bottom['shap'],
                label=f"bottom {n_hospitals} hospitals")
    ax1.set_xlabel(f'{feature} feature value')
    ax1.set_ylabel(f'{feature} SHAP value (mean of patients)')
    ax1.grid()
    title = ""
    if thrombolysis_given == 0:
        title = "(not receiving thrombolysis) "
    elif thrombolysis_given == 1:
        title = "(receiving thrombolysis) "
    ax1.set_title(f'{feature} SHAP values for patients {title}\nattending a low or ' +
                  f'high thrombolysing hospital')
    ax1.legend(loc='upper right', bbox_to_anchor=(1.5, 1))
    plt.show()

Plot the SHAP values for stroke severity for the two hospital groups.

plot_shap_per_hospital_group('Stroke severity', top_30_hospitals, 
                             bottom_30_hospitals, data_values_df, 
                             shap_values_df)
../../_images/08_xgb_combined_fit_shap_high_vs_low_thrombolysing_units_23_0.png

Plot the SHAP values for stroke severity for the two hospital groups, split the patients by whether received thrombolysis.

thrombolysed = [0,1]

for thrombolysis_given in thrombolysed:
    plot_shap_per_hospital_group('Stroke severity', top_30_hospitals, 
                                 bottom_30_hospitals, data_values_df, 
                                 shap_values_df, 
                                 thrombolysis_given=thrombolysis_given)
../../_images/08_xgb_combined_fit_shap_high_vs_low_thrombolysing_units_25_0.png ../../_images/08_xgb_combined_fit_shap_high_vs_low_thrombolysing_units_25_1.png

Plot the 30 hosptials individually

def plot_shap_per_hospital(feature, top_hospital_list, 
                           bottom_hospital_list, data_values_df, 
                           shap_values_df,
                           thrombolysis_given=-1):

    title = ""
    if thrombolysis_given == 0:
        title = "(not receiving thrombolysis) "
    elif thrombolysis_given == 1:
        title = "(receiving thrombolysis) "

    fig = plt.figure(figsize=(10,5))
    ax = fig.add_subplot(121)
    
    n_hospitals = len(bottom_hospital_list)
    ax.set_title(f'{feature} SHAP values for patients \n{title}\nattending ' +
                 f'bottom {n_hospitals} thrombolysing hospital')

    # Calculate data for bottom n hospitals
    for hospital in bottom_hospital_list:
        column_name = f"team_{hospital}"
        mask1 = data_values_df[column_name]!=0
        if thrombolysis_given == -1:
            #-1 is flag to not split on this value, so create a mask all 1
            mask2 = mask1 + ~mask1
        else:
            # split based on thrombolysis_given
            mask2 = data_values_df['S2Thrombolysis'] == thrombolysis_given
        mask = mask1 * mask2
        s1 = pd.Series(shap_values_df[feature][mask], name="shap")
        s2 = pd.Series(data_values_df[feature][mask], name="data")
        df = pd.concat([s1, s2], axis=1)
        plot_hospital_data = df.groupby('data').mean()
        ax.plot(plot_hospital_data.index,
                    plot_hospital_data['shap'],
                    label=hospital)
    #ax.legend(loc='lower center', prop={'size': 8})
    ax.grid()
    ax.set_xlabel(f'{feature}')
    ax.set_ylabel('SHAP (mean of patients)')
    ax.set_ylim(-3, 2.5)

    ax = fig.add_subplot(122)
    n_hospitals = len(top_hospital_list)
    ax.set_title(f'{feature} SHAP values for patients \n{title}\nattending ' +
                 f'top {n_hospitals} thrombolysing hospital')

    # Calculate data for top n hospitals
    for hospital in top_hospital_list:
        column_name = f"team_{hospital}"
        mask1 = data_values_df[column_name]!=0
        if thrombolysis_given == -1:
            #-1 is flag to not split on this value, so create a mask all 1
            mask2 = mask1 + ~mask1
        else:
            # split based on thrombolysis_given
            mask2 = data_values_df['S2Thrombolysis'] == thrombolysis_given
        mask = mask1 * mask2
        s1 = pd.Series(shap_values_df[feature][mask], name="shap")
        s2 = pd.Series(data_values_df[feature][mask], name="data")
        df = pd.concat([s1, s2], axis=1)
        plot_hospital_data = df.groupby('data').mean()
        ax.plot(plot_hospital_data.index,
                    plot_hospital_data['shap'],
                    label=hospital)
    #ax.legend(loc='lower center', prop={'size': 8})
    ax.grid()
    ax.set_xlabel(f'{feature}')
    ax.set_ylabel('SHAP (mean of patients)')
    ax.set_ylim(-3, 2.5)
    plt.show()
plot_shap_per_hospital('Stroke severity', top_30_hospitals, 
                       bottom_30_hospitals, data_values_df, 
                       shap_values_df)
../../_images/08_xgb_combined_fit_shap_high_vs_low_thrombolysing_units_28_0.png

Plot the 30 hosptials individually, split the patients by whether received thrombolysis.

thrombolysed = [0,1]

for thrombolysis_given in thrombolysed:
    plot_shap_per_hospital('Stroke severity', top_30_hospitals, 
                           bottom_30_hospitals, data_values_df, 
                           shap_values_df,
                           thrombolysis_given=thrombolysis_given)
../../_images/08_xgb_combined_fit_shap_high_vs_low_thrombolysing_units_30_0.png ../../_images/08_xgb_combined_fit_shap_high_vs_low_thrombolysing_units_30_1.png

Repeat for groups with 9 hospitals#

top_9_hospitals = list(thrombolysis_by_hosp.head(9).index)
bottom_9_hospitals = list(thrombolysis_by_hosp.tail(9).index)

Plot the SHAP values for stroke severity for the two hospital groups.

plot_shap_per_hospital_group('Stroke severity', top_9_hospitals, 
                             bottom_9_hospitals, data_values_df, 
                             shap_values_df)
../../_images/08_xgb_combined_fit_shap_high_vs_low_thrombolysing_units_34_0.png

Plot the SHAP values for stroke severity for the two hospital groups, split the patients by whether received thrombolysis.

thrombolysed = [0,1]

for thrombolysis_given in thrombolysed:
    plot_shap_per_hospital_group('Stroke severity', top_9_hospitals, 
                                 bottom_9_hospitals, data_values_df, 
                                 shap_values_df, 
                                 thrombolysis_given=thrombolysis_given)
../../_images/08_xgb_combined_fit_shap_high_vs_low_thrombolysing_units_36_0.png ../../_images/08_xgb_combined_fit_shap_high_vs_low_thrombolysing_units_36_1.png

Plot the 9 hosptials individually

plot_shap_per_hospital('Stroke severity', top_9_hospitals, 
                       bottom_9_hospitals, data_values_df, 
                       shap_values_df)
../../_images/08_xgb_combined_fit_shap_high_vs_low_thrombolysing_units_38_0.png

Plot the 9 hosptials individually, split the patients by whether received thrombolysis.

thrombolysed = [0,1]

for thrombolysis_given in thrombolysed:
    plot_shap_per_hospital('Stroke severity', top_9_hospitals, 
                           bottom_9_hospitals, data_values_df, 
                           shap_values_df,
                           thrombolysis_given=thrombolysis_given)
../../_images/08_xgb_combined_fit_shap_high_vs_low_thrombolysing_units_40_0.png ../../_images/08_xgb_combined_fit_shap_high_vs_low_thrombolysing_units_40_1.png

Plot for Prior Disability#

plot_shap_per_hospital_group('Prior disability level', top_9_hospitals, 
                             bottom_9_hospitals, data_values_df, 
                             shap_values_df)
../../_images/08_xgb_combined_fit_shap_high_vs_low_thrombolysing_units_42_0.png

Plot for Use of AF anticoagulents#

plot_shap_per_hospital_group('Use of AF anticoagulents', top_9_hospitals, 
                             bottom_9_hospitals, data_values_df, 
                             shap_values_df)
../../_images/08_xgb_combined_fit_shap_high_vs_low_thrombolysing_units_44_0.png

Plot for Arrival-to-scan time#

(limit the graph to 400 minutes)

mask = data_values_df['Arrival-to-scan time']<400
data_to_plot = data_values_df[mask]
shap_to_plot = shap_values_df[mask]

plot_shap_per_hospital_group('Arrival-to-scan time', top_9_hospitals, 
                             bottom_9_hospitals, data_to_plot, 
                             shap_to_plot)
../../_images/08_xgb_combined_fit_shap_high_vs_low_thrombolysing_units_46_0.png