Skip to content

Commit 5310bd5

Browse files
Merge pull request #105 from rainlanguage/2025-08-15-minimize-denominator-div
full precision div
2 parents 87c8c65 + b87ef83 commit 5310bd5

11 files changed

Lines changed: 405 additions & 237 deletions

.gas-snapshot

Lines changed: 199 additions & 199 deletions
Large diffs are not rendered by default.

crates/float/src/lib.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1368,7 +1368,10 @@ mod tests {
13681368
let zero = Float::parse("0".to_string()).unwrap();
13691369
let err = (one / zero).unwrap_err();
13701370

1371-
assert!(matches!(err, FloatError::Revert(_)));
1371+
assert!(matches!(
1372+
err,
1373+
FloatError::DecimalFloat(DecimalFloatErrors::MulDivOverflow(_))
1374+
));
13721375
}
13731376

13741377
#[test]

src/error/ErrDecimalFloat.sol

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,6 @@ error LossyConversionFromFloat(int256 signedCoefficient, int256 exponent);
3030

3131
/// @dev Thrown when attempting to exponentiate 0^b where b is negative.
3232
error ZeroNegativePower(Float b);
33+
34+
/// @dev Thrown when mulDiv internal to division overflows.
35+
error MulDivOverflow(uint256 x, uint256 y, uint256 denominator);

src/lib/implementation/LibDecimalFloatImplementation.sol

Lines changed: 143 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// SPDX-License-Identifier: CAL
22
pragma solidity ^0.8.25;
33

4-
import {ExponentOverflow, Log10Negative, Log10Zero} from "../../error/ErrDecimalFloat.sol";
4+
import {ExponentOverflow, Log10Negative, Log10Zero, MulDivOverflow} from "../../error/ErrDecimalFloat.sol";
55
import {
66
LOG_TABLES,
77
LOG_TABLES_SMALL,
@@ -188,13 +188,149 @@ library LibDecimalFloatImplementation {
188188
pure
189189
returns (int256, int256)
190190
{
191+
uint256 scale = 1e76;
192+
int256 adjustExponent = 76;
193+
int256 signedCoefficient;
194+
191195
unchecked {
196+
// Move both coefficients into the e75/e76 range, so that the result
197+
// of division will not cause a mulDiv overflow.
192198
(signedCoefficientA, exponentA) = maximize(signedCoefficientA, exponentA);
193-
(signedCoefficientB, exponentB) = normalize(signedCoefficientB, exponentB);
199+
(signedCoefficientB, exponentB) = maximize(signedCoefficientB, exponentB);
200+
201+
// mulDiv only works with unsigned integers, so get the absolute
202+
// values of the coefficients.
203+
uint256 signedCoefficientAAbs;
204+
if (signedCoefficientA > 0) {
205+
signedCoefficientAAbs = uint256(signedCoefficientA);
206+
} else if (signedCoefficientA < 0) {
207+
if (signedCoefficientA == type(int256).min) {
208+
signedCoefficientAAbs = uint256(type(int256).max) + 1;
209+
} else {
210+
signedCoefficientAAbs = uint256(-signedCoefficientA);
211+
}
212+
} else {
213+
return (MAXIMIZED_ZERO_SIGNED_COEFFICIENT, MAXIMIZED_ZERO_EXPONENT);
214+
}
215+
uint256 signedCoefficientBAbs;
216+
if (signedCoefficientB < 0) {
217+
if (signedCoefficientB == type(int256).min) {
218+
signedCoefficientBAbs = uint256(type(int256).max) + 1;
219+
} else {
220+
signedCoefficientBAbs = uint256(-signedCoefficientB);
221+
}
222+
} else {
223+
signedCoefficientBAbs = uint256(signedCoefficientB);
224+
}
194225

195-
int256 signedCoefficient = signedCoefficientA / signedCoefficientB;
196-
int256 exponent = exponentA - exponentB;
197-
return (signedCoefficient, exponent);
226+
// We are going to scale the numerator up by the largest power of ten
227+
// that is smaller than the denominator. This will always overflow
228+
// internally to the mulDiv during the initial multiplication, in
229+
// 512 bits, but will subsequently always be reduced back down to
230+
// fit in 256 bits by the division of a denominator that is larger
231+
// than the scale up.
232+
if (signedCoefficientBAbs < scale) {
233+
scale = 1e75;
234+
adjustExponent = 75;
235+
}
236+
uint256 signedCoefficientAbs = mulDiv(signedCoefficientAAbs, scale, signedCoefficientBAbs);
237+
signedCoefficient = (signedCoefficientA ^ signedCoefficientB) < 0
238+
? -int256(signedCoefficientAbs)
239+
: int256(signedCoefficientAbs);
240+
}
241+
242+
// Keep the exponent calculation outside the unchecked block so that we
243+
// don't silently under/overflow.
244+
int256 exponent = exponentA - exponentB - adjustExponent;
245+
return (signedCoefficient, exponent);
246+
}
247+
248+
/// mulDiv as seen in Open Zeppelin, PRB Math, Solady, and other libraries.
249+
/// Credit to Remco Bloemen under MIT license: https://2π.com/21/muldiv
250+
function mulDiv(uint256 x, uint256 y, uint256 denominator) internal pure returns (uint256 result) {
251+
// 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2^256 and mod 2^256 - 1, then use
252+
// use the Chinese Remainder Theorem to reconstruct the 512-bit result. The result is stored in two 256
253+
// variables such that product = prod1 * 2^256 + prod0.
254+
uint256 prod0; // Least significant 256 bits of the product
255+
uint256 prod1; // Most significant 256 bits of the product
256+
assembly ("memory-safe") {
257+
let mm := mulmod(x, y, not(0))
258+
prod0 := mul(x, y)
259+
prod1 := sub(sub(mm, prod0), lt(mm, prod0))
260+
}
261+
262+
// Handle non-overflow cases, 256 by 256 division.
263+
if (prod1 == 0) {
264+
unchecked {
265+
return prod0 / denominator;
266+
}
267+
}
268+
269+
// Make sure the result is less than 2^256. Also prevents denominator == 0.
270+
if (prod1 >= denominator) {
271+
revert MulDivOverflow(x, y, denominator);
272+
}
273+
274+
////////////////////////////////////////////////////////////////////////////
275+
// 512 by 256 division
276+
////////////////////////////////////////////////////////////////////////////
277+
278+
// Make division exact by subtracting the remainder from [prod1 prod0].
279+
uint256 remainder;
280+
assembly ("memory-safe") {
281+
// Compute remainder using the mulmod Yul instruction.
282+
remainder := mulmod(x, y, denominator)
283+
284+
// Subtract 256 bit number from 512-bit number.
285+
prod1 := sub(prod1, gt(remainder, prod0))
286+
prod0 := sub(prod0, remainder)
287+
}
288+
289+
unchecked {
290+
// Calculate the largest power of two divisor of the denominator using the unary operator ~. This operation cannot overflow
291+
// because the denominator cannot be zero at this point in the function execution. The result is always >= 1.
292+
// For more detail, see https://cs.stackexchange.com/q/138556/92363.
293+
uint256 lpotdod = denominator & (~denominator + 1);
294+
uint256 flippedLpotdod;
295+
296+
assembly ("memory-safe") {
297+
// Factor powers of two out of denominator.
298+
// slither-disable-next-line divide-before-multiply
299+
denominator := div(denominator, lpotdod)
300+
301+
// Divide [prod1 prod0] by lpotdod.
302+
// slither-disable-next-line divide-before-multiply
303+
prod0 := div(prod0, lpotdod)
304+
305+
// Get the flipped value `2^256 / lpotdod`. If the `lpotdod` is zero, the flipped value is one.
306+
// `sub(0, lpotdod)` produces the two's complement version of `lpotdod`, which is equivalent to flipping all the bits.
307+
// However, `div` interprets this value as an unsigned value: https://ethereum.stackexchange.com/q/147168/24693
308+
flippedLpotdod := add(div(sub(0, lpotdod), lpotdod), 1)
309+
}
310+
311+
// Shift in bits from prod1 into prod0.
312+
prod0 |= prod1 * flippedLpotdod;
313+
314+
// Invert denominator mod 2^256. Now that denominator is an odd number, it has an inverse modulo 2^256 such
315+
// that denominator * inv = 1 mod 2^256. Compute the inverse by starting with a seed that is correct for
316+
// four bits. That is, denominator * inv = 1 mod 2^4.
317+
// slither-disable-next-line incorrect-exp
318+
uint256 inverse = (3 * denominator) ^ 2;
319+
320+
// Use the Newton-Raphson iteration to improve the precision. Thanks to Hensel's lifting lemma, this also works
321+
// in modular arithmetic, doubling the correct bits in each step.
322+
inverse *= 2 - denominator * inverse; // inverse mod 2^8
323+
inverse *= 2 - denominator * inverse; // inverse mod 2^16
324+
inverse *= 2 - denominator * inverse; // inverse mod 2^32
325+
inverse *= 2 - denominator * inverse; // inverse mod 2^64
326+
inverse *= 2 - denominator * inverse; // inverse mod 2^128
327+
inverse *= 2 - denominator * inverse; // inverse mod 2^256
328+
329+
// Because the division is now exact we can divide by multiplying with the modular inverse of denominator.
330+
// This will give us the correct result modulo 2^256. Since the preconditions guarantee that the outcome is
331+
// less than 2^256, this is the final result. We don't need to compute the high bits of the result and prod1
332+
// is no longer required.
333+
result = prod0 * inverse;
198334
}
199335
}
200336

@@ -367,14 +503,9 @@ library LibDecimalFloatImplementation {
367503
return signedCoefficientA == signedCoefficientB;
368504
}
369505

370-
/// Inverts a float. Equivalent to `1 / x` with modest gas optimizations.
506+
/// Inverts a float. Equivalent to `1 / x`.
371507
function inv(int256 signedCoefficient, int256 exponent) internal pure returns (int256, int256) {
372-
(signedCoefficient, exponent) = normalize(signedCoefficient, exponent);
373-
374-
signedCoefficient = 1e76 / signedCoefficient;
375-
exponent = -exponent - 76;
376-
377-
return (signedCoefficient, exponent);
508+
return div(1e76, -76, signedCoefficient, exponent);
378509
}
379510

380511
/// log10(x) for a float x.

test/lib/LibCommonResults.sol

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// SPDX-License-Identifier: CAL
22
pragma solidity ^0.8.25;
33

4-
int256 constant ONES = 111111111111111111111111111111111111111;
5-
int256 constant THREES = 333333333333333333333333333333333333333;
4+
int256 constant ONES = 1111111111111111111111111111111111111111111111111111111111111111111111111111;
5+
int256 constant THREES_PACKED = 3333333333333333333333333333333333333333333333333333333333333333333;
6+
int256 constant THREES = 3333333333333333333333333333333333333333333333333333333333333333333333333333;

test/src/lib/LibDecimalFloat.mixed.t.sol

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,26 @@
22
pragma solidity =0.8.25;
33

44
import {LibDecimalFloat, Float} from "src/lib/LibDecimalFloat.sol";
5-
import {THREES, ONES} from "../../lib/LibCommonResults.sol";
5+
import {THREES_PACKED, ONES} from "../../lib/LibCommonResults.sol";
66

77
import {Test} from "forge-std/Test.sol";
88

99
contract LibDecimalFloatMixedTest is Test {
1010
using LibDecimalFloat for Float;
1111

1212
/// (1 / 3) * 555e18
13-
function testDiv1Over3() external pure {
13+
function testDiv1Over3Mixed() external pure {
1414
Float a = LibDecimalFloat.packLossless(1, 0);
1515
Float b = LibDecimalFloat.packLossless(3, 0);
1616
Float c = a.div(b);
1717
(int256 signedCoefficientDiv, int256 exponentDiv) = LibDecimalFloat.unpack(c);
18-
assertEq(signedCoefficientDiv, THREES, "coefficient");
19-
assertEq(exponentDiv, -39, "exponent");
18+
assertEq(signedCoefficientDiv, THREES_PACKED, "coefficient");
19+
assertEq(exponentDiv, -67, "exponent");
2020

2121
Float d = c.mul(LibDecimalFloat.packLossless(555, 18));
2222
(int256 signedCoefficientMul, int256 exponentMul) = LibDecimalFloat.unpack(d);
2323

24-
assertEq(signedCoefficientMul, 184999999999999999999999999999999999999815);
25-
assertEq(exponentMul, -21);
24+
assertEq(signedCoefficientMul, 1849999999999999999999999999999999999999999999999999999999999999999);
25+
assertEq(exponentMul, -46);
2626
}
2727
}

test/src/lib/LibDecimalFloat.pow.t.sol

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,15 @@ contract LibDecimalFloatPowTest is LogTest {
3838
}
3939

4040
function testPows() external {
41-
checkPow(5e37, -38, 3e37, -36, 9.32835820895522388059701492537313432835e38, -48);
42-
checkPow(5e37, -38, 6e37, -36, 8.71080139372822299651567944250871080139e38, -57);
43-
// // Issues found in fuzzing from here.
41+
// 0.5 ^ 30 = 9.3132257e-10
42+
checkPow(
43+
5e37, -38, 3e37, -36, 9.328358208955223880597014925373134328358208955223880597014925373134e66, -66 - 10
44+
);
45+
// 0.5 ^ 60 = 8.6736174e-19
46+
checkPow(
47+
5e37, -38, 6e37, -36, 8.710801393728222996515679442508710801393728222996515679442508710801e66, -66 - 19
48+
);
49+
// Issues found in fuzzing from here.
4450
checkPow(99999, 0, 12182, 0, 1000, 60907);
4551
checkPow(1785215562, 0, 18, 0, 3388, 163);
4652
}

test/src/lib/LibDecimalFloat.pow10.t.sol

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ contract LibDecimalFloatPow10Test is LogTest {
2727
} else {
2828
Float floatPower10 = this.pow10External(float);
2929
(int256 signedCoefficientUnpacked, int256 exponentUnpacked) = floatPower10.unpack();
30+
31+
// Compensate for the implied pack and unpack.
32+
(Float resultPacked, bool lossless) = LibDecimalFloat.packLossy(signedCoefficient, exponent);
33+
(lossless);
34+
(signedCoefficient, exponent) = resultPacked.unpack();
35+
3036
assertEq(signedCoefficient, signedCoefficientUnpacked);
3137
assertEq(exponent, exponentUnpacked);
3238
}

test/src/lib/implementation/LibDecimalFloatImplementation.div.t.sol

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ contract LibDecimalFloatImplementationDivTest is Test {
4040

4141
/// 1 / 3
4242
function testDiv1Over3() external pure {
43-
checkDiv(1, 0, 3, 0, THREES, -39);
43+
checkDiv(1, 0, 3, 0, THREES, -76);
4444
}
4545

4646
/// - 1 / 3
4747
function testDivNegative1Over3() external pure {
48-
checkDiv(-1, 0, 3, 0, -THREES, -39);
48+
checkDiv(-1, 0, 3, 0, -THREES, -76);
4949
}
5050

5151
/// 1 / 3 gas
@@ -56,41 +56,47 @@ contract LibDecimalFloatImplementationDivTest is Test {
5656

5757
/// 1e18 / 3
5858
function testDiv1e18Over3() external pure {
59-
checkDiv(1e18, 0, 3, 0, THREES, -21);
59+
checkDiv(1e18, 0, 3, 0, THREES, -58);
6060
}
6161

6262
/// 10,0 / 1e38,-37 == 1
6363
function testDivTenOverOOMs() external pure {
64-
checkDiv(10, 0, 1e38, -37, 1e39, -39);
64+
checkDiv(10, 0, 1e38, -37, 1e76, -76);
6565
}
6666

6767
/// 1e38,-37 / 2,0 == 5
6868
function testDivOOMsOverTen() external pure {
69-
checkDiv(1e38, -37, 2, 0, 5e38, -38);
69+
checkDiv(1e38, -37, 2, 0, 5e75, -75);
7070
}
7171

7272
/// 5e37,-37 / 2e37,-37 == 2.5
7373
function testDivOOMs5and2() external pure {
74-
checkDiv(5e37, -37, 2e37, -37, 25e38, -39);
74+
checkDiv(5e37, -37, 2e37, -37, 2.5e76, -76);
7575
}
7676

7777
/// (1 / 9) / (1 / 3) == 0.333..
7878
function testDiv1Over9Over1Over3() external pure {
7979
// 1 / 9
8080
(int256 signedCoefficientA, int256 exponentA) = LibDecimalFloatImplementation.div(1, 0, 9, 0);
8181
assertEq(signedCoefficientA, ONES);
82-
assertEq(exponentA, -39);
82+
assertEq(exponentA, -76);
8383

8484
// 1 / 3
8585
(int256 signedCoefficientB, int256 exponentB) = LibDecimalFloatImplementation.div(1, 0, 3, 0);
8686
assertEq(signedCoefficientB, THREES);
87-
assertEq(exponentB, -39);
87+
assertEq(exponentB, -76);
8888

8989
// (1 / 9) / (1 / 3)
9090
(int256 signedCoefficient, int256 exponent) =
9191
LibDecimalFloatImplementation.div(signedCoefficientA, exponentA, signedCoefficientB, exponentB);
92-
assertEq(signedCoefficient, 333333333333333333333333333333333333336);
93-
assertEq(exponent, -39);
92+
assertEq(signedCoefficient, THREES);
93+
assertEq(exponent, -76);
94+
95+
// (1 / 3) / (1 / 9) == 3
96+
(signedCoefficient, exponent) =
97+
LibDecimalFloatImplementation.div(signedCoefficientB, exponentB, signedCoefficientA, exponentA);
98+
assertEq(signedCoefficient, 3e76);
99+
assertEq(exponent, -76);
94100
}
95101

96102
/// forge-config: default.fuzz.runs = 100
@@ -102,7 +108,7 @@ contract LibDecimalFloatImplementationDivTest is Test {
102108
int256 di = 0;
103109
while (true) {
104110
int256 i = 1;
105-
int256 j = -39 - di;
111+
int256 j = -76 - di;
106112
while (true) {
107113
// want to see full precision on the THREES regardless of the
108114
// scale of the numerator and denominator.

test/src/lib/implementation/LibDecimalFloatImplementation.inv.t.sol

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,16 @@ import {LibDecimalFloatSlow} from "test/lib/LibDecimalFloatSlow.sol";
66
import {
77
LibDecimalFloatImplementation,
88
EXPONENT_MIN,
9-
EXPONENT_MAX
9+
EXPONENT_MAX,
10+
MulDivOverflow
1011
} from "src/lib/implementation/LibDecimalFloatImplementation.sol";
1112

1213
contract LibDecimalFloatImplementationInvTest is Test {
14+
function invExternal(int256 signedCoefficient, int256 exponent) external pure returns (int256, int256) {
15+
(signedCoefficient, exponent) = LibDecimalFloatImplementation.inv(signedCoefficient, exponent);
16+
return (signedCoefficient, exponent);
17+
}
18+
1319
/// Compare reference.
1420
function testInvReference(int256 signedCoefficient, int256 exponent) external pure {
1521
vm.assume(signedCoefficient != 0);
@@ -33,4 +39,9 @@ contract LibDecimalFloatImplementationInvTest is Test {
3339
(int256 outputSignedCoefficient, int256 outputExponent) = LibDecimalFloatSlow.invSlow(3e37, -37);
3440
(outputSignedCoefficient, outputExponent);
3541
}
42+
43+
function testInv0() external {
44+
vm.expectRevert(abi.encodeWithSelector(MulDivOverflow.selector, 1e76, 1e75, 0));
45+
this.invExternal(0, 0);
46+
}
3647
}

0 commit comments

Comments
 (0)