// Copyright 2018 The OPA Authors. All rights reserved. // Use of this source code is governed by an Apache2 // license that can be found in the LICENSE file. // Package planner contains a query planner for Rego queries. package planner import ( "errors" "fmt" "io" "sort" "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/ast/location" "github.com/open-policy-agent/opa/internal/debug" "github.com/open-policy-agent/opa/ir" ) // QuerySet represents the input to the planner. type QuerySet struct { Name string Queries []ast.Body RewrittenVars map[ast.Var]ast.Var } type planiter func() error type planLocalIter func(ir.Local) error type stmtFactory func(ir.Local) ir.Stmt // Planner implements a query planner for Rego queries. type Planner struct { policy *ir.Policy // result of planning queries []QuerySet // input queries to plan modules []*ast.Module // input modules to support queries strings map[string]int // global string constant indices files map[string]int // global file constant indices externs map[string]*ast.Builtin // built-in functions that are required in execution environment decls map[string]*ast.Builtin // built-in functions that may be provided in execution environment rules *ruletrie // rules that may be planned mocks *functionMocksStack // replacements for built-in functions funcs *funcstack // functions that have been planned plan *ir.Plan // in-progress query plan curr *ir.Block // in-progress query block vars *varstack // in-scope variables ltarget ir.Operand // target variable or constant of last planned statement lnext ir.Local // next variable to use loc *location.Location // location currently "being planned" debug debug.Debug // debug information produced during planning } // debugf prepends the planner location. We're passing callstack depth 2 because // it should still log the file location of p.debugf. func (p *Planner) debugf(format string, args ...interface{}) { var msg string if p.loc != nil { msg = fmt.Sprintf("%s: "+format, append([]interface{}{p.loc}, args...)...) } else { msg = fmt.Sprintf(format, args...) } _ = p.debug.Output(2, msg) // ignore error } // New returns a new Planner object. func New() *Planner { return &Planner{ policy: &ir.Policy{ Static: &ir.Static{}, Plans: &ir.Plans{}, Funcs: &ir.Funcs{}, }, strings: map[string]int{}, files: map[string]int{}, externs: map[string]*ast.Builtin{}, lnext: ir.Unused, vars: newVarstack(map[ast.Var]ir.Local{ ast.InputRootDocument.Value.(ast.Var): ir.Input, ast.DefaultRootDocument.Value.(ast.Var): ir.Data, }), rules: newRuletrie(), funcs: newFuncstack(), mocks: newFunctionMocksStack(), debug: debug.Discard(), } } // WithBuiltinDecls tells the planner what built-in function may be available // inside the execution environment. func (p *Planner) WithBuiltinDecls(decls map[string]*ast.Builtin) *Planner { p.decls = decls return p } // WithQueries sets the query sets to generate a plan for. The rewritten collection provides // a mapping of rewritten query vars for each query set. The planner uses rewritten variables // but the result set key will be the original variable name. func (p *Planner) WithQueries(queries []QuerySet) *Planner { p.queries = queries return p } // WithModules sets the module set that contains query dependencies. func (p *Planner) WithModules(modules []*ast.Module) *Planner { p.modules = modules return p } // WithDebug sets where debug messages are written to. func (p *Planner) WithDebug(sink io.Writer) *Planner { if sink != nil { p.debug = debug.New(sink) } return p } // Plan returns a IR plan for the policy query. func (p *Planner) Plan() (*ir.Policy, error) { if err := p.buildFunctrie(); err != nil { return nil, err } if err := p.planQueries(); err != nil { return nil, err } if err := p.planExterns(); err != nil { return nil, err } return p.policy, nil } func (p *Planner) buildFunctrie() error { for _, module := range p.modules { // Create functrie node for empty packages so that extent queries return // empty objects. For example: // // package x.y // // Query: data.x // // Expected result: {"y": {}} if len(module.Rules) == 0 { _ = p.rules.LookupOrInsert(module.Package.Path) continue } for _, rule := range module.Rules { r := rule.Ref().StringPrefix() val := p.rules.LookupOrInsert(r) val.rules = val.DescendantRules() val.rules = append(val.rules, rule) val.children = nil } } return nil } func (p *Planner) planRules(rules []*ast.Rule) (string, error) { // We know the rules with closer to the root (shorter static path) are ordered first. pathRef := rules[0].Ref() // figure out what our rules' collective name/path is: // if we're planning both p.q.r and p.q[s], we'll name // the function p.q (for the mapping table) pieces := len(pathRef) for i := range rules { r := rules[i].Ref() for j, t := range r { if _, ok := t.Value.(ast.String); !ok && j > 0 && j < pieces { pieces = j } } } // control if p.a = 1 is to return 1 directly; or insert 1 under key "a" into an object buildObject := pieces != len(pathRef) var pathPieces []string for i := 1; /* skip `data` */ i < pieces; i++ { switch q := pathRef[i].Value.(type) { case ast.String: pathPieces = append(pathPieces, string(q)) default: panic("impossible") } } path := pathRef[:pieces].String() if funcName, ok := p.funcs.Get(path); ok { return funcName, nil } // Save current state of planner. // // TODO(tsandall): perhaps we would be better off using stacks here or // splitting block planner into separate struct that could be instantiated // for rule and comprehension bodies. pvars := p.vars pcurr := p.curr pltarget := p.ltarget plnext := p.lnext ploc := p.loc // Reset the variable counter for the function plan. p.lnext = ir.Input // Set the location to the rule head. p.loc = rules[0].Head.Loc() // Create function definition for rules. fn := &ir.Func{ Name: fmt.Sprintf("g%d.%s", p.funcs.gen(), path), Params: []ir.Local{ p.newLocal(), // input document p.newLocal(), // data document }, Return: p.newLocal(), Path: append([]string{fmt.Sprintf("g%d", p.funcs.gen())}, pathPieces...), } // Initialize parameters for functions. for i := 0; i < len(rules[0].Head.Args); i++ { fn.Params = append(fn.Params, p.newLocal()) } params := fn.Params[2:] // Initialize return value for partial set/object rules. Complete document // rules assign directly to `fn.Return`. switch rules[0].Head.RuleKind() { case ast.SingleValue: if buildObject { fn.Blocks = append(fn.Blocks, p.blockWithStmt(&ir.MakeObjectStmt{Target: fn.Return})) } case ast.MultiValue: if buildObject { fn.Blocks = append(fn.Blocks, p.blockWithStmt(&ir.MakeObjectStmt{Target: fn.Return})) } else { fn.Blocks = append(fn.Blocks, p.blockWithStmt(&ir.MakeSetStmt{Target: fn.Return})) } } // For complete document rules, allocate one local variable for output // of the rule body + else branches. // It is used to let ordered rules (else blocks) check if the previous // rule body returned a value. lresult := p.newLocal() // At this point the locals for the params and return value have been // allocated. This will be the first local that can be used in each block. lnext := p.lnext var defaultRule *ast.Rule var ruleLoc *location.Location // We sort rules by ref length, to ensure that when merged, we can detect conflicts when one // rule attempts to override values (deep and shallow) defined by another rule. sort.Slice(rules, func(i, j int) bool { return len(rules[i].Ref()) > len(rules[j].Ref()) }) // Generate function blocks for rules. for i := range rules { // Save location of first encountered rule for the ReturnLocalStmt below if i == 0 { ruleLoc = p.loc } // Save default rule for the end. if rules[i].Default { defaultRule = rules[i] continue } // Ordered rules are nested inside an additional block so that execution // can short-circuit. For unordered rules, blocks can be added directly // to the function. var blocks *[]*ir.Block if rules[i].Else == nil { blocks = &fn.Blocks } else { stmt := &ir.BlockStmt{} block := &ir.Block{Stmts: []ir.Stmt{stmt}} fn.Blocks = append(fn.Blocks, block) blocks = &stmt.Blocks } var prev *ast.Rule // Unordered rules are treated as a special case of ordered rules. for rule := rules[i]; rule != nil; prev, rule = rule, rule.Else { // Update the location for each ordered rule. p.loc = rule.Head.Loc() // Setup planner for block. p.lnext = lnext p.vars = newVarstack(map[ast.Var]ir.Local{ ast.InputRootDocument.Value.(ast.Var): fn.Params[0], ast.DefaultRootDocument.Value.(ast.Var): fn.Params[1], }) curr := &ir.Block{} *blocks = append(*blocks, curr) p.curr = curr if prev != nil { // Ordered rules are handled by short circuiting execution. The // plan will jump out to the extra block that was planned above. p.appendStmt(&ir.IsUndefinedStmt{Source: lresult}) } else { // The first rule body resets the local, so it can be reused. // TODO(sr): I don't think we need this anymore. Double-check? Perhaps multi-value rules need it. p.appendStmt(&ir.ResetLocalStmt{Target: lresult}) } // Complete and partial rules are treated as special cases of // functions. If there are no args, the first step is a no-op. err := p.planFuncParams(params, rule.Head.Args, 0, func() error { // Run planner on the rule body. return p.planQuery(rule.Body, 0, func() error { // Run planner on the result. switch rule.Head.RuleKind() { case ast.SingleValue: if buildObject { ref := rule.Ref() return p.planTerm(rule.Head.Value, func() error { value := p.ltarget return p.planNestedObjects(fn.Return, ref[pieces:len(ref)-1], func(obj ir.Local) error { return p.planTerm(ref[len(ref)-1], func() error { key := p.ltarget p.appendStmt(&ir.ObjectInsertOnceStmt{ Object: obj, Key: key, Value: value, }) return nil }) }) }) } return p.planTerm(rule.Head.Value, func() error { p.appendStmt(&ir.AssignVarOnceStmt{ Target: lresult, Source: p.ltarget, }) return nil }) case ast.MultiValue: if buildObject { ref := rule.Ref() // we drop the trailing set key from the ref return p.planNestedObjects(fn.Return, ref[pieces:len(ref)-1], func(obj ir.Local) error { // Last term on rule ref is the key an which the set is assigned in the deepest nested object return p.planTerm(ref[len(ref)-1], func() error { key := p.ltarget return p.planTerm(rule.Head.Key, func() error { value := p.ltarget factory := func(v ir.Local) ir.Stmt { return &ir.MakeSetStmt{Target: v} } return p.planDotOr(obj, key, factory, func(set ir.Local) error { p.appendStmt(&ir.SetAddStmt{ Set: set, Value: value, }) p.appendStmt(&ir.ObjectInsertStmt{Key: key, Value: op(set), Object: obj}) return nil }) }) }) }) } return p.planTerm(rule.Head.Key, func() error { p.appendStmt(&ir.SetAddStmt{ Set: fn.Return, Value: p.ltarget, }) return nil }) default: return fmt.Errorf("illegal rule kind") } }) }) if err != nil { return "", err } } // rule[i] and its else-rule(s), if present, are done if rules[i].Head.RuleKind() == ast.SingleValue && !buildObject { end := &ir.Block{} p.appendStmtToBlock(&ir.IsDefinedStmt{Source: lresult}, end) p.appendStmtToBlock( &ir.AssignVarOnceStmt{ Target: fn.Return, Source: op(lresult), }, end) *blocks = append(*blocks, end) } } // Default rules execute if the return is undefined. if defaultRule != nil { // Set the location for the default rule head. p.loc = defaultRule.Head.Loc() // NOTE(sr) for `default p = 1`, // defaultRule.Loc() is `default`, // defaultRule.Head.Loc() is `p = 1`. fn.Blocks = append(fn.Blocks, p.blockWithStmt(&ir.IsUndefinedStmt{Source: fn.Return})) p.curr = fn.Blocks[len(fn.Blocks)-1] err := p.planQuery(defaultRule.Body, 0, func() error { p.loc = defaultRule.Head.Loc() return p.planTerm(defaultRule.Head.Value, func() error { p.appendStmt(&ir.AssignVarOnceStmt{ Target: fn.Return, Source: p.ltarget, }) return nil }) }) if err != nil { return "", err } } p.loc = ruleLoc // All rules return a value. fn.Blocks = append(fn.Blocks, p.blockWithStmt(&ir.ReturnLocalStmt{Source: fn.Return})) p.appendFunc(fn) p.funcs.Add(path, fn.Name) // Restore the state of the planner. p.lnext = plnext p.ltarget = pltarget p.vars = pvars p.curr = pcurr p.loc = ploc return fn.Name, nil } func (p *Planner) planDotOr(obj ir.Local, key ir.Operand, or stmtFactory, iter planLocalIter) error { // We're constructing the following plan: // // | block a // | | block b // | | | dot &{Source:Local Key:{Value:Local} Target:Local} // | | | break 1 // | | or &{Target:Local} // | iter &{Target:Local} # may update Local. // | *ir.ObjectInsertStmt &{Key:{Value:Local} Value:{Value:Local} Object:Local} prev := p.curr dotBlock := &ir.Block{} p.curr = dotBlock val := p.newLocal() p.appendStmt(&ir.DotStmt{ Source: op(obj), Key: key, Target: val, }) p.appendStmt(&ir.BreakStmt{Index: 1}) outerBlock := &ir.Block{ Stmts: []ir.Stmt{ &ir.BlockStmt{Blocks: []*ir.Block{dotBlock}}, // FIXME: Set Location or(val), }, } p.curr = prev p.appendStmt(&ir.BlockStmt{Blocks: []*ir.Block{outerBlock}}) if err := iter(val); err != nil { return err } p.appendStmt(&ir.ObjectInsertStmt{Key: key, Value: op(val), Object: obj}) return nil } func (p *Planner) planNestedObjects(obj ir.Local, ref ast.Ref, iter planLocalIter) error { if len(ref) == 0 { //return fmt.Errorf("nested object construction didn't create object") return iter(obj) } t := ref[0] return p.planTerm(t, func() error { key := p.ltarget factory := func(v ir.Local) ir.Stmt { return &ir.MakeObjectStmt{Target: v} } return p.planDotOr(obj, key, factory, func(childObj ir.Local) error { return p.planNestedObjects(childObj, ref[1:], iter) }) }) } func (p *Planner) planFuncParams(params []ir.Local, args ast.Args, idx int, iter planiter) error { if idx >= len(args) { return iter() } return p.planUnifyLocal(op(params[idx]), args[idx], func() error { return p.planFuncParams(params, args, idx+1, iter) }) } func (p *Planner) planQueries() error { for _, qs := range p.queries { // Initialize the plan with a block that prepares the query result. p.plan = &ir.Plan{Name: qs.Name} p.policy.Plans.Plans = append(p.policy.Plans.Plans, p.plan) p.curr = &ir.Block{} // Build a set of variables appearing in the query and allocate strings for // each one. The strings will be used in the result set objects. qvs := ast.NewVarSet() for _, q := range qs.Queries { vs := q.Vars(ast.VarVisitorParams{SkipRefCallHead: true, SkipClosures: true}).Diff(ast.ReservedVars) qvs.Update(vs) } lvarnames := make(map[ast.Var]ir.StringIndex, len(qvs)) for _, qv := range qvs.Sorted() { qv = rewrittenVar(qs.RewrittenVars, qv) if !qv.IsGenerated() && !qv.IsWildcard() { lvarnames[qv] = ir.StringIndex(p.getStringConst(string(qv))) } } if len(p.curr.Stmts) > 0 { p.appendBlock(p.curr) } lnext := p.lnext for _, q := range qs.Queries { p.loc = q.Loc() p.lnext = lnext p.vars.Push(map[ast.Var]ir.Local{}) p.curr = &ir.Block{} defined := false qvs := q.Vars(ast.VarVisitorParams{SkipRefCallHead: true, SkipClosures: true}).Diff(ast.ReservedVars).Sorted() if err := p.planQuery(q, 0, func() error { // Add an object containing variable bindings into the result set. lr := p.newLocal() p.appendStmt(&ir.MakeObjectStmt{ Target: lr, }) for _, qv := range qvs { rw := rewrittenVar(qs.RewrittenVars, qv) if !rw.IsGenerated() && !rw.IsWildcard() { p.appendStmt(&ir.ObjectInsertStmt{ Object: lr, Key: op(lvarnames[rw]), Value: p.vars.GetOpOrEmpty(qv), }) } } p.appendStmt(&ir.ResultSetAddStmt{ Value: lr, }) defined = true return nil }); err != nil { return err } p.vars.Pop() if defined { p.appendBlock(p.curr) } } } return nil } func (p *Planner) planQuery(q ast.Body, index int, iter planiter) error { if index >= len(q) { return iter() } old := p.loc p.loc = q[index].Loc() err := p.planExpr(q[index], func() error { return p.planQuery(q, index+1, func() error { curr := p.loc p.loc = old err := iter() p.loc = curr return err }) }) p.loc = old return err } // TODO(tsandall): improve errors to include location information. func (p *Planner) planExpr(e *ast.Expr, iter planiter) error { switch { case e.Negated: return p.planNot(e, iter) case len(e.With) > 0: return p.planWith(e, iter) case e.IsCall(): return p.planExprCall(e, iter) case e.IsEvery(): return p.planExprEvery(e, iter) } return p.planExprTerm(e, iter) } func (p *Planner) planNot(e *ast.Expr, iter planiter) error { not := &ir.NotStmt{ Block: &ir.Block{}, } prev := p.curr p.curr = not.Block if err := p.planExpr(e.Complement(), func() error { return nil }); err != nil { return err } p.curr = prev p.appendStmt(not) return iter() } func (p *Planner) planWith(e *ast.Expr, iter planiter) error { // Plan the values that will be applied by the `with` modifiers. All values // must be defined for the overall expression to evaluate. values := make([]*ast.Term, 0, len(e.With)) // NOTE(sr): we could be overallocating if there are builtin replacements targets := make([]ast.Ref, 0, len(e.With)) mocks := frame{} for _, w := range e.With { v := w.Target.Value.(ast.Ref) switch { case p.isFunction(v): // nothing to do case ast.DefaultRootDocument.Equal(v[0]) || ast.InputRootDocument.Equal(v[0]): values = append(values, w.Value) targets = append(targets, w.Target.Value.(ast.Ref)) continue // not a mock } mocks[w.Target.String()] = w.Value } return p.planTermSlice(values, func(locals []ir.Operand) error { p.mocks.PushFrame(mocks) paths := make([][]int, len(targets)) saveVars := ast.NewVarSet() dataRefs := []ast.Ref{} for i, target := range targets { paths[i] = make([]int, len(target)-1) for j := 1; j < len(target); j++ { if s, ok := target[j].Value.(ast.String); ok { paths[i][j-1] = p.getStringConst(string(s)) } else { return errors.New("invalid with target") } } head := target[0].Value.(ast.Var) saveVars.Add(head) if head.Equal(ast.DefaultRootDocument.Value) { dataRefs = append(dataRefs, target) } } restore := make([][2]ir.Local, len(saveVars)) for i, v := range saveVars.Sorted() { lorig := p.vars.GetOrEmpty(v) lsave := p.newLocal() p.appendStmt(&ir.AssignVarStmt{Source: op(lorig), Target: lsave}) restore[i] = [2]ir.Local{lorig, lsave} } // If any of the `with` statements targeted the data document, overwriting // parts of the ruletrie; or if a function has been mocked, we shadow the // existing planned functions during expression planning. // This causes the planner to re-plan any rules that may be required during // planning of this expression (transitively). shadowing := p.dataRefsShadowRuletrie(dataRefs) || len(mocks) > 0 if shadowing { p.funcs.Push(map[string]string{}) for _, ref := range dataRefs { p.rules.Push(ref) } } err := p.planWithRec(e, paths, locals, 0, func() error { p.mocks.PopFrame() if shadowing { p.funcs.Pop() for i := len(dataRefs) - 1; i >= 0; i-- { p.rules.Pop(dataRefs[i]) } } err := p.planWithUndoRec(restore, 0, func() error { err := iter() p.mocks.PushFrame(mocks) if shadowing { p.funcs.Push(map[string]string{}) for _, ref := range dataRefs { p.rules.Push(ref) } } return err }) return err }) p.mocks.PopFrame() if shadowing { p.funcs.Pop() for i := len(dataRefs) - 1; i >= 0; i-- { p.rules.Pop(dataRefs[i]) } } return err }) } func (p *Planner) planWithRec(e *ast.Expr, targets [][]int, values []ir.Operand, index int, iter planiter) error { if index >= len(targets) { return p.planExpr(e.NoWith(), iter) } prev := p.curr p.curr = &ir.Block{} err := p.planWithRec(e, targets, values, index+1, iter) if err != nil { return err } block := p.curr p.curr = prev target := e.With[index].Target.Value.(ast.Ref) head := target[0].Value.(ast.Var) p.appendStmt(&ir.WithStmt{ Local: p.vars.GetOrEmpty(head), Path: targets[index], Value: values[index], Block: block, }) return nil } func (p *Planner) planWithUndoRec(restore [][2]ir.Local, index int, iter planiter) error { if index >= len(restore) { return iter() } prev := p.curr p.curr = &ir.Block{} if err := p.planWithUndoRec(restore, index+1, iter); err != nil { return err } block := p.curr p.curr = prev lorig := restore[index][0] lsave := restore[index][1] p.appendStmt(&ir.WithStmt{ Local: lorig, Value: op(lsave), Block: block, }) return nil } func (p *Planner) dataRefsShadowRuletrie(refs []ast.Ref) bool { for _, ref := range refs { if p.rules.Lookup(ref) != nil { return true } } return false } func (p *Planner) planExprTerm(e *ast.Expr, iter planiter) error { // NOTE(sr): There are only three cases to deal with when we see a naked term // in a rule body: // 1. it's `false` -- so we can stop, emit a break stmt // 2. it's a var or a ref, like `input` or `data.foo.bar`, where we need to // check what it ends up being (at run time) to determine if it's not false // 3. it's any other term -- `true`, a string, a number, whatever. We can skip // that, since it's true-ish enough for evaluating the rule body. switch t := e.Terms.(*ast.Term).Value.(type) { case ast.Boolean: if !bool(t) { // We know this cannot hold, break unconditionally p.appendStmt(&ir.BreakStmt{}) return iter() } case ast.Ref, ast.Var: // We don't know these at plan-time return p.planTerm(e.Terms.(*ast.Term), func() error { p.appendStmt(&ir.NotEqualStmt{ A: p.ltarget, B: op(ir.Bool(false)), }) return iter() }) } return iter() } func (p *Planner) planExprEvery(e *ast.Expr, iter planiter) error { every := e.Terms.(*ast.Every) cond0 := p.newLocal() // outer not cond1 := p.newLocal() // inner not // We're using condition variables together with IsDefinedStmt to encode // this: // every x, y in xs { p(x,y) } // ~> p(x1, y1) AND p(x2, y2) AND ... AND p(xn, yn) // ~> NOT (NOT p(x1, y1) OR NOT p(x2, y2) OR ... OR NOT p(xn, yn)) // // cond1 is initialized to 0, and set to TRUE if p(xi, yi) succeeds for // a binding of (xi, yi). We then use IsUndefined to check that this has NOT // happened (NOT p(xi, yi)). // cond0 is initialized to 0, and set to TRUE if cond1 happens to not // be set: it's encoding the NOT ( ... OR ... OR ... ) part of this. p.appendStmt(&ir.ResetLocalStmt{ Target: cond0, }) err := p.planTerm(every.Domain, func() error { return p.planScan(every.Key, func(ir.Local) error { p.appendStmt(&ir.ResetLocalStmt{ Target: cond1, }) nested := &ir.BlockStmt{Blocks: []*ir.Block{{}}} prev := p.curr p.curr = nested.Blocks[0] lval := p.ltarget err := p.planUnifyLocal(lval, every.Value, func() error { return p.planQuery(every.Body, 0, func() error { p.appendStmt(&ir.AssignVarStmt{ Source: op(ir.Bool(true)), Target: cond1, }) return nil }) }) if err != nil { return err } p.curr = prev p.appendStmt(nested) p.appendStmt(&ir.IsUndefinedStmt{ Source: cond1, }) p.appendStmt(&ir.AssignVarStmt{ Source: op(ir.Bool(true)), Target: cond0, }) return nil }) }) if err != nil { return err } p.appendStmt(&ir.IsUndefinedStmt{ Source: cond0, }) return iter() } func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error { operator := e.Operator().String() switch operator { case ast.Equality.Name: return p.planUnify(e.Operand(0), e.Operand(1), iter) default: var relation bool var name string var arity int var void bool var args []ir.Operand var err error operands := e.Operands() op := e.Operator() if replacement := p.mocks.Lookup(operator); replacement != nil { switch r := replacement.Value.(type) { case ast.Ref: if !r.HasPrefix(ast.DefaultRootRef) && !r.HasPrefix(ast.InputRootRef) { // replacement is builtin operator = r.String() bi := p.decls[operator] p.externs[operator] = bi // void functions and relations are forbidden; arity validation happened in compiler return p.planExprCallFunc(operator, len(bi.Decl.FuncArgs().Args), void, operands, args, iter) } // replacement is a function (rule) if node := p.rules.Lookup(r); node != nil { if node.Arity() > 0 { p.mocks.Push() // new scope name, err = p.planRules(node.Rules()) if err != nil { return err } p.mocks.Pop() return p.planExprCallFunc(name, node.Arity(), void, operands, p.defaultOperands(), iter) } // arity==0 => replacement is a ref to a rule to be used as value (fallthrough) } } // replacement is a value, or ref if bi, ok := p.decls[operator]; ok { return p.planExprCallValue(replacement, len(bi.Decl.FuncArgs().Args), operands, iter) } if node := p.rules.Lookup(op); node != nil { return p.planExprCallValue(replacement, node.Arity(), operands, iter) } return fmt.Errorf("illegal replacement of operator %q by %v", operator, replacement) // should be unreachable } if node := p.rules.Lookup(op); node != nil { name, err = p.planRules(node.Rules()) if err != nil { return err } arity = node.Arity() args = p.defaultOperands() } else if decl, ok := p.decls[operator]; ok { relation = decl.Relation arity = len(decl.Decl.Args()) void = decl.Decl.Result() == nil name = operator p.externs[operator] = decl } else { return fmt.Errorf("illegal call: unknown operator %q", operator) } if len(operands) < arity || len(operands) > arity+1 { return fmt.Errorf("illegal call: wrong number of operands: got %v, want %v)", len(operands), arity) } if relation { return p.planExprCallRelation(name, arity, operands, args, iter) } return p.planExprCallFunc(name, arity, void, operands, args, iter) } } func (p *Planner) planExprCallRelation(name string, arity int, operands []*ast.Term, args []ir.Operand, iter planiter) error { if len(operands) == arity { return p.planCallArgs(operands, 0, args, func(args []ir.Operand) error { p.ltarget = p.newOperand() ltarget := p.ltarget.Value.(ir.Local) p.appendStmt(&ir.CallStmt{ Func: name, Args: args, Result: ltarget, }) lsize := p.newLocal() p.appendStmt(&ir.LenStmt{ Source: op(ltarget), Target: lsize, }) lzero := p.newLocal() p.appendStmt(&ir.MakeNumberIntStmt{ Value: 0, Target: lzero, }) p.appendStmt(&ir.NotEqualStmt{ A: op(lsize), B: op(lzero), }) return iter() }) } return p.planCallArgs(operands[:len(operands)-1], 0, args, func(args []ir.Operand) error { p.ltarget = p.newOperand() p.appendStmt(&ir.CallStmt{ Func: name, Args: args, Result: p.ltarget.Value.(ir.Local), }) return p.planScanValues(operands[len(operands)-1], func(ir.Local) error { return iter() }) }) } func (p *Planner) planExprCallFunc(name string, arity int, void bool, operands []*ast.Term, args []ir.Operand, iter planiter) error { switch { case len(operands) == arity: // definition: f(x) = y { ... } // call: f(x) # result not captured return p.planCallArgs(operands, 0, args, func(args []ir.Operand) error { p.ltarget = p.newOperand() ltarget := p.ltarget.Value.(ir.Local) p.appendStmt(&ir.CallStmt{ Func: name, Args: args, Result: ltarget, }) if !void { p.appendStmt(&ir.NotEqualStmt{ A: op(ltarget), B: op(ir.Bool(false)), }) } return iter() }) case len(operands) == arity+1: // definition: f(x) = y { ... } // call: f(x, 1) # caller captures result return p.planCallArgs(operands[:len(operands)-1], 0, args, func(args []ir.Operand) error { result := p.newLocal() p.appendStmt(&ir.CallStmt{ Func: name, Args: args, Result: result, }) return p.planUnifyLocal(op(result), operands[len(operands)-1], iter) }) default: return fmt.Errorf("impossible replacement, arity mismatch") } } func (p *Planner) planExprCallValue(value *ast.Term, arity int, operands []*ast.Term, iter planiter) error { switch { case len(operands) == arity: // call: f(x) # result not captured return p.planCallArgs(operands, 0, nil, func([]ir.Operand) error { p.ltarget = p.newOperand() return p.planTerm(value, func() error { p.appendStmt(&ir.NotEqualStmt{ A: p.ltarget, B: op(ir.Bool(false)), }) return iter() }) }) case len(operands) == arity+1: // call: f(x, 1) # caller captures result return p.planCallArgs(operands[:len(operands)-1], 0, nil, func([]ir.Operand) error { p.ltarget = p.newOperand() return p.planTerm(value, func() error { return p.planUnifyLocal(p.ltarget, operands[len(operands)-1], iter) }) }) default: return fmt.Errorf("impossible replacement, arity mismatch") } } func (p *Planner) planCallArgs(terms []*ast.Term, idx int, args []ir.Operand, iter func([]ir.Operand) error) error { if idx >= len(terms) { return iter(args) } return p.planTerm(terms[idx], func() error { args = append(args, p.ltarget) return p.planCallArgs(terms, idx+1, args, iter) }) } func (p *Planner) planUnify(a, b *ast.Term, iter planiter) error { switch va := a.Value.(type) { case ast.Null, ast.Boolean, ast.Number, ast.String, ast.Ref, ast.Set, *ast.SetComprehension, *ast.ArrayComprehension, *ast.ObjectComprehension: return p.planTerm(a, func() error { return p.planUnifyLocal(p.ltarget, b, iter) }) case ast.Var: return p.planUnifyVar(va, b, iter) case *ast.Array: switch vb := b.Value.(type) { case ast.Var: return p.planUnifyVar(vb, a, iter) case ast.Ref: return p.planTerm(b, func() error { return p.planUnifyLocalArray(p.ltarget, va, iter) }) case *ast.ArrayComprehension: return p.planTerm(b, func() error { return p.planUnifyLocalArray(p.ltarget, va, iter) }) case *ast.Array: if va.Len() == vb.Len() { return p.planUnifyArraysRec(va, vb, 0, iter) } return nil } case ast.Object: switch vb := b.Value.(type) { case ast.Var: return p.planUnifyVar(vb, a, iter) case ast.Ref: return p.planTerm(b, func() error { return p.planUnifyLocalObject(p.ltarget, va, iter) }) case ast.Object: return p.planUnifyObjects(va, vb, iter) } } return fmt.Errorf("not implemented: unify(%v, %v)", a, b) } func (p *Planner) planUnifyVar(a ast.Var, b *ast.Term, iter planiter) error { if la, ok := p.vars.GetOp(a); ok { return p.planUnifyLocal(la, b, iter) } return p.planTerm(b, func() error { // `a` may have become known while planning b, like in `a = input.x[a]` la, ok := p.vars.GetOp(a) if ok { p.appendStmt(&ir.EqualStmt{ A: la, B: p.ltarget, }) } else { target := p.newLocal() p.vars.Put(a, target) p.appendStmt(&ir.AssignVarStmt{ Source: p.ltarget, Target: target, }) } return iter() }) } func (p *Planner) planUnifyLocal(a ir.Operand, b *ast.Term, iter planiter) error { // special cases: when a is StringIndex or Bool, and b is a string, or a bool, we can shortcut switch va := a.Value.(type) { case ir.StringIndex: if vb, ok := b.Value.(ast.String); ok { if va != ir.StringIndex(p.getStringConst(string(vb))) { p.appendStmt(&ir.BreakStmt{}) } return iter() // Don't plan EqualStmt{A: "foo", B: "foo"} } case ir.Bool: if vb, ok := b.Value.(ast.Boolean); ok { if va != ir.Bool(vb) { p.appendStmt(&ir.BreakStmt{}) } return iter() // Don't plan EqualStmt{A: true, B: true} } } switch vb := b.Value.(type) { case ast.Null, ast.Boolean, ast.Number, ast.String, ast.Ref, ast.Set, *ast.SetComprehension, *ast.ArrayComprehension, *ast.ObjectComprehension: return p.planTerm(b, func() error { p.appendStmt(&ir.EqualStmt{ A: a, B: p.ltarget, }) return iter() }) case ast.Var: if lv, ok := p.vars.GetOp(vb); ok { p.appendStmt(&ir.EqualStmt{ A: a, B: lv, }) return iter() } lv := p.newLocal() p.vars.Put(vb, lv) p.appendStmt(&ir.AssignVarStmt{ Source: a, Target: lv, }) return iter() case *ast.Array: return p.planUnifyLocalArray(a, vb, iter) case ast.Object: return p.planUnifyLocalObject(a, vb, iter) } return fmt.Errorf("not implemented: unifyLocal(%v, %v)", a, b) } func (p *Planner) planUnifyLocalArray(a ir.Operand, b *ast.Array, iter planiter) error { p.appendStmt(&ir.IsArrayStmt{ Source: a, }) blen := p.newLocal() alen := p.newLocal() p.appendStmt(&ir.LenStmt{ Source: a, Target: alen, }) p.appendStmt(&ir.MakeNumberIntStmt{ Value: int64(b.Len()), Target: blen, }) p.appendStmt(&ir.EqualStmt{ A: op(alen), B: op(blen), }) lkey := p.newLocal() p.appendStmt(&ir.MakeNumberIntStmt{ Target: lkey, }) lval := p.newLocal() return p.planUnifyLocalArrayRec(a, 0, b, lkey, lval, iter) } func (p *Planner) planUnifyLocalArrayRec(a ir.Operand, index int, b *ast.Array, lkey, lval ir.Local, iter planiter) error { if b.Len() == index { return iter() } p.appendStmt(&ir.AssignIntStmt{ Value: int64(index), Target: lkey, }) p.appendStmt(&ir.DotStmt{ Source: a, Key: op(lkey), Target: lval, }) return p.planUnifyLocal(op(lval), b.Elem(index), func() error { return p.planUnifyLocalArrayRec(a, index+1, b, lkey, lval, iter) }) } func (p *Planner) planUnifyObjects(a, b ast.Object, iter planiter) error { if a.Len() != b.Len() { return nil } aKeys := ast.NewSet(a.Keys()...) bKeys := ast.NewSet(b.Keys()...) unifyKeys := aKeys.Diff(bKeys) // planUnifyObjectsRec will actually set variables where possible; // planUnifyObjectLocals only asserts equality -- it won't assign // to any local return p.planUnifyObjectsRec(a, b, aKeys.Intersect(bKeys).Slice(), 0, func() error { if unifyKeys.Len() == 0 { return iter() } return p.planObject(a, func() error { la := p.ltarget return p.planObject(b, func() error { return p.planUnifyObjectLocals(la, p.ltarget, unifyKeys.Slice(), 0, p.newLocal(), p.newLocal(), iter) }) }) }) } func (p *Planner) planUnifyObjectLocals(a, b ir.Operand, keys []*ast.Term, index int, l0, l1 ir.Local, iter planiter) error { if index == len(keys) { return iter() } return p.planTerm(keys[index], func() error { p.appendStmt(&ir.DotStmt{ Source: a, Key: p.ltarget, Target: l0, }) p.appendStmt(&ir.DotStmt{ Source: b, Key: p.ltarget, Target: l1, }) p.appendStmt(&ir.EqualStmt{ A: op(l0), B: op(l1), }) return p.planUnifyObjectLocals(a, b, keys, index+1, l0, l1, iter) }) } func (p *Planner) planUnifyLocalObject(a ir.Operand, b ast.Object, iter planiter) error { p.appendStmt(&ir.IsObjectStmt{ Source: a, }) blen := p.newLocal() alen := p.newLocal() p.appendStmt(&ir.LenStmt{ Source: a, Target: alen, }) p.appendStmt(&ir.MakeNumberIntStmt{ Value: int64(b.Len()), Target: blen, }) p.appendStmt(&ir.EqualStmt{ A: op(alen), B: op(blen), }) lval := p.newLocal() bkeys := b.Keys() return p.planUnifyLocalObjectRec(a, 0, bkeys, b, lval, iter) } func (p *Planner) planUnifyLocalObjectRec(a ir.Operand, index int, keys []*ast.Term, b ast.Object, lval ir.Local, iter planiter) error { if index == len(keys) { return iter() } return p.planTerm(keys[index], func() error { p.appendStmt(&ir.DotStmt{ Source: a, Key: p.ltarget, Target: lval, }) return p.planUnifyLocal(op(lval), b.Get(keys[index]), func() error { return p.planUnifyLocalObjectRec(a, index+1, keys, b, lval, iter) }) }) } func (p *Planner) planUnifyArraysRec(a, b *ast.Array, index int, iter planiter) error { if index == a.Len() { return iter() } return p.planUnify(a.Elem(index), b.Elem(index), func() error { return p.planUnifyArraysRec(a, b, index+1, iter) }) } func (p *Planner) planUnifyObjectsRec(a, b ast.Object, keys []*ast.Term, index int, iter planiter) error { if index == len(keys) { return iter() } aval := a.Get(keys[index]) bval := b.Get(keys[index]) if aval == nil || bval == nil { return nil } return p.planUnify(aval, bval, func() error { return p.planUnifyObjectsRec(a, b, keys, index+1, iter) }) } func (p *Planner) planTerm(t *ast.Term, iter planiter) error { return p.planValue(t.Value, t.Loc(), iter) } func (p *Planner) planValue(t ast.Value, loc *ast.Location, iter planiter) error { switch v := t.(type) { case ast.Null: return p.planNull(v, iter) case ast.Boolean: return p.planBoolean(v, iter) case ast.Number: return p.planNumber(v, iter) case ast.String: return p.planString(v, iter) case ast.Var: return p.planVar(v, iter) case ast.Ref: return p.planRef(v, iter) case *ast.Array: return p.planArray(v, iter) case ast.Object: return p.planObject(v, iter) case ast.Set: return p.planSet(v, iter) case *ast.SetComprehension: p.loc = loc return p.planSetComprehension(v, iter) case *ast.ArrayComprehension: p.loc = loc return p.planArrayComprehension(v, iter) case *ast.ObjectComprehension: p.loc = loc return p.planObjectComprehension(v, iter) default: return fmt.Errorf("%v term not implemented", ast.TypeName(v)) } } func (p *Planner) planNull(null ast.Null, iter planiter) error { target := p.newLocal() p.appendStmt(&ir.MakeNullStmt{ Target: target, }) p.ltarget = op(target) return iter() } func (p *Planner) planBoolean(b ast.Boolean, iter planiter) error { p.ltarget = op(ir.Bool(b)) return iter() } func (p *Planner) planNumber(num ast.Number, iter planiter) error { index := p.getStringConst(string(num)) target := p.newLocal() p.appendStmt(&ir.MakeNumberRefStmt{ Index: index, Target: target, }) p.ltarget = op(target) return iter() } func (p *Planner) planString(str ast.String, iter planiter) error { p.ltarget = op(ir.StringIndex(p.getStringConst(string(str)))) return iter() } func (p *Planner) planVar(v ast.Var, iter planiter) error { p.ltarget = op(p.vars.GetOrElse(v, func() ir.Local { return p.newLocal() })) return iter() } func (p *Planner) planArray(arr *ast.Array, iter planiter) error { larr := p.newLocal() p.appendStmt(&ir.MakeArrayStmt{ Capacity: int32(arr.Len()), Target: larr, }) return p.planArrayRec(arr, 0, larr, iter) } func (p *Planner) planArrayRec(arr *ast.Array, index int, larr ir.Local, iter planiter) error { if index == arr.Len() { p.ltarget = op(larr) return iter() } return p.planTerm(arr.Elem(index), func() error { p.appendStmt(&ir.ArrayAppendStmt{ Value: p.ltarget, Array: larr, }) return p.planArrayRec(arr, index+1, larr, iter) }) } func (p *Planner) planObject(obj ast.Object, iter planiter) error { lobj := p.newLocal() p.appendStmt(&ir.MakeObjectStmt{ Target: lobj, }) return p.planObjectRec(obj, 0, obj.Keys(), lobj, iter) } func (p *Planner) planObjectRec(obj ast.Object, index int, keys []*ast.Term, lobj ir.Local, iter planiter) error { if index == len(keys) { p.ltarget = op(lobj) return iter() } return p.planTerm(keys[index], func() error { lkey := p.ltarget return p.planTerm(obj.Get(keys[index]), func() error { lval := p.ltarget p.appendStmt(&ir.ObjectInsertStmt{ Key: lkey, Value: lval, Object: lobj, }) return p.planObjectRec(obj, index+1, keys, lobj, iter) }) }) } func (p *Planner) planSet(set ast.Set, iter planiter) error { lset := p.newLocal() p.appendStmt(&ir.MakeSetStmt{ Target: lset, }) return p.planSetRec(set, 0, set.Slice(), lset, iter) } func (p *Planner) planSetRec(set ast.Set, index int, elems []*ast.Term, lset ir.Local, iter planiter) error { if index == len(elems) { p.ltarget = op(lset) return iter() } return p.planTerm(elems[index], func() error { p.appendStmt(&ir.SetAddStmt{ Value: p.ltarget, Set: lset, }) return p.planSetRec(set, index+1, elems, lset, iter) }) } func (p *Planner) planSetComprehension(sc *ast.SetComprehension, iter planiter) error { lset := p.newLocal() p.appendStmt(&ir.MakeSetStmt{ Target: lset, }) return p.planComprehension(sc.Body, func() error { return p.planTerm(sc.Term, func() error { p.appendStmt(&ir.SetAddStmt{ Value: p.ltarget, Set: lset, }) return nil }) }, lset, iter) } func (p *Planner) planArrayComprehension(ac *ast.ArrayComprehension, iter planiter) error { larr := p.newLocal() p.appendStmt(&ir.MakeArrayStmt{ Target: larr, }) return p.planComprehension(ac.Body, func() error { return p.planTerm(ac.Term, func() error { p.appendStmt(&ir.ArrayAppendStmt{ Value: p.ltarget, Array: larr, }) return nil }) }, larr, iter) } func (p *Planner) planObjectComprehension(oc *ast.ObjectComprehension, iter planiter) error { lobj := p.newLocal() p.appendStmt(&ir.MakeObjectStmt{ Target: lobj, }) return p.planComprehension(oc.Body, func() error { return p.planTerm(oc.Key, func() error { lkey := p.ltarget return p.planTerm(oc.Value, func() error { p.appendStmt(&ir.ObjectInsertOnceStmt{ Key: lkey, Value: p.ltarget, Object: lobj, }) return nil }) }) }, lobj, iter) } func (p *Planner) planComprehension(body ast.Body, closureIter planiter, target ir.Local, iter planiter) error { // Variables that have been introduced in this comprehension have // no effect on other parts of the policy, so they'll be dropped // below. p.vars.Push(map[ast.Var]ir.Local{}) prev := p.curr block := &ir.Block{} p.curr = block ploc := p.loc if err := p.planQuery(body, 0, closureIter); err != nil { return err } p.curr = prev p.loc = ploc p.vars.Pop() p.appendStmt(&ir.BlockStmt{ Blocks: []*ir.Block{ block, }, }) p.ltarget = op(target) return iter() } func (p *Planner) planRef(ref ast.Ref, iter planiter) error { head, ok := ref[0].Value.(ast.Var) if !ok { return fmt.Errorf("illegal ref: non-var head") } if head.Compare(ast.DefaultRootDocument.Value) == 0 { virtual := p.rules.Get(ref[0].Value) base := &baseptr{local: p.vars.GetOrEmpty(ast.DefaultRootDocument.Value.(ast.Var))} return p.planRefData(virtual, base, ref, 1, iter) } if ref.Equal(ast.InputRootRef) { p.appendStmt(&ir.IsDefinedStmt{ Source: p.vars.GetOrEmpty(ast.InputRootDocument.Value.(ast.Var)), }) } p.ltarget, ok = p.vars.GetOp(head) if !ok { return fmt.Errorf("illegal ref: unsafe head") } return p.planRefRec(ref, 1, iter) } func (p *Planner) planRefRec(ref ast.Ref, index int, iter planiter) error { if len(ref) == index { return iter() } if !p.unseenVars(ref[index]) { return p.planDot(ref[index], func() error { return p.planRefRec(ref, index+1, iter) }) } return p.planScan(ref[index], func(ir.Local) error { return p.planRefRec(ref, index+1, iter) }) } type baseptr struct { local ir.Local path ast.Ref } // planRefData implements the virtual document model by generating the value of // the ref parameter and invoking the iterator with the planner target set to // the virtual document and all variables in the reference assigned. func (p *Planner) planRefData(virtual *ruletrie, base *baseptr, ref ast.Ref, index int, iter planiter) error { // Early-exit if the end of the reference has been reached. In this case the // plan has to materialize the full extent of the referenced value. if index >= len(ref) { return p.planRefDataExtent(virtual, base, iter) } // On the first iteration, we check if this can be optimized using a // CallDynamicStatement // NOTE(sr): we do it on the first index because later on, the recursion // on subtrees of virtual already lost parts of the path we've taken. if index == 1 && virtual != nil { rulesets, path, index, optimize := p.optimizeLookup(virtual, ref) if optimize { // If there are no rulesets in a situation that otherwise would // allow for a call_indirect optimization, then there's nothing // to do for this ref, except scanning the base document. if len(rulesets) == 0 { return p.planRefData(nil, base, ref, 1, iter) // ignore index returned by optimizeLookup } // plan rules for _, rules := range rulesets { if _, err := p.planRules(rules); err != nil { return err } } // We're planning a structure like this: // // block res // block a // block b // block c1 // opa_mapping_lookup || br c1 // call_indirect || br res // br b // end // block c2 // dot i || br c2 // dot i+1 || br c2 // br b // end // br a // end // dot i+2 || br res // dot i+3 || br res // end; a // [add_to_result_set] // end; res // // We have to do it like this because the dot IR stmts // are compiled to `br 0`, the innermost block, if they // fail. // The "c2" block will construct the reference from `data` // only, in case the mapping lookup doesn't yield a func // to call_dynamic. ltarget := p.newLocal() p.ltarget = op(ltarget) prev := p.curr callDynBlock := &ir.Block{} // "c1" in the sketch p.curr = callDynBlock p.appendStmt(&ir.CallDynamicStmt{ Args: []ir.Local{ p.vars.GetOrEmpty(ast.InputRootDocument.Value.(ast.Var)), p.vars.GetOrEmpty(ast.DefaultRootDocument.Value.(ast.Var)), }, Path: path, Result: ltarget, }) p.appendStmt(&ir.BreakStmt{Index: 1}) dotBlock := &ir.Block{} // "c2" in the sketch above p.curr = dotBlock p.ltarget = p.vars.GetOpOrEmpty(ast.DefaultRootDocument.Value.(ast.Var)) return p.planRefRec(ref[:index+1], 1, func() error { p.appendStmt(&ir.AssignVarStmt{ Source: p.ltarget, Target: ltarget, }) p.appendStmt(&ir.BreakStmt{Index: 1}) p.ltarget = op(ltarget) outerBlock := &ir.Block{Stmts: []ir.Stmt{ &ir.BlockStmt{Blocks: []*ir.Block{ { // block "b" in the sketch above Stmts: []ir.Stmt{ &ir.BlockStmt{Blocks: []*ir.Block{callDynBlock, dotBlock}}, &ir.BreakStmt{Index: 2}}, }, }}, }} p.curr = outerBlock if err := p.planRefRec(ref, index+1, iter); err != nil { // rest of the ref return err } p.curr = prev p.appendStmt(&ir.BlockStmt{Blocks: []*ir.Block{outerBlock}}) return nil }) } } // If the reference operand is ground then either continue to the next // operand or invoke the function for the rule referred to by this operand. if ref[index].IsGround() { var vchild *ruletrie var rules []*ast.Rule if virtual != nil { vchild = virtual.Get(ref[index].Value) rules = vchild.Rules() // hit or miss } if len(rules) > 0 { p.ltarget = p.newOperand() funcName, err := p.planRules(rules) if err != nil { return err } p.appendStmt(&ir.CallStmt{ Func: funcName, Args: p.defaultOperands(), Result: p.ltarget.Value.(ir.Local), }) return p.planRefRec(ref, index+1, iter) } bchild := *base bchild.path = append(bchild.path, ref[index]) return p.planRefData(vchild, &bchild, ref, index+1, iter) } exclude := ast.NewSet() // The planner does not support dynamic dispatch so generate blocks to // evaluate each of the rulesets on the child nodes. if virtual != nil { stmt := &ir.BlockStmt{} for _, child := range virtual.Children() { block := &ir.Block{} prev := p.curr p.curr = block key := ast.NewTerm(child) exclude.Add(key) // Assignments in each block due to local unification must be undone // so create a new frame that will be popped after this key is // processed. p.vars.Push(map[ast.Var]ir.Local{}) if err := p.planTerm(key, func() error { return p.planUnifyLocal(p.ltarget, ref[index], func() error { // Create a copy of the reference with this operand plugged. // This will result in evaluation of the rulesets on the // child node. cpy := ref.Copy() cpy[index] = key return p.planRefData(virtual, base, cpy, index, iter) }) }); err != nil { return err } p.vars.Pop() p.curr = prev stmt.Blocks = append(stmt.Blocks, block) } p.appendStmt(stmt) } // If the virtual tree was enumerated then we do not want to enumerate base // trees that are rooted at the same key as any of the virtual sub trees. To // prevent this we build a set of keys that are to be excluded and check // below during the base scan. var lexclude *ir.Operand if exclude.Len() > 0 { if err := p.planSet(exclude, func() error { v := p.ltarget lexclude = &v return nil }); err != nil { return err } // Perform a scan of the base documents starting from the location referred // to by the 'path' data pointer. Use the `lexclude` set to avoid revisiting // sub trees. p.ltarget = op(base.local) return p.planRefRec(base.path, 0, func() error { return p.planScan(ref[index], func(lkey ir.Local) error { if lexclude != nil { lignore := p.newLocal() p.appendStmt(&ir.NotStmt{ Block: p.blockWithStmt(&ir.DotStmt{ Source: *lexclude, Key: op(lkey), Target: lignore, })}) } // Assume that virtual sub trees have been visited already so // recurse without the virtual node. return p.planRefData(nil, &baseptr{local: p.ltarget.Value.(ir.Local)}, ref, index+1, iter) }) }) } // There is nothing to exclude, so we do the same thing done above, but // use planRefRec to avoid the scan if ref[index] is ground or seen. p.ltarget = op(base.local) base.path = append(base.path, ref[index]) return p.planRefRec(base.path, 0, func() error { return p.planRefData(nil, &baseptr{local: p.ltarget.Value.(ir.Local)}, ref, index+1, iter) }) } // planRefDataExtent generates the full extent (combined) of the base and // virtual nodes and then invokes the iterator with the planner target set to // the full extent. func (p *Planner) planRefDataExtent(virtual *ruletrie, base *baseptr, iter planiter) error { vtarget := p.newLocal() // Generate the virtual document out of rules contained under the virtual // node (recursively). This document will _ONLY_ contain values generated by // rules. No base document values will be included. if virtual != nil { p.appendStmt(&ir.MakeObjectStmt{ Target: vtarget, }) anyKeyNonGround := false for _, key := range virtual.Children() { if !key.IsGround() { anyKeyNonGround = true break } } if anyKeyNonGround { var rules []*ast.Rule for _, key := range virtual.Children() { // TODO(sr): skip functions rules = append(rules, virtual.Get(key).Rules()...) } funcName, err := p.planRules(rules) if err != nil { return err } // Add leaf to object if defined. b := &ir.Block{} p.appendStmtToBlock(&ir.CallStmt{ Func: funcName, Args: p.defaultOperands(), Result: vtarget, }, b) p.appendStmt(&ir.BlockStmt{Blocks: []*ir.Block{b}}) } else { for _, key := range virtual.Children() { child := virtual.Get(key) // Skip functions. if child.Arity() > 0 { continue } rules := child.Rules() err := p.planValue(key, nil, func() error { lkey := p.ltarget // Build object hierarchy depth-first. if len(rules) == 0 { return p.planRefDataExtent(child, nil, func() error { p.appendStmt(&ir.ObjectInsertStmt{ Object: vtarget, Key: lkey, Value: p.ltarget, }) return nil }) } // Generate virtual document for leaf. lvalue := p.newLocal() funcName, err := p.planRules(rules) if err != nil { return err } // Add leaf to object if defined. b := &ir.Block{} p.appendStmtToBlock(&ir.CallStmt{ Func: funcName, Args: p.defaultOperands(), Result: lvalue, }, b) p.appendStmtToBlock(&ir.ObjectInsertStmt{ Object: vtarget, Key: lkey, Value: op(lvalue), }, b) p.appendStmt(&ir.BlockStmt{Blocks: []*ir.Block{b}}) return nil }) if err != nil { return err } } } // At this point vtarget refers to the full extent of the virtual // document at ref. If the base pointer is unset, no further processing // is required. if base == nil { p.ltarget = op(vtarget) return iter() } } // Obtain the base document value and merge (recursively) with the virtual // document value above if needed. prev := p.curr p.curr = &ir.Block{} p.ltarget = op(base.local) target := p.newLocal() err := p.planRefRec(base.path, 0, func() error { if virtual == nil { target = p.ltarget.Value.(ir.Local) } else { stmt := &ir.ObjectMergeStmt{ A: p.ltarget.Value.(ir.Local), B: vtarget, Target: target, } p.appendStmt(stmt) } p.appendStmt(&ir.BreakStmt{Index: 1}) return nil }) if err != nil { return err } inner := p.curr // Fallback to virtual document value if base document is undefined. // Otherwise, this block is undefined. p.curr = &ir.Block{} p.appendStmt(&ir.BlockStmt{Blocks: []*ir.Block{inner}}) if virtual != nil { p.appendStmt(&ir.AssignVarStmt{ Source: op(vtarget), Target: target, }) } else { p.appendStmt(&ir.BreakStmt{Index: 1}) } outer := p.curr p.curr = prev p.appendStmt(&ir.BlockStmt{Blocks: []*ir.Block{outer}}) // At this point, target refers to either the full extent of the base and // virtual documents at ref or just the base document at ref. p.ltarget = op(target) return iter() } func (p *Planner) planDot(key *ast.Term, iter planiter) error { source := p.ltarget return p.planTerm(key, func() error { target := p.newLocal() p.appendStmt(&ir.DotStmt{ Source: source, Key: p.ltarget, Target: target, }) p.ltarget = op(target) return iter() }) } type scaniter func(ir.Local) error func (p *Planner) planScan(key *ast.Term, iter scaniter) error { scan := &ir.ScanStmt{ Source: p.ltarget.Value.(ir.Local), Key: p.newLocal(), Value: p.newLocal(), Block: &ir.Block{}, } prev := p.curr p.curr = scan.Block if err := p.planUnifyLocal(op(scan.Key), key, func() error { p.ltarget = op(scan.Value) return iter(scan.Key) }); err != nil { return err } p.curr = prev p.appendStmt(scan) return nil } func (p *Planner) planScanValues(val *ast.Term, iter scaniter) error { scan := &ir.ScanStmt{ Source: p.ltarget.Value.(ir.Local), Key: p.newLocal(), Value: p.newLocal(), Block: &ir.Block{}, } prev := p.curr p.curr = scan.Block if err := p.planUnifyLocal(op(scan.Value), val, func() error { p.ltarget = op(scan.Value) return iter(scan.Value) }); err != nil { return err } p.curr = prev p.appendStmt(scan) return nil } type termsliceiter func([]ir.Operand) error func (p *Planner) planTermSlice(terms []*ast.Term, iter termsliceiter) error { return p.planTermSliceRec(terms, make([]ir.Operand, len(terms)), 0, iter) } func (p *Planner) planTermSliceRec(terms []*ast.Term, locals []ir.Operand, index int, iter termsliceiter) error { if index >= len(terms) { return iter(locals) } return p.planTerm(terms[index], func() error { locals[index] = p.ltarget return p.planTermSliceRec(terms, locals, index+1, iter) }) } func (p *Planner) planExterns() error { p.policy.Static.BuiltinFuncs = make([]*ir.BuiltinFunc, 0, len(p.externs)) for name, decl := range p.externs { p.policy.Static.BuiltinFuncs = append(p.policy.Static.BuiltinFuncs, &ir.BuiltinFunc{Name: name, Decl: decl.Decl}) } sort.Slice(p.policy.Static.BuiltinFuncs, func(i, j int) bool { return p.policy.Static.BuiltinFuncs[i].Name < p.policy.Static.BuiltinFuncs[j].Name }) return nil } func (p *Planner) getStringConst(s string) int { index, ok := p.strings[s] if !ok { index = len(p.policy.Static.Strings) p.policy.Static.Strings = append(p.policy.Static.Strings, &ir.StringConst{ Value: s, }) p.strings[s] = index } return index } func (p *Planner) getFileConst(s string) int { index, ok := p.files[s] if !ok { index = len(p.policy.Static.Files) p.policy.Static.Files = append(p.policy.Static.Files, &ir.StringConst{ Value: s, }) p.files[s] = index } return index } func (p *Planner) appendStmt(s ir.Stmt) { p.appendStmtToBlock(s, p.curr) } func (p *Planner) appendStmtToBlock(s ir.Stmt, b *ir.Block) { if p.loc != nil { str := p.loc.File if str == "" { str = `` } s.SetLocation(p.getFileConst(str), p.loc.Row, p.loc.Col, str, string(p.loc.Text)) } b.Stmts = append(b.Stmts, s) } func (p *Planner) blockWithStmt(s ir.Stmt) *ir.Block { b := &ir.Block{} p.appendStmtToBlock(s, b) return b } func (p *Planner) appendBlock(b *ir.Block) { p.plan.Blocks = append(p.plan.Blocks, b) } func (p *Planner) appendFunc(f *ir.Func) { p.policy.Funcs.Funcs = append(p.policy.Funcs.Funcs, f) } func (p *Planner) newLocal() ir.Local { x := p.lnext p.lnext++ return x } func (p *Planner) newOperand() ir.Operand { return op(p.newLocal()) } func rewrittenVar(vars map[ast.Var]ast.Var, k ast.Var) ast.Var { rw, ok := vars[k] if !ok { return k } return rw } // optimizeLookup returns a set of rulesets and required statements planning // the locals (strings) needed with the used local variables, and the index // into ref's parth that is still to be planned; if the passed ref's vars // allow for optimization using CallDynamicStmt. // // It's possible if all of these conditions hold: // - all vars in ref have been seen // - all ground terms (strings) match some child key on their respective // layer of the ruletrie // - there are no child trees left (only rulesets) if we're done checking // ref // // The last condition is necessary because we don't deal with _which key a // var actually matched_ -- so we don't know which subtree to evaluate // with the results. func (p *Planner) optimizeLookup(t *ruletrie, ref ast.Ref) ([][]*ast.Rule, []ir.Operand, int, bool) { dont := func() ([][]*ast.Rule, []ir.Operand, int, bool) { return nil, nil, 0, false } if t == nil { p.debugf("no optimization of %s: trie is nil", ref) return dont() } nodes := []*ruletrie{t} opt := false var index int // ref[0] is data, ignore outer: for i := 1; i < len(ref); i++ { index = i r := ref[i] var nextNodes []*ruletrie switch r := r.Value.(type) { case ast.Var: // check if it's been "seen" before _, ok := p.vars.Get(r) if !ok { p.debugf("no optimization of %s: ref[%d] is unseen var: %v", ref, i, r) return dont() } opt = true // take all children, they might match for _, node := range nodes { for _, child := range node.Children() { if node := node.Get(child); node != nil { nextNodes = append(nextNodes, node) } } } case ast.String: // take all children that either match or have a var key for _, node := range nodes { if node := node.Get(r); node != nil { nextNodes = append(nextNodes, node) } } default: p.debugf("no optimization of %s: ref[%d] is type %T", ref, i, r) // TODO(sr): think more about this return dont() } nodes = nextNodes // if all nodes have rules() > 0, abort ref check and optimize // NOTE(sr): for a.b[c] = ... and a.b.d = ..., we stop at a.b, as its rules() // will collect the children rules // We keep the "all nodes have 0 children" check since it's cheaper and might // let us break, too. all := 0 for _, node := range nodes { all += node.ChildrenCount() } if all == 0 { p.debugf("ref %s: all nodes have 0 children, break", ref[0:index+1]) break } // NOTE(sr): we only need this check for the penultimate part: // When planning the ref data.pkg.a[input.x][input.y], // We want to capture this situation: // a.b[c] := "x" if c := "c" // a.b.d := "y" // // Not this: // a.b[c] := "x" if c := "c" // a.d := "y" // since the length doesn't add up. Even if input.x was "d", the second // rule (a.d) wouldn't contribute anything to the result, since we cannot // "dot it". if index == len(ref)-2 { for _, node := range nodes { anyNonGround := false for _, r := range node.Rules() { anyNonGround = anyNonGround || !r.Ref().IsGround() } if anyNonGround { p.debugf("ref %s: at least one node has 1+ non-ground ref rules, break", ref[0:index+1]) break outer } } } } var res [][]*ast.Rule // if there hasn't been any var, we're not making things better by // introducing CallDynamicStmt if !opt { p.debugf("no optimization of %s: no vars seen before trie descend encountered no children", ref) return dont() } for _, node := range nodes { // we're done with ref, check if there's only ruleset leaves; collect rules if index == len(ref)-1 { if len(node.Rules()) == 0 && node.ChildrenCount() > 0 { p.debugf("no optimization of %s: unbalanced ruletrie", ref) return dont() } } if rules := node.Rules(); len(rules) > 0 { res = append(res, rules) } } if len(res) == 0 { p.debugf("ref %s: nothing to plan, no rule leaves", ref[0:index+1]) return nil, nil, index, true } var path []ir.Operand // plan generation path = append(path, op(ir.StringIndex(p.getStringConst(fmt.Sprintf("g%d", p.funcs.gen()))))) for i := 1; i <= index; i++ { switch r := ref[i].Value.(type) { case ast.Var: lv, ok := p.vars.GetOp(r) if !ok { p.debugf("no optimization of %s: ref[%d] not a seen var: %v", ref, i, ref[i]) return dont() } path = append(path, lv) case ast.String: path = append(path, op(ir.StringIndex(p.getStringConst(string(r))))) } } return res, path, index, true } func (p *Planner) unseenVars(t *ast.Term) bool { unseen := false // any var unseen? ast.WalkVars(t, func(v ast.Var) bool { if !unseen { _, exists := p.vars.Get(v) if !exists { unseen = true } } return unseen }) return unseen } func (p *Planner) defaultOperands() []ir.Operand { return []ir.Operand{ p.vars.GetOpOrEmpty(ast.InputRootDocument.Value.(ast.Var)), p.vars.GetOpOrEmpty(ast.DefaultRootDocument.Value.(ast.Var)), } } func (p *Planner) isFunction(r ast.Ref) bool { if node := p.rules.Lookup(r); node != nil { return node.Arity() > 0 } return false } func op(v ir.Val) ir.Operand { return ir.Operand{Value: v} }