Skip to content

Commit a0e9996

Browse files
Merge pull request #137 from alexanderthclark/feature/mr-affine-transition
affine refactor for mr
2 parents 46e9b8d + f0b27e4 commit a0e9996

4 files changed

Lines changed: 211 additions & 27 deletions

File tree

freeride/affine.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ def ppf_sum(*curves, comparative_advantage=True):
541541

542542
class BaseAffine:
543543

544-
def __init__(self, intercept=None, slope=None, elements=None, inverse=True):
544+
def __init__(self, intercept=None, slope=None, elements=None, inverse=True, sum_elements=True):
545545
"""
546546
Initialize the BaseAffine object.
547547
@@ -555,6 +555,8 @@ def __init__(self, intercept=None, slope=None, elements=None, inverse=True):
555555
List of AffineElement objects. If provided, it will override `intercept` and `slope`. Default is None.
556556
inverse : bool, optional
557557
Indicates if the transformation should be inverted. Default is True.
558+
sum_elements : bool, optional
559+
Whether to sum elements together (True) or keep them as separate pieces (False). Default is True.
558560
559561
Raises
560562
------
@@ -572,6 +574,7 @@ def __init__(self, intercept=None, slope=None, elements=None, inverse=True):
572574
zipped = zip(slope, intercept)
573575
elements = [AffineElement(slope=m, intercept=b, inverse=inverse) for m, b in zipped]
574576
self.elements = elements
577+
self.sum_elements = sum_elements
575578

576579
if intercept is None:
577580
intercept = [c.intercept for c in elements]
@@ -666,13 +669,13 @@ class Affine(BaseAffine):
666669
A class to represent a piecewise affine function.
667670
"""
668671

669-
def __init__(self, intercept=None, slope=None, elements=None, inverse=True):
672+
def __init__(self, intercept=None, slope=None, elements=None, inverse=True, sum_elements=True):
670673
"""
671674
Initializes an Affine object with given slopes and intercepts or elements.
672675
The slopes correspond to elements, which are differentiated from pieces.
673676
674-
The elements represent the individual curves which are horizontally summed.
675-
The pieces are the resulting functions for the piecewise expression describing the aggregate.
677+
When sum_elements=True: elements are horizontally summed to create aggregate pieces.
678+
When sum_elements=False: elements are kept as separate pieces (for discontinuous functions).
676679
677680
Parameters
678681
----------
@@ -684,14 +687,16 @@ def __init__(self, intercept=None, slope=None, elements=None, inverse=True):
684687
A list of AffineElements whose horizontal sum defines the Affine object.
685688
inverse : bool, optional
686689
When inverse is True, it is assumed that equations are in the form P(Q).
690+
sum_elements : bool, optional
691+
Whether to sum elements together (True) or keep them as separate pieces (False). Default is True.
687692
688693
Raises
689694
------
690695
ValueError
691696
If the lengths of `slope` and `intercept` do not match.
692697
"""
693698

694-
super().__init__(intercept, slope, elements, inverse)
699+
super().__init__(intercept, slope, elements, inverse, sum_elements)
695700

696701
# Special handling for perfectly elastic curves - they can't be summed
697702
# so they'll only have one element and don't need horizontal_sum
@@ -703,6 +708,23 @@ def __init__(self, intercept=None, slope=None, elements=None, inverse=True):
703708
mids = [self.elements[0].intercept]
704709
sections = [(0, np.inf)]
705710
qsections = [(0, np.inf)]
711+
elif not self.sum_elements:
712+
# Non-summing mode: keep elements as separate pieces
713+
pieces = self.elements
714+
self.pieces = pieces
715+
# Create sections based on element domains
716+
sections = []
717+
qsections = []
718+
cuts = []
719+
for element in self.elements:
720+
if hasattr(element, '_domain') and element._domain:
721+
domain = element._domain
722+
sections.append((min(domain), max(domain)))
723+
qsections.append((min(domain), max(domain)))
724+
cuts.extend([min(domain), max(domain)])
725+
# Remove duplicates and sort cuts
726+
cuts = sorted(list(set(cuts))) if cuts else [0, np.inf]
727+
mids = [(a + b) / 2 for a, b in sections] if sections else [0]
706728
else:
707729
# Normal processing for non-horizontal curves
708730
pieces, cuts, mids = horizontal_sum(*self.elements)
@@ -755,10 +777,14 @@ def _get_active_piece(self, q):
755777
for piece in [piece for piece in self.pieces if piece]:
756778
q0, q1 = np.min(piece._domain), np.max(piece._domain)
757779

758-
# this has to be closed on the right and open on the left
759-
# there has to be area to the left when calc surplus
760-
if q0 < q <= q1:
761-
return piece
780+
if not self.sum_elements:
781+
# Non-summing mode: use right-continuous convention [a,b)
782+
if q0 <= q < q1:
783+
return piece
784+
else:
785+
# Summing mode: original logic (left-open, right-closed)
786+
if q0 < q <= q1:
787+
return piece
762788
return None
763789

764790
def __call__(self, x):
@@ -785,7 +811,23 @@ def __call__(self, x):
785811

786812
def q(self, p):
787813
# returns q given p
788-
return np.sum([np.max([0,c.q(p)]) for c in self.elements])
814+
if not self.sum_elements:
815+
# Non-summing mode: find the piece that contains this price using [a,b) convention
816+
for piece in self.pieces:
817+
if piece and hasattr(piece, '_domain') and piece._domain:
818+
# Use right-continuous convention: [a,b) - left inclusive, right exclusive
819+
domain = piece._domain
820+
a, b = np.min(domain), np.max(domain)
821+
if a <= p < b:
822+
try:
823+
q_val = piece.q(p)
824+
if np.isfinite(q_val) and q_val >= 0:
825+
return q_val
826+
except:
827+
continue
828+
return 0
829+
else:
830+
return np.sum([np.max([0,c.q(p)]) for c in self.elements])
789831

790832
def p(self, q):
791833
# returns p given q

freeride/monopoly.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .curves import Demand
88
from .costs import Cost
9+
from .revenue import MarginalRevenue
910

1011

1112
class Monopoly:
@@ -15,35 +16,35 @@ def __init__(self, demand: Demand, total_cost: Cost):
1516
self.demand = demand
1617
self.total_cost = total_cost
1718
self._mc = total_cost.marginal_cost()
19+
self._mr = MarginalRevenue.from_demand(demand)
1820

1921
self.q = 0.0
2022
self.p = 0.0
2123
self.profit = 0.0
2224

2325
self._solve()
2426

25-
@staticmethod
26-
def _mr_for_piece(piece, q: float) -> float:
27-
"""Return marginal revenue for ``piece`` at quantity ``q``."""
28-
return piece.intercept + 2 * piece.slope * q
29-
3027
def _solve(self):
3128
candidates = []
32-
for piece in [p for p in self.demand.pieces if p]:
29+
30+
# Find interior solutions where MR = MC
31+
# For each MR piece, solve MR(q) = MC(q)
32+
for mr_piece in [p for p in self._mr.pieces if p]:
3333
mc_coef = list(self._mc.coef)
3434
if len(mc_coef) < 2:
3535
mc_coef += [0] * (2 - len(mc_coef))
3636
diff = mc_coef.copy()
37-
diff[0] -= piece.intercept
38-
diff[1] -= 2 * piece.slope
37+
diff[0] -= mr_piece.intercept
38+
diff[1] -= mr_piece.slope
3939
poly = np.polynomial.Polynomial(diff)
4040
for r in poly.roots():
4141
if np.isreal(r):
4242
q = float(np.real(r))
4343
if q <= 0:
4444
continue
45-
dom = piece._domain
46-
if dom and not (min(dom) < q <= max(dom)):
45+
# Check if q is in the domain of this MR piece
46+
dom = mr_piece._domain
47+
if dom and not (min(dom) <= q < max(dom)):
4748
continue
4849
candidates.append(q)
4950

freeride/revenue.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""Revenue curve utilities."""
44

55
from .quadratic import QuadraticElement, BaseQuadratic
6-
from .affine import AffineElement
6+
from .affine import AffineElement, Affine
77

88

99
class Revenue(BaseQuadratic):
@@ -27,20 +27,26 @@ def from_demand(cls, demand) -> "Revenue":
2727
return cls(elements=elements)
2828

2929

30-
class MarginalRevenue(BaseQuadratic):
30+
class MarginalRevenue(Affine):
3131
"""Piecewise linear marginal revenue curve.
3232
3333
This class represents marginal revenue as a function of quantity.
3434
These curves will not necessarily be continuous for piecewise Demand.
35+
Uses Affine with sum_elements=False to keep pieces separate for discontinuities.
3536
"""
3637

3738
@classmethod
3839
def from_demand(cls, demand) -> "MarginalRevenue":
40+
"""Construct a MarginalRevenue curve from a :class:`Demand` instance."""
3941
elements = []
4042
pieces = [p for p in demand.pieces if p]
4143
for piece in pieces:
42-
coef = piece.intercept, 2*piece.slope, 0
43-
revenue_element = QuadraticElement(*coef)
44-
revenue_element._domain = sorted(piece._domain)
45-
elements.append(revenue_element)
46-
return cls(elements=elements)
44+
# For linear demand P = a + bQ, MR = a + 2bQ
45+
mr_intercept = piece.intercept
46+
mr_slope = 2 * piece.slope # Slope becomes twice as steep
47+
mr_element = AffineElement(mr_intercept, mr_slope)
48+
mr_element._domain = piece._domain # Use same domain as demand piece
49+
elements.append(mr_element)
50+
51+
# Create Affine object with sum_elements=False to keep pieces separate
52+
return cls(elements=elements, sum_elements=False)

tests/test_monopoly.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
import numpy as np
23
from freeride.curves import Demand
34
from freeride.costs import Cost
45
from freeride.monopoly import Monopoly
@@ -22,6 +23,140 @@ def test_piecewise_demand_zero_cost(self):
2223
self.assertAlmostEqual(m.p, 3.75)
2324
self.assertAlmostEqual(m.profit, 28.125)
2425

26+
def test_kinked_demand_monopoly(self):
27+
"""Test monopoly with kinked demand curve (discontinuous MR)."""
28+
# Create a kinked demand: steep segment + flat segment
29+
d1 = Demand(20, -1) # P = 20 - Q, steep segment
30+
d2 = Demand(10, -0.5) # P = 10 - 0.5*Q, flat segment
31+
kinked_demand = d1 + d2
32+
33+
# Use constant marginal cost
34+
cost = Cost(0, 3) # MC = 3
35+
m = Monopoly(kinked_demand, cost)
36+
37+
# Verify this is truly profit-maximizing by checking grid
38+
q_grid = np.linspace(0.1, 25, 1000)
39+
profits = []
40+
for q in q_grid:
41+
p = kinked_demand.p(q)
42+
if p > 0: # Only consider positive prices
43+
profit = p * q - cost.cost(q)
44+
profits.append(profit)
45+
else:
46+
profits.append(-np.inf)
47+
48+
max_profit_grid = max(profits)
49+
50+
# Our solution should be within 1% of the grid maximum
51+
self.assertGreater(m.profit, 0.99 * max_profit_grid)
52+
53+
def test_kinked_demand_with_quadratic_cost(self):
54+
"""Test kinked demand with quadratic cost function."""
55+
# Create kinked demand
56+
d1 = Demand(15, -0.8) # P = 15 - 0.8*Q
57+
d2 = Demand(8, -0.3) # P = 8 - 0.3*Q
58+
kinked_demand = d1 + d2
59+
60+
# Quadratic cost: TC = 2 + Q + 0.1*Q^2, so MC = 1 + 0.2*Q
61+
cost = Cost([2, 1, 0.1])
62+
m = Monopoly(kinked_demand, cost)
63+
64+
# Verify this is profit-maximizing using grid search
65+
q_grid = np.linspace(0.1, 30, 2000)
66+
profits = []
67+
for q in q_grid:
68+
p = kinked_demand.p(q)
69+
if p > 0:
70+
profit = p * q - cost.cost(q)
71+
profits.append(profit)
72+
else:
73+
profits.append(-np.inf)
74+
75+
max_profit_grid = max(profits)
76+
best_q_idx = np.argmax(profits)
77+
best_q_grid = q_grid[best_q_idx]
78+
79+
# Our solution should be very close to grid optimum
80+
self.assertGreater(m.profit, 0.99 * max_profit_grid)
81+
self.assertAlmostEqual(m.q, best_q_grid, delta=0.1)
82+
83+
def test_kinked_demand_single_segment_only(self):
84+
"""Test kinked demand where monopolist serves only one segment."""
85+
# Use your exact suggestion: P = 10 - Q and P = 1 - 10*Q
86+
d1 = Demand(10, -1) # P = 10 - Q
87+
d2 = Demand(1, -10) # P = 1 - 10*Q (very steep, low willingness to pay)
88+
kinked_demand = d1 + d2
89+
90+
# With MC = 0, monopolist will choose Q where MR = 0 on the profitable segment
91+
cost = Cost(0, 0)
92+
m = Monopoly(kinked_demand, cost)
93+
94+
# For first segment P = 10 - Q, MR = 10 - 2Q
95+
# Setting MR = 0: Q = 5, P = 5
96+
# At Q = 5, first segment gives profit = 5 * 5 = 25
97+
# Second segment at any reasonable Q gives much lower prices/profits
98+
# So monopolist should choose Q = 5, P = 5 (still pricing above second segment)
99+
100+
self.assertAlmostEqual(m.q, 5.0, places=1)
101+
self.assertAlmostEqual(m.p, 5.0, places=1)
102+
# Verify this price is indeed above what second segment would offer
103+
self.assertGreater(m.p, 1.0) # Much higher than second segment's max price
104+
105+
# Verify optimality with grid search
106+
q_grid = np.linspace(0.1, 15, 1000)
107+
profits = []
108+
for q in q_grid:
109+
p = kinked_demand.p(q)
110+
if p > 0:
111+
profit = p * q - cost.cost(q)
112+
profits.append(profit)
113+
else:
114+
profits.append(-np.inf)
115+
116+
max_profit_grid = max(profits)
117+
self.assertGreater(m.profit, 0.99 * max_profit_grid)
118+
119+
def test_kinked_demand_both_segments(self):
120+
"""Test kinked demand where monopolist serves both segments."""
121+
# Create kinked demand where both segments are attractive
122+
d1 = Demand(15, -0.5) # P = 15 - 0.5*Q, gentle slope, high willingness to pay
123+
d2 = Demand(12, -1) # P = 12 - Q, steeper but still reasonable
124+
kinked_demand = d1 + d2
125+
126+
# Use marginal cost that makes serving both segments optimal
127+
cost = Cost(0, 2) # MC = 2
128+
m = Monopoly(kinked_demand, cost)
129+
130+
# Find the kink point (where the segments meet)
131+
# d1: P = 15 - 0.5*Q, d2: P = 12 - Q
132+
# They intersect when 15 - 0.5*Q = 12 - Q
133+
# 3 = -0.5*Q, so Q = 6, P = 12
134+
kink_q = 6.0
135+
kink_p = 12.0
136+
137+
# Monopolist should operate beyond the kink (serve both segments)
138+
self.assertGreater(m.q, kink_q) # Should serve both segments
139+
self.assertLess(m.p, kink_p) # Price should be below kink point
140+
141+
# Verify optimality with grid search
142+
q_grid = np.linspace(0.1, 20, 1500)
143+
profits = []
144+
for q in q_grid:
145+
p = kinked_demand.p(q)
146+
if p > 0:
147+
profit = p * q - cost.cost(q)
148+
profits.append(profit)
149+
else:
150+
profits.append(-np.inf)
151+
152+
max_profit_grid = max(profits)
153+
best_q_idx = np.argmax(profits)
154+
best_q_grid = q_grid[best_q_idx]
155+
156+
# Verify we found the optimum
157+
self.assertGreater(m.profit, 0.99 * max_profit_grid)
158+
self.assertAlmostEqual(m.q, best_q_grid, delta=0.1)
159+
25160

26161
if __name__ == "__main__":
27162
unittest.main()

0 commit comments

Comments
 (0)