From 4067d14f230603bb0eac0c01bde61508c756434e Mon Sep 17 00:00:00 2001 From: Julian Klug Date: Wed, 18 Feb 2026 16:22:09 +0100 Subject: [PATCH 1/5] updated MIMIC preprocessing to be compatible with short term outcomes --- .../patient_selection/patient_selection.ipynb | 8 +- .../admission_preprocessing.py | 18 ++- .../database_assembly/database_assembly.py | 12 +- .../further_exclusion_criteria.py | 11 +- .../lab_preprocessing/lab_preprocessing.py | 3 +- .../selected_lab_values.xlsx | Bin 10307 -> 10326 bytes .../mimic_nihss_exploration.ipynb | 15 +- .../monitoring_preprocessing.py | 5 +- .../outcome_extraction.py | 0 .../outcome_preprocessing.ipynb | 0 .../outcome_preprocessing.py | 0 .../early_neurological_deterioration.py | 133 ++++++++++++++++++ .../short_term_outcomes_preprocessing.py | 42 ++++++ .../preprocessing_pipeline.py | 116 +++++++++++---- .../impute_missing_values.py | 26 ++-- 15 files changed, 323 insertions(+), 66 deletions(-) rename preprocessing/mimic_preprocessing/outcome_preprocessing/{ => long_term_outcomes}/outcome_extraction.py (100%) rename preprocessing/mimic_preprocessing/outcome_preprocessing/{ => long_term_outcomes}/outcome_preprocessing.ipynb (100%) rename preprocessing/mimic_preprocessing/outcome_preprocessing/{ => long_term_outcomes}/outcome_preprocessing.py (100%) create mode 100644 preprocessing/mimic_preprocessing/outcome_preprocessing/short_term_outcomes/early_neurological_deterioration.py create mode 100644 preprocessing/mimic_preprocessing/outcome_preprocessing/short_term_outcomes/short_term_outcomes_preprocessing.py diff --git a/preprocessing/geneva_stroke_unit_preprocessing/patient_selection/patient_selection.ipynb b/preprocessing/geneva_stroke_unit_preprocessing/patient_selection/patient_selection.ipynb index 6a6e3d3..078fd91 100644 --- a/preprocessing/geneva_stroke_unit_preprocessing/patient_selection/patient_selection.ipynb +++ b/preprocessing/geneva_stroke_unit_preprocessing/patient_selection/patient_selection.ipynb @@ -41,7 +41,7 @@ }, "outputs": [], "source": [ - "stroke_registry_data_path = '/Users/jk1/Library/CloudStorage/OneDrive-unige.ch/stroke_research/geneva_stroke_unit_dataset/data/stroke_registry/post_hoc_modified/stroke_registry_post_hoc_modified.xlsx'" + "stroke_registry_data_path = '/Users/jk1/stroke_datasets/stroke_registry_post_hoc_modified.xlsx'" ] }, { @@ -55,7 +55,7 @@ }, "outputs": [], "source": [ - "manual_eds_completion_folder = '/Users/jk1/Library/CloudStorage/OneDrive-unige.ch/stroke_research/geneva_stroke_unit_dataset/data/stroke_registry/manuel_eds_completion'" + "manual_eds_completion_folder = '/Users/jk1/stroke_datasets/manuel_eds_completion'" ] }, { @@ -70,7 +70,7 @@ "outputs": [], "source": [ "# general consent is present for the extraction of 20221117\n", - "general_consent_eds_path = '/Users/jk1/stroke_datasets/stroke_unit_dataset/per_value/Extraction_20221117/eds_j1.csv'" + "general_consent_eds_path = '/Users/jk1/stroke_datasets/Extraction20221117/eds_j1.csv'" ] }, { @@ -84,7 +84,7 @@ }, "outputs": [], "source": [ - "output_path = '/Users/jk1/temp/opsum_end'" + "output_path = '/Users/jk1/Downloads'" ] }, { diff --git a/preprocessing/mimic_preprocessing/admission_preprocessing/admission_preprocessing.py b/preprocessing/mimic_preprocessing/admission_preprocessing/admission_preprocessing.py index 4dd1faa..a474bb1 100644 --- a/preprocessing/mimic_preprocessing/admission_preprocessing/admission_preprocessing.py +++ b/preprocessing/mimic_preprocessing/admission_preprocessing/admission_preprocessing.py @@ -34,14 +34,24 @@ def preprocess_admission(admission_notes_data_path:str, admission_table_path:str """ - possible_value_ranges_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(''))), + # possible_value_ranges_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(''))), + possible_value_ranges_file = os.path.join(os.path.abspath(''), 'preprocessing', 'geneva_stroke_unit_preprocessing/possible_ranges_for_variables.xlsx') possible_value_ranges = pd.read_excel(possible_value_ranges_file) + # load data + admission_data_df = pd.read_excel(admission_notes_data_path) + admission_table_df = pd.read_csv(admission_table_path) # Preprocessing admission table data - admission_table_df = pd.read_csv(admission_table_path) - admission_table_df = admission_table_df[['subject_id', 'hadm_id', 'icustay_id', 'dob', 'admittime', 'age', 'gender', 'admission_location']] + # if admisttime not a column in admission_table_df, get it from admission_notes data + if 'admittime' not in admission_table_df.columns: + if 'admittime' not in admission_data_df.columns: + raise ValueError('admittime not found in either admission_table_df or admission_data_df') + else: + admission_table_df = admission_table_df.merge(admission_data_df[['hadm_id', 'icustay_id', 'admittime']], on=['hadm_id', 'icustay_id'], how='left') + + admission_table_df = admission_table_df[['subject_id', 'hadm_id', 'icustay_id', 'admittime', 'age', 'gender', 'admission_location']] admission_table_df.drop_duplicates(inplace=True) if verbose: @@ -76,8 +86,6 @@ def preprocess_admission(admission_notes_data_path:str, admission_table_path:str var_name='sample_label') # Preprocessing admission notes data - admission_data_df = pd.read_excel(admission_notes_data_path) - # restrict to patients admitted to ICU with stroke as primary reason and with onset to admission < 7 d admission_data_df = admission_data_df[admission_data_df['admitted to ICU for stroke'] == 'y'] admission_data_df = admission_data_df[admission_data_df['onset to ICU admission > 7d'] == 'n'] diff --git a/preprocessing/mimic_preprocessing/database_assembly/database_assembly.py b/preprocessing/mimic_preprocessing/database_assembly/database_assembly.py index ac3611d..41418b3 100644 --- a/preprocessing/mimic_preprocessing/database_assembly/database_assembly.py +++ b/preprocessing/mimic_preprocessing/database_assembly/database_assembly.py @@ -36,7 +36,7 @@ def assemble_variable_database(extracted_tables_path: str, admission_notes_data_ admission_data_df = admission_data_df[target_columns] # Apply further exclusion criteria - hadm_ids_to_exclude = apply_further_exclusion_criteria(admission_data_df['case_admission_id'].unique(), admission_table_path, log_dir=log_dir) + hadm_ids_to_exclude = apply_further_exclusion_criteria(admission_data_df['case_admission_id'].unique(), admission_table_path, admission_notes_data_path, log_dir=log_dir) admission_data_df = admission_data_df[~admission_data_df['case_admission_id'].isin(hadm_ids_to_exclude)] # Deduce patient selection from admission data @@ -55,12 +55,20 @@ def assemble_variable_database(extracted_tables_path: str, admission_notes_data_ lab_data_df = lab_data_df[target_columns] # Preprocess monitoring data - if preproccessed_monitoring_data_path != '': + if preproccessed_monitoring_data_path != '' and not preproccessed_monitoring_data_path is None: monitoring_data_df = pd.read_csv(preproccessed_monitoring_data_path) else: if mimic_admission_nihss_db_path == '': raise ValueError('Please provide a path to the MIMIC admission nihss database.') monitoring_df = pd.read_csv(os.path.join(extracted_tables_path, 'monitoring_df.csv')) + + if 'admittime' not in monitoring_df.columns: + admission_notes_df = pd.read_excel(admission_notes_data_path) + if 'admittime' not in admission_notes_df.columns: + raise ValueError('admittime not found in either admission_table_df or admission_data_df') + else: + monitoring_df = monitoring_df.merge(admission_notes_df[['hadm_id', 'icustay_id', 'admittime']], on=['hadm_id', 'icustay_id'], how='left') + monitoring_data_df = preprocess_monitoring(monitoring_df, mimic_admission_nihss_db_path, verbose) monitoring_data_df['case_admission_id'] = monitoring_data_df['hadm_id'].astype(int).astype(str) + '_' + monitoring_data_df['icustay_id'].astype(int).astype(str) diff --git a/preprocessing/mimic_preprocessing/database_assembly/further_exclusion_criteria.py b/preprocessing/mimic_preprocessing/database_assembly/further_exclusion_criteria.py index b2019f7..8d2d30d 100644 --- a/preprocessing/mimic_preprocessing/database_assembly/further_exclusion_criteria.py +++ b/preprocessing/mimic_preprocessing/database_assembly/further_exclusion_criteria.py @@ -1,7 +1,7 @@ import pandas as pd import os -def apply_further_exclusion_criteria(patient_selection: list, admission_table_path:str, log_dir:str) -> set: +def apply_further_exclusion_criteria(patient_selection: list, admission_table_path:str, admission_notes_path:str, log_dir:str) -> set: """ Applies further exclusion criteria: - Exclude patients with time of death during surveillance period @@ -16,6 +16,9 @@ def apply_further_exclusion_criteria(patient_selection: list, admission_table_pa admission_table_path : str path to admission table + admission_notes_path : str + path to admission notes + log_dir : str path to log directory @@ -27,6 +30,12 @@ def apply_further_exclusion_criteria(patient_selection: list, admission_table_pa # Exclude patients with time of death during surveillance period # (i.e. death in the ICU within 72 hours of admission) admission_table = pd.read_csv(admission_table_path) + + # if admisttime not a column in admission_table_df, get it from admission_notes data + if 'admittime' not in admission_table.columns: + admission_data_df = pd.read_excel(admission_notes_path) + admission_table = admission_table.merge(admission_data_df[['hadm_id', 'icustay_id', 'admittime']], on=['hadm_id', 'icustay_id'], how='left') + admission_table['case_admission_id'] = admission_table['hadm_id'].astype(str) + '_' + admission_table[ 'icustay_id'].astype(str) admission_table = admission_table[admission_table['case_admission_id'].isin(patient_selection)] diff --git a/preprocessing/mimic_preprocessing/lab_preprocessing/lab_preprocessing.py b/preprocessing/mimic_preprocessing/lab_preprocessing/lab_preprocessing.py index 9a400f0..a7a4ffb 100644 --- a/preprocessing/mimic_preprocessing/lab_preprocessing/lab_preprocessing.py +++ b/preprocessing/mimic_preprocessing/lab_preprocessing/lab_preprocessing.py @@ -132,7 +132,8 @@ def preprocess_labs(lab_df: pd.DataFrame, log_dir:str = '', verbose: bool = True ## RESTRICT TO PLAUSSIBLE RANGE plausible_restricted_lab_df = restricted_lab_df.copy() - possible_value_ranges_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(''))), + # possible_value_ranges_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(''))), + possible_value_ranges_file = os.path.join(os.path.abspath(''), 'preprocessing', 'geneva_stroke_unit_preprocessing/possible_ranges_for_variables.xlsx') possible_value_ranges = pd.read_excel(possible_value_ranges_file) plausible_restricted_lab_df['out_of_range'] = False diff --git a/preprocessing/mimic_preprocessing/lab_preprocessing/selected_lab_values.xlsx b/preprocessing/mimic_preprocessing/lab_preprocessing/selected_lab_values.xlsx index faf8ec0d360956251e9fdf06f3cb4ec144d18684..79d2e595899f331fb47a53ce84482a2cb716f29f 100644 GIT binary patch delta 6516 zcmZ9RWmKHawyhg?_r_g=OCSWN@c=VY|HE4>$1uoNSrHaZ3bf|Q zpQo#y@AuN1jfxqUrDj#)fqrg|4JX6oL#;9<69NeAN_e-dojdB;ol4sYv<-%neV@ zS-PE7VHOIkyB;?@g0j$O75yiqApouW$~oU6UcWS%F!5SN_!vGSz&C18mV5o`FSwwh zQl*4jFAXAF0Jc^icd)s#_$_vkEi3H4ScU9SeK;Ew!He(OP6UA%6mTMFfY8Oxko;o; z8cp9|{yaHtZ6Zab*0;Ic{>B}p&E&pGU$mM8@}E4zUc}sR zLDUBKP`wY@@X_DAVdnnk_a|!O$>sxwG$Z9&I466|;I-WNnEWJ~9*v{M?}mj=kJ+k& zxV$}HVDZey3F-i5LRfzr!fjMMpi@&xkPF!@amUT|N1qRIno8C^#GSXw_-?P{ z@4ISrMY-GD39m+O1a=IWDfYH5`+vW@a`V3!Vt$aW7;8}MIw~)Cx2^EAnc*)OA)Q~E zrc8+7WmvsWaX3J$UP6gHLe=DNVI_RUR%zvaQ>tldj%1fpq#!a@ z_N#Ki{b&F0Wgq?uk1+r2gf+nDl)()ZzA!C=XRzYsiKp=ZC z^gRVV@a}zFsVHGUCE_{Ii@==3RVOhY&jE3+RkC+=rv9x~7DoW>_)z!pn3fvZi-F_7 z&NF$I{>%BWU&A}#UYTSgW=nCHj~Z$iYg?t(6L-4hZTWQDA5g;|BNv)#nRbG+kd{-% z-{A;6**=b*r>+mD@ym1-&Zb#2)+9SpjHed@78a+`Cg$5c=fb2XB0xhFkM)}N>m)W& z_pJ2nu+v^1*0M8StCLC2SVMUJOp=KLCyTWT2mEQB#=2L}2)xA1i+7$n<3M~Wj6xLt zGNQMb*bZ!4clYGC=Ix~nKPC{wW8Pp6H^T;w^kTrrHmVyOJ805aI4lc=BPQ3WR}%n! zzR}y6=a!!MlI+?>Ltb;{$bb=+^{E;g=HbV@s&_Bc!??}b`ubyT`A&=7nipQHlaGE^ z|0#6I*7E2ek$yaYvh;^8h4?oY>x`?ZW+ z-6U#;ER-CKMU3a?@Y^WVfN(|(1#)bFK`|`h+e+G>L@P`5BV5-9`a;D3bTW_f{Ks3E zbEWHDPd`#?E$sR)R>dbGL>=eXW`+-b^v<7F7~2ISrWu~zPC z0yarw7!gZpRioxwNaNlM6886aG_hPg6;ZigXxBX%kL1EpO$JCgeGY=u-8t`htsr*a z@Q!8#KygR0K|QdaQvxzRo~+2kO%Sck+!eBQUTs{C$jVRrm>mn|5mR_YyhioJf;KL` zuJfVVTZ*gra}u-`r9FEM#E}lb3&srb0a|*7k@ei2kdhiQ?-RqA(HHA=fYT<9?(EW9 zs^JnVfwwMG-R_3fNf|0u>kB+c_^OA8|%5^g^4CT_exd^!L(OTPbt(0Jjw|RS}BogBIV% zjqndqtTfo##jH$WZcQ}YBy*BswW|7wG~xAi-*RqWPyCmm;Uu*^^X+xO8Pr_&7403j zfTNOe=xO1^JPE-FI)CXoEJfXl^nEHS=Ci#!2epUlav$PAv3xu=tfTo5 zz;HB%T($3W1#&*A9A_0!${&)#-x18O6(WHFZn=qgeh|RwtBng~ZWkXwHAQqRA}~ z34ReJcx>5zRGZooVM=6_RMDd!$gEy=EC@?<}IUL)>rmv(*3Dj%ClJdv(@ml*4EGZIcR?n9Eg=5oTk_9 z0k68LDbX7{R@-f}UQ*2PeWrb}B_(j0LG>=J`hoKaiBHNr9DYveR6fO>5|bW8>rII> z`Lg?l0G~u9zV|HHq+N;FMHfh=_$r~$P& zYBAz=i~)e*j(2~mydPfkqnQ0F+R@93&vjEdDTB^KhoOWDCnyLmzD3L+<58W8cGVEh zpd5&5$ej8>i6UAvM_;Yzp{5c0)5{@_;tjtLJUZhJW?04*P{wQ8am!+{n}n(2J&y&O z;X-2i8>(q5ag0jPpZ}Iy=IxqWBtWI9!fr@Jwb*)ksx))@JyJ!EaG=1JbI{?{ex;=K=ELbj9;i~N>xXOOz%M%ajji*1GLTKS@q4eQ!K}JY1gK3^J&H7WvWl= z412+N6Fsp4hI=vLM&B$Hbqe%x1iMKFE~GwP7;3r6>JE!qT;DFT_+!I#m4 z+ZVv3$)G6k>rXO4z`zo1UJfJv%O$EKrKFc@0k4T9QC;gbxQG>iTGH=^%S?+)aonQa zJ^g|v2`>-k0Ww4;ZQItO${~szrBt_$+*r4!B>BZn*QvL~#UZ8P$DAADeNm4udeX%T z#8!4$j**hN6Do7ef;+Dc#Er9=zmvTo;qkC-$It z8cU*VlkRB3N4jrxt}v%i+RUIUVny+a@?fan`4^vT7|>dcOL+g)U#M?5Jy$tRAPl)% zsgYqlw1aSb2S)VXXY12Z(TgsSVxrD5i?F}r>6eOwh$rcfaUFrpp~3x?HzumtCP_6w z1$^%7JxYoQN~`)Y$xD?6@qHMq8DLOr8)#9nmd(8?sygz)c!Zh}ivnCyVw1zh0HEZ#0>R~O}>91Jzd%alaqviWmKkhSW+4h8td| zp}TrX-Kplw{;jL57B@zmW2#u<%{!eobVr!Lid6R4A%bMVjp% z3;|vTzfZ$5bql6)i(9j-30zFL!^5o=H=a~a8W4S0;2)Ax7J&NK|rkg~^2E#wmpf2GFrjHCL%vm%->WXWX%x|K(TUWrRjTrVS z?V+FPdnbk5hEvv&!bqbs^bX;i$)6*%Mu1lF!q4&zj5SQpqqOSchw}!yY~eL7!{h>^ znxvV-LQRb0G2+{{8z}pnW@rSpcm)zgw8~KhhjjRv587k&v1PJ;A;jJI_ zT)s)@e?&MgSAOhtKNOjf&GXTWn92G@@sf~dGV(XgWL@uqaGCfkdUuVMo?!Tyf+Y~X zB5EBvL*ndv7P5F-Higmit4}z>xnlM((>dUQa=;Sg8KG8GfW^A_5eHY%%tI)B9Ni&- zovV^)jl>9h9!siJVX4Y2>aI|g8=TVRW5`*b=hTLX$@6ii)Z-2@J>;a};d%qp_Gpt1 z-p6ei?)CB?$$ya5fzQTF>hOzX_LET zEA{NIsoN%U$DQ#ZApr=xxbNt@h`?^AHhR&p^0)c($EjTqMZUr*UXgC2a%$0*a*J92 zj8x&(uhFWm9f74WF=9^`|59Vv*cKb182>mZbXd?1Tr$AXd7hu}Ny_%Q^Xyufj|EpN z(Jl!BL%e~`tK8cWW0c*ox~fx%9(EjOp<;3^;q0RmSM%-cj0VuH5SfVv=d;7?i>!hl zaR`{XzUc8ctDY2x52u^!QZwOnB^O{VPk_yeoL744=Uja{iVzHe@$SJheY!$3<>w+~ z1?^g7$-qgGt`Q-MPMhn#s?J-4N5ZPps<)*9A)8A$KgsQJAZ8xDSTdAbo2WY_TBz2s znxFXg;1g>yn;L$JXn?XS;>MCT7jhh2H*M{LI#dtSv*iWN=zaX{dMCM5^2%K$Kn8j9 zJpN(eGC03Y>`uL4Z0aK7-Q8z!s%i#!%Z^zi7+_XL&~+YY!}tViD+AY-xnuS#cOyPm zq0VY-7`qfEAnmPg`q?KajE!zYer`YvgVmTQS3I^oOolv;*(tNyE?phzOhj#4QM;cb z-ML~-Uyx*dG`;s8i8&0Z@1LSw?H~{}FC`#XKfLy`_SAKCPptZ$d8Lh_6n)cD;FUr_bf-^Xis-T=mc>Db#yonshwI>NTe*795c5Hif-AG2g;OM|d`vC! z5}pO~V^#M)8#(Z8+JeBGS^4&lg|^PoDPU$qD-%!*ov(xxz;33O|2{Rehi-K+P4ALb zcL`M)Nd*QWjVQ4k|CMX-V7&=(d`$y02R3b*Ef)E`gU#VQ86j4ZrG}OZ%G^c%q*`xe z)oRwZEuIBC!%t~n%tAmw+ljud`myZ?R^PBnU z7QwD;p2mliV51isY}*`%qJ2XI6BPnnXXLx;ChKgSo4TY9pOt6r1O{Q_St35jnF$kJOJgEs)4E1 zn07&j(xl`Kr4kdo!rw8*(DzsH=!6J2Tv+?~;45)1d!1O9s%C1OUZ$dCa@ULtctVKk z-4T|7sScS`qir;hXCwO~QjXoC24gRu*+m2O3qzNKjuhUH;{M`(Km6ltW5)BkIBZd! zR%`kpFyd>YaSN;AQDF8b==%G#8hV5=<)cgHHaIs_@M6$-=tZnV%EecyaBHhLM^}IELnAvTT@X0BZDh8JN+RSmh;1^aWE}$ zE0=U_@$pCsZ`Qw9I~Fn-B-x{P?OPmm(rk0vRQF>m^w5!~g{7{m)Jr0tJF}<$VrBM3 z4$rV%Tx|eNF0m1O>FYpVE3Y(l13r+n5=-CW4S+A-xjWg@ei+a$wk3Kd-L(EgF9sbc z`+IDY_Qd+)?)74xlG-RKzDAIBrn7U}6&Y4v(Zs@&4mo&XzNw+_vqYlikQ7UA2ku74 zzil=&+F3yz)mk4079_wg(M+wIpR{)ctwa8|Z|0MQj@&}fARQuTRpImni>FDIihSbS z;X-D^`@^)GG~YJ?;ya0c)v4|$czux&GLDE3q=6ssC}3P;3qDlKd2k8X%W`ToG>JgP zsIxmR>x`;Cc$BR$B|9AbdjXHdFPZl{!kx18q&FL0 zv8*o#OY)mg!hBfrm=kM>N<52>kg_a($_BfVvug(fZTQi3faz{>gbudnSo@^FQPRj9 zL+nu-S$?(S!q1UAN+}>%TK`X_|CH8GCj(rq^e>P!q&={X)RCR6-ihupj7D>N3E$ps zX|U3lIdFw`v>cLJQ+&mWE_?#-XOfSAo_$}EE=&EXz9EbU6^^cT7i6BpOJ~hl+PA)L zN$ZxO6ME3h3a<_y_4xpq$izBD<@n-vDl$k9s$FB!N1^%ue;odium3w*(V>}ajL<=J ze8T@T^{?oK4FW;u(ZNt>HbR>J*q$H|<^M~8N^zlP6kw2bVUCJ Vc|#Z2rI7BRU=C(Jn3m`>v1NDHcCYN^FXhq_y#x*do=7K&B$ z#(-64!l}ecrTOvv{=w_3GuTaz^LV#Q#Z*AP9RE8_wII7Jv?_s|hI>z|l`&ms8w#By zaaLxaX0Jb}cnFE1++RUt$0Z@m&(FJt&gnP=c+*qtNwxpP4|`ztc%~K>AY; z**>^ag61kk0(zNBDut?E#SYSZQM_A}_il2H>c8p<_4y=`Uz$;E1u@2dsa7^A7sCKx zQRwokv=TYZD3%|K^7C0@-<=zW;oXEhmj4z%91Z`_3n71Ow+)xG4D0%4;GjbCsi=d# zoYz=eL(PY^=lvy|g0T$6d>VC5m!TdhxE|h_hB?wEQ}hw4x15M^`VivImy9&a&3!1|6BneQBfsi z_bt(;NRfP-?Xh;nh{!im|3NBQzr6k}j*HfK^m=K6BFfTX;EB#B-;?UKe2Mjkm5IG- z_$Jila_6!(nv`Kd zDL7u#^<2{-$}X&+&Kj~5W^D$v)j}d zl~P>2;YPJ9y!?}f#h^VVIA>+Xj5p4A>82g93&37UyHwb`d{XEetpQEnT`DXd*mMOb zYYO3&XjA35(T>}#f4$fO11R_K(0uOowj-O^zzowr$j~HYA|RRpGJ1}LfDrW8ueks6 ztGCa47hCUtxV1yy%sq>jDEMpRv&CID1%{0Z%G)=(g{q5!W!a4e=xRD$w$ z^>IV(?*m;T{s-bs=rZ5wZ^q+9X{53^jXtac1A&|g^Od?^c0z8z-yzf=eE0TW-o6O; z1-`q`Wtm17H0MDHaEy?mzVR1$rRQ1(HPpx|pzbu07F;|;+zd_x;|~^~7Gh)EJL<{- zqHj|{xhpBm;d&&$go3;ea`fS+-Ltd3L1TWN5Y7;h3IBfx|odG=3)Z zO{uU(1KZd}8`#bY z?NuWrXump2*7E@!j-A=YYHr?CZkb?;(RN6mXE#$T{9W*`rua}6*JJ!p9i{}JXf_H- z4y55BhPk zsE5tJy0Fuz2i`A>p)6P*UwG_yLy%WTSwq$ob;U32rd1BNJ@KoHC(rrEHwkWTGe~1K z!*jbcsqO+SVK(6vAN0{92KDjqF^gs9QN-9}HNN7`zkqsT5o7g`;&*AFBOnaYLC**& zfw4GB5u$F=dl?*W|35;TiOef&KV6{9Og7mmR87r?+WG?g$sLq0k(kAB<4YvxQyIpjM9m%PukDq@Qq zBm)?cNv+@PM8j5D;)j%b8NZSYQBi+bNhbjn=~R(1Cf~jbx6%R!^V2gP;-;zs;IaW4db#L`U7cfU`h|1L96kw}4e-Dbr=e`AMBWZEk))cX8WZ9)n`=YY#O6(8dxkTnB zeegF&@3dMjowXEHyjml}Mud(hM+*FOx_1%PEd^Ze-u0Q+!^{qYXU8FJ`MD395cAjse zJ%ceNWYj3tp)Og+`;qyJ18e1(qo=1HIcm$$3s_8h%tVs&F}G?Z!p)9AuEB=RJfzn4 zH*PFpOtEOslx}yOd^~kUyL|LET#A1~tv2TJ7{BI2^fa@)mj$y9ccZ*j^k}aw-N%LV zu@%>vzEbBIsURCh%fAQV4T&z>9vpSx);`n`Q{WS#?uxV7;`ai8p4TkZJWL9KvFsQ3 z)|?;T{1%V8qLA&z!r8=opfT$gPJZLHI|Yj)6G)Qsn}cXcdM~IY5BG<#;L$54`F~I$ z4N9DsP>_xj?jq$TwUuZ^=^kZhi78jBH- z5Wgr(U1%rM^@$$@$iOYq-i8yF0;4;yNdpFpRO&<3O3?7Q-qjZiMBC(Y=UH~IGz^gK z%~^_Tt-ZTvcmN4o48dyQ9fiWOs+s8am$!l033%#jLP*=RTXECuaMWGiZl8aJjjWh?z-iA3}4$hT+$MV#K6KKT@H^Xelk_at}u2v zPRW|f5`KY9q@>^Z61oK$_!(2r`p18S#vjoe4K>l^>>))wi_LaNA*=!OhL+Pg&>?5( zn@4h`_w0K;?{_n|&kyBx)9<8fY{#85E{q4SuCHXBGCG37Mp#bHcXKp5>b3>wQ;tVb z2ioQQgHbww$#`_$(nakyR+tmwIKyqGP&(&Q3M1F1CyNX#|3cyfj9yBfj5+TzX_OQyf%rb1$qJHsf!mMDFUIP1&7X;3V46LnTEG)Vc{St0>O->js9E^g_cQ1Fz z`rrWo8jSIm32DA=s~@>aj@zxjZF)_Q&|bvTBKww(tBmEJgqact`A%SV{t;c$}7Be!5LguTV_&k4a4s zEt%$gW>4M$w`6)b@^z=?vIyK8`Cs)MMQ)E8Twc}A z+%P=CnfIz4>%W5Ikb8K!zW(4>H|o&ni*0DoUThTL#K6(~C4B4Zt0XFQR?R9ioc#kZ z_I<^9pqhPqKeE&@V~6e^CK>1fh?eN=YRS2$F|yrHQ4XgFTf{9vjUX7S9b&9p&mrtC z=*y~eTqzWm-S}@m=;RSJS48-c>5o3vn(q5@yjZhH(l*7LKDb9TwVY7&bF3MHb02SR zbB#q1#KozZ{~&Ezzu)LGoe%OnA)wREerOuWHFEWHim+>vVngZFD3~~$6_0pVWpaWW zxLi9)C3)J(h>N!{wb&0@O(iI@XpM`Bmb5Z^nErxl(_;A3 z?K>7-)gXLV=0?}NT%NVP&5g~dyUDg}cLXF^xdZHlj}zPR@8?4aLM$X4&fd6Y6992qV5`ode@&zU!0>OG`Hv6d2BrZ=}e$Wl|cLj>>yk z((v#+yl0^nv-DswrdV_Xule&gKs`VI{ZKqVr=(9NEYjgC0?g-N>JJcTWwLTp_G2O& z$YAY#zcAMRgdl1=*!mxGi(nx%t)={z^xNk_=xPKB-c9>wo@`b zh=s)+;$u)@ch{oBQ~ZjsK_?keF#&&m$ch(+GF=@EBTJ943U#1LinC*4+9Lg>7F^EL z!9%;(cSESy37Y8P1c!$wUIV5ZHms3gP%kMUxtD0>f(2q$V3CLoJBag8ZwOZ;@GIm<}j#dzYn|V`+oB4l6*N5-EwhI@Y zvv@i7d!)R-K(DuOJxCxx359BAW~vxhImaPlYv{DkaxvPjLiNjhFaO0{nq;!3WTyZV z&Cw_QVdk#>c0kW%MSI2i#;xGuW-1Q1ji!@{)IH{P#ayA#LWcc$yPxAU8ZJbC7}eCO zn`6B@!0)WUtV)IiH!Kb+p?jm~}_i*6+rJ zBSmXGyq0k|WMPNb`ntVnzF+9tOuEr4aDv&+#gZMJ8u-D;@Hj#8_Z}23nSxnwfJq^d zu6zzjsomS+#oiF|3`|68|H66GZ_Lu=m%j=6)r-%& zG$5H=b5aDko*Q?JunCTf>h(|Uj!j7~OE`g1fJ^M_A0slfc;+M=g$;O^-O8;+6ur-t zr5`u=0jhz6B_hqC##*VRU#hhFnn`W_&RjT0Qv08ImZINhe|p*#DZ9*0r>QqDavQ4_ z<9w>EZC>9Q`j^fA={3njZKr--9EBPypU^AeqK}bGwJyLz>dV#ys<^?Vq#3{YX}Hnw zlcN&vMKKSjc$uzXxMk>@X|qg{A6!*{$(~!qk1+&oA^%wtT0WSU>Q)3cy3~7(&+>GM z%{H)xX8fM7W-EhCt1?8v`*Vih%*YjMrdp3Dh>BJxKJf!u= z@|bB_yh2*tExHU-r}M0;UeC{f(!W(kl~RL`Ki)&t{m50*=&3tz^5S>$ z90K@}a->t@%kn0lnFf3vDV60N%mQ&@CF@Y@Q&hd?{yb2#F54IsY?YSWP5l5D^emq! z)d?B0LjLlUw8xKOB*iIh6f7qLn**-BUv}EIMjXgkC_TO}?>@AIdFSNS`1KWDhg^Zw zvrF^VkEbeMZH`UNMS>-P1rA>jc88BpsQcUdhShEjN^alYkB&{WEG=+M=cYd{^s9&> zk2ZoHUV^6uQS$-S?~ytYp6Xa*JN%4C+UcPf$f0FfzuL0wRS4!Km`fEVe5QrEq9%z1?y`I`w_xub zy($xnSV?t?@a#uE9E8tCLrNGOxTqVDa9%czudpGC%YhjyFMV#)n~eJ9*@LJ{xT&E@ z+c_TFQhMv&S%YjqED2FLn-2*ZIA2(N1Fpb}&JZuWJ&hko-W`7-7S!wo`7tfcF{gHx z5X|YUMY`g3^w|Q2?3be+OKUk?t?g}j33tnM4ka9!;P3{y)y>r9x(n_;nkQWvPR4v> zg7|ehrlx_IQTOd}FPJspa}z=*Brmv;Voy@10c#yMY7vxQ&2w zn-oPP5J&eEC2?splF?hte(LgY6KT1cWZ}ZmdDKEk6DSVirY{|?S7UVbaL7&b!l2RWR0pykt&Zx0KXid~Dl`+{qewe)Q|SR+Ji zC)uS=&-WB}a71Y*T7eJi)rFF>+4gRgX`~2SZGvOZzLI=`pT%T;YsYZFoA=gh1A<&+ ztUYKsgrq$Z3c$|W4u@^r`&0Ch zKQ&K>E~XE(gi_j8iP7KMgzLJ6!Ul<6WEt~QOU`C-^g-Wu_hlO09Kt)3>D>GkN@l(1 zb)yKug3;fG@*fCaZ%cKy^Q4t3Z!lr15Dq%3{1)si_CBBQ4)N+vgwdADM?T}*GiOXP z6*FbvCx$Zg*_#*K{v$m$U@*!d|LrB3A}bNq{}CVrgqQ!f@$anv1(c7Knd;xO>;K;&Px?2>a!~$nmHLZS siT=NikPNM4Wu^S*lR!j3u>V`h2>wmChaR&^Bkn-C*cj2OnExUD7dg{iuK)l5 diff --git a/preprocessing/mimic_preprocessing/mimic_nihss_preprocessing/mimic_nihss_exploration.ipynb b/preprocessing/mimic_preprocessing/mimic_nihss_preprocessing/mimic_nihss_exploration.ipynb index 7f159c1..f405720 100644 --- a/preprocessing/mimic_preprocessing/mimic_nihss_preprocessing/mimic_nihss_exploration.ipynb +++ b/preprocessing/mimic_preprocessing/mimic_nihss_preprocessing/mimic_nihss_exploration.ipynb @@ -27,9 +27,8 @@ "metadata": {}, "outputs": [], "source": [ - "mimic_nihss_train_path = '/Users/jk1/stroke_datasets/national-institutes-of-health-stroke-scale-nihss-annotations-for-the-mimic-iii-database-1.0.0/NER_Train.txt'\n", - "\n", - "mimic_nihss_test_path = '/Users/jk1/stroke_datasets/national-institutes-of-health-stroke-scale-nihss-annotations-for-the-mimic-iii-database-1.0.0/NER_Test.txt'" + "mimic_nihss_train_path = '/Users/jk1/temp/opsum_end/preprocessing/national-institutes-of-health-stroke-scale-nihss-annotations-for-the-mimic-iii-database-1.0.0/NER_Train.txt'\n", + "mimic_nihss_test_path = '/Users/jk1/temp/opsum_end/preprocessing/national-institutes-of-health-stroke-scale-nihss-annotations-for-the-mimic-iii-database-1.0.0/NER_Test.txt'" ] }, { @@ -211,7 +210,7 @@ "outputs": [], "source": [ "import os\n", - "# overall_mimic_df.to_csv(os.path.join(output_dir, 'mimic_nihss_database.csv'))" + "overall_mimic_df.to_csv(os.path.join(output_dir, 'mimic_nihss_database.csv'))" ] }, { @@ -261,21 +260,21 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "opsum", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "pygments_lexer": "ipython3", + "version": "3.8.11" } }, "nbformat": 4, diff --git a/preprocessing/mimic_preprocessing/monitoring_preprocessing/monitoring_preprocessing.py b/preprocessing/mimic_preprocessing/monitoring_preprocessing/monitoring_preprocessing.py index 762c9bc..b9d0648 100644 --- a/preprocessing/mimic_preprocessing/monitoring_preprocessing/monitoring_preprocessing.py +++ b/preprocessing/mimic_preprocessing/monitoring_preprocessing/monitoring_preprocessing.py @@ -11,8 +11,9 @@ def preprocess_monitoring(monitoring_df: pd.DataFrame, mimic_admission_nihss_db_path:str, verbose:bool = False): - possible_value_ranges_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(''))), - 'geneva_stroke_unit_preprocessing/possible_ranges_for_variables.xlsx') + # possible_value_ranges_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(''))), + possible_value_ranges_file = os.path.join(os.path.abspath(''), 'preprocessing', + 'geneva_stroke_unit_preprocessing/possible_ranges_for_variables.xlsx') possible_value_ranges = pd.read_excel(possible_value_ranges_file) ## FIO2 PROCESSING diff --git a/preprocessing/mimic_preprocessing/outcome_preprocessing/outcome_extraction.py b/preprocessing/mimic_preprocessing/outcome_preprocessing/long_term_outcomes/outcome_extraction.py similarity index 100% rename from preprocessing/mimic_preprocessing/outcome_preprocessing/outcome_extraction.py rename to preprocessing/mimic_preprocessing/outcome_preprocessing/long_term_outcomes/outcome_extraction.py diff --git a/preprocessing/mimic_preprocessing/outcome_preprocessing/outcome_preprocessing.ipynb b/preprocessing/mimic_preprocessing/outcome_preprocessing/long_term_outcomes/outcome_preprocessing.ipynb similarity index 100% rename from preprocessing/mimic_preprocessing/outcome_preprocessing/outcome_preprocessing.ipynb rename to preprocessing/mimic_preprocessing/outcome_preprocessing/long_term_outcomes/outcome_preprocessing.ipynb diff --git a/preprocessing/mimic_preprocessing/outcome_preprocessing/outcome_preprocessing.py b/preprocessing/mimic_preprocessing/outcome_preprocessing/long_term_outcomes/outcome_preprocessing.py similarity index 100% rename from preprocessing/mimic_preprocessing/outcome_preprocessing/outcome_preprocessing.py rename to preprocessing/mimic_preprocessing/outcome_preprocessing/long_term_outcomes/outcome_preprocessing.py diff --git a/preprocessing/mimic_preprocessing/outcome_preprocessing/short_term_outcomes/early_neurological_deterioration.py b/preprocessing/mimic_preprocessing/outcome_preprocessing/short_term_outcomes/early_neurological_deterioration.py new file mode 100644 index 0000000..67efc8f --- /dev/null +++ b/preprocessing/mimic_preprocessing/outcome_preprocessing/short_term_outcomes/early_neurological_deterioration.py @@ -0,0 +1,133 @@ +import numpy as np +import pandas as pd + + +def early_neurological_deterioration(df, require_min_repeats=False, min_delta=4, keep_multiple_events=True): + """ + Detects early neurological deterioration based on NIHSS scores from an input DataFrame. + + Args: + df: A pandas DataFrame containing the necessary data. Must contain the following columns: + - sample_label: The label of the sample. + - source: The source of the sample. + - sample_date: The date of the sample. + - value: The value of the sample. + Example: restricted_feature_df + require_min_repeats (bool): Whether to require a minimum number of repeated measurements for detection. + min_delta (int): The minimum difference in NIHSS scores to consider as deterioration. + keep_multiple_events (bool): If True, allows detection of multiple deterioration events by resetting the baseline after each event. + + Returns: + pandas DataFrame: A subset of the input DataFrame with early neurological deterioration detected. + + Raises: + ValueError: If the input DataFrame is empty or does not contain the required columns. + """ + + nihss_df = df[(df['sample_label'] == 'NIHSS') & (df['source'] == 'EHR')] + + end_nihss_df = nihss_df.groupby('case_admission_id').apply( + detect_end_events, + require_min_repeats=require_min_repeats, + min_delta=min_delta, + keep_multiple_events=keep_multiple_events, + ) + # undo groupby + end_nihss_df.reset_index(drop=True, inplace=True) + + end_nihss_df = end_nihss_df[end_nihss_df.end] + end_nihss_df['relative_sample_date_hourly_cat'] = np.floor( + end_nihss_df['relative_sample_date'] + ) + + return end_nihss_df + + +def detect_end_events(temp, require_min_repeats=False, min_delta=4, keep_multiple_events=True): + """ + Detect neurological deterioration events based on NIHSS increases. + + Parameters: + - temp: DataFrame with NIHSS measurements + - require_min_repeats: If True, only consider NIHSS values confirmed by at least 2 consecutive measurements + - min_delta: Minimum increase in NIHSS score to be considered deterioration + - keep_multiple_events: If True, detect multiple deterioration events by resetting baseline after each event + + Returns: + - DataFrame with detected events + """ + temp = temp.copy() + temp['sample_date'] = pd.to_datetime(temp['sample_date'], format='%d.%m.%Y %H:%M') + temp['value'] = temp['value'].astype(float) + temp.sort_values('sample_date', inplace=True) + + # Initialize columns for tracking + temp['min_nihss'] = np.nan + temp['delta_to_min'] = np.nan + temp['end'] = False + + if not keep_multiple_events: + # Original behavior - just use expanding min and mark first event only + if require_min_repeats: + # For a given patient, compute minimum NIHSS confirmed by at least 2 consecutive measurements + temp['same_as_previous'] = (temp['value'].shift(1) == temp['value']).astype(int) + temp['score_with_min_1_repeat'] = temp['value'] + temp.loc[temp['same_as_previous'] == 0, 'score_with_min_1_repeat'] = np.nan + temp['min_nihss'] = temp['score_with_min_1_repeat'].expanding().min() + else: + temp['min_nihss'] = temp['value'].expanding().min() + + temp['delta_to_min'] = temp['value'] - temp['min_nihss'] + temp['end'] = temp['delta_to_min'] >= min_delta + + # Only retain first end event + temp['n_end'] = temp['end'].cumsum() + temp.loc[temp['n_end'] > 1, 'end'] = False + drop_cols = ['n_end'] + + else: + # New behavior - reset minimum after each event + current_min = np.inf + last_event_idx = -1 + + if require_min_repeats: + # Mark scores that are repeated at least once + temp['same_as_previous'] = (temp['value'].shift(1) == temp['value']).astype(int) + temp['valid_score'] = ( + (temp['same_as_previous'] == 1) | (temp['same_as_previous'].shift(-1) == 1) + ) + else: + temp['valid_score'] = True + + # Process rows sequentially to detect multiple events + for i, row in temp.iterrows(): + if not pd.isna(row['value']): + if require_min_repeats and not row['valid_score']: + # Skip this measurement as it's not confirmed + temp.at[i, 'min_nihss'] = current_min + continue + + # Update minimum if this is a new minimum since last event + if row['value'] < current_min: + current_min = row['value'] + + # Calculate delta and check for event + temp.at[i, 'min_nihss'] = current_min + temp.at[i, 'delta_to_min'] = row['value'] - current_min + + if temp.at[i, 'delta_to_min'] >= min_delta: + # This is a deterioration event + temp.at[i, 'end'] = True + # Reset the minimum to this new value for subsequent measurements + current_min = row['value'] + + if require_min_repeats: + drop_cols = ['same_as_previous', 'valid_score'] + else: + drop_cols = [] + + # Drop temporary columns + if 'drop_cols' in locals() and len(drop_cols) > 0: + temp.drop(drop_cols, axis=1, inplace=True) + + return temp diff --git a/preprocessing/mimic_preprocessing/outcome_preprocessing/short_term_outcomes/short_term_outcomes_preprocessing.py b/preprocessing/mimic_preprocessing/outcome_preprocessing/short_term_outcomes/short_term_outcomes_preprocessing.py new file mode 100644 index 0000000..8ef502b --- /dev/null +++ b/preprocessing/mimic_preprocessing/outcome_preprocessing/short_term_outcomes/short_term_outcomes_preprocessing.py @@ -0,0 +1,42 @@ +from preprocessing.mimic_preprocessing.outcome_preprocessing.short_term_outcomes.early_neurological_deterioration import \ + early_neurological_deterioration + + +def preprocess_short_term_outcomes(df, end_require_min_repeats=False, end_min_delta=4, end_keep_multiple_events=True): + """ + Preprocess short term outcomes: + - Early neurological deterioration + + Args: + df: A pandas DataFrame containing the necessary data. Must contain the following columns: + - sample_label: The label of the sample. + - source: The source of the sample. + - sample_date: The date of the sample. + - value: The value of the sample. + Example: restricted_feature_df + end_require_min_repeats (bool): Whether to require a minimum number of repeated measurements for detection of END + end_min_delta (int): The minimum difference in NIHSS scores to consider as END. + keep_multiple_events (bool): If True, allows detection of multiple deterioration events by resetting the baseline after each event. + + Returns: + pandas DataFrame: A subset of the input DataFrame with short term outcomes detected. + + Raises: + ValueError: If the input DataFrame is empty or does not contain the required columns. + """ + + end_df = early_neurological_deterioration( + df, + require_min_repeats=end_require_min_repeats, + min_delta=end_min_delta, + keep_multiple_events=end_keep_multiple_events, + ) + end_df['outcome_label'] = 'early_neurological_deterioration' + # store arguments for each outcome + end_df['outcome_args'] = ( + f'require_min_repeats={end_require_min_repeats}, ' + f'min_delta={end_min_delta}, ' + f'keep_multiple_events={end_keep_multiple_events}' + ) + + return end_df diff --git a/preprocessing/mimic_preprocessing/preprocessing_pipeline/preprocessing_pipeline.py b/preprocessing/mimic_preprocessing/preprocessing_pipeline/preprocessing_pipeline.py index 3aaff1b..b4c8756 100644 --- a/preprocessing/mimic_preprocessing/preprocessing_pipeline/preprocessing_pipeline.py +++ b/preprocessing/mimic_preprocessing/preprocessing_pipeline/preprocessing_pipeline.py @@ -8,7 +8,9 @@ from preprocessing.mimic_preprocessing.database_assembly.database_assembly import assemble_variable_database from preprocessing.mimic_preprocessing.database_assembly.relative_timestamps import transform_to_relative_timestamps -from preprocessing.mimic_preprocessing.outcome_preprocessing.outcome_preprocessing import preprocess_outcomes +from preprocessing.mimic_preprocessing.outcome_preprocessing.long_term_outcomes.outcome_preprocessing import preprocess_outcomes +from preprocessing.mimic_preprocessing.outcome_preprocessing.short_term_outcomes.short_term_outcomes_preprocessing import \ + preprocess_short_term_outcomes from prediction.utils.utils import ensure_dir from preprocessing.preprocessing_tools.resample_to_time_bins.resample_to_hourly_features import resample_to_hourly_features from preprocessing.preprocessing_tools.encoding_categorical_variables.encode_categorical_variables import encode_categorical_variables @@ -18,13 +20,20 @@ from preprocessing.preprocessing_tools.preprocessing_verification.variable_presence_verification import variable_presence_verification -def preprocess(extracted_tables_path: str, admission_notes_data_path: str, - reference_population_imputation_path: str, - reference_population_normalisation_parameters_path: str, - reference_categorical_encoding_path: str, - preproccessed_monitoring_data_path: str = '', - mimic_admission_nihss_db_path: str = '', - log_dir: str = '', verbose:bool=True, desired_time_range:int=72) -> pd.DataFrame: +def preprocess( + extracted_tables_path: str, + admission_notes_data_path: str, + reference_population_imputation_path: str, + reference_population_normalisation_parameters_path: str, + reference_categorical_encoding_path: str, + preproccessed_monitoring_data_path: str = '', + mimic_admission_nihss_db_path: str = '', + log_dir: str = '', + include_short_term_outcomes: bool = True, + short_term_outcomes_config: dict = {}, + verbose: bool = True, + desired_time_range: int = 72, +) -> pd.DataFrame: """ Apply geneva_stroke_unit_preprocessing pipeline detailed in ./geneva_stroke_unit_preprocessing/readme.md to the MIMIC-III dataset. @@ -90,22 +99,35 @@ def preprocess(extracted_tables_path: str, admission_notes_data_path: str, # 11. geneva_stroke_unit_preprocessing outcomes outcome_table_path = os.path.join(extracted_tables_path, 'outcome_df.csv') - preprocessed_outcomes_df = preprocess_outcomes(outcome_table_path, verbose=verbose) + preprocessed_long_term_outcomes_df = preprocess_outcomes(outcome_table_path, verbose=verbose) + if include_short_term_outcomes: + preprocessed_short_term_outcomes_df = preprocess_short_term_outcomes( + restricted_feature_df, + **short_term_outcomes_config, + ) + preprocessed_outcomes = (preprocessed_long_term_outcomes_df, preprocessed_short_term_outcomes_df) + else: + preprocessed_outcomes = (preprocessed_long_term_outcomes_df,) - return normalised_df, preprocessed_outcomes_df + return normalised_df, preprocessed_outcomes def preprocess_and_save( - extracted_tables_path: str, admission_notes_data_path: str, - reference_population_imputation_path: str, - reference_population_normalisation_parameters_path: str, - reference_categorical_encoding_path: str, - output_dir: str, - preproccessed_monitoring_data_path: str = '', - mimic_admission_nihss_db_path: str = '', - feature_file_prefix:str = 'preprocessed_features', outcome_file_prefix:str = 'preprocessed_outcomes', - verbose:bool=True): + extracted_tables_path: str, + admission_notes_data_path: str, + reference_population_imputation_path: str, + reference_population_normalisation_parameters_path: str, + reference_categorical_encoding_path: str, + output_dir: str, + preproccessed_monitoring_data_path: str = '', + mimic_admission_nihss_db_path: str = '', + include_short_term_outcomes: bool = True, + short_term_outcomes_config: dict = {}, + feature_file_prefix: str = 'preprocessed_features', + outcome_file_prefix: str = 'preprocessed_outcomes', + verbose: bool = True, +): timestamp = time.strftime("%d%m%Y_%H%M%S") desired_time_range = 72 @@ -118,26 +140,49 @@ def preprocess_and_save( with open(os.path.join(log_dir, 'preprocessing_arguments.json'), 'w') as fp: json.dump(saved_args, fp, indent=4) - preprocessed_feature_df, preprocessed_outcome_df = preprocess(extracted_tables_path, admission_notes_data_path, - reference_population_imputation_path, reference_population_normalisation_parameters_path, - reference_categorical_encoding_path, - preproccessed_monitoring_data_path, mimic_admission_nihss_db_path, - log_dir=log_dir, verbose=verbose, - desired_time_range=desired_time_range) + preprocessed_feature_df, preprocessed_outcomes = preprocess( + extracted_tables_path, + admission_notes_data_path, + reference_population_imputation_path, + reference_population_normalisation_parameters_path, + reference_categorical_encoding_path, + preproccessed_monitoring_data_path, + mimic_admission_nihss_db_path, + log_dir=log_dir, + include_short_term_outcomes=include_short_term_outcomes, + short_term_outcomes_config=short_term_outcomes_config, + verbose=verbose, + desired_time_range=desired_time_range, + ) features_save_path = os.path.join(output_dir, f'{feature_file_prefix}_{timestamp}.csv') outcomes_save_path = os.path.join(output_dir, f'{outcome_file_prefix}_{timestamp}.csv') preprocessed_feature_df.to_csv(features_save_path, index=False) - preprocessed_outcome_df.to_csv(outcomes_save_path, index=False) - - if preproccessed_monitoring_data_path != '': + if include_short_term_outcomes: + preprocessed_long_term_outcome_df, preprocessed_short_term_outcome_df = preprocessed_outcomes + short_term_outcomes_path = os.path.join( + output_dir, + f'{outcome_file_prefix}_short_term_{timestamp}.csv', + ) + preprocessed_short_term_outcome_df.to_csv(short_term_outcomes_path, index=False) + else: + preprocessed_long_term_outcome_df = preprocessed_outcomes[0] + + preprocessed_long_term_outcome_df.to_csv(outcomes_save_path, index=False) + + if preproccessed_monitoring_data_path != '' and not preproccessed_monitoring_data_path is None: # copy file to log dir shutil.copy(preproccessed_monitoring_data_path, log_dir) # verification of geneva_stroke_unit_preprocessing variable_presence_verification(preprocessed_feature_df, desired_time_range=desired_time_range) - outcome_presence_verification(preprocessed_outcome_df, preprocessed_feature_df, log_dir=log_dir, outcomes=['Death in hospital', '3M Death']) + outcome_presence_verification( + preprocessed_long_term_outcome_df, + preprocessed_feature_df, + log_dir=log_dir, + outcomes=['Death in hospital', '3M Death'], + ) if __name__ == '__main__': @@ -159,10 +204,21 @@ def preprocess_and_save( parser.add_argument('-o', '--output_dir', type=str, help='Output directory') parser.add_argument('-m', '--preprocessed_monitoring_path', type=str, help='Path to preprocessed monitoring data') parser.add_argument('-mi', '--mimic_admission_nihss_db_path', type=str, help='Path to the mimic nihss data') + parser.add_argument('-s', '--include_short_term_outcomes', action='store_true', help='Include short term outcomes', default=False) + parser.add_argument('-end_min_repeats', '--end_min_repeats', type=bool, help='Whether to require a minimum number of repeated measurements for detection of END', default=False) + parser.add_argument('-end_min_delta', '--end_min_delta', type=int, help='The minimum difference in NIHSS scores to consider as END', default=4) + parser.add_argument('-end_keep_multiple_events', '--end_keep_multiple_events', type=bool, help='If True, allows detection of multiple deterioration events by resetting the baseline after each event', default=True) args = parser.parse_args() preprocess_and_save(args.ehr_tables_path, args.admission_notes_path, args.reference_population_imputation_path, args.reference_population_normalisation_parameters_path, args.reference_categorical_encoding_path, args.output_dir, - args.preprocessed_monitoring_path, args.mimic_admission_nihss_db_path, verbose=True) \ No newline at end of file + args.preprocessed_monitoring_path, args.mimic_admission_nihss_db_path, + include_short_term_outcomes=args.include_short_term_outcomes, + short_term_outcomes_config={ + 'end_require_min_repeats': args.end_min_repeats, + 'end_min_delta': args.end_min_delta, + 'end_keep_multiple_events': args.end_keep_multiple_events, + }, + verbose=True) \ No newline at end of file diff --git a/preprocessing/preprocessing_tools/handling_missing_values/impute_missing_values.py b/preprocessing/preprocessing_tools/handling_missing_values/impute_missing_values.py index 86479f8..17815a4 100644 --- a/preprocessing/preprocessing_tools/handling_missing_values/impute_missing_values.py +++ b/preprocessing/preprocessing_tools/handling_missing_values/impute_missing_values.py @@ -111,7 +111,10 @@ def impute_missing_values(df:pd.DataFrame, reference_population_imputation_path: imputation_parameters_columns = ['variable', 'imputed_value', 'imputation_method', 'imputation_range'] imputation_parameters_df = pd.DataFrame(columns=imputation_parameters_columns) - for sample_label in tqdm(imputed_missing_df.sample_label.unique()): + # parameter space should be the union of parameters in imputation_parameters_df and imputed_missing_df + parameter_space = set(imputed_missing_df.sample_label.unique()).union(set(reference_population_imputation_df.variable.unique()) if reference_population_imputation_path != '' else set()) + + for sample_label in tqdm(parameter_space): # find case_admission_ids with no value for sample_label in first timebin patients_with_no_sample_label_tp0 = set(imputed_missing_df.case_admission_id.unique()).difference(set( imputed_missing_df[(imputed_missing_df.sample_label == sample_label) & ( @@ -128,17 +131,11 @@ def impute_missing_values(df:pd.DataFrame, reference_population_imputation_path: elif (n_missing_cids_overall > 2/3 * df.case_admission_id.nunique()) & (reference_population_imputation_path != ''): # if sample label has a lot of missing values (~50%), then use mean/median of the reference population - if sample_label in categorical_vars: - # not implemented - raise NotImplementedError('Imputation from reference population of categorical variables is not implemented.') - else: - # use median - imputed_tp0_value = reference_population_imputation_df[ - (reference_population_imputation_df.variable == sample_label) - & (reference_population_imputation_df.imputation_method == 'median')]\ - ['imputed_value'].iloc[0] + imputed_tp0_value = reference_population_imputation_df[ + (reference_population_imputation_df.variable == sample_label)]\ + ['imputed_value'].iloc[0] + imputation_method = reference_population_imputation_df[reference_population_imputation_df.variable == sample_label]['imputation_method'].iloc[0] labels_imputed_from_reference_population.append([sample_label, imputed_tp0_value, len(patients_with_no_sample_label_tp0)]) - imputation_method = 'reference_population_median' imputation_range = 'reference_population' elif sample_label in categorical_vars: # for categorical vars, impute with mode @@ -168,8 +165,11 @@ def impute_missing_values(df:pd.DataFrame, reference_population_imputation_path: print( f'{len(patients_with_no_sample_label_tp0)} patients with no {sample_label} in first timebin for which {imputed_tp0_value} was imputed') - sample_label_original_source = \ - imputed_missing_df[imputed_missing_df.sample_label == sample_label].source.mode(dropna=True)[0] + if len(imputed_missing_df[imputed_missing_df.sample_label == sample_label].source.mode(dropna=True)) > 0: + sample_label_original_source = \ + imputed_missing_df[imputed_missing_df.sample_label == sample_label].source.mode(dropna=True)[0] + else: + sample_label_original_source = 'missing' imputed_sample_label = pd.DataFrame({'case_admission_id': list(patients_with_no_sample_label_tp0), 'sample_label': sample_label, From ccebf7b3ccb55dde818e4fd5239648ff37f4ba94 Mon Sep 17 00:00:00 2001 From: Julian Klug Date: Wed, 18 Feb 2026 16:26:53 +0100 Subject: [PATCH 2/5] updated meta data processing for END --- meta_data/other/ICH_meta.ipynb | 4 +- .../end_population_table.ipynb | 329 +++++++++++++++++- 2 files changed, 319 insertions(+), 14 deletions(-) diff --git a/meta_data/other/ICH_meta.ipynb b/meta_data/other/ICH_meta.ipynb index 591cf20..6dc0ae0 100644 --- a/meta_data/other/ICH_meta.ipynb +++ b/meta_data/other/ICH_meta.ipynb @@ -323,7 +323,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "opsum", "language": "python", "name": "python3" }, @@ -337,7 +337,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", - "version": "2.7.6" + "version": "3.8.11" } }, "nbformat": 4, diff --git a/meta_data/short_term_outcomes/end_population_table.ipynb b/meta_data/short_term_outcomes/end_population_table.ipynb index 5bdc41d..289bedf 100644 --- a/meta_data/short_term_outcomes/end_population_table.ipynb +++ b/meta_data/short_term_outcomes/end_population_table.ipynb @@ -31,9 +31,12 @@ }, "outputs": [], "source": [ - "cids_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_08062024_083500/case_admission_ids.csv'\n", - "outcomes_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_08062024_083500/preprocessed_outcomes_short_term_08062024_083500.csv'\n", - "registry_path = '/Users/jk1/Library/CloudStorage/OneDrive-unige.ch/stroke_research/geneva_stroke_unit_dataset/data/stroke_registry/post_hoc_modified/stroke_registry_post_hoc_modified.xlsx'" + "outcomes_path = '/Users/jk1/temp/opsum_end/preprocessing/with_imaging/gsu_Extraction_20220815_prepro_30012026_154047/preprocessed_outcomes_short_term_30012026_154047.csv'\n", + "features_path = '/Users/jk1/temp/opsum_end/preprocessing/with_imaging/gsu_Extraction_20220815_prepro_30012026_154047/preprocessed_features_30012026_154047.csv'\n", + "registry_path = '/Users/jk1/stroke_datasets/stroke_registry_post_hoc_modified.xlsx'\n", + "\n", + "train_pids_path = '/Users/jk1/temp/opsum_end/preprocessing/with_imaging/gsu_Extraction_20220815_prepro_30012026_154047/splits/pid_train.csv'\n", + "test_pids_path = '/Users/jk1/temp/opsum_end/preprocessing/with_imaging/gsu_Extraction_20220815_prepro_30012026_154047/splits/pid_test.csv'\n" ] }, { @@ -48,11 +51,42 @@ }, "outputs": [], "source": [ - "cids_df = pd.read_csv(cids_path)\n", + "features_df = pd.read_csv(features_path)\n", "outcomes_df = pd.read_csv(outcomes_path)\n", "registry_df = pd.read_excel(registry_path)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "f155845a", + "metadata": {}, + "outputs": [], + "source": [ + "train_pids = pd.read_csv(train_pids_path)\n", + "test_pids = pd.read_csv(test_pids_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ff2cd21", + "metadata": {}, + "outputs": [], + "source": [ + "cids_df = features_df[['case_admission_id']].drop_duplicates()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "274c8884", + "metadata": {}, + "outputs": [], + "source": [ + "print('Number of unique case_admission_ids in features_df:', len(cids_df))" + ] + }, { "cell_type": "code", "execution_count": null, @@ -102,7 +136,32 @@ }, "outputs": [], "source": [ - "registry_df['Etiology TOAST'].value_counts()" + "def outcome_preprocessing(df: pd.DataFrame):\n", + " # if death in hospital, set mRs to 6\n", + " df.loc[df['Death in hospital'] == 'yes', '3M mRS'] = 6\n", + " # if 3M Death and 3M mRS nan, set mrs to 6\n", + " df.loc[(df['3M Death'] == 'yes') & (df['3M mRS'].isna()), '3M mRS'] = 6\n", + "\n", + " # if death in hospital set 3M Death to yes\n", + " df.loc[df['Death in hospital'] == 'yes', '3M Death'] = 'yes'\n", + " # if 3M mRs == 6, set 3M Death to yes\n", + " df.loc[df['3M mRS'] == 6, '3M Death'] = 'yes'\n", + " # if 3M mRs not nan and not 6, set 3M Death to no\n", + " df.loc[(df['3M mRS'] != 6) &\n", + " (~df['3M mRS'].isna())\n", + " & (df['3M Death'].isna()), '3M Death'] = 'no'\n", + "\n", + " return df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5247e33", + "metadata": {}, + "outputs": [], + "source": [ + "registry_df = outcome_preprocessing(registry_df)" ] }, { @@ -139,6 +198,7 @@ " 'Prestroke disability (Rankin)',\n", " 'NIH on admission',\n", " 'BMI',\n", + " '3M mRS'\n", " ]\n", "\n", "CATEGORICAL_CHARACTERISTICS = [\n", @@ -152,6 +212,7 @@ " 'Etiology - Cardiac embolism',\n", " 'Etiology - Large artery atherosclerosis',\n", " 'Etiology - Small vessel disease',\n", + " '3M Death'\n", "]" ] }, @@ -286,35 +347,279 @@ }, "outputs": [], "source": [ - "comparison_table_df.to_csv('/Users/jk1/Downloads/end_table1.csv')" + "# comparison_table_df.to_csv('/Users/jk1/Downloads/end_table1.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a438d3df", + "metadata": {}, + "outputs": [], + "source": [ + "registry_df['patient_id'] = registry_df['Case ID'].apply(lambda x: x[8:-4]).astype(str)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca846318", + "metadata": {}, + "outputs": [], + "source": [ + "train_population_df, train_population_str_df = create_population_table(registry_df[registry_df.patient_id.isin(train_pids.patient_id.astype(str))], CONTINUOUS_CHARACTERISTICS, CATEGORICAL_CHARACTERISTICS)\n", + "test_population_df, test_population_str_df = create_population_table(registry_df[registry_df.patient_id.isin(test_pids.patient_id.astype(str))], CONTINUOUS_CHARACTERISTICS, CATEGORICAL_CHARACTERISTICS)\n", + "combined_table_df = pd.concat([train_population_str_df.T, test_population_str_df.T], axis=1)\n", + "combined_table_df.columns = ['Train', 'Test']\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64f12e9f", + "metadata": {}, + "outputs": [], + "source": [ + "combined_table_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6390c96a", + "metadata": {}, + "outputs": [], + "source": [ + "# combined_table_df.to_csv('/Users/jk1/Downloads/end_train_test_table1.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "685cec7e", + "metadata": {}, + "outputs": [], + "source": [ + "train_cids = registry_df[registry_df.patient_id.isin(train_pids.patient_id.astype(str))]['case_admission_id'].unique()\n", + "test_cids = registry_df[registry_df.patient_id.isin(test_pids.patient_id.astype(str))]['case_admission_id'].unique()\n", + "\n", + "print('Number of unique case_admission_ids in train set:', len(train_cids))\n", + "print('Number of unique case_admission_ids in test set:', len(test_cids))\n", + "\n", + "# number of unique patients in train and test set\n", + "print('Number of unique patients in train set:', len(train_pids))\n", + "print('Number of unique patients in test set:', len(test_pids))\n" + ] + }, + { + "cell_type": "markdown", + "id": "185b7df5", + "metadata": {}, + "source": [ + "## END characteristics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c720071f", + "metadata": {}, + "outputs": [], + "source": [ + "# number of events in train and test set\n", + "n_events_train = registry_df[registry_df.case_admission_id.isin(train_cids)]['END'].sum()\n", + "n_events_test = registry_df[registry_df.case_admission_id.isin(test_cids)]['END'].sum()\n", + "print('Number of END events in train set:', n_events_train, 'Number of admissions without END:', len(train_cids) - n_events_train)\n", + "print('Number of END events in test set:', n_events_test, 'Number of admissions without END:', len(test_cids) - n_events_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "453aba1a", + "metadata": {}, + "outputs": [], + "source": [ + "# get 95% confidence interval for proportion of END events in train and test set (bootstrap)\n", + "def bootstrap_proportion_ci(values, n_boot=10000, alpha=0.05, seed=42):\n", + " rng = np.random.default_rng(seed)\n", + " values = np.asarray(values, dtype=float)\n", + " n = len(values)\n", + " if n == 0:\n", + " return (np.nan, np.nan)\n", + " boot_props = rng.choice(values, size=(n_boot, n), replace=True).mean(axis=1)\n", + " lower = np.percentile(boot_props, 100 * (alpha / 2))\n", + " upper = np.percentile(boot_props, 100 * (1 - alpha / 2))\n", + " return lower, upper\n", + "\n", + "train_end_labels = registry_df[registry_df.case_admission_id.isin(train_cids)]['END'].astype(int).values\n", + "test_end_labels = registry_df[registry_df.case_admission_id.isin(test_cids)]['END'].astype(int).values\n", + "\n", + "ci_train = bootstrap_proportion_ci(train_end_labels)\n", + "ci_test = bootstrap_proportion_ci(test_end_labels)\n", + "\n", + "# print in this format: in X patients (X%; 95% CI, X-X)\n", + "print(f'In train set: {n_events_train} patients ({n_events_train/len(train_cids)*100:.1f}%; 95% CI, {ci_train[0]*100:.1f}-{ci_train[1]*100:.1f}%) had END events.')\n", + "print(f'In test set: {n_events_test} patients ({n_events_test/len(test_cids)*100:.1f}%; 95% CI, {ci_test[0]*100:.1f}-{ci_test[1]*100:.1f}%) had END events.')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "219788fa", + "metadata": {}, + "outputs": [], + "source": [ + "# get delta_to_min median (IQR), print as X (X-X)\n", + "delta_to_min_median = outcomes_df['delta_to_min'].median()\n", + "delta_to_min_q25 = outcomes_df['delta_to_min'].quantile(0.25)\n", + "delta_to_min_q75 = outcomes_df['delta_to_min'].quantile(0.75)\n", + "print(f'{delta_to_min_median:.1f} ({delta_to_min_q25:.1f}-{delta_to_min_q75:.1f})')" ] }, { "cell_type": "code", "execution_count": null, - "id": "fff1af9d4943b770", + "id": "b8aadade", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# from relative_sample_date get distribution of onsent of end\n", + "# express as within 6 hours in X patients (X%), 6 to 12 hours in X (X%), 12 to 24 hours X (X%) and 24 to 72 hours in X (X%)\n", + "outcomes_df['onset_category'] = pd.cut(outcomes_df['relative_sample_date'], bins=[-np.inf, 6, 12, 24, 72, np.inf], labels=['<6h', '6-12h', '12-24h', '24-72h', '>72h'])\n", + "onset_distribution = outcomes_df['onset_category'].value_counts().sort_index()\n", + "onset_distribution_percent = outcomes_df['onset_category'].value_counts(normalize=True).sort_index() * 100\n", + "for category in onset_distribution.index:\n", + " print(f'Onset {category}: {onset_distribution[category]} patients ({onset_distribution_percent[category]:.1f}%)')" + ] + }, + { + "cell_type": "markdown", + "id": "dbe9008e", + "metadata": {}, + "source": [ + "## Effect of END on outcome" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dfe37dbe", + "metadata": {}, + "outputs": [], + "source": [ + "# from registry_df get END (1/0) and 3M Death - get univariable OR through logistic regression\n", + "import statsmodels.api as sm\n", + "\n", + "df_or = registry_df[['END', '3M Death']].copy()\n", + "df_or = df_or[df_or['3M Death'].isin(['yes', 'no'])]\n", + "df_or['death_3m'] = df_or['3M Death'].map({'yes': 1, 'no': 0})\n", + "df_or = df_or.dropna(subset=['END', 'death_3m'])\n", + "\n", + "X = sm.add_constant(df_or['END'].astype(float))\n", + "y = df_or['death_3m'].astype(int)\n", + "\n", + "model = sm.Logit(y, X).fit(disp=False)\n", + "or_end = float(np.exp(model.params['END']))\n", + "ci_low, ci_high = np.exp(model.conf_int().loc['END']).tolist()\n", + "p_value = float(model.pvalues['END'])\n", + "\n", + "n_total = len(df_or)\n", + "_ = y.sum()\n", + "print(\n", + " f\"Univariable OR (END -> 3M Death): {or_end:.2f} \"\n", + " f\"(95% CI {ci_low:.2f}-{ci_high:.2f}), p={p_value:.3g}; \"\n", + " f\"N={n_total}, events={int(_)}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7db72f43", + "metadata": {}, + "outputs": [], + "source": [ + "registry_df['mrs0_2'] = registry_df['3M mRS'].apply(lambda x: 1 if x <= 2 else 0 if not pd.isna(x) else np.nan)\n", + "registry_df['mrs>2'] = registry_df['3M mRS'].apply(lambda x: 1 if x > 2 else 0 if not pd.isna(x) else np.nan)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d89109f9", + "metadata": {}, + "outputs": [], + "source": [ + "# OR for favorable outcome (3M mRS 0-2)\n", + "df_or_mrs = registry_df[['END', '3M mRS']].copy()\n", + "df_or_mrs = df_or_mrs.dropna(subset=['END', '3M mRS'])\n", + "df_or_mrs['mrs02'] = (df_or_mrs['3M mRS'] <= 2).astype(int)\n", + "\n", + "X_mrs = sm.add_constant(df_or_mrs['END'].astype(float))\n", + "y_mrs = df_or_mrs['mrs02'].astype(int)\n", + "\n", + "model_mrs = sm.Logit(y_mrs, X_mrs).fit(disp=False)\n", + "or_end_mrs = float(np.exp(model_mrs.params['END']))\n", + "ci_low_mrs, ci_high_mrs = np.exp(model_mrs.conf_int().loc['END']).tolist()\n", + "p_value_mrs = float(model_mrs.pvalues['END'])\n", + "\n", + "n_total_mrs = len(df_or_mrs)\n", + "_ = y_mrs.sum()\n", + "print(\n", + " f\"Univariable OR (END -> mRS 0-2): {or_end_mrs:.2f} \"\n", + " f\"(95% CI {ci_low_mrs:.2f}-{ci_high_mrs:.2f}), p={p_value_mrs:.3g}; \"\n", + " f\"N={n_total_mrs}, mRS0-2={int(_)}\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "007a564c", + "metadata": {}, + "outputs": [], + "source": [ + "# OR for unfavorable outcome (3M mRS >2)\n", + "df_or_mrs_gt2 = registry_df[['END', '3M mRS']].copy()\n", + "df_or_mrs_gt2 = df_or_mrs_gt2.dropna(subset=['END', '3M mRS'])\n", + "df_or_mrs_gt2['mrs_gt2'] = (df_or_mrs_gt2['3M mRS'] > 2).astype(int)\n", + "\n", + "X_mrs_gt2 = sm.add_constant(df_or_mrs_gt2['END'].astype(float))\n", + "y_mrs_gt2 = df_or_mrs_gt2['mrs_gt2'].astype(int)\n", + "\n", + "model_mrs_gt2 = sm.Logit(y_mrs_gt2, X_mrs_gt2).fit(disp=False)\n", + "or_end_mrs_gt2 = float(np.exp(model_mrs_gt2.params['END']))\n", + "ci_low_mrs_gt2, ci_high_mrs_gt2 = np.exp(model_mrs_gt2.conf_int().loc['END']).tolist()\n", + "p_value_mrs_gt2 = float(model_mrs_gt2.pvalues['END'])\n", + "\n", + "n_total_mrs_gt2 = len(df_or_mrs_gt2)\n", + "_ = y_mrs_gt2.sum()\n", + "print(\n", + " f\"Univariable OR (END -> mRS >2): {or_end_mrs_gt2:.2f} \"\n", + " f\"(95% CI {ci_low_mrs_gt2:.2f}-{ci_high_mrs_gt2:.2f}), p={p_value_mrs_gt2:.3g}; \"\n", + " f\"N={n_total_mrs_gt2}, mRS>2={int(_)}\"\n", + ")" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "opsum", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "pygments_lexer": "ipython3", + "version": "3.8.11" } }, "nbformat": 4, From 17b2007d3f670cba47aefe4979a8a5d18275d119 Mon Sep 17 00:00:00 2001 From: Julian Klug Date: Tue, 24 Feb 2026 10:51:43 +0100 Subject: [PATCH 3/5] updated imaging missingness computation and table 1 --- .../end_population_table.ipynb | 39 ++- .../imaging_missingness.ipynb | 262 ++++++++++++++++++ 2 files changed, 300 insertions(+), 1 deletion(-) create mode 100644 meta_data/short_term_outcomes/imaging_missingness.ipynb diff --git a/meta_data/short_term_outcomes/end_population_table.ipynb b/meta_data/short_term_outcomes/end_population_table.ipynb index 289bedf..8e6df07 100644 --- a/meta_data/short_term_outcomes/end_population_table.ipynb +++ b/meta_data/short_term_outcomes/end_population_table.ipynb @@ -164,6 +164,24 @@ "registry_df = outcome_preprocessing(registry_df)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "73c1a875", + "metadata": {}, + "outputs": [], + "source": [ + "registry_df['Etiology TOAST'].unique()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d17a67d", + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -178,7 +196,23 @@ "source": [ "registry_df['Etiology - Cardiac embolism'] = registry_df['Etiology TOAST'].apply(lambda x: 1 if x == 'Cardiac embolism' else 0)\n", "registry_df['Etiology - Large artery atherosclerosis'] = registry_df['Etiology TOAST'].apply(lambda x: 1 if x == 'Large artery atherosclerosis' else 0)\n", - "registry_df['Etiology - Small vessel disease'] = registry_df['Etiology TOAST'].apply(lambda x: 1 if x == 'Small vessel disease' else 0)" + "registry_df['Etiology - Small vessel disease'] = registry_df['Etiology TOAST'].apply(lambda x: 1 if x == 'Small vessel disease' else 0)\n", + "# etiology is Other if == 'Other determined etiology', 'PFO', 'Cervical artery dissection', 'More than one possible etiology', 'Stroke or TIA mimic'\n", + "registry_df['Etiology - Other'] = registry_df['Etiology TOAST'].apply(lambda x: 1 if x in ['Other determined etiology', 'PFO', 'Cervical artery dissection', 'More than one possible etiology', 'Stroke or TIA mimic'] else 0)\n", + "# etiology is unknown if == 'Unknown etiology with incomplete evaluation', 'Unknown etiology despite complete evaluation', or nan\n", + "registry_df['Etiology - Unknown'] = registry_df['Etiology TOAST'].apply(lambda x: 1 if x in ['Unknown etiology with incomplete evaluation', 'Unknown etiology despite complete evaluation'] or pd.isna(x) else 0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7e54fbc", + "metadata": {}, + "outputs": [], + "source": [ + "# ensure that the sum of the etiology columns is 1 for each patient (i.e. each patient has only one etiology)\n", + "registry_df['Etiology_sum'] = registry_df['Etiology - Cardiac embolism'] + registry_df['Etiology - Large artery atherosclerosis'] + registry_df['Etiology - Small vessel disease'] + registry_df['Etiology - Other'] + registry_df['Etiology - Unknown']\n", + "assert registry_df['Etiology_sum'].sum() == len(registry_df), 'Error: Each patient should have only one etiology, but some patients have more than one etiology.'" ] }, { @@ -209,9 +243,12 @@ " 'MedHist Diabetes',\n", " 'MedHist Hyperlipidemia',\n", " 'MedHist Atrial Fibr.',\n", + " 'MedHist Smoking',\n", " 'Etiology - Cardiac embolism',\n", " 'Etiology - Large artery atherosclerosis',\n", " 'Etiology - Small vessel disease',\n", + " 'Etiology - Other',\n", + " 'Etiology - Unknown',\n", " '3M Death'\n", "]" ] diff --git a/meta_data/short_term_outcomes/imaging_missingness.ipynb b/meta_data/short_term_outcomes/imaging_missingness.ipynb new file mode 100644 index 0000000..83f97fd --- /dev/null +++ b/meta_data/short_term_outcomes/imaging_missingness.ipynb @@ -0,0 +1,262 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b2136a32", + "metadata": {}, + "source": [ + "# Verify missingness of imaging data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ff3514d", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from preprocessing.geneva_stroke_unit_preprocessing.utils import create_registry_case_identification_column" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d45af44", + "metadata": {}, + "outputs": [], + "source": [ + "features_path = '/Users/jk1/temp/opsum_end/preprocessing/with_imaging/gsu_Extraction_20220815_prepro_30012026_154047/preprocessed_features_30012026_154047.csv'\n", + "registry_path = '/Users/jk1/stroke_datasets/stroke_registry_post_hoc_modified.xlsx'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4bffac50", + "metadata": {}, + "outputs": [], + "source": [ + "features_df = pd.read_csv(features_path)\n", + "registry_df = pd.read_excel(registry_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "795b50c5", + "metadata": {}, + "outputs": [], + "source": [ + "registry_df['case_admission_id'] = create_registry_case_identification_column(registry_df)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57084091", + "metadata": {}, + "outputs": [], + "source": [ + "features_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fa6da80", + "metadata": {}, + "outputs": [], + "source": [ + "features_df.sample_label.unique()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df668ccb", + "metadata": {}, + "outputs": [], + "source": [ + "imaging_labels = ['cbf_lt_20',\n", + " 'cbf_lt_30', 'cbf_lt_34', 'cbf_lt_38', 'cbv_lt_34', 'cbv_lt_38',\n", + " 'cbv_lt_42', 'tmax_gt_10',\n", + " 'tmax_gt_4', 'tmax_gt_6', 'tmax_gt_8']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e252a8d3", + "metadata": {}, + "outputs": [], + "source": [ + "imaging_df = features_df[features_df.sample_label.isin(imaging_labels)]\n", + "hypoperf_df = features_df[features_df.sample_label == 'hypoperfusion_with_mismatch']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd766684", + "metadata": {}, + "outputs": [], + "source": [ + "cids_with_imaging_according_to_hypoperfusion = hypoperf_df[hypoperf_df.source == 'stroke_registry'].case_admission_id.unique()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05b23751", + "metadata": {}, + "outputs": [], + "source": [ + "cids_with_imaging = imaging_df[imaging_df.source == 'EHR'].case_admission_id.unique()\n", + "cids_without_imaging = set(features_df.case_admission_id.unique()) - set(cids_with_imaging)\n", + "\n", + "print(f'Number of cases with imaging: {len(cids_with_imaging)}')\n", + "print(f'Number of cases without imaging: {len(cids_without_imaging)}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0add517", + "metadata": {}, + "outputs": [], + "source": [ + "cids_with_missing_imaging_according_to_hypoperfusion = set(cids_with_imaging_according_to_hypoperfusion) - set(cids_with_imaging)\n", + "print(f'Number of cases with missing imaging according to hypoperfusion: {len(cids_with_missing_imaging_according_to_hypoperfusion)}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4c5c8ac", + "metadata": {}, + "outputs": [], + "source": [ + "# create a dataframe with columsn case_admission_id and imaging_missing\n", + "missing_imaging_df = pd.DataFrame({'case_admission_id': list(features_df.case_admission_id.unique())})\n", + "missing_imaging_df['imaging_missing'] = missing_imaging_df.case_admission_id.apply(lambda x: 1 if x in cids_without_imaging else 0)\n", + "missing_imaging_df['imaging_missing_according_to_registry'] = missing_imaging_df.case_admission_id.apply(lambda x: 1 if x in cids_with_missing_imaging_according_to_hypoperfusion else 0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9846239", + "metadata": {}, + "outputs": [], + "source": [ + "registry_df['year'] = registry_df['Entry date'].apply(lambda x: str(x)[0:4] if x else None).astype('Int64')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de6485b1", + "metadata": {}, + "outputs": [], + "source": [ + "missing_imaging_df = missing_imaging_df.merge(registry_df[['case_admission_id', 'year']], on='case_admission_id', how='left')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a7ca2c8b", + "metadata": {}, + "outputs": [], + "source": [ + "# make a histogram of the years of the cases with missing imaging (overall and according to hypoperfusion\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "\n", + "plt.figure(figsize=(12, 6))\n", + "sns.histplot(missing_imaging_df[missing_imaging_df.imaging_missing == 1]['year'], bins=range(2017, 2024), color='blue', label='Missing Imaging (EHR)', kde=False)\n", + "sns.histplot(missing_imaging_df[missing_imaging_df.imaging_missing_according_to_registry == 1]['year'], bins=range(2017, 2024), color='red', label='Missing Imaging (Hypoperfusion)', kde=False)\n", + "plt.xlabel('Year')\n", + "plt.ylabel('Count')\n", + "plt.title('Distribution of Cases with Missing Imaging by Year')\n", + "\n", + "# add legend\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cabf5d32", + "metadata": {}, + "outputs": [], + "source": [ + "missing_imaging_df[missing_imaging_df.imaging_missing == 1].year.value_counts().sort_index()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49750620", + "metadata": {}, + "outputs": [], + "source": [ + "registry_df[registry_df.case_admission_id.isin(cids_with_missing_imaging_according_to_hypoperfusion)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae7c218a", + "metadata": {}, + "outputs": [], + "source": [ + "extraction_target_df = registry_df[registry_df.case_admission_id.isin(cids_with_missing_imaging_according_to_hypoperfusion)]\n", + "extraction_target_df['patient_id'] = extraction_target_df['Case ID'].apply(lambda x: x[8:-4]).astype(str)\n", + "# columns Case ID, case_admissions_id, patient_id\n", + "extraction_target_df = extraction_target_df[['Case ID', 'case_admission_id', 'patient_id']]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27609b7a", + "metadata": {}, + "outputs": [], + "source": [ + "# extraction_target_df.to_csv('/Users/jk1/Downloads/refined_extraction_target_20022026.csv', index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a0961fc", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "opsum", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 96c8d2bbca3fb44c3de48e4daddcb91e03dbedd3 Mon Sep 17 00:00:00 2001 From: Julian Klug Date: Sun, 1 Mar 2026 08:28:05 +0100 Subject: [PATCH 4/5] rerun shap figures --- .../identification_of_true_positives.ipynb | 307 ++++++++++++++++++ .../exploration/top_predictors_selection.py | 74 +++++ .../figures/single_subj_inference.ipynb | 10 +- .../figures/xgb_top_shap_values.ipynb | 6 +- .../hyperopt/xgb_gridsearch_evaluation.ipynb | 33 +- 5 files changed, 416 insertions(+), 14 deletions(-) create mode 100644 meta_data/short_term_outcomes/identification_of_true_positives.ipynb create mode 100644 prediction/short_term_outcome_prediction/figures/exploration/top_predictors_selection.py diff --git a/meta_data/short_term_outcomes/identification_of_true_positives.ipynb b/meta_data/short_term_outcomes/identification_of_true_positives.ipynb new file mode 100644 index 0000000..31f3944 --- /dev/null +++ b/meta_data/short_term_outcomes/identification_of_true_positives.ipynb @@ -0,0 +1,307 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "01c8fdbb", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import pickle\n", + "import torch as ch\n", + "from preprocessing.geneva_stroke_unit_preprocessing.utils import create_ehr_case_identification_column" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ca15e7e", + "metadata": {}, + "outputs": [], + "source": [ + "predictions_path = '/Users/jk1/temp/opsum_end/testing/with_imaging/xgb_test_results/test_predictions.pkl'\n", + "test_data_path = '/Users/jk1/temp/opsum_end/preprocessing/with_imaging/gsu_Extraction_20220815_prepro_30012026_154047/splits/test_data_early_neurological_deterioration_ts0.8_rs42_ns5.pth'\n", + "eds_data = '/Users/jk1/stroke_datasets/Extraction_20220815/eds_j1.csv'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00eb5aab", + "metadata": {}, + "outputs": [], + "source": [ + "n_timesteps = 72\n", + "threshold = 0.239" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65e49be4", + "metadata": {}, + "outputs": [], + "source": [ + "with open(predictions_path, 'rb') as f:\n", + " predictions = pickle.load(f)\n", + "y_test, y_prob = predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b6f882e", + "metadata": {}, + "outputs": [], + "source": [ + "X_test_raw, y_test_raw = ch.load(test_data_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4caa0ea6", + "metadata": {}, + "outputs": [], + "source": [ + "eds_df = pd.read_csv(eds_data, delimiter=';', encoding='utf-8',\n", + " dtype=str)\n", + "eds_df['case_admission_id'] = create_ehr_case_identification_column(eds_df)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8795f6b3", + "metadata": {}, + "outputs": [], + "source": [ + "cids = X_test_raw[:,0,0,0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "98b5e94d", + "metadata": {}, + "outputs": [], + "source": [ + "y_prob_matrix = y_prob.reshape(-1, n_timesteps)\n", + "y_test_matrix = y_test.reshape(-1, n_timesteps)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a8046a7", + "metadata": {}, + "outputs": [], + "source": [ + "y_pred_matrix = (y_prob_matrix >= threshold).astype(int)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa7565f6", + "metadata": {}, + "outputs": [], + "source": [ + "true_positives = ((y_test_matrix == 1) & (y_pred_matrix == 1))\n", + "false_positives = ((y_test_matrix == 0) & (y_pred_matrix == 1))\n", + "true_negatives = ((y_test_matrix == 0) & (y_pred_matrix == 0))\n", + "false_negatives = ((y_test_matrix == 1) & (y_pred_matrix == 0))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cfe3eb02", + "metadata": {}, + "outputs": [], + "source": [ + "n_true_positives_per_patient = true_positives.sum(axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44a18041", + "metadata": {}, + "outputs": [], + "source": [ + "true_positives_per_patient_df = pd.DataFrame({\n", + " 'case_admission_id': cids,\n", + " 'n_true_positives': n_true_positives_per_patient\n", + "})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca5ea996", + "metadata": {}, + "outputs": [], + "source": [ + "cids_with_true_positives = true_positives_per_patient_df[true_positives_per_patient_df['n_true_positives'] > 0]['case_admission_id'].tolist()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43b58644", + "metadata": {}, + "outputs": [], + "source": [ + "len(cids_with_true_positives)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f970269", + "metadata": {}, + "outputs": [], + "source": [ + "# ensure that cids_with_true_positives is a subset of the cids in the y_test_raw set\n", + "set(cids_with_true_positives).issubset(set(y_test_raw.case_admission_id))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f30bdcbc", + "metadata": {}, + "outputs": [], + "source": [ + "true_positive_df = y_test_raw[y_test_raw['case_admission_id'].isin(cids_with_true_positives)].copy()\n", + "true_positive_df = true_positive_df[['case_admission_id', 'patient_id', 'sample_date', 'relative_sample_date', 'value', 'min_nihss', 'delta_to_min',]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56e64d7c", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89b475d8", + "metadata": {}, + "outputs": [], + "source": [ + "true_positive_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7c05817", + "metadata": {}, + "outputs": [], + "source": [ + "eds_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "843ee230", + "metadata": {}, + "outputs": [], + "source": [ + "true_positive_df = true_positive_df.merge(eds_df[['case_admission_id', 'DOB', 'eds_final_id']], on='case_admission_id', how='left')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "420c05ed", + "metadata": {}, + "outputs": [], + "source": [ + "true_positive_df.columns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce55920b", + "metadata": {}, + "outputs": [], + "source": [ + "true_positive_df = true_positive_df[['case_admission_id', 'eds_final_id', 'patient_id', 'DOB', 'sample_date',\n", + " 'relative_sample_date', 'value', 'min_nihss', 'delta_to_min',]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "530b4216", + "metadata": {}, + "outputs": [], + "source": [ + "# rename eds_final_id to EDS\n", + "true_positive_df.rename(columns={'eds_final_id': 'EDS'}, inplace=True)\n", + "# rename sample_date to END_date\n", + "true_positive_df.rename(columns={'sample_date': 'END_date'}, inplace=True)\n", + "# rename value to END_value\n", + "true_positive_df.rename(columns={'value': 'END_value'}, inplace=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e582869", + "metadata": {}, + "outputs": [], + "source": [ + "# sort by case_admission_id \n", + "true_positive_df.sort_values(by='case_admission_id', inplace=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21d9a452", + "metadata": {}, + "outputs": [], + "source": [ + "# true_positive_df.to_csv('/Users/jk1/Downloads/end_true_positives_test_set_25022026.csv', index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "542b99cd", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "opsum", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/prediction/short_term_outcome_prediction/figures/exploration/top_predictors_selection.py b/prediction/short_term_outcome_prediction/figures/exploration/top_predictors_selection.py new file mode 100644 index 0000000..9841495 --- /dev/null +++ b/prediction/short_term_outcome_prediction/figures/exploration/top_predictors_selection.py @@ -0,0 +1,74 @@ +import pandas as pd +import pickle +import numpy as np +import torch as ch +import os +import seaborn as sns +import matplotlib.pyplot as plt +from prediction.utils.visualisation_helper_functions import hex_to_rgb_color, create_palette +from colormath.color_objects import LabColor + +shap_values_path = '/Users/jk1/temp/opsum_end/testing/with_imaging/xgb_test_results/shap_explanations_over_time/tree_explainer_shap_values_over_ts.pkl' +test_data_path = '/Users/jk1/temp/opsum_end/preprocessing/with_imaging/gsu_Extraction_20220815_prepro_30012026_154047/splits/test_data_early_neurological_deterioration_ts0.8_rs42_ns5.pth' +cat_encoding_path = '/Users/jk1/temp/opsum_end/preprocessing/with_imaging/gsu_Extraction_20220815_prepro_30012026_154047/logs_30012026_154047/categorical_variable_encoding.csv' + +# load the shap values +with open(os.path.join(shap_values_path), 'rb') as handle: + original_shap_values = pickle.load(handle) + +shap_values = [np.array([original_shap_values[i] for i in range(len(original_shap_values))]).swapaxes(0, 1)][0] + +X_test, y_test= ch.load(test_data_path) + +features = X_test[0, 0, :, 2] + +# Toggle these to match the model that produced the SHAP values +add_lag_features = True +add_rolling_features = True + +# Build aggregated feature names matching aggregate_features_over_time output order: +# [features, avg_features, min_features, max_features, std_features, diff_features, timestep_feature] [lag2, lag3] [roll_mean, roll_std, roll_trend] +# features, avg_, min_, max_, std_, diff_, timestep_idx, [lag2_, lag3_], [rolling_mean_, rolling_std_, rolling_trend_] +aggregated_feature_names = list(features) +for prefix in ['avg_', 'min_', 'max_', 'std_', 'diff_']: + aggregated_feature_names += [f'{prefix}{f}' for f in features] +aggregated_feature_names += ['timestep_idx'] + +if add_lag_features: + for prefix in ['lag2_', 'lag3_']: + aggregated_feature_names += [f'{prefix}{f}' for f in features] + +if add_rolling_features: + for prefix in ['rolling_mean_', 'rolling_std_', 'rolling_trend_']: + aggregated_feature_names += [f'{prefix}{f}' for f in features] + +aggregated_feature_names += ['base_value'] +print(f'{len(aggregated_feature_names)} feature names (including base_value), SHAP columns: {shap_values.shape[2]}') + +sum_over_all_shap_values = np.abs(shap_values).sum(axis=(0,1)) + + +temp_df = pd.DataFrame({'feature': aggregated_feature_names, 'shap_value': sum_over_all_shap_values}) +# remove timestep_idx and base_value from the features +temp_df = temp_df[~temp_df.feature.isin(['timestep_idx', 'base_value'])] +# remove avg_, min_, max_, std_, diff_, timestep_idx, [lag2_, lag3_], [rolling_mean_, rolling_std_, rolling_trend_] from the feature names to get the original feature names +prefixes = ['rolling_mean_', 'rolling_std_', 'rolling_trend_', 'avg_', 'min_', 'max_', 'std_', 'diff_', 'lag2_', 'lag3_',] +for prefix in prefixes: + temp_df.loc[temp_df.feature.str.contains(prefix), 'feature'] = temp_df[temp_df.feature.str.contains(prefix)].feature.apply(lambda x: x.replace(prefix, '')) +hourly_pool_prefixes = ['median_', 'min_', 'max_'] +for prefix in hourly_pool_prefixes: + temp_df.loc[temp_df.feature.str.contains(prefix), 'feature'] = temp_df[temp_df.feature.str.contains(prefix)].feature.apply(lambda x: x.replace(prefix, '')) +blood_pressure_prefixes = ['systolic_', 'diastolic_', 'mean_'] +for prefix in blood_pressure_prefixes: + temp_df.loc[temp_df.feature.str.contains(prefix), 'feature'] = temp_df[temp_df.feature.str.contains(prefix)].feature.apply(lambda x: x.replace(prefix, '')) + +# transform to absolute shap values +temp_df['absolute_shap_value'] = np.abs(temp_df['shap_value']) +# drop shap value +temp_df = temp_df.drop(columns=['shap_value']) +# sum the shap values for the same original feature names +temp_df = temp_df.groupby('feature').sum().reset_index() +temp_df.sort_values(by='absolute_shap_value', ascending=False).head(10) +top_10_features_by_mean_abs_summed_shap = temp_df.sort_values(by='absolute_shap_value', ascending=False).head(10).feature.values + +print(f'Top 10 features by mean absolute summed SHAP values: {top_10_features_by_mean_abs_summed_shap}') \ No newline at end of file diff --git a/prediction/short_term_outcome_prediction/figures/single_subj_inference.ipynb b/prediction/short_term_outcome_prediction/figures/single_subj_inference.ipynb index a6c31e4..8b9aa1b 100644 --- a/prediction/short_term_outcome_prediction/figures/single_subj_inference.ipynb +++ b/prediction/short_term_outcome_prediction/figures/single_subj_inference.ipynb @@ -36,12 +36,12 @@ "metadata": {}, "outputs": [], "source": [ - "shap_values_path = '/Users/jk1/temp/opsum_end/training/hyperopt/xgb_gridsearch/xgb_gs_20250513_154517/checkpoints_short_opsum_xgb_20250518_001112_cv_1/shap_explanations_over_time/tree_explainer_shap_values_over_ts.pkl'\n", - "test_data_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_09052025_220520/early_neurological_deterioration_train_data_splits/test_data_early_neurological_deterioration_ts0.8_rs42_ns5.pth'\n", - "cat_encoding_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_09052025_220520/logs_09052025_220520/categorical_variable_encoding.csv'\n", + "shap_values_path = '/Users/jk1/temp/opsum_end/training/without_imaging/hyperopt/xgb_gridsearch/xgb_gs_20250513_154517/checkpoints_short_opsum_xgb_20250518_001112_cv_1/shap_explanations_over_time/tree_explainer_shap_values_over_ts.pkl'\n", + "test_data_path = '/Users/jk1/temp/opsum_end/preprocessing/without_imaging/gsu_Extraction_20220815_prepro_09052025_220520/early_neurological_deterioration_train_data_splits/test_data_early_neurological_deterioration_ts0.8_rs42_ns5.pth'\n", + "cat_encoding_path = '/Users/jk1/temp/opsum_end/preprocessing/without_imaging/gsu_Extraction_20220815_prepro_09052025_220520/logs_09052025_220520/categorical_variable_encoding.csv'\n", "\n", - "normalisation_parameters_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_09052025_220520/logs_09052025_220520/normalisation_parameters.csv'\n", - "predictions_path = '/Users/jk1/temp/opsum_end/testing/test_gt_and_pred_cv_1.pkl'" + "normalisation_parameters_path = '/Users/jk1/temp/opsum_end/preprocessing/without_imaging/gsu_Extraction_20220815_prepro_09052025_220520/logs_09052025_220520/normalisation_parameters.csv'\n", + "predictions_path = '/Users/jk1/temp/opsum_end/testing/without_imaging/test_gt_and_pred_cv_1.pkl'" ] }, { diff --git a/prediction/short_term_outcome_prediction/figures/xgb_top_shap_values.ipynb b/prediction/short_term_outcome_prediction/figures/xgb_top_shap_values.ipynb index d7e074e..d2f33a6 100644 --- a/prediction/short_term_outcome_prediction/figures/xgb_top_shap_values.ipynb +++ b/prediction/short_term_outcome_prediction/figures/xgb_top_shap_values.ipynb @@ -34,9 +34,9 @@ "metadata": {}, "outputs": [], "source": [ - "shap_values_path = '/Users/jk1/temp/opsum_end/training/hyperopt/xgb_gridsearch/xgb_gs_20250513_154517/checkpoints_short_opsum_xgb_20250518_001112_cv_1/shap_explanations_over_time/tree_explainer_shap_values_over_ts.pkl'\n", - "test_data_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_09052025_220520/early_neurological_deterioration_train_data_splits/test_data_early_neurological_deterioration_ts0.8_rs42_ns5.pth'\n", - "cat_encoding_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_09052025_220520/logs_09052025_220520/categorical_variable_encoding.csv'" + "shap_values_path = '/Users/jk1/temp/opsum_end/training/without_imaging/hyperopt/xgb_gridsearch/xgb_gs_20250513_154517/checkpoints_short_opsum_xgb_20250518_001112_cv_1/shap_explanations_over_time/tree_explainer_shap_values_over_ts.pkl'\n", + "test_data_path = '/Users/jk1/temp/opsum_end/preprocessing/without_imaging/gsu_Extraction_20220815_prepro_09052025_220520/early_neurological_deterioration_train_data_splits/test_data_early_neurological_deterioration_ts0.8_rs42_ns5.pth'\n", + "cat_encoding_path = '/Users/jk1/temp/opsum_end/preprocessing/without_imaging/gsu_Extraction_20220815_prepro_09052025_220520/logs_09052025_220520/categorical_variable_encoding.csv'\n" ] }, { diff --git a/prediction/short_term_outcome_prediction/hyperopt/xgb_gridsearch_evaluation.ipynb b/prediction/short_term_outcome_prediction/hyperopt/xgb_gridsearch_evaluation.ipynb index 0ad1ca7..f87e020 100644 --- a/prediction/short_term_outcome_prediction/hyperopt/xgb_gridsearch_evaluation.ipynb +++ b/prediction/short_term_outcome_prediction/hyperopt/xgb_gridsearch_evaluation.ipynb @@ -30,8 +30,8 @@ }, "outputs": [], "source": [ - "log_folder_path = '/Users/jk1/temp/opsum_end/training/hyperopt/xgb_gridsearch/xgb_gs_20250513_154517'\n", - "output_dir = '/Users/jk1/temp/opsum_end/training/hyperopt/xgb_gridsearch/xgb_gs_20250513_154517'" + "log_folder_path = '/Users/jk1/temp/opsum_end/training/with_imaging/xgb_hyperopt'\n", + "output_dir = '/Users/jk1/temp/opsum_end/training/with_imaging/xgb_hyperopt'" ] }, { @@ -84,7 +84,7 @@ }, "outputs": [], "source": [ - "best_df = gs_df.sort_values('median_val_auc', ascending=False).head(1)\n", + "best_df = gs_df.sort_values('median_val_auprc', ascending=False).head(1)\n", "best_df" ] }, @@ -112,7 +112,7 @@ "outputs": [], "source": [ "# plot histogram of median_val_scores for all split_files\n", - "ax = sns.histplot(x=\"median_val_scores\", data=gs_df, hue=\"split_file\")\n", + "ax = sns.histplot(x=\"median_val_auc\", data=gs_df, hue=\"split_file\")\n", "\n", "labels = ['All events / With interval', 'First event / No interval']\n", "# set legend labels\n", @@ -146,7 +146,7 @@ }, "outputs": [], "source": [ - "full_results_dir = '/Users/jk1/temp/opsum_end/training/hyperopt/xgb_gridsearch/xgb_gs_20250513_154517'" + "full_results_dir = '/Users/jk1/temp/opsum_end/training/with_imaging/xgb_hyperopt'" ] }, { @@ -378,7 +378,28 @@ "id": "812d89a4", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "config_used = {\n", + " \"n_trials\": 1000,\n", + " \"target_interval\": 1,\n", + " \"restrict_to_first_event\": 0,\n", + " \"max_depth\": [2, 8, 10, 12],\n", + " \"n_estimators\": [1000, 2000, 4000],\n", + " \"learning_rate\": [0.02, 0.1],\n", + " \"reg_lambda\": [1, 10, 75],\n", + " \"alpha\": [5, 10, 15],\n", + " \"early_stopping_rounds\": [50, 100, 150],\n", + " \"scale_pos_weight\": [25, 55, 10],\n", + " \"min_child_weight\": [1, 3],\n", + " \"subsample\": [0.5, 0.8, 1],\n", + " \"colsample_bytree\": [0.8],\n", + " \"colsample_bylevel\": [1],\n", + " \"booster\": [\"dart\"],\n", + " \"grow_policy\": [\"depthwise\", \"lossguide\"],\n", + " \"num_boost_round\": [200, 300, 500],\n", + " \"gamma\": [0.24, 0.5, 0.75]\n", + "}" + ] } ], "metadata": { From 0c09518456cec0aed5cab2fd825cb25c13f283f9 Mon Sep 17 00:00:00 2001 From: Julian Klug Date: Wed, 4 Mar 2026 08:31:12 +0100 Subject: [PATCH 5/5] refactored inference plotting into a reusable function --- .../figures/inference_plotting.py | 994 ++++++++++++++++++ .../figures/test_inference_plotting.ipynb | 105 ++ 2 files changed, 1099 insertions(+) create mode 100644 prediction/short_term_outcome_prediction/figures/inference_plotting.py create mode 100644 prediction/short_term_outcome_prediction/figures/test_inference_plotting.ipynb diff --git a/prediction/short_term_outcome_prediction/figures/inference_plotting.py b/prediction/short_term_outcome_prediction/figures/inference_plotting.py new file mode 100644 index 0000000..a4422e0 --- /dev/null +++ b/prediction/short_term_outcome_prediction/figures/inference_plotting.py @@ -0,0 +1,994 @@ +import argparse +import os +import pickle + +import numpy as np +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt +import torch as ch +from tqdm import tqdm +from matplotlib.lines import Line2D +from matplotlib.legend_handler import HandlerTuple + +from prediction.utils.utils import filter_consecutive_numbers, smooth +from prediction.utils.visualisation_helper_functions import ( + LegendTitle, + reverse_normalisation_for_subj, +) + + +def load_inference_raw_inputs( + shap_values_path, + test_data_path, + cat_encoding_path, + normalisation_parameters_path, + predictions_path, + n_time_steps=72, + only_last_timestep=False, +): + with open(os.path.join(shap_values_path), "rb") as handle: + original_shap_values = pickle.load(handle) + + if only_last_timestep: + shap_values = [original_shap_values[-1]] + else: + shap_values = [ + np.array( + [original_shap_values[i] for i in range(len(original_shap_values))] + ).swapaxes(0, 1) + ][0] + + normalisation_parameters_df = pd.read_csv(normalisation_parameters_path) + + with open(predictions_path, "rb") as handle: + gt_over_time, predictions_over_time = pickle.load(handle) + + gt_over_time = gt_over_time.reshape(-1, n_time_steps) + predictions_over_time = predictions_over_time.reshape(-1, n_time_steps) + + X_test, _ = ch.load(test_data_path) + test_X_np = X_test[:, :, :, -1].astype("float32") + + features = X_test[0, 0, :, 2] + avg_features = [f"avg_{i}" for i in features] + min_features = [f"min_{i}" for i in features] + max_features = [f"max_{i}" for i in features] + aggregated_feature_names = ( + features.tolist() + avg_features + min_features + max_features + ["base_value"] + ) + + return { + "shap_values": shap_values, + "X_test": X_test, + "test_X_np": test_X_np, + "features": features, + "aggregated_feature_names": aggregated_feature_names, + "gt_over_time": gt_over_time, + "predictions_over_time": predictions_over_time, + "normalisation_parameters_df": normalisation_parameters_df, + "cat_encoding_path": cat_encoding_path, + } + + +def build_inference_plot_inputs( + raw_inputs, + n_time_steps=72, + reverse_categorical_encoding=True, + pool_hourly_split_values=True, + only_keep_current_value_shap=True, + shap_aggregation_func="sum", + use_simplified_shap_values=True, + smoothing_window=15, + feature_to_english_name_correspondence_path=None, +): + shap_values = raw_inputs["shap_values"] + features = raw_inputs["features"] + aggregated_feature_names = raw_inputs["aggregated_feature_names"] + test_X_np = raw_inputs["test_X_np"] + normalisation_parameters_df = raw_inputs["normalisation_parameters_df"] + + shap_values_df = pd.DataFrame() + for ts in tqdm(range(n_time_steps), desc="Build SHAP table"): + ts_shap_values_df = pd.DataFrame( + data=shap_values[:, ts], columns=np.array(aggregated_feature_names) + ) + ts_shap_values_df = ts_shap_values_df.reset_index() + ts_shap_values_df.rename(columns={"index": "case_admission_id_idx"}, inplace=True) + ts_shap_values_df = ts_shap_values_df.melt( + id_vars="case_admission_id_idx", var_name="feature", value_name="shap_value" + ) + ts_shap_values_df["time_step"] = ts + shap_values_df = pd.concat((shap_values_df, ts_shap_values_df), ignore_index=True) + + if only_keep_current_value_shap: + shap_values_df = shap_values_df[shap_values_df["feature"].isin(features)] + + feature_values_df = pd.DataFrame() + for subj_idx in tqdm(range(test_X_np.shape[0]), desc="Build feature table"): + subj_feature_values_df = pd.DataFrame( + data=test_X_np[subj_idx, :, :], columns=np.array(features) + ) + subj_feature_values_df = reverse_normalisation_for_subj( + subj_feature_values_df, normalisation_parameters_df + ) + subj_feature_values_df = subj_feature_values_df.reset_index() + subj_feature_values_df.rename(columns={"index": "time_step"}, inplace=True) + subj_feature_values_df["case_admission_id_idx"] = subj_idx + subj_feature_values_df = subj_feature_values_df.melt( + id_vars=["case_admission_id_idx", "time_step"], + var_name="feature", + value_name="feature_value", + ) + feature_values_df = pd.concat( + (feature_values_df, subj_feature_values_df), ignore_index=True + ) + + if reverse_categorical_encoding: + cat_encoding_df = pd.read_csv(raw_inputs["cat_encoding_path"]) + for i in tqdm(range(len(cat_encoding_df)), desc="Decode categoricals"): + cat_basename = cat_encoding_df.sample_label[i].lower().replace(" ", "_") + cat_item_list = ( + cat_encoding_df.other_categories[i] + .replace("[", "") + .replace("]", "") + .replace("'", "") + .split(", ") + ) + cat_item_list = [ + cat_basename + "_" + item.replace(" ", "_").lower() + for item in cat_item_list + ] + for cat_item_idx, cat_item in enumerate(cat_item_list): + feature_values_df.loc[ + feature_values_df.feature == cat_item, "feature_value" + ] *= cat_item_idx + 1 + feature_values_df.loc[ + feature_values_df.feature == cat_item, "feature" + ] = cat_encoding_df.sample_label[i] + feature_values_df = ( + feature_values_df.groupby( + ["case_admission_id_idx", "feature", "time_step"] + ) + .sum() + .reset_index() + ) + + shap_values_df.loc[ + shap_values_df.feature == cat_item, "feature" + ] = cat_encoding_df.sample_label[i] + if shap_aggregation_func == "sum": + shap_values_df = ( + shap_values_df.groupby( + ["case_admission_id_idx", "feature", "time_step"] + ) + .sum() + .reset_index() + ) + else: + shap_values_df = ( + shap_values_df.groupby( + ["case_admission_id_idx", "feature", "time_step"] + ) + .median() + .reset_index() + ) + + cat_to_numerical_encoding = { + "Prestroke disability (Rankin)": {0: 0, 1: 3, 2: 4, 3: 2, 4: 1, 5: 5}, + "categorical_onset_to_admission_time": {0: 2, 1: 1, 2: 0, 3: 3, 4: 5, 5: 4}, + "categorical_IVT": {0: 2, 1: 3, 2: 4, 3: 1, 4: 0}, + "categorical_IAT": {0: 1, 1: 2, 2: 3, 3: 0}, + } + for cat_feature, cat_encoding in cat_to_numerical_encoding.items(): + mask = feature_values_df.feature == cat_feature + feature_values_df.loc[mask, "feature_value"] = feature_values_df.loc[ + mask, "feature_value" + ].map(cat_encoding) + + if pool_hourly_split_values: + hourly_split_features = [ + "NIHSS", + "systolic_blood_pressure", + "diastolic_blood_pressure", + "mean_blood_pressure", + "heart_rate", + "respiratory_rate", + "temperature", + "oxygen_saturation", + ] + for feature in tqdm(hourly_split_features, desc="Pool hourly feature splits"): + shap_values_df.loc[shap_values_df.feature.str.contains(feature), "feature"] = ( + feature[0].upper() + feature[1:] + ).replace("_", " ") + if shap_aggregation_func == "median": + shap_values_df = ( + shap_values_df.groupby( + ["case_admission_id_idx", "feature", "time_step"] + ) + .median() + .reset_index() + ) + else: + shap_values_df = ( + shap_values_df.groupby( + ["case_admission_id_idx", "feature", "time_step"] + ) + .sum() + .reset_index() + ) + + feature_values_df.loc[ + feature_values_df.feature.str.contains(feature), "feature" + ] = (feature[0].upper() + feature[1:]).replace("_", " ") + feature_values_df = ( + feature_values_df.groupby(["case_admission_id_idx", "feature", "time_step"]) + .median() + .reset_index() + ) + + if feature_to_english_name_correspondence_path: + correspondence = pd.read_excel(feature_to_english_name_correspondence_path) + for feature in shap_values_df.feature.unique(): + if feature in correspondence.feature_name.values: + shap_values_df.loc[shap_values_df.feature == feature, "feature"] = correspondence[ + correspondence.feature_name == feature + ].english_name.values[0] + for feature in feature_values_df.feature.unique(): + if feature in correspondence.feature_name.values: + feature_values_df.loc[ + feature_values_df.feature == feature, "feature" + ] = correspondence[correspondence.feature_name == feature].english_name.values[0] + + if use_simplified_shap_values: + shap_values_over_time = [] + for ts in tqdm(range(n_time_steps), desc="Create simplified SHAP tensor"): + subj_values_over_time = [] + for subj in range(len(test_X_np)): + values = shap_values_df[ + (shap_values_df.case_admission_id_idx == subj) + & (shap_values_df.time_step == ts) + ].shap_value.values + subj_values_over_time.append(values) + shap_values_over_time.append(np.array(subj_values_over_time)) + shap_values_over_time = np.array(shap_values_over_time) + else: + shap_values_over_time = np.moveaxis(shap_values, 1, 0) + + reduced_feature_names = shap_values_df.feature.unique() + + smoothed_shap_values_over_time = [] + for subj_idx in range(shap_values_over_time.shape[1]): + subj_smoothed = [] + for feature_idx in range(shap_values_over_time.shape[2]): + subj_smoothed.append( + smooth(shap_values_over_time[:, subj_idx, feature_idx], smoothing_window) + ) + smoothed_shap_values_over_time.append(np.moveaxis(np.array(subj_smoothed), 0, -1)) + smoothed_shap_values_over_time = np.moveaxis( + np.array(smoothed_shap_values_over_time), 0, 1 + ) + + return { + "predictions_over_time": raw_inputs["predictions_over_time"], + "gt_over_time": raw_inputs["gt_over_time"], + "feature_values_df": feature_values_df, + "smoothed_shap_values_over_time": smoothed_shap_values_over_time, + "shap_values_over_time": shap_values_over_time, + "reduced_feature_names": reduced_feature_names, + "raw_inputs": raw_inputs, + "shap_values_df": shap_values_df, + } + + +def load_preprocess_and_plot_subjects( + subjects, + shap_values_path, + test_data_path, + cat_encoding_path, + normalisation_parameters_path, + predictions_path, + n_time_steps=72, + only_last_timestep=False, + reverse_categorical_encoding=True, + pool_hourly_split_values=True, + only_keep_current_value_shap=True, + shap_aggregation_func="sum", + use_simplified_shap_values=True, + smoothing_window=15, + feature_to_english_name_correspondence_path=None, + plot_kwargs=None, +): + if plot_kwargs is None: + plot_kwargs = {} + + raw_inputs = load_inference_raw_inputs( + shap_values_path=shap_values_path, + test_data_path=test_data_path, + cat_encoding_path=cat_encoding_path, + normalisation_parameters_path=normalisation_parameters_path, + predictions_path=predictions_path, + n_time_steps=n_time_steps, + only_last_timestep=only_last_timestep, + ) + + prepared = build_inference_plot_inputs( + raw_inputs=raw_inputs, + n_time_steps=n_time_steps, + reverse_categorical_encoding=reverse_categorical_encoding, + pool_hourly_split_values=pool_hourly_split_values, + only_keep_current_value_shap=only_keep_current_value_shap, + shap_aggregation_func=shap_aggregation_func, + use_simplified_shap_values=use_simplified_shap_values, + smoothing_window=smoothing_window, + feature_to_english_name_correspondence_path=feature_to_english_name_correspondence_path, + ) + + figures_by_subject = {} + for subj in subjects: + fig = plot_joint_subject_explanation( + subj=subj, + predictions_over_time=prepared["predictions_over_time"], + gt_over_time=prepared["gt_over_time"], + feature_values_df=prepared["feature_values_df"], + smoothed_shap_values_over_time=prepared["smoothed_shap_values_over_time"], + shap_values_over_time=prepared["shap_values_over_time"], + reduced_feature_names=prepared["reduced_feature_names"], + use_simplified_shap_values=use_simplified_shap_values, + **plot_kwargs, + ) + figures_by_subject[subj] = fig + + return { + "figures_by_subject": figures_by_subject, + "prepared_inputs": prepared, + } + + +def plot_joint_subject_explanation( + subj, + predictions_over_time, + gt_over_time, + feature_values_df, + smoothed_shap_values_over_time, + shap_values_over_time, + reduced_feature_names, + use_simplified_shap_values=True, + threshold=0.04, + n_features_selection=0, + n_features=1, + k=0.25, + alpha=0.3, + only_non_static_features=True, + use_smoothed_shap_values=True, + plot_ground_truth=True, + display_significant_slopes=True, + n_slope_steps=5, + slope_threshold_multiplier=1.5, + skip_label_at_zero=True, + display_text_labels=True, + display_legend=True, + display_title=False, + plot_NIHSS_continuously=True, + ts_marker_level="shap", + tick_label_size=13, + label_font_size=16, +): + subj_pred_over_ts = predictions_over_time[subj] + subj_gt_over_ts = gt_over_time[subj] + n_time_steps = len(subj_pred_over_ts) + + fig_joint, (ax_main, ax_features) = plt.subplots( + nrows=2, + ncols=1, + figsize=(15, 12), + gridspec_kw=dict(height_ratios=[2, 1], hspace=0.3), + ) + + if use_smoothed_shap_values: + working_shap_values = smoothed_shap_values_over_time + else: + working_shap_values = shap_values_over_time + + significant_positive_timesteps = filter_consecutive_numbers( + np.where(np.diff(subj_pred_over_ts) > threshold)[0] + ) + significant_negative_timesteps = filter_consecutive_numbers( + np.where(np.diff(subj_pred_over_ts) < -threshold)[0] + ) + significant_timesteps = np.concatenate( + (significant_positive_timesteps, significant_negative_timesteps) + ) + + non_norm_subj_df = ( + feature_values_df[feature_values_df.case_admission_id_idx == subj] + .drop(columns=["case_admission_id_idx"]) + .pivot(index="time_step", columns="feature", values="feature_value") + ) + + if only_non_static_features: + non_static_features = np.where(non_norm_subj_df.std() > 0.01)[0] + if use_simplified_shap_values: + non_static_features = np.where( + np.isin( + reduced_feature_names, + np.array(non_norm_subj_df.std()[non_norm_subj_df.std() > 0.01].index), + ) + )[0] + selected_positive_features_by_impact = np.diff( + working_shap_values[:, subj, non_static_features], axis=0 + )[significant_positive_timesteps].argmax(axis=1) + selected_positive_features_by_impact = non_static_features[ + selected_positive_features_by_impact + ] + selected_negative_features_by_impact = np.diff( + working_shap_values[:, subj, non_static_features], axis=0 + )[significant_negative_timesteps].argmin(axis=1) + selected_negative_features_by_impact = non_static_features[ + selected_negative_features_by_impact + ] + else: + non_static_features = np.arange(working_shap_values.shape[-1]) + selected_positive_features_by_impact = np.diff( + working_shap_values[:, subj], axis=0 + )[significant_positive_timesteps].argmax(axis=1) + selected_negative_features_by_impact = np.diff( + working_shap_values[:, subj], axis=0 + )[significant_negative_timesteps].argmin(axis=1) + + selected_features_by_impact = np.concatenate( + (selected_positive_features_by_impact, selected_negative_features_by_impact) + ) + + if display_significant_slopes: + slope_threshold = slope_threshold_multiplier * threshold + significant_positive_slope = filter_consecutive_numbers( + set( + np.where( + ( + np.concatenate( + ( + subj_pred_over_ts[n_slope_steps:], + np.zeros(n_slope_steps), + ) + ) + - subj_pred_over_ts + )[:-n_slope_steps] + > slope_threshold + )[0] + ).difference(set(significant_positive_timesteps)) + ) + + significant_negative_slope = filter_consecutive_numbers( + set( + np.where( + ( + np.concatenate( + ( + subj_pred_over_ts[n_slope_steps:], + np.zeros(n_slope_steps), + ) + ) + - subj_pred_over_ts + )[:-n_slope_steps] + < -slope_threshold + )[0] + ).difference(set(significant_negative_timesteps)) + ) + + delta_shap_by_features = np.concatenate( + ( + working_shap_values[n_slope_steps:, subj, non_static_features], + np.zeros((n_slope_steps, len(non_static_features))), + ) + ) - working_shap_values[:, subj, non_static_features] + + selected_positive_features_by_slope = delta_shap_by_features[:-n_slope_steps][ + significant_positive_slope + ].argmax(axis=1) + selected_positive_features_by_slope = non_static_features[ + selected_positive_features_by_slope + ] + selected_negative_features_by_slope = delta_shap_by_features[:-n_slope_steps][ + significant_negative_slope + ].argmin(axis=1) + selected_negative_features_by_slope = non_static_features[ + selected_negative_features_by_slope + ] + + selected_features_by_impact = np.concatenate( + ( + selected_features_by_impact, + selected_positive_features_by_slope, + selected_negative_features_by_slope, + ) + ) + significant_timesteps = np.concatenate( + ( + significant_timesteps, + significant_positive_slope, + significant_negative_slope, + ) + ) + selected_positive_features_by_impact = np.concatenate( + ( + selected_positive_features_by_impact, + selected_positive_features_by_slope, + ) + ) + selected_negative_features_by_impact = np.concatenate( + ( + selected_negative_features_by_impact, + selected_negative_features_by_slope, + ) + ) + + if n_features_selection == 0: + selected_positive_features = np.array([]) + selected_negative_features = np.array([]) + else: + selected_positive_features = working_shap_values[-1, subj].argsort()[-n_features:][ + ::-1 + ] + selected_negative_features = working_shap_values[-1, subj].argsort()[:n_features][ + ::-1 + ] + + selected_features = np.concatenate( + ( + selected_positive_features, + selected_positive_features_by_impact, + selected_negative_features, + selected_negative_features_by_impact, + ) + ).astype(int) + + positive_color_palette = sns.color_palette( + "mako", + n_colors=len( + set(np.concatenate((selected_positive_features, selected_positive_features_by_impact))) + ), + ) + negative_color_palette = sns.color_palette( + "flare_r", + n_colors=len( + set(np.concatenate((selected_negative_features, selected_negative_features_by_impact))) + ), + ) + + timestep_axis = np.array(range(n_time_steps)) + sns.lineplot( + x=timestep_axis, + y=subj_pred_over_ts, + label="Predicted probability", + linewidth=2, + ax=ax_main, + ) + + if plot_ground_truth: + changes_in_gt = np.diff(subj_gt_over_ts, prepend=0) + change_pairs = list(zip(np.where(changes_in_gt == 1)[0], np.where(changes_in_gt == -1)[0])) + for change_pair in change_pairs: + ax_main.plot([change_pair[0], change_pair[1]], [0, 0], color="#7b002c", linewidth=10, alpha=0.8) + ax_main.text( + np.mean(change_pair), + 0 + 0.02, + "6h to END", + horizontalalignment="center", + verticalalignment="center", + fontsize=tick_label_size, + ) + + pos_baseline = subj_pred_over_ts + neg_baseline = subj_pred_over_ts + pos_count = 0 + neg_count = 0 + feature_color_dict = {} + + for feature in set(selected_features): + subj_feature_shap_value_over_time = working_shap_values[:, subj, feature] + positive_portion = subj_feature_shap_value_over_time > 0 + negative_portion = subj_feature_shap_value_over_time < 0 + + pos_function = subj_feature_shap_value_over_time.copy() + neg_function = subj_feature_shap_value_over_time.copy() + pos_function[negative_portion] = 0 + neg_function[positive_portion] = 0 + + if feature in selected_features_by_impact: + important_ts_idx = np.where(selected_features_by_impact == feature)[0] + if not np.logical_and( + plot_NIHSS_continuously, reduced_feature_names[feature] == "NIHSS" + ): + pos_function[: significant_timesteps[important_ts_idx][0] + 1] = 0 + neg_function[: significant_timesteps[important_ts_idx][0] + 1] = 0 + + if feature in selected_positive_features: + feature_color = positive_color_palette[pos_count] + pos_count += 1 + elif feature in selected_negative_features: + feature_color = negative_color_palette[neg_count] + neg_count += 1 + elif feature in selected_negative_features_by_impact: + feature_color = negative_color_palette[neg_count] + neg_count += 1 + elif feature in selected_positive_features_by_impact: + feature_color = positive_color_palette[pos_count] + pos_count += 1 + else: + feature_color = "grey" + feature_color_dict[feature] = feature_color + + if np.any(pos_function): + positive_feature = pos_baseline + k * pos_function + ax_main.fill_between( + timestep_axis, pos_baseline, positive_feature, color=feature_color, alpha=alpha + ) + pos_baseline = positive_feature + + if np.any(neg_function): + negative_feature = neg_baseline + k * neg_function + ax_main.fill_between( + timestep_axis, negative_feature, neg_baseline, color=feature_color, alpha=alpha + ) + neg_baseline = negative_feature + + ax_main.scatter( + [], + [], + color=feature_color, + alpha=alpha, + label=reduced_feature_names[feature], + marker="s", + s=200, + ) + + for feature in set(selected_features_by_impact): + important_ts_idx = np.where(selected_features_by_impact == feature)[0] + for ts_idx in important_ts_idx: + if skip_label_at_zero and significant_timesteps[ts_idx] == 0: + continue + if subj_pred_over_ts[significant_timesteps[ts_idx]] > subj_pred_over_ts[ + significant_timesteps[ts_idx] + 1 + ]: + marker = "v" + if ts_marker_level == "shap": + marker_y_level = pos_baseline[significant_timesteps[ts_idx]] + 0.005 + else: + marker_y_level = subj_pred_over_ts[significant_timesteps[ts_idx]] + 0.005 + text_y_level = marker_y_level + 0.01 + else: + marker = "^" + if ts_marker_level == "shap": + marker_y_level = neg_baseline[significant_timesteps[ts_idx]] - 0.005 + else: + marker_y_level = subj_pred_over_ts[significant_timesteps[ts_idx]] - 0.005 + text_y_level = marker_y_level - 0.015 + + ax_main.scatter( + significant_timesteps[ts_idx], + marker_y_level, + color=feature_color_dict[feature], + s=100, + marker=marker, + alpha=1, + edgecolors="white", + ) + + if display_text_labels: + if marker == "v": + ax_main.text( + significant_timesteps[ts_idx] + 0.01, + text_y_level, + reduced_feature_names[feature], + fontsize=12, + color="black", + rotation=45, + ha="left", + va="bottom", + ) + else: + ax_main.text( + significant_timesteps[ts_idx] - 0.01, + text_y_level, + reduced_feature_names[feature], + fontsize=12, + color="black", + ) + + if display_title: + ax_main.set_title(f"Predictions for subject {subj} of test set along time", fontsize=20) + + ax_main.set_xlabel("Time from admission (hours)", fontsize=label_font_size) + ax_main.set_ylabel("Probability of END", fontsize=label_font_size) + ax_main.tick_params(axis="both", labelsize=tick_label_size) + + if display_legend: + legend_markers, legend_labels = ax_main.get_legend_handles_labels() + + shap_shades_markers = legend_markers[1:] + shap_shades_labels = legend_labels[1:] + legend_markers = [legend_markers[0]] + legend_labels = [legend_labels[0]] + + ts_marker_down = Line2D( + [0], [0], marker="v", linestyle="", color="grey", markersize=7, alpha=0.8 + ) + ts_marker_up = Line2D( + [0], [0], marker="^", linestyle="", color="grey", markersize=7, alpha=0.8 + ) + ts_label = "Positive / Negative impact on inflection of prediction" + legend_markers.append((ts_marker_up, ts_marker_down)) + legend_labels.append(ts_label) + + legend_markers.append("") + legend_labels.append("") + legend_markers.append("Weight & direction of influence on model prediction") + legend_labels.append("") + + legend_markers += shap_shades_markers + legend_labels += shap_shades_labels + + ax_main.legend( + legend_markers, + legend_labels, + fontsize=label_font_size, + title="Influence on model prediction", + title_fontsize=label_font_size, + handler_map={ + tuple: HandlerTuple(ndivide=1), + str: LegendTitle({"fontsize": label_font_size}), + }, + bbox_to_anchor=(1.05, 1), + loc="upper left", + ) + + n_features_small = len(set(selected_features_by_impact)) + + if n_features_small > 0: + cols = min(4, n_features_small) + rows = (n_features_small + cols - 1) // cols + + gs_nested = ax_features.figure.add_gridspec( + rows, + cols, + left=ax_features.get_position().x0, + right=ax_features.get_position().x1, + bottom=ax_features.get_position().y0, + top=ax_features.get_position().y1, + hspace=0.4, + wspace=0.3, + ) + + ax_features.remove() + + for idx, feature in enumerate(set(selected_features_by_impact)): + row = idx // cols + col = idx % cols + ax_small = fig_joint.add_subplot(gs_nested[row, col]) + + feature_name = reduced_feature_names[feature] + feature_color = feature_color_dict[feature] + feature_data = non_norm_subj_df[feature_name] + + ax_small.plot(timestep_axis, feature_data, color=feature_color, linewidth=2) + ax_small.fill_between(timestep_axis, feature_data, alpha=0.3, color=feature_color) + + important_ts_idx = np.where(selected_features_by_impact == feature)[0] + for ts_idx in important_ts_idx: + timestep = significant_timesteps[ts_idx] + ax_small.scatter( + timestep, + feature_data.iloc[timestep], + color=feature_color, + s=60, + zorder=5, + edgecolors="white", + linewidth=1, + ) + + ax_small.set_title( + feature_name, + fontsize=tick_label_size, + color=feature_color, + weight="bold", + ) + ax_small.set_xlim(0, n_time_steps) + ax_small.spines["top"].set_visible(False) + ax_small.spines["right"].set_visible(False) + + y_min, y_max = feature_data.min(), feature_data.max() + if y_min == y_max: + y_ticks = [y_min] + else: + y_ticks = [y_min, y_max] + ax_small.set_yticks(y_ticks) + ax_small.tick_params(labelsize=tick_label_size - 2) + ax_small.set_ylim(y_min - 0.2 * (y_max - y_min), y_max + 0.2 * (y_max - y_min)) + + if row == rows - 1: + ax_small.set_xlabel("Time (h)", fontsize=tick_label_size - 1) + else: + ax_small.set_xticklabels([]) + else: + ax_features.text( + 0.5, + 0.5, + "No significant feature changes detected", + transform=ax_features.transAxes, + ha="center", + va="center", + fontsize=label_font_size, + style="italic", + ) + ax_features.set_xlim(0, 1) + ax_features.set_ylim(0, 1) + ax_features.axis("off") + + ax_main.spines["top"].set_visible(False) + ax_main.spines["right"].set_visible(False) + + return fig_joint + + +def _build_arg_parser(): + parser = argparse.ArgumentParser( + description="Load inference artifacts and generate explanation plots for selected subjects." + ) + parser.add_argument("--shap-values-path", required=True, help="Path to SHAP values pickle.") + parser.add_argument("--test-data-path", required=True, help="Path to test data .pth file.") + parser.add_argument( + "--cat-encoding-path", + required=True, + help="Path to categorical encoding CSV.", + ) + parser.add_argument( + "--normalisation-parameters-path", + required=True, + help="Path to normalization parameters CSV.", + ) + parser.add_argument( + "--predictions-path", + required=True, + help="Path to predictions pickle (gt, pred).", + ) + parser.add_argument( + "--subjects", + nargs="+", + type=int, + required=True, + help="Subject indices to plot (e.g. --subjects 3 10 42).", + ) + parser.add_argument("--n-time-steps", type=int, default=72) + parser.add_argument("--only-last-timestep", action="store_true") + + parser.add_argument( + "--reverse-categorical-encoding", + dest="reverse_categorical_encoding", + action="store_true", + default=True, + ) + parser.add_argument( + "--no-reverse-categorical-encoding", + dest="reverse_categorical_encoding", + action="store_false", + ) + parser.add_argument( + "--pool-hourly-split-values", + dest="pool_hourly_split_values", + action="store_true", + default=True, + ) + parser.add_argument( + "--no-pool-hourly-split-values", + dest="pool_hourly_split_values", + action="store_false", + ) + parser.add_argument( + "--only-keep-current-value-shap", + dest="only_keep_current_value_shap", + action="store_true", + default=True, + ) + parser.add_argument( + "--keep-all-aggregated-shap", + dest="only_keep_current_value_shap", + action="store_false", + ) + parser.add_argument( + "--shap-aggregation-func", + choices=["sum", "median"], + default="sum", + ) + parser.add_argument( + "--use-simplified-shap-values", + dest="use_simplified_shap_values", + action="store_true", + default=True, + ) + parser.add_argument( + "--no-simplified-shap-values", + dest="use_simplified_shap_values", + action="store_false", + ) + parser.add_argument("--smoothing-window", type=int, default=15) + parser.add_argument( + "--feature-to-english-name-correspondence-path", + default=None, + help="Optional path to feature name mapping Excel file.", + ) + + parser.add_argument("--threshold", type=float, default=0.04) + parser.add_argument("--n-features-selection", type=int, default=0) + parser.add_argument("--n-features", type=int, default=1) + parser.add_argument("--display-legend", action="store_true", default=False) + parser.add_argument("--display-text-labels", action="store_true", default=False) + parser.add_argument("--display-title", action="store_true", default=False) + parser.add_argument("--plot-ground-truth", action="store_true", default=True) + parser.add_argument("--no-plot-ground-truth", dest="plot_ground_truth", action="store_false") + + parser.add_argument( + "--output-dir", + default=None, + help="Optional output directory; if provided, each subject plot is saved as PNG.", + ) + parser.add_argument("--dpi", type=int, default=300) + parser.add_argument( + "--show", + action="store_true", + help="Display generated figures interactively.", + ) + return parser + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + + plot_kwargs = { + "threshold": args.threshold, + "n_features_selection": args.n_features_selection, + "n_features": args.n_features, + "display_legend": args.display_legend, + "display_text_labels": args.display_text_labels, + "display_title": args.display_title, + "plot_ground_truth": args.plot_ground_truth, + } + + result = load_preprocess_and_plot_subjects( + subjects=args.subjects, + shap_values_path=args.shap_values_path, + test_data_path=args.test_data_path, + cat_encoding_path=args.cat_encoding_path, + normalisation_parameters_path=args.normalisation_parameters_path, + predictions_path=args.predictions_path, + n_time_steps=args.n_time_steps, + only_last_timestep=args.only_last_timestep, + reverse_categorical_encoding=args.reverse_categorical_encoding, + pool_hourly_split_values=args.pool_hourly_split_values, + only_keep_current_value_shap=args.only_keep_current_value_shap, + shap_aggregation_func=args.shap_aggregation_func, + use_simplified_shap_values=args.use_simplified_shap_values, + smoothing_window=args.smoothing_window, + feature_to_english_name_correspondence_path=args.feature_to_english_name_correspondence_path, + plot_kwargs=plot_kwargs, + ) + + figures_by_subject = result["figures_by_subject"] + + if args.output_dir: + os.makedirs(args.output_dir, exist_ok=True) + for subj, fig in figures_by_subject.items(): + out_path = os.path.join(args.output_dir, f"subject_{subj}_inference_plot.png") + fig.savefig(out_path, bbox_inches="tight", dpi=args.dpi) + print(f"Saved: {out_path}") + + if args.show: + plt.show() + else: + for fig in figures_by_subject.values(): + plt.close(fig) + + +if __name__ == "__main__": + main() diff --git a/prediction/short_term_outcome_prediction/figures/test_inference_plotting.ipynb b/prediction/short_term_outcome_prediction/figures/test_inference_plotting.ipynb new file mode 100644 index 0000000..98d133d --- /dev/null +++ b/prediction/short_term_outcome_prediction/figures/test_inference_plotting.ipynb @@ -0,0 +1,105 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "47090ced", + "metadata": {}, + "source": [ + "# Test inference plotting\n", + "Run the new `inference_plotting` wrapper for three subjects and render the generated figures." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2950beb", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "from prediction.short_term_outcome_prediction.figures.inference_plotting import load_preprocess_and_plot_subjects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "603f0dcb", + "metadata": {}, + "outputs": [], + "source": [ + "# Reuse paths from the inference notebook setup\n", + "shap_values_path = '/Users/jk1/temp/opsum_end/training/without_imaging/hyperopt/xgb_gridsearch/xgb_gs_20250513_154517/checkpoints_short_opsum_xgb_20250518_001112_cv_1/shap_explanations_over_time/tree_explainer_shap_values_over_ts.pkl'\n", + "test_data_path = '/Users/jk1/temp/opsum_end/preprocessing/without_imaging/gsu_Extraction_20220815_prepro_09052025_220520/early_neurological_deterioration_train_data_splits/test_data_early_neurological_deterioration_ts0.8_rs42_ns5.pth'\n", + "cat_encoding_path = '/Users/jk1/temp/opsum_end/preprocessing/without_imaging/gsu_Extraction_20220815_prepro_09052025_220520/logs_09052025_220520/categorical_variable_encoding.csv'\n", + "normalisation_parameters_path = '/Users/jk1/temp/opsum_end/preprocessing/without_imaging/gsu_Extraction_20220815_prepro_09052025_220520/logs_09052025_220520/normalisation_parameters.csv'\n", + "predictions_path = '/Users/jk1/temp/opsum_end/testing/without_imaging/test_gt_and_pred_cv_1.pkl'\n", + "\n", + "subjects = [9, 96, 133]\n", + "output_dir = '/Users/jk1/temp/opsum_end/testing/inference_plotting_notebook'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0b232f3", + "metadata": {}, + "outputs": [], + "source": [ + "result = load_preprocess_and_plot_subjects(\n", + " subjects=subjects,\n", + " shap_values_path=shap_values_path,\n", + " test_data_path=test_data_path,\n", + " cat_encoding_path=cat_encoding_path,\n", + " normalisation_parameters_path=normalisation_parameters_path,\n", + " predictions_path=predictions_path,\n", + " n_time_steps=72,\n", + " use_simplified_shap_values=True,\n", + " plot_kwargs={\n", + " 'threshold': 0.04,\n", + " 'n_features_selection': 0,\n", + " 'n_features': 1,\n", + " 'display_legend': True,\n", + " 'display_text_labels': True,\n", + " 'display_title': False,\n", + " 'plot_ground_truth': True,\n", + " },\n", + ")\n", + "\n", + "figures_by_subject = result['figures_by_subject']\n", + "\n", + "# Save and display all three figures\n", + "import os\n", + "from IPython.display import display\n", + "os.makedirs(output_dir, exist_ok=True)\n", + "for subj, fig in figures_by_subject.items():\n", + " out_path = os.path.join(output_dir, f'subject_{subj}_inference_plot.png')\n", + " fig.savefig(out_path, bbox_inches='tight', dpi=300)\n", + " print(f'Saved: {out_path}')\n", + " display(fig)\n", + "\n", + "len(figures_by_subject)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "opsum", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}