{ "cells": [ { "cell_type": "markdown", "id": "98972347-26b3-4869-b544-6f50258f8b01", "metadata": {}, "source": [ "# A simple worked example of Shap\n", "\n", "This notebook shows a very simple example of Shap. We examine scores in a pub quiz. Those scores depend on the players present (Tim, Mark, and Carrie). The pub quiz team can have any number and combination of players - including none of the team turning up!\n", "\n", "Data has been generated according to a known algorithm:\n", "\n", "1) Add 3 marks if Tim is present\n", "2) Add 6 marks if Mark is present\n", "3) Add 9 marks if Carrie is present\n", "4) Add 0-20% (as an integer)\n", "\n", "We then fit an XGBoost regressor model to predict the score given any number and combination of players, and then fit a Shap model to explore the XGBoost model.\n", "\n", "Note: Shap values are usually reported as *log odds shifts* in model predictions. For a description of the relationships between probability, odds, and Shap values (log odds shifts) see [here](./odds_prob.md)." ] }, { "cell_type": "code", "execution_count": 1, "id": "a9a8c4ea-785a-47c1-9e8f-45e539bc549e", "metadata": {}, "outputs": [], "source": [ "# Turn warnings off to keep notebook tidy\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "f77f1e1c-5003-47bb-8503-b3a78aa0f34d", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import shap\n", "\n", "from sklearn import metrics\n", "from sklearn.linear_model import LinearRegression\n", "from xgboost import XGBRegressor" ] }, { "cell_type": "markdown", "id": "8dc5d273-2172-4cb7-b4b3-122715cab721", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "code", "execution_count": 3, "id": "1e8ddec0-bc23-4bae-948c-331c8ae06a6c", "metadata": {}, "outputs": [], "source": [ "scores = pd.read_csv('shap_example.csv')" ] }, { "cell_type": "markdown", "id": "f00f9456-11c2-4ba8-ad4a-0a3916300086", "metadata": {}, "source": [ "Show first eight scores (a single set of all combinations)." ] }, { "cell_type": "code", "execution_count": 4, "id": "70e54890-6904-4cbf-91d9-919e19a2a6cf", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Tim | \n", "Mark | \n", "Carrie | \n", "Score | \n", "
---|---|---|---|---|
0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
1 | \n", "1 | \n", "0 | \n", "0 | \n", "4 | \n", "
2 | \n", "0 | \n", "1 | \n", "0 | \n", "6 | \n", "
3 | \n", "0 | \n", "0 | \n", "1 | \n", "11 | \n", "
4 | \n", "1 | \n", "1 | \n", "0 | \n", "11 | \n", "
5 | \n", "1 | \n", "0 | \n", "1 | \n", "14 | \n", "
6 | \n", "0 | \n", "1 | \n", "1 | \n", "15 | \n", "
7 | \n", "1 | \n", "1 | \n", "1 | \n", "19 | \n", "