diff --git a/cmd/extensions/oidc-policy/api/v1alpha1/oidcpolicy_types.go b/cmd/extensions/oidc-policy/api/v1alpha1/oidcpolicy_types.go index d6fdc9bdd..ae3d0152b 100644 --- a/cmd/extensions/oidc-policy/api/v1alpha1/oidcpolicy_types.go +++ b/cmd/extensions/oidc-policy/api/v1alpha1/oidcpolicy_types.go @@ -271,6 +271,23 @@ func (p *OIDCPolicy) redirectURL(igwURL *url.URL) (*url.URL, error) { return redirectURL, nil } +// GetBaseURL returns the base URL (scheme + host + port) for post-authentication redirects. +// It derives this from spec.provider.redirectURI if set, otherwise from igwURL. +func (p *OIDCPolicy) GetBaseURL(igwURL *url.URL) (*url.URL, error) { + redirectURL, err := p.redirectURL(igwURL) + if err != nil { + return nil, err + } + + // Extract base URL (scheme + host, no path or query) + baseURL := &url.URL{ + Scheme: redirectURL.Scheme, + Host: redirectURL.Host, + } + + return baseURL, nil +} + // +kubebuilder:object:root=true // OIDCPolicyList contains a list of OIDCPolicy diff --git a/cmd/extensions/oidc-policy/api/v1alpha1/oidcpolicy_types_test.go b/cmd/extensions/oidc-policy/api/v1alpha1/oidcpolicy_types_test.go index 0eea0df81..fced76e52 100644 --- a/cmd/extensions/oidc-policy/api/v1alpha1/oidcpolicy_types_test.go +++ b/cmd/extensions/oidc-policy/api/v1alpha1/oidcpolicy_types_test.go @@ -163,6 +163,103 @@ func TestOIDCPolicyStatus_Equals(t *testing.T) { } } +func TestGetBaseURL(t *testing.T) { + tests := []struct { + name string + redirectURI string + igwURL string + expectedBase string + }{ + { + name: "No custom redirectURI - uses igwURL", + redirectURI: "", + igwURL: "http://gateway.example.com:8001", + expectedBase: "http://gateway.example.com:8001", + }, + { + name: "Custom redirectURI with non-standard port", + redirectURI: "https://public.example.com:8443/auth/callback", + igwURL: "http://gateway.example.com:8001", + expectedBase: "https://public.example.com:8443", + }, + { + name: "Custom redirectURI with standard port", + redirectURI: "https://public.example.com/auth/callback", + igwURL: "http://gateway.example.com:8001", + expectedBase: "https://public.example.com", + }, + { + name: "Custom redirectURI with different scheme", + redirectURI: "http://external.example.com:9000/custom/callback", + igwURL: "https://gateway.example.com", + expectedBase: "http://external.example.com:9000", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + policy := &OIDCPolicy{ + Spec: OIDCPolicySpec{ + OIDCPolicySpecProper: OIDCPolicySpecProper{ + Provider: &Provider{ + IssuerURL: "https://issuer.com", + ClientID: "client123", + RedirectURI: tt.redirectURI, + }, + }, + }, + } + + igwURL, err := url.Parse(tt.igwURL) + if err != nil { + t.Fatal(err) + } + + baseURL, err := policy.GetBaseURL(igwURL) + if err != nil { + t.Fatalf("GetBaseURL() error = %v", err) + } + + if baseURL.String() != tt.expectedBase { + t.Errorf("GetBaseURL() = %v, want %v", baseURL.String(), tt.expectedBase) + } + }) + } +} + +func TestGetBaseURL_ExtractsBaseFromRedirectURI(t *testing.T) { + policy := &OIDCPolicy{ + Spec: OIDCPolicySpec{ + OIDCPolicySpecProper: OIDCPolicySpecProper{ + Provider: &Provider{ + IssuerURL: "https://issuer.com", + ClientID: "client123", + RedirectURI: "https://public.example.com:8443/auth/callback?foo=bar", + }, + }, + }, + } + + igwURL, _ := url.Parse("http://gateway.example.com:8001") + baseURL, err := policy.GetBaseURL(igwURL) + if err != nil { + t.Fatalf("GetBaseURL() error = %v", err) + } + + // Base URL should only have scheme and host, no path or query + if baseURL.String() != "https://public.example.com:8443" { + t.Errorf("GetBaseURL() = %v, want https://public.example.com:8443", baseURL.String()) + } + + if baseURL.Path != "" { + t.Errorf("GetBaseURL() path should be empty, got %v", baseURL.Path) + } + + if baseURL.RawQuery != "" { + t.Errorf("GetBaseURL() query should be empty, got %v", baseURL.RawQuery) + } +} + func mockMinimalOIDCPolicy() *OIDCPolicy { return &OIDCPolicy{ TypeMeta: metav1.TypeMeta{}, diff --git a/cmd/extensions/oidc-policy/internal/controller/oidcpolicy_reconciler.go b/cmd/extensions/oidc-policy/internal/controller/oidcpolicy_reconciler.go index aad358e4b..973a23357 100644 --- a/cmd/extensions/oidc-policy/internal/controller/oidcpolicy_reconciler.go +++ b/cmd/extensions/oidc-policy/internal/controller/oidcpolicy_reconciler.go @@ -40,14 +40,21 @@ type ingressGatewayInfo struct { Name string `json:"name"` Namespace string `json:"namespace"` Protocol gatewayapiv1.ProtocolType `json:"protocol"` + Port int32 `json:"port"` url *url.URL } func (g *ingressGatewayInfo) GetURL() *url.URL { if g.url == nil { + host := g.Hostname + // Include port if it's not the standard port for the protocol + if (g.Protocol == gatewayapiv1.HTTPProtocolType && g.Port != 80) || + (g.Protocol == gatewayapiv1.HTTPSProtocolType && g.Port != 443) { + host = fmt.Sprintf("%s:%d", g.Hostname, g.Port) + } g.url = &url.URL{ Scheme: strings.ToLower(string(g.Protocol)), - Host: g.Hostname, + Host: host, } } return g.url @@ -109,6 +116,7 @@ func (r *OIDCPolicyReconciler) Reconcile(ctx context.Context, request reconcile. oidcPolicy, `{"protocol": self.findGateways()[0].spec.listeners[0].protocol, "hostname": self.findGateways()[0].spec.listeners[0].hostname, + "port": self.findGateways()[0].spec.listeners[0].port, "name": self.findGateways()[0].metadata.name, "namespace": self.findGateways()[0].metadata.namespace}`, true) @@ -282,6 +290,11 @@ func (r *OIDCPolicyReconciler) reconcileHTTPRoute(ctx context.Context, desired * return err } +func buildTargetCookieExpression(hostname string, protocol gatewayapiv1.ProtocolType) string { + return fmt.Sprintf(` +"target=" + request.path + (has(request.query) && request.query != "" ? "?" + request.query : "") + "; domain=%s; HttpOnly; %s SameSite=Lax; Path=/; Max-Age=3600"`, hostname, getSecureFlag(protocol)) +} + func buildMainAuthPolicy(pol *v1alpha1.OIDCPolicy, igw *ingressGatewayInfo) (*kuadrantv1.AuthPolicy, error) { authorizeURL, err := pol.GetAuthorizeURL(igw.GetURL()) if err != nil { @@ -292,8 +305,7 @@ func buildMainAuthPolicy(pol *v1alpha1.OIDCPolicy, igw *ingressGatewayInfo) (*ku return nil, err } - setCookie := fmt.Sprintf(` -"target=" + request.path + "; domain=%s; HttpOnly; %s SameSite=Lax; Path=/; Max-Age=3600"`, igw.Hostname, getSecureFlag(igw.Protocol)) + setCookie := buildTargetCookieExpression(igw.Hostname, igw.Protocol) var authorization = map[string]kuadrantv1.MergeableAuthorizationSpec{} var authPatterns []authorinov1beta3.PatternExpressionOrRef @@ -427,6 +439,14 @@ func buildCallbackHTTPRoute(pol *v1alpha1.OIDCPolicy, igw *ingressGatewayInfo) * } } +func buildOpaAuthorizationRule(baseURL *url.URL, igwURL *url.URL, authorizeURL string) string { + return fmt.Sprintf(`cookies := { name: value | raw_cookies := input.request.headers.cookie; cookie_parts := split(raw_cookies, ";"); part := cookie_parts[_]; trimmed := trim(part, " "); eq_idx := indexof(trimmed, "="); eq_idx != -1; name := trim(substring(trimmed, 0, eq_idx), " "); value := trim(substring(trimmed, eq_idx + 1, -1), " ")} +location := concat("", ["%s", cookies.target]) { input.auth.metadata.token.id_token; cookies.target } +location := "%s" { input.auth.metadata.token.id_token; not cookies.target } +location := "%s" { not input.auth.metadata.token.id_token } +allow = true`, baseURL, igwURL, authorizeURL) +} + func buildCallbackAuthPolicy(pol *v1alpha1.OIDCPolicy, igw *ingressGatewayInfo) (*kuadrantv1.AuthPolicy, error) { igwURL := igw.GetURL() tokenRequestURL, err := pol.GetTokenRequestURL() @@ -443,6 +463,12 @@ func buildCallbackAuthPolicy(pol *v1alpha1.OIDCPolicy, igw *ingressGatewayInfo) return nil, err } + // Get the base URL for post-auth redirects (respects custom redirectURI if set) + baseURL, err := pol.GetBaseURL(igwURL) + if err != nil { + return nil, err + } + callbackRoute := gatewayapiv1alpha2.LocalPolicyTargetReference{ Group: gatewayapiv1alpha2.GroupName, Kind: gatewayapiv1alpha2.Kind("HTTPRoute"), @@ -455,11 +481,7 @@ func buildCallbackAuthPolicy(pol *v1alpha1.OIDCPolicy, igw *ingressGatewayInfo) Expression: authorinov1beta3.CelExpression(callBodyCelExpression), } - opaAuthorizationRule := fmt.Sprintf(`cookies := { name: value | raw_cookies := input.request.headers.cookie; cookie_parts := split(raw_cookies, ";"); part := cookie_parts[_]; kv := split(trim(part, " "), "="); count(kv) == 2; name := trim(kv[0], " "); value := trim(kv[1], " ")} -location := concat("", ["%s", cookies.target]) { input.auth.metadata.token.id_token; cookies.target } -location := "%s" { input.auth.metadata.token.id_token; not cookies.target } -location := "%s" { not input.auth.metadata.token.id_token } -allow = true`, igwURL, igwURL, authorizeURL) + opaAuthorizationRule := buildOpaAuthorizationRule(baseURL, igwURL, authorizeURL) return &kuadrantv1.AuthPolicy{ TypeMeta: metav1.TypeMeta{ diff --git a/cmd/extensions/oidc-policy/internal/controller/oidcpolicy_reconciler_test.go b/cmd/extensions/oidc-policy/internal/controller/oidcpolicy_reconciler_test.go new file mode 100644 index 000000000..b4b463e4e --- /dev/null +++ b/cmd/extensions/oidc-policy/internal/controller/oidcpolicy_reconciler_test.go @@ -0,0 +1,537 @@ +//go:build unit + +package controller + +import ( + "fmt" + "net/url" + "strings" + "testing" + + gatewayapiv1 "sigs.k8s.io/gateway-api/apis/v1" +) + +func TestBuildOpaAuthorizationRule(t *testing.T) { + baseURL, err := url.Parse("https://gateway.example.com:8443") + if err != nil { + t.Fatal(err) + } + igwURL, err := url.Parse("https://gateway.example.com:8443") + if err != nil { + t.Fatal(err) + } + authorizeURL := "https://issuer.com/authorize?client_id=test" + + rule := buildOpaAuthorizationRule(baseURL, igwURL, authorizeURL) + fmt.Println(rule) + + // Verify the rule contains the correct cookie parser that handles JWT tokens + if !strings.Contains(rule, "eq_idx := indexof(trimmed, \"=\")") { + t.Error("OPA rule should use indexof to find first = character") + } + if !strings.Contains(rule, "substring(trimmed, 0, eq_idx)") { + t.Error("OPA rule should use substring to extract cookie name") + } + if !strings.Contains(rule, "substring(trimmed, eq_idx + 1, -1)") { + t.Error("OPA rule should use substring to extract cookie value") + } + + // Verify the rule does NOT use the broken split/count pattern + if strings.Contains(rule, "count(kv) == 2") { + t.Error("OPA rule should not use count check that breaks with = characters in values") + } + + // Verify URLs are correctly embedded + if !strings.Contains(rule, igwURL.String()) { + t.Errorf("OPA rule should contain gateway URL: %s", igwURL.String()) + } + if !strings.Contains(rule, authorizeURL) { + t.Errorf("OPA rule should contain authorize URL: %s", authorizeURL) + } +} + +func TestBuildOpaAuthorizationRule_CookieParserPattern(t *testing.T) { + baseURL, _ := url.Parse("http://example.com") + igwURL, _ := url.Parse("http://example.com") + authorizeURL := "http://issuer.com/auth" + + rule := buildOpaAuthorizationRule(baseURL, igwURL, authorizeURL) + + // The cookie parser should handle JWT tokens with = padding + // Example JWT: eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIn0.Signature== + // Cookie: jwt=eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIn0.Signature== + + // The pattern should: + // 1. Find the index of first = character + expectedPatterns := []string{ + "trimmed := trim(part, \" \")", // trim the cookie part + "eq_idx := indexof(trimmed, \"=\")", // find first = + "eq_idx != -1", // ensure = was found + "substring(trimmed, 0, eq_idx)", // extract name (before =) + "substring(trimmed, eq_idx + 1, -1)", // extract value (after =, including any additional =) + } + + for _, pattern := range expectedPatterns { + if !strings.Contains(rule, pattern) { + t.Errorf("OPA rule missing expected pattern: %s", pattern) + } + } + + // Verify the location logic is present + expectedLocationLogic := []string{ + "location := concat", + "cookies.target", + "input.auth.metadata.token.id_token", + "allow = true", + } + + for _, logic := range expectedLocationLogic { + if !strings.Contains(rule, logic) { + t.Errorf("OPA rule missing expected logic: %s", logic) + } + } +} + +func TestBuildOpaAuthorizationRule_JWTScenarios(t *testing.T) { + tests := []struct { + name string + description string + jwtExample string + }{ + { + name: "JWT with single = padding", + description: "JWT token ending with single = character", + jwtExample: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.signature=", + }, + { + name: "JWT with double = padding", + description: "JWT token ending with double == characters", + jwtExample: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIn0.sig==", + }, + { + name: "JWT with no padding", + description: "JWT token with no = padding", + jwtExample: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.signature", + }, + } + + baseURL, _ := url.Parse("http://example.com") + igwURL, _ := url.Parse("http://example.com") + authorizeURL := "http://issuer.com/auth" + rule := buildOpaAuthorizationRule(baseURL, igwURL, authorizeURL) + + // Document that the cookie parser pattern can handle all these scenarios + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // The rule uses substring(trimmed, eq_idx + 1, -1) which extracts everything after the first = + // This correctly handles JWTs with any number of = characters + if !strings.Contains(rule, "substring(trimmed, eq_idx + 1, -1)") { + t.Errorf("Cookie parser should handle %s: %s", tt.description, tt.jwtExample) + } + }) + } +} + +func TestBuildTargetCookieExpression(t *testing.T) { + tests := []struct { + name string + hostname string + protocol gatewayapiv1.ProtocolType + want []string + }{ + { + name: "HTTP protocol", + hostname: "example.com", + protocol: gatewayapiv1.HTTPProtocolType, + want: []string{ + `"target=" + request.path`, + `has(request.query) && request.query != ""`, + `"?" + request.query`, + `domain=example.com`, + `HttpOnly`, + `SameSite=Lax`, + `Path=/`, + `Max-Age=3600`, + }, + }, + { + name: "HTTPS protocol with Secure flag", + hostname: "secure.example.com", + protocol: gatewayapiv1.HTTPSProtocolType, + want: []string{ + `"target=" + request.path`, + `has(request.query) && request.query != ""`, + `"?" + request.query`, + `domain=secure.example.com`, + `HttpOnly`, + `Secure`, + `SameSite=Lax`, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildTargetCookieExpression(tt.hostname, tt.protocol) + + for _, expected := range tt.want { + if !strings.Contains(result, expected) { + t.Errorf("buildTargetCookieExpression() missing expected pattern: %s\nGot: %s", expected, result) + } + } + }) + } +} + +func TestBuildTargetCookieExpression_QueryStringHandling(t *testing.T) { + hostname := "example.com" + protocol := gatewayapiv1.HTTPProtocolType + + expression := buildTargetCookieExpression(hostname, protocol) + + // Verify the expression includes query string handling + requiredPatterns := []string{ + // CEL ternary operator to conditionally add query string + `has(request.query) && request.query != ""`, + `"?" + request.query`, + // The pattern should be: path + (has(request.query) ? "?" + query : "") + `request.path + (has(request.query) && request.query != "" ? "?" + request.query : "")`, + } + + for _, pattern := range requiredPatterns { + if !strings.Contains(expression, pattern) { + t.Errorf("Expression missing query string handling pattern: %s", pattern) + } + } + + // Verify it does NOT use the broken pattern that only stores the path + if strings.Contains(expression, `"target=" + request.path + "; domain=`) { + t.Error("Expression should not directly concatenate path with cookie attributes (missing query string logic)") + } +} + +func TestBuildTargetCookieExpression_Examples(t *testing.T) { + expression := buildTargetCookieExpression("example.com", gatewayapiv1.HTTPSProtocolType) + + // Document the expected behavior with examples + examples := []struct { + scenario string + requestPath string + query string + expected string + }{ + { + scenario: "Path with query parameters", + requestPath: "/dashboard", + query: "elicitation_id=123&user=456", + expected: "/dashboard?elicitation_id=123&user=456", + }, + { + scenario: "Path without query parameters", + requestPath: "/home", + query: "", + expected: "/home", + }, + { + scenario: "Path with complex query string", + requestPath: "/api/v1/resource", + query: "filter=active&sort=desc&limit=50", + expected: "/api/v1/resource?filter=active&sort=desc&limit=50", + }, + } + + for _, ex := range examples { + t.Run(ex.scenario, func(t *testing.T) { + // The CEL expression uses a ternary: request.path + (has(request.query) ? "?" + request.query : "") + // This should construct the full path with query when query is present + if !strings.Contains(expression, `request.path + (has(request.query) ? "?" + request.query : "")`) { + t.Errorf("Expression should handle scenario: %s\nExpected to preserve: %s", ex.scenario, ex.expected) + } + }) + } +} + +func TestIngressGatewayInfo_GetURL(t *testing.T) { + tests := []struct { + name string + hostname string + protocol gatewayapiv1.ProtocolType + port int32 + expectedScheme string + expectedHost string + expectedFullURL string + }{ + { + name: "HTTP standard port 80", + hostname: "example.com", + protocol: gatewayapiv1.HTTPProtocolType, + port: 80, + expectedScheme: "http", + expectedHost: "example.com", + expectedFullURL: "http://example.com", + }, + { + name: "HTTPS standard port 443", + hostname: "secure.example.com", + protocol: gatewayapiv1.HTTPSProtocolType, + port: 443, + expectedScheme: "https", + expectedHost: "secure.example.com", + expectedFullURL: "https://secure.example.com", + }, + { + name: "HTTP non-standard port 8080", + hostname: "example.com", + protocol: gatewayapiv1.HTTPProtocolType, + port: 8080, + expectedScheme: "http", + expectedHost: "example.com:8080", + expectedFullURL: "http://example.com:8080", + }, + { + name: "HTTP non-standard port 8001", + hostname: "example.com", + protocol: gatewayapiv1.HTTPProtocolType, + port: 8001, + expectedScheme: "http", + expectedHost: "example.com:8001", + expectedFullURL: "http://example.com:8001", + }, + { + name: "HTTPS non-standard port 8443", + hostname: "secure.example.com", + protocol: gatewayapiv1.HTTPSProtocolType, + port: 8443, + expectedScheme: "https", + expectedHost: "secure.example.com:8443", + expectedFullURL: "https://secure.example.com:8443", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + igw := &ingressGatewayInfo{ + Hostname: tt.hostname, + Protocol: tt.protocol, + Port: tt.port, + Name: "test-gateway", + Namespace: "default", + } + + url := igw.GetURL() + + if url.Scheme != tt.expectedScheme { + t.Errorf("GetURL() scheme = %v, want %v", url.Scheme, tt.expectedScheme) + } + if url.Host != tt.expectedHost { + t.Errorf("GetURL() host = %v, want %v", url.Host, tt.expectedHost) + } + if url.String() != tt.expectedFullURL { + t.Errorf("GetURL() = %v, want %v", url.String(), tt.expectedFullURL) + } + }) + } +} + +func TestIngressGatewayInfo_GetURL_CachesResult(t *testing.T) { + igw := &ingressGatewayInfo{ + Hostname: "example.com", + Protocol: gatewayapiv1.HTTPProtocolType, + Port: 8080, + Name: "test-gateway", + Namespace: "default", + } + + // Call GetURL multiple times + url1 := igw.GetURL() + url2 := igw.GetURL() + + // Should return the same cached instance + if url1 != url2 { + t.Error("GetURL() should cache and return the same URL instance") + } + + // Verify the URL is correct + expectedURL := "http://example.com:8080" + if url1.String() != expectedURL { + t.Errorf("GetURL() = %v, want %v", url1.String(), expectedURL) + } +} + +func TestIngressGatewayInfo_GetURL_PortInCookieDomain(t *testing.T) { + // Test that demonstrates Bug 1 is fixed: port is preserved in URL construction + igw := &ingressGatewayInfo{ + Hostname: "example.com", + Protocol: gatewayapiv1.HTTPProtocolType, + Port: 8001, + Name: "test-gateway", + Namespace: "default", + } + + url := igw.GetURL() + + // The URL should include the port + if url.Host != "example.com:8001" { + t.Errorf("Expected Host to include port: got %v, want example.com:8001", url.Host) + } + + // When this URL is used for redirect URI construction, the port will be preserved + redirectURI := url.String() + "/auth/callback" + expectedRedirectURI := "http://example.com:8001/auth/callback" + if redirectURI != expectedRedirectURI { + t.Errorf("Redirect URI = %v, want %v", redirectURI, expectedRedirectURI) + } + + // Cookie domain uses igw.Hostname (without port), which is correct + cookieExpr := buildTargetCookieExpression(igw.Hostname, igw.Protocol) + if !strings.Contains(cookieExpr, "domain=example.com") { + t.Error("Cookie domain should use hostname without port") + } +} + +func TestBuildOpaAuthorizationRule_UsesCorrectBaseURL(t *testing.T) { + tests := []struct { + name string + baseURL string + igwURL string + authorizeURL string + expectedInRule []string + notExpectedInRule []string + description string + }{ + { + name: "No custom redirectURI - both URLs are the same", + baseURL: "http://gateway.example.com:8001", + igwURL: "http://gateway.example.com:8001", + authorizeURL: "https://issuer.com/authorize", + expectedInRule: []string{ + "http://gateway.example.com:8001", + "https://issuer.com/authorize", + }, + description: "When no custom redirectURI is set, baseURL and igwURL are identical", + }, + { + name: "Custom redirect URI - baseURL differs from igwURL", + baseURL: "https://public.example.com:8443", + igwURL: "http://gateway.internal:8080", + authorizeURL: "https://issuer.com/authorize", + expectedInRule: []string{ + // Location 1: with cookies.target uses baseURL + `concat("", ["https://public.example.com:8443", cookies.target])`, + // Location 2: without cookies.target uses igwURL + `location := "http://gateway.internal:8080"`, + // Location 3: no auth uses authorizeURL + `location := "https://issuer.com/authorize"`, + }, + description: "When custom redirectURI is set, location 1 uses baseURL, location 2 uses igwURL", + }, + { + name: "Custom redirect URI without port", + baseURL: "https://public.example.com", + igwURL: "http://gateway.example.com:8001", + authorizeURL: "https://issuer.com/authorize?client_id=test", + expectedInRule: []string{ + `concat("", ["https://public.example.com", cookies.target])`, + `location := "http://gateway.example.com:8001"`, + `location := "https://issuer.com/authorize?client_id=test"`, + }, + description: "Handles custom redirectURI with standard port, igwURL with non-standard port", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + baseURL, err := url.Parse(tt.baseURL) + if err != nil { + t.Fatal(err) + } + igwURL, err := url.Parse(tt.igwURL) + if err != nil { + t.Fatal(err) + } + + rule := buildOpaAuthorizationRule(baseURL, igwURL, tt.authorizeURL) + + // Verify expected strings are in the rule + for _, expected := range tt.expectedInRule { + if !strings.Contains(rule, expected) { + t.Errorf("OPA rule missing expected pattern: %s\nDescription: %s\nRule: %s", + expected, tt.description, rule) + } + } + + // Verify unexpected strings are NOT in the rule + for _, notExpected := range tt.notExpectedInRule { + if strings.Contains(rule, notExpected) { + t.Errorf("OPA rule should not contain: %s\nDescription: %s\nRule: %s", + notExpected, tt.description, rule) + } + } + }) + } +} + +func TestBuildOpaAuthorizationRule_LocationRedirects(t *testing.T) { + baseURL, _ := url.Parse("https://public.example.com:8443") + igwURL, _ := url.Parse("http://gateway.internal:8080") + authorizeURL := "https://issuer.com/authorize?client_id=test" + + rule := buildOpaAuthorizationRule(baseURL, igwURL, authorizeURL) + + // The rule should have three location assignments: + // 1. Successful auth with target cookie: concat baseURL with cookies.target (uses custom redirectURI base) + // 2. Successful auth without target cookie: redirect to igwURL (uses gateway URL as default) + // 3. Failed auth: redirect to authorizeURL + + expectedPatterns := []string{ + // Pattern 1: successful auth with target - uses baseURL + `location := concat("", ["https://public.example.com:8443", cookies.target])`, + `input.auth.metadata.token.id_token`, + `cookies.target`, + + // Pattern 2: successful auth without target - uses igwURL + `location := "http://gateway.internal:8080"`, + `not cookies.target`, + + // Pattern 3: failed auth + `location := "https://issuer.com/authorize?client_id=test"`, + `not input.auth.metadata.token.id_token`, + + // Allow statement + `allow = true`, + } + + for _, pattern := range expectedPatterns { + if !strings.Contains(rule, pattern) { + t.Errorf("OPA rule missing expected pattern: %s", pattern) + } + } +} + +func TestBuildOpaAuthorizationRule_CustomRedirectURI_Scenario(t *testing.T) { + // Real-world scenario: LoadBalancer exposes gateway on public URL, + // but internal gateway uses different host/port + baseURL, _ := url.Parse("https://app.example.com") // Custom redirectURI base + igwURL, _ := url.Parse("http://gateway.internal:8080") // Internal gateway URL + authorizeURL := "https://auth.example.com/authorize" + + rule := buildOpaAuthorizationRule(baseURL, igwURL, authorizeURL) + + // Scenario 1: User tried to access /dashboard?tab=settings + // After auth, they should be redirected to: https://app.example.com/dashboard?tab=settings + if !strings.Contains(rule, `concat("", ["https://app.example.com", cookies.target])`) { + t.Error("Location 1 should use custom baseURL for user's intended destination") + } + + // Scenario 2: User accessed callback directly (no target cookie) + // They should be redirected to the internal gateway URL as default: http://gateway.internal:8080 + if !strings.Contains(rule, `location := "http://gateway.internal:8080" { input.auth.metadata.token.id_token; not cookies.target }`) { + t.Error("Location 2 should use igwURL as default when no target cookie exists") + } + + // Scenario 3: Auth failed (no token) + // Redirect to authorize URL for re-authentication + if !strings.Contains(rule, `location := "https://auth.example.com/authorize" { not input.auth.metadata.token.id_token }`) { + t.Error("Location 3 should redirect to authorize URL when auth fails") + } +} diff --git a/internal/extension/reconciler.go b/internal/extension/reconciler.go index ccb793072..77c9624d2 100644 --- a/internal/extension/reconciler.go +++ b/internal/extension/reconciler.go @@ -234,12 +234,14 @@ func toListeners(listeners []v1.Listener) []*extpb.Listener { ls := make([]*extpb.Listener, len(listeners)) for i, l := range listeners { listener := extpb.Listener{} + listener.Name = string(l.Name) if l.Hostname != nil { listener.Hostname = string(*l.Hostname) } if l.Protocol != "" { listener.Protocol = string(l.Protocol) } + listener.Port = int32(l.Port) ls[i] = &listener } return ls diff --git a/pkg/extension/grpc/v1/common.pb.go b/pkg/extension/grpc/v1/common.pb.go index fe6e2a697..0d5db4394 100644 --- a/pkg/extension/grpc/v1/common.pb.go +++ b/pkg/extension/grpc/v1/common.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.11 -// protoc v3.19.6 +// protoc-gen-go v1.36.6 +// protoc v5.29.3 // source: v1/common.proto package v1 diff --git a/pkg/extension/grpc/v1/descriptor_service.pb.go b/pkg/extension/grpc/v1/descriptor_service.pb.go index c756d9aa7..0def85ca2 100644 --- a/pkg/extension/grpc/v1/descriptor_service.pb.go +++ b/pkg/extension/grpc/v1/descriptor_service.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.11 -// protoc v3.19.6 +// protoc-gen-go v1.36.6 +// protoc v5.29.3 // source: v1/descriptor_service.proto package v1 diff --git a/pkg/extension/grpc/v1/descriptor_service_grpc.pb.go b/pkg/extension/grpc/v1/descriptor_service_grpc.pb.go index eaf6e99a3..2f2e68d61 100644 --- a/pkg/extension/grpc/v1/descriptor_service_grpc.pb.go +++ b/pkg/extension/grpc/v1/descriptor_service_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.1 -// - protoc v3.19.6 +// - protoc-gen-go-grpc v1.5.1 +// - protoc v5.29.3 // source: v1/descriptor_service.proto package v1 @@ -69,7 +69,7 @@ type DescriptorServiceServer interface { type UnimplementedDescriptorServiceServer struct{} func (UnimplementedDescriptorServiceServer) GetServiceDescriptors(context.Context, *GetServiceDescriptorsRequest) (*GetServiceDescriptorsResponse, error) { - return nil, status.Error(codes.Unimplemented, "method GetServiceDescriptors not implemented") + return nil, status.Errorf(codes.Unimplemented, "method GetServiceDescriptors not implemented") } func (UnimplementedDescriptorServiceServer) mustEmbedUnimplementedDescriptorServiceServer() {} func (UnimplementedDescriptorServiceServer) testEmbeddedByValue() {} @@ -82,7 +82,7 @@ type UnsafeDescriptorServiceServer interface { } func RegisterDescriptorServiceServer(s grpc.ServiceRegistrar, srv DescriptorServiceServer) { - // If the following call panics, it indicates UnimplementedDescriptorServiceServer was + // If the following call pancis, it indicates UnimplementedDescriptorServiceServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. diff --git a/pkg/extension/grpc/v1/gateway_api.pb.go b/pkg/extension/grpc/v1/gateway_api.pb.go index 0d253b133..a59ba3f00 100644 --- a/pkg/extension/grpc/v1/gateway_api.pb.go +++ b/pkg/extension/grpc/v1/gateway_api.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.11 -// protoc v3.19.6 +// protoc-gen-go v1.36.6 +// protoc v5.29.3 // source: v1/gateway_api.proto package v1 @@ -146,6 +146,7 @@ type Listener struct { Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` Hostname string `protobuf:"bytes,2,opt,name=hostname,proto3" json:"hostname,omitempty"` Protocol string `protobuf:"bytes,3,opt,name=protocol,proto3" json:"protocol,omitempty"` + Port int32 `protobuf:"varint,4,opt,name=port,proto3" json:"port,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -201,6 +202,13 @@ func (x *Listener) GetProtocol() string { return "" } +func (x *Listener) GetPort() int32 { + if x != nil { + return x.Port + } + return 0 +} + type GatewayAddresses struct { state protoimpl.MessageState `protogen:"open.v1"` AddressType string `protobuf:"bytes,1,opt,name=addressType,proto3" json:"addressType,omitempty"` @@ -533,11 +541,12 @@ const file_v1_gateway_api_proto_rawDesc = "" + "\vGatewaySpec\x12*\n" + "\x10gatewayClassName\x18\x01 \x01(\tR\x10gatewayClassName\x123\n" + "\tlisteners\x18\x02 \x03(\v2\x15.kuadrant.v1.ListenerR\tlisteners\x12;\n" + - "\taddresses\x18\x03 \x03(\v2\x1d.kuadrant.v1.GatewayAddressesR\taddresses\"V\n" + + "\taddresses\x18\x03 \x03(\v2\x1d.kuadrant.v1.GatewayAddressesR\taddresses\"j\n" + "\bListener\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x12\x1a\n" + "\bhostname\x18\x02 \x01(\tR\bhostname\x12\x1a\n" + - "\bprotocol\x18\x03 \x01(\tR\bprotocol\"J\n" + + "\bprotocol\x18\x03 \x01(\tR\bprotocol\x12\x12\n" + + "\x04port\x18\x04 \x01(\x05R\x04port\"J\n" + "\x10GatewayAddresses\x12 \n" + "\vaddressType\x18\x01 \x01(\tR\vaddressType\x12\x14\n" + "\x05value\x18\x02 \x01(\tR\x05value\"\xbf\x01\n" + diff --git a/pkg/extension/grpc/v1/gateway_api.proto b/pkg/extension/grpc/v1/gateway_api.proto index e797183ad..7e2c79ac6 100644 --- a/pkg/extension/grpc/v1/gateway_api.proto +++ b/pkg/extension/grpc/v1/gateway_api.proto @@ -22,6 +22,7 @@ message Listener { string name = 1; string hostname = 2; string protocol = 3; + int32 port = 4; } message GatewayAddresses { diff --git a/pkg/extension/grpc/v1/kuadrant.pb.go b/pkg/extension/grpc/v1/kuadrant.pb.go index 6eefcfa15..028533b1e 100644 --- a/pkg/extension/grpc/v1/kuadrant.pb.go +++ b/pkg/extension/grpc/v1/kuadrant.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.11 -// protoc v3.19.6 +// protoc-gen-go v1.36.6 +// protoc v5.29.3 // source: v1/kuadrant.proto package v1 diff --git a/pkg/extension/grpc/v1/kuadrant_grpc.pb.go b/pkg/extension/grpc/v1/kuadrant_grpc.pb.go index 36c410860..dbbdbc429 100644 --- a/pkg/extension/grpc/v1/kuadrant_grpc.pb.go +++ b/pkg/extension/grpc/v1/kuadrant_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.1 -// - protoc v3.19.6 +// - protoc-gen-go-grpc v1.5.1 +// - protoc v5.29.3 // source: v1/kuadrant.proto package v1 @@ -169,25 +169,25 @@ type ExtensionServiceServer interface { type UnimplementedExtensionServiceServer struct{} func (UnimplementedExtensionServiceServer) Ping(context.Context, *PingRequest) (*PongResponse, error) { - return nil, status.Error(codes.Unimplemented, "method Ping not implemented") + return nil, status.Errorf(codes.Unimplemented, "method Ping not implemented") } func (UnimplementedExtensionServiceServer) Subscribe(*SubscribeRequest, grpc.ServerStreamingServer[SubscribeResponse]) error { - return status.Error(codes.Unimplemented, "method Subscribe not implemented") + return status.Errorf(codes.Unimplemented, "method Subscribe not implemented") } func (UnimplementedExtensionServiceServer) Resolve(context.Context, *ResolveRequest) (*ResolveResponse, error) { - return nil, status.Error(codes.Unimplemented, "method Resolve not implemented") + return nil, status.Errorf(codes.Unimplemented, "method Resolve not implemented") } func (UnimplementedExtensionServiceServer) RegisterMutator(context.Context, *RegisterMutatorRequest) (*empty.Empty, error) { - return nil, status.Error(codes.Unimplemented, "method RegisterMutator not implemented") + return nil, status.Errorf(codes.Unimplemented, "method RegisterMutator not implemented") } func (UnimplementedExtensionServiceServer) ClearPolicy(context.Context, *ClearPolicyRequest) (*ClearPolicyResponse, error) { - return nil, status.Error(codes.Unimplemented, "method ClearPolicy not implemented") + return nil, status.Errorf(codes.Unimplemented, "method ClearPolicy not implemented") } func (UnimplementedExtensionServiceServer) RegisterActionMethod(context.Context, *RegisterActionMethodRequest) (*empty.Empty, error) { - return nil, status.Error(codes.Unimplemented, "method RegisterActionMethod not implemented") + return nil, status.Errorf(codes.Unimplemented, "method RegisterActionMethod not implemented") } func (UnimplementedExtensionServiceServer) PipelineCommit(context.Context, *PipelineCommitRequest) (*empty.Empty, error) { - return nil, status.Error(codes.Unimplemented, "method PipelineCommit not implemented") + return nil, status.Errorf(codes.Unimplemented, "method PipelineCommit not implemented") } func (UnimplementedExtensionServiceServer) mustEmbedUnimplementedExtensionServiceServer() {} func (UnimplementedExtensionServiceServer) testEmbeddedByValue() {} @@ -200,7 +200,7 @@ type UnsafeExtensionServiceServer interface { } func RegisterExtensionServiceServer(s grpc.ServiceRegistrar, srv ExtensionServiceServer) { - // If the following call panics, it indicates UnimplementedExtensionServiceServer was + // If the following call pancis, it indicates UnimplementedExtensionServiceServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. diff --git a/pkg/extension/grpc/v1/policy.pb.go b/pkg/extension/grpc/v1/policy.pb.go index ed974de8f..39b118b80 100644 --- a/pkg/extension/grpc/v1/policy.pb.go +++ b/pkg/extension/grpc/v1/policy.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.11 -// protoc v3.19.6 +// protoc-gen-go v1.36.6 +// protoc v5.29.3 // source: v1/policy.proto package v1