package topdown import ( "context" "fmt" "sort" "strconv" "strings" "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/topdown/builtins" "github.com/open-policy-agent/opa/topdown/copypropagation" ) type evalIterator func(*eval) error type unifyIterator func() error type queryIDFactory struct { curr uint64 } // Note: The first call to Next() returns 0. func (f *queryIDFactory) Next() uint64 { curr := f.curr f.curr++ return curr } type eval struct { ctx context.Context queryID uint64 queryIDFact *queryIDFactory parent *eval caller *eval cancel Cancel query ast.Body queryCompiler ast.QueryCompiler index int indexing bool bindings *bindings store storage.Store baseCache *baseCache txn storage.Transaction compiler *ast.Compiler input *ast.Term data *ast.Term targetStack *refStack tracers []Tracer instr *Instrumentation builtins map[string]*Builtin builtinCache builtins.Cache virtualCache *virtualCache saveSet *saveSet saveStack *saveStack saveSupport *saveSupport saveNamespace *ast.Term disableInlining [][]ast.Ref genvarprefix string runtime *ast.Term } func (e *eval) Run(iter evalIterator) error { e.traceEnter(e.query) return e.eval(func(e *eval) error { e.traceExit(e.query) err := iter(e) e.traceRedo(e.query) return err }) } func (e *eval) builtinFunc(name string) (*ast.Builtin, BuiltinFunc, bool) { decl, ok := ast.BuiltinMap[name] if !ok { bi, ok := e.builtins[name] if ok { return bi.Decl, bi.Func, true } } else { f, ok := builtinFunctions[name] if ok { return decl, f, true } } return nil, nil, false } func (e *eval) closure(query ast.Body) *eval { cpy := *e cpy.index = 0 cpy.query = query cpy.queryID = cpy.queryIDFact.Next() cpy.parent = e return &cpy } func (e *eval) child(query ast.Body) *eval { cpy := *e cpy.index = 0 cpy.query = query cpy.queryID = cpy.queryIDFact.Next() cpy.bindings = newBindings(cpy.queryID, e.instr) cpy.parent = e return &cpy } func (e *eval) next(iter evalIterator) error { e.index++ err := e.evalExpr(iter) e.index-- return err } func (e *eval) partial() bool { return e.saveSet != nil } func (e *eval) unknown(x interface{}, b *bindings) bool { if !e.partial() { return false } // If the caller provided an ast.Value directly (e.g., an ast.Ref) wrap // it as an ast.Term because the saveSet Contains() function expects // ast.Term. if v, ok := x.(ast.Value); ok { x = ast.NewTerm(v) } return saveRequired(e.compiler, e.saveSet, b, x, false) } func (e *eval) traceEnter(x ast.Node) { e.traceEvent(EnterOp, x, "") } func (e *eval) traceExit(x ast.Node) { e.traceEvent(ExitOp, x, "") } func (e *eval) traceEval(x ast.Node) { e.traceEvent(EvalOp, x, "") } func (e *eval) traceFail(x ast.Node) { e.traceEvent(FailOp, x, "") } func (e *eval) traceRedo(x ast.Node) { e.traceEvent(RedoOp, x, "") } func (e *eval) traceSave(x ast.Node) { e.traceEvent(SaveOp, x, "") } func (e *eval) traceIndex(x ast.Node, msg string) { e.traceEvent(IndexOp, x, msg) } func (e *eval) traceEvent(op Op, x ast.Node, msg string) { if !traceIsEnabled(e.tracers) { return } locals := ast.NewValueMap() localMeta := map[ast.Var]VarMetadata{} e.bindings.Iter(nil, func(k, v *ast.Term) error { original := k.Value.(ast.Var) rewritten, _ := e.rewrittenVar(original) localMeta[original] = VarMetadata{ Name: rewritten, Location: k.Loc(), } // For backwards compatibility save a copy of the values too.. locals.Put(k.Value, v.Value) return nil }) ast.WalkTerms(x, func(term *ast.Term) bool { if v, ok := term.Value.(ast.Var); ok { if _, ok := localMeta[v]; !ok { if rewritten, ok := e.rewrittenVar(v); ok { localMeta[v] = VarMetadata{ Name: rewritten, Location: term.Loc(), } } } } return false }) var parentID uint64 if e.parent != nil { parentID = e.parent.queryID } evt := &Event{ QueryID: e.queryID, ParentID: parentID, Op: op, Node: x, Location: x.Loc(), Locals: locals, LocalMetadata: localMeta, Message: msg, } for i := range e.tracers { if e.tracers[i].Enabled() { e.tracers[i].Trace(evt) } } } func (e *eval) eval(iter evalIterator) error { return e.evalExpr(iter) } func (e *eval) evalExpr(iter evalIterator) error { if e.cancel != nil && e.cancel.Cancelled() { return &Error{ Code: CancelErr, Message: "caller cancelled query execution", } } if e.index >= len(e.query) { return iter(e) } expr := e.query[e.index] e.traceEval(expr) if len(expr.With) > 0 { return e.evalWith(iter) } return e.evalStep(func(e *eval) error { return e.next(iter) }) } func (e *eval) evalStep(iter evalIterator) error { expr := e.query[e.index] if expr.Negated { return e.evalNot(iter) } var defined bool var err error switch terms := expr.Terms.(type) { case []*ast.Term: if expr.IsEquality() { err = e.unify(terms[1], terms[2], func() error { defined = true err := iter(e) e.traceRedo(expr) return err }) } else { err = e.evalCall(terms, func() error { defined = true err := iter(e) e.traceRedo(expr) return err }) } case *ast.Term: rterm := e.generateVar(fmt.Sprintf("term_%d_%d", e.queryID, e.index)) err = e.unify(terms, rterm, func() error { if e.saveSet.Contains(rterm, e.bindings) { return e.saveExpr(ast.NewExpr(rterm), e.bindings, func() error { return iter(e) }) } if !e.bindings.Plug(rterm).Equal(ast.BooleanTerm(false)) { defined = true err := iter(e) e.traceRedo(expr) return err } return nil }) } if err != nil { return err } if !defined { e.traceFail(expr) } return nil } func (e *eval) evalNot(iter evalIterator) error { expr := e.query[e.index] if e.unknown(expr, e.bindings) { return e.evalNotPartial(iter) } negation := ast.NewBody(expr.Complement().NoWith()) child := e.closure(negation) var defined bool child.traceEnter(negation) err := child.eval(func(*eval) error { child.traceExit(negation) defined = true child.traceRedo(negation) return nil }) if err != nil { return err } if !defined { return iter(e) } e.traceFail(expr) return nil } func (e *eval) evalWith(iter evalIterator) error { expr := e.query[e.index] var disable []ast.Ref if e.partial() { // If the value is unknown the with statement cannot be evaluated and so // the entire expression should be saved to be safe. In the future this // could be relaxed in certain cases (e.g., if the with statement would // have no affect.) for _, with := range expr.With { if e.saveSet.ContainsRecursive(with.Value, e.bindings) { return e.saveExpr(expr, e.bindings, func() error { return e.next(iter) }) } } // Disable inlining on all references in the expression so the result of // partial evaluation has the same semamntics w/ the with statements // preserved. ast.WalkRefs(expr, func(x ast.Ref) bool { disable = append(disable, x.GroundPrefix()) return false }) } pairsInput := [][2]*ast.Term{} pairsData := [][2]*ast.Term{} targets := []ast.Ref{} for i := range expr.With { plugged := e.bindings.Plug(expr.With[i].Value) if isInputRef(expr.With[i].Target) { pairsInput = append(pairsInput, [...]*ast.Term{expr.With[i].Target, plugged}) } else if isDataRef(expr.With[i].Target) { pairsData = append(pairsData, [...]*ast.Term{expr.With[i].Target, plugged}) } targets = append(targets, expr.With[i].Target.Value.(ast.Ref)) } input, err := mergeTermWithValues(e.input, pairsInput) if err != nil { return &Error{ Code: ConflictErr, Location: expr.Location, Message: err.Error(), } } data, err := mergeTermWithValues(e.data, pairsData) if err != nil { return &Error{ Code: ConflictErr, Location: expr.Location, Message: err.Error(), } } oldInput, oldData := e.evalWithPush(input, data, targets, disable) err = e.evalStep(func(e *eval) error { e.evalWithPop(oldInput, oldData) err := e.next(iter) oldInput, oldData = e.evalWithPush(input, data, targets, disable) return err }) e.evalWithPop(oldInput, oldData) return err } func (e *eval) evalWithPush(input *ast.Term, data *ast.Term, targets []ast.Ref, disable []ast.Ref) (*ast.Term, *ast.Term) { var oldInput *ast.Term if input != nil { oldInput = e.input e.input = input } var oldData *ast.Term if data != nil { oldData = e.data e.data = data } e.virtualCache.Push() e.targetStack.Push(targets) e.disableInlining = append(e.disableInlining, disable) return oldInput, oldData } func (e *eval) evalWithPop(input *ast.Term, data *ast.Term) { e.disableInlining = e.disableInlining[:len(e.disableInlining)-1] e.targetStack.Pop() e.virtualCache.Pop() e.data = data e.input = input } func (e *eval) evalNotPartial(iter evalIterator) error { // Prepare query normally. expr := e.query[e.index] negation := expr.Complement().NoWith() child := e.closure(ast.NewBody(negation)) // Unknowns is the set of variables that are marked as unknown. The variables // are namespaced with the query ID that they originate in. This ensures that // variables across two or more queries are identified uniquely. // // NOTE(tsandall): this is greedy in the sense that we only need variable // dependencies of the negation. unknowns := e.saveSet.Vars(e.caller.bindings) // Run partial evaluation, plugging the result and applying copy propagation to // each result. Since the result may require support, push a new query onto the // save stack to avoid mutating the current save query. p := copypropagation.New(unknowns).WithEnsureNonEmptyBody(true) var savedQueries []ast.Body e.saveStack.PushQuery(nil) child.eval(func(*eval) error { query := e.saveStack.Peek() plugged := query.Plug(e.caller.bindings) result := applyCopyPropagation(p, e.instr, plugged) savedQueries = append(savedQueries, result) return nil }) e.saveStack.PopQuery() // If partial evaluation produced no results, the expression is always undefined // so it does not have to be saved. if len(savedQueries) == 0 { return iter(e) } // Check if the partial evaluation result can be inlined in this query. If not, // generate support rules for the result. Depending on the size of the partial // evaluation result and the contents, it may or may not be inlinable. We treat // the unknowns as safe because vars in the save set will either be known to // the caller or made safe by an expression on the save stack. if !canInlineNegation(unknowns, savedQueries) { return e.evalNotPartialSupport(expr, unknowns, savedQueries, iter) } // If we can inline the result, we have to generate the cross product of the // queries. For example: // // (A && B) || (C && D) // // Becomes: // // (!A && !C) || (!A && !D) || (!B && !C) || (!B && !D) return complementedCartesianProduct(savedQueries, 0, nil, func(q ast.Body) error { return e.saveInlinedNegatedExprs(q, func() error { return iter(e) }) }) } func (e *eval) evalNotPartialSupport(expr *ast.Expr, unknowns ast.VarSet, queries []ast.Body, iter evalIterator) error { // Prepare support rule head. supportName := fmt.Sprintf("__not%d_%d__", e.queryID, e.index) term := ast.RefTerm(ast.DefaultRootDocument, e.saveNamespace, ast.StringTerm(supportName)) path := term.Value.(ast.Ref) head := ast.NewHead(ast.Var(supportName), nil, ast.BooleanTerm(true)) bodyVars := ast.NewVarSet() for _, q := range queries { bodyVars.Update(q.Vars(ast.VarVisitorParams{})) } unknowns = unknowns.Intersect(bodyVars) // Make rule args. Sort them to ensure order is deterministic. args := make([]*ast.Term, 0, len(unknowns)) for v := range unknowns { args = append(args, ast.NewTerm(v)) } sort.Slice(args, func(i, j int) bool { return args[i].Value.Compare(args[j].Value) < 0 }) if len(args) > 0 { head.Args = ast.Args(args) } // Save support rules. for _, query := range queries { e.saveSupport.Insert(path, &ast.Rule{ Head: head, Body: query, }) } // Save expression that refers to support rule set. expr = expr.Copy() if len(args) > 0 { terms := make([]*ast.Term, len(args)+1) terms[0] = term for i := 0; i < len(args); i++ { terms[i+1] = args[i] } expr.Terms = terms } else { expr.Terms = term } return e.saveInlinedNegatedExprs([]*ast.Expr{expr}, func() error { return e.next(iter) }) } func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error { ref := terms[0].Value.(ast.Ref) if ref[0].Equal(ast.DefaultRootDocument) { eval := evalFunc{ e: e, ref: ref, terms: terms, } return eval.eval(iter) } bi, f, ok := e.builtinFunc(ref.String()) if !ok { return unsupportedBuiltinErr(e.query[e.index].Location) } if e.unknown(e.query[e.index], e.bindings) { return e.saveCall(len(bi.Decl.Args()), terms, iter) } var parentID uint64 if e.parent != nil { parentID = e.parent.queryID } bctx := BuiltinContext{ Context: e.ctx, Cancel: e.cancel, Runtime: e.runtime, Cache: e.builtinCache, Location: e.query[e.index].Location, Tracers: e.tracers, QueryID: e.queryID, ParentID: parentID, } eval := evalBuiltin{ e: e, bi: bi, bctx: bctx, f: f, terms: terms[1:], } return eval.eval(iter) } func (e *eval) unify(a, b *ast.Term, iter unifyIterator) error { return e.biunify(a, b, e.bindings, e.bindings, iter) } func (e *eval) biunify(a, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error { a, b1 = b1.apply(a) b, b2 = b2.apply(b) switch vA := a.Value.(type) { case ast.Var, ast.Ref, *ast.ArrayComprehension, *ast.SetComprehension, *ast.ObjectComprehension: return e.biunifyValues(a, b, b1, b2, iter) case ast.Null: switch b.Value.(type) { case ast.Var, ast.Null, ast.Ref: return e.biunifyValues(a, b, b1, b2, iter) } case ast.Boolean: switch b.Value.(type) { case ast.Var, ast.Boolean, ast.Ref: return e.biunifyValues(a, b, b1, b2, iter) } case ast.Number: switch b.Value.(type) { case ast.Var, ast.Number, ast.Ref: return e.biunifyValues(a, b, b1, b2, iter) } case ast.String: switch b.Value.(type) { case ast.Var, ast.String, ast.Ref: return e.biunifyValues(a, b, b1, b2, iter) } case ast.Array: switch vB := b.Value.(type) { case ast.Var, ast.Ref, *ast.ArrayComprehension: return e.biunifyValues(a, b, b1, b2, iter) case ast.Array: return e.biunifyArrays(vA, vB, b1, b2, iter) } case ast.Object: switch vB := b.Value.(type) { case ast.Var, ast.Ref, *ast.ObjectComprehension: return e.biunifyValues(a, b, b1, b2, iter) case ast.Object: return e.biunifyObjects(vA, vB, b1, b2, iter) } case ast.Set: return e.biunifyValues(a, b, b1, b2, iter) } return nil } func (e *eval) biunifyArrays(a, b ast.Array, b1, b2 *bindings, iter unifyIterator) error { if len(a) != len(b) { return nil } return e.biunifyArraysRec(a, b, b1, b2, iter, 0) } func (e *eval) biunifyArraysRec(a, b ast.Array, b1, b2 *bindings, iter unifyIterator, idx int) error { if idx == len(a) { return iter() } return e.biunify(a[idx], b[idx], b1, b2, func() error { return e.biunifyArraysRec(a, b, b1, b2, iter, idx+1) }) } func (e *eval) biunifyObjects(a, b ast.Object, b1, b2 *bindings, iter unifyIterator) error { if a.Len() != b.Len() { return nil } // Objects must not contain unbound variables as keys at this point as we // cannot unify them. Similar to sets, plug both sides before comparing the // keys and unifying the values. if nonGroundKeys(a) { a = plugKeys(a, b1) } if nonGroundKeys(b) { b = plugKeys(b, b2) } return e.biunifyObjectsRec(a, b, b1, b2, iter, a.Keys(), 0) } func (e *eval) biunifyObjectsRec(a, b ast.Object, b1, b2 *bindings, iter unifyIterator, keys []*ast.Term, idx int) error { if idx == len(keys) { return iter() } v2 := b.Get(keys[idx]) if v2 == nil { return nil } return e.biunify(a.Get(keys[idx]), v2, b1, b2, func() error { return e.biunifyObjectsRec(a, b, b1, b2, iter, keys, idx+1) }) } func (e *eval) biunifyValues(a, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error { // Try to evaluate refs and comprehensions. If partial evaluation is // enabled, then skip evaluation (and save the expression) if the term is // in the save set. Currently, comprehensions are not evaluated during // partial eval. This could be improved in the future. var saveA, saveB bool if _, ok := a.Value.(ast.Set); ok { saveA = e.saveSet.ContainsRecursive(a, b1) } else { saveA = e.saveSet.Contains(a, b1) if !saveA { if _, refA := a.Value.(ast.Ref); refA { return e.biunifyRef(a, b, b1, b2, iter) } } } if _, ok := b.Value.(ast.Set); ok { saveB = e.saveSet.ContainsRecursive(b, b2) } else { saveB = e.saveSet.Contains(b, b2) if !saveB { if _, refB := b.Value.(ast.Ref); refB { return e.biunifyRef(b, a, b2, b1, iter) } } } if saveA || saveB { return e.saveUnify(a, b, b1, b2, iter) } if ast.IsComprehension(a.Value) { return e.biunifyComprehension(a, b, b1, b2, false, iter) } else if ast.IsComprehension(b.Value) { return e.biunifyComprehension(b, a, b2, b1, true, iter) } // Perform standard unification. _, varA := a.Value.(ast.Var) _, varB := b.Value.(ast.Var) if varA && varB { if b1 == b2 && a.Equal(b) { return iter() } undo := b1.bind(a, b, b2) err := iter() undo.Undo() return err } else if varA && !varB { undo := b1.bind(a, b, b2) err := iter() undo.Undo() return err } else if varB && !varA { undo := b2.bind(b, a, b1) err := iter() undo.Undo() return err } // Sets must not contain unbound variables at this point as we cannot unify // them. So simply plug both sides (to substitute any bound variables with // values) and then check for equality. switch a.Value.(type) { case ast.Set: a = b1.Plug(a) b = b2.Plug(b) } if a.Equal(b) { return iter() } return nil } func (e *eval) biunifyRef(a, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error { ref := a.Value.(ast.Ref) if ref[0].Equal(ast.DefaultRootDocument) { node := e.compiler.RuleTree.Child(ref[0].Value) eval := evalTree{ e: e, ref: ref, pos: 1, plugged: ref.Copy(), bindings: b1, rterm: b, rbindings: b2, node: node, } return eval.eval(iter) } var term *ast.Term var termbindings *bindings if ref[0].Equal(ast.InputRootDocument) { term = e.input termbindings = b1 } else { term, termbindings = b1.apply(ref[0]) if term == ref[0] { term = nil } } if term == nil { return nil } eval := evalTerm{ e: e, ref: ref, pos: 1, bindings: b1, term: term, termbindings: termbindings, rterm: b, rbindings: b2, } return eval.eval(iter) } func (e *eval) biunifyComprehension(a, b *ast.Term, b1, b2 *bindings, swap bool, iter unifyIterator) error { if e.unknown(a, b1) { return e.biunifyComprehensionPartial(a, b, b1, b2, swap, iter) } switch a := a.Value.(type) { case *ast.ArrayComprehension: return e.biunifyComprehensionArray(a, b, b1, b2, iter) case *ast.SetComprehension: return e.biunifyComprehensionSet(a, b, b1, b2, iter) case *ast.ObjectComprehension: return e.biunifyComprehensionObject(a, b, b1, b2, iter) } return fmt.Errorf("illegal comprehension %T", a) } func (e *eval) biunifyComprehensionPartial(a, b *ast.Term, b1, b2 *bindings, swap bool, iter unifyIterator) error { // Capture bindings available to the comprehension. We will add expressions // to the comprehension body that ensure the comprehension body is safe. // Currently this process adds _all_ bindings (even if they are not // needed.) Eventually we may want to make the logic a bit smarter. var extras []*ast.Expr err := b1.Iter(e.caller.bindings, func(k, v *ast.Term) error { extras = append(extras, ast.Equality.Expr(k, v)) return nil }) if err != nil { return err } // Namespace the variables in the body to avoid collision when the final // queries returned by partial evaluation. var body *ast.Body switch a := a.Value.(type) { case *ast.ArrayComprehension: body = &a.Body case *ast.SetComprehension: body = &a.Body case *ast.ObjectComprehension: body = &a.Body default: return fmt.Errorf("illegal comprehension %T", a) } for _, e := range extras { body.Append(e) } b1.Namespace(a, e.caller.bindings) // The other term might need to be plugged so include the bindings. The // bindings for the comprehension term are saved (for compatibility) but // the eventual plug operation on the comprehension will be a no-op. if !swap { return e.saveUnify(a, b, b1, b2, iter) } return e.saveUnify(b, a, b2, b1, iter) } func (e *eval) biunifyComprehensionArray(x *ast.ArrayComprehension, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error { result := ast.Array{} child := e.closure(x.Body) err := child.Run(func(child *eval) error { result = append(result, child.bindings.Plug(x.Term)) return nil }) if err != nil { return err } return e.biunify(ast.NewTerm(result), b, b1, b2, iter) } func (e *eval) biunifyComprehensionSet(x *ast.SetComprehension, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error { result := ast.NewSet() child := e.closure(x.Body) err := child.Run(func(child *eval) error { result.Add(child.bindings.Plug(x.Term)) return nil }) if err != nil { return err } return e.biunify(ast.NewTerm(result), b, b1, b2, iter) } func (e *eval) biunifyComprehensionObject(x *ast.ObjectComprehension, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error { result := ast.NewObject() child := e.closure(x.Body) err := child.Run(func(child *eval) error { key := child.bindings.Plug(x.Key) value := child.bindings.Plug(x.Value) exist := result.Get(key) if exist != nil && !exist.Equal(value) { return objectDocKeyConflictErr(x.Key.Location) } result.Insert(key, value) return nil }) if err != nil { return err } return e.biunify(ast.NewTerm(result), b, b1, b2, iter) } type savePair struct { term *ast.Term b *bindings } func getSavePairs(x *ast.Term, b *bindings, result []savePair) []savePair { if _, ok := x.Value.(ast.Var); ok { result = append(result, savePair{x, b}) return result } vis := ast.NewVarVisitor().WithParams(ast.VarVisitorParams{ SkipClosures: true, SkipRefHead: true, }) vis.Walk(x) for v := range vis.Vars() { y, next := b.apply(ast.NewTerm(v)) result = getSavePairs(y, next, result) } return result } func (e *eval) saveExpr(expr *ast.Expr, b *bindings, iter unifyIterator) error { expr.With = e.query[e.index].With e.saveStack.Push(expr, b, b) e.traceSave(expr) err := iter() e.saveStack.Pop() return err } func (e *eval) saveUnify(a, b *ast.Term, b1, b2 *bindings, iter unifyIterator) error { e.instr.startTimer(partialOpSaveUnify) expr := ast.Equality.Expr(a, b) expr.With = e.query[e.index].With pops := 0 if pairs := getSavePairs(a, b1, nil); len(pairs) > 0 { pops += len(pairs) for _, p := range pairs { e.saveSet.Push([]*ast.Term{p.term}, p.b) } } if pairs := getSavePairs(b, b2, nil); len(pairs) > 0 { pops += len(pairs) for _, p := range pairs { e.saveSet.Push([]*ast.Term{p.term}, p.b) } } e.saveStack.Push(expr, b1, b2) e.traceSave(expr) e.instr.stopTimer(partialOpSaveUnify) err := iter() e.saveStack.Pop() for i := 0; i < pops; i++ { e.saveSet.Pop() } return err } func (e *eval) saveCall(declArgsLen int, terms []*ast.Term, iter unifyIterator) error { expr := ast.NewExpr(terms) expr.With = e.query[e.index].With // If call-site includes output value then partial eval must add vars in output // position to the save set. pops := 0 if declArgsLen == len(terms)-2 { if pairs := getSavePairs(terms[len(terms)-1], e.bindings, nil); len(pairs) > 0 { pops += len(pairs) for _, p := range pairs { e.saveSet.Push([]*ast.Term{p.term}, p.b) } } } e.saveStack.Push(expr, e.bindings, nil) e.traceSave(expr) err := iter() e.saveStack.Pop() for i := 0; i < pops; i++ { e.saveSet.Pop() } return err } func (e *eval) saveInlinedNegatedExprs(exprs []*ast.Expr, iter unifyIterator) error { // This function does not include with statements on the exprs because // they will have already been saved and therefore had their any relevant // with statements set. for _, expr := range exprs { e.saveStack.Push(expr, nil, nil) e.traceSave(expr) } err := iter() for i := 0; i < len(exprs); i++ { e.saveStack.Pop() } return err } func (e *eval) getRules(ref ast.Ref) (*ast.IndexResult, error) { e.instr.startTimer(evalOpRuleIndex) defer e.instr.stopTimer(evalOpRuleIndex) index := e.compiler.RuleIndex(ref) if index == nil { return nil, nil } var result *ast.IndexResult var err error if e.indexing { result, err = index.Lookup(e) } else { result, err = index.AllRules(e) } if err != nil { return nil, err } var msg string if len(result.Rules) == 1 { msg = "(matched 1 rule)" } else { var b strings.Builder b.Grow(len("(matched NNNN rules)")) b.WriteString("matched ") b.WriteString(strconv.FormatInt(int64(len(result.Rules)), 10)) b.WriteString(" rules)") msg = b.String() } e.traceIndex(e.query[e.index], msg) return result, err } func (e *eval) Resolve(ref ast.Ref) (ast.Value, error) { e.instr.startTimer(evalOpResolve) if e.saveSet.Contains(ast.NewTerm(ref), nil) { e.instr.stopTimer(evalOpResolve) return nil, ast.UnknownValueErr{} } if ref[0].Equal(ast.InputRootDocument) { if e.input != nil { v, err := e.input.Value.Find(ref[1:]) if err != nil { v = nil } e.instr.stopTimer(evalOpResolve) return v, nil } e.instr.stopTimer(evalOpResolve) return nil, nil } if ref[0].Equal(ast.DefaultRootDocument) { var repValue ast.Value if e.data != nil { if v, err := e.data.Value.Find(ref[1:]); err == nil { repValue = v } else { repValue = nil } } if e.targetStack.Prefixed(ref) { e.instr.stopTimer(evalOpResolve) return repValue, nil } var merged ast.Value var err error // Converting large JSON values into AST values can be fairly expensive. For // example, a 2MB JSON value can take upwards of 30 millisceonds to convert. // We cache the result of conversion here in case the same base document is // being read multiple times during evaluation. realValue := e.baseCache.Get(ref) if realValue != nil { e.instr.counterIncr(evalOpBaseCacheHit) if repValue == nil { e.instr.stopTimer(evalOpResolve) return realValue, nil } var ok bool merged, ok = merge(repValue, realValue) if !ok { err = mergeConflictErr(ref[0].Location) } } else { e.instr.counterIncr(evalOpBaseCacheMiss) merged, err = e.resolveReadFromStorage(ref, repValue) } e.instr.stopTimer(evalOpResolve) return merged, err } e.instr.stopTimer(evalOpResolve) return nil, fmt.Errorf("illegal ref") } func (e *eval) resolveReadFromStorage(ref ast.Ref, a ast.Value) (ast.Value, error) { if refContainsNonScalar(ref) { return a, nil } path, err := storage.NewPathForRef(ref) if err != nil { if !storage.IsNotFound(err) { return nil, err } return a, nil } blob, err := e.store.Read(e.ctx, e.txn, path) if err != nil { if !storage.IsNotFound(err) { return nil, err } return a, nil } if len(path) == 0 { obj := blob.(map[string]interface{}) if len(obj) > 0 { cpy := make(map[string]interface{}, len(obj)-1) for k, v := range obj { if string(ast.SystemDocumentKey) == k { continue } cpy[k] = v } blob = cpy } } v, err := ast.InterfaceToValue(blob) if err != nil { return nil, err } e.baseCache.Put(ref, v) if a == nil { return v, nil } merged, ok := merge(a, v) if !ok { return nil, mergeConflictErr(ref[0].Location) } return merged, nil } func (e *eval) generateVar(suffix string) *ast.Term { return ast.VarTerm(fmt.Sprintf("%v_%v", e.genvarprefix, suffix)) } func (e *eval) rewrittenVar(v ast.Var) (ast.Var, bool) { if e.compiler != nil { if rw, ok := e.compiler.RewrittenVars[v]; ok { return rw, true } } if e.queryCompiler != nil { if rw, ok := e.queryCompiler.RewrittenVars()[v]; ok { return rw, true } } return v, false } type evalBuiltin struct { e *eval bi *ast.Builtin bctx BuiltinContext f BuiltinFunc terms []*ast.Term } func (e evalBuiltin) eval(iter unifyIterator) error { operands := make([]*ast.Term, len(e.terms)) for i := 0; i < len(e.terms); i++ { operands[i] = e.e.bindings.Plug(e.terms[i]) } numDeclArgs := len(e.bi.Decl.Args()) e.e.instr.startTimer(evalOpBuiltinCall) err := e.f(e.bctx, operands, func(output *ast.Term) error { e.e.instr.stopTimer(evalOpBuiltinCall) var err error if len(operands) == numDeclArgs { if output.Value.Compare(ast.Boolean(false)) != 0 { err = iter() } } else { err = e.e.unify(e.terms[len(e.terms)-1], output, iter) } e.e.instr.startTimer(evalOpBuiltinCall) return err }) e.e.instr.stopTimer(evalOpBuiltinCall) return err } type evalFunc struct { e *eval ref ast.Ref terms []*ast.Term } func (e evalFunc) eval(iter unifyIterator) error { ir, err := e.e.getRules(e.ref) if err != nil { return err } if ir.Empty() { return nil } if len(ir.Else) > 0 && e.e.unknown(e.e.query[e.e.index], e.e.bindings) { // Partial evaluation of ordered rules is not supported currently. Save the // expression and continue. This could be revisited in the future. return e.e.saveCall(len(ir.Rules[0].Head.Args), e.terms, iter) } var prev *ast.Term for i := range ir.Rules { next, err := e.evalOneRule(iter, ir.Rules[i], prev) if err != nil { return err } if next == nil { for _, rule := range ir.Else[ir.Rules[i]] { next, err = e.evalOneRule(iter, rule, prev) if err != nil { return err } if next != nil { break } } } if next != nil { prev = next } } return nil } func (e evalFunc) evalOneRule(iter unifyIterator, rule *ast.Rule, prev *ast.Term) (*ast.Term, error) { child := e.e.child(rule.Body) args := make(ast.Array, len(e.terms)-1) for i := range rule.Head.Args { args[i] = rule.Head.Args[i] } if len(args) == len(rule.Head.Args)+1 { args[len(args)-1] = rule.Head.Value } var result *ast.Term child.traceEnter(rule) err := child.biunifyArrays(e.terms[1:], args, e.e.bindings, child.bindings, func() error { return child.eval(func(child *eval) error { child.traceExit(rule) result = child.bindings.Plug(rule.Head.Value) if len(rule.Head.Args) == len(e.terms)-1 { if result.Value.Compare(ast.Boolean(false)) == 0 { return nil } } // Partial evaluation should explore all rules and may not produce // a ground result so we do not perform conflict detection or // deduplication. See "ignore conflicts: functions" test case for // an example. if !e.e.partial() { if prev != nil { if ast.Compare(prev, result) != 0 { return functionConflictErr(rule.Location) } child.traceRedo(rule) return nil } } prev = result if err := iter(); err != nil { return err } child.traceRedo(rule) return nil }) }) return result, err } type evalTree struct { e *eval ref ast.Ref plugged ast.Ref pos int bindings *bindings rterm *ast.Term rbindings *bindings node *ast.TreeNode } func (e evalTree) eval(iter unifyIterator) error { if len(e.ref) == e.pos { return e.finish(iter) } plugged := e.bindings.Plug(e.ref[e.pos]) if plugged.IsGround() { return e.next(iter, plugged) } return e.enumerate(iter) } func (e evalTree) finish(iter unifyIterator) error { // During partial evaluation it may not be possible to compute the value // for this reference if it refers to a virtual document so save the entire // expression. See "save: full extent" test case for an example. if e.node != nil && e.e.unknown(e.ref, e.e.bindings) { return e.e.saveUnify(ast.NewTerm(e.plugged), e.rterm, e.bindings, e.rbindings, iter) } v, err := e.extent() if err != nil || v == nil { return err } return e.e.biunify(e.rterm, v, e.rbindings, e.bindings, func() error { return iter() }) } func (e evalTree) next(iter unifyIterator, plugged *ast.Term) error { var node *ast.TreeNode cpy := e cpy.plugged[e.pos] = plugged cpy.pos++ if !e.e.targetStack.Prefixed(cpy.plugged[:cpy.pos]) { if e.node != nil { node = e.node.Child(plugged.Value) if node != nil && len(node.Values) > 0 { r := evalVirtual{ e: e.e, ref: e.ref, plugged: e.plugged, pos: e.pos, bindings: e.bindings, rterm: e.rterm, rbindings: e.rbindings, } r.plugged[e.pos] = plugged return r.eval(iter) } } } cpy.node = node return cpy.eval(iter) } func (e evalTree) enumerate(iter unifyIterator) error { doc, err := e.e.Resolve(e.plugged[:e.pos]) if err != nil { return err } if doc != nil { switch doc := doc.(type) { case ast.Array: for i := range doc { k := ast.IntNumberTerm(i) err := e.e.biunify(k, e.ref[e.pos], e.bindings, e.bindings, func() error { return e.next(iter, k) }) if err != nil { return err } } case ast.Object: err := doc.Iter(func(k, _ *ast.Term) error { return e.e.biunify(k, e.ref[e.pos], e.bindings, e.bindings, func() error { return e.next(iter, k) }) }) if err != nil { return err } case ast.Set: err := doc.Iter(func(elem *ast.Term) error { return e.e.biunify(elem, e.ref[e.pos], e.bindings, e.bindings, func() error { return e.next(iter, elem) }) }) if err != nil { return err } } } if e.node == nil { return nil } for k := range e.node.Children { key := ast.NewTerm(k) if err := e.e.biunify(key, e.ref[e.pos], e.bindings, e.bindings, func() error { return e.next(iter, key) }); err != nil { return err } } return nil } func (e evalTree) extent() (*ast.Term, error) { base, err := e.e.Resolve(e.plugged) if err != nil { return nil, err } virtual, err := e.leaves(e.plugged, e.node) if err != nil { return nil, err } if virtual == nil { if base == nil { return nil, nil } return ast.NewTerm(base), nil } if base != nil { merged, ok := merge(base, virtual) if !ok { return nil, mergeConflictErr(e.plugged[0].Location) } return ast.NewTerm(merged), nil } return ast.NewTerm(virtual), nil } func (e evalTree) leaves(plugged ast.Ref, node *ast.TreeNode) (ast.Object, error) { if e.node == nil { return nil, nil } result := ast.NewObject() for _, child := range node.Children { if child.Hide { continue } plugged = append(plugged, ast.NewTerm(child.Key)) var save ast.Value var err error if len(child.Values) > 0 { rterm := e.e.generateVar("leaf") err = e.e.unify(ast.NewTerm(plugged), rterm, func() error { save = e.e.bindings.Plug(rterm).Value return nil }) } else { save, err = e.leaves(plugged, child) } if err != nil { return nil, err } if save != nil { v := ast.NewObject([2]*ast.Term{plugged[len(plugged)-1], ast.NewTerm(save)}) result, _ = result.Merge(v) } plugged = plugged[:len(plugged)-1] } return result, nil } type evalVirtual struct { e *eval ref ast.Ref plugged ast.Ref pos int bindings *bindings rterm *ast.Term rbindings *bindings } func (e evalVirtual) eval(iter unifyIterator) error { ir, err := e.e.getRules(e.plugged[:e.pos+1]) if err != nil { return err } // Partial evaluation of ordered rules is not supported currently. Save the // expression and continue. This could be revisited in the future. if len(ir.Else) > 0 && e.e.unknown(e.ref, e.bindings) { return e.e.saveUnify(ast.NewTerm(e.ref), e.rterm, e.bindings, e.rbindings, iter) } switch ir.Kind { case ast.PartialSetDoc: eval := evalVirtualPartial{ e: e.e, ref: e.ref, plugged: e.plugged, pos: e.pos, ir: ir, bindings: e.bindings, rterm: e.rterm, rbindings: e.rbindings, empty: ast.SetTerm(), } return eval.eval(iter) case ast.PartialObjectDoc: eval := evalVirtualPartial{ e: e.e, ref: e.ref, plugged: e.plugged, pos: e.pos, ir: ir, bindings: e.bindings, rterm: e.rterm, rbindings: e.rbindings, empty: ast.ObjectTerm(), } return eval.eval(iter) default: eval := evalVirtualComplete{ e: e.e, ref: e.ref, plugged: e.plugged, pos: e.pos, ir: ir, bindings: e.bindings, rterm: e.rterm, rbindings: e.rbindings, } return eval.eval(iter) } } type evalVirtualPartial struct { e *eval ref ast.Ref plugged ast.Ref pos int ir *ast.IndexResult bindings *bindings rterm *ast.Term rbindings *bindings empty *ast.Term } func (e evalVirtualPartial) eval(iter unifyIterator) error { if len(e.ref) == e.pos+1 { // During partial evaluation, it may not be possible to produce a value // for this reference so save the entire expression. See "save: full // extent: partial object" test case for an example. if e.e.unknown(e.ref, e.bindings) { return e.e.saveUnify(ast.NewTerm(e.ref), e.rterm, e.bindings, e.rbindings, iter) } return e.evalAllRules(iter, e.ir.Rules) } var cacheKey ast.Ref if e.ir.Kind == ast.PartialObjectDoc { plugged := e.bindings.Plug(e.ref[e.pos+1]) if plugged.IsGround() { path := e.plugged[:e.pos+2] path[len(path)-1] = plugged cached := e.e.virtualCache.Get(path) if cached != nil { e.e.instr.counterIncr(evalOpVirtualCacheHit) return e.evalTerm(iter, cached, e.bindings) } e.e.instr.counterIncr(evalOpVirtualCacheMiss) cacheKey = path } } generateSupport := anyRefSetContainsPrefix(e.e.disableInlining, e.plugged[:e.pos+1]) if generateSupport { return e.partialEvalSupport(iter) } for _, rule := range e.ir.Rules { if err := e.evalOneRule(iter, rule, cacheKey); err != nil { return err } } return nil } func (e evalVirtualPartial) evalAllRules(iter unifyIterator, rules []*ast.Rule) error { result := e.empty for _, rule := range rules { child := e.e.child(rule.Body) child.traceEnter(rule) err := child.eval(func(*eval) error { child.traceExit(rule) var err error result, err = e.reduce(rule.Head, child.bindings, result) if err != nil { return err } child.traceRedo(rule) return nil }) if err != nil { return err } } return e.e.biunify(result, e.rterm, e.bindings, e.bindings, iter) } func (e evalVirtualPartial) evalOneRule(iter unifyIterator, rule *ast.Rule, cacheKey ast.Ref) error { key := e.ref[e.pos+1] child := e.e.child(rule.Body) child.traceEnter(rule) var defined bool err := child.biunify(rule.Head.Key, key, child.bindings, e.bindings, func() error { defined = true return child.eval(func(child *eval) error { child.traceExit(rule) term := rule.Head.Value if term == nil { term = rule.Head.Key } if cacheKey != nil { result := child.bindings.Plug(term) e.e.virtualCache.Put(cacheKey, result) } term, termbindings := child.bindings.apply(term) err := e.evalTerm(iter, term, termbindings) if err != nil { return err } child.traceRedo(rule) return nil }) }) if err != nil { return err } if !defined { child.traceFail(rule) } return nil } func (e evalVirtualPartial) partialEvalSupport(iter unifyIterator) error { path := e.plugged[:e.pos+1].Insert(e.e.saveNamespace, 1) if !e.e.saveSupport.Exists(path) { for i := range e.ir.Rules { err := e.partialEvalSupportRule(iter, e.ir.Rules[i], path) if err != nil { return err } } } rewritten := ast.NewTerm(e.ref.Insert(e.e.saveNamespace, 1)) return e.e.saveUnify(rewritten, e.rterm, e.bindings, e.rbindings, iter) } func (e evalVirtualPartial) partialEvalSupportRule(iter unifyIterator, rule *ast.Rule, path ast.Ref) error { child := e.e.child(rule.Body) child.traceEnter(rule) e.e.saveStack.PushQuery(nil) err := child.eval(func(child *eval) error { child.traceExit(rule) current := e.e.saveStack.PopQuery() plugged := current.Plug(e.e.caller.bindings) var key, value *ast.Term if rule.Head.Key != nil { key = child.bindings.PlugNamespaced(rule.Head.Key, e.e.caller.bindings) } if rule.Head.Value != nil { value = child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings) } head := ast.NewHead(rule.Head.Name, key, value) p := copypropagation.New(head.Vars()).WithEnsureNonEmptyBody(true) e.e.saveSupport.Insert(path, &ast.Rule{ Head: head, Body: p.Apply(plugged), Default: rule.Default, }) child.traceRedo(rule) e.e.saveStack.PushQuery(current) return nil }) e.e.saveStack.PopQuery() return err } func (e evalVirtualPartial) evalTerm(iter unifyIterator, term *ast.Term, termbindings *bindings) error { eval := evalTerm{ e: e.e, ref: e.ref, pos: e.pos + 2, bindings: e.bindings, term: term, termbindings: termbindings, rterm: e.rterm, rbindings: e.rbindings, } return eval.eval(iter) } func (e evalVirtualPartial) reduce(head *ast.Head, b *bindings, result *ast.Term) (*ast.Term, error) { switch v := result.Value.(type) { case ast.Set: v.Add(b.Plug(head.Key)) case ast.Object: key := b.Plug(head.Key) value := b.Plug(head.Value) exist := v.Get(key) if exist != nil && !exist.Equal(value) { return nil, objectDocKeyConflictErr(head.Location) } v.Insert(key, value) result.Value = v } return result, nil } type evalVirtualComplete struct { e *eval ref ast.Ref plugged ast.Ref pos int ir *ast.IndexResult bindings *bindings rterm *ast.Term rbindings *bindings } func (e evalVirtualComplete) eval(iter unifyIterator) error { if e.ir.Empty() { return nil } if len(e.ir.Rules) > 0 && len(e.ir.Rules[0].Head.Args) > 0 { return nil } if !e.e.unknown(e.ref, e.bindings) { return e.evalValue(iter) } var generateSupport bool if e.ir.Default != nil { // If the other term is not constant OR it's equal to the default value, then // a support rule must be produced as the default value _may_ be required. On // the other hand, if the other term is constant (i.e., it does not require // evaluation) and it differs from the default value then the default value is // _not_ required, so partially evaluate the rule normally. rterm := e.rbindings.Plug(e.rterm) generateSupport = !ast.IsConstant(rterm.Value) || e.ir.Default.Head.Value.Equal(rterm) } generateSupport = generateSupport || anyRefSetContainsPrefix(e.e.disableInlining, e.plugged[:e.pos+1]) if generateSupport { return e.partialEvalSupport(iter) } return e.partialEval(iter) } func (e evalVirtualComplete) evalValue(iter unifyIterator) error { cached := e.e.virtualCache.Get(e.plugged[:e.pos+1]) if cached != nil { e.e.instr.counterIncr(evalOpVirtualCacheHit) return e.evalTerm(iter, cached, e.bindings) } e.e.instr.counterIncr(evalOpVirtualCacheMiss) var prev *ast.Term for i := range e.ir.Rules { next, err := e.evalValueRule(iter, e.ir.Rules[i], prev) if err != nil { return err } if next == nil { for _, rule := range e.ir.Else[e.ir.Rules[i]] { next, err = e.evalValueRule(iter, rule, prev) if err != nil { return err } if next != nil { break } } } if next != nil { prev = next } } if e.ir.Default != nil && prev == nil { _, err := e.evalValueRule(iter, e.ir.Default, prev) return err } return nil } func (e evalVirtualComplete) evalValueRule(iter unifyIterator, rule *ast.Rule, prev *ast.Term) (*ast.Term, error) { child := e.e.child(rule.Body) child.traceEnter(rule) var result *ast.Term err := child.eval(func(child *eval) error { child.traceExit(rule) result = child.bindings.Plug(rule.Head.Value) if prev != nil { if ast.Compare(result, prev) != 0 { return completeDocConflictErr(rule.Location) } child.traceRedo(rule) return nil } prev = result e.e.virtualCache.Put(e.plugged[:e.pos+1], result) term, termbindings := child.bindings.apply(rule.Head.Value) err := e.evalTerm(iter, term, termbindings) if err != nil { return err } child.traceRedo(rule) return nil }) return result, err } func (e evalVirtualComplete) partialEval(iter unifyIterator) error { for _, rule := range e.ir.Rules { child := e.e.child(rule.Body) child.traceEnter(rule) err := child.eval(func(child *eval) error { child.traceExit(rule) term, termbindings := child.bindings.apply(rule.Head.Value) err := e.evalTerm(iter, term, termbindings) if err != nil { return err } child.traceRedo(rule) return nil }) if err != nil { return err } } return nil } func (e evalVirtualComplete) partialEvalSupport(iter unifyIterator) error { path := e.plugged[:e.pos+1].Insert(e.e.saveNamespace, 1) if !e.e.saveSupport.Exists(path) { for i := range e.ir.Rules { err := e.partialEvalSupportRule(iter, e.ir.Rules[i], path) if err != nil { return err } } if e.ir.Default != nil { err := e.partialEvalSupportRule(iter, e.ir.Default, path) if err != nil { return err } } } rewritten := ast.NewTerm(e.ref.Insert(e.e.saveNamespace, 1)) return e.e.saveUnify(rewritten, e.rterm, e.bindings, e.rbindings, iter) } func (e evalVirtualComplete) partialEvalSupportRule(iter unifyIterator, rule *ast.Rule, path ast.Ref) error { child := e.e.child(rule.Body) child.traceEnter(rule) e.e.saveStack.PushQuery(nil) err := child.eval(func(child *eval) error { child.traceExit(rule) current := e.e.saveStack.PopQuery() plugged := current.Plug(e.e.caller.bindings) head := ast.NewHead(rule.Head.Name, nil, child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings)) p := copypropagation.New(head.Vars()).WithEnsureNonEmptyBody(true) e.e.saveSupport.Insert(path, &ast.Rule{ Head: head, Body: applyCopyPropagation(p, e.e.instr, plugged), Default: rule.Default, }) child.traceRedo(rule) e.e.saveStack.PushQuery(current) return nil }) e.e.saveStack.PopQuery() return err } func (e evalVirtualComplete) evalTerm(iter unifyIterator, term *ast.Term, termbindings *bindings) error { eval := evalTerm{ e: e.e, ref: e.ref, pos: e.pos + 1, bindings: e.bindings, term: term, termbindings: termbindings, rterm: e.rterm, rbindings: e.rbindings, } return eval.eval(iter) } type evalTerm struct { e *eval ref ast.Ref pos int bindings *bindings term *ast.Term termbindings *bindings rterm *ast.Term rbindings *bindings } func (e evalTerm) eval(iter unifyIterator) error { if len(e.ref) == e.pos { return e.e.biunify(e.term, e.rterm, e.termbindings, e.rbindings, iter) } if e.e.saveSet.Contains(e.term, e.termbindings) { return e.save(iter) } plugged := e.bindings.Plug(e.ref[e.pos]) if plugged.IsGround() { return e.next(iter, plugged) } return e.enumerate(iter) } func (e evalTerm) next(iter unifyIterator, plugged *ast.Term) error { term, bindings := e.get(plugged) if term == nil { return nil } cpy := e cpy.term = term cpy.termbindings = bindings cpy.pos++ return cpy.eval(iter) } func (e evalTerm) enumerate(iter unifyIterator) error { switch v := e.term.Value.(type) { case ast.Array: for i := range v { k := ast.IntNumberTerm(i) err := e.e.biunify(k, e.ref[e.pos], e.bindings, e.bindings, func() error { return e.next(iter, k) }) if err != nil { return err } } case ast.Object: return v.Iter(func(k, _ *ast.Term) error { return e.e.biunify(k, e.ref[e.pos], e.termbindings, e.bindings, func() error { return e.next(iter, e.termbindings.Plug(k)) }) }) case ast.Set: return v.Iter(func(elem *ast.Term) error { return e.e.biunify(elem, e.ref[e.pos], e.termbindings, e.bindings, func() error { return e.next(iter, e.termbindings.Plug(elem)) }) }) } return nil } func (e evalTerm) get(plugged *ast.Term) (*ast.Term, *bindings) { switch v := e.term.Value.(type) { case ast.Set: if v.IsGround() { if v.Contains(plugged) { return e.termbindings.apply(plugged) } } else { var t *ast.Term var b *bindings stop := v.Until(func(elem *ast.Term) bool { if e.termbindings.Plug(elem).Equal(plugged) { t, b = e.termbindings.apply(plugged) return true } return false }) if stop { return t, b } } case ast.Object: if v.IsGround() { term := v.Get(plugged) if term != nil { return e.termbindings.apply(term) } } else { var t *ast.Term var b *bindings stop := v.Until(func(k, v *ast.Term) bool { if e.termbindings.Plug(k).Equal(plugged) { t, b = e.termbindings.apply(v) return true } return false }) if stop { return t, b } } case ast.Array: term := v.Get(plugged) if term != nil { return e.termbindings.apply(term) } } return nil, nil } func (e evalTerm) save(iter unifyIterator) error { suffix := e.ref[e.pos:] ref := make(ast.Ref, len(suffix)+1) ref[0] = e.term for i := 0; i < len(suffix); i++ { ref[i+1] = suffix[i] } return e.e.biunify(ast.NewTerm(ref), e.rterm, e.termbindings, e.rbindings, iter) } func applyCopyPropagation(p *copypropagation.CopyPropagator, instr *Instrumentation, body ast.Body) ast.Body { instr.startTimer(partialOpCopyPropagation) result := p.Apply(body) instr.stopTimer(partialOpCopyPropagation) return result } func nonGroundKeys(a ast.Object) bool { return a.Until(func(k, _ *ast.Term) bool { return !k.IsGround() }) } func plugKeys(a ast.Object, b *bindings) ast.Object { plugged, _ := a.Map(func(k, v *ast.Term) (*ast.Term, *ast.Term, error) { return b.Plug(k), v, nil }) return plugged } func plugSlice(xs []*ast.Term, b *bindings) []*ast.Term { cpy := make([]*ast.Term, len(xs)) for i := range cpy { cpy[i] = b.Plug(xs[i]) } return cpy } func canInlineNegation(safe ast.VarSet, queries []ast.Body) bool { size := 1 for _, query := range queries { size *= len(query) for _, expr := range query { if !expr.Negated { // Positive expressions containing variables cannot be trivially negated // because they become unsafe (e.g., "x = 1" negated is "not x = 1" making x // unsafe.) We check if the vars in the expr are already safe. vis := ast.NewVarVisitor().WithParams(ast.VarVisitorParams{ SkipRefCallHead: true, SkipClosures: true, }) vis.Walk(expr) unsafe := vis.Vars().Diff(safe).Diff(ast.ReservedVars) if len(unsafe) > 0 { return false } } } } // NOTE(tsandall): this limit is arbitrary–it's only in place to prevent the // partial evaluation result from blowing up. In the future, we could make this // configurable or do something more clever. if size > 16 { return false } return true } func complementedCartesianProduct(queries []ast.Body, idx int, curr ast.Body, iter func(ast.Body) error) error { if idx == len(queries) { return iter(curr) } for _, expr := range queries[idx] { curr = append(curr, expr.Complement()) if err := complementedCartesianProduct(queries, idx+1, curr, iter); err != nil { return err } curr = curr[:len(curr)-1] } return nil } func isInputRef(term *ast.Term) bool { if ref, ok := term.Value.(ast.Ref); ok { if ref.HasPrefix(ast.InputRootRef) { return true } } return false } func isDataRef(term *ast.Term) bool { if ref, ok := term.Value.(ast.Ref); ok { if ref.HasPrefix(ast.DefaultRootRef) { return true } } return false } func merge(a, b ast.Value) (ast.Value, bool) { aObj, ok1 := a.(ast.Object) bObj, ok2 := b.(ast.Object) if ok1 && ok2 { return mergeObjects(aObj, bObj) } return nil, false } // mergeObjects returns a new Object containing the non-overlapping keys of // the objA and objB. If there are overlapping keys between objA and objB, // the values of associated with the keys are merged. Only // objects can be merged with other objects. If the values cannot be merged, // objB value will be overwritten by objA value. func mergeObjects(objA, objB ast.Object) (result ast.Object, ok bool) { result = ast.NewObject() stop := objA.Until(func(k, v *ast.Term) bool { if v2 := objB.Get(k); v2 == nil { result.Insert(k, v) } else { obj1, ok1 := v.Value.(ast.Object) obj2, ok2 := v2.Value.(ast.Object) if !ok1 || !ok2 { result.Insert(k, v) return false } obj3, ok := mergeObjects(obj1, obj2) if !ok { return true } result.Insert(k, ast.NewTerm(obj3)) } return false }) if stop { return nil, false } objB.Foreach(func(k, v *ast.Term) { if v2 := objA.Get(k); v2 == nil { result.Insert(k, v) } }) return result, true } func anyRefSetContainsPrefix(s [][]ast.Ref, prefix ast.Ref) bool { for _, refs := range s { for _, ref := range refs { if ref.HasPrefix(prefix) { return true } } } return false } func refContainsNonScalar(ref ast.Ref) bool { for _, term := range ref[1:] { if !ast.IsScalar(term.Value) { return true } } return false }