Av softmax#322
Conversation
|
Warning Gemini encountered an error creating the review. You can try again by commenting |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces the tied_embeddings option to the AleaCarta basket model, allowing items to have distinct embeddings depending on whether they are in the basket. It also updates the softmax utility calculation and adds corresponding unit and integration tests. The review identified a critical bug in softmax_with_availabilities where multiplying logits by availability can corrupt the maximum utility calculation when all available items have negative utilities. Additionally, setting a global TensorFlow seed inside the model's instantiate method was flagged as a bad practice due to global side-effects.
| numerator = tf.multiply(items_logit_by_choice, available_items_by_choice) | ||
| normalizer = tf.reduce_max(numerator, axis=axis, keepdims=True) | ||
| numerator = tf.exp(numerator - normalizer) |
There was a problem hiding this comment.
Bug Description
Multiplying items_logit_by_choice by available_items_by_choice (which is 0.0 for unavailable items) sets the logits of unavailable items to 0.0.
If all available items have negative utilities (e.g., [-10.0, -20.0]), the maximum utility (normalizer) will be incorrectly computed as 0.0 (from the unavailable items) instead of -10.0. This defeats the overflow/underflow prevention and can cause the probabilities of available items to underflow to 0.0 or result in NaN values.
Solution
Instead of multiplying by the availability mask before taking the maximum, use tf.where to mask the unavailable items with a large negative value (e.g., -1e9). This ensures that:
- Unavailable items do not affect the maximum utility calculation.
- Exponentiating the masked logits safely underflows to
0.0for unavailable items, avoiding anyNaNor overflow issues.
| numerator = tf.multiply(items_logit_by_choice, available_items_by_choice) | |
| normalizer = tf.reduce_max(numerator, axis=axis, keepdims=True) | |
| numerator = tf.exp(numerator - normalizer) | |
| large_negative = tf.constant(-1e9, dtype=items_logit_by_choice.dtype) | |
| masked_logits = tf.where(available_items_by_choice > 0.5, items_logit_by_choice, large_negative) | |
| normalizer = tf.reduce_max(masked_logits, axis=axis, keepdims=True) | |
| numerator = tf.exp(masked_logits - normalizer) |
| n_stores = 1 | ||
| self.n_stores = n_stores | ||
|
|
||
| tf.random.set_seed(42) |
There was a problem hiding this comment.
Issue
Calling tf.random.set_seed(42) inside the instantiate method of a model class introduces a global side-effect. It resets the global random seed of the entire TensorFlow runtime, which can silently break the reproducibility of other models, training loops, or random processes in the user's application.
Recommendation
Avoid setting global seeds inside library code. Instead, allow the user to control the global seed externally, or accept an optional seed parameter in the model's constructor and use it locally (e.g., passing it to the initializers or using a tf.random.Generator). If reproducibility is desired by default, you can pass a seed directly to the initializers or document that the user should set the global seed.
Coverage Report for Python 3.9
|
|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Coverage Report for Python 3.12
|
|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Coverage Report for Python 3.11
|
|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
fixes limit case of unavailable large utility