fix(ogtags): respect target host/SNI/insecure flags in OG passthrough (#1283)

This commit is contained in:
Jason Cameron 2025-11-16 21:32:03 -05:00 committed by GitHub
parent c70b939651
commit 1d91bc99f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 440 additions and 84 deletions

View file

@ -439,26 +439,29 @@ func main() {
} }
s, err := libanubis.New(libanubis.Options{ s, err := libanubis.New(libanubis.Options{
BasePrefix: *basePrefix, BasePrefix: *basePrefix,
StripBasePrefix: *stripBasePrefix, StripBasePrefix: *stripBasePrefix,
Next: rp, Next: rp,
Policy: policy, Policy: policy,
ServeRobotsTXT: *robotsTxt, TargetHost: *targetHost,
ED25519PrivateKey: ed25519Priv, TargetSNI: *targetSNI,
HS512Secret: []byte(*hs512Secret), TargetInsecureSkipVerify: *targetInsecureSkipVerify,
CookieDomain: *cookieDomain, ServeRobotsTXT: *robotsTxt,
CookieDynamicDomain: *cookieDynamicDomain, ED25519PrivateKey: ed25519Priv,
CookieExpiration: *cookieExpiration, HS512Secret: []byte(*hs512Secret),
CookiePartitioned: *cookiePartitioned, CookieDomain: *cookieDomain,
RedirectDomains: redirectDomainsList, CookieDynamicDomain: *cookieDynamicDomain,
Target: *target, CookieExpiration: *cookieExpiration,
WebmasterEmail: *webmasterEmail, CookiePartitioned: *cookiePartitioned,
OpenGraph: policy.OpenGraph, RedirectDomains: redirectDomainsList,
CookieSecure: *cookieSecure, Target: *target,
CookieSameSite: parseSameSite(*cookieSameSite), WebmasterEmail: *webmasterEmail,
PublicUrl: *publicUrl, OpenGraph: policy.OpenGraph,
JWTRestrictionHeader: *jwtRestrictionHeader, CookieSecure: *cookieSecure,
DifficultyInJWT: *difficultyInJWT, CookieSameSite: parseSameSite(*cookieSameSite),
PublicUrl: *publicUrl,
JWTRestrictionHeader: *jwtRestrictionHeader,
DifficultyInJWT: *difficultyInJWT,
}) })
if err != nil { if err != nil {
log.Fatalf("can't construct libanubis.Server: %v", err) log.Fatalf("can't construct libanubis.Server: %v", err)

View file

@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Allow Renovate as an OCI registry client. - Allow Renovate as an OCI registry client.
- Properly handle 4in6 addresses so that IP matching works with those addresses. - Properly handle 4in6 addresses so that IP matching works with those addresses.
- Add support to simple Valkey/Redis cluster mode - Add support to simple Valkey/Redis cluster mode
- Open Graph passthrough now reuses the configured target Host/SNI/TLS settings, so metadata fetches succeed when the upstream certificate differs from the public domain. ([1283](https://github.com/TecharoHQ/anubis/pull/1283))
- Stabilize the CVE-2025-24369 regression test by always submitting an invalid proof instead of relying on random POW failures. - Stabilize the CVE-2025-24369 regression test by always submitting an invalid proof instead of relying on random POW failures.
## v1.23.1: Lyse Hext - Echo 1 ## v1.23.1: Lyse Hext - Echo 1

View file

@ -24,7 +24,7 @@ func TestCacheReturnsDefault(t *testing.T) {
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: false, ConsiderHost: false,
Override: want, Override: want,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
u, err := url.Parse("https://anubis.techaro.lol") u, err := url.Parse("https://anubis.techaro.lol")
if err != nil { if err != nil {
@ -52,7 +52,7 @@ func TestCheckCache(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
// Set up test data // Set up test data
urlStr := "http://example.com/page" urlStr := "http://example.com/page"
@ -115,7 +115,7 @@ func TestGetOGTags(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
// Parse the test server URL // Parse the test server URL
parsedURL, err := url.Parse(ts.URL) parsedURL, err := url.Parse(ts.URL)
@ -271,7 +271,7 @@ func TestGetOGTagsWithHostConsideration(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: tc.ogCacheConsiderHost, ConsiderHost: tc.ogCacheConsiderHost,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
for i, req := range tc.requests { for i, req := range tc.requests {
ogTags, err := cache.GetOGTags(t.Context(), parsedURL, req.host) ogTags, err := cache.GetOGTags(t.Context(), parsedURL, req.host)

View file

@ -27,16 +27,29 @@ func (c *OGTagCache) fetchHTMLDocumentWithCache(ctx context.Context, urlStr stri
} }
// Set the Host header to the original host // Set the Host header to the original host
if originalHost != "" { var hostForRequest string
req.Host = originalHost switch {
case c.targetHost != "":
hostForRequest = c.targetHost
case originalHost != "":
hostForRequest = originalHost
}
if hostForRequest != "" {
req.Host = hostForRequest
} }
// Add proxy headers // Add proxy headers
req.Header.Set("X-Forwarded-Proto", "https") req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("User-Agent", "Anubis-OGTag-Fetcher/1.0") // For tracking purposes req.Header.Set("User-Agent", "Anubis-OGTag-Fetcher/1.0") // For tracking purposes
serverName := hostForRequest
if serverName == "" {
serverName = req.URL.Hostname()
}
client := c.clientForSNI(serverName)
// Send the request // Send the request
resp, err := c.client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
var netErr net.Error var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() { if errors.As(err, &netErr) && netErr.Timeout() {

View file

@ -87,7 +87,7 @@ func TestFetchHTMLDocument(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
doc, err := cache.fetchHTMLDocument(t.Context(), ts.URL, "anything") doc, err := cache.fetchHTMLDocument(t.Context(), ts.URL, "anything")
if tt.expectError { if tt.expectError {
@ -118,7 +118,7 @@ func TestFetchHTMLDocumentInvalidURL(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
doc, err := cache.fetchHTMLDocument(t.Context(), "http://invalid.url.that.doesnt.exist.example", "anything") doc, err := cache.fetchHTMLDocument(t.Context(), "http://invalid.url.that.doesnt.exist.example", "anything")

View file

@ -111,7 +111,7 @@ func TestIntegrationGetOGTags(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
// Create URL for test // Create URL for test
testURL, _ := url.Parse(ts.URL) testURL, _ := url.Parse(ts.URL)

View file

@ -31,7 +31,7 @@ func BenchmarkGetTarget(b *testing.B) {
for _, tt := range tests { for _, tt := range tests {
b.Run(tt.name, func(b *testing.B) { b.Run(tt.name, func(b *testing.B) {
cache := NewOGTagCache(tt.target, config.OpenGraph{}, memory.New(b.Context())) cache := NewOGTagCache(tt.target, config.OpenGraph{}, memory.New(b.Context()), TargetOptions{})
urls := make([]*url.URL, len(tt.paths)) urls := make([]*url.URL, len(tt.paths))
for i, path := range tt.paths { for i, path := range tt.paths {
u, _ := url.Parse(path) u, _ := url.Parse(path)
@ -67,7 +67,7 @@ func BenchmarkExtractOGTags(b *testing.B) {
</head><body><div><p>Content</p></div></body></html>`, </head><body><div><p>Content</p></div></body></html>`,
} }
cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(b.Context())) cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(b.Context()), TargetOptions{})
docs := make([]*html.Node, len(htmlSamples)) docs := make([]*html.Node, len(htmlSamples))
for i, sample := range htmlSamples { for i, sample := range htmlSamples {
@ -85,7 +85,7 @@ func BenchmarkExtractOGTags(b *testing.B) {
// Memory usage test // Memory usage test
func TestMemoryUsage(t *testing.T) { func TestMemoryUsage(t *testing.T) {
cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(t.Context())) cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(t.Context()), TargetOptions{})
// Force GC and wait for it to complete // Force GC and wait for it to complete
runtime.GC() runtime.GC()

View file

@ -2,11 +2,13 @@ package ogtags
import ( import (
"context" "context"
"crypto/tls"
"log/slog" "log/slog"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"sync"
"time" "time"
"github.com/TecharoHQ/anubis/lib/policy/config" "github.com/TecharoHQ/anubis/lib/policy/config"
@ -22,21 +24,34 @@ const (
) )
type OGTagCache struct { type OGTagCache struct {
ogOverride map[string]string
targetURL *url.URL targetURL *url.URL
client *http.Client client *http.Client
ogOverride map[string]string transport *http.Transport
cache store.JSON[map[string]string] cache store.JSON[map[string]string]
// Pre-built strings for optimization // Pre-built strings for optimization
unixPrefix string // "http://unix" unixPrefix string // "http://unix"
approvedTags []string targetSNI string
targetHost string
approvedPrefixes []string approvedPrefixes []string
approvedTags []string
ogTimeToLive time.Duration ogTimeToLive time.Duration
ogCacheConsiderHost bool
ogPassthrough bool ogPassthrough bool
ogCacheConsiderHost bool
targetSNIAuto bool
insecureSkipVerify bool
sniClients map[string]*http.Client
transportMu sync.RWMutex
} }
func NewOGTagCache(target string, conf config.OpenGraph, backend store.Interface) *OGTagCache { type TargetOptions struct {
Host string
SNI string
InsecureSkipVerify bool
}
func NewOGTagCache(target string, conf config.OpenGraph, backend store.Interface, targetOpts TargetOptions) *OGTagCache {
// Predefined approved tags and prefixes // Predefined approved tags and prefixes
defaultApprovedTags := []string{"description", "keywords", "author"} defaultApprovedTags := []string{"description", "keywords", "author"}
defaultApprovedPrefixes := []string{"og:", "twitter:", "fediverse:"} defaultApprovedPrefixes := []string{"og:", "twitter:", "fediverse:"}
@ -62,20 +77,37 @@ func NewOGTagCache(target string, conf config.OpenGraph, backend store.Interface
} }
} }
client := &http.Client{ transport := http.DefaultTransport.(*http.Transport).Clone()
Timeout: httpTimeout,
}
// Configure custom transport for Unix sockets // Configure custom transport for Unix sockets
if parsedTargetURL.Scheme == "unix" { if parsedTargetURL.Scheme == "unix" {
socketPath := parsedTargetURL.Path // For unix scheme, path is the socket path socketPath := parsedTargetURL.Path // For unix scheme, path is the socket path
client.Transport = &http.Transport{ transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) {
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { return net.Dial("unix", socketPath)
return net.Dial("unix", socketPath)
},
} }
} }
targetSNIAuto := targetOpts.SNI == "auto"
if targetOpts.SNI != "" && !targetSNIAuto {
if transport.TLSClientConfig == nil {
transport.TLSClientConfig = &tls.Config{}
}
transport.TLSClientConfig.ServerName = targetOpts.SNI
}
if targetOpts.InsecureSkipVerify {
if transport.TLSClientConfig == nil {
transport.TLSClientConfig = &tls.Config{}
}
transport.TLSClientConfig.InsecureSkipVerify = true
}
client := &http.Client{
Timeout: httpTimeout,
Transport: transport,
}
return &OGTagCache{ return &OGTagCache{
cache: store.JSON[map[string]string]{ cache: store.JSON[map[string]string]{
Underlying: backend, Underlying: backend,
@ -89,7 +121,13 @@ func NewOGTagCache(target string, conf config.OpenGraph, backend store.Interface
approvedTags: defaultApprovedTags, approvedTags: defaultApprovedTags,
approvedPrefixes: defaultApprovedPrefixes, approvedPrefixes: defaultApprovedPrefixes,
client: client, client: client,
transport: transport,
unixPrefix: "http://unix", unixPrefix: "http://unix",
targetHost: targetOpts.Host,
targetSNI: targetOpts.SNI,
targetSNIAuto: targetSNIAuto,
insecureSkipVerify: targetOpts.InsecureSkipVerify,
sniClients: make(map[string]*http.Client),
} }
} }

View file

@ -48,7 +48,7 @@ func FuzzGetTarget(f *testing.F) {
} }
// Create cache - should not panic // Create cache - should not panic
cache := NewOGTagCache(target, config.OpenGraph{}, memory.New(context.Background())) cache := NewOGTagCache(target, config.OpenGraph{}, memory.New(context.Background()), TargetOptions{})
// Create URL // Create URL
u := &url.URL{ u := &url.URL{
@ -132,7 +132,7 @@ func FuzzExtractOGTags(f *testing.F) {
return return
} }
cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(context.Background())) cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(context.Background()), TargetOptions{})
// Should not panic // Should not panic
tags := cache.extractOGTags(doc) tags := cache.extractOGTags(doc)
@ -188,7 +188,7 @@ func FuzzGetTargetRoundTrip(f *testing.F) {
t.Skip() t.Skip()
} }
cache := NewOGTagCache(target, config.OpenGraph{}, memory.New(context.Background())) cache := NewOGTagCache(target, config.OpenGraph{}, memory.New(context.Background()), TargetOptions{})
u := &url.URL{Path: path, RawQuery: query} u := &url.URL{Path: path, RawQuery: query}
result := cache.getTarget(u) result := cache.getTarget(u)
@ -245,7 +245,7 @@ func FuzzExtractMetaTagInfo(f *testing.F) {
}, },
} }
cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(context.Background())) cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(context.Background()), TargetOptions{})
// Should not panic // Should not panic
property, content := cache.extractMetaTagInfo(node) property, content := cache.extractMetaTagInfo(node)
@ -298,7 +298,7 @@ func BenchmarkFuzzedGetTarget(b *testing.B) {
for _, input := range inputs { for _, input := range inputs {
b.Run(input.name, func(b *testing.B) { b.Run(input.name, func(b *testing.B) {
cache := NewOGTagCache(input.target, config.OpenGraph{}, memory.New(context.Background())) cache := NewOGTagCache(input.target, config.OpenGraph{}, memory.New(context.Background()), TargetOptions{})
u := &url.URL{Path: input.path, RawQuery: input.query} u := &url.URL{Path: input.path, RawQuery: input.query}
b.ResetTimer() b.ResetTimer()

View file

@ -2,15 +2,23 @@ package ogtags
import ( import (
"context" "context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors" "errors"
"fmt" "fmt"
"math/big"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@ -45,7 +53,7 @@ func TestNewOGTagCache(t *testing.T) {
Enabled: tt.ogPassthrough, Enabled: tt.ogPassthrough,
TimeToLive: tt.ogTimeToLive, TimeToLive: tt.ogTimeToLive,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
if cache == nil { if cache == nil {
t.Fatal("expected non-nil cache, got nil") t.Fatal("expected non-nil cache, got nil")
@ -85,7 +93,7 @@ func TestNewOGTagCache_UnixSocket(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: 5 * time.Minute, TimeToLive: 5 * time.Minute,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
if cache == nil { if cache == nil {
t.Fatal("expected non-nil cache, got nil") t.Fatal("expected non-nil cache, got nil")
@ -170,7 +178,7 @@ func TestGetTarget(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
u := &url.URL{ u := &url.URL{
Path: tt.path, Path: tt.path,
@ -243,7 +251,7 @@ func TestIntegrationGetOGTags_UnixSocket(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
// Create a dummy URL for the request (path and query matter) // Create a dummy URL for the request (path and query matter)
testReqURL, _ := url.Parse("/some/page?query=1") testReqURL, _ := url.Parse("/some/page?query=1")
@ -274,3 +282,244 @@ func TestIntegrationGetOGTags_UnixSocket(t *testing.T) {
t.Errorf("Expected cached OG tags %v, got %v", expectedTags, cachedTags) t.Errorf("Expected cached OG tags %v, got %v", expectedTags, cachedTags)
} }
} }
func TestGetOGTagsWithTargetHostOverride(t *testing.T) {
originalHost := "example.test"
overrideHost := "backend.internal"
seenHosts := make(chan string, 10)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seenHosts <- r.Host
w.Header().Set("Content-Type", "text/html")
fmt.Fprintln(w, `<!DOCTYPE html><html><head><meta property="og:title" content="HostOverride" /></head><body>ok</body></html>`)
}))
defer ts.Close()
targetURL, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("failed to parse server URL: %v", err)
}
conf := config.OpenGraph{
Enabled: true,
TimeToLive: time.Minute,
ConsiderHost: false,
}
t.Run("default host uses original", func(t *testing.T) {
cache := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{})
if _, err := cache.GetOGTags(t.Context(), targetURL, originalHost); err != nil {
t.Fatalf("GetOGTags failed: %v", err)
}
select {
case host := <-seenHosts:
if host != originalHost {
t.Fatalf("expected host %q, got %q", originalHost, host)
}
case <-time.After(time.Second):
t.Fatal("server did not receive request")
}
})
t.Run("override host respected", func(t *testing.T) {
cache := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{
Host: overrideHost,
})
if _, err := cache.GetOGTags(t.Context(), targetURL, originalHost); err != nil {
t.Fatalf("GetOGTags failed: %v", err)
}
select {
case host := <-seenHosts:
if host != overrideHost {
t.Fatalf("expected host %q, got %q", overrideHost, host)
}
case <-time.After(time.Second):
t.Fatal("server did not receive request")
}
})
}
func TestGetOGTagsWithInsecureSkipVerify(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
fmt.Fprintln(w, `<!DOCTYPE html><html><head><meta property="og:title" content="Self-Signed" /></head><body>hello</body></html>`)
})
ts := httptest.NewTLSServer(handler)
defer ts.Close()
parsedURL, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("failed to parse server URL: %v", err)
}
conf := config.OpenGraph{
Enabled: true,
TimeToLive: time.Minute,
ConsiderHost: false,
}
// Without skip verify we should get a TLS error
cacheStrict := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{})
if _, err := cacheStrict.GetOGTags(t.Context(), parsedURL, parsedURL.Host); err == nil {
t.Fatal("expected TLS verification error without InsecureSkipVerify")
}
cacheSkip := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{
InsecureSkipVerify: true,
})
tags, err := cacheSkip.GetOGTags(t.Context(), parsedURL, parsedURL.Host)
if err != nil {
t.Fatalf("expected successful fetch with InsecureSkipVerify, got: %v", err)
}
if tags["og:title"] != "Self-Signed" {
t.Fatalf("expected og:title to be %q, got %q", "Self-Signed", tags["og:title"])
}
}
func TestGetOGTagsWithTargetSNI(t *testing.T) {
originalHost := "hecate.test"
conf := config.OpenGraph{
Enabled: true,
TimeToLive: time.Minute,
ConsiderHost: false,
}
t.Run("explicit SNI override", func(t *testing.T) {
expectedSNI := "backend.internal"
ts, recorder := newSNIServer(t, `<!DOCTYPE html><html><head><meta property="og:title" content="SNI Works" /></head><body>ok</body></html>`)
defer ts.Close()
targetURL, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("failed to parse server URL: %v", err)
}
cacheExplicit := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{
SNI: expectedSNI,
InsecureSkipVerify: true,
})
if _, err := cacheExplicit.GetOGTags(t.Context(), targetURL, originalHost); err != nil {
t.Fatalf("expected successful fetch with explicit SNI, got: %v", err)
}
if got := recorder.last(); got != expectedSNI {
t.Fatalf("expected server to see SNI %q, got %q", expectedSNI, got)
}
})
t.Run("auto SNI uses original host", func(t *testing.T) {
ts, recorder := newSNIServer(t, `<!DOCTYPE html><html><head><meta property="og:title" content="SNI Auto" /></head><body>ok</body></html>`)
defer ts.Close()
targetURL, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("failed to parse server URL: %v", err)
}
cacheAuto := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{
SNI: "auto",
InsecureSkipVerify: true,
})
if _, err := cacheAuto.GetOGTags(t.Context(), targetURL, originalHost); err != nil {
t.Fatalf("expected successful fetch with auto SNI, got: %v", err)
}
if got := recorder.last(); got != originalHost {
t.Fatalf("expected server to see SNI %q with auto, got %q", originalHost, got)
}
})
t.Run("default SNI uses backend host", func(t *testing.T) {
ts, recorder := newSNIServer(t, `<!DOCTYPE html><html><head><meta property="og:title" content="SNI Default" /></head><body>ok</body></html>`)
defer ts.Close()
targetURL, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("failed to parse server URL: %v", err)
}
cacheDefault := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{
InsecureSkipVerify: true,
})
if _, err := cacheDefault.GetOGTags(t.Context(), targetURL, originalHost); err != nil {
t.Fatalf("expected successful fetch without explicit SNI, got: %v", err)
}
wantSNI := ""
if net.ParseIP(targetURL.Hostname()) == nil {
wantSNI = targetURL.Hostname()
}
if got := recorder.last(); got != wantSNI {
t.Fatalf("expected default SNI %q, got %q", wantSNI, got)
}
})
}
func newSNIServer(t *testing.T, body string) (*httptest.Server, *sniRecorder) {
t.Helper()
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
fmt.Fprint(w, body)
})
recorder := &sniRecorder{}
ts := httptest.NewUnstartedServer(handler)
cert := mustCertificateForHost(t, "sni.test")
ts.TLS = &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
recorder.record(hello.ServerName)
return &cert, nil
},
}
ts.StartTLS()
return ts, recorder
}
func mustCertificateForHost(t *testing.T, host string) tls.Certificate {
t.Helper()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate key: %v", err)
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: host,
},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
BasicConstraintsValid: true,
DNSNames: []string{host},
}
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
if err != nil {
t.Fatalf("failed to create certificate: %v", err)
}
return tls.Certificate{
Certificate: [][]byte{der},
PrivateKey: priv,
}
}
type sniRecorder struct {
mu sync.Mutex
names []string
}
func (r *sniRecorder) record(name string) {
r.mu.Lock()
defer r.mu.Unlock()
r.names = append(r.names, name)
}
func (r *sniRecorder) last() string {
r.mu.Lock()
defer r.mu.Unlock()
if len(r.names) == 0 {
return ""
}
return r.names[len(r.names)-1]
}

View file

@ -18,7 +18,7 @@ func TestExtractOGTags(t *testing.T) {
Enabled: false, Enabled: false,
ConsiderHost: false, ConsiderHost: false,
TimeToLive: time.Minute, TimeToLive: time.Minute,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
// Manually set approved tags/prefixes based on the user request for clarity // Manually set approved tags/prefixes based on the user request for clarity
testCache.approvedTags = []string{"description"} testCache.approvedTags = []string{"description"}
testCache.approvedPrefixes = []string{"og:"} testCache.approvedPrefixes = []string{"og:"}
@ -199,7 +199,7 @@ func TestExtractMetaTagInfo(t *testing.T) {
Enabled: false, Enabled: false,
ConsiderHost: false, ConsiderHost: false,
TimeToLive: time.Minute, TimeToLive: time.Minute,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
testCache.approvedTags = []string{"description"} testCache.approvedTags = []string{"description"}
testCache.approvedPrefixes = []string{"og:"} testCache.approvedPrefixes = []string{"og:"}

42
internal/ogtags/sni.go Normal file
View file

@ -0,0 +1,42 @@
package ogtags
import (
"crypto/tls"
"net/http"
)
// clientForSNI returns a cached client for the given server name, creating one if needed.
func (c *OGTagCache) clientForSNI(serverName string) *http.Client {
if !c.targetSNIAuto || serverName == "" {
return c.client
}
c.transportMu.RLock()
cli, ok := c.sniClients[serverName]
c.transportMu.RUnlock()
if ok {
return cli
}
c.transportMu.Lock()
defer c.transportMu.Unlock()
if cli, ok := c.sniClients[serverName]; ok {
return cli
}
tr := c.transport.Clone()
if tr.TLSClientConfig == nil {
tr.TLSClientConfig = &tls.Config{}
}
tr.TLSClientConfig.ServerName = serverName
if c.insecureSkipVerify {
tr.TLSClientConfig.InsecureSkipVerify = true
}
cli = &http.Client{
Timeout: httpTimeout,
Transport: tr,
}
c.sniClients[serverName] = cli
return cli
}

View file

@ -27,27 +27,30 @@ import (
) )
type Options struct { type Options struct {
Next http.Handler Next http.Handler
Policy *policy.ParsedConfig Policy *policy.ParsedConfig
Logger *slog.Logger Target string
OpenGraph config.OpenGraph TargetHost string
PublicUrl string TargetSNI string
CookieDomain string TargetInsecureSkipVerify bool
JWTRestrictionHeader string CookieDynamicDomain bool
BasePrefix string CookieDomain string
WebmasterEmail string CookieExpiration time.Duration
Target string CookiePartitioned bool
RedirectDomains []string BasePrefix string
ED25519PrivateKey ed25519.PrivateKey WebmasterEmail string
HS512Secret []byte RedirectDomains []string
CookieExpiration time.Duration ED25519PrivateKey ed25519.PrivateKey
CookieSameSite http.SameSite HS512Secret []byte
ServeRobotsTXT bool StripBasePrefix bool
CookieSecure bool OpenGraph config.OpenGraph
StripBasePrefix bool ServeRobotsTXT bool
CookiePartitioned bool CookieSecure bool
CookieDynamicDomain bool CookieSameSite http.SameSite
DifficultyInJWT bool Logger *slog.Logger
PublicUrl string
JWTRestrictionHeader string
DifficultyInJWT bool
} }
func LoadPoliciesOrDefault(ctx context.Context, fname string, defaultDifficulty int) (*policy.ParsedConfig, error) { func LoadPoliciesOrDefault(ctx context.Context, fname string, defaultDifficulty int) (*policy.ParsedConfig, error) {
@ -116,9 +119,13 @@ func New(opts Options) (*Server, error) {
hs512Secret: opts.HS512Secret, hs512Secret: opts.HS512Secret,
policy: opts.Policy, policy: opts.Policy,
opts: opts, opts: opts,
OGTags: ogtags.NewOGTagCache(opts.Target, opts.Policy.OpenGraph, opts.Policy.Store), OGTags: ogtags.NewOGTagCache(opts.Target, opts.Policy.OpenGraph, opts.Policy.Store, ogtags.TargetOptions{
store: opts.Policy.Store, Host: opts.TargetHost,
logger: opts.Logger, SNI: opts.TargetSNI,
InsecureSkipVerify: opts.TargetInsecureSkipVerify,
}),
store: opts.Policy.Store,
logger: opts.Logger,
} }
mux := http.NewServeMux() mux := http.NewServeMux()

View file

@ -62,11 +62,14 @@ type BotConfig struct {
Expression *ExpressionOrList `json:"expression,omitempty" yaml:"expression,omitempty"` Expression *ExpressionOrList `json:"expression,omitempty" yaml:"expression,omitempty"`
Challenge *ChallengeRules `json:"challenge,omitempty" yaml:"challenge,omitempty"` Challenge *ChallengeRules `json:"challenge,omitempty" yaml:"challenge,omitempty"`
Weight *Weight `json:"weight,omitempty" yaml:"weight,omitempty"` Weight *Weight `json:"weight,omitempty" yaml:"weight,omitempty"`
GeoIP *GeoIP `json:"geoip,omitempty"`
ASNs *ASNs `json:"asns,omitempty"` // Thoth features
Name string `json:"name" yaml:"name"` GeoIP *GeoIP `json:"geoip,omitempty"`
Action Rule `json:"action" yaml:"action"` ASNs *ASNs `json:"asns,omitempty"`
RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"`
Name string `json:"name" yaml:"name"`
Action Rule `json:"action" yaml:"action"`
RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"`
} }
func (b BotConfig) Zero() bool { func (b BotConfig) Zero() bool {