Files
kubesphere/vendor/github.com/open-policy-agent/opa/rego/rego.go
hongming cfebd96a1f update dependencies (#6267)
Signed-off-by: hongming <coder.scala@gmail.com>
2024-11-06 10:27:06 +08:00

2890 lines
79 KiB
Go

// Copyright 2017 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
// Package rego exposes high level APIs for evaluating Rego policies.
package rego
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"strings"
"time"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/bundle"
bundleUtils "github.com/open-policy-agent/opa/internal/bundle"
"github.com/open-policy-agent/opa/internal/compiler/wasm"
"github.com/open-policy-agent/opa/internal/future"
"github.com/open-policy-agent/opa/internal/planner"
"github.com/open-policy-agent/opa/internal/rego/opa"
"github.com/open-policy-agent/opa/internal/wasm/encoding"
"github.com/open-policy-agent/opa/ir"
"github.com/open-policy-agent/opa/loader"
"github.com/open-policy-agent/opa/metrics"
"github.com/open-policy-agent/opa/plugins"
"github.com/open-policy-agent/opa/resolver"
"github.com/open-policy-agent/opa/storage"
"github.com/open-policy-agent/opa/storage/inmem"
"github.com/open-policy-agent/opa/topdown"
"github.com/open-policy-agent/opa/topdown/builtins"
"github.com/open-policy-agent/opa/topdown/cache"
"github.com/open-policy-agent/opa/topdown/print"
"github.com/open-policy-agent/opa/tracing"
"github.com/open-policy-agent/opa/types"
"github.com/open-policy-agent/opa/util"
)
const (
defaultPartialNamespace = "partial"
wasmVarPrefix = "^"
)
// nolint: deadcode,varcheck
const (
targetWasm = "wasm"
targetRego = "rego"
)
// CompileResult represents the result of compiling a Rego query, zero or more
// Rego modules, and arbitrary contextual data into an executable.
type CompileResult struct {
Bytes []byte `json:"bytes"`
}
// PartialQueries contains the queries and support modules produced by partial
// evaluation.
type PartialQueries struct {
Queries []ast.Body `json:"queries,omitempty"`
Support []*ast.Module `json:"modules,omitempty"`
}
// PartialResult represents the result of partial evaluation. The result can be
// used to generate a new query that can be run when inputs are known.
type PartialResult struct {
compiler *ast.Compiler
store storage.Store
body ast.Body
builtinDecls map[string]*ast.Builtin
builtinFuncs map[string]*topdown.Builtin
}
// Rego returns an object that can be evaluated to produce a query result.
func (pr PartialResult) Rego(options ...func(*Rego)) *Rego {
options = append(options, Compiler(pr.compiler), Store(pr.store), ParsedQuery(pr.body))
r := New(options...)
// Propagate any custom builtins.
for k, v := range pr.builtinDecls {
r.builtinDecls[k] = v
}
for k, v := range pr.builtinFuncs {
r.builtinFuncs[k] = v
}
return r
}
// preparedQuery is a wrapper around a Rego object which has pre-processed
// state stored on it. Once prepared there are a more limited number of actions
// that can be taken with it. It will, however, be able to evaluate faster since
// it will not have to re-parse or compile as much.
type preparedQuery struct {
r *Rego
cfg *PrepareConfig
}
// EvalContext defines the set of options allowed to be set at evaluation
// time. Any other options will need to be set on a new Rego object.
type EvalContext struct {
hasInput bool
time time.Time
seed io.Reader
rawInput *interface{}
parsedInput ast.Value
metrics metrics.Metrics
txn storage.Transaction
instrument bool
instrumentation *topdown.Instrumentation
partialNamespace string
queryTracers []topdown.QueryTracer
compiledQuery compiledQuery
unknowns []string
disableInlining []ast.Ref
parsedUnknowns []*ast.Term
indexing bool
earlyExit bool
interQueryBuiltinCache cache.InterQueryCache
interQueryBuiltinValueCache cache.InterQueryValueCache
ndBuiltinCache builtins.NDBCache
resolvers []refResolver
sortSets bool
copyMaps bool
printHook print.Hook
capabilities *ast.Capabilities
strictBuiltinErrors bool
virtualCache topdown.VirtualCache
}
func (e *EvalContext) RawInput() *interface{} {
return e.rawInput
}
func (e *EvalContext) ParsedInput() ast.Value {
return e.parsedInput
}
func (e *EvalContext) Time() time.Time {
return e.time
}
func (e *EvalContext) Seed() io.Reader {
return e.seed
}
func (e *EvalContext) InterQueryBuiltinCache() cache.InterQueryCache {
return e.interQueryBuiltinCache
}
func (e *EvalContext) InterQueryBuiltinValueCache() cache.InterQueryValueCache {
return e.interQueryBuiltinValueCache
}
func (e *EvalContext) PrintHook() print.Hook {
return e.printHook
}
func (e *EvalContext) Metrics() metrics.Metrics {
return e.metrics
}
func (e *EvalContext) StrictBuiltinErrors() bool {
return e.strictBuiltinErrors
}
func (e *EvalContext) NDBCache() builtins.NDBCache {
return e.ndBuiltinCache
}
func (e *EvalContext) CompiledQuery() ast.Body {
return e.compiledQuery.query
}
func (e *EvalContext) Capabilities() *ast.Capabilities {
return e.capabilities
}
func (e *EvalContext) Transaction() storage.Transaction {
return e.txn
}
// EvalOption defines a function to set an option on an EvalConfig
type EvalOption func(*EvalContext)
// EvalInput configures the input for a Prepared Query's evaluation
func EvalInput(input interface{}) EvalOption {
return func(e *EvalContext) {
e.rawInput = &input
e.hasInput = true
}
}
// EvalParsedInput configures the input for a Prepared Query's evaluation
func EvalParsedInput(input ast.Value) EvalOption {
return func(e *EvalContext) {
e.parsedInput = input
e.hasInput = true
}
}
// EvalMetrics configures the metrics for a Prepared Query's evaluation
func EvalMetrics(metric metrics.Metrics) EvalOption {
return func(e *EvalContext) {
e.metrics = metric
}
}
// EvalTransaction configures the Transaction for a Prepared Query's evaluation
func EvalTransaction(txn storage.Transaction) EvalOption {
return func(e *EvalContext) {
e.txn = txn
}
}
// EvalInstrument enables or disables instrumenting for a Prepared Query's evaluation
func EvalInstrument(instrument bool) EvalOption {
return func(e *EvalContext) {
e.instrument = instrument
}
}
// EvalTracer configures a tracer for a Prepared Query's evaluation
// Deprecated: Use EvalQueryTracer instead.
func EvalTracer(tracer topdown.Tracer) EvalOption {
return func(e *EvalContext) {
if tracer != nil {
e.queryTracers = append(e.queryTracers, topdown.WrapLegacyTracer(tracer))
}
}
}
// EvalQueryTracer configures a tracer for a Prepared Query's evaluation
func EvalQueryTracer(tracer topdown.QueryTracer) EvalOption {
return func(e *EvalContext) {
if tracer != nil {
e.queryTracers = append(e.queryTracers, tracer)
}
}
}
// EvalPartialNamespace returns an argument that sets the namespace to use for
// partial evaluation results. The namespace must be a valid package path
// component.
func EvalPartialNamespace(ns string) EvalOption {
return func(e *EvalContext) {
e.partialNamespace = ns
}
}
// EvalUnknowns returns an argument that sets the values to treat as
// unknown during partial evaluation.
func EvalUnknowns(unknowns []string) EvalOption {
return func(e *EvalContext) {
e.unknowns = unknowns
}
}
// EvalDisableInlining returns an argument that adds a set of paths to exclude from
// partial evaluation inlining.
func EvalDisableInlining(paths []ast.Ref) EvalOption {
return func(e *EvalContext) {
e.disableInlining = paths
}
}
// EvalParsedUnknowns returns an argument that sets the values to treat
// as unknown during partial evaluation.
func EvalParsedUnknowns(unknowns []*ast.Term) EvalOption {
return func(e *EvalContext) {
e.parsedUnknowns = unknowns
}
}
// EvalRuleIndexing will disable indexing optimizations for the
// evaluation. This should only be used when tracing in debug mode.
func EvalRuleIndexing(enabled bool) EvalOption {
return func(e *EvalContext) {
e.indexing = enabled
}
}
// EvalEarlyExit will disable 'early exit' optimizations for the
// evaluation. This should only be used when tracing in debug mode.
func EvalEarlyExit(enabled bool) EvalOption {
return func(e *EvalContext) {
e.earlyExit = enabled
}
}
// EvalTime sets the wall clock time to use during policy evaluation.
// time.now_ns() calls will return this value.
func EvalTime(x time.Time) EvalOption {
return func(e *EvalContext) {
e.time = x
}
}
// EvalSeed sets a reader that will seed randomization required by built-in functions.
// If a seed is not provided crypto/rand.Reader is used.
func EvalSeed(r io.Reader) EvalOption {
return func(e *EvalContext) {
e.seed = r
}
}
// EvalInterQueryBuiltinCache sets the inter-query cache that built-in functions can utilize
// during evaluation.
func EvalInterQueryBuiltinCache(c cache.InterQueryCache) EvalOption {
return func(e *EvalContext) {
e.interQueryBuiltinCache = c
}
}
// EvalInterQueryBuiltinValueCache sets the inter-query value cache that built-in functions can utilize
// during evaluation.
func EvalInterQueryBuiltinValueCache(c cache.InterQueryValueCache) EvalOption {
return func(e *EvalContext) {
e.interQueryBuiltinValueCache = c
}
}
// EvalNDBuiltinCache sets the non-deterministic builtin cache that built-in functions can
// use during evaluation.
func EvalNDBuiltinCache(c builtins.NDBCache) EvalOption {
return func(e *EvalContext) {
e.ndBuiltinCache = c
}
}
// EvalResolver sets a Resolver for a specified ref path for this evaluation.
func EvalResolver(ref ast.Ref, r resolver.Resolver) EvalOption {
return func(e *EvalContext) {
e.resolvers = append(e.resolvers, refResolver{ref, r})
}
}
// EvalSortSets causes the evaluator to sort sets before returning them as JSON arrays.
func EvalSortSets(yes bool) EvalOption {
return func(e *EvalContext) {
e.sortSets = yes
}
}
// EvalCopyMaps causes the evaluator to copy `map[string]interface{}`s before returning them.
func EvalCopyMaps(yes bool) EvalOption {
return func(e *EvalContext) {
e.copyMaps = yes
}
}
// EvalPrintHook sets the object to use for handling print statement outputs.
func EvalPrintHook(ph print.Hook) EvalOption {
return func(e *EvalContext) {
e.printHook = ph
}
}
// EvalVirtualCache sets the topdown.VirtualCache to use for evaluation. This is
// optional, and if not set, the default cache is used.
func EvalVirtualCache(vc topdown.VirtualCache) EvalOption {
return func(e *EvalContext) {
e.virtualCache = vc
}
}
func (pq preparedQuery) Modules() map[string]*ast.Module {
mods := make(map[string]*ast.Module)
for name, mod := range pq.r.parsedModules {
mods[name] = mod
}
for _, b := range pq.r.bundles {
for _, mod := range b.Modules {
mods[mod.Path] = mod.Parsed
}
}
return mods
}
// newEvalContext creates a new EvalContext overlaying any EvalOptions over top
// the Rego object on the preparedQuery. The returned function should be called
// once the evaluation is complete to close any transactions that might have
// been opened.
func (pq preparedQuery) newEvalContext(ctx context.Context, options []EvalOption) (*EvalContext, func(context.Context), error) {
ectx := &EvalContext{
hasInput: false,
rawInput: nil,
parsedInput: nil,
metrics: nil,
txn: nil,
instrument: false,
instrumentation: nil,
partialNamespace: pq.r.partialNamespace,
queryTracers: nil,
unknowns: pq.r.unknowns,
parsedUnknowns: pq.r.parsedUnknowns,
compiledQuery: compiledQuery{},
indexing: true,
earlyExit: true,
resolvers: pq.r.resolvers,
printHook: pq.r.printHook,
capabilities: pq.r.capabilities,
strictBuiltinErrors: pq.r.strictBuiltinErrors,
}
for _, o := range options {
o(ectx)
}
if ectx.metrics == nil {
ectx.metrics = metrics.New()
}
if ectx.instrument {
ectx.instrumentation = topdown.NewInstrumentation(ectx.metrics)
}
// Default to an empty "finish" function
finishFunc := func(context.Context) {}
var err error
ectx.disableInlining, err = parseStringsToRefs(pq.r.disableInlining)
if err != nil {
return nil, finishFunc, err
}
if ectx.txn == nil {
ectx.txn, err = pq.r.store.NewTransaction(ctx)
if err != nil {
return nil, finishFunc, err
}
finishFunc = func(ctx context.Context) {
pq.r.store.Abort(ctx, ectx.txn)
}
}
// If we didn't get an input specified in the Eval options
// then fall back to the Rego object's input fields.
if !ectx.hasInput {
ectx.rawInput = pq.r.rawInput
ectx.parsedInput = pq.r.parsedInput
}
if ectx.parsedInput == nil {
if ectx.rawInput == nil {
// Fall back to the original Rego objects input if none was specified
// Note that it could still be nil
ectx.rawInput = pq.r.rawInput
}
if pq.r.targetPlugin(pq.r.target) == nil && // no plugin claims this target
pq.r.target != targetWasm {
ectx.parsedInput, err = pq.r.parseRawInput(ectx.rawInput, ectx.metrics)
if err != nil {
return nil, finishFunc, err
}
}
}
return ectx, finishFunc, nil
}
// PreparedEvalQuery holds the prepared Rego state that has been pre-processed
// for subsequent evaluations.
type PreparedEvalQuery struct {
preparedQuery
}
// Eval evaluates this PartialResult's Rego object with additional eval options
// and returns a ResultSet.
// If options are provided they will override the original Rego options respective value.
// The original Rego object transaction will *not* be re-used. A new transaction will be opened
// if one is not provided with an EvalOption.
func (pq PreparedEvalQuery) Eval(ctx context.Context, options ...EvalOption) (ResultSet, error) {
ectx, finish, err := pq.newEvalContext(ctx, options)
if err != nil {
return nil, err
}
defer finish(ctx)
ectx.compiledQuery = pq.r.compiledQueries[evalQueryType]
return pq.r.eval(ctx, ectx)
}
// PreparedPartialQuery holds the prepared Rego state that has been pre-processed
// for partial evaluations.
type PreparedPartialQuery struct {
preparedQuery
}
// Partial runs partial evaluation on the prepared query and returns the result.
// The original Rego object transaction will *not* be re-used. A new transaction will be opened
// if one is not provided with an EvalOption.
func (pq PreparedPartialQuery) Partial(ctx context.Context, options ...EvalOption) (*PartialQueries, error) {
ectx, finish, err := pq.newEvalContext(ctx, options)
if err != nil {
return nil, err
}
defer finish(ctx)
ectx.compiledQuery = pq.r.compiledQueries[partialQueryType]
return pq.r.partial(ctx, ectx)
}
// Errors represents a collection of errors returned when evaluating Rego.
type Errors []error
func (errs Errors) Error() string {
if len(errs) == 0 {
return "no error"
}
if len(errs) == 1 {
return fmt.Sprintf("1 error occurred: %v", errs[0].Error())
}
buf := []string{fmt.Sprintf("%v errors occurred", len(errs))}
for _, err := range errs {
buf = append(buf, err.Error())
}
return strings.Join(buf, "\n")
}
var errPartialEvaluationNotEffective = errors.New("partial evaluation not effective")
// IsPartialEvaluationNotEffectiveErr returns true if err is an error returned by
// this package to indicate that partial evaluation was ineffective.
func IsPartialEvaluationNotEffectiveErr(err error) bool {
errs, ok := err.(Errors)
if !ok {
return false
}
return len(errs) == 1 && errs[0] == errPartialEvaluationNotEffective
}
type compiledQuery struct {
query ast.Body
compiler ast.QueryCompiler
}
type queryType int
// Define a query type for each of the top level Rego
// API's that compile queries differently.
const (
evalQueryType queryType = iota
partialResultQueryType
partialQueryType
compileQueryType
)
type loadPaths struct {
paths []string
filter loader.Filter
}
// Rego constructs a query and can be evaluated to obtain results.
type Rego struct {
query string
parsedQuery ast.Body
compiledQueries map[queryType]compiledQuery
pkg string
parsedPackage *ast.Package
imports []string
parsedImports []*ast.Import
rawInput *interface{}
parsedInput ast.Value
unknowns []string
parsedUnknowns []*ast.Term
disableInlining []string
shallowInlining bool
skipPartialNamespace bool
partialNamespace string
modules []rawModule
parsedModules map[string]*ast.Module
compiler *ast.Compiler
store storage.Store
ownStore bool
ownStoreReadAst bool
txn storage.Transaction
metrics metrics.Metrics
queryTracers []topdown.QueryTracer
tracebuf *topdown.BufferTracer
trace bool
instrumentation *topdown.Instrumentation
instrument bool
capture map[*ast.Expr]ast.Var // map exprs to generated capture vars
termVarID int
dump io.Writer
runtime *ast.Term
time time.Time
seed io.Reader
capabilities *ast.Capabilities
builtinDecls map[string]*ast.Builtin
builtinFuncs map[string]*topdown.Builtin
unsafeBuiltins map[string]struct{}
loadPaths loadPaths
bundlePaths []string
bundles map[string]*bundle.Bundle
skipBundleVerification bool
interQueryBuiltinCache cache.InterQueryCache
interQueryBuiltinValueCache cache.InterQueryValueCache
ndBuiltinCache builtins.NDBCache
strictBuiltinErrors bool
builtinErrorList *[]topdown.Error
resolvers []refResolver
schemaSet *ast.SchemaSet
target string // target type (wasm, rego, etc.)
opa opa.EvalEngine
generateJSON func(*ast.Term, *EvalContext) (interface{}, error)
printHook print.Hook
enablePrintStatements bool
distributedTacingOpts tracing.Options
strict bool
pluginMgr *plugins.Manager
plugins []TargetPlugin
targetPrepState TargetPluginEval
regoVersion ast.RegoVersion
}
// Function represents a built-in function that is callable in Rego.
type Function struct {
Name string
Description string
Decl *types.Function
Memoize bool
Nondeterministic bool
}
// BuiltinContext contains additional attributes from the evaluator that
// built-in functions can use, e.g., the request context.Context, caches, etc.
type BuiltinContext = topdown.BuiltinContext
type (
// Builtin1 defines a built-in function that accepts 1 argument.
Builtin1 func(bctx BuiltinContext, op1 *ast.Term) (*ast.Term, error)
// Builtin2 defines a built-in function that accepts 2 arguments.
Builtin2 func(bctx BuiltinContext, op1, op2 *ast.Term) (*ast.Term, error)
// Builtin3 defines a built-in function that accepts 3 argument.
Builtin3 func(bctx BuiltinContext, op1, op2, op3 *ast.Term) (*ast.Term, error)
// Builtin4 defines a built-in function that accepts 4 argument.
Builtin4 func(bctx BuiltinContext, op1, op2, op3, op4 *ast.Term) (*ast.Term, error)
// BuiltinDyn defines a built-in function that accepts a list of arguments.
BuiltinDyn func(bctx BuiltinContext, terms []*ast.Term) (*ast.Term, error)
)
// RegisterBuiltin1 adds a built-in function globally inside the OPA runtime.
func RegisterBuiltin1(decl *Function, impl Builtin1) {
ast.RegisterBuiltin(&ast.Builtin{
Name: decl.Name,
Description: decl.Description,
Decl: decl.Decl,
Nondeterministic: decl.Nondeterministic,
})
topdown.RegisterBuiltinFunc(decl.Name, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error {
result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return impl(bctx, terms[0]) })
return finishFunction(decl.Name, bctx, result, err, iter)
})
}
// RegisterBuiltin2 adds a built-in function globally inside the OPA runtime.
func RegisterBuiltin2(decl *Function, impl Builtin2) {
ast.RegisterBuiltin(&ast.Builtin{
Name: decl.Name,
Description: decl.Description,
Decl: decl.Decl,
Nondeterministic: decl.Nondeterministic,
})
topdown.RegisterBuiltinFunc(decl.Name, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error {
result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return impl(bctx, terms[0], terms[1]) })
return finishFunction(decl.Name, bctx, result, err, iter)
})
}
// RegisterBuiltin3 adds a built-in function globally inside the OPA runtime.
func RegisterBuiltin3(decl *Function, impl Builtin3) {
ast.RegisterBuiltin(&ast.Builtin{
Name: decl.Name,
Description: decl.Description,
Decl: decl.Decl,
Nondeterministic: decl.Nondeterministic,
})
topdown.RegisterBuiltinFunc(decl.Name, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error {
result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return impl(bctx, terms[0], terms[1], terms[2]) })
return finishFunction(decl.Name, bctx, result, err, iter)
})
}
// RegisterBuiltin4 adds a built-in function globally inside the OPA runtime.
func RegisterBuiltin4(decl *Function, impl Builtin4) {
ast.RegisterBuiltin(&ast.Builtin{
Name: decl.Name,
Description: decl.Description,
Decl: decl.Decl,
Nondeterministic: decl.Nondeterministic,
})
topdown.RegisterBuiltinFunc(decl.Name, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error {
result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return impl(bctx, terms[0], terms[1], terms[2], terms[3]) })
return finishFunction(decl.Name, bctx, result, err, iter)
})
}
// RegisterBuiltinDyn adds a built-in function globally inside the OPA runtime.
func RegisterBuiltinDyn(decl *Function, impl BuiltinDyn) {
ast.RegisterBuiltin(&ast.Builtin{
Name: decl.Name,
Description: decl.Description,
Decl: decl.Decl,
Nondeterministic: decl.Nondeterministic,
})
topdown.RegisterBuiltinFunc(decl.Name, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error {
result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return impl(bctx, terms) })
return finishFunction(decl.Name, bctx, result, err, iter)
})
}
// Function1 returns an option that adds a built-in function to the Rego object.
func Function1(decl *Function, f Builtin1) func(*Rego) {
return newFunction(decl, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error {
result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return f(bctx, terms[0]) })
return finishFunction(decl.Name, bctx, result, err, iter)
})
}
// Function2 returns an option that adds a built-in function to the Rego object.
func Function2(decl *Function, f Builtin2) func(*Rego) {
return newFunction(decl, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error {
result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return f(bctx, terms[0], terms[1]) })
return finishFunction(decl.Name, bctx, result, err, iter)
})
}
// Function3 returns an option that adds a built-in function to the Rego object.
func Function3(decl *Function, f Builtin3) func(*Rego) {
return newFunction(decl, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error {
result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return f(bctx, terms[0], terms[1], terms[2]) })
return finishFunction(decl.Name, bctx, result, err, iter)
})
}
// Function4 returns an option that adds a built-in function to the Rego object.
func Function4(decl *Function, f Builtin4) func(*Rego) {
return newFunction(decl, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error {
result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return f(bctx, terms[0], terms[1], terms[2], terms[3]) })
return finishFunction(decl.Name, bctx, result, err, iter)
})
}
// FunctionDyn returns an option that adds a built-in function to the Rego object.
func FunctionDyn(decl *Function, f BuiltinDyn) func(*Rego) {
return newFunction(decl, func(bctx BuiltinContext, terms []*ast.Term, iter func(*ast.Term) error) error {
result, err := memoize(decl, bctx, terms, func() (*ast.Term, error) { return f(bctx, terms) })
return finishFunction(decl.Name, bctx, result, err, iter)
})
}
// FunctionDecl returns an option that adds a custom-built-in function
// __declaration__. NO implementation is provided. This is used for
// non-interpreter execution envs (e.g., Wasm).
func FunctionDecl(decl *Function) func(*Rego) {
return newDecl(decl)
}
func newDecl(decl *Function) func(*Rego) {
return func(r *Rego) {
r.builtinDecls[decl.Name] = &ast.Builtin{
Name: decl.Name,
Decl: decl.Decl,
}
}
}
type memo struct {
term *ast.Term
err error
}
type memokey string
func memoize(decl *Function, bctx BuiltinContext, terms []*ast.Term, ifEmpty func() (*ast.Term, error)) (*ast.Term, error) {
if !decl.Memoize {
return ifEmpty()
}
// NOTE(tsandall): we assume memoization is applied to infrequent built-in
// calls that do things like fetch data from remote locations. As such,
// converting the terms to strings is acceptable for now.
var b strings.Builder
if _, err := b.WriteString(decl.Name); err != nil {
return nil, err
}
// The term slice _may_ include an output term depending on how the caller
// referred to the built-in function. Only use the arguments as the cache
// key. Unification ensures we don't get false positive matches.
for i := 0; i < len(decl.Decl.Args()); i++ {
if _, err := b.WriteString(terms[i].String()); err != nil {
return nil, err
}
}
key := memokey(b.String())
hit, ok := bctx.Cache.Get(key)
var m memo
if ok {
m = hit.(memo)
} else {
m.term, m.err = ifEmpty()
bctx.Cache.Put(key, m)
}
return m.term, m.err
}
// Dump returns an argument that sets the writer to dump debugging information to.
func Dump(w io.Writer) func(r *Rego) {
return func(r *Rego) {
r.dump = w
}
}
// Query returns an argument that sets the Rego query.
func Query(q string) func(r *Rego) {
return func(r *Rego) {
r.query = q
}
}
// ParsedQuery returns an argument that sets the Rego query.
func ParsedQuery(q ast.Body) func(r *Rego) {
return func(r *Rego) {
r.parsedQuery = q
}
}
// Package returns an argument that sets the Rego package on the query's
// context.
func Package(p string) func(r *Rego) {
return func(r *Rego) {
r.pkg = p
}
}
// ParsedPackage returns an argument that sets the Rego package on the query's
// context.
func ParsedPackage(pkg *ast.Package) func(r *Rego) {
return func(r *Rego) {
r.parsedPackage = pkg
}
}
// Imports returns an argument that adds a Rego import to the query's context.
func Imports(p []string) func(r *Rego) {
return func(r *Rego) {
r.imports = append(r.imports, p...)
}
}
// ParsedImports returns an argument that adds Rego imports to the query's
// context.
func ParsedImports(imp []*ast.Import) func(r *Rego) {
return func(r *Rego) {
r.parsedImports = append(r.parsedImports, imp...)
}
}
// Input returns an argument that sets the Rego input document. Input should be
// a native Go value representing the input document.
func Input(x interface{}) func(r *Rego) {
return func(r *Rego) {
r.rawInput = &x
}
}
// ParsedInput returns an argument that sets the Rego input document.
func ParsedInput(x ast.Value) func(r *Rego) {
return func(r *Rego) {
r.parsedInput = x
}
}
// Unknowns returns an argument that sets the values to treat as unknown during
// partial evaluation.
func Unknowns(unknowns []string) func(r *Rego) {
return func(r *Rego) {
r.unknowns = unknowns
}
}
// ParsedUnknowns returns an argument that sets the values to treat as unknown
// during partial evaluation.
func ParsedUnknowns(unknowns []*ast.Term) func(r *Rego) {
return func(r *Rego) {
r.parsedUnknowns = unknowns
}
}
// DisableInlining adds a set of paths to exclude from partial evaluation inlining.
func DisableInlining(paths []string) func(r *Rego) {
return func(r *Rego) {
r.disableInlining = paths
}
}
// ShallowInlining prevents rules that depend on unknown values from being inlined.
// Rules that only depend on known values are inlined.
func ShallowInlining(yes bool) func(r *Rego) {
return func(r *Rego) {
r.shallowInlining = yes
}
}
// SkipPartialNamespace disables namespacing of partial evalution results for support
// rules generated from policy. Synthetic support rules are still namespaced.
func SkipPartialNamespace(yes bool) func(r *Rego) {
return func(r *Rego) {
r.skipPartialNamespace = yes
}
}
// PartialNamespace returns an argument that sets the namespace to use for
// partial evaluation results. The namespace must be a valid package path
// component.
func PartialNamespace(ns string) func(r *Rego) {
return func(r *Rego) {
r.partialNamespace = ns
}
}
// Module returns an argument that adds a Rego module.
func Module(filename, input string) func(r *Rego) {
return func(r *Rego) {
r.modules = append(r.modules, rawModule{
filename: filename,
module: input,
})
}
}
// ParsedModule returns an argument that adds a parsed Rego module. If a string
// module with the same filename name is added, it will override the parsed
// module.
func ParsedModule(module *ast.Module) func(*Rego) {
return func(r *Rego) {
var filename string
if module.Package.Location != nil {
filename = module.Package.Location.File
} else {
filename = fmt.Sprintf("module_%p.rego", module)
}
r.parsedModules[filename] = module
}
}
// Load returns an argument that adds a filesystem path to load data
// and Rego modules from. Any file with a *.rego, *.yaml, or *.json
// extension will be loaded. The path can be either a directory or file,
// directories are loaded recursively. The optional ignore string patterns
// can be used to filter which files are used.
// The Load option can only be used once.
// Note: Loading files will require a write transaction on the store.
func Load(paths []string, filter loader.Filter) func(r *Rego) {
return func(r *Rego) {
r.loadPaths = loadPaths{paths, filter}
}
}
// LoadBundle returns an argument that adds a filesystem path to load
// a bundle from. The path can be a compressed bundle file or a directory
// to be loaded as a bundle.
// Note: Loading bundles will require a write transaction on the store.
func LoadBundle(path string) func(r *Rego) {
return func(r *Rego) {
r.bundlePaths = append(r.bundlePaths, path)
}
}
// ParsedBundle returns an argument that adds a bundle to be loaded.
func ParsedBundle(name string, b *bundle.Bundle) func(r *Rego) {
return func(r *Rego) {
r.bundles[name] = b
}
}
// Compiler returns an argument that sets the Rego compiler.
func Compiler(c *ast.Compiler) func(r *Rego) {
return func(r *Rego) {
r.compiler = c
}
}
// Store returns an argument that sets the policy engine's data storage layer.
//
// If using the Load, LoadBundle, or ParsedBundle options then a transaction
// must also be provided via the Transaction() option. After loading files
// or bundles the transaction should be aborted or committed.
func Store(s storage.Store) func(r *Rego) {
return func(r *Rego) {
r.store = s
}
}
// StoreReadAST returns an argument that sets whether the store should eagerly convert data to AST values.
//
// Only applicable when no store has been set on the Rego object through the Store option.
func StoreReadAST(enabled bool) func(r *Rego) {
return func(r *Rego) {
r.ownStoreReadAst = enabled
}
}
// Transaction returns an argument that sets the transaction to use for storage
// layer operations.
//
// Requires the store associated with the transaction to be provided via the
// Store() option. If using Load(), LoadBundle(), or ParsedBundle() options
// the transaction will likely require write params.
func Transaction(txn storage.Transaction) func(r *Rego) {
return func(r *Rego) {
r.txn = txn
}
}
// Metrics returns an argument that sets the metrics collection.
func Metrics(m metrics.Metrics) func(r *Rego) {
return func(r *Rego) {
r.metrics = m
}
}
// Instrument returns an argument that enables instrumentation for diagnosing
// performance issues.
func Instrument(yes bool) func(r *Rego) {
return func(r *Rego) {
r.instrument = yes
}
}
// Trace returns an argument that enables tracing on r.
func Trace(yes bool) func(r *Rego) {
return func(r *Rego) {
r.trace = yes
}
}
// Tracer returns an argument that adds a query tracer to r.
// Deprecated: Use QueryTracer instead.
func Tracer(t topdown.Tracer) func(r *Rego) {
return func(r *Rego) {
if t != nil {
r.queryTracers = append(r.queryTracers, topdown.WrapLegacyTracer(t))
}
}
}
// QueryTracer returns an argument that adds a query tracer to r.
func QueryTracer(t topdown.QueryTracer) func(r *Rego) {
return func(r *Rego) {
if t != nil {
r.queryTracers = append(r.queryTracers, t)
}
}
}
// Runtime returns an argument that sets the runtime data to provide to the
// evaluation engine.
func Runtime(term *ast.Term) func(r *Rego) {
return func(r *Rego) {
r.runtime = term
}
}
// Time sets the wall clock time to use during policy evaluation. Prepared queries
// do not inherit this parameter. Use EvalTime to set the wall clock time when
// executing a prepared query.
func Time(x time.Time) func(r *Rego) {
return func(r *Rego) {
r.time = x
}
}
// Seed sets a reader that will seed randomization required by built-in functions.
// If a seed is not provided crypto/rand.Reader is used.
func Seed(r io.Reader) func(*Rego) {
return func(e *Rego) {
e.seed = r
}
}
// PrintTrace is a helper function to write a human-readable version of the
// trace to the writer w.
func PrintTrace(w io.Writer, r *Rego) {
if r == nil || r.tracebuf == nil {
return
}
topdown.PrettyTrace(w, *r.tracebuf)
}
// PrintTraceWithLocation is a helper function to write a human-readable version of the
// trace to the writer w.
func PrintTraceWithLocation(w io.Writer, r *Rego) {
if r == nil || r.tracebuf == nil {
return
}
topdown.PrettyTraceWithLocation(w, *r.tracebuf)
}
// UnsafeBuiltins sets the built-in functions to treat as unsafe and not allow.
// This option is ignored for module compilation if the caller supplies the
// compiler. This option is always honored for query compilation. Provide an
// empty (non-nil) map to disable checks on queries.
func UnsafeBuiltins(unsafeBuiltins map[string]struct{}) func(r *Rego) {
return func(r *Rego) {
r.unsafeBuiltins = unsafeBuiltins
}
}
// SkipBundleVerification skips verification of a signed bundle.
func SkipBundleVerification(yes bool) func(r *Rego) {
return func(r *Rego) {
r.skipBundleVerification = yes
}
}
// InterQueryBuiltinCache sets the inter-query cache that built-in functions can utilize
// during evaluation.
func InterQueryBuiltinCache(c cache.InterQueryCache) func(r *Rego) {
return func(r *Rego) {
r.interQueryBuiltinCache = c
}
}
// InterQueryBuiltinValueCache sets the inter-query value cache that built-in functions can utilize
// during evaluation.
func InterQueryBuiltinValueCache(c cache.InterQueryValueCache) func(r *Rego) {
return func(r *Rego) {
r.interQueryBuiltinValueCache = c
}
}
// NDBuiltinCache sets the non-deterministic builtins cache.
func NDBuiltinCache(c builtins.NDBCache) func(r *Rego) {
return func(r *Rego) {
r.ndBuiltinCache = c
}
}
// StrictBuiltinErrors tells the evaluator to treat all built-in function errors as fatal errors.
func StrictBuiltinErrors(yes bool) func(r *Rego) {
return func(r *Rego) {
r.strictBuiltinErrors = yes
}
}
// BuiltinErrorList supplies an error slice to store built-in function errors.
func BuiltinErrorList(list *[]topdown.Error) func(r *Rego) {
return func(r *Rego) {
r.builtinErrorList = list
}
}
// Resolver sets a Resolver for a specified ref path.
func Resolver(ref ast.Ref, r resolver.Resolver) func(r *Rego) {
return func(rego *Rego) {
rego.resolvers = append(rego.resolvers, refResolver{ref, r})
}
}
// Schemas sets the schemaSet
func Schemas(x *ast.SchemaSet) func(r *Rego) {
return func(r *Rego) {
r.schemaSet = x
}
}
// Capabilities configures the underlying compiler's capabilities.
// This option is ignored for module compilation if the caller supplies the
// compiler.
func Capabilities(c *ast.Capabilities) func(r *Rego) {
return func(r *Rego) {
r.capabilities = c
}
}
// Target sets the runtime to exercise.
func Target(t string) func(r *Rego) {
return func(r *Rego) {
r.target = t
}
}
// GenerateJSON sets the AST to JSON converter for the results.
func GenerateJSON(f func(*ast.Term, *EvalContext) (interface{}, error)) func(r *Rego) {
return func(r *Rego) {
r.generateJSON = f
}
}
// PrintHook sets the object to use for handling print statement outputs.
func PrintHook(h print.Hook) func(r *Rego) {
return func(r *Rego) {
r.printHook = h
}
}
// DistributedTracingOpts sets the options to be used by distributed tracing.
func DistributedTracingOpts(tr tracing.Options) func(r *Rego) {
return func(r *Rego) {
r.distributedTacingOpts = tr
}
}
// EnablePrintStatements enables print() calls. If this option is not provided,
// print() calls will be erased from the policy. This option only applies to
// queries and policies that passed as raw strings, i.e., this function will not
// have any affect if the caller supplies the ast.Compiler instance.
func EnablePrintStatements(yes bool) func(r *Rego) {
return func(r *Rego) {
r.enablePrintStatements = yes
}
}
// Strict enables or disables strict-mode in the compiler
func Strict(yes bool) func(r *Rego) {
return func(r *Rego) {
r.strict = yes
}
}
func SetRegoVersion(version ast.RegoVersion) func(r *Rego) {
return func(r *Rego) {
r.regoVersion = version
}
}
// New returns a new Rego object.
func New(options ...func(r *Rego)) *Rego {
r := &Rego{
parsedModules: map[string]*ast.Module{},
capture: map[*ast.Expr]ast.Var{},
compiledQueries: map[queryType]compiledQuery{},
builtinDecls: map[string]*ast.Builtin{},
builtinFuncs: map[string]*topdown.Builtin{},
bundles: map[string]*bundle.Bundle{},
}
for _, option := range options {
option(r)
}
if r.compiler == nil {
r.compiler = ast.NewCompiler().
WithUnsafeBuiltins(r.unsafeBuiltins).
WithBuiltins(r.builtinDecls).
WithDebug(r.dump).
WithSchemas(r.schemaSet).
WithCapabilities(r.capabilities).
WithEnablePrintStatements(r.enablePrintStatements).
WithStrict(r.strict).
WithUseTypeCheckAnnotations(true)
// topdown could be target "" or "rego", but both could be overridden by
// a target plugin (checked below)
if r.target == targetWasm {
r.compiler = r.compiler.WithEvalMode(ast.EvalModeIR)
}
}
if r.store == nil {
r.store = inmem.NewWithOpts(inmem.OptReturnASTValuesOnRead(r.ownStoreReadAst))
r.ownStore = true
} else {
r.ownStore = false
}
if r.metrics == nil {
r.metrics = metrics.New()
}
if r.instrument {
r.instrumentation = topdown.NewInstrumentation(r.metrics)
r.compiler.WithMetrics(r.metrics)
}
if r.trace {
r.tracebuf = topdown.NewBufferTracer()
r.queryTracers = append(r.queryTracers, r.tracebuf)
}
if r.partialNamespace == "" {
r.partialNamespace = defaultPartialNamespace
}
if r.generateJSON == nil {
r.generateJSON = generateJSON
}
if r.pluginMgr != nil {
for _, name := range r.pluginMgr.Plugins() {
p := r.pluginMgr.Plugin(name)
if p0, ok := p.(TargetPlugin); ok {
r.plugins = append(r.plugins, p0)
}
}
}
if t := r.targetPlugin(r.target); t != nil {
r.compiler = r.compiler.WithEvalMode(ast.EvalModeIR)
}
return r
}
// Eval evaluates this Rego object and returns a ResultSet.
func (r *Rego) Eval(ctx context.Context) (ResultSet, error) {
var err error
var txnClose transactionCloser
r.txn, txnClose, err = r.getTxn(ctx)
if err != nil {
return nil, err
}
pq, err := r.PrepareForEval(ctx)
if err != nil {
_ = txnClose(ctx, err) // Ignore error
return nil, err
}
evalArgs := []EvalOption{
EvalTransaction(r.txn),
EvalMetrics(r.metrics),
EvalInstrument(r.instrument),
EvalTime(r.time),
EvalInterQueryBuiltinCache(r.interQueryBuiltinCache),
EvalInterQueryBuiltinValueCache(r.interQueryBuiltinValueCache),
EvalSeed(r.seed),
}
if r.ndBuiltinCache != nil {
evalArgs = append(evalArgs, EvalNDBuiltinCache(r.ndBuiltinCache))
}
for _, qt := range r.queryTracers {
evalArgs = append(evalArgs, EvalQueryTracer(qt))
}
for i := range r.resolvers {
evalArgs = append(evalArgs, EvalResolver(r.resolvers[i].ref, r.resolvers[i].r))
}
rs, err := pq.Eval(ctx, evalArgs...)
txnErr := txnClose(ctx, err) // Always call closer
if err == nil {
err = txnErr
}
return rs, err
}
// PartialEval has been deprecated and renamed to PartialResult.
func (r *Rego) PartialEval(ctx context.Context) (PartialResult, error) {
return r.PartialResult(ctx)
}
// PartialResult partially evaluates this Rego object and returns a PartialResult.
func (r *Rego) PartialResult(ctx context.Context) (PartialResult, error) {
var err error
var txnClose transactionCloser
r.txn, txnClose, err = r.getTxn(ctx)
if err != nil {
return PartialResult{}, err
}
pq, err := r.PrepareForEval(ctx, WithPartialEval())
txnErr := txnClose(ctx, err) // Always call closer
if err != nil {
return PartialResult{}, err
}
if txnErr != nil {
return PartialResult{}, txnErr
}
pr := PartialResult{
compiler: pq.r.compiler,
store: pq.r.store,
body: pq.r.parsedQuery,
builtinDecls: pq.r.builtinDecls,
builtinFuncs: pq.r.builtinFuncs,
}
return pr, nil
}
// Partial runs partial evaluation on r and returns the result.
func (r *Rego) Partial(ctx context.Context) (*PartialQueries, error) {
var err error
var txnClose transactionCloser
r.txn, txnClose, err = r.getTxn(ctx)
if err != nil {
return nil, err
}
pq, err := r.PrepareForPartial(ctx)
if err != nil {
_ = txnClose(ctx, err) // Ignore error
return nil, err
}
evalArgs := []EvalOption{
EvalTransaction(r.txn),
EvalMetrics(r.metrics),
EvalInstrument(r.instrument),
EvalInterQueryBuiltinCache(r.interQueryBuiltinCache),
EvalInterQueryBuiltinValueCache(r.interQueryBuiltinValueCache),
}
if r.ndBuiltinCache != nil {
evalArgs = append(evalArgs, EvalNDBuiltinCache(r.ndBuiltinCache))
}
for _, t := range r.queryTracers {
evalArgs = append(evalArgs, EvalQueryTracer(t))
}
for i := range r.resolvers {
evalArgs = append(evalArgs, EvalResolver(r.resolvers[i].ref, r.resolvers[i].r))
}
pqs, err := pq.Partial(ctx, evalArgs...)
txnErr := txnClose(ctx, err) // Always call closer
if err == nil {
err = txnErr
}
return pqs, err
}
// CompileOption defines a function to set options on Compile calls.
type CompileOption func(*CompileContext)
// CompileContext contains options for Compile calls.
type CompileContext struct {
partial bool
}
// CompilePartial defines an option to control whether partial evaluation is run
// before the query is planned and compiled.
func CompilePartial(yes bool) CompileOption {
return func(cfg *CompileContext) {
cfg.partial = yes
}
}
// Compile returns a compiled policy query.
func (r *Rego) Compile(ctx context.Context, opts ...CompileOption) (*CompileResult, error) {
var cfg CompileContext
for _, opt := range opts {
opt(&cfg)
}
var queries []ast.Body
modules := make([]*ast.Module, 0, len(r.compiler.Modules))
if cfg.partial {
pq, err := r.Partial(ctx)
if err != nil {
return nil, err
}
if r.dump != nil {
if len(pq.Queries) != 0 {
msg := fmt.Sprintf("QUERIES (%d total):", len(pq.Queries))
fmt.Fprintln(r.dump, msg)
fmt.Fprintln(r.dump, strings.Repeat("-", len(msg)))
for i := range pq.Queries {
fmt.Println(pq.Queries[i])
}
fmt.Fprintln(r.dump)
}
if len(pq.Support) != 0 {
msg := fmt.Sprintf("SUPPORT (%d total):", len(pq.Support))
fmt.Fprintln(r.dump, msg)
fmt.Fprintln(r.dump, strings.Repeat("-", len(msg)))
for i := range pq.Support {
fmt.Println(pq.Support[i])
}
fmt.Fprintln(r.dump)
}
}
queries = pq.Queries
modules = pq.Support
for _, module := range r.compiler.Modules {
modules = append(modules, module)
}
} else {
var err error
// If creating a new transaction it should be closed before calling the
// planner to avoid holding open the transaction longer than needed.
//
// TODO(tsandall): in future, planner could make use of store, in which
// case this will need to change.
var txnClose transactionCloser
r.txn, txnClose, err = r.getTxn(ctx)
if err != nil {
return nil, err
}
err = r.prepare(ctx, compileQueryType, nil)
txnErr := txnClose(ctx, err) // Always call closer
if err != nil {
return nil, err
}
if txnErr != nil {
return nil, err
}
for _, module := range r.compiler.Modules {
modules = append(modules, module)
}
queries = []ast.Body{r.compiledQueries[compileQueryType].query}
}
if tgt := r.targetPlugin(r.target); tgt != nil {
return nil, fmt.Errorf("unsupported for rego target plugins")
}
return r.compileWasm(modules, queries, compileQueryType) // TODO(sr) control flow is funky here
}
func (r *Rego) compileWasm(_ []*ast.Module, queries []ast.Body, qType queryType) (*CompileResult, error) {
policy, err := r.planQuery(queries, qType)
if err != nil {
return nil, err
}
m, err := wasm.New().WithPolicy(policy).Compile()
if err != nil {
return nil, err
}
var out bytes.Buffer
if err := encoding.WriteModule(&out, m); err != nil {
return nil, err
}
return &CompileResult{
Bytes: out.Bytes(),
}, nil
}
// PrepareOption defines a function to set an option to control
// the behavior of the Prepare call.
type PrepareOption func(*PrepareConfig)
// PrepareConfig holds settings to control the behavior of the
// Prepare call.
type PrepareConfig struct {
doPartialEval bool
disableInlining *[]string
builtinFuncs map[string]*topdown.Builtin
}
// WithPartialEval configures an option for PrepareForEval
// which will have it perform partial evaluation while preparing
// the query (similar to rego.Rego#PartialResult)
func WithPartialEval() PrepareOption {
return func(p *PrepareConfig) {
p.doPartialEval = true
}
}
// WithNoInline adds a set of paths to exclude from partial evaluation inlining.
func WithNoInline(paths []string) PrepareOption {
return func(p *PrepareConfig) {
p.disableInlining = &paths
}
}
// WithBuiltinFuncs carries the rego.Function{1,2,3} per-query function definitions
// to the target plugins.
func WithBuiltinFuncs(bis map[string]*topdown.Builtin) PrepareOption {
return func(p *PrepareConfig) {
if p.builtinFuncs == nil {
p.builtinFuncs = make(map[string]*topdown.Builtin, len(bis))
}
for k, v := range bis {
p.builtinFuncs[k] = v
}
}
}
// BuiltinFuncs allows retrieving the builtin funcs set via PrepareOption
// WithBuiltinFuncs.
func (p *PrepareConfig) BuiltinFuncs() map[string]*topdown.Builtin {
return p.builtinFuncs
}
// PrepareForEval will parse inputs, modules, and query arguments in preparation
// of evaluating them.
func (r *Rego) PrepareForEval(ctx context.Context, opts ...PrepareOption) (PreparedEvalQuery, error) {
if !r.hasQuery() {
return PreparedEvalQuery{}, fmt.Errorf("cannot evaluate empty query")
}
pCfg := &PrepareConfig{}
for _, o := range opts {
o(pCfg)
}
var err error
var txnClose transactionCloser
r.txn, txnClose, err = r.getTxn(ctx)
if err != nil {
return PreparedEvalQuery{}, err
}
// If the caller wanted to do partial evaluation as part of preparation
// do it now and use the new Rego object.
if pCfg.doPartialEval {
pr, err := r.partialResult(ctx, pCfg)
if err != nil {
_ = txnClose(ctx, err) // Ignore error
return PreparedEvalQuery{}, err
}
// Prepare the new query using the result of partial evaluation
pq, err := pr.Rego(Transaction(r.txn)).PrepareForEval(ctx)
txnErr := txnClose(ctx, err)
if err != nil {
return pq, err
}
return pq, txnErr
}
err = r.prepare(ctx, evalQueryType, []extraStage{
{
after: "ResolveRefs",
stage: ast.QueryCompilerStageDefinition{
Name: "RewriteToCaptureValue",
MetricName: "query_compile_stage_rewrite_to_capture_value",
Stage: r.rewriteQueryToCaptureValue,
},
},
})
if err != nil {
_ = txnClose(ctx, err) // Ignore error
return PreparedEvalQuery{}, err
}
switch r.target {
case targetWasm: // TODO(sr): make wasm a target plugin, too
if r.hasWasmModule() {
_ = txnClose(ctx, err) // Ignore error
return PreparedEvalQuery{}, fmt.Errorf("wasm target not supported")
}
var modules []*ast.Module
for _, module := range r.compiler.Modules {
modules = append(modules, module)
}
queries := []ast.Body{r.compiledQueries[evalQueryType].query}
e, err := opa.LookupEngine(targetWasm)
if err != nil {
return PreparedEvalQuery{}, err
}
// nolint: staticcheck // SA4006 false positive
cr, err := r.compileWasm(modules, queries, evalQueryType)
if err != nil {
_ = txnClose(ctx, err) // Ignore error
return PreparedEvalQuery{}, err
}
// nolint: staticcheck // SA4006 false positive
data, err := r.store.Read(ctx, r.txn, storage.Path{})
if err != nil {
_ = txnClose(ctx, err) // Ignore error
return PreparedEvalQuery{}, err
}
o, err := e.New().WithPolicyBytes(cr.Bytes).WithDataJSON(data).Init()
if err != nil {
_ = txnClose(ctx, err) // Ignore error
return PreparedEvalQuery{}, err
}
r.opa = o
case targetRego: // do nothing, don't lookup default plugin
default: // either a specific plugin target, or one that is default
if tgt := r.targetPlugin(r.target); tgt != nil {
queries := []ast.Body{r.compiledQueries[evalQueryType].query}
pol, err := r.planQuery(queries, evalQueryType)
if err != nil {
return PreparedEvalQuery{}, err
}
// always add the builtins provided via rego.FunctionN options
opts = append(opts, WithBuiltinFuncs(r.builtinFuncs))
r.targetPrepState, err = tgt.PrepareForEval(ctx, pol, opts...)
if err != nil {
return PreparedEvalQuery{}, err
}
}
}
txnErr := txnClose(ctx, err) // Always call closer
if err != nil {
return PreparedEvalQuery{}, err
}
if txnErr != nil {
return PreparedEvalQuery{}, txnErr
}
return PreparedEvalQuery{preparedQuery{r, pCfg}}, err
}
// PrepareForPartial will parse inputs, modules, and query arguments in preparation
// of partially evaluating them.
func (r *Rego) PrepareForPartial(ctx context.Context, opts ...PrepareOption) (PreparedPartialQuery, error) {
if !r.hasQuery() {
return PreparedPartialQuery{}, fmt.Errorf("cannot evaluate empty query")
}
pCfg := &PrepareConfig{}
for _, o := range opts {
o(pCfg)
}
var err error
var txnClose transactionCloser
r.txn, txnClose, err = r.getTxn(ctx)
if err != nil {
return PreparedPartialQuery{}, err
}
err = r.prepare(ctx, partialQueryType, []extraStage{
{
after: "CheckSafety",
stage: ast.QueryCompilerStageDefinition{
Name: "RewriteEquals",
MetricName: "query_compile_stage_rewrite_equals",
Stage: r.rewriteEqualsForPartialQueryCompile,
},
},
})
txnErr := txnClose(ctx, err) // Always call closer
if err != nil {
return PreparedPartialQuery{}, err
}
if txnErr != nil {
return PreparedPartialQuery{}, txnErr
}
return PreparedPartialQuery{preparedQuery{r, pCfg}}, err
}
func (r *Rego) prepare(ctx context.Context, qType queryType, extras []extraStage) error {
var err error
r.parsedInput, err = r.parseInput()
if err != nil {
return err
}
err = r.loadFiles(ctx, r.txn, r.metrics)
if err != nil {
return err
}
err = r.loadBundles(ctx, r.txn, r.metrics)
if err != nil {
return err
}
err = r.parseModules(ctx, r.txn, r.metrics)
if err != nil {
return err
}
// Compile the modules *before* the query, else functions
// defined in the module won't be found...
err = r.compileModules(ctx, r.txn, r.metrics)
if err != nil {
return err
}
imports, err := r.prepareImports()
if err != nil {
return err
}
queryImports := []*ast.Import{}
for _, imp := range imports {
path := imp.Path.Value.(ast.Ref)
if path.HasPrefix([]*ast.Term{ast.FutureRootDocument}) || path.HasPrefix([]*ast.Term{ast.RegoRootDocument}) {
queryImports = append(queryImports, imp)
}
}
r.parsedQuery, err = r.parseQuery(queryImports, r.metrics)
if err != nil {
return err
}
err = r.compileAndCacheQuery(qType, r.parsedQuery, imports, r.metrics, extras)
if err != nil {
return err
}
return nil
}
func (r *Rego) parseModules(ctx context.Context, txn storage.Transaction, m metrics.Metrics) error {
if len(r.modules) == 0 {
return nil
}
ids, err := r.store.ListPolicies(ctx, txn)
if err != nil {
return err
}
m.Timer(metrics.RegoModuleParse).Start()
defer m.Timer(metrics.RegoModuleParse).Stop()
var errs Errors
// Parse any modules that are saved to the store, but only if
// another compile step is going to occur (ie. we have parsed modules
// that need to be compiled).
for _, id := range ids {
// if it is already on the compiler we're using
// then don't bother to re-parse it from source
if _, haveMod := r.compiler.Modules[id]; haveMod {
continue
}
bs, err := r.store.GetPolicy(ctx, txn, id)
if err != nil {
return err
}
parsed, err := ast.ParseModuleWithOpts(id, string(bs), ast.ParserOptions{RegoVersion: r.regoVersion})
if err != nil {
errs = append(errs, err)
}
r.parsedModules[id] = parsed
}
// Parse any passed in as arguments to the Rego object
for _, module := range r.modules {
p, err := module.ParseWithOpts(ast.ParserOptions{RegoVersion: r.regoVersion})
if err != nil {
switch errorWithType := err.(type) {
case ast.Errors:
for _, e := range errorWithType {
errs = append(errs, e)
}
default:
errs = append(errs, errorWithType)
}
}
r.parsedModules[module.filename] = p
}
if len(errs) > 0 {
return errs
}
return nil
}
func (r *Rego) loadFiles(ctx context.Context, txn storage.Transaction, m metrics.Metrics) error {
if len(r.loadPaths.paths) == 0 {
return nil
}
m.Timer(metrics.RegoLoadFiles).Start()
defer m.Timer(metrics.RegoLoadFiles).Stop()
result, err := loader.NewFileLoader().
WithMetrics(m).
WithProcessAnnotation(true).
WithRegoVersion(r.regoVersion).
Filtered(r.loadPaths.paths, r.loadPaths.filter)
if err != nil {
return err
}
for name, mod := range result.Modules {
r.parsedModules[name] = mod.Parsed
}
if len(result.Documents) > 0 {
err = r.store.Write(ctx, txn, storage.AddOp, storage.Path{}, result.Documents)
if err != nil {
return err
}
}
return nil
}
func (r *Rego) loadBundles(_ context.Context, _ storage.Transaction, m metrics.Metrics) error {
if len(r.bundlePaths) == 0 {
return nil
}
m.Timer(metrics.RegoLoadBundles).Start()
defer m.Timer(metrics.RegoLoadBundles).Stop()
for _, path := range r.bundlePaths {
bndl, err := loader.NewFileLoader().
WithMetrics(m).
WithProcessAnnotation(true).
WithSkipBundleVerification(r.skipBundleVerification).
WithRegoVersion(r.regoVersion).
AsBundle(path)
if err != nil {
return fmt.Errorf("loading error: %s", err)
}
r.bundles[path] = bndl
}
return nil
}
func (r *Rego) parseInput() (ast.Value, error) {
if r.parsedInput != nil {
return r.parsedInput, nil
}
return r.parseRawInput(r.rawInput, r.metrics)
}
func (r *Rego) parseRawInput(rawInput *interface{}, m metrics.Metrics) (ast.Value, error) {
var input ast.Value
if rawInput == nil {
return input, nil
}
m.Timer(metrics.RegoInputParse).Start()
defer m.Timer(metrics.RegoInputParse).Stop()
rawPtr := util.Reference(rawInput)
// roundtrip through json: this turns slices (e.g. []string, []bool) into
// []interface{}, the only array type ast.InterfaceToValue can work with
if err := util.RoundTrip(rawPtr); err != nil {
return nil, err
}
return ast.InterfaceToValue(*rawPtr)
}
func (r *Rego) parseQuery(queryImports []*ast.Import, m metrics.Metrics) (ast.Body, error) {
if r.parsedQuery != nil {
return r.parsedQuery, nil
}
m.Timer(metrics.RegoQueryParse).Start()
defer m.Timer(metrics.RegoQueryParse).Stop()
popts, err := future.ParserOptionsFromFutureImports(queryImports)
if err != nil {
return nil, err
}
popts.RegoVersion = r.regoVersion
popts, err = parserOptionsFromRegoVersionImport(queryImports, popts)
if err != nil {
return nil, err
}
popts.SkipRules = true
return ast.ParseBodyWithOpts(r.query, popts)
}
func parserOptionsFromRegoVersionImport(imports []*ast.Import, popts ast.ParserOptions) (ast.ParserOptions, error) {
for _, imp := range imports {
path := imp.Path.Value.(ast.Ref)
if ast.Compare(path, ast.RegoV1CompatibleRef) == 0 {
popts.RegoVersion = ast.RegoV1
return popts, nil
}
}
return popts, nil
}
func (r *Rego) compileModules(ctx context.Context, txn storage.Transaction, m metrics.Metrics) error {
// Only compile again if there are new modules.
if len(r.bundles) > 0 || len(r.parsedModules) > 0 {
// The bundle.Activate call will activate any bundles passed in
// (ie compile + handle data store changes), and include any of
// the additional modules passed in. If no bundles are provided
// it will only compile the passed in modules.
// Use this as the single-point of compiling everything only a
// single time.
opts := &bundle.ActivateOpts{
Ctx: ctx,
Store: r.store,
Txn: txn,
Compiler: r.compilerForTxn(ctx, r.store, txn),
Metrics: m,
Bundles: r.bundles,
ExtraModules: r.parsedModules,
ParserOptions: ast.ParserOptions{RegoVersion: r.regoVersion},
}
err := bundle.Activate(opts)
if err != nil {
return err
}
}
// Ensure all configured resolvers from the store are loaded. Skip if any were explicitly provided.
if len(r.resolvers) == 0 {
resolvers, err := bundleUtils.LoadWasmResolversFromStore(ctx, r.store, txn, r.bundles)
if err != nil {
return err
}
for _, rslvr := range resolvers {
for _, ep := range rslvr.Entrypoints() {
r.resolvers = append(r.resolvers, refResolver{ep, rslvr})
}
}
}
return nil
}
func (r *Rego) compileAndCacheQuery(qType queryType, query ast.Body, imports []*ast.Import, m metrics.Metrics, extras []extraStage) error {
m.Timer(metrics.RegoQueryCompile).Start()
defer m.Timer(metrics.RegoQueryCompile).Stop()
cachedQuery, ok := r.compiledQueries[qType]
if ok && cachedQuery.query != nil && cachedQuery.compiler != nil {
return nil
}
qc, compiled, err := r.compileQuery(query, imports, m, extras)
if err != nil {
return err
}
// cache the query for future use
r.compiledQueries[qType] = compiledQuery{
query: compiled,
compiler: qc,
}
return nil
}
func (r *Rego) prepareImports() ([]*ast.Import, error) {
imports := r.parsedImports
if len(r.imports) > 0 {
s := make([]string, len(r.imports))
for i := range r.imports {
s[i] = fmt.Sprintf("import %v", r.imports[i])
}
parsed, err := ast.ParseImports(strings.Join(s, "\n"))
if err != nil {
return nil, err
}
imports = append(imports, parsed...)
}
return imports, nil
}
func (r *Rego) compileQuery(query ast.Body, imports []*ast.Import, _ metrics.Metrics, extras []extraStage) (ast.QueryCompiler, ast.Body, error) {
var pkg *ast.Package
if r.pkg != "" {
var err error
pkg, err = ast.ParsePackage(fmt.Sprintf("package %v", r.pkg))
if err != nil {
return nil, nil, err
}
} else {
pkg = r.parsedPackage
}
qctx := ast.NewQueryContext().
WithPackage(pkg).
WithImports(imports)
qc := r.compiler.QueryCompiler().
WithContext(qctx).
WithUnsafeBuiltins(r.unsafeBuiltins).
WithEnablePrintStatements(r.enablePrintStatements).
WithStrict(false)
for _, extra := range extras {
qc = qc.WithStageAfter(extra.after, extra.stage)
}
compiled, err := qc.Compile(query)
return qc, compiled, err
}
func (r *Rego) eval(ctx context.Context, ectx *EvalContext) (ResultSet, error) {
switch {
case r.targetPrepState != nil: // target plugin flow
var val ast.Value
if r.runtime != nil {
val = r.runtime.Value
}
s, err := r.targetPrepState.Eval(ctx, ectx, val)
if err != nil {
return nil, err
}
return r.valueToQueryResult(s, ectx)
case r.target == targetWasm:
return r.evalWasm(ctx, ectx)
case r.target == targetRego: // continue
}
q := topdown.NewQuery(ectx.compiledQuery.query).
WithQueryCompiler(ectx.compiledQuery.compiler).
WithCompiler(r.compiler).
WithStore(r.store).
WithTransaction(ectx.txn).
WithBuiltins(r.builtinFuncs).
WithMetrics(ectx.metrics).
WithInstrumentation(ectx.instrumentation).
WithRuntime(r.runtime).
WithIndexing(ectx.indexing).
WithEarlyExit(ectx.earlyExit).
WithInterQueryBuiltinCache(ectx.interQueryBuiltinCache).
WithInterQueryBuiltinValueCache(ectx.interQueryBuiltinValueCache).
WithStrictBuiltinErrors(r.strictBuiltinErrors).
WithBuiltinErrorList(r.builtinErrorList).
WithSeed(ectx.seed).
WithPrintHook(ectx.printHook).
WithDistributedTracingOpts(r.distributedTacingOpts).
WithVirtualCache(ectx.virtualCache)
if !ectx.time.IsZero() {
q = q.WithTime(ectx.time)
}
if ectx.ndBuiltinCache != nil {
q = q.WithNDBuiltinCache(ectx.ndBuiltinCache)
}
for i := range ectx.queryTracers {
q = q.WithQueryTracer(ectx.queryTracers[i])
}
if ectx.parsedInput != nil {
q = q.WithInput(ast.NewTerm(ectx.parsedInput))
}
for i := range ectx.resolvers {
q = q.WithResolver(ectx.resolvers[i].ref, ectx.resolvers[i].r)
}
// Cancel query if context is cancelled or deadline is reached.
c := topdown.NewCancel()
q = q.WithCancel(c)
exit := make(chan struct{})
defer close(exit)
go waitForDone(ctx, exit, func() {
c.Cancel()
})
var rs ResultSet
err := q.Iter(ctx, func(qr topdown.QueryResult) error {
result, err := r.generateResult(qr, ectx)
if err != nil {
return err
}
rs = append(rs, result)
return nil
})
if err != nil {
return nil, err
}
if len(rs) == 0 {
return nil, nil
}
return rs, nil
}
func (r *Rego) evalWasm(ctx context.Context, ectx *EvalContext) (ResultSet, error) {
input := ectx.rawInput
if ectx.parsedInput != nil {
i := interface{}(ectx.parsedInput)
input = &i
}
result, err := r.opa.Eval(ctx, opa.EvalOpts{
Metrics: r.metrics,
Input: input,
Time: ectx.time,
Seed: ectx.seed,
InterQueryBuiltinCache: ectx.interQueryBuiltinCache,
NDBuiltinCache: ectx.ndBuiltinCache,
PrintHook: ectx.printHook,
Capabilities: ectx.capabilities,
})
if err != nil {
return nil, err
}
parsed, err := ast.ParseTerm(string(result.Result))
if err != nil {
return nil, err
}
return r.valueToQueryResult(parsed.Value, ectx)
}
func (r *Rego) valueToQueryResult(res ast.Value, ectx *EvalContext) (ResultSet, error) {
resultSet, ok := res.(ast.Set)
if !ok {
return nil, fmt.Errorf("illegal result type")
}
if resultSet.Len() == 0 {
return nil, nil
}
var rs ResultSet
err := resultSet.Iter(func(term *ast.Term) error {
obj, ok := term.Value.(ast.Object)
if !ok {
return fmt.Errorf("illegal result type")
}
qr := topdown.QueryResult{}
obj.Foreach(func(k, v *ast.Term) {
kvt := ast.VarTerm(string(k.Value.(ast.String)))
qr[kvt.Value.(ast.Var)] = v
})
result, err := r.generateResult(qr, ectx)
if err != nil {
return err
}
rs = append(rs, result)
return nil
})
return rs, err
}
func (r *Rego) generateResult(qr topdown.QueryResult, ectx *EvalContext) (Result, error) {
rewritten := ectx.compiledQuery.compiler.RewrittenVars()
result := newResult()
for k, term := range qr {
v, err := r.generateJSON(term, ectx)
if err != nil {
return result, err
}
if rw, ok := rewritten[k]; ok {
k = rw
}
if isTermVar(k) || isTermWasmVar(k) || k.IsGenerated() || k.IsWildcard() {
continue
}
result.Bindings[string(k)] = v
}
for _, expr := range ectx.compiledQuery.query {
if expr.Generated {
continue
}
if k, ok := r.capture[expr]; ok {
v, err := r.generateJSON(qr[k], ectx)
if err != nil {
return result, err
}
result.Expressions = append(result.Expressions, newExpressionValue(expr, v))
} else {
result.Expressions = append(result.Expressions, newExpressionValue(expr, true))
}
}
return result, nil
}
func (r *Rego) partialResult(ctx context.Context, pCfg *PrepareConfig) (PartialResult, error) {
err := r.prepare(ctx, partialResultQueryType, []extraStage{
{
after: "ResolveRefs",
stage: ast.QueryCompilerStageDefinition{
Name: "RewriteForPartialEval",
MetricName: "query_compile_stage_rewrite_for_partial_eval",
Stage: r.rewriteQueryForPartialEval,
},
},
})
if err != nil {
return PartialResult{}, err
}
ectx := &EvalContext{
parsedInput: r.parsedInput,
metrics: r.metrics,
txn: r.txn,
partialNamespace: r.partialNamespace,
queryTracers: r.queryTracers,
compiledQuery: r.compiledQueries[partialResultQueryType],
instrumentation: r.instrumentation,
indexing: true,
resolvers: r.resolvers,
capabilities: r.capabilities,
strictBuiltinErrors: r.strictBuiltinErrors,
}
disableInlining := r.disableInlining
if pCfg.disableInlining != nil {
disableInlining = *pCfg.disableInlining
}
ectx.disableInlining, err = parseStringsToRefs(disableInlining)
if err != nil {
return PartialResult{}, err
}
pq, err := r.partial(ctx, ectx)
if err != nil {
return PartialResult{}, err
}
// Construct module for queries.
id := fmt.Sprintf("__partialresult__%s__", ectx.partialNamespace)
module, err := ast.ParseModule(id, "package "+ectx.partialNamespace)
if err != nil {
return PartialResult{}, fmt.Errorf("bad partial namespace")
}
module.Rules = make([]*ast.Rule, len(pq.Queries))
for i, body := range pq.Queries {
rule := &ast.Rule{
Head: ast.NewHead(ast.Var("__result__"), nil, ast.Wildcard),
Body: body,
Module: module,
}
module.Rules[i] = rule
if checkPartialResultForRecursiveRefs(body, rule.Path()) {
return PartialResult{}, Errors{errPartialEvaluationNotEffective}
}
}
// Update compiler with partial evaluation output.
r.compiler.Modules[id] = module
for i, module := range pq.Support {
r.compiler.Modules[fmt.Sprintf("__partialsupport__%s__%d__", ectx.partialNamespace, i)] = module
}
r.metrics.Timer(metrics.RegoModuleCompile).Start()
r.compilerForTxn(ctx, r.store, r.txn).Compile(r.compiler.Modules)
r.metrics.Timer(metrics.RegoModuleCompile).Stop()
if r.compiler.Failed() {
return PartialResult{}, r.compiler.Errors
}
result := PartialResult{
compiler: r.compiler,
store: r.store,
body: ast.MustParseBody(fmt.Sprintf("data.%v.__result__", ectx.partialNamespace)),
builtinDecls: r.builtinDecls,
builtinFuncs: r.builtinFuncs,
}
return result, nil
}
func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries, error) {
var unknowns []*ast.Term
switch {
case ectx.parsedUnknowns != nil:
unknowns = ectx.parsedUnknowns
case ectx.unknowns != nil:
unknowns = make([]*ast.Term, len(ectx.unknowns))
for i := range ectx.unknowns {
var err error
unknowns[i], err = ast.ParseTerm(ectx.unknowns[i])
if err != nil {
return nil, err
}
}
default:
// Use input document as unknown if caller has not specified any.
unknowns = []*ast.Term{ast.NewTerm(ast.InputRootRef)}
}
q := topdown.NewQuery(ectx.compiledQuery.query).
WithQueryCompiler(ectx.compiledQuery.compiler).
WithCompiler(r.compiler).
WithStore(r.store).
WithTransaction(ectx.txn).
WithBuiltins(r.builtinFuncs).
WithMetrics(ectx.metrics).
WithInstrumentation(ectx.instrumentation).
WithUnknowns(unknowns).
WithDisableInlining(ectx.disableInlining).
WithRuntime(r.runtime).
WithIndexing(ectx.indexing).
WithEarlyExit(ectx.earlyExit).
WithPartialNamespace(ectx.partialNamespace).
WithSkipPartialNamespace(r.skipPartialNamespace).
WithShallowInlining(r.shallowInlining).
WithInterQueryBuiltinCache(ectx.interQueryBuiltinCache).
WithInterQueryBuiltinValueCache(ectx.interQueryBuiltinValueCache).
WithStrictBuiltinErrors(ectx.strictBuiltinErrors).
WithSeed(ectx.seed).
WithPrintHook(ectx.printHook)
if !ectx.time.IsZero() {
q = q.WithTime(ectx.time)
}
if ectx.ndBuiltinCache != nil {
q = q.WithNDBuiltinCache(ectx.ndBuiltinCache)
}
for i := range ectx.queryTracers {
q = q.WithQueryTracer(ectx.queryTracers[i])
}
if ectx.parsedInput != nil {
q = q.WithInput(ast.NewTerm(ectx.parsedInput))
}
for i := range ectx.resolvers {
q = q.WithResolver(ectx.resolvers[i].ref, ectx.resolvers[i].r)
}
// Cancel query if context is cancelled or deadline is reached.
c := topdown.NewCancel()
q = q.WithCancel(c)
exit := make(chan struct{})
defer close(exit)
go waitForDone(ctx, exit, func() {
c.Cancel()
})
queries, support, err := q.PartialRun(ctx)
if err != nil {
return nil, err
}
// If the target rego-version is v0, and the rego.v1 import is available, then we attempt to apply it to support modules.
if r.regoVersion == ast.RegoV0 && (r.capabilities == nil || r.capabilities.ContainsFeature(ast.FeatureRegoV1Import)) {
for i, mod := range support {
// We can't apply the RegoV0CompatV1 version to the support module if it contains rules or vars that
// conflict with future keywords.
applyRegoVersion := true
ast.WalkRules(mod, func(r *ast.Rule) bool {
name := r.Head.Name
if name == "" && len(r.Head.Reference) > 0 {
name = r.Head.Reference[0].Value.(ast.Var)
}
if ast.IsFutureKeyword(name.String()) {
applyRegoVersion = false
return true
}
return false
})
if applyRegoVersion {
ast.WalkVars(mod, func(v ast.Var) bool {
if ast.IsFutureKeyword(v.String()) {
applyRegoVersion = false
return true
}
return false
})
}
if applyRegoVersion {
support[i].SetRegoVersion(ast.RegoV0CompatV1)
} else {
support[i].SetRegoVersion(r.regoVersion)
}
}
} else {
// If the target rego-version is not v0, then we apply the target rego-version to the support modules.
for i := range support {
support[i].SetRegoVersion(r.regoVersion)
}
}
pq := &PartialQueries{
Queries: queries,
Support: support,
}
return pq, nil
}
func (r *Rego) rewriteQueryToCaptureValue(_ ast.QueryCompiler, query ast.Body) (ast.Body, error) {
checkCapture := iteration(query) || len(query) > 1
for _, expr := range query {
if expr.Negated {
continue
}
if expr.IsAssignment() || expr.IsEquality() {
continue
}
var capture *ast.Term
// If the expression can be evaluated as a function, rewrite it to
// capture the return value. E.g., neq(1,2) becomes neq(1,2,x) but
// plus(1,2,x) does not get rewritten.
switch terms := expr.Terms.(type) {
case *ast.Term:
capture = r.generateTermVar()
expr.Terms = ast.Equality.Expr(terms, capture).Terms
r.capture[expr] = capture.Value.(ast.Var)
case []*ast.Term:
tpe := r.compiler.TypeEnv.Get(terms[0])
if !types.Void(tpe) && types.Arity(tpe) == len(terms)-1 {
capture = r.generateTermVar()
expr.Terms = append(terms, capture)
r.capture[expr] = capture.Value.(ast.Var)
}
}
if capture != nil && checkCapture {
cpy := expr.Copy()
cpy.Terms = capture
cpy.Generated = true
cpy.With = nil
query.Append(cpy)
}
}
return query, nil
}
func (r *Rego) rewriteQueryForPartialEval(_ ast.QueryCompiler, query ast.Body) (ast.Body, error) {
if len(query) != 1 {
return nil, fmt.Errorf("partial evaluation requires single ref (not multiple expressions)")
}
term, ok := query[0].Terms.(*ast.Term)
if !ok {
return nil, fmt.Errorf("partial evaluation requires ref (not expression)")
}
ref, ok := term.Value.(ast.Ref)
if !ok {
return nil, fmt.Errorf("partial evaluation requires ref (not %v)", ast.TypeName(term.Value))
}
if !ref.IsGround() {
return nil, fmt.Errorf("partial evaluation requires ground ref")
}
return ast.NewBody(ast.Equality.Expr(ast.Wildcard, term)), nil
}
// rewriteEqualsForPartialQueryCompile will rewrite == to = in queries. Normally
// this wouldn't be done, except for handling queries with the `Partial` API
// where rewriting them can substantially simplify the result, and it is unlikely
// that the caller would need expression values.
func (r *Rego) rewriteEqualsForPartialQueryCompile(_ ast.QueryCompiler, query ast.Body) (ast.Body, error) {
doubleEq := ast.Equal.Ref()
unifyOp := ast.Equality.Ref()
ast.WalkExprs(query, func(x *ast.Expr) bool {
if x.IsCall() {
operator := x.Operator()
if operator.Equal(doubleEq) && len(x.Operands()) == 2 {
x.SetOperator(ast.NewTerm(unifyOp))
}
}
return false
})
return query, nil
}
func (r *Rego) generateTermVar() *ast.Term {
r.termVarID++
prefix := ast.WildcardPrefix
if p := r.targetPlugin(r.target); p != nil {
prefix = wasmVarPrefix
} else if r.target == targetWasm {
prefix = wasmVarPrefix
}
return ast.VarTerm(fmt.Sprintf("%sterm%v", prefix, r.termVarID))
}
func (r Rego) hasQuery() bool {
return len(r.query) != 0 || len(r.parsedQuery) != 0
}
func (r Rego) hasWasmModule() bool {
for _, b := range r.bundles {
if len(b.WasmModules) > 0 {
return true
}
}
return false
}
type transactionCloser func(ctx context.Context, err error) error
// getTxn will conditionally create a read or write transaction suitable for
// the configured Rego object. The returned function should be used to close the txn
// regardless of status.
func (r *Rego) getTxn(ctx context.Context) (storage.Transaction, transactionCloser, error) {
noopCloser := func(_ context.Context, _ error) error {
return nil // no-op default
}
if r.txn != nil {
// Externally provided txn
return r.txn, noopCloser, nil
}
// Create a new transaction..
params := storage.TransactionParams{}
// Bundles and data paths may require writing data files or manifests to storage
if len(r.bundles) > 0 || len(r.bundlePaths) > 0 || len(r.loadPaths.paths) > 0 {
// If we were given a store we will *not* write to it, only do that on one
// which was created automatically on behalf of the user.
if !r.ownStore {
return nil, noopCloser, errors.New("unable to start write transaction when store was provided")
}
params.Write = true
}
txn, err := r.store.NewTransaction(ctx, params)
if err != nil {
return nil, noopCloser, err
}
// Setup a closer function that will abort or commit as needed.
closer := func(ctx context.Context, txnErr error) error {
var err error
if txnErr == nil && params.Write {
err = r.store.Commit(ctx, txn)
} else {
r.store.Abort(ctx, txn)
}
// Clear the auto created transaction now that it is closed.
r.txn = nil
return err
}
return txn, closer, nil
}
func (r *Rego) compilerForTxn(ctx context.Context, store storage.Store, txn storage.Transaction) *ast.Compiler {
// Update the compiler to have a valid path conflict check
// for the current context and transaction.
return r.compiler.WithPathConflictsCheck(storage.NonEmpty(ctx, store, txn))
}
func checkPartialResultForRecursiveRefs(body ast.Body, path ast.Ref) bool {
var stop bool
ast.WalkRefs(body, func(x ast.Ref) bool {
if !stop {
if path.HasPrefix(x) {
stop = true
}
}
return stop
})
return stop
}
func isTermVar(v ast.Var) bool {
return strings.HasPrefix(string(v), ast.WildcardPrefix+"term")
}
func isTermWasmVar(v ast.Var) bool {
return strings.HasPrefix(string(v), wasmVarPrefix+"term")
}
func waitForDone(ctx context.Context, exit chan struct{}, f func()) {
select {
case <-exit:
return
case <-ctx.Done():
f()
return
}
}
type rawModule struct {
filename string
module string
}
func (m rawModule) Parse() (*ast.Module, error) {
return ast.ParseModule(m.filename, m.module)
}
func (m rawModule) ParseWithOpts(opts ast.ParserOptions) (*ast.Module, error) {
return ast.ParseModuleWithOpts(m.filename, m.module, opts)
}
type extraStage struct {
after string
stage ast.QueryCompilerStageDefinition
}
type refResolver struct {
ref ast.Ref
r resolver.Resolver
}
func iteration(x interface{}) bool {
var stopped bool
vis := ast.NewGenericVisitor(func(x interface{}) bool {
switch x := x.(type) {
case *ast.Term:
if ast.IsComprehension(x.Value) {
return true
}
case ast.Ref:
if !stopped {
if bi := ast.BuiltinMap[x.String()]; bi != nil {
if bi.Relation {
stopped = true
return stopped
}
}
for i := 1; i < len(x); i++ {
if _, ok := x[i].Value.(ast.Var); ok {
stopped = true
return stopped
}
}
}
return stopped
}
return stopped
})
vis.Walk(x)
return stopped
}
func parseStringsToRefs(s []string) ([]ast.Ref, error) {
refs := make([]ast.Ref, len(s))
for i := range refs {
var err error
refs[i], err = ast.ParseRef(s[i])
if err != nil {
return nil, err
}
}
return refs, nil
}
// helper function to finish a built-in function call. If an error occurred,
// wrap the error and return it. Otherwise, invoke the iterator if the result
// was defined.
func finishFunction(name string, bctx topdown.BuiltinContext, result *ast.Term, err error, iter func(*ast.Term) error) error {
if err != nil {
var e *HaltError
if errors.As(err, &e) {
tdErr := &topdown.Error{
Code: topdown.BuiltinErr,
Message: fmt.Sprintf("%v: %v", name, e.Error()),
Location: bctx.Location,
}
return topdown.Halt{Err: tdErr.Wrap(e)}
}
tdErr := &topdown.Error{
Code: topdown.BuiltinErr,
Message: fmt.Sprintf("%v: %v", name, err.Error()),
Location: bctx.Location,
}
return tdErr.Wrap(err)
}
if result == nil {
return nil
}
return iter(result)
}
// helper function to return an option that sets a custom built-in function.
func newFunction(decl *Function, f topdown.BuiltinFunc) func(*Rego) {
return func(r *Rego) {
r.builtinDecls[decl.Name] = &ast.Builtin{
Name: decl.Name,
Decl: decl.Decl,
Nondeterministic: decl.Nondeterministic,
}
r.builtinFuncs[decl.Name] = &topdown.Builtin{
Decl: r.builtinDecls[decl.Name],
Func: f,
}
}
}
func generateJSON(term *ast.Term, ectx *EvalContext) (interface{}, error) {
return ast.JSONWithOpt(term.Value,
ast.JSONOpt{
SortSets: ectx.sortSets,
CopyMaps: ectx.copyMaps,
})
}
func (r *Rego) planQuery(queries []ast.Body, evalQueryType queryType) (*ir.Policy, error) {
modules := make([]*ast.Module, 0, len(r.compiler.Modules))
for _, module := range r.compiler.Modules {
modules = append(modules, module)
}
decls := make(map[string]*ast.Builtin, len(r.builtinDecls)+len(ast.BuiltinMap))
for k, v := range ast.BuiltinMap {
decls[k] = v
}
for k, v := range r.builtinDecls {
decls[k] = v
}
const queryName = "eval" // NOTE(tsandall): the query name is arbitrary
p := planner.New().
WithQueries([]planner.QuerySet{
{
Name: queryName,
Queries: queries,
RewrittenVars: r.compiledQueries[evalQueryType].compiler.RewrittenVars(),
},
}).
WithModules(modules).
WithBuiltinDecls(decls).
WithDebug(r.dump)
policy, err := p.Plan()
if err != nil {
return nil, err
}
if r.dump != nil {
fmt.Fprintln(r.dump, "PLAN:")
fmt.Fprintln(r.dump, "-----")
err = ir.Pretty(r.dump, policy)
if err != nil {
return nil, err
}
fmt.Fprintln(r.dump)
}
return policy, nil
}