Files
kubesphere/vendor/github.com/open-policy-agent/opa/topdown/eval.go
KubeSphere CI Bot 447a51f08b feat: kubesphere 4.0 (#6115)
* feat: kubesphere 4.0

Signed-off-by: ci-bot <ci-bot@kubesphere.io>

* feat: kubesphere 4.0

Signed-off-by: ci-bot <ci-bot@kubesphere.io>

---------

Signed-off-by: ci-bot <ci-bot@kubesphere.io>
Co-authored-by: ks-ci-bot <ks-ci-bot@example.com>
Co-authored-by: joyceliu <joyceliu@yunify.com>
2024-09-06 11:05:52 +08:00

3682 lines
87 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package topdown
import (
"context"
"fmt"
"io"
"sort"
"strconv"
"strings"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/metrics"
"github.com/open-policy-agent/opa/storage"
"github.com/open-policy-agent/opa/topdown/builtins"
"github.com/open-policy-agent/opa/topdown/cache"
"github.com/open-policy-agent/opa/topdown/copypropagation"
"github.com/open-policy-agent/opa/topdown/print"
"github.com/open-policy-agent/opa/tracing"
"github.com/open-policy-agent/opa/types"
)
type evalIterator func(*eval) error
type unifyIterator func() error
type unifyRefIterator func(pos int) error
type queryIDFactory struct {
curr uint64
}
// Note: The first call to Next() returns 0.
func (f *queryIDFactory) Next() uint64 {
curr := f.curr
f.curr++
return curr
}
type builtinErrors struct {
errs []error
}
// earlyExitError is used to abort iteration where early exit is possible
type earlyExitError struct {
prev error
e *eval
}
func (ee *earlyExitError) Error() string {
return fmt.Sprintf("%v: early exit", ee.e.query)
}
type eval struct {
ctx context.Context
metrics metrics.Metrics
seed io.Reader
time *ast.Term
queryID uint64
queryIDFact *queryIDFactory
parent *eval
caller *eval
cancel Cancel
query ast.Body
queryCompiler ast.QueryCompiler
index int
indexing bool
earlyExit bool
bindings *bindings
store storage.Store
baseCache *baseCache
txn storage.Transaction
compiler *ast.Compiler
input *ast.Term
data *ast.Term
external *resolverTrie
targetStack *refStack
tracers []QueryTracer
traceEnabled bool
traceLastLocation *ast.Location // Last location of a trace event.
plugTraceVars bool
instr *Instrumentation
builtins map[string]*Builtin
builtinCache builtins.Cache
ndBuiltinCache builtins.NDBCache
functionMocks *functionMocksStack
virtualCache *virtualCache
comprehensionCache *comprehensionCache
interQueryBuiltinCache cache.InterQueryCache
saveSet *saveSet
saveStack *saveStack
saveSupport *saveSupport
saveNamespace *ast.Term
skipSaveNamespace bool
inliningControl *inliningControl
genvarprefix string
genvarid int
runtime *ast.Term
builtinErrors *builtinErrors
printHook print.Hook
tracingOpts tracing.Options
findOne bool
strictObjects bool
}
func (e *eval) Run(iter evalIterator) error {
e.traceEnter(e.query)
return e.eval(func(e *eval) error {
e.traceExit(e.query)
err := iter(e)
e.traceRedo(e.query)
return err
})
}
func (e *eval) String() string {
s := strings.Builder{}
e.string(&s)
return s.String()
}
func (e *eval) string(s *strings.Builder) {
fmt.Fprintf(s, "<query: %v index: %d findOne: %v", e.query, e.index, e.findOne)
if e.parent != nil {
s.WriteRune(' ')
e.parent.string(s)
}
s.WriteRune('>')
}
func (e *eval) builtinFunc(name string) (*ast.Builtin, BuiltinFunc, bool) {
decl, ok := ast.BuiltinMap[name]
if ok {
f, ok := builtinFunctions[name]
if ok {
return decl, f, true
}
} else {
bi, ok := e.builtins[name]
if ok {
return bi.Decl, bi.Func, true
}
}
return nil, nil, false
}
func (e *eval) closure(query ast.Body) *eval {
cpy := *e
cpy.index = 0
cpy.query = query
cpy.queryID = cpy.queryIDFact.Next()
cpy.parent = e
cpy.findOne = false
return &cpy
}
func (e *eval) child(query ast.Body) *eval {
cpy := *e
cpy.index = 0
cpy.query = query
cpy.queryID = cpy.queryIDFact.Next()
cpy.bindings = newBindings(cpy.queryID, e.instr)
cpy.parent = e
cpy.findOne = false
return &cpy
}
func (e *eval) next(iter evalIterator) error {
e.index++
err := e.evalExpr(iter)
e.index--
return err
}
func (e *eval) partial() bool {
return e.saveSet != nil
}
func (e *eval) unknown(x interface{}, b *bindings) bool {
if !e.partial() {
return false
}
// If the caller provided an ast.Value directly (e.g., an ast.Ref) wrap
// it as an ast.Term because the saveSet Contains() function expects
// ast.Term.
if v, ok := x.(ast.Value); ok {
x = ast.NewTerm(v)
}
return saveRequired(e.compiler, e.inliningControl, true, e.saveSet, b, x, false)
}
func (e *eval) traceEnter(x ast.Node) {
e.traceEvent(EnterOp, x, "", nil)
}
func (e *eval) traceExit(x ast.Node) {
var msg string
if e.findOne {
msg = "early"
}
e.traceEvent(ExitOp, x, msg, nil)
}
func (e *eval) traceEval(x ast.Node) {
e.traceEvent(EvalOp, x, "", nil)
}
func (e *eval) traceDuplicate(x ast.Node) {
e.traceEvent(DuplicateOp, x, "", nil)
}
func (e *eval) traceFail(x ast.Node) {
e.traceEvent(FailOp, x, "", nil)
}
func (e *eval) traceRedo(x ast.Node) {
e.traceEvent(RedoOp, x, "", nil)
}
func (e *eval) traceSave(x ast.Node) {
e.traceEvent(SaveOp, x, "", nil)
}
func (e *eval) traceIndex(x ast.Node, msg string, target *ast.Ref) {
e.traceEvent(IndexOp, x, msg, target)
}
func (e *eval) traceWasm(x ast.Node, target *ast.Ref) {
e.traceEvent(WasmOp, x, "", target)
}
func (e *eval) traceEvent(op Op, x ast.Node, msg string, target *ast.Ref) {
if !e.traceEnabled {
return
}
var parentID uint64
if e.parent != nil {
parentID = e.parent.queryID
}
location := x.Loc()
if location == nil {
location = e.traceLastLocation
} else {
e.traceLastLocation = location
}
evt := Event{
QueryID: e.queryID,
ParentID: parentID,
Op: op,
Node: x,
Location: location,
Message: msg,
Ref: target,
input: e.input,
bindings: e.bindings,
}
// Skip plugging the local variables, unless any of the tracers
// had required it via their configuration. If any required the
// variable bindings then we will plug and give values for all
// tracers.
if e.plugTraceVars {
evt.Locals = ast.NewValueMap()
evt.LocalMetadata = map[ast.Var]VarMetadata{}
_ = e.bindings.Iter(nil, func(k, v *ast.Term) error {
original := k.Value.(ast.Var)
rewritten, _ := e.rewrittenVar(original)
evt.LocalMetadata[original] = VarMetadata{
Name: rewritten,
Location: k.Loc(),
}
// For backwards compatibility save a copy of the values too..
evt.Locals.Put(k.Value, v.Value)
return nil
}) // cannot return error
ast.WalkTerms(x, func(term *ast.Term) bool {
if v, ok := term.Value.(ast.Var); ok {
if _, ok := evt.LocalMetadata[v]; !ok {
if rewritten, ok := e.rewrittenVar(v); ok {
evt.LocalMetadata[v] = VarMetadata{
Name: rewritten,
Location: term.Loc(),
}
}
}
}
return false
})
}
for i := range e.tracers {
e.tracers[i].TraceEvent(evt)
}
}
func (e *eval) eval(iter evalIterator) error {
return e.evalExpr(iter)
}
func (e *eval) evalExpr(iter evalIterator) error {
if e.cancel != nil && e.cancel.Cancelled() {
return &Error{
Code: CancelErr,
Message: "caller cancelled query execution",
}
}
if e.index >= len(e.query) {
err := iter(e)
if err != nil {
ee, ok := err.(*earlyExitError)
if !ok {
return err
}
if !e.findOne {
return nil
}
return &earlyExitError{prev: ee, e: e}
}
if e.findOne && !e.partial() { // we've found one!
return &earlyExitError{e: e}
}
return nil
}
expr := e.query[e.index]
e.traceEval(expr)
if len(expr.With) > 0 {
return e.evalWith(iter)
}
return e.evalStep(func(e *eval) error {
return e.next(iter)
})
}
func (e *eval) evalStep(iter evalIterator) error {
expr := e.query[e.index]
if expr.Negated {
return e.evalNot(iter)
}
var defined bool
var err error
switch terms := expr.Terms.(type) {
case []*ast.Term:
switch {
case expr.IsEquality():
err = e.unify(terms[1], terms[2], func() error {
defined = true
err := iter(e)
e.traceRedo(expr)
return err
})
default:
err = e.evalCall(terms, func() error {
defined = true
err := iter(e)
e.traceRedo(expr)
return err
})
}
case *ast.Term:
rterm := e.generateVar(fmt.Sprintf("term_%d_%d", e.queryID, e.index))
err = e.unify(terms, rterm, func() error {
if e.saveSet.Contains(rterm, e.bindings) {
return e.saveExpr(ast.NewExpr(rterm), e.bindings, func() error {
return iter(e)
})
}
if !e.bindings.Plug(rterm).Equal(ast.BooleanTerm(false)) {
defined = true
err := iter(e)
e.traceRedo(expr)
return err
}
return nil
})
case *ast.Every:
eval := evalEvery{
e: e,
expr: expr,
generator: ast.NewBody(
ast.Equality.Expr(
ast.RefTerm(terms.Domain, terms.Key).SetLocation(terms.Domain.Location),
terms.Value,
).SetLocation(terms.Domain.Location),
),
body: terms.Body,
}
err = eval.eval(func() error {
defined = true
err := iter(e)
e.traceRedo(expr)
return err
})
default: // guard-rail for adding extra (Expr).Terms types
return fmt.Errorf("got %T terms: %[1]v", terms)
}
if err != nil {
return err
}
if !defined {
e.traceFail(expr)
}
return nil
}
func (e *eval) evalNot(iter evalIterator) error {
expr := e.query[e.index]
if e.unknown(expr, e.bindings) {
return e.evalNotPartial(iter)
}
negation := ast.NewBody(expr.Complement().NoWith())
child := e.closure(negation)
var defined bool
child.traceEnter(negation)
err := child.eval(func(*eval) error {
child.traceExit(negation)
defined = true
child.traceRedo(negation)
return nil
})
if err != nil {
return err
}
if !defined {
return iter(e)
}
e.traceFail(expr)
return nil
}
func (e *eval) evalWith(iter evalIterator) error {
expr := e.query[e.index]
// Disable inlining on all references in the expression so the result of
// partial evaluation has the same semantics w/ the with statements
// preserved.
var disable []ast.Ref
disableRef := func(x ast.Ref) bool {
disable = append(disable, x.GroundPrefix())
return false
}
if e.partial() {
// If the value is unknown the with statement cannot be evaluated and so
// the entire expression should be saved to be safe. In the future this
// could be relaxed in certain cases (e.g., if the with statement would
// have no effect.)
for _, with := range expr.With {
if isFunction(e.compiler.TypeEnv, with.Target) || // non-builtin function replaced
isOtherRef(with.Target) { // built-in replaced
ast.WalkRefs(with.Value, disableRef)
continue
}
// with target is data or input (not built-in)
if e.saveSet.ContainsRecursive(with.Value, e.bindings) {
return e.saveExprMarkUnknowns(expr, e.bindings, func() error {
return e.next(iter)
})
}
ast.WalkRefs(with.Target, disableRef)
ast.WalkRefs(with.Value, disableRef)
}
ast.WalkRefs(expr.NoWith(), disableRef)
}
pairsInput := [][2]*ast.Term{}
pairsData := [][2]*ast.Term{}
functionMocks := [][2]*ast.Term{}
targets := []ast.Ref{}
for i := range expr.With {
target := expr.With[i].Target
plugged := e.bindings.Plug(expr.With[i].Value)
switch {
// NOTE(sr): ordering matters here: isFunction's ref is also covered by isDataRef
case isFunction(e.compiler.TypeEnv, target):
functionMocks = append(functionMocks, [...]*ast.Term{target, plugged})
case isInputRef(target):
pairsInput = append(pairsInput, [...]*ast.Term{target, plugged})
case isDataRef(target):
pairsData = append(pairsData, [...]*ast.Term{target, plugged})
default: // target must be builtin
if _, _, ok := e.builtinFunc(target.String()); ok {
functionMocks = append(functionMocks, [...]*ast.Term{target, plugged})
continue // don't append to disabled targets below
}
}
targets = append(targets, target.Value.(ast.Ref))
}
input, err := mergeTermWithValues(e.input, pairsInput)
if err != nil {
return &Error{
Code: ConflictErr,
Location: expr.Location,
Message: err.Error(),
}
}
data, err := mergeTermWithValues(e.data, pairsData)
if err != nil {
return &Error{
Code: ConflictErr,
Location: expr.Location,
Message: err.Error(),
}
}
oldInput, oldData := e.evalWithPush(input, data, functionMocks, targets, disable)
err = e.evalStep(func(e *eval) error {
e.evalWithPop(oldInput, oldData)
err := e.next(iter)
oldInput, oldData = e.evalWithPush(input, data, functionMocks, targets, disable)
return err
})
e.evalWithPop(oldInput, oldData)
return err
}
func (e *eval) evalWithPush(input, data *ast.Term, functionMocks [][2]*ast.Term, targets, disable []ast.Ref) (*ast.Term, *ast.Term) {
var oldInput *ast.Term
if input != nil {
oldInput = e.input
e.input = input
}
var oldData *ast.Term
if data != nil {
oldData = e.data
e.data = data
}
e.comprehensionCache.Push()
e.virtualCache.Push()
e.targetStack.Push(targets)
e.inliningControl.PushDisable(disable, true)
e.functionMocks.PutPairs(functionMocks)
return oldInput, oldData
}
func (e *eval) evalWithPop(input, data *ast.Term) {
e.inliningControl.PopDisable()
e.targetStack.Pop()
e.virtualCache.Pop()
e.comprehensionCache.Pop()
e.functionMocks.PopPairs()
e.data = data
e.input = input
}
func (e *eval) evalNotPartial(iter evalIterator) error {
// Prepare query normally.
expr := e.query[e.index]
negation := expr.Complement().NoWith()
child := e.closure(ast.NewBody(negation))
// Unknowns is the set of variables that are marked as unknown. The variables
// are namespaced with the query ID that they originate in. This ensures that
// variables across two or more queries are identified uniquely.
//
// NOTE(tsandall): this is greedy in the sense that we only need variable
// dependencies of the negation.
unknowns := e.saveSet.Vars(e.caller.bindings)
// Run partial evaluation. Since the result may require support, push a new
// query onto the save stack to avoid mutating the current save query. If
// shallow inlining is not enabled, run copy propagation to further simplify
// the result.
var cp *copypropagation.CopyPropagator
if !e.inliningControl.shallow {
cp = copypropagation.New(unknowns).WithEnsureNonEmptyBody(true).WithCompiler(e.compiler)
}
var savedQueries []ast.Body
e.saveStack.PushQuery(nil)
_ = child.eval(func(*eval) error {
query := e.saveStack.Peek()
plugged := query.Plug(e.caller.bindings)
// Skip this rule body if it fails to type-check.
// Type-checking failure means the rule body will never succeed.
if !e.compiler.PassesTypeCheck(plugged) {
return nil
}
if cp != nil {
plugged = applyCopyPropagation(cp, e.instr, plugged)
}
savedQueries = append(savedQueries, plugged)
return nil
}) // cannot return error
e.saveStack.PopQuery()
// If partial evaluation produced no results, the expression is always undefined
// so it does not have to be saved.
if len(savedQueries) == 0 {
return iter(e)
}
// Check if the partial evaluation result can be inlined in this query. If not,
// generate support rules for the result. Depending on the size of the partial
// evaluation result and the contents, it may or may not be inlinable. We treat
// the unknowns as safe because vars in the save set will either be known to
// the caller or made safe by an expression on the save stack.
if !canInlineNegation(unknowns, savedQueries) {
return e.evalNotPartialSupport(child.queryID, expr, unknowns, savedQueries, iter)
}
// If we can inline the result, we have to generate the cross product of the
// queries. For example:
//
// (A && B) || (C && D)
//
// Becomes:
//
// (!A && !C) || (!A && !D) || (!B && !C) || (!B && !D)
return complementedCartesianProduct(savedQueries, 0, nil, func(q ast.Body) error {
return e.saveInlinedNegatedExprs(q, func() error {
return iter(e)
})
})
}
func (e *eval) evalNotPartialSupport(negationID uint64, expr *ast.Expr, unknowns ast.VarSet, queries []ast.Body, iter evalIterator) error {
// Prepare support rule head.
supportName := fmt.Sprintf("__not%d_%d_%d__", e.queryID, e.index, negationID)
term := ast.RefTerm(ast.DefaultRootDocument, e.saveNamespace, ast.StringTerm(supportName))
path := term.Value.(ast.Ref)
head := ast.NewHead(ast.Var(supportName), nil, ast.BooleanTerm(true))
bodyVars := ast.NewVarSet()
for _, q := range queries {
bodyVars.Update(q.Vars(ast.VarVisitorParams{}))
}
unknowns = unknowns.Intersect(bodyVars)
// Make rule args. Sort them to ensure order is deterministic.
args := make([]*ast.Term, 0, len(unknowns))
for v := range unknowns {
args = append(args, ast.NewTerm(v))
}
sort.Slice(args, func(i, j int) bool {
return args[i].Value.Compare(args[j].Value) < 0
})
if len(args) > 0 {
head.Args = args
}
// Save support rules.
for _, query := range queries {
e.saveSupport.Insert(path, &ast.Rule{
Head: head,
Body: query,
})
}
// Save expression that refers to support rule set.
cpy := expr.CopyWithoutTerms()
if len(args) > 0 {
terms := make([]*ast.Term, len(args)+1)
terms[0] = term
copy(terms[1:], args)
cpy.Terms = terms
} else {
cpy.Terms = term
}
return e.saveInlinedNegatedExprs([]*ast.Expr{cpy}, func() error {
return e.next(iter)
})
}
func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error {
ref := terms[0].Value.(ast.Ref)
var mocked bool
mock, mocked := e.functionMocks.Get(ref)
if mocked {
if m, ok := mock.Value.(ast.Ref); ok && isFunction(e.compiler.TypeEnv, m) { // builtin or data function
mockCall := append([]*ast.Term{ast.NewTerm(m)}, terms[1:]...)
e.functionMocks.Push()
err := e.evalCall(mockCall, func() error {
e.functionMocks.Pop()
err := iter()
e.functionMocks.Push()
return err
})
e.functionMocks.Pop()
return err
}
}
// 'mocked' true now indicates that the replacement is a value: if
// it was a ref to a function, we'd have called that above.
if ref[0].Equal(ast.DefaultRootDocument) {
if mocked {
f := e.compiler.TypeEnv.Get(ref).(*types.Function)
return e.evalCallValue(len(f.FuncArgs().Args), terms, mock, iter)
}
var ir *ast.IndexResult
var err error
if e.partial() {
ir, err = e.getRules(ref, nil)
} else {
ir, err = e.getRules(ref, terms[1:])
}
if err != nil {
return err
}
eval := evalFunc{
e: e,
ref: ref,
terms: terms,
ir: ir,
}
return eval.eval(iter)
}
builtinName := ref.String()
bi, f, ok := e.builtinFunc(builtinName)
if !ok {
return unsupportedBuiltinErr(e.query[e.index].Location)
}
if mocked { // value replacement of built-in call
return e.evalCallValue(len(bi.Decl.Args()), terms, mock, iter)
}
if e.unknown(e.query[e.index], e.bindings) {
return e.saveCall(len(bi.Decl.Args()), terms, iter)
}
var parentID uint64
if e.parent != nil {
parentID = e.parent.queryID
}
var capabilities *ast.Capabilities
if e.compiler != nil {
capabilities = e.compiler.Capabilities()
}
bctx := BuiltinContext{
Context: e.ctx,
Metrics: e.metrics,
Seed: e.seed,
Time: e.time,
Cancel: e.cancel,
Runtime: e.runtime,
Cache: e.builtinCache,
InterQueryBuiltinCache: e.interQueryBuiltinCache,
NDBuiltinCache: e.ndBuiltinCache,
Location: e.query[e.index].Location,
QueryTracers: e.tracers,
TraceEnabled: e.traceEnabled,
QueryID: e.queryID,
ParentID: parentID,
PrintHook: e.printHook,
DistributedTracingOpts: e.tracingOpts,
Capabilities: capabilities,
}
eval := evalBuiltin{
e: e,
bi: bi,
bctx: bctx,
f: f,
terms: terms[1:],
}
return eval.eval(iter)
}
func (e *eval) evalCallValue(arity int, terms []*ast.Term, mock *ast.Term, iter unifyIterator) error {
switch {
case len(terms) == arity+2: // captured var
return e.unify(terms[len(terms)-1], mock, iter)
case len(terms) == arity+1:
if mock.Value.Compare(ast.Boolean(false)) != 0 {
return iter()
}
return nil
}
panic("unreachable")
}
func (e *eval) unify(a, b *ast.Term, iter unifyIterator) error {
return e.biunify(a, b, e.bindings, e.bindings, iter)
}
func (e *eval) biunify(a, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error {
a, b1 = b1.apply(a)
b, b2 = b2.apply(b)
if e.traceEnabled {
e.traceEvent(UnifyOp, ast.Equality.Expr(a, b), "", nil)
}
switch vA := a.Value.(type) {
case ast.Var, ast.Ref, *ast.ArrayComprehension, *ast.SetComprehension, *ast.ObjectComprehension:
return e.biunifyValues(a, b, b1, b2, iter)
case ast.Null:
switch b.Value.(type) {
case ast.Var, ast.Null, ast.Ref:
return e.biunifyValues(a, b, b1, b2, iter)
}
case ast.Boolean:
switch b.Value.(type) {
case ast.Var, ast.Boolean, ast.Ref:
return e.biunifyValues(a, b, b1, b2, iter)
}
case ast.Number:
switch b.Value.(type) {
case ast.Var, ast.Number, ast.Ref:
return e.biunifyValues(a, b, b1, b2, iter)
}
case ast.String:
switch b.Value.(type) {
case ast.Var, ast.String, ast.Ref:
return e.biunifyValues(a, b, b1, b2, iter)
}
case *ast.Array:
switch vB := b.Value.(type) {
case ast.Var, ast.Ref, *ast.ArrayComprehension:
return e.biunifyValues(a, b, b1, b2, iter)
case *ast.Array:
return e.biunifyArrays(vA, vB, b1, b2, iter)
}
case ast.Object:
switch vB := b.Value.(type) {
case ast.Var, ast.Ref, *ast.ObjectComprehension:
return e.biunifyValues(a, b, b1, b2, iter)
case ast.Object:
return e.biunifyObjects(vA, vB, b1, b2, iter)
}
case ast.Set:
return e.biunifyValues(a, b, b1, b2, iter)
}
return nil
}
func (e *eval) biunifyArrays(a, b *ast.Array, b1, b2 *bindings, iter unifyIterator) error {
if a.Len() != b.Len() {
return nil
}
return e.biunifyArraysRec(a, b, b1, b2, iter, 0)
}
func (e *eval) biunifyArraysRec(a, b *ast.Array, b1, b2 *bindings, iter unifyIterator, idx int) error {
if idx == a.Len() {
return iter()
}
return e.biunify(a.Elem(idx), b.Elem(idx), b1, b2, func() error {
return e.biunifyArraysRec(a, b, b1, b2, iter, idx+1)
})
}
func (e *eval) biunifyObjects(a, b ast.Object, b1, b2 *bindings, iter unifyIterator) error {
if a.Len() != b.Len() {
return nil
}
// Objects must not contain unbound variables as keys at this point as we
// cannot unify them. Similar to sets, plug both sides before comparing the
// keys and unifying the values.
if nonGroundKeys(a) {
a = plugKeys(a, b1)
}
if nonGroundKeys(b) {
b = plugKeys(b, b2)
}
return e.biunifyObjectsRec(a, b, b1, b2, iter, a, a.KeysIterator())
}
func (e *eval) biunifyObjectsRec(a, b ast.Object, b1, b2 *bindings, iter unifyIterator, keys ast.Object, oki ast.ObjectKeysIterator) error {
key, more := oki.Next() // Get next key from iterator.
if !more {
return iter()
}
v2 := b.Get(key)
if v2 == nil {
return nil
}
return e.biunify(a.Get(key), v2, b1, b2, func() error {
return e.biunifyObjectsRec(a, b, b1, b2, iter, keys, oki)
})
}
func (e *eval) biunifyValues(a, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error {
// Try to evaluate refs and comprehensions. If partial evaluation is
// enabled, then skip evaluation (and save the expression) if the term is
// in the save set. Currently, comprehensions are not evaluated during
// partial eval. This could be improved in the future.
var saveA, saveB bool
if _, ok := a.Value.(ast.Set); ok {
saveA = e.saveSet.ContainsRecursive(a, b1)
} else {
saveA = e.saveSet.Contains(a, b1)
if !saveA {
if _, refA := a.Value.(ast.Ref); refA {
return e.biunifyRef(a, b, b1, b2, iter)
}
}
}
if _, ok := b.Value.(ast.Set); ok {
saveB = e.saveSet.ContainsRecursive(b, b2)
} else {
saveB = e.saveSet.Contains(b, b2)
if !saveB {
if _, refB := b.Value.(ast.Ref); refB {
return e.biunifyRef(b, a, b2, b1, iter)
}
}
}
if saveA || saveB {
return e.saveUnify(a, b, b1, b2, iter)
}
if ast.IsComprehension(a.Value) {
return e.biunifyComprehension(a, b, b1, b2, false, iter)
} else if ast.IsComprehension(b.Value) {
return e.biunifyComprehension(b, a, b2, b1, true, iter)
}
// Perform standard unification.
_, varA := a.Value.(ast.Var)
_, varB := b.Value.(ast.Var)
var undo undo
if varA && varB {
if b1 == b2 && a.Equal(b) {
return iter()
}
b1.bind(a, b, b2, &undo)
err := iter()
undo.Undo()
return err
} else if varA && !varB {
b1.bind(a, b, b2, &undo)
err := iter()
undo.Undo()
return err
} else if varB && !varA {
b2.bind(b, a, b1, &undo)
err := iter()
undo.Undo()
return err
}
// Sets must not contain unbound variables at this point as we cannot unify
// them. So simply plug both sides (to substitute any bound variables with
// values) and then check for equality.
switch a.Value.(type) {
case ast.Set:
a = b1.Plug(a)
b = b2.Plug(b)
}
if a.Equal(b) {
return iter()
}
return nil
}
func (e *eval) biunifyRef(a, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error {
ref := a.Value.(ast.Ref)
if ref[0].Equal(ast.DefaultRootDocument) {
node := e.compiler.RuleTree.Child(ref[0].Value)
eval := evalTree{
e: e,
ref: ref,
pos: 1,
plugged: ref.Copy(),
bindings: b1,
rterm: b,
rbindings: b2,
node: node,
}
return eval.eval(iter)
}
var term *ast.Term
var termbindings *bindings
if ref[0].Equal(ast.InputRootDocument) {
term = e.input
termbindings = b1
} else {
term, termbindings = b1.apply(ref[0])
if term == ref[0] {
term = nil
}
}
if term == nil {
return nil
}
eval := evalTerm{
e: e,
ref: ref,
pos: 1,
bindings: b1,
term: term,
termbindings: termbindings,
rterm: b,
rbindings: b2,
}
return eval.eval(iter)
}
func (e *eval) biunifyComprehension(a, b *ast.Term, b1, b2 *bindings, swap bool, iter unifyIterator) error {
if e.unknown(a, b1) {
return e.biunifyComprehensionPartial(a, b, b1, b2, swap, iter)
}
value, err := e.buildComprehensionCache(a)
if err != nil {
return err
} else if value != nil {
return e.biunify(value, b, b1, b2, iter)
} else {
e.instr.counterIncr(evalOpComprehensionCacheMiss)
}
switch a := a.Value.(type) {
case *ast.ArrayComprehension:
return e.biunifyComprehensionArray(a, b, b1, b2, iter)
case *ast.SetComprehension:
return e.biunifyComprehensionSet(a, b, b1, b2, iter)
case *ast.ObjectComprehension:
return e.biunifyComprehensionObject(a, b, b1, b2, iter)
}
return internalErr(e.query[e.index].Location, "illegal comprehension type")
}
func (e *eval) buildComprehensionCache(a *ast.Term) (*ast.Term, error) {
index := e.comprehensionIndex(a)
if index == nil {
e.instr.counterIncr(evalOpComprehensionCacheSkip)
return nil, nil
}
cache, ok := e.comprehensionCache.Elem(a)
if !ok {
var err error
switch x := a.Value.(type) {
case *ast.ArrayComprehension:
cache, err = e.buildComprehensionCacheArray(x, index.Keys)
case *ast.SetComprehension:
cache, err = e.buildComprehensionCacheSet(x, index.Keys)
case *ast.ObjectComprehension:
cache, err = e.buildComprehensionCacheObject(x, index.Keys)
default:
err = internalErr(e.query[e.index].Location, "illegal comprehension type")
}
if err != nil {
return nil, err
}
e.comprehensionCache.Set(a, cache)
e.instr.counterIncr(evalOpComprehensionCacheBuild)
} else {
e.instr.counterIncr(evalOpComprehensionCacheHit)
}
values := make([]*ast.Term, len(index.Keys))
for i := range index.Keys {
values[i] = e.bindings.Plug(index.Keys[i])
}
return cache.Get(values), nil
}
func (e *eval) buildComprehensionCacheArray(x *ast.ArrayComprehension, keys []*ast.Term) (*comprehensionCacheElem, error) {
child := e.child(x.Body)
node := newComprehensionCacheElem()
return node, child.Run(func(child *eval) error {
values := make([]*ast.Term, len(keys))
for i := range keys {
values[i] = child.bindings.Plug(keys[i])
}
head := child.bindings.Plug(x.Term)
cached := node.Get(values)
if cached != nil {
cached.Value = cached.Value.(*ast.Array).Append(head)
} else {
node.Put(values, ast.ArrayTerm(head))
}
return nil
})
}
func (e *eval) buildComprehensionCacheSet(x *ast.SetComprehension, keys []*ast.Term) (*comprehensionCacheElem, error) {
child := e.child(x.Body)
node := newComprehensionCacheElem()
return node, child.Run(func(child *eval) error {
values := make([]*ast.Term, len(keys))
for i := range keys {
values[i] = child.bindings.Plug(keys[i])
}
head := child.bindings.Plug(x.Term)
cached := node.Get(values)
if cached != nil {
set := cached.Value.(ast.Set)
set.Add(head)
} else {
node.Put(values, ast.SetTerm(head))
}
return nil
})
}
func (e *eval) buildComprehensionCacheObject(x *ast.ObjectComprehension, keys []*ast.Term) (*comprehensionCacheElem, error) {
child := e.child(x.Body)
node := newComprehensionCacheElem()
return node, child.Run(func(child *eval) error {
values := make([]*ast.Term, len(keys))
for i := range keys {
values[i] = child.bindings.Plug(keys[i])
}
headKey := child.bindings.Plug(x.Key)
headValue := child.bindings.Plug(x.Value)
cached := node.Get(values)
if cached != nil {
obj := cached.Value.(ast.Object)
obj.Insert(headKey, headValue)
} else {
node.Put(values, ast.ObjectTerm(ast.Item(headKey, headValue)))
}
return nil
})
}
func (e *eval) biunifyComprehensionPartial(a, b *ast.Term, b1, b2 *bindings, swap bool, iter unifyIterator) error {
var err error
cpyA, err := e.amendComprehension(a, b1)
if err != nil {
return err
}
if ast.IsComprehension(b.Value) {
b, err = e.amendComprehension(b, b2)
if err != nil {
return err
}
}
// The other term might need to be plugged so include the bindings. The
// bindings for the comprehension term are saved (for compatibility) but
// the eventual plug operation on the comprehension will be a no-op.
if !swap {
return e.saveUnify(cpyA, b, b1, b2, iter)
}
return e.saveUnify(b, cpyA, b2, b1, iter)
}
// amendComprehension captures bindings available to the comprehension,
// and used within its term or body.
func (e *eval) amendComprehension(a *ast.Term, b1 *bindings) (*ast.Term, error) {
cpyA := a.Copy()
// Namespace the variables in the body to avoid collision when the final
// queries returned by partial evaluation.
var body *ast.Body
switch a := cpyA.Value.(type) {
case *ast.ArrayComprehension:
body = &a.Body
case *ast.SetComprehension:
body = &a.Body
case *ast.ObjectComprehension:
body = &a.Body
default:
return nil, fmt.Errorf("illegal comprehension %T", a)
}
vars := a.Vars()
err := b1.Iter(e.caller.bindings, func(k, v *ast.Term) error {
if vars.Contains(k.Value.(ast.Var)) {
body.Append(ast.Equality.Expr(k, v))
}
return nil
})
if err != nil {
return nil, err
}
b1.Namespace(cpyA, e.caller.bindings)
return cpyA, nil
}
func (e *eval) biunifyComprehensionArray(x *ast.ArrayComprehension, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error {
result := ast.NewArray()
child := e.closure(x.Body)
err := child.Run(func(child *eval) error {
result = result.Append(child.bindings.Plug(x.Term))
return nil
})
if err != nil {
return err
}
return e.biunify(ast.NewTerm(result), b, b1, b2, iter)
}
func (e *eval) biunifyComprehensionSet(x *ast.SetComprehension, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error {
result := ast.NewSet()
child := e.closure(x.Body)
err := child.Run(func(child *eval) error {
result.Add(child.bindings.Plug(x.Term))
return nil
})
if err != nil {
return err
}
return e.biunify(ast.NewTerm(result), b, b1, b2, iter)
}
func (e *eval) biunifyComprehensionObject(x *ast.ObjectComprehension, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error {
result := ast.NewObject()
child := e.closure(x.Body)
err := child.Run(func(child *eval) error {
key := child.bindings.Plug(x.Key)
value := child.bindings.Plug(x.Value)
exist := result.Get(key)
if exist != nil && !exist.Equal(value) {
return objectDocKeyConflictErr(x.Key.Location)
}
result.Insert(key, value)
return nil
})
if err != nil {
return err
}
return e.biunify(ast.NewTerm(result), b, b1, b2, iter)
}
func (e *eval) saveExpr(expr *ast.Expr, b *bindings, iter unifyIterator) error {
e.updateFromQuery(expr)
e.saveStack.Push(expr, b, b)
e.traceSave(expr)
err := iter()
e.saveStack.Pop()
return err
}
func (e *eval) saveExprMarkUnknowns(expr *ast.Expr, b *bindings, iter unifyIterator) error {
e.updateFromQuery(expr)
declArgsLen, err := e.getDeclArgsLen(expr)
if err != nil {
return err
}
var pops int
if pairs := getSavePairsFromExpr(declArgsLen, expr, b, nil); len(pairs) > 0 {
pops += len(pairs)
for _, p := range pairs {
e.saveSet.Push([]*ast.Term{p.term}, p.b)
}
}
e.saveStack.Push(expr, b, b)
e.traceSave(expr)
err = iter()
e.saveStack.Pop()
for i := 0; i < pops; i++ {
e.saveSet.Pop()
}
return err
}
func (e *eval) saveUnify(a, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error {
e.instr.startTimer(partialOpSaveUnify)
expr := ast.Equality.Expr(a, b)
e.updateFromQuery(expr)
pops := 0
if pairs := getSavePairsFromTerm(a, b1, nil); len(pairs) > 0 {
pops += len(pairs)
for _, p := range pairs {
e.saveSet.Push([]*ast.Term{p.term}, p.b)
}
}
if pairs := getSavePairsFromTerm(b, b2, nil); len(pairs) > 0 {
pops += len(pairs)
for _, p := range pairs {
e.saveSet.Push([]*ast.Term{p.term}, p.b)
}
}
e.saveStack.Push(expr, b1, b2)
e.traceSave(expr)
e.instr.stopTimer(partialOpSaveUnify)
err := iter()
e.saveStack.Pop()
for i := 0; i < pops; i++ {
e.saveSet.Pop()
}
return err
}
func (e *eval) saveCall(declArgsLen int, terms []*ast.Term, iter unifyIterator) error {
expr := ast.NewExpr(terms)
e.updateFromQuery(expr)
// If call-site includes output value then partial eval must add vars in output
// position to the save set.
pops := 0
if declArgsLen == len(terms)-2 {
if pairs := getSavePairsFromTerm(terms[len(terms)-1], e.bindings, nil); len(pairs) > 0 {
pops += len(pairs)
for _, p := range pairs {
e.saveSet.Push([]*ast.Term{p.term}, p.b)
}
}
}
e.saveStack.Push(expr, e.bindings, nil)
e.traceSave(expr)
err := iter()
e.saveStack.Pop()
for i := 0; i < pops; i++ {
e.saveSet.Pop()
}
return err
}
func (e *eval) saveInlinedNegatedExprs(exprs []*ast.Expr, iter unifyIterator) error {
with := make([]*ast.With, len(e.query[e.index].With))
for i := range e.query[e.index].With {
cpy := e.query[e.index].With[i].Copy()
cpy.Value = e.bindings.PlugNamespaced(cpy.Value, e.caller.bindings)
with[i] = cpy
}
for _, expr := range exprs {
expr.With = e.updateSavedMocks(with)
e.saveStack.Push(expr, nil, nil)
e.traceSave(expr)
}
err := iter()
for i := 0; i < len(exprs); i++ {
e.saveStack.Pop()
}
return err
}
func (e *eval) getRules(ref ast.Ref, args []*ast.Term) (*ast.IndexResult, error) {
e.instr.startTimer(evalOpRuleIndex)
defer e.instr.stopTimer(evalOpRuleIndex)
index := e.compiler.RuleIndex(ref)
if index == nil {
return nil, nil
}
var result *ast.IndexResult
var err error
if e.indexing {
result, err = index.Lookup(&evalResolver{e: e, args: args})
} else {
result, err = index.AllRules(&evalResolver{e: e})
}
if err != nil {
return nil, err
}
result.EarlyExit = result.EarlyExit && e.earlyExit
var msg strings.Builder
if len(result.Rules) == 1 {
msg.WriteString("(matched 1 rule")
} else {
msg.Grow(len("(matched NNNN rules)"))
msg.WriteString("(matched ")
msg.WriteString(strconv.Itoa(len(result.Rules)))
msg.WriteString(" rules")
}
if result.EarlyExit {
msg.WriteString(", early exit")
}
msg.WriteRune(')')
e.traceIndex(e.query[e.index], msg.String(), &ref)
return result, err
}
func (e *eval) Resolve(ref ast.Ref) (ast.Value, error) {
return (&evalResolver{e: e}).Resolve(ref)
}
type evalResolver struct {
e *eval
args []*ast.Term
}
func (e *evalResolver) Resolve(ref ast.Ref) (ast.Value, error) {
e.e.instr.startTimer(evalOpResolve)
if e.e.inliningControl.Disabled(ref, true) || e.e.saveSet.Contains(ast.NewTerm(ref), nil) {
e.e.instr.stopTimer(evalOpResolve)
return nil, ast.UnknownValueErr{}
}
// Lookup of function argument values works by using the `args` ref[0],
// where the ast.Number in ref[1] references the function argument of
// that number. The callsite-local arguments are passed in e.args,
// indexed by argument index.
if ref[0].Equal(ast.FunctionArgRootDocument) {
v, ok := ref[1].Value.(ast.Number)
if ok {
i, ok := v.Int()
if ok && i >= 0 && i < len(e.args) {
e.e.instr.stopTimer(evalOpResolve)
plugged := e.e.bindings.PlugNamespaced(e.args[i], e.e.caller.bindings)
return plugged.Value, nil
}
}
e.e.instr.stopTimer(evalOpResolve)
return nil, ast.UnknownValueErr{}
}
if ref[0].Equal(ast.InputRootDocument) {
if e.e.input != nil {
v, err := e.e.input.Value.Find(ref[1:])
if err != nil {
v = nil
}
e.e.instr.stopTimer(evalOpResolve)
return v, nil
}
e.e.instr.stopTimer(evalOpResolve)
return nil, nil
}
if ref[0].Equal(ast.DefaultRootDocument) {
var repValue ast.Value
if e.e.data != nil {
if v, err := e.e.data.Value.Find(ref[1:]); err == nil {
repValue = v
}
}
if e.e.targetStack.Prefixed(ref) {
e.e.instr.stopTimer(evalOpResolve)
return repValue, nil
}
var merged ast.Value
var err error
// Converting large JSON values into AST values can be fairly expensive. For
// example, a 2MB JSON value can take upwards of 30 millisceonds to convert.
// We cache the result of conversion here in case the same base document is
// being read multiple times during evaluation.
realValue := e.e.baseCache.Get(ref)
if realValue != nil {
e.e.instr.counterIncr(evalOpBaseCacheHit)
if repValue == nil {
e.e.instr.stopTimer(evalOpResolve)
return realValue, nil
}
var ok bool
merged, ok = merge(repValue, realValue)
if !ok {
err = mergeConflictErr(ref[0].Location)
}
} else { // baseCache miss
e.e.instr.counterIncr(evalOpBaseCacheMiss)
merged, err = e.e.resolveReadFromStorage(ref, repValue)
}
e.e.instr.stopTimer(evalOpResolve)
return merged, err
}
e.e.instr.stopTimer(evalOpResolve)
return nil, fmt.Errorf("illegal ref")
}
func (e *eval) resolveReadFromStorage(ref ast.Ref, a ast.Value) (ast.Value, error) {
if refContainsNonScalar(ref) {
return a, nil
}
v, err := e.external.Resolve(e, ref)
if err != nil {
return nil, err
}
if v == nil {
path, err := storage.NewPathForRef(ref)
if err != nil {
if !storage.IsNotFound(err) {
return nil, err
}
return a, nil
}
blob, err := e.store.Read(e.ctx, e.txn, path)
if err != nil {
if !storage.IsNotFound(err) {
return nil, err
}
return a, nil
}
if len(path) == 0 {
switch obj := blob.(type) {
case map[string]interface{}:
if len(obj) > 0 {
cpy := make(map[string]interface{}, len(obj)-1)
for k, v := range obj {
if string(ast.SystemDocumentKey) != k {
cpy[k] = v
}
}
blob = cpy
}
case ast.Object:
if obj.Len() > 0 {
cpy := ast.NewObject()
if err := obj.Iter(func(k *ast.Term, v *ast.Term) error {
if !ast.SystemDocumentKey.Equal(k.Value) {
cpy.Insert(k, v)
}
return nil
}); err != nil {
return nil, err
}
blob = cpy
}
}
}
switch blob := blob.(type) {
case ast.Value:
v = blob
default:
if blob, ok := blob.(map[string]interface{}); ok && !e.strictObjects {
v = ast.LazyObject(blob)
break
}
v, err = ast.InterfaceToValue(blob)
if err != nil {
return nil, err
}
}
}
e.baseCache.Put(ref, v)
if a == nil {
return v, nil
}
merged, ok := merge(a, v)
if !ok {
return nil, mergeConflictErr(ref[0].Location)
}
return merged, nil
}
func (e *eval) generateVar(suffix string) *ast.Term {
return ast.VarTerm(fmt.Sprintf("%v_%v", e.genvarprefix, suffix))
}
func (e *eval) rewrittenVar(v ast.Var) (ast.Var, bool) {
if e.compiler != nil {
if rw, ok := e.compiler.RewrittenVars[v]; ok {
return rw, true
}
}
if e.queryCompiler != nil {
if rw, ok := e.queryCompiler.RewrittenVars()[v]; ok {
return rw, true
}
}
return v, false
}
func (e *eval) getDeclArgsLen(x *ast.Expr) (int, error) {
if !x.IsCall() {
return -1, nil
}
operator := x.Operator()
bi, _, ok := e.builtinFunc(operator.String())
if ok {
return len(bi.Decl.Args()), nil
}
ir, err := e.getRules(operator, nil)
if err != nil {
return -1, err
} else if ir == nil || ir.Empty() {
return -1, nil
}
return len(ir.Rules[0].Head.Args), nil
}
// updateFromQuery enriches the passed expression with Location and With
// fields of the currently looked-at query item (`e.query[e.index]`).
// With values are namespaced to ensure that replacement functions of
// mocked built-ins are properly referenced in the support module.
func (e *eval) updateFromQuery(expr *ast.Expr) {
expr.With = e.updateSavedMocks(e.query[e.index].With)
expr.Location = e.query[e.index].Location
}
type evalBuiltin struct {
e *eval
bi *ast.Builtin
bctx BuiltinContext
f BuiltinFunc
terms []*ast.Term
}
// Is this builtin non-deterministic, and did the caller provide an NDBCache?
func (e *evalBuiltin) canUseNDBCache(bi *ast.Builtin) bool {
return bi.Nondeterministic && e.bctx.NDBuiltinCache != nil
}
func (e evalBuiltin) eval(iter unifyIterator) error {
operands := make([]*ast.Term, len(e.terms))
for i := range e.terms {
operands[i] = e.e.bindings.Plug(e.terms[i])
}
numDeclArgs := len(e.bi.Decl.FuncArgs().Args)
e.e.instr.startTimer(evalOpBuiltinCall)
var err error
// NOTE(philipc): We sometimes have to drop the very last term off
// the args list for cases where a builtin's result is used/assigned,
// because the last term will be a generated term, not an actual
// argument to the builtin.
endIndex := len(operands)
if len(operands) > numDeclArgs {
endIndex--
}
// We skip evaluation of the builtin entirely if the NDBCache is
// present, and we have a non-deterministic builtin already cached.
if e.canUseNDBCache(e.bi) {
e.e.instr.stopTimer(evalOpBuiltinCall)
// Unify against the NDBCache result if present.
if v, ok := e.bctx.NDBuiltinCache.Get(e.bi.Name, ast.NewArray(operands[:endIndex]...)); ok {
switch {
case e.bi.Decl.Result() == nil:
return iter()
case len(operands) == numDeclArgs:
if v.Compare(ast.Boolean(false)) == 0 {
return nil // nothing to do
}
return iter()
default:
return e.e.unify(e.terms[endIndex], ast.NewTerm(v), iter)
}
}
// Otherwise, we'll need to go through the normal unify flow.
e.e.instr.startTimer(evalOpBuiltinCall)
}
// Normal unification flow for builtins:
err = e.f(e.bctx, operands, func(output *ast.Term) error {
e.e.instr.stopTimer(evalOpBuiltinCall)
var err error
switch {
case e.bi.Decl.Result() == nil:
err = iter()
case len(operands) == numDeclArgs:
if output.Value.Compare(ast.Boolean(false)) != 0 {
err = iter()
} // else: nothing to do, don't iter()
default:
err = e.e.unify(e.terms[endIndex], output, iter)
}
// If the NDBCache is present, we can assume this builtin
// call was not cached earlier.
if e.canUseNDBCache(e.bi) {
// Populate the NDBCache from the output term.
e.bctx.NDBuiltinCache.Put(e.bi.Name, ast.NewArray(operands[:endIndex]...), output.Value)
}
if err != nil {
// NOTE(sr): We wrap the errors here into Halt{} because we don't want to
// record them into builtinErrors below. The errors set here are coming from
// the call to iter(), not from the builtin implementation.
err = Halt{Err: err}
}
e.e.instr.startTimer(evalOpBuiltinCall)
return err
})
if err != nil {
if t, ok := err.(Halt); ok {
err = t.Err
} else {
e.e.builtinErrors.errs = append(e.e.builtinErrors.errs, err)
err = nil
}
}
e.e.instr.stopTimer(evalOpBuiltinCall)
return err
}
type evalFunc struct {
e *eval
ref ast.Ref
terms []*ast.Term
ir *ast.IndexResult
}
func (e evalFunc) eval(iter unifyIterator) error {
if e.ir.Empty() {
return nil
}
var argCount int
if len(e.ir.Rules) > 0 {
argCount = len(e.ir.Rules[0].Head.Args)
} else if e.ir.Default != nil {
argCount = len(e.ir.Default.Head.Args)
}
if len(e.ir.Else) > 0 && e.e.unknown(e.e.query[e.e.index], e.e.bindings) {
// Partial evaluation of ordered rules is not supported currently. Save the
// expression and continue. This could be revisited in the future.
return e.e.saveCall(argCount, e.terms, iter)
}
if e.e.partial() && (e.e.inliningControl.shallow || e.e.inliningControl.Disabled(e.ref, false)) {
// check if the function definitions, or any of the arguments
// contain something unknown
unknown := e.e.unknown(e.ref, e.e.bindings)
for i := 1; !unknown && i <= argCount; i++ {
unknown = e.e.unknown(e.terms[i], e.e.bindings)
}
if unknown {
return e.partialEvalSupport(argCount, iter)
}
}
return suppressEarlyExit(e.evalValue(iter, argCount, e.ir.EarlyExit))
}
func (e evalFunc) evalValue(iter unifyIterator, argCount int, findOne bool) error {
var cacheKey ast.Ref
var hit bool
var err error
if !e.e.partial() {
cacheKey, hit, err = e.evalCache(argCount, iter)
if err != nil {
return err
} else if hit {
return nil
}
}
var prev *ast.Term
for _, rule := range e.ir.Rules {
next, err := e.evalOneRule(iter, rule, cacheKey, prev, findOne)
if err != nil {
return err
}
if next == nil {
for _, erule := range e.ir.Else[rule] {
next, err = e.evalOneRule(iter, erule, cacheKey, prev, findOne)
if err != nil {
return err
}
if next != nil {
break
}
}
}
if next != nil {
prev = next
}
}
if e.ir.Default != nil && prev == nil {
_, err := e.evalOneRule(iter, e.ir.Default, cacheKey, prev, findOne)
return err
}
return nil
}
func (e evalFunc) evalCache(argCount int, iter unifyIterator) (ast.Ref, bool, error) {
var plen int
if len(e.terms) == argCount+2 { // func name + output = 2
plen = len(e.terms) - 1
} else {
plen = len(e.terms)
}
cacheKey := make([]*ast.Term, plen)
for i := 0; i < plen; i++ {
cacheKey[i] = e.e.bindings.Plug(e.terms[i])
}
cached, _ := e.e.virtualCache.Get(cacheKey)
if cached != nil {
e.e.instr.counterIncr(evalOpVirtualCacheHit)
if argCount == len(e.terms)-1 { // f(x)
if ast.Boolean(false).Equal(cached.Value) {
return nil, true, nil
}
return nil, true, iter()
}
// f(x, y), y captured output value
return nil, true, e.e.unify(e.terms[len(e.terms)-1] /* y */, cached, iter)
}
e.e.instr.counterIncr(evalOpVirtualCacheMiss)
return cacheKey, false, nil
}
func (e evalFunc) evalOneRule(iter unifyIterator, rule *ast.Rule, cacheKey ast.Ref, prev *ast.Term, findOne bool) (*ast.Term, error) {
child := e.e.child(rule.Body)
child.findOne = findOne
args := make([]*ast.Term, len(e.terms)-1)
copy(args, rule.Head.Args)
if len(args) == len(rule.Head.Args)+1 {
args[len(args)-1] = rule.Head.Value
}
var result *ast.Term
child.traceEnter(rule)
err := child.biunifyArrays(ast.NewArray(e.terms[1:]...), ast.NewArray(args...), e.e.bindings, child.bindings, func() error {
return child.eval(func(child *eval) error {
child.traceExit(rule)
// Partial evaluation must save an expression that tests the output value if the output value
// was not captured to handle the case where the output value may be `false`.
if len(rule.Head.Args) == len(e.terms)-1 && e.e.saveSet.Contains(rule.Head.Value, child.bindings) {
err := e.e.saveExpr(ast.NewExpr(rule.Head.Value), child.bindings, iter)
child.traceRedo(rule)
return err
}
result = child.bindings.Plug(rule.Head.Value)
if cacheKey != nil {
e.e.virtualCache.Put(cacheKey, result) // the redos confirm this, or the evaluation is aborted
}
if len(rule.Head.Args) == len(e.terms)-1 {
if result.Value.Compare(ast.Boolean(false)) == 0 {
if prev != nil && ast.Compare(prev, result) != 0 {
return functionConflictErr(rule.Location)
}
prev = result
return nil
}
}
// Partial evaluation should explore all rules and may not produce
// a ground result so we do not perform conflict detection or
// deduplication. See "ignore conflicts: functions" test case for
// an example.
if !e.e.partial() {
if prev != nil {
if ast.Compare(prev, result) != 0 {
return functionConflictErr(rule.Location)
}
child.traceRedo(rule)
return nil
}
}
prev = result
if err := iter(); err != nil {
return err
}
child.traceRedo(rule)
return nil
})
})
return result, err
}
func (e evalFunc) partialEvalSupport(declArgsLen int, iter unifyIterator) error {
path := e.e.namespaceRef(e.ref)
term := ast.NewTerm(path)
if !e.e.saveSupport.Exists(path) {
for _, rule := range e.ir.Rules {
err := e.partialEvalSupportRule(rule, path)
if err != nil {
return err
}
}
}
if !e.e.saveSupport.Exists(path) { // we haven't saved anything, nothing to call
return nil
}
return e.e.saveCall(declArgsLen, append([]*ast.Term{term}, e.terms[1:]...), iter)
}
func (e evalFunc) partialEvalSupportRule(rule *ast.Rule, path ast.Ref) error {
child := e.e.child(rule.Body)
child.traceEnter(rule)
e.e.saveStack.PushQuery(nil)
// treat the function arguments as unknown during rule body evaluation
var args []*ast.Term
ast.WalkVars(rule.Head.Args, func(v ast.Var) bool {
args = append(args, ast.VarTerm(string(v)))
return false
})
e.e.saveSet.Push(args, child.bindings)
err := child.eval(func(child *eval) error {
child.traceExit(rule)
current := e.e.saveStack.PopQuery()
plugged := current.Plug(e.e.caller.bindings)
// Skip this rule body if it fails to type-check.
// Type-checking failure means the rule body will never succeed.
if e.e.compiler.PassesTypeCheck(plugged) {
head := &ast.Head{
Name: rule.Head.Name,
Reference: rule.Head.Reference,
Value: child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings),
Args: make([]*ast.Term, len(rule.Head.Args)),
}
for i, a := range rule.Head.Args {
head.Args[i] = child.bindings.PlugNamespaced(a, e.e.caller.bindings)
}
e.e.saveSupport.Insert(path, &ast.Rule{
Head: head,
Body: plugged,
})
}
child.traceRedo(rule)
e.e.saveStack.PushQuery(current)
return nil
})
e.e.saveSet.Pop()
e.e.saveStack.PopQuery()
return err
}
type evalTree struct {
e *eval
ref ast.Ref
plugged ast.Ref
pos int
bindings *bindings
rterm *ast.Term
rbindings *bindings
node *ast.TreeNode
}
func (e evalTree) eval(iter unifyIterator) error {
if len(e.ref) == e.pos {
return e.finish(iter)
}
plugged := e.bindings.Plug(e.ref[e.pos])
if plugged.IsGround() {
return e.next(iter, plugged)
}
return e.enumerate(iter)
}
func (e evalTree) finish(iter unifyIterator) error {
// In some cases, it may not be possible to PE the ref. If the path refers
// to virtual docs that PE does not support or base documents where inlining
// has been disabled, then we have to save.
save := e.e.unknown(e.plugged, e.e.bindings)
if save {
return e.e.saveUnify(ast.NewTerm(e.plugged), e.rterm, e.bindings, e.rbindings, iter)
}
v, err := e.extent()
if err != nil || v == nil {
return err
}
return e.e.biunify(e.rterm, v, e.rbindings, e.bindings, iter)
}
func (e evalTree) next(iter unifyIterator, plugged *ast.Term) error {
var node *ast.TreeNode
cpy := e
cpy.plugged[e.pos] = plugged
cpy.pos++
if !e.e.targetStack.Prefixed(cpy.plugged[:cpy.pos]) {
if e.node != nil {
node = e.node.Child(plugged.Value)
if node != nil && len(node.Values) > 0 {
r := evalVirtual{
e: e.e,
ref: e.ref,
plugged: e.plugged,
pos: e.pos,
bindings: e.bindings,
rterm: e.rterm,
rbindings: e.rbindings,
}
r.plugged[e.pos] = plugged
return r.eval(iter)
}
}
}
cpy.node = node
return cpy.eval(iter)
}
func (e evalTree) enumerate(iter unifyIterator) error {
if e.e.inliningControl.Disabled(e.plugged[:e.pos], true) {
return e.e.saveUnify(ast.NewTerm(e.plugged), e.rterm, e.bindings, e.rbindings, iter)
}
doc, err := e.e.Resolve(e.plugged[:e.pos])
if err != nil {
return err
}
if doc != nil {
switch doc := doc.(type) {
case *ast.Array:
for i := 0; i < doc.Len(); i++ {
k := ast.IntNumberTerm(i)
err := e.e.biunify(k, e.ref[e.pos], e.bindings, e.bindings, func() error {
return e.next(iter, k)
})
if err != nil {
return err
}
}
case ast.Object:
ki := doc.KeysIterator()
for k, more := ki.Next(); more; k, more = ki.Next() {
if err := e.e.biunify(k, e.ref[e.pos], e.bindings, e.bindings, func() error {
return e.next(iter, k)
}); err != nil {
return err
}
}
case ast.Set:
err := doc.Iter(func(elem *ast.Term) error {
return e.e.biunify(elem, e.ref[e.pos], e.bindings, e.bindings, func() error {
return e.next(iter, elem)
})
})
if err != nil {
return err
}
}
}
if e.node == nil {
return nil
}
for _, k := range e.node.Sorted {
key := ast.NewTerm(k)
if err := e.e.biunify(key, e.ref[e.pos], e.bindings, e.bindings, func() error {
return e.next(iter, key)
}); err != nil {
return err
}
}
return nil
}
func (e evalTree) extent() (*ast.Term, error) {
base, err := e.e.Resolve(e.plugged)
if err != nil {
return nil, err
}
virtual, err := e.leaves(e.plugged, e.node)
if err != nil {
return nil, err
}
if virtual == nil {
if base == nil {
return nil, nil
}
return ast.NewTerm(base), nil
}
if base != nil {
merged, ok := merge(base, virtual)
if !ok {
return nil, mergeConflictErr(e.plugged[0].Location)
}
return ast.NewTerm(merged), nil
}
return ast.NewTerm(virtual), nil
}
// leaves builds a tree from evaluating the full rule tree extent, by recursing into all
// branches, and building up objects as it goes.
func (e evalTree) leaves(plugged ast.Ref, node *ast.TreeNode) (ast.Object, error) {
if e.node == nil {
return nil, nil
}
result := ast.NewObject()
for _, k := range node.Sorted {
child := node.Children[k]
if child.Hide {
continue
}
plugged = append(plugged, ast.NewTerm(child.Key))
var save ast.Value
var err error
if len(child.Values) > 0 {
rterm := e.e.generateVar("leaf")
err = e.e.unify(ast.NewTerm(plugged), rterm, func() error {
save = e.e.bindings.Plug(rterm).Value
return nil
})
} else {
save, err = e.leaves(plugged, child)
}
if err != nil {
return nil, err
}
if save != nil {
v := ast.NewObject([2]*ast.Term{plugged[len(plugged)-1], ast.NewTerm(save)})
result, _ = result.Merge(v)
}
plugged = plugged[:len(plugged)-1]
}
return result, nil
}
type evalVirtual struct {
e *eval
ref ast.Ref
plugged ast.Ref
pos int
bindings *bindings
rterm *ast.Term
rbindings *bindings
}
func (e evalVirtual) eval(iter unifyIterator) error {
ir, err := e.e.getRules(e.plugged[:e.pos+1], nil)
if err != nil {
return err
}
// Partial evaluation of ordered rules is not supported currently. Save the
// expression and continue. This could be revisited in the future.
if len(ir.Else) > 0 && e.e.unknown(e.ref, e.bindings) {
return e.e.saveUnify(ast.NewTerm(e.ref), e.rterm, e.bindings, e.rbindings, iter)
}
switch ir.Kind {
case ast.MultiValue:
var empty *ast.Term
if ir.OnlyGroundRefs {
// rule ref contains no vars, so we're building a set
empty = ast.SetTerm()
} else {
// rule ref contains vars, so we're building an object containing a set leaf
empty = ast.ObjectTerm()
}
eval := evalVirtualPartial{
e: e.e,
ref: e.ref,
plugged: e.plugged,
pos: e.pos,
ir: ir,
bindings: e.bindings,
rterm: e.rterm,
rbindings: e.rbindings,
empty: empty,
}
return eval.eval(iter)
case ast.SingleValue:
if ir.OnlyGroundRefs {
eval := evalVirtualComplete{
e: e.e,
ref: e.ref,
plugged: e.plugged,
pos: e.pos,
ir: ir,
bindings: e.bindings,
rterm: e.rterm,
rbindings: e.rbindings,
}
return eval.eval(iter)
}
eval := evalVirtualPartial{
e: e.e,
ref: e.ref,
plugged: e.plugged,
pos: e.pos,
ir: ir,
bindings: e.bindings,
rterm: e.rterm,
rbindings: e.rbindings,
empty: ast.ObjectTerm(),
}
return eval.eval(iter)
default:
panic("unreachable")
}
}
type evalVirtualPartial struct {
e *eval
ref ast.Ref
plugged ast.Ref
pos int
ir *ast.IndexResult
bindings *bindings
rterm *ast.Term
rbindings *bindings
empty *ast.Term
}
type evalVirtualPartialCacheHint struct {
key ast.Ref
hit bool
full bool
}
func (e evalVirtualPartial) eval(iter unifyIterator) error {
unknown := e.e.unknown(e.ref[:e.pos+1], e.bindings)
if len(e.ref) == e.pos+1 {
if unknown {
return e.partialEvalSupport(iter)
}
return e.evalAllRules(iter, e.ir.Rules)
}
if (unknown && e.e.inliningControl.shallow) || e.e.inliningControl.Disabled(e.ref[:e.pos+1], false) {
return e.partialEvalSupport(iter)
}
return e.evalEachRule(iter, unknown)
}
// returns the maximum length a ref can be without being longer than the longest rule ref in rules.
func maxRefLength(rules []*ast.Rule, ceil int) int {
var l int
for _, r := range rules {
rl := len(r.Ref())
if r.Head.RuleKind() == ast.MultiValue {
rl = rl + 1
}
if rl >= ceil {
return ceil
} else if rl > l {
l = rl
}
}
return l
}
func (e evalVirtualPartial) evalEachRule(iter unifyIterator, unknown bool) error {
if e.ir.Empty() {
return nil
}
m := maxRefLength(e.ir.Rules, len(e.ref))
if e.e.unknown(e.ref[e.pos+1:m], e.bindings) {
for _, rule := range e.ir.Rules {
if err := e.evalOneRulePostUnify(iter, rule); err != nil {
return err
}
}
return nil
}
hint, err := e.evalCache(iter)
if err != nil {
return err
} else if hint.hit {
return nil
}
if hint.full {
result, err := e.evalAllRulesNoCache(e.ir.Rules)
if err != nil {
return err
}
e.e.virtualCache.Put(hint.key, result)
return e.evalTerm(iter, e.pos+1, result, e.bindings)
}
result := e.empty
var visitedRefs []ast.Ref
for _, rule := range e.ir.Rules {
result, err = e.evalOneRulePreUnify(iter, rule, result, unknown, &visitedRefs)
if err != nil {
return err
}
}
if hint.key != nil {
if v, err := result.Value.Find(hint.key[e.pos+1:]); err == nil && v != nil {
e.e.virtualCache.Put(hint.key, ast.NewTerm(v))
}
}
if !unknown {
return e.evalTerm(iter, e.pos+1, result, e.bindings)
}
return nil
}
func (e evalVirtualPartial) evalAllRules(iter unifyIterator, rules []*ast.Rule) error {
cacheKey := e.plugged[:e.pos+1]
result, _ := e.e.virtualCache.Get(cacheKey)
if result != nil {
e.e.instr.counterIncr(evalOpVirtualCacheHit)
return e.e.biunify(result, e.rterm, e.bindings, e.rbindings, iter)
}
e.e.instr.counterIncr(evalOpVirtualCacheMiss)
result, err := e.evalAllRulesNoCache(rules)
if err != nil {
return err
}
if cacheKey != nil {
e.e.virtualCache.Put(cacheKey, result)
}
return e.e.biunify(result, e.rterm, e.bindings, e.rbindings, iter)
}
func (e evalVirtualPartial) evalAllRulesNoCache(rules []*ast.Rule) (*ast.Term, error) {
result := e.empty
var visitedRefs []ast.Ref
for _, rule := range rules {
child := e.e.child(rule.Body)
child.traceEnter(rule)
err := child.eval(func(*eval) error {
child.traceExit(rule)
var err error
result, _, err = e.reduce(rule, child.bindings, result, &visitedRefs)
if err != nil {
return err
}
child.traceRedo(rule)
return nil
})
if err != nil {
return nil, err
}
}
return result, nil
}
func wrapInObjects(leaf *ast.Term, ref ast.Ref) *ast.Term {
// We build the nested objects leaf-to-root to preserve ground:ness
if len(ref) == 0 {
return leaf
}
key := ref[0]
val := wrapInObjects(leaf, ref[1:])
return ast.ObjectTerm(ast.Item(key, val))
}
func (e evalVirtualPartial) evalOneRulePreUnify(iter unifyIterator, rule *ast.Rule, result *ast.Term, unknown bool, visitedRefs *[]ast.Ref) (*ast.Term, error) {
child := e.e.child(rule.Body)
child.traceEnter(rule)
var defined bool
headKey := rule.Head.Key
if headKey == nil {
headKey = rule.Head.Reference[len(rule.Head.Reference)-1]
}
// Walk the dynamic portion of rule ref and key to unify vars
err := child.biunifyRuleHead(e.pos+1, e.ref, rule, e.bindings, child.bindings, func(pos int) error {
defined = true
return child.eval(func(child *eval) error {
child.traceExit(rule)
term := rule.Head.Value
if term == nil {
term = headKey
}
if unknown {
term, termbindings := child.bindings.apply(term)
if rule.Head.RuleKind() == ast.MultiValue {
term = ast.SetTerm(term)
}
objRef := rule.Ref()[e.pos+1:]
term = wrapInObjects(term, objRef)
err := e.evalTerm(iter, e.pos+1, term, termbindings)
if err != nil {
return err
}
} else {
var dup bool
var err error
result, dup, err = e.reduce(rule, child.bindings, result, visitedRefs)
if err != nil {
return err
} else if !unknown && dup {
child.traceDuplicate(rule)
return nil
}
}
child.traceRedo(rule)
return nil
})
})
if err != nil {
return nil, err
}
if !defined {
child.traceFail(rule)
}
return result, nil
}
func (e *eval) biunifyRuleHead(pos int, ref ast.Ref, rule *ast.Rule, refBindings, ruleBindings *bindings, iter unifyRefIterator) error {
return e.biunifyDynamicRef(pos, ref, rule.Ref(), refBindings, ruleBindings, func(pos int) error {
// FIXME: Is there a simpler, more robust way of figuring out that we should biunify the rule key?
if rule.Head.RuleKind() == ast.MultiValue && pos < len(ref) && len(rule.Ref()) <= len(ref) {
headKey := rule.Head.Key
if headKey == nil {
headKey = rule.Head.Reference[len(rule.Head.Reference)-1]
}
return e.biunify(ref[pos], headKey, refBindings, ruleBindings, func() error {
return iter(pos + 1)
})
}
return iter(pos)
})
}
func (e *eval) biunifyDynamicRef(pos int, a, b ast.Ref, b1, b2 *bindings, iter unifyRefIterator) error {
if pos >= len(a) || pos >= len(b) {
return iter(pos)
}
return e.biunify(a[pos], b[pos], b1, b2, func() error {
return e.biunifyDynamicRef(pos+1, a, b, b1, b2, iter)
})
}
func (e evalVirtualPartial) evalOneRulePostUnify(iter unifyIterator, rule *ast.Rule) error {
child := e.e.child(rule.Body)
child.traceEnter(rule)
var defined bool
err := child.eval(func(child *eval) error {
defined = true
return e.e.biunifyRuleHead(e.pos+1, e.ref, rule, e.bindings, child.bindings, func(pos int) error {
return e.evalOneRuleContinue(iter, rule, child)
})
})
if err != nil {
return err
}
if !defined {
child.traceFail(rule)
}
return nil
}
func (e evalVirtualPartial) evalOneRuleContinue(iter unifyIterator, rule *ast.Rule, child *eval) error {
child.traceExit(rule)
term := rule.Head.Value
if term == nil {
term = rule.Head.Key
}
term, termbindings := child.bindings.apply(term)
if rule.Head.RuleKind() == ast.MultiValue {
term = ast.SetTerm(term)
}
objRef := rule.Ref()[e.pos+1:]
term = wrapInObjects(term, objRef)
err := e.evalTerm(iter, e.pos+1, term, termbindings)
if err != nil {
return err
}
child.traceRedo(rule)
return nil
}
func (e evalVirtualPartial) partialEvalSupport(iter unifyIterator) error {
path := e.e.namespaceRef(e.plugged[:e.pos+1])
term := ast.NewTerm(e.e.namespaceRef(e.ref))
var defined bool
if e.e.saveSupport.Exists(path) {
defined = true
} else {
for i := range e.ir.Rules {
ok, err := e.partialEvalSupportRule(e.ir.Rules[i], path)
if err != nil {
return err
}
if ok {
defined = true
}
}
}
if !defined {
if len(e.ref) != e.pos+1 {
return nil
}
// the entire partial set/obj was queried, e.g. data.a.q (not data.a.q[x])
term = e.empty
}
return e.e.saveUnify(term, e.rterm, e.bindings, e.rbindings, iter)
}
func (e evalVirtualPartial) partialEvalSupportRule(rule *ast.Rule, path ast.Ref) (bool, error) {
child := e.e.child(rule.Body)
child.traceEnter(rule)
e.e.saveStack.PushQuery(nil)
var defined bool
err := child.eval(func(child *eval) error {
child.traceExit(rule)
defined = true
current := e.e.saveStack.PopQuery()
plugged := current.Plug(e.e.caller.bindings)
// Skip this rule body if it fails to type-check.
// Type-checking failure means the rule body will never succeed.
if e.e.compiler.PassesTypeCheck(plugged) {
var value *ast.Term
if rule.Head.Value != nil {
value = child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings)
}
ref := e.e.namespaceRef(rule.Ref())
for i := 1; i < len(ref); i++ {
ref[i] = child.bindings.plugNamespaced(ref[i], e.e.caller.bindings)
}
pkg, ruleRef := splitPackageAndRule(ref)
head := ast.RefHead(ruleRef, value)
// key is also part of ref in single-value rules, and can be dropped
if rule.Head.Key != nil && rule.Head.RuleKind() == ast.MultiValue {
head.Key = child.bindings.PlugNamespaced(rule.Head.Key, e.e.caller.bindings)
}
if rule.Head.RuleKind() == ast.SingleValue && len(ruleRef) == 2 {
head.Key = ruleRef[len(ruleRef)-1]
}
if head.Name.Equal(ast.Var("")) && (len(ruleRef) == 1 || (len(ruleRef) == 2 && rule.Head.RuleKind() == ast.SingleValue)) {
head.Name = ruleRef[0].Value.(ast.Var)
}
if !e.e.inliningControl.shallow {
cp := copypropagation.New(head.Vars()).
WithEnsureNonEmptyBody(true).
WithCompiler(e.e.compiler)
plugged = applyCopyPropagation(cp, e.e.instr, plugged)
}
e.e.saveSupport.InsertByPkg(pkg, &ast.Rule{
Head: head,
Body: plugged,
Default: rule.Default,
})
}
child.traceRedo(rule)
e.e.saveStack.PushQuery(current)
return nil
})
e.e.saveStack.PopQuery()
return defined, err
}
func (e evalVirtualPartial) evalTerm(iter unifyIterator, pos int, term *ast.Term, termbindings *bindings) error {
eval := evalTerm{
e: e.e,
ref: e.ref,
pos: pos,
bindings: e.bindings,
term: term,
termbindings: termbindings,
rterm: e.rterm,
rbindings: e.rbindings,
}
return eval.eval(iter)
}
func (e evalVirtualPartial) evalCache(iter unifyIterator) (evalVirtualPartialCacheHint, error) {
var hint evalVirtualPartialCacheHint
if e.e.unknown(e.ref[:e.pos+1], e.bindings) {
// FIXME: Return empty hint if unknowns in any e.ref elem overlapping with applicable rule refs?
return hint, nil
}
if cached, _ := e.e.virtualCache.Get(e.plugged[:e.pos+1]); cached != nil { // have full extent cached
e.e.instr.counterIncr(evalOpVirtualCacheHit)
hint.hit = true
return hint, e.evalTerm(iter, e.pos+1, cached, e.bindings)
}
plugged := e.bindings.Plug(e.ref[e.pos+1])
if _, ok := plugged.Value.(ast.Var); ok {
hint.full = true
hint.key = e.plugged[:e.pos+1]
e.e.instr.counterIncr(evalOpVirtualCacheMiss)
return hint, nil
}
m := maxRefLength(e.ir.Rules, len(e.ref))
for i := e.pos + 1; i < m; i++ {
plugged = e.bindings.Plug(e.ref[i])
if !plugged.IsGround() {
break
}
hint.key = append(e.plugged[:i], plugged)
if cached, _ := e.e.virtualCache.Get(hint.key); cached != nil {
e.e.instr.counterIncr(evalOpVirtualCacheHit)
hint.hit = true
return hint, e.evalTerm(iter, i+1, cached, e.bindings)
}
}
e.e.instr.counterIncr(evalOpVirtualCacheMiss)
return hint, nil
}
func getNestedObject(ref ast.Ref, rootObj *ast.Object, b *bindings, l *ast.Location) (*ast.Object, error) {
current := rootObj
for _, term := range ref {
key := b.Plug(term)
if child := (*current).Get(key); child != nil {
if val, ok := child.Value.(ast.Object); ok {
current = &val
} else {
return nil, objectDocKeyConflictErr(l)
}
} else {
child := ast.NewObject()
(*current).Insert(key, ast.NewTerm(child))
current = &child
}
}
return current, nil
}
func hasCollisions(path ast.Ref, visitedRefs *[]ast.Ref, b *bindings) bool {
collisionPathTerm := b.Plug(ast.NewTerm(path))
collisionPath := collisionPathTerm.Value.(ast.Ref)
for _, c := range *visitedRefs {
if collisionPath.HasPrefix(c) && !collisionPath.Equal(c) {
return true
}
}
*visitedRefs = append(*visitedRefs, collisionPath)
return false
}
func (e evalVirtualPartial) reduce(rule *ast.Rule, b *bindings, result *ast.Term, visitedRefs *[]ast.Ref) (*ast.Term, bool, error) {
var exists bool
head := rule.Head
switch v := result.Value.(type) {
case ast.Set:
key := b.Plug(head.Key)
exists = v.Contains(key)
v.Add(key)
case ast.Object:
// data.p.q[r].s.t := 42 {...}
// |----|-|
// ^ ^
// | leafKey
// objPath
fullPath := rule.Ref()
collisionPath := fullPath[e.pos+1:]
if hasCollisions(collisionPath, visitedRefs, b) {
return nil, false, objectDocKeyConflictErr(head.Location)
}
objPath := fullPath[e.pos+1 : len(fullPath)-1] // the portion of the ref that generates nested objects
leafKey := b.Plug(fullPath[len(fullPath)-1]) // the portion of the ref that is the deepest nested key for the value
leafObj, err := getNestedObject(objPath, &v, b, head.Location)
if err != nil {
return nil, false, err
}
if kind := head.RuleKind(); kind == ast.SingleValue {
// We're inserting into an object
val := b.Plug(head.Value) // head.Value instance is shared between rule enumerations;but this is ok, as we don't allow rules to modify each others values.
if curr := (*leafObj).Get(leafKey); curr != nil {
if !curr.Equal(val) {
return nil, false, objectDocKeyConflictErr(head.Location)
}
exists = true
} else {
(*leafObj).Insert(leafKey, val)
}
} else {
// We're inserting into a set
var set *ast.Set
if leaf := (*leafObj).Get(leafKey); leaf != nil {
if s, ok := leaf.Value.(ast.Set); ok {
set = &s
} else {
return nil, false, objectDocKeyConflictErr(head.Location)
}
} else {
s := ast.NewSet()
(*leafObj).Insert(leafKey, ast.NewTerm(s))
set = &s
}
key := b.Plug(head.Key)
exists = (*set).Contains(key)
(*set).Add(key)
}
}
return result, exists, nil
}
type evalVirtualComplete struct {
e *eval
ref ast.Ref
plugged ast.Ref
pos int
ir *ast.IndexResult
bindings *bindings
rterm *ast.Term
rbindings *bindings
}
func (e evalVirtualComplete) eval(iter unifyIterator) error {
if e.ir.Empty() {
return nil
}
// When evaluating the full extent, skip functions.
if len(e.ir.Rules) > 0 && len(e.ir.Rules[0].Head.Args) > 0 ||
e.ir.Default != nil && len(e.ir.Default.Head.Args) > 0 {
return nil
}
if !e.e.unknown(e.ref, e.bindings) {
return suppressEarlyExit(e.evalValue(iter, e.ir.EarlyExit))
}
var generateSupport bool
if e.ir.Default != nil {
// If the other term is not constant OR it's equal to the default value, then
// a support rule must be produced as the default value _may_ be required. On
// the other hand, if the other term is constant (i.e., it does not require
// evaluation) and it differs from the default value then the default value is
// _not_ required, so partially evaluate the rule normally.
rterm := e.rbindings.Plug(e.rterm)
generateSupport = !ast.IsConstant(rterm.Value) || e.ir.Default.Head.Value.Equal(rterm)
}
if generateSupport || e.e.inliningControl.shallow || e.e.inliningControl.Disabled(e.plugged[:e.pos+1], false) {
return e.partialEvalSupport(iter)
}
return e.partialEval(iter)
}
func (e evalVirtualComplete) evalValue(iter unifyIterator, findOne bool) error {
cached, undefined := e.e.virtualCache.Get(e.plugged[:e.pos+1])
if undefined {
e.e.instr.counterIncr(evalOpVirtualCacheHit)
return nil
}
if cached != nil {
e.e.instr.counterIncr(evalOpVirtualCacheHit)
return e.evalTerm(iter, cached, e.bindings)
}
e.e.instr.counterIncr(evalOpVirtualCacheMiss)
var prev *ast.Term
for _, rule := range e.ir.Rules {
next, err := e.evalValueRule(iter, rule, prev, findOne)
if err != nil {
return err
}
if next == nil {
for _, erule := range e.ir.Else[rule] {
next, err = e.evalValueRule(iter, erule, prev, findOne)
if err != nil {
return err
}
if next != nil {
break
}
}
}
if next != nil {
prev = next
}
}
if e.ir.Default != nil && prev == nil {
_, err := e.evalValueRule(iter, e.ir.Default, prev, findOne)
return err
}
if prev == nil {
e.e.virtualCache.Put(e.plugged[:e.pos+1], nil)
}
return nil
}
func (e evalVirtualComplete) evalValueRule(iter unifyIterator, rule *ast.Rule, prev *ast.Term, findOne bool) (*ast.Term, error) {
child := e.e.child(rule.Body)
child.findOne = findOne
child.traceEnter(rule)
var result *ast.Term
err := child.eval(func(child *eval) error {
child.traceExit(rule)
result = child.bindings.Plug(rule.Head.Value)
if prev != nil {
if ast.Compare(result, prev) != 0 {
return completeDocConflictErr(rule.Location)
}
child.traceRedo(rule)
return nil
}
prev = result
e.e.virtualCache.Put(e.plugged[:e.pos+1], result)
term, termbindings := child.bindings.apply(rule.Head.Value)
err := e.evalTerm(iter, term, termbindings)
if err != nil {
return err
}
child.traceRedo(rule)
return nil
})
return result, err
}
func (e evalVirtualComplete) partialEval(iter unifyIterator) error {
for _, rule := range e.ir.Rules {
child := e.e.child(rule.Body)
child.traceEnter(rule)
err := child.eval(func(child *eval) error {
child.traceExit(rule)
term, termbindings := child.bindings.apply(rule.Head.Value)
err := e.evalTerm(iter, term, termbindings)
if err != nil {
return err
}
child.traceRedo(rule)
return nil
})
if err != nil {
return err
}
}
return nil
}
func (e evalVirtualComplete) partialEvalSupport(iter unifyIterator) error {
path := e.e.namespaceRef(e.plugged[:e.pos+1])
term := ast.NewTerm(e.e.namespaceRef(e.ref))
var defined bool
if e.e.saveSupport.Exists(path) {
defined = true
} else {
for i := range e.ir.Rules {
ok, err := e.partialEvalSupportRule(e.ir.Rules[i], path)
if err != nil {
return err
}
if ok {
defined = true
}
}
if e.ir.Default != nil {
ok, err := e.partialEvalSupportRule(e.ir.Default, path)
if err != nil {
return err
}
if ok {
defined = true
}
}
}
if !defined {
return nil
}
return e.e.saveUnify(term, e.rterm, e.bindings, e.rbindings, iter)
}
func (e evalVirtualComplete) partialEvalSupportRule(rule *ast.Rule, path ast.Ref) (bool, error) {
child := e.e.child(rule.Body)
child.traceEnter(rule)
e.e.saveStack.PushQuery(nil)
var defined bool
err := child.eval(func(child *eval) error {
child.traceExit(rule)
defined = true
current := e.e.saveStack.PopQuery()
plugged := current.Plug(e.e.caller.bindings)
// Skip this rule body if it fails to type-check.
// Type-checking failure means the rule body will never succeed.
if e.e.compiler.PassesTypeCheck(plugged) {
pkg, ruleRef := splitPackageAndRule(path)
head := ast.RefHead(ruleRef, child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings))
if !e.e.inliningControl.shallow {
cp := copypropagation.New(head.Vars()).
WithEnsureNonEmptyBody(true).
WithCompiler(e.e.compiler)
plugged = applyCopyPropagation(cp, e.e.instr, plugged)
}
e.e.saveSupport.InsertByPkg(pkg, &ast.Rule{
Head: head,
Body: plugged,
Default: rule.Default,
})
}
child.traceRedo(rule)
e.e.saveStack.PushQuery(current)
return nil
})
e.e.saveStack.PopQuery()
return defined, err
}
func (e evalVirtualComplete) evalTerm(iter unifyIterator, term *ast.Term, termbindings *bindings) error {
eval := evalTerm{
e: e.e,
ref: e.ref,
pos: e.pos + 1,
bindings: e.bindings,
term: term,
termbindings: termbindings,
rterm: e.rterm,
rbindings: e.rbindings,
}
return eval.eval(iter)
}
type evalTerm struct {
e *eval
ref ast.Ref
pos int
bindings *bindings
term *ast.Term
termbindings *bindings
rterm *ast.Term
rbindings *bindings
}
func (e evalTerm) eval(iter unifyIterator) error {
if len(e.ref) == e.pos {
return e.e.biunify(e.term, e.rterm, e.termbindings, e.rbindings, iter)
}
if e.e.saveSet.Contains(e.term, e.termbindings) {
return e.save(iter)
}
plugged := e.bindings.Plug(e.ref[e.pos])
if plugged.IsGround() {
return e.next(iter, plugged)
}
return e.enumerate(iter)
}
func (e evalTerm) next(iter unifyIterator, plugged *ast.Term) error {
term, bindings := e.get(plugged)
if term == nil {
return nil
}
cpy := e
cpy.term = term
cpy.termbindings = bindings
cpy.pos++
return cpy.eval(iter)
}
func (e evalTerm) enumerate(iter unifyIterator) error {
switch v := e.term.Value.(type) {
case *ast.Array:
for i := 0; i < v.Len(); i++ {
k := ast.IntNumberTerm(i)
err := e.e.biunify(k, e.ref[e.pos], e.bindings, e.bindings, func() error {
return e.next(iter, k)
})
if err != nil {
return err
}
}
case ast.Object:
return v.Iter(func(k, _ *ast.Term) error {
return e.e.biunify(k, e.ref[e.pos], e.termbindings, e.bindings, func() error {
return e.next(iter, e.termbindings.Plug(k))
})
})
case ast.Set:
return v.Iter(func(elem *ast.Term) error {
return e.e.biunify(elem, e.ref[e.pos], e.termbindings, e.bindings, func() error {
return e.next(iter, e.termbindings.Plug(elem))
})
})
}
return nil
}
func (e evalTerm) get(plugged *ast.Term) (*ast.Term, *bindings) {
switch v := e.term.Value.(type) {
case ast.Set:
if v.IsGround() {
if v.Contains(plugged) {
return e.termbindings.apply(plugged)
}
} else {
var t *ast.Term
var b *bindings
stop := v.Until(func(elem *ast.Term) bool {
if e.termbindings.Plug(elem).Equal(plugged) {
t, b = e.termbindings.apply(plugged)
return true
}
return false
})
if stop {
return t, b
}
}
case ast.Object:
if v.IsGround() {
term := v.Get(plugged)
if term != nil {
return e.termbindings.apply(term)
}
} else {
var t *ast.Term
var b *bindings
stop := v.Until(func(k, v *ast.Term) bool {
if e.termbindings.Plug(k).Equal(plugged) {
t, b = e.termbindings.apply(v)
return true
}
return false
})
if stop {
return t, b
}
}
case *ast.Array:
term := v.Get(plugged)
if term != nil {
return e.termbindings.apply(term)
}
}
return nil, nil
}
func (e evalTerm) save(iter unifyIterator) error {
v := e.e.generateVar(fmt.Sprintf("ref_%d", e.e.genvarid))
e.e.genvarid++
return e.e.biunify(e.term, v, e.termbindings, e.bindings, func() error {
suffix := e.ref[e.pos:]
ref := make(ast.Ref, len(suffix)+1)
ref[0] = v
copy(ref[1:], suffix)
return e.e.biunify(ast.NewTerm(ref), e.rterm, e.bindings, e.rbindings, iter)
})
}
type evalEvery struct {
e *eval
expr *ast.Expr
generator ast.Body
body ast.Body
}
func (e evalEvery) eval(iter unifyIterator) error {
// unknowns in domain or body: save the expression, PE its body
if e.e.unknown(e.generator, e.e.bindings) || e.e.unknown(e.body, e.e.bindings) {
return e.save(iter)
}
domain := e.e.closure(e.generator)
all := true // all generator evaluations yield one successful body evaluation
domain.traceEnter(e.expr)
err := domain.eval(func(child *eval) error {
if !all {
// NOTE(sr): Is this good enough? We don't have a "fail EE".
// This would do extra work, like iterating needlessly if domain was a large array.
return nil
}
body := child.closure(e.body)
body.findOne = true
body.traceEnter(e.body)
done := false
err := body.eval(func(*eval) error {
body.traceExit(e.body)
done = true
body.traceRedo(e.body)
return nil
})
if !done {
all = false
}
child.traceRedo(e.expr)
return err
})
if err != nil {
return err
}
if all {
err := iter()
domain.traceExit(e.expr)
return err
}
domain.traceFail(e.expr)
return nil
}
func (e *evalEvery) save(iter unifyIterator) error {
return e.e.saveExpr(e.plug(e.expr), e.e.bindings, iter)
}
func (e *evalEvery) plug(expr *ast.Expr) *ast.Expr {
cpy := expr.Copy()
every := cpy.Terms.(*ast.Every)
for i := range every.Body {
switch t := every.Body[i].Terms.(type) {
case *ast.Term:
every.Body[i].Terms = e.e.bindings.PlugNamespaced(t, e.e.caller.bindings)
case []*ast.Term:
for j := 1; j < len(t); j++ { // don't plug operator, t[0]
t[j] = e.e.bindings.PlugNamespaced(t[j], e.e.caller.bindings)
}
case *ast.Every:
every.Body[i] = e.plug(every.Body[i])
}
}
every.Key = e.e.bindings.PlugNamespaced(every.Key, e.e.caller.bindings)
every.Value = e.e.bindings.PlugNamespaced(every.Value, e.e.caller.bindings)
every.Domain = e.e.bindings.PlugNamespaced(every.Domain, e.e.caller.bindings)
cpy.Terms = every
return cpy
}
func (e *eval) comprehensionIndex(term *ast.Term) *ast.ComprehensionIndex {
if e.queryCompiler != nil {
return e.queryCompiler.ComprehensionIndex(term)
}
return e.compiler.ComprehensionIndex(term)
}
func (e *eval) namespaceRef(ref ast.Ref) ast.Ref {
if e.skipSaveNamespace {
return ref.Copy()
}
return ref.Insert(e.saveNamespace, 1)
}
type savePair struct {
term *ast.Term
b *bindings
}
func getSavePairsFromExpr(declArgsLen int, x *ast.Expr, b *bindings, result []savePair) []savePair {
switch terms := x.Terms.(type) {
case *ast.Term:
return getSavePairsFromTerm(terms, b, result)
case []*ast.Term:
if x.IsEquality() {
return getSavePairsFromTerm(terms[2], b, getSavePairsFromTerm(terms[1], b, result))
}
if declArgsLen == len(terms)-2 {
return getSavePairsFromTerm(terms[len(terms)-1], b, result)
}
}
return result
}
func getSavePairsFromTerm(x *ast.Term, b *bindings, result []savePair) []savePair {
if _, ok := x.Value.(ast.Var); ok {
result = append(result, savePair{x, b})
return result
}
vis := ast.NewVarVisitor().WithParams(ast.VarVisitorParams{
SkipClosures: true,
SkipRefHead: true,
})
vis.Walk(x)
for v := range vis.Vars() {
y, next := b.apply(ast.NewTerm(v))
result = getSavePairsFromTerm(y, next, result)
}
return result
}
func applyCopyPropagation(p *copypropagation.CopyPropagator, instr *Instrumentation, body ast.Body) ast.Body {
instr.startTimer(partialOpCopyPropagation)
result := p.Apply(body)
instr.stopTimer(partialOpCopyPropagation)
return result
}
func nonGroundKeys(a ast.Object) bool {
return a.Until(func(k, _ *ast.Term) bool {
return !k.IsGround()
})
}
func plugKeys(a ast.Object, b *bindings) ast.Object {
plugged, _ := a.Map(func(k, v *ast.Term) (*ast.Term, *ast.Term, error) {
return b.Plug(k), v, nil
})
return plugged
}
func canInlineNegation(safe ast.VarSet, queries []ast.Body) bool {
size := 1
vis := newNestedCheckVisitor()
for _, query := range queries {
size *= len(query)
for _, expr := range query {
if containsNestedRefOrCall(vis, expr) {
// Expressions containing nested refs or calls cannot be trivially negated
// because the semantics would change. For example, the complement of `not f(input.x)`
// is _not_ `f(input.x)`--it is `not input.x` OR `f(input.x)`.
//
// NOTE(tsandall): Since this would require the complement function to undo the
// copy propagation optimization, just bail out here. If this becomes a problem
// in the future, we can handle more cases.
return false
}
if !expr.Negated {
// Positive expressions containing variables cannot be trivially negated
// because they become unsafe (e.g., "x = 1" negated is "not x = 1" making x
// unsafe.) We check if the vars in the expr are already safe.
vis := ast.NewVarVisitor().WithParams(ast.VarVisitorParams{
SkipRefCallHead: true,
SkipClosures: true,
})
vis.Walk(expr)
unsafe := vis.Vars().Diff(safe).Diff(ast.ReservedVars)
if len(unsafe) > 0 {
return false
}
}
}
}
// NOTE(tsandall): this limit is arbitraryit's only in place to prevent the
// partial evaluation result from blowing up. In the future, we could make this
// configurable or do something more clever.
return size <= 16
}
type nestedCheckVisitor struct {
vis *ast.GenericVisitor
found bool
}
func newNestedCheckVisitor() *nestedCheckVisitor {
v := &nestedCheckVisitor{}
v.vis = ast.NewGenericVisitor(v.visit)
return v
}
func (v *nestedCheckVisitor) visit(x interface{}) bool {
switch x.(type) {
case ast.Ref, ast.Call:
v.found = true
}
return v.found
}
func containsNestedRefOrCall(vis *nestedCheckVisitor, expr *ast.Expr) bool {
if expr.IsEquality() {
for _, term := range expr.Operands() {
if containsNestedRefOrCallInTerm(vis, term) {
return true
}
}
return false
}
if expr.IsCall() {
for _, term := range expr.Operands() {
vis.vis.Walk(term)
if vis.found {
return true
}
}
return false
}
return containsNestedRefOrCallInTerm(vis, expr.Terms.(*ast.Term))
}
func containsNestedRefOrCallInTerm(vis *nestedCheckVisitor, term *ast.Term) bool {
switch v := term.Value.(type) {
case ast.Ref:
for i := 1; i < len(v); i++ {
vis.vis.Walk(v[i])
if vis.found {
return true
}
}
return false
default:
vis.vis.Walk(v)
if vis.found {
return true
}
return false
}
}
func complementedCartesianProduct(queries []ast.Body, idx int, curr ast.Body, iter func(ast.Body) error) error {
if idx == len(queries) {
return iter(curr)
}
for _, expr := range queries[idx] {
curr = append(curr, expr.Complement())
if err := complementedCartesianProduct(queries, idx+1, curr, iter); err != nil {
return err
}
curr = curr[:len(curr)-1]
}
return nil
}
func isInputRef(term *ast.Term) bool {
if ref, ok := term.Value.(ast.Ref); ok {
if ref.HasPrefix(ast.InputRootRef) {
return true
}
}
return false
}
func isDataRef(term *ast.Term) bool {
if ref, ok := term.Value.(ast.Ref); ok {
if ref.HasPrefix(ast.DefaultRootRef) {
return true
}
}
return false
}
func isOtherRef(term *ast.Term) bool {
ref, ok := term.Value.(ast.Ref)
if !ok {
panic("unreachable")
}
return !ref.HasPrefix(ast.DefaultRootRef) && !ref.HasPrefix(ast.InputRootRef)
}
func isFunction(env *ast.TypeEnv, ref interface{}) bool {
var r ast.Ref
switch v := ref.(type) {
case ast.Ref:
r = v
case *ast.Term:
return isFunction(env, v.Value)
case ast.Value:
return false
default:
panic("expected ast.Value or *ast.Term")
}
_, ok := env.Get(r).(*types.Function)
return ok
}
func merge(a, b ast.Value) (ast.Value, bool) {
aObj, ok1 := a.(ast.Object)
bObj, ok2 := b.(ast.Object)
if ok1 && ok2 {
return mergeObjects(aObj, bObj)
}
// nothing to merge, a wins
return a, true
}
// mergeObjects returns a new Object containing the non-overlapping keys of
// the objA and objB. If there are overlapping keys between objA and objB,
// the values of associated with the keys are merged. Only
// objects can be merged with other objects. If the values cannot be merged,
// objB value will be overwritten by objA value.
func mergeObjects(objA, objB ast.Object) (result ast.Object, ok bool) {
result = ast.NewObject()
stop := objA.Until(func(k, v *ast.Term) bool {
if v2 := objB.Get(k); v2 == nil {
result.Insert(k, v)
} else {
obj1, ok1 := v.Value.(ast.Object)
obj2, ok2 := v2.Value.(ast.Object)
if !ok1 || !ok2 {
result.Insert(k, v)
return false
}
obj3, ok := mergeObjects(obj1, obj2)
if !ok {
return true
}
result.Insert(k, ast.NewTerm(obj3))
}
return false
})
if stop {
return nil, false
}
objB.Foreach(func(k, v *ast.Term) {
if v2 := objA.Get(k); v2 == nil {
result.Insert(k, v)
}
})
return result, true
}
func refContainsNonScalar(ref ast.Ref) bool {
for _, term := range ref[1:] {
if !ast.IsScalar(term.Value) {
return true
}
}
return false
}
func suppressEarlyExit(err error) error {
ee, ok := err.(*earlyExitError)
if !ok {
return err
}
return ee.prev // nil if we're done
}
func (e *eval) updateSavedMocks(withs []*ast.With) []*ast.With {
ret := make([]*ast.With, 0, len(withs))
for _, w := range withs {
if isOtherRef(w.Target) || isFunction(e.compiler.TypeEnv, w.Target) {
continue
}
ret = append(ret, w.Copy())
}
return ret
}