diff --git a/cmd/anubis/main.go b/cmd/anubis/main.go index 65241b9..6ad1027 100644 --- a/cmd/anubis/main.go +++ b/cmd/anubis/main.go @@ -439,26 +439,29 @@ func main() { } s, err := libanubis.New(libanubis.Options{ - BasePrefix: *basePrefix, - StripBasePrefix: *stripBasePrefix, - Next: rp, - Policy: policy, - ServeRobotsTXT: *robotsTxt, - ED25519PrivateKey: ed25519Priv, - HS512Secret: []byte(*hs512Secret), - CookieDomain: *cookieDomain, - CookieDynamicDomain: *cookieDynamicDomain, - CookieExpiration: *cookieExpiration, - CookiePartitioned: *cookiePartitioned, - RedirectDomains: redirectDomainsList, - Target: *target, - WebmasterEmail: *webmasterEmail, - OpenGraph: policy.OpenGraph, - CookieSecure: *cookieSecure, - CookieSameSite: parseSameSite(*cookieSameSite), - PublicUrl: *publicUrl, - JWTRestrictionHeader: *jwtRestrictionHeader, - DifficultyInJWT: *difficultyInJWT, + BasePrefix: *basePrefix, + StripBasePrefix: *stripBasePrefix, + Next: rp, + Policy: policy, + TargetHost: *targetHost, + TargetSNI: *targetSNI, + TargetInsecureSkipVerify: *targetInsecureSkipVerify, + ServeRobotsTXT: *robotsTxt, + ED25519PrivateKey: ed25519Priv, + HS512Secret: []byte(*hs512Secret), + CookieDomain: *cookieDomain, + CookieDynamicDomain: *cookieDynamicDomain, + CookieExpiration: *cookieExpiration, + CookiePartitioned: *cookiePartitioned, + RedirectDomains: redirectDomainsList, + Target: *target, + WebmasterEmail: *webmasterEmail, + OpenGraph: policy.OpenGraph, + CookieSecure: *cookieSecure, + CookieSameSite: parseSameSite(*cookieSameSite), + PublicUrl: *publicUrl, + JWTRestrictionHeader: *jwtRestrictionHeader, + DifficultyInJWT: *difficultyInJWT, }) if err != nil { log.Fatalf("can't construct libanubis.Server: %v", err) diff --git a/docs/docs/CHANGELOG.md b/docs/docs/CHANGELOG.md index 0c5858d..66e09ea 100644 --- a/docs/docs/CHANGELOG.md +++ b/docs/docs/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Allow Renovate as an OCI registry client. - Properly handle 4in6 addresses so that IP matching works with those addresses. - Add support to simple Valkey/Redis cluster mode +- Open Graph passthrough now reuses the configured target Host/SNI/TLS settings, so metadata fetches succeed when the upstream certificate differs from the public domain. ([1283](https://github.com/TecharoHQ/anubis/pull/1283)) - Stabilize the CVE-2025-24369 regression test by always submitting an invalid proof instead of relying on random POW failures. ## v1.23.1: Lyse Hext - Echo 1 diff --git a/internal/ogtags/cache_test.go b/internal/ogtags/cache_test.go index 08bf4e3..89ba229 100644 --- a/internal/ogtags/cache_test.go +++ b/internal/ogtags/cache_test.go @@ -24,7 +24,7 @@ func TestCacheReturnsDefault(t *testing.T) { TimeToLive: time.Minute, ConsiderHost: false, Override: want, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) u, err := url.Parse("https://anubis.techaro.lol") if err != nil { @@ -52,7 +52,7 @@ func TestCheckCache(t *testing.T) { Enabled: true, TimeToLive: time.Minute, ConsiderHost: false, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) // Set up test data urlStr := "http://example.com/page" @@ -115,7 +115,7 @@ func TestGetOGTags(t *testing.T) { Enabled: true, TimeToLive: time.Minute, ConsiderHost: false, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) // Parse the test server URL parsedURL, err := url.Parse(ts.URL) @@ -271,7 +271,7 @@ func TestGetOGTagsWithHostConsideration(t *testing.T) { Enabled: true, TimeToLive: time.Minute, ConsiderHost: tc.ogCacheConsiderHost, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) for i, req := range tc.requests { ogTags, err := cache.GetOGTags(t.Context(), parsedURL, req.host) diff --git a/internal/ogtags/fetch.go b/internal/ogtags/fetch.go index 26a0af2..0bfb0a1 100644 --- a/internal/ogtags/fetch.go +++ b/internal/ogtags/fetch.go @@ -27,16 +27,29 @@ func (c *OGTagCache) fetchHTMLDocumentWithCache(ctx context.Context, urlStr stri } // Set the Host header to the original host - if originalHost != "" { - req.Host = originalHost + var hostForRequest string + switch { + case c.targetHost != "": + hostForRequest = c.targetHost + case originalHost != "": + hostForRequest = originalHost + } + if hostForRequest != "" { + req.Host = hostForRequest } // Add proxy headers req.Header.Set("X-Forwarded-Proto", "https") req.Header.Set("User-Agent", "Anubis-OGTag-Fetcher/1.0") // For tracking purposes + serverName := hostForRequest + if serverName == "" { + serverName = req.URL.Hostname() + } + client := c.clientForSNI(serverName) + // Send the request - resp, err := c.client.Do(req) + resp, err := client.Do(req) if err != nil { var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { diff --git a/internal/ogtags/fetch_test.go b/internal/ogtags/fetch_test.go index c986272..864e8f2 100644 --- a/internal/ogtags/fetch_test.go +++ b/internal/ogtags/fetch_test.go @@ -87,7 +87,7 @@ func TestFetchHTMLDocument(t *testing.T) { Enabled: true, TimeToLive: time.Minute, ConsiderHost: false, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) doc, err := cache.fetchHTMLDocument(t.Context(), ts.URL, "anything") if tt.expectError { @@ -118,7 +118,7 @@ func TestFetchHTMLDocumentInvalidURL(t *testing.T) { Enabled: true, TimeToLive: time.Minute, ConsiderHost: false, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) doc, err := cache.fetchHTMLDocument(t.Context(), "http://invalid.url.that.doesnt.exist.example", "anything") diff --git a/internal/ogtags/integration_test.go b/internal/ogtags/integration_test.go index 574172d..af56668 100644 --- a/internal/ogtags/integration_test.go +++ b/internal/ogtags/integration_test.go @@ -111,7 +111,7 @@ func TestIntegrationGetOGTags(t *testing.T) { Enabled: true, TimeToLive: time.Minute, ConsiderHost: false, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) // Create URL for test testURL, _ := url.Parse(ts.URL) diff --git a/internal/ogtags/mem_test.go b/internal/ogtags/mem_test.go index b415cda..7d2ac0c 100644 --- a/internal/ogtags/mem_test.go +++ b/internal/ogtags/mem_test.go @@ -31,7 +31,7 @@ func BenchmarkGetTarget(b *testing.B) { for _, tt := range tests { b.Run(tt.name, func(b *testing.B) { - cache := NewOGTagCache(tt.target, config.OpenGraph{}, memory.New(b.Context())) + cache := NewOGTagCache(tt.target, config.OpenGraph{}, memory.New(b.Context()), TargetOptions{}) urls := make([]*url.URL, len(tt.paths)) for i, path := range tt.paths { u, _ := url.Parse(path) @@ -67,7 +67,7 @@ func BenchmarkExtractOGTags(b *testing.B) {

Content

`, } - cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(b.Context())) + cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(b.Context()), TargetOptions{}) docs := make([]*html.Node, len(htmlSamples)) for i, sample := range htmlSamples { @@ -85,7 +85,7 @@ func BenchmarkExtractOGTags(b *testing.B) { // Memory usage test func TestMemoryUsage(t *testing.T) { - cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(t.Context())) + cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(t.Context()), TargetOptions{}) // Force GC and wait for it to complete runtime.GC() diff --git a/internal/ogtags/ogtags.go b/internal/ogtags/ogtags.go index 37cf79d..f0c0adf 100644 --- a/internal/ogtags/ogtags.go +++ b/internal/ogtags/ogtags.go @@ -2,11 +2,13 @@ package ogtags import ( "context" + "crypto/tls" "log/slog" "net" "net/http" "net/url" "strings" + "sync" "time" "github.com/TecharoHQ/anubis/lib/policy/config" @@ -22,21 +24,34 @@ const ( ) type OGTagCache struct { + ogOverride map[string]string targetURL *url.URL client *http.Client - ogOverride map[string]string + transport *http.Transport cache store.JSON[map[string]string] // Pre-built strings for optimization unixPrefix string // "http://unix" - approvedTags []string + targetSNI string + targetHost string approvedPrefixes []string + approvedTags []string ogTimeToLive time.Duration - ogCacheConsiderHost bool ogPassthrough bool + ogCacheConsiderHost bool + targetSNIAuto bool + insecureSkipVerify bool + sniClients map[string]*http.Client + transportMu sync.RWMutex } -func NewOGTagCache(target string, conf config.OpenGraph, backend store.Interface) *OGTagCache { +type TargetOptions struct { + Host string + SNI string + InsecureSkipVerify bool +} + +func NewOGTagCache(target string, conf config.OpenGraph, backend store.Interface, targetOpts TargetOptions) *OGTagCache { // Predefined approved tags and prefixes defaultApprovedTags := []string{"description", "keywords", "author"} defaultApprovedPrefixes := []string{"og:", "twitter:", "fediverse:"} @@ -62,20 +77,37 @@ func NewOGTagCache(target string, conf config.OpenGraph, backend store.Interface } } - client := &http.Client{ - Timeout: httpTimeout, - } + transport := http.DefaultTransport.(*http.Transport).Clone() // Configure custom transport for Unix sockets if parsedTargetURL.Scheme == "unix" { socketPath := parsedTargetURL.Path // For unix scheme, path is the socket path - client.Transport = &http.Transport{ - DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { - return net.Dial("unix", socketPath) - }, + transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", socketPath) } } + targetSNIAuto := targetOpts.SNI == "auto" + + if targetOpts.SNI != "" && !targetSNIAuto { + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{} + } + transport.TLSClientConfig.ServerName = targetOpts.SNI + } + + if targetOpts.InsecureSkipVerify { + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{} + } + transport.TLSClientConfig.InsecureSkipVerify = true + } + + client := &http.Client{ + Timeout: httpTimeout, + Transport: transport, + } + return &OGTagCache{ cache: store.JSON[map[string]string]{ Underlying: backend, @@ -89,7 +121,13 @@ func NewOGTagCache(target string, conf config.OpenGraph, backend store.Interface approvedTags: defaultApprovedTags, approvedPrefixes: defaultApprovedPrefixes, client: client, + transport: transport, unixPrefix: "http://unix", + targetHost: targetOpts.Host, + targetSNI: targetOpts.SNI, + targetSNIAuto: targetSNIAuto, + insecureSkipVerify: targetOpts.InsecureSkipVerify, + sniClients: make(map[string]*http.Client), } } diff --git a/internal/ogtags/ogtags_fuzz_test.go b/internal/ogtags/ogtags_fuzz_test.go index 1656e58..499d9f5 100644 --- a/internal/ogtags/ogtags_fuzz_test.go +++ b/internal/ogtags/ogtags_fuzz_test.go @@ -48,7 +48,7 @@ func FuzzGetTarget(f *testing.F) { } // Create cache - should not panic - cache := NewOGTagCache(target, config.OpenGraph{}, memory.New(context.Background())) + cache := NewOGTagCache(target, config.OpenGraph{}, memory.New(context.Background()), TargetOptions{}) // Create URL u := &url.URL{ @@ -132,7 +132,7 @@ func FuzzExtractOGTags(f *testing.F) { return } - cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(context.Background())) + cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(context.Background()), TargetOptions{}) // Should not panic tags := cache.extractOGTags(doc) @@ -188,7 +188,7 @@ func FuzzGetTargetRoundTrip(f *testing.F) { t.Skip() } - cache := NewOGTagCache(target, config.OpenGraph{}, memory.New(context.Background())) + cache := NewOGTagCache(target, config.OpenGraph{}, memory.New(context.Background()), TargetOptions{}) u := &url.URL{Path: path, RawQuery: query} result := cache.getTarget(u) @@ -245,7 +245,7 @@ func FuzzExtractMetaTagInfo(f *testing.F) { }, } - cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(context.Background())) + cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(context.Background()), TargetOptions{}) // Should not panic property, content := cache.extractMetaTagInfo(node) @@ -298,7 +298,7 @@ func BenchmarkFuzzedGetTarget(b *testing.B) { for _, input := range inputs { b.Run(input.name, func(b *testing.B) { - cache := NewOGTagCache(input.target, config.OpenGraph{}, memory.New(context.Background())) + cache := NewOGTagCache(input.target, config.OpenGraph{}, memory.New(context.Background()), TargetOptions{}) u := &url.URL{Path: input.path, RawQuery: input.query} b.ResetTimer() diff --git a/internal/ogtags/ogtags_test.go b/internal/ogtags/ogtags_test.go index c936b01..7441119 100644 --- a/internal/ogtags/ogtags_test.go +++ b/internal/ogtags/ogtags_test.go @@ -2,15 +2,23 @@ package ogtags import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "errors" "fmt" + "math/big" "net" "net/http" + "net/http/httptest" "net/url" "os" "path/filepath" "reflect" "strings" + "sync" "testing" "time" @@ -45,7 +53,7 @@ func TestNewOGTagCache(t *testing.T) { Enabled: tt.ogPassthrough, TimeToLive: tt.ogTimeToLive, ConsiderHost: false, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) if cache == nil { t.Fatal("expected non-nil cache, got nil") @@ -85,7 +93,7 @@ func TestNewOGTagCache_UnixSocket(t *testing.T) { Enabled: true, TimeToLive: 5 * time.Minute, ConsiderHost: false, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) if cache == nil { t.Fatal("expected non-nil cache, got nil") @@ -170,7 +178,7 @@ func TestGetTarget(t *testing.T) { Enabled: true, TimeToLive: time.Minute, ConsiderHost: false, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) u := &url.URL{ Path: tt.path, @@ -243,7 +251,7 @@ func TestIntegrationGetOGTags_UnixSocket(t *testing.T) { Enabled: true, TimeToLive: time.Minute, ConsiderHost: false, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) // Create a dummy URL for the request (path and query matter) testReqURL, _ := url.Parse("/some/page?query=1") @@ -274,3 +282,244 @@ func TestIntegrationGetOGTags_UnixSocket(t *testing.T) { t.Errorf("Expected cached OG tags %v, got %v", expectedTags, cachedTags) } } + +func TestGetOGTagsWithTargetHostOverride(t *testing.T) { + originalHost := "example.test" + overrideHost := "backend.internal" + seenHosts := make(chan string, 10) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seenHosts <- r.Host + w.Header().Set("Content-Type", "text/html") + fmt.Fprintln(w, `ok`) + })) + defer ts.Close() + + targetURL, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("failed to parse server URL: %v", err) + } + + conf := config.OpenGraph{ + Enabled: true, + TimeToLive: time.Minute, + ConsiderHost: false, + } + + t.Run("default host uses original", func(t *testing.T) { + cache := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{}) + if _, err := cache.GetOGTags(t.Context(), targetURL, originalHost); err != nil { + t.Fatalf("GetOGTags failed: %v", err) + } + select { + case host := <-seenHosts: + if host != originalHost { + t.Fatalf("expected host %q, got %q", originalHost, host) + } + case <-time.After(time.Second): + t.Fatal("server did not receive request") + } + }) + + t.Run("override host respected", func(t *testing.T) { + cache := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{ + Host: overrideHost, + }) + if _, err := cache.GetOGTags(t.Context(), targetURL, originalHost); err != nil { + t.Fatalf("GetOGTags failed: %v", err) + } + select { + case host := <-seenHosts: + if host != overrideHost { + t.Fatalf("expected host %q, got %q", overrideHost, host) + } + case <-time.After(time.Second): + t.Fatal("server did not receive request") + } + }) +} + +func TestGetOGTagsWithInsecureSkipVerify(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprintln(w, `hello`) + }) + ts := httptest.NewTLSServer(handler) + defer ts.Close() + + parsedURL, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("failed to parse server URL: %v", err) + } + + conf := config.OpenGraph{ + Enabled: true, + TimeToLive: time.Minute, + ConsiderHost: false, + } + + // Without skip verify we should get a TLS error + cacheStrict := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{}) + if _, err := cacheStrict.GetOGTags(t.Context(), parsedURL, parsedURL.Host); err == nil { + t.Fatal("expected TLS verification error without InsecureSkipVerify") + } + + cacheSkip := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{ + InsecureSkipVerify: true, + }) + + tags, err := cacheSkip.GetOGTags(t.Context(), parsedURL, parsedURL.Host) + if err != nil { + t.Fatalf("expected successful fetch with InsecureSkipVerify, got: %v", err) + } + if tags["og:title"] != "Self-Signed" { + t.Fatalf("expected og:title to be %q, got %q", "Self-Signed", tags["og:title"]) + } +} + +func TestGetOGTagsWithTargetSNI(t *testing.T) { + originalHost := "hecate.test" + conf := config.OpenGraph{ + Enabled: true, + TimeToLive: time.Minute, + ConsiderHost: false, + } + + t.Run("explicit SNI override", func(t *testing.T) { + expectedSNI := "backend.internal" + ts, recorder := newSNIServer(t, `ok`) + defer ts.Close() + + targetURL, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("failed to parse server URL: %v", err) + } + + cacheExplicit := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{ + SNI: expectedSNI, + InsecureSkipVerify: true, + }) + if _, err := cacheExplicit.GetOGTags(t.Context(), targetURL, originalHost); err != nil { + t.Fatalf("expected successful fetch with explicit SNI, got: %v", err) + } + if got := recorder.last(); got != expectedSNI { + t.Fatalf("expected server to see SNI %q, got %q", expectedSNI, got) + } + }) + + t.Run("auto SNI uses original host", func(t *testing.T) { + ts, recorder := newSNIServer(t, `ok`) + defer ts.Close() + + targetURL, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("failed to parse server URL: %v", err) + } + + cacheAuto := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{ + SNI: "auto", + InsecureSkipVerify: true, + }) + if _, err := cacheAuto.GetOGTags(t.Context(), targetURL, originalHost); err != nil { + t.Fatalf("expected successful fetch with auto SNI, got: %v", err) + } + if got := recorder.last(); got != originalHost { + t.Fatalf("expected server to see SNI %q with auto, got %q", originalHost, got) + } + }) + + t.Run("default SNI uses backend host", func(t *testing.T) { + ts, recorder := newSNIServer(t, `ok`) + defer ts.Close() + + targetURL, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("failed to parse server URL: %v", err) + } + + cacheDefault := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{ + InsecureSkipVerify: true, + }) + if _, err := cacheDefault.GetOGTags(t.Context(), targetURL, originalHost); err != nil { + t.Fatalf("expected successful fetch without explicit SNI, got: %v", err) + } + wantSNI := "" + if net.ParseIP(targetURL.Hostname()) == nil { + wantSNI = targetURL.Hostname() + } + if got := recorder.last(); got != wantSNI { + t.Fatalf("expected default SNI %q, got %q", wantSNI, got) + } + }) +} + +func newSNIServer(t *testing.T, body string) (*httptest.Server, *sniRecorder) { + t.Helper() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, body) + }) + + recorder := &sniRecorder{} + ts := httptest.NewUnstartedServer(handler) + cert := mustCertificateForHost(t, "sni.test") + ts.TLS = &tls.Config{ + GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + recorder.record(hello.ServerName) + return &cert, nil + }, + } + ts.StartTLS() + return ts, recorder +} + +func mustCertificateForHost(t *testing.T, host string) tls.Certificate { + t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: host, + }, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + BasicConstraintsValid: true, + DNSNames: []string{host}, + } + + der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + t.Fatalf("failed to create certificate: %v", err) + } + + return tls.Certificate{ + Certificate: [][]byte{der}, + PrivateKey: priv, + } +} + +type sniRecorder struct { + mu sync.Mutex + names []string +} + +func (r *sniRecorder) record(name string) { + r.mu.Lock() + defer r.mu.Unlock() + r.names = append(r.names, name) +} + +func (r *sniRecorder) last() string { + r.mu.Lock() + defer r.mu.Unlock() + if len(r.names) == 0 { + return "" + } + return r.names[len(r.names)-1] +} diff --git a/internal/ogtags/parse_test.go b/internal/ogtags/parse_test.go index d479d11..55e536a 100644 --- a/internal/ogtags/parse_test.go +++ b/internal/ogtags/parse_test.go @@ -18,7 +18,7 @@ func TestExtractOGTags(t *testing.T) { Enabled: false, ConsiderHost: false, TimeToLive: time.Minute, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) // Manually set approved tags/prefixes based on the user request for clarity testCache.approvedTags = []string{"description"} testCache.approvedPrefixes = []string{"og:"} @@ -199,7 +199,7 @@ func TestExtractMetaTagInfo(t *testing.T) { Enabled: false, ConsiderHost: false, TimeToLive: time.Minute, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) testCache.approvedTags = []string{"description"} testCache.approvedPrefixes = []string{"og:"} diff --git a/internal/ogtags/sni.go b/internal/ogtags/sni.go new file mode 100644 index 0000000..46cfe03 --- /dev/null +++ b/internal/ogtags/sni.go @@ -0,0 +1,42 @@ +package ogtags + +import ( + "crypto/tls" + "net/http" +) + +// clientForSNI returns a cached client for the given server name, creating one if needed. +func (c *OGTagCache) clientForSNI(serverName string) *http.Client { + if !c.targetSNIAuto || serverName == "" { + return c.client + } + + c.transportMu.RLock() + cli, ok := c.sniClients[serverName] + c.transportMu.RUnlock() + if ok { + return cli + } + + c.transportMu.Lock() + defer c.transportMu.Unlock() + if cli, ok := c.sniClients[serverName]; ok { + return cli + } + + tr := c.transport.Clone() + if tr.TLSClientConfig == nil { + tr.TLSClientConfig = &tls.Config{} + } + tr.TLSClientConfig.ServerName = serverName + if c.insecureSkipVerify { + tr.TLSClientConfig.InsecureSkipVerify = true + } + + cli = &http.Client{ + Timeout: httpTimeout, + Transport: tr, + } + c.sniClients[serverName] = cli + return cli +} diff --git a/lib/config.go b/lib/config.go index 1c5bc7c..1c4d7a0 100644 --- a/lib/config.go +++ b/lib/config.go @@ -27,27 +27,30 @@ import ( ) type Options struct { - Next http.Handler - Policy *policy.ParsedConfig - Logger *slog.Logger - OpenGraph config.OpenGraph - PublicUrl string - CookieDomain string - JWTRestrictionHeader string - BasePrefix string - WebmasterEmail string - Target string - RedirectDomains []string - ED25519PrivateKey ed25519.PrivateKey - HS512Secret []byte - CookieExpiration time.Duration - CookieSameSite http.SameSite - ServeRobotsTXT bool - CookieSecure bool - StripBasePrefix bool - CookiePartitioned bool - CookieDynamicDomain bool - DifficultyInJWT bool + Next http.Handler + Policy *policy.ParsedConfig + Target string + TargetHost string + TargetSNI string + TargetInsecureSkipVerify bool + CookieDynamicDomain bool + CookieDomain string + CookieExpiration time.Duration + CookiePartitioned bool + BasePrefix string + WebmasterEmail string + RedirectDomains []string + ED25519PrivateKey ed25519.PrivateKey + HS512Secret []byte + StripBasePrefix bool + OpenGraph config.OpenGraph + ServeRobotsTXT bool + CookieSecure bool + CookieSameSite http.SameSite + Logger *slog.Logger + PublicUrl string + JWTRestrictionHeader string + DifficultyInJWT bool } func LoadPoliciesOrDefault(ctx context.Context, fname string, defaultDifficulty int) (*policy.ParsedConfig, error) { @@ -116,9 +119,13 @@ func New(opts Options) (*Server, error) { hs512Secret: opts.HS512Secret, policy: opts.Policy, opts: opts, - OGTags: ogtags.NewOGTagCache(opts.Target, opts.Policy.OpenGraph, opts.Policy.Store), - store: opts.Policy.Store, - logger: opts.Logger, + OGTags: ogtags.NewOGTagCache(opts.Target, opts.Policy.OpenGraph, opts.Policy.Store, ogtags.TargetOptions{ + Host: opts.TargetHost, + SNI: opts.TargetSNI, + InsecureSkipVerify: opts.TargetInsecureSkipVerify, + }), + store: opts.Policy.Store, + logger: opts.Logger, } mux := http.NewServeMux() diff --git a/lib/policy/config/config.go b/lib/policy/config/config.go index e2c62a2..577470a 100644 --- a/lib/policy/config/config.go +++ b/lib/policy/config/config.go @@ -62,11 +62,14 @@ type BotConfig struct { Expression *ExpressionOrList `json:"expression,omitempty" yaml:"expression,omitempty"` Challenge *ChallengeRules `json:"challenge,omitempty" yaml:"challenge,omitempty"` Weight *Weight `json:"weight,omitempty" yaml:"weight,omitempty"` - GeoIP *GeoIP `json:"geoip,omitempty"` - ASNs *ASNs `json:"asns,omitempty"` - Name string `json:"name" yaml:"name"` - Action Rule `json:"action" yaml:"action"` - RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"` + + // Thoth features + GeoIP *GeoIP `json:"geoip,omitempty"` + ASNs *ASNs `json:"asns,omitempty"` + + Name string `json:"name" yaml:"name"` + Action Rule `json:"action" yaml:"action"` + RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"` } func (b BotConfig) Zero() bool {