Implement FCrDNS and other DNS features (#1308)

* Implement FCrDNS and other DNS features

* Redesign DNS cache and methods

* Fix DNS cache

* Rename regexSafe arg

* Alter verifyFCrDNS(addr) behaviour

* Remove unused dnsCache field from Server struct

* Upd expressions docs

* Update docs/docs/CHANGELOG.md

Signed-off-by: Xe Iaso <me@xeiaso.net>

* refactor(dns): simplify FCrDNS logging

* docs: clarify verifyFCrDNS behavior

Add a note to the documentation for `verifyFCrDNS` to clarify that it returns true when no PTR records are found for the given IP address.

* fix(dns): Improve FCrDNS error handling and tests

The `VerifyFCrDNS` function previously ignored errors returned from reverse DNS lookups. This could lead to incorrect passes when a DNS failure (other than a simple 'not found') occurred. This change ensures that any error from a reverse lookup will cause the FCrDNS check to fail.

The test suite for FCrDNS has been updated to reflect this change. The mock DNS lookups now simulate both 'not found' errors and other generic DNS errors. The test cases have been updated to ensure that the function behaves correctly in both scenarios, resolving a situation where two test cases were effectively duplicates.

* docs: Update FCrDNS documentation and spelling

Corrected a typo in the `verifyFCrDNS` function documentation.

Additionally, updated the spelling exception list to include new terms and remove redundant entries.

* chore: update spelling

Signed-off-by: Xe Iaso <me@xeiaso.net>

---------

Signed-off-by: Xe Iaso <me@xeiaso.net>
Co-authored-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
The Ninth 2025-11-27 06:24:45 +03:00 committed by GitHub
parent 4ead3ed16e
commit 00fa939acf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 1652 additions and 480 deletions

70
internal/dns/cache.go Normal file
View file

@ -0,0 +1,70 @@
package dns
import (
"log/slog"
"time"
"github.com/TecharoHQ/anubis/lib/store"
_ "github.com/TecharoHQ/anubis/lib/store/all"
)
type DnsCache struct {
forward store.JSON[[]string]
reverse store.JSON[[]string]
forwardTTL time.Duration
reverseTTL time.Duration
}
func NewDNSCache(forwardTTL int, reverseTTL int, backend store.Interface) *DnsCache {
return &DnsCache{
forward: store.JSON[[]string]{
Underlying: backend,
Prefix: "forwardDNS",
},
reverse: store.JSON[[]string]{
Underlying: backend,
Prefix: "reverseDNS",
},
forwardTTL: time.Duration(forwardTTL) * time.Second,
reverseTTL: time.Duration(reverseTTL) * time.Second,
}
}
func (d *Dns) getCachedForward(host string) ([]string, bool) {
if d.cache == nil {
return nil, false
}
if cached, err := d.cache.forward.Get(d.ctx, host); err == nil {
slog.Debug("DNS: forward cache hit", "name", host, "ips", cached)
return cached, true
}
slog.Debug("DNS: forward cache miss", "name", host)
return nil, false
}
func (d *Dns) getCachedReverse(addr string) ([]string, bool) {
if d.cache == nil {
return nil, false
}
if cached, err := d.cache.reverse.Get(d.ctx, addr); err == nil {
slog.Debug("DNS: reverse cache hit", "addr", addr, "names", cached)
return cached, true
}
slog.Debug("DNS: reverse cache miss", "addr", addr)
return nil, false
}
func (d *Dns) forwardCachePut(host string, entries []string) {
if d.cache == nil {
return
}
d.cache.forward.Set(d.ctx, host, entries, d.cache.forwardTTL)
}
func (d *Dns) reverseCachePut(addr string, entries []string) {
if d.cache == nil {
return
}
d.cache.reverse.Set(d.ctx, addr, entries, d.cache.reverseTTL)
}

174
internal/dns/dns.go Normal file
View file

@ -0,0 +1,174 @@
package dns
import (
"context"
"encoding/hex"
"errors"
"fmt"
"log/slog"
"net"
"regexp"
"slices"
"strings"
)
var (
DNSLookupAddr = net.LookupAddr
DNSLookupHost = net.LookupHost
)
type Dns struct {
cache *DnsCache
ctx context.Context
}
func New(ctx context.Context, cache *DnsCache) *Dns {
return &Dns{
cache: cache,
ctx: ctx,
}
}
// ReverseDNS performs a reverse DNS lookup for the given IP address and trims the trailing dot from the results.
func (d *Dns) ReverseDNS(addr string) ([]string, error) {
slog.Debug("DNS: performing reverse lookup", "addr", addr)
if cached, ok := d.getCachedReverse(addr); ok {
return cached, nil
}
names, err := DNSLookupAddr(addr)
if err != nil {
if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.IsNotFound {
slog.Debug("DNS: no PTR record found", "addr", addr)
return []string{}, nil
}
slog.Error("DNS: reverse lookup failed", "addr", addr, "err", err)
return nil, err
}
slog.Debug("DNS: reverse lookup successful", "addr", addr, "names", names)
trimmedNames := make([]string, len(names))
for i, name := range names {
trimmedNames[i] = strings.TrimSuffix(name, ".")
}
d.reverseCachePut(addr, trimmedNames)
return trimmedNames, nil
}
// LookupHost performs a forward DNS lookup for the given hostname.
func (d *Dns) LookupHost(host string) ([]string, error) {
slog.Debug("DNS: performing forward lookup", "host", host)
if cached, ok := d.getCachedForward(host); ok {
return cached, nil
}
addrs, err := DNSLookupHost(host)
if err != nil {
if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.IsNotFound {
slog.Debug("DNS: no A/AAAA record found", "host", host)
return []string{}, nil
}
slog.Error("DNS: forward lookup failed", "host", host, "err", err)
return nil, err
}
slog.Debug("DNS: forward lookup successful", "host", host, "addrs", addrs)
d.forwardCachePut(host, addrs)
return addrs, nil
}
// verifyFCrDNSInternal performs the second half of the FCrDNS check, using a
// pre-fetched list of names to perform the forward lookups.
func (d *Dns) verifyFCrDNSInternal(addr string, names []string) bool {
for _, name := range names {
if cached, err := d.LookupHost(name); err == nil {
if slices.Contains(cached, addr) {
slog.Info("DNS: forward lookup confirmed original IP", "name", name, "addr", addr)
return true
}
continue
}
}
slog.Info("DNS: could not confirm original IP in forward lookups", "addr", addr)
return false
}
// VerifyFCrDNS performs a forward-confirmed reverse DNS (FCrDNS) lookup for the given IP address,
// optionally matching against a provided pattern.
func (d *Dns) VerifyFCrDNS(addr string, pattern *string) bool {
var patternVal string
if pattern != nil {
patternVal = *pattern
}
slog.Debug("DNS: performing FCrDNS lookup", "addr", addr, "pattern", patternVal)
names, err := d.ReverseDNS(addr)
if err != nil {
return false
}
if len(names) == 0 {
return pattern == nil // If no pattern specified, check is passed
}
// If a pattern is provided, check for a match.
if pattern != nil {
anyNameMatched := false
for _, name := range names {
matched, err := regexp.MatchString(*pattern, name)
if err != nil {
slog.Error("DNS: verifyFCrDNS invalid regex pattern", "err", err)
return false // Invalid pattern is a failure.
}
if matched {
anyNameMatched = true
break
}
}
if !anyNameMatched {
slog.Debug("DNS: FCrDNS no PTR matches the pattern", "addr", addr, "pattern", *pattern)
return false
}
slog.Debug("DNS: FCrDNS PTR matched pattern, proceeding with forward check", "addr", addr, "pattern", *pattern)
}
// If we're here, either there was no pattern, or the pattern matched.
// Proceed with the forward lookup confirmation.
return d.verifyFCrDNSInternal(addr, names)
}
// ArpaReverseIP performs translation from ip v4/v6 to arpa reverse notation
func (d *Dns) ArpaReverseIP(addr string) (string, error) {
ip := net.ParseIP(addr)
if ip == nil {
return addr, errors.New("invalid IP address")
}
if ipv4 := ip.To4(); ipv4 != nil {
return fmt.Sprintf("%d.%d.%d.%d", ipv4[3], ipv4[2], ipv4[1], ipv4[0]), nil
}
ipv6 := ip.To16()
if ipv6 == nil {
return addr, errors.New("invalid IPv6 address")
}
hexBytes := make([]byte, hex.EncodedLen(len(ipv6)))
hex.Encode(hexBytes, ipv6)
var sb strings.Builder
sb.Grow(len(hexBytes)*2 - 1)
for i := len(hexBytes) - 1; i >= 0; i-- {
sb.WriteByte(hexBytes[i])
if i > 0 {
sb.WriteByte('.')
}
}
return sb.String(), nil
}

308
internal/dns/dns_test.go Normal file
View file

@ -0,0 +1,308 @@
package dns
import (
"context"
"errors"
"net"
"reflect"
"testing"
"time"
"github.com/TecharoHQ/anubis/lib/store/memory"
)
// newTestDNS is a helper function to create a new Dns object with an in-memory cache for testing.
func newTestDNS(forwardTTL int, reverseTTL int) *Dns {
ctx := context.Background()
memStore := memory.New(ctx)
cache := NewDNSCache(forwardTTL, reverseTTL, memStore)
return New(ctx, cache)
}
// mockLookupAddr is a mock implementation of the net.LookupAddr function.
func mockLookupAddr(addr string) ([]string, error) {
switch addr {
case "8.8.8.8":
return []string{"dns.google."}, nil
case "1.1.1.1":
return []string{"one.one.one.one."}, nil
case "208.67.222.222":
return []string{"resolver1.opendns.com."}, nil
case "9.9.9.9":
return nil, &net.DNSError{Err: "no such host", Name: "9.9.9.9", IsNotFound: true}
case "1.2.3.4":
return nil, errors.New("unknown error")
default:
return nil, &net.DNSError{Err: "no such host", Name: addr, IsNotFound: true}
}
}
// mockLookupHost is a mock implementation of the net.LookupHost function.
func mockLookupHost(host string) ([]string, error) {
switch host {
case "dns.google":
return []string{"8.8.8.8", "8.8.4.4"}, nil
case "one.one.one.one":
return []string{"1.1.1.1", "1.0.0.1"}, nil
case "resolver1.opendns.com":
return []string{"208.67.222.222"}, nil
case "example.com":
return nil, &net.DNSError{Err: "no such host", Name: "example.com", IsNotFound: true}
default:
return nil, &net.DNSError{Err: "no such host", Name: host, IsNotFound: true}
}
}
func TestMain(m *testing.M) {
// Before all tests
originalLookupAddr := DNSLookupAddr
originalLookupHost := DNSLookupHost
DNSLookupAddr = mockLookupAddr
DNSLookupHost = mockLookupHost
// Run tests
exitCode := m.Run()
// After all tests
DNSLookupAddr = originalLookupAddr
DNSLookupHost = originalLookupHost
// Exit
if exitCode != 0 {
panic(exitCode)
}
}
func TestDns_ArpaReverseIP(t *testing.T) {
d := newTestDNS(0, 0)
tests := []struct {
name string
ip string
want string
wantErr bool
}{
{"ipv4", "192.0.2.1", "1.2.0.192", false},
{"ipv6", "2001:db8::1", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2", false},
{"invalid ip", "invalid", "invalid", true},
{"ipv4-mapped ipv6", "::ffff:192.0.2.1", "1.2.0.192", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := d.ArpaReverseIP(tt.ip)
if (err != nil) != tt.wantErr {
t.Errorf("ArpaReverseIP() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ArpaReverseIP() = %v, want %v", got, tt.want)
}
})
}
}
func TestDns_ReverseDNS(t *testing.T) {
d := newTestDNS(1, 1) // short TTL for testing cache
// First call - cache miss
t.Run("cache miss", func(t *testing.T) {
got, err := d.ReverseDNS("8.8.8.8")
if err != nil {
t.Fatalf("ReverseDNS() error = %v", err)
}
want := []string{"dns.google"}
if !reflect.DeepEqual(got, want) {
t.Errorf("ReverseDNS() = %v, want %v", got, want)
}
})
// Second call - cache hit
t.Run("cache hit", func(t *testing.T) {
// Temporarily replace lookup function to ensure cache is used
originalLookupAddr := DNSLookupAddr
DNSLookupAddr = func(addr string) ([]string, error) {
return nil, errors.New("should not be called")
}
defer func() { DNSLookupAddr = originalLookupAddr }()
got, err := d.ReverseDNS("8.8.8.8")
if err != nil {
t.Fatalf("ReverseDNS() error = %v", err)
}
want := []string{"dns.google"}
if !reflect.DeepEqual(got, want) {
t.Errorf("ReverseDNS() = %v, want %v", got, want)
}
})
// Test cache expiration
t.Run("cache expiration", func(t *testing.T) {
time.Sleep(2 * time.Second)
// Now the cache should be expired
// We expect the mock to be called again
// To test this we will change the mock to return something different
originalLookupAddr := DNSLookupAddr
DNSLookupAddr = func(addr string) ([]string, error) {
if addr == "8.8.8.8" {
return []string{"expired.google."}, nil
}
return mockLookupAddr(addr)
}
defer func() { DNSLookupAddr = originalLookupAddr }()
got, err := d.ReverseDNS("8.8.8.8")
if err != nil {
t.Fatalf("ReverseDNS() error = %v", err)
}
want := []string{"expired.google"}
if !reflect.DeepEqual(got, want) {
t.Errorf("ReverseDNS() = %v, want %v", got, want)
}
})
// Test not found
t.Run("not found", func(t *testing.T) {
got, err := d.ReverseDNS("9.9.9.9")
if err != nil {
t.Fatalf("ReverseDNS() error = %v", err)
}
if len(got) != 0 {
t.Errorf("ReverseDNS() = %v, want empty slice", got)
}
})
}
func TestDns_LookupHost(t *testing.T) {
d := newTestDNS(1, 1)
t.Run("cache miss", func(t *testing.T) {
got, err := d.LookupHost("dns.google")
if err != nil {
t.Fatalf("LookupHost() error = %v", err)
}
want := []string{"8.8.8.8", "8.8.4.4"}
if !reflect.DeepEqual(got, want) {
t.Errorf("LookupHost() = %v, want %v", got, want)
}
})
t.Run("cache hit", func(t *testing.T) {
originalLookupHost := DNSLookupHost
DNSLookupHost = func(host string) ([]string, error) {
return nil, errors.New("should not be called")
}
defer func() { DNSLookupHost = originalLookupHost }()
got, err := d.LookupHost("dns.google")
if err != nil {
t.Fatalf("LookupHost() error = %v", err)
}
want := []string{"8.8.8.8", "8.8.4.4"}
if !reflect.DeepEqual(got, want) {
t.Errorf("LookupHost() = %v, want %v", got, want)
}
})
t.Run("cache expiration", func(t *testing.T) {
time.Sleep(2 * time.Second)
originalLookupHost := DNSLookupHost
DNSLookupHost = func(host string) ([]string, error) {
if host == "dns.google" {
return []string{"9.9.9.9"}, nil
}
return mockLookupHost(host)
}
defer func() { DNSLookupHost = originalLookupHost }()
got, err := d.LookupHost("dns.google")
if err != nil {
t.Fatalf("LookupHost() error = %v", err)
}
want := []string{"9.9.9.9"}
if !reflect.DeepEqual(got, want) {
t.Errorf("LookupHost() = %v, want %v", got, want)
}
})
t.Run("not found", func(t *testing.T) {
got, err := d.LookupHost("example.com")
if err != nil {
t.Fatalf("LookupHost() error = %v", err)
}
if len(got) != 0 {
t.Errorf("LookupHost() = %v, want empty slice", got)
}
})
}
func TestDns_VerifyFCrDNS(t *testing.T) {
d := newTestDNS(1, 1)
// Helper to convert string to *string
p := func(s string) *string {
return &s
}
tests := []struct {
name string
ip string
pattern *string
want bool
}{
// Cases without pattern
{"valid no pattern", "8.8.8.8", nil, true},
{"valid partial no pattern", "1.1.1.1", nil, true},
{"not found no pattern", "9.9.9.9", nil, true},
{"unknown error no pattern", "1.2.3.4", nil, false},
// Cases with pattern
{"valid match", "8.8.8.8", p(`.*\.google$`), true},
{"valid no match", "8.8.8.8", p(`\.com$`), false},
{"not found with pattern", "9.9.9.9", p(".*"), false},
{"unknown error with pattern", "1.2.3.4", p(".*"), false},
{"invalid pattern", "8.8.8.8", p(`[`), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := d.VerifyFCrDNS(tt.ip, tt.pattern); got != tt.want {
t.Errorf("VerifyFCrDNS() = %v, want %v", got, tt.want)
}
})
}
t.Run("reverse cache hit", func(t *testing.T) {
// Prime the cache
if got := d.VerifyFCrDNS("8.8.8.8", nil); got != true {
t.Fatalf("VerifyFCrDNS() priming failed, got %v, want true", got)
}
// Now test with a failing lookup to ensure cache is used
originalLookupAddr := DNSLookupAddr
DNSLookupAddr = func(addr string) ([]string, error) {
return nil, errors.New("should not be called")
}
defer func() { DNSLookupAddr = originalLookupAddr }()
if got := d.VerifyFCrDNS("8.8.8.8", nil); got != true {
t.Errorf("VerifyFCrDNS() = %v, want true", got)
}
})
t.Run("forward cache hit", func(t *testing.T) {
// Prime the cache
if got := d.VerifyFCrDNS("8.8.8.8", nil); got != true {
t.Fatalf("VerifyFCrDNS() priming failed, got %v, want true", got)
}
// Now test with a failing lookup to ensure cache is used
originalLookupHost := DNSLookupHost
DNSLookupHost = func(host string) ([]string, error) {
return nil, errors.New("should not be called")
}
defer func() { DNSLookupHost = originalLookupHost }()
if got := d.VerifyFCrDNS("8.8.8.8", nil); got != true {
t.Errorf("VerifyFCrDNS() = %v, want true", got)
}
})
}