diff --git a/apps/web/app/api/domains/[domain]/validate/route.ts b/apps/web/app/api/domains/[domain]/validate/route.ts index 629706460e3..1904d0741fc 100644 --- a/apps/web/app/api/domains/[domain]/validate/route.ts +++ b/apps/web/app/api/domains/[domain]/validate/route.ts @@ -1,6 +1,7 @@ import { isValidDomain } from "@/lib/api/domains/is-valid-domain"; import { domainExists } from "@/lib/api/domains/utils"; import { validateDubLinkSubdomain } from "@/lib/api/domains/validate-dub-link-subdomain"; +import { safeFetch } from "@/lib/api/safe-fetch"; import { withSession } from "@/lib/auth"; import dns from "dns/promises"; import { NextResponse } from "next/server"; @@ -50,25 +51,21 @@ export const GET = withSession(async ({ params }) => { // Helper function to check if a site is active on the domain async function hasSiteConfigured(domain: string): Promise { try { - // Try HTTP HEAD request first (both HTTP and HTTPS) + // Try HTTP HEAD request first (both HTTP and HTTPS). + // safeFetch enforces SSRF guards on the user-supplied domain (the path + // param flows in here) so this can't be used to probe internal hosts. const urls = [`https://${domain}`, `http://${domain}`]; for (const url of urls) { try { - const controller = new AbortController(); - const timeoutId = setTimeout(() => controller.abort(), 3000); // 3 second timeout - const response = await fetch(url, { - method: "HEAD", - signal: controller.signal, - }); - clearTimeout(timeoutId); + const response = await safeFetch( + url, + { method: "HEAD" }, + { timeoutMs: 3000 }, + ); if (response.ok) return true; - } catch (e) { - if (e instanceof DOMException && e.name === "AbortError") { - // If request was aborted due to timeout, continue to next check - continue; - } - // Continue to next URL if this one fails + } catch { + // Continue to next URL if this one fails (timeout, SSRF guard, etc.) continue; } } diff --git a/apps/web/app/api/links/iframeable/route.ts b/apps/web/app/api/links/iframeable/route.ts index 7e6a447921e..71202c24c92 100644 --- a/apps/web/app/api/links/iframeable/route.ts +++ b/apps/web/app/api/links/iframeable/route.ts @@ -1,4 +1,5 @@ import { handleAndReturnErrorResponse } from "@/lib/api/errors"; +import { safeFetch } from "@/lib/api/safe-fetch"; import { ratelimitOrThrow } from "@/lib/api/utils"; import { getDomainQuerySchema, @@ -17,7 +18,11 @@ export async function GET(req: NextRequest) { await ratelimitOrThrow(req, "iframeable"); - const iframeable = await isIframeable({ url, requestDomain: domain }); + const res = await safeFetch(url); + const iframeable = isIframeable({ + headers: res.headers, + requestDomain: domain, + }); return NextResponse.json({ iframeable }); } catch (error) { diff --git a/apps/web/app/api/links/metatags/utils.ts b/apps/web/app/api/links/metatags/utils.ts index 97e57052304..5da6d9f86ef 100644 --- a/apps/web/app/api/links/metatags/utils.ts +++ b/apps/web/app/api/links/metatags/utils.ts @@ -1,13 +1,14 @@ +import { safeFetch } from "@/lib/api/safe-fetch"; import { recordMetatags } from "@/lib/upstash"; import { linkPreviewImageBase64PrefixRegex } from "@/lib/zod/schemas/images"; -import { fetchWithTimeout, isValidUrl } from "@dub/utils"; +import { isValidUrl } from "@dub/utils"; import { waitUntil } from "@vercel/functions"; import he from "he"; import { parse } from "node-html-parser"; export const getHtml = async (url: string) => { try { - const response = await fetchWithTimeout(url); + const response = await safeFetch(url); if (!response.ok) { // If we get a 406 or other error, check if it's a Cloudflare-protected site diff --git a/apps/web/app/api/providers/route.ts b/apps/web/app/api/providers/route.ts index ec9ac97ef3e..6540b49f0fa 100644 --- a/apps/web/app/api/providers/route.ts +++ b/apps/web/app/api/providers/route.ts @@ -1,4 +1,5 @@ import { handleAndReturnErrorResponse } from "@/lib/api/errors"; +import { safeFetch } from "@/lib/api/safe-fetch"; import { ratelimitOrThrow } from "@/lib/api/utils"; import { getUrlQuerySchema } from "@/lib/zod/schemas/links"; import { fetchWithTimeout } from "@dub/utils"; @@ -41,12 +42,11 @@ export async function GET(req: NextRequest) { urlObject.pathname = "/xyz"; - const headers = await fetchWithTimeout(urlObject.toString(), { - headers: { - method: "HEAD", - }, - redirect: "manual", - }) + const headers = await safeFetch( + urlObject.toString(), + { method: "HEAD" }, + { maxRedirects: 0 }, + ) .then((r) => ({ engine: r.headers.get("engine"), poweredBy: r.headers.get("x-powered-by"), diff --git a/apps/web/lib/api/safe-fetch.ts b/apps/web/lib/api/safe-fetch.ts new file mode 100644 index 00000000000..87f2de5b9a8 --- /dev/null +++ b/apps/web/lib/api/safe-fetch.ts @@ -0,0 +1,297 @@ +/** + * SSRF-safe outbound fetch. + * + * Wraps `fetch` with three protections against attacker-controlled URLs: + * 1. Scheme allowlist: only `http:` and `https:` are permitted. + * 2. Host check: the hostname is resolved (IPv4 + IPv6 via Google DoH) and + * every returned address is checked against a deny-list of private, + * loopback, link-local, cloud-metadata, multicast, and reserved ranges. + * 3. Manual redirect handling: each hop's `Location` is re-validated, so a + * whitelisted host cannot bounce us into an internal IP. + */ + +import { DubApiError } from "@/lib/api/errors"; + +const ALLOWED_PROTOCOLS = new Set(["http:", "https:"]); +const DEFAULT_TIMEOUT_MS = 5000; +const DEFAULT_MAX_REDIRECTS = 5; + +const IPV4_REGEX = /^(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/; + +const BIG_ZERO = BigInt(0); +const BIG_16 = BigInt(16); +const BIG_32 = BigInt(32); +const BIG_FFFF = BigInt(0xffff); +const BIG_FFFFFFFF = BigInt(0xffffffff); + +const PRIVATE_IPV4_CIDRS = [ + "0.0.0.0/8", // current network + "10.0.0.0/8", // RFC1918 + "100.64.0.0/10", // CGNAT + "127.0.0.0/8", // loopback + "169.254.0.0/16", // link-local + cloud metadata + "172.16.0.0/12", // RFC1918 + "192.0.0.0/24", // IETF protocol assignments + "192.168.0.0/16", // RFC1918 + "198.18.0.0/15", // benchmarking + "224.0.0.0/4", // multicast + "240.0.0.0/4", // reserved +]; + +const PRIVATE_IPV6_CIDRS = [ + "::/128", // unspecified + "::1/128", // loopback + "fc00::/7", // unique local + "fe80::/10", // link-local + "ff00::/8", // multicast + "64:ff9b::/96", // IPv4/IPv6 translation + "2001:db8::/32", // documentation +]; + +function ipv4ToInt(ip: string): number | null { + const m = IPV4_REGEX.exec(ip); + if (!m) return null; + const octets = [m[1], m[2], m[3], m[4]].map((s) => parseInt(s, 10)); + if (octets.some((n) => Number.isNaN(n) || n < 0 || n > 255)) return null; + return ( + ((octets[0] << 24) | (octets[1] << 16) | (octets[2] << 8) | octets[3]) >>> 0 + ); +} + +function ipv4InCidr(ipInt: number, cidr: string): boolean { + const [rangeIp, prefix] = cidr.split("/"); + const prefixLen = parseInt(prefix, 10); + const rangeInt = ipv4ToInt(rangeIp); + if (rangeInt === null || prefixLen < 0 || prefixLen > 32) return false; + if (prefixLen === 0) return true; + const mask = (0xffffffff << (32 - prefixLen)) >>> 0; + return (ipInt & mask) === (rangeInt & mask); +} + +function isPrivateIpv4(ip: string): boolean { + const ipInt = ipv4ToInt(ip); + // Unparseable → fail closed. + if (ipInt === null) return true; + return PRIVATE_IPV4_CIDRS.some((c) => ipv4InCidr(ipInt, c)); +} + +/** Expand an IPv6 string (possibly with embedded IPv4) to a 128-bit bigint. */ +function ipv6ToBigInt(ip: string): bigint | null { + let value = ip.trim(); + if (value.startsWith("[") && value.endsWith("]")) { + value = value.slice(1, -1); + } + + // Handle embedded IPv4 (e.g. ::ffff:127.0.0.1) by converting the dotted + // suffix into two hextets. + const lastColon = value.lastIndexOf(":"); + if (lastColon !== -1) { + const suffix = value.slice(lastColon + 1); + if (suffix.includes(".")) { + const v4Int = ipv4ToInt(suffix); + if (v4Int === null) return null; + const hi = ((v4Int >>> 16) & 0xffff).toString(16); + const lo = (v4Int & 0xffff).toString(16); + value = `${value.slice(0, lastColon)}:${hi}:${lo}`; + } + } + + let head: string[] = []; + let tail: string[] = []; + const doubleColonIdx = value.indexOf("::"); + if (doubleColonIdx !== -1) { + // Reject more than one "::" + if (value.indexOf("::", doubleColonIdx + 1) !== -1) return null; + const left = value.slice(0, doubleColonIdx); + const right = value.slice(doubleColonIdx + 2); + head = left ? left.split(":") : []; + tail = right ? right.split(":") : []; + } else { + head = value.split(":"); + } + + const missing = 8 - head.length - tail.length; + if (missing < 0 || (doubleColonIdx === -1 && missing !== 0)) return null; + + const groups = [...head, ...Array(missing).fill("0"), ...tail]; + if (groups.length !== 8) return null; + + let result = BIG_ZERO; + for (const g of groups) { + if (!/^[0-9a-fA-F]{1,4}$/.test(g)) return null; + result = (result << BIG_16) | BigInt(parseInt(g, 16)); + } + return result; +} + +function ipv6InCidr(ipBig: bigint, cidr: string): boolean { + const [rangeIp, prefix] = cidr.split("/"); + const prefixLen = parseInt(prefix, 10); + const rangeBig = ipv6ToBigInt(rangeIp); + if (rangeBig === null || prefixLen < 0 || prefixLen > 128) return false; + if (prefixLen === 0) return true; + const shift = BigInt(128 - prefixLen); + return ipBig >> shift === rangeBig >> shift; +} + +function isPrivateIpv6(ip: string): boolean { + const ipBig = ipv6ToBigInt(ip); + if (ipBig === null) return true; + + // IPv4-mapped IPv6 (::ffff:0:0/96): unwrap the last 32 bits and re-check. + if (ipBig >> BIG_32 === BIG_FFFF) { + const v4Int = Number(ipBig & BIG_FFFFFFFF); + const v4Str = `${(v4Int >>> 24) & 0xff}.${(v4Int >>> 16) & 0xff}.${ + (v4Int >>> 8) & 0xff + }.${v4Int & 0xff}`; + return isPrivateIpv4(v4Str); + } + + return PRIVATE_IPV6_CIDRS.some((c) => ipv6InCidr(ipBig, c)); +} + +export function isPrivateIp(ip: string): boolean { + if (IPV4_REGEX.test(ip)) return isPrivateIpv4(ip); + if (ip.includes(":")) return isPrivateIpv6(ip); + // Anything we cannot classify, treat as private (fail closed). + return true; +} + +/** Returns the IP literal if `hostname` is one, else null. */ +function ipLiteralFromHostname(hostname: string): string | null { + if (IPV4_REGEX.test(hostname)) return hostname; + if (hostname.startsWith("[") && hostname.endsWith("]")) { + return hostname.slice(1, -1); + } + return null; +} + +async function resolveHostnameToIps(hostname: string): Promise { + const query = async (type: "A" | "AAAA"): Promise => { + try { + const res = await fetch( + `https://dns.google/resolve?name=${encodeURIComponent(hostname)}&type=${type}`, + ); + if (!res.ok) return []; + const data = (await res.json()) as { + Answer?: { data: string; type: number }[]; + }; + // Type 1 = A, type 28 = AAAA. Ignore CNAME / other intermediate records. + const wantedType = type === "A" ? 1 : 28; + return (data.Answer ?? []) + .filter((a) => a.type === wantedType) + .map((a) => a.data); + } catch { + return []; + } + }; + + const [a, aaaa] = await Promise.all([query("A"), query("AAAA")]); + return [...a, ...aaaa]; +} + +async function assertUrlIsSafe(url: URL): Promise { + if (!ALLOWED_PROTOCOLS.has(url.protocol)) { + throw new DubApiError({ + code: "unprocessable_entity", + message: `URL protocol "${url.protocol}" is not allowed.`, + }); + } + + const literal = ipLiteralFromHostname(url.hostname); + if (literal !== null) { + if (isPrivateIp(literal)) { + throw new DubApiError({ + code: "unprocessable_entity", + message: "URL resolves to a disallowed IP address.", + }); + } + return; + } + + const ips = await resolveHostnameToIps(url.hostname); + if (ips.length === 0) { + throw new DubApiError({ + code: "unprocessable_entity", + message: `Could not resolve hostname: ${url.hostname}`, + }); + } + if (ips.some((ip) => isPrivateIp(ip))) { + throw new DubApiError({ + code: "unprocessable_entity", + message: "URL resolves to a disallowed IP address.", + }); + } +} + +type SafeFetchOptions = { + /** Total timeout across all redirect hops. Defaults to 5000ms. */ + timeoutMs?: number; + /** Max number of redirects to follow. Defaults to 5. Set to 0 to disable. */ + maxRedirects?: number; +}; + +export async function safeFetch( + url: string, + init?: RequestInit, + opts: SafeFetchOptions = {}, +): Promise { + const timeoutMs = opts.timeoutMs ?? DEFAULT_TIMEOUT_MS; + const maxRedirects = opts.maxRedirects ?? DEFAULT_MAX_REDIRECTS; + + let currentUrl: URL; + try { + currentUrl = new URL(url); + } catch { + throw new DubApiError({ + code: "unprocessable_entity", + message: `Invalid URL: ${url}`, + }); + } + + const deadline = Date.now() + timeoutMs; + let hops = 0; + + while (true) { + await assertUrlIsSafe(currentUrl); + + const remaining = deadline - Date.now(); + if (remaining <= 0) { + throw new Error("Request timed out"); + } + + const controller = new AbortController(); + const timer = setTimeout(() => controller.abort(), remaining); + + let response: Response; + try { + response = await fetch(currentUrl.toString(), { + ...init, + redirect: "manual", + signal: controller.signal, + }); + } finally { + clearTimeout(timer); + } + + const isRedirect = response.status >= 300 && response.status < 400; + const location = response.headers.get("location"); + if (!isRedirect || !location) { + return response; + } + + if (hops >= maxRedirects) { + throw new Error("Too many redirects"); + } + + try { + currentUrl = new URL(location, currentUrl); + } catch { + throw new DubApiError({ + code: "unprocessable_entity", + message: `Invalid redirect location: ${location}`, + }); + } + hops += 1; + } +} diff --git a/apps/web/lib/sitemaps/import-tracked-sitemaps.ts b/apps/web/lib/sitemaps/import-tracked-sitemaps.ts index 15256e684ca..0e9a290c90d 100644 --- a/apps/web/lib/sitemaps/import-tracked-sitemaps.ts +++ b/apps/web/lib/sitemaps/import-tracked-sitemaps.ts @@ -1,8 +1,8 @@ import { bulkCreateLinks } from "@/lib/api/links/bulk-create-links"; +import { safeFetch } from "@/lib/api/safe-fetch"; import type { TrackedSitemap } from "@/lib/sitemaps/site-visit-tracking"; import { ProcessedLinkProps } from "@/lib/types"; import { prisma } from "@dub/prisma"; -import { fetchWithTimeout } from "@dub/utils/src"; import { XMLParser } from "fast-xml-parser"; type SitemapXmlUrlEntry = { @@ -72,7 +72,12 @@ async function decompressIfGzip(buffer: ArrayBuffer): Promise { async function fetchAndParseSitemap( sitemapUrl: string, ): Promise { - const response = await fetchWithTimeout(sitemapUrl, { redirect: "error" }); // don't follow redirects + const response = await safeFetch(sitemapUrl, undefined, { maxRedirects: 0 }); + if (!response.ok) { + throw new Error( + `Failed to fetch sitemap: ${response.status} ${response.statusText}`, + ); + } const MAX_SITEMAP_BYTES = 10 * 1024 * 1024; // 10 MB const contentLength = response.headers.get("content-length"); if (contentLength && parseInt(contentLength, 10) > MAX_SITEMAP_BYTES) { diff --git a/apps/web/tests/misc/import-tracked-sitemaps.test.ts b/apps/web/tests/misc/import-tracked-sitemaps.test.ts index 492e7e937db..b90c0c26587 100644 --- a/apps/web/tests/misc/import-tracked-sitemaps.test.ts +++ b/apps/web/tests/misc/import-tracked-sitemaps.test.ts @@ -6,6 +6,8 @@ import { parseTrackedSitemaps } from "@/lib/sitemaps/site-visit-tracking"; import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import { gzipSync } from "zlib"; +vi.mock("server-only", () => ({})); + function toArrayBuffer(buf: Buffer): ArrayBuffer { const ab = new ArrayBuffer(buf.length); new Uint8Array(ab).set(buf); @@ -22,11 +24,14 @@ function makeFetchResponse( status, headers: new Headers(headers), arrayBuffer: () => Promise.resolve(toArrayBuffer(body)), + json: () => Promise.resolve(JSON.parse(body.toString("utf8"))), } as unknown as Response; } -function redirectResponse(location: string, status = 302): Response { - return makeFetchResponse(Buffer.alloc(0), status, { Location: location }); +function dohResponseFor(ip: string, type: 1 | 28 = 1): Response { + return makeFetchResponse( + Buffer.from(JSON.stringify({ Answer: [{ data: ip, type }] })), + ); } const urlsetXml = (urls: string[]) => @@ -134,7 +139,15 @@ describe("parseTrackedSitemaps", () => { }); describe("crawlSitemapUrls", () => { - const mockFetch = vi.fn(); + const mockContentFetch = vi.fn(); + const mockFetch = vi.fn((url: string | URL | Request) => { + const href = typeof url === "string" ? url : url.toString(); + if (href.includes("dns.google/resolve")) { + // Return a public IP so safeFetch lets the request through. + return Promise.resolve(dohResponseFor("93.184.216.34")); + } + return mockContentFetch(href); + }); beforeEach(() => { vi.stubGlobal("fetch", mockFetch); @@ -151,7 +164,7 @@ describe("crawlSitemapUrls", () => { "https://example.com/page-1", "https://example.com/page-2", ]; - mockFetch.mockResolvedValue( + mockContentFetch.mockResolvedValue( makeFetchResponse(Buffer.from(urlsetXml(pageUrls))), ); @@ -165,7 +178,7 @@ describe("crawlSitemapUrls", () => { }); it("deduplicates URLs across multiple entries", async () => { - mockFetch.mockResolvedValue( + mockContentFetch.mockResolvedValue( makeFetchResponse( Buffer.from( urlsetXml([ @@ -186,7 +199,7 @@ describe("crawlSitemapUrls", () => { describe("sitemapindex", () => { it("does not fetch nested sitemaps from a sitemap index", async () => { - mockFetch.mockResolvedValueOnce( + mockContentFetch.mockResolvedValueOnce( makeFetchResponse( Buffer.from( sitemapindexXml([ @@ -203,11 +216,11 @@ describe("crawlSitemapUrls", () => { expect(urls).toEqual([]); expect(hadErrors).toBe(false); - expect(mockFetch).toHaveBeenCalledTimes(1); + expect(mockContentFetch).toHaveBeenCalledTimes(1); }); it("returns no page URLs when the index only references itself", async () => { - mockFetch.mockResolvedValue( + mockContentFetch.mockResolvedValue( makeFetchResponse( Buffer.from(sitemapindexXml(["https://example.com/sitemap.xml"])), ), @@ -218,13 +231,13 @@ describe("crawlSitemapUrls", () => { ); expect(urls).toEqual([]); - expect(mockFetch).toHaveBeenCalledTimes(1); + expect(mockContentFetch).toHaveBeenCalledTimes(1); }); }); describe("nested xml in urlset", () => { it("skips .xml loc entries instead of treating them as pages", async () => { - mockFetch.mockResolvedValue( + mockContentFetch.mockResolvedValue( makeFetchResponse( Buffer.from( urlsetXml([ @@ -248,7 +261,7 @@ describe("crawlSitemapUrls", () => { const many = Array.from({ length: MAX_URLS_PER_SITEMAP + 50 }, (_, i) => i === 0 ? "https://example.com/first" : `https://example.com/p/${i}`, ); - mockFetch.mockResolvedValue( + mockContentFetch.mockResolvedValue( makeFetchResponse(Buffer.from(urlsetXml(many))), ); @@ -264,7 +277,7 @@ describe("crawlSitemapUrls", () => { it("decompresses a gzip-compressed sitemap response", async () => { const xml = urlsetXml(["https://example.com/page-1"]); const compressed = gzipSync(xml); - mockFetch.mockResolvedValue(makeFetchResponse(compressed)); + mockContentFetch.mockResolvedValue(makeFetchResponse(compressed)); const { urls } = await crawlSitemapUrls( "https://example.com/sitemap.xml.gz", @@ -275,7 +288,7 @@ describe("crawlSitemapUrls", () => { it("handles a plain XML response without attempting decompression", async () => { const xml = urlsetXml(["https://example.com/page-1"]); - mockFetch.mockResolvedValue(makeFetchResponse(Buffer.from(xml))); + mockContentFetch.mockResolvedValue(makeFetchResponse(Buffer.from(xml))); const { urls } = await crawlSitemapUrls( "https://example.com/sitemap.xml", @@ -284,4 +297,22 @@ describe("crawlSitemapUrls", () => { expect(urls).toContain("https://example.com/page-1"); }); }); + + describe("non-2xx responses", () => { + it("reports hadErrors instead of parsing the body on a non-2xx response", async () => { + mockContentFetch.mockResolvedValue( + makeFetchResponse( + Buffer.from(urlsetXml(["https://example.com/page-1"])), + 500, + ), + ); + + const { urls, hadErrors } = await crawlSitemapUrls( + "https://example.com/sitemap.xml", + ); + + expect(urls).toEqual([]); + expect(hadErrors).toBe(true); + }); + }); }); diff --git a/apps/web/tests/misc/safe-fetch.test.ts b/apps/web/tests/misc/safe-fetch.test.ts new file mode 100644 index 00000000000..fd366accd74 --- /dev/null +++ b/apps/web/tests/misc/safe-fetch.test.ts @@ -0,0 +1,100 @@ +import { describe, expect, it, vi } from "vitest"; + +// `server-only` throws on import outside of a Next.js server bundle. Stub it +// so we can import `safe-fetch.ts`, which depends transitively on it via +// `@/lib/api/errors`. +vi.mock("server-only", () => ({})); + +import { isPrivateIp } from "@/lib/api/safe-fetch"; + +describe("isPrivateIp", () => { + describe("IPv4", () => { + it.each([ + "0.0.0.0", + "10.0.0.1", + "10.255.255.255", + "100.64.0.1", + "127.0.0.1", + "127.255.255.254", + "169.254.169.254", // AWS / GCP / Azure metadata + "172.16.0.1", + "172.31.255.254", + "192.168.0.1", + "192.168.1.1", + "198.18.0.1", + "224.0.0.1", // multicast + "240.0.0.1", // reserved + "255.255.255.255", + ])("flags %s as private", (ip) => { + expect(isPrivateIp(ip)).toBe(true); + }); + + it.each([ + "1.1.1.1", + "8.8.8.8", + "9.9.9.9", + "11.0.0.1", // just outside 10.0.0.0/8 + "100.63.255.255", // just outside CGNAT + "172.15.255.255", // just outside 172.16.0.0/12 + "172.32.0.1", // just outside 172.16.0.0/12 + "192.0.1.1", // just outside 192.0.0.0/24 + "192.167.255.255", // just outside 192.168.0.0/16 + "193.0.0.1", + "223.255.255.255", // just outside multicast + ])("does not flag %s as private", (ip) => { + expect(isPrivateIp(ip)).toBe(false); + }); + }); + + describe("IPv6", () => { + it.each([ + "::", + "::1", + "fc00::", + "fc00::1", + "fd00::1", + "fdff::ffff", + "fe80::1", + "fe80::abcd:1234", + "ff00::1", + "ff02::1", + "64:ff9b::1.2.3.4", + "2001:db8::1", + ])("flags %s as private", (ip) => { + expect(isPrivateIp(ip)).toBe(true); + }); + + it.each([ + "::ffff:127.0.0.1", // IPv4-mapped loopback + "::ffff:10.0.0.1", // IPv4-mapped RFC1918 + "::ffff:169.254.169.254", // IPv4-mapped metadata + "::ffff:192.168.1.1", + ])("flags IPv4-mapped %s as private", (ip) => { + expect(isPrivateIp(ip)).toBe(true); + }); + + it.each([ + "2606:4700:4700::1111", // Cloudflare DNS + "2001:4860:4860::8888", // Google DNS + "2620:fe::fe", // Quad9 + "::ffff:1.1.1.1", // IPv4-mapped public + "::ffff:8.8.8.8", + ])("does not flag %s as private", (ip) => { + expect(isPrivateIp(ip)).toBe(false); + }); + }); + + describe("invalid inputs", () => { + it.each([ + "", + "not-an-ip", + "999.999.999.999", + "256.0.0.1", + "1.2.3", + "::gggg", + ":::1", + ])("treats %s as private (fail closed)", (value) => { + expect(isPrivateIp(value)).toBe(true); + }); + }); +}); diff --git a/apps/web/tests/setupTests.ts b/apps/web/tests/setupTests.ts index f73f75a3615..4d6837a3262 100644 --- a/apps/web/tests/setupTests.ts +++ b/apps/web/tests/setupTests.ts @@ -20,6 +20,9 @@ vi.mock("@axiomhq/logging", () => ({ AxiomJSTransport: class { constructor(_config: any) {} }, + ConsoleTransport: class { + constructor(_config?: any) {} + }, Logger: class { constructor(_config: any) {} log = vi.fn(); diff --git a/packages/utils/src/functions/is-iframeable.ts b/packages/utils/src/functions/is-iframeable.ts index 7b6a26a3b21..836d51c43e3 100644 --- a/packages/utils/src/functions/is-iframeable.ts +++ b/packages/utils/src/functions/is-iframeable.ts @@ -1,14 +1,14 @@ -// check if a link can be displayed in an iframe -export const isIframeable = async ({ - url, +// Determine whether a fetched page can be displayed in an iframe by the +// requesting domain, given the response's CSP / X-Frame-Options headers. +// The caller is responsible for fetching the URL (use an SSRF-safe fetcher). +export const isIframeable = ({ + headers, requestDomain, }: { - url: string; + headers: Headers; requestDomain: string; }) => { - const res = await fetch(url); - - const cspHeader = res.headers.get("content-security-policy"); + const cspHeader = headers.get("content-security-policy"); if (cspHeader) { const frameAncestorsMatch = cspHeader.match( /frame-ancestors\s+([\s\S]+?)(?=;|$)/i, @@ -24,7 +24,13 @@ export const isIframeable = async ({ } } - const xFrameOptions = res.headers.get("X-Frame-Options"); + // X-Frame-Options values are tokens per RFC 7034 but some servers send + // them lowercased, padded, or duplicated. Normalize before comparing. + const xFrameOptions = headers + .get("x-frame-options") + ?.split(",")[0] + ?.trim() + .toUpperCase(); if (xFrameOptions === "DENY" || xFrameOptions === "SAMEORIGIN") { return false; }