// Copyright 2016 The OPA Authors. All rights reserved. // Use of this source code is governed by an Apache2 // license that can be found in the LICENSE file. package ast import ( "fmt" "sort" "strconv" "strings" "github.com/open-policy-agent/opa/metrics" "github.com/open-policy-agent/opa/util" ) // CompileErrorLimitDefault is the default number errors a compiler will allow before // exiting. const CompileErrorLimitDefault = 10 var errLimitReached = NewError(CompileErr, nil, "error limit reached") // Compiler contains the state of a compilation process. type Compiler struct { // Errors contains errors that occurred during the compilation process. // If there are one or more errors, the compilation process is considered // "failed". Errors Errors // Modules contains the compiled modules. The compiled modules are the // output of the compilation process. If the compilation process failed, // there is no guarantee about the state of the modules. Modules map[string]*Module // ModuleTree organizes the modules into a tree where each node is keyed by // an element in the module's package path. E.g., given modules containing // the following package directives: "a", "a.b", "a.c", and "a.b", the // resulting module tree would be: // // root // | // +--- data (no modules) // | // +--- a (1 module) // | // +--- b (2 modules) // | // +--- c (1 module) // ModuleTree *ModuleTreeNode // RuleTree organizes rules into a tree where each node is keyed by an // element in the rule's path. The rule path is the concatenation of the // containing package and the stringified rule name. E.g., given the // following module: // // package ex // p[1] { true } // p[2] { true } // q = true // // root // | // +--- data (no rules) // | // +--- ex (no rules) // | // +--- p (2 rules) // | // +--- q (1 rule) RuleTree *TreeNode // Graph contains dependencies between rules. An edge (u,v) is added to the // graph if rule 'u' refers to the virtual document defined by 'v'. Graph *Graph // TypeEnv holds type information for values inferred by the compiler. TypeEnv *TypeEnv // RewrittenVars is a mapping of variables that have been rewritten // with the key being the generated name and value being the original. RewrittenVars map[Var]Var localvargen *localVarGenerator moduleLoader ModuleLoader ruleIndices *util.HashMap stages []struct { name string metricName string f func() } maxErrs int sorted []string // list of sorted module names pathExists func([]string) (bool, error) after map[string][]CompilerStageDefinition metrics metrics.Metrics builtins map[string]*Builtin unsafeBuiltinsMap map[string]struct{} } // CompilerStage defines the interface for stages in the compiler. type CompilerStage func(*Compiler) *Error // CompilerStageDefinition defines a compiler stage type CompilerStageDefinition struct { Name string MetricName string Stage CompilerStage } // QueryContext contains contextual information for running an ad-hoc query. // // Ad-hoc queries can be run in the context of a package and imports may be // included to provide concise access to data. type QueryContext struct { Package *Package Imports []*Import } // NewQueryContext returns a new QueryContext object. func NewQueryContext() *QueryContext { return &QueryContext{} } // WithPackage sets the pkg on qc. func (qc *QueryContext) WithPackage(pkg *Package) *QueryContext { if qc == nil { qc = NewQueryContext() } qc.Package = pkg return qc } // WithImports sets the imports on qc. func (qc *QueryContext) WithImports(imports []*Import) *QueryContext { if qc == nil { qc = NewQueryContext() } qc.Imports = imports return qc } // Copy returns a deep copy of qc. func (qc *QueryContext) Copy() *QueryContext { if qc == nil { return nil } cpy := *qc if cpy.Package != nil { cpy.Package = qc.Package.Copy() } cpy.Imports = make([]*Import, len(qc.Imports)) for i := range qc.Imports { cpy.Imports[i] = qc.Imports[i].Copy() } return &cpy } // QueryCompiler defines the interface for compiling ad-hoc queries. type QueryCompiler interface { // Compile should be called to compile ad-hoc queries. The return value is // the compiled version of the query. Compile(q Body) (Body, error) // TypeEnv returns the type environment built after running type checking // on the query. TypeEnv() *TypeEnv // WithContext sets the QueryContext on the QueryCompiler. Subsequent calls // to Compile will take the QueryContext into account. WithContext(qctx *QueryContext) QueryCompiler // WithUnsafeBuiltins sets the built-in functions to treat as unsafe and not // allow inside of queries. By default the query compiler inherits the // compiler's unsafe built-in functions. This function allows callers to // override that set. If an empty (non-nil) map is provided, all built-ins // are allowed. WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler // WithStageAfter registers a stage to run during query compilation after // the named stage. WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler // RewrittenVars maps generated vars in the compiled query to vars from the // parsed query. For example, given the query "input := 1" the rewritten // query would be "__local0__ = 1". The mapping would then be {__local0__: input}. RewrittenVars() map[Var]Var } // QueryCompilerStage defines the interface for stages in the query compiler. type QueryCompilerStage func(QueryCompiler, Body) (Body, error) // QueryCompilerStageDefinition defines a QueryCompiler stage type QueryCompilerStageDefinition struct { Name string MetricName string Stage QueryCompilerStage } const compileStageMetricPrefex = "ast_compile_stage_" // NewCompiler returns a new empty compiler. func NewCompiler() *Compiler { c := &Compiler{ Modules: map[string]*Module{}, TypeEnv: NewTypeEnv(), RewrittenVars: map[Var]Var{}, ruleIndices: util.NewHashMap(func(a, b util.T) bool { r1, r2 := a.(Ref), b.(Ref) return r1.Equal(r2) }, func(x util.T) int { return x.(Ref).Hash() }), maxErrs: CompileErrorLimitDefault, after: map[string][]CompilerStageDefinition{}, unsafeBuiltinsMap: map[string]struct{}{}, } c.ModuleTree = NewModuleTree(nil) c.RuleTree = NewRuleTree(c.ModuleTree) // Initialize the compiler with the statically compiled built-in functions. // If the caller customizes the compiler, a copy will be made. c.builtins = BuiltinMap checker := newTypeChecker() c.TypeEnv = checker.checkLanguageBuiltins(nil, c.builtins) c.stages = []struct { name string metricName string f func() }{ // Reference resolution should run first as it may be used to lazily // load additional modules. If any stages run before resolution, they // need to be re-run after resolution. {"ResolveRefs", "compile_stage_resolve_refs", c.resolveAllRefs}, // The local variable generator must be initialized after references are // resolved and the dynamic module loader has run but before subsequent // stages that need to generate variables. {"InitLocalVarGen", "compile_stage_init_local_var_gen", c.initLocalVarGen}, {"RewriteLocalVars", "compile_stage_rewrite_local_vars", c.rewriteLocalVars}, {"RewriteExprTerms", "compile_stage_rewrite_expr_terms", c.rewriteExprTerms}, {"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree}, {"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree}, {"SetGraph", "compile_stage_set_graph", c.setGraph}, {"RewriteComprehensionTerms", "compile_stage_rewrite_comprehension_terms", c.rewriteComprehensionTerms}, {"RewriteRefsInHead", "compile_stage_rewrite_refs_in_head", c.rewriteRefsInHead}, {"RewriteWithValues", "compile_stage_rewrite_with_values", c.rewriteWithModifiers}, {"CheckRuleConflicts", "compile_stage_check_rule_conflicts", c.checkRuleConflicts}, {"CheckUndefinedFuncs", "compile_stage_check_undefined_funcs", c.checkUndefinedFuncs}, {"CheckSafetyRuleHeads", "compile_stage_check_safety_rule_heads", c.checkSafetyRuleHeads}, {"CheckSafetyRuleBodies", "compile_stage_check_safety_rule_bodies", c.checkSafetyRuleBodies}, {"RewriteEquals", "compile_stage_rewrite_equals", c.rewriteEquals}, {"RewriteDynamicTerms", "compile_stage_rewrite_dynamic_terms", c.rewriteDynamicTerms}, {"CheckRecursion", "compile_stage_check_recursion", c.checkRecursion}, {"CheckTypes", "compile_stage_check_types", c.checkTypes}, {"CheckUnsafeBuiltins", "compile_state_check_unsafe_builtins", c.checkUnsafeBuiltins}, {"BuildRuleIndices", "compile_stage_rebuild_indices", c.buildRuleIndices}, } return c } // SetErrorLimit sets the number of errors the compiler can encounter before it // quits. Zero or a negative number indicates no limit. func (c *Compiler) SetErrorLimit(limit int) *Compiler { c.maxErrs = limit return c } // WithPathConflictsCheck enables base-virtual document conflict // detection. The compiler will check that rules don't overlap with // paths that exist as determined by the provided callable. func (c *Compiler) WithPathConflictsCheck(fn func([]string) (bool, error)) *Compiler { c.pathExists = fn return c } // WithStageAfter registers a stage to run during compilation after // the named stage. func (c *Compiler) WithStageAfter(after string, stage CompilerStageDefinition) *Compiler { c.after[after] = append(c.after[after], stage) return c } // WithMetrics will set a metrics.Metrics and be used for profiling // the Compiler instance. func (c *Compiler) WithMetrics(metrics metrics.Metrics) *Compiler { c.metrics = metrics return c } // WithBuiltins adds a set of custom built-in functions to the compiler. func (c *Compiler) WithBuiltins(builtins map[string]*Builtin) *Compiler { if len(builtins) == 0 { return c } cpy := make(map[string]*Builtin, len(c.builtins)+len(builtins)) for k, v := range c.builtins { cpy[k] = v } for k, v := range builtins { cpy[k] = v } c.builtins = cpy // Build type env for custom functions and wrap existing one. checker := newTypeChecker() c.TypeEnv = checker.checkLanguageBuiltins(c.TypeEnv, builtins) return c } // WithUnsafeBuiltins will add all built-ins in the map to the "blacklist". func (c *Compiler) WithUnsafeBuiltins(unsafeBuiltins map[string]struct{}) *Compiler { for name := range unsafeBuiltins { c.unsafeBuiltinsMap[name] = struct{}{} } return c } // QueryCompiler returns a new QueryCompiler object. func (c *Compiler) QueryCompiler() QueryCompiler { return newQueryCompiler(c) } // Compile runs the compilation process on the input modules. The compiled // version of the modules and associated data structures are stored on the // compiler. If the compilation process fails for any reason, the compiler will // contain a slice of errors. func (c *Compiler) Compile(modules map[string]*Module) { c.Modules = make(map[string]*Module, len(modules)) for k, v := range modules { c.Modules[k] = v.Copy() c.sorted = append(c.sorted, k) } sort.Strings(c.sorted) c.compile() } // Failed returns true if a compilation error has been encountered. func (c *Compiler) Failed() bool { return len(c.Errors) > 0 } // GetArity returns the number of args a function referred to by ref takes. If // ref refers to built-in function, the built-in declaration is consulted, // otherwise, the ref is used to perform a ruleset lookup. func (c *Compiler) GetArity(ref Ref) int { if bi := c.builtins[ref.String()]; bi != nil { return len(bi.Decl.Args()) } rules := c.GetRulesExact(ref) if len(rules) == 0 { return -1 } return len(rules[0].Head.Args) } // GetRulesExact returns a slice of rules referred to by the reference. // // E.g., given the following module: // // package a.b.c // // p[k] = v { ... } # rule1 // p[k1] = v1 { ... } # rule2 // // The following calls yield the rules on the right. // // GetRulesExact("data.a.b.c.p") => [rule1, rule2] // GetRulesExact("data.a.b.c.p.x") => nil // GetRulesExact("data.a.b.c") => nil func (c *Compiler) GetRulesExact(ref Ref) (rules []*Rule) { node := c.RuleTree for _, x := range ref { if node = node.Child(x.Value); node == nil { return nil } } return extractRules(node.Values) } // GetRulesForVirtualDocument returns a slice of rules that produce the virtual // document referred to by the reference. // // E.g., given the following module: // // package a.b.c // // p[k] = v { ... } # rule1 // p[k1] = v1 { ... } # rule2 // // The following calls yield the rules on the right. // // GetRulesForVirtualDocument("data.a.b.c.p") => [rule1, rule2] // GetRulesForVirtualDocument("data.a.b.c.p.x") => [rule1, rule2] // GetRulesForVirtualDocument("data.a.b.c") => nil func (c *Compiler) GetRulesForVirtualDocument(ref Ref) (rules []*Rule) { node := c.RuleTree for _, x := range ref { if node = node.Child(x.Value); node == nil { return nil } if len(node.Values) > 0 { return extractRules(node.Values) } } return extractRules(node.Values) } // GetRulesWithPrefix returns a slice of rules that share the prefix ref. // // E.g., given the following module: // // package a.b.c // // p[x] = y { ... } # rule1 // p[k] = v { ... } # rule2 // q { ... } # rule3 // // The following calls yield the rules on the right. // // GetRulesWithPrefix("data.a.b.c.p") => [rule1, rule2] // GetRulesWithPrefix("data.a.b.c.p.a") => nil // GetRulesWithPrefix("data.a.b.c") => [rule1, rule2, rule3] func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) { node := c.RuleTree for _, x := range ref { if node = node.Child(x.Value); node == nil { return nil } } var acc func(node *TreeNode) acc = func(node *TreeNode) { rules = append(rules, extractRules(node.Values)...) for _, child := range node.Children { if child.Hide { continue } acc(child) } } acc(node) return rules } func extractRules(s []util.T) (rules []*Rule) { for _, r := range s { rules = append(rules, r.(*Rule)) } return rules } // GetRules returns a slice of rules that are referred to by ref. // // E.g., given the following module: // // package a.b.c // // p[x] = y { q[x] = y; ... } # rule1 // q[x] = y { ... } # rule2 // // The following calls yield the rules on the right. // // GetRules("data.a.b.c.p") => [rule1] // GetRules("data.a.b.c.p.x") => [rule1] // GetRules("data.a.b.c.q") => [rule2] // GetRules("data.a.b.c") => [rule1, rule2] // GetRules("data.a.b.d") => nil func (c *Compiler) GetRules(ref Ref) (rules []*Rule) { set := map[*Rule]struct{}{} for _, rule := range c.GetRulesForVirtualDocument(ref) { set[rule] = struct{}{} } for _, rule := range c.GetRulesWithPrefix(ref) { set[rule] = struct{}{} } for rule := range set { rules = append(rules, rule) } return rules } // GetRulesDynamic returns a slice of rules that could be referred to by a ref. // When parts of the ref are statically known, we use that information to narrow // down which rules the ref could refer to, but in the most general case this // will be an over-approximation. // // E.g., given the following modules: // // package a.b.c // // r1 = 1 # rule1 // // and: // // package a.d.c // // r2 = 2 # rule2 // // The following calls yield the rules on the right. // // GetRulesDynamic("data.a[x].c[y]") => [rule1, rule2] // GetRulesDynamic("data.a[x].c.r2") => [rule2] // GetRulesDynamic("data.a.b[x][y]") => [rule1] func (c *Compiler) GetRulesDynamic(ref Ref) (rules []*Rule) { node := c.RuleTree set := map[*Rule]struct{}{} var walk func(node *TreeNode, i int) walk = func(node *TreeNode, i int) { if i >= len(ref) { // We've reached the end of the reference and want to collect everything // under this "prefix". node.DepthFirst(func(descendant *TreeNode) bool { insertRules(set, descendant.Values) return descendant.Hide }) } else if i == 0 || IsConstant(ref[i].Value) { // The head of the ref is always grounded. In case another part of the // ref is also grounded, we can lookup the exact child. If it's not found // we can immediately return... if child := node.Child(ref[i].Value); child == nil { return } else if len(child.Values) > 0 { // If there are any rules at this position, it's what the ref would // refer to. We can just append those and stop here. insertRules(set, child.Values) } else { // Otherwise, we continue using the child node. walk(child, i+1) } } else { // This part of the ref is a dynamic term. We can't know what it refers // to and will just need to try all of the children. for _, child := range node.Children { if child.Hide { continue } insertRules(set, child.Values) walk(child, i+1) } } } walk(node, 0) for rule := range set { rules = append(rules, rule) } return rules } // Utility: add all rule values to the set. func insertRules(set map[*Rule]struct{}, rules []util.T) { for _, rule := range rules { set[rule.(*Rule)] = struct{}{} } } // RuleIndex returns a RuleIndex built for the rule set referred to by path. // The path must refer to the rule set exactly, i.e., given a rule set at path // data.a.b.c.p, refs data.a.b.c.p.x and data.a.b.c would not return a // RuleIndex built for the rule. func (c *Compiler) RuleIndex(path Ref) RuleIndex { r, ok := c.ruleIndices.Get(path) if !ok { return nil } return r.(RuleIndex) } // ModuleLoader defines the interface that callers can implement to enable lazy // loading of modules during compilation. type ModuleLoader func(resolved map[string]*Module) (parsed map[string]*Module, err error) // WithModuleLoader sets f as the ModuleLoader on the compiler. // // The compiler will invoke the ModuleLoader after resolving all references in // the current set of input modules. The ModuleLoader can return a new // collection of parsed modules that are to be included in the compilation // process. This process will repeat until the ModuleLoader returns an empty // collection or an error. If an error is returned, compilation will stop // immediately. func (c *Compiler) WithModuleLoader(f ModuleLoader) *Compiler { c.moduleLoader = f return c } // buildRuleIndices constructs indices for rules. func (c *Compiler) buildRuleIndices() { c.RuleTree.DepthFirst(func(node *TreeNode) bool { if len(node.Values) == 0 { return false } index := newBaseDocEqIndex(func(ref Ref) bool { return isVirtual(c.RuleTree, ref.GroundPrefix()) }) if rules := extractRules(node.Values); index.Build(rules) { c.ruleIndices.Put(rules[0].Path(), index) } return false }) } // checkRecursion ensures that there are no recursive definitions, i.e., there are // no cycles in the Graph. func (c *Compiler) checkRecursion() { eq := func(a, b util.T) bool { return a.(*Rule) == b.(*Rule) } c.RuleTree.DepthFirst(func(node *TreeNode) bool { for _, rule := range node.Values { for node := rule.(*Rule); node != nil; node = node.Else { c.checkSelfPath(node.Loc(), eq, node, node) } } return false }) } func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b util.T) { tr := NewGraphTraversal(c.Graph) if p := util.DFSPath(tr, eq, a, b); len(p) > 0 { n := []string{} for _, x := range p { n = append(n, astNodeToString(x)) } c.err(NewError(RecursionErr, loc, "rule %v is recursive: %v", astNodeToString(a), strings.Join(n, " -> "))) } } func astNodeToString(x interface{}) string { switch x := x.(type) { case *Rule: return string(x.Head.Name) default: panic("not reached") } } // checkRuleConflicts ensures that rules definitions are not in conflict. func (c *Compiler) checkRuleConflicts() { c.RuleTree.DepthFirst(func(node *TreeNode) bool { if len(node.Values) == 0 { return false } kinds := map[DocKind]struct{}{} defaultRules := 0 arities := map[int]struct{}{} declared := false for _, rule := range node.Values { r := rule.(*Rule) kinds[r.Head.DocKind()] = struct{}{} arities[len(r.Head.Args)] = struct{}{} if r.Head.Assign { declared = true } if r.Default { defaultRules++ } } name := Var(node.Key.(String)) if declared && len(node.Values) > 1 { c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "rule named %v redeclared at %v", name, node.Values[1].(*Rule).Loc())) } else if len(kinds) > 1 || len(arities) > 1 { c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules named %v found", name)) } else if defaultRules > 1 { c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules named %s found", name)) } return false }) if c.pathExists != nil { for _, err := range CheckPathConflicts(c, c.pathExists) { c.err(err) } } c.ModuleTree.DepthFirst(func(node *ModuleTreeNode) bool { for _, mod := range node.Modules { for _, rule := range mod.Rules { if childNode, ok := node.Children[String(rule.Head.Name)]; ok { for _, childMod := range childNode.Modules { msg := fmt.Sprintf("%v conflicts with rule defined at %v", childMod.Package, rule.Loc()) c.err(NewError(TypeErr, mod.Package.Loc(), msg)) } } } } return false }) } func (c *Compiler) checkUndefinedFuncs() { for _, name := range c.sorted { m := c.Modules[name] for _, err := range checkUndefinedFuncs(m, c.GetArity) { c.err(err) } } } func checkUndefinedFuncs(x interface{}, arity func(Ref) int) Errors { var errs Errors WalkExprs(x, func(expr *Expr) bool { if !expr.IsCall() { return false } ref := expr.Operator() if arity(ref) >= 0 { return false } errs = append(errs, NewError(TypeErr, expr.Loc(), "undefined function %v", ref)) return true }) return errs } // checkSafetyRuleBodies ensures that variables appearing in negated expressions or non-target // positions of built-in expressions will be bound when evaluating the rule from left // to right, re-ordering as necessary. func (c *Compiler) checkSafetyRuleBodies() { for _, name := range c.sorted { m := c.Modules[name] WalkRules(m, func(r *Rule) bool { safe := ReservedVars.Copy() safe.Update(r.Head.Args.Vars()) r.Body = c.checkBodySafety(safe, m, r.Body) return false }) } } func (c *Compiler) checkBodySafety(safe VarSet, m *Module, b Body) Body { reordered, unsafe := reorderBodyForSafety(c.builtins, c.GetArity, safe, b) if errs := safetyErrorSlice(unsafe); len(errs) > 0 { for _, err := range errs { c.err(err) } return b } return reordered } var safetyCheckVarVisitorParams = VarVisitorParams{ SkipRefCallHead: true, SkipClosures: true, } // checkSafetyRuleHeads ensures that variables appearing in the head of a // rule also appear in the body. func (c *Compiler) checkSafetyRuleHeads() { for _, name := range c.sorted { m := c.Modules[name] WalkRules(m, func(r *Rule) bool { safe := r.Body.Vars(safetyCheckVarVisitorParams) safe.Update(r.Head.Args.Vars()) unsafe := r.Head.Vars().Diff(safe) for v := range unsafe { if !v.IsGenerated() { c.err(NewError(UnsafeVarErr, r.Loc(), "var %v is unsafe", v)) } } return false }) } } // checkTypes runs the type checker on all rules. The type checker builds a // TypeEnv that is stored on the compiler. func (c *Compiler) checkTypes() { // Recursion is caught in earlier step, so this cannot fail. sorted, _ := c.Graph.Sort() checker := newTypeChecker().WithVarRewriter(rewriteVarsInRef(c.RewrittenVars)) env, errs := checker.CheckTypes(c.TypeEnv, sorted) for _, err := range errs { c.err(err) } c.TypeEnv = env } func (c *Compiler) checkUnsafeBuiltins() { for _, name := range c.sorted { errs := checkUnsafeBuiltins(c.unsafeBuiltinsMap, c.Modules[name]) for _, err := range errs { c.err(err) } } } func (c *Compiler) runStage(metricName string, f func()) { if c.metrics != nil { c.metrics.Timer(metricName).Start() defer c.metrics.Timer(metricName).Stop() } f() } func (c *Compiler) runStageAfter(metricName string, s CompilerStage) *Error { if c.metrics != nil { c.metrics.Timer(metricName).Start() defer c.metrics.Timer(metricName).Stop() } return s(c) } func (c *Compiler) compile() { defer func() { if r := recover(); r != nil && r != errLimitReached { panic(r) } }() for _, s := range c.stages { c.runStage(s.metricName, s.f) if c.Failed() { return } for _, s := range c.after[s.name] { err := c.runStageAfter(s.MetricName, s.Stage) if err != nil { c.err(err) } } } } func (c *Compiler) err(err *Error) { if c.maxErrs > 0 && len(c.Errors) >= c.maxErrs { c.Errors = append(c.Errors, errLimitReached) panic(errLimitReached) } c.Errors = append(c.Errors, err) } func (c *Compiler) getExports() *util.HashMap { rules := util.NewHashMap(func(a, b util.T) bool { r1 := a.(Ref) r2 := a.(Ref) return r1.Equal(r2) }, func(v util.T) int { return v.(Ref).Hash() }) for _, name := range c.sorted { mod := c.Modules[name] rv, ok := rules.Get(mod.Package.Path) if !ok { rv = []Var{} } rvs := rv.([]Var) for _, rule := range mod.Rules { rvs = append(rvs, rule.Head.Name) } rules.Put(mod.Package.Path, rvs) } return rules } // resolveAllRefs resolves references in expressions to their fully qualified values. // // For instance, given the following module: // // package a.b // import data.foo.bar // p[x] { bar[_] = x } // // The reference "bar[_]" would be resolved to "data.foo.bar[_]". func (c *Compiler) resolveAllRefs() { rules := c.getExports() for _, name := range c.sorted { mod := c.Modules[name] var ruleExports []Var if x, ok := rules.Get(mod.Package.Path); ok { ruleExports = x.([]Var) } globals := getGlobals(mod.Package, ruleExports, mod.Imports) WalkRules(mod, func(rule *Rule) bool { err := resolveRefsInRule(globals, rule) if err != nil { c.err(NewError(CompileErr, rule.Location, err.Error())) } return false }) // Once imports have been resolved, they are no longer needed. mod.Imports = nil } if c.moduleLoader != nil { parsed, err := c.moduleLoader(c.Modules) if err != nil { c.err(NewError(CompileErr, nil, err.Error())) return } if len(parsed) == 0 { return } for id, module := range parsed { c.Modules[id] = module.Copy() c.sorted = append(c.sorted, id) } sort.Strings(c.sorted) c.resolveAllRefs() } } func (c *Compiler) initLocalVarGen() { c.localvargen = newLocalVarGeneratorForModuleSet(c.sorted, c.Modules) } func (c *Compiler) rewriteComprehensionTerms() { f := newEqualityFactory(c.localvargen) for _, name := range c.sorted { mod := c.Modules[name] rewriteComprehensionTerms(f, mod) } } func (c *Compiler) rewriteExprTerms() { for _, name := range c.sorted { mod := c.Modules[name] WalkRules(mod, func(rule *Rule) bool { rewriteExprTermsInHead(c.localvargen, rule) rule.Body = rewriteExprTermsInBody(c.localvargen, rule.Body) return false }) } } // rewriteTermsInHead will rewrite rules so that the head does not contain any // terms that require evaluation (e.g., refs or comprehensions). If the key or // value contains or more of these terms, the key or value will be moved into // the body and assigned to a new variable. The new variable will replace the // key or value in the head. // // For instance, given the following rule: // // p[{"foo": data.foo[i]}] { i < 100 } // // The rule would be re-written as: // // p[__local0__] { i < 100; __local0__ = {"foo": data.foo[i]} } func (c *Compiler) rewriteRefsInHead() { f := newEqualityFactory(c.localvargen) for _, name := range c.sorted { mod := c.Modules[name] WalkRules(mod, func(rule *Rule) bool { if requiresEval(rule.Head.Key) { expr := f.Generate(rule.Head.Key) rule.Head.Key = expr.Operand(0) rule.Body.Append(expr) } if requiresEval(rule.Head.Value) { expr := f.Generate(rule.Head.Value) rule.Head.Value = expr.Operand(0) rule.Body.Append(expr) } for i := 0; i < len(rule.Head.Args); i++ { if requiresEval(rule.Head.Args[i]) { expr := f.Generate(rule.Head.Args[i]) rule.Head.Args[i] = expr.Operand(0) rule.Body.Append(expr) } } return false }) } } func (c *Compiler) rewriteEquals() { for _, name := range c.sorted { mod := c.Modules[name] rewriteEquals(mod) } } func (c *Compiler) rewriteDynamicTerms() { f := newEqualityFactory(c.localvargen) for _, name := range c.sorted { mod := c.Modules[name] WalkRules(mod, func(rule *Rule) bool { rule.Body = rewriteDynamics(f, rule.Body) return false }) } } func (c *Compiler) rewriteLocalVars() { for _, name := range c.sorted { mod := c.Modules[name] gen := c.localvargen WalkRules(mod, func(rule *Rule) bool { var errs Errors // Rewrite assignments contained in head of rule. Assignments can // occur in rule head if they're inside a comprehension. Note, // assigned vars in comprehensions in the head will be rewritten // first to preserve scoping rules. For example: // // p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 } // // This behaviour is consistent scoping inside the body. For example: // // p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] } WalkTerms(rule.Head, func(term *Term) bool { stop := false stack := newLocalDeclaredVars() switch v := term.Value.(type) { case *ArrayComprehension: errs = rewriteDeclaredVarsInArrayComprehension(gen, stack, v, errs) stop = true case *SetComprehension: errs = rewriteDeclaredVarsInSetComprehension(gen, stack, v, errs) stop = true case *ObjectComprehension: errs = rewriteDeclaredVarsInObjectComprehension(gen, stack, v, errs) stop = true } for k, v := range stack.rewritten { c.RewrittenVars[k] = v } return stop }) for _, err := range errs { c.err(err) } // Rewrite assignments in body. used := NewVarSet() if rule.Head.Key != nil { used.Update(rule.Head.Key.Vars()) } if rule.Head.Value != nil { used.Update(rule.Head.Value.Vars()) } stack := newLocalDeclaredVars() c.rewriteLocalArgVars(gen, stack, rule) body, declared, errs := rewriteLocalVars(gen, stack, used, rule.Body) for _, err := range errs { c.err(err) } // For rewritten vars use the collection of all variables that // were in the stack at some point in time. for k, v := range stack.rewritten { c.RewrittenVars[k] = v } rule.Body = body // Rewrite vars in head that refer to locally declared vars in the body. vis := NewGenericVisitor(func(x interface{}) bool { term, ok := x.(*Term) if !ok { return false } switch v := term.Value.(type) { case Object: // Make a copy of the object because the keys may be mutated. cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) { if vark, ok := k.Value.(Var); ok { if gv, ok := declared[vark]; ok { k = k.Copy() k.Value = gv } } return k, v, nil }) term.Value = cpy case Var: if gv, ok := declared[v]; ok { term.Value = gv return true } } return false }) vis.Walk(rule.Head.Args) if rule.Head.Key != nil { vis.Walk(rule.Head.Key) } if rule.Head.Value != nil { vis.Walk(rule.Head.Value) } return false }) } } func (c *Compiler) rewriteLocalArgVars(gen *localVarGenerator, stack *localDeclaredVars, rule *Rule) { vis := &ruleArgLocalRewriter{ stack: stack, gen: gen, } for i := range rule.Head.Args { Walk(vis, rule.Head.Args[i]) } for i := range vis.errs { c.err(vis.errs[i]) } } type ruleArgLocalRewriter struct { stack *localDeclaredVars gen *localVarGenerator errs []*Error } func (vis *ruleArgLocalRewriter) Visit(x interface{}) Visitor { t, ok := x.(*Term) if !ok { return vis } switch v := t.Value.(type) { case Var: gv, ok := vis.stack.Declared(v) if !ok { gv = vis.gen.Generate() vis.stack.Insert(v, gv, argVar) } t.Value = gv return nil case Object: if cpy, err := v.Map(func(k, v *Term) (*Term, *Term, error) { vcpy := v.Copy() Walk(vis, vcpy) return k, vcpy, nil }); err != nil { vis.errs = append(vis.errs, NewError(CompileErr, t.Location, err.Error())) } else { t.Value = cpy } return nil case Null, Boolean, Number, String, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Set: // Scalars are no-ops. Comprehensions are handled above. Sets must not // contain variables. return nil default: // Recurse on refs, arrays, and calls. Any embedded // variables can be rewritten. return vis } } func (c *Compiler) rewriteWithModifiers() { f := newEqualityFactory(c.localvargen) for _, name := range c.sorted { mod := c.Modules[name] t := NewGenericTransformer(func(x interface{}) (interface{}, error) { body, ok := x.(Body) if !ok { return x, nil } body, err := rewriteWithModifiersInBody(c, f, body) if err != nil { c.err(err) } return body, nil }) Transform(t, mod) } } func (c *Compiler) setModuleTree() { c.ModuleTree = NewModuleTree(c.Modules) } func (c *Compiler) setRuleTree() { c.RuleTree = NewRuleTree(c.ModuleTree) } func (c *Compiler) setGraph() { c.Graph = NewGraph(c.Modules, c.GetRulesDynamic) } type queryCompiler struct { compiler *Compiler qctx *QueryContext typeEnv *TypeEnv rewritten map[Var]Var after map[string][]QueryCompilerStageDefinition unsafeBuiltins map[string]struct{} } func newQueryCompiler(compiler *Compiler) QueryCompiler { qc := &queryCompiler{ compiler: compiler, qctx: nil, after: map[string][]QueryCompilerStageDefinition{}, } return qc } func (qc *queryCompiler) WithContext(qctx *QueryContext) QueryCompiler { qc.qctx = qctx return qc } func (qc *queryCompiler) WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler { qc.after[after] = append(qc.after[after], stage) return qc } func (qc *queryCompiler) WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler { qc.unsafeBuiltins = unsafe return qc } func (qc *queryCompiler) RewrittenVars() map[Var]Var { return qc.rewritten } func (qc *queryCompiler) runStage(metricName string, qctx *QueryContext, query Body, s func(*QueryContext, Body) (Body, error)) (Body, error) { if qc.compiler.metrics != nil { qc.compiler.metrics.Timer(metricName).Start() defer qc.compiler.metrics.Timer(metricName).Stop() } return s(qctx, query) } func (qc *queryCompiler) runStageAfter(metricName string, query Body, s QueryCompilerStage) (Body, error) { if qc.compiler.metrics != nil { qc.compiler.metrics.Timer(metricName).Start() defer qc.compiler.metrics.Timer(metricName).Stop() } return s(qc, query) } func (qc *queryCompiler) Compile(query Body) (Body, error) { query = query.Copy() stages := []struct { name string metricName string f func(*QueryContext, Body) (Body, error) }{ {"ResolveRefs", "query_compile_stage_resolve_refs", qc.resolveRefs}, {"RewriteLocalVars", "query_compile_stage_rewrite_local_vars", qc.rewriteLocalVars}, {"RewriteExprTerms", "query_compile_stage_rewrite_expr_terms", qc.rewriteExprTerms}, {"RewriteComprehensionTerms", "query_compile_stage_rewrite_comprehension_terms", qc.rewriteComprehensionTerms}, {"RewriteWithValues", "query_compile_stage_rewrite_with_values", qc.rewriteWithModifiers}, {"CheckUndefinedFuncs", "query_compile_stage_check_undefined_funcs", qc.checkUndefinedFuncs}, {"CheckSafety", "query_compile_stage_check_safety", qc.checkSafety}, {"RewriteDynamicTerms", "query_compile_stage_rewrite_dynamic_terms", qc.rewriteDynamicTerms}, {"CheckTypes", "query_compile_stage_check_types", qc.checkTypes}, {"CheckUnsafeBuiltins", "query_compile_stage_check_unsafe_builtins", qc.checkUnsafeBuiltins}, } qctx := qc.qctx.Copy() for _, s := range stages { var err error query, err = qc.runStage(s.metricName, qctx, query, s.f) if err != nil { return nil, qc.applyErrorLimit(err) } for _, s := range qc.after[s.name] { query, err = qc.runStageAfter(s.MetricName, query, s.Stage) if err != nil { return nil, qc.applyErrorLimit(err) } } } return query, nil } func (qc *queryCompiler) TypeEnv() *TypeEnv { return qc.typeEnv } func (qc *queryCompiler) applyErrorLimit(err error) error { if errs, ok := err.(Errors); ok { if qc.compiler.maxErrs > 0 && len(errs) > qc.compiler.maxErrs { err = append(errs[:qc.compiler.maxErrs], errLimitReached) } } return err } func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error) { var globals map[Var]Ref if qctx != nil && qctx.Package != nil { var ruleExports []Var rules := qc.compiler.getExports() if exist, ok := rules.Get(qctx.Package.Path); ok { ruleExports = exist.([]Var) } globals = getGlobals(qctx.Package, ruleExports, qc.qctx.Imports) qctx.Imports = nil } ignore := &declaredVarStack{declaredVars(body)} return resolveRefsInBody(globals, ignore, body), nil } func (qc *queryCompiler) rewriteComprehensionTerms(_ *QueryContext, body Body) (Body, error) { gen := newLocalVarGenerator("q", body) f := newEqualityFactory(gen) node, err := rewriteComprehensionTerms(f, body) if err != nil { return nil, err } return node.(Body), nil } func (qc *queryCompiler) rewriteDynamicTerms(_ *QueryContext, body Body) (Body, error) { gen := newLocalVarGenerator("q", body) f := newEqualityFactory(gen) return rewriteDynamics(f, body), nil } func (qc *queryCompiler) rewriteExprTerms(_ *QueryContext, body Body) (Body, error) { gen := newLocalVarGenerator("q", body) return rewriteExprTermsInBody(gen, body), nil } func (qc *queryCompiler) rewriteLocalVars(_ *QueryContext, body Body) (Body, error) { gen := newLocalVarGenerator("q", body) stack := newLocalDeclaredVars() body, _, err := rewriteLocalVars(gen, stack, nil, body) if len(err) != 0 { return nil, err } qc.rewritten = make(map[Var]Var, len(stack.rewritten)) for k, v := range stack.rewritten { // The vars returned during the rewrite will include all seen vars, // even if they're not declared with an assignment operation. We don't // want to include these inside the rewritten set though. qc.rewritten[k] = v } return body, nil } func (qc *queryCompiler) checkUndefinedFuncs(_ *QueryContext, body Body) (Body, error) { if errs := checkUndefinedFuncs(body, qc.compiler.GetArity); len(errs) > 0 { return nil, errs } return body, nil } func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) { safe := ReservedVars.Copy() reordered, unsafe := reorderBodyForSafety(qc.compiler.builtins, qc.compiler.GetArity, safe, body) if errs := safetyErrorSlice(unsafe); len(errs) > 0 { return nil, errs } return reordered, nil } func (qc *queryCompiler) checkTypes(qctx *QueryContext, body Body) (Body, error) { var errs Errors checker := newTypeChecker().WithVarRewriter(rewriteVarsInRef(qc.rewritten, qc.compiler.RewrittenVars)) qc.typeEnv, errs = checker.CheckBody(qc.compiler.TypeEnv, body) if len(errs) > 0 { return nil, errs } return body, nil } func (qc *queryCompiler) checkUnsafeBuiltins(qctx *QueryContext, body Body) (Body, error) { var unsafe map[string]struct{} if qc.unsafeBuiltins != nil { unsafe = qc.unsafeBuiltins } else { unsafe = qc.compiler.unsafeBuiltinsMap } errs := checkUnsafeBuiltins(unsafe, body) if len(errs) > 0 { return nil, errs } return body, nil } func (qc *queryCompiler) rewriteWithModifiers(qctx *QueryContext, body Body) (Body, error) { f := newEqualityFactory(newLocalVarGenerator("q", body)) body, err := rewriteWithModifiersInBody(qc.compiler, f, body) if err != nil { return nil, Errors{err} } return body, nil } // ModuleTreeNode represents a node in the module tree. The module // tree is keyed by the package path. type ModuleTreeNode struct { Key Value Modules []*Module Children map[Value]*ModuleTreeNode Hide bool } // NewModuleTree returns a new ModuleTreeNode that represents the root // of the module tree populated with the given modules. func NewModuleTree(mods map[string]*Module) *ModuleTreeNode { root := &ModuleTreeNode{ Children: map[Value]*ModuleTreeNode{}, } for _, m := range mods { node := root for i, x := range m.Package.Path { c, ok := node.Children[x.Value] if !ok { var hide bool if i == 1 && x.Value.Compare(SystemDocumentKey) == 0 { hide = true } c = &ModuleTreeNode{ Key: x.Value, Children: map[Value]*ModuleTreeNode{}, Hide: hide, } node.Children[x.Value] = c } node = c } node.Modules = append(node.Modules, m) } return root } // Size returns the number of modules in the tree. func (n *ModuleTreeNode) Size() int { s := len(n.Modules) for _, c := range n.Children { s += c.Size() } return s } // DepthFirst performs a depth-first traversal of the module tree rooted at n. // If f returns true, traversal will not continue to the children of n. func (n *ModuleTreeNode) DepthFirst(f func(node *ModuleTreeNode) bool) { if !f(n) { for _, node := range n.Children { node.DepthFirst(f) } } } // TreeNode represents a node in the rule tree. The rule tree is keyed by // rule path. type TreeNode struct { Key Value Values []util.T Children map[Value]*TreeNode Hide bool } // NewRuleTree returns a new TreeNode that represents the root // of the rule tree populated with the given rules. func NewRuleTree(mtree *ModuleTreeNode) *TreeNode { ruleSets := map[String][]util.T{} // Build rule sets for this package. for _, mod := range mtree.Modules { for _, rule := range mod.Rules { key := String(rule.Head.Name) ruleSets[key] = append(ruleSets[key], rule) } } // Each rule set becomes a leaf node. children := map[Value]*TreeNode{} for key, rules := range ruleSets { children[key] = &TreeNode{ Key: key, Children: nil, Values: rules, } } // Each module in subpackage becomes child node. for _, child := range mtree.Children { children[child.Key] = NewRuleTree(child) } return &TreeNode{ Key: mtree.Key, Values: nil, Children: children, Hide: mtree.Hide, } } // Size returns the number of rules in the tree. func (n *TreeNode) Size() int { s := len(n.Values) for _, c := range n.Children { s += c.Size() } return s } // Child returns n's child with key k. func (n *TreeNode) Child(k Value) *TreeNode { switch k.(type) { case String, Var: return n.Children[k] } return nil } // DepthFirst performs a depth-first traversal of the rule tree rooted at n. If // f returns true, traversal will not continue to the children of n. func (n *TreeNode) DepthFirst(f func(node *TreeNode) bool) { if !f(n) { for _, node := range n.Children { node.DepthFirst(f) } } } // Graph represents the graph of dependencies between rules. type Graph struct { adj map[util.T]map[util.T]struct{} nodes map[util.T]struct{} sorted []util.T } // NewGraph returns a new Graph based on modules. The list function must return // the rules referred to directly by the ref. func NewGraph(modules map[string]*Module, list func(Ref) []*Rule) *Graph { graph := &Graph{ adj: map[util.T]map[util.T]struct{}{}, nodes: map[util.T]struct{}{}, sorted: nil, } // Create visitor to walk a rule AST and add edges to the rule graph for // each dependency. vis := func(a *Rule) *GenericVisitor { stop := false return NewGenericVisitor(func(x interface{}) bool { switch x := x.(type) { case Ref: for _, b := range list(x) { for node := b; node != nil; node = node.Else { graph.addDependency(a, node) } } case *Rule: if stop { // Do not recurse into else clauses (which will be handled // by the outer visitor.) return true } stop = true } return false }) } // Walk over all rules, add them to graph, and build adjencency lists. for _, module := range modules { WalkRules(module, func(a *Rule) bool { graph.addNode(a) vis(a).Walk(a) return false }) } return graph } // Dependencies returns the set of rules that x depends on. func (g *Graph) Dependencies(x util.T) map[util.T]struct{} { return g.adj[x] } // Sort returns a slice of rules sorted by dependencies. If a cycle is found, // ok is set to false. func (g *Graph) Sort() (sorted []util.T, ok bool) { if g.sorted != nil { return g.sorted, true } sort := &graphSort{ sorted: make([]util.T, 0, len(g.nodes)), deps: g.Dependencies, marked: map[util.T]struct{}{}, temp: map[util.T]struct{}{}, } for node := range g.nodes { if !sort.Visit(node) { return nil, false } } g.sorted = sort.sorted return g.sorted, true } func (g *Graph) addDependency(u util.T, v util.T) { if _, ok := g.nodes[u]; !ok { g.addNode(u) } if _, ok := g.nodes[v]; !ok { g.addNode(v) } edges, ok := g.adj[u] if !ok { edges = map[util.T]struct{}{} g.adj[u] = edges } edges[v] = struct{}{} } func (g *Graph) addNode(n util.T) { g.nodes[n] = struct{}{} } type graphSort struct { sorted []util.T deps func(util.T) map[util.T]struct{} marked map[util.T]struct{} temp map[util.T]struct{} } func (sort *graphSort) Marked(node util.T) bool { _, marked := sort.marked[node] return marked } func (sort *graphSort) Visit(node util.T) (ok bool) { if _, ok := sort.temp[node]; ok { return false } if sort.Marked(node) { return true } sort.temp[node] = struct{}{} for other := range sort.deps(node) { if !sort.Visit(other) { return false } } sort.marked[node] = struct{}{} delete(sort.temp, node) sort.sorted = append(sort.sorted, node) return true } // GraphTraversal is a Traversal that understands the dependency graph type GraphTraversal struct { graph *Graph visited map[util.T]struct{} } // NewGraphTraversal returns a Traversal for the dependency graph func NewGraphTraversal(graph *Graph) *GraphTraversal { return &GraphTraversal{ graph: graph, visited: map[util.T]struct{}{}, } } // Edges lists all dependency connections for a given node func (g *GraphTraversal) Edges(x util.T) []util.T { r := []util.T{} for v := range g.graph.Dependencies(x) { r = append(r, v) } return r } // Visited returns whether a node has been visited, setting a node to visited if not func (g *GraphTraversal) Visited(u util.T) bool { _, ok := g.visited[u] g.visited[u] = struct{}{} return ok } type unsafePair struct { Expr *Expr Vars VarSet } type unsafeVarLoc struct { Var Var Loc *Location } type unsafeVars map[*Expr]VarSet func (vs unsafeVars) Add(e *Expr, v Var) { if u, ok := vs[e]; ok { u[v] = struct{}{} } else { vs[e] = VarSet{v: struct{}{}} } } func (vs unsafeVars) Set(e *Expr, s VarSet) { vs[e] = s } func (vs unsafeVars) Update(o unsafeVars) { for k, v := range o { if _, ok := vs[k]; !ok { vs[k] = VarSet{} } vs[k].Update(v) } } func (vs unsafeVars) Vars() (result []unsafeVarLoc) { locs := map[Var]*Location{} // If var appears in multiple sets then pick first by location. for expr, vars := range vs { for v := range vars { if locs[v].Compare(expr.Location) > 0 { locs[v] = expr.Location } } } for v, loc := range locs { result = append(result, unsafeVarLoc{ Var: v, Loc: loc, }) } sort.Slice(result, func(i, j int) bool { return result[i].Loc.Compare(result[j].Loc) < 0 }) return result } func (vs unsafeVars) Slice() (result []unsafePair) { for expr, vs := range vs { result = append(result, unsafePair{ Expr: expr, Vars: vs, }) } return } // reorderBodyForSafety returns a copy of the body ordered such that // left to right evaluation of the body will not encounter unbound variables // in input positions or negated expressions. // // Expressions are added to the re-ordered body as soon as they are considered // safe. If multiple expressions become safe in the same pass, they are added // in their original order. This results in minimal re-ordering of the body. // // If the body cannot be reordered to ensure safety, the second return value // contains a mapping of expressions to unsafe variables in those expressions. func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) { body, unsafe := reorderBodyForClosures(builtins, arity, globals, body) if len(unsafe) != 0 { return nil, unsafe } reordered := Body{} safe := VarSet{} for _, e := range body { for v := range e.Vars(safetyCheckVarVisitorParams) { if globals.Contains(v) { safe.Add(v) } else { unsafe.Add(e, v) } } } for { n := len(reordered) for _, e := range body { if reordered.Contains(e) { continue } safe.Update(outputVarsForExpr(e, builtins, arity, safe)) for v := range unsafe[e] { if safe.Contains(v) { delete(unsafe[e], v) } } if len(unsafe[e]) == 0 { delete(unsafe, e) reordered = append(reordered, e) } } if len(reordered) == n { break } } // Recursively visit closures and perform the safety checks on them. // Update the globals at each expression to include the variables that could // be closed over. g := globals.Copy() for i, e := range reordered { if i > 0 { g.Update(reordered[i-1].Vars(safetyCheckVarVisitorParams)) } vis := &bodySafetyVisitor{ builtins: builtins, arity: arity, current: e, globals: g, unsafe: unsafe, } NewGenericVisitor(vis.Visit).Walk(e) } // Need to reset expression indices as re-ordering may have // changed them. setExprIndices(reordered) return reordered, unsafe } type bodySafetyVisitor struct { builtins map[string]*Builtin arity func(Ref) int current *Expr globals VarSet unsafe unsafeVars } func (vis *bodySafetyVisitor) Visit(x interface{}) bool { switch x := x.(type) { case *Expr: cpy := *vis cpy.current = x switch ts := x.Terms.(type) { case *SomeDecl: NewGenericVisitor(cpy.Visit).Walk(ts) case []*Term: for _, t := range ts { NewGenericVisitor(cpy.Visit).Walk(t) } case *Term: NewGenericVisitor(cpy.Visit).Walk(ts) } for i := range x.With { NewGenericVisitor(cpy.Visit).Walk(x.With[i]) } return true case *ArrayComprehension: vis.checkArrayComprehensionSafety(x) return true case *ObjectComprehension: vis.checkObjectComprehensionSafety(x) return true case *SetComprehension: vis.checkSetComprehensionSafety(x) return true } return false } // Check term for safety. This is analogous to the rule head safety check. func (vis *bodySafetyVisitor) checkComprehensionSafety(tv VarSet, body Body) Body { bv := body.Vars(safetyCheckVarVisitorParams) bv.Update(vis.globals) uv := tv.Diff(bv) for v := range uv { vis.unsafe.Add(vis.current, v) } // Check body for safety, reordering as necessary. r, u := reorderBodyForSafety(vis.builtins, vis.arity, vis.globals, body) if len(u) == 0 { return r } vis.unsafe.Update(u) return body } func (vis *bodySafetyVisitor) checkArrayComprehensionSafety(ac *ArrayComprehension) { ac.Body = vis.checkComprehensionSafety(ac.Term.Vars(), ac.Body) } func (vis *bodySafetyVisitor) checkObjectComprehensionSafety(oc *ObjectComprehension) { tv := oc.Key.Vars() tv.Update(oc.Value.Vars()) oc.Body = vis.checkComprehensionSafety(tv, oc.Body) } func (vis *bodySafetyVisitor) checkSetComprehensionSafety(sc *SetComprehension) { sc.Body = vis.checkComprehensionSafety(sc.Term.Vars(), sc.Body) } // reorderBodyForClosures returns a copy of the body ordered such that // expressions (such as array comprehensions) that close over variables are ordered // after other expressions that contain the same variable in an output position. func reorderBodyForClosures(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) { reordered := Body{} unsafe := unsafeVars{} for { n := len(reordered) for _, e := range body { if reordered.Contains(e) { continue } // Collect vars that are contained in closures within this // expression. vs := VarSet{} WalkClosures(e, func(x interface{}) bool { vis := &VarVisitor{vars: vs} vis.Walk(x) return true }) // Compute vars that are closed over from the body but not yet // contained in the output position of an expression in the reordered // body. These vars are considered unsafe. cv := vs.Intersect(body.Vars(safetyCheckVarVisitorParams)).Diff(globals) uv := cv.Diff(outputVarsForBody(reordered, builtins, arity, globals)) if len(uv) == 0 { reordered = append(reordered, e) delete(unsafe, e) } else { unsafe.Set(e, uv) } } if len(reordered) == n { break } } return reordered, unsafe } func outputVarsForBody(body Body, builtins map[string]*Builtin, arity func(Ref) int, safe VarSet) VarSet { o := safe.Copy() for _, e := range body { o.Update(outputVarsForExpr(e, builtins, arity, o)) } return o.Diff(safe) } func outputVarsForExpr(expr *Expr, builtins map[string]*Builtin, arity func(Ref) int, safe VarSet) VarSet { // Negated expressions must be safe. if expr.Negated { return VarSet{} } // With modifier inputs must be safe. for _, with := range expr.With { unsafe := false WalkVars(with, func(v Var) bool { if !safe.Contains(v) { unsafe = true return true } return false }) if unsafe { return VarSet{} } } if !expr.IsCall() { return outputVarsForExprRefs(expr, safe) } terms := expr.Terms.([]*Term) name := terms[0].String() if b := builtins[name]; b != nil { if b.Name == Equality.Name { return outputVarsForExprEq(expr, safe) } return outputVarsForExprBuiltin(expr, b, safe) } return outputVarsForExprCall(expr, builtins, arity, safe, terms) } func outputVarsForExprBuiltin(expr *Expr, b *Builtin, safe VarSet) VarSet { output := outputVarsForExprRefs(expr, safe) terms := expr.Terms.([]*Term) // Check that all input terms are safe. for i, t := range terms[1:] { if b.IsTargetPos(i) { continue } vis := NewVarVisitor().WithParams(VarVisitorParams{ SkipClosures: true, SkipSets: true, SkipObjectKeys: true, SkipRefHead: true, }) vis.Walk(t) unsafe := vis.Vars().Diff(output).Diff(safe) if len(unsafe) > 0 { return VarSet{} } } // Add vars in target positions to result. for i, t := range terms[1:] { if b.IsTargetPos(i) { vis := NewVarVisitor().WithParams(VarVisitorParams{ SkipRefHead: true, SkipSets: true, SkipObjectKeys: true, SkipClosures: true, }) vis.Walk(t) output.Update(vis.vars) } } return output } func outputVarsForExprEq(expr *Expr, safe VarSet) VarSet { if !validEqAssignArgCount(expr) { return safe } output := outputVarsForExprRefs(expr, safe) output.Update(safe) output.Update(Unify(output, expr.Operand(0), expr.Operand(1))) return output.Diff(safe) } func outputVarsForExprCall(expr *Expr, builtins map[string]*Builtin, arity func(Ref) int, safe VarSet, terms []*Term) VarSet { output := outputVarsForExprRefs(expr, safe) ref, ok := terms[0].Value.(Ref) if !ok { return VarSet{} } numArgs := arity(ref) if numArgs == -1 { return VarSet{} } numInputTerms := numArgs + 1 if numInputTerms >= len(terms) { return output } vis := NewVarVisitor().WithParams(VarVisitorParams{ SkipClosures: true, SkipSets: true, SkipObjectKeys: true, SkipRefHead: true, }) vis.Walk(Args(terms[:numInputTerms])) unsafe := vis.Vars().Diff(output).Diff(safe) if len(unsafe) > 0 { return VarSet{} } vis = NewVarVisitor().WithParams(VarVisitorParams{ SkipRefHead: true, SkipSets: true, SkipObjectKeys: true, SkipClosures: true, }) vis.Walk(Args(terms[numInputTerms:])) output.Update(vis.vars) return output } func outputVarsForExprRefs(expr *Expr, safe VarSet) VarSet { output := VarSet{} WalkRefs(expr, func(r Ref) bool { if safe.Contains(r[0].Value.(Var)) { output.Update(r.OutputVars()) return false } return true }) return output } type equalityFactory struct { gen *localVarGenerator } func newEqualityFactory(gen *localVarGenerator) *equalityFactory { return &equalityFactory{gen} } func (f *equalityFactory) Generate(other *Term) *Expr { term := NewTerm(f.gen.Generate()).SetLocation(other.Location) expr := Equality.Expr(term, other) expr.Generated = true expr.Location = other.Location return expr } type localVarGenerator struct { exclude VarSet suffix string next int } func newLocalVarGeneratorForModuleSet(sorted []string, modules map[string]*Module) *localVarGenerator { exclude := NewVarSet() vis := &VarVisitor{vars: exclude} for _, key := range sorted { vis.Walk(modules[key]) } return &localVarGenerator{exclude: exclude, next: 0} } func newLocalVarGenerator(suffix string, node interface{}) *localVarGenerator { exclude := NewVarSet() vis := &VarVisitor{vars: exclude} vis.Walk(node) return &localVarGenerator{exclude: exclude, suffix: suffix, next: 0} } func (l *localVarGenerator) Generate() Var { for { result := Var("__local" + l.suffix + strconv.Itoa(l.next) + "__") l.next++ if !l.exclude.Contains(result) { return result } } } func getGlobals(pkg *Package, rules []Var, imports []*Import) map[Var]Ref { globals := map[Var]Ref{} // Populate globals with exports within the package. for _, v := range rules { global := append(Ref{}, pkg.Path...) global = append(global, &Term{Value: String(v)}) globals[v] = global } // Populate globals with imports. for _, i := range imports { if len(i.Alias) > 0 { path := i.Path.Value.(Ref) globals[i.Alias] = path } else { path := i.Path.Value.(Ref) if len(path) == 1 { globals[path[0].Value.(Var)] = path } else { v := path[len(path)-1].Value.(String) globals[Var(v)] = path } } } return globals } func requiresEval(x *Term) bool { if x == nil { return false } return ContainsRefs(x) || ContainsComprehensions(x) } func resolveRef(globals map[Var]Ref, ignore *declaredVarStack, ref Ref) Ref { r := Ref{} for i, x := range ref { switch v := x.Value.(type) { case Var: if g, ok := globals[v]; ok && !ignore.Contains(v) { cpy := g.Copy() for i := range cpy { cpy[i].SetLocation(x.Location) } if i == 0 { r = cpy } else { r = append(r, NewTerm(cpy).SetLocation(x.Location)) } } else { r = append(r, x) } case Ref, Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call: r = append(r, resolveRefsInTerm(globals, ignore, x)) default: r = append(r, x) } } return r } func resolveRefsInRule(globals map[Var]Ref, rule *Rule) error { ignore := &declaredVarStack{} vars := NewVarSet() var vis *GenericVisitor var err error // Walk args to collect vars and transform body so that callers can shadow // root documents. vis = NewGenericVisitor(func(x interface{}) bool { if err != nil { return true } switch x := x.(type) { case Var: vars.Add(x) // Object keys cannot be pattern matched so only walk values. case Object: for _, k := range x.Keys() { vis.Walk(x.Get(k)) } // Skip terms that could contain vars that cannot be pattern matched. case Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call: return true case *Term: if _, ok := x.Value.(Ref); ok { if RootDocumentRefs.Contains(x) { // We could support args named input, data, etc. however // this would require rewriting terms in the head and body. // Preventing root document shadowing is simpler, and // arguably, will prevent confusing names from being used. err = fmt.Errorf("args must not shadow %v (use a different variable name)", x) return true } } } return false }) vis.Walk(rule.Head.Args) if err != nil { return err } ignore.Push(vars) ignore.Push(declaredVars(rule.Body)) if rule.Head.Key != nil { rule.Head.Key = resolveRefsInTerm(globals, ignore, rule.Head.Key) } if rule.Head.Value != nil { rule.Head.Value = resolveRefsInTerm(globals, ignore, rule.Head.Value) } rule.Body = resolveRefsInBody(globals, ignore, rule.Body) return nil } func resolveRefsInBody(globals map[Var]Ref, ignore *declaredVarStack, body Body) Body { r := Body{} for _, expr := range body { r = append(r, resolveRefsInExpr(globals, ignore, expr)) } return r } func resolveRefsInExpr(globals map[Var]Ref, ignore *declaredVarStack, expr *Expr) *Expr { cpy := *expr switch ts := expr.Terms.(type) { case *Term: cpy.Terms = resolveRefsInTerm(globals, ignore, ts) case []*Term: buf := make([]*Term, len(ts)) for i := 0; i < len(ts); i++ { buf[i] = resolveRefsInTerm(globals, ignore, ts[i]) } cpy.Terms = buf } for _, w := range cpy.With { w.Target = resolveRefsInTerm(globals, ignore, w.Target) w.Value = resolveRefsInTerm(globals, ignore, w.Value) } return &cpy } func resolveRefsInTerm(globals map[Var]Ref, ignore *declaredVarStack, term *Term) *Term { switch v := term.Value.(type) { case Var: if g, ok := globals[v]; ok && !ignore.Contains(v) { cpy := g.Copy() for i := range cpy { cpy[i].SetLocation(term.Location) } return NewTerm(cpy).SetLocation(term.Location) } return term case Ref: fqn := resolveRef(globals, ignore, v) cpy := *term cpy.Value = fqn return &cpy case Object: cpy := *term cpy.Value, _ = v.Map(func(k, v *Term) (*Term, *Term, error) { k = resolveRefsInTerm(globals, ignore, k) v = resolveRefsInTerm(globals, ignore, v) return k, v, nil }) return &cpy case Array: cpy := *term cpy.Value = Array(resolveRefsInTermSlice(globals, ignore, v)) return &cpy case Call: cpy := *term cpy.Value = Call(resolveRefsInTermSlice(globals, ignore, v)) return &cpy case Set: s, _ := v.Map(func(e *Term) (*Term, error) { return resolveRefsInTerm(globals, ignore, e), nil }) cpy := *term cpy.Value = s return &cpy case *ArrayComprehension: ac := &ArrayComprehension{} ignore.Push(declaredVars(v.Body)) ac.Term = resolveRefsInTerm(globals, ignore, v.Term) ac.Body = resolveRefsInBody(globals, ignore, v.Body) cpy := *term cpy.Value = ac ignore.Pop() return &cpy case *ObjectComprehension: oc := &ObjectComprehension{} ignore.Push(declaredVars(v.Body)) oc.Key = resolveRefsInTerm(globals, ignore, v.Key) oc.Value = resolveRefsInTerm(globals, ignore, v.Value) oc.Body = resolveRefsInBody(globals, ignore, v.Body) cpy := *term cpy.Value = oc ignore.Pop() return &cpy case *SetComprehension: sc := &SetComprehension{} ignore.Push(declaredVars(v.Body)) sc.Term = resolveRefsInTerm(globals, ignore, v.Term) sc.Body = resolveRefsInBody(globals, ignore, v.Body) cpy := *term cpy.Value = sc ignore.Pop() return &cpy default: return term } } func resolveRefsInTermSlice(globals map[Var]Ref, ignore *declaredVarStack, terms []*Term) []*Term { cpy := make([]*Term, len(terms)) for i := 0; i < len(terms); i++ { cpy[i] = resolveRefsInTerm(globals, ignore, terms[i]) } return cpy } type declaredVarStack []VarSet func (s declaredVarStack) Contains(v Var) bool { for i := len(s) - 1; i >= 0; i-- { if _, ok := s[i][v]; ok { return ok } } return false } func (s declaredVarStack) Add(v Var) { s[len(s)-1].Add(v) } func (s *declaredVarStack) Push(vs VarSet) { *s = append(*s, vs) } func (s *declaredVarStack) Pop() { curr := *s *s = curr[:len(curr)-1] } func declaredVars(x interface{}) VarSet { vars := NewVarSet() vis := NewGenericVisitor(func(x interface{}) bool { switch x := x.(type) { case *Expr: if x.IsAssignment() && validEqAssignArgCount(x) { WalkVars(x.Operand(0), func(v Var) bool { vars.Add(v) return false }) } else if decl, ok := x.Terms.(*SomeDecl); ok { for i := range decl.Symbols { vars.Add(decl.Symbols[i].Value.(Var)) } } case *ArrayComprehension, *SetComprehension, *ObjectComprehension: return true } return false }) vis.Walk(x) return vars } // rewriteComprehensionTerms will rewrite comprehensions so that the term part // is bound to a variable in the body. This allows any type of term to be used // in the term part (even if the term requires evaluation.) // // For instance, given the following comprehension: // // [x[0] | x = y[_]; y = [1,2,3]] // // The comprehension would be rewritten as: // // [__local0__ | x = y[_]; y = [1,2,3]; __local0__ = x[0]] func rewriteComprehensionTerms(f *equalityFactory, node interface{}) (interface{}, error) { return TransformComprehensions(node, func(x interface{}) (Value, error) { switch x := x.(type) { case *ArrayComprehension: if requiresEval(x.Term) { expr := f.Generate(x.Term) x.Term = expr.Operand(0) x.Body.Append(expr) } return x, nil case *SetComprehension: if requiresEval(x.Term) { expr := f.Generate(x.Term) x.Term = expr.Operand(0) x.Body.Append(expr) } return x, nil case *ObjectComprehension: if requiresEval(x.Key) { expr := f.Generate(x.Key) x.Key = expr.Operand(0) x.Body.Append(expr) } if requiresEval(x.Value) { expr := f.Generate(x.Value) x.Value = expr.Operand(0) x.Body.Append(expr) } return x, nil } panic("illegal type") }) } // rewriteEquals will rewrite exprs under x as unification calls instead of == // calls. For example: // // data.foo == data.bar is rewritten as data.foo = data.bar // // This stage should only run the safety check (since == is a built-in with no // outputs, so the inputs must not be marked as safe.) // // This stage is not executed by the query compiler by default because when // callers specify == instead of = they expect to receive a true/false/undefined // result back whereas with = the result is only ever true/undefined. For // partial evaluation cases we do want to rewrite == to = to simplify the // result. func rewriteEquals(x interface{}) { doubleEq := Equal.Ref() unifyOp := Equality.Ref() WalkExprs(x, func(x *Expr) bool { if x.IsCall() { operator := x.Operator() if operator.Equal(doubleEq) && len(x.Operands()) == 2 { x.SetOperator(NewTerm(unifyOp)) } } return false }) } // rewriteDynamics will rewrite the body so that dynamic terms (i.e., refs and // comprehensions) are bound to vars earlier in the query. This translation // results in eager evaluation. // // For instance, given the following query: // // foo(data.bar) = 1 // // The rewritten version will be: // // __local0__ = data.bar; foo(__local0__) = 1 func rewriteDynamics(f *equalityFactory, body Body) Body { result := make(Body, 0, len(body)) for _, expr := range body { if expr.IsEquality() { result = rewriteDynamicsEqExpr(f, expr, result) } else if expr.IsCall() { result = rewriteDynamicsCallExpr(f, expr, result) } else { result = rewriteDynamicsTermExpr(f, expr, result) } } return result } func appendExpr(body Body, expr *Expr) Body { body.Append(expr) return body } func rewriteDynamicsEqExpr(f *equalityFactory, expr *Expr, result Body) Body { if !validEqAssignArgCount(expr) { return appendExpr(result, expr) } terms := expr.Terms.([]*Term) result, terms[1] = rewriteDynamicsInTerm(expr, f, terms[1], result) result, terms[2] = rewriteDynamicsInTerm(expr, f, terms[2], result) return appendExpr(result, expr) } func rewriteDynamicsCallExpr(f *equalityFactory, expr *Expr, result Body) Body { terms := expr.Terms.([]*Term) for i := 1; i < len(terms); i++ { result, terms[i] = rewriteDynamicsOne(expr, f, terms[i], result) } return appendExpr(result, expr) } func rewriteDynamicsTermExpr(f *equalityFactory, expr *Expr, result Body) Body { term := expr.Terms.(*Term) result, expr.Terms = rewriteDynamicsInTerm(expr, f, term, result) return appendExpr(result, expr) } func rewriteDynamicsInTerm(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) { switch v := term.Value.(type) { case Ref: for i := 1; i < len(v); i++ { result, v[i] = rewriteDynamicsOne(original, f, v[i], result) } case *ArrayComprehension: v.Body = rewriteDynamics(f, v.Body) case *SetComprehension: v.Body = rewriteDynamics(f, v.Body) case *ObjectComprehension: v.Body = rewriteDynamics(f, v.Body) default: result, term = rewriteDynamicsOne(original, f, term, result) } return result, term } func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) { switch v := term.Value.(type) { case Ref: for i := 1; i < len(v); i++ { result, v[i] = rewriteDynamicsOne(original, f, v[i], result) } generated := f.Generate(term) generated.With = original.With result.Append(generated) return result, result[len(result)-1].Operand(0) case Array: for i := 0; i < len(v); i++ { result, v[i] = rewriteDynamicsOne(original, f, v[i], result) } return result, term case Object: cpy := NewObject() for _, key := range v.Keys() { value := v.Get(key) result, key = rewriteDynamicsOne(original, f, key, result) result, value = rewriteDynamicsOne(original, f, value, result) cpy.Insert(key, value) } return result, NewTerm(cpy).SetLocation(term.Location) case Set: cpy := NewSet() for _, term := range v.Slice() { var rw *Term result, rw = rewriteDynamicsOne(original, f, term, result) cpy.Add(rw) } return result, NewTerm(cpy).SetLocation(term.Location) case *ArrayComprehension: var extra *Expr v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term) result.Append(extra) return result, result[len(result)-1].Operand(0) case *SetComprehension: var extra *Expr v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term) result.Append(extra) return result, result[len(result)-1].Operand(0) case *ObjectComprehension: var extra *Expr v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term) result.Append(extra) return result, result[len(result)-1].Operand(0) } return result, term } func rewriteDynamicsComprehensionBody(original *Expr, f *equalityFactory, body Body, term *Term) (Body, *Expr) { body = rewriteDynamics(f, body) generated := f.Generate(term) generated.With = original.With return body, generated } func rewriteExprTermsInHead(gen *localVarGenerator, rule *Rule) { if rule.Head.Key != nil { support, output := expandExprTerm(gen, rule.Head.Key) for i := range support { rule.Body.Append(support[i]) } rule.Head.Key = output } if rule.Head.Value != nil { support, output := expandExprTerm(gen, rule.Head.Value) for i := range support { rule.Body.Append(support[i]) } rule.Head.Value = output } } func rewriteExprTermsInBody(gen *localVarGenerator, body Body) Body { cpy := make(Body, 0, len(body)) for i := 0; i < len(body); i++ { for _, expr := range expandExpr(gen, body[i]) { cpy.Append(expr) } } return cpy } func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) { for i := range expr.With { extras, value := expandExprTerm(gen, expr.With[i].Value) expr.With[i].Value = value result = append(result, extras...) } switch terms := expr.Terms.(type) { case *Term: extras, term := expandExprTerm(gen, terms) if len(expr.With) > 0 { for i := range extras { extras[i].With = expr.With } } result = append(result, extras...) expr.Terms = term result = append(result, expr) case []*Term: for i := 1; i < len(terms); i++ { var extras []*Expr extras, terms[i] = expandExprTerm(gen, terms[i]) if len(expr.With) > 0 { for i := range extras { extras[i].With = expr.With } } result = append(result, extras...) } result = append(result, expr) } return } func expandExprTerm(gen *localVarGenerator, term *Term) (support []*Expr, output *Term) { output = term switch v := term.Value.(type) { case Call: for i := 1; i < len(v); i++ { var extras []*Expr extras, v[i] = expandExprTerm(gen, v[i]) support = append(support, extras...) } output = NewTerm(gen.Generate()).SetLocation(term.Location) expr := v.MakeExpr(output).SetLocation(term.Location) expr.Generated = true support = append(support, expr) case Ref: support = expandExprRef(gen, v) case Array: support = expandExprTermSlice(gen, v) case Object: cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) { extras1, expandedKey := expandExprTerm(gen, k) extras2, expandedValue := expandExprTerm(gen, v) support = append(support, extras1...) support = append(support, extras2...) return expandedKey, expandedValue, nil }) output = NewTerm(cpy).SetLocation(term.Location) case Set: cpy, _ := v.Map(func(x *Term) (*Term, error) { extras, expanded := expandExprTerm(gen, x) support = append(support, extras...) return expanded, nil }) output = NewTerm(cpy).SetLocation(term.Location) case *ArrayComprehension: support, term := expandExprTerm(gen, v.Term) for i := range support { v.Body.Append(support[i]) } v.Term = term v.Body = rewriteExprTermsInBody(gen, v.Body) case *SetComprehension: support, term := expandExprTerm(gen, v.Term) for i := range support { v.Body.Append(support[i]) } v.Term = term v.Body = rewriteExprTermsInBody(gen, v.Body) case *ObjectComprehension: support, key := expandExprTerm(gen, v.Key) for i := range support { v.Body.Append(support[i]) } v.Key = key support, value := expandExprTerm(gen, v.Value) for i := range support { v.Body.Append(support[i]) } v.Value = value v.Body = rewriteExprTermsInBody(gen, v.Body) } return } func expandExprRef(gen *localVarGenerator, v []*Term) (support []*Expr) { // Start by calling a normal expandExprTerm on all terms. support = expandExprTermSlice(gen, v) // Rewrite references in order to support indirect references. We rewrite // e.g. // // [1, 2, 3][i] // // to // // __local_var = [1, 2, 3] // __local_var[i] // // to support these. This only impacts the reference subject, i.e. the // first item in the slice. var subject = v[0] switch subject.Value.(type) { case Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call: f := newEqualityFactory(gen) assignToLocal := f.Generate(subject) support = append(support, assignToLocal) v[0] = assignToLocal.Operand(0) } return } func expandExprTermSlice(gen *localVarGenerator, v []*Term) (support []*Expr) { for i := 0; i < len(v); i++ { var extras []*Expr extras, v[i] = expandExprTerm(gen, v[i]) support = append(support, extras...) } return } type localDeclaredVars struct { vars []*declaredVarSet // rewritten contains a mapping of *all* user-defined variables // that have been rewritten whereas vars contains the state // from the current query (not not any nested queries, and all // vars seen). rewritten map[Var]Var } type varOccurrence int const ( newVar varOccurrence = iota argVar seenVar assignedVar declaredVar ) type declaredVarSet struct { vs map[Var]Var reverse map[Var]Var occurrence map[Var]varOccurrence } func newDeclaredVarSet() *declaredVarSet { return &declaredVarSet{ vs: map[Var]Var{}, reverse: map[Var]Var{}, occurrence: map[Var]varOccurrence{}, } } func newLocalDeclaredVars() *localDeclaredVars { return &localDeclaredVars{ vars: []*declaredVarSet{newDeclaredVarSet()}, rewritten: map[Var]Var{}, } } func (s *localDeclaredVars) Push() { s.vars = append(s.vars, newDeclaredVarSet()) } func (s *localDeclaredVars) Pop() *declaredVarSet { sl := s.vars curr := sl[len(sl)-1] s.vars = sl[:len(sl)-1] return curr } func (s localDeclaredVars) Peek() *declaredVarSet { return s.vars[len(s.vars)-1] } func (s localDeclaredVars) Insert(x, y Var, occurrence varOccurrence) { elem := s.vars[len(s.vars)-1] elem.vs[x] = y elem.reverse[y] = x elem.occurrence[x] = occurrence // If the variable has been rewritten (where x != y, with y being // the generated value), store it in the map of rewritten vars. // Assume that the generated values are unique for the compilation. if !x.Equal(y) { s.rewritten[y] = x } } func (s localDeclaredVars) Declared(x Var) (y Var, ok bool) { for i := len(s.vars) - 1; i >= 0; i-- { if y, ok = s.vars[i].vs[x]; ok { return } } return } // Occurrence returns a flag that indicates whether x has occurred in the // current scope. func (s localDeclaredVars) Occurrence(x Var) varOccurrence { return s.vars[len(s.vars)-1].occurrence[x] } // rewriteLocalVars rewrites bodies to remove assignment/declaration // expressions. For example: // // a := 1; p[a] // // Is rewritten to: // // __local0__ = 1; p[__local0__] // // During rewriting, assignees are validated to prevent use before declaration. func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body) (Body, map[Var]Var, Errors) { var errs Errors body, errs = rewriteDeclaredVarsInBody(g, stack, used, body, errs) return body, stack.Pop().vs, errs } func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, errs Errors) (Body, Errors) { var cpy Body for i := range body { var expr *Expr if body[i].IsAssignment() { expr, errs = rewriteDeclaredAssignment(g, stack, body[i], errs) } else if decl, ok := body[i].Terms.(*SomeDecl); ok { errs = rewriteSomeDeclStatement(g, stack, decl, errs) } else { expr, errs = rewriteDeclaredVarsInExpr(g, stack, body[i], errs) } if expr != nil { cpy.Append(expr) } } // If the body only contained a var statement it will be empty at this // point. Append true to the body to ensure that it's non-empty (zero length // bodies are not supported.) if len(cpy) == 0 { cpy.Append(NewExpr(BooleanTerm(true))) } return cpy, checkUnusedDeclaredVars(body[0].Loc(), stack, used, cpy, errs) } func checkUnusedDeclaredVars(loc *Location, stack *localDeclaredVars, used VarSet, cpy Body, errs Errors) Errors { // NOTE(tsandall): Do not generate more errors if there are existing // declaration errors. if len(errs) > 0 { return errs } dvs := stack.Peek() declared := NewVarSet() for v, occ := range dvs.occurrence { if occ == declaredVar { declared.Add(dvs.vs[v]) } } bodyvars := cpy.Vars(VarVisitorParams{}) for v := range used { if gv, ok := stack.Declared(v); ok { bodyvars.Add(gv) } else { bodyvars.Add(v) } } unused := declared.Diff(bodyvars).Diff(used) for _, gv := range unused.Sorted() { errs = append(errs, NewError(CompileErr, loc, "declared var %v unused", dvs.reverse[gv])) } return errs } func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, decl *SomeDecl, errs Errors) Errors { for i := range decl.Symbols { v := decl.Symbols[i].Value.(Var) if _, err := rewriteDeclaredVar(g, stack, v, declaredVar); err != nil { errs = append(errs, NewError(CompileErr, decl.Loc(), err.Error())) } } return errs } func rewriteDeclaredVarsInExpr(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors) (*Expr, Errors) { vis := NewGenericVisitor(func(x interface{}) bool { var stop bool switch x := x.(type) { case *Term: stop, errs = rewriteDeclaredVarsInTerm(g, stack, x, errs) case *With: _, errs = rewriteDeclaredVarsInTerm(g, stack, x.Value, errs) stop = true } return stop }) vis.Walk(expr) return expr, errs } func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors) (*Expr, Errors) { if expr.Negated { errs = append(errs, NewError(CompileErr, expr.Location, "cannot assign vars inside negated expression")) return expr, errs } numErrsBefore := len(errs) if !validEqAssignArgCount(expr) { return expr, errs } // Rewrite terms on right hand side capture seen vars and recursively // process comprehensions before left hand side is processed. Also // rewrite with modifier. errs = rewriteDeclaredVarsInTermRecursive(g, stack, expr.Operand(1), errs) for _, w := range expr.With { errs = rewriteDeclaredVarsInTermRecursive(g, stack, w.Value, errs) } // Rewrite vars on left hand side with unique names. Catch redeclaration // and invalid term types here. var vis func(t *Term) bool vis = func(t *Term) bool { switch v := t.Value.(type) { case Var: if gv, err := rewriteDeclaredVar(g, stack, v, assignedVar); err != nil { errs = append(errs, NewError(CompileErr, t.Location, err.Error())) } else { t.Value = gv } return true case Array: return false case Object: v.Foreach(func(_, v *Term) { WalkTerms(v, vis) }) return true case Ref: if RootDocumentRefs.Contains(t) { if gv, err := rewriteDeclaredVar(g, stack, v[0].Value.(Var), assignedVar); err != nil { errs = append(errs, NewError(CompileErr, t.Location, err.Error())) } else { t.Value = gv } return true } } errs = append(errs, NewError(CompileErr, t.Location, "cannot assign to %v", TypeName(t.Value))) return true } WalkTerms(expr.Operand(0), vis) if len(errs) == numErrsBefore { loc := expr.Operator()[0].Location expr.SetOperator(RefTerm(VarTerm(Equality.Name).SetLocation(loc)).SetLocation(loc)) } return expr, errs } func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors) (bool, Errors) { switch v := term.Value.(type) { case Var: if gv, ok := stack.Declared(v); ok { term.Value = gv } else if stack.Occurrence(v) == newVar { stack.Insert(v, v, seenVar) } case Ref: if RootDocumentRefs.Contains(term) { if gv, ok := stack.Declared(v[0].Value.(Var)); ok { term.Value = gv } return true, errs } return false, errs case Object: cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) { kcpy := k.Copy() errs = rewriteDeclaredVarsInTermRecursive(g, stack, kcpy, errs) errs = rewriteDeclaredVarsInTermRecursive(g, stack, v, errs) return kcpy, v, nil }) term.Value = cpy case Set: cpy, _ := v.Map(func(elem *Term) (*Term, error) { elemcpy := elem.Copy() errs = rewriteDeclaredVarsInTermRecursive(g, stack, elemcpy, errs) return elemcpy, nil }) term.Value = cpy case *ArrayComprehension: errs = rewriteDeclaredVarsInArrayComprehension(g, stack, v, errs) case *SetComprehension: errs = rewriteDeclaredVarsInSetComprehension(g, stack, v, errs) case *ObjectComprehension: errs = rewriteDeclaredVarsInObjectComprehension(g, stack, v, errs) default: return false, errs } return true, errs } func rewriteDeclaredVarsInTermRecursive(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors) Errors { WalkNodes(term, func(n Node) bool { var stop bool switch n := n.(type) { case *With: _, errs = rewriteDeclaredVarsInTerm(g, stack, n.Value, errs) stop = true case *Term: stop, errs = rewriteDeclaredVarsInTerm(g, stack, n, errs) } return stop }) return errs } func rewriteDeclaredVarsInArrayComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ArrayComprehension, errs Errors) Errors { stack.Push() v.Body, errs = rewriteDeclaredVarsInBody(g, stack, nil, v.Body, errs) errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs) stack.Pop() return errs } func rewriteDeclaredVarsInSetComprehension(g *localVarGenerator, stack *localDeclaredVars, v *SetComprehension, errs Errors) Errors { stack.Push() v.Body, errs = rewriteDeclaredVarsInBody(g, stack, nil, v.Body, errs) errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs) stack.Pop() return errs } func rewriteDeclaredVarsInObjectComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ObjectComprehension, errs Errors) Errors { stack.Push() v.Body, errs = rewriteDeclaredVarsInBody(g, stack, nil, v.Body, errs) errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Key, errs) errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Value, errs) stack.Pop() return errs } func rewriteDeclaredVar(g *localVarGenerator, stack *localDeclaredVars, v Var, occ varOccurrence) (gv Var, err error) { switch stack.Occurrence(v) { case seenVar: return gv, fmt.Errorf("var %v referenced above", v) case assignedVar: return gv, fmt.Errorf("var %v assigned above", v) case declaredVar: return gv, fmt.Errorf("var %v declared above", v) case argVar: return gv, fmt.Errorf("arg %v redeclared", v) } gv = g.Generate() stack.Insert(v, gv, occ) return } // rewriteWithModifiersInBody will rewrite the body so that with modifiers do // not contain terms that require evaluation as values. If this function // encounters an invalid with modifier target then it will raise an error. func rewriteWithModifiersInBody(c *Compiler, f *equalityFactory, body Body) (Body, *Error) { var result Body for i := range body { exprs, err := rewriteWithModifier(c, f, body[i]) if err != nil { return nil, err } if len(exprs) > 0 { for _, expr := range exprs { result.Append(expr) } } else { result.Append(body[i]) } } return result, nil } func rewriteWithModifier(c *Compiler, f *equalityFactory, expr *Expr) ([]*Expr, *Error) { var result []*Expr for i := range expr.With { err := validateTarget(c, expr.With[i].Target) if err != nil { return nil, err } if requiresEval(expr.With[i].Value) { eq := f.Generate(expr.With[i].Value) result = append(result, eq) expr.With[i].Value = eq.Operand(0) } } // If any of the with modifiers in this expression were rewritten then result // will be non-empty. In this case, the expression will have been modified and // it should also be added to the result. if len(result) > 0 { result = append(result, expr) } return result, nil } func validateTarget(c *Compiler, term *Term) *Error { if !isInputRef(term) && !isDataRef(term) { return NewError(TypeErr, term.Location, "with keyword target must start with %v or %v", InputRootDocument, DefaultRootDocument) } if isDataRef(term) { ref := term.Value.(Ref) node := c.RuleTree for i := 0; i < len(ref)-1; i++ { child := node.Child(ref[i].Value) if child == nil { break } else if len(child.Values) > 0 { return NewError(CompileErr, term.Loc(), "with keyword cannot partially replace virtual document(s)") } node = child } if node != nil { if child := node.Child(ref[len(ref)-1].Value); child != nil { for _, value := range child.Values { if len(value.(*Rule).Head.Args) > 0 { return NewError(CompileErr, term.Loc(), "with keyword cannot replace functions") } } } } } return nil } func isInputRef(term *Term) bool { if ref, ok := term.Value.(Ref); ok { if ref.HasPrefix(InputRootRef) { return true } } return false } func isDataRef(term *Term) bool { if ref, ok := term.Value.(Ref); ok { if ref.HasPrefix(DefaultRootRef) { return true } } return false } func isVirtual(node *TreeNode, ref Ref) bool { for i := 0; i < len(ref); i++ { child := node.Child(ref[i].Value) if child == nil { return false } else if len(child.Values) > 0 { return true } node = child } return true } func safetyErrorSlice(unsafe unsafeVars) (result Errors) { if len(unsafe) == 0 { return } for _, pair := range unsafe.Vars() { if !pair.Var.IsGenerated() { result = append(result, NewError(UnsafeVarErr, pair.Loc, "var %v is unsafe", pair.Var)) } } if len(result) > 0 { return } // If the expression contains unsafe generated variables, report which // expressions are unsafe instead of the variables that are unsafe (since // the latter are not meaningful to the user.) pairs := unsafe.Slice() sort.Slice(pairs, func(i, j int) bool { return pairs[i].Expr.Location.Compare(pairs[j].Expr.Location) < 0 }) // Report at most one error per generated variable. seen := NewVarSet() for _, expr := range pairs { before := len(seen) for v := range expr.Vars { if v.IsGenerated() { seen.Add(v) } } if len(seen) > before { result = append(result, NewError(UnsafeVarErr, expr.Expr.Location, "expression is unsafe")) } } return } func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, node interface{}) Errors { errs := make(Errors, 0) WalkExprs(node, func(x *Expr) bool { if x.IsCall() { operator := x.Operator().String() if _, ok := unsafeBuiltinsMap[operator]; ok { errs = append(errs, NewError(TypeErr, x.Loc(), "unsafe built-in function calls in expression: %v", operator)) } } return false }) return errs } func rewriteVarsInRef(vars ...map[Var]Var) func(Ref) Ref { return func(node Ref) Ref { i, _ := TransformVars(node, func(v Var) (Value, error) { for _, m := range vars { if u, ok := m[v]; ok { return u, nil } } return v, nil }) return i.(Ref) } } func rewriteVarsNop(node Ref) Ref { return node }