{ "cells": [ { "cell_type": "markdown", "id": "98972347-26b3-4869-b544-6f50258f8b01", "metadata": {}, "source": [ "# A simple worked example of Shap\n", "\n", "This notebook shows a very simple example of Shap. We examine scores in a pub quiz. Those scores depend on the players present (Tim, Mark, and Carrie). The pub quiz team can have any number and combination of players - including none of the team turning up!\n", "\n", "Data has been generated according to a known algorithm:\n", "\n", "1) Add 3 marks if Tim is present\n", "2) Add 6 marks if Mark is present\n", "3) Add 9 marks if Carrie is present\n", "4) Add 0-20% (as an integer)\n", "\n", "We then fit an XGBoost regressor model to predict the score given any number and combination of players, and then fit a Shap model to explore the XGBoost model.\n", "\n", "Note: Shap values are usually reported as *log odds shifts* in model predictions. For a description of the relationships between probability, odds, and Shap values (log odds shifts) see [here](./odds_prob.md)." ] }, { "cell_type": "code", "execution_count": 1, "id": "a9a8c4ea-785a-47c1-9e8f-45e539bc549e", "metadata": {}, "outputs": [], "source": [ "# Turn warnings off to keep notebook tidy\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "f77f1e1c-5003-47bb-8503-b3a78aa0f34d", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import shap\n", "\n", "from sklearn import metrics\n", "from sklearn.linear_model import LinearRegression\n", "from xgboost import XGBRegressor" ] }, { "cell_type": "markdown", "id": "8dc5d273-2172-4cb7-b4b3-122715cab721", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "code", "execution_count": 3, "id": "1e8ddec0-bc23-4bae-948c-331c8ae06a6c", "metadata": {}, "outputs": [], "source": [ "scores = pd.read_csv('shap_example.csv')" ] }, { "cell_type": "markdown", "id": "f00f9456-11c2-4ba8-ad4a-0a3916300086", "metadata": {}, "source": [ "Show first eight scores (a single set of all combinations)." ] }, { "cell_type": "code", "execution_count": 4, "id": "70e54890-6904-4cbf-91d9-919e19a2a6cf", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
TimMarkCarrieScore
00000
11004
20106
300111
411011
510114
601115
711119
\n", "
" ], "text/plain": [ " Tim Mark Carrie Score\n", "0 0 0 0 0\n", "1 1 0 0 4\n", "2 0 1 0 6\n", "3 0 0 1 11\n", "4 1 1 0 11\n", "5 1 0 1 14\n", "6 0 1 1 15\n", "7 1 1 1 19" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scores.head(8)" ] }, { "cell_type": "markdown", "id": "df10804c-9f54-4ec2-943e-449c1b8b3bdc", "metadata": {}, "source": [ "## Calculate average of all games\n", "\n", "Shap values show change from the global average of all scores." ] }, { "cell_type": "code", "execution_count": 5, "id": "1b96ea2d-9c5f-46d5-87d0-fcf2b5125258", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Global average: 9.8\n" ] } ], "source": [ "global_average = scores['Score'].mean()\n", "print(f'Global average: {global_average:0.1f}')" ] }, { "cell_type": "markdown", "id": "dab156ce-a88a-4ff5-bd74-a4eaf68925b4", "metadata": {}, "source": [ "## Show averages by whether player is present or not\n", "\n", "Show averages by whether player is present or not, and show difference from average score." ] }, { "cell_type": "code", "execution_count": 6, "id": "8239745d-5480-425a-a8da-889c91d4c934", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Average scores wrt Tim\n", "Tim\n", "0 8.1\n", "1 11.6\n", "Name: Score, dtype: float64\n", "\n", "Difference from average:\n", "Tim\n", "0 -1.8\n", "1 1.8\n", "Name: Score, dtype: float64\n", "\n", "\n", "Average scores wrt Mark\n", "Mark\n", "0 6.8\n", "1 12.8\n", "Name: Score, dtype: float64\n", "\n", "Difference from average:\n", "Mark\n", "0 -3.0\n", "1 3.0\n", "Name: Score, dtype: float64\n", "\n", "\n", "Average scores wrt Carrie\n", "Carrie\n", "0 5.2\n", "1 14.4\n", "Name: Score, dtype: float64\n", "\n", "Difference from average:\n", "Carrie\n", "0 -4.6\n", "1 4.6\n", "Name: Score, dtype: float64\n", "\n", "\n" ] } ], "source": [ "players = ['Tim', 'Mark', 'Carrie']\n", "for player in players:\n", " print(f'Average scores wrt {player}')\n", " average_scores = scores.groupby(player).mean()['Score']\n", " print(average_scores.round(1))\n", " print('\\nDifference from average:')\n", " difference = average_scores - global_average\n", " print(difference.round(1))\n", " print('\\n')" ] }, { "cell_type": "markdown", "id": "954c3646-621c-4e03-aa53-47043bb63854", "metadata": {}, "source": [ "## Split into X and y and fit XGBoost regressor model" ] }, { "cell_type": "code", "execution_count": 7, "id": "b15c99c3-e59a-4773-8bbf-651151fd718e", "metadata": {}, "outputs": [], "source": [ "X = scores.drop('Score', axis=1)\n", "y = scores['Score']" ] }, { "cell_type": "code", "execution_count": 8, "id": "e6c765d0-8746-4b40-8240-0f077340a63f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1,\n", " colsample_bynode=1, colsample_bytree=1, enable_categorical=False,\n", " gamma=0, gpu_id=-1, importance_type=None,\n", " interaction_constraints='', learning_rate=0.300000012,\n", " max_delta_step=0, max_depth=6, min_child_weight=1, missing=nan,\n", " monotone_constraints='()', n_estimators=100, n_jobs=36,\n", " num_parallel_tree=1, predictor='auto', random_state=0, reg_alpha=0,\n", " reg_lambda=1, scale_pos_weight=1, subsample=1, tree_method='exact',\n", " validate_parameters=1, verbosity=None)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Define model\n", "model = XGBRegressor()\n", "\n", "# Fit model\n", "model.fit(X, y)" ] }, { "cell_type": "markdown", "id": "736264e1-0269-4726-b7f1-54f079955091", "metadata": {}, "source": [ "## Get predictions and plot observed vs. predicted\n", "\n", "(Predicted includes a random component to score, so no model can be exact)." ] }, { "cell_type": "code", "execution_count": 9, "id": "caf5fc49-1627-4e1e-9e22-97b7ac2b5846", "metadata": {}, "outputs": [], "source": [ "y_pred = model.predict(X)" ] }, { "cell_type": "code", "execution_count": 10, "id": "8ea34f76-14de-4016-890e-2465f4238cf2", "metadata": {}, "outputs": [], "source": [ "y_array = np.array(y).reshape(-1,1)\n", "y_pred_array = np.array(y_pred).reshape(-1,1)\n", "slr = LinearRegression()\n", "slr.fit(y_array, y_pred_array)\n", "y_pred_best_fit = slr.predict(y_array)\n", "r_square = metrics.r2_score(y_array, y_pred_best_fit)" ] }, { "cell_type": "code", "execution_count": 11, "id": "cedcc020-7055-4f6e-a5a6-b9196650f663", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig = plt.figure(figsize=(5,5))\n", "ax = fig.add_subplot()\n", "ax.scatter(y, y_pred)\n", "ax.set_xlabel('Observed')\n", "ax.set_ylabel('Predicted')\n", "\n", "ax.plot (y, slr.predict(y_array), color = 'red')\n", "text = f'R squared: {r_square:.3f}'\n", "ax.text(16, 12, text, \n", " bbox=dict(facecolor='white', edgecolor='black'))\n", "ax.grid()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "596684d4-684e-4c20-8692-8663e4b23055", "metadata": {}, "source": [ "## Train Shap model" ] }, { "cell_type": "code", "execution_count": 12, "id": "f0c74945-58db-45bb-a5e1-bb1852d28216", "metadata": {}, "outputs": [], "source": [ "# Train explainer on Training set\n", "explainer = shap.TreeExplainer(model, X)\n", "\n", "# Get Shapley values along with base and features\n", "shap_values_extended = explainer(X)\n", "shap_values = shap_values_extended.values" ] }, { "cell_type": "markdown", "id": "7bc14c53-00e5-4b27-bbb1-1743bece2ef6", "metadata": {}, "source": [ "## Show beeswarm of Shap\n", "\n", "The Beeswarm plot shows the Shap values for each instance predicted. Each player has a Shap value for their presence or absence in a team, which shows the effect of their presence/absence compared with the global average of all scores." ] }, { "cell_type": "code", "execution_count": 13, "id": "01e5b22e-90cb-40a0-b1e5-9e4212ea1876", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "features = list(X)\n", "\n", "shap.summary_plot(shap_values=shap_values, \n", " features=X,\n", " feature_names=features,\n", " show=False)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "4844f484-0366-4926-b598-ffbee272f3d4", "metadata": {}, "source": [ "## Calculate Shap for each player when present or not\n", "\n", "Here we calculate the average Shap values for the absence and presence of a player in the team." ] }, { "cell_type": "code", "execution_count": 14, "id": "0d828ac3-21ff-45f1-b9b6-8d1ca8dde7a3", "metadata": {}, "outputs": [], "source": [ "shap_summary_by_player = dict()\n", "for player in list(X):\n", " player_shap_values = shap_values_extended[:, player]\n", " df = pd.DataFrame()\n", " df['player_present'] = player_shap_values.data\n", " df['shap'] = player_shap_values.values\n", " shap_summary = df.groupby('player_present').mean()\n", " shap_summary_by_player[player] = shap_summary " ] }, { "cell_type": "code", "execution_count": 15, "id": "ac3601fa-5a88-44b5-8782-6dfd5efdfd27", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shap for Tim:\n", " shap\n", "player_present \n", "0 -1.7\n", "1 1.7\n", "\n", "Shap for Mark:\n", " shap\n", "player_present \n", "0 -3.0\n", "1 3.0\n", "\n", "Shap for Carrie:\n", " shap\n", "player_present \n", "0 -4.6\n", "1 4.6\n", "\n" ] } ], "source": [ "for player in list(X):\n", " print (f'Shap for {player}:')\n", " print(shap_summary_by_player[player].round(1))\n", " print()" ] }, { "cell_type": "markdown", "id": "3d652de8-683e-490e-bba7-36315c4b6721", "metadata": {}, "source": [ "## Show waterfall plots for lowest and highest scores\n", "\n", "Waterfall plots show the influence of features on the predicted outcome starting from a baseline model prediction.\n", "\n", "Here we show the lowest and highest score." ] }, { "cell_type": "code", "execution_count": 16, "id": "599e379e-01d6-4864-b189-4d822921642e", "metadata": {}, "outputs": [], "source": [ "# Get the location of an example each where probability of giving thrombolysis\n", "# is <0.1 or >0.9\n", "\n", "location_low_score = np.where(y_pred == np.min(y_pred))[0][0]\n", "location_high_score = np.where(y_pred == np.max(y_pred))[0][0]" ] }, { "cell_type": "code", "execution_count": 17, "id": "72f27a73-1625-455c-a2a6-b109199adb19", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = shap.plots.waterfall(shap_values_extended[location_low_score])" ] }, { "cell_type": "code", "execution_count": 18, "id": "1c2661cf-8f87-4f7f-bbba-cf8b694180b4", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = shap.plots.waterfall(shap_values_extended[location_high_score])" ] }, { "cell_type": "markdown", "id": "91728d31-58c3-425f-a024-5898ecd6eaee", "metadata": {}, "source": [ "Pick a random example." ] }, { "cell_type": "code", "execution_count": 19, "id": "51ac3a0b-61fe-4103-9568-b168574d9c4c", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "random_location = np.random.randint(0, len(scores))\n", "fig = shap.plots.waterfall(shap_values_extended[random_location])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 5 }