{ "cells": [ { "cell_type": "markdown", "id": "1b43c8f1", "metadata": {}, "source": [ "# Find similar patients who are treated differently within the same hospital" ] }, { "cell_type": "markdown", "id": "97e262a5", "metadata": {}, "source": [ "## Aims\n", "\n", "- Investigate the number of misclassifications within a hospital\n", "- Define a similarity measure using the decision tree structure of a random forest classifier\n", "- For patients with a predicted outcome different to their true outcome, find similar patients that were treated differently" ] }, { "cell_type": "markdown", "id": "324b703c", "metadata": {}, "source": [ "## Code " ] }, { "cell_type": "markdown", "id": "27d42f11", "metadata": {}, "source": [ "### Import libraries " ] }, { "cell_type": "code", "execution_count": 1, "id": "3ef6869b", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "import pickle as pkl\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import matplotlib.ticker as ticker\n", "%matplotlib inline\n", "\n", "import sklearn as sk\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.tree import DecisionTreeClassifier\n", "\n", "# Turn warnings off to keep notebook tidy\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "id": "fce0441a", "metadata": {}, "source": [ "### Load pre-trained hospital models into dictionary `hospital2model` \n", "\n", "keys = hospitals\n", "\n", "values = trained_classifier, threshold, patients, outcomes\n", "\n", "Note: patients is a numpy array. " ] }, { "cell_type": "code", "execution_count": 2, "id": "f6e16911", "metadata": {}, "outputs": [], "source": [ "with open ('./models/trained_hospital_models_for _cohort.pkl', 'rb') as f:\n", " \n", " hospital2model = pkl.load(f)" ] }, { "cell_type": "markdown", "id": "5d976937", "metadata": {}, "source": [ "### Select a hospital for investigation " ] }, { "cell_type": "code", "execution_count": 3, "id": "4f4bb860", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'TPFFP4410O'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "hospital = list(hospital2model.keys())[1]\n", "hospital" ] }, { "cell_type": "markdown", "id": "0d8bc208", "metadata": {}, "source": [ "### Load test cohort and extract hospital patients " ] }, { "cell_type": "code", "execution_count": 4, "id": "e13c8f74", "metadata": {}, "outputs": [], "source": [ "cohort = pd.read_csv('../data/10k_training_test/cohort_10000_test.csv')\n", "\n", "test_patients_df = cohort.loc[cohort['StrokeTeam']==hospital]" ] }, { "cell_type": "code", "execution_count": 5, "id": "76cced3b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | StrokeTeam | \n", "S1AgeOnArrival | \n", "S1OnsetToArrival_min | \n", "S2RankinBeforeStroke | \n", "Loc | \n", "LocQuestions | \n", "LocCommands | \n", "BestGaze | \n", "Visual | \n", "FacialPalsy | \n", "... | \n", "S2NewAFDiagnosis_Yes | \n", "S2NewAFDiagnosis_missing | \n", "S2StrokeType_Infarction | \n", "S2StrokeType_Primary Intracerebral Haemorrhage | \n", "S2StrokeType_missing | \n", "S2TIAInLastMonth_No | \n", "S2TIAInLastMonth_No but | \n", "S2TIAInLastMonth_Yes | \n", "S2TIAInLastMonth_missing | \n", "S2Thrombolysis | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
105 | \n", "TPFFP4410O | \n", "72.5 | \n", "68.0 | \n", "2 | \n", "2 | \n", "2.0 | \n", "2.0 | \n", "0.0 | \n", "2.0 | \n", "2.0 | \n", "... | \n", "0 | \n", "1 | \n", "0 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "0 | \n", "
720 | \n", "TPFFP4410O | \n", "87.5 | \n", "108.0 | \n", "0 | \n", "3 | \n", "2.0 | \n", "2.0 | \n", "2.0 | \n", "3.0 | \n", "3.0 | \n", "... | \n", "0 | \n", "1 | \n", "0 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "0 | \n", "
759 | \n", "TPFFP4410O | \n", "82.5 | \n", "131.0 | \n", "4 | \n", "1 | \n", "2.0 | \n", "2.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "... | \n", "0 | \n", "0 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "0 | \n", "
779 | \n", "TPFFP4410O | \n", "57.5 | \n", "145.0 | \n", "0 | \n", "0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "2.0 | \n", "... | \n", "0 | \n", "1 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "1 | \n", "
898 | \n", "TPFFP4410O | \n", "67.5 | \n", "162.0 | \n", "1 | \n", "0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "2.0 | \n", "3.0 | \n", "... | \n", "0 | \n", "0 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "1 | \n", "1 | \n", "
5 rows × 101 columns
\n", "\n", " | Patient | \n", "1 | \n", "2 | \n", "3 | \n", "4 | \n", "5 | \n", "
---|---|---|---|---|---|---|
S2BrainImagingTime_min | \n", "23.0 | \n", "19.0 | \n", "11.0 | \n", "33.0 | \n", "30.0 | \n", "12.0 | \n", "
S2NihssArrival | \n", "4.0 | \n", "5.0 | \n", "6.0 | \n", "9.0 | \n", "6.0 | \n", "8.0 | \n", "
S1OnsetToArrival_min | \n", "83.0 | \n", "149.0 | \n", "134.0 | \n", "121.0 | \n", "79.0 | \n", "82.0 | \n", "
S2StrokeType_Primary Intracerebral Haemorrhage | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
S2RankinBeforeStroke | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
AFAnticoagulentHeparin_Yes | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
S1OnsetInHospital_Yes | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
S1OnsetInHospital_No | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "
S1Ethnicity_Mixed | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
S1OnsetTimeType_Not known | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
99 rows × 6 columns
\n", "\n", " | Patient | \n", "1 | \n", "2 | \n", "3 | \n", "4 | \n", "5 | \n", "
---|---|---|---|---|---|---|
S2BrainImagingTime_min | \n", "40.0 | \n", "88.0 | \n", "40.0 | \n", "27.0 | \n", "51.0 | \n", "35.0 | \n", "
S2NihssArrival | \n", "9.0 | \n", "11.0 | \n", "10.0 | \n", "6.0 | \n", "6.0 | \n", "6.0 | \n", "
S1OnsetToArrival_min | \n", "105.0 | \n", "67.0 | \n", "122.0 | \n", "83.0 | \n", "75.0 | \n", "160.0 | \n", "
S2StrokeType_Primary Intracerebral Haemorrhage | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
S2RankinBeforeStroke | \n", "0.0 | \n", "0.0 | \n", "4.0 | \n", "0.0 | \n", "3.0 | \n", "0.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
AFAnticoagulentHeparin_Yes | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
S1OnsetInHospital_Yes | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
S1OnsetInHospital_No | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "
S1Ethnicity_Mixed | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
S1OnsetTimeType_Not known | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "
99 rows × 6 columns
\n", "