Update dependencies (#5518)

This commit is contained in:
hongming
2023-02-12 23:09:20 +08:00
committed by GitHub
parent d3b35fb2da
commit a979342f56
1486 changed files with 126660 additions and 71128 deletions

View File

@@ -64,6 +64,7 @@ type Compiler struct {
// p[1] { true }
// p[2] { true }
// q = true
// a.b.c = 3
//
// root
// |
@@ -74,6 +75,12 @@ type Compiler struct {
// +--- p (2 rules)
// |
// +--- q (1 rule)
// |
// +--- a
// |
// +--- b
// |
// +--- c (1 rule)
RuleTree *TreeNode
// Graph contains dependencies between rules. An edge (u,v) is added to the
@@ -95,26 +102,27 @@ type Compiler struct {
metricName string
f func()
}
maxErrs int
sorted []string // list of sorted module names
pathExists func([]string) (bool, error)
after map[string][]CompilerStageDefinition
metrics metrics.Metrics
capabilities *Capabilities // user-supplied capabilities
builtins map[string]*Builtin // universe of built-in functions
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
enablePrintStatements bool // indicates if print statements should be elided (default)
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
initialized bool // indicates if init() has been called
debug debug.Debug // emits debug information produced during compilation
schemaSet *SchemaSet // user-supplied schemas for input and data documents
inputType types.Type // global input type retrieved from schema set
annotationSet *AnnotationSet // hierarchical set of annotations
strict bool // enforce strict compilation checks
keepModules bool // whether to keep the unprocessed, parse modules (below)
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
maxErrs int
sorted []string // list of sorted module names
pathExists func([]string) (bool, error)
after map[string][]CompilerStageDefinition
metrics metrics.Metrics
capabilities *Capabilities // user-supplied capabilities
builtins map[string]*Builtin // universe of built-in functions
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
enablePrintStatements bool // indicates if print statements should be elided (default)
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
initialized bool // indicates if init() has been called
debug debug.Debug // emits debug information produced during compilation
schemaSet *SchemaSet // user-supplied schemas for input and data documents
inputType types.Type // global input type retrieved from schema set
annotationSet *AnnotationSet // hierarchical set of annotations
strict bool // enforce strict compilation checks
keepModules bool // whether to keep the unprocessed, parse modules (below)
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker
}
// CompilerStage defines the interface for stages in the compiler.
@@ -221,6 +229,9 @@ type QueryCompiler interface {
// ComprehensionIndex returns an index data structure for the given comprehension
// term. If no index is found, returns nil.
ComprehensionIndex(term *Term) *ComprehensionIndex
// WithStrict enables strict mode for the query compiler.
WithStrict(strict bool) QueryCompiler
}
// QueryCompilerStage defines the interface for stages in the query compiler.
@@ -265,15 +276,16 @@ func NewCompiler() *Compiler {
// load additional modules. If any stages run before resolution, they
// need to be re-run after resolution.
{"ResolveRefs", "compile_stage_resolve_refs", c.resolveAllRefs},
{"CheckKeywordOverrides", "compile_stage_check_keyword_overrides", c.checkKeywordOverrides},
{"CheckDuplicateImports", "compile_stage_check_duplicate_imports", c.checkDuplicateImports},
{"RemoveImports", "compile_stage_remove_imports", c.removeImports},
{"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree},
{"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree},
// The local variable generator must be initialized after references are
// resolved and the dynamic module loader has run but before subsequent
// stages that need to generate variables.
{"InitLocalVarGen", "compile_stage_init_local_var_gen", c.initLocalVarGen},
{"RewriteRuleHeadRefs", "compile_stage_rewrite_rule_head_refs", c.rewriteRuleHeadRefs},
{"CheckKeywordOverrides", "compile_stage_check_keyword_overrides", c.checkKeywordOverrides},
{"CheckDuplicateImports", "compile_stage_check_duplicate_imports", c.checkDuplicateImports},
{"RemoveImports", "compile_stage_remove_imports", c.removeImports},
{"SetModuleTree", "compile_stage_set_module_tree", c.setModuleTree},
{"SetRuleTree", "compile_stage_set_rule_tree", c.setRuleTree}, // depends on RewriteRuleHeadRefs
{"RewriteLocalVars", "compile_stage_rewrite_local_vars", c.rewriteLocalVars},
{"CheckVoidCalls", "compile_stage_check_void_calls", c.checkVoidCalls},
{"RewritePrintCalls", "compile_stage_rewrite_print_calls", c.rewritePrintCalls},
@@ -396,6 +408,12 @@ func (c *Compiler) WithKeepModules(y bool) *Compiler {
return c
}
// WithUseTypeCheckAnnotations use schema annotations during type checking
func (c *Compiler) WithUseTypeCheckAnnotations(enabled bool) *Compiler {
c.useTypeCheckAnnotations = enabled
return c
}
// ParsedModules returns the parsed, unprocessed modules from the compiler.
// It is `nil` if keeping modules wasn't enabled via `WithKeepModules(true)`.
// The map includes all modules loaded via the ModuleLoader, if one was used.
@@ -403,10 +421,10 @@ func (c *Compiler) ParsedModules() map[string]*Module {
return c.parsedModules
}
// QueryCompiler returns a new QueryCompiler object.
func (c *Compiler) QueryCompiler() QueryCompiler {
c.init()
return newQueryCompiler(c)
c0 := *c
return newQueryCompiler(&c0)
}
// Compile runs the compilation process on the input modules. The compiled
@@ -570,9 +588,10 @@ func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) {
return rules
}
func extractRules(s []util.T) (rules []*Rule) {
for _, r := range s {
rules = append(rules, r.(*Rule))
func extractRules(s []util.T) []*Rule {
rules := make([]*Rule, len(s))
for i := range s {
rules[i] = s[i].(*Rule)
}
return rules
}
@@ -768,13 +787,30 @@ func (c *Compiler) buildRuleIndices() {
if len(node.Values) == 0 {
return false
}
rules := extractRules(node.Values)
hasNonGroundKey := false
for _, r := range rules {
if ref := r.Head.Ref(); len(ref) > 1 {
if !ref[len(ref)-1].IsGround() {
hasNonGroundKey = true
}
}
}
if hasNonGroundKey {
// collect children: as of now, this cannot go deeper than one level,
// so we grab those, and abort the DepthFirst processing for this branch
for _, n := range node.Children {
rules = append(rules, extractRules(n.Values)...)
}
}
index := newBaseDocEqIndex(func(ref Ref) bool {
return isVirtual(c.RuleTree, ref.GroundPrefix())
})
if rules := extractRules(node.Values); index.Build(rules) {
c.ruleIndices.Put(rules[0].Path(), index)
if index.Build(rules) {
c.ruleIndices.Put(rules[0].Ref().GroundPrefix(), index)
}
return false
return hasNonGroundKey // currently, we don't allow those branches to go deeper
})
}
@@ -811,7 +847,7 @@ func (c *Compiler) checkRecursion() {
func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b util.T) {
tr := NewGraphTraversal(c.Graph)
if p := util.DFSPath(tr, eq, a, b); len(p) > 0 {
n := []string{}
n := make([]string, 0, len(p))
for _, x := range p {
n = append(n, astNodeToString(x))
}
@@ -820,40 +856,69 @@ func (c *Compiler) checkSelfPath(loc *Location, eq func(a, b util.T) bool, a, b
}
func astNodeToString(x interface{}) string {
switch x := x.(type) {
case *Rule:
return string(x.Head.Name)
default:
panic("not reached")
}
return x.(*Rule).Ref().String()
}
// checkRuleConflicts ensures that rules definitions are not in conflict.
func (c *Compiler) checkRuleConflicts() {
rw := rewriteVarsInRef(c.RewrittenVars)
c.RuleTree.DepthFirst(func(node *TreeNode) bool {
if len(node.Values) == 0 {
return false
return false // go deeper
}
kinds := map[DocKind]struct{}{}
kinds := map[RuleKind]struct{}{}
defaultRules := 0
arities := map[int]struct{}{}
name := ""
var singleValueConflicts []Ref
for _, rule := range node.Values {
r := rule.(*Rule)
kinds[r.Head.DocKind()] = struct{}{}
ref := r.Ref()
name = rw(ref.Copy()).String() // varRewriter operates in-place
kinds[r.Head.RuleKind()] = struct{}{}
arities[len(r.Head.Args)] = struct{}{}
if r.Default {
defaultRules++
}
// Single-value rules may not have any other rules in their extent: these pairs are invalid:
//
// data.p.q.r { true } # data.p.q is { "r": true }
// data.p.q.r.s { true }
//
// data.p.q[r] { r := input.r } # data.p.q could be { "r": true }
// data.p.q.r.s { true }
// But this is allowed:
// data.p.q[r] = 1 { r := "r" }
// data.p.q.s = 2
if r.Head.RuleKind() == SingleValue && len(node.Children) > 0 {
if len(ref) > 1 && !ref[len(ref)-1].IsGround() { // p.q[x] and p.q.s.t => check grandchildren
for _, c := range node.Children {
if len(c.Children) > 0 {
singleValueConflicts = node.flattenChildren()
break
}
}
} else { // p.q.s and p.q.s.t => any children are in conflict
singleValueConflicts = node.flattenChildren()
}
}
}
name := Var(node.Key.(String))
switch {
case singleValueConflicts != nil:
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "single-value rule %v conflicts with %v", name, singleValueConflicts))
if len(kinds) > 1 || len(arities) > 1 {
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules named %v found", name))
} else if defaultRules > 1 {
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules named %s found", name))
case len(kinds) > 1 || len(arities) > 1:
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "conflicting rules %v found", name))
case defaultRules > 1:
c.err(NewError(TypeErr, node.Values[0].(*Rule).Loc(), "multiple default rules %s found", name))
}
return false
@@ -865,13 +930,21 @@ func (c *Compiler) checkRuleConflicts() {
}
}
// NOTE(sr): depthfirst might better use sorted for stable errs?
c.ModuleTree.DepthFirst(func(node *ModuleTreeNode) bool {
for _, mod := range node.Modules {
for _, rule := range mod.Rules {
if childNode, ok := node.Children[String(rule.Head.Name)]; ok {
ref := rule.Head.Ref().GroundPrefix()
childNode, tail := node.find(ref)
if childNode != nil {
for _, childMod := range childNode.Modules {
msg := fmt.Sprintf("%v conflicts with rule defined at %v", childMod.Package, rule.Loc())
c.err(NewError(TypeErr, mod.Package.Loc(), msg))
if childMod.Equal(mod) {
continue // don't self-conflict
}
if len(tail) == 0 {
msg := fmt.Sprintf("%v conflicts with rule %v defined at %v", childMod.Package, rule.Head.Ref(), rule.Loc())
c.err(NewError(TypeErr, mod.Package.Loc(), msg))
}
}
}
}
@@ -1239,7 +1312,11 @@ func (c *Compiler) checkTypes() {
WithSchemaSet(c.schemaSet).
WithInputType(c.inputType).
WithVarRewriter(rewriteVarsInRef(c.RewrittenVars))
env, errs := checker.CheckTypes(c.TypeEnv, sorted, c.annotationSet)
var as *AnnotationSet
if c.useTypeCheckAnnotations {
as = c.annotationSet
}
env, errs := checker.CheckTypes(c.TypeEnv, sorted, as)
for _, err := range errs {
c.err(err)
}
@@ -1363,21 +1440,29 @@ func (c *Compiler) getExports() *util.HashMap {
for _, name := range c.sorted {
mod := c.Modules[name]
rv, ok := rules.Get(mod.Package.Path)
if !ok {
rv = []Var{}
}
rvs := rv.([]Var)
for _, rule := range mod.Rules {
rvs = append(rvs, rule.Head.Name)
hashMapAdd(rules, mod.Package.Path, rule.Head.Ref().GroundPrefix())
}
rules.Put(mod.Package.Path, rvs)
}
return rules
}
func hashMapAdd(rules *util.HashMap, pkg, rule Ref) {
prev, ok := rules.Get(pkg)
if !ok {
rules.Put(pkg, []Ref{rule})
return
}
for _, p := range prev.([]Ref) {
if p.Equal(rule) {
return
}
}
rules.Put(pkg, append(prev.([]Ref), rule))
}
func (c *Compiler) GetAnnotationSet() *AnnotationSet {
return c.annotationSet
}
@@ -1450,6 +1535,15 @@ func checkKeywordOverrides(node interface{}, strict bool) Errors {
// p[x] { bar[_] = x }
//
// The reference "bar[_]" would be resolved to "data.foo.bar[_]".
//
// Ref rules are resolved, too:
//
// package a.b
// q { c.d.e == 1 }
// c.d[e] := 1 if e := "e"
//
// The reference "c.d.e" would be resolved to "data.a.b.c.d.e".
func (c *Compiler) resolveAllRefs() {
rules := c.getExports()
@@ -1457,9 +1551,9 @@ func (c *Compiler) resolveAllRefs() {
for _, name := range c.sorted {
mod := c.Modules[name]
var ruleExports []Var
var ruleExports []Ref
if x, ok := rules.Get(mod.Package.Path); ok {
ruleExports = x.([]Var)
ruleExports = x.([]Ref)
}
globals := getGlobals(mod.Package, ruleExports, mod.Imports)
@@ -1542,6 +1636,65 @@ func (c *Compiler) rewriteExprTerms() {
}
}
func (c *Compiler) rewriteRuleHeadRefs() {
f := newEqualityFactory(c.localvargen)
for _, name := range c.sorted {
WalkRules(c.Modules[name], func(rule *Rule) bool {
ref := rule.Head.Ref()
// NOTE(sr): We're backfilling Refs here -- all parser code paths would have them, but
// it's possible to construct Module{} instances from Golang code, so we need
// to accommodate for that, too.
if len(rule.Head.Reference) == 0 {
rule.Head.Reference = ref
}
cannotSpeakRefs := true
for _, f := range c.capabilities.Features {
if f == FeatureRefHeadStringPrefixes {
cannotSpeakRefs = false
break
}
}
if cannotSpeakRefs && rule.Head.Name == "" {
c.err(NewError(CompileErr, rule.Loc(), "rule heads with refs are not supported: %v", rule.Head.Reference))
return true
}
for i := 1; i < len(ref); i++ {
// NOTE(sr): In the first iteration, non-string values in the refs are forbidden
// except for the last position, e.g.
// OK: p.q.r[s]
// NOT OK: p[q].r.s
// TODO(sr): This is stricter than necessary. We could allow any non-var values there,
// but we'll also have to adjust the type tree, for example.
if i != len(ref)-1 { // last
if _, ok := ref[i].Value.(String); !ok {
c.err(NewError(TypeErr, rule.Loc(), "rule head must only contain string terms (except for last): %v", ref[i]))
continue
}
}
// Rewrite so that any non-scalar elements that in the last position of
// the rule are vars:
// p.q.r[y.z] { ... } => p.q.r[__local0__] { __local0__ = y.z }
// because that's what the RuleTree knows how to deal with.
if _, ok := ref[i].Value.(Var); !ok && !IsScalar(ref[i].Value) {
expr := f.Generate(ref[i])
if i == len(ref)-1 && rule.Head.Key.Equal(ref[i]) {
rule.Head.Key = expr.Operand(0)
}
rule.Head.Reference[i] = expr.Operand(0)
rule.Body.Append(expr)
}
}
return true
})
}
}
func (c *Compiler) checkVoidCalls() {
for _, name := range c.sorted {
mod := c.Modules[name]
@@ -2044,6 +2197,9 @@ func (c *Compiler) rewriteLocalVars() {
// Rewrite assignments in body.
used := NewVarSet()
last := rule.Head.Ref()[len(rule.Head.Ref())-1]
used.Update(last.Vars())
if rule.Head.Key != nil {
used.Update(rule.Head.Key.Vars())
}
@@ -2076,6 +2232,9 @@ func (c *Compiler) rewriteLocalVars() {
rule.Head.Args[i], _ = transformTerm(localXform, rule.Head.Args[i])
}
for i := 1; i < len(rule.Head.Ref()); i++ {
rule.Head.Reference[i], _ = transformTerm(localXform, rule.Head.Ref()[i])
}
if rule.Head.Key != nil {
rule.Head.Key, _ = transformTerm(localXform, rule.Head.Key)
}
@@ -2277,6 +2436,11 @@ func newQueryCompiler(compiler *Compiler) QueryCompiler {
return qc
}
func (qc *queryCompiler) WithStrict(strict bool) QueryCompiler {
qc.compiler.WithStrict(strict)
return qc
}
func (qc *queryCompiler) WithEnablePrintStatements(yes bool) QueryCompiler {
qc.enablePrintStatements = yes
return qc
@@ -2406,10 +2570,10 @@ func (qc *queryCompiler) resolveRefs(qctx *QueryContext, body Body) (Body, error
pkg = &Package{Path: RefTerm(VarTerm("")).Value.(Ref)}
}
if pkg != nil {
var ruleExports []Var
var ruleExports []Ref
rules := qc.compiler.getExports()
if exist, ok := rules.Get(pkg.Path); ok {
ruleExports = exist.([]Var)
ruleExports = exist.([]Ref)
}
globals = getGlobals(qctx.Package, ruleExports, qctx.Imports)
@@ -2792,6 +2956,16 @@ type ModuleTreeNode struct {
Hide bool
}
func (n *ModuleTreeNode) String() string {
var rules []string
for _, m := range n.Modules {
for _, r := range m.Rules {
rules = append(rules, r.Head.String())
}
}
return fmt.Sprintf("<ModuleTreeNode key:%v children:%v rules:%v hide:%v>", n.Key, n.Children, rules, n.Hide)
}
// NewModuleTree returns a new ModuleTreeNode that represents the root
// of the module tree populated with the given modules.
func NewModuleTree(mods map[string]*Module) *ModuleTreeNode {
@@ -2836,13 +3010,43 @@ func (n *ModuleTreeNode) Size() int {
return s
}
// Child returns n's child with key k.
func (n *ModuleTreeNode) child(k Value) *ModuleTreeNode {
switch k.(type) {
case String, Var:
return n.Children[k]
}
return nil
}
// Find dereferences ref along the tree. ref[0] is converted to a String
// for convenience.
func (n *ModuleTreeNode) find(ref Ref) (*ModuleTreeNode, Ref) {
if v, ok := ref[0].Value.(Var); ok {
ref = Ref{StringTerm(string(v))}.Concat(ref[1:])
}
node := n
for i, r := range ref {
next := node.child(r.Value)
if next == nil {
tail := make(Ref, len(ref)-i)
tail[0] = VarTerm(string(ref[i].Value.(String)))
copy(tail[1:], ref[i+1:])
return node, tail
}
node = next
}
return node, nil
}
// DepthFirst performs a depth-first traversal of the module tree rooted at n.
// If f returns true, traversal will not continue to the children of n.
func (n *ModuleTreeNode) DepthFirst(f func(node *ModuleTreeNode) bool) {
if !f(n) {
for _, node := range n.Children {
node.DepthFirst(f)
}
func (n *ModuleTreeNode) DepthFirst(f func(*ModuleTreeNode) bool) {
if f(n) {
return
}
for _, node := range n.Children {
node.DepthFirst(f)
}
}
@@ -2856,49 +3060,56 @@ type TreeNode struct {
Hide bool
}
func (n *TreeNode) String() string {
return fmt.Sprintf("<TreeNode key:%v values:%v sorted:%v hide:%v>", n.Key, n.Values, n.Sorted, n.Hide)
}
// NewRuleTree returns a new TreeNode that represents the root
// of the rule tree populated with the given rules.
func NewRuleTree(mtree *ModuleTreeNode) *TreeNode {
root := TreeNode{
Key: mtree.Key,
}
ruleSets := map[String][]util.T{}
// Build rule sets for this package.
for _, mod := range mtree.Modules {
for _, rule := range mod.Rules {
key := String(rule.Head.Name)
ruleSets[key] = append(ruleSets[key], rule)
mtree.DepthFirst(func(m *ModuleTreeNode) bool {
for _, mod := range m.Modules {
if len(mod.Rules) == 0 {
root.add(mod.Package.Path, nil)
}
for _, rule := range mod.Rules {
root.add(rule.Ref().GroundPrefix(), rule)
}
}
}
// Each rule set becomes a leaf node.
children := map[Value]*TreeNode{}
sorted := make([]Value, 0, len(ruleSets))
for key, rules := range ruleSets {
sorted = append(sorted, key)
children[key] = &TreeNode{
Key: key,
Children: nil,
Values: rules,
}
}
// Each module in subpackage becomes child node.
for key, child := range mtree.Children {
sorted = append(sorted, key)
children[child.Key] = NewRuleTree(child)
}
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Compare(sorted[j]) < 0
return false
})
return &TreeNode{
Key: mtree.Key,
Values: nil,
Children: children,
Sorted: sorted,
Hide: mtree.Hide,
// ensure that data.system's TreeNode is hidden
node, tail := root.find(DefaultRootRef.Append(NewTerm(SystemDocumentKey)))
if len(tail) == 0 { // found
node.Hide = true
}
root.DepthFirst(func(x *TreeNode) bool {
x.sort()
return false
})
return &root
}
func (n *TreeNode) add(path Ref, rule *Rule) {
node, tail := n.find(path)
if len(tail) > 0 {
sub := treeNodeFromRef(tail, rule)
if node.Children == nil {
node.Children = make(map[Value]*TreeNode, 1)
}
node.Children[sub.Key] = sub
node.Sorted = append(node.Sorted, sub.Key)
} else {
if rule != nil {
node.Values = append(node.Values, rule)
}
}
}
@@ -2914,33 +3125,95 @@ func (n *TreeNode) Size() int {
// Child returns n's child with key k.
func (n *TreeNode) Child(k Value) *TreeNode {
switch k.(type) {
case String, Var:
case Ref, Call:
return nil
default:
return n.Children[k]
}
return nil
}
// Find dereferences ref along the tree
func (n *TreeNode) Find(ref Ref) *TreeNode {
node := n
for _, r := range ref {
child := node.Child(r.Value)
if child == nil {
node = node.Child(r.Value)
if node == nil {
return nil
}
node = child
}
return node
}
func (n *TreeNode) find(ref Ref) (*TreeNode, Ref) {
node := n
for i := range ref {
next := node.Child(ref[i].Value)
if next == nil {
tail := make(Ref, len(ref)-i)
copy(tail, ref[i:])
return node, tail
}
node = next
}
return node, nil
}
// DepthFirst performs a depth-first traversal of the rule tree rooted at n. If
// f returns true, traversal will not continue to the children of n.
func (n *TreeNode) DepthFirst(f func(node *TreeNode) bool) {
if !f(n) {
for _, node := range n.Children {
node.DepthFirst(f)
func (n *TreeNode) DepthFirst(f func(*TreeNode) bool) {
if f(n) {
return
}
for _, node := range n.Children {
node.DepthFirst(f)
}
}
func (n *TreeNode) sort() {
sort.Slice(n.Sorted, func(i, j int) bool {
return n.Sorted[i].Compare(n.Sorted[j]) < 0
})
}
func treeNodeFromRef(ref Ref, rule *Rule) *TreeNode {
depth := len(ref) - 1
key := ref[depth].Value
node := &TreeNode{
Key: key,
Children: nil,
}
if rule != nil {
node.Values = []util.T{rule}
}
for i := len(ref) - 2; i >= 0; i-- {
key := ref[i].Value
node = &TreeNode{
Key: key,
Children: map[Value]*TreeNode{ref[i+1].Value: node},
Sorted: []Value{ref[i+1].Value},
}
}
return node
}
// flattenChildren flattens all children's rule refs into a sorted array.
func (n *TreeNode) flattenChildren() []Ref {
ret := newRefSet()
for _, sub := range n.Children { // we only want the children, so don't use n.DepthFirst() right away
sub.DepthFirst(func(x *TreeNode) bool {
for _, r := range x.Values {
rule := r.(*Rule)
ret.AddPrefix(rule.Ref())
}
return false
})
}
sort.Slice(ret.s, func(i, j int) bool {
return ret.s[i].Compare(ret.s[j]) < 0
})
return ret.s
}
// Graph represents the graph of dependencies between rules.
@@ -3554,15 +3827,14 @@ func (l *localVarGenerator) Generate() Var {
}
}
func getGlobals(pkg *Package, rules []Var, imports []*Import) map[Var]*usedRef {
func getGlobals(pkg *Package, rules []Ref, imports []*Import) map[Var]*usedRef {
globals := map[Var]*usedRef{}
globals := make(map[Var]*usedRef, len(rules)) // NB: might grow bigger with imports
// Populate globals with exports within the package.
for _, v := range rules {
global := append(Ref{}, pkg.Path...)
global = append(global, &Term{Value: String(v)})
globals[v] = &usedRef{ref: global}
for _, ref := range rules {
v := ref[0].Value.(Var)
globals[v] = &usedRef{ref: pkg.Path.Append(StringTerm(string(v)))}
}
// Populate globals with imports.
@@ -3670,6 +3942,10 @@ func resolveRefsInRule(globals map[Var]*usedRef, rule *Rule) error {
ignore.Push(vars)
ignore.Push(declaredVars(rule.Body))
ref := rule.Head.Ref()
for i := 1; i < len(ref); i++ {
ref[i] = resolveRefsInTerm(globals, ignore, ref[i])
}
if rule.Head.Key != nil {
rule.Head.Key = resolveRefsInTerm(globals, ignore, rule.Head.Key)
}
@@ -4453,7 +4729,7 @@ func rewriteDeclaredVarsInBody(g *localVarGenerator, stack *localDeclaredVars, u
}
errs = checkUnusedAssignedVars(body[0].Loc(), stack, used, errs, strict)
return cpy, checkUnusedDeclaredVars(body[0].Loc(), stack, used, cpy, errs)
return cpy, checkUnusedDeclaredVars(body, stack, used, cpy, errs)
}
func checkUnusedAssignedVars(loc *Location, stack *localDeclaredVars, used VarSet, errs Errors, strict bool) Errors {
@@ -4491,7 +4767,7 @@ func checkUnusedAssignedVars(loc *Location, stack *localDeclaredVars, used VarSe
return errs
}
func checkUnusedDeclaredVars(loc *Location, stack *localDeclaredVars, used VarSet, cpy Body, errs Errors) Errors {
func checkUnusedDeclaredVars(body Body, stack *localDeclaredVars, used VarSet, cpy Body, errs Errors) Errors {
// NOTE(tsandall): Do not generate more errors if there are existing
// declaration errors.
@@ -4523,7 +4799,23 @@ func checkUnusedDeclaredVars(loc *Location, stack *localDeclaredVars, used VarSe
for _, gv := range unused.Sorted() {
rv := dvs.reverse[gv]
if !rv.IsGenerated() {
errs = append(errs, NewError(CompileErr, loc, "declared var %v unused", rv))
// Scan through body exprs, looking for a match between the
// bad var's original name, and each expr's declared vars.
foundUnusedVarByName := false
for i := range body {
varsDeclaredInExpr := declaredVars(body[i])
if varsDeclaredInExpr.Contains(dvs.reverse[gv]) {
// TODO(philipc): Clean up the offset logic here when the parser
// reports more accurate locations.
errs = append(errs, NewError(CompileErr, body[i].Loc(), "declared var %v unused", dvs.reverse[gv]))
foundUnusedVarByName = true
break
}
}
// Default error location returned.
if !foundUnusedVarByName {
errs = append(errs, NewError(CompileErr, body[0].Loc(), "declared var %v unused", dvs.reverse[gv]))
}
}
}
@@ -4894,7 +5186,7 @@ func validateWith(c *Compiler, unsafeBuiltinsMap map[string]struct{}, expr *Expr
for _, v := range child.Values {
if len(v.(*Rule).Head.Args) > 0 {
if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok {
return false, err // may be nil
return false, err // err may be nil
}
}
}
@@ -4916,7 +5208,7 @@ func validateWith(c *Compiler, unsafeBuiltinsMap map[string]struct{}, expr *Expr
}
if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok {
return false, err // may be nil
return false, err // err may be nil
}
default:
return false, NewError(TypeErr, target.Location, "with keyword target must reference existing %v, %v, or a function", InputRootDocument, DefaultRootDocument)
@@ -4985,7 +5277,7 @@ func isBuiltinRefOrVar(bs map[string]*Builtin, unsafeBuiltinsMap map[string]stru
}
func isVirtual(node *TreeNode, ref Ref) bool {
for i := 0; i < len(ref); i++ {
for i := range ref {
child := node.Child(ref[i].Value)
if child == nil {
return false
@@ -5095,3 +5387,57 @@ func rewriteVarsInRef(vars ...map[Var]Var) varRewriter {
return i.(Ref)
}
}
// NOTE(sr): This is duplicated with compile/compile.go; but moving it into another location
// would cause a circular dependency -- the refSet definition needs ast.Ref. If we make it
// public in the ast package, the compile package could take it from there, but it would also
// increase our public interface. Let's reconsider if we need it in a third place.
type refSet struct {
s []Ref
}
func newRefSet(x ...Ref) *refSet {
result := &refSet{}
for i := range x {
result.AddPrefix(x[i])
}
return result
}
// ContainsPrefix returns true if r is prefixed by any of the existing refs in the set.
func (rs *refSet) ContainsPrefix(r Ref) bool {
for i := range rs.s {
if r.HasPrefix(rs.s[i]) {
return true
}
}
return false
}
// AddPrefix inserts r into the set if r is not prefixed by any existing
// refs in the set. If any existing refs are prefixed by r, those existing
// refs are removed.
func (rs *refSet) AddPrefix(r Ref) {
if rs.ContainsPrefix(r) {
return
}
cpy := []Ref{r}
for i := range rs.s {
if !rs.s[i].HasPrefix(r) {
cpy = append(cpy, rs.s[i])
}
}
rs.s = cpy
}
// Sorted returns a sorted slice of terms for refs in the set.
func (rs *refSet) Sorted() []*Term {
terms := make([]*Term, len(rs.s))
for i := range rs.s {
terms[i] = NewTerm(rs.s[i])
}
sort.Slice(terms, func(i, j int) bool {
return terms[i].Value.Compare(terms[j].Value) < 0
})
return terms
}