update dependencies (#6267)

Signed-off-by: hongming <coder.scala@gmail.com>
This commit is contained in:
hongming
2024-11-06 10:27:06 +08:00
committed by GitHub
parent faf255a084
commit cfebd96a1f
4263 changed files with 341374 additions and 132036 deletions

View File

@@ -120,34 +120,35 @@ type Compiler struct {
// Capabliities required by the modules that were compiled.
Required *Capabilities
localvargen *localVarGenerator
moduleLoader ModuleLoader
ruleIndices *util.HashMap
stages []stage
maxErrs int
sorted []string // list of sorted module names
pathExists func([]string) (bool, error)
after map[string][]CompilerStageDefinition
metrics metrics.Metrics
capabilities *Capabilities // user-supplied capabilities
imports map[string][]*Import // saved imports from stripping
builtins map[string]*Builtin // universe of built-in functions
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
enablePrintStatements bool // indicates if print statements should be elided (default)
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
initialized bool // indicates if init() has been called
debug debug.Debug // emits debug information produced during compilation
schemaSet *SchemaSet // user-supplied schemas for input and data documents
inputType types.Type // global input type retrieved from schema set
annotationSet *AnnotationSet // hierarchical set of annotations
strict bool // enforce strict compilation checks
keepModules bool // whether to keep the unprocessed, parse modules (below)
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker
allowUndefinedFuncCalls bool // don't error on calls to unknown functions.
evalMode CompilerEvalMode
localvargen *localVarGenerator
moduleLoader ModuleLoader
ruleIndices *util.HashMap
stages []stage
maxErrs int
sorted []string // list of sorted module names
pathExists func([]string) (bool, error)
after map[string][]CompilerStageDefinition
metrics metrics.Metrics
capabilities *Capabilities // user-supplied capabilities
imports map[string][]*Import // saved imports from stripping
builtins map[string]*Builtin // universe of built-in functions
customBuiltins map[string]*Builtin // user-supplied custom built-in functions (deprecated: use capabilities)
unsafeBuiltinsMap map[string]struct{} // user-supplied set of unsafe built-ins functions to block (deprecated: use capabilities)
deprecatedBuiltinsMap map[string]struct{} // set of deprecated, but not removed, built-in functions
enablePrintStatements bool // indicates if print statements should be elided (default)
comprehensionIndices map[*Term]*ComprehensionIndex // comprehension key index
initialized bool // indicates if init() has been called
debug debug.Debug // emits debug information produced during compilation
schemaSet *SchemaSet // user-supplied schemas for input and data documents
inputType types.Type // global input type retrieved from schema set
annotationSet *AnnotationSet // hierarchical set of annotations
strict bool // enforce strict compilation checks
keepModules bool // whether to keep the unprocessed, parse modules (below)
parsedModules map[string]*Module // parsed, but otherwise unprocessed modules, kept track of when keepModules is true
useTypeCheckAnnotations bool // whether to provide annotated information (schemas) to the type checker
allowUndefinedFuncCalls bool // don't error on calls to unknown functions.
evalMode CompilerEvalMode //
rewriteTestRulesForTracing bool // rewrite test rules to capture dynamic values for tracing.
}
// CompilerStage defines the interface for stages in the compiler.
@@ -346,6 +347,7 @@ func NewCompiler() *Compiler {
{"CheckSafetyRuleBodies", "compile_stage_check_safety_rule_bodies", c.checkSafetyRuleBodies},
{"RewriteEquals", "compile_stage_rewrite_equals", c.rewriteEquals},
{"RewriteDynamicTerms", "compile_stage_rewrite_dynamic_terms", c.rewriteDynamicTerms},
{"RewriteTestRulesForTracing", "compile_stage_rewrite_test_rules_for_tracing", c.rewriteTestRuleEqualities}, // must run after RewriteDynamicTerms
{"CheckRecursion", "compile_stage_check_recursion", c.checkRecursion},
{"CheckTypes", "compile_stage_check_types", c.checkTypes}, // must be run after CheckRecursion
{"CheckUnsafeBuiltins", "compile_state_check_unsafe_builtins", c.checkUnsafeBuiltins},
@@ -469,6 +471,13 @@ func (c *Compiler) WithEvalMode(e CompilerEvalMode) *Compiler {
return c
}
// WithRewriteTestRules enables rewriting test rules to capture dynamic values in local variables,
// so they can be accessed by tracing.
func (c *Compiler) WithRewriteTestRules(rewrite bool) *Compiler {
c.rewriteTestRulesForTracing = rewrite
return c
}
// ParsedModules returns the parsed, unprocessed modules from the compiler.
// It is `nil` if keeping modules wasn't enabled via `WithKeepModules(true)`.
// The map includes all modules loaded via the ModuleLoader, if one was used.
@@ -1585,7 +1594,7 @@ func (c *Compiler) compile() {
}
}
if c.allowUndefinedFuncCalls && s.name == "CheckUndefinedFuncs" {
if c.allowUndefinedFuncCalls && (s.name == "CheckUndefinedFuncs" || s.name == "CheckSafetyRuleBodies") {
continue
}
@@ -2167,6 +2176,43 @@ func (c *Compiler) rewriteDynamicTerms() {
}
}
// rewriteTestRuleEqualities rewrites equality expressions in test rule bodies to create local vars for statements that would otherwise
// not have their values captured through tracing, such as refs and comprehensions not unified/assigned to a local var.
// For example, given the following module:
//
// package test
//
// p.q contains v if {
// some v in numbers.range(1, 3)
// }
//
// p.r := "foo"
//
// test_rule {
// p == {
// "q": {4, 5, 6}
// }
// }
//
// `p` in `test_rule` resolves to `data.test.p`, which won't be an entry in the virtual-cache and must therefore be calculated after-the-fact.
// If `p` isn't captured in a local var, there is no trivial way to retrieve its value for test reporting.
func (c *Compiler) rewriteTestRuleEqualities() {
if !c.rewriteTestRulesForTracing {
return
}
f := newEqualityFactory(c.localvargen)
for _, name := range c.sorted {
mod := c.Modules[name]
WalkRules(mod, func(rule *Rule) bool {
if strings.HasPrefix(string(rule.Head.Name), "test_") {
rule.Body = rewriteTestEqualities(f, rule.Body)
}
return false
})
}
}
func (c *Compiler) parseMetadataBlocks() {
// Only parse annotations if rego.metadata built-ins are called
regoMetadataCalled := false
@@ -2196,6 +2242,8 @@ func (c *Compiler) parseMetadataBlocks() {
for _, err := range errs {
c.err(err)
}
attachRuleAnnotations(mod)
}
}
}
@@ -4192,7 +4240,7 @@ func resolveRefsInRule(globals map[Var]*usedRef, rule *Rule) error {
// Object keys cannot be pattern matched so only walk values.
case *object:
x.Foreach(func(k, v *Term) {
x.Foreach(func(_, v *Term) {
vis.Walk(v)
})
@@ -4515,6 +4563,41 @@ func rewriteEquals(x interface{}) (modified bool) {
return modified
}
func rewriteTestEqualities(f *equalityFactory, body Body) Body {
result := make(Body, 0, len(body))
for _, expr := range body {
// We can't rewrite negated expressions; if the extracted term is undefined, evaluation would fail before
// reaching the negation check.
if !expr.Negated && !expr.Generated {
switch {
case expr.IsEquality():
terms := expr.Terms.([]*Term)
result, terms[1] = rewriteDynamicsShallow(expr, f, terms[1], result)
result, terms[2] = rewriteDynamicsShallow(expr, f, terms[2], result)
case expr.IsEvery():
// We rewrite equalities inside of every-bodies as a fail here will be the cause of the test-rule fail.
// Failures inside other expressions with closures, such as comprehensions, won't cause the test-rule to fail, so we skip those.
every := expr.Terms.(*Every)
every.Body = rewriteTestEqualities(f, every.Body)
}
}
result = appendExpr(result, expr)
}
return result
}
func rewriteDynamicsShallow(original *Expr, f *equalityFactory, term *Term, result Body) (Body, *Term) {
switch term.Value.(type) {
case Ref, *ArrayComprehension, *SetComprehension, *ObjectComprehension:
generated := f.Generate(term)
generated.With = original.With
result.Append(generated)
connectGeneratedExprs(original, generated)
return result, result[len(result)-1].Operand(0)
}
return result, term
}
// rewriteDynamics will rewrite the body so that dynamic terms (i.e., refs and
// comprehensions) are bound to vars earlier in the query. This translation
// results in eager evaluation.
@@ -4606,6 +4689,7 @@ func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result B
generated := f.Generate(term)
generated.With = original.With
result.Append(generated)
connectGeneratedExprs(original, generated)
return result, result[len(result)-1].Operand(0)
case *Array:
for i := 0; i < v.Len(); i++ {
@@ -4634,16 +4718,19 @@ func rewriteDynamicsOne(original *Expr, f *equalityFactory, term *Term, result B
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
connectGeneratedExprs(original, extra)
return result, result[len(result)-1].Operand(0)
case *SetComprehension:
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
connectGeneratedExprs(original, extra)
return result, result[len(result)-1].Operand(0)
case *ObjectComprehension:
var extra *Expr
v.Body, extra = rewriteDynamicsComprehensionBody(original, f, v.Body, term)
result.Append(extra)
connectGeneratedExprs(original, extra)
return result, result[len(result)-1].Operand(0)
}
return result, term
@@ -4711,6 +4798,7 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
for i := 1; i < len(terms); i++ {
var extras []*Expr
extras, terms[i] = expandExprTerm(gen, terms[i])
connectGeneratedExprs(expr, extras...)
if len(expr.With) > 0 {
for i := range extras {
extras[i].With = expr.With
@@ -4721,16 +4809,14 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
result = append(result, expr)
case *Every:
var extras []*Expr
if _, ok := terms.Domain.Value.(Call); ok {
extras, terms.Domain = expandExprTerm(gen, terms.Domain)
} else {
term := NewTerm(gen.Generate()).SetLocation(terms.Domain.Location)
eq := Equality.Expr(term, terms.Domain).SetLocation(terms.Domain.Location)
eq.Generated = true
eq.With = expr.With
extras = append(extras, eq)
terms.Domain = term
}
term := NewTerm(gen.Generate()).SetLocation(terms.Domain.Location)
eq := Equality.Expr(term, terms.Domain).SetLocation(terms.Domain.Location)
eq.Generated = true
eq.With = expr.With
extras = expandExpr(gen, eq)
terms.Domain = term
terms.Body = rewriteExprTermsInBody(gen, terms.Body)
result = append(result, extras...)
result = append(result, expr)
@@ -4738,6 +4824,13 @@ func expandExpr(gen *localVarGenerator, expr *Expr) (result []*Expr) {
return
}
func connectGeneratedExprs(parent *Expr, children ...*Expr) {
for _, child := range children {
child.generatedFrom = parent
parent.generates = append(parent.generates, child)
}
}
func expandExprTerm(gen *localVarGenerator, term *Term) (support []*Expr, output *Term) {
output = term
switch v := term.Value.(type) {
@@ -5239,8 +5332,7 @@ func rewriteDeclaredVarsInExpr(g *localVarGenerator, stack *localDeclaredVars, e
case *Term:
stop, errs = rewriteDeclaredVarsInTerm(g, stack, x, errs, strict)
case *With:
errs = rewriteDeclaredVarsInTermRecursive(g, stack, x.Value, errs, strict)
stop = true
stop, errs = true, rewriteDeclaredVarsInWithRecursive(g, stack, x, errs, strict)
}
return stop
})
@@ -5373,20 +5465,38 @@ func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, t
}
func rewriteDeclaredVarsInTermRecursive(g *localVarGenerator, stack *localDeclaredVars, term *Term, errs Errors, strict bool) Errors {
WalkNodes(term, func(n Node) bool {
WalkTerms(term, func(t *Term) bool {
var stop bool
switch n := n.(type) {
case *With:
errs = rewriteDeclaredVarsInTermRecursive(g, stack, n.Value, errs, strict)
stop = true
case *Term:
stop, errs = rewriteDeclaredVarsInTerm(g, stack, n, errs, strict)
}
stop, errs = rewriteDeclaredVarsInTerm(g, stack, t, errs, strict)
return stop
})
return errs
}
func rewriteDeclaredVarsInWithRecursive(g *localVarGenerator, stack *localDeclaredVars, w *With, errs Errors, strict bool) Errors {
// NOTE(sr): `with input as` and `with input.a.b.c as` are deliberately skipped here: `input` could
// have been shadowed by a local variable/argument but should NOT be replaced in the `with` target.
//
// We cannot drop `input` from the stack since it's conceivable to do `with input[input] as` where
// the second input is meant to be the local var. It's a terrible idea, but when you're shadowing
// `input` those might be your thing.
errs = rewriteDeclaredVarsInTermRecursive(g, stack, w.Target, errs, strict)
if sdwInput, ok := stack.Declared(InputRootDocument.Value.(Var)); ok { // Was "input" shadowed...
switch value := w.Target.Value.(type) {
case Var:
if sdwInput.Equal(value) { // ...and replaced? If so, fix it
w.Target.Value = InputRootRef
}
case Ref:
if sdwInput.Equal(value[0].Value.(Var)) {
w.Target.Value.(Ref)[0].Value = InputRootDocument.Value
}
}
}
// No special handling of the `with` value
return rewriteDeclaredVarsInTermRecursive(g, stack, w.Value, errs, strict)
}
func rewriteDeclaredVarsInArrayComprehension(g *localVarGenerator, stack *localDeclaredVars, v *ArrayComprehension, errs Errors, strict bool) Errors {
used := NewVarSet()
used.Update(v.Term.Vars())
@@ -5492,26 +5602,34 @@ func validateWith(c *Compiler, unsafeBuiltinsMap map[string]struct{}, expr *Expr
return false, err
}
isAllowedUnknownFuncCall := false
if c.allowUndefinedFuncCalls {
switch target.Value.(type) {
case Ref, Var:
isAllowedUnknownFuncCall = true
}
}
switch {
case isDataRef(target):
ref := target.Value.(Ref)
node := c.RuleTree
targetNode := c.RuleTree
for i := 0; i < len(ref)-1; i++ {
child := node.Child(ref[i].Value)
child := targetNode.Child(ref[i].Value)
if child == nil {
break
} else if len(child.Values) > 0 {
return false, NewError(CompileErr, target.Loc(), "with keyword cannot partially replace virtual document(s)")
}
node = child
targetNode = child
}
if node != nil {
if targetNode != nil {
// NOTE(sr): at this point in the compiler stages, we don't have a fully-populated
// TypeEnv yet -- so we have to make do with this check to see if the replacement
// target is a function. It's probably wrong for arity-0 functions, but those are
// and edge case anyways.
if child := node.Child(ref[len(ref)-1].Value); child != nil {
if child := targetNode.Child(ref[len(ref)-1].Value); child != nil {
for _, v := range child.Values {
if len(v.(*Rule).Head.Args) > 0 {
if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok {
@@ -5521,6 +5639,18 @@ func validateWith(c *Compiler, unsafeBuiltinsMap map[string]struct{}, expr *Expr
}
}
}
// If the with-value is a ref to a function, but not a call, we can't rewrite it
if r, ok := value.Value.(Ref); ok {
// TODO: check that target ref doesn't exist?
if valueNode := c.RuleTree.Find(r); valueNode != nil {
for _, v := range valueNode.Values {
if len(v.(*Rule).Head.Args) > 0 {
return false, nil
}
}
}
}
case isInputRef(target): // ok, valid
case isBuiltinRefOrVar:
@@ -5539,6 +5669,9 @@ func validateWith(c *Compiler, unsafeBuiltinsMap map[string]struct{}, expr *Expr
if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok {
return false, err // err may be nil
}
case isAllowedUnknownFuncCall:
// The target isn't a ref to the input doc, data doc, or a known built-in, but it might be a ref to an unknown built-in.
return false, nil
default:
return false, NewError(TypeErr, target.Location, "with keyword target must reference existing %v, %v, or a function", InputRootDocument, DefaultRootDocument)
}