3420 lines
87 KiB
Go
3420 lines
87 KiB
Go
// Copyright 2016 The OPA Authors. All rights reserved.
|
|
// Use of this source code is governed by an Apache2
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package ast
|
|
|
|
import (
|
|
"fmt"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/open-policy-agent/opa/metrics"
|
|
"github.com/open-policy-agent/opa/util"
|
|
)
|
|
|
|
// CompileErrorLimitDefault is the default number errors a compiler will allow before
|
|
// exiting.
|
|
const CompileErrorLimitDefault = 10
|
|
|
|
var errLimitReached = NewError(CompileErr, nil, "error limit reached")
|
|
|
|
// Compiler contains the state of a compilation process.
|
|
type Compiler struct {
|
|
|
|
// Errors contains errors that occurred during the compilation process.
|
|
// If there are one or more errors, the compilation process is considered
|
|
// "failed".
|
|
Errors Errors
|
|
|
|
// Modules contains the compiled modules. The compiled modules are the
|
|
// output of the compilation process. If the compilation process failed,
|
|
// there is no guarantee about the state of the modules.
|
|
Modules map[string]*Module
|
|
|
|
// ModuleTree organizes the modules into a tree where each node is keyed by
|
|
// an element in the module's package path. E.g., given modules containing
|
|
// the following package directives: "a", "a.b", "a.c", and "a.b", the
|
|
// resulting module tree would be:
|
|
//
|
|
// root
|
|
// |
|
|
// +--- data (no modules)
|
|
// |
|
|
// +--- a (1 module)
|
|
// |
|
|
// +--- b (2 modules)
|
|
// |
|
|
// +--- c (1 module)
|
|
//
|
|
ModuleTree *ModuleTreeNode
|
|
|
|
// RuleTree organizes rules into a tree where each node is keyed by an
|
|
// element in the rule's path. The rule path is the concatenation of the
|
|
// containing package and the stringified rule name. E.g., given the
|
|
// following module:
|
|
//
|
|
// package ex
|
|
// p[1] { true }
|
|
// p[2] { true }
|
|
// q = true
|
|
//
|
|
// root
|
|
// |
|
|
// +--- data (no rules)
|
|
// |
|
|
// +--- ex (no rules)
|
|
// |
|
|
// +--- p (2 rules)
|
|
// |
|
|
// +--- q (1 rule)
|
|
RuleTree *TreeNode
|
|
|
|
// Graph contains dependencies between rules. An edge (u,v) is added to the
|
|
// graph if rule 'u' refers to the virtual document defined by 'v'.
|
|
Graph *Graph
|
|
|
|
// TypeEnv holds type information for values inferred by the compiler.
|
|
TypeEnv *TypeEnv
|
|
|
|
// RewrittenVars is a mapping of variables that have been rewritten
|
|
// with the key being the generated name and value being the original.
|
|
RewrittenVars map[Var]Var
|
|
|
|
localvargen *localVarGenerator
|
|
moduleLoader ModuleLoader
|
|
ruleIndices *util.HashMap
|
|
stages []struct {
|
|
name string
|
|
metricName string
|
|
f func()
|
|
}
|
|
maxErrs int
|
|
sorted []string // list of sorted module names
|
|
pathExists func([]string) (bool, error)
|
|
after map[string][]CompilerStageDefinition
|
|
metrics metrics.Metrics
|
|
builtins map[string]*Builtin
|
|
unsafeBuiltinsMap map[string]struct{}
|
|
}
|
|
|
|
// CompilerStage defines the interface for stages in the compiler.
|
|
type CompilerStage func(*Compiler) *Error
|
|
|
|
// CompilerStageDefinition defines a compiler stage
|
|
type CompilerStageDefinition struct {
|
|
Name string
|
|
MetricName string
|
|
Stage CompilerStage
|
|
}
|
|
|
|
// QueryContext contains contextual information for running an ad-hoc query.
|
|
//
|
|
// Ad-hoc queries can be run in the context of a package and imports may be
|
|
// included to provide concise access to data.
|
|
type QueryContext struct {
|
|
Package *Package
|
|
Imports []*Import
|
|
}
|
|
|
|
// NewQueryContext returns a new QueryContext object.
|
|
func NewQueryContext() *QueryContext {
|
|
return &QueryContext{}
|
|
}
|
|
|
|
// WithPackage sets the pkg on qc.
|
|
func (qc *QueryContext) WithPackage(pkg *Package) *QueryContext {
|
|
if qc == nil {
|
|
qc = NewQueryContext()
|
|
}
|
|
qc.Package = pkg
|
|
return qc
|
|
}
|
|
|
|
// WithImports sets the imports on qc.
|
|
func (qc *QueryContext) WithImports(imports []*Import) *QueryContext {
|
|
if qc == nil {
|
|
qc = NewQueryContext()
|
|
}
|
|
qc.Imports = imports
|
|
return qc
|
|
}
|
|
|
|
// Copy returns a deep copy of qc.
|
|
func (qc *QueryContext) Copy() *QueryContext {
|
|
if qc == nil {
|
|
return nil
|
|
}
|
|
cpy := *qc
|
|
if cpy.Package != nil {
|
|
cpy.Package = qc.Package.Copy()
|
|
}
|
|
cpy.Imports = make([]*Import, len(qc.Imports))
|
|
for i := range qc.Imports {
|
|
cpy.Imports[i] = qc.Imports[i].Copy()
|
|
}
|
|
return &cpy
|
|
}
|
|
|
|
// QueryCompiler defines the interface for compiling ad-hoc queries.
|
|
type QueryCompiler interface {
|
|
|
|
// Compile should be called to compile ad-hoc queries. The return value is
|
|
// the compiled version of the query.
|
|
Compile(q Body) (Body, error)
|
|
|
|
// TypeEnv returns the type environment built after running type checking
|
|
// on the query.
|
|
TypeEnv() *TypeEnv
|
|
|
|
// WithContext sets the QueryContext on the QueryCompiler. Subsequent calls
|
|
// to Compile will take the QueryContext into account.
|
|
WithContext(qctx *QueryContext) QueryCompiler
|
|
|
|
// WithUnsafeBuiltins sets the built-in functions to treat as unsafe and not
|
|
// allow inside of queries. By default the query compiler inherits the
|
|
// compiler's unsafe built-in functions. This function allows callers to
|
|
// override that set. If an empty (non-nil) map is provided, all built-ins
|
|
// are allowed.
|
|
WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler
|
|
|
|
// WithStageAfter registers a stage to run during query compilation after
|
|
// the named stage.
|
|
WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler
|
|
|
|
// RewrittenVars maps generated vars in the compiled query to vars from the
|
|
// parsed query. For example, given the query "input := 1" the rewritten
|
|
// query would be "__local0__ = 1". The mapping would then be {__local0__: input}.
|
|
RewrittenVars() map[Var]Var
|
|
}
|
|
|
|
// QueryCompilerStage defines the interface for stages in the query compiler.
|
|
type QueryCompilerStage func(QueryCompiler, Body) (Body, error)
|
|
|
|
// QueryCompilerStageDefinition defines a QueryCompiler stage
|
|
type QueryCompilerStageDefinition struct {
|
|
Name string
|
|
MetricName string
|
|
Stage QueryCompilerStage
|
|
}
|
|
|
|
const compileStageMetricPrefex = "ast_compile_stage_"
|
|
|
|
// NewCompiler returns a new empty compiler.
|
|
func NewCompiler() *Compiler {
|
|
|
|
c := &Compiler{
|
|
Modules: map[string]*Module{},
|
|
TypeEnv: NewTypeEnv(),
|
|
RewrittenVars: map[Var]Var{},
|
|
ruleIndices: util.NewHashMap(func(a, b util.T) bool {
|
|
r1, r2 := a.(Ref), b.(Ref)
|
|
return r1.Equal(r2)
|
|
}, func(x util.T) int {
|
|
return x.(Ref).Hash()
|
|
}),
|
|
maxErrs: CompileErrorLimitDefault,
|
|
after: map[string][]CompilerStageDefinition{},
|
|
unsafeBuiltinsMap: map[string]struct{}{},
|
|
}
|
|
|
|
c.ModuleTree = NewModuleTree(nil)
|
|
c.RuleTree = NewRuleTree(c.ModuleTree)
|
|
|
|
// Initialize the compiler with the statically compiled built-in functions.
|
|
// If the caller customizes the compiler, a copy will be made.
|
|
c.builtins = BuiltinMap
|
|
checker := newTypeChecker()
|
|
c.TypeEnv = checker.checkLanguageBuiltins(nil, c.builtins)
|
|
|
|
c.stages = []struct {
|
|
name string
|
|
metricName string
|
|
f func()
|
|
}{
|
|
// Reference resolution should run first as it may be used to lazily
|
|
// load additional modules. If any stages run before resolution, they
|
|
// need to be re-run after resolution.
|
|
{"ResolveRefs", "compile_stage_resolve_refs", c.resolveAllRefs},
|
|
|
|
// The local variable generator must be initialized after references are
|
|
// resolved and the dynamic module loader has run but before subsequent
|
|
// stages that need to generate variables.
|
|
{"InitLocalVarGen", "compile_stage_init_local_var_gen", c.initLocalVarGen},
|
|
|
|
{"RewriteLocalVars", "compile_stage_rewrite_local_vars", c.rewriteLocalVars},
|
|
{"RewriteExprTerms", "compile_stage_rewrite_expr_terms", c.rewriteExprTerms},
|
|
{"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree},
|
|
{"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree},
|
|
{"SetGraph", "compile_stage_set_graph", c.setGraph},
|
|
{"RewriteComprehensionTerms", "compile_stage_rewrite_comprehension_terms", c.rewriteComprehensionTerms},
|
|
{"RewriteRefsInHead", "compile_stage_rewrite_refs_in_head", c.rewriteRefsInHead},
|
|
{"RewriteWithValues", "compile_stage_rewrite_with_values", c.rewriteWithModifiers},
|
|
{"CheckRuleConflicts", "compile_stage_check_rule_conflicts", c.checkRuleConflicts},
|
|
{"CheckUndefinedFuncs", "compile_stage_check_undefined_funcs", c.checkUndefinedFuncs},
|
|
{"CheckSafetyRuleHeads", "compile_stage_check_safety_rule_heads", c.checkSafetyRuleHeads},
|
|
{"CheckSafetyRuleBodies", "compile_stage_check_safety_rule_bodies", c.checkSafetyRuleBodies},
|
|
{"RewriteEquals", "compile_stage_rewrite_equals", c.rewriteEquals},
|
|
{"RewriteDynamicTerms", "compile_stage_rewrite_dynamic_terms", c.rewriteDynamicTerms},
|
|
{"CheckRecursion", "compile_stage_check_recursion", c.checkRecursion},
|
|
{"CheckTypes", "compile_stage_check_types", c.checkTypes},
|
|
{"CheckUnsafeBuiltins", "compile_state_check_unsafe_builtins", c.checkUnsafeBuiltins},
|
|
{"BuildRuleIndices", "compile_stage_rebuild_indices", c.buildRuleIndices},
|
|
}
|
|
|
|
return c
|
|
}
|
|
|
|
// SetErrorLimit sets the number of errors the compiler can encounter before it
|
|
// quits. Zero or a negative number indicates no limit.
|
|
func (c *Compiler) SetErrorLimit(limit int) *Compiler {
|
|
c.maxErrs = limit
|
|
return c
|
|
}
|
|
|
|
// WithPathConflictsCheck enables base-virtual document conflict
|
|
// detection. The compiler will check that rules don't overlap with
|
|
// paths that exist as determined by the provided callable.
|
|
func (c *Compiler) WithPathConflictsCheck(fn func([]string) (bool, error)) *Compiler {
|
|
c.pathExists = fn
|
|
return c
|
|
}
|
|
|
|
// WithStageAfter registers a stage to run during compilation after
|
|
// the named stage.
|
|
func (c *Compiler) WithStageAfter(after string, stage CompilerStageDefinition) *Compiler {
|
|
c.after[after] = append(c.after[after], stage)
|
|
return c
|
|
}
|
|
|
|
// WithMetrics will set a metrics.Metrics and be used for profiling
|
|
// the Compiler instance.
|
|
func (c *Compiler) WithMetrics(metrics metrics.Metrics) *Compiler {
|
|
c.metrics = metrics
|
|
return c
|
|
}
|
|
|
|
// WithBuiltins adds a set of custom built-in functions to the compiler.
|
|
func (c *Compiler) WithBuiltins(builtins map[string]*Builtin) *Compiler {
|
|
if len(builtins) == 0 {
|
|
return c
|
|
}
|
|
cpy := make(map[string]*Builtin, len(c.builtins)+len(builtins))
|
|
for k, v := range c.builtins {
|
|
cpy[k] = v
|
|
}
|
|
for k, v := range builtins {
|
|
cpy[k] = v
|
|
}
|
|
c.builtins = cpy
|
|
// Build type env for custom functions and wrap existing one.
|
|
checker := newTypeChecker()
|
|
c.TypeEnv = checker.checkLanguageBuiltins(c.TypeEnv, builtins)
|
|
return c
|
|
}
|
|
|
|
// WithUnsafeBuiltins will add all built-ins in the map to the "blacklist".
|
|
func (c *Compiler) WithUnsafeBuiltins(unsafeBuiltins map[string]struct{}) *Compiler {
|
|
for name := range unsafeBuiltins {
|
|
c.unsafeBuiltinsMap[name] = struct{}{}
|
|
}
|
|
return c
|
|
}
|
|
|
|
// QueryCompiler returns a new QueryCompiler object.
|
|
func (c *Compiler) QueryCompiler() QueryCompiler {
|
|
return newQueryCompiler(c)
|
|
}
|
|
|
|
// Compile runs the compilation process on the input modules. The compiled
|
|
// version of the modules and associated data structures are stored on the
|
|
// compiler. If the compilation process fails for any reason, the compiler will
|
|
// contain a slice of errors.
|
|
func (c *Compiler) Compile(modules map[string]*Module) {
|
|
|
|
c.Modules = make(map[string]*Module, len(modules))
|
|
|
|
for k, v := range modules {
|
|
c.Modules[k] = v.Copy()
|
|
c.sorted = append(c.sorted, k)
|
|
}
|
|
|
|
sort.Strings(c.sorted)
|
|
|
|
c.compile()
|
|
}
|
|
|
|
// Failed returns true if a compilation error has been encountered.
|
|
func (c *Compiler) Failed() bool {
|
|
return len(c.Errors) > 0
|
|
}
|
|
|
|
// GetArity returns the number of args a function referred to by ref takes. If
|
|
// ref refers to built-in function, the built-in declaration is consulted,
|
|
// otherwise, the ref is used to perform a ruleset lookup.
|
|
func (c *Compiler) GetArity(ref Ref) int {
|
|
if bi := c.builtins[ref.String()]; bi != nil {
|
|
return len(bi.Decl.Args())
|
|
}
|
|
rules := c.GetRulesExact(ref)
|
|
if len(rules) == 0 {
|
|
return -1
|
|
}
|
|
return len(rules[0].Head.Args)
|
|
}
|
|
|
|
// GetRulesExact returns a slice of rules referred to by the reference.
|
|
//
|
|
// E.g., given the following module:
|
|
//
|
|
// package a.b.c
|
|
//
|
|
// p[k] = v { ... } # rule1
|
|
// p[k1] = v1 { ... } # rule2
|
|
//
|
|
// The following calls yield the rules on the right.
|
|
//
|
|
// GetRulesExact("data.a.b.c.p") => [rule1, rule2]
|
|
// GetRulesExact("data.a.b.c.p.x") => nil
|
|
// GetRulesExact("data.a.b.c") => nil
|
|
func (c *Compiler) GetRulesExact(ref Ref) (rules []*Rule) {
|
|
node := c.RuleTree
|
|
|
|
for _, x := range ref {
|
|
if node = node.Child(x.Value); node == nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return extractRules(node.Values)
|
|
}
|
|
|
|
// GetRulesForVirtualDocument returns a slice of rules that produce the virtual
|
|
// document referred to by the reference.
|
|
//
|
|
// E.g., given the following module:
|
|
//
|
|
// package a.b.c
|
|
//
|
|
// p[k] = v { ... } # rule1
|
|
// p[k1] = v1 { ... } # rule2
|
|
//
|
|
// The following calls yield the rules on the right.
|
|
//
|
|
// GetRulesForVirtualDocument("data.a.b.c.p") => [rule1, rule2]
|
|
// GetRulesForVirtualDocument("data.a.b.c.p.x") => [rule1, rule2]
|
|
// GetRulesForVirtualDocument("data.a.b.c") => nil
|
|
func (c *Compiler) GetRulesForVirtualDocument(ref Ref) (rules []*Rule) {
|
|
|
|
node := c.RuleTree
|
|
|
|
for _, x := range ref {
|
|
if node = node.Child(x.Value); node == nil {
|
|
return nil
|
|
}
|
|
if len(node.Values) > 0 {
|
|
return extractRules(node.Values)
|
|
}
|
|
}
|
|
|
|
return extractRules(node.Values)
|
|
}
|
|
|
|
// GetRulesWithPrefix returns a slice of rules that share the prefix ref.
|
|
//
|
|
// E.g., given the following module:
|
|
//
|
|
// package a.b.c
|
|
//
|
|
// p[x] = y { ... } # rule1
|
|
// p[k] = v { ... } # rule2
|
|
// q { ... } # rule3
|
|
//
|
|
// The following calls yield the rules on the right.
|
|
//
|
|
// GetRulesWithPrefix("data.a.b.c.p") => [rule1, rule2]
|
|
// GetRulesWithPrefix("data.a.b.c.p.a") => nil
|
|
// GetRulesWithPrefix("data.a.b.c") => [rule1, rule2, rule3]
|
|
func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) {
|
|
|
|
node := c.RuleTree
|
|
|
|
for _, x := range ref {
|
|
if node = node.Child(x.Value); node == nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
var acc func(node *TreeNode)
|
|
|
|
acc = func(node *TreeNode) {
|
|
rules = append(rules, extractRules(node.Values)...)
|
|
for _, child := range node.Children {
|
|
if child.Hide {
|
|
continue
|
|
}
|
|
acc(child)
|
|
}
|
|
}
|
|
|
|
acc(node)
|
|
|
|
return rules
|
|
}
|
|
|
|
func extractRules(s []util.T) (rules []*Rule) {
|
|
for _, r := range s {
|
|
rules = append(rules, r.(*Rule))
|
|
}
|
|
return rules
|
|
}
|
|
|
|
// GetRules returns a slice of rules that are referred to by ref.
|
|
//
|
|
// E.g., given the following module:
|
|
//
|
|
// package a.b.c
|
|
//
|
|
// p[x] = y { q[x] = y; ... } # rule1
|
|
// q[x] = y { ... } # rule2
|
|
//
|
|
// The following calls yield the rules on the right.
|
|
//
|
|
// GetRules("data.a.b.c.p") => [rule1]
|
|
// GetRules("data.a.b.c.p.x") => [rule1]
|
|
// GetRules("data.a.b.c.q") => [rule2]
|
|
// GetRules("data.a.b.c") => [rule1, rule2]
|
|
// GetRules("data.a.b.d") => nil
|
|
func (c *Compiler) GetRules(ref Ref) (rules []*Rule) {
|
|
|
|
set := map[*Rule]struct{}{}
|
|
|
|
for _, rule := range c.GetRulesForVirtualDocument(ref) {
|
|
set[rule] = struct{}{}
|
|
}
|
|
|
|
for _, rule := range c.GetRulesWithPrefix(ref) {
|
|
set[rule] = struct{}{}
|
|
}
|
|
|
|
for rule := range set {
|
|
rules = append(rules, rule)
|
|
}
|
|
|
|
return rules
|
|
}
|
|
|
|
// GetRulesDynamic returns a slice of rules that could be referred to by a ref.
|
|
// When parts of the ref are statically known, we use that information to narrow
|
|
// down which rules the ref could refer to, but in the most general case this
|
|
// will be an over-approximation.
|
|
//
|
|
// E.g., given the following modules:
|
|
//
|
|
// package a.b.c
|
|
//
|
|
// r1 = 1 # rule1
|
|
//
|
|
// and:
|
|
//
|
|
// package a.d.c
|
|
//
|
|
// r2 = 2 # rule2
|
|
//
|
|
// The following calls yield the rules on the right.
|
|
//
|
|
// GetRulesDynamic("data.a[x].c[y]") => [rule1, rule2]
|
|
// GetRulesDynamic("data.a[x].c.r2") => [rule2]
|
|
// GetRulesDynamic("data.a.b[x][y]") => [rule1]
|
|
func (c *Compiler) GetRulesDynamic(ref Ref) (rules []*Rule) {
|
|
node := c.RuleTree
|
|
|
|
set := map[*Rule]struct{}{}
|
|
var walk func(node *TreeNode, i int)
|
|
walk = func(node *TreeNode, i int) {
|
|
if i >= len(ref) {
|
|
// We've reached the end of the reference and want to collect everything
|
|
// under this "prefix".
|
|
node.DepthFirst(func(descendant *TreeNode) bool {
|
|
insertRules(set, descendant.Values)
|
|
return descendant.Hide
|
|
})
|
|
} else if i == 0 || IsConstant(ref[i].Value) {
|
|
// The head of the ref is always grounded. In case another part of the
|
|
// ref is also grounded, we can lookup the exact child. If it's not found
|
|
// we can immediately return...
|
|
if child := node.Child(ref[i].Value); child == nil {
|
|
return
|
|
} else if len(child.Values) > 0 {
|
|
// If there are any rules at this position, it's what the ref would
|
|
// refer to. We can just append those and stop here.
|
|
insertRules(set, child.Values)
|
|
} else {
|
|
// Otherwise, we continue using the child node.
|
|
walk(child, i+1)
|
|
}
|
|
} else {
|
|
// This part of the ref is a dynamic term. We can't know what it refers
|
|
// to and will just need to try all of the children.
|
|
for _, child := range node.Children {
|
|
if child.Hide {
|
|
continue
|
|
}
|
|
insertRules(set, child.Values)
|
|
walk(child, i+1)
|
|
}
|
|
}
|
|
}
|
|
|
|
walk(node, 0)
|
|
for rule := range set {
|
|
rules = append(rules, rule)
|
|
}
|
|
return rules
|
|
}
|
|
|
|
// Utility: add all rule values to the set.
|
|
func insertRules(set map[*Rule]struct{}, rules []util.T) {
|
|
for _, rule := range rules {
|
|
set[rule.(*Rule)] = struct{}{}
|
|
}
|
|
}
|
|
|
|
// RuleIndex returns a RuleIndex built for the rule set referred to by path.
|
|
// The path must refer to the rule set exactly, i.e., given a rule set at path
|
|
// data.a.b.c.p, refs data.a.b.c.p.x and data.a.b.c would not return a
|
|
// RuleIndex built for the rule.
|
|
func (c *Compiler) RuleIndex(path Ref) RuleIndex {
|
|
r, ok := c.ruleIndices.Get(path)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
return r.(RuleIndex)
|
|
}
|
|
|
|
// ModuleLoader defines the interface that callers can implement to enable lazy
|
|
// loading of modules during compilation.
|
|
type ModuleLoader func(resolved map[string]*Module) (parsed map[string]*Module, err error)
|
|
|
|
// WithModuleLoader sets f as the ModuleLoader on the compiler.
|
|
//
|
|
// The compiler will invoke the ModuleLoader after resolving all references in
|
|
// the current set of input modules. The ModuleLoader can return a new
|
|
// collection of parsed modules that are to be included in the compilation
|
|
// process. This process will repeat until the ModuleLoader returns an empty
|
|
// collection or an error. If an error is returned, compilation will stop
|
|
// immediately.
|
|
func (c *Compiler) WithModuleLoader(f ModuleLoader) *Compiler {
|
|
c.moduleLoader = f
|
|
return c
|
|
}
|
|
|
|
// buildRuleIndices constructs indices for rules.
|
|
func (c *Compiler) buildRuleIndices() {
|
|
|
|
c.RuleTree.DepthFirst(func(node *TreeNode) bool {
|
|
if len(node.Values) == 0 {
|
|
return false
|
|
}
|
|
index := newBaseDocEqIndex(func(ref Ref) bool {
|
|
return isVirtual(c.RuleTree, ref.GroundPrefix())
|
|
})
|
|
if rules := extractRules(node.Values); index.Build(rules) {
|
|
c.ruleIndices.Put(rules[0].Path(), index)
|
|
}
|
|
return false
|
|
})
|
|
|
|
}
|
|
|
|
// checkRecursion ensures that there are no recursive definitions, i.e., there are
|
|
// no cycles in the Graph.
|
|
func (c *Compiler) checkRecursion() {
|
|
eq := func(a, b util.T) bool {
|
|
return a.(*Rule) == b.(*Rule)
|
|
}
|
|
|
|
c.RuleTree.DepthFirst(func(node *TreeNode) bool {
|
|
for _, rule := range node.Values {
|
|
for node := rule.(*Rule); node != nil; node = node.Else {
|
|
c.checkSelfPath(node.Loc(), eq, node, node)
|
|
}
|
|
}
|
|
return false
|
|
})
|
|
}
|
|
|
|
func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b util.T) {
|
|
tr := NewGraphTraversal(c.Graph)
|
|
if p := util.DFSPath(tr, eq, a, b); len(p) > 0 {
|
|
n := []string{}
|
|
for _, x := range p {
|
|
n = append(n, astNodeToString(x))
|
|
}
|
|
c.err(NewError(RecursionErr, loc, "rule %v is recursive: %v", astNodeToString(a), strings.Join(n, " -> ")))
|
|
}
|
|
}
|
|
|
|
func astNodeToString(x interface{}) string {
|
|
switch x := x.(type) {
|
|
case *Rule:
|
|
return string(x.Head.Name)
|
|
default:
|
|
panic("not reached")
|
|
}
|
|
}
|
|
|
|
// checkRuleConflicts ensures that rules definitions are not in conflict.
|
|
func (c *Compiler) checkRuleConflicts() {
|
|
c.RuleTree.DepthFirst(func(node *TreeNode) bool {
|
|
if len(node.Values) == 0 {
|
|
return false
|
|
}
|
|
|
|
kinds := map[DocKind]struct{}{}
|
|
defaultRules := 0
|
|
arities := map[int]struct{}{}
|
|
declared := false
|
|
|
|
for _, rule := range node.Values {
|
|
r := rule.(*Rule)
|
|
kinds[r.Head.DocKind()] = struct{}{}
|
|
arities[len(r.Head.Args)] = struct{}{}
|
|
if r.Head.Assign {
|
|
declared = true
|
|
}
|
|
if r.Default {
|
|
defaultRules++
|
|
}
|
|
}
|
|
|
|
name := Var(node.Key.(String))
|
|
|
|
if declared && len(node.Values) > 1 {
|
|
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "rule named %v redeclared at %v", name, node.Values[1].(*Rule).Loc()))
|
|
} else if len(kinds) > 1 || len(arities) > 1 {
|
|
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules named %v found", name))
|
|
} else if defaultRules > 1 {
|
|
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules named %s found", name))
|
|
}
|
|
|
|
return false
|
|
})
|
|
|
|
if c.pathExists != nil {
|
|
for _, err := range CheckPathConflicts(c, c.pathExists) {
|
|
c.err(err)
|
|
}
|
|
}
|
|
|
|
c.ModuleTree.DepthFirst(func(node *ModuleTreeNode) bool {
|
|
for _, mod := range node.Modules {
|
|
for _, rule := range mod.Rules {
|
|
if childNode, ok := node.Children[String(rule.Head.Name)]; ok {
|
|
for _, childMod := range childNode.Modules {
|
|
msg := fmt.Sprintf("%v conflicts with rule defined at %v", childMod.Package, rule.Loc())
|
|
c.err(NewError(TypeErr, mod.Package.Loc(), msg))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
})
|
|
}
|
|
|
|
func (c *Compiler) checkUndefinedFuncs() {
|
|
for _, name := range c.sorted {
|
|
m := c.Modules[name]
|
|
for _, err := range checkUndefinedFuncs(m, c.GetArity) {
|
|
c.err(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func checkUndefinedFuncs(x interface{}, arity func(Ref) int) Errors {
|
|
|
|
var errs Errors
|
|
|
|
WalkExprs(x, func(expr *Expr) bool {
|
|
if !expr.IsCall() {
|
|
return false
|
|
}
|
|
ref := expr.Operator()
|
|
if arity(ref) >= 0 {
|
|
return false
|
|
}
|
|
errs = append(errs, NewError(TypeErr, expr.Loc(), "undefined function %v", ref))
|
|
return true
|
|
})
|
|
|
|
return errs
|
|
}
|
|
|
|
// checkSafetyRuleBodies ensures that variables appearing in negated expressions or non-target
|
|
// positions of built-in expressions will be bound when evaluating the rule from left
|
|
// to right, re-ordering as necessary.
|
|
func (c *Compiler) checkSafetyRuleBodies() {
|
|
for _, name := range c.sorted {
|
|
m := c.Modules[name]
|
|
WalkRules(m, func(r *Rule) bool {
|
|
safe := ReservedVars.Copy()
|
|
safe.Update(r.Head.Args.Vars())
|
|
r.Body = c.checkBodySafety(safe, m, r.Body)
|
|
return false
|
|
})
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) checkBodySafety(safe VarSet, m *Module, b Body) Body {
|
|
reordered, unsafe := reorderBodyForSafety(c.builtins, c.GetArity, safe, b)
|
|
if errs := safetyErrorSlice(unsafe); len(errs) > 0 {
|
|
for _, err := range errs {
|
|
c.err(err)
|
|
}
|
|
return b
|
|
}
|
|
return reordered
|
|
}
|
|
|
|
var safetyCheckVarVisitorParams = VarVisitorParams{
|
|
SkipRefCallHead: true,
|
|
SkipClosures: true,
|
|
}
|
|
|
|
// checkSafetyRuleHeads ensures that variables appearing in the head of a
|
|
// rule also appear in the body.
|
|
func (c *Compiler) checkSafetyRuleHeads() {
|
|
|
|
for _, name := range c.sorted {
|
|
m := c.Modules[name]
|
|
WalkRules(m, func(r *Rule) bool {
|
|
safe := r.Body.Vars(safetyCheckVarVisitorParams)
|
|
safe.Update(r.Head.Args.Vars())
|
|
unsafe := r.Head.Vars().Diff(safe)
|
|
for v := range unsafe {
|
|
if !v.IsGenerated() {
|
|
c.err(NewError(UnsafeVarErr, r.Loc(), "var %v is unsafe", v))
|
|
}
|
|
}
|
|
return false
|
|
})
|
|
}
|
|
}
|
|
|
|
// checkTypes runs the type checker on all rules. The type checker builds a
|
|
// TypeEnv that is stored on the compiler.
|
|
func (c *Compiler) checkTypes() {
|
|
// Recursion is caught in earlier step, so this cannot fail.
|
|
sorted, _ := c.Graph.Sort()
|
|
checker := newTypeChecker().WithVarRewriter(rewriteVarsInRef(c.RewrittenVars))
|
|
env, errs := checker.CheckTypes(c.TypeEnv, sorted)
|
|
for _, err := range errs {
|
|
c.err(err)
|
|
}
|
|
c.TypeEnv = env
|
|
}
|
|
|
|
func (c *Compiler) checkUnsafeBuiltins() {
|
|
for _, name := range c.sorted {
|
|
errs := checkUnsafeBuiltins(c.unsafeBuiltinsMap, c.Modules[name])
|
|
for _, err := range errs {
|
|
c.err(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) runStage(metricName string, f func()) {
|
|
if c.metrics != nil {
|
|
c.metrics.Timer(metricName).Start()
|
|
defer c.metrics.Timer(metricName).Stop()
|
|
}
|
|
f()
|
|
}
|
|
|
|
func (c *Compiler) runStageAfter(metricName string, s CompilerStage) *Error {
|
|
if c.metrics != nil {
|
|
c.metrics.Timer(metricName).Start()
|
|
defer c.metrics.Timer(metricName).Stop()
|
|
}
|
|
return s(c)
|
|
}
|
|
|
|
func (c *Compiler) compile() {
|
|
defer func() {
|
|
if r := recover(); r != nil && r != errLimitReached {
|
|
panic(r)
|
|
}
|
|
}()
|
|
|
|
for _, s := range c.stages {
|
|
c.runStage(s.metricName, s.f)
|
|
if c.Failed() {
|
|
return
|
|
}
|
|
for _, s := range c.after[s.name] {
|
|
err := c.runStageAfter(s.MetricName, s.Stage)
|
|
if err != nil {
|
|
c.err(err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) err(err *Error) {
|
|
if c.maxErrs > 0 && len(c.Errors) >= c.maxErrs {
|
|
c.Errors = append(c.Errors, errLimitReached)
|
|
panic(errLimitReached)
|
|
}
|
|
c.Errors = append(c.Errors, err)
|
|
}
|
|
|
|
func (c *Compiler) getExports() *util.HashMap {
|
|
|
|
rules := util.NewHashMap(func(a, b util.T) bool {
|
|
r1 := a.(Ref)
|
|
r2 := a.(Ref)
|
|
return r1.Equal(r2)
|
|
}, func(v util.T) int {
|
|
return v.(Ref).Hash()
|
|
})
|
|
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
rv, ok := rules.Get(mod.Package.Path)
|
|
if !ok {
|
|
rv = []Var{}
|
|
}
|
|
rvs := rv.([]Var)
|
|
|
|
for _, rule := range mod.Rules {
|
|
rvs = append(rvs, rule.Head.Name)
|
|
}
|
|
rules.Put(mod.Package.Path, rvs)
|
|
}
|
|
|
|
return rules
|
|
}
|
|
|
|
// resolveAllRefs resolves references in expressions to their fully qualified values.
|
|
//
|
|
// For instance, given the following module:
|
|
//
|
|
// package a.b
|
|
// import data.foo.bar
|
|
// p[x] { bar[_] = x }
|
|
//
|
|
// The reference "bar[_]" would be resolved to "data.foo.bar[_]".
|
|
func (c *Compiler) resolveAllRefs() {
|
|
|
|
rules := c.getExports()
|
|
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
|
|
var ruleExports []Var
|
|
if x, ok := rules.Get(mod.Package.Path); ok {
|
|
ruleExports = x.([]Var)
|
|
}
|
|
|
|
globals := getGlobals(mod.Package, ruleExports, mod.Imports)
|
|
|
|
WalkRules(mod, func(rule *Rule) bool {
|
|
err := resolveRefsInRule(globals, rule)
|
|
if err != nil {
|
|
c.err(NewError(CompileErr, rule.Location, err.Error()))
|
|
}
|
|
return false
|
|
})
|
|
|
|
// Once imports have been resolved, they are no longer needed.
|
|
mod.Imports = nil
|
|
}
|
|
|
|
if c.moduleLoader != nil {
|
|
|
|
parsed, err := c.moduleLoader(c.Modules)
|
|
if err != nil {
|
|
c.err(NewError(CompileErr, nil, err.Error()))
|
|
return
|
|
}
|
|
|
|
if len(parsed) == 0 {
|
|
return
|
|
}
|
|
|
|
for id, module := range parsed {
|
|
c.Modules[id] = module.Copy()
|
|
c.sorted = append(c.sorted, id)
|
|
}
|
|
|
|
sort.Strings(c.sorted)
|
|
c.resolveAllRefs()
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) initLocalVarGen() {
|
|
c.localvargen = newLocalVarGeneratorForModuleSet(c.sorted, c.Modules)
|
|
}
|
|
|
|
func (c *Compiler) rewriteComprehensionTerms() {
|
|
f := newEqualityFactory(c.localvargen)
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
rewriteComprehensionTerms(f, mod)
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) rewriteExprTerms() {
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
WalkRules(mod, func(rule *Rule) bool {
|
|
rewriteExprTermsInHead(c.localvargen, rule)
|
|
rule.Body = rewriteExprTermsInBody(c.localvargen, rule.Body)
|
|
return false
|
|
})
|
|
}
|
|
}
|
|
|
|
// rewriteTermsInHead will rewrite rules so that the head does not contain any
|
|
// terms that require evaluation (e.g., refs or comprehensions). If the key or
|
|
// value contains or more of these terms, the key or value will be moved into
|
|
// the body and assigned to a new variable. The new variable will replace the
|
|
// key or value in the head.
|
|
//
|
|
// For instance, given the following rule:
|
|
//
|
|
// p[{"foo": data.foo[i]}] { i < 100 }
|
|
//
|
|
// The rule would be re-written as:
|
|
//
|
|
// p[__local0__] { i < 100; __local0__ = {"foo": data.foo[i]} }
|
|
func (c *Compiler) rewriteRefsInHead() {
|
|
f := newEqualityFactory(c.localvargen)
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
WalkRules(mod, func(rule *Rule) bool {
|
|
if requiresEval(rule.Head.Key) {
|
|
expr := f.Generate(rule.Head.Key)
|
|
rule.Head.Key = expr.Operand(0)
|
|
rule.Body.Append(expr)
|
|
}
|
|
if requiresEval(rule.Head.Value) {
|
|
expr := f.Generate(rule.Head.Value)
|
|
rule.Head.Value = expr.Operand(0)
|
|
rule.Body.Append(expr)
|
|
}
|
|
for i := 0; i < len(rule.Head.Args); i++ {
|
|
if requiresEval(rule.Head.Args[i]) {
|
|
expr := f.Generate(rule.Head.Args[i])
|
|
rule.Head.Args[i] = expr.Operand(0)
|
|
rule.Body.Append(expr)
|
|
}
|
|
}
|
|
return false
|
|
})
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) rewriteEquals() {
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
rewriteEquals(mod)
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) rewriteDynamicTerms() {
|
|
f := newEqualityFactory(c.localvargen)
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
WalkRules(mod, func(rule *Rule) bool {
|
|
rule.Body = rewriteDynamics(f, rule.Body)
|
|
return false
|
|
})
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) rewriteLocalVars() {
|
|
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
gen := c.localvargen
|
|
|
|
WalkRules(mod, func(rule *Rule) bool {
|
|
|
|
var errs Errors
|
|
|
|
// Rewrite assignments contained in head of rule. Assignments can
|
|
// occur in rule head if they're inside a comprehension. Note,
|
|
// assigned vars in comprehensions in the head will be rewritten
|
|
// first to preserve scoping rules. For example:
|
|
//
|
|
// p = [x | x := 1] { x := 2 } becomes p = [__local0__ | __local0__ = 1] { __local1__ = 2 }
|
|
//
|
|
// This behaviour is consistent scoping inside the body. For example:
|
|
//
|
|
// p = xs { x := 2; xs = [x | x := 1] } becomes p = xs { __local0__ = 2; xs = [__local1__ | __local1__ = 1] }
|
|
WalkTerms(rule.Head, func(term *Term) bool {
|
|
stop := false
|
|
stack := newLocalDeclaredVars()
|
|
switch v := term.Value.(type) {
|
|
case *ArrayComprehension:
|
|
errs = rewriteDeclaredVarsInArrayComprehension(gen, stack, v, errs)
|
|
stop = true
|
|
case *SetComprehension:
|
|
errs = rewriteDeclaredVarsInSetComprehension(gen, stack, v, errs)
|
|
stop = true
|
|
case *ObjectComprehension:
|
|
errs = rewriteDeclaredVarsInObjectComprehension(gen, stack, v, errs)
|
|
stop = true
|
|
}
|
|
|
|
for k, v := range stack.rewritten {
|
|
c.RewrittenVars[k] = v
|
|
}
|
|
|
|
return stop
|
|
})
|
|
|
|
for _, err := range errs {
|
|
c.err(err)
|
|
}
|
|
|
|
// Rewrite assignments in body.
|
|
used := NewVarSet()
|
|
|
|
if rule.Head.Key != nil {
|
|
used.Update(rule.Head.Key.Vars())
|
|
}
|
|
|
|
if rule.Head.Value != nil {
|
|
used.Update(rule.Head.Value.Vars())
|
|
}
|
|
|
|
stack := newLocalDeclaredVars()
|
|
|
|
c.rewriteLocalArgVars(gen, stack, rule)
|
|
|
|
body, declared, errs := rewriteLocalVars(gen, stack, used, rule.Body)
|
|
for _, err := range errs {
|
|
c.err(err)
|
|
}
|
|
|
|
// For rewritten vars use the collection of all variables that
|
|
// were in the stack at some point in time.
|
|
for k, v := range stack.rewritten {
|
|
c.RewrittenVars[k] = v
|
|
}
|
|
|
|
rule.Body = body
|
|
|
|
// Rewrite vars in head that refer to locally declared vars in the body.
|
|
vis := NewGenericVisitor(func(x interface{}) bool {
|
|
|
|
term, ok := x.(*Term)
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
switch v := term.Value.(type) {
|
|
case Object:
|
|
// Make a copy of the object because the keys may be mutated.
|
|
cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) {
|
|
if vark, ok := k.Value.(Var); ok {
|
|
if gv, ok := declared[vark]; ok {
|
|
k = k.Copy()
|
|
k.Value = gv
|
|
}
|
|
}
|
|
return k, v, nil
|
|
})
|
|
term.Value = cpy
|
|
case Var:
|
|
if gv, ok := declared[v]; ok {
|
|
term.Value = gv
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
})
|
|
|
|
vis.Walk(rule.Head.Args)
|
|
|
|
if rule.Head.Key != nil {
|
|
vis.Walk(rule.Head.Key)
|
|
}
|
|
|
|
if rule.Head.Value != nil {
|
|
vis.Walk(rule.Head.Value)
|
|
}
|
|
|
|
return false
|
|
})
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) rewriteLocalArgVars(gen *localVarGenerator, stack *localDeclaredVars, rule *Rule) {
|
|
|
|
vis := &ruleArgLocalRewriter{
|
|
stack: stack,
|
|
gen: gen,
|
|
}
|
|
|
|
for i := range rule.Head.Args {
|
|
Walk(vis, rule.Head.Args[i])
|
|
}
|
|
|
|
for i := range vis.errs {
|
|
c.err(vis.errs[i])
|
|
}
|
|
}
|
|
|
|
type ruleArgLocalRewriter struct {
|
|
stack *localDeclaredVars
|
|
gen *localVarGenerator
|
|
errs []*Error
|
|
}
|
|
|
|
func (vis *ruleArgLocalRewriter) Visit(x interface{}) Visitor {
|
|
|
|
t, ok := x.(*Term)
|
|
if !ok {
|
|
return vis
|
|
}
|
|
|
|
switch v := t.Value.(type) {
|
|
case Var:
|
|
gv, ok := vis.stack.Declared(v)
|
|
if !ok {
|
|
gv = vis.gen.Generate()
|
|
vis.stack.Insert(v, gv, argVar)
|
|
}
|
|
t.Value = gv
|
|
return nil
|
|
case Object:
|
|
if cpy, err := v.Map(func(k, v *Term) (*Term, *Term, error) {
|
|
vcpy := v.Copy()
|
|
Walk(vis, vcpy)
|
|
return k, vcpy, nil
|
|
}); err != nil {
|
|
vis.errs = append(vis.errs, NewError(CompileErr, t.Location, err.Error()))
|
|
} else {
|
|
t.Value = cpy
|
|
}
|
|
return nil
|
|
case Null, Boolean, Number, String, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Set:
|
|
// Scalars are no-ops. Comprehensions are handled above. Sets must not
|
|
// contain variables.
|
|
return nil
|
|
default:
|
|
// Recurse on refs, arrays, and calls. Any embedded
|
|
// variables can be rewritten.
|
|
return vis
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) rewriteWithModifiers() {
|
|
f := newEqualityFactory(c.localvargen)
|
|
for _, name := range c.sorted {
|
|
mod := c.Modules[name]
|
|
t := NewGenericTransformer(func(x interface{}) (interface{}, error) {
|
|
body, ok := x.(Body)
|
|
if !ok {
|
|
return x, nil
|
|
}
|
|
body, err := rewriteWithModifiersInBody(c, f, body)
|
|
if err != nil {
|
|
c.err(err)
|
|
}
|
|
|
|
return body, nil
|
|
})
|
|
Transform(t, mod)
|
|
}
|
|
}
|
|
|
|
func (c *Compiler) setModuleTree() {
|
|
c.ModuleTree = NewModuleTree(c.Modules)
|
|
}
|
|
|
|
func (c *Compiler) setRuleTree() {
|
|
c.RuleTree = NewRuleTree(c.ModuleTree)
|
|
}
|
|
|
|
func (c *Compiler) setGraph() {
|
|
c.Graph = NewGraph(c.Modules, c.GetRulesDynamic)
|
|
}
|
|
|
|
type queryCompiler struct {
|
|
compiler *Compiler
|
|
qctx *QueryContext
|
|
typeEnv *TypeEnv
|
|
rewritten map[Var]Var
|
|
after map[string][]QueryCompilerStageDefinition
|
|
unsafeBuiltins map[string]struct{}
|
|
}
|
|
|
|
func newQueryCompiler(compiler *Compiler) QueryCompiler {
|
|
qc := &queryCompiler{
|
|
compiler: compiler,
|
|
qctx: nil,
|
|
after: map[string][]QueryCompilerStageDefinition{},
|
|
}
|
|
return qc
|
|
}
|
|
|
|
func (qc *queryCompiler) WithContext(qctx *QueryContext) QueryCompiler {
|
|
qc.qctx = qctx
|
|
return qc
|
|
}
|
|
|
|
func (qc *queryCompiler) WithStageAfter(after string, stage QueryCompilerStageDefinition) QueryCompiler {
|
|
qc.after[after] = append(qc.after[after], stage)
|
|
return qc
|
|
}
|
|
|
|
func (qc *queryCompiler) WithUnsafeBuiltins(unsafe map[string]struct{}) QueryCompiler {
|
|
qc.unsafeBuiltins = unsafe
|
|
return qc
|
|
}
|
|
|
|
func (qc *queryCompiler) RewrittenVars() map[Var]Var {
|
|
return qc.rewritten
|
|
}
|
|
|
|
func (qc *queryCompiler) runStage(metricName string, qctx *QueryContext, query Body, s func(*QueryContext, Body) (Body, error)) (Body, error) {
|
|
if qc.compiler.metrics != nil {
|
|
qc.compiler.metrics.Timer(metricName).Start()
|
|
defer qc.compiler.metrics.Timer(metricName).Stop()
|
|
}
|
|
return s(qctx, query)
|
|
}
|
|
|
|
func (qc *queryCompiler) runStageAfter(metricName string, query Body, s QueryCompilerStage) (Body, error) {
|
|
if qc.compiler.metrics != nil {
|
|
qc.compiler.metrics.Timer(metricName).Start()
|
|
defer qc.compiler.metrics.Timer(metricName).Stop()
|
|
}
|
|
return s(qc, query)
|
|
}
|
|
|
|
func (qc *queryCompiler) Compile(query Body) (Body, error) {
|
|
|
|
query = query.Copy()
|
|
|
|
stages := []struct {
|
|
name string
|
|
metricName string
|
|
f func(*QueryContext, Body) (Body, error)
|
|
}{
|
|
{"ResolveRefs", "query_compile_stage_resolve_refs", qc.resolveRefs},
|
|
{"RewriteLocalVars", "query_compile_stage_rewrite_local_vars", qc.rewriteLocalVars},
|
|
{"RewriteExprTerms", "query_compile_stage_rewrite_expr_terms", qc.rewriteExprTerms},
|
|
{"RewriteComprehensionTerms", "query_compile_stage_rewrite_comprehension_terms", qc.rewriteComprehensionTerms},
|
|
{"RewriteWithValues", "query_compile_stage_rewrite_with_values", qc.rewriteWithModifiers},
|
|
{"CheckUndefinedFuncs", "query_compile_stage_check_undefined_funcs", qc.checkUndefinedFuncs},
|
|
{"CheckSafety", "query_compile_stage_check_safety", qc.checkSafety},
|
|
{"RewriteDynamicTerms", "query_compile_stage_rewrite_dynamic_terms", qc.rewriteDynamicTerms},
|
|
{"CheckTypes", "query_compile_stage_check_types", qc.checkTypes},
|
|
{"CheckUnsafeBuiltins", "query_compile_stage_check_unsafe_builtins", qc.checkUnsafeBuiltins},
|
|
}
|
|
|
|
qctx := qc.qctx.Copy()
|
|
|
|
for _, s := range stages {
|
|
var err error
|
|
query, err = qc.runStage(s.metricName, qctx, query, s.f)
|
|
if err != nil {
|
|
return nil, qc.applyErrorLimit(err)
|
|
}
|
|
for _, s := range qc.after[s.name] {
|
|
query, err = qc.runStageAfter(s.MetricName, query, s.Stage)
|
|
if err != nil {
|
|
return nil, qc.applyErrorLimit(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return query, nil
|
|
}
|
|
|
|
func (qc *queryCompiler) TypeEnv() *TypeEnv {
|
|
return qc.typeEnv
|
|
}
|
|
|
|
func (qc *queryCompiler) applyErrorLimit(err error) error {
|
|
if errs, ok := err.(Errors); ok {
|
|
if qc.compiler.maxErrs > 0 && len(errs) > qc.compiler.maxErrs {
|
|
err = append(errs[:qc.compiler.maxErrs], errLimitReached)
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error) {
|
|
|
|
var globals map[Var]Ref
|
|
|
|
if qctx != nil && qctx.Package != nil {
|
|
var ruleExports []Var
|
|
rules := qc.compiler.getExports()
|
|
if exist, ok := rules.Get(qctx.Package.Path); ok {
|
|
ruleExports = exist.([]Var)
|
|
}
|
|
|
|
globals = getGlobals(qctx.Package, ruleExports, qc.qctx.Imports)
|
|
qctx.Imports = nil
|
|
}
|
|
|
|
ignore := &declaredVarStack{declaredVars(body)}
|
|
|
|
return resolveRefsInBody(globals, ignore, body), nil
|
|
}
|
|
|
|
func (qc *queryCompiler) rewriteComprehensionTerms(_ *QueryContext, body Body) (Body, error) {
|
|
gen := newLocalVarGenerator("q", body)
|
|
f := newEqualityFactory(gen)
|
|
node, err := rewriteComprehensionTerms(f, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return node.(Body), nil
|
|
}
|
|
|
|
func (qc *queryCompiler) rewriteDynamicTerms(_ *QueryContext, body Body) (Body, error) {
|
|
gen := newLocalVarGenerator("q", body)
|
|
f := newEqualityFactory(gen)
|
|
return rewriteDynamics(f, body), nil
|
|
}
|
|
|
|
func (qc *queryCompiler) rewriteExprTerms(_ *QueryContext, body Body) (Body, error) {
|
|
gen := newLocalVarGenerator("q", body)
|
|
return rewriteExprTermsInBody(gen, body), nil
|
|
}
|
|
|
|
func (qc *queryCompiler) rewriteLocalVars(_ *QueryContext, body Body) (Body, error) {
|
|
gen := newLocalVarGenerator("q", body)
|
|
stack := newLocalDeclaredVars()
|
|
body, _, err := rewriteLocalVars(gen, stack, nil, body)
|
|
if len(err) != 0 {
|
|
return nil, err
|
|
}
|
|
qc.rewritten = make(map[Var]Var, len(stack.rewritten))
|
|
for k, v := range stack.rewritten {
|
|
// The vars returned during the rewrite will include all seen vars,
|
|
// even if they're not declared with an assignment operation. We don't
|
|
// want to include these inside the rewritten set though.
|
|
qc.rewritten[k] = v
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func (qc *queryCompiler) checkUndefinedFuncs(_ *QueryContext, body Body) (Body, error) {
|
|
if errs := checkUndefinedFuncs(body, qc.compiler.GetArity); len(errs) > 0 {
|
|
return nil, errs
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func (qc *queryCompiler) checkSafety(_ *QueryContext, body Body) (Body, error) {
|
|
safe := ReservedVars.Copy()
|
|
reordered, unsafe := reorderBodyForSafety(qc.compiler.builtins, qc.compiler.GetArity, safe, body)
|
|
if errs := safetyErrorSlice(unsafe); len(errs) > 0 {
|
|
return nil, errs
|
|
}
|
|
return reordered, nil
|
|
}
|
|
|
|
func (qc *queryCompiler) checkTypes(qctx *QueryContext, body Body) (Body, error) {
|
|
var errs Errors
|
|
checker := newTypeChecker().WithVarRewriter(rewriteVarsInRef(qc.rewritten, qc.compiler.RewrittenVars))
|
|
qc.typeEnv, errs = checker.CheckBody(qc.compiler.TypeEnv, body)
|
|
if len(errs) > 0 {
|
|
return nil, errs
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func (qc *queryCompiler) checkUnsafeBuiltins(qctx *QueryContext, body Body) (Body, error) {
|
|
var unsafe map[string]struct{}
|
|
if qc.unsafeBuiltins != nil {
|
|
unsafe = qc.unsafeBuiltins
|
|
} else {
|
|
unsafe = qc.compiler.unsafeBuiltinsMap
|
|
}
|
|
errs := checkUnsafeBuiltins(unsafe, body)
|
|
if len(errs) > 0 {
|
|
return nil, errs
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func (qc *queryCompiler) rewriteWithModifiers(qctx *QueryContext, body Body) (Body, error) {
|
|
f := newEqualityFactory(newLocalVarGenerator("q", body))
|
|
body, err := rewriteWithModifiersInBody(qc.compiler, f, body)
|
|
if err != nil {
|
|
return nil, Errors{err}
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
// ModuleTreeNode represents a node in the module tree. The module
|
|
// tree is keyed by the package path.
|
|
type ModuleTreeNode struct {
|
|
Key Value
|
|
Modules []*Module
|
|
Children map[Value]*ModuleTreeNode
|
|
Hide bool
|
|
}
|
|
|
|
// NewModuleTree returns a new ModuleTreeNode that represents the root
|
|
// of the module tree populated with the given modules.
|
|
func NewModuleTree(mods map[string]*Module) *ModuleTreeNode {
|
|
root := &ModuleTreeNode{
|
|
Children: map[Value]*ModuleTreeNode{},
|
|
}
|
|
for _, m := range mods {
|
|
node := root
|
|
for i, x := range m.Package.Path {
|
|
c, ok := node.Children[x.Value]
|
|
if !ok {
|
|
var hide bool
|
|
if i == 1 && x.Value.Compare(SystemDocumentKey) == 0 {
|
|
hide = true
|
|
}
|
|
c = &ModuleTreeNode{
|
|
Key: x.Value,
|
|
Children: map[Value]*ModuleTreeNode{},
|
|
Hide: hide,
|
|
}
|
|
node.Children[x.Value] = c
|
|
}
|
|
node = c
|
|
}
|
|
node.Modules = append(node.Modules, m)
|
|
}
|
|
return root
|
|
}
|
|
|
|
// Size returns the number of modules in the tree.
|
|
func (n *ModuleTreeNode) Size() int {
|
|
s := len(n.Modules)
|
|
for _, c := range n.Children {
|
|
s += c.Size()
|
|
}
|
|
return s
|
|
}
|
|
|
|
// DepthFirst performs a depth-first traversal of the module tree rooted at n.
|
|
// If f returns true, traversal will not continue to the children of n.
|
|
func (n *ModuleTreeNode) DepthFirst(f func(node *ModuleTreeNode) bool) {
|
|
if !f(n) {
|
|
for _, node := range n.Children {
|
|
node.DepthFirst(f)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TreeNode represents a node in the rule tree. The rule tree is keyed by
|
|
// rule path.
|
|
type TreeNode struct {
|
|
Key Value
|
|
Values []util.T
|
|
Children map[Value]*TreeNode
|
|
Hide bool
|
|
}
|
|
|
|
// NewRuleTree returns a new TreeNode that represents the root
|
|
// of the rule tree populated with the given rules.
|
|
func NewRuleTree(mtree *ModuleTreeNode) *TreeNode {
|
|
|
|
ruleSets := map[String][]util.T{}
|
|
|
|
// Build rule sets for this package.
|
|
for _, mod := range mtree.Modules {
|
|
for _, rule := range mod.Rules {
|
|
key := String(rule.Head.Name)
|
|
ruleSets[key] = append(ruleSets[key], rule)
|
|
}
|
|
}
|
|
|
|
// Each rule set becomes a leaf node.
|
|
children := map[Value]*TreeNode{}
|
|
|
|
for key, rules := range ruleSets {
|
|
children[key] = &TreeNode{
|
|
Key: key,
|
|
Children: nil,
|
|
Values: rules,
|
|
}
|
|
}
|
|
|
|
// Each module in subpackage becomes child node.
|
|
for _, child := range mtree.Children {
|
|
children[child.Key] = NewRuleTree(child)
|
|
}
|
|
|
|
return &TreeNode{
|
|
Key: mtree.Key,
|
|
Values: nil,
|
|
Children: children,
|
|
Hide: mtree.Hide,
|
|
}
|
|
}
|
|
|
|
// Size returns the number of rules in the tree.
|
|
func (n *TreeNode) Size() int {
|
|
s := len(n.Values)
|
|
for _, c := range n.Children {
|
|
s += c.Size()
|
|
}
|
|
return s
|
|
}
|
|
|
|
// Child returns n's child with key k.
|
|
func (n *TreeNode) Child(k Value) *TreeNode {
|
|
switch k.(type) {
|
|
case String, Var:
|
|
return n.Children[k]
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DepthFirst performs a depth-first traversal of the rule tree rooted at n. If
|
|
// f returns true, traversal will not continue to the children of n.
|
|
func (n *TreeNode) DepthFirst(f func(node *TreeNode) bool) {
|
|
if !f(n) {
|
|
for _, node := range n.Children {
|
|
node.DepthFirst(f)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Graph represents the graph of dependencies between rules.
|
|
type Graph struct {
|
|
adj map[util.T]map[util.T]struct{}
|
|
nodes map[util.T]struct{}
|
|
sorted []util.T
|
|
}
|
|
|
|
// NewGraph returns a new Graph based on modules. The list function must return
|
|
// the rules referred to directly by the ref.
|
|
func NewGraph(modules map[string]*Module, list func(Ref) []*Rule) *Graph {
|
|
|
|
graph := &Graph{
|
|
adj: map[util.T]map[util.T]struct{}{},
|
|
nodes: map[util.T]struct{}{},
|
|
sorted: nil,
|
|
}
|
|
|
|
// Create visitor to walk a rule AST and add edges to the rule graph for
|
|
// each dependency.
|
|
vis := func(a *Rule) *GenericVisitor {
|
|
stop := false
|
|
return NewGenericVisitor(func(x interface{}) bool {
|
|
switch x := x.(type) {
|
|
case Ref:
|
|
for _, b := range list(x) {
|
|
for node := b; node != nil; node = node.Else {
|
|
graph.addDependency(a, node)
|
|
}
|
|
}
|
|
case *Rule:
|
|
if stop {
|
|
// Do not recurse into else clauses (which will be handled
|
|
// by the outer visitor.)
|
|
return true
|
|
}
|
|
stop = true
|
|
}
|
|
return false
|
|
})
|
|
}
|
|
|
|
// Walk over all rules, add them to graph, and build adjencency lists.
|
|
for _, module := range modules {
|
|
WalkRules(module, func(a *Rule) bool {
|
|
graph.addNode(a)
|
|
vis(a).Walk(a)
|
|
return false
|
|
})
|
|
}
|
|
|
|
return graph
|
|
}
|
|
|
|
// Dependencies returns the set of rules that x depends on.
|
|
func (g *Graph) Dependencies(x util.T) map[util.T]struct{} {
|
|
return g.adj[x]
|
|
}
|
|
|
|
// Sort returns a slice of rules sorted by dependencies. If a cycle is found,
|
|
// ok is set to false.
|
|
func (g *Graph) Sort() (sorted []util.T, ok bool) {
|
|
if g.sorted != nil {
|
|
return g.sorted, true
|
|
}
|
|
|
|
sort := &graphSort{
|
|
sorted: make([]util.T, 0, len(g.nodes)),
|
|
deps: g.Dependencies,
|
|
marked: map[util.T]struct{}{},
|
|
temp: map[util.T]struct{}{},
|
|
}
|
|
|
|
for node := range g.nodes {
|
|
if !sort.Visit(node) {
|
|
return nil, false
|
|
}
|
|
}
|
|
|
|
g.sorted = sort.sorted
|
|
return g.sorted, true
|
|
}
|
|
|
|
func (g *Graph) addDependency(u util.T, v util.T) {
|
|
|
|
if _, ok := g.nodes[u]; !ok {
|
|
g.addNode(u)
|
|
}
|
|
|
|
if _, ok := g.nodes[v]; !ok {
|
|
g.addNode(v)
|
|
}
|
|
|
|
edges, ok := g.adj[u]
|
|
if !ok {
|
|
edges = map[util.T]struct{}{}
|
|
g.adj[u] = edges
|
|
}
|
|
|
|
edges[v] = struct{}{}
|
|
}
|
|
|
|
func (g *Graph) addNode(n util.T) {
|
|
g.nodes[n] = struct{}{}
|
|
}
|
|
|
|
type graphSort struct {
|
|
sorted []util.T
|
|
deps func(util.T) map[util.T]struct{}
|
|
marked map[util.T]struct{}
|
|
temp map[util.T]struct{}
|
|
}
|
|
|
|
func (sort *graphSort) Marked(node util.T) bool {
|
|
_, marked := sort.marked[node]
|
|
return marked
|
|
}
|
|
|
|
func (sort *graphSort) Visit(node util.T) (ok bool) {
|
|
if _, ok := sort.temp[node]; ok {
|
|
return false
|
|
}
|
|
if sort.Marked(node) {
|
|
return true
|
|
}
|
|
sort.temp[node] = struct{}{}
|
|
for other := range sort.deps(node) {
|
|
if !sort.Visit(other) {
|
|
return false
|
|
}
|
|
}
|
|
sort.marked[node] = struct{}{}
|
|
delete(sort.temp, node)
|
|
sort.sorted = append(sort.sorted, node)
|
|
return true
|
|
}
|
|
|
|
// GraphTraversal is a Traversal that understands the dependency graph
|
|
type GraphTraversal struct {
|
|
graph *Graph
|
|
visited map[util.T]struct{}
|
|
}
|
|
|
|
// NewGraphTraversal returns a Traversal for the dependency graph
|
|
func NewGraphTraversal(graph *Graph) *GraphTraversal {
|
|
return &GraphTraversal{
|
|
graph: graph,
|
|
visited: map[util.T]struct{}{},
|
|
}
|
|
}
|
|
|
|
// Edges lists all dependency connections for a given node
|
|
func (g *GraphTraversal) Edges(x util.T) []util.T {
|
|
r := []util.T{}
|
|
for v := range g.graph.Dependencies(x) {
|
|
r = append(r, v)
|
|
}
|
|
return r
|
|
}
|
|
|
|
// Visited returns whether a node has been visited, setting a node to visited if not
|
|
func (g *GraphTraversal) Visited(u util.T) bool {
|
|
_, ok := g.visited[u]
|
|
g.visited[u] = struct{}{}
|
|
return ok
|
|
}
|
|
|
|
type unsafePair struct {
|
|
Expr *Expr
|
|
Vars VarSet
|
|
}
|
|
|
|
type unsafeVarLoc struct {
|
|
Var Var
|
|
Loc *Location
|
|
}
|
|
|
|
type unsafeVars map[*Expr]VarSet
|
|
|
|
func (vs unsafeVars) Add(e *Expr, v Var) {
|
|
if u, ok := vs[e]; ok {
|
|
u[v] = struct{}{}
|
|
} else {
|
|
vs[e] = VarSet{v: struct{}{}}
|
|
}
|
|
}
|
|
|
|
func (vs unsafeVars) Set(e *Expr, s VarSet) {
|
|
vs[e] = s
|
|
}
|
|
|
|
func (vs unsafeVars) Update(o unsafeVars) {
|
|
for k, v := range o {
|
|
if _, ok := vs[k]; !ok {
|
|
vs[k] = VarSet{}
|
|
}
|
|
vs[k].Update(v)
|
|
}
|
|
}
|
|
|
|
func (vs unsafeVars) Vars() (result []unsafeVarLoc) {
|
|
|
|
locs := map[Var]*Location{}
|
|
|
|
// If var appears in multiple sets then pick first by location.
|
|
for expr, vars := range vs {
|
|
for v := range vars {
|
|
if locs[v].Compare(expr.Location) > 0 {
|
|
locs[v] = expr.Location
|
|
}
|
|
}
|
|
}
|
|
|
|
for v, loc := range locs {
|
|
result = append(result, unsafeVarLoc{
|
|
Var: v,
|
|
Loc: loc,
|
|
})
|
|
}
|
|
|
|
sort.Slice(result, func(i, j int) bool {
|
|
return result[i].Loc.Compare(result[j].Loc) < 0
|
|
})
|
|
|
|
return result
|
|
}
|
|
|
|
func (vs unsafeVars) Slice() (result []unsafePair) {
|
|
for expr, vs := range vs {
|
|
result = append(result, unsafePair{
|
|
Expr: expr,
|
|
Vars: vs,
|
|
})
|
|
}
|
|
return
|
|
}
|
|
|
|
// reorderBodyForSafety returns a copy of the body ordered such that
|
|
// left to right evaluation of the body will not encounter unbound variables
|
|
// in input positions or negated expressions.
|
|
//
|
|
// Expressions are added to the re-ordered body as soon as they are considered
|
|
// safe. If multiple expressions become safe in the same pass, they are added
|
|
// in their original order. This results in minimal re-ordering of the body.
|
|
//
|
|
// If the body cannot be reordered to ensure safety, the second return value
|
|
// contains a mapping of expressions to unsafe variables in those expressions.
|
|
func reorderBodyForSafety(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) {
|
|
|
|
body, unsafe := reorderBodyForClosures(builtins, arity, globals, body)
|
|
if len(unsafe) != 0 {
|
|
return nil, unsafe
|
|
}
|
|
|
|
reordered := Body{}
|
|
safe := VarSet{}
|
|
|
|
for _, e := range body {
|
|
for v := range e.Vars(safetyCheckVarVisitorParams) {
|
|
if globals.Contains(v) {
|
|
safe.Add(v)
|
|
} else {
|
|
unsafe.Add(e, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
for {
|
|
n := len(reordered)
|
|
|
|
for _, e := range body {
|
|
if reordered.Contains(e) {
|
|
continue
|
|
}
|
|
|
|
safe.Update(outputVarsForExpr(e, builtins, arity, safe))
|
|
|
|
for v := range unsafe[e] {
|
|
if safe.Contains(v) {
|
|
delete(unsafe[e], v)
|
|
}
|
|
}
|
|
|
|
if len(unsafe[e]) == 0 {
|
|
delete(unsafe, e)
|
|
reordered = append(reordered, e)
|
|
}
|
|
}
|
|
|
|
if len(reordered) == n {
|
|
break
|
|
}
|
|
}
|
|
|
|
// Recursively visit closures and perform the safety checks on them.
|
|
// Update the globals at each expression to include the variables that could
|
|
// be closed over.
|
|
g := globals.Copy()
|
|
for i, e := range reordered {
|
|
if i > 0 {
|
|
g.Update(reordered[i-1].Vars(safetyCheckVarVisitorParams))
|
|
}
|
|
vis := &bodySafetyVisitor{
|
|
builtins: builtins,
|
|
arity: arity,
|
|
current: e,
|
|
globals: g,
|
|
unsafe: unsafe,
|
|
}
|
|
NewGenericVisitor(vis.Visit).Walk(e)
|
|
}
|
|
|
|
// Need to reset expression indices as re-ordering may have
|
|
// changed them.
|
|
setExprIndices(reordered)
|
|
|
|
return reordered, unsafe
|
|
}
|
|
|
|
type bodySafetyVisitor struct {
|
|
builtins map[string]*Builtin
|
|
arity func(Ref) int
|
|
current *Expr
|
|
globals VarSet
|
|
unsafe unsafeVars
|
|
}
|
|
|
|
func (vis *bodySafetyVisitor) Visit(x interface{}) bool {
|
|
switch x := x.(type) {
|
|
case *Expr:
|
|
cpy := *vis
|
|
cpy.current = x
|
|
|
|
switch ts := x.Terms.(type) {
|
|
case *SomeDecl:
|
|
NewGenericVisitor(cpy.Visit).Walk(ts)
|
|
case []*Term:
|
|
for _, t := range ts {
|
|
NewGenericVisitor(cpy.Visit).Walk(t)
|
|
}
|
|
case *Term:
|
|
NewGenericVisitor(cpy.Visit).Walk(ts)
|
|
}
|
|
for i := range x.With {
|
|
NewGenericVisitor(cpy.Visit).Walk(x.With[i])
|
|
}
|
|
return true
|
|
case *ArrayComprehension:
|
|
vis.checkArrayComprehensionSafety(x)
|
|
return true
|
|
case *ObjectComprehension:
|
|
vis.checkObjectComprehensionSafety(x)
|
|
return true
|
|
case *SetComprehension:
|
|
vis.checkSetComprehensionSafety(x)
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Check term for safety. This is analogous to the rule head safety check.
|
|
func (vis *bodySafetyVisitor) checkComprehensionSafety(tv VarSet, body Body) Body {
|
|
bv := body.Vars(safetyCheckVarVisitorParams)
|
|
bv.Update(vis.globals)
|
|
uv := tv.Diff(bv)
|
|
for v := range uv {
|
|
vis.unsafe.Add(vis.current, v)
|
|
}
|
|
|
|
// Check body for safety, reordering as necessary.
|
|
r, u := reorderBodyForSafety(vis.builtins, vis.arity, vis.globals, body)
|
|
if len(u) == 0 {
|
|
return r
|
|
}
|
|
|
|
vis.unsafe.Update(u)
|
|
return body
|
|
}
|
|
|
|
func (vis *bodySafetyVisitor) checkArrayComprehensionSafety(ac *ArrayComprehension) {
|
|
ac.Body = vis.checkComprehensionSafety(ac.Term.Vars(), ac.Body)
|
|
}
|
|
|
|
func (vis *bodySafetyVisitor) checkObjectComprehensionSafety(oc *ObjectComprehension) {
|
|
tv := oc.Key.Vars()
|
|
tv.Update(oc.Value.Vars())
|
|
oc.Body = vis.checkComprehensionSafety(tv, oc.Body)
|
|
}
|
|
|
|
func (vis *bodySafetyVisitor) checkSetComprehensionSafety(sc *SetComprehension) {
|
|
sc.Body = vis.checkComprehensionSafety(sc.Term.Vars(), sc.Body)
|
|
}
|
|
|
|
// reorderBodyForClosures returns a copy of the body ordered such that
|
|
// expressions (such as array comprehensions) that close over variables are ordered
|
|
// after other expressions that contain the same variable in an output position.
|
|
func reorderBodyForClosures(builtins map[string]*Builtin, arity func(Ref) int, globals VarSet, body Body) (Body, unsafeVars) {
|
|
|
|
reordered := Body{}
|
|
unsafe := unsafeVars{}
|
|
|
|
for {
|
|
n := len(reordered)
|
|
|
|
for _, e := range body {
|
|
if reordered.Contains(e) {
|
|
continue
|
|
}
|
|
|
|
// Collect vars that are contained in closures within this
|
|
// expression.
|
|
vs := VarSet{}
|
|
WalkClosures(e, func(x interface{}) bool {
|
|
vis := &VarVisitor{vars: vs}
|
|
vis.Walk(x)
|
|
return true
|
|
})
|
|
|
|
// Compute vars that are closed over from the body but not yet
|
|
// contained in the output position of an expression in the reordered
|
|
// body. These vars are considered unsafe.
|
|
cv := vs.Intersect(body.Vars(safetyCheckVarVisitorParams)).Diff(globals)
|
|
uv := cv.Diff(outputVarsForBody(reordered, builtins, arity, globals))
|
|
|
|
if len(uv) == 0 {
|
|
reordered = append(reordered, e)
|
|
delete(unsafe, e)
|
|
} else {
|
|
unsafe.Set(e, uv)
|
|
}
|
|
}
|
|
|
|
if len(reordered) == n {
|
|
break
|
|
}
|
|
}
|
|
|
|
return reordered, unsafe
|
|
}
|
|
|
|
func outputVarsForBody(body Body, builtins map[string]*Builtin, arity func(Ref) int, safe VarSet) VarSet {
|
|
o := safe.Copy()
|
|
for _, e := range body {
|
|
o.Update(outputVarsForExpr(e, builtins, arity, o))
|
|
}
|
|
return o.Diff(safe)
|
|
}
|
|
|
|
func outputVarsForExpr(expr *Expr, builtins map[string]*Builtin, arity func(Ref) int, safe VarSet) VarSet {
|
|
|
|
// Negated expressions must be safe.
|
|
if expr.Negated {
|
|
return VarSet{}
|
|
}
|
|
|
|
// With modifier inputs must be safe.
|
|
for _, with := range expr.With {
|
|
unsafe := false
|
|
WalkVars(with, func(v Var) bool {
|
|
if !safe.Contains(v) {
|
|
unsafe = true
|
|
return true
|
|
}
|
|
return false
|
|
})
|
|
if unsafe {
|
|
return VarSet{}
|
|
}
|
|
}
|
|
|
|
if !expr.IsCall() {
|
|
return outputVarsForExprRefs(expr, safe)
|
|
}
|
|
|
|
terms := expr.Terms.([]*Term)
|
|
name := terms[0].String()
|
|
|
|
if b := builtins[name]; b != nil {
|
|
if b.Name == Equality.Name {
|
|
return outputVarsForExprEq(expr, safe)
|
|
}
|
|
return outputVarsForExprBuiltin(expr, b, safe)
|
|
}
|
|
|
|
return outputVarsForExprCall(expr, builtins, arity, safe, terms)
|
|
}
|
|
|
|
func outputVarsForExprBuiltin(expr *Expr, b *Builtin, safe VarSet) VarSet {
|
|
|
|
output := outputVarsForExprRefs(expr, safe)
|
|
terms := expr.Terms.([]*Term)
|
|
|
|
// Check that all input terms are safe.
|
|
for i, t := range terms[1:] {
|
|
if b.IsTargetPos(i) {
|
|
continue
|
|
}
|
|
vis := NewVarVisitor().WithParams(VarVisitorParams{
|
|
SkipClosures: true,
|
|
SkipSets: true,
|
|
SkipObjectKeys: true,
|
|
SkipRefHead: true,
|
|
})
|
|
vis.Walk(t)
|
|
unsafe := vis.Vars().Diff(output).Diff(safe)
|
|
if len(unsafe) > 0 {
|
|
return VarSet{}
|
|
}
|
|
}
|
|
|
|
// Add vars in target positions to result.
|
|
for i, t := range terms[1:] {
|
|
if b.IsTargetPos(i) {
|
|
vis := NewVarVisitor().WithParams(VarVisitorParams{
|
|
SkipRefHead: true,
|
|
SkipSets: true,
|
|
SkipObjectKeys: true,
|
|
SkipClosures: true,
|
|
})
|
|
vis.Walk(t)
|
|
output.Update(vis.vars)
|
|
}
|
|
}
|
|
|
|
return output
|
|
}
|
|
|
|
func outputVarsForExprEq(expr *Expr, safe VarSet) VarSet {
|
|
if !validEqAssignArgCount(expr) {
|
|
return safe
|
|
}
|
|
output := outputVarsForExprRefs(expr, safe)
|
|
output.Update(safe)
|
|
output.Update(Unify(output, expr.Operand(0), expr.Operand(1)))
|
|
return output.Diff(safe)
|
|
}
|
|
|
|
func outputVarsForExprCall(expr *Expr, builtins map[string]*Builtin, arity func(Ref) int, safe VarSet, terms []*Term) VarSet {
|
|
|
|
output := outputVarsForExprRefs(expr, safe)
|
|
|
|
ref, ok := terms[0].Value.(Ref)
|
|
if !ok {
|
|
return VarSet{}
|
|
}
|
|
|
|
numArgs := arity(ref)
|
|
if numArgs == -1 {
|
|
return VarSet{}
|
|
}
|
|
|
|
numInputTerms := numArgs + 1
|
|
|
|
if numInputTerms >= len(terms) {
|
|
return output
|
|
}
|
|
|
|
vis := NewVarVisitor().WithParams(VarVisitorParams{
|
|
SkipClosures: true,
|
|
SkipSets: true,
|
|
SkipObjectKeys: true,
|
|
SkipRefHead: true,
|
|
})
|
|
|
|
vis.Walk(Args(terms[:numInputTerms]))
|
|
unsafe := vis.Vars().Diff(output).Diff(safe)
|
|
|
|
if len(unsafe) > 0 {
|
|
return VarSet{}
|
|
}
|
|
|
|
vis = NewVarVisitor().WithParams(VarVisitorParams{
|
|
SkipRefHead: true,
|
|
SkipSets: true,
|
|
SkipObjectKeys: true,
|
|
SkipClosures: true,
|
|
})
|
|
|
|
vis.Walk(Args(terms[numInputTerms:]))
|
|
output.Update(vis.vars)
|
|
return output
|
|
}
|
|
|
|
func outputVarsForExprRefs(expr *Expr, safe VarSet) VarSet {
|
|
output := VarSet{}
|
|
WalkRefs(expr, func(r Ref) bool {
|
|
if safe.Contains(r[0].Value.(Var)) {
|
|
output.Update(r.OutputVars())
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
return output
|
|
}
|
|
|
|
type equalityFactory struct {
|
|
gen *localVarGenerator
|
|
}
|
|
|
|
func newEqualityFactory(gen *localVarGenerator) *equalityFactory {
|
|
return &equalityFactory{gen}
|
|
}
|
|
|
|
func (f *equalityFactory) Generate(other *Term) *Expr {
|
|
term := NewTerm(f.gen.Generate()).SetLocation(other.Location)
|
|
expr := Equality.Expr(term, other)
|
|
expr.Generated = true
|
|
expr.Location = other.Location
|
|
return expr
|
|
}
|
|
|
|
type localVarGenerator struct {
|
|
exclude VarSet
|
|
suffix string
|
|
next int
|
|
}
|
|
|
|
func newLocalVarGeneratorForModuleSet(sorted []string, modules map[string]*Module) *localVarGenerator {
|
|
exclude := NewVarSet()
|
|
vis := &VarVisitor{vars: exclude}
|
|
for _, key := range sorted {
|
|
vis.Walk(modules[key])
|
|
}
|
|
return &localVarGenerator{exclude: exclude, next: 0}
|
|
}
|
|
|
|
func newLocalVarGenerator(suffix string, node interface{}) *localVarGenerator {
|
|
exclude := NewVarSet()
|
|
vis := &VarVisitor{vars: exclude}
|
|
vis.Walk(node)
|
|
return &localVarGenerator{exclude: exclude, suffix: suffix, next: 0}
|
|
}
|
|
|
|
func (l *localVarGenerator) Generate() Var {
|
|
for {
|
|
result := Var("__local" + l.suffix + strconv.Itoa(l.next) + "__")
|
|
l.next++
|
|
if !l.exclude.Contains(result) {
|
|
return result
|
|
}
|
|
}
|
|
}
|
|
|
|
func getGlobals(pkg *Package, rules []Var, imports []*Import) map[Var]Ref {
|
|
|
|
globals := map[Var]Ref{}
|
|
|
|
// Populate globals with exports within the package.
|
|
for _, v := range rules {
|
|
global := append(Ref{}, pkg.Path...)
|
|
global = append(global, &Term{Value: String(v)})
|
|
globals[v] = global
|
|
}
|
|
|
|
// Populate globals with imports.
|
|
for _, i := range imports {
|
|
if len(i.Alias) > 0 {
|
|
path := i.Path.Value.(Ref)
|
|
globals[i.Alias] = path
|
|
} else {
|
|
path := i.Path.Value.(Ref)
|
|
if len(path) == 1 {
|
|
globals[path[0].Value.(Var)] = path
|
|
} else {
|
|
v := path[len(path)-1].Value.(String)
|
|
globals[Var(v)] = path
|
|
}
|
|
}
|
|
}
|
|
|
|
return globals
|
|
}
|
|
|
|
func requiresEval(x *Term) bool {
|
|
if x == nil {
|
|
return false
|
|
}
|
|
return ContainsRefs(x) || ContainsComprehensions(x)
|
|
}
|
|
|
|
func resolveRef(globals map[Var]Ref, ignore *declaredVarStack, ref Ref) Ref {
|
|
|
|
r := Ref{}
|
|
for i, x := range ref {
|
|
switch v := x.Value.(type) {
|
|
case Var:
|
|
if g, ok := globals[v]; ok && !ignore.Contains(v) {
|
|
cpy := g.Copy()
|
|
for i := range cpy {
|
|
cpy[i].SetLocation(x.Location)
|
|
}
|
|
if i == 0 {
|
|
r = cpy
|
|
} else {
|
|
r = append(r, NewTerm(cpy).SetLocation(x.Location))
|
|
}
|
|
} else {
|
|
r = append(r, x)
|
|
}
|
|
case Ref, Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call:
|
|
r = append(r, resolveRefsInTerm(globals, ignore, x))
|
|
default:
|
|
r = append(r, x)
|
|
}
|
|
}
|
|
|
|
return r
|
|
}
|
|
|
|
func resolveRefsInRule(globals map[Var]Ref, rule *Rule) error {
|
|
ignore := &declaredVarStack{}
|
|
|
|
vars := NewVarSet()
|
|
var vis *GenericVisitor
|
|
var err error
|
|
|
|
// Walk args to collect vars and transform body so that callers can shadow
|
|
// root documents.
|
|
vis = NewGenericVisitor(func(x interface{}) bool {
|
|
if err != nil {
|
|
return true
|
|
}
|
|
switch x := x.(type) {
|
|
case Var:
|
|
vars.Add(x)
|
|
|
|
// Object keys cannot be pattern matched so only walk values.
|
|
case Object:
|
|
for _, k := range x.Keys() {
|
|
vis.Walk(x.Get(k))
|
|
}
|
|
|
|
// Skip terms that could contain vars that cannot be pattern matched.
|
|
case Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call:
|
|
return true
|
|
|
|
case *Term:
|
|
if _, ok := x.Value.(Ref); ok {
|
|
if RootDocumentRefs.Contains(x) {
|
|
// We could support args named input, data, etc. however
|
|
// this would require rewriting terms in the head and body.
|
|
// Preventing root document shadowing is simpler, and
|
|
// arguably, will prevent confusing names from being used.
|
|
err = fmt.Errorf("args must not shadow %v (use a different variable name)", x)
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
})
|
|
|
|
vis.Walk(rule.Head.Args)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ignore.Push(vars)
|
|
ignore.Push(declaredVars(rule.Body))
|
|
|
|
if rule.Head.Key != nil {
|
|
rule.Head.Key = resolveRefsInTerm(globals, ignore, rule.Head.Key)
|
|
}
|
|
|
|
if rule.Head.Value != nil {
|
|
rule.Head.Value = resolveRefsInTerm(globals, ignore, rule.Head.Value)
|
|
}
|
|
|
|
rule.Body = resolveRefsInBody(globals, ignore, rule.Body)
|
|
return nil
|
|
}
|
|
|
|
func resolveRefsInBody(globals map[Var]Ref, ignore *declaredVarStack, body Body) Body {
|
|
r := Body{}
|
|
for _, expr := range body {
|
|
r = append(r, resolveRefsInExpr(globals, ignore, expr))
|
|
}
|
|
return r
|
|
}
|
|
|
|
func resolveRefsInExpr(globals map[Var]Ref, ignore *declaredVarStack, expr *Expr) *Expr {
|
|
cpy := *expr
|
|
switch ts := expr.Terms.(type) {
|
|
case *Term:
|
|
cpy.Terms = resolveRefsInTerm(globals, ignore, ts)
|
|
case []*Term:
|
|
buf := make([]*Term, len(ts))
|
|
for i := 0; i < len(ts); i++ {
|
|
buf[i] = resolveRefsInTerm(globals, ignore, ts[i])
|
|
}
|
|
cpy.Terms = buf
|
|
}
|
|
for _, w := range cpy.With {
|
|
w.Target = resolveRefsInTerm(globals, ignore, w.Target)
|
|
w.Value = resolveRefsInTerm(globals, ignore, w.Value)
|
|
}
|
|
return &cpy
|
|
}
|
|
|
|
func resolveRefsInTerm(globals map[Var]Ref, ignore *declaredVarStack, term *Term) *Term {
|
|
switch v := term.Value.(type) {
|
|
case Var:
|
|
if g, ok := globals[v]; ok && !ignore.Contains(v) {
|
|
cpy := g.Copy()
|
|
for i := range cpy {
|
|
cpy[i].SetLocation(term.Location)
|
|
}
|
|
return NewTerm(cpy).SetLocation(term.Location)
|
|
}
|
|
return term
|
|
case Ref:
|
|
fqn := resolveRef(globals, ignore, v)
|
|
cpy := *term
|
|
cpy.Value = fqn
|
|
return &cpy
|
|
case Object:
|
|
cpy := *term
|
|
cpy.Value, _ = v.Map(func(k, v *Term) (*Term, *Term, error) {
|
|
k = resolveRefsInTerm(globals, ignore, k)
|
|
v = resolveRefsInTerm(globals, ignore, v)
|
|
return k, v, nil
|
|
})
|
|
return &cpy
|
|
case Array:
|
|
cpy := *term
|
|
cpy.Value = Array(resolveRefsInTermSlice(globals, ignore, v))
|
|
return &cpy
|
|
case Call:
|
|
cpy := *term
|
|
cpy.Value = Call(resolveRefsInTermSlice(globals, ignore, v))
|
|
return &cpy
|
|
case Set:
|
|
s, _ := v.Map(func(e *Term) (*Term, error) {
|
|
return resolveRefsInTerm(globals, ignore, e), nil
|
|
})
|
|
cpy := *term
|
|
cpy.Value = s
|
|
return &cpy
|
|
case *ArrayComprehension:
|
|
ac := &ArrayComprehension{}
|
|
ignore.Push(declaredVars(v.Body))
|
|
ac.Term = resolveRefsInTerm(globals, ignore, v.Term)
|
|
ac.Body = resolveRefsInBody(globals, ignore, v.Body)
|
|
cpy := *term
|
|
cpy.Value = ac
|
|
ignore.Pop()
|
|
return &cpy
|
|
case *ObjectComprehension:
|
|
oc := &ObjectComprehension{}
|
|
ignore.Push(declaredVars(v.Body))
|
|
oc.Key = resolveRefsInTerm(globals, ignore, v.Key)
|
|
oc.Value = resolveRefsInTerm(globals, ignore, v.Value)
|
|
oc.Body = resolveRefsInBody(globals, ignore, v.Body)
|
|
cpy := *term
|
|
cpy.Value = oc
|
|
ignore.Pop()
|
|
return &cpy
|
|
case *SetComprehension:
|
|
sc := &SetComprehension{}
|
|
ignore.Push(declaredVars(v.Body))
|
|
sc.Term = resolveRefsInTerm(globals, ignore, v.Term)
|
|
sc.Body = resolveRefsInBody(globals, ignore, v.Body)
|
|
cpy := *term
|
|
cpy.Value = sc
|
|
ignore.Pop()
|
|
return &cpy
|
|
default:
|
|
return term
|
|
}
|
|
}
|
|
|
|
func resolveRefsInTermSlice(globals map[Var]Ref, ignore *declaredVarStack, terms []*Term) []*Term {
|
|
cpy := make([]*Term, len(terms))
|
|
for i := 0; i < len(terms); i++ {
|
|
cpy[i] = resolveRefsInTerm(globals, ignore, terms[i])
|
|
}
|
|
return cpy
|
|
}
|
|
|
|
type declaredVarStack []VarSet
|
|
|
|
func (s declaredVarStack) Contains(v Var) bool {
|
|
for i := len(s) - 1; i >= 0; i-- {
|
|
if _, ok := s[i][v]; ok {
|
|
return ok
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (s declaredVarStack) Add(v Var) {
|
|
s[len(s)-1].Add(v)
|
|
}
|
|
|
|
func (s *declaredVarStack) Push(vs VarSet) {
|
|
*s = append(*s, vs)
|
|
}
|
|
|
|
func (s *declaredVarStack) Pop() {
|
|
curr := *s
|
|
*s = curr[:len(curr)-1]
|
|
}
|
|
|
|
func declaredVars(x interface{}) VarSet {
|
|
vars := NewVarSet()
|
|
vis := NewGenericVisitor(func(x interface{}) bool {
|
|
switch x := x.(type) {
|
|
case *Expr:
|
|
if x.IsAssignment() && validEqAssignArgCount(x) {
|
|
WalkVars(x.Operand(0), func(v Var) bool {
|
|
vars.Add(v)
|
|
return false
|
|
})
|
|
} else if decl, ok := x.Terms.(*SomeDecl); ok {
|
|
for i := range decl.Symbols {
|
|
vars.Add(decl.Symbols[i].Value.(Var))
|
|
}
|
|
}
|
|
case *ArrayComprehension, *SetComprehension, *ObjectComprehension:
|
|
return true
|
|
}
|
|
return false
|
|
})
|
|
vis.Walk(x)
|
|
return vars
|
|
}
|
|
|
|
// rewriteComprehensionTerms will rewrite comprehensions so that the term part
|
|
// is bound to a variable in the body. This allows any type of term to be used
|
|
// in the term part (even if the term requires evaluation.)
|
|
//
|
|
// For instance, given the following comprehension:
|
|
//
|
|
// [x[0] | x = y[_]; y = [1,2,3]]
|
|
//
|
|
// The comprehension would be rewritten as:
|
|
//
|
|
// [__local0__ | x = y[_]; y = [1,2,3]; __local0__ = x[0]]
|
|
func rewriteComprehensionTerms(f *equalityFactory, node interface{}) (interface{}, error) {
|
|
return TransformComprehensions(node, func(x interface{}) (Value, error) {
|
|
switch x := x.(type) {
|
|
case *ArrayComprehension:
|
|
if requiresEval(x.Term) {
|
|
expr := f.Generate(x.Term)
|
|
x.Term = expr.Operand(0)
|
|
x.Body.Append(expr)
|
|
}
|
|
return x, nil
|
|
case *SetComprehension:
|
|
if requiresEval(x.Term) {
|
|
expr := f.Generate(x.Term)
|
|
x.Term = expr.Operand(0)
|
|
x.Body.Append(expr)
|
|
}
|
|
return x, nil
|
|
case *ObjectComprehension:
|
|
if requiresEval(x.Key) {
|
|
expr := f.Generate(x.Key)
|
|
x.Key = expr.Operand(0)
|
|
x.Body.Append(expr)
|
|
}
|
|
if requiresEval(x.Value) {
|
|
expr := f.Generate(x.Value)
|
|
x.Value = expr.Operand(0)
|
|
x.Body.Append(expr)
|
|
}
|
|
return x, nil
|
|
}
|
|
panic("illegal type")
|
|
})
|
|
}
|
|
|
|
// rewriteEquals will rewrite exprs under x as unification calls instead of ==
|
|
// calls. For example:
|
|
//
|
|
// data.foo == data.bar is rewritten as data.foo = data.bar
|
|
//
|
|
// This stage should only run the safety check (since == is a built-in with no
|
|
// outputs, so the inputs must not be marked as safe.)
|
|
//
|
|
// This stage is not executed by the query compiler by default because when
|
|
// callers specify == instead of = they expect to receive a true/false/undefined
|
|
// result back whereas with = the result is only ever true/undefined. For
|
|
// partial evaluation cases we do want to rewrite == to = to simplify the
|
|
// result.
|
|
func rewriteEquals(x interface{}) {
|
|
doubleEq := Equal.Ref()
|
|
unifyOp := Equality.Ref()
|
|
WalkExprs(x, func(x *Expr) bool {
|
|
if x.IsCall() {
|
|
operator := x.Operator()
|
|
if operator.Equal(doubleEq) && len(x.Operands()) == 2 {
|
|
x.SetOperator(NewTerm(unifyOp))
|
|
}
|
|
}
|
|
return false
|
|
})
|
|
}
|
|
|
|
// rewriteDynamics will rewrite the body so that dynamic terms (i.e., refs and
|
|
// comprehensions) are bound to vars earlier in the query. This translation
|
|
// results in eager evaluation.
|
|
//
|
|
// For instance, given the following query:
|
|
//
|
|
// foo(data.bar) = 1
|
|
//
|
|
// The rewritten version will be:
|
|
//
|
|
// __local0__ = data.bar; foo(__local0__) = 1
|
|
func rewriteDynamics(f *equalityFactory, body Body) Body {
|
|
result := make(Body, 0, len(body))
|
|
for _, expr := range body {
|
|
if expr.IsEquality() {
|
|
result = rewriteDynamicsEqExpr(f, expr, result)
|
|
} else if expr.IsCall() {
|
|
result = rewriteDynamicsCallExpr(f, expr, result)
|
|
} else {
|
|
result = rewriteDynamicsTermExpr(f, expr, result)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func appendExpr(body Body, expr *Expr) Body {
|
|
body.Append(expr)
|
|
return body
|
|
}
|
|
|
|
func rewriteDynamicsEqExpr(f *equalityFactory, expr *Expr, result Body) Body {
|
|
if !validEqAssignArgCount(expr) {
|
|
return appendExpr(result, expr)
|
|
}
|
|
terms := expr.Terms.([]*Term)
|
|
result, terms[1] = rewriteDynamicsInTerm(expr, f, terms[1], result)
|
|
result, terms[2] = rewriteDynamicsInTerm(expr, f, terms[2], result)
|
|
return appendExpr(result, expr)
|
|
}
|
|
|
|
func rewriteDynamicsCallExpr(f *equalityFactory, expr *Expr, result Body) Body {
|
|
terms := expr.Terms.([]*Term)
|
|
for i := 1; i < len(terms); i++ {
|
|
result, terms[i] = rewriteDynamicsOne(expr, f, terms[i], result)
|
|
}
|
|
return appendExpr(result, expr)
|
|
}
|
|
|
|
func rewriteDynamicsTermExpr(f *equalityFactory, expr *Expr, result Body) Body {
|
|
term := expr.Terms.(*Term)
|
|
result, expr.Terms = rewriteDynamicsInTerm(expr, f, term, result)
|
|
return appendExpr(result, expr)
|
|
}
|
|
|
|
func rewriteDynamicsInTerm(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) {
|
|
switch v := term.Value.(type) {
|
|
case Ref:
|
|
for i := 1; i < len(v); i++ {
|
|
result, v[i] = rewriteDynamicsOne(original, f, v[i], result)
|
|
}
|
|
case *ArrayComprehension:
|
|
v.Body = rewriteDynamics(f, v.Body)
|
|
case *SetComprehension:
|
|
v.Body = rewriteDynamics(f, v.Body)
|
|
case *ObjectComprehension:
|
|
v.Body = rewriteDynamics(f, v.Body)
|
|
default:
|
|
result, term = rewriteDynamicsOne(original, f, term, result)
|
|
}
|
|
return result, term
|
|
}
|
|
|
|
func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) {
|
|
switch v := term.Value.(type) {
|
|
case Ref:
|
|
for i := 1; i < len(v); i++ {
|
|
result, v[i] = rewriteDynamicsOne(original, f, v[i], result)
|
|
}
|
|
generated := f.Generate(term)
|
|
generated.With = original.With
|
|
result.Append(generated)
|
|
return result, result[len(result)-1].Operand(0)
|
|
case Array:
|
|
for i := 0; i < len(v); i++ {
|
|
result, v[i] = rewriteDynamicsOne(original, f, v[i], result)
|
|
}
|
|
return result, term
|
|
case Object:
|
|
cpy := NewObject()
|
|
for _, key := range v.Keys() {
|
|
value := v.Get(key)
|
|
result, key = rewriteDynamicsOne(original, f, key, result)
|
|
result, value = rewriteDynamicsOne(original, f, value, result)
|
|
cpy.Insert(key, value)
|
|
}
|
|
return result, NewTerm(cpy).SetLocation(term.Location)
|
|
case Set:
|
|
cpy := NewSet()
|
|
for _, term := range v.Slice() {
|
|
var rw *Term
|
|
result, rw = rewriteDynamicsOne(original, f, term, result)
|
|
cpy.Add(rw)
|
|
}
|
|
return result, NewTerm(cpy).SetLocation(term.Location)
|
|
case *ArrayComprehension:
|
|
var extra *Expr
|
|
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
|
|
result.Append(extra)
|
|
return result, result[len(result)-1].Operand(0)
|
|
case *SetComprehension:
|
|
var extra *Expr
|
|
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
|
|
result.Append(extra)
|
|
return result, result[len(result)-1].Operand(0)
|
|
case *ObjectComprehension:
|
|
var extra *Expr
|
|
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
|
|
result.Append(extra)
|
|
return result, result[len(result)-1].Operand(0)
|
|
}
|
|
return result, term
|
|
}
|
|
|
|
func rewriteDynamicsComprehensionBody(original *Expr, f *equalityFactory, body Body, term *Term) (Body, *Expr) {
|
|
body = rewriteDynamics(f, body)
|
|
generated := f.Generate(term)
|
|
generated.With = original.With
|
|
return body, generated
|
|
}
|
|
|
|
func rewriteExprTermsInHead(gen *localVarGenerator, rule *Rule) {
|
|
if rule.Head.Key != nil {
|
|
support, output := expandExprTerm(gen, rule.Head.Key)
|
|
for i := range support {
|
|
rule.Body.Append(support[i])
|
|
}
|
|
rule.Head.Key = output
|
|
}
|
|
if rule.Head.Value != nil {
|
|
support, output := expandExprTerm(gen, rule.Head.Value)
|
|
for i := range support {
|
|
rule.Body.Append(support[i])
|
|
}
|
|
rule.Head.Value = output
|
|
}
|
|
}
|
|
|
|
func rewriteExprTermsInBody(gen *localVarGenerator, body Body) Body {
|
|
cpy := make(Body, 0, len(body))
|
|
for i := 0; i < len(body); i++ {
|
|
for _, expr := range expandExpr(gen, body[i]) {
|
|
cpy.Append(expr)
|
|
}
|
|
}
|
|
return cpy
|
|
}
|
|
|
|
func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
|
|
for i := range expr.With {
|
|
extras, value := expandExprTerm(gen, expr.With[i].Value)
|
|
expr.With[i].Value = value
|
|
result = append(result, extras...)
|
|
}
|
|
switch terms := expr.Terms.(type) {
|
|
case *Term:
|
|
extras, term := expandExprTerm(gen, terms)
|
|
if len(expr.With) > 0 {
|
|
for i := range extras {
|
|
extras[i].With = expr.With
|
|
}
|
|
}
|
|
result = append(result, extras...)
|
|
expr.Terms = term
|
|
result = append(result, expr)
|
|
case []*Term:
|
|
for i := 1; i < len(terms); i++ {
|
|
var extras []*Expr
|
|
extras, terms[i] = expandExprTerm(gen, terms[i])
|
|
if len(expr.With) > 0 {
|
|
for i := range extras {
|
|
extras[i].With = expr.With
|
|
}
|
|
}
|
|
result = append(result, extras...)
|
|
}
|
|
result = append(result, expr)
|
|
}
|
|
return
|
|
}
|
|
|
|
func expandExprTerm(gen *localVarGenerator, term *Term) (support []*Expr, output *Term) {
|
|
output = term
|
|
switch v := term.Value.(type) {
|
|
case Call:
|
|
for i := 1; i < len(v); i++ {
|
|
var extras []*Expr
|
|
extras, v[i] = expandExprTerm(gen, v[i])
|
|
support = append(support, extras...)
|
|
}
|
|
output = NewTerm(gen.Generate()).SetLocation(term.Location)
|
|
expr := v.MakeExpr(output).SetLocation(term.Location)
|
|
expr.Generated = true
|
|
support = append(support, expr)
|
|
case Ref:
|
|
support = expandExprRef(gen, v)
|
|
case Array:
|
|
support = expandExprTermSlice(gen, v)
|
|
case Object:
|
|
cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) {
|
|
extras1, expandedKey := expandExprTerm(gen, k)
|
|
extras2, expandedValue := expandExprTerm(gen, v)
|
|
support = append(support, extras1...)
|
|
support = append(support, extras2...)
|
|
return expandedKey, expandedValue, nil
|
|
})
|
|
output = NewTerm(cpy).SetLocation(term.Location)
|
|
case Set:
|
|
cpy, _ := v.Map(func(x *Term) (*Term, error) {
|
|
extras, expanded := expandExprTerm(gen, x)
|
|
support = append(support, extras...)
|
|
return expanded, nil
|
|
})
|
|
output = NewTerm(cpy).SetLocation(term.Location)
|
|
case *ArrayComprehension:
|
|
support, term := expandExprTerm(gen, v.Term)
|
|
for i := range support {
|
|
v.Body.Append(support[i])
|
|
}
|
|
v.Term = term
|
|
v.Body = rewriteExprTermsInBody(gen, v.Body)
|
|
case *SetComprehension:
|
|
support, term := expandExprTerm(gen, v.Term)
|
|
for i := range support {
|
|
v.Body.Append(support[i])
|
|
}
|
|
v.Term = term
|
|
v.Body = rewriteExprTermsInBody(gen, v.Body)
|
|
case *ObjectComprehension:
|
|
support, key := expandExprTerm(gen, v.Key)
|
|
for i := range support {
|
|
v.Body.Append(support[i])
|
|
}
|
|
v.Key = key
|
|
support, value := expandExprTerm(gen, v.Value)
|
|
for i := range support {
|
|
v.Body.Append(support[i])
|
|
}
|
|
v.Value = value
|
|
v.Body = rewriteExprTermsInBody(gen, v.Body)
|
|
}
|
|
return
|
|
}
|
|
|
|
func expandExprRef(gen *localVarGenerator, v []*Term) (support []*Expr) {
|
|
// Start by calling a normal expandExprTerm on all terms.
|
|
support = expandExprTermSlice(gen, v)
|
|
|
|
// Rewrite references in order to support indirect references. We rewrite
|
|
// e.g.
|
|
//
|
|
// [1, 2, 3][i]
|
|
//
|
|
// to
|
|
//
|
|
// __local_var = [1, 2, 3]
|
|
// __local_var[i]
|
|
//
|
|
// to support these. This only impacts the reference subject, i.e. the
|
|
// first item in the slice.
|
|
var subject = v[0]
|
|
switch subject.Value.(type) {
|
|
case Array, Object, Set, *ArrayComprehension, *SetComprehension, *ObjectComprehension, Call:
|
|
f := newEqualityFactory(gen)
|
|
assignToLocal := f.Generate(subject)
|
|
support = append(support, assignToLocal)
|
|
v[0] = assignToLocal.Operand(0)
|
|
}
|
|
return
|
|
}
|
|
|
|
func expandExprTermSlice(gen *localVarGenerator, v []*Term) (support []*Expr) {
|
|
for i := 0; i < len(v); i++ {
|
|
var extras []*Expr
|
|
extras, v[i] = expandExprTerm(gen, v[i])
|
|
support = append(support, extras...)
|
|
}
|
|
return
|
|
}
|
|
|
|
type localDeclaredVars struct {
|
|
vars []*declaredVarSet
|
|
|
|
// rewritten contains a mapping of *all* user-defined variables
|
|
// that have been rewritten whereas vars contains the state
|
|
// from the current query (not not any nested queries, and all
|
|
// vars seen).
|
|
rewritten map[Var]Var
|
|
}
|
|
|
|
type varOccurrence int
|
|
|
|
const (
|
|
newVar varOccurrence = iota
|
|
argVar
|
|
seenVar
|
|
assignedVar
|
|
declaredVar
|
|
)
|
|
|
|
type declaredVarSet struct {
|
|
vs map[Var]Var
|
|
reverse map[Var]Var
|
|
occurrence map[Var]varOccurrence
|
|
}
|
|
|
|
func newDeclaredVarSet() *declaredVarSet {
|
|
return &declaredVarSet{
|
|
vs: map[Var]Var{},
|
|
reverse: map[Var]Var{},
|
|
occurrence: map[Var]varOccurrence{},
|
|
}
|
|
}
|
|
|
|
func newLocalDeclaredVars() *localDeclaredVars {
|
|
return &localDeclaredVars{
|
|
vars: []*declaredVarSet{newDeclaredVarSet()},
|
|
rewritten: map[Var]Var{},
|
|
}
|
|
}
|
|
|
|
func (s *localDeclaredVars) Push() {
|
|
s.vars = append(s.vars, newDeclaredVarSet())
|
|
}
|
|
|
|
func (s *localDeclaredVars) Pop() *declaredVarSet {
|
|
sl := s.vars
|
|
curr := sl[len(sl)-1]
|
|
s.vars = sl[:len(sl)-1]
|
|
return curr
|
|
}
|
|
|
|
func (s localDeclaredVars) Peek() *declaredVarSet {
|
|
return s.vars[len(s.vars)-1]
|
|
}
|
|
|
|
func (s localDeclaredVars) Insert(x, y Var, occurrence varOccurrence) {
|
|
elem := s.vars[len(s.vars)-1]
|
|
elem.vs[x] = y
|
|
elem.reverse[y] = x
|
|
elem.occurrence[x] = occurrence
|
|
|
|
// If the variable has been rewritten (where x != y, with y being
|
|
// the generated value), store it in the map of rewritten vars.
|
|
// Assume that the generated values are unique for the compilation.
|
|
if !x.Equal(y) {
|
|
s.rewritten[y] = x
|
|
}
|
|
}
|
|
|
|
func (s localDeclaredVars) Declared(x Var) (y Var, ok bool) {
|
|
for i := len(s.vars) - 1; i >= 0; i-- {
|
|
if y, ok = s.vars[i].vs[x]; ok {
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// Occurrence returns a flag that indicates whether x has occurred in the
|
|
// current scope.
|
|
func (s localDeclaredVars) Occurrence(x Var) varOccurrence {
|
|
return s.vars[len(s.vars)-1].occurrence[x]
|
|
}
|
|
|
|
// rewriteLocalVars rewrites bodies to remove assignment/declaration
|
|
// expressions. For example:
|
|
//
|
|
// a := 1; p[a]
|
|
//
|
|
// Is rewritten to:
|
|
//
|
|
// __local0__ = 1; p[__local0__]
|
|
//
|
|
// During rewriting, assignees are validated to prevent use before declaration.
|
|
func rewriteLocalVars(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body) (Body, map[Var]Var, Errors) {
|
|
var errs Errors
|
|
body, errs = rewriteDeclaredVarsInBody(g, stack, used, body, errs)
|
|
return body, stack.Pop().vs, errs
|
|
}
|
|
|
|
func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, used VarSet, body Body, errs Errors) (Body, Errors) {
|
|
|
|
var cpy Body
|
|
|
|
for i := range body {
|
|
var expr *Expr
|
|
if body[i].IsAssignment() {
|
|
expr, errs = rewriteDeclaredAssignment(g, stack, body[i], errs)
|
|
} else if decl, ok := body[i].Terms.(*SomeDecl); ok {
|
|
errs = rewriteSomeDeclStatement(g, stack, decl, errs)
|
|
} else {
|
|
expr, errs = rewriteDeclaredVarsInExpr(g, stack, body[i], errs)
|
|
}
|
|
if expr != nil {
|
|
cpy.Append(expr)
|
|
}
|
|
}
|
|
|
|
// If the body only contained a var statement it will be empty at this
|
|
// point. Append true to the body to ensure that it's non-empty (zero length
|
|
// bodies are not supported.)
|
|
if len(cpy) == 0 {
|
|
cpy.Append(NewExpr(BooleanTerm(true)))
|
|
}
|
|
|
|
return cpy, checkUnusedDeclaredVars(body[0].Loc(), stack, used, cpy, errs)
|
|
}
|
|
|
|
func checkUnusedDeclaredVars(loc *Location, stack *localDeclaredVars, used VarSet, cpy Body, errs Errors) Errors {
|
|
|
|
// NOTE(tsandall): Do not generate more errors if there are existing
|
|
// declaration errors.
|
|
if len(errs) > 0 {
|
|
return errs
|
|
}
|
|
|
|
dvs := stack.Peek()
|
|
declared := NewVarSet()
|
|
|
|
for v, occ := range dvs.occurrence {
|
|
if occ == declaredVar {
|
|
declared.Add(dvs.vs[v])
|
|
}
|
|
}
|
|
|
|
bodyvars := cpy.Vars(VarVisitorParams{})
|
|
|
|
for v := range used {
|
|
if gv, ok := stack.Declared(v); ok {
|
|
bodyvars.Add(gv)
|
|
} else {
|
|
bodyvars.Add(v)
|
|
}
|
|
}
|
|
|
|
unused := declared.Diff(bodyvars).Diff(used)
|
|
|
|
for _, gv := range unused.Sorted() {
|
|
errs = append(errs, NewError(CompileErr, loc, "declared var %v unused", dvs.reverse[gv]))
|
|
}
|
|
|
|
return errs
|
|
}
|
|
|
|
func rewriteSomeDeclStatement(g *localVarGenerator, stack *localDeclaredVars, decl *SomeDecl, errs Errors) Errors {
|
|
for i := range decl.Symbols {
|
|
v := decl.Symbols[i].Value.(Var)
|
|
if _, err := rewriteDeclaredVar(g, stack, v, declaredVar); err != nil {
|
|
errs = append(errs, NewError(CompileErr, decl.Loc(), err.Error()))
|
|
}
|
|
}
|
|
return errs
|
|
}
|
|
|
|
func rewriteDeclaredVarsInExpr(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors) (*Expr, Errors) {
|
|
vis := NewGenericVisitor(func(x interface{}) bool {
|
|
var stop bool
|
|
switch x := x.(type) {
|
|
case *Term:
|
|
stop, errs = rewriteDeclaredVarsInTerm(g, stack, x, errs)
|
|
case *With:
|
|
_, errs = rewriteDeclaredVarsInTerm(g, stack, x.Value, errs)
|
|
stop = true
|
|
}
|
|
return stop
|
|
})
|
|
vis.Walk(expr)
|
|
return expr, errs
|
|
}
|
|
|
|
func rewriteDeclaredAssignment(g *localVarGenerator, stack *localDeclaredVars, expr *Expr, errs Errors) (*Expr, Errors) {
|
|
|
|
if expr.Negated {
|
|
errs = append(errs, NewError(CompileErr, expr.Location, "cannot assign vars inside negated expression"))
|
|
return expr, errs
|
|
}
|
|
|
|
numErrsBefore := len(errs)
|
|
|
|
if !validEqAssignArgCount(expr) {
|
|
return expr, errs
|
|
}
|
|
|
|
// Rewrite terms on right hand side capture seen vars and recursively
|
|
// process comprehensions before left hand side is processed. Also
|
|
// rewrite with modifier.
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, expr.Operand(1), errs)
|
|
|
|
for _, w := range expr.With {
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, w.Value, errs)
|
|
}
|
|
|
|
// Rewrite vars on left hand side with unique names. Catch redeclaration
|
|
// and invalid term types here.
|
|
var vis func(t *Term) bool
|
|
|
|
vis = func(t *Term) bool {
|
|
switch v := t.Value.(type) {
|
|
case Var:
|
|
if gv, err := rewriteDeclaredVar(g, stack, v, assignedVar); err != nil {
|
|
errs = append(errs, NewError(CompileErr, t.Location, err.Error()))
|
|
} else {
|
|
t.Value = gv
|
|
}
|
|
return true
|
|
case Array:
|
|
return false
|
|
case Object:
|
|
v.Foreach(func(_, v *Term) {
|
|
WalkTerms(v, vis)
|
|
})
|
|
return true
|
|
case Ref:
|
|
if RootDocumentRefs.Contains(t) {
|
|
if gv, err := rewriteDeclaredVar(g, stack, v[0].Value.(Var), assignedVar); err != nil {
|
|
errs = append(errs, NewError(CompileErr, t.Location, err.Error()))
|
|
} else {
|
|
t.Value = gv
|
|
}
|
|
return true
|
|
}
|
|
}
|
|
errs = append(errs, NewError(CompileErr, t.Location, "cannot assign to %v", TypeName(t.Value)))
|
|
return true
|
|
}
|
|
|
|
WalkTerms(expr.Operand(0), vis)
|
|
|
|
if len(errs) == numErrsBefore {
|
|
loc := expr.Operator()[0].Location
|
|
expr.SetOperator(RefTerm(VarTerm(Equality.Name).SetLocation(loc)).SetLocation(loc))
|
|
}
|
|
|
|
return expr, errs
|
|
}
|
|
|
|
func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors) (bool, Errors) {
|
|
switch v := term.Value.(type) {
|
|
case Var:
|
|
if gv, ok := stack.Declared(v); ok {
|
|
term.Value = gv
|
|
} else if stack.Occurrence(v) == newVar {
|
|
stack.Insert(v, v, seenVar)
|
|
}
|
|
case Ref:
|
|
if RootDocumentRefs.Contains(term) {
|
|
if gv, ok := stack.Declared(v[0].Value.(Var)); ok {
|
|
term.Value = gv
|
|
}
|
|
return true, errs
|
|
}
|
|
return false, errs
|
|
case Object:
|
|
cpy, _ := v.Map(func(k, v *Term) (*Term, *Term, error) {
|
|
kcpy := k.Copy()
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, kcpy, errs)
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v, errs)
|
|
return kcpy, v, nil
|
|
})
|
|
term.Value = cpy
|
|
case Set:
|
|
cpy, _ := v.Map(func(elem *Term) (*Term, error) {
|
|
elemcpy := elem.Copy()
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, elemcpy, errs)
|
|
return elemcpy, nil
|
|
})
|
|
term.Value = cpy
|
|
case *ArrayComprehension:
|
|
errs = rewriteDeclaredVarsInArrayComprehension(g, stack, v, errs)
|
|
case *SetComprehension:
|
|
errs = rewriteDeclaredVarsInSetComprehension(g, stack, v, errs)
|
|
case *ObjectComprehension:
|
|
errs = rewriteDeclaredVarsInObjectComprehension(g, stack, v, errs)
|
|
default:
|
|
return false, errs
|
|
}
|
|
return true, errs
|
|
}
|
|
|
|
func rewriteDeclaredVarsInTermRecursive(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors) Errors {
|
|
WalkNodes(term, func(n Node) bool {
|
|
var stop bool
|
|
switch n := n.(type) {
|
|
case *With:
|
|
_, errs = rewriteDeclaredVarsInTerm(g, stack, n.Value, errs)
|
|
stop = true
|
|
case *Term:
|
|
stop, errs = rewriteDeclaredVarsInTerm(g, stack, n, errs)
|
|
}
|
|
return stop
|
|
})
|
|
return errs
|
|
}
|
|
|
|
func rewriteDeclaredVarsInArrayComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ArrayComprehension, errs Errors) Errors {
|
|
stack.Push()
|
|
v.Body, errs = rewriteDeclaredVarsInBody(g, stack, nil, v.Body, errs)
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs)
|
|
stack.Pop()
|
|
return errs
|
|
}
|
|
|
|
func rewriteDeclaredVarsInSetComprehension(g *localVarGenerator, stack *localDeclaredVars, v *SetComprehension, errs Errors) Errors {
|
|
stack.Push()
|
|
v.Body, errs = rewriteDeclaredVarsInBody(g, stack, nil, v.Body, errs)
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Term, errs)
|
|
stack.Pop()
|
|
return errs
|
|
}
|
|
|
|
func rewriteDeclaredVarsInObjectComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ObjectComprehension, errs Errors) Errors {
|
|
stack.Push()
|
|
v.Body, errs = rewriteDeclaredVarsInBody(g, stack, nil, v.Body, errs)
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Key, errs)
|
|
errs = rewriteDeclaredVarsInTermRecursive(g, stack, v.Value, errs)
|
|
stack.Pop()
|
|
return errs
|
|
}
|
|
|
|
func rewriteDeclaredVar(g *localVarGenerator, stack *localDeclaredVars, v Var, occ varOccurrence) (gv Var, err error) {
|
|
switch stack.Occurrence(v) {
|
|
case seenVar:
|
|
return gv, fmt.Errorf("var %v referenced above", v)
|
|
case assignedVar:
|
|
return gv, fmt.Errorf("var %v assigned above", v)
|
|
case declaredVar:
|
|
return gv, fmt.Errorf("var %v declared above", v)
|
|
case argVar:
|
|
return gv, fmt.Errorf("arg %v redeclared", v)
|
|
}
|
|
gv = g.Generate()
|
|
stack.Insert(v, gv, occ)
|
|
return
|
|
}
|
|
|
|
// rewriteWithModifiersInBody will rewrite the body so that with modifiers do
|
|
// not contain terms that require evaluation as values. If this function
|
|
// encounters an invalid with modifier target then it will raise an error.
|
|
func rewriteWithModifiersInBody(c *Compiler, f *equalityFactory, body Body) (Body, *Error) {
|
|
var result Body
|
|
for i := range body {
|
|
exprs, err := rewriteWithModifier(c, f, body[i])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(exprs) > 0 {
|
|
for _, expr := range exprs {
|
|
result.Append(expr)
|
|
}
|
|
} else {
|
|
result.Append(body[i])
|
|
}
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func rewriteWithModifier(c *Compiler, f *equalityFactory, expr *Expr) ([]*Expr, *Error) {
|
|
|
|
var result []*Expr
|
|
for i := range expr.With {
|
|
err := validateTarget(c, expr.With[i].Target)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if requiresEval(expr.With[i].Value) {
|
|
eq := f.Generate(expr.With[i].Value)
|
|
result = append(result, eq)
|
|
expr.With[i].Value = eq.Operand(0)
|
|
}
|
|
}
|
|
|
|
// If any of the with modifiers in this expression were rewritten then result
|
|
// will be non-empty. In this case, the expression will have been modified and
|
|
// it should also be added to the result.
|
|
if len(result) > 0 {
|
|
result = append(result, expr)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func validateTarget(c *Compiler, term *Term) *Error {
|
|
if !isInputRef(term) && !isDataRef(term) {
|
|
return NewError(TypeErr, term.Location, "with keyword target must start with %v or %v", InputRootDocument, DefaultRootDocument)
|
|
}
|
|
|
|
if isDataRef(term) {
|
|
ref := term.Value.(Ref)
|
|
node := c.RuleTree
|
|
for i := 0; i < len(ref)-1; i++ {
|
|
child := node.Child(ref[i].Value)
|
|
if child == nil {
|
|
break
|
|
} else if len(child.Values) > 0 {
|
|
return NewError(CompileErr, term.Loc(), "with keyword cannot partially replace virtual document(s)")
|
|
}
|
|
node = child
|
|
}
|
|
|
|
if node != nil {
|
|
if child := node.Child(ref[len(ref)-1].Value); child != nil {
|
|
for _, value := range child.Values {
|
|
if len(value.(*Rule).Head.Args) > 0 {
|
|
return NewError(CompileErr, term.Loc(), "with keyword cannot replace functions")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func isInputRef(term *Term) bool {
|
|
if ref, ok := term.Value.(Ref); ok {
|
|
if ref.HasPrefix(InputRootRef) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func isDataRef(term *Term) bool {
|
|
if ref, ok := term.Value.(Ref); ok {
|
|
if ref.HasPrefix(DefaultRootRef) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func isVirtual(node *TreeNode, ref Ref) bool {
|
|
for i := 0; i < len(ref); i++ {
|
|
child := node.Child(ref[i].Value)
|
|
if child == nil {
|
|
return false
|
|
} else if len(child.Values) > 0 {
|
|
return true
|
|
}
|
|
node = child
|
|
}
|
|
return true
|
|
}
|
|
|
|
func safetyErrorSlice(unsafe unsafeVars) (result Errors) {
|
|
|
|
if len(unsafe) == 0 {
|
|
return
|
|
}
|
|
|
|
for _, pair := range unsafe.Vars() {
|
|
if !pair.Var.IsGenerated() {
|
|
result = append(result, NewError(UnsafeVarErr, pair.Loc, "var %v is unsafe", pair.Var))
|
|
}
|
|
}
|
|
|
|
if len(result) > 0 {
|
|
return
|
|
}
|
|
|
|
// If the expression contains unsafe generated variables, report which
|
|
// expressions are unsafe instead of the variables that are unsafe (since
|
|
// the latter are not meaningful to the user.)
|
|
pairs := unsafe.Slice()
|
|
|
|
sort.Slice(pairs, func(i, j int) bool {
|
|
return pairs[i].Expr.Location.Compare(pairs[j].Expr.Location) < 0
|
|
})
|
|
|
|
// Report at most one error per generated variable.
|
|
seen := NewVarSet()
|
|
|
|
for _, expr := range pairs {
|
|
before := len(seen)
|
|
for v := range expr.Vars {
|
|
if v.IsGenerated() {
|
|
seen.Add(v)
|
|
}
|
|
}
|
|
if len(seen) > before {
|
|
result = append(result, NewError(UnsafeVarErr, expr.Expr.Location, "expression is unsafe"))
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func checkUnsafeBuiltins(unsafeBuiltinsMap map[string]struct{}, node interface{}) Errors {
|
|
errs := make(Errors, 0)
|
|
WalkExprs(node, func(x *Expr) bool {
|
|
if x.IsCall() {
|
|
operator := x.Operator().String()
|
|
if _, ok := unsafeBuiltinsMap[operator]; ok {
|
|
errs = append(errs, NewError(TypeErr, x.Loc(), "unsafe built-in function calls in expression: %v", operator))
|
|
}
|
|
}
|
|
return false
|
|
})
|
|
return errs
|
|
}
|
|
|
|
func rewriteVarsInRef(vars ...map[Var]Var) func(Ref) Ref {
|
|
return func(node Ref) Ref {
|
|
i, _ := TransformVars(node, func(v Var) (Value, error) {
|
|
for _, m := range vars {
|
|
if u, ok := m[v]; ok {
|
|
return u, nil
|
|
}
|
|
}
|
|
return v, nil
|
|
})
|
|
return i.(Ref)
|
|
}
|
|
}
|
|
|
|
func rewriteVarsNop(node Ref) Ref {
|
|
return node
|
|
}
|