diff --git a/.github/actions/spelling/excludes.txt b/.github/actions/spelling/excludes.txt index 8e7cbad..af8c574 100644 --- a/.github/actions/spelling/excludes.txt +++ b/.github/actions/spelling/excludes.txt @@ -88,6 +88,7 @@ ^docs/manifest/.*$ ^docs/static/\.nojekyll$ ^lib/policy/config/testdata/bad/unparseable\.json$ +^internal/glob/glob_test.go$ ignore$ robots.txt ^lib/localization/locales/.*\.json$ diff --git a/docs/docs/CHANGELOG.md b/docs/docs/CHANGELOG.md index 15ff184..b32a12e 100644 --- a/docs/docs/CHANGELOG.md +++ b/docs/docs/CHANGELOG.md @@ -43,6 +43,7 @@ A new ["proof of React"](./admin/configuration/challenges/preact.mdx) has been a - Add a default block rule for Alibaba Cloud. - Added support to use Traefik forwardAuth middleware. - Add X-Request-URI support so that Subrequest Authentication has path support. +- Added glob matching for `REDIRECT_DOMAINS`. You can pass `*.bugs.techaro.lol` to allow redirecting to anything ending with `.bugs.techaro.lol`. There is a limit of 4 wildcards. ### Fixes diff --git a/internal/glob/glob.go b/internal/glob/glob.go new file mode 100644 index 0000000..44c1a67 --- /dev/null +++ b/internal/glob/glob.go @@ -0,0 +1,61 @@ +package glob + +import "strings" + +const GLOB = "*" + +const maxGlobParts = 5 + +// Glob will test a string pattern, potentially containing globs, against a +// subject string. The result is a simple true/false, determining whether or +// not the glob pattern matched the subject text. +func Glob(pattern, subj string) bool { + // Empty pattern can only match empty subject + if pattern == "" { + return subj == pattern + } + + // If the pattern _is_ a glob, it matches everything + if pattern == GLOB { + return true + } + + parts := strings.Split(pattern, GLOB) + + if len(parts) > maxGlobParts { + return false // Pattern is too complex, reject it. + } + + if len(parts) == 1 { + // No globs in pattern, so test for equality + return subj == pattern + } + + leadingGlob := strings.HasPrefix(pattern, GLOB) + trailingGlob := strings.HasSuffix(pattern, GLOB) + end := len(parts) - 1 + + // Go over the leading parts and ensure they match. + for i := 0; i < end; i++ { + idx := strings.Index(subj, parts[i]) + + switch i { + case 0: + // Check the first section. Requires special handling. + if !leadingGlob && idx != 0 { + return false + } + default: + // Check that the middle parts match. + if idx < 0 { + return false + } + } + + // Trim evaluated text from subj as we loop over the pattern. + subj = subj[idx+len(parts[i]):] + } + + // Reached the last section. Requires special handling. + return trailingGlob || strings.HasSuffix(subj, parts[end]) +} diff --git a/internal/glob/glob_test.go b/internal/glob/glob_test.go new file mode 100644 index 0000000..257752c --- /dev/null +++ b/internal/glob/glob_test.go @@ -0,0 +1,189 @@ +package glob + +import "testing" + +func TestGlob_EqualityAndEmpty(t *testing.T) { + cases := []struct { + name string + pattern string + subj string + want bool + }{ + {"exact match", "hello", "hello", true}, + {"exact mismatch", "hello", "hell", false}, + {"empty pattern and subject", "", "", true}, + {"empty pattern with non-empty subject", "", "x", false}, + {"pattern star matches empty", "*", "", true}, + {"pattern star matches anything", "*", "anything at all", true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := Glob(tc.pattern, tc.subj); got != tc.want { + t.Fatalf("Glob(%q,%q) = %v, want %v", tc.pattern, tc.subj, got, tc.want) + } + }) + } +} + +func TestGlob_LeadingAndTrailing(t *testing.T) { + cases := []struct { + name string + pattern string + subj string + want bool + }{ + {"prefix match - minimal", "foo*", "foo", true}, + {"prefix match - extended", "foo*", "foobar", true}, + {"prefix mismatch - not at start", "foo*", "xfoo", false}, + {"suffix match - minimal", "*foo", "foo", true}, + {"suffix match - extended", "*foo", "xfoo", true}, + {"suffix mismatch - not at end", "*foo", "foox", false}, + {"contains match", "*foo*", "barfoobaz", true}, + {"contains mismatch - missing needle", "*foo*", "f", false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := Glob(tc.pattern, tc.subj); got != tc.want { + t.Fatalf("Glob(%q,%q) = %v, want %v", tc.pattern, tc.subj, got, tc.want) + } + }) + } +} + +func TestGlob_MiddleAndOrder(t *testing.T) { + cases := []struct { + name string + pattern string + subj string + want bool + }{ + {"middle wildcard basic", "f*o", "fo", true}, + {"middle wildcard gap", "f*o", "fZZZo", true}, + {"middle wildcard requires start f", "f*o", "xfyo", false}, + {"order enforced across parts", "a*b*c*d", "axxbxxcxxd", true}, + {"order mismatch fails", "a*b*c*d", "abdc", false}, + {"must end with last part when no trailing *", "*foo*bar", "zzfooqqbar", true}, + {"failing when trailing chars remain", "*foo*bar", "zzfooqqbarzz", false}, + {"first part must start when no leading *", "foo*bar", "zzfooqqbar", false}, + {"works with overlapping content", "ab*ba", "ababa", true}, + {"needle not found fails", "foo*bar", "foobaz", false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := Glob(tc.pattern, tc.subj); got != tc.want { + t.Fatalf("Glob(%q,%q) = %v, want %v", tc.pattern, tc.subj, got, tc.want) + } + }) + } +} + +func TestGlob_ConsecutiveStarsAndEmptyParts(t *testing.T) { + cases := []struct { + name string + pattern string + subj string + want bool + }{ + {"double star matches anything", "**", "", true}, + {"double star matches anything non-empty", "**", "abc", true}, + {"consecutive stars behave like single", "a**b", "ab", true}, + {"consecutive stars with gaps", "a**b", "axxxb", true}, + {"consecutive stars + trailing star", "a**b*", "axxbzzz", true}, + {"consecutive stars still enforce anchors", "a**b", "xaBy", false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := Glob(tc.pattern, tc.subj); got != tc.want { + t.Fatalf("Glob(%q,%q) = %v, want %v", tc.pattern, tc.subj, got, tc.want) + } + }) + } +} + +func TestGlob_MaxPartsLimit(t *testing.T) { + // Allowed: up to 4 '*' (5 parts) + allowed := []struct { + pattern string + subj string + want bool + }{ + {"a*b*c*d*e", "axxbxxcxxdxxe", true}, // 4 stars -> 5 parts + {"*a*b*c*d", "zzzaaaabbbcccddd", true}, + {"a*b*c*d*e", "abcde", true}, + {"a*b*c*d*e", "abxdxe", false}, // missing 'c' should fail + } + for _, tc := range allowed { + if got := Glob(tc.pattern, tc.subj); got != tc.want { + t.Fatalf("allowed pattern Glob(%q,%q) = %v, want %v", tc.pattern, tc.subj, got, tc.want) + } + } + + // Disallowed: 5 '*' (6 parts) -> always false by complexity check + disallowed := []struct { + pattern string + subj string + }{ + {"a*b*c*d*e*f", "aXXbYYcZZdQQeRRf"}, + {"*a*b*c*d*e*", "abcdef"}, + {"******", "anything"}, // 6 stars -> 7 parts + } + for _, tc := range disallowed { + if got := Glob(tc.pattern, tc.subj); got { + t.Fatalf("disallowed pattern should fail Glob(%q,%q) = %v, want false", tc.pattern, tc.subj, got) + } + } +} + +func TestGlob_CaseSensitivity(t *testing.T) { + cases := []struct { + pattern string + subj string + want bool + }{ + {"FOO*", "foo", false}, + {"*Bar", "bar", false}, + {"Foo*Bar", "FooZZZBar", true}, + } + for _, tc := range cases { + if got := Glob(tc.pattern, tc.subj); got != tc.want { + t.Fatalf("Glob(%q,%q) = %v, want %v", tc.pattern, tc.subj, got, tc.want) + } + } +} + +func TestGlob_EmptySubjectInteractions(t *testing.T) { + cases := []struct { + pattern string + subj string + want bool + }{ + {"*a", "", false}, + {"a*", "", false}, + {"**", "", true}, + {"*", "", true}, + } + for _, tc := range cases { + if got := Glob(tc.pattern, tc.subj); got != tc.want { + t.Fatalf("Glob(%q,%q) = %v, want %v", tc.pattern, tc.subj, got, tc.want) + } + } +} + +func BenchmarkGlob(b *testing.B) { + patterns := []string{ + "*", "*foo*", "foo*bar", "a*b*c*d*e", "a**b*", "*needle*end", + } + subjects := []string{ + "", "foo", "barfoo", "foobarbaz", "axxbxxcxxdxxe", "zzfooqqbarzz", + "lorem ipsum dolor sit amet, consectetur adipiscing elit", + } + for _, p := range patterns { + for _, s := range subjects { + b.Run(p+"::"+s, func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = Glob(p, s) + } + }) + } + } +} diff --git a/lib/anubis.go b/lib/anubis.go index 1e2ccca..7b8d4f1 100644 --- a/lib/anubis.go +++ b/lib/anubis.go @@ -11,7 +11,6 @@ import ( "net" "net/http" "net/url" - "slices" "strings" "time" @@ -435,7 +434,7 @@ func (s *Server) PassChallenge(w http.ResponseWriter, r *http.Request) { s.respondWithError(w, r, localizer.T("redirect_not_parseable")) return } - if (len(urlParsed.Host) > 0 && len(s.opts.RedirectDomains) != 0 && !slices.Contains(s.opts.RedirectDomains, urlParsed.Host)) || urlParsed.Host != r.URL.Host { + if (len(urlParsed.Host) > 0 && len(s.opts.RedirectDomains) != 0 && !matchRedirectDomain(s.opts.RedirectDomains, urlParsed.Host)) || urlParsed.Host != r.URL.Host { lg.Debug("domain not allowed", "domain", urlParsed.Host) s.respondWithError(w, r, localizer.T("redirect_domain_not_allowed")) return diff --git a/lib/http.go b/lib/http.go index b1a449d..6110707 100644 --- a/lib/http.go +++ b/lib/http.go @@ -7,12 +7,12 @@ import ( "net/http" "net/url" "regexp" - "slices" "strings" "time" "github.com/TecharoHQ/anubis" "github.com/TecharoHQ/anubis/internal" + "github.com/TecharoHQ/anubis/internal/glob" "github.com/TecharoHQ/anubis/lib/challenge" "github.com/TecharoHQ/anubis/lib/localization" "github.com/TecharoHQ/anubis/lib/policy" @@ -24,6 +24,26 @@ import ( var domainMatchRegexp = regexp.MustCompile(`^((xn--)?[a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) +// matchRedirectDomain returns true if host matches any of the allowed redirect +// domain patterns. Patterns may contain '*' which are matched using the +// internal glob matcher. Matching is case-insensitive on hostnames. +func matchRedirectDomain(allowed []string, host string) bool { + h := strings.ToLower(strings.TrimSpace(host)) + for _, pat := range allowed { + p := strings.ToLower(strings.TrimSpace(pat)) + if strings.Contains(p, glob.GLOB) { + if glob.Glob(p, h) { + return true + } + continue + } + if p == h { + return true + } + } + return false +} + type CookieOpts struct { Value string Host string @@ -217,8 +237,8 @@ func (s *Server) constructRedirectURL(r *http.Request) (string, error) { if proto == "" || host == "" || uri == "" { return "", errors.New(localizer.T("missing_required_forwarded_headers")) } - // Check if host is allowed in RedirectDomains - if len(s.opts.RedirectDomains) > 0 && !slices.Contains(s.opts.RedirectDomains, host) { + // Check if host is allowed in RedirectDomains (supports '*' via glob) + if len(s.opts.RedirectDomains) > 0 && !matchRedirectDomain(s.opts.RedirectDomains, host) { lg := internal.GetRequestLogger(s.logger, r) lg.Debug("domain not allowed", "domain", host) return "", errors.New(localizer.T("redirect_domain_not_allowed")) @@ -290,7 +310,7 @@ func (s *Server) ServeHTTPNext(w http.ResponseWriter, r *http.Request) { hostNotAllowed := len(urlParsed.Host) > 0 && len(s.opts.RedirectDomains) != 0 && - !slices.Contains(s.opts.RedirectDomains, urlParsed.Host) + !matchRedirectDomain(s.opts.RedirectDomains, urlParsed.Host) hostMismatch := r.URL.Host != "" && urlParsed.Host != r.URL.Host if hostNotAllowed || hostMismatch {