fix: make ogtags and dnsbl use the Store instead of memory (#760)

Signed-off-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
Xe Iaso 2025-07-05 20:17:46 +00:00 committed by GitHub
parent e870ede120
commit 7d0c58d1a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 134 additions and 86 deletions

View file

@ -70,7 +70,6 @@ type Server struct {
next http.Handler
mux *http.ServeMux
policy *policy.ParsedConfig
DNSBLCache *decaymap.Impl[string, dnsbl.DroneBLResponse]
OGTags *ogtags.OGTagCache
ed25519Priv ed25519.PrivateKey
hs512Secret []byte
@ -279,15 +278,16 @@ func (s *Server) checkRules(w http.ResponseWriter, r *http.Request, cr policy.Ch
}
func (s *Server) handleDNSBL(w http.ResponseWriter, r *http.Request, ip string, lg *slog.Logger) bool {
db := &store.JSON[dnsbl.DroneBLResponse]{Underlying: s.store, Prefix: "dronebl:"}
if s.policy.DNSBL && ip != "" {
resp, ok := s.DNSBLCache.Get(ip)
if !ok {
resp, err := db.Get(r.Context(), ip)
if err != nil {
lg.Debug("looking up ip in dnsbl")
resp, err := dnsbl.Lookup(ip)
if err != nil {
lg.Error("can't look up ip in dnsbl", "err", err)
}
s.DNSBLCache.Set(ip, resp, 24*time.Hour)
db.Set(r.Context(), ip, resp, 24*time.Hour)
droneBLHits.WithLabelValues(resp.String()).Inc()
}
@ -551,8 +551,3 @@ func (s *Server) check(r *http.Request) (policy.CheckResult, *policy.Bot, error)
Rules: &checker.List{},
}, nil
}
func (s *Server) CleanupDecayMap() {
s.DNSBLCache.Cleanup()
s.OGTags.Cleanup()
}

View file

@ -15,9 +15,7 @@ import (
"github.com/TecharoHQ/anubis"
"github.com/TecharoHQ/anubis/data"
"github.com/TecharoHQ/anubis/decaymap"
"github.com/TecharoHQ/anubis/internal"
"github.com/TecharoHQ/anubis/internal/dnsbl"
"github.com/TecharoHQ/anubis/internal/ogtags"
"github.com/TecharoHQ/anubis/lib/challenge"
"github.com/TecharoHQ/anubis/lib/localization"
@ -108,8 +106,7 @@ func New(opts Options) (*Server, error) {
hs512Secret: opts.HS512Secret,
policy: opts.Policy,
opts: opts,
DNSBLCache: decaymap.New[string, dnsbl.DroneBLResponse](),
OGTags: ogtags.NewOGTagCache(opts.Target, opts.Policy.OpenGraph),
OGTags: ogtags.NewOGTagCache(opts.Target, opts.Policy.OpenGraph, opts.Policy.Store),
store: opts.Policy.Store,
}

View file

@ -138,7 +138,7 @@ func (s *Server) RenderIndex(w http.ResponseWriter, r *http.Request, rule *polic
var ogTags map[string]string = nil
if s.opts.OpenGraph.Enabled {
var err error
ogTags, err = s.OGTags.GetOGTags(r.URL, r.Host)
ogTags, err = s.OGTags.GetOGTags(r.Context(), r.URL, r.Host)
if err != nil {
lg.Error("failed to get OG tags", "err", err)
}

View file

@ -43,13 +43,22 @@ func z[T any]() T { return *new(T) }
type JSON[T any] struct {
Underlying Interface
Prefix string
}
func (j *JSON[T]) Delete(ctx context.Context, key string) error {
if j.Prefix != "" {
key = j.Prefix + key
}
return j.Underlying.Delete(ctx, key)
}
func (j *JSON[T]) Get(ctx context.Context, key string) (T, error) {
if j.Prefix != "" {
key = j.Prefix + key
}
data, err := j.Underlying.Get(ctx, key)
if err != nil {
return z[T](), err
@ -64,6 +73,10 @@ func (j *JSON[T]) Get(ctx context.Context, key string) (T, error) {
}
func (j *JSON[T]) Set(ctx context.Context, key string, value T, expiry time.Duration) error {
if j.Prefix != "" {
key = j.Prefix + key
}
data, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("%w: %w", ErrCantEncode, err)

50
lib/store/json_test.go Normal file
View file

@ -0,0 +1,50 @@
package store_test
import (
"testing"
"time"
"github.com/TecharoHQ/anubis/lib/store"
"github.com/TecharoHQ/anubis/lib/store/memory"
)
func TestJSON(t *testing.T) {
type data struct {
ID string `json:"id"`
}
st := memory.New(t.Context())
db := store.JSON[data]{
Underlying: st,
Prefix: "foo:",
}
if err := db.Set(t.Context(), "test", data{ID: t.Name()}, time.Minute); err != nil {
t.Fatal(err)
}
got, err := db.Get(t.Context(), "test")
if err != nil {
t.Fatal(err)
}
if got.ID != t.Name() {
t.Fatalf("got wrong data for key \"test\", wanted %q but got: %q", t.Name(), got.ID)
}
if err := db.Delete(t.Context(), "test"); err != nil {
t.Fatal(err)
}
if _, err := db.Get(t.Context(), "test"); err == nil {
t.Fatal("wanted invalid get to fail, it did not")
}
if err := st.Set(t.Context(), "foo:test", []byte("}"), time.Minute); err != nil {
t.Fatal(err)
}
if _, err := db.Get(t.Context(), "test"); err == nil {
t.Fatal("wanted invalid get to fail, it did not")
}
}