feat(og): Foward host header (#370)

* feat(ogtags): enhance target URL handling for OGTagCache, support Unix sockets

Closes: #323 #319
Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* docs: update CHANGELOG.md to include Opengraph passthrough support for Unix sockets

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* docs: update CHANGELOG.md to include Opengraph passthrough support for Unix sockets

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* feat(ogtags): add option to consider host in Open Graph tag cache key

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* feat(ogtags): add option to consider host in OG tag cache key

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* test(ogtags): enhance tests for OGTagCache with host consideration scenarios

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* refactor(ogtags): extract constants for HTTP timeout and max content length

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* refactor(ogtags): restore fetchHTMLDocument method for cache key generation

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* refactor(ogtags): replace maxContentLength field with constant and ensure HTTP scheme is set correctly

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

* fix(fetch): add proxy headers

Signed-off-by: Jason Cameron <git@jasoncameron.dev>

---------

Signed-off-by: Jason Cameron <git@jasoncameron.dev>
This commit is contained in:
Jason Cameron 2025-04-29 08:20:04 -04:00 committed by GitHub
parent 7a20a46b0d
commit 4184b42282
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 484 additions and 91 deletions

View file

@ -8,18 +8,21 @@ import (
)
// GetOGTags is the main function that retrieves Open Graph tags for a URL
func (c *OGTagCache) GetOGTags(url *url.URL) (map[string]string, error) {
func (c *OGTagCache) GetOGTags(url *url.URL, originalHost string) (map[string]string, error) {
if url == nil {
return nil, errors.New("nil URL provided, cannot fetch OG tags")
}
urlStr := c.getTarget(url)
target := c.getTarget(url)
cacheKey := c.generateCacheKey(target, originalHost)
// Check cache first
if cachedTags := c.checkCache(urlStr); cachedTags != nil {
if cachedTags := c.checkCache(cacheKey); cachedTags != nil {
return cachedTags, nil
}
// Fetch HTML content
doc, err := c.fetchHTMLDocument(urlStr)
// Fetch HTML content, passing the original host
doc, err := c.fetchHTMLDocumentWithCache(target, originalHost, cacheKey)
if errors.Is(err, syscall.ECONNREFUSED) {
slog.Debug("Connection refused, returning empty tags")
return nil, nil
@ -35,17 +38,28 @@ func (c *OGTagCache) GetOGTags(url *url.URL) (map[string]string, error) {
ogTags := c.extractOGTags(doc)
// Store in cache
c.cache.Set(urlStr, ogTags, c.ogTimeToLive)
c.cache.Set(cacheKey, ogTags, c.ogTimeToLive)
return ogTags, nil
}
func (c *OGTagCache) generateCacheKey(target string, originalHost string) string {
var cacheKey string
if c.ogCacheConsiderHost {
cacheKey = target + "|" + originalHost
} else {
cacheKey = target
}
return cacheKey
}
// checkCache checks if we have the tags cached and returns them if so
func (c *OGTagCache) checkCache(urlStr string) map[string]string {
if cachedTags, ok := c.cache.Get(urlStr); ok {
func (c *OGTagCache) checkCache(cacheKey string) map[string]string {
if cachedTags, ok := c.cache.Get(cacheKey); ok {
slog.Debug("cache hit", "tags", cachedTags)
return cachedTags
}
slog.Debug("cache miss", "url", urlStr)
slog.Debug("cache miss", "url", cacheKey)
return nil
}

View file

@ -4,12 +4,13 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"testing"
"time"
)
func TestCheckCache(t *testing.T) {
cache := NewOGTagCache("http://example.com", true, time.Minute)
cache := NewOGTagCache("http://example.com", true, time.Minute, false)
// Set up test data
urlStr := "http://example.com/page"
@ -17,18 +18,19 @@ func TestCheckCache(t *testing.T) {
"og:title": "Test Title",
"og:description": "Test Description",
}
cacheKey := cache.generateCacheKey(urlStr, "example.com")
// Test cache miss
tags := cache.checkCache(urlStr)
tags := cache.checkCache(cacheKey)
if tags != nil {
t.Errorf("expected nil tags on cache miss, got %v", tags)
}
// Manually add to cache
cache.cache.Set(urlStr, expectedTags, time.Minute)
cache.cache.Set(cacheKey, expectedTags, time.Minute)
// Test cache hit
tags = cache.checkCache(urlStr)
tags = cache.checkCache(cacheKey)
if tags == nil {
t.Fatal("expected non-nil tags on cache hit, got nil")
}
@ -67,7 +69,7 @@ func TestGetOGTags(t *testing.T) {
defer ts.Close()
// Create an instance of OGTagCache with a short TTL for testing
cache := NewOGTagCache(ts.URL, true, 1*time.Minute)
cache := NewOGTagCache(ts.URL, true, 1*time.Minute, false)
// Parse the test server URL
parsedURL, err := url.Parse(ts.URL)
@ -76,7 +78,8 @@ func TestGetOGTags(t *testing.T) {
}
// Test fetching OG tags from the test server
ogTags, err := cache.GetOGTags(parsedURL)
// Pass the host from the parsed test server URL
ogTags, err := cache.GetOGTags(parsedURL, parsedURL.Host)
if err != nil {
t.Fatalf("failed to get OG tags: %v", err)
}
@ -95,13 +98,15 @@ func TestGetOGTags(t *testing.T) {
}
// Test fetching OG tags from the cache
ogTags, err = cache.GetOGTags(parsedURL)
// Pass the host from the parsed test server URL
ogTags, err = cache.GetOGTags(parsedURL, parsedURL.Host)
if err != nil {
t.Fatalf("failed to get OG tags from cache: %v", err)
}
// Test fetching OG tags from the cache (3rd time)
newOgTags, err := cache.GetOGTags(parsedURL)
// Pass the host from the parsed test server URL
newOgTags, err := cache.GetOGTags(parsedURL, parsedURL.Host)
if err != nil {
t.Fatalf("failed to get OG tags from cache: %v", err)
}
@ -120,3 +125,116 @@ func TestGetOGTags(t *testing.T) {
}
}
// TestGetOGTagsWithHostConsideration tests the behavior of the cache with and without host consideration and for multiple hosts in a theoretical setup.
func TestGetOGTagsWithHostConsideration(t *testing.T) {
var loadCount int // Counter to track how many times the test route is loaded
// Create a test server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
loadCount++ // Increment counter on each request to the server
w.Header().Set("Content-Type", "text/html")
w.Write([]byte(`
<!DOCTYPE html>
<html>
<head>
<meta property="og:title" content="Test Title" />
<meta property="og:description" content="Test Description" />
</head>
<body><p>Content</p></body>
</html>
`))
}))
defer ts.Close()
parsedURL, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("failed to parse test server URL: %v", err)
}
expectedTags := map[string]string{
"og:title": "Test Title",
"og:description": "Test Description",
}
testCases := []struct {
name string
ogCacheConsiderHost bool
requests []struct {
host string
expectedLoadCount int // Expected load count *after* this request
}
}{
{
name: "Host Not Considered - Same Host",
ogCacheConsiderHost: false,
requests: []struct {
host string
expectedLoadCount int
}{
{"host1", 1}, // First request, miss
{"host1", 1}, // Second request, same host, hit (host ignored)
},
},
{
name: "Host Not Considered - Different Host",
ogCacheConsiderHost: false,
requests: []struct {
host string
expectedLoadCount int
}{
{"host1", 1}, // First request, miss
{"host2", 1}, // Second request, different host, hit (host ignored)
},
},
{
name: "Host Considered - Same Host",
ogCacheConsiderHost: true,
requests: []struct {
host string
expectedLoadCount int
}{
{"host1", 1}, // First request, miss
{"host1", 1}, // Second request, same host, hit
},
},
{
name: "Host Considered - Different Host",
ogCacheConsiderHost: true,
requests: []struct {
host string
expectedLoadCount int
}{
{"host1", 1}, // First request, miss
{"host2", 2}, // Second request, different host, miss
{"host2", 2}, // Third request, same as second, hit
{"host1", 2}, // Fourth request, same as first, hit
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
loadCount = 0 // Reset load count for each test case
cache := NewOGTagCache(ts.URL, true, 1*time.Minute, tc.ogCacheConsiderHost)
for i, req := range tc.requests {
ogTags, err := cache.GetOGTags(parsedURL, req.host)
if err != nil {
t.Errorf("Request %d (host: %s): unexpected error: %v", i+1, req.host, err)
continue // Skip further checks for this request if error occurred
}
// Verify tags are correct (should always be the same in this setup)
if !reflect.DeepEqual(ogTags, expectedTags) {
t.Errorf("Request %d (host: %s): expected tags %v, got %v", i+1, req.host, expectedTags, ogTags)
}
// Verify the load count to check cache hit/miss behavior
if loadCount != req.expectedLoadCount {
t.Errorf("Request %d (host: %s): expected load count %d, got %d (cache hit/miss mismatch)", i+1, req.host, req.expectedLoadCount, loadCount)
}
}
})
}
}

View file

@ -1,6 +1,7 @@
package ogtags
import (
"context"
"errors"
"fmt"
"golang.org/x/net/html"
@ -16,17 +17,35 @@ var (
emptyMap = map[string]string{} // used to indicate an empty result in the cache. Can't use nil as it would be a cache miss.
)
func (c *OGTagCache) fetchHTMLDocument(urlStr string) (*html.Node, error) {
resp, err := c.client.Get(urlStr)
// fetchHTMLDocumentWithCache fetches the HTML document from the given URL string,
// preserving the original host header.
func (c *OGTagCache) fetchHTMLDocumentWithCache(urlStr string, originalHost string, cacheKey string) (*html.Node, error) {
req, err := http.NewRequestWithContext(context.Background(), "GET", urlStr, nil)
if err != nil {
return nil, fmt.Errorf("failed to create http request: %w", err)
}
// Set the Host header to the original host
if originalHost != "" {
req.Host = originalHost
}
// Add proxy headers
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("User-Agent", "Anubis-OGTag-Fetcher/1.0") // For tracking purposes
// Send the request
resp, err := c.client.Do(req)
if err != nil {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
slog.Debug("og: request timed out", "url", urlStr)
c.cache.Set(urlStr, emptyMap, c.ogTimeToLive/2) // Cache empty result for half the TTL to not spam the server
c.cache.Set(cacheKey, emptyMap, c.ogTimeToLive/2) // Cache empty result for half the TTL to not spam the server
}
return nil, fmt.Errorf("http get failed: %w", err)
}
// this defer will call MaxBytesReader's Close, which closes the original body.
// Ensure the response body is closed
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
@ -36,19 +55,17 @@ func (c *OGTagCache) fetchHTMLDocument(urlStr string) (*html.Node, error) {
if resp.StatusCode != http.StatusOK {
slog.Debug("og: received non-OK status code", "url", urlStr, "status", resp.StatusCode)
c.cache.Set(urlStr, emptyMap, c.ogTimeToLive) // Cache empty result for non-successful status codes
c.cache.Set(cacheKey, emptyMap, c.ogTimeToLive) // Cache empty result for non-successful status codes
return nil, fmt.Errorf("%w: page not found", ErrOgHandled)
}
// Check content type
ct := resp.Header.Get("Content-Type")
if ct == "" {
// assume non html body
return nil, fmt.Errorf("missing Content-Type header")
} else {
mediaType, _, err := mime.ParseMediaType(ct)
if err != nil {
// Malformed Content-Type header
slog.Debug("og: malformed Content-Type header", "url", urlStr, "contentType", ct)
return nil, fmt.Errorf("%w malformed Content-Type header: %w", ErrOgHandled, err)
}
@ -59,17 +76,16 @@ func (c *OGTagCache) fetchHTMLDocument(urlStr string) (*html.Node, error) {
}
}
resp.Body = http.MaxBytesReader(nil, resp.Body, c.maxContentLength)
resp.Body = http.MaxBytesReader(nil, resp.Body, maxContentLength)
doc, err := html.Parse(resp.Body)
if err != nil {
// Check if the error is specifically because the limit was exceeded
var maxBytesErr *http.MaxBytesError
if errors.As(err, &maxBytesErr) {
slog.Debug("og: content exceeded max length", "url", urlStr, "limit", c.maxContentLength)
return nil, fmt.Errorf("content too large: exceeded %d bytes", c.maxContentLength)
slog.Debug("og: content exceeded max length", "url", urlStr, "limit", maxContentLength)
return nil, fmt.Errorf("content too large: exceeded %d bytes", maxContentLength)
}
// parsing error (e.g., malformed HTML)
return nil, fmt.Errorf("failed to parse HTML: %w", err)
}

View file

@ -2,6 +2,7 @@ package ogtags
import (
"fmt"
"golang.org/x/net/html"
"io"
"net/http"
"net/http/httptest"
@ -78,8 +79,8 @@ func TestFetchHTMLDocument(t *testing.T) {
}))
defer ts.Close()
cache := NewOGTagCache("", true, time.Minute)
doc, err := cache.fetchHTMLDocument(ts.URL)
cache := NewOGTagCache("", true, time.Minute, false)
doc, err := cache.fetchHTMLDocument(ts.URL, "anything")
if tt.expectError {
if err == nil {
@ -105,9 +106,9 @@ func TestFetchHTMLDocumentInvalidURL(t *testing.T) {
t.Skip("test requires theoretical network egress")
}
cache := NewOGTagCache("", true, time.Minute)
cache := NewOGTagCache("", true, time.Minute, false)
doc, err := cache.fetchHTMLDocument("http://invalid.url.that.doesnt.exist.example")
doc, err := cache.fetchHTMLDocument("http://invalid.url.that.doesnt.exist.example", "anything")
if err == nil {
t.Error("expected error for invalid URL, got nil")
@ -117,3 +118,9 @@ func TestFetchHTMLDocumentInvalidURL(t *testing.T) {
t.Error("expected nil document for invalid URL, got non-nil")
}
}
// fetchHTMLDocument allows you to call fetchHTMLDocumentWithCache without a duplicate generateCacheKey call
func (c *OGTagCache) fetchHTMLDocument(urlStr string, originalHost string) (*html.Node, error) {
cacheKey := c.generateCacheKey(urlStr, originalHost)
return c.fetchHTMLDocumentWithCache(urlStr, originalHost, cacheKey)
}

View file

@ -104,7 +104,7 @@ func TestIntegrationGetOGTags(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create cache instance
cache := NewOGTagCache(ts.URL, true, 1*time.Minute)
cache := NewOGTagCache(ts.URL, true, 1*time.Minute, false)
// Create URL for test
testURL, _ := url.Parse(ts.URL)
@ -112,7 +112,8 @@ func TestIntegrationGetOGTags(t *testing.T) {
testURL.RawQuery = tc.query
// Get OG tags
ogTags, err := cache.GetOGTags(testURL)
// Pass the host from the test URL
ogTags, err := cache.GetOGTags(testURL, testURL.Host)
// Check error expectation
if tc.expectError {
@ -139,7 +140,8 @@ func TestIntegrationGetOGTags(t *testing.T) {
}
// Test cache retrieval
cachedOGTags, err := cache.GetOGTags(testURL)
// Pass the host from the test URL
cachedOGTags, err := cache.GetOGTags(testURL, testURL.Host)
if err != nil {
t.Fatalf("failed to get OG tags from cache: %v", err)
}

View file

@ -1,51 +1,111 @@
package ogtags
import (
"context"
"log/slog"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/TecharoHQ/anubis/decaymap"
)
const (
maxContentLength = 16 << 20 // 16 MiB in bytes, if there is a reasonable reason that you need more than this...Why?
httpTimeout = 5 * time.Second /*todo: make this configurable?*/
)
type OGTagCache struct {
cache *decaymap.Impl[string, map[string]string]
target string
ogPassthrough bool
ogTimeToLive time.Duration
approvedTags []string
approvedPrefixes []string
client *http.Client
maxContentLength int64
cache *decaymap.Impl[string, map[string]string]
targetURL *url.URL
ogCacheConsiderHost bool
ogPassthrough bool
ogTimeToLive time.Duration
approvedTags []string
approvedPrefixes []string
client *http.Client
}
func NewOGTagCache(target string, ogPassthrough bool, ogTimeToLive time.Duration) *OGTagCache {
func NewOGTagCache(target string, ogPassthrough bool, ogTimeToLive time.Duration, ogTagsConsiderHost bool) *OGTagCache {
// Predefined approved tags and prefixes
// In the future, these could come from configuration
defaultApprovedTags := []string{"description", "keywords", "author"}
defaultApprovedPrefixes := []string{"og:", "twitter:", "fediverse:"}
client := &http.Client{
Timeout: 5 * time.Second, /*make this configurable?*/
var parsedTargetURL *url.URL
var err error
if target == "" {
// Default to localhost if target is empty
parsedTargetURL, _ = url.Parse("http://localhost")
} else {
parsedTargetURL, err = url.Parse(target)
if err != nil {
slog.Debug("og: failed to parse target URL, treating as non-unix", "target", target, "error", err)
// If parsing fails, treat it as a non-unix target for backward compatibility or default behavior
// For now, assume it's not a scheme issue but maybe an invalid char, etc.
// A simple string target might be intended if it's not a full URL.
parsedTargetURL = &url.URL{Scheme: "http", Host: target} // Assume http if scheme missing and host-like
if !strings.Contains(target, "://") && !strings.HasPrefix(target, "unix:") {
// If it looks like just a host/host:port (and not unix), prepend http:// (todo: is this bad...? Trace path to see if i can yell at user to do it right)
parsedTargetURL, _ = url.Parse("http://" + target) // fetch cares about scheme but anubis doesn't
}
}
}
const maxContentLength = 16 << 20 // 16 MiB in bytes
client := &http.Client{
Timeout: httpTimeout,
}
// Configure custom transport for Unix sockets
if parsedTargetURL.Scheme == "unix" {
socketPath := parsedTargetURL.Path // For unix scheme, path is the socket path
client.Transport = &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", socketPath)
},
}
}
return &OGTagCache{
cache: decaymap.New[string, map[string]string](),
target: target,
ogPassthrough: ogPassthrough,
ogTimeToLive: ogTimeToLive,
approvedTags: defaultApprovedTags,
approvedPrefixes: defaultApprovedPrefixes,
client: client,
maxContentLength: maxContentLength,
cache: decaymap.New[string, map[string]string](),
targetURL: parsedTargetURL, // Store the parsed URL
ogPassthrough: ogPassthrough,
ogTimeToLive: ogTimeToLive,
ogCacheConsiderHost: ogTagsConsiderHost, // todo: refactor to be a separate struct
approvedTags: defaultApprovedTags,
approvedPrefixes: defaultApprovedPrefixes,
client: client,
}
}
// getTarget constructs the target URL string for fetching OG tags.
// For Unix sockets, it creates a "fake" HTTP URL that the custom dialer understands.
func (c *OGTagCache) getTarget(u *url.URL) string {
return c.target + u.Path
if c.targetURL.Scheme == "unix" {
// The custom dialer ignores the host, but we need a valid http URL structure.
// Use "unix" as a placeholder host. Path and Query from original request are appended.
fakeURL := &url.URL{
Scheme: "http", // Scheme must be http/https for client.Get
Host: "unix", // Arbitrary host, ignored by custom dialer
Path: u.Path,
RawQuery: u.RawQuery,
}
return fakeURL.String()
}
// For regular http/https targets
target := *c.targetURL // Make a copy
target.Path = u.Path
target.RawQuery = u.RawQuery
return target.String()
}
func (c *OGTagCache) Cleanup() {
c.cache.Cleanup()
if c.cache != nil {
c.cache.Cleanup()
}
}

View file

@ -1,7 +1,16 @@
package ogtags
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
"time"
)
@ -29,14 +38,23 @@ func TestNewOGTagCache(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache := NewOGTagCache(tt.target, tt.ogPassthrough, tt.ogTimeToLive)
cache := NewOGTagCache(tt.target, tt.ogPassthrough, tt.ogTimeToLive, false)
if cache == nil {
t.Fatal("expected non-nil cache, got nil")
}
if cache.target != tt.target {
t.Errorf("expected target %s, got %s", tt.target, cache.target)
// Check the parsed targetURL, handling the default case for empty target
expectedURLStr := tt.target
if tt.target == "" {
// Default behavior when target is empty is now http://localhost
expectedURLStr = "http://localhost"
} else if !strings.Contains(tt.target, "://") && !strings.HasPrefix(tt.target, "unix:") {
// Handle case where target is just host or host:port (and not unix)
expectedURLStr = "http://" + tt.target
}
if cache.targetURL.String() != expectedURLStr {
t.Errorf("expected targetURL %s, got %s", expectedURLStr, cache.targetURL.String())
}
if cache.ogPassthrough != tt.ogPassthrough {
@ -50,6 +68,45 @@ func TestNewOGTagCache(t *testing.T) {
}
}
// TestNewOGTagCache_UnixSocket specifically tests unix socket initialization
func TestNewOGTagCache_UnixSocket(t *testing.T) {
tempDir := t.TempDir()
socketPath := filepath.Join(tempDir, "test.sock")
target := "unix://" + socketPath
cache := NewOGTagCache(target, true, 5*time.Minute, false)
if cache == nil {
t.Fatal("expected non-nil cache, got nil")
}
if cache.targetURL.Scheme != "unix" {
t.Errorf("expected targetURL scheme 'unix', got '%s'", cache.targetURL.Scheme)
}
if cache.targetURL.Path != socketPath {
t.Errorf("expected targetURL path '%s', got '%s'", socketPath, cache.targetURL.Path)
}
// Check if the client transport is configured for Unix sockets
transport, ok := cache.client.Transport.(*http.Transport)
if !ok {
t.Fatalf("expected client transport to be *http.Transport, got %T", cache.client.Transport)
}
if transport.DialContext == nil {
t.Fatal("expected client transport DialContext to be non-nil for unix socket")
}
// Attempt a dummy dial to see if it uses the correct path (optional, more involved check)
dummyConn, err := transport.DialContext(context.Background(), "", "")
if err == nil {
dummyConn.Close()
t.Log("DialContext seems functional, but couldn't verify path without a listener")
} else if !strings.Contains(err.Error(), "connect: connection refused") && !strings.Contains(err.Error(), "connect: no such file or directory") {
// We expect connection refused or not found if nothing is listening
t.Errorf("DialContext failed with unexpected error: %v", err)
}
}
func TestGetTarget(t *testing.T) {
tests := []struct {
name string
@ -66,24 +123,39 @@ func TestGetTarget(t *testing.T) {
expected: "http://example.com",
},
{
name: "With complex path",
target: "http://example.com",
path: "/pag(#*((#@)ΓΓΓΓe/Γ",
query: "id=123",
expected: "http://example.com/pag(#*((#@)ΓΓΓΓe/Γ",
name: "With complex path",
target: "http://example.com",
path: "/pag(#*((#@)ΓΓΓΓe/Γ",
query: "id=123",
// Expect URL encoding and query parameter
expected: "http://example.com/pag%28%23%2A%28%28%23@%29%CE%93%CE%93%CE%93%CE%93e/%CE%93?id=123",
},
{
name: "With query and path",
target: "http://example.com",
path: "/page",
query: "id=123",
expected: "http://example.com/page",
expected: "http://example.com/page?id=123",
},
{
name: "Unix socket target",
target: "unix:/tmp/anubis.sock",
path: "/some/path",
query: "key=value&flag=true",
expected: "http://unix/some/path?key=value&flag=true", // Scheme becomes http, host is 'unix'
},
{
name: "Unix socket target with ///",
target: "unix:///var/run/anubis.sock",
path: "/",
query: "",
expected: "http://unix/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache := NewOGTagCache(tt.target, false, time.Minute)
cache := NewOGTagCache(tt.target, false, time.Minute, false)
u := &url.URL{
Path: tt.path,
@ -98,3 +170,86 @@ func TestGetTarget(t *testing.T) {
})
}
}
// TestIntegrationGetOGTags_UnixSocket tests fetching OG tags via a Unix socket.
func TestIntegrationGetOGTags_UnixSocket(t *testing.T) {
tempDir := t.TempDir()
socketPath := filepath.Join(tempDir, "anubis-test.sock")
// Ensure the socket does not exist initially
_ = os.Remove(socketPath)
// Create a simple HTTP server listening on the Unix socket
listener, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatalf("Failed to listen on unix socket %s: %v", socketPath, err)
}
defer func(listener net.Listener, socketPath string) {
if listener != nil {
if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
t.Logf("Error closing listener: %v", err)
}
}
if _, err := os.Stat(socketPath); err == nil {
if err := os.Remove(socketPath); err != nil {
t.Logf("Error removing socket file %s: %v", socketPath, err)
}
}
}(listener, socketPath)
server := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
fmt.Fprintln(w, `<!DOCTYPE html><html><head><meta property="og:title" content="Unix Socket Test" /></head><body>Test</body></html>`)
}),
}
go func() {
if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
t.Logf("Unix socket server error: %v", err)
}
}()
defer func(server *http.Server, ctx context.Context) {
err := server.Shutdown(ctx)
if err != nil {
t.Logf("Error shutting down server: %v", err)
}
}(server, context.Background()) // Ensure server is shut down
// Wait a moment for the server to start
time.Sleep(100 * time.Millisecond)
// Create cache instance pointing to the Unix socket
targetURL := "unix://" + socketPath
cache := NewOGTagCache(targetURL, true, 1*time.Minute, false)
// Create a dummy URL for the request (path and query matter)
testReqURL, _ := url.Parse("/some/page?query=1")
// Get OG tags
// Pass an empty string for host, as it's irrelevant for unix sockets
ogTags, err := cache.GetOGTags(testReqURL, "")
if err != nil {
t.Fatalf("GetOGTags failed for unix socket: %v", err)
}
expectedTags := map[string]string{
"og:title": "Unix Socket Test",
}
if !reflect.DeepEqual(ogTags, expectedTags) {
t.Errorf("Expected OG tags %v, got %v", expectedTags, ogTags)
}
// Test cache retrieval (should hit cache)
// Pass an empty string for host
cachedTags, err := cache.GetOGTags(testReqURL, "")
if err != nil {
t.Fatalf("GetOGTags (cache hit) failed for unix socket: %v", err)
}
if !reflect.DeepEqual(cachedTags, expectedTags) {
t.Errorf("Expected cached OG tags %v, got %v", expectedTags, cachedTags)
}
}

View file

@ -12,7 +12,7 @@ import (
// TestExtractOGTags updated with correct expectations based on filtering logic
func TestExtractOGTags(t *testing.T) {
// Use a cache instance that reflects the default approved lists
testCache := NewOGTagCache("", false, time.Minute)
testCache := NewOGTagCache("", false, time.Minute, false)
// Manually set approved tags/prefixes based on the user request for clarity
testCache.approvedTags = []string{"description"}
testCache.approvedPrefixes = []string{"og:"}
@ -189,7 +189,7 @@ func TestIsOGMetaTag(t *testing.T) {
func TestExtractMetaTagInfo(t *testing.T) {
// Use a cache instance that reflects the default approved lists
testCache := NewOGTagCache("", false, time.Minute)
testCache := NewOGTagCache("", false, time.Minute, false)
testCache.approvedTags = []string{"description"}
testCache.approvedPrefixes = []string{"og:"}