From 708c06019abfdfa9febe526554db8ab0dc4c6490 Mon Sep 17 00:00:00 2001 From: Arad Reder Date: Sun, 29 Mar 2026 13:08:54 +0300 Subject: [PATCH] refactor: extract common delayed field logic --- src/staking/staking.cairo | 61 +++++++++++++++------------------------ src/staking/utils.cairo | 36 +++++++++++++++++++++++ 2 files changed, 60 insertions(+), 37 deletions(-) diff --git a/src/staking/staking.cairo b/src/staking/staking.cairo index ad0b1a73..4bdfcc3e 100644 --- a/src/staking/staking.cairo +++ b/src/staking/staking.cairo @@ -36,8 +36,8 @@ pub mod Staking { use staking::staking::utils::{ assert_caller_is_not_zero, balance_at_epoch, calculate_staker_total_staking_power, claim_from_reward_supplier, compute_new_delegated_stake, deploy_delegation_pool_contract, - get_undelegate_intent_token, is_btc_active, split_rewards_with_commission, - strk_token_dispatcher, + get_undelegate_intent_token, is_btc_active, resolve_delayed_field, + split_rewards_with_commission, strk_token_dispatcher, validate_delayed_field_update, }; use staking::types::{ Amount, BlockNumber, Commission, Epoch, InternalStakerInfoLatest, PeerId, PublicKey, @@ -838,16 +838,19 @@ pub mod Staking { let staker_info = self.internal_staker_info(:staker_address); assert!(staker_info.unstake_time.is_none(), "{}", Error::UNSTAKE_IN_PROGRESS); - let (curr_activation_epoch, _, prev_public_key) = self.public_key.read(staker_address); - let curr_epoch = self.get_current_epoch(); - // TODO: Confirm with product this set period is ok. - assert!(curr_epoch >= curr_activation_epoch, "{}", Error::PUBLIC_KEY_SET_IN_PROGRESS); - assert!(prev_public_key != public_key, "{}", Error::PUBLIC_KEY_MUST_DIFFER); + let delayed = self.public_key.read(staker_address); + validate_delayed_field_update( + delayed_field: delayed, + new_value: public_key, + curr_epoch: self.get_current_epoch(), + err_set_in_progress: Error::PUBLIC_KEY_SET_IN_PROGRESS, + err_must_differ: Error::PUBLIC_KEY_MUST_DIFFER, + ); - let new_activation_epoch = self.get_epoch_plus_k(); + let (_, _, prev_public_key) = delayed; self .public_key - .write(staker_address, (new_activation_epoch, prev_public_key, public_key)); + .write(staker_address, (self.get_epoch_plus_k(), prev_public_key, public_key)); self.emit(Events::PublicKeySet { staker_address, public_key }); } @@ -868,13 +871,17 @@ pub mod Staking { let staker_info = self.internal_staker_info(:staker_address); assert!(staker_info.unstake_time.is_none(), "{}", Error::UNSTAKE_IN_PROGRESS); - let (curr_activation_epoch, _, prev_peer_id) = self.peer_id.read(staker_address); - let curr_epoch = self.get_current_epoch(); - assert!(curr_epoch >= curr_activation_epoch, "{}", Error::PEER_ID_SET_IN_PROGRESS); - assert!(prev_peer_id != peer_id, "{}", Error::PEER_ID_MUST_DIFFER); + let delayed = self.peer_id.read(staker_address); + validate_delayed_field_update( + delayed_field: delayed, + new_value: peer_id, + curr_epoch: self.get_current_epoch(), + err_set_in_progress: Error::PEER_ID_SET_IN_PROGRESS, + err_must_differ: Error::PEER_ID_MUST_DIFFER, + ); - let new_activation_epoch = self.get_epoch_plus_k(); - self.peer_id.write(staker_address, (new_activation_epoch, prev_peer_id, peer_id)); + let (_, _, prev_peer_id) = delayed; + self.peer_id.write(staker_address, (self.get_epoch_plus_k(), prev_peer_id, peer_id)); self.emit(Events::PeerIdSet { staker_address, peer_id }); } @@ -2254,17 +2261,7 @@ pub mod Staking { fn get_public_key_at_epoch( self: @ContractState, staker_address: ContractAddress, epoch_id: Epoch, ) -> Option { - let (activation_epoch, old_pk, new_pk) = self.public_key.read(staker_address); - let current_pk = if epoch_id >= activation_epoch { - new_pk - } else { - old_pk - }; - if current_pk.is_non_zero() { - Some(current_pk) - } else { - None - } + resolve_delayed_field(delayed_field: self.public_key.read(staker_address), :epoch_id) } /// Returns the peer ID for `staker_address` at `epoch_id`, @@ -2276,17 +2273,7 @@ pub mod Staking { fn get_peer_id_at_epoch( self: @ContractState, staker_address: ContractAddress, epoch_id: Epoch, ) -> Option { - let (activation_epoch, old_pid, new_pid) = self.peer_id.read(staker_address); - let current_pid = if epoch_id >= activation_epoch { - new_pid - } else { - old_pid - }; - if current_pid.is_non_zero() { - Some(current_pid) - } else { - None - } + resolve_delayed_field(delayed_field: self.peer_id.read(staker_address), :epoch_id) } /// Calculates rewards for the given staker and his pools, updates the staker's diff --git a/src/staking/utils.cairo b/src/staking/utils.cairo index 6bc38de2..e0e96530 100644 --- a/src/staking/utils.cairo +++ b/src/staking/utils.cairo @@ -117,6 +117,42 @@ pub(crate) fn is_btc_active(active_status: (Epoch, bool), epoch_id: Epoch) -> bo (epoch_id >= epoch) == is_active } +/// Resolves a delayed field value at a given epoch. +/// `delayed_field` is a tuple of (activation_epoch, old_value, new_value). +/// Returns `Some(value)` if the resolved value is non-zero, `None` otherwise. +/// +/// Precondition: `get_current_epoch() <= epoch_id < get_current_epoch() + K`. +pub(crate) fn resolve_delayed_field, +Copy, +Drop>( + delayed_field: (Epoch, T, T), epoch_id: Epoch, +) -> Option { + let (activation_epoch, old_value, new_value) = delayed_field; + let current_value = if epoch_id >= activation_epoch { + new_value + } else { + old_value + }; + if current_value.is_non_zero() { + Option::Some(current_value) + } else { + Option::None + } +} + +/// Validates that a delayed field can be updated. +/// Checks that the current epoch is past the activation epoch and the new value differs +/// from the previous value. +pub(crate) fn validate_delayed_field_update, +Copy, +Drop>( + delayed_field: (Epoch, T, T), + new_value: T, + curr_epoch: Epoch, + err_set_in_progress: Error, + err_must_differ: Error, +) { + let (curr_activation_epoch, _, prev_value) = delayed_field; + assert!(curr_epoch >= curr_activation_epoch, "{}", err_set_in_progress); + assert!(prev_value != new_value, "{}", err_must_differ); +} + /// Returns the staking power for the given staker. /// The staking power is calculated by: /// ((staker_strk_total_amount / strk_total_amount) * (1 - ALPHA) +