feat(og): Foward host header (#370)
* feat(ogtags): enhance target URL handling for OGTagCache, support Unix sockets Closes: #323 #319 Signed-off-by: Jason Cameron <git@jasoncameron.dev> * docs: update CHANGELOG.md to include Opengraph passthrough support for Unix sockets Signed-off-by: Jason Cameron <git@jasoncameron.dev> * docs: update CHANGELOG.md to include Opengraph passthrough support for Unix sockets Signed-off-by: Jason Cameron <git@jasoncameron.dev> * feat(ogtags): add option to consider host in Open Graph tag cache key Signed-off-by: Jason Cameron <git@jasoncameron.dev> * feat(ogtags): add option to consider host in OG tag cache key Signed-off-by: Jason Cameron <git@jasoncameron.dev> * test(ogtags): enhance tests for OGTagCache with host consideration scenarios Signed-off-by: Jason Cameron <git@jasoncameron.dev> * refactor(ogtags): extract constants for HTTP timeout and max content length Signed-off-by: Jason Cameron <git@jasoncameron.dev> * refactor(ogtags): restore fetchHTMLDocument method for cache key generation Signed-off-by: Jason Cameron <git@jasoncameron.dev> * refactor(ogtags): replace maxContentLength field with constant and ensure HTTP scheme is set correctly Signed-off-by: Jason Cameron <git@jasoncameron.dev> * fix(fetch): add proxy headers Signed-off-by: Jason Cameron <git@jasoncameron.dev> --------- Signed-off-by: Jason Cameron <git@jasoncameron.dev>
This commit is contained in:
parent
7a20a46b0d
commit
4184b42282
14 changed files with 484 additions and 91 deletions
|
|
@ -8,18 +8,21 @@ import (
|
|||
)
|
||||
|
||||
// GetOGTags is the main function that retrieves Open Graph tags for a URL
|
||||
func (c *OGTagCache) GetOGTags(url *url.URL) (map[string]string, error) {
|
||||
func (c *OGTagCache) GetOGTags(url *url.URL, originalHost string) (map[string]string, error) {
|
||||
if url == nil {
|
||||
return nil, errors.New("nil URL provided, cannot fetch OG tags")
|
||||
}
|
||||
urlStr := c.getTarget(url)
|
||||
|
||||
target := c.getTarget(url)
|
||||
cacheKey := c.generateCacheKey(target, originalHost)
|
||||
|
||||
// Check cache first
|
||||
if cachedTags := c.checkCache(urlStr); cachedTags != nil {
|
||||
if cachedTags := c.checkCache(cacheKey); cachedTags != nil {
|
||||
return cachedTags, nil
|
||||
}
|
||||
|
||||
// Fetch HTML content
|
||||
doc, err := c.fetchHTMLDocument(urlStr)
|
||||
// Fetch HTML content, passing the original host
|
||||
doc, err := c.fetchHTMLDocumentWithCache(target, originalHost, cacheKey)
|
||||
if errors.Is(err, syscall.ECONNREFUSED) {
|
||||
slog.Debug("Connection refused, returning empty tags")
|
||||
return nil, nil
|
||||
|
|
@ -35,17 +38,28 @@ func (c *OGTagCache) GetOGTags(url *url.URL) (map[string]string, error) {
|
|||
ogTags := c.extractOGTags(doc)
|
||||
|
||||
// Store in cache
|
||||
c.cache.Set(urlStr, ogTags, c.ogTimeToLive)
|
||||
c.cache.Set(cacheKey, ogTags, c.ogTimeToLive)
|
||||
|
||||
return ogTags, nil
|
||||
}
|
||||
|
||||
func (c *OGTagCache) generateCacheKey(target string, originalHost string) string {
|
||||
var cacheKey string
|
||||
|
||||
if c.ogCacheConsiderHost {
|
||||
cacheKey = target + "|" + originalHost
|
||||
} else {
|
||||
cacheKey = target
|
||||
}
|
||||
return cacheKey
|
||||
}
|
||||
|
||||
// checkCache checks if we have the tags cached and returns them if so
|
||||
func (c *OGTagCache) checkCache(urlStr string) map[string]string {
|
||||
if cachedTags, ok := c.cache.Get(urlStr); ok {
|
||||
func (c *OGTagCache) checkCache(cacheKey string) map[string]string {
|
||||
if cachedTags, ok := c.cache.Get(cacheKey); ok {
|
||||
slog.Debug("cache hit", "tags", cachedTags)
|
||||
return cachedTags
|
||||
}
|
||||
slog.Debug("cache miss", "url", urlStr)
|
||||
slog.Debug("cache miss", "url", cacheKey)
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,12 +4,13 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCheckCache(t *testing.T) {
|
||||
cache := NewOGTagCache("http://example.com", true, time.Minute)
|
||||
cache := NewOGTagCache("http://example.com", true, time.Minute, false)
|
||||
|
||||
// Set up test data
|
||||
urlStr := "http://example.com/page"
|
||||
|
|
@ -17,18 +18,19 @@ func TestCheckCache(t *testing.T) {
|
|||
"og:title": "Test Title",
|
||||
"og:description": "Test Description",
|
||||
}
|
||||
cacheKey := cache.generateCacheKey(urlStr, "example.com")
|
||||
|
||||
// Test cache miss
|
||||
tags := cache.checkCache(urlStr)
|
||||
tags := cache.checkCache(cacheKey)
|
||||
if tags != nil {
|
||||
t.Errorf("expected nil tags on cache miss, got %v", tags)
|
||||
}
|
||||
|
||||
// Manually add to cache
|
||||
cache.cache.Set(urlStr, expectedTags, time.Minute)
|
||||
cache.cache.Set(cacheKey, expectedTags, time.Minute)
|
||||
|
||||
// Test cache hit
|
||||
tags = cache.checkCache(urlStr)
|
||||
tags = cache.checkCache(cacheKey)
|
||||
if tags == nil {
|
||||
t.Fatal("expected non-nil tags on cache hit, got nil")
|
||||
}
|
||||
|
|
@ -67,7 +69,7 @@ func TestGetOGTags(t *testing.T) {
|
|||
defer ts.Close()
|
||||
|
||||
// Create an instance of OGTagCache with a short TTL for testing
|
||||
cache := NewOGTagCache(ts.URL, true, 1*time.Minute)
|
||||
cache := NewOGTagCache(ts.URL, true, 1*time.Minute, false)
|
||||
|
||||
// Parse the test server URL
|
||||
parsedURL, err := url.Parse(ts.URL)
|
||||
|
|
@ -76,7 +78,8 @@ func TestGetOGTags(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test fetching OG tags from the test server
|
||||
ogTags, err := cache.GetOGTags(parsedURL)
|
||||
// Pass the host from the parsed test server URL
|
||||
ogTags, err := cache.GetOGTags(parsedURL, parsedURL.Host)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get OG tags: %v", err)
|
||||
}
|
||||
|
|
@ -95,13 +98,15 @@ func TestGetOGTags(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test fetching OG tags from the cache
|
||||
ogTags, err = cache.GetOGTags(parsedURL)
|
||||
// Pass the host from the parsed test server URL
|
||||
ogTags, err = cache.GetOGTags(parsedURL, parsedURL.Host)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get OG tags from cache: %v", err)
|
||||
}
|
||||
|
||||
// Test fetching OG tags from the cache (3rd time)
|
||||
newOgTags, err := cache.GetOGTags(parsedURL)
|
||||
// Pass the host from the parsed test server URL
|
||||
newOgTags, err := cache.GetOGTags(parsedURL, parsedURL.Host)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get OG tags from cache: %v", err)
|
||||
}
|
||||
|
|
@ -120,3 +125,116 @@ func TestGetOGTags(t *testing.T) {
|
|||
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetOGTagsWithHostConsideration tests the behavior of the cache with and without host consideration and for multiple hosts in a theoretical setup.
|
||||
func TestGetOGTagsWithHostConsideration(t *testing.T) {
|
||||
var loadCount int // Counter to track how many times the test route is loaded
|
||||
|
||||
// Create a test server
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
loadCount++ // Increment counter on each request to the server
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Write([]byte(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta property="og:title" content="Test Title" />
|
||||
<meta property="og:description" content="Test Description" />
|
||||
</head>
|
||||
<body><p>Content</p></body>
|
||||
</html>
|
||||
`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
parsedURL, err := url.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse test server URL: %v", err)
|
||||
}
|
||||
|
||||
expectedTags := map[string]string{
|
||||
"og:title": "Test Title",
|
||||
"og:description": "Test Description",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
ogCacheConsiderHost bool
|
||||
requests []struct {
|
||||
host string
|
||||
expectedLoadCount int // Expected load count *after* this request
|
||||
}
|
||||
}{
|
||||
{
|
||||
name: "Host Not Considered - Same Host",
|
||||
ogCacheConsiderHost: false,
|
||||
requests: []struct {
|
||||
host string
|
||||
expectedLoadCount int
|
||||
}{
|
||||
{"host1", 1}, // First request, miss
|
||||
{"host1", 1}, // Second request, same host, hit (host ignored)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Host Not Considered - Different Host",
|
||||
ogCacheConsiderHost: false,
|
||||
requests: []struct {
|
||||
host string
|
||||
expectedLoadCount int
|
||||
}{
|
||||
{"host1", 1}, // First request, miss
|
||||
{"host2", 1}, // Second request, different host, hit (host ignored)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Host Considered - Same Host",
|
||||
ogCacheConsiderHost: true,
|
||||
requests: []struct {
|
||||
host string
|
||||
expectedLoadCount int
|
||||
}{
|
||||
{"host1", 1}, // First request, miss
|
||||
{"host1", 1}, // Second request, same host, hit
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Host Considered - Different Host",
|
||||
ogCacheConsiderHost: true,
|
||||
requests: []struct {
|
||||
host string
|
||||
expectedLoadCount int
|
||||
}{
|
||||
{"host1", 1}, // First request, miss
|
||||
{"host2", 2}, // Second request, different host, miss
|
||||
{"host2", 2}, // Third request, same as second, hit
|
||||
{"host1", 2}, // Fourth request, same as first, hit
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
loadCount = 0 // Reset load count for each test case
|
||||
cache := NewOGTagCache(ts.URL, true, 1*time.Minute, tc.ogCacheConsiderHost)
|
||||
|
||||
for i, req := range tc.requests {
|
||||
ogTags, err := cache.GetOGTags(parsedURL, req.host)
|
||||
if err != nil {
|
||||
t.Errorf("Request %d (host: %s): unexpected error: %v", i+1, req.host, err)
|
||||
continue // Skip further checks for this request if error occurred
|
||||
}
|
||||
|
||||
// Verify tags are correct (should always be the same in this setup)
|
||||
if !reflect.DeepEqual(ogTags, expectedTags) {
|
||||
t.Errorf("Request %d (host: %s): expected tags %v, got %v", i+1, req.host, expectedTags, ogTags)
|
||||
}
|
||||
|
||||
// Verify the load count to check cache hit/miss behavior
|
||||
if loadCount != req.expectedLoadCount {
|
||||
t.Errorf("Request %d (host: %s): expected load count %d, got %d (cache hit/miss mismatch)", i+1, req.host, req.expectedLoadCount, loadCount)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package ogtags
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/net/html"
|
||||
|
|
@ -16,17 +17,35 @@ var (
|
|||
emptyMap = map[string]string{} // used to indicate an empty result in the cache. Can't use nil as it would be a cache miss.
|
||||
)
|
||||
|
||||
func (c *OGTagCache) fetchHTMLDocument(urlStr string) (*html.Node, error) {
|
||||
resp, err := c.client.Get(urlStr)
|
||||
// fetchHTMLDocumentWithCache fetches the HTML document from the given URL string,
|
||||
// preserving the original host header.
|
||||
func (c *OGTagCache) fetchHTMLDocumentWithCache(urlStr string, originalHost string, cacheKey string) (*html.Node, error) {
|
||||
req, err := http.NewRequestWithContext(context.Background(), "GET", urlStr, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create http request: %w", err)
|
||||
}
|
||||
|
||||
// Set the Host header to the original host
|
||||
if originalHost != "" {
|
||||
req.Host = originalHost
|
||||
}
|
||||
|
||||
// Add proxy headers
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("User-Agent", "Anubis-OGTag-Fetcher/1.0") // For tracking purposes
|
||||
|
||||
// Send the request
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
slog.Debug("og: request timed out", "url", urlStr)
|
||||
c.cache.Set(urlStr, emptyMap, c.ogTimeToLive/2) // Cache empty result for half the TTL to not spam the server
|
||||
c.cache.Set(cacheKey, emptyMap, c.ogTimeToLive/2) // Cache empty result for half the TTL to not spam the server
|
||||
}
|
||||
return nil, fmt.Errorf("http get failed: %w", err)
|
||||
}
|
||||
// this defer will call MaxBytesReader's Close, which closes the original body.
|
||||
|
||||
// Ensure the response body is closed
|
||||
defer func(Body io.ReadCloser) {
|
||||
err := Body.Close()
|
||||
if err != nil {
|
||||
|
|
@ -36,19 +55,17 @@ func (c *OGTagCache) fetchHTMLDocument(urlStr string) (*html.Node, error) {
|
|||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
slog.Debug("og: received non-OK status code", "url", urlStr, "status", resp.StatusCode)
|
||||
c.cache.Set(urlStr, emptyMap, c.ogTimeToLive) // Cache empty result for non-successful status codes
|
||||
c.cache.Set(cacheKey, emptyMap, c.ogTimeToLive) // Cache empty result for non-successful status codes
|
||||
return nil, fmt.Errorf("%w: page not found", ErrOgHandled)
|
||||
}
|
||||
|
||||
// Check content type
|
||||
ct := resp.Header.Get("Content-Type")
|
||||
if ct == "" {
|
||||
// assume non html body
|
||||
return nil, fmt.Errorf("missing Content-Type header")
|
||||
} else {
|
||||
mediaType, _, err := mime.ParseMediaType(ct)
|
||||
if err != nil {
|
||||
// Malformed Content-Type header
|
||||
slog.Debug("og: malformed Content-Type header", "url", urlStr, "contentType", ct)
|
||||
return nil, fmt.Errorf("%w malformed Content-Type header: %w", ErrOgHandled, err)
|
||||
}
|
||||
|
|
@ -59,17 +76,16 @@ func (c *OGTagCache) fetchHTMLDocument(urlStr string) (*html.Node, error) {
|
|||
}
|
||||
}
|
||||
|
||||
resp.Body = http.MaxBytesReader(nil, resp.Body, c.maxContentLength)
|
||||
resp.Body = http.MaxBytesReader(nil, resp.Body, maxContentLength)
|
||||
|
||||
doc, err := html.Parse(resp.Body)
|
||||
if err != nil {
|
||||
// Check if the error is specifically because the limit was exceeded
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err, &maxBytesErr) {
|
||||
slog.Debug("og: content exceeded max length", "url", urlStr, "limit", c.maxContentLength)
|
||||
return nil, fmt.Errorf("content too large: exceeded %d bytes", c.maxContentLength)
|
||||
slog.Debug("og: content exceeded max length", "url", urlStr, "limit", maxContentLength)
|
||||
return nil, fmt.Errorf("content too large: exceeded %d bytes", maxContentLength)
|
||||
}
|
||||
// parsing error (e.g., malformed HTML)
|
||||
return nil, fmt.Errorf("failed to parse HTML: %w", err)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package ogtags
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"golang.org/x/net/html"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
|
@ -78,8 +79,8 @@ func TestFetchHTMLDocument(t *testing.T) {
|
|||
}))
|
||||
defer ts.Close()
|
||||
|
||||
cache := NewOGTagCache("", true, time.Minute)
|
||||
doc, err := cache.fetchHTMLDocument(ts.URL)
|
||||
cache := NewOGTagCache("", true, time.Minute, false)
|
||||
doc, err := cache.fetchHTMLDocument(ts.URL, "anything")
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
|
|
@ -105,9 +106,9 @@ func TestFetchHTMLDocumentInvalidURL(t *testing.T) {
|
|||
t.Skip("test requires theoretical network egress")
|
||||
}
|
||||
|
||||
cache := NewOGTagCache("", true, time.Minute)
|
||||
cache := NewOGTagCache("", true, time.Minute, false)
|
||||
|
||||
doc, err := cache.fetchHTMLDocument("http://invalid.url.that.doesnt.exist.example")
|
||||
doc, err := cache.fetchHTMLDocument("http://invalid.url.that.doesnt.exist.example", "anything")
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid URL, got nil")
|
||||
|
|
@ -117,3 +118,9 @@ func TestFetchHTMLDocumentInvalidURL(t *testing.T) {
|
|||
t.Error("expected nil document for invalid URL, got non-nil")
|
||||
}
|
||||
}
|
||||
|
||||
// fetchHTMLDocument allows you to call fetchHTMLDocumentWithCache without a duplicate generateCacheKey call
|
||||
func (c *OGTagCache) fetchHTMLDocument(urlStr string, originalHost string) (*html.Node, error) {
|
||||
cacheKey := c.generateCacheKey(urlStr, originalHost)
|
||||
return c.fetchHTMLDocumentWithCache(urlStr, originalHost, cacheKey)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ func TestIntegrationGetOGTags(t *testing.T) {
|
|||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create cache instance
|
||||
cache := NewOGTagCache(ts.URL, true, 1*time.Minute)
|
||||
cache := NewOGTagCache(ts.URL, true, 1*time.Minute, false)
|
||||
|
||||
// Create URL for test
|
||||
testURL, _ := url.Parse(ts.URL)
|
||||
|
|
@ -112,7 +112,8 @@ func TestIntegrationGetOGTags(t *testing.T) {
|
|||
testURL.RawQuery = tc.query
|
||||
|
||||
// Get OG tags
|
||||
ogTags, err := cache.GetOGTags(testURL)
|
||||
// Pass the host from the test URL
|
||||
ogTags, err := cache.GetOGTags(testURL, testURL.Host)
|
||||
|
||||
// Check error expectation
|
||||
if tc.expectError {
|
||||
|
|
@ -139,7 +140,8 @@ func TestIntegrationGetOGTags(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test cache retrieval
|
||||
cachedOGTags, err := cache.GetOGTags(testURL)
|
||||
// Pass the host from the test URL
|
||||
cachedOGTags, err := cache.GetOGTags(testURL, testURL.Host)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get OG tags from cache: %v", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,51 +1,111 @@
|
|||
package ogtags
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/TecharoHQ/anubis/decaymap"
|
||||
)
|
||||
|
||||
const (
|
||||
maxContentLength = 16 << 20 // 16 MiB in bytes, if there is a reasonable reason that you need more than this...Why?
|
||||
httpTimeout = 5 * time.Second /*todo: make this configurable?*/
|
||||
)
|
||||
|
||||
type OGTagCache struct {
|
||||
cache *decaymap.Impl[string, map[string]string]
|
||||
target string
|
||||
ogPassthrough bool
|
||||
ogTimeToLive time.Duration
|
||||
approvedTags []string
|
||||
approvedPrefixes []string
|
||||
client *http.Client
|
||||
maxContentLength int64
|
||||
cache *decaymap.Impl[string, map[string]string]
|
||||
targetURL *url.URL
|
||||
ogCacheConsiderHost bool
|
||||
ogPassthrough bool
|
||||
ogTimeToLive time.Duration
|
||||
approvedTags []string
|
||||
approvedPrefixes []string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewOGTagCache(target string, ogPassthrough bool, ogTimeToLive time.Duration) *OGTagCache {
|
||||
func NewOGTagCache(target string, ogPassthrough bool, ogTimeToLive time.Duration, ogTagsConsiderHost bool) *OGTagCache {
|
||||
// Predefined approved tags and prefixes
|
||||
// In the future, these could come from configuration
|
||||
defaultApprovedTags := []string{"description", "keywords", "author"}
|
||||
defaultApprovedPrefixes := []string{"og:", "twitter:", "fediverse:"}
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Second, /*make this configurable?*/
|
||||
|
||||
var parsedTargetURL *url.URL
|
||||
var err error
|
||||
|
||||
if target == "" {
|
||||
// Default to localhost if target is empty
|
||||
parsedTargetURL, _ = url.Parse("http://localhost")
|
||||
} else {
|
||||
parsedTargetURL, err = url.Parse(target)
|
||||
if err != nil {
|
||||
slog.Debug("og: failed to parse target URL, treating as non-unix", "target", target, "error", err)
|
||||
// If parsing fails, treat it as a non-unix target for backward compatibility or default behavior
|
||||
// For now, assume it's not a scheme issue but maybe an invalid char, etc.
|
||||
// A simple string target might be intended if it's not a full URL.
|
||||
parsedTargetURL = &url.URL{Scheme: "http", Host: target} // Assume http if scheme missing and host-like
|
||||
if !strings.Contains(target, "://") && !strings.HasPrefix(target, "unix:") {
|
||||
// If it looks like just a host/host:port (and not unix), prepend http:// (todo: is this bad...? Trace path to see if i can yell at user to do it right)
|
||||
parsedTargetURL, _ = url.Parse("http://" + target) // fetch cares about scheme but anubis doesn't
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const maxContentLength = 16 << 20 // 16 MiB in bytes
|
||||
client := &http.Client{
|
||||
Timeout: httpTimeout,
|
||||
}
|
||||
|
||||
// Configure custom transport for Unix sockets
|
||||
if parsedTargetURL.Scheme == "unix" {
|
||||
socketPath := parsedTargetURL.Path // For unix scheme, path is the socket path
|
||||
client.Transport = &http.Transport{
|
||||
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return net.Dial("unix", socketPath)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return &OGTagCache{
|
||||
cache: decaymap.New[string, map[string]string](),
|
||||
target: target,
|
||||
ogPassthrough: ogPassthrough,
|
||||
ogTimeToLive: ogTimeToLive,
|
||||
approvedTags: defaultApprovedTags,
|
||||
approvedPrefixes: defaultApprovedPrefixes,
|
||||
client: client,
|
||||
maxContentLength: maxContentLength,
|
||||
cache: decaymap.New[string, map[string]string](),
|
||||
targetURL: parsedTargetURL, // Store the parsed URL
|
||||
ogPassthrough: ogPassthrough,
|
||||
ogTimeToLive: ogTimeToLive,
|
||||
ogCacheConsiderHost: ogTagsConsiderHost, // todo: refactor to be a separate struct
|
||||
approvedTags: defaultApprovedTags,
|
||||
approvedPrefixes: defaultApprovedPrefixes,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// getTarget constructs the target URL string for fetching OG tags.
|
||||
// For Unix sockets, it creates a "fake" HTTP URL that the custom dialer understands.
|
||||
func (c *OGTagCache) getTarget(u *url.URL) string {
|
||||
return c.target + u.Path
|
||||
if c.targetURL.Scheme == "unix" {
|
||||
// The custom dialer ignores the host, but we need a valid http URL structure.
|
||||
// Use "unix" as a placeholder host. Path and Query from original request are appended.
|
||||
fakeURL := &url.URL{
|
||||
Scheme: "http", // Scheme must be http/https for client.Get
|
||||
Host: "unix", // Arbitrary host, ignored by custom dialer
|
||||
Path: u.Path,
|
||||
RawQuery: u.RawQuery,
|
||||
}
|
||||
return fakeURL.String()
|
||||
}
|
||||
|
||||
// For regular http/https targets
|
||||
target := *c.targetURL // Make a copy
|
||||
target.Path = u.Path
|
||||
target.RawQuery = u.RawQuery
|
||||
return target.String()
|
||||
|
||||
}
|
||||
|
||||
func (c *OGTagCache) Cleanup() {
|
||||
c.cache.Cleanup()
|
||||
if c.cache != nil {
|
||||
c.cache.Cleanup()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,16 @@
|
|||
package ogtags
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
|
@ -29,14 +38,23 @@ func TestNewOGTagCache(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cache := NewOGTagCache(tt.target, tt.ogPassthrough, tt.ogTimeToLive)
|
||||
cache := NewOGTagCache(tt.target, tt.ogPassthrough, tt.ogTimeToLive, false)
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("expected non-nil cache, got nil")
|
||||
}
|
||||
|
||||
if cache.target != tt.target {
|
||||
t.Errorf("expected target %s, got %s", tt.target, cache.target)
|
||||
// Check the parsed targetURL, handling the default case for empty target
|
||||
expectedURLStr := tt.target
|
||||
if tt.target == "" {
|
||||
// Default behavior when target is empty is now http://localhost
|
||||
expectedURLStr = "http://localhost"
|
||||
} else if !strings.Contains(tt.target, "://") && !strings.HasPrefix(tt.target, "unix:") {
|
||||
// Handle case where target is just host or host:port (and not unix)
|
||||
expectedURLStr = "http://" + tt.target
|
||||
}
|
||||
if cache.targetURL.String() != expectedURLStr {
|
||||
t.Errorf("expected targetURL %s, got %s", expectedURLStr, cache.targetURL.String())
|
||||
}
|
||||
|
||||
if cache.ogPassthrough != tt.ogPassthrough {
|
||||
|
|
@ -50,6 +68,45 @@ func TestNewOGTagCache(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestNewOGTagCache_UnixSocket specifically tests unix socket initialization
|
||||
func TestNewOGTagCache_UnixSocket(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
socketPath := filepath.Join(tempDir, "test.sock")
|
||||
target := "unix://" + socketPath
|
||||
|
||||
cache := NewOGTagCache(target, true, 5*time.Minute, false)
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("expected non-nil cache, got nil")
|
||||
}
|
||||
|
||||
if cache.targetURL.Scheme != "unix" {
|
||||
t.Errorf("expected targetURL scheme 'unix', got '%s'", cache.targetURL.Scheme)
|
||||
}
|
||||
if cache.targetURL.Path != socketPath {
|
||||
t.Errorf("expected targetURL path '%s', got '%s'", socketPath, cache.targetURL.Path)
|
||||
}
|
||||
|
||||
// Check if the client transport is configured for Unix sockets
|
||||
transport, ok := cache.client.Transport.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("expected client transport to be *http.Transport, got %T", cache.client.Transport)
|
||||
}
|
||||
if transport.DialContext == nil {
|
||||
t.Fatal("expected client transport DialContext to be non-nil for unix socket")
|
||||
}
|
||||
|
||||
// Attempt a dummy dial to see if it uses the correct path (optional, more involved check)
|
||||
dummyConn, err := transport.DialContext(context.Background(), "", "")
|
||||
if err == nil {
|
||||
dummyConn.Close()
|
||||
t.Log("DialContext seems functional, but couldn't verify path without a listener")
|
||||
} else if !strings.Contains(err.Error(), "connect: connection refused") && !strings.Contains(err.Error(), "connect: no such file or directory") {
|
||||
// We expect connection refused or not found if nothing is listening
|
||||
t.Errorf("DialContext failed with unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTarget(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
@ -66,24 +123,39 @@ func TestGetTarget(t *testing.T) {
|
|||
expected: "http://example.com",
|
||||
},
|
||||
{
|
||||
name: "With complex path",
|
||||
target: "http://example.com",
|
||||
path: "/pag(#*((#@)ΓΓΓΓe/Γ",
|
||||
query: "id=123",
|
||||
expected: "http://example.com/pag(#*((#@)ΓΓΓΓe/Γ",
|
||||
name: "With complex path",
|
||||
target: "http://example.com",
|
||||
path: "/pag(#*((#@)ΓΓΓΓe/Γ",
|
||||
query: "id=123",
|
||||
// Expect URL encoding and query parameter
|
||||
expected: "http://example.com/pag%28%23%2A%28%28%23@%29%CE%93%CE%93%CE%93%CE%93e/%CE%93?id=123",
|
||||
},
|
||||
{
|
||||
name: "With query and path",
|
||||
target: "http://example.com",
|
||||
path: "/page",
|
||||
query: "id=123",
|
||||
expected: "http://example.com/page",
|
||||
expected: "http://example.com/page?id=123",
|
||||
},
|
||||
{
|
||||
name: "Unix socket target",
|
||||
target: "unix:/tmp/anubis.sock",
|
||||
path: "/some/path",
|
||||
query: "key=value&flag=true",
|
||||
expected: "http://unix/some/path?key=value&flag=true", // Scheme becomes http, host is 'unix'
|
||||
},
|
||||
{
|
||||
name: "Unix socket target with ///",
|
||||
target: "unix:///var/run/anubis.sock",
|
||||
path: "/",
|
||||
query: "",
|
||||
expected: "http://unix/",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cache := NewOGTagCache(tt.target, false, time.Minute)
|
||||
cache := NewOGTagCache(tt.target, false, time.Minute, false)
|
||||
|
||||
u := &url.URL{
|
||||
Path: tt.path,
|
||||
|
|
@ -98,3 +170,86 @@ func TestGetTarget(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegrationGetOGTags_UnixSocket tests fetching OG tags via a Unix socket.
|
||||
func TestIntegrationGetOGTags_UnixSocket(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
socketPath := filepath.Join(tempDir, "anubis-test.sock")
|
||||
|
||||
// Ensure the socket does not exist initially
|
||||
_ = os.Remove(socketPath)
|
||||
|
||||
// Create a simple HTTP server listening on the Unix socket
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen on unix socket %s: %v", socketPath, err)
|
||||
}
|
||||
defer func(listener net.Listener, socketPath string) {
|
||||
if listener != nil {
|
||||
if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
t.Logf("Error closing listener: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := os.Stat(socketPath); err == nil {
|
||||
if err := os.Remove(socketPath); err != nil {
|
||||
t.Logf("Error removing socket file %s: %v", socketPath, err)
|
||||
}
|
||||
}
|
||||
}(listener, socketPath)
|
||||
|
||||
server := &http.Server{
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprintln(w, `<!DOCTYPE html><html><head><meta property="og:title" content="Unix Socket Test" /></head><body>Test</body></html>`)
|
||||
}),
|
||||
}
|
||||
go func() {
|
||||
if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
t.Logf("Unix socket server error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func(server *http.Server, ctx context.Context) {
|
||||
err := server.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Logf("Error shutting down server: %v", err)
|
||||
}
|
||||
}(server, context.Background()) // Ensure server is shut down
|
||||
|
||||
// Wait a moment for the server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create cache instance pointing to the Unix socket
|
||||
targetURL := "unix://" + socketPath
|
||||
cache := NewOGTagCache(targetURL, true, 1*time.Minute, false)
|
||||
|
||||
// Create a dummy URL for the request (path and query matter)
|
||||
testReqURL, _ := url.Parse("/some/page?query=1")
|
||||
|
||||
// Get OG tags
|
||||
// Pass an empty string for host, as it's irrelevant for unix sockets
|
||||
ogTags, err := cache.GetOGTags(testReqURL, "")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("GetOGTags failed for unix socket: %v", err)
|
||||
}
|
||||
|
||||
expectedTags := map[string]string{
|
||||
"og:title": "Unix Socket Test",
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(ogTags, expectedTags) {
|
||||
t.Errorf("Expected OG tags %v, got %v", expectedTags, ogTags)
|
||||
}
|
||||
|
||||
// Test cache retrieval (should hit cache)
|
||||
// Pass an empty string for host
|
||||
cachedTags, err := cache.GetOGTags(testReqURL, "")
|
||||
if err != nil {
|
||||
t.Fatalf("GetOGTags (cache hit) failed for unix socket: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(cachedTags, expectedTags) {
|
||||
t.Errorf("Expected cached OG tags %v, got %v", expectedTags, cachedTags)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import (
|
|||
// TestExtractOGTags updated with correct expectations based on filtering logic
|
||||
func TestExtractOGTags(t *testing.T) {
|
||||
// Use a cache instance that reflects the default approved lists
|
||||
testCache := NewOGTagCache("", false, time.Minute)
|
||||
testCache := NewOGTagCache("", false, time.Minute, false)
|
||||
// Manually set approved tags/prefixes based on the user request for clarity
|
||||
testCache.approvedTags = []string{"description"}
|
||||
testCache.approvedPrefixes = []string{"og:"}
|
||||
|
|
@ -189,7 +189,7 @@ func TestIsOGMetaTag(t *testing.T) {
|
|||
|
||||
func TestExtractMetaTagInfo(t *testing.T) {
|
||||
// Use a cache instance that reflects the default approved lists
|
||||
testCache := NewOGTagCache("", false, time.Minute)
|
||||
testCache := NewOGTagCache("", false, time.Minute, false)
|
||||
testCache.approvedTags = []string{"description"}
|
||||
testCache.approvedPrefixes = []string{"og:"}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue