Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ permissions:
checks: write

env:
GO_VERSION: "1.24.9"
GO_VERSION: "1.25.11"

jobs:
# Static analysis and code quality check
Expand Down Expand Up @@ -88,18 +88,19 @@ jobs:

- name: Run Go Vulnerability Check
run: |
go install golang.org/x/vuln/cmd/govulncheck@latest
go install golang.org/x/vuln/cmd/govulncheck@v1.3.0
govulncheck ./...

- name: Run dependency scan
uses: aquasecurity/trivy-action@0.33.1
uses: aquasecurity/trivy-action@v0.36.0
with:
scan-type: "fs"
scan-ref: "."
format: "sarif"
output: "trivy-results.sarif"
severity: "CRITICAL,HIGH,MEDIUM"
timeout: "10m"
version: "v0.71.2"

- name: Upload security scan results
uses: github/codeql-action/upload-sarif@v4
Expand Down
11 changes: 11 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ type Config struct {
Mode string // "native" or "proxy"
Provider string // "hmac", "okta", "google", "azure"
RedirectURIs string // Redirect URIs (single or comma-separated)
// AllowedClientRedirectURIs are exact client callback URIs allowed in fixed
// redirect mode in addition to localhost/loopback callbacks.
AllowedClientRedirectURIs string

// OIDC configuration
Issuer string
Expand Down Expand Up @@ -187,6 +190,13 @@ func (b *ConfigBuilder) WithRedirectURIs(uris string) *ConfigBuilder {
return b
}

// WithAllowedClientRedirectURIs sets exact client callback URIs allowed in
// fixed redirect mode in addition to localhost/loopback callbacks.
func (b *ConfigBuilder) WithAllowedClientRedirectURIs(uris string) *ConfigBuilder {
b.config.AllowedClientRedirectURIs = uris
return b
}

// WithIssuer sets the OIDC issuer
func (b *ConfigBuilder) WithIssuer(issuer string) *ConfigBuilder {
b.config.Issuer = issuer
Expand Down Expand Up @@ -285,6 +295,7 @@ func FromEnv() (*Config, error) {
WithMode(getEnv("OAUTH_MODE", "")).
WithProvider(getEnv("OAUTH_PROVIDER", "")).
WithRedirectURIs(getEnv("OAUTH_REDIRECT_URIS", "")).
WithAllowedClientRedirectURIs(getEnv("OAUTH_ALLOWED_CLIENT_REDIRECT_URIS", "")).
WithIssuer(getEnv("OIDC_ISSUER", "")).
WithAudience(getEnv("OIDC_AUDIENCE", "")).
WithClientID(getEnv("OIDC_CLIENT_ID", "")).
Expand Down
43 changes: 43 additions & 0 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,49 @@ func TestConfigBuilder(t *testing.T) {
}
}

func TestAllowedClientRedirectURIsConfig(t *testing.T) {
cfg, err := NewConfigBuilder().
WithMode("proxy").
WithProvider("okta").
WithIssuer("https://okta.example.com").
WithAudience("test-audience").
WithClientID("client-123").
WithClientSecret("secret-456").
WithRedirectURIs("https://mcp-server.com/oauth/callback").
WithAllowedClientRedirectURIs("cursor://anysphere.cursor-mcp/oauth/callback").
Build()
if err != nil {
t.Fatalf("Build() error = %v", err)
}

oauth2Config := NewOAuth2ConfigFromConfig(cfg, "test")
if got, want := oauth2Config.AllowedClientRedirectURIs, "cursor://anysphere.cursor-mcp/oauth/callback"; got != want {
t.Fatalf("AllowedClientRedirectURIs = %q, want %q", got, want)
}
}

func TestAllowedClientRedirectURIsEnvFallback(t *testing.T) {
t.Setenv("OAUTH_ALLOWED_CLIENT_REDIRECT_URIS", "cursor://anysphere.cursor-mcp/oauth/callback")

cfg, err := NewConfigBuilder().
WithMode("proxy").
WithProvider("okta").
WithIssuer("https://okta.example.com").
WithAudience("test-audience").
WithClientID("client-123").
WithClientSecret("secret-456").
WithRedirectURIs("https://mcp-server.com/oauth/callback").
Build()
if err != nil {
t.Fatalf("Build() error = %v", err)
}

oauth2Config := NewOAuth2ConfigFromConfig(cfg, "test")
if got, want := oauth2Config.AllowedClientRedirectURIs, "cursor://anysphere.cursor-mcp/oauth/callback"; got != want {
t.Fatalf("AllowedClientRedirectURIs = %q, want %q", got, want)
}
}

func TestOAuth2HandlerRequestsProviderDefaultScopes(t *testing.T) {
tests := []struct {
name string
Expand Down
13 changes: 13 additions & 0 deletions docs/CONFIGURATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,17 @@ RedirectURIs: "https://your-server.com/oauth/callback"

Server uses this URI with provider. For security, client redirects must be localhost only.

To support a native client callback in fixed redirect mode, keep `RedirectURIs`
as the server callback and add exact client callbacks separately:

```go
RedirectURIs: "https://your-server.com/oauth/callback"
AllowedClientRedirectURIs: "cursor://anysphere.cursor-mcp/oauth/callback"
```

This preserves dynamic localhost callbacks for development tools while allowing
only the configured native client callbacks.

**Multiple URIs (Allowlist):**

```go
Expand Down Expand Up @@ -554,6 +565,7 @@ OAUTH_CLIENT_ID=your-client-id
OAUTH_CLIENT_SECRET=your-client-secret
OAUTH_SERVER_URL=https://your-server.com
OAUTH_REDIRECT_URIS=https://your-server.com/oauth/callback
OAUTH_ALLOWED_CLIENT_REDIRECT_URIS=cursor://anysphere.cursor-mcp/oauth/callback
```

Load in code:
Expand All @@ -572,6 +584,7 @@ func main() {
ClientSecret: os.Getenv("OAUTH_CLIENT_SECRET"),
ServerURL: os.Getenv("OAUTH_SERVER_URL"),
RedirectURIs: os.Getenv("OAUTH_REDIRECT_URIS"),
AllowedClientRedirectURIs: os.Getenv("OAUTH_ALLOWED_CLIENT_REDIRECT_URIS"),
JWTSecret: []byte(os.Getenv("JWT_SECRET")),
})
}
Expand Down
139 changes: 134 additions & 5 deletions fixed_redirect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@ package oauth

import (
"crypto/rand"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"golang.org/x/oauth2"
)

func TestFixedRedirectModeLocalhostOnly(t *testing.T) {
Expand Down Expand Up @@ -39,19 +45,19 @@ func TestFixedRedirectModeLocalhostOnly(t *testing.T) {
name: "HTTPS production domain rejected",
clientURI: "https://evil.com/callback",
shouldPass: false,
expectedError: "Fixed redirect mode only allows localhost",
expectedError: "fixed redirect mode only allows localhost",
},
{
name: "HTTP production domain rejected",
clientURI: "http://evil.com/callback",
shouldPass: false,
expectedError: "HTTPS required for non-localhost",
expectedError: "https required for non-localhost",
},
{
name: "localhost subdomain rejected",
clientURI: "https://localhost.evil.com/callback",
shouldPass: false,
expectedError: "Fixed redirect mode only allows localhost",
expectedError: "fixed redirect mode only allows localhost",
},
{
name: "URI with fragment rejected",
Expand All @@ -63,7 +69,7 @@ func TestFixedRedirectModeLocalhostOnly(t *testing.T) {
name: "Custom scheme rejected",
clientURI: "custom://localhost:8080/callback",
shouldPass: false,
expectedError: "Invalid redirect_uri scheme",
expectedError: "invalid redirect_uri scheme",
},
}

Expand All @@ -75,7 +81,7 @@ func TestFixedRedirectModeLocalhostOnly(t *testing.T) {
t.Errorf("Expected localhost detection to pass for %s", tt.clientURI)
}

if !tt.shouldPass && isLocalhost && tt.expectedError != "must not contain fragment" && tt.expectedError != "Invalid redirect_uri scheme" {
if !tt.shouldPass && isLocalhost && tt.expectedError != "must not contain fragment" && tt.expectedError != "invalid redirect_uri scheme" {
t.Errorf("Expected localhost detection to fail for %s", tt.clientURI)
}

Expand All @@ -100,3 +106,126 @@ func TestFixedRedirectModeSecurityModel(t *testing.T) {
t.Log("Use Case: Development tools (MCP Inspector) running on localhost")
t.Log("Production: Use allowlist mode instead")
}

func TestFixedRedirectModeAllowsConfiguredClientRedirectURI(t *testing.T) {
handler := newFixedRedirectTestHandler(t, "cursor://anysphere.cursor-mcp/oauth/callback")

req := httptest.NewRequest(http.MethodGet, "/oauth/authorize?client_id=test-client&redirect_uri="+url.QueryEscape("cursor://anysphere.cursor-mcp/oauth/callback")+"&response_type=code&code_challenge=test&code_challenge_method=S256&state=test-state", nil)
recorder := httptest.NewRecorder()

handler.HandleAuthorize(recorder, req)

if recorder.Code != http.StatusTemporaryRedirect {
t.Fatalf("status = %d, expected %d, body: %s", recorder.Code, http.StatusTemporaryRedirect, recorder.Body.String())
}

location := recorder.Header().Get("Location")
if !strings.HasPrefix(location, "https://okta.example/authorize?") {
t.Fatalf("Location = %q, expected Okta authorize redirect", location)
}
if !strings.Contains(location, url.QueryEscape("https://mcp-server.com/oauth/callback")) {
t.Fatalf("Location = %q, expected provider redirect_uri to remain the fixed server callback", location)
}
}

func TestFixedRedirectModeRejectsUnconfiguredCustomScheme(t *testing.T) {
handler := newFixedRedirectTestHandler(t, "")

req := httptest.NewRequest(http.MethodGet, "/oauth/authorize?client_id=test-client&redirect_uri="+url.QueryEscape("cursor://anysphere.cursor-mcp/oauth/callback")+"&response_type=code&code_challenge=test&code_challenge_method=S256&state=test-state", nil)
recorder := httptest.NewRecorder()

handler.HandleAuthorize(recorder, req)

if recorder.Code != http.StatusBadRequest {
t.Fatalf("status = %d, expected %d", recorder.Code, http.StatusBadRequest)
}
if !strings.Contains(recorder.Body.String(), "invalid redirect_uri scheme") {
t.Fatalf("body = %q, expected invalid redirect_uri scheme", recorder.Body.String())
}
}

func TestFixedRedirectModeStillAllowsLocalhostRedirectURI(t *testing.T) {
handler := newFixedRedirectTestHandler(t, "")

req := httptest.NewRequest(http.MethodGet, "/oauth/authorize?client_id=test-client&redirect_uri="+url.QueryEscape("http://127.0.0.1:3333/oauth/callback")+"&response_type=code&code_challenge=test&code_challenge_method=S256&state=test-state", nil)
recorder := httptest.NewRecorder()

handler.HandleAuthorize(recorder, req)

if recorder.Code != http.StatusTemporaryRedirect {
t.Fatalf("status = %d, expected %d, body: %s", recorder.Code, http.StatusTemporaryRedirect, recorder.Body.String())
}
}

func TestFixedRedirectCallbackAllowsConfiguredClientRedirectURI(t *testing.T) {
handler := newFixedRedirectTestHandler(t, "cursor://anysphere.cursor-mcp/oauth/callback")
signedState, err := handler.signState(map[string]string{
"state": "client-state",
"redirect": "cursor://anysphere.cursor-mcp/oauth/callback",
})
if err != nil {
t.Fatalf("sign state: %v", err)
}

req := httptest.NewRequest(http.MethodGet, "/oauth/callback?code=auth-code&state="+url.QueryEscape(signedState), nil)
recorder := httptest.NewRecorder()

handler.HandleCallback(recorder, req)

if recorder.Code != http.StatusFound {
t.Fatalf("status = %d, expected %d, body: %s", recorder.Code, http.StatusFound, recorder.Body.String())
}
location := recorder.Header().Get("Location")
if !strings.HasPrefix(location, "cursor://anysphere.cursor-mcp/oauth/callback?") {
t.Fatalf("Location = %q, expected Cursor callback redirect", location)
}
if !strings.Contains(location, "code=auth-code") || !strings.Contains(location, "state=client-state") {
t.Fatalf("Location = %q, expected code and original state", location)
}
}

func TestFixedRedirectCallbackRejectsUnconfiguredCustomScheme(t *testing.T) {
handler := newFixedRedirectTestHandler(t, "")
signedState, err := handler.signState(map[string]string{
"state": "client-state",
"redirect": "cursor://anysphere.cursor-mcp/oauth/callback",
})
if err != nil {
t.Fatalf("sign state: %v", err)
}

req := httptest.NewRequest(http.MethodGet, "/oauth/callback?code=auth-code&state="+url.QueryEscape(signedState), nil)
recorder := httptest.NewRecorder()

handler.HandleCallback(recorder, req)

if recorder.Code != http.StatusBadRequest {
t.Fatalf("status = %d, expected %d", recorder.Code, http.StatusBadRequest)
}
if !strings.Contains(recorder.Body.String(), "Invalid redirect URI in state") {
t.Fatalf("body = %q, expected Invalid redirect URI in state", recorder.Body.String())
}
}

func newFixedRedirectTestHandler(t *testing.T, allowedClientRedirectURIs string) *OAuth2Handler {
t.Helper()

key := make([]byte, 32)
_, _ = rand.Read(key)

return &OAuth2Handler{
config: &OAuth2Config{
RedirectURIs: "https://mcp-server.com/oauth/callback",
AllowedClientRedirectURIs: allowedClientRedirectURIs,
stateSigningKey: key,
},
oauth2Config: &oauth2.Config{
ClientID: "test-client",
RedirectURL: "https://mcp-server.com/oauth/callback",
Endpoint: oauth2.Endpoint{
AuthURL: "https://okta.example/authorize",
},
},
logger: &defaultLogger{},
}
}
13 changes: 8 additions & 5 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
module github.com/Vungle/oauth-mcp-proxy

go 1.24.9
go 1.25.11

require (
github.com/coreos/go-oidc/v3 v3.16.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/mark3labs/mcp-go v0.41.1
github.com/modelcontextprotocol/go-sdk v1.0.0
golang.org/x/oauth2 v0.32.0
github.com/modelcontextprotocol/go-sdk v1.4.1
golang.org/x/oauth2 v0.34.0
)

require (
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
github.com/google/jsonschema-go v0.3.0 // indirect
github.com/go-jose/go-jose/v4 v4.1.4 // indirect
github.com/google/jsonschema-go v0.4.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/invopop/jsonschema v0.13.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/segmentio/asm v1.1.3 // indirect
github.com/segmentio/encoding v0.5.4 // indirect
github.com/spf13/cast v1.8.0 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
golang.org/x/sys v0.40.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Loading
Loading