fix(config): actually load threshold config (#696)

* fix(config): actually load threshold config

Signed-off-by: Xe Iaso <me@xeiaso.net>

* chore: spelling

Signed-off-by: Xe Iaso <me@xeiaso.net>

* test(lib): fix test failures

Signed-off-by: Xe Iaso <me@xeiaso.net>

---------

Signed-off-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
Xe Iaso 2025-06-19 17:13:01 -04:00 committed by GitHub
parent 226cf36bf7
commit 7aa732c700
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 201 additions and 30 deletions

View file

@ -0,0 +1,55 @@
package config
import (
"errors"
"fmt"
"testing"
)
func TestASNsValid(t *testing.T) {
for _, tt := range []struct {
name string
input *ASNs
err error
}{
{
name: "basic valid",
input: &ASNs{
Match: []uint32{13335}, // Cloudflare
},
},
{
name: "private ASN",
input: &ASNs{
Match: []uint32{64513, 4206942069}, // 16 and 32 bit private ASN
},
err: ErrPrivateASN,
},
} {
t.Run(tt.name, func(t *testing.T) {
if err := tt.input.Valid(); !errors.Is(err, tt.err) {
t.Logf("want: %v", tt.err)
t.Logf("got: %v", err)
t.Error("got wrong validation error")
}
})
}
}
func TestIsPrivateASN(t *testing.T) {
for _, tt := range []struct {
input uint32
output bool
}{
{13335, false}, // Cloudflare
{64513, true}, // 16 bit private ASN
{4206942069, true}, // 32 bit private ASN
} {
t.Run(fmt.Sprint(tt.input, "->", tt.output), func(t *testing.T) {
result := isPrivateASN(tt.input)
if result != tt.output {
t.Errorf("wanted isPrivateASN(%d) == %v, got: %v", tt.input, tt.output, result)
}
})
}
}

View file

@ -326,7 +326,7 @@ type fileConfig struct {
Bots []BotOrImport `json:"bots"`
DNSBL bool `json:"dnsbl"`
StatusCodes StatusCodes `json:"status_codes"`
Thresholds []Threshold `json:"threshold"`
Thresholds []Threshold `json:"thresholds"`
}
func (c *fileConfig) Valid() error {
@ -346,10 +346,6 @@ func (c *fileConfig) Valid() error {
errs = append(errs, err)
}
if len(c.Thresholds) == 0 {
errs = append(errs, ErrNoThresholdRulesDefined)
}
for i, t := range c.Thresholds {
if err := t.Valid(); err != nil {
errs = append(errs, fmt.Errorf("threshold %d: %w", i, err))
@ -369,7 +365,6 @@ func Load(fin io.Reader, fname string) (*Config, error) {
Challenge: http.StatusOK,
Deny: http.StatusOK,
},
Thresholds: DefaultThresholds,
}
if err := yaml.NewYAMLToJSONDecoder(fin).Decode(&c); err != nil {
@ -407,6 +402,10 @@ func Load(fin io.Reader, fname string) (*Config, error) {
}
}
if len(c.Thresholds) == 0 {
c.Thresholds = DefaultThresholds
}
for _, t := range c.Thresholds {
if err := t.Valid(); err != nil {
validationErrs = append(validationErrs, err)

View file

@ -8,7 +8,7 @@ import (
)
var (
countryCodeRegexp = regexp.MustCompile(`^\w{2}$`)
countryCodeRegexp = regexp.MustCompile(`^[a-zA-Z]{2}$`)
ErrNotCountryCode = errors.New("config.Bot: invalid country code")
)

View file

@ -0,0 +1,36 @@
package config
import (
"errors"
"testing"
)
func TestGeoIPValid(t *testing.T) {
for _, tt := range []struct {
name string
input *GeoIP
err error
}{
{
name: "basic valid",
input: &GeoIP{
Countries: []string{"CA"},
},
},
{
name: "invalid country",
input: &GeoIP{
Countries: []string{"XOB"},
},
err: ErrNotCountryCode,
},
} {
t.Run(tt.name, func(t *testing.T) {
if err := tt.input.Valid(); !errors.Is(err, tt.err) {
t.Logf("want: %v", tt.err)
t.Logf("got: %v", err)
t.Error("got wrong validation error")
}
})
}
}

View file

@ -0,0 +1,11 @@
bots:
- name: simple-weight-adjust
action: WEIGH
user_agent_regex: Mozilla
weight:
adjust: 5
thresholds:
- name: extreme-suspicion
expression: "true"
action: WEIGH

View file

@ -0,0 +1,15 @@
bots:
- name: simple-weight-adjust
action: WEIGH
user_agent_regex: Mozilla
weight:
adjust: 5
thresholds:
- name: extreme-suspicion
expression: "true"
action: WEIGH
challenge:
algorithm: fast
difficulty: 4
report_as: 4

View file

@ -3,6 +3,8 @@ package config
import (
"errors"
"fmt"
"os"
"path/filepath"
"testing"
)
@ -90,3 +92,20 @@ func TestDefaultThresholdsValid(t *testing.T) {
})
}
}
func TestLoadActuallyLoadsThresholds(t *testing.T) {
fin, err := os.Open(filepath.Join(".", "testdata", "good", "thresholds.yaml"))
if err != nil {
t.Fatal(err)
}
defer fin.Close()
c, err := Load(fin, fin.Name())
if err != nil {
t.Fatal(err)
}
if len(c.Thresholds) != 4 {
t.Errorf("wanted 4 thresholds, got %d thresholds", len(c.Thresholds))
}
}