Files
kubesphere/vendor/github.com/open-policy-agent/opa/ast/policy.go
hongming 9769357005 update
Signed-off-by: hongming <talonwan@yunify.com>
2020-03-20 02:16:11 +08:00

1353 lines
31 KiB
Go

// Copyright 2016 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package ast
import (
"bytes"
"encoding/json"
"fmt"
"math/rand"
"strings"
"time"
"github.com/open-policy-agent/opa/util"
)
// Initialize seed for term hashing. This is intentionally placed before the
// root document sets are constructed to ensure they use the same hash seed as
// subsequent lookups. If the hash seeds are out of sync, lookups will fail.
var hashSeed = rand.New(rand.NewSource(time.Now().UnixNano()))
var hashSeed0 = (uint64(hashSeed.Uint32()) << 32) | uint64(hashSeed.Uint32())
var hashSeed1 = (uint64(hashSeed.Uint32()) << 32) | uint64(hashSeed.Uint32())
// DefaultRootDocument is the default root document.
//
// All package directives inside source files are implicitly prefixed with the
// DefaultRootDocument value.
var DefaultRootDocument = VarTerm("data")
// InputRootDocument names the document containing query arguments.
var InputRootDocument = VarTerm("input")
// RootDocumentNames contains the names of top-level documents that can be
// referred to in modules and queries.
var RootDocumentNames = NewSet(
DefaultRootDocument,
InputRootDocument,
)
// DefaultRootRef is a reference to the root of the default document.
//
// All refs to data in the policy engine's storage layer are prefixed with this ref.
var DefaultRootRef = Ref{DefaultRootDocument}
// InputRootRef is a reference to the root of the input document.
//
// All refs to query arguments are prefixed with this ref.
var InputRootRef = Ref{InputRootDocument}
// RootDocumentRefs contains the prefixes of top-level documents that all
// non-local references start with.
var RootDocumentRefs = NewSet(
NewTerm(DefaultRootRef),
NewTerm(InputRootRef),
)
// SystemDocumentKey is the name of the top-level key that identifies the system
// document.
var SystemDocumentKey = String("system")
// ReservedVars is the set of names that refer to implicitly ground vars.
var ReservedVars = NewVarSet(
DefaultRootDocument.Value.(Var),
InputRootDocument.Value.(Var),
)
// Wildcard represents the wildcard variable as defined in the language.
var Wildcard = &Term{Value: Var("_")}
// WildcardPrefix is the special character that all wildcard variables are
// prefixed with when the statement they are contained in is parsed.
var WildcardPrefix = "$"
// Keywords contains strings that map to language keywords.
var Keywords = [...]string{
"not",
"package",
"import",
"as",
"default",
"else",
"with",
"null",
"true",
"false",
"some",
}
// IsKeyword returns true if s is a language keyword.
func IsKeyword(s string) bool {
for _, x := range Keywords {
if x == s {
return true
}
}
return false
}
type (
// Node represents a node in an AST. Nodes may be statements in a policy module
// or elements of an ad-hoc query, expression, etc.
Node interface {
fmt.Stringer
Loc() *Location
SetLoc(*Location)
}
// Statement represents a single statement in a policy module.
Statement interface {
Node
}
)
type (
// Module represents a collection of policies (defined by rules)
// within a namespace (defined by the package) and optional
// dependencies on external documents (defined by imports).
Module struct {
Package *Package `json:"package"`
Imports []*Import `json:"imports,omitempty"`
Rules []*Rule `json:"rules,omitempty"`
Comments []*Comment `json:"comments,omitempty"`
}
// Comment contains the raw text from the comment in the definition.
Comment struct {
Text []byte
Location *Location
}
// Package represents the namespace of the documents produced
// by rules inside the module.
Package struct {
Location *Location `json:"-"`
Path Ref `json:"path"`
}
// Import represents a dependency on a document outside of the policy
// namespace. Imports are optional.
Import struct {
Location *Location `json:"-"`
Path *Term `json:"path"`
Alias Var `json:"alias,omitempty"`
}
// Rule represents a rule as defined in the language. Rules define the
// content of documents that represent policy decisions.
Rule struct {
Location *Location `json:"-"`
Default bool `json:"default,omitempty"`
Head *Head `json:"head"`
Body Body `json:"body"`
Else *Rule `json:"else,omitempty"`
// Module is a pointer to the module containing this rule. If the rule
// was NOT created while parsing/constructing a module, this should be
// left unset. The pointer is not included in any standard operations
// on the rule (e.g., printing, comparison, visiting, etc.)
Module *Module `json:"-"`
}
// Head represents the head of a rule.
Head struct {
Location *Location `json:"-"`
Name Var `json:"name"`
Args Args `json:"args,omitempty"`
Key *Term `json:"key,omitempty"`
Value *Term `json:"value,omitempty"`
Assign bool `json:"assign,omitempty"`
}
// Args represents zero or more arguments to a rule.
Args []*Term
// Body represents one or more expressions contained inside a rule or user
// function.
Body []*Expr
// Expr represents a single expression contained inside the body of a rule.
Expr struct {
Location *Location `json:"-"`
Generated bool `json:"generated,omitempty"`
Index int `json:"index"`
Negated bool `json:"negated,omitempty"`
Terms interface{} `json:"terms"`
With []*With `json:"with,omitempty"`
}
// SomeDecl represents a variable declaration statement. The symbols are variables.
SomeDecl struct {
Location *Location `json:"-"`
Symbols []*Term `json:"symbols"`
}
// With represents a modifier on an expression.
With struct {
Location *Location `json:"-"`
Target *Term `json:"target"`
Value *Term `json:"value"`
}
)
// Compare returns an integer indicating whether mod is less than, equal to,
// or greater than other.
func (mod *Module) Compare(other *Module) int {
if mod == nil {
if other == nil {
return 0
}
return -1
} else if other == nil {
return 1
}
if cmp := mod.Package.Compare(other.Package); cmp != 0 {
return cmp
}
if cmp := importsCompare(mod.Imports, other.Imports); cmp != 0 {
return cmp
}
return rulesCompare(mod.Rules, other.Rules)
}
// Copy returns a deep copy of mod.
func (mod *Module) Copy() *Module {
cpy := *mod
cpy.Rules = make([]*Rule, len(mod.Rules))
for i := range mod.Rules {
cpy.Rules[i] = mod.Rules[i].Copy()
}
cpy.Imports = make([]*Import, len(mod.Imports))
for i := range mod.Imports {
cpy.Imports[i] = mod.Imports[i].Copy()
}
cpy.Package = mod.Package.Copy()
return &cpy
}
// Equal returns true if mod equals other.
func (mod *Module) Equal(other *Module) bool {
return mod.Compare(other) == 0
}
func (mod *Module) String() string {
buf := []string{}
buf = append(buf, mod.Package.String())
if len(mod.Imports) > 0 {
buf = append(buf, "")
for _, imp := range mod.Imports {
buf = append(buf, imp.String())
}
}
if len(mod.Rules) > 0 {
buf = append(buf, "")
for _, rule := range mod.Rules {
buf = append(buf, rule.String())
}
}
return strings.Join(buf, "\n")
}
// RuleSet returns a RuleSet containing named rules in the mod.
func (mod *Module) RuleSet(name Var) RuleSet {
rs := NewRuleSet()
for _, rule := range mod.Rules {
if rule.Head.Name.Equal(name) {
rs.Add(rule)
}
}
return rs
}
// UnmarshalJSON parses bs and stores the result in mod. The rules in the module
// will have their module pointer set to mod.
func (mod *Module) UnmarshalJSON(bs []byte) error {
// Declare a new type and use a type conversion to avoid recursively calling
// Module#UnmarshalJSON.
type module Module
if err := util.UnmarshalJSON(bs, (*module)(mod)); err != nil {
return err
}
WalkRules(mod, func(rule *Rule) bool {
rule.Module = mod
return false
})
return nil
}
// NewComment returns a new Comment object.
func NewComment(text []byte) *Comment {
return &Comment{
Text: text,
}
}
// Loc returns the location of the comment in the definition.
func (c *Comment) Loc() *Location {
if c == nil {
return nil
}
return c.Location
}
// SetLoc sets the location on c.
func (c *Comment) SetLoc(loc *Location) {
c.Location = loc
}
func (c *Comment) String() string {
return "#" + string(c.Text)
}
// Equal returns true if this comment equals the other comment.
// Unlike other equality checks on AST nodes, comment equality
// depends on location.
func (c *Comment) Equal(other *Comment) bool {
return c.Location.Equal(other.Location) && bytes.Equal(c.Text, other.Text)
}
// Compare returns an integer indicating whether pkg is less than, equal to,
// or greater than other.
func (pkg *Package) Compare(other *Package) int {
return Compare(pkg.Path, other.Path)
}
// Copy returns a deep copy of pkg.
func (pkg *Package) Copy() *Package {
cpy := *pkg
cpy.Path = pkg.Path.Copy()
return &cpy
}
// Equal returns true if pkg is equal to other.
func (pkg *Package) Equal(other *Package) bool {
return pkg.Compare(other) == 0
}
// Loc returns the location of the Package in the definition.
func (pkg *Package) Loc() *Location {
if pkg == nil {
return nil
}
return pkg.Location
}
// SetLoc sets the location on pkg.
func (pkg *Package) SetLoc(loc *Location) {
pkg.Location = loc
}
func (pkg *Package) String() string {
if pkg == nil {
return "<illegal nil package>"
} else if len(pkg.Path) <= 1 {
return fmt.Sprintf("package <illegal path %q>", pkg.Path)
}
// Omit head as all packages have the DefaultRootDocument prepended at parse time.
path := make(Ref, len(pkg.Path)-1)
path[0] = VarTerm(string(pkg.Path[1].Value.(String)))
copy(path[1:], pkg.Path[2:])
return fmt.Sprintf("package %v", path)
}
// IsValidImportPath returns an error indicating if the import path is invalid.
// If the import path is invalid, err is nil.
func IsValidImportPath(v Value) (err error) {
switch v := v.(type) {
case Var:
if !v.Equal(DefaultRootDocument.Value) && !v.Equal(InputRootDocument.Value) {
return fmt.Errorf("invalid path %v: path must begin with input or data", v)
}
case Ref:
if err := IsValidImportPath(v[0].Value); err != nil {
return fmt.Errorf("invalid path %v: path must begin with input or data", v)
}
for _, e := range v[1:] {
if _, ok := e.Value.(String); !ok {
return fmt.Errorf("invalid path %v: path elements must be strings", v)
}
}
default:
return fmt.Errorf("invalid path %v: path must be ref or var", v)
}
return nil
}
// Compare returns an integer indicating whether imp is less than, equal to,
// or greater than other.
func (imp *Import) Compare(other *Import) int {
if imp == nil {
if other == nil {
return 0
}
return -1
} else if other == nil {
return 1
}
if cmp := Compare(imp.Path, other.Path); cmp != 0 {
return cmp
}
return Compare(imp.Alias, other.Alias)
}
// Copy returns a deep copy of imp.
func (imp *Import) Copy() *Import {
cpy := *imp
cpy.Path = imp.Path.Copy()
return &cpy
}
// Equal returns true if imp is equal to other.
func (imp *Import) Equal(other *Import) bool {
return imp.Compare(other) == 0
}
// Loc returns the location of the Import in the definition.
func (imp *Import) Loc() *Location {
if imp == nil {
return nil
}
return imp.Location
}
// SetLoc sets the location on imp.
func (imp *Import) SetLoc(loc *Location) {
imp.Location = loc
}
// Name returns the variable that is used to refer to the imported virtual
// document. This is the alias if defined otherwise the last element in the
// path.
func (imp *Import) Name() Var {
if len(imp.Alias) != 0 {
return imp.Alias
}
switch v := imp.Path.Value.(type) {
case Var:
return v
case Ref:
if len(v) == 1 {
return v[0].Value.(Var)
}
return Var(v[len(v)-1].Value.(String))
}
panic("illegal import")
}
func (imp *Import) String() string {
buf := []string{"import", imp.Path.String()}
if len(imp.Alias) > 0 {
buf = append(buf, "as "+imp.Alias.String())
}
return strings.Join(buf, " ")
}
// Compare returns an integer indicating whether rule is less than, equal to,
// or greater than other.
func (rule *Rule) Compare(other *Rule) int {
if rule == nil {
if other == nil {
return 0
}
return -1
} else if other == nil {
return 1
}
if cmp := rule.Head.Compare(other.Head); cmp != 0 {
return cmp
}
if cmp := util.Compare(rule.Default, other.Default); cmp != 0 {
return cmp
}
if cmp := rule.Body.Compare(other.Body); cmp != 0 {
return cmp
}
return rule.Else.Compare(other.Else)
}
// Copy returns a deep copy of rule.
func (rule *Rule) Copy() *Rule {
cpy := *rule
cpy.Head = rule.Head.Copy()
cpy.Body = rule.Body.Copy()
if cpy.Else != nil {
cpy.Else = rule.Else.Copy()
}
return &cpy
}
// Equal returns true if rule is equal to other.
func (rule *Rule) Equal(other *Rule) bool {
return rule.Compare(other) == 0
}
// Loc returns the location of the Rule in the definition.
func (rule *Rule) Loc() *Location {
if rule == nil {
return nil
}
return rule.Location
}
// SetLoc sets the location on rule.
func (rule *Rule) SetLoc(loc *Location) {
rule.Location = loc
}
// Path returns a ref referring to the document produced by this rule. If rule
// is not contained in a module, this function panics.
func (rule *Rule) Path() Ref {
if rule.Module == nil {
panic("assertion failed")
}
return rule.Module.Package.Path.Append(StringTerm(string(rule.Head.Name)))
}
func (rule *Rule) String() string {
buf := []string{}
if rule.Default {
buf = append(buf, "default")
}
buf = append(buf, rule.Head.String())
if !rule.Default {
buf = append(buf, "{")
buf = append(buf, rule.Body.String())
buf = append(buf, "}")
}
if rule.Else != nil {
buf = append(buf, rule.Else.elseString())
}
return strings.Join(buf, " ")
}
func (rule *Rule) elseString() string {
var buf []string
buf = append(buf, "else")
value := rule.Head.Value
if value != nil {
buf = append(buf, "=")
buf = append(buf, value.String())
}
buf = append(buf, "{")
buf = append(buf, rule.Body.String())
buf = append(buf, "}")
if rule.Else != nil {
buf = append(buf, rule.Else.elseString())
}
return strings.Join(buf, " ")
}
// NewHead returns a new Head object. If args are provided, the first will be
// used for the key and the second will be used for the value.
func NewHead(name Var, args ...*Term) *Head {
head := &Head{
Name: name,
}
if len(args) == 0 {
return head
}
head.Key = args[0]
if len(args) == 1 {
return head
}
head.Value = args[1]
return head
}
// DocKind represents the collection of document types that can be produced by rules.
type DocKind int
const (
// CompleteDoc represents a document that is completely defined by the rule.
CompleteDoc = iota
// PartialSetDoc represents a set document that is partially defined by the rule.
PartialSetDoc = iota
// PartialObjectDoc represents an object document that is partially defined by the rule.
PartialObjectDoc = iota
)
// DocKind returns the type of document produced by this rule.
func (head *Head) DocKind() DocKind {
if head.Key != nil {
if head.Value != nil {
return PartialObjectDoc
}
return PartialSetDoc
}
return CompleteDoc
}
// Compare returns an integer indicating whether head is less than, equal to,
// or greater than other.
func (head *Head) Compare(other *Head) int {
if head == nil {
if other == nil {
return 0
}
return -1
} else if other == nil {
return 1
}
if head.Assign && !other.Assign {
return -1
} else if !head.Assign && other.Assign {
return 1
}
if cmp := Compare(head.Args, other.Args); cmp != 0 {
return cmp
}
if cmp := Compare(head.Name, other.Name); cmp != 0 {
return cmp
}
if cmp := Compare(head.Key, other.Key); cmp != 0 {
return cmp
}
return Compare(head.Value, other.Value)
}
// Copy returns a deep copy of head.
func (head *Head) Copy() *Head {
cpy := *head
cpy.Args = head.Args.Copy()
cpy.Key = head.Key.Copy()
cpy.Value = head.Value.Copy()
return &cpy
}
// Equal returns true if this head equals other.
func (head *Head) Equal(other *Head) bool {
return head.Compare(other) == 0
}
func (head *Head) String() string {
var buf []string
if len(head.Args) != 0 {
buf = append(buf, head.Name.String()+head.Args.String())
} else if head.Key != nil {
buf = append(buf, head.Name.String()+"["+head.Key.String()+"]")
} else {
buf = append(buf, head.Name.String())
}
if head.Value != nil {
if head.Assign {
buf = append(buf, ":=")
} else {
buf = append(buf, "=")
}
buf = append(buf, head.Value.String())
}
return strings.Join(buf, " ")
}
// Vars returns a set of vars found in the head.
func (head *Head) Vars() VarSet {
vis := &VarVisitor{vars: VarSet{}}
// TODO: improve test coverage for this.
if head.Args != nil {
vis.Walk(head.Args)
}
if head.Key != nil {
vis.Walk(head.Key)
}
if head.Value != nil {
vis.Walk(head.Value)
}
return vis.vars
}
// Loc returns the Location of head.
func (head *Head) Loc() *Location {
if head == nil {
return nil
}
return head.Location
}
// SetLoc sets the location on head.
func (head *Head) SetLoc(loc *Location) {
head.Location = loc
}
// Copy returns a deep copy of a.
func (a Args) Copy() Args {
cpy := Args{}
for _, t := range a {
cpy = append(cpy, t.Copy())
}
return cpy
}
func (a Args) String() string {
var buf []string
for _, t := range a {
buf = append(buf, t.String())
}
return "(" + strings.Join(buf, ", ") + ")"
}
// Loc returns the Location of a.
func (a Args) Loc() *Location {
if len(a) == 0 {
return nil
}
return a[0].Location
}
// SetLoc sets the location on a.
func (a Args) SetLoc(loc *Location) {
if len(a) != 0 {
a[0].SetLocation(loc)
}
}
// Vars returns a set of vars that appear in a.
func (a Args) Vars() VarSet {
vis := &VarVisitor{vars: VarSet{}}
vis.Walk(a)
return vis.vars
}
// NewBody returns a new Body containing the given expressions. The indices of
// the immediate expressions will be reset.
func NewBody(exprs ...*Expr) Body {
for i, expr := range exprs {
expr.Index = i
}
return Body(exprs)
}
// MarshalJSON returns JSON encoded bytes representing body.
func (body Body) MarshalJSON() ([]byte, error) {
// Serialize empty Body to empty array. This handles both the empty case and the
// nil case (whereas by default the result would be null if body was nil.)
if len(body) == 0 {
return []byte(`[]`), nil
}
return json.Marshal([]*Expr(body))
}
// Append adds the expr to the body and updates the expr's index accordingly.
func (body *Body) Append(expr *Expr) {
n := len(*body)
expr.Index = n
*body = append(*body, expr)
}
// Set sets the expr in the body at the specified position and updates the
// expr's index accordingly.
func (body Body) Set(expr *Expr, pos int) {
body[pos] = expr
expr.Index = pos
}
// Compare returns an integer indicating whether body is less than, equal to,
// or greater than other.
//
// If body is a subset of other, it is considered less than (and vice versa).
func (body Body) Compare(other Body) int {
minLen := len(body)
if len(other) < minLen {
minLen = len(other)
}
for i := 0; i < minLen; i++ {
if cmp := body[i].Compare(other[i]); cmp != 0 {
return cmp
}
}
if len(body) < len(other) {
return -1
}
if len(other) < len(body) {
return 1
}
return 0
}
// Copy returns a deep copy of body.
func (body Body) Copy() Body {
cpy := make(Body, len(body))
for i := range body {
cpy[i] = body[i].Copy()
}
return cpy
}
// Contains returns true if this body contains the given expression.
func (body Body) Contains(x *Expr) bool {
for _, e := range body {
if e.Equal(x) {
return true
}
}
return false
}
// Equal returns true if this Body is equal to the other Body.
func (body Body) Equal(other Body) bool {
return body.Compare(other) == 0
}
// Hash returns the hash code for the Body.
func (body Body) Hash() int {
s := 0
for _, e := range body {
s += e.Hash()
}
return s
}
// IsGround returns true if all of the expressions in the Body are ground.
func (body Body) IsGround() bool {
for _, e := range body {
if !e.IsGround() {
return false
}
}
return true
}
// Loc returns the location of the Body in the definition.
func (body Body) Loc() *Location {
if len(body) == 0 {
return nil
}
return body[0].Location
}
// SetLoc sets the location on body.
func (body Body) SetLoc(loc *Location) {
if len(body) != 0 {
body[0].SetLocation(loc)
}
}
func (body Body) String() string {
var buf []string
for _, v := range body {
buf = append(buf, v.String())
}
return strings.Join(buf, "; ")
}
// Vars returns a VarSet containing variables in body. The params can be set to
// control which vars are included.
func (body Body) Vars(params VarVisitorParams) VarSet {
vis := NewVarVisitor().WithParams(params)
vis.Walk(body)
return vis.Vars()
}
// NewExpr returns a new Expr object.
func NewExpr(terms interface{}) *Expr {
return &Expr{
Negated: false,
Terms: terms,
Index: 0,
With: nil,
}
}
// Complement returns a copy of this expression with the negation flag flipped.
func (expr *Expr) Complement() *Expr {
cpy := *expr
cpy.Negated = !cpy.Negated
return &cpy
}
// Equal returns true if this Expr equals the other Expr.
func (expr *Expr) Equal(other *Expr) bool {
return expr.Compare(other) == 0
}
// Compare returns an integer indicating whether expr is less than, equal to,
// or greater than other.
//
// Expressions are compared as follows:
//
// 1. Declarations are always less than other expressions.
// 2. Preceding expression (by Index) is always less than the other expression.
// 3. Non-negated expressions are always less than than negated expressions.
// 4. Single term expressions are always less than built-in expressions.
//
// Otherwise, the expression terms are compared normally. If both expressions
// have the same terms, the modifiers are compared.
func (expr *Expr) Compare(other *Expr) int {
if expr == nil {
if other == nil {
return 0
}
return -1
} else if other == nil {
return 1
}
o1 := expr.sortOrder()
o2 := other.sortOrder()
if o1 < o2 {
return -1
} else if o2 < o1 {
return 1
}
switch {
case expr.Index < other.Index:
return -1
case expr.Index > other.Index:
return 1
}
switch {
case expr.Negated && !other.Negated:
return 1
case !expr.Negated && other.Negated:
return -1
}
switch t := expr.Terms.(type) {
case *Term:
if cmp := Compare(t.Value, other.Terms.(*Term).Value); cmp != 0 {
return cmp
}
case []*Term:
if cmp := termSliceCompare(t, other.Terms.([]*Term)); cmp != 0 {
return cmp
}
case *SomeDecl:
if cmp := Compare(t, other.Terms.(*SomeDecl)); cmp != 0 {
return cmp
}
}
return withSliceCompare(expr.With, other.With)
}
func (expr *Expr) sortOrder() int {
switch expr.Terms.(type) {
case *SomeDecl:
return 0
case *Term:
return 1
case []*Term:
return 2
}
return -1
}
// Copy returns a deep copy of expr.
func (expr *Expr) Copy() *Expr {
cpy := *expr
switch ts := expr.Terms.(type) {
case *SomeDecl:
cpy.Terms = ts.Copy()
case []*Term:
cpyTs := make([]*Term, len(ts))
for i := range ts {
cpyTs[i] = ts[i].Copy()
}
cpy.Terms = cpyTs
case *Term:
cpy.Terms = ts.Copy()
}
cpy.With = make([]*With, len(expr.With))
for i := range expr.With {
cpy.With[i] = expr.With[i].Copy()
}
return &cpy
}
// Hash returns the hash code of the Expr.
func (expr *Expr) Hash() int {
s := expr.Index
switch ts := expr.Terms.(type) {
case *SomeDecl:
s += ts.Hash()
case []*Term:
for _, t := range ts {
s += t.Value.Hash()
}
case *Term:
s += ts.Value.Hash()
}
if expr.Negated {
s++
}
for _, w := range expr.With {
s += w.Hash()
}
return s
}
// IncludeWith returns a copy of expr with the with modifier appended.
func (expr *Expr) IncludeWith(target *Term, value *Term) *Expr {
cpy := *expr
cpy.With = append(cpy.With, &With{Target: target, Value: value})
return &cpy
}
// NoWith returns a copy of expr where the with modifier has been removed.
func (expr *Expr) NoWith() *Expr {
cpy := *expr
cpy.With = nil
return &cpy
}
// IsEquality returns true if this is an equality expression.
func (expr *Expr) IsEquality() bool {
return isglobalbuiltin(expr, Var(Equality.Name))
}
// IsAssignment returns true if this an assignment expression.
func (expr *Expr) IsAssignment() bool {
return isglobalbuiltin(expr, Var(Assign.Name))
}
// IsCall returns true if this expression calls a function.
func (expr *Expr) IsCall() bool {
_, ok := expr.Terms.([]*Term)
return ok
}
// Operator returns the name of the function or built-in this expression refers
// to. If this expression is not a function call, returns nil.
func (expr *Expr) Operator() Ref {
terms, ok := expr.Terms.([]*Term)
if !ok || len(terms) == 0 {
return nil
}
return terms[0].Value.(Ref)
}
// Operand returns the term at the zero-based pos. If the expr does not include
// at least pos+1 terms, this function returns nil.
func (expr *Expr) Operand(pos int) *Term {
terms, ok := expr.Terms.([]*Term)
if !ok {
return nil
}
idx := pos + 1
if idx < len(terms) {
return terms[idx]
}
return nil
}
// Operands returns the built-in function operands.
func (expr *Expr) Operands() []*Term {
terms, ok := expr.Terms.([]*Term)
if !ok {
return nil
}
return terms[1:]
}
// IsGround returns true if all of the expression terms are ground.
func (expr *Expr) IsGround() bool {
switch ts := expr.Terms.(type) {
case []*Term:
for _, t := range ts[1:] {
if !t.IsGround() {
return false
}
}
case *Term:
return ts.IsGround()
}
return true
}
// SetOperator sets the expr's operator and returns the expr itself. If expr is
// not a call expr, this function will panic.
func (expr *Expr) SetOperator(term *Term) *Expr {
expr.Terms.([]*Term)[0] = term
return expr
}
// SetLocation sets the expr's location and returns the expr itself.
func (expr *Expr) SetLocation(loc *Location) *Expr {
expr.Location = loc
return expr
}
// Loc returns the Location of expr.
func (expr *Expr) Loc() *Location {
if expr == nil {
return nil
}
return expr.Location
}
// SetLoc sets the location on expr.
func (expr *Expr) SetLoc(loc *Location) {
expr.SetLocation(loc)
}
func (expr *Expr) String() string {
var buf []string
if expr.Negated {
buf = append(buf, "not")
}
switch t := expr.Terms.(type) {
case []*Term:
if expr.IsEquality() && validEqAssignArgCount(expr) {
buf = append(buf, fmt.Sprintf("%v %v %v", t[1], Equality.Infix, t[2]))
} else {
buf = append(buf, Call(t).String())
}
case *Term:
buf = append(buf, t.String())
case *SomeDecl:
buf = append(buf, t.String())
}
for i := range expr.With {
buf = append(buf, expr.With[i].String())
}
return strings.Join(buf, " ")
}
// UnmarshalJSON parses the byte array and stores the result in expr.
func (expr *Expr) UnmarshalJSON(bs []byte) error {
v := map[string]interface{}{}
if err := util.UnmarshalJSON(bs, &v); err != nil {
return err
}
return unmarshalExpr(expr, v)
}
// Vars returns a VarSet containing variables in expr. The params can be set to
// control which vars are included.
func (expr *Expr) Vars(params VarVisitorParams) VarSet {
vis := NewVarVisitor().WithParams(params)
vis.Walk(expr)
return vis.Vars()
}
// NewBuiltinExpr creates a new Expr object with the supplied terms.
// The builtin operator must be the first term.
func NewBuiltinExpr(terms ...*Term) *Expr {
return &Expr{Terms: terms}
}
func (d *SomeDecl) String() string {
buf := make([]string, len(d.Symbols))
for i := range buf {
buf[i] = d.Symbols[i].String()
}
return "some " + strings.Join(buf, ", ")
}
// SetLoc sets the Location on d.
func (d *SomeDecl) SetLoc(loc *Location) {
d.Location = loc
}
// Loc returns the Location of d.
func (d *SomeDecl) Loc() *Location {
return d.Location
}
// Copy returns a deep copy of d.
func (d *SomeDecl) Copy() *SomeDecl {
cpy := *d
cpy.Symbols = termSliceCopy(d.Symbols)
return &cpy
}
// Compare returns an integer indicating whether d is less than, equal to, or
// greater than other.
func (d *SomeDecl) Compare(other *SomeDecl) int {
return termSliceCompare(d.Symbols, other.Symbols)
}
// Hash returns a hash code of d.
func (d *SomeDecl) Hash() int {
return termSliceHash(d.Symbols)
}
func (w *With) String() string {
return "with " + w.Target.String() + " as " + w.Value.String()
}
// Equal returns true if this With is equals the other With.
func (w *With) Equal(other *With) bool {
return Compare(w, other) == 0
}
// Compare returns an integer indicating whether w is less than, equal to, or
// greater than other.
func (w *With) Compare(other *With) int {
if w == nil {
if other == nil {
return 0
}
return -1
} else if other == nil {
return 1
}
if cmp := Compare(w.Target, other.Target); cmp != 0 {
return cmp
}
return Compare(w.Value, other.Value)
}
// Copy returns a deep copy of w.
func (w *With) Copy() *With {
cpy := *w
cpy.Value = w.Value.Copy()
cpy.Target = w.Target.Copy()
return &cpy
}
// Hash returns the hash code of the With.
func (w With) Hash() int {
return w.Target.Hash() + w.Value.Hash()
}
// SetLocation sets the location on w.
func (w *With) SetLocation(loc *Location) *With {
w.Location = loc
return w
}
// Loc returns the Location of w.
func (w *With) Loc() *Location {
if w == nil {
return nil
}
return w.Location
}
// SetLoc sets the location on w.
func (w *With) SetLoc(loc *Location) {
w.Location = loc
}
// RuleSet represents a collection of rules that produce a virtual document.
type RuleSet []*Rule
// NewRuleSet returns a new RuleSet containing the given rules.
func NewRuleSet(rules ...*Rule) RuleSet {
rs := make(RuleSet, 0, len(rules))
for _, rule := range rules {
rs.Add(rule)
}
return rs
}
// Add inserts the rule into rs.
func (rs *RuleSet) Add(rule *Rule) {
for _, exist := range *rs {
if exist.Equal(rule) {
return
}
}
*rs = append(*rs, rule)
}
// Contains returns true if rs contains rule.
func (rs RuleSet) Contains(rule *Rule) bool {
for i := range rs {
if rs[i].Equal(rule) {
return true
}
}
return false
}
// Diff returns a new RuleSet containing rules in rs that are not in other.
func (rs RuleSet) Diff(other RuleSet) RuleSet {
result := NewRuleSet()
for i := range rs {
if !other.Contains(rs[i]) {
result.Add(rs[i])
}
}
return result
}
// Equal returns true if rs equals other.
func (rs RuleSet) Equal(other RuleSet) bool {
return len(rs.Diff(other)) == 0 && len(other.Diff(rs)) == 0
}
// Merge returns a ruleset containing the union of rules from rs an other.
func (rs RuleSet) Merge(other RuleSet) RuleSet {
result := NewRuleSet()
for i := range rs {
result.Add(rs[i])
}
for i := range other {
result.Add(other[i])
}
return result
}
func (rs RuleSet) String() string {
buf := make([]string, 0, len(rs))
for _, rule := range rs {
buf = append(buf, rule.String())
}
return "{" + strings.Join(buf, ", ") + "}"
}
type ruleSlice []*Rule
func (s ruleSlice) Less(i, j int) bool { return Compare(s[i], s[j]) < 0 }
func (s ruleSlice) Swap(i, j int) { x := s[i]; s[i] = s[j]; s[j] = x }
func (s ruleSlice) Len() int { return len(s) }
// Returns true if the equality or assignment expression referred to by expr
// has a valid number of arguments.
func validEqAssignArgCount(expr *Expr) bool {
return len(expr.Operands()) == 2
}
// this function checks if the expr refers to a non-namespaced (global) built-in
// function like eq, gt, plus, etc.
func isglobalbuiltin(expr *Expr, name Var) bool {
terms, ok := expr.Terms.([]*Term)
if !ok {
return false
}
// NOTE(tsandall): do not use Term#Equal or Value#Compare to avoid
// allocation here.
ref, ok := terms[0].Value.(Ref)
if !ok || len(ref) != 1 {
return false
} else if head, ok := ref[0].Value.(Var); !ok {
return false
} else {
return head.Equal(name)
}
}