Files
kubesphere/vendor/github.com/open-policy-agent/opa/topdown/eval.go
hongming 9769357005 update
Signed-off-by: hongming <talonwan@yunify.com>
2020-03-20 02:16:11 +08:00

2383 lines
53 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package topdown
import (
"context"
"fmt"
"sort"
"strconv"
"strings"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/storage"
"github.com/open-policy-agent/opa/topdown/builtins"
"github.com/open-policy-agent/opa/topdown/copypropagation"
)
type evalIterator func(*eval) error
type unifyIterator func() error
type queryIDFactory struct {
curr uint64
}
// Note: The first call to Next() returns 0.
func (f *queryIDFactory) Next() uint64 {
curr := f.curr
f.curr++
return curr
}
type eval struct {
ctx context.Context
queryID uint64
queryIDFact *queryIDFactory
parent *eval
caller *eval
cancel Cancel
query ast.Body
queryCompiler ast.QueryCompiler
index int
indexing bool
bindings *bindings
store storage.Store
baseCache *baseCache
txn storage.Transaction
compiler *ast.Compiler
input *ast.Term
data *ast.Term
targetStack *refStack
tracers []Tracer
instr *Instrumentation
builtins map[string]*Builtin
builtinCache builtins.Cache
virtualCache *virtualCache
saveSet *saveSet
saveStack *saveStack
saveSupport *saveSupport
saveNamespace *ast.Term
disableInlining [][]ast.Ref
genvarprefix string
runtime *ast.Term
}
func (e *eval) Run(iter evalIterator) error {
e.traceEnter(e.query)
return e.eval(func(e *eval) error {
e.traceExit(e.query)
err := iter(e)
e.traceRedo(e.query)
return err
})
}
func (e *eval) builtinFunc(name string) (*ast.Builtin, BuiltinFunc, bool) {
decl, ok := ast.BuiltinMap[name]
if !ok {
bi, ok := e.builtins[name]
if ok {
return bi.Decl, bi.Func, true
}
} else {
f, ok := builtinFunctions[name]
if ok {
return decl, f, true
}
}
return nil, nil, false
}
func (e *eval) closure(query ast.Body) *eval {
cpy := *e
cpy.index = 0
cpy.query = query
cpy.queryID = cpy.queryIDFact.Next()
cpy.parent = e
return &cpy
}
func (e *eval) child(query ast.Body) *eval {
cpy := *e
cpy.index = 0
cpy.query = query
cpy.queryID = cpy.queryIDFact.Next()
cpy.bindings = newBindings(cpy.queryID, e.instr)
cpy.parent = e
return &cpy
}
func (e *eval) next(iter evalIterator) error {
e.index++
err := e.evalExpr(iter)
e.index--
return err
}
func (e *eval) partial() bool {
return e.saveSet != nil
}
func (e *eval) unknown(x interface{}, b *bindings) bool {
if !e.partial() {
return false
}
// If the caller provided an ast.Value directly (e.g., an ast.Ref) wrap
// it as an ast.Term because the saveSet Contains() function expects
// ast.Term.
if v, ok := x.(ast.Value); ok {
x = ast.NewTerm(v)
}
return saveRequired(e.compiler, e.saveSet, b, x, false)
}
func (e *eval) traceEnter(x ast.Node) {
e.traceEvent(EnterOp, x, "")
}
func (e *eval) traceExit(x ast.Node) {
e.traceEvent(ExitOp, x, "")
}
func (e *eval) traceEval(x ast.Node) {
e.traceEvent(EvalOp, x, "")
}
func (e *eval) traceFail(x ast.Node) {
e.traceEvent(FailOp, x, "")
}
func (e *eval) traceRedo(x ast.Node) {
e.traceEvent(RedoOp, x, "")
}
func (e *eval) traceSave(x ast.Node) {
e.traceEvent(SaveOp, x, "")
}
func (e *eval) traceIndex(x ast.Node, msg string) {
e.traceEvent(IndexOp, x, msg)
}
func (e *eval) traceEvent(op Op, x ast.Node, msg string) {
if !traceIsEnabled(e.tracers) {
return
}
locals := ast.NewValueMap()
localMeta := map[ast.Var]VarMetadata{}
e.bindings.Iter(nil, func(k, v *ast.Term) error {
original := k.Value.(ast.Var)
rewritten, _ := e.rewrittenVar(original)
localMeta[original] = VarMetadata{
Name: rewritten,
Location: k.Loc(),
}
// For backwards compatibility save a copy of the values too..
locals.Put(k.Value, v.Value)
return nil
})
ast.WalkTerms(x, func(term *ast.Term) bool {
if v, ok := term.Value.(ast.Var); ok {
if _, ok := localMeta[v]; !ok {
if rewritten, ok := e.rewrittenVar(v); ok {
localMeta[v] = VarMetadata{
Name: rewritten,
Location: term.Loc(),
}
}
}
}
return false
})
var parentID uint64
if e.parent != nil {
parentID = e.parent.queryID
}
evt := &Event{
QueryID: e.queryID,
ParentID: parentID,
Op: op,
Node: x,
Location: x.Loc(),
Locals: locals,
LocalMetadata: localMeta,
Message: msg,
}
for i := range e.tracers {
if e.tracers[i].Enabled() {
e.tracers[i].Trace(evt)
}
}
}
func (e *eval) eval(iter evalIterator) error {
return e.evalExpr(iter)
}
func (e *eval) evalExpr(iter evalIterator) error {
if e.cancel != nil && e.cancel.Cancelled() {
return &Error{
Code: CancelErr,
Message: "caller cancelled query execution",
}
}
if e.index >= len(e.query) {
return iter(e)
}
expr := e.query[e.index]
e.traceEval(expr)
if len(expr.With) > 0 {
return e.evalWith(iter)
}
return e.evalStep(func(e *eval) error {
return e.next(iter)
})
}
func (e *eval) evalStep(iter evalIterator) error {
expr := e.query[e.index]
if expr.Negated {
return e.evalNot(iter)
}
var defined bool
var err error
switch terms := expr.Terms.(type) {
case []*ast.Term:
if expr.IsEquality() {
err = e.unify(terms[1], terms[2], func() error {
defined = true
err := iter(e)
e.traceRedo(expr)
return err
})
} else {
err = e.evalCall(terms, func() error {
defined = true
err := iter(e)
e.traceRedo(expr)
return err
})
}
case *ast.Term:
rterm := e.generateVar(fmt.Sprintf("term_%d_%d", e.queryID, e.index))
err = e.unify(terms, rterm, func() error {
if e.saveSet.Contains(rterm, e.bindings) {
return e.saveExpr(ast.NewExpr(rterm), e.bindings, func() error {
return iter(e)
})
}
if !e.bindings.Plug(rterm).Equal(ast.BooleanTerm(false)) {
defined = true
err := iter(e)
e.traceRedo(expr)
return err
}
return nil
})
}
if err != nil {
return err
}
if !defined {
e.traceFail(expr)
}
return nil
}
func (e *eval) evalNot(iter evalIterator) error {
expr := e.query[e.index]
if e.unknown(expr, e.bindings) {
return e.evalNotPartial(iter)
}
negation := ast.NewBody(expr.Complement().NoWith())
child := e.closure(negation)
var defined bool
child.traceEnter(negation)
err := child.eval(func(*eval) error {
child.traceExit(negation)
defined = true
child.traceRedo(negation)
return nil
})
if err != nil {
return err
}
if !defined {
return iter(e)
}
e.traceFail(expr)
return nil
}
func (e *eval) evalWith(iter evalIterator) error {
expr := e.query[e.index]
var disable []ast.Ref
if e.partial() {
// If the value is unknown the with statement cannot be evaluated and so
// the entire expression should be saved to be safe. In the future this
// could be relaxed in certain cases (e.g., if the with statement would
// have no affect.)
for _, with := range expr.With {
if e.saveSet.ContainsRecursive(with.Value, e.bindings) {
return e.saveExpr(expr, e.bindings, func() error {
return e.next(iter)
})
}
}
// Disable inlining on all references in the expression so the result of
// partial evaluation has the same semamntics w/ the with statements
// preserved.
ast.WalkRefs(expr, func(x ast.Ref) bool {
disable = append(disable, x.GroundPrefix())
return false
})
}
pairsInput := [][2]*ast.Term{}
pairsData := [][2]*ast.Term{}
targets := []ast.Ref{}
for i := range expr.With {
plugged := e.bindings.Plug(expr.With[i].Value)
if isInputRef(expr.With[i].Target) {
pairsInput = append(pairsInput, [...]*ast.Term{expr.With[i].Target, plugged})
} else if isDataRef(expr.With[i].Target) {
pairsData = append(pairsData, [...]*ast.Term{expr.With[i].Target, plugged})
}
targets = append(targets, expr.With[i].Target.Value.(ast.Ref))
}
input, err := mergeTermWithValues(e.input, pairsInput)
if err != nil {
return &Error{
Code: ConflictErr,
Location: expr.Location,
Message: err.Error(),
}
}
data, err := mergeTermWithValues(e.data, pairsData)
if err != nil {
return &Error{
Code: ConflictErr,
Location: expr.Location,
Message: err.Error(),
}
}
oldInput, oldData := e.evalWithPush(input, data, targets, disable)
err = e.evalStep(func(e *eval) error {
e.evalWithPop(oldInput, oldData)
err := e.next(iter)
oldInput, oldData = e.evalWithPush(input, data, targets, disable)
return err
})
e.evalWithPop(oldInput, oldData)
return err
}
func (e *eval) evalWithPush(input *ast.Term, data *ast.Term, targets []ast.Ref, disable []ast.Ref) (*ast.Term, *ast.Term) {
var oldInput *ast.Term
if input != nil {
oldInput = e.input
e.input = input
}
var oldData *ast.Term
if data != nil {
oldData = e.data
e.data = data
}
e.virtualCache.Push()
e.targetStack.Push(targets)
e.disableInlining = append(e.disableInlining, disable)
return oldInput, oldData
}
func (e *eval) evalWithPop(input *ast.Term, data *ast.Term) {
e.disableInlining = e.disableInlining[:len(e.disableInlining)-1]
e.targetStack.Pop()
e.virtualCache.Pop()
e.data = data
e.input = input
}
func (e *eval) evalNotPartial(iter evalIterator) error {
// Prepare query normally.
expr := e.query[e.index]
negation := expr.Complement().NoWith()
child := e.closure(ast.NewBody(negation))
// Unknowns is the set of variables that are marked as unknown. The variables
// are namespaced with the query ID that they originate in. This ensures that
// variables across two or more queries are identified uniquely.
//
// NOTE(tsandall): this is greedy in the sense that we only need variable
// dependencies of the negation.
unknowns := e.saveSet.Vars(e.caller.bindings)
// Run partial evaluation, plugging the result and applying copy propagation to
// each result. Since the result may require support, push a new query onto the
// save stack to avoid mutating the current save query.
p := copypropagation.New(unknowns).WithEnsureNonEmptyBody(true)
var savedQueries []ast.Body
e.saveStack.PushQuery(nil)
child.eval(func(*eval) error {
query := e.saveStack.Peek()
plugged := query.Plug(e.caller.bindings)
result := applyCopyPropagation(p, e.instr, plugged)
savedQueries = append(savedQueries, result)
return nil
})
e.saveStack.PopQuery()
// If partial evaluation produced no results, the expression is always undefined
// so it does not have to be saved.
if len(savedQueries) == 0 {
return iter(e)
}
// Check if the partial evaluation result can be inlined in this query. If not,
// generate support rules for the result. Depending on the size of the partial
// evaluation result and the contents, it may or may not be inlinable. We treat
// the unknowns as safe because vars in the save set will either be known to
// the caller or made safe by an expression on the save stack.
if !canInlineNegation(unknowns, savedQueries) {
return e.evalNotPartialSupport(expr, unknowns, savedQueries, iter)
}
// If we can inline the result, we have to generate the cross product of the
// queries. For example:
//
// (A && B) || (C && D)
//
// Becomes:
//
// (!A && !C) || (!A && !D) || (!B && !C) || (!B && !D)
return complementedCartesianProduct(savedQueries, 0, nil, func(q ast.Body) error {
return e.saveInlinedNegatedExprs(q, func() error {
return iter(e)
})
})
}
func (e *eval) evalNotPartialSupport(expr *ast.Expr, unknowns ast.VarSet, queries []ast.Body, iter evalIterator) error {
// Prepare support rule head.
supportName := fmt.Sprintf("__not%d_%d__", e.queryID, e.index)
term := ast.RefTerm(ast.DefaultRootDocument, e.saveNamespace, ast.StringTerm(supportName))
path := term.Value.(ast.Ref)
head := ast.NewHead(ast.Var(supportName), nil, ast.BooleanTerm(true))
bodyVars := ast.NewVarSet()
for _, q := range queries {
bodyVars.Update(q.Vars(ast.VarVisitorParams{}))
}
unknowns = unknowns.Intersect(bodyVars)
// Make rule args. Sort them to ensure order is deterministic.
args := make([]*ast.Term, 0, len(unknowns))
for v := range unknowns {
args = append(args, ast.NewTerm(v))
}
sort.Slice(args, func(i, j int) bool {
return args[i].Value.Compare(args[j].Value) < 0
})
if len(args) > 0 {
head.Args = ast.Args(args)
}
// Save support rules.
for _, query := range queries {
e.saveSupport.Insert(path, &ast.Rule{
Head: head,
Body: query,
})
}
// Save expression that refers to support rule set.
expr = expr.Copy()
if len(args) > 0 {
terms := make([]*ast.Term, len(args)+1)
terms[0] = term
for i := 0; i < len(args); i++ {
terms[i+1] = args[i]
}
expr.Terms = terms
} else {
expr.Terms = term
}
return e.saveInlinedNegatedExprs([]*ast.Expr{expr}, func() error {
return e.next(iter)
})
}
func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error {
ref := terms[0].Value.(ast.Ref)
if ref[0].Equal(ast.DefaultRootDocument) {
eval := evalFunc{
e: e,
ref: ref,
terms: terms,
}
return eval.eval(iter)
}
bi, f, ok := e.builtinFunc(ref.String())
if !ok {
return unsupportedBuiltinErr(e.query[e.index].Location)
}
if e.unknown(e.query[e.index], e.bindings) {
return e.saveCall(len(bi.Decl.Args()), terms, iter)
}
var parentID uint64
if e.parent != nil {
parentID = e.parent.queryID
}
bctx := BuiltinContext{
Context: e.ctx,
Cancel: e.cancel,
Runtime: e.runtime,
Cache: e.builtinCache,
Location: e.query[e.index].Location,
Tracers: e.tracers,
QueryID: e.queryID,
ParentID: parentID,
}
eval := evalBuiltin{
e: e,
bi: bi,
bctx: bctx,
f: f,
terms: terms[1:],
}
return eval.eval(iter)
}
func (e *eval) unify(a, b *ast.Term, iter unifyIterator) error {
return e.biunify(a, b, e.bindings, e.bindings, iter)
}
func (e *eval) biunify(a, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error {
a, b1 = b1.apply(a)
b, b2 = b2.apply(b)
switch vA := a.Value.(type) {
case ast.Var, ast.Ref, *ast.ArrayComprehension, *ast.SetComprehension, *ast.ObjectComprehension:
return e.biunifyValues(a, b, b1, b2, iter)
case ast.Null:
switch b.Value.(type) {
case ast.Var, ast.Null, ast.Ref:
return e.biunifyValues(a, b, b1, b2, iter)
}
case ast.Boolean:
switch b.Value.(type) {
case ast.Var, ast.Boolean, ast.Ref:
return e.biunifyValues(a, b, b1, b2, iter)
}
case ast.Number:
switch b.Value.(type) {
case ast.Var, ast.Number, ast.Ref:
return e.biunifyValues(a, b, b1, b2, iter)
}
case ast.String:
switch b.Value.(type) {
case ast.Var, ast.String, ast.Ref:
return e.biunifyValues(a, b, b1, b2, iter)
}
case ast.Array:
switch vB := b.Value.(type) {
case ast.Var, ast.Ref, *ast.ArrayComprehension:
return e.biunifyValues(a, b, b1, b2, iter)
case ast.Array:
return e.biunifyArrays(vA, vB, b1, b2, iter)
}
case ast.Object:
switch vB := b.Value.(type) {
case ast.Var, ast.Ref, *ast.ObjectComprehension:
return e.biunifyValues(a, b, b1, b2, iter)
case ast.Object:
return e.biunifyObjects(vA, vB, b1, b2, iter)
}
case ast.Set:
return e.biunifyValues(a, b, b1, b2, iter)
}
return nil
}
func (e *eval) biunifyArrays(a, b ast.Array, b1, b2 *bindings, iter unifyIterator) error {
if len(a) != len(b) {
return nil
}
return e.biunifyArraysRec(a, b, b1, b2, iter, 0)
}
func (e *eval) biunifyArraysRec(a, b ast.Array, b1, b2 *bindings, iter unifyIterator, idx int) error {
if idx == len(a) {
return iter()
}
return e.biunify(a[idx], b[idx], b1, b2, func() error {
return e.biunifyArraysRec(a, b, b1, b2, iter, idx+1)
})
}
func (e *eval) biunifyObjects(a, b ast.Object, b1, b2 *bindings, iter unifyIterator) error {
if a.Len() != b.Len() {
return nil
}
// Objects must not contain unbound variables as keys at this point as we
// cannot unify them. Similar to sets, plug both sides before comparing the
// keys and unifying the values.
if nonGroundKeys(a) {
a = plugKeys(a, b1)
}
if nonGroundKeys(b) {
b = plugKeys(b, b2)
}
return e.biunifyObjectsRec(a, b, b1, b2, iter, a.Keys(), 0)
}
func (e *eval) biunifyObjectsRec(a, b ast.Object, b1, b2 *bindings, iter unifyIterator, keys []*ast.Term, idx int) error {
if idx == len(keys) {
return iter()
}
v2 := b.Get(keys[idx])
if v2 == nil {
return nil
}
return e.biunify(a.Get(keys[idx]), v2, b1, b2, func() error {
return e.biunifyObjectsRec(a, b, b1, b2, iter, keys, idx+1)
})
}
func (e *eval) biunifyValues(a, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error {
// Try to evaluate refs and comprehensions. If partial evaluation is
// enabled, then skip evaluation (and save the expression) if the term is
// in the save set. Currently, comprehensions are not evaluated during
// partial eval. This could be improved in the future.
var saveA, saveB bool
if _, ok := a.Value.(ast.Set); ok {
saveA = e.saveSet.ContainsRecursive(a, b1)
} else {
saveA = e.saveSet.Contains(a, b1)
if !saveA {
if _, refA := a.Value.(ast.Ref); refA {
return e.biunifyRef(a, b, b1, b2, iter)
}
}
}
if _, ok := b.Value.(ast.Set); ok {
saveB = e.saveSet.ContainsRecursive(b, b2)
} else {
saveB = e.saveSet.Contains(b, b2)
if !saveB {
if _, refB := b.Value.(ast.Ref); refB {
return e.biunifyRef(b, a, b2, b1, iter)
}
}
}
if saveA || saveB {
return e.saveUnify(a, b, b1, b2, iter)
}
if ast.IsComprehension(a.Value) {
return e.biunifyComprehension(a, b, b1, b2, false, iter)
} else if ast.IsComprehension(b.Value) {
return e.biunifyComprehension(b, a, b2, b1, true, iter)
}
// Perform standard unification.
_, varA := a.Value.(ast.Var)
_, varB := b.Value.(ast.Var)
if varA && varB {
if b1 == b2 && a.Equal(b) {
return iter()
}
undo := b1.bind(a, b, b2)
err := iter()
undo.Undo()
return err
} else if varA && !varB {
undo := b1.bind(a, b, b2)
err := iter()
undo.Undo()
return err
} else if varB && !varA {
undo := b2.bind(b, a, b1)
err := iter()
undo.Undo()
return err
}
// Sets must not contain unbound variables at this point as we cannot unify
// them. So simply plug both sides (to substitute any bound variables with
// values) and then check for equality.
switch a.Value.(type) {
case ast.Set:
a = b1.Plug(a)
b = b2.Plug(b)
}
if a.Equal(b) {
return iter()
}
return nil
}
func (e *eval) biunifyRef(a, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error {
ref := a.Value.(ast.Ref)
if ref[0].Equal(ast.DefaultRootDocument) {
node := e.compiler.RuleTree.Child(ref[0].Value)
eval := evalTree{
e: e,
ref: ref,
pos: 1,
plugged: ref.Copy(),
bindings: b1,
rterm: b,
rbindings: b2,
node: node,
}
return eval.eval(iter)
}
var term *ast.Term
var termbindings *bindings
if ref[0].Equal(ast.InputRootDocument) {
term = e.input
termbindings = b1
} else {
term, termbindings = b1.apply(ref[0])
if term == ref[0] {
term = nil
}
}
if term == nil {
return nil
}
eval := evalTerm{
e: e,
ref: ref,
pos: 1,
bindings: b1,
term: term,
termbindings: termbindings,
rterm: b,
rbindings: b2,
}
return eval.eval(iter)
}
func (e *eval) biunifyComprehension(a, b *ast.Term, b1, b2 *bindings, swap bool, iter unifyIterator) error {
if e.unknown(a, b1) {
return e.biunifyComprehensionPartial(a, b, b1, b2, swap, iter)
}
switch a := a.Value.(type) {
case *ast.ArrayComprehension:
return e.biunifyComprehensionArray(a, b, b1, b2, iter)
case *ast.SetComprehension:
return e.biunifyComprehensionSet(a, b, b1, b2, iter)
case *ast.ObjectComprehension:
return e.biunifyComprehensionObject(a, b, b1, b2, iter)
}
return fmt.Errorf("illegal comprehension %T", a)
}
func (e *eval) biunifyComprehensionPartial(a, b *ast.Term, b1, b2 *bindings, swap bool, iter unifyIterator) error {
// Capture bindings available to the comprehension. We will add expressions
// to the comprehension body that ensure the comprehension body is safe.
// Currently this process adds _all_ bindings (even if they are not
// needed.) Eventually we may want to make the logic a bit smarter.
var extras []*ast.Expr
err := b1.Iter(e.caller.bindings, func(k, v *ast.Term) error {
extras = append(extras, ast.Equality.Expr(k, v))
return nil
})
if err != nil {
return err
}
// Namespace the variables in the body to avoid collision when the final
// queries returned by partial evaluation.
var body *ast.Body
switch a := a.Value.(type) {
case *ast.ArrayComprehension:
body = &a.Body
case *ast.SetComprehension:
body = &a.Body
case *ast.ObjectComprehension:
body = &a.Body
default:
return fmt.Errorf("illegal comprehension %T", a)
}
for _, e := range extras {
body.Append(e)
}
b1.Namespace(a, e.caller.bindings)
// The other term might need to be plugged so include the bindings. The
// bindings for the comprehension term are saved (for compatibility) but
// the eventual plug operation on the comprehension will be a no-op.
if !swap {
return e.saveUnify(a, b, b1, b2, iter)
}
return e.saveUnify(b, a, b2, b1, iter)
}
func (e *eval) biunifyComprehensionArray(x *ast.ArrayComprehension, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error {
result := ast.Array{}
child := e.closure(x.Body)
err := child.Run(func(child *eval) error {
result = append(result, child.bindings.Plug(x.Term))
return nil
})
if err != nil {
return err
}
return e.biunify(ast.NewTerm(result), b, b1, b2, iter)
}
func (e *eval) biunifyComprehensionSet(x *ast.SetComprehension, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error {
result := ast.NewSet()
child := e.closure(x.Body)
err := child.Run(func(child *eval) error {
result.Add(child.bindings.Plug(x.Term))
return nil
})
if err != nil {
return err
}
return e.biunify(ast.NewTerm(result), b, b1, b2, iter)
}
func (e *eval) biunifyComprehensionObject(x *ast.ObjectComprehension, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error {
result := ast.NewObject()
child := e.closure(x.Body)
err := child.Run(func(child *eval) error {
key := child.bindings.Plug(x.Key)
value := child.bindings.Plug(x.Value)
exist := result.Get(key)
if exist != nil && !exist.Equal(value) {
return objectDocKeyConflictErr(x.Key.Location)
}
result.Insert(key, value)
return nil
})
if err != nil {
return err
}
return e.biunify(ast.NewTerm(result), b, b1, b2, iter)
}
type savePair struct {
term *ast.Term
b *bindings
}
func getSavePairs(x *ast.Term, b *bindings, result []savePair) []savePair {
if _, ok := x.Value.(ast.Var); ok {
result = append(result, savePair{x, b})
return result
}
vis := ast.NewVarVisitor().WithParams(ast.VarVisitorParams{
SkipClosures: true,
SkipRefHead: true,
})
vis.Walk(x)
for v := range vis.Vars() {
y, next := b.apply(ast.NewTerm(v))
result = getSavePairs(y, next, result)
}
return result
}
func (e *eval) saveExpr(expr *ast.Expr, b *bindings, iter unifyIterator) error {
expr.With = e.query[e.index].With
e.saveStack.Push(expr, b, b)
e.traceSave(expr)
err := iter()
e.saveStack.Pop()
return err
}
func (e *eval) saveUnify(a, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error {
e.instr.startTimer(partialOpSaveUnify)
expr := ast.Equality.Expr(a, b)
expr.With = e.query[e.index].With
pops := 0
if pairs := getSavePairs(a, b1, nil); len(pairs) > 0 {
pops += len(pairs)
for _, p := range pairs {
e.saveSet.Push([]*ast.Term{p.term}, p.b)
}
}
if pairs := getSavePairs(b, b2, nil); len(pairs) > 0 {
pops += len(pairs)
for _, p := range pairs {
e.saveSet.Push([]*ast.Term{p.term}, p.b)
}
}
e.saveStack.Push(expr, b1, b2)
e.traceSave(expr)
e.instr.stopTimer(partialOpSaveUnify)
err := iter()
e.saveStack.Pop()
for i := 0; i < pops; i++ {
e.saveSet.Pop()
}
return err
}
func (e *eval) saveCall(declArgsLen int, terms []*ast.Term, iter unifyIterator) error {
expr := ast.NewExpr(terms)
expr.With = e.query[e.index].With
// If call-site includes output value then partial eval must add vars in output
// position to the save set.
pops := 0
if declArgsLen == len(terms)-2 {
if pairs := getSavePairs(terms[len(terms)-1], e.bindings, nil); len(pairs) > 0 {
pops += len(pairs)
for _, p := range pairs {
e.saveSet.Push([]*ast.Term{p.term}, p.b)
}
}
}
e.saveStack.Push(expr, e.bindings, nil)
e.traceSave(expr)
err := iter()
e.saveStack.Pop()
for i := 0; i < pops; i++ {
e.saveSet.Pop()
}
return err
}
func (e *eval) saveInlinedNegatedExprs(exprs []*ast.Expr, iter unifyIterator) error {
// This function does not include with statements on the exprs because
// they will have already been saved and therefore had their any relevant
// with statements set.
for _, expr := range exprs {
e.saveStack.Push(expr, nil, nil)
e.traceSave(expr)
}
err := iter()
for i := 0; i < len(exprs); i++ {
e.saveStack.Pop()
}
return err
}
func (e *eval) getRules(ref ast.Ref) (*ast.IndexResult, error) {
e.instr.startTimer(evalOpRuleIndex)
defer e.instr.stopTimer(evalOpRuleIndex)
index := e.compiler.RuleIndex(ref)
if index == nil {
return nil, nil
}
var result *ast.IndexResult
var err error
if e.indexing {
result, err = index.Lookup(e)
} else {
result, err = index.AllRules(e)
}
if err != nil {
return nil, err
}
var msg string
if len(result.Rules) == 1 {
msg = "(matched 1 rule)"
} else {
var b strings.Builder
b.Grow(len("(matched NNNN rules)"))
b.WriteString("matched ")
b.WriteString(strconv.FormatInt(int64(len(result.Rules)), 10))
b.WriteString(" rules)")
msg = b.String()
}
e.traceIndex(e.query[e.index], msg)
return result, err
}
func (e *eval) Resolve(ref ast.Ref) (ast.Value, error) {
e.instr.startTimer(evalOpResolve)
if e.saveSet.Contains(ast.NewTerm(ref), nil) {
e.instr.stopTimer(evalOpResolve)
return nil, ast.UnknownValueErr{}
}
if ref[0].Equal(ast.InputRootDocument) {
if e.input != nil {
v, err := e.input.Value.Find(ref[1:])
if err != nil {
v = nil
}
e.instr.stopTimer(evalOpResolve)
return v, nil
}
e.instr.stopTimer(evalOpResolve)
return nil, nil
}
if ref[0].Equal(ast.DefaultRootDocument) {
var repValue ast.Value
if e.data != nil {
if v, err := e.data.Value.Find(ref[1:]); err == nil {
repValue = v
} else {
repValue = nil
}
}
if e.targetStack.Prefixed(ref) {
e.instr.stopTimer(evalOpResolve)
return repValue, nil
}
var merged ast.Value
var err error
// Converting large JSON values into AST values can be fairly expensive. For
// example, a 2MB JSON value can take upwards of 30 millisceonds to convert.
// We cache the result of conversion here in case the same base document is
// being read multiple times during evaluation.
realValue := e.baseCache.Get(ref)
if realValue != nil {
e.instr.counterIncr(evalOpBaseCacheHit)
if repValue == nil {
e.instr.stopTimer(evalOpResolve)
return realValue, nil
}
var ok bool
merged, ok = merge(repValue, realValue)
if !ok {
err = mergeConflictErr(ref[0].Location)
}
} else {
e.instr.counterIncr(evalOpBaseCacheMiss)
merged, err = e.resolveReadFromStorage(ref, repValue)
}
e.instr.stopTimer(evalOpResolve)
return merged, err
}
e.instr.stopTimer(evalOpResolve)
return nil, fmt.Errorf("illegal ref")
}
func (e *eval) resolveReadFromStorage(ref ast.Ref, a ast.Value) (ast.Value, error) {
if refContainsNonScalar(ref) {
return a, nil
}
path, err := storage.NewPathForRef(ref)
if err != nil {
if !storage.IsNotFound(err) {
return nil, err
}
return a, nil
}
blob, err := e.store.Read(e.ctx, e.txn, path)
if err != nil {
if !storage.IsNotFound(err) {
return nil, err
}
return a, nil
}
if len(path) == 0 {
obj := blob.(map[string]interface{})
if len(obj) > 0 {
cpy := make(map[string]interface{}, len(obj)-1)
for k, v := range obj {
if string(ast.SystemDocumentKey) == k {
continue
}
cpy[k] = v
}
blob = cpy
}
}
v, err := ast.InterfaceToValue(blob)
if err != nil {
return nil, err
}
e.baseCache.Put(ref, v)
if a == nil {
return v, nil
}
merged, ok := merge(a, v)
if !ok {
return nil, mergeConflictErr(ref[0].Location)
}
return merged, nil
}
func (e *eval) generateVar(suffix string) *ast.Term {
return ast.VarTerm(fmt.Sprintf("%v_%v", e.genvarprefix, suffix))
}
func (e *eval) rewrittenVar(v ast.Var) (ast.Var, bool) {
if e.compiler != nil {
if rw, ok := e.compiler.RewrittenVars[v]; ok {
return rw, true
}
}
if e.queryCompiler != nil {
if rw, ok := e.queryCompiler.RewrittenVars()[v]; ok {
return rw, true
}
}
return v, false
}
type evalBuiltin struct {
e *eval
bi *ast.Builtin
bctx BuiltinContext
f BuiltinFunc
terms []*ast.Term
}
func (e evalBuiltin) eval(iter unifyIterator) error {
operands := make([]*ast.Term, len(e.terms))
for i := 0; i < len(e.terms); i++ {
operands[i] = e.e.bindings.Plug(e.terms[i])
}
numDeclArgs := len(e.bi.Decl.Args())
e.e.instr.startTimer(evalOpBuiltinCall)
err := e.f(e.bctx, operands, func(output *ast.Term) error {
e.e.instr.stopTimer(evalOpBuiltinCall)
var err error
if len(operands) == numDeclArgs {
if output.Value.Compare(ast.Boolean(false)) != 0 {
err = iter()
}
} else {
err = e.e.unify(e.terms[len(e.terms)-1], output, iter)
}
e.e.instr.startTimer(evalOpBuiltinCall)
return err
})
e.e.instr.stopTimer(evalOpBuiltinCall)
return err
}
type evalFunc struct {
e *eval
ref ast.Ref
terms []*ast.Term
}
func (e evalFunc) eval(iter unifyIterator) error {
ir, err := e.e.getRules(e.ref)
if err != nil {
return err
}
if ir.Empty() {
return nil
}
if len(ir.Else) > 0 && e.e.unknown(e.e.query[e.e.index], e.e.bindings) {
// Partial evaluation of ordered rules is not supported currently. Save the
// expression and continue. This could be revisited in the future.
return e.e.saveCall(len(ir.Rules[0].Head.Args), e.terms, iter)
}
var prev *ast.Term
for i := range ir.Rules {
next, err := e.evalOneRule(iter, ir.Rules[i], prev)
if err != nil {
return err
}
if next == nil {
for _, rule := range ir.Else[ir.Rules[i]] {
next, err = e.evalOneRule(iter, rule, prev)
if err != nil {
return err
}
if next != nil {
break
}
}
}
if next != nil {
prev = next
}
}
return nil
}
func (e evalFunc) evalOneRule(iter unifyIterator, rule *ast.Rule, prev *ast.Term) (*ast.Term, error) {
child := e.e.child(rule.Body)
args := make(ast.Array, len(e.terms)-1)
for i := range rule.Head.Args {
args[i] = rule.Head.Args[i]
}
if len(args) == len(rule.Head.Args)+1 {
args[len(args)-1] = rule.Head.Value
}
var result *ast.Term
child.traceEnter(rule)
err := child.biunifyArrays(e.terms[1:], args, e.e.bindings, child.bindings, func() error {
return child.eval(func(child *eval) error {
child.traceExit(rule)
result = child.bindings.Plug(rule.Head.Value)
if len(rule.Head.Args) == len(e.terms)-1 {
if result.Value.Compare(ast.Boolean(false)) == 0 {
return nil
}
}
// Partial evaluation should explore all rules and may not produce
// a ground result so we do not perform conflict detection or
// deduplication. See "ignore conflicts: functions" test case for
// an example.
if !e.e.partial() {
if prev != nil {
if ast.Compare(prev, result) != 0 {
return functionConflictErr(rule.Location)
}
child.traceRedo(rule)
return nil
}
}
prev = result
if err := iter(); err != nil {
return err
}
child.traceRedo(rule)
return nil
})
})
return result, err
}
type evalTree struct {
e *eval
ref ast.Ref
plugged ast.Ref
pos int
bindings *bindings
rterm *ast.Term
rbindings *bindings
node *ast.TreeNode
}
func (e evalTree) eval(iter unifyIterator) error {
if len(e.ref) == e.pos {
return e.finish(iter)
}
plugged := e.bindings.Plug(e.ref[e.pos])
if plugged.IsGround() {
return e.next(iter, plugged)
}
return e.enumerate(iter)
}
func (e evalTree) finish(iter unifyIterator) error {
// During partial evaluation it may not be possible to compute the value
// for this reference if it refers to a virtual document so save the entire
// expression. See "save: full extent" test case for an example.
if e.node != nil && e.e.unknown(e.ref, e.e.bindings) {
return e.e.saveUnify(ast.NewTerm(e.plugged), e.rterm, e.bindings, e.rbindings, iter)
}
v, err := e.extent()
if err != nil || v == nil {
return err
}
return e.e.biunify(e.rterm, v, e.rbindings, e.bindings, func() error {
return iter()
})
}
func (e evalTree) next(iter unifyIterator, plugged *ast.Term) error {
var node *ast.TreeNode
cpy := e
cpy.plugged[e.pos] = plugged
cpy.pos++
if !e.e.targetStack.Prefixed(cpy.plugged[:cpy.pos]) {
if e.node != nil {
node = e.node.Child(plugged.Value)
if node != nil && len(node.Values) > 0 {
r := evalVirtual{
e: e.e,
ref: e.ref,
plugged: e.plugged,
pos: e.pos,
bindings: e.bindings,
rterm: e.rterm,
rbindings: e.rbindings,
}
r.plugged[e.pos] = plugged
return r.eval(iter)
}
}
}
cpy.node = node
return cpy.eval(iter)
}
func (e evalTree) enumerate(iter unifyIterator) error {
doc, err := e.e.Resolve(e.plugged[:e.pos])
if err != nil {
return err
}
if doc != nil {
switch doc := doc.(type) {
case ast.Array:
for i := range doc {
k := ast.IntNumberTerm(i)
err := e.e.biunify(k, e.ref[e.pos], e.bindings, e.bindings, func() error {
return e.next(iter, k)
})
if err != nil {
return err
}
}
case ast.Object:
err := doc.Iter(func(k, _ *ast.Term) error {
return e.e.biunify(k, e.ref[e.pos], e.bindings, e.bindings, func() error {
return e.next(iter, k)
})
})
if err != nil {
return err
}
case ast.Set:
err := doc.Iter(func(elem *ast.Term) error {
return e.e.biunify(elem, e.ref[e.pos], e.bindings, e.bindings, func() error {
return e.next(iter, elem)
})
})
if err != nil {
return err
}
}
}
if e.node == nil {
return nil
}
for k := range e.node.Children {
key := ast.NewTerm(k)
if err := e.e.biunify(key, e.ref[e.pos], e.bindings, e.bindings, func() error {
return e.next(iter, key)
}); err != nil {
return err
}
}
return nil
}
func (e evalTree) extent() (*ast.Term, error) {
base, err := e.e.Resolve(e.plugged)
if err != nil {
return nil, err
}
virtual, err := e.leaves(e.plugged, e.node)
if err != nil {
return nil, err
}
if virtual == nil {
if base == nil {
return nil, nil
}
return ast.NewTerm(base), nil
}
if base != nil {
merged, ok := merge(base, virtual)
if !ok {
return nil, mergeConflictErr(e.plugged[0].Location)
}
return ast.NewTerm(merged), nil
}
return ast.NewTerm(virtual), nil
}
func (e evalTree) leaves(plugged ast.Ref, node *ast.TreeNode) (ast.Object, error) {
if e.node == nil {
return nil, nil
}
result := ast.NewObject()
for _, child := range node.Children {
if child.Hide {
continue
}
plugged = append(plugged, ast.NewTerm(child.Key))
var save ast.Value
var err error
if len(child.Values) > 0 {
rterm := e.e.generateVar("leaf")
err = e.e.unify(ast.NewTerm(plugged), rterm, func() error {
save = e.e.bindings.Plug(rterm).Value
return nil
})
} else {
save, err = e.leaves(plugged, child)
}
if err != nil {
return nil, err
}
if save != nil {
v := ast.NewObject([2]*ast.Term{plugged[len(plugged)-1], ast.NewTerm(save)})
result, _ = result.Merge(v)
}
plugged = plugged[:len(plugged)-1]
}
return result, nil
}
type evalVirtual struct {
e *eval
ref ast.Ref
plugged ast.Ref
pos int
bindings *bindings
rterm *ast.Term
rbindings *bindings
}
func (e evalVirtual) eval(iter unifyIterator) error {
ir, err := e.e.getRules(e.plugged[:e.pos+1])
if err != nil {
return err
}
// Partial evaluation of ordered rules is not supported currently. Save the
// expression and continue. This could be revisited in the future.
if len(ir.Else) > 0 && e.e.unknown(e.ref, e.bindings) {
return e.e.saveUnify(ast.NewTerm(e.ref), e.rterm, e.bindings, e.rbindings, iter)
}
switch ir.Kind {
case ast.PartialSetDoc:
eval := evalVirtualPartial{
e: e.e,
ref: e.ref,
plugged: e.plugged,
pos: e.pos,
ir: ir,
bindings: e.bindings,
rterm: e.rterm,
rbindings: e.rbindings,
empty: ast.SetTerm(),
}
return eval.eval(iter)
case ast.PartialObjectDoc:
eval := evalVirtualPartial{
e: e.e,
ref: e.ref,
plugged: e.plugged,
pos: e.pos,
ir: ir,
bindings: e.bindings,
rterm: e.rterm,
rbindings: e.rbindings,
empty: ast.ObjectTerm(),
}
return eval.eval(iter)
default:
eval := evalVirtualComplete{
e: e.e,
ref: e.ref,
plugged: e.plugged,
pos: e.pos,
ir: ir,
bindings: e.bindings,
rterm: e.rterm,
rbindings: e.rbindings,
}
return eval.eval(iter)
}
}
type evalVirtualPartial struct {
e *eval
ref ast.Ref
plugged ast.Ref
pos int
ir *ast.IndexResult
bindings *bindings
rterm *ast.Term
rbindings *bindings
empty *ast.Term
}
func (e evalVirtualPartial) eval(iter unifyIterator) error {
if len(e.ref) == e.pos+1 {
// During partial evaluation, it may not be possible to produce a value
// for this reference so save the entire expression. See "save: full
// extent: partial object" test case for an example.
if e.e.unknown(e.ref, e.bindings) {
return e.e.saveUnify(ast.NewTerm(e.ref), e.rterm, e.bindings, e.rbindings, iter)
}
return e.evalAllRules(iter, e.ir.Rules)
}
var cacheKey ast.Ref
if e.ir.Kind == ast.PartialObjectDoc {
plugged := e.bindings.Plug(e.ref[e.pos+1])
if plugged.IsGround() {
path := e.plugged[:e.pos+2]
path[len(path)-1] = plugged
cached := e.e.virtualCache.Get(path)
if cached != nil {
e.e.instr.counterIncr(evalOpVirtualCacheHit)
return e.evalTerm(iter, cached, e.bindings)
}
e.e.instr.counterIncr(evalOpVirtualCacheMiss)
cacheKey = path
}
}
generateSupport := anyRefSetContainsPrefix(e.e.disableInlining, e.plugged[:e.pos+1])
if generateSupport {
return e.partialEvalSupport(iter)
}
for _, rule := range e.ir.Rules {
if err := e.evalOneRule(iter, rule, cacheKey); err != nil {
return err
}
}
return nil
}
func (e evalVirtualPartial) evalAllRules(iter unifyIterator, rules []*ast.Rule) error {
result := e.empty
for _, rule := range rules {
child := e.e.child(rule.Body)
child.traceEnter(rule)
err := child.eval(func(*eval) error {
child.traceExit(rule)
var err error
result, err = e.reduce(rule.Head, child.bindings, result)
if err != nil {
return err
}
child.traceRedo(rule)
return nil
})
if err != nil {
return err
}
}
return e.e.biunify(result, e.rterm, e.bindings, e.bindings, iter)
}
func (e evalVirtualPartial) evalOneRule(iter unifyIterator, rule *ast.Rule, cacheKey ast.Ref) error {
key := e.ref[e.pos+1]
child := e.e.child(rule.Body)
child.traceEnter(rule)
var defined bool
err := child.biunify(rule.Head.Key, key, child.bindings, e.bindings, func() error {
defined = true
return child.eval(func(child *eval) error {
child.traceExit(rule)
term := rule.Head.Value
if term == nil {
term = rule.Head.Key
}
if cacheKey != nil {
result := child.bindings.Plug(term)
e.e.virtualCache.Put(cacheKey, result)
}
term, termbindings := child.bindings.apply(term)
err := e.evalTerm(iter, term, termbindings)
if err != nil {
return err
}
child.traceRedo(rule)
return nil
})
})
if err != nil {
return err
}
if !defined {
child.traceFail(rule)
}
return nil
}
func (e evalVirtualPartial) partialEvalSupport(iter unifyIterator) error {
path := e.plugged[:e.pos+1].Insert(e.e.saveNamespace, 1)
if !e.e.saveSupport.Exists(path) {
for i := range e.ir.Rules {
err := e.partialEvalSupportRule(iter, e.ir.Rules[i], path)
if err != nil {
return err
}
}
}
rewritten := ast.NewTerm(e.ref.Insert(e.e.saveNamespace, 1))
return e.e.saveUnify(rewritten, e.rterm, e.bindings, e.rbindings, iter)
}
func (e evalVirtualPartial) partialEvalSupportRule(iter unifyIterator, rule *ast.Rule, path ast.Ref) error {
child := e.e.child(rule.Body)
child.traceEnter(rule)
e.e.saveStack.PushQuery(nil)
err := child.eval(func(child *eval) error {
child.traceExit(rule)
current := e.e.saveStack.PopQuery()
plugged := current.Plug(e.e.caller.bindings)
var key, value *ast.Term
if rule.Head.Key != nil {
key = child.bindings.PlugNamespaced(rule.Head.Key, e.e.caller.bindings)
}
if rule.Head.Value != nil {
value = child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings)
}
head := ast.NewHead(rule.Head.Name, key, value)
p := copypropagation.New(head.Vars()).WithEnsureNonEmptyBody(true)
e.e.saveSupport.Insert(path, &ast.Rule{
Head: head,
Body: p.Apply(plugged),
Default: rule.Default,
})
child.traceRedo(rule)
e.e.saveStack.PushQuery(current)
return nil
})
e.e.saveStack.PopQuery()
return err
}
func (e evalVirtualPartial) evalTerm(iter unifyIterator, term *ast.Term, termbindings *bindings) error {
eval := evalTerm{
e: e.e,
ref: e.ref,
pos: e.pos + 2,
bindings: e.bindings,
term: term,
termbindings: termbindings,
rterm: e.rterm,
rbindings: e.rbindings,
}
return eval.eval(iter)
}
func (e evalVirtualPartial) reduce(head *ast.Head, b *bindings, result *ast.Term) (*ast.Term, error) {
switch v := result.Value.(type) {
case ast.Set:
v.Add(b.Plug(head.Key))
case ast.Object:
key := b.Plug(head.Key)
value := b.Plug(head.Value)
exist := v.Get(key)
if exist != nil && !exist.Equal(value) {
return nil, objectDocKeyConflictErr(head.Location)
}
v.Insert(key, value)
result.Value = v
}
return result, nil
}
type evalVirtualComplete struct {
e *eval
ref ast.Ref
plugged ast.Ref
pos int
ir *ast.IndexResult
bindings *bindings
rterm *ast.Term
rbindings *bindings
}
func (e evalVirtualComplete) eval(iter unifyIterator) error {
if e.ir.Empty() {
return nil
}
if len(e.ir.Rules) > 0 && len(e.ir.Rules[0].Head.Args) > 0 {
return nil
}
if !e.e.unknown(e.ref, e.bindings) {
return e.evalValue(iter)
}
var generateSupport bool
if e.ir.Default != nil {
// If the other term is not constant OR it's equal to the default value, then
// a support rule must be produced as the default value _may_ be required. On
// the other hand, if the other term is constant (i.e., it does not require
// evaluation) and it differs from the default value then the default value is
// _not_ required, so partially evaluate the rule normally.
rterm := e.rbindings.Plug(e.rterm)
generateSupport = !ast.IsConstant(rterm.Value) || e.ir.Default.Head.Value.Equal(rterm)
}
generateSupport = generateSupport || anyRefSetContainsPrefix(e.e.disableInlining, e.plugged[:e.pos+1])
if generateSupport {
return e.partialEvalSupport(iter)
}
return e.partialEval(iter)
}
func (e evalVirtualComplete) evalValue(iter unifyIterator) error {
cached := e.e.virtualCache.Get(e.plugged[:e.pos+1])
if cached != nil {
e.e.instr.counterIncr(evalOpVirtualCacheHit)
return e.evalTerm(iter, cached, e.bindings)
}
e.e.instr.counterIncr(evalOpVirtualCacheMiss)
var prev *ast.Term
for i := range e.ir.Rules {
next, err := e.evalValueRule(iter, e.ir.Rules[i], prev)
if err != nil {
return err
}
if next == nil {
for _, rule := range e.ir.Else[e.ir.Rules[i]] {
next, err = e.evalValueRule(iter, rule, prev)
if err != nil {
return err
}
if next != nil {
break
}
}
}
if next != nil {
prev = next
}
}
if e.ir.Default != nil && prev == nil {
_, err := e.evalValueRule(iter, e.ir.Default, prev)
return err
}
return nil
}
func (e evalVirtualComplete) evalValueRule(iter unifyIterator, rule *ast.Rule, prev *ast.Term) (*ast.Term, error) {
child := e.e.child(rule.Body)
child.traceEnter(rule)
var result *ast.Term
err := child.eval(func(child *eval) error {
child.traceExit(rule)
result = child.bindings.Plug(rule.Head.Value)
if prev != nil {
if ast.Compare(result, prev) != 0 {
return completeDocConflictErr(rule.Location)
}
child.traceRedo(rule)
return nil
}
prev = result
e.e.virtualCache.Put(e.plugged[:e.pos+1], result)
term, termbindings := child.bindings.apply(rule.Head.Value)
err := e.evalTerm(iter, term, termbindings)
if err != nil {
return err
}
child.traceRedo(rule)
return nil
})
return result, err
}
func (e evalVirtualComplete) partialEval(iter unifyIterator) error {
for _, rule := range e.ir.Rules {
child := e.e.child(rule.Body)
child.traceEnter(rule)
err := child.eval(func(child *eval) error {
child.traceExit(rule)
term, termbindings := child.bindings.apply(rule.Head.Value)
err := e.evalTerm(iter, term, termbindings)
if err != nil {
return err
}
child.traceRedo(rule)
return nil
})
if err != nil {
return err
}
}
return nil
}
func (e evalVirtualComplete) partialEvalSupport(iter unifyIterator) error {
path := e.plugged[:e.pos+1].Insert(e.e.saveNamespace, 1)
if !e.e.saveSupport.Exists(path) {
for i := range e.ir.Rules {
err := e.partialEvalSupportRule(iter, e.ir.Rules[i], path)
if err != nil {
return err
}
}
if e.ir.Default != nil {
err := e.partialEvalSupportRule(iter, e.ir.Default, path)
if err != nil {
return err
}
}
}
rewritten := ast.NewTerm(e.ref.Insert(e.e.saveNamespace, 1))
return e.e.saveUnify(rewritten, e.rterm, e.bindings, e.rbindings, iter)
}
func (e evalVirtualComplete) partialEvalSupportRule(iter unifyIterator, rule *ast.Rule, path ast.Ref) error {
child := e.e.child(rule.Body)
child.traceEnter(rule)
e.e.saveStack.PushQuery(nil)
err := child.eval(func(child *eval) error {
child.traceExit(rule)
current := e.e.saveStack.PopQuery()
plugged := current.Plug(e.e.caller.bindings)
head := ast.NewHead(rule.Head.Name, nil, child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings))
p := copypropagation.New(head.Vars()).WithEnsureNonEmptyBody(true)
e.e.saveSupport.Insert(path, &ast.Rule{
Head: head,
Body: applyCopyPropagation(p, e.e.instr, plugged),
Default: rule.Default,
})
child.traceRedo(rule)
e.e.saveStack.PushQuery(current)
return nil
})
e.e.saveStack.PopQuery()
return err
}
func (e evalVirtualComplete) evalTerm(iter unifyIterator, term *ast.Term, termbindings *bindings) error {
eval := evalTerm{
e: e.e,
ref: e.ref,
pos: e.pos + 1,
bindings: e.bindings,
term: term,
termbindings: termbindings,
rterm: e.rterm,
rbindings: e.rbindings,
}
return eval.eval(iter)
}
type evalTerm struct {
e *eval
ref ast.Ref
pos int
bindings *bindings
term *ast.Term
termbindings *bindings
rterm *ast.Term
rbindings *bindings
}
func (e evalTerm) eval(iter unifyIterator) error {
if len(e.ref) == e.pos {
return e.e.biunify(e.term, e.rterm, e.termbindings, e.rbindings, iter)
}
if e.e.saveSet.Contains(e.term, e.termbindings) {
return e.save(iter)
}
plugged := e.bindings.Plug(e.ref[e.pos])
if plugged.IsGround() {
return e.next(iter, plugged)
}
return e.enumerate(iter)
}
func (e evalTerm) next(iter unifyIterator, plugged *ast.Term) error {
term, bindings := e.get(plugged)
if term == nil {
return nil
}
cpy := e
cpy.term = term
cpy.termbindings = bindings
cpy.pos++
return cpy.eval(iter)
}
func (e evalTerm) enumerate(iter unifyIterator) error {
switch v := e.term.Value.(type) {
case ast.Array:
for i := range v {
k := ast.IntNumberTerm(i)
err := e.e.biunify(k, e.ref[e.pos], e.bindings, e.bindings, func() error {
return e.next(iter, k)
})
if err != nil {
return err
}
}
case ast.Object:
return v.Iter(func(k, _ *ast.Term) error {
return e.e.biunify(k, e.ref[e.pos], e.termbindings, e.bindings, func() error {
return e.next(iter, e.termbindings.Plug(k))
})
})
case ast.Set:
return v.Iter(func(elem *ast.Term) error {
return e.e.biunify(elem, e.ref[e.pos], e.termbindings, e.bindings, func() error {
return e.next(iter, e.termbindings.Plug(elem))
})
})
}
return nil
}
func (e evalTerm) get(plugged *ast.Term) (*ast.Term, *bindings) {
switch v := e.term.Value.(type) {
case ast.Set:
if v.IsGround() {
if v.Contains(plugged) {
return e.termbindings.apply(plugged)
}
} else {
var t *ast.Term
var b *bindings
stop := v.Until(func(elem *ast.Term) bool {
if e.termbindings.Plug(elem).Equal(plugged) {
t, b = e.termbindings.apply(plugged)
return true
}
return false
})
if stop {
return t, b
}
}
case ast.Object:
if v.IsGround() {
term := v.Get(plugged)
if term != nil {
return e.termbindings.apply(term)
}
} else {
var t *ast.Term
var b *bindings
stop := v.Until(func(k, v *ast.Term) bool {
if e.termbindings.Plug(k).Equal(plugged) {
t, b = e.termbindings.apply(v)
return true
}
return false
})
if stop {
return t, b
}
}
case ast.Array:
term := v.Get(plugged)
if term != nil {
return e.termbindings.apply(term)
}
}
return nil, nil
}
func (e evalTerm) save(iter unifyIterator) error {
suffix := e.ref[e.pos:]
ref := make(ast.Ref, len(suffix)+1)
ref[0] = e.term
for i := 0; i < len(suffix); i++ {
ref[i+1] = suffix[i]
}
return e.e.biunify(ast.NewTerm(ref), e.rterm, e.termbindings, e.rbindings, iter)
}
func applyCopyPropagation(p *copypropagation.CopyPropagator, instr *Instrumentation, body ast.Body) ast.Body {
instr.startTimer(partialOpCopyPropagation)
result := p.Apply(body)
instr.stopTimer(partialOpCopyPropagation)
return result
}
func nonGroundKeys(a ast.Object) bool {
return a.Until(func(k, _ *ast.Term) bool {
return !k.IsGround()
})
}
func plugKeys(a ast.Object, b *bindings) ast.Object {
plugged, _ := a.Map(func(k, v *ast.Term) (*ast.Term, *ast.Term, error) {
return b.Plug(k), v, nil
})
return plugged
}
func plugSlice(xs []*ast.Term, b *bindings) []*ast.Term {
cpy := make([]*ast.Term, len(xs))
for i := range cpy {
cpy[i] = b.Plug(xs[i])
}
return cpy
}
func canInlineNegation(safe ast.VarSet, queries []ast.Body) bool {
size := 1
for _, query := range queries {
size *= len(query)
for _, expr := range query {
if !expr.Negated {
// Positive expressions containing variables cannot be trivially negated
// because they become unsafe (e.g., "x = 1" negated is "not x = 1" making x
// unsafe.) We check if the vars in the expr are already safe.
vis := ast.NewVarVisitor().WithParams(ast.VarVisitorParams{
SkipRefCallHead: true,
SkipClosures: true,
})
vis.Walk(expr)
unsafe := vis.Vars().Diff(safe).Diff(ast.ReservedVars)
if len(unsafe) > 0 {
return false
}
}
}
}
// NOTE(tsandall): this limit is arbitraryit's only in place to prevent the
// partial evaluation result from blowing up. In the future, we could make this
// configurable or do something more clever.
if size > 16 {
return false
}
return true
}
func complementedCartesianProduct(queries []ast.Body, idx int, curr ast.Body, iter func(ast.Body) error) error {
if idx == len(queries) {
return iter(curr)
}
for _, expr := range queries[idx] {
curr = append(curr, expr.Complement())
if err := complementedCartesianProduct(queries, idx+1, curr, iter); err != nil {
return err
}
curr = curr[:len(curr)-1]
}
return nil
}
func isInputRef(term *ast.Term) bool {
if ref, ok := term.Value.(ast.Ref); ok {
if ref.HasPrefix(ast.InputRootRef) {
return true
}
}
return false
}
func isDataRef(term *ast.Term) bool {
if ref, ok := term.Value.(ast.Ref); ok {
if ref.HasPrefix(ast.DefaultRootRef) {
return true
}
}
return false
}
func merge(a, b ast.Value) (ast.Value, bool) {
aObj, ok1 := a.(ast.Object)
bObj, ok2 := b.(ast.Object)
if ok1 && ok2 {
return mergeObjects(aObj, bObj)
}
return nil, false
}
// mergeObjects returns a new Object containing the non-overlapping keys of
// the objA and objB. If there are overlapping keys between objA and objB,
// the values of associated with the keys are merged. Only
// objects can be merged with other objects. If the values cannot be merged,
// objB value will be overwritten by objA value.
func mergeObjects(objA, objB ast.Object) (result ast.Object, ok bool) {
result = ast.NewObject()
stop := objA.Until(func(k, v *ast.Term) bool {
if v2 := objB.Get(k); v2 == nil {
result.Insert(k, v)
} else {
obj1, ok1 := v.Value.(ast.Object)
obj2, ok2 := v2.Value.(ast.Object)
if !ok1 || !ok2 {
result.Insert(k, v)
return false
}
obj3, ok := mergeObjects(obj1, obj2)
if !ok {
return true
}
result.Insert(k, ast.NewTerm(obj3))
}
return false
})
if stop {
return nil, false
}
objB.Foreach(func(k, v *ast.Term) {
if v2 := objA.Get(k); v2 == nil {
result.Insert(k, v)
}
})
return result, true
}
func anyRefSetContainsPrefix(s [][]ast.Ref, prefix ast.Ref) bool {
for _, refs := range s {
for _, ref := range refs {
if ref.HasPrefix(prefix) {
return true
}
}
}
return false
}
func refContainsNonScalar(ref ast.Ref) bool {
for _, term := range ref[1:] {
if !ast.IsScalar(term.Value) {
return true
}
}
return false
}