Update dependencies (#5518)
This commit is contained in:
596
vendor/github.com/open-policy-agent/opa/ast/compile.go
generated
vendored
596
vendor/github.com/open-policy-agent/opa/ast/compile.go
generated
vendored
@@ -64,6 +64,7 @@ type Compiler struct {
|
||||
// p[1] { true }
|
||||
// p[2] { true }
|
||||
// q = true
|
||||
// a.b.c = 3
|
||||
//
|
||||
// root
|
||||
// |
|
||||
@@ -74,6 +75,12 @@ type Compiler struct {
|
||||
// +--- p (2 rules)
|
||||
// |
|
||||
// +--- q (1 rule)
|
||||
// |
|
||||
// +--- a
|
||||
// |
|
||||
// +--- b
|
||||
// |
|
||||
// +--- c (1 rule)
|
||||
RuleTree *TreeNode
|
||||
|
||||
// Graph contains dependencies between rules. An edge (u,v) is added to the
|
||||
@@ -95,26 +102,27 @@ type Compiler struct {
|
||||
metricName string
|
||||
f func()
|
||||
}
|
||||
maxErrs int
|
||||
sorted []string // list of sorted module names
|
||||
pathExists func([]string) (bool, error)
|
||||
after map[string][]CompilerStageDefinition
|
||||
metrics metrics.Metrics
|
||||
capabilities *Capabilities // user-supplied capabilities
|
||||
builtins map[string]*Builtin // universe of built-in functions
|
||||
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
|
||||
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
|
||||
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
|
||||
enablePrintStatements bool // indicates if print statements should be elided (default)
|
||||
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
|
||||
initialized bool // indicates if init() has been called
|
||||
debug debug.Debug // emits debug information produced during compilation
|
||||
schemaSet *SchemaSet // user-supplied schemas for input and data documents
|
||||
inputType types.Type // global input type retrieved from schema set
|
||||
annotationSet *AnnotationSet // hierarchical set of annotations
|
||||
strict bool // enforce strict compilation checks
|
||||
keepModules bool // whether to keep the unprocessed, parse modules (below)
|
||||
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
|
||||
maxErrs int
|
||||
sorted []string // list of sorted module names
|
||||
pathExists func([]string) (bool, error)
|
||||
after map[string][]CompilerStageDefinition
|
||||
metrics metrics.Metrics
|
||||
capabilities *Capabilities // user-supplied capabilities
|
||||
builtins map[string]*Builtin // universe of built-in functions
|
||||
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
|
||||
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
|
||||
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
|
||||
enablePrintStatements bool // indicates if print statements should be elided (default)
|
||||
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
|
||||
initialized bool // indicates if init() has been called
|
||||
debug debug.Debug // emits debug information produced during compilation
|
||||
schemaSet *SchemaSet // user-supplied schemas for input and data documents
|
||||
inputType types.Type // global input type retrieved from schema set
|
||||
annotationSet *AnnotationSet // hierarchical set of annotations
|
||||
strict bool // enforce strict compilation checks
|
||||
keepModules bool // whether to keep the unprocessed, parse modules (below)
|
||||
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
|
||||
useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker
|
||||
}
|
||||
|
||||
// CompilerStage defines the interface for stages in the compiler.
|
||||
@@ -221,6 +229,9 @@ type QueryCompiler interface {
|
||||
// ComprehensionIndex returns an index data structure for the given comprehension
|
||||
// term. If no index is found, returns nil.
|
||||
ComprehensionIndex(term *Term) *ComprehensionIndex
|
||||
|
||||
// WithStrict enables strict mode for the query compiler.
|
||||
WithStrict(strict bool) QueryCompiler
|
||||
}
|
||||
|
||||
// QueryCompilerStage defines the interface for stages in the query compiler.
|
||||
@@ -265,15 +276,16 @@ func NewCompiler() *Compiler {
|
||||
// load additional modules. If any stages run before resolution, they
|
||||
// need to be re-run after resolution.
|
||||
{"ResolveRefs", "compile_stage_resolve_refs", c.resolveAllRefs},
|
||||
{"CheckKeywordOverrides", "compile_stage_check_keyword_overrides", c.checkKeywordOverrides},
|
||||
{"CheckDuplicateImports", "compile_stage_check_duplicate_imports", c.checkDuplicateImports},
|
||||
{"RemoveImports", "compile_stage_remove_imports", c.removeImports},
|
||||
{"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree},
|
||||
{"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree},
|
||||
// The local variable generator must be initialized after references are
|
||||
// resolved and the dynamic module loader has run but before subsequent
|
||||
// stages that need to generate variables.
|
||||
{"InitLocalVarGen", "compile_stage_init_local_var_gen", c.initLocalVarGen},
|
||||
{"RewriteRuleHeadRefs", "compile_stage_rewrite_rule_head_refs", c.rewriteRuleHeadRefs},
|
||||
{"CheckKeywordOverrides", "compile_stage_check_keyword_overrides", c.checkKeywordOverrides},
|
||||
{"CheckDuplicateImports", "compile_stage_check_duplicate_imports", c.checkDuplicateImports},
|
||||
{"RemoveImports", "compile_stage_remove_imports", c.removeImports},
|
||||
{"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree},
|
||||
{"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree}, // depends on RewriteRuleHeadRefs
|
||||
{"RewriteLocalVars", "compile_stage_rewrite_local_vars", c.rewriteLocalVars},
|
||||
{"CheckVoidCalls", "compile_stage_check_void_calls", c.checkVoidCalls},
|
||||
{"RewritePrintCalls", "compile_stage_rewrite_print_calls", c.rewritePrintCalls},
|
||||
@@ -396,6 +408,12 @@ func (c *Compiler) WithKeepModules(y bool) *Compiler {
|
||||
return c
|
||||
}
|
||||
|
||||
// WithUseTypeCheckAnnotations use schema annotations during type checking
|
||||
func (c *Compiler) WithUseTypeCheckAnnotations(enabled bool) *Compiler {
|
||||
c.useTypeCheckAnnotations = enabled
|
||||
return c
|
||||
}
|
||||
|
||||
// ParsedModules returns the parsed, unprocessed modules from the compiler.
|
||||
// It is `nil` if keeping modules wasn't enabled via `WithKeepModules(true)`.
|
||||
// The map includes all modules loaded via the ModuleLoader, if one was used.
|
||||
@@ -403,10 +421,10 @@ func (c *Compiler) ParsedModules() map[string]*Module {
|
||||
return c.parsedModules
|
||||
}
|
||||
|
||||
// QueryCompiler returns a new QueryCompiler object.
|
||||
func (c *Compiler) QueryCompiler() QueryCompiler {
|
||||
c.init()
|
||||
return newQueryCompiler(c)
|
||||
c0 := *c
|
||||
return newQueryCompiler(&c0)
|
||||
}
|
||||
|
||||
// Compile runs the compilation process on the input modules. The compiled
|
||||
@@ -570,9 +588,10 @@ func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) {
|
||||
return rules
|
||||
}
|
||||
|
||||
func extractRules(s []util.T) (rules []*Rule) {
|
||||
for _, r := range s {
|
||||
rules = append(rules, r.(*Rule))
|
||||
func extractRules(s []util.T) []*Rule {
|
||||
rules := make([]*Rule, len(s))
|
||||
for i := range s {
|
||||
rules[i] = s[i].(*Rule)
|
||||
}
|
||||
return rules
|
||||
}
|
||||
@@ -768,13 +787,30 @@ func (c *Compiler) buildRuleIndices() {
|
||||
if len(node.Values) == 0 {
|
||||
return false
|
||||
}
|
||||
rules := extractRules(node.Values)
|
||||
hasNonGroundKey := false
|
||||
for _, r := range rules {
|
||||
if ref := r.Head.Ref(); len(ref) > 1 {
|
||||
if !ref[len(ref)-1].IsGround() {
|
||||
hasNonGroundKey = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if hasNonGroundKey {
|
||||
// collect children: as of now, this cannot go deeper than one level,
|
||||
// so we grab those, and abort the DepthFirst processing for this branch
|
||||
for _, n := range node.Children {
|
||||
rules = append(rules, extractRules(n.Values)...)
|
||||
}
|
||||
}
|
||||
|
||||
index := newBaseDocEqIndex(func(ref Ref) bool {
|
||||
return isVirtual(c.RuleTree, ref.GroundPrefix())
|
||||
})
|
||||
if rules := extractRules(node.Values); index.Build(rules) {
|
||||
c.ruleIndices.Put(rules[0].Path(), index)
|
||||
if index.Build(rules) {
|
||||
c.ruleIndices.Put(rules[0].Ref().GroundPrefix(), index)
|
||||
}
|
||||
return false
|
||||
return hasNonGroundKey // currently, we don't allow those branches to go deeper
|
||||
})
|
||||
|
||||
}
|
||||
@@ -811,7 +847,7 @@ func (c *Compiler) checkRecursion() {
|
||||
func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b util.T) {
|
||||
tr := NewGraphTraversal(c.Graph)
|
||||
if p := util.DFSPath(tr, eq, a, b); len(p) > 0 {
|
||||
n := []string{}
|
||||
n := make([]string, 0, len(p))
|
||||
for _, x := range p {
|
||||
n = append(n, astNodeToString(x))
|
||||
}
|
||||
@@ -820,40 +856,69 @@ func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b
|
||||
}
|
||||
|
||||
func astNodeToString(x interface{}) string {
|
||||
switch x := x.(type) {
|
||||
case *Rule:
|
||||
return string(x.Head.Name)
|
||||
default:
|
||||
panic("not reached")
|
||||
}
|
||||
return x.(*Rule).Ref().String()
|
||||
}
|
||||
|
||||
// checkRuleConflicts ensures that rules definitions are not in conflict.
|
||||
func (c *Compiler) checkRuleConflicts() {
|
||||
rw := rewriteVarsInRef(c.RewrittenVars)
|
||||
|
||||
c.RuleTree.DepthFirst(func(node *TreeNode) bool {
|
||||
if len(node.Values) == 0 {
|
||||
return false
|
||||
return false // go deeper
|
||||
}
|
||||
|
||||
kinds := map[DocKind]struct{}{}
|
||||
kinds := map[RuleKind]struct{}{}
|
||||
defaultRules := 0
|
||||
arities := map[int]struct{}{}
|
||||
name := ""
|
||||
var singleValueConflicts []Ref
|
||||
|
||||
for _, rule := range node.Values {
|
||||
r := rule.(*Rule)
|
||||
kinds[r.Head.DocKind()] = struct{}{}
|
||||
ref := r.Ref()
|
||||
name = rw(ref.Copy()).String() // varRewriter operates in-place
|
||||
kinds[r.Head.RuleKind()] = struct{}{}
|
||||
arities[len(r.Head.Args)] = struct{}{}
|
||||
if r.Default {
|
||||
defaultRules++
|
||||
}
|
||||
|
||||
// Single-value rules may not have any other rules in their extent: these pairs are invalid:
|
||||
//
|
||||
// data.p.q.r { true } # data.p.q is { "r": true }
|
||||
// data.p.q.r.s { true }
|
||||
//
|
||||
// data.p.q[r] { r := input.r } # data.p.q could be { "r": true }
|
||||
// data.p.q.r.s { true }
|
||||
|
||||
// But this is allowed:
|
||||
// data.p.q[r] = 1 { r := "r" }
|
||||
// data.p.q.s = 2
|
||||
|
||||
if r.Head.RuleKind() == SingleValue && len(node.Children) > 0 {
|
||||
if len(ref) > 1 && !ref[len(ref)-1].IsGround() { // p.q[x] and p.q.s.t => check grandchildren
|
||||
for _, c := range node.Children {
|
||||
if len(c.Children) > 0 {
|
||||
singleValueConflicts = node.flattenChildren()
|
||||
break
|
||||
}
|
||||
}
|
||||
} else { // p.q.s and p.q.s.t => any children are in conflict
|
||||
singleValueConflicts = node.flattenChildren()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
name := Var(node.Key.(String))
|
||||
switch {
|
||||
case singleValueConflicts != nil:
|
||||
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "single-value rule %v conflicts with %v", name, singleValueConflicts))
|
||||
|
||||
if len(kinds) > 1 || len(arities) > 1 {
|
||||
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules named %v found", name))
|
||||
} else if defaultRules > 1 {
|
||||
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules named %s found", name))
|
||||
case len(kinds) > 1 || len(arities) > 1:
|
||||
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules %v found", name))
|
||||
|
||||
case defaultRules > 1:
|
||||
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules %s found", name))
|
||||
}
|
||||
|
||||
return false
|
||||
@@ -865,13 +930,21 @@ func (c *Compiler) checkRuleConflicts() {
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE(sr): depthfirst might better use sorted for stable errs?
|
||||
c.ModuleTree.DepthFirst(func(node *ModuleTreeNode) bool {
|
||||
for _, mod := range node.Modules {
|
||||
for _, rule := range mod.Rules {
|
||||
if childNode, ok := node.Children[String(rule.Head.Name)]; ok {
|
||||
ref := rule.Head.Ref().GroundPrefix()
|
||||
childNode, tail := node.find(ref)
|
||||
if childNode != nil {
|
||||
for _, childMod := range childNode.Modules {
|
||||
msg := fmt.Sprintf("%v conflicts with rule defined at %v", childMod.Package, rule.Loc())
|
||||
c.err(NewError(TypeErr, mod.Package.Loc(), msg))
|
||||
if childMod.Equal(mod) {
|
||||
continue // don't self-conflict
|
||||
}
|
||||
if len(tail) == 0 {
|
||||
msg := fmt.Sprintf("%v conflicts with rule %v defined at %v", childMod.Package, rule.Head.Ref(), rule.Loc())
|
||||
c.err(NewError(TypeErr, mod.Package.Loc(), msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1239,7 +1312,11 @@ func (c *Compiler) checkTypes() {
|
||||
WithSchemaSet(c.schemaSet).
|
||||
WithInputType(c.inputType).
|
||||
WithVarRewriter(rewriteVarsInRef(c.RewrittenVars))
|
||||
env, errs := checker.CheckTypes(c.TypeEnv, sorted, c.annotationSet)
|
||||
var as *AnnotationSet
|
||||
if c.useTypeCheckAnnotations {
|
||||
as = c.annotationSet
|
||||
}
|
||||
env, errs := checker.CheckTypes(c.TypeEnv, sorted, as)
|
||||
for _, err := range errs {
|
||||
c.err(err)
|
||||
}
|
||||
@@ -1363,21 +1440,29 @@ func (c *Compiler) getExports() *util.HashMap {
|
||||
|
||||
for _, name := range c.sorted {
|
||||
mod := c.Modules[name]
|
||||
rv, ok := rules.Get(mod.Package.Path)
|
||||
if !ok {
|
||||
rv = []Var{}
|
||||
}
|
||||
rvs := rv.([]Var)
|
||||
|
||||
for _, rule := range mod.Rules {
|
||||
rvs = append(rvs, rule.Head.Name)
|
||||
hashMapAdd(rules, mod.Package.Path, rule.Head.Ref().GroundPrefix())
|
||||
}
|
||||
rules.Put(mod.Package.Path, rvs)
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
func hashMapAdd(rules *util.HashMap, pkg, rule Ref) {
|
||||
prev, ok := rules.Get(pkg)
|
||||
if !ok {
|
||||
rules.Put(pkg, []Ref{rule})
|
||||
return
|
||||
}
|
||||
for _, p := range prev.([]Ref) {
|
||||
if p.Equal(rule) {
|
||||
return
|
||||
}
|
||||
}
|
||||
rules.Put(pkg, append(prev.([]Ref), rule))
|
||||
}
|
||||
|
||||
func (c *Compiler) GetAnnotationSet() *AnnotationSet {
|
||||
return c.annotationSet
|
||||
}
|
||||
@@ -1450,6 +1535,15 @@ func checkKeywordOverrides(node interface{}, strict bool) Errors {
|
||||
// p[x] { bar[_] = x }
|
||||
//
|
||||
// The reference "bar[_]" would be resolved to "data.foo.bar[_]".
|
||||
//
|
||||
// Ref rules are resolved, too:
|
||||
//
|
||||
// package a.b
|
||||
// q { c.d.e == 1 }
|
||||
// c.d[e] := 1 if e := "e"
|
||||
//
|
||||
// The reference "c.d.e" would be resolved to "data.a.b.c.d.e".
|
||||
|
||||
func (c *Compiler) resolveAllRefs() {
|
||||
|
||||
rules := c.getExports()
|
||||
@@ -1457,9 +1551,9 @@ func (c *Compiler) resolveAllRefs() {
|
||||
for _, name := range c.sorted {
|
||||
mod := c.Modules[name]
|
||||
|
||||
var ruleExports []Var
|
||||
var ruleExports []Ref
|
||||
if x, ok := rules.Get(mod.Package.Path); ok {
|
||||
ruleExports = x.([]Var)
|
||||
ruleExports = x.([]Ref)
|
||||
}
|
||||
|
||||
globals := getGlobals(mod.Package, ruleExports, mod.Imports)
|
||||
@@ -1542,6 +1636,65 @@ func (c *Compiler) rewriteExprTerms() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Compiler) rewriteRuleHeadRefs() {
|
||||
f := newEqualityFactory(c.localvargen)
|
||||
for _, name := range c.sorted {
|
||||
WalkRules(c.Modules[name], func(rule *Rule) bool {
|
||||
|
||||
ref := rule.Head.Ref()
|
||||
// NOTE(sr): We're backfilling Refs here -- all parser code paths would have them, but
|
||||
// it's possible to construct Module{} instances from Golang code, so we need
|
||||
// to accommodate for that, too.
|
||||
if len(rule.Head.Reference) == 0 {
|
||||
rule.Head.Reference = ref
|
||||
}
|
||||
|
||||
cannotSpeakRefs := true
|
||||
for _, f := range c.capabilities.Features {
|
||||
if f == FeatureRefHeadStringPrefixes {
|
||||
cannotSpeakRefs = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if cannotSpeakRefs && rule.Head.Name == "" {
|
||||
c.err(NewError(CompileErr, rule.Loc(), "rule heads with refs are not supported: %v", rule.Head.Reference))
|
||||
return true
|
||||
}
|
||||
|
||||
for i := 1; i < len(ref); i++ {
|
||||
// NOTE(sr): In the first iteration, non-string values in the refs are forbidden
|
||||
// except for the last position, e.g.
|
||||
// OK: p.q.r[s]
|
||||
// NOT OK: p[q].r.s
|
||||
// TODO(sr): This is stricter than necessary. We could allow any non-var values there,
|
||||
// but we'll also have to adjust the type tree, for example.
|
||||
if i != len(ref)-1 { // last
|
||||
if _, ok := ref[i].Value.(String); !ok {
|
||||
c.err(NewError(TypeErr, rule.Loc(), "rule head must only contain string terms (except for last): %v", ref[i]))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Rewrite so that any non-scalar elements that in the last position of
|
||||
// the rule are vars:
|
||||
// p.q.r[y.z] { ... } => p.q.r[__local0__] { __local0__ = y.z }
|
||||
// because that's what the RuleTree knows how to deal with.
|
||||
if _, ok := ref[i].Value.(Var); !ok && !IsScalar(ref[i].Value) {
|
||||
expr := f.Generate(ref[i])
|
||||
if i == len(ref)-1 && rule.Head.Key.Equal(ref[i]) {
|
||||
rule.Head.Key = expr.Operand(0)
|
||||
}
|
||||
rule.Head.Reference[i] = expr.Operand(0)
|
||||
rule.Body.Append(expr)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Compiler) checkVoidCalls() {
|
||||
for _, name := range c.sorted {
|
||||
mod := c.Modules[name]
|
||||
@@ -2044,6 +2197,9 @@ func (c *Compiler) rewriteLocalVars() {
|
||||
// Rewrite assignments in body.
|
||||
used := NewVarSet()
|
||||
|
||||
last := rule.Head.Ref()[len(rule.Head.Ref())-1]
|
||||
used.Update(last.Vars())
|
||||
|
||||
if rule.Head.Key != nil {
|
||||
used.Update(rule.Head.Key.Vars())
|
||||
}
|
||||
@@ -2076,6 +2232,9 @@ func (c *Compiler) rewriteLocalVars() {
|
||||
rule.Head.Args[i], _ = transformTerm(localXform, rule.Head.Args[i])
|
||||
}
|
||||
|
||||
for i := 1; i < len(rule.Head.Ref()); i++ {
|
||||
rule.Head.Reference[i], _ = transformTerm(localXform, rule.Head.Ref()[i])
|
||||
}
|
||||
if rule.Head.Key != nil {
|
||||
rule.Head.Key, _ = transformTerm(localXform, rule.Head.Key)
|
||||
}
|
||||
@@ -2277,6 +2436,11 @@ func newQueryCompiler(compiler *Compiler) QueryCompiler {
|
||||
return qc
|
||||
}
|
||||
|
||||
func (qc *queryCompiler) WithStrict(strict bool) QueryCompiler {
|
||||
qc.compiler.WithStrict(strict)
|
||||
return qc
|
||||
}
|
||||
|
||||
func (qc *queryCompiler) WithEnablePrintStatements(yes bool) QueryCompiler {
|
||||
qc.enablePrintStatements = yes
|
||||
return qc
|
||||
@@ -2406,10 +2570,10 @@ func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error
|
||||
pkg = &Package{Path: RefTerm(VarTerm("")).Value.(Ref)}
|
||||
}
|
||||
if pkg != nil {
|
||||
var ruleExports []Var
|
||||
var ruleExports []Ref
|
||||
rules := qc.compiler.getExports()
|
||||
if exist, ok := rules.Get(pkg.Path); ok {
|
||||
ruleExports = exist.([]Var)
|
||||
ruleExports = exist.([]Ref)
|
||||
}
|
||||
|
||||
globals = getGlobals(qctx.Package, ruleExports, qctx.Imports)
|
||||
@@ -2792,6 +2956,16 @@ type ModuleTreeNode struct {
|
||||
Hide bool
|
||||
}
|
||||
|
||||
func (n *ModuleTreeNode) String() string {
|
||||
var rules []string
|
||||
for _, m := range n.Modules {
|
||||
for _, r := range m.Rules {
|
||||
rules = append(rules, r.Head.String())
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("<ModuleTreeNode key:%v children:%v rules:%v hide:%v>", n.Key, n.Children, rules, n.Hide)
|
||||
}
|
||||
|
||||
// NewModuleTree returns a new ModuleTreeNode that represents the root
|
||||
// of the module tree populated with the given modules.
|
||||
func NewModuleTree(mods map[string]*Module) *ModuleTreeNode {
|
||||
@@ -2836,13 +3010,43 @@ func (n *ModuleTreeNode) Size() int {
|
||||
return s
|
||||
}
|
||||
|
||||
// Child returns n's child with key k.
|
||||
func (n *ModuleTreeNode) child(k Value) *ModuleTreeNode {
|
||||
switch k.(type) {
|
||||
case String, Var:
|
||||
return n.Children[k]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find dereferences ref along the tree. ref[0] is converted to a String
|
||||
// for convenience.
|
||||
func (n *ModuleTreeNode) find(ref Ref) (*ModuleTreeNode, Ref) {
|
||||
if v, ok := ref[0].Value.(Var); ok {
|
||||
ref = Ref{StringTerm(string(v))}.Concat(ref[1:])
|
||||
}
|
||||
node := n
|
||||
for i, r := range ref {
|
||||
next := node.child(r.Value)
|
||||
if next == nil {
|
||||
tail := make(Ref, len(ref)-i)
|
||||
tail[0] = VarTerm(string(ref[i].Value.(String)))
|
||||
copy(tail[1:], ref[i+1:])
|
||||
return node, tail
|
||||
}
|
||||
node = next
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
// DepthFirst performs a depth-first traversal of the module tree rooted at n.
|
||||
// If f returns true, traversal will not continue to the children of n.
|
||||
func (n *ModuleTreeNode) DepthFirst(f func(node *ModuleTreeNode) bool) {
|
||||
if !f(n) {
|
||||
for _, node := range n.Children {
|
||||
node.DepthFirst(f)
|
||||
}
|
||||
func (n *ModuleTreeNode) DepthFirst(f func(*ModuleTreeNode) bool) {
|
||||
if f(n) {
|
||||
return
|
||||
}
|
||||
for _, node := range n.Children {
|
||||
node.DepthFirst(f)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2856,49 +3060,56 @@ type TreeNode struct {
|
||||
Hide bool
|
||||
}
|
||||
|
||||
func (n *TreeNode) String() string {
|
||||
return fmt.Sprintf("<TreeNode key:%v values:%v sorted:%v hide:%v>", n.Key, n.Values, n.Sorted, n.Hide)
|
||||
}
|
||||
|
||||
// NewRuleTree returns a new TreeNode that represents the root
|
||||
// of the rule tree populated with the given rules.
|
||||
func NewRuleTree(mtree *ModuleTreeNode) *TreeNode {
|
||||
root := TreeNode{
|
||||
Key: mtree.Key,
|
||||
}
|
||||
|
||||
ruleSets := map[String][]util.T{}
|
||||
|
||||
// Build rule sets for this package.
|
||||
for _, mod := range mtree.Modules {
|
||||
for _, rule := range mod.Rules {
|
||||
key := String(rule.Head.Name)
|
||||
ruleSets[key] = append(ruleSets[key], rule)
|
||||
mtree.DepthFirst(func(m *ModuleTreeNode) bool {
|
||||
for _, mod := range m.Modules {
|
||||
if len(mod.Rules) == 0 {
|
||||
root.add(mod.Package.Path, nil)
|
||||
}
|
||||
for _, rule := range mod.Rules {
|
||||
root.add(rule.Ref().GroundPrefix(), rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Each rule set becomes a leaf node.
|
||||
children := map[Value]*TreeNode{}
|
||||
sorted := make([]Value, 0, len(ruleSets))
|
||||
|
||||
for key, rules := range ruleSets {
|
||||
sorted = append(sorted, key)
|
||||
children[key] = &TreeNode{
|
||||
Key: key,
|
||||
Children: nil,
|
||||
Values: rules,
|
||||
}
|
||||
}
|
||||
|
||||
// Each module in subpackage becomes child node.
|
||||
for key, child := range mtree.Children {
|
||||
sorted = append(sorted, key)
|
||||
children[child.Key] = NewRuleTree(child)
|
||||
}
|
||||
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].Compare(sorted[j]) < 0
|
||||
return false
|
||||
})
|
||||
|
||||
return &TreeNode{
|
||||
Key: mtree.Key,
|
||||
Values: nil,
|
||||
Children: children,
|
||||
Sorted: sorted,
|
||||
Hide: mtree.Hide,
|
||||
// ensure that data.system's TreeNode is hidden
|
||||
node, tail := root.find(DefaultRootRef.Append(NewTerm(SystemDocumentKey)))
|
||||
if len(tail) == 0 { // found
|
||||
node.Hide = true
|
||||
}
|
||||
|
||||
root.DepthFirst(func(x *TreeNode) bool {
|
||||
x.sort()
|
||||
return false
|
||||
})
|
||||
|
||||
return &root
|
||||
}
|
||||
|
||||
func (n *TreeNode) add(path Ref, rule *Rule) {
|
||||
node, tail := n.find(path)
|
||||
if len(tail) > 0 {
|
||||
sub := treeNodeFromRef(tail, rule)
|
||||
if node.Children == nil {
|
||||
node.Children = make(map[Value]*TreeNode, 1)
|
||||
}
|
||||
node.Children[sub.Key] = sub
|
||||
node.Sorted = append(node.Sorted, sub.Key)
|
||||
} else {
|
||||
if rule != nil {
|
||||
node.Values = append(node.Values, rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2914,33 +3125,95 @@ func (n *TreeNode) Size() int {
|
||||
// Child returns n's child with key k.
|
||||
func (n *TreeNode) Child(k Value) *TreeNode {
|
||||
switch k.(type) {
|
||||
case String, Var:
|
||||
case Ref, Call:
|
||||
return nil
|
||||
default:
|
||||
return n.Children[k]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find dereferences ref along the tree
|
||||
func (n *TreeNode) Find(ref Ref) *TreeNode {
|
||||
node := n
|
||||
for _, r := range ref {
|
||||
child := node.Child(r.Value)
|
||||
if child == nil {
|
||||
node = node.Child(r.Value)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
node = child
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
func (n *TreeNode) find(ref Ref) (*TreeNode, Ref) {
|
||||
node := n
|
||||
for i := range ref {
|
||||
next := node.Child(ref[i].Value)
|
||||
if next == nil {
|
||||
tail := make(Ref, len(ref)-i)
|
||||
copy(tail, ref[i:])
|
||||
return node, tail
|
||||
}
|
||||
node = next
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
// DepthFirst performs a depth-first traversal of the rule tree rooted at n. If
|
||||
// f returns true, traversal will not continue to the children of n.
|
||||
func (n *TreeNode) DepthFirst(f func(node *TreeNode) bool) {
|
||||
if !f(n) {
|
||||
for _, node := range n.Children {
|
||||
node.DepthFirst(f)
|
||||
func (n *TreeNode) DepthFirst(f func(*TreeNode) bool) {
|
||||
if f(n) {
|
||||
return
|
||||
}
|
||||
for _, node := range n.Children {
|
||||
node.DepthFirst(f)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *TreeNode) sort() {
|
||||
sort.Slice(n.Sorted, func(i, j int) bool {
|
||||
return n.Sorted[i].Compare(n.Sorted[j]) < 0
|
||||
})
|
||||
}
|
||||
|
||||
func treeNodeFromRef(ref Ref, rule *Rule) *TreeNode {
|
||||
depth := len(ref) - 1
|
||||
key := ref[depth].Value
|
||||
node := &TreeNode{
|
||||
Key: key,
|
||||
Children: nil,
|
||||
}
|
||||
if rule != nil {
|
||||
node.Values = []util.T{rule}
|
||||
}
|
||||
|
||||
for i := len(ref) - 2; i >= 0; i-- {
|
||||
key := ref[i].Value
|
||||
node = &TreeNode{
|
||||
Key: key,
|
||||
Children: map[Value]*TreeNode{ref[i+1].Value: node},
|
||||
Sorted: []Value{ref[i+1].Value},
|
||||
}
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
// flattenChildren flattens all children's rule refs into a sorted array.
|
||||
func (n *TreeNode) flattenChildren() []Ref {
|
||||
ret := newRefSet()
|
||||
for _, sub := range n.Children { // we only want the children, so don't use n.DepthFirst() right away
|
||||
sub.DepthFirst(func(x *TreeNode) bool {
|
||||
for _, r := range x.Values {
|
||||
rule := r.(*Rule)
|
||||
ret.AddPrefix(rule.Ref())
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(ret.s, func(i, j int) bool {
|
||||
return ret.s[i].Compare(ret.s[j]) < 0
|
||||
})
|
||||
return ret.s
|
||||
}
|
||||
|
||||
// Graph represents the graph of dependencies between rules.
|
||||
@@ -3554,15 +3827,14 @@ func (l *localVarGenerator) Generate() Var {
|
||||
}
|
||||
}
|
||||
|
||||
func getGlobals(pkg *Package, rules []Var, imports []*Import) map[Var]*usedRef {
|
||||
func getGlobals(pkg *Package, rules []Ref, imports []*Import) map[Var]*usedRef {
|
||||
|
||||
globals := map[Var]*usedRef{}
|
||||
globals := make(map[Var]*usedRef, len(rules)) // NB: might grow bigger with imports
|
||||
|
||||
// Populate globals with exports within the package.
|
||||
for _, v := range rules {
|
||||
global := append(Ref{}, pkg.Path...)
|
||||
global = append(global, &Term{Value: String(v)})
|
||||
globals[v] = &usedRef{ref: global}
|
||||
for _, ref := range rules {
|
||||
v := ref[0].Value.(Var)
|
||||
globals[v] = &usedRef{ref: pkg.Path.Append(StringTerm(string(v)))}
|
||||
}
|
||||
|
||||
// Populate globals with imports.
|
||||
@@ -3670,6 +3942,10 @@ func resolveRefsInRule(globals map[Var]*usedRef, rule *Rule) error {
|
||||
ignore.Push(vars)
|
||||
ignore.Push(declaredVars(rule.Body))
|
||||
|
||||
ref := rule.Head.Ref()
|
||||
for i := 1; i < len(ref); i++ {
|
||||
ref[i] = resolveRefsInTerm(globals, ignore, ref[i])
|
||||
}
|
||||
if rule.Head.Key != nil {
|
||||
rule.Head.Key = resolveRefsInTerm(globals, ignore, rule.Head.Key)
|
||||
}
|
||||
@@ -4453,7 +4729,7 @@ func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, u
|
||||
}
|
||||
|
||||
errs = checkUnusedAssignedVars(body[0].Loc(), stack, used, errs, strict)
|
||||
return cpy, checkUnusedDeclaredVars(body[0].Loc(), stack, used, cpy, errs)
|
||||
return cpy, checkUnusedDeclaredVars(body, stack, used, cpy, errs)
|
||||
}
|
||||
|
||||
func checkUnusedAssignedVars(loc *Location, stack *localDeclaredVars, used VarSet, errs Errors, strict bool) Errors {
|
||||
@@ -4491,7 +4767,7 @@ func checkUnusedAssignedVars(loc *Location, stack *localDeclaredVars, used VarSe
|
||||
return errs
|
||||
}
|
||||
|
||||
func checkUnusedDeclaredVars(loc *Location, stack *localDeclaredVars, used VarSet, cpy Body, errs Errors) Errors {
|
||||
func checkUnusedDeclaredVars(body Body, stack *localDeclaredVars, used VarSet, cpy Body, errs Errors) Errors {
|
||||
|
||||
// NOTE(tsandall): Do not generate more errors if there are existing
|
||||
// declaration errors.
|
||||
@@ -4523,7 +4799,23 @@ func checkUnusedDeclaredVars(loc *Location, stack *localDeclaredVars, used VarSe
|
||||
for _, gv := range unused.Sorted() {
|
||||
rv := dvs.reverse[gv]
|
||||
if !rv.IsGenerated() {
|
||||
errs = append(errs, NewError(CompileErr, loc, "declared var %v unused", rv))
|
||||
// Scan through body exprs, looking for a match between the
|
||||
// bad var's original name, and each expr's declared vars.
|
||||
foundUnusedVarByName := false
|
||||
for i := range body {
|
||||
varsDeclaredInExpr := declaredVars(body[i])
|
||||
if varsDeclaredInExpr.Contains(dvs.reverse[gv]) {
|
||||
// TODO(philipc): Clean up the offset logic here when the parser
|
||||
// reports more accurate locations.
|
||||
errs = append(errs, NewError(CompileErr, body[i].Loc(), "declared var %v unused", dvs.reverse[gv]))
|
||||
foundUnusedVarByName = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// Default error location returned.
|
||||
if !foundUnusedVarByName {
|
||||
errs = append(errs, NewError(CompileErr, body[0].Loc(), "declared var %v unused", dvs.reverse[gv]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4894,7 +5186,7 @@ func validateWith(c *Compiler, unsafeBuiltinsMap map[string]struct{}, expr *Expr
|
||||
for _, v := range child.Values {
|
||||
if len(v.(*Rule).Head.Args) > 0 {
|
||||
if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok {
|
||||
return false, err // may be nil
|
||||
return false, err // err may be nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4916,7 +5208,7 @@ func validateWith(c *Compiler, unsafeBuiltinsMap map[string]struct{}, expr *Expr
|
||||
}
|
||||
|
||||
if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok {
|
||||
return false, err // may be nil
|
||||
return false, err // err may be nil
|
||||
}
|
||||
default:
|
||||
return false, NewError(TypeErr, target.Location, "with keyword target must reference existing %v, %v, or a function", InputRootDocument, DefaultRootDocument)
|
||||
@@ -4985,7 +5277,7 @@ func isBuiltinRefOrVar(bs map[string]*Builtin, unsafeBuiltinsMap map[string]stru
|
||||
}
|
||||
|
||||
func isVirtual(node *TreeNode, ref Ref) bool {
|
||||
for i := 0; i < len(ref); i++ {
|
||||
for i := range ref {
|
||||
child := node.Child(ref[i].Value)
|
||||
if child == nil {
|
||||
return false
|
||||
@@ -5095,3 +5387,57 @@ func rewriteVarsInRef(vars ...map[Var]Var) varRewriter {
|
||||
return i.(Ref)
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE(sr): This is duplicated with compile/compile.go; but moving it into another location
|
||||
// would cause a circular dependency -- the refSet definition needs ast.Ref. If we make it
|
||||
// public in the ast package, the compile package could take it from there, but it would also
|
||||
// increase our public interface. Let's reconsider if we need it in a third place.
|
||||
type refSet struct {
|
||||
s []Ref
|
||||
}
|
||||
|
||||
func newRefSet(x ...Ref) *refSet {
|
||||
result := &refSet{}
|
||||
for i := range x {
|
||||
result.AddPrefix(x[i])
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ContainsPrefix returns true if r is prefixed by any of the existing refs in the set.
|
||||
func (rs *refSet) ContainsPrefix(r Ref) bool {
|
||||
for i := range rs.s {
|
||||
if r.HasPrefix(rs.s[i]) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AddPrefix inserts r into the set if r is not prefixed by any existing
|
||||
// refs in the set. If any existing refs are prefixed by r, those existing
|
||||
// refs are removed.
|
||||
func (rs *refSet) AddPrefix(r Ref) {
|
||||
if rs.ContainsPrefix(r) {
|
||||
return
|
||||
}
|
||||
cpy := []Ref{r}
|
||||
for i := range rs.s {
|
||||
if !rs.s[i].HasPrefix(r) {
|
||||
cpy = append(cpy, rs.s[i])
|
||||
}
|
||||
}
|
||||
rs.s = cpy
|
||||
}
|
||||
|
||||
// Sorted returns a sorted slice of terms for refs in the set.
|
||||
func (rs *refSet) Sorted() []*Term {
|
||||
terms := make([]*Term, len(rs.s))
|
||||
for i := range rs.s {
|
||||
terms[i] = NewTerm(rs.s[i])
|
||||
}
|
||||
sort.Slice(terms, func(i, j int) bool {
|
||||
return terms[i].Value.Compare(terms[j].Value) < 0
|
||||
})
|
||||
return terms
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user