Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
307 changes: 307 additions & 0 deletions meta_data/short_term_outcomes/identification_of_true_positives.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "01c8fdbb",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import pickle\n",
"import torch as ch\n",
"from preprocessing.geneva_stroke_unit_preprocessing.utils import create_ehr_case_identification_column"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2ca15e7e",
"metadata": {},
"outputs": [],
"source": [
"predictions_path = '/Users/jk1/temp/opsum_end/testing/with_imaging/xgb_test_results/test_predictions.pkl'\n",
"test_data_path = '/Users/jk1/temp/opsum_end/preprocessing/with_imaging/gsu_Extraction_20220815_prepro_30012026_154047/splits/test_data_early_neurological_deterioration_ts0.8_rs42_ns5.pth'\n",
"eds_data = '/Users/jk1/stroke_datasets/Extraction_20220815/eds_j1.csv'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00eb5aab",
"metadata": {},
"outputs": [],
"source": [
"n_timesteps = 72\n",
"threshold = 0.239"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65e49be4",
"metadata": {},
"outputs": [],
"source": [
"with open(predictions_path, 'rb') as f:\n",
" predictions = pickle.load(f)\n",
"y_test, y_prob = predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b6f882e",
"metadata": {},
"outputs": [],
"source": [
"X_test_raw, y_test_raw = ch.load(test_data_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4caa0ea6",
"metadata": {},
"outputs": [],
"source": [
"eds_df = pd.read_csv(eds_data, delimiter=';', encoding='utf-8',\n",
" dtype=str)\n",
"eds_df['case_admission_id'] = create_ehr_case_identification_column(eds_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8795f6b3",
"metadata": {},
"outputs": [],
"source": [
"cids = X_test_raw[:,0,0,0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "98b5e94d",
"metadata": {},
"outputs": [],
"source": [
"y_prob_matrix = y_prob.reshape(-1, n_timesteps)\n",
"y_test_matrix = y_test.reshape(-1, n_timesteps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5a8046a7",
"metadata": {},
"outputs": [],
"source": [
"y_pred_matrix = (y_prob_matrix >= threshold).astype(int)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa7565f6",
"metadata": {},
"outputs": [],
"source": [
"true_positives = ((y_test_matrix == 1) & (y_pred_matrix == 1))\n",
"false_positives = ((y_test_matrix == 0) & (y_pred_matrix == 1))\n",
"true_negatives = ((y_test_matrix == 0) & (y_pred_matrix == 0))\n",
"false_negatives = ((y_test_matrix == 1) & (y_pred_matrix == 0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cfe3eb02",
"metadata": {},
"outputs": [],
"source": [
"n_true_positives_per_patient = true_positives.sum(axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "44a18041",
"metadata": {},
"outputs": [],
"source": [
"true_positives_per_patient_df = pd.DataFrame({\n",
" 'case_admission_id': cids,\n",
" 'n_true_positives': n_true_positives_per_patient\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ca5ea996",
"metadata": {},
"outputs": [],
"source": [
"cids_with_true_positives = true_positives_per_patient_df[true_positives_per_patient_df['n_true_positives'] > 0]['case_admission_id'].tolist()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "43b58644",
"metadata": {},
"outputs": [],
"source": [
"len(cids_with_true_positives)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f970269",
"metadata": {},
"outputs": [],
"source": [
"# ensure that cids_with_true_positives is a subset of the cids in the y_test_raw set\n",
"set(cids_with_true_positives).issubset(set(y_test_raw.case_admission_id))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f30bdcbc",
"metadata": {},
"outputs": [],
"source": [
"true_positive_df = y_test_raw[y_test_raw['case_admission_id'].isin(cids_with_true_positives)].copy()\n",
"true_positive_df = true_positive_df[['case_admission_id', 'patient_id', 'sample_date', 'relative_sample_date', 'value', 'min_nihss', 'delta_to_min',]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "56e64d7c",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "89b475d8",
"metadata": {},
"outputs": [],
"source": [
"true_positive_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e7c05817",
"metadata": {},
"outputs": [],
"source": [
"eds_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "843ee230",
"metadata": {},
"outputs": [],
"source": [
"true_positive_df = true_positive_df.merge(eds_df[['case_admission_id', 'DOB', 'eds_final_id']], on='case_admission_id', how='left')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "420c05ed",
"metadata": {},
"outputs": [],
"source": [
"true_positive_df.columns"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ce55920b",
"metadata": {},
"outputs": [],
"source": [
"true_positive_df = true_positive_df[['case_admission_id', 'eds_final_id', 'patient_id', 'DOB', 'sample_date',\n",
" 'relative_sample_date', 'value', 'min_nihss', 'delta_to_min',]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "530b4216",
"metadata": {},
"outputs": [],
"source": [
"# rename eds_final_id to EDS\n",
"true_positive_df.rename(columns={'eds_final_id': 'EDS'}, inplace=True)\n",
"# rename sample_date to END_date\n",
"true_positive_df.rename(columns={'sample_date': 'END_date'}, inplace=True)\n",
"# rename value to END_value\n",
"true_positive_df.rename(columns={'value': 'END_value'}, inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3e582869",
"metadata": {},
"outputs": [],
"source": [
"# sort by case_admission_id \n",
"true_positive_df.sort_values(by='case_admission_id', inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "21d9a452",
"metadata": {},
"outputs": [],
"source": [
"# true_positive_df.to_csv('/Users/jk1/Downloads/end_true_positives_test_set_25022026.csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "542b99cd",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "opsum",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pandas as pd
import pickle
import numpy as np
import torch as ch
import os
import seaborn as sns
import matplotlib.pyplot as plt
from prediction.utils.visualisation_helper_functions import hex_to_rgb_color, create_palette
from colormath.color_objects import LabColor

shap_values_path = '/Users/jk1/temp/opsum_end/testing/with_imaging/xgb_test_results/shap_explanations_over_time/tree_explainer_shap_values_over_ts.pkl'
test_data_path = '/Users/jk1/temp/opsum_end/preprocessing/with_imaging/gsu_Extraction_20220815_prepro_30012026_154047/splits/test_data_early_neurological_deterioration_ts0.8_rs42_ns5.pth'
cat_encoding_path = '/Users/jk1/temp/opsum_end/preprocessing/with_imaging/gsu_Extraction_20220815_prepro_30012026_154047/logs_30012026_154047/categorical_variable_encoding.csv'

# load the shap values
with open(os.path.join(shap_values_path), 'rb') as handle:
original_shap_values = pickle.load(handle)

shap_values = [np.array([original_shap_values[i] for i in range(len(original_shap_values))]).swapaxes(0, 1)][0]

X_test, y_test= ch.load(test_data_path)

features = X_test[0, 0, :, 2]

# Toggle these to match the model that produced the SHAP values
add_lag_features = True
add_rolling_features = True

# Build aggregated feature names matching aggregate_features_over_time output order:
# [features, avg_features, min_features, max_features, std_features, diff_features, timestep_feature] [lag2, lag3] [roll_mean, roll_std, roll_trend]
# features, avg_, min_, max_, std_, diff_, timestep_idx, [lag2_, lag3_], [rolling_mean_, rolling_std_, rolling_trend_]
aggregated_feature_names = list(features)
for prefix in ['avg_', 'min_', 'max_', 'std_', 'diff_']:
aggregated_feature_names += [f'{prefix}{f}' for f in features]
aggregated_feature_names += ['timestep_idx']

if add_lag_features:
for prefix in ['lag2_', 'lag3_']:
aggregated_feature_names += [f'{prefix}{f}' for f in features]

if add_rolling_features:
for prefix in ['rolling_mean_', 'rolling_std_', 'rolling_trend_']:
aggregated_feature_names += [f'{prefix}{f}' for f in features]

aggregated_feature_names += ['base_value']
print(f'{len(aggregated_feature_names)} feature names (including base_value), SHAP columns: {shap_values.shape[2]}')

sum_over_all_shap_values = np.abs(shap_values).sum(axis=(0,1))


temp_df = pd.DataFrame({'feature': aggregated_feature_names, 'shap_value': sum_over_all_shap_values})
# remove timestep_idx and base_value from the features
temp_df = temp_df[~temp_df.feature.isin(['timestep_idx', 'base_value'])]
# remove avg_, min_, max_, std_, diff_, timestep_idx, [lag2_, lag3_], [rolling_mean_, rolling_std_, rolling_trend_] from the feature names to get the original feature names
prefixes = ['rolling_mean_', 'rolling_std_', 'rolling_trend_', 'avg_', 'min_', 'max_', 'std_', 'diff_', 'lag2_', 'lag3_',]
for prefix in prefixes:
temp_df.loc[temp_df.feature.str.contains(prefix), 'feature'] = temp_df[temp_df.feature.str.contains(prefix)].feature.apply(lambda x: x.replace(prefix, ''))
hourly_pool_prefixes = ['median_', 'min_', 'max_']
for prefix in hourly_pool_prefixes:
temp_df.loc[temp_df.feature.str.contains(prefix), 'feature'] = temp_df[temp_df.feature.str.contains(prefix)].feature.apply(lambda x: x.replace(prefix, ''))
blood_pressure_prefixes = ['systolic_', 'diastolic_', 'mean_']
for prefix in blood_pressure_prefixes:
temp_df.loc[temp_df.feature.str.contains(prefix), 'feature'] = temp_df[temp_df.feature.str.contains(prefix)].feature.apply(lambda x: x.replace(prefix, ''))

# transform to absolute shap values
temp_df['absolute_shap_value'] = np.abs(temp_df['shap_value'])
# drop shap value
temp_df = temp_df.drop(columns=['shap_value'])
# sum the shap values for the same original feature names
temp_df = temp_df.groupby('feature').sum().reset_index()
temp_df.sort_values(by='absolute_shap_value', ascending=False).head(10)
top_10_features_by_mean_abs_summed_shap = temp_df.sort_values(by='absolute_shap_value', ascending=False).head(10).feature.values

print(f'Top 10 features by mean absolute summed SHAP values: {top_10_features_by_mean_abs_summed_shap}')
Loading