Comparing Shap values between high and low thrombolysing hospitals
Contents
Comparing Shap values between high and low thrombolysing hospitals#
Plain English summary#
It has become apparent that hospitals are making different clinical decisions, and that it is not just the patient group that is accounting for the difference in thrombolysis rates between hosptials. For example, hospitals presented with the same information can make opposing decisions on whether or not to give a patient thrombolysis. We would like to pick apart what factors different hospitals are using to make their different decisions. Here we create two groups of hospitals: those that have a high propensity to thrombolyse, and those that have a low propensity to thrombolyse. By comparing the SHAP values for a specific feature between these two hospital groups shows the contribution this feature is having on their different decision of whether to give thrombolysis.
We see that low and high thrombolysing hosptials have the same general pattern, that a patient is more likely to recieve thromboylsis if they have a mid-level stroke severity. But it can be seen that the more cautious hospitals are even less likely to thrombolyse mild and severe strokes, are more likely to thrombolyse mid-level strokes. The wider the net for including hospitals into the two groups (9 vs 30) the less pronounced the difference.
Model and data#
The necessary data for this analysis was created in previous notebooks (that trained the models): feature data, SHAP values.
These XGBoost models were trained on stratified k-fold cross-validation data. The 8 features in the model are:
Arrival-to-scan time: Time from arrival at hospital to scan (mins)
Infarction: Stroke type (1 = infarction, 0 = haemorrhage)
Stroke severity: Stroke severity (NIHSS) on arrival
Precise onset time: Onset time type (1 = precise, 0 = best estimate)
Prior disability level: Disability level (modified Rankin Scale) before stroke
Stroke team: Stroke team attended
Use of AF anticoagulents: Use of atrial fibrillation anticoagulant (1 = Yes, 0 = No)
Onset-to-arrival time: Time from onset of stroke to arrival at hospital (mins)
And one target feature:
Thrombolysis: Recieve thrombolysis (1 = Yes, 0 = No)
The 8 features included in the model (to predict whether a patient will recieve thrombolysis) were chosen sequentially as having the single best improvement in model performance (using the ROC AUC). The stroke team feature is included as a one-hot encoded feature.
The Python library SHAP was applied to the first k-fold model to obtain a SHAP value for each feature, for each instance. SHAP values are in the same units as the model output, so for XGBoost this is in log odds-ratio.
A single SHAP value per feature was obtained by taking the mean of the absolute values across all instances.
A note on Shap values#
Shap values are usually reported as log odds shifts in model predictions. For a description of the relationships between probability, odds, and Shap values (log odds shifts) see here.
Aims:#
Identify top 30 and bottom 30 hosptials judged by the 10k cohort dataset
Plot SHAP values for the feature stroke severity (NIHSS) for the top and bottom hospitals
Represent the data for the patients that did, and did not, receive thrombolysis
Repeat for top and bottom 9 hospitals
Repeat for feature prior disability (mRS)
Observations#
Low and high thrombolysing hosptials have the same general pattern: that a patient is more likely to recieve thromboylsis if they have a mid-level stroke severity.
The more cautious hospitals are even less likely to thrombolyse mild and severe strokes, are more likely to thrombolyse mid-level strokes.
The wider the net for including hospitals into the two groups (9 vs 30) the less pronounced the difference.
Dividing the patients into did/did not receive thrombolysis shows for those
Note on shap version 0.40:#
Installed using pip install shap
There is a bug in the waterfall plot where show=False
(required to save plot) fails. To correct this find the _waterfall.py
file in the shap library (e.g. in anaconda/envs/samuel2/lib/python3.8/site-packages/shap/plots
, and search/replace all pl.
to plt.
, and replace initial import of import matplotlib.pyplot as pl
with import matplotlib.pyplot as plt
.
Load packages#
# Turn warnings off to keep notebook tidy
import warnings
warnings.filterwarnings("ignore")
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import shap
from xgboost import XGBClassifier
import json
Create output folders if needed#
path = './saved_models'
if not os.path.exists(path):
os.makedirs(path)
Load data on predicted 10k cohort thrombolysis use at each hospital#
Use the hospitals thrombolysis rate on the same set of 10k patients to select the 30 hospitals with the highest thrombolysis rates.
thrombolysis_by_hosp = pd.read_csv(
'./output/10k_thrombolysis_rate_by_hosp_key_features.csv', index_col='stroke_team')
thrombolysis_by_hosp.sort_values(
'Thrombolysis rate', ascending=False, inplace=True)
thrombolysis_by_hosp.head()
Thrombolysis rate | |
---|---|
stroke_team | |
VKKDD9172T | 0.4610 |
GKONI0110I | 0.4356 |
CNBGF2713O | 0.4207 |
HPWIF9956L | 0.4191 |
MHMYL4920B | 0.3981 |
Load SHAP data for first k-fold model#
Use explainer to access the feature names (explainer.data_feature_names) and shap_values_extended to access the data values and SHAP values
k = 0
filename = f'./output/shap_values_extended_xgb_key_features_{k}.p'
with open(filename, 'rb') as filehandler:
shap_values_extended = pickle.load(filehandler)
filename = f'./output/shap_values_explainer_xgb_key_features_{k}.p'
with open(filename, 'rb') as filehandler:
explainer = pickle.load(filehandler)
Load target feature data for first k-fold model#
data_loc = '../data/kfold_5fold/'
feature = 'S2Thrombolysis'
test = pd.read_csv(data_loc + 'test_0.csv')
target_feature = test[feature]
Prepare the datasets#
Create Dataframe containing the feature values and target value (with feature names as column titles)
data_values_df = pd.DataFrame(shap_values_extended.data)
data_values_df.columns = explainer.data_feature_names
data_values_df = data_values_df.join(target_feature)
data_values_df
Arrival-to-scan time | Infarction | Stroke severity | Precise onset time | Prior disability level | Use of AF anticoagulents | Onset-to-arrival time | team_AGNOF1041H | team_AKCGO9726K | team_AOBTM3098N | ... | team_XPABC1435F | team_XQAGA4299B | team_XWUBX0795L | team_YEXCH8391J | team_YPKYH1768F | team_YQMZV4284N | team_ZBVSO0975W | team_ZHCLE1578P | team_ZRRCV7012C | S2Thrombolysis | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 17.0 | 1.0 | 14.0 | 1.0 | 0.0 | 0.0 | 186.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1 |
1 | 25.0 | 1.0 | 6.0 | 1.0 | 0.0 | 0.0 | 71.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1 |
2 | 138.0 | 1.0 | 2.0 | 1.0 | 0.0 | 0.0 | 67.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0 |
3 | 21.0 | 0.0 | 11.0 | 1.0 | 0.0 | 0.0 | 86.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0 |
4 | 8.0 | 1.0 | 16.0 | 1.0 | 0.0 | 0.0 | 83.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
17754 | 8.0 | 1.0 | 4.0 | 1.0 | 0.0 | 0.0 | 105.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0 |
17755 | 35.0 | 1.0 | 25.0 | 0.0 | 3.0 | 0.0 | 208.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0 |
17756 | 80.0 | 1.0 | 2.0 | 0.0 | 1.0 | 0.0 | 236.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0 |
17757 | 16.0 | 1.0 | 10.0 | 1.0 | 0.0 | 0.0 | 34.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1 |
17758 | 23.0 | 1.0 | 10.0 | 1.0 | 0.0 | 0.0 | 125.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0 |
17759 rows × 140 columns
Create Dataframe containing the SHAP values (with feature names as column titles)
#explainer.data_feature_names
shap_values_df = pd.DataFrame(shap_values_extended.values)
shap_values_df.columns = explainer.data_feature_names
shap_values_df
Arrival-to-scan time | Infarction | Stroke severity | Precise onset time | Prior disability level | Use of AF anticoagulents | Onset-to-arrival time | team_AGNOF1041H | team_AKCGO9726K | team_AOBTM3098N | ... | team_XKAWN3771U | team_XPABC1435F | team_XQAGA4299B | team_XWUBX0795L | team_YEXCH8391J | team_YPKYH1768F | team_YQMZV4284N | team_ZBVSO0975W | team_ZHCLE1578P | team_ZRRCV7012C | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1.665572 | 1.764335 | 1.519817 | 0.533487 | 0.516871 | 0.261582 | -0.351841 | 0.006454 | -0.000906 | 0.0 | ... | 0.0 | 0.0 | 0.012956 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | -0.005460 | 0.007542 |
1 | 1.042294 | 1.674223 | 0.856954 | 0.567241 | 0.615244 | 0.261544 | -0.050369 | -0.002464 | -0.002205 | 0.0 | ... | 0.0 | 0.0 | 0.014900 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | -0.006770 | 0.008916 |
2 | -0.826964 | 1.314690 | -1.217614 | 0.609297 | 0.652299 | 0.212339 | 0.475219 | -0.002121 | -0.005571 | 0.0 | ... | 0.0 | 0.0 | 0.014900 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.000129 | 0.008393 |
3 | 0.516566 | -7.852724 | 0.652316 | 0.334562 | 0.286924 | 0.083673 | -0.004301 | 0.000162 | 0.000863 | 0.0 | ... | 0.0 | 0.0 | 0.017112 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | -0.006360 | 0.008215 |
4 | 1.597652 | 1.797345 | 1.430198 | 0.551891 | 0.520666 | 0.206812 | 0.149001 | 0.014654 | 0.000492 | 0.0 | ... | 0.0 | 0.0 | 0.017112 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | -0.003191 | 0.008034 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
17754 | 1.450873 | 1.646119 | 0.159537 | 0.893681 | 0.592578 | 0.227030 | -0.014525 | -0.000794 | -0.001883 | 0.0 | ... | 0.0 | 0.0 | 0.011911 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | -0.003646 | 0.005996 |
17755 | 0.327505 | 1.435369 | 0.858219 | -0.925355 | -0.411963 | 0.139809 | -0.969236 | 0.003032 | -0.000350 | 0.0 | ... | 0.0 | 0.0 | 0.012956 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | -0.008049 | 0.013558 |
17756 | -0.835297 | 1.215644 | -0.853170 | -0.755055 | -0.054575 | 0.254649 | -2.016223 | 0.004402 | -0.008321 | 0.0 | ... | 0.0 | 0.0 | 0.015529 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.001583 | 0.007954 |
17757 | 1.188463 | 1.760321 | 1.185057 | 0.486597 | 0.581963 | 0.170620 | 0.500511 | 0.007738 | -0.000324 | 0.0 | ... | 0.0 | 0.0 | 0.008683 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | -0.013498 | 0.007206 |
17758 | 1.234641 | 1.749588 | 0.820623 | 0.555529 | 0.406468 | 0.227709 | -0.020579 | 0.002986 | -0.000084 | 0.0 | ... | 0.0 | 0.0 | 0.012956 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | -0.006360 | 0.007542 |
17759 rows × 139 columns
Plot SHAP values for Stroke Severity for the top and bottom 30 hospitals#
Identify the top and bottom 30 hospitals with highest thrombolysis rates for the common set of 10k patients. Store their stroke team name.
top_30_hospitals = list(thrombolysis_by_hosp.head(30).index)
bottom_30_hospitals = list(thrombolysis_by_hosp.tail(30).index)
Define the function to plot the SHAP values for stroke severity for the two hospital groups.
def plot_shap_per_hospital_group(feature, top_hospital_list,
bottom_hospital_list, data_values_df,
shap_values_df,
thrombolysis_given=-1):
"""
If thrombolysis_given is passed (0 or 1) then use mask to split patients
based on this. If not given, then -1 is flag to not split
"""
# Calculate data for top n hospitals
shap_df = pd.DataFrame()
for hospital in top_hospital_list:
column_name = f"team_{hospital}"
mask1 = data_values_df[column_name]!=0
if thrombolysis_given == -1:
#-1 is flag to not split on this value, so create a mask all 1
mask2 = mask1 + ~mask1
else:
# split based on thrombolysis_given
mask2 = data_values_df['S2Thrombolysis'] == thrombolysis_given
mask = mask1 * mask2
s1 = pd.Series(shap_values_df[feature][mask], name="shap")
s2 = pd.Series(data_values_df[feature][mask], name="data")
df = pd.concat([s1, s2], axis=1)
shap_df = shap_df.append(df)
# calculate mean per category
mean_by_category_top = shap_df.groupby('data').mean()
# Calculate data for bottom n hospitals
shap_df = pd.DataFrame()
for hospital in bottom_hospital_list:
column_name = f"team_{hospital}"
mask1 = data_values_df[column_name]!=0
if thrombolysis_given == -1:
#-1 is flag to not split on this value, so create a mask all 1
mask2 = mask1 + ~mask1
else:
# split based on thrombolysis_given
mask2 = data_values_df['S2Thrombolysis'] == thrombolysis_given
mask = mask1 * mask2
s1 = pd.Series(shap_values_df[feature][mask], name="shap")
s2 = pd.Series(data_values_df[feature][mask], name="data")
df = pd.concat([s1, s2], axis=1)
shap_df = shap_df.append(df)
# calculate mean per category
mean_by_category_bottom = shap_df.groupby('data').mean()
# plot means
fig = plt.figure(figsize=(6,6))
ax1 = fig.add_subplot(111)
n_hospitals = len(top_hospital_list)
ax1.scatter(mean_by_category_top.index,
mean_by_category_top['shap'],
label=f"top {n_hospitals} hospitals")
n_hospitals = len(bottom_hospital_list)
ax1.scatter(mean_by_category_bottom.index,
mean_by_category_bottom['shap'],
label=f"bottom {n_hospitals} hospitals")
ax1.set_xlabel(f'{feature} feature value')
ax1.set_ylabel(f'{feature} SHAP value (mean of patients)')
ax1.grid()
title = ""
if thrombolysis_given == 0:
title = "(not receiving thrombolysis) "
elif thrombolysis_given == 1:
title = "(receiving thrombolysis) "
ax1.set_title(f'{feature} SHAP values for patients {title}\nattending a low or ' +
f'high thrombolysing hospital')
ax1.legend(loc='upper right', bbox_to_anchor=(1.5, 1))
plt.show()
Plot the SHAP values for stroke severity for the two hospital groups.
plot_shap_per_hospital_group('Stroke severity', top_30_hospitals,
bottom_30_hospitals, data_values_df,
shap_values_df)
Plot the SHAP values for stroke severity for the two hospital groups, split the patients by whether received thrombolysis.
thrombolysed = [0,1]
for thrombolysis_given in thrombolysed:
plot_shap_per_hospital_group('Stroke severity', top_30_hospitals,
bottom_30_hospitals, data_values_df,
shap_values_df,
thrombolysis_given=thrombolysis_given)
Plot the 30 hosptials individually
def plot_shap_per_hospital(feature, top_hospital_list,
bottom_hospital_list, data_values_df,
shap_values_df,
thrombolysis_given=-1):
title = ""
if thrombolysis_given == 0:
title = "(not receiving thrombolysis) "
elif thrombolysis_given == 1:
title = "(receiving thrombolysis) "
fig = plt.figure(figsize=(10,5))
ax = fig.add_subplot(121)
n_hospitals = len(bottom_hospital_list)
ax.set_title(f'{feature} SHAP values for patients \n{title}\nattending ' +
f'bottom {n_hospitals} thrombolysing hospital')
# Calculate data for bottom n hospitals
for hospital in bottom_hospital_list:
column_name = f"team_{hospital}"
mask1 = data_values_df[column_name]!=0
if thrombolysis_given == -1:
#-1 is flag to not split on this value, so create a mask all 1
mask2 = mask1 + ~mask1
else:
# split based on thrombolysis_given
mask2 = data_values_df['S2Thrombolysis'] == thrombolysis_given
mask = mask1 * mask2
s1 = pd.Series(shap_values_df[feature][mask], name="shap")
s2 = pd.Series(data_values_df[feature][mask], name="data")
df = pd.concat([s1, s2], axis=1)
plot_hospital_data = df.groupby('data').mean()
ax.plot(plot_hospital_data.index,
plot_hospital_data['shap'],
label=hospital)
#ax.legend(loc='lower center', prop={'size': 8})
ax.grid()
ax.set_xlabel(f'{feature}')
ax.set_ylabel('SHAP (mean of patients)')
ax.set_ylim(-3, 2.5)
ax = fig.add_subplot(122)
n_hospitals = len(top_hospital_list)
ax.set_title(f'{feature} SHAP values for patients \n{title}\nattending ' +
f'top {n_hospitals} thrombolysing hospital')
# Calculate data for top n hospitals
for hospital in top_hospital_list:
column_name = f"team_{hospital}"
mask1 = data_values_df[column_name]!=0
if thrombolysis_given == -1:
#-1 is flag to not split on this value, so create a mask all 1
mask2 = mask1 + ~mask1
else:
# split based on thrombolysis_given
mask2 = data_values_df['S2Thrombolysis'] == thrombolysis_given
mask = mask1 * mask2
s1 = pd.Series(shap_values_df[feature][mask], name="shap")
s2 = pd.Series(data_values_df[feature][mask], name="data")
df = pd.concat([s1, s2], axis=1)
plot_hospital_data = df.groupby('data').mean()
ax.plot(plot_hospital_data.index,
plot_hospital_data['shap'],
label=hospital)
#ax.legend(loc='lower center', prop={'size': 8})
ax.grid()
ax.set_xlabel(f'{feature}')
ax.set_ylabel('SHAP (mean of patients)')
ax.set_ylim(-3, 2.5)
plt.show()
plot_shap_per_hospital('Stroke severity', top_30_hospitals,
bottom_30_hospitals, data_values_df,
shap_values_df)
Plot the 30 hosptials individually, split the patients by whether received thrombolysis.
thrombolysed = [0,1]
for thrombolysis_given in thrombolysed:
plot_shap_per_hospital('Stroke severity', top_30_hospitals,
bottom_30_hospitals, data_values_df,
shap_values_df,
thrombolysis_given=thrombolysis_given)
Repeat for groups with 9 hospitals#
top_9_hospitals = list(thrombolysis_by_hosp.head(9).index)
bottom_9_hospitals = list(thrombolysis_by_hosp.tail(9).index)
Plot the SHAP values for stroke severity for the two hospital groups.
plot_shap_per_hospital_group('Stroke severity', top_9_hospitals,
bottom_9_hospitals, data_values_df,
shap_values_df)
Plot the SHAP values for stroke severity for the two hospital groups, split the patients by whether received thrombolysis.
thrombolysed = [0,1]
for thrombolysis_given in thrombolysed:
plot_shap_per_hospital_group('Stroke severity', top_9_hospitals,
bottom_9_hospitals, data_values_df,
shap_values_df,
thrombolysis_given=thrombolysis_given)
Plot the 9 hosptials individually
plot_shap_per_hospital('Stroke severity', top_9_hospitals,
bottom_9_hospitals, data_values_df,
shap_values_df)
Plot the 9 hosptials individually, split the patients by whether received thrombolysis.
thrombolysed = [0,1]
for thrombolysis_given in thrombolysed:
plot_shap_per_hospital('Stroke severity', top_9_hospitals,
bottom_9_hospitals, data_values_df,
shap_values_df,
thrombolysis_given=thrombolysis_given)
Plot for Prior Disability#
plot_shap_per_hospital_group('Prior disability level', top_9_hospitals,
bottom_9_hospitals, data_values_df,
shap_values_df)
Plot for Use of AF anticoagulents#
plot_shap_per_hospital_group('Use of AF anticoagulents', top_9_hospitals,
bottom_9_hospitals, data_values_df,
shap_values_df)
Plot for Arrival-to-scan time#
(limit the graph to 400 minutes)
mask = data_values_df['Arrival-to-scan time']<400
data_to_plot = data_values_df[mask]
shap_to_plot = shap_values_df[mask]
plot_shap_per_hospital_group('Arrival-to-scan time', top_9_hospitals,
bottom_9_hospitals, data_to_plot,
shap_to_plot)