Skip to content

Av softmax#322

Merged
VincentAuriau merged 15 commits into
mainfrom
av-softmax
May 27, 2026
Merged

Av softmax#322
VincentAuriau merged 15 commits into
mainfrom
av-softmax

Conversation

@VincentAuriau

Copy link
Copy Markdown
Collaborator

fixes limit case of unavailable large utility

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

Gemini encountered an error creating the review. You can try again by commenting /gemini review.

@VincentAuriau

Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the tied_embeddings option to the AleaCarta basket model, allowing items to have distinct embeddings depending on whether they are in the basket. It also updates the softmax utility calculation and adds corresponding unit and integration tests. The review identified a critical bug in softmax_with_availabilities where multiplying logits by availability can corrupt the maximum utility calculation when all available items have negative utilities. Additionally, setting a global TensorFlow seed inside the model's instantiate method was flagged as a bad practice due to global side-effects.

Comment thread choice_learn/tf_ops.py Outdated
Comment on lines +36 to +38
numerator = tf.multiply(items_logit_by_choice, available_items_by_choice)
normalizer = tf.reduce_max(numerator, axis=axis, keepdims=True)
numerator = tf.exp(numerator - normalizer)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Bug Description

Multiplying items_logit_by_choice by available_items_by_choice (which is 0.0 for unavailable items) sets the logits of unavailable items to 0.0.

If all available items have negative utilities (e.g., [-10.0, -20.0]), the maximum utility (normalizer) will be incorrectly computed as 0.0 (from the unavailable items) instead of -10.0. This defeats the overflow/underflow prevention and can cause the probabilities of available items to underflow to 0.0 or result in NaN values.

Solution

Instead of multiplying by the availability mask before taking the maximum, use tf.where to mask the unavailable items with a large negative value (e.g., -1e9). This ensures that:

  1. Unavailable items do not affect the maximum utility calculation.
  2. Exponentiating the masked logits safely underflows to 0.0 for unavailable items, avoiding any NaN or overflow issues.
Suggested change
numerator = tf.multiply(items_logit_by_choice, available_items_by_choice)
normalizer = tf.reduce_max(numerator, axis=axis, keepdims=True)
numerator = tf.exp(numerator - normalizer)
large_negative = tf.constant(-1e9, dtype=items_logit_by_choice.dtype)
masked_logits = tf.where(available_items_by_choice > 0.5, items_logit_by_choice, large_negative)
normalizer = tf.reduce_max(masked_logits, axis=axis, keepdims=True)
numerator = tf.exp(masked_logits - normalizer)

n_stores = 1
self.n_stores = n_stores

tf.random.set_seed(42)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Issue

Calling tf.random.set_seed(42) inside the instantiate method of a model class introduces a global side-effect. It resets the global random seed of the entire TensorFlow runtime, which can silently break the reproducibility of other models, training loops, or random processes in the user's application.

Recommendation

Avoid setting global seeds inside library code. Instead, allow the user to control the global seed externally, or accept an optional seed parameter in the model's constructor and use it locally (e.g., passing it to the initializers or using a tf.random.Generator). If reproducibility is desired by default, you can pass a seed directly to the initializers or document that the user should set the global seed.

@github-actions

github-actions Bot commented May 27, 2026

Copy link
Copy Markdown
Contributor

Coverage

Coverage Report for Python 3.10
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py64198%286
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py155398%111, 136, 326
   base_basket_model.py2482988%115–116, 127, 145, 189, 259, 381, 489, 589–591, 680, 785, 793, 803, 851, 854–864, 925–928, 967–968
   basic_attention_model.py89496%424, 427, 433, 440
   self_attention_model.py133993%71, 73, 75, 450–454, 651
   shopper.py184995%130, 159, 325, 345, 360, 363, 377, 489, 618
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1956268%74–77, 223–231, 299–301, 411, 544–580, 608–648, 671, 678–683, 753–764, 812
   preprocessing.py947817%43–45, 128–364
choice_learn/basket_models/datasets
   __init__.py30100% 
   badminton.py81693%62, 194–199, 247
   bakery.py38392%47, 51, 61
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py3353689%145, 187, 289, 297, 303, 312, 352, 356–357, 362, 391, 395–396, 413, 426, 434, 475–476, 485–486, 587, 589, 605, 609, 611, 734–735, 908, 935, 939–953
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
choice_learn/utils
   metrics.py116794%73, 151–153, 219–221, 287–289
TOTAL570483385% 

Tests Skipped Failures Errors Time
228 0 💤 0 ❌ 0 🔥 5m 22s ⏱️

@github-actions

github-actions Bot commented May 27, 2026

Copy link
Copy Markdown
Contributor

Coverage

Coverage Report for Python 3.9
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py64198%286
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py155398%111, 136, 326
   base_basket_model.py2482988%115–116, 127, 145, 189, 259, 381, 489, 589–591, 680, 785, 793, 803, 851, 854–864, 925–928, 967–968
   basic_attention_model.py89496%424, 427, 433, 440
   self_attention_model.py133993%71, 73, 75, 450–454, 651
   shopper.py184995%130, 159, 325, 345, 360, 363, 377, 489, 618
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1956268%74–77, 223–231, 299–301, 411, 544–580, 608–648, 671, 678–683, 753–764, 812
   preprocessing.py947817%43–45, 128–364
choice_learn/basket_models/datasets
   __init__.py30100% 
   badminton.py81693%62, 194–199, 247
   bakery.py38392%47, 51, 61
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py3353590%145, 187, 289, 297, 303, 312, 352, 356–357, 362, 391, 395–396, 413, 426, 434, 475–476, 485–486, 587, 589, 605, 609, 734–735, 908, 935, 939–953
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py124298%186, 374
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132795%122, 285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2362360%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
choice_learn/utils
   metrics.py1167337%71–105, 115, 149–173, 183, 217–240, 250, 285–309, 319
TOTAL570289784% 

Tests Skipped Failures Errors Time
228 0 💤 0 ❌ 0 🔥 5m 32s ⏱️

@github-actions

github-actions Bot commented May 27, 2026

Copy link
Copy Markdown
Contributor

Coverage

Coverage Report for Python 3.12
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py64198%286
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py155398%111, 136, 326
   base_basket_model.py2482988%115–116, 127, 145, 189, 259, 381, 489, 589–591, 680, 785, 793, 803, 851, 854–864, 925–928, 967–968
   basic_attention_model.py89496%424, 427, 433, 440
   self_attention_model.py133993%71, 73, 75, 450–454, 651
   shopper.py184995%130, 159, 325, 345, 360, 363, 377, 489, 618
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1956268%74–77, 223–231, 299–301, 411, 544–580, 608–648, 671, 678–683, 753–764, 812
   preprocessing.py947817%43–45, 128–364
choice_learn/basket_models/datasets
   __init__.py30100% 
   badminton.py81693%62, 194–199, 247
   bakery.py38392%47, 53, 61
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py3353590%145, 187, 289, 297, 303, 312, 352, 356–357, 362, 391, 395–396, 413, 426, 434, 475–476, 485–486, 587, 589, 605, 609, 611, 734–735, 935, 939–953
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py1241885%186, 341, 360, 364–380
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
choice_learn/utils
   metrics.py116794%73, 151–153, 219–221, 287–289
TOTAL570484885% 

Tests Skipped Failures Errors Time
228 0 💤 1 ❌ 0 🔥 7m 34s ⏱️

@github-actions

github-actions Bot commented May 27, 2026

Copy link
Copy Markdown
Contributor

Coverage

Coverage Report for Python 3.11
FileStmtsMissCoverMissing
choice_learn
   __init__.py20100% 
   tf_ops.py64198%286
choice_learn/basket_models
   __init__.py40100% 
   alea_carta.py155398%111, 136, 326
   base_basket_model.py2482988%115–116, 127, 145, 189, 259, 381, 489, 589–591, 680, 785, 793, 803, 851, 854–864, 925–928, 967–968
   basic_attention_model.py89496%424, 427, 433, 440
   self_attention_model.py133993%71, 73, 75, 450–454, 651
   shopper.py184995%130, 159, 325, 345, 360, 363, 377, 489, 618
choice_learn/basket_models/data
   __init__.py20100% 
   basket_dataset.py1956268%74–77, 223–231, 299–301, 411, 544–580, 608–648, 671, 678–683, 753–764, 812
   preprocessing.py947817%43–45, 128–364
choice_learn/basket_models/datasets
   __init__.py30100% 
   badminton.py81693%62, 194–199, 247
   bakery.py38392%47, 51, 61
choice_learn/basket_models/utils
   __init__.py00100% 
   permutation.py22195%37
choice_learn/data
   __init__.py30100% 
   choice_dataset.py6493395%198, 250, 283, 421, 463–464, 589, 724, 738, 840, 842, 937, 957–961, 1140, 1159–1161, 1179–1181, 1209, 1214, 1223, 1240, 1281, 1293, 1307, 1346, 1361, 1366, 1395, 1408, 1443–1444
   indexer.py2412390%20, 31, 45, 60–67, 202–204, 219–230, 265, 291, 582
   storage.py161696%22, 33, 51, 56, 61, 71
   store.py72720%3–275
choice_learn/datasets
   __init__.py40100% 
   base.py400599%42–43, 153–154, 714
   expedia.py1028319%37–301
   tafeng.py490100% 
choice_learn/datasets/data
   __init__.py00100% 
choice_learn/models
   __init__.py14286%15–16
   base_model.py3353590%145, 187, 289, 297, 303, 312, 352, 356–357, 362, 391, 395–396, 413, 426, 434, 475–476, 485–486, 587, 589, 605, 609, 611, 734–735, 935, 939–953
   baseline_models.py490100% 
   conditional_logit.py2692690%49, 52, 54, 85, 88, 91–95, 98–102, 136, 206, 212–216, 351, 388, 445, 520–526, 651, 685, 822, 826
   halo_mnl.py1241885%186, 341, 360, 364–380
   latent_class_base_model.py2863986%55–61, 273–279, 288, 325–330, 497–500, 605, 624, 665–701, 715, 720, 751–752, 774–775, 869–870, 974
   latent_class_mnl.py62690%257–261, 296
   learning_mnl.py67396%157, 182, 188
   nested_logit.py2911296%55, 77, 160, 269, 351, 484, 530, 600, 679, 848, 900, 904
   reslogit.py132695%285, 360, 369, 374, 382, 432
   rumnet.py236399%748–751, 982
   simple_mnl.py139696%167, 275, 347, 355, 357, 359
   tastenet.py94397%142, 180, 188
choice_learn/toolbox
   __init__.py00100% 
   assortment_optimizer.py27678%28–30, 93–95, 160–162
   gurobi_opt.py2382380%3–675
   or_tools_opt.py2301195%103, 107, 296–305, 315, 319, 607, 611
choice_learn/utils
   metrics.py116794%73, 151–153, 219–221, 287–289
TOTAL570484885% 

Tests Skipped Failures Errors Time
228 0 💤 1 ❌ 0 🔥 5m 38s ⏱️

@VincentAuriau VincentAuriau merged commit 40cd22b into main May 27, 2026
8 checks passed
@VincentAuriau VincentAuriau deleted the av-softmax branch May 27, 2026 17:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants