feat(checker): add CEL for matching complicated expressions (#421)

* feat(lib/policy): add support for CEL checkers

This adds the ability for administrators to use Common Expression
Language[0] (CEL) for more advanced check logic than Anubis previously
offered.

These can be as simple as:

```yaml
- name: allow-api-routes
  action: ALLOW
  expression:
    and:
    - '!(method == "HEAD" || method == "GET")'
    - path.startsWith("/api/")
```

or get as complicated as:

```yaml
- name: allow-git-clients
  action: ALLOW
  expression:
    and:
    - userAgent.startsWith("git/") || userAgent.contains("libgit") || userAgent.startsWith("go-git") || userAgent.startsWith("JGit/") || userAgent.startsWith("JGit-")
    - >
      "Git-Protocol" in headers && headers["Git-Protocol"] == "version=2"
```

Internally these are compiled and evaluated with cel-go[1]. This also
leaves room for extensibility should that be desired in the future. This
will intersect with #338 and eventually intersect with TLS fingerprints
as in #337.

[0]: https://cel.dev/
[1]: https://github.com/google/cel-go

Signed-off-by: Xe Iaso <me@xeiaso.net>

* feat(data/apps): add API route allow rule for non-HEAD/GET

Signed-off-by: Xe Iaso <me@xeiaso.net>

* docs: document expression syntax

Signed-off-by: Xe Iaso <me@xeiaso.net>

* fix: fixes in review

Signed-off-by: Xe Iaso <me@xeiaso.net>

---------

Signed-off-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
Xe Iaso 2025-05-03 14:26:54 -04:00 committed by GitHub
parent af07691139
commit 865d513e35
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 1166 additions and 14 deletions

108
lib/policy/celchecker.go Normal file
View file

@ -0,0 +1,108 @@
package policy
import (
"fmt"
"net/http"
"github.com/TecharoHQ/anubis/internal"
"github.com/TecharoHQ/anubis/lib/policy/config"
"github.com/TecharoHQ/anubis/lib/policy/expressions"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
)
type CELChecker struct {
src string
program cel.Program
}
func NewCELChecker(cfg *config.ExpressionOrList) (*CELChecker, error) {
env, err := expressions.NewEnvironment()
if err != nil {
return nil, err
}
var src string
var ast *cel.Ast
if cfg.Expression != "" {
src = cfg.Expression
var iss *cel.Issues
interm, iss := env.Compile(src)
if iss != nil {
return nil, iss.Err()
}
ast, iss = env.Check(interm)
if iss != nil {
return nil, iss.Err()
}
}
if len(cfg.All) != 0 {
ast, err = expressions.Join(env, expressions.JoinAnd, cfg.All...)
}
if len(cfg.Any) != 0 {
ast, err = expressions.Join(env, expressions.JoinOr, cfg.Any...)
}
if err != nil {
return nil, err
}
program, err := expressions.Compile(env, ast)
if err != nil {
return nil, fmt.Errorf("can't compile CEL program: %w", err)
}
return &CELChecker{
src: src,
program: program,
}, nil
}
func (cc *CELChecker) Hash() string {
return internal.SHA256sum(cc.src)
}
func (cc *CELChecker) Check(r *http.Request) (bool, error) {
result, _, err := cc.program.ContextEval(r.Context(), &CELRequest{r})
if err != nil {
return false, err
}
if val, ok := result.(types.Bool); ok {
return bool(val), nil
}
return false, nil
}
type CELRequest struct {
*http.Request
}
func (cr *CELRequest) Parent() cel.Activation { return nil }
func (cr *CELRequest) ResolveName(name string) (any, bool) {
switch name {
case "remoteAddress":
return cr.Header.Get("X-Real-Ip"), true
case "host":
return cr.Host, true
case "method":
return cr.Method, true
case "userAgent":
return cr.UserAgent(), true
case "path":
return cr.URL.Path, true
case "query":
return expressions.URLValues{Values: cr.URL.Query()}, true
case "headers":
return expressions.HTTPHeaders{Header: cr.Header}, true
default:
return nil, false
}
}

View file

@ -55,9 +55,11 @@ type BotConfig struct {
UserAgentRegex *string `json:"user_agent_regex"`
PathRegex *string `json:"path_regex"`
HeadersRegex map[string]string `json:"headers_regex"`
Action Rule `json:"action"`
RemoteAddr []string `json:"remote_addresses"`
Challenge *ChallengeRules `json:"challenge,omitempty"`
Expression *ExpressionOrList `json:"expression"`
Action Rule `json:"action"`
Challenge *ChallengeRules `json:"challenge,omitempty"`
}
func (b BotConfig) Zero() bool {
@ -85,7 +87,12 @@ func (b BotConfig) Valid() error {
errs = append(errs, ErrBotMustHaveName)
}
if b.UserAgentRegex == nil && b.PathRegex == nil && len(b.RemoteAddr) == 0 && len(b.HeadersRegex) == 0 {
allFieldsEmpty := b.UserAgentRegex == nil &&
b.PathRegex == nil &&
len(b.RemoteAddr) == 0 &&
len(b.HeadersRegex) == 0
if allFieldsEmpty && b.Expression == nil {
errs = append(errs, ErrBotMustHaveUserAgentOrPath)
}
@ -137,6 +144,12 @@ func (b BotConfig) Valid() error {
}
}
if b.Expression != nil {
if err := b.Expression.Valid(); err != nil {
errs = append(errs, err)
}
}
switch b.Action {
case RuleAllow, RuleBenchmark, RuleChallenge, RuleDeny:
// okay

View file

@ -0,0 +1,62 @@
package config
import (
"encoding/json"
"errors"
"slices"
)
var (
ErrExpressionOrListMustBeStringOrObject = errors.New("config: this must be a string or an object")
ErrExpressionEmpty = errors.New("config: this expression is empty")
ErrExpressionCantHaveBoth = errors.New("config: expression block can't contain multiple expression types")
)
type ExpressionOrList struct {
Expression string `json:"-"`
All []string `json:"all"`
Any []string `json:"any"`
}
func (eol ExpressionOrList) Equal(rhs *ExpressionOrList) bool {
if eol.Expression != rhs.Expression {
return false
}
if !slices.Equal(eol.All, rhs.All) {
return false
}
if !slices.Equal(eol.Any, rhs.Any) {
return false
}
return true
}
func (eol *ExpressionOrList) UnmarshalJSON(data []byte) error {
switch string(data[0]) {
case `"`: // string
return json.Unmarshal(data, &eol.Expression)
case "{": // object
type RawExpressionOrList ExpressionOrList
var val RawExpressionOrList
if err := json.Unmarshal(data, &val); err != nil {
return err
}
eol.All = val.All
eol.Any = val.Any
return nil
}
return ErrExpressionOrListMustBeStringOrObject
}
func (eol *ExpressionOrList) Valid() error {
if len(eol.All) != 0 && len(eol.Any) != 0 {
return ErrExpressionCantHaveBoth
}
return nil
}

View file

@ -0,0 +1,73 @@
package config
import (
"encoding/json"
"errors"
"testing"
)
func TestExpressionOrListUnmarshal(t *testing.T) {
for _, tt := range []struct {
name string
inp string
err error
validErr error
result *ExpressionOrList
}{
{
name: "simple",
inp: `"\"User-Agent\" in headers"`,
result: &ExpressionOrList{
Expression: `"User-Agent" in headers`,
},
},
{
name: "object-and",
inp: `{
"all": ["\"User-Agent\" in headers"]
}`,
result: &ExpressionOrList{
All: []string{
`"User-Agent" in headers`,
},
},
},
{
name: "object-or",
inp: `{
"any": ["\"User-Agent\" in headers"]
}`,
result: &ExpressionOrList{
Any: []string{
`"User-Agent" in headers`,
},
},
},
{
name: "both-or-and",
inp: `{
"all": ["\"User-Agent\" in headers"],
"any": ["\"User-Agent\" in headers"]
}`,
validErr: ErrExpressionCantHaveBoth,
},
} {
t.Run(tt.name, func(t *testing.T) {
var eol ExpressionOrList
if err := json.Unmarshal([]byte(tt.inp), &eol); !errors.Is(err, tt.err) {
t.Errorf("wanted unmarshal error: %v but got: %v", tt.err, err)
}
if tt.result != nil && !eol.Equal(tt.result) {
t.Logf("wanted: %#v", tt.result)
t.Logf("got: %#v", &eol)
t.Fatal("parsed expression is not what was expected")
}
if err := eol.Valid(); !errors.Is(err, tt.validErr) {
t.Errorf("wanted validation error: %v but got: %v", tt.err, err)
}
})
}
}

View file

@ -0,0 +1,17 @@
{
"bots": [
{
"name": "multiple-expression-types",
"action": "ALLOW",
"expression": {
"all": [
"userAgent.startsWith(\"git/\") || userAgent.contains(\"libgit\")",
"\"Git-Protocol\" in headers && headers[\"Git-Protocol\"] == \"version=2\"\n"
],
"any": [
"userAgent.startsWith(\"evilbot/\")"
]
}
}
]
}

View file

@ -0,0 +1,10 @@
bots:
- name: multiple-expression-types
action: ALLOW
expression:
all:
- userAgent.startsWith("git/") || userAgent.contains("libgit")
- >
"Git-Protocol" in headers && headers["Git-Protocol"] == "version=2"
any:
- userAgent.startsWith("evilbot/")

View file

@ -0,0 +1,14 @@
{
"bots": [
{
"name": "allow-git-clients",
"action": "ALLOW",
"expression": {
"all": [
"userAgent.startsWith(\"git/\") || userAgent.contains(\"libgit\")",
"\"Git-Protocol\" in headers && headers[\"Git-Protocol\"] == \"version=2\""
]
}
}
]
}

View file

@ -0,0 +1,8 @@
bots:
- name: allow-git-clients
action: ALLOW
expression:
all:
- userAgent.startsWith("git/") || userAgent.contains("libgit")
- >
"Git-Protocol" in headers && headers["Git-Protocol"] == "version=2"

View file

@ -0,0 +1,3 @@
# Expressions support
The expressions support is based on ideas from [go-away](https://git.gammaspectra.live/git/go-away) but with different opinions about how things should be done.

View file

@ -0,0 +1,45 @@
package expressions
import (
"github.com/google/cel-go/cel"
"github.com/google/cel-go/ext"
)
// NewEnvironment creates a new CEL environment, this is the set of
// variables and functions that are passed into the CEL scope so that
// Anubis can fail loudly and early when something is invalid instead
// of blowing up at runtime.
func NewEnvironment() (*cel.Env, error) {
return cel.NewEnv(
ext.Strings(
ext.StringsLocale("en_US"),
ext.StringsValidateFormatCalls(true),
),
// default all timestamps to UTC
cel.DefaultUTCTimeZone(true),
// Variables exposed to CEL programs:
cel.Variable("remoteAddress", cel.StringType),
cel.Variable("host", cel.StringType),
cel.Variable("method", cel.StringType),
cel.Variable("userAgent", cel.StringType),
cel.Variable("path", cel.StringType),
cel.Variable("query", cel.MapType(cel.StringType, cel.StringType)),
cel.Variable("headers", cel.MapType(cel.StringType, cel.StringType)),
// Functions exposed to CEL programs:
)
}
// Compile takes CEL environment and syntax tree then emits an optimized
// Program for execution.
func Compile(env *cel.Env, ast *cel.Ast) (cel.Program, error) {
return env.Program(
ast,
cel.EvalOptions(
// optimize regular expressions right now instead of on the fly
cel.OptOptimize,
),
)
}

View file

@ -0,0 +1,75 @@
package expressions
import (
"net/http"
"reflect"
"strings"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
// HTTPHeaders is a type wrapper to expose HTTP headers into CEL programs.
type HTTPHeaders struct {
http.Header
}
func (h HTTPHeaders) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, ErrNotImplemented
}
func (h HTTPHeaders) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case types.MapType:
return h
case types.TypeType:
return types.MapType
}
return types.NewErr("can't convert from %q to %q", types.MapType, typeVal)
}
func (h HTTPHeaders) Equal(other ref.Val) ref.Val {
return types.Bool(false) // We don't want to compare header maps
}
func (h HTTPHeaders) Type() ref.Type {
return types.MapType
}
func (h HTTPHeaders) Value() any { return h }
func (h HTTPHeaders) Find(key ref.Val) (ref.Val, bool) {
k, ok := key.(types.String)
if !ok {
return nil, false
}
if _, ok := h.Header[string(k)]; !ok {
return nil, false
}
return types.String(strings.Join(h.Header.Values(string(k)), ",")), true
}
func (h HTTPHeaders) Contains(key ref.Val) ref.Val {
_, ok := h.Find(key)
return types.Bool(ok)
}
func (h HTTPHeaders) Get(key ref.Val) ref.Val {
result, ok := h.Find(key)
if !ok {
return types.ValOrErr(result, "no such key: %v", key)
}
return result
}
func (h HTTPHeaders) Iterator() traits.Iterator { panic("TODO(Xe): implement me") }
func (h HTTPHeaders) IsZeroValue() bool {
return len(h.Header) == 0
}
func (h HTTPHeaders) Size() ref.Val { return types.Int(len(h.Header)) }

View file

@ -0,0 +1,52 @@
package expressions
import (
"net/http"
"testing"
"github.com/google/cel-go/common/types"
)
func TestHTTPHeaders(t *testing.T) {
headers := HTTPHeaders{
Header: http.Header{
"Content-Type": {"application/json"},
"Cf-Worker": {"true"},
"User-Agent": {"Go-http-client/2"},
},
}
t.Run("contains-existing-header", func(t *testing.T) {
resp := headers.Contains(types.String("User-Agent"))
if !bool(resp.(types.Bool)) {
t.Fatal("headers does not contain User-Agent")
}
})
t.Run("not-contains-missing-header", func(t *testing.T) {
resp := headers.Contains(types.String("Xxx-Random-Header"))
if bool(resp.(types.Bool)) {
t.Fatal("headers does not contain User-Agent")
}
})
t.Run("get-existing-header", func(t *testing.T) {
val := headers.Get(types.String("User-Agent"))
switch val.(type) {
case types.String:
// ok
default:
t.Fatalf("result was wrong type %T", val)
}
})
t.Run("not-get-missing-header", func(t *testing.T) {
val := headers.Get(types.String("Xxx-Random-Header"))
switch val.(type) {
case *types.Err:
// ok
default:
t.Fatalf("result was wrong type %T", val)
}
})
}

View file

@ -0,0 +1,104 @@
package expressions
import (
"errors"
"fmt"
"strings"
"github.com/google/cel-go/cel"
)
// JoinOperator is a type wrapper for and/or operators.
//
// This is a separate type so that validation can be done at the type level.
type JoinOperator string
// Possible values for JoinOperator
const (
JoinAnd JoinOperator = "&&"
JoinOr JoinOperator = "||"
)
// Valid ensures that JoinOperator is semantically valid.
func (jo JoinOperator) Valid() error {
switch jo {
case JoinAnd, JoinOr:
return nil
default:
return ErrWrongJoinOperator
}
}
var (
ErrWrongJoinOperator = errors.New("expressions: invalid join operator")
ErrNoExpressions = errors.New("expressions: cannot join zero expressions")
ErrCantCompile = errors.New("expressions: can't compile one expression")
)
// JoinClauses joins a list of compiled clauses into one big if statement.
//
// Imagine the following two clauses:
//
// ball.color == "red"
// ball.shape == "round"
//
// JoinClauses would emit one "joined" clause such as:
//
// ( ball.color == "red" ) && ( ball.shape == "round" )
func JoinClauses(env *cel.Env, operator JoinOperator, clauses ...*cel.Ast) (*cel.Ast, error) {
if err := operator.Valid(); err != nil {
return nil, fmt.Errorf("%w: wanted && or ||, got: %q", err, operator)
}
switch len(clauses) {
case 0:
return nil, ErrNoExpressions
case 1:
return clauses[0], nil
}
var exprs []string
var errs []error
for _, clause := range clauses {
clauseStr, err := cel.AstToString(clause)
if err != nil {
errs = append(errs, err)
continue
}
exprs = append(exprs, "( "+clauseStr+" )")
}
if len(errs) != 0 {
return nil, fmt.Errorf("errors while decompiling statements: %w", errors.Join(errs...))
}
statement := strings.Join(exprs, " "+string(operator)+" ")
result, iss := env.Compile(statement)
if iss != nil {
return nil, iss.Err()
}
return result, nil
}
func Join(env *cel.Env, operator JoinOperator, clauses ...string) (*cel.Ast, error) {
var statements []*cel.Ast
var errs []error
for _, clause := range clauses {
stmt, iss := env.Compile(clause)
if iss != nil && iss.Err() != nil {
errs = append(errs, fmt.Errorf("%w: %q gave: %w", ErrCantCompile, clause, iss.Err()))
continue
}
statements = append(statements, stmt)
}
if len(errs) != 0 {
return nil, fmt.Errorf("errors while joining clauses: %w", errors.Join(errs...))
}
return JoinClauses(env, operator, statements...)
}

View file

@ -0,0 +1,90 @@
package expressions
import (
"errors"
"testing"
"github.com/google/cel-go/cel"
)
func TestJoin(t *testing.T) {
env, err := NewEnvironment()
if err != nil {
t.Fatal(err)
}
for _, tt := range []struct {
name string
clauses []string
op JoinOperator
err error
resultStr string
}{
{
name: "no-clauses",
clauses: []string{},
op: JoinAnd,
err: ErrNoExpressions,
},
{
name: "one-clause-identity",
clauses: []string{`remoteAddress == "8.8.8.8"`},
op: JoinAnd,
err: nil,
resultStr: `remoteAddress == "8.8.8.8"`,
},
{
name: "multi-clause-and",
clauses: []string{
`remoteAddress == "8.8.8.8"`,
`host == "anubis.techaro.lol"`,
},
op: JoinAnd,
err: nil,
resultStr: `remoteAddress == "8.8.8.8" && host == "anubis.techaro.lol"`,
},
{
name: "multi-clause-or",
clauses: []string{
`remoteAddress == "8.8.8.8"`,
`host == "anubis.techaro.lol"`,
},
op: JoinOr,
err: nil,
resultStr: `remoteAddress == "8.8.8.8" || host == "anubis.techaro.lol"`,
},
{
name: "git-user-agent",
clauses: []string{
`userAgent.startsWith("git/") || userAgent.contains("libgit")`,
`"Git-Protocol" in headers && headers["Git-Protocol"] == "version=2"`,
},
op: JoinAnd,
err: nil,
resultStr: `(userAgent.startsWith("git/") || userAgent.contains("libgit")) && "Git-Protocol" in headers &&
headers["Git-Protocol"] == "version=2"`,
},
} {
t.Run(tt.name, func(t *testing.T) {
result, err := Join(env, tt.op, tt.clauses...)
if !errors.Is(err, tt.err) {
t.Errorf("wanted error %v but got: %v", tt.err, err)
}
if tt.err != nil {
return
}
program, err := cel.AstToString(result)
if err != nil {
t.Fatalf("can't decompile program: %v", err)
}
if tt.resultStr != program {
t.Logf("wanted: %s", tt.resultStr)
t.Logf("got: %s", program)
t.Error("program did not compile as expected")
}
})
}
}

View file

@ -0,0 +1,78 @@
package expressions
import (
"errors"
"net/url"
"reflect"
"strings"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
var ErrNotImplemented = errors.New("expressions: not implemented")
// URLValues is a type wrapper to expose url.Values into CEL programs.
type URLValues struct {
url.Values
}
func (u URLValues) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, ErrNotImplemented
}
func (u URLValues) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case types.MapType:
return u
case types.TypeType:
return types.MapType
}
return types.NewErr("can't convert from %q to %q", types.MapType, typeVal)
}
func (u URLValues) Equal(other ref.Val) ref.Val {
return types.Bool(false) // We don't want to compare header maps
}
func (u URLValues) Type() ref.Type {
return types.MapType
}
func (u URLValues) Value() any { return u }
func (u URLValues) Find(key ref.Val) (ref.Val, bool) {
k, ok := key.(types.String)
if !ok {
return nil, false
}
if _, ok := u.Values[string(k)]; !ok {
return nil, false
}
return types.String(strings.Join(u.Values[string(k)], ",")), true
}
func (u URLValues) Contains(key ref.Val) ref.Val {
_, ok := u.Find(key)
return types.Bool(ok)
}
func (u URLValues) Get(key ref.Val) ref.Val {
result, ok := u.Find(key)
if !ok {
return types.ValOrErr(result, "no such key: %v", key)
}
return result
}
func (u URLValues) Iterator() traits.Iterator { panic("TODO(Xe): implement me") }
func (u URLValues) IsZeroValue() bool {
return len(u.Values) == 0
}
func (u URLValues) Size() ref.Val { return types.Int(len(u.Values)) }

View file

@ -0,0 +1,50 @@
package expressions
import (
"net/url"
"testing"
"github.com/google/cel-go/common/types"
)
func TestURLValues(t *testing.T) {
headers := URLValues{
Values: url.Values{
"format": {"json"},
},
}
t.Run("contains-existing-key", func(t *testing.T) {
resp := headers.Contains(types.String("format"))
if !bool(resp.(types.Bool)) {
t.Fatal("headers does not contain User-Agent")
}
})
t.Run("not-contains-missing-key", func(t *testing.T) {
resp := headers.Contains(types.String("not-there"))
if bool(resp.(types.Bool)) {
t.Fatal("headers does not contain User-Agent")
}
})
t.Run("get-existing-key", func(t *testing.T) {
val := headers.Get(types.String("format"))
switch val.(type) {
case types.String:
// ok
default:
t.Fatalf("result was wrong type %T", val)
}
})
t.Run("not-get-missing-key", func(t *testing.T) {
val := headers.Get(types.String("not-there"))
switch val.(type) {
case *types.Err:
// ok
default:
t.Fatalf("result was wrong type %T", val)
}
})
}

View file

@ -94,6 +94,15 @@ func ParseConfig(fin io.Reader, fname string, defaultDifficulty int) (*ParsedCon
}
}
if b.Expression != nil {
c, err := NewCELChecker(b.Expression)
if err != nil {
validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s expressions: %w", b.Name, err))
} else {
cl = append(cl, c)
}
}
if b.Challenge == nil {
parsedBot.Challenge = &config.ChallengeRules{
Difficulty: defaultDifficulty,