Skip to content

Commit d458c09

Browse files
authored
Merge pull request #89 from henrypinkard/main
Neurips poster + video making code
2 parents 94ea89e + 9e5bc51 commit d458c09

6 files changed

Lines changed: 3701 additions & 174 deletions

File tree

figure_making/IDEAL_e2e_remake_final_figure.ipynb

Lines changed: 160 additions & 0 deletions
Large diffs are not rendered by default.

figure_making/IDEAL_figure.ipynb

Lines changed: 8 additions & 22 deletions
Large diffs are not rendered by default.

figure_making/animations/tree_figure.ipynb

Lines changed: 805 additions & 97 deletions
Large diffs are not rendered by default.

figure_making/animations/tree_figure_new.ipynb

Lines changed: 2598 additions & 0 deletions
Large diffs are not rendered by default.

mi_estimator_experiments/pixel_cnn_mi_estimation.ipynb

Lines changed: 76 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,17 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 1,
5+
"execution_count": null,
66
"metadata": {},
77
"outputs": [
88
{
9-
"ename": "ModuleNotFoundError",
10-
"evalue": "No module named 'tensorflow_datasets'",
11-
"output_type": "error",
12-
"traceback": [
13-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
14-
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
15-
"Cell \u001b[0;32mIn[1], line 13\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mencoding_information\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mgpu_utils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m limit_gpu_memory_growth\n\u001b[1;32m 11\u001b[0m limit_gpu_memory_growth()\n\u001b[0;32m---> 13\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtensorflow_datasets\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mtfds\u001b[39;00m \u001b[38;5;66;03m# TFDS for MNIST\u001b[39;00m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtensorflow\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mtf\u001b[39;00m \u001b[38;5;66;03m# TensorFlow operations\u001b[39;00m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;66;03m# from image_distribution_models import PixelCNN\u001b[39;00m\n",
16-
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'tensorflow_datasets'"
9+
"name": "stderr",
10+
"output_type": "stream",
11+
"text": [
12+
"2025-11-04 08:40:21.862103: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
13+
"2025-11-04 08:40:22.694951: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
14+
"2025-11-04 08:40:22.695049: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
15+
"2025-11-04 08:40:22.695059: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
1716
]
1817
}
1918
],
@@ -26,7 +25,7 @@
2625
"config.update(\"jax_enable_x64\", True)\n",
2726
"\n",
2827
"os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" \n",
29-
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1'\n",
28+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = '1' \n",
3029
"from encoding_information.gpu_utils import limit_gpu_memory_growth\n",
3130
"limit_gpu_memory_growth()\n",
3231
"\n",
@@ -57,63 +56,106 @@
5756
},
5857
{
5958
"cell_type": "code",
60-
"execution_count": 4,
59+
"execution_count": null,
6160
"metadata": {},
6261
"outputs": [
62+
{
63+
"name": "stderr",
64+
"output_type": "stream",
65+
"text": [
66+
"/2tb_nvme/hpinkard_waller/GitRepos/EncodingInformation/src/encoding_information/information_estimation.py:478: UserWarning: This function is deprecated. Use estimate_information() instead.\n",
67+
" warnings.warn(\"This function is deprecated. Use estimate_information() instead.\")\n"
68+
]
69+
},
70+
{
71+
"name": "stdout",
72+
"output_type": "stream",
73+
"text": [
74+
"Initial validation NLL: 13.15\n"
75+
]
76+
},
77+
{
78+
"name": "stderr",
79+
"output_type": "stream",
80+
"text": [
81+
"Epoch 1: 100%|██████████| 100/100 [00:17<00:00, 5.79it/s]\n"
82+
]
83+
},
84+
{
85+
"name": "stdout",
86+
"output_type": "stream",
87+
"text": [
88+
"Epoch 1: validation NLL: 3.49\n"
89+
]
90+
},
91+
{
92+
"name": "stderr",
93+
"output_type": "stream",
94+
"text": [
95+
"Epoch 2: 100%|██████████| 100/100 [00:07<00:00, 14.18it/s]\n"
96+
]
97+
},
6398
{
6499
"name": "stdout",
65100
"output_type": "stream",
66101
"text": [
67-
"Setting up PixelCNN\n",
68-
"Setting up PixelCNN\n",
69-
"Setting up PixelCNN\n",
70-
"Initial validation NLL: 10.56\n"
102+
"Epoch 2: validation NLL: 3.48\n"
71103
]
72104
},
73105
{
74106
"name": "stderr",
75107
"output_type": "stream",
76108
"text": [
77-
"Epoch 1: 0%| | 0/100 [00:00<?, ?it/s]"
109+
"Epoch 3: 100%|██████████| 100/100 [00:06<00:00, 15.57it/s]\n"
78110
]
79111
},
80112
{
81113
"name": "stdout",
82114
"output_type": "stream",
83115
"text": [
84-
"Setting up PixelCNN\n"
116+
"Epoch 3: validation NLL: 3.47\n"
85117
]
86118
},
87119
{
88120
"name": "stderr",
89121
"output_type": "stream",
90122
"text": [
91-
"Epoch 1: 69%|██████▉ | 69/100 [00:10<00:04, 6.88it/s]\n"
123+
"Epoch 4: 100%|██████████| 100/100 [00:06<00:00, 15.52it/s]\n"
124+
]
125+
},
126+
{
127+
"name": "stdout",
128+
"output_type": "stream",
129+
"text": [
130+
"Epoch 4: validation NLL: 3.46\n"
92131
]
93132
},
94133
{
95-
"ename": "KeyboardInterrupt",
96-
"evalue": "",
97-
"output_type": "error",
98-
"traceback": [
99-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
100-
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
101-
"\u001b[1;32m/home/hpinkard_waller/GitRepos/EncodingInformation/mi_estimator_experiments/pixel_cnn_mi_estimation.ipynb Cell 3\u001b[0m line \u001b[0;36m4\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bwaller-fuoco.eecs.berkeley.edu/home/hpinkard_waller/GitRepos/EncodingInformation/mi_estimator_experiments/pixel_cnn_mi_estimation.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mencoding_information\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39minformation_estimation\u001b[39;00m \u001b[39mimport\u001b[39;00m estimate_mutual_information\n\u001b[1;32m <a href='vscode-notebook-cell://ssh-remote%2Bwaller-fuoco.eecs.berkeley.edu/home/hpinkard_waller/GitRepos/EncodingInformation/mi_estimator_experiments/pixel_cnn_mi_estimation.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a>\u001b[0m \u001b[39m# mi_gaussian, stationary_gaussian = estimate_mutual_information(patches, eigenvalue_floor=1e0, entropy_model='gaussian', max_epochs=10, verbose=True, return_entropy_model=True)\u001b[39;00m\n\u001b[0;32m----> <a href='vscode-notebook-cell://ssh-remote%2Bwaller-fuoco.eecs.berkeley.edu/home/hpinkard_waller/GitRepos/EncodingInformation/mi_estimator_experiments/pixel_cnn_mi_estimation.ipynb#W2sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a>\u001b[0m mi_pixelcnn, pixel_cnn \u001b[39m=\u001b[39m estimate_mutual_information(patches, entropy_model\u001b[39m=\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39mpixelcnn\u001b[39;49m\u001b[39m'\u001b[39;49m, max_epochs\u001b[39m=\u001b[39;49m\u001b[39m20\u001b[39;49m, verbose\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, return_entropy_model\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n",
102-
"File \u001b[0;32m~/GitRepos/EncodingInformation/encoding_information/information_estimation.py:248\u001b[0m, in \u001b[0;36mestimate_mutual_information\u001b[0;34m(noisy_images, clean_images, entropy_model, test_set_fraction, gaussian_noise_sigma, estimate_conditional_from_model_samples, patience, num_val_samples, batch_size, max_epochs, learning_rate, use_iterative_optimization, eigenvalue_floor, gradient_clip, momentum, analytic_marginal_entropy, steps_per_epoch, num_hidden_channels, num_mixture_components, return_entropy_model, verbose)\u001b[0m\n\u001b[1;32m 246\u001b[0m \u001b[39mif\u001b[39;00m v \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 247\u001b[0m hyperparams[k] \u001b[39m=\u001b[39m v\n\u001b[0;32m--> 248\u001b[0m noisy_image_model\u001b[39m.\u001b[39;49mfit(training_set, verbose\u001b[39m=\u001b[39;49mverbose, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mhyperparams)\n\u001b[1;32m 249\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 250\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mUnrecognized entropy model \u001b[39m\u001b[39m{\u001b[39;00mentropy_model\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n",
103-
"File \u001b[0;32m~/GitRepos/EncodingInformation/encoding_information/models/pixel_cnn.py:294\u001b[0m, in \u001b[0;36mPixelCNN.fit\u001b[0;34m(self, train_images, learning_rate, max_epochs, steps_per_epoch, patience, sigma_min, batch_size, num_val_samples, seed, verbose)\u001b[0m\n\u001b[1;32m 290\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_flax_model\u001b[39m.\u001b[39mcompute_loss(\u001b[39m*\u001b[39moutput, x)\n\u001b[1;32m 292\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_state \u001b[39m=\u001b[39m TrainState\u001b[39m.\u001b[39mcreate(apply_fn\u001b[39m=\u001b[39mapply_fn, params\u001b[39m=\u001b[39minitial_params, tx\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_optimizer) \n\u001b[0;32m--> 294\u001b[0m best_params, val_loss_history \u001b[39m=\u001b[39m train_model(train_images\u001b[39m=\u001b[39;49mtrain_images, state\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_state, batch_size\u001b[39m=\u001b[39;49mbatch_size, num_val_samples\u001b[39m=\u001b[39;49m\u001b[39mint\u001b[39;49m(num_val_samples),\n\u001b[1;32m 295\u001b[0m steps_per_epoch\u001b[39m=\u001b[39;49msteps_per_epoch, num_epochs\u001b[39m=\u001b[39;49mmax_epochs, patience\u001b[39m=\u001b[39;49mpatience, verbose\u001b[39m=\u001b[39;49mverbose)\n\u001b[1;32m 296\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_state \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_state\u001b[39m.\u001b[39mreplace(params\u001b[39m=\u001b[39mbest_params)\n\u001b[1;32m 297\u001b[0m \u001b[39mreturn\u001b[39;00m val_loss_history\n",
104-
"File \u001b[0;32m~/GitRepos/EncodingInformation/encoding_information/models/image_distribution_models.py:193\u001b[0m, in \u001b[0;36mtrain_model\u001b[0;34m(train_images, state, batch_size, num_val_samples, steps_per_epoch, num_epochs, patience, train_step, verbose)\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[39mfor\u001b[39;00m _ \u001b[39min\u001b[39;00m \u001b[39miter\u001b[39m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m verbose \u001b[39melse\u001b[39;00m tqdm(\u001b[39miter\u001b[39m, desc\u001b[39m=\u001b[39m\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39mEpoch \u001b[39m\u001b[39m{\u001b[39;00mepoch_idx\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m):\n\u001b[1;32m 192\u001b[0m batch \u001b[39m=\u001b[39m \u001b[39mnext\u001b[39m(train_ds_iterator)\n\u001b[0;32m--> 193\u001b[0m state, loss \u001b[39m=\u001b[39m train_step(state, batch)\n\u001b[1;32m 195\u001b[0m avg_loss \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m loss \u001b[39m/\u001b[39m steps_per_epoch\n\u001b[1;32m 197\u001b[0m \u001b[39m# uniform noise already added in the dataset generators\u001b[39;00m\n",
105-
"File \u001b[0;32m<string>:1\u001b[0m, in \u001b[0;36m<lambda>\u001b[0;34m(_cls, count, mu, nu)\u001b[0m\n",
106-
"File \u001b[0;32m_pydevd_bundle/pydevd_cython.pyx:1457\u001b[0m, in \u001b[0;36m_pydevd_bundle.pydevd_cython.SafeCallWrapper.__call__\u001b[0;34m()\u001b[0m\n",
107-
"File \u001b[0;32m_pydevd_bundle/pydevd_cython.pyx:1758\u001b[0m, in \u001b[0;36m_pydevd_bundle.pydevd_cython.ThreadTracer.__call__\u001b[0;34m()\u001b[0m\n",
108-
"File \u001b[0;32m~/mambaforge/envs/phenotypes/lib/python3.10/site-packages/debugpy/_vendored/pydevd/_pydev_bundle/pydev_is_thread_alive.py:9\u001b[0m, in \u001b[0;36mis_thread_alive\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 6\u001b[0m _temp \u001b[39m=\u001b[39m threading\u001b[39m.\u001b[39mThread()\n\u001b[1;32m 7\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mhasattr\u001b[39m(_temp, \u001b[39m'\u001b[39m\u001b[39m_is_stopped\u001b[39m\u001b[39m'\u001b[39m): \u001b[39m# Python 3.x has this\u001b[39;00m\n\u001b[0;32m----> 9\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mis_thread_alive\u001b[39m(t):\n\u001b[1;32m 10\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mnot\u001b[39;00m t\u001b[39m.\u001b[39m_is_stopped\n\u001b[1;32m 12\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39mhasattr\u001b[39m(_temp, \u001b[39m'\u001b[39m\u001b[39m_Thread__stopped\u001b[39m\u001b[39m'\u001b[39m): \u001b[39m# Python 2.x has this\u001b[39;00m\n",
109-
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
134+
"name": "stderr",
135+
"output_type": "stream",
136+
"text": [
137+
"Epoch 5: 100%|██████████| 100/100 [00:06<00:00, 15.55it/s]\n"
138+
]
139+
},
140+
{
141+
"name": "stdout",
142+
"output_type": "stream",
143+
"text": [
144+
"Epoch 5: validation NLL: 3.46\n"
145+
]
146+
},
147+
{
148+
"name": "stderr",
149+
"output_type": "stream",
150+
"text": [
151+
"Epoch 6: 66%|██████▌ | 66/100 [00:04<00:02, 15.55it/s]"
110152
]
111153
}
112154
],
113155
"source": [
114156
"from encoding_information.information_estimation import estimate_mutual_information\n",
115157
"\n",
116-
"mi_gaussian, stationary_gaussian = estimate_mutual_information(patches, eigenvalue_floor=1e0, entropy_model='gaussian', max_epochs=10, verbose=True, return_entropy_model=True)\n",
158+
"# mi_gaussian, stationary_gaussian = estimate_mutual_information(patches, eigenvalue_floor=1e0, entropy_model='gaussian', max_epochs=10, verbose=True, return_entropy_model=True)\n",
117159
"mi_pixelcnn, pixel_cnn = estimate_mutual_information(patches, entropy_model='pixelcnn', max_epochs=20, verbose=True, return_entropy_model=True)"
118160
]
119161
},

0 commit comments

Comments
 (0)