Skip to content

Commit c414494

Browse files
committed
Watch for network changes and rotate PQ tickets when needed
1 parent 3f06406 commit c414494

6 files changed

Lines changed: 334 additions & 26 deletions

File tree

dnscrypt-proxy/crypto.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,13 @@ func (proxy *Proxy) Encrypt(
7979
serverInfo *ServerInfo,
8080
packet []byte,
8181
proto string,
82-
) (sharedKey *[32]byte, encrypted []byte, clientNonce []byte, err error) {
82+
) (sharedKey *[32]byte, encrypted []byte, clientNonce []byte, queryEpoch uint64, err error) {
8383
if serverInfo.CryptoConstruction == XWingPQ {
8484
return proxy.encryptPQ(serverInfo, packet, proto)
8585
}
8686
nonce, clientNonce := make([]byte, NonceSize), make([]byte, HalfNonceSize)
8787
if _, err := crypto_rand.Read(clientNonce); err != nil {
88-
return nil, nil, nil, err
88+
return nil, nil, nil, queryEpoch, err
8989
}
9090
copy(nonce, clientNonce)
9191
var publicKey *[PublicKeySize]byte
@@ -110,7 +110,7 @@ func (proxy *Proxy) Encrypt(
110110
} else {
111111
var xpad [1]byte
112112
if _, err := crypto_rand.Read(xpad[:]); err != nil {
113-
return nil, nil, nil, err
113+
return nil, nil, nil, queryEpoch, err
114114
}
115115
minQuestionSize += int(xpad[0])
116116
}
@@ -122,7 +122,7 @@ func (proxy *Proxy) Encrypt(
122122
}
123123
if QueryOverhead+len(packet)+1 > paddedLength {
124124
err = errors.New("Question too large; cannot be padded")
125-
return sharedKey, encrypted, clientNonce, err
125+
return sharedKey, encrypted, clientNonce, queryEpoch, err
126126
}
127127
encrypted = append(serverInfo.MagicQuery[:], publicKey[:]...)
128128
encrypted = append(encrypted, nonce[:HalfNonceSize]...)
@@ -134,14 +134,15 @@ func (proxy *Proxy) Encrypt(
134134
copy(xsalsaNonce[:], nonce)
135135
encrypted = secretbox.Seal(encrypted, padded, &xsalsaNonce, sharedKey)
136136
}
137-
return sharedKey, encrypted, clientNonce, err
137+
return sharedKey, encrypted, clientNonce, queryEpoch, err
138138
}
139139

140140
func (proxy *Proxy) Decrypt(
141141
serverInfo *ServerInfo,
142142
sharedKey *[32]byte,
143143
encrypted []byte,
144144
nonce []byte,
145+
queryEpoch uint64,
145146
) ([]byte, error) {
146147
serverMagicLen := len(ServerMagic)
147148
responseHeaderLen := serverMagicLen + NonceSize
@@ -172,7 +173,7 @@ func (proxy *Proxy) Decrypt(
172173
return encrypted, err
173174
}
174175
if serverInfo.CryptoConstruction == XWingPQ {
175-
packet, err = proxy.pqStripControl(serverInfo, sharedKey, nonce, packet)
176+
packet, err = proxy.pqStripControl(serverInfo, sharedKey, nonce, packet, queryEpoch)
176177
if err != nil {
177178
return encrypted, err
178179
}

dnscrypt-proxy/netmon.go

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"crypto/sha256"
6+
"encoding/hex"
7+
"net"
8+
"sort"
9+
"strconv"
10+
"sync"
11+
"sync/atomic"
12+
"time"
13+
14+
"github.com/jedisct1/dlog"
15+
)
16+
17+
const (
18+
defaultNetworkMonitorInterval = 5 * time.Second
19+
offlineNetworkFingerprint = "offline"
20+
)
21+
22+
type networkInterfaceSnapshot struct {
23+
Name string
24+
Index int
25+
HardwareAddr net.HardwareAddr
26+
Addrs []*net.IPNet
27+
}
28+
29+
type networkMonitor struct {
30+
epochValue atomic.Uint64
31+
mu sync.Mutex
32+
last string
33+
fingerprint func() string
34+
}
35+
36+
func newNetworkMonitor() *networkMonitor {
37+
return &networkMonitor{fingerprint: currentNetworkFingerprint}
38+
}
39+
40+
func (monitor *networkMonitor) epoch() uint64 {
41+
if monitor == nil {
42+
return 0
43+
}
44+
return monitor.epochValue.Load()
45+
}
46+
47+
func (monitor *networkMonitor) init() {
48+
if monitor == nil {
49+
return
50+
}
51+
fingerprint := monitor.currentFingerprint()
52+
monitor.mu.Lock()
53+
monitor.last = fingerprint
54+
monitor.mu.Unlock()
55+
}
56+
57+
func (monitor *networkMonitor) start(ctx context.Context, interval time.Duration) {
58+
if monitor == nil {
59+
return
60+
}
61+
if interval <= 0 {
62+
interval = defaultNetworkMonitorInterval
63+
}
64+
monitor.init()
65+
ticker := time.NewTicker(interval)
66+
defer ticker.Stop()
67+
for {
68+
select {
69+
case <-ctx.Done():
70+
return
71+
case <-ticker.C:
72+
monitor.check()
73+
}
74+
}
75+
}
76+
77+
func (monitor *networkMonitor) check() {
78+
fingerprint := monitor.currentFingerprint()
79+
monitor.mu.Lock()
80+
defer monitor.mu.Unlock()
81+
if monitor.last == "" {
82+
monitor.last = fingerprint
83+
return
84+
}
85+
if monitor.last == fingerprint {
86+
return
87+
}
88+
monitor.last = fingerprint
89+
monitor.epochValue.Add(1)
90+
dlog.Notice("Network change detected; rotating PQ resumption tickets")
91+
}
92+
93+
func (monitor *networkMonitor) currentFingerprint() string {
94+
if monitor.fingerprint == nil {
95+
return offlineNetworkFingerprint
96+
}
97+
return monitor.fingerprint()
98+
}
99+
100+
func currentNetworkFingerprint() string {
101+
localIPs := discoverNetworkMonitorLocalIPs()
102+
if len(localIPs) == 0 {
103+
return offlineNetworkFingerprint
104+
}
105+
interfaces := snapshotNetworkInterfaces()
106+
return buildNetworkFingerprint(localIPs, interfaces)
107+
}
108+
109+
func discoverNetworkMonitorLocalIPs() []net.IP {
110+
probeAddrs := []string{"192.0.2.1:9", "[2001:db8::1]:9"}
111+
localIPs := make([]net.IP, 0, len(probeAddrs))
112+
seen := make(map[string]struct{}, len(probeAddrs))
113+
for _, probeAddr := range probeAddrs {
114+
conn, err := net.DialTimeout("udp", probeAddr, time.Second)
115+
if err != nil {
116+
continue
117+
}
118+
localAddr, ok := conn.LocalAddr().(*net.UDPAddr)
119+
conn.Close()
120+
if !ok || localAddr.IP == nil || localAddr.IP.IsUnspecified() {
121+
continue
122+
}
123+
ip := append(net.IP(nil), localAddr.IP...)
124+
key := ip.String()
125+
if _, ok := seen[key]; ok {
126+
continue
127+
}
128+
seen[key] = struct{}{}
129+
localIPs = append(localIPs, ip)
130+
}
131+
return localIPs
132+
}
133+
134+
func snapshotNetworkInterfaces() []networkInterfaceSnapshot {
135+
interfaces, err := net.Interfaces()
136+
if err != nil {
137+
return nil
138+
}
139+
snapshots := make([]networkInterfaceSnapshot, 0, len(interfaces))
140+
for _, iface := range interfaces {
141+
addrs, err := iface.Addrs()
142+
if err != nil {
143+
continue
144+
}
145+
snapshot := networkInterfaceSnapshot{
146+
Name: iface.Name,
147+
Index: iface.Index,
148+
HardwareAddr: append(net.HardwareAddr(nil), iface.HardwareAddr...),
149+
}
150+
for _, addr := range addrs {
151+
ip, ipNet, err := net.ParseCIDR(addr.String())
152+
if err != nil {
153+
continue
154+
}
155+
ipNet.IP = ip
156+
snapshot.Addrs = append(snapshot.Addrs, ipNet)
157+
}
158+
snapshots = append(snapshots, snapshot)
159+
}
160+
return snapshots
161+
}
162+
163+
func buildNetworkFingerprint(localIPs []net.IP, interfaces []networkInterfaceSnapshot) string {
164+
if len(localIPs) == 0 {
165+
return offlineNetworkFingerprint
166+
}
167+
parts := make([]string, 0, len(localIPs))
168+
for _, ip := range localIPs {
169+
if ip == nil || ip.IsUnspecified() {
170+
continue
171+
}
172+
ip = append(net.IP(nil), ip...)
173+
iface, ok := findNetworkInterfaceForIP(ip, interfaces)
174+
part := "ip=" + ip.String()
175+
if ok {
176+
part += "|name=" + iface.Name + "|index=" + strconv.Itoa(iface.Index)
177+
if len(iface.HardwareAddr) > 0 {
178+
part += "|mac=" + iface.HardwareAddr.String()
179+
}
180+
}
181+
parts = append(parts, part)
182+
}
183+
if len(parts) == 0 {
184+
return offlineNetworkFingerprint
185+
}
186+
sort.Strings(parts)
187+
h := sha256.New()
188+
for _, part := range parts {
189+
h.Write([]byte(part))
190+
h.Write([]byte{0})
191+
}
192+
return hex.EncodeToString(h.Sum(nil))
193+
}
194+
195+
func (proxy *Proxy) networkEpoch() uint64 {
196+
if proxy == nil || proxy.netMonitor == nil {
197+
return 0
198+
}
199+
return proxy.netMonitor.epoch()
200+
}
201+
202+
func findNetworkInterfaceForIP(ip net.IP, interfaces []networkInterfaceSnapshot) (networkInterfaceSnapshot, bool) {
203+
matches := make([]networkInterfaceSnapshot, 0, 1)
204+
for _, iface := range interfaces {
205+
for _, addr := range iface.Addrs {
206+
if addr != nil && addr.Contains(ip) {
207+
matches = append(matches, iface)
208+
break
209+
}
210+
}
211+
}
212+
if len(matches) == 0 {
213+
return networkInterfaceSnapshot{}, false
214+
}
215+
sort.Slice(matches, func(i, j int) bool {
216+
if matches[i].Name != matches[j].Name {
217+
return matches[i].Name < matches[j].Name
218+
}
219+
return matches[i].Index < matches[j].Index
220+
})
221+
return matches[0], true
222+
}

0 commit comments

Comments
 (0)