Files
kubesphere/vendor/github.com/open-policy-agent/opa/ast/term.go
hongming 9769357005 update
Signed-off-by: hongming <talonwan@yunify.com>
2020-03-20 02:16:11 +08:00

2573 lines
58 KiB
Go

// Copyright 2016 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package ast
import (
"bytes"
"encoding/json"
"fmt"
"io"
"math/big"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
"github.com/OneOfOne/xxhash"
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/util"
)
var errFindNotFound = fmt.Errorf("find: not found")
// Location records a position in source code
type Location struct {
Text []byte `json:"-"` // The original text fragment from the source.
File string `json:"file"` // The name of the source file (which may be empty).
Row int `json:"row"` // The line in the source.
Col int `json:"col"` // The column in the row.
}
// NewLocation returns a new Location object.
func NewLocation(text []byte, file string, row int, col int) *Location {
return &Location{Text: text, File: file, Row: row, Col: col}
}
// Equal checks if two locations are equal to each other.
func (loc *Location) Equal(other *Location) bool {
return bytes.Equal(loc.Text, other.Text) &&
loc.File == other.File &&
loc.Row == other.Row &&
loc.Col == other.Col
}
// Errorf returns a new error value with a message formatted to include the location
// info (e.g., line, column, filename, etc.)
func (loc *Location) Errorf(f string, a ...interface{}) error {
return errors.New(loc.Format(f, a...))
}
// Wrapf returns a new error value that wraps an existing error with a message formatted
// to include the location info (e.g., line, column, filename, etc.)
func (loc *Location) Wrapf(err error, f string, a ...interface{}) error {
return errors.Wrap(err, loc.Format(f, a...))
}
// Format returns a formatted string prefixed with the location information.
func (loc *Location) Format(f string, a ...interface{}) string {
if len(loc.File) > 0 {
f = fmt.Sprintf("%v:%v: %v", loc.File, loc.Row, f)
} else {
f = fmt.Sprintf("%v:%v: %v", loc.Row, loc.Col, f)
}
return fmt.Sprintf(f, a...)
}
func (loc *Location) String() string {
if len(loc.File) > 0 {
return fmt.Sprintf("%v:%v", loc.File, loc.Row)
}
if len(loc.Text) > 0 {
return string(loc.Text)
}
return fmt.Sprintf("%v:%v", loc.Row, loc.Col)
}
// Compare returns -1, 0, or 1 to indicate if this loc is less than, equal to,
// or greater than the other. Comparison is performed on the file, row, and
// column of the Location (but not on the text.) Nil locations are greater than
// non-nil locations.
func (loc *Location) Compare(other *Location) int {
if loc == nil && other == nil {
return 0
} else if loc == nil {
return 1
} else if other == nil {
return -1
} else if loc.File < other.File {
return -1
} else if loc.File > other.File {
return 1
} else if loc.Row < other.Row {
return -1
} else if loc.Row > other.Row {
return 1
} else if loc.Col < other.Col {
return -1
} else if loc.Col > other.Col {
return 1
}
return 0
}
// Value declares the common interface for all Term values. Every kind of Term value
// in the language is represented as a type that implements this interface:
//
// - Null, Boolean, Number, String
// - Object, Array, Set
// - Variables, References
// - Array, Set, and Object Comprehensions
// - Calls
type Value interface {
Compare(other Value) int // Compare returns <0, 0, or >0 if this Value is less than, equal to, or greater than other, respectively.
Find(path Ref) (Value, error) // Find returns value referred to by path or an error if path is not found.
Hash() int // Returns hash code of the value.
IsGround() bool // IsGround returns true if this value is not a variable or contains no variables.
String() string // String returns a human readable string representation of the value.
}
// InterfaceToValue converts a native Go value x to a Value.
func InterfaceToValue(x interface{}) (Value, error) {
switch x := x.(type) {
case nil:
return Null{}, nil
case bool:
return Boolean(x), nil
case json.Number:
return Number(x), nil
case int64:
return int64Number(x), nil
case float64:
return floatNumber(x), nil
case int:
return intNumber(x), nil
case string:
return String(x), nil
case []interface{}:
r := make(Array, 0, len(x))
for _, e := range x {
e, err := InterfaceToValue(e)
if err != nil {
return nil, err
}
r = append(r, &Term{Value: e})
}
return r, nil
case map[string]interface{}:
r := newobject(len(x))
for k, v := range x {
k, err := InterfaceToValue(k)
if err != nil {
return nil, err
}
v, err := InterfaceToValue(v)
if err != nil {
return nil, err
}
r.Insert(NewTerm(k), NewTerm(v))
}
return r, nil
case map[string]string:
r := newobject(len(x))
for k, v := range x {
k, err := InterfaceToValue(k)
if err != nil {
return nil, err
}
v, err := InterfaceToValue(v)
if err != nil {
return nil, err
}
r.Insert(NewTerm(k), NewTerm(v))
}
return r, nil
default:
return nil, fmt.Errorf("ast: illegal value: %T", x)
}
}
// ValueFromReader returns an AST value from a JSON serialized value in the reader.
func ValueFromReader(r io.Reader) (Value, error) {
var x interface{}
if err := util.NewJSONDecoder(r).Decode(&x); err != nil {
return nil, err
}
return InterfaceToValue(x)
}
// As converts v into a Go native type referred to by x.
func As(v Value, x interface{}) error {
return util.NewJSONDecoder(bytes.NewBufferString(v.String())).Decode(x)
}
// Resolver defines the interface for resolving references to native Go values.
type Resolver interface {
Resolve(ref Ref) (value interface{}, err error)
}
// ValueResolver defines the interface for resolving references to AST values.
type ValueResolver interface {
Resolve(ref Ref) (value Value, err error)
}
// UnknownValueErr indicates a ValueResolver was unable to resolve a reference
// because the reference refers to an unknown value.
type UnknownValueErr struct{}
func (UnknownValueErr) Error() string {
return "unknown value"
}
// IsUnknownValueErr returns true if the err is an UnknownValueErr.
func IsUnknownValueErr(err error) bool {
_, ok := err.(UnknownValueErr)
return ok
}
type illegalResolver struct{}
func (illegalResolver) Resolve(ref Ref) (interface{}, error) {
return nil, fmt.Errorf("illegal value: %v", ref)
}
// ValueToInterface returns the Go representation of an AST value. The AST
// value should not contain any values that require evaluation (e.g., vars,
// comprehensions, etc.)
func ValueToInterface(v Value, resolver Resolver) (interface{}, error) {
switch v := v.(type) {
case Null:
return nil, nil
case Boolean:
return bool(v), nil
case Number:
return json.Number(v), nil
case String:
return string(v), nil
case Array:
buf := []interface{}{}
for _, x := range v {
x1, err := ValueToInterface(x.Value, resolver)
if err != nil {
return nil, err
}
buf = append(buf, x1)
}
return buf, nil
case Object:
buf := map[string]interface{}{}
err := v.Iter(func(k, v *Term) error {
ki, err := ValueToInterface(k.Value, resolver)
if err != nil {
return err
}
asStr, stringKey := ki.(string)
if !stringKey {
return fmt.Errorf("object value has non-string key (%T)", ki)
}
vi, err := ValueToInterface(v.Value, resolver)
if err != nil {
return err
}
buf[asStr] = vi
return nil
})
if err != nil {
return nil, err
}
return buf, nil
case Set:
buf := []interface{}{}
err := v.Iter(func(x *Term) error {
x1, err := ValueToInterface(x.Value, resolver)
if err != nil {
return err
}
buf = append(buf, x1)
return nil
})
if err != nil {
return nil, err
}
return buf, nil
case Ref:
return resolver.Resolve(v)
default:
return nil, fmt.Errorf("%v requires evaluation", TypeName(v))
}
}
// JSON returns the JSON representation of v. The value must not contain any
// refs or terms that require evaluation (e.g., vars, comprehensions, etc.)
func JSON(v Value) (interface{}, error) {
return ValueToInterface(v, illegalResolver{})
}
// MustInterfaceToValue converts a native Go value x to a Value. If the
// conversion fails, this function will panic. This function is mostly for test
// purposes.
func MustInterfaceToValue(x interface{}) Value {
v, err := InterfaceToValue(x)
if err != nil {
panic(err)
}
return v
}
// Term is an argument to a function.
type Term struct {
Value Value `json:"value"` // the value of the Term as represented in Go
Location *Location `json:"-"` // the location of the Term in the source
}
// NewTerm returns a new Term object.
func NewTerm(v Value) *Term {
return &Term{
Value: v,
}
}
// SetLocation updates the term's Location and returns the term itself.
func (term *Term) SetLocation(loc *Location) *Term {
term.Location = loc
return term
}
// Loc returns the Location of term.
func (term *Term) Loc() *Location {
if term == nil {
return nil
}
return term.Location
}
// SetLoc sets the location on term.
func (term *Term) SetLoc(loc *Location) {
term.SetLocation(loc)
}
// Copy returns a deep copy of term.
func (term *Term) Copy() *Term {
if term == nil {
return nil
}
cpy := *term
switch v := term.Value.(type) {
case Null, Boolean, Number, String, Var:
cpy.Value = v
case Ref:
cpy.Value = v.Copy()
case Array:
cpy.Value = v.Copy()
case Set:
cpy.Value = v.Copy()
case Object:
cpy.Value = v.Copy()
case *ArrayComprehension:
cpy.Value = v.Copy()
case *ObjectComprehension:
cpy.Value = v.Copy()
case *SetComprehension:
cpy.Value = v.Copy()
case Call:
cpy.Value = v.Copy()
}
return &cpy
}
// Equal returns true if this term equals the other term. Equality is
// defined for each kind of term.
func (term *Term) Equal(other *Term) bool {
if term == nil && other != nil {
return false
}
if term != nil && other == nil {
return false
}
if term == other {
return true
}
// TODO(tsandall): This early-exit avoids allocations for types that have
// Equal() functions that just use == underneath. We should revisit the
// other types and implement Equal() functions that do not require
// allocations.
switch v := term.Value.(type) {
case Null:
return v.Equal(other.Value)
case Boolean:
return v.Equal(other.Value)
case Number:
return v.Equal(other.Value)
case String:
return v.Equal(other.Value)
case Var:
return v.Equal(other.Value)
}
return term.Value.Compare(other.Value) == 0
}
// Get returns a value referred to by name from the term.
func (term *Term) Get(name *Term) *Term {
switch v := term.Value.(type) {
case Array:
return v.Get(name)
case Object:
return v.Get(name)
case Set:
if v.Contains(name) {
return name
}
}
return nil
}
// Hash returns the hash code of the Term's value.
func (term *Term) Hash() int {
return term.Value.Hash()
}
// IsGround returns true if this terms' Value is ground.
func (term *Term) IsGround() bool {
return term.Value.IsGround()
}
// MarshalJSON returns the JSON encoding of the term.
//
// Specialized marshalling logic is required to include a type hint for Value.
func (term *Term) MarshalJSON() ([]byte, error) {
d := map[string]interface{}{
"type": TypeName(term.Value),
"value": term.Value,
}
return json.Marshal(d)
}
func (term *Term) String() string {
return term.Value.String()
}
// UnmarshalJSON parses the byte array and stores the result in term.
// Specialized unmarshalling is required to handle Value.
func (term *Term) UnmarshalJSON(bs []byte) error {
v := map[string]interface{}{}
if err := util.UnmarshalJSON(bs, &v); err != nil {
return err
}
val, err := unmarshalValue(v)
if err != nil {
return err
}
term.Value = val
return nil
}
// Vars returns a VarSet with variables contained in this term.
func (term *Term) Vars() VarSet {
vis := &VarVisitor{vars: VarSet{}}
vis.Walk(term)
return vis.vars
}
// IsConstant returns true if the AST value is constant.
func IsConstant(v Value) bool {
found := false
vis := GenericVisitor{
func(x interface{}) bool {
switch x.(type) {
case Var, Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension, Call:
found = true
return true
}
return false
},
}
vis.Walk(v)
return !found
}
// IsComprehension returns true if the supplied value is a comprehension.
func IsComprehension(x Value) bool {
switch x.(type) {
case *ArrayComprehension, *ObjectComprehension, *SetComprehension:
return true
}
return false
}
// ContainsRefs returns true if the Value v contains refs.
func ContainsRefs(v interface{}) bool {
found := false
WalkRefs(v, func(r Ref) bool {
found = true
return found
})
return found
}
// ContainsComprehensions returns true if the Value v contains comprehensions.
func ContainsComprehensions(v interface{}) bool {
found := false
WalkClosures(v, func(x interface{}) bool {
switch x.(type) {
case *ArrayComprehension, *ObjectComprehension, *SetComprehension:
found = true
return found
}
return found
})
return found
}
// IsScalar returns true if the AST value is a scalar.
func IsScalar(v Value) bool {
switch v.(type) {
case String:
return true
case Number:
return true
case Boolean:
return true
case Null:
return true
}
return false
}
// Null represents the null value defined by JSON.
type Null struct{}
// NullTerm creates a new Term with a Null value.
func NullTerm() *Term {
return &Term{Value: Null{}}
}
// Equal returns true if the other term Value is also Null.
func (null Null) Equal(other Value) bool {
switch other.(type) {
case Null:
return true
default:
return false
}
}
// Compare compares null to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (null Null) Compare(other Value) int {
return Compare(null, other)
}
// Find returns the current value or a not found error.
func (null Null) Find(path Ref) (Value, error) {
if len(path) == 0 {
return null, nil
}
return nil, errFindNotFound
}
// Hash returns the hash code for the Value.
func (null Null) Hash() int {
return 0
}
// IsGround always returns true.
func (null Null) IsGround() bool {
return true
}
func (null Null) String() string {
return "null"
}
// Boolean represents a boolean value defined by JSON.
type Boolean bool
// BooleanTerm creates a new Term with a Boolean value.
func BooleanTerm(b bool) *Term {
return &Term{Value: Boolean(b)}
}
// Equal returns true if the other Value is a Boolean and is equal.
func (bol Boolean) Equal(other Value) bool {
switch other := other.(type) {
case Boolean:
return bol == other
default:
return false
}
}
// Compare compares bol to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (bol Boolean) Compare(other Value) int {
return Compare(bol, other)
}
// Find returns the current value or a not found error.
func (bol Boolean) Find(path Ref) (Value, error) {
if len(path) == 0 {
return bol, nil
}
return nil, errFindNotFound
}
// Hash returns the hash code for the Value.
func (bol Boolean) Hash() int {
if bol {
return 1
}
return 0
}
// IsGround always returns true.
func (bol Boolean) IsGround() bool {
return true
}
func (bol Boolean) String() string {
return strconv.FormatBool(bool(bol))
}
// Number represents a numeric value as defined by JSON.
type Number json.Number
// NumberTerm creates a new Term with a Number value.
func NumberTerm(n json.Number) *Term {
return &Term{Value: Number(n)}
}
// IntNumberTerm creates a new Term with an integer Number value.
func IntNumberTerm(i int) *Term {
return &Term{Value: Number(strconv.Itoa(i))}
}
// FloatNumberTerm creates a new Term with a floating point Number value.
func FloatNumberTerm(f float64) *Term {
s := strconv.FormatFloat(f, 'g', -1, 64)
return &Term{Value: Number(s)}
}
// Equal returns true if the other Value is a Number and is equal.
func (num Number) Equal(other Value) bool {
switch other := other.(type) {
case Number:
return Compare(num, other) == 0
default:
return false
}
}
// Compare compares num to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (num Number) Compare(other Value) int {
return Compare(num, other)
}
// Find returns the current value or a not found error.
func (num Number) Find(path Ref) (Value, error) {
if len(path) == 0 {
return num, nil
}
return nil, errFindNotFound
}
// Hash returns the hash code for the Value.
func (num Number) Hash() int {
f, err := json.Number(num).Float64()
if err != nil {
bs := []byte(num)
h := xxhash.Checksum64(bs)
return int(h)
}
return int(f)
}
// Int returns the int representation of num if possible.
func (num Number) Int() (int, bool) {
i64, ok := num.Int64()
return int(i64), ok
}
// Int64 returns the int64 representation of num if possible.
func (num Number) Int64() (int64, bool) {
i, err := json.Number(num).Int64()
if err != nil {
return 0, false
}
return i, true
}
// Float64 returns the float64 representation of num if possible.
func (num Number) Float64() (float64, bool) {
f, err := json.Number(num).Float64()
if err != nil {
return 0, false
}
return f, true
}
// IsGround always returns true.
func (num Number) IsGround() bool {
return true
}
// MarshalJSON returns JSON encoded bytes representing num.
func (num Number) MarshalJSON() ([]byte, error) {
return json.Marshal(json.Number(num))
}
func (num Number) String() string {
return string(num)
}
func intNumber(i int) Number {
return Number(strconv.Itoa(i))
}
func int64Number(i int64) Number {
return Number(strconv.FormatInt(i, 10))
}
func floatNumber(f float64) Number {
return Number(strconv.FormatFloat(f, 'g', -1, 64))
}
// String represents a string value as defined by JSON.
type String string
// StringTerm creates a new Term with a String value.
func StringTerm(s string) *Term {
return &Term{Value: String(s)}
}
// Equal returns true if the other Value is a String and is equal.
func (str String) Equal(other Value) bool {
switch other := other.(type) {
case String:
return str == other
default:
return false
}
}
// Compare compares str to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (str String) Compare(other Value) int {
return Compare(str, other)
}
// Find returns the current value or a not found error.
func (str String) Find(path Ref) (Value, error) {
if len(path) == 0 {
return str, nil
}
return nil, errFindNotFound
}
// IsGround always returns true.
func (str String) IsGround() bool {
return true
}
func (str String) String() string {
return strconv.Quote(string(str))
}
// Hash returns the hash code for the Value.
func (str String) Hash() int {
h := xxhash.ChecksumString64S(string(str), hashSeed0)
return int(h)
}
// Var represents a variable as defined by the language.
type Var string
// VarTerm creates a new Term with a Variable value.
func VarTerm(v string) *Term {
return &Term{Value: Var(v)}
}
// Equal returns true if the other Value is a Variable and has the same value
// (name).
func (v Var) Equal(other Value) bool {
switch other := other.(type) {
case Var:
return v == other
default:
return false
}
}
// Compare compares v to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (v Var) Compare(other Value) int {
return Compare(v, other)
}
// Find returns the current value or a not found error.
func (v Var) Find(path Ref) (Value, error) {
if len(path) == 0 {
return v, nil
}
return nil, errFindNotFound
}
// Hash returns the hash code for the Value.
func (v Var) Hash() int {
h := xxhash.ChecksumString64S(string(v), hashSeed0)
return int(h)
}
// IsGround always returns false.
func (v Var) IsGround() bool {
return false
}
// IsWildcard returns true if this is a wildcard variable.
func (v Var) IsWildcard() bool {
return strings.HasPrefix(string(v), WildcardPrefix)
}
// IsGenerated returns true if this variable was generated during compilation.
func (v Var) IsGenerated() bool {
return strings.HasPrefix(string(v), "__local")
}
func (v Var) String() string {
// Special case for wildcard so that string representation is parseable. The
// parser mangles wildcard variables to make their names unique and uses an
// illegal variable name character (WildcardPrefix) to avoid conflicts. When
// we serialize the variable here, we need to make sure it's parseable.
if v.IsWildcard() {
return Wildcard.String()
}
return string(v)
}
// Ref represents a reference as defined by the language.
type Ref []*Term
// EmptyRef returns a new, empty reference.
func EmptyRef() Ref {
return Ref([]*Term{})
}
// PtrRef returns a new reference against the head for the pointer
// s. Path components in the pointer are unescaped.
func PtrRef(head *Term, s string) (Ref, error) {
s = strings.Trim(s, "/")
if s == "" {
return Ref{head}, nil
}
parts := strings.Split(s, "/")
ref := make(Ref, len(parts)+1)
ref[0] = head
for i := 0; i < len(parts); i++ {
var err error
parts[i], err = url.PathUnescape(parts[i])
if err != nil {
return nil, err
}
ref[i+1] = StringTerm(parts[i])
}
return ref, nil
}
// RefTerm creates a new Term with a Ref value.
func RefTerm(r ...*Term) *Term {
return &Term{Value: Ref(r)}
}
// Append returns a copy of ref with the term appended to the end.
func (ref Ref) Append(term *Term) Ref {
n := len(ref)
dst := make(Ref, n+1)
copy(dst, ref)
dst[n] = term
return dst
}
// Insert returns a copy of the ref with x inserted at pos. If pos < len(ref),
// existing elements are shifted to the right. If pos > len(ref)+1 this
// function panics.
func (ref Ref) Insert(x *Term, pos int) Ref {
if pos == len(ref) {
return ref.Append(x)
} else if pos > len(ref)+1 {
panic("illegal index")
}
cpy := make(Ref, len(ref)+1)
for i := 0; i < pos; i++ {
cpy[i] = ref[i]
}
cpy[pos] = x
for i := pos; i < len(ref); i++ {
cpy[i+1] = ref[i]
}
return cpy
}
// Extend returns a copy of ref with the terms from other appended. The head of
// other will be converted to a string.
func (ref Ref) Extend(other Ref) Ref {
dst := make(Ref, len(ref)+len(other))
for i := range ref {
dst[i] = ref[i]
}
head := other[0].Copy()
head.Value = String(head.Value.(Var))
offset := len(ref)
dst[offset] = head
for i := range other[1:] {
dst[offset+i+1] = other[i+1]
}
return dst
}
// Concat returns a ref with the terms appended.
func (ref Ref) Concat(terms []*Term) Ref {
if len(terms) == 0 {
return ref
}
cpy := make(Ref, len(ref)+len(terms))
for i := range ref {
cpy[i] = ref[i]
}
for i := range terms {
cpy[len(ref)+i] = terms[i]
}
return cpy
}
// Dynamic returns the offset of the first non-constant operand of ref.
func (ref Ref) Dynamic() int {
switch ref[0].Value.(type) {
case Call:
return 0
}
for i := 1; i < len(ref); i++ {
if !IsConstant(ref[i].Value) {
return i
}
}
return -1
}
// Copy returns a deep copy of ref.
func (ref Ref) Copy() Ref {
return termSliceCopy(ref)
}
// Equal returns true if ref is equal to other.
func (ref Ref) Equal(other Value) bool {
return Compare(ref, other) == 0
}
// Compare compares ref to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (ref Ref) Compare(other Value) int {
return Compare(ref, other)
}
// Find returns the current value or a not found error.
func (ref Ref) Find(path Ref) (Value, error) {
if len(path) == 0 {
return ref, nil
}
return nil, errFindNotFound
}
// Hash returns the hash code for the Value.
func (ref Ref) Hash() int {
return termSliceHash(ref)
}
// HasPrefix returns true if the other ref is a prefix of this ref.
func (ref Ref) HasPrefix(other Ref) bool {
if len(other) > len(ref) {
return false
}
for i := range other {
if !ref[i].Equal(other[i]) {
return false
}
}
return true
}
// ConstantPrefix returns the constant portion of the ref starting from the head.
func (ref Ref) ConstantPrefix() Ref {
ref = ref.Copy()
i := ref.Dynamic()
if i < 0 {
return ref
}
return ref[:i]
}
// GroundPrefix returns the ground portion of the ref starting from the head. By
// definition, the head of the reference is always ground.
func (ref Ref) GroundPrefix() Ref {
prefix := make(Ref, 0, len(ref))
for i, x := range ref {
if i > 0 && !x.IsGround() {
break
}
prefix = append(prefix, x)
}
return prefix
}
// IsGround returns true if all of the parts of the Ref are ground.
func (ref Ref) IsGround() bool {
if len(ref) == 0 {
return true
}
return termSliceIsGround(ref[1:])
}
// IsNested returns true if this ref contains other Refs.
func (ref Ref) IsNested() bool {
for _, x := range ref {
if _, ok := x.Value.(Ref); ok {
return true
}
}
return false
}
// Ptr returns a slash-separated path string for this ref. If the ref
// contains non-string terms this function returns an error. Path
// components are escaped.
func (ref Ref) Ptr() (string, error) {
parts := make([]string, 0, len(ref)-1)
for _, term := range ref[1:] {
if str, ok := term.Value.(String); ok {
parts = append(parts, url.PathEscape(string(str)))
} else {
return "", fmt.Errorf("invalid path value type")
}
}
return strings.Join(parts, "/"), nil
}
var varRegexp = regexp.MustCompile("^[[:alpha:]_][[:alpha:][:digit:]_]*$")
func (ref Ref) String() string {
if len(ref) == 0 {
return ""
}
buf := []string{ref[0].Value.String()}
path := ref[1:]
for _, p := range path {
switch p := p.Value.(type) {
case String:
str := string(p)
if varRegexp.MatchString(str) && len(buf) > 0 && !IsKeyword(str) {
buf = append(buf, "."+str)
} else {
buf = append(buf, "["+p.String()+"]")
}
default:
buf = append(buf, "["+p.String()+"]")
}
}
return strings.Join(buf, "")
}
// OutputVars returns a VarSet containing variables that would be bound by evaluating
// this expression in isolation.
func (ref Ref) OutputVars() VarSet {
vis := NewVarVisitor().WithParams(VarVisitorParams{SkipRefHead: true})
vis.Walk(ref)
return vis.Vars()
}
// QueryIterator defines the interface for querying AST documents with references.
type QueryIterator func(map[Var]Value, Value) error
// Array represents an array as defined by the language. Arrays are similar to the
// same types as defined by JSON with the exception that they can contain Vars
// and References.
type Array []*Term
// ArrayTerm creates a new Term with an Array value.
func ArrayTerm(a ...*Term) *Term {
return &Term{Value: Array(a)}
}
// Copy returns a deep copy of arr.
func (arr Array) Copy() Array {
return termSliceCopy(arr)
}
// Equal returns true if arr is equal to other.
func (arr Array) Equal(other Value) bool {
return Compare(arr, other) == 0
}
// Compare compares arr to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (arr Array) Compare(other Value) int {
return Compare(arr, other)
}
// Find returns the value at the index or an out-of-range error.
func (arr Array) Find(path Ref) (Value, error) {
if len(path) == 0 {
return arr, nil
}
num, ok := path[0].Value.(Number)
if !ok {
return nil, errFindNotFound
}
i, ok := num.Int()
if !ok {
return nil, errFindNotFound
}
if i < 0 || i >= len(arr) {
return nil, errFindNotFound
}
return arr[i].Value.Find(path[1:])
}
// Get returns the element at pos or nil if not possible.
func (arr Array) Get(pos *Term) *Term {
num, ok := pos.Value.(Number)
if !ok {
return nil
}
i, ok := num.Int()
if !ok {
return nil
}
if i >= 0 && i < len(arr) {
return arr[i]
}
return nil
}
// Sorted returns a new Array that contains the sorted elements of arr.
func (arr Array) Sorted() Array {
cpy := make(Array, len(arr))
for i := range cpy {
cpy[i] = arr[i]
}
sort.Sort(termSlice(cpy))
return cpy
}
// Hash returns the hash code for the Value.
func (arr Array) Hash() int {
return termSliceHash(arr)
}
// IsGround returns true if all of the Array elements are ground.
func (arr Array) IsGround() bool {
return termSliceIsGround(arr)
}
// MarshalJSON returns JSON encoded bytes representing arr.
func (arr Array) MarshalJSON() ([]byte, error) {
if len(arr) == 0 {
return json.Marshal([]interface{}{})
}
return json.Marshal([]*Term(arr))
}
func (arr Array) String() string {
var buf []string
for _, e := range arr {
buf = append(buf, e.String())
}
return "[" + strings.Join(buf, ", ") + "]"
}
// Set represents a set as defined by the language.
type Set interface {
Value
Len() int
Copy() Set
Diff(Set) Set
Intersect(Set) Set
Union(Set) Set
Add(*Term)
Iter(func(*Term) error) error
Until(func(*Term) bool) bool
Foreach(func(*Term))
Contains(*Term) bool
Map(func(*Term) (*Term, error)) (Set, error)
Reduce(*Term, func(*Term, *Term) (*Term, error)) (*Term, error)
Sorted() Array
Slice() []*Term
}
// NewSet returns a new Set containing t.
func NewSet(t ...*Term) Set {
s := newset(len(t))
for i := range t {
s.Add(t[i])
}
return s
}
func newset(n int) *set {
var keys []*Term
if n > 0 {
keys = make([]*Term, 0, n)
}
return &set{
elems: make(map[int]*Term, n),
keys: keys,
}
}
// SetTerm returns a new Term representing a set containing terms t.
func SetTerm(t ...*Term) *Term {
set := NewSet(t...)
return &Term{
Value: set,
}
}
type set struct {
elems map[int]*Term
keys []*Term
}
// Copy returns a deep copy of s.
func (s *set) Copy() Set {
cpy := newset(s.Len())
s.Foreach(func(x *Term) {
cpy.Add(x.Copy())
})
return cpy
}
// IsGround returns true if all terms in s are ground.
func (s *set) IsGround() bool {
return !s.Until(func(x *Term) bool {
return !x.IsGround()
})
}
// Hash returns a hash code for s.
func (s *set) Hash() int {
var hash int
s.Foreach(func(x *Term) {
hash += x.Hash()
})
return hash
}
func (s *set) String() string {
if s.Len() == 0 {
return "set()"
}
buf := []string{}
s.Foreach(func(x *Term) {
buf = append(buf, fmt.Sprint(x))
})
return "{" + strings.Join(buf, ", ") + "}"
}
// Compare compares s to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (s *set) Compare(other Value) int {
o1 := sortOrder(s)
o2 := sortOrder(other)
if o1 < o2 {
return -1
} else if o1 > o2 {
return 1
}
t := other.(*set)
sort.Sort(termSlice(s.keys))
sort.Sort(termSlice(t.keys))
return termSliceCompare(s.keys, t.keys)
}
// Find returns the set or dereferences the element itself.
func (s *set) Find(path Ref) (Value, error) {
if len(path) == 0 {
return s, nil
}
if !s.Contains(path[0]) {
return nil, errFindNotFound
}
return path[0].Value.Find(path[1:])
}
// Diff returns elements in s that are not in other.
func (s *set) Diff(other Set) Set {
r := NewSet()
s.Foreach(func(x *Term) {
if !other.Contains(x) {
r.Add(x)
}
})
return r
}
// Intersect returns the set containing elements in both s and other.
func (s *set) Intersect(other Set) Set {
o := other.(*set)
n, m := s.Len(), o.Len()
ss := s
so := o
if m < n {
ss = o
so = s
n = m
}
r := newset(n)
ss.Foreach(func(x *Term) {
if so.Contains(x) {
r.Add(x)
}
})
return r
}
// Union returns the set containing all elements of s and other.
func (s *set) Union(other Set) Set {
r := NewSet()
s.Foreach(func(x *Term) {
r.Add(x)
})
other.Foreach(func(x *Term) {
r.Add(x)
})
return r
}
// Add updates s to include t.
func (s *set) Add(t *Term) {
s.insert(t)
}
// Iter calls f on each element in s. If f returns an error, iteration stops
// and the return value is the error.
func (s *set) Iter(f func(*Term) error) error {
for i := range s.keys {
if err := f(s.keys[i]); err != nil {
return err
}
}
return nil
}
var errStop = errors.New("stop")
// Until calls f on each element in s. If f returns true, iteration stops.
func (s *set) Until(f func(*Term) bool) bool {
err := s.Iter(func(t *Term) error {
if f(t) {
return errStop
}
return nil
})
return err != nil
}
// Foreach calls f on each element in s.
func (s *set) Foreach(f func(*Term)) {
s.Iter(func(t *Term) error {
f(t)
return nil
})
}
// Map returns a new Set obtained by applying f to each value in s.
func (s *set) Map(f func(*Term) (*Term, error)) (Set, error) {
set := NewSet()
err := s.Iter(func(x *Term) error {
term, err := f(x)
if err != nil {
return err
}
set.Add(term)
return nil
})
if err != nil {
return nil, err
}
return set, nil
}
// Reduce returns a Term produced by applying f to each value in s. The first
// argument to f is the reduced value (starting with i) and the second argument
// to f is the element in s.
func (s *set) Reduce(i *Term, f func(*Term, *Term) (*Term, error)) (*Term, error) {
err := s.Iter(func(x *Term) error {
var err error
i, err = f(i, x)
if err != nil {
return err
}
return nil
})
return i, err
}
// Contains returns true if t is in s.
func (s *set) Contains(t *Term) bool {
return s.get(t) != nil
}
// Len returns the number of elements in the set.
func (s *set) Len() int {
return len(s.keys)
}
// MarshalJSON returns JSON encoded bytes representing s.
func (s *set) MarshalJSON() ([]byte, error) {
if s.keys == nil {
return json.Marshal([]interface{}{})
}
return json.Marshal(s.keys)
}
// Sorted returns an Array that contains the sorted elements of s.
func (s *set) Sorted() Array {
cpy := make(Array, len(s.keys))
for i := range cpy {
cpy[i] = s.keys[i]
}
sort.Sort(termSlice(cpy))
return cpy
}
// Slice returns a slice of terms contained in the set.
func (s *set) Slice() []*Term {
return s.keys
}
func (s *set) insert(x *Term) {
hash := x.Hash()
var equal func(v Value) bool
switch x := x.Value.(type) {
case Null, Boolean, String, Var:
equal = func(y Value) bool { return x == y }
case Number:
if xi, err := json.Number(x).Int64(); err == nil {
equal = func(y Value) bool {
if y, ok := y.(Number); ok {
if yi, err := json.Number(y).Int64(); err == nil {
return xi == yi
}
}
return false
}
break
}
a, ok := new(big.Float).SetString(string(x))
if !ok {
panic("illegal value")
}
equal = func(b Value) bool {
if b, ok := b.(Number); ok {
b, ok := new(big.Float).SetString(string(b))
if !ok {
panic("illegal value")
}
return a.Cmp(b) == 0
}
return false
}
default:
equal = func(y Value) bool { return Compare(x, y) == 0 }
}
for curr, ok := s.elems[hash]; ok; {
if equal(curr.Value) {
return
}
hash++
curr, ok = s.elems[hash]
}
s.elems[hash] = x
s.keys = append(s.keys, x)
}
func (s *set) get(x *Term) *Term {
hash := x.Hash()
var equal func(v Value) bool
switch x := x.Value.(type) {
case Null, Boolean, String, Var:
equal = func(y Value) bool { return x == y }
case Number:
if xi, err := json.Number(x).Int64(); err == nil {
equal = func(y Value) bool {
if y, ok := y.(Number); ok {
if yi, err := json.Number(y).Int64(); err == nil {
return xi == yi
}
}
return false
}
break
}
a, ok := new(big.Float).SetString(string(x))
if !ok {
panic("illegal value")
}
equal = func(b Value) bool {
if b, ok := b.(Number); ok {
b, ok := new(big.Float).SetString(string(b))
if !ok {
panic("illegal value")
}
return a.Cmp(b) == 0
}
return false
}
default:
equal = func(y Value) bool { return Compare(x, y) == 0 }
}
for curr, ok := s.elems[hash]; ok; {
if equal(curr.Value) {
return curr
}
hash++
curr, ok = s.elems[hash]
}
return nil
}
// Object represents an object as defined by the language.
type Object interface {
Value
Len() int
Get(*Term) *Term
Copy() Object
Insert(*Term, *Term)
Iter(func(*Term, *Term) error) error
Until(func(*Term, *Term) bool) bool
Foreach(func(*Term, *Term))
Map(func(*Term, *Term) (*Term, *Term, error)) (Object, error)
Diff(other Object) Object
Intersect(other Object) [][3]*Term
Merge(other Object) (Object, bool)
MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool)
Filter(filter Object) (Object, error)
Keys() []*Term
}
// NewObject creates a new Object with t.
func NewObject(t ...[2]*Term) Object {
obj := newobject(len(t))
for i := range t {
obj.Insert(t[i][0], t[i][1])
}
return obj
}
// ObjectTerm creates a new Term with an Object value.
func ObjectTerm(o ...[2]*Term) *Term {
return &Term{Value: NewObject(o...)}
}
type object struct {
elems map[int]*objectElem
keys []*Term
ground bool
}
func newobject(n int) *object {
var keys []*Term
if n > 0 {
keys = make([]*Term, 0, n)
}
return &object{
elems: make(map[int]*objectElem, n),
keys: keys,
ground: true,
}
}
type objectElem struct {
key *Term
value *Term
next *objectElem
}
// Item is a helper for constructing an tuple containing two Terms
// representing a key/value pair in an Object.
func Item(key, value *Term) [2]*Term {
return [2]*Term{key, value}
}
// Compare compares obj to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (obj *object) Compare(other Value) int {
o1 := sortOrder(obj)
o2 := sortOrder(other)
if o1 < o2 {
return -1
} else if o2 < o1 {
return 1
}
a := obj
b := other.(*object)
keysA := a.Keys()
keysB := b.Keys()
sort.Sort(termSlice(keysA))
sort.Sort(termSlice(keysB))
minLen := a.Len()
if b.Len() < a.Len() {
minLen = b.Len()
}
for i := 0; i < minLen; i++ {
keysCmp := Compare(keysA[i], keysB[i])
if keysCmp < 0 {
return -1
}
if keysCmp > 0 {
return 1
}
valA := a.Get(keysA[i])
valB := b.Get(keysB[i])
valCmp := Compare(valA, valB)
if valCmp != 0 {
return valCmp
}
}
if a.Len() < b.Len() {
return -1
}
if b.Len() < a.Len() {
return 1
}
return 0
}
// Find returns the value at the key or undefined.
func (obj *object) Find(path Ref) (Value, error) {
if len(path) == 0 {
return obj, nil
}
value := obj.Get(path[0])
if value == nil {
return nil, errFindNotFound
}
return value.Value.Find(path[1:])
}
func (obj *object) Insert(k, v *Term) {
obj.insert(k, v)
}
// Get returns the value of k in obj if k exists, otherwise nil.
func (obj *object) Get(k *Term) *Term {
if elem := obj.get(k); elem != nil {
return elem.value
}
return nil
}
// Hash returns the hash code for the Value.
func (obj *object) Hash() int {
var hash int
obj.Foreach(func(k, v *Term) {
hash += k.Value.Hash()
hash += v.Value.Hash()
})
return hash
}
// IsGround returns true if all of the Object key/value pairs are ground.
func (obj *object) IsGround() bool {
return obj.ground
}
// Copy returns a deep copy of obj.
func (obj *object) Copy() Object {
cpy, _ := obj.Map(func(k, v *Term) (*Term, *Term, error) {
return k.Copy(), v.Copy(), nil
})
return cpy
}
// Diff returns a new Object that contains only the key/value pairs that exist in obj.
func (obj *object) Diff(other Object) Object {
r := NewObject()
obj.Foreach(func(k, v *Term) {
if other.Get(k) == nil {
r.Insert(k, v)
}
})
return r
}
// Intersect returns a slice of term triplets that represent the intersection of keys
// between obj and other. For each intersecting key, the values from obj and other are included
// as the last two terms in the triplet (respectively).
func (obj *object) Intersect(other Object) [][3]*Term {
r := [][3]*Term{}
obj.Foreach(func(k, v *Term) {
if v2 := other.Get(k); v2 != nil {
r = append(r, [3]*Term{k, v, v2})
}
})
return r
}
// Iter calls the function f for each key-value pair in the object. If f
// returns an error, iteration stops and the error is returned.
func (obj *object) Iter(f func(*Term, *Term) error) error {
for i := range obj.keys {
k := obj.keys[i]
node := obj.get(k)
if node == nil {
panic("corrupt object")
}
if err := f(k, node.value); err != nil {
return err
}
}
return nil
}
// Until calls f for each key-value pair in the object. If f returns true,
// iteration stops.
func (obj *object) Until(f func(*Term, *Term) bool) bool {
err := obj.Iter(func(k, v *Term) error {
if f(k, v) {
return errStop
}
return nil
})
return err != nil
}
// Foreach calls f for each key-value pair in the object.
func (obj *object) Foreach(f func(*Term, *Term)) {
obj.Iter(func(k, v *Term) error {
f(k, v)
return nil
})
}
// Map returns a new Object constructed by mapping each element in the object
// using the function f.
func (obj *object) Map(f func(*Term, *Term) (*Term, *Term, error)) (Object, error) {
cpy := newobject(obj.Len())
err := obj.Iter(func(k, v *Term) error {
var err error
k, v, err = f(k, v)
if err != nil {
return err
}
cpy.insert(k, v)
return nil
})
if err != nil {
return nil, err
}
return cpy, nil
}
// Keys returns the keys of obj.
func (obj *object) Keys() []*Term {
return obj.keys
}
// MarshalJSON returns JSON encoded bytes representing obj.
func (obj *object) MarshalJSON() ([]byte, error) {
sl := make([][2]*Term, obj.Len())
for i := range obj.keys {
k := obj.keys[i]
sl[i] = Item(k, obj.get(k).value)
}
return json.Marshal(sl)
}
// Merge returns a new Object containing the non-overlapping keys of obj and other. If there are
// overlapping keys between obj and other, the values of associated with the keys are merged. Only
// objects can be merged with other objects. If the values cannot be merged, the second turn value
// will be false.
func (obj object) Merge(other Object) (Object, bool) {
return obj.MergeWith(other, func(v1, v2 *Term) (*Term, bool) {
obj1, ok1 := v1.Value.(Object)
obj2, ok2 := v2.Value.(Object)
if !ok1 || !ok2 {
return nil, true
}
obj3, ok := obj1.Merge(obj2)
if !ok {
return nil, true
}
return NewTerm(obj3), false
})
}
// MergeWith returns a new Object containing the merged keys of obj and other.
// If there are overlapping keys between obj and other, the conflictResolver
// is called. The conflictResolver can return a merged value and a boolean
// indicating if the merge has failed and should stop.
func (obj object) MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) {
result := NewObject()
stop := obj.Until(func(k, v *Term) bool {
v2 := other.Get(k)
// The key didn't exist in other, keep the original value
if v2 == nil {
result.Insert(k, v)
return false
}
// The key exists in both, resolve the conflict if possible
merged, stop := conflictResolver(v, v2)
if !stop {
result.Insert(k, merged)
}
return stop
})
if stop {
return nil, false
}
// Copy in any values from other for keys that don't exist in obj
other.Foreach(func(k, v *Term) {
if v2 := obj.Get(k); v2 == nil {
result.Insert(k, v)
}
})
return result, true
}
// Filter returns a new object from values in obj where the keys are
// found in filter. Array indices for values can be specified as
// number strings.
func (obj *object) Filter(filter Object) (Object, error) {
filtered, err := filterObject(obj, filter)
if err != nil {
return nil, err
}
return filtered.(Object), nil
}
// Len returns the number of elements in the object.
func (obj object) Len() int {
return len(obj.keys)
}
func (obj object) String() string {
var buf []string
obj.Foreach(func(k, v *Term) {
buf = append(buf, fmt.Sprintf("%s: %s", k, v))
})
return "{" + strings.Join(buf, ", ") + "}"
}
func (obj *object) get(k *Term) *objectElem {
hash := k.Hash()
var equal func(v Value) bool
switch x := k.Value.(type) {
case Null, Boolean, String, Var:
equal = func(y Value) bool { return x == y }
case Number:
if xi, err := json.Number(x).Int64(); err == nil {
equal = func(y Value) bool {
if y, ok := y.(Number); ok {
if yi, err := json.Number(y).Int64(); err == nil {
return xi == yi
}
}
return false
}
break
}
a, ok := new(big.Float).SetString(string(x))
if !ok {
panic("illegal value")
}
equal = func(b Value) bool {
if b, ok := b.(Number); ok {
b, ok := new(big.Float).SetString(string(b))
if !ok {
panic("illegal value")
}
return a.Cmp(b) == 0
}
return false
}
default:
equal = func(y Value) bool { return Compare(x, y) == 0 }
}
for curr := obj.elems[hash]; curr != nil; curr = curr.next {
if equal(curr.key.Value) {
return curr
}
}
return nil
}
func (obj *object) insert(k, v *Term) {
hash := k.Hash()
head := obj.elems[hash]
var equal func(v Value) bool
switch x := k.Value.(type) {
case Null, Boolean, String, Var:
equal = func(y Value) bool { return x == y }
case Number:
if xi, err := json.Number(x).Int64(); err == nil {
equal = func(y Value) bool {
if y, ok := y.(Number); ok {
if yi, err := json.Number(y).Int64(); err == nil {
return xi == yi
}
}
return false
}
break
}
a, ok := new(big.Float).SetString(string(x))
if !ok {
panic("illegal value")
}
equal = func(b Value) bool {
if b, ok := b.(Number); ok {
b, ok := new(big.Float).SetString(string(b))
if !ok {
panic("illegal value")
}
return a.Cmp(b) == 0
}
return false
}
default:
equal = func(y Value) bool { return Compare(x, y) == 0 }
}
for curr := head; curr != nil; curr = curr.next {
if equal(curr.key.Value) {
curr.value = v
return
}
}
obj.elems[hash] = &objectElem{
key: k,
value: v,
next: head,
}
obj.keys = append(obj.keys, k)
obj.ground = obj.ground && k.IsGround() && v.IsGround()
}
func filterObject(o Value, filter Value) (Value, error) {
if filter.Compare(Null{}) == 0 {
return o, nil
}
filteredObj, ok := filter.(Object)
if !ok {
return nil, fmt.Errorf("invalid filter value %q, expected an object", filter)
}
switch v := o.(type) {
case String, Number, Boolean, Null:
return o, nil
case Array:
var values Array
for i, t := range v {
subFilter := filteredObj.Get(StringTerm(strconv.Itoa(i)))
if subFilter != nil {
filteredValue, err := filterObject(t.Value, subFilter.Value)
if err != nil {
return nil, err
}
values = append(values, NewTerm(filteredValue))
}
}
return values, nil
case Set:
values := NewSet()
err := v.Iter(func(t *Term) error {
if filteredObj.Get(t) != nil {
filteredValue, err := filterObject(t.Value, filteredObj.Get(t).Value)
if err != nil {
return err
}
values.Add(NewTerm(filteredValue))
}
return nil
})
return values, err
case Object:
values := NewObject()
iterObj := v
other := filteredObj
if v.Len() < filteredObj.Len() {
iterObj = filteredObj
other = v
}
err := iterObj.Iter(func(key *Term, value *Term) error {
if other.Get(key) != nil {
filteredValue, err := filterObject(v.Get(key).Value, filteredObj.Get(key).Value)
if err != nil {
return err
}
values.Insert(key, NewTerm(filteredValue))
}
return nil
})
return values, err
default:
return nil, fmt.Errorf("invalid object value type %q", v)
}
}
// ArrayComprehension represents an array comprehension as defined in the language.
type ArrayComprehension struct {
Term *Term `json:"term"`
Body Body `json:"body"`
}
// ArrayComprehensionTerm creates a new Term with an ArrayComprehension value.
func ArrayComprehensionTerm(term *Term, body Body) *Term {
return &Term{
Value: &ArrayComprehension{
Term: term,
Body: body,
},
}
}
// Copy returns a deep copy of ac.
func (ac *ArrayComprehension) Copy() *ArrayComprehension {
cpy := *ac
cpy.Body = ac.Body.Copy()
cpy.Term = ac.Term.Copy()
return &cpy
}
// Equal returns true if ac is equal to other.
func (ac *ArrayComprehension) Equal(other Value) bool {
return Compare(ac, other) == 0
}
// Compare compares ac to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (ac *ArrayComprehension) Compare(other Value) int {
return Compare(ac, other)
}
// Find returns the current value or a not found error.
func (ac *ArrayComprehension) Find(path Ref) (Value, error) {
if len(path) == 0 {
return ac, nil
}
return nil, errFindNotFound
}
// Hash returns the hash code of the Value.
func (ac *ArrayComprehension) Hash() int {
return ac.Term.Hash() + ac.Body.Hash()
}
// IsGround returns true if the Term and Body are ground.
func (ac *ArrayComprehension) IsGround() bool {
return ac.Term.IsGround() && ac.Body.IsGround()
}
func (ac *ArrayComprehension) String() string {
return "[" + ac.Term.String() + " | " + ac.Body.String() + "]"
}
// ObjectComprehension represents an object comprehension as defined in the language.
type ObjectComprehension struct {
Key *Term `json:"key"`
Value *Term `json:"value"`
Body Body `json:"body"`
}
// ObjectComprehensionTerm creates a new Term with an ObjectComprehension value.
func ObjectComprehensionTerm(key, value *Term, body Body) *Term {
return &Term{
Value: &ObjectComprehension{
Key: key,
Value: value,
Body: body,
},
}
}
// Copy returns a deep copy of oc.
func (oc *ObjectComprehension) Copy() *ObjectComprehension {
cpy := *oc
cpy.Body = oc.Body.Copy()
cpy.Key = oc.Key.Copy()
cpy.Value = oc.Value.Copy()
return &cpy
}
// Equal returns true if oc is equal to other.
func (oc *ObjectComprehension) Equal(other Value) bool {
return Compare(oc, other) == 0
}
// Compare compares oc to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (oc *ObjectComprehension) Compare(other Value) int {
return Compare(oc, other)
}
// Find returns the current value or a not found error.
func (oc *ObjectComprehension) Find(path Ref) (Value, error) {
if len(path) == 0 {
return oc, nil
}
return nil, errFindNotFound
}
// Hash returns the hash code of the Value.
func (oc *ObjectComprehension) Hash() int {
return oc.Key.Hash() + oc.Value.Hash() + oc.Body.Hash()
}
// IsGround returns true if the Key, Value and Body are ground.
func (oc *ObjectComprehension) IsGround() bool {
return oc.Key.IsGround() && oc.Value.IsGround() && oc.Body.IsGround()
}
func (oc *ObjectComprehension) String() string {
return "{" + oc.Key.String() + ": " + oc.Value.String() + " | " + oc.Body.String() + "}"
}
// SetComprehension represents a set comprehension as defined in the language.
type SetComprehension struct {
Term *Term `json:"term"`
Body Body `json:"body"`
}
// SetComprehensionTerm creates a new Term with an SetComprehension value.
func SetComprehensionTerm(term *Term, body Body) *Term {
return &Term{
Value: &SetComprehension{
Term: term,
Body: body,
},
}
}
// Copy returns a deep copy of sc.
func (sc *SetComprehension) Copy() *SetComprehension {
cpy := *sc
cpy.Body = sc.Body.Copy()
cpy.Term = sc.Term.Copy()
return &cpy
}
// Equal returns true if sc is equal to other.
func (sc *SetComprehension) Equal(other Value) bool {
return Compare(sc, other) == 0
}
// Compare compares sc to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (sc *SetComprehension) Compare(other Value) int {
return Compare(sc, other)
}
// Find returns the current value or a not found error.
func (sc *SetComprehension) Find(path Ref) (Value, error) {
if len(path) == 0 {
return sc, nil
}
return nil, errFindNotFound
}
// Hash returns the hash code of the Value.
func (sc *SetComprehension) Hash() int {
return sc.Term.Hash() + sc.Body.Hash()
}
// IsGround returns true if the Term and Body are ground.
func (sc *SetComprehension) IsGround() bool {
return sc.Term.IsGround() && sc.Body.IsGround()
}
func (sc *SetComprehension) String() string {
return "{" + sc.Term.String() + " | " + sc.Body.String() + "}"
}
// Call represents as function call in the language.
type Call []*Term
// CallTerm returns a new Term with a Call value defined by terms. The first
// term is the operator and the rest are operands.
func CallTerm(terms ...*Term) *Term {
return NewTerm(Call(terms))
}
// Copy returns a deep copy of c.
func (c Call) Copy() Call {
return termSliceCopy(c)
}
// Compare compares c to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (c Call) Compare(other Value) int {
return Compare(c, other)
}
// Find returns the current value or a not found error.
func (c Call) Find(Ref) (Value, error) {
return nil, errFindNotFound
}
// Hash returns the hash code for the Value.
func (c Call) Hash() int {
return termSliceHash(c)
}
// IsGround returns true if the Value is ground.
func (c Call) IsGround() bool {
return termSliceIsGround(c)
}
// MakeExpr returns an ew Expr from this call.
func (c Call) MakeExpr(output *Term) *Expr {
terms := []*Term(c)
return NewExpr(append(terms, output))
}
func (c Call) String() string {
args := make([]string, len(c)-1)
for i := 1; i < len(c); i++ {
args[i-1] = c[i].String()
}
return fmt.Sprintf("%v(%v)", c[0], strings.Join(args, ", "))
}
func termSliceCopy(a []*Term) []*Term {
cpy := make([]*Term, len(a))
for i := range a {
cpy[i] = a[i].Copy()
}
return cpy
}
func termSliceEqual(a, b []*Term) bool {
if len(a) == len(b) {
for i := range a {
if !a[i].Equal(b[i]) {
return false
}
}
return true
}
return false
}
func termSliceHash(a []*Term) int {
var hash int
for _, v := range a {
hash += v.Value.Hash()
}
return hash
}
func termSliceIsGround(a []*Term) bool {
for _, v := range a {
if !v.IsGround() {
return false
}
}
return true
}
// NOTE(tsandall): The unmarshalling errors in these functions are not
// helpful for callers because they do not identify the source of the
// unmarshalling error. Because OPA doesn't accept JSON describing ASTs
// from callers, this is acceptable (for now). If that changes in the future,
// the error messages should be revisited. The current approach focuses
// on the happy path and treats all errors the same. If better error
// reporting is needed, the error paths will need to be fleshed out.
func unmarshalBody(b []interface{}) (Body, error) {
buf := Body{}
for _, e := range b {
if m, ok := e.(map[string]interface{}); ok {
expr := &Expr{}
if err := unmarshalExpr(expr, m); err == nil {
buf = append(buf, expr)
continue
}
}
goto unmarshal_error
}
return buf, nil
unmarshal_error:
return nil, fmt.Errorf("ast: unable to unmarshal body")
}
func unmarshalExpr(expr *Expr, v map[string]interface{}) error {
if x, ok := v["negated"]; ok {
if b, ok := x.(bool); ok {
expr.Negated = b
} else {
return fmt.Errorf("ast: unable to unmarshal negated field with type: %T (expected true or false)", v["negated"])
}
}
if err := unmarshalExprIndex(expr, v); err != nil {
return err
}
switch ts := v["terms"].(type) {
case map[string]interface{}:
t, err := unmarshalTerm(ts)
if err != nil {
return err
}
expr.Terms = t
case []interface{}:
terms, err := unmarshalTermSlice(ts)
if err != nil {
return err
}
expr.Terms = terms
default:
return fmt.Errorf(`ast: unable to unmarshal terms field with type: %T (expected {"value": ..., "type": ...} or [{"value": ..., "type": ...}, ...])`, v["terms"])
}
if x, ok := v["with"]; ok {
if sl, ok := x.([]interface{}); ok {
ws := make([]*With, len(sl))
for i := range sl {
var err error
ws[i], err = unmarshalWith(sl[i])
if err != nil {
return err
}
}
expr.With = ws
}
}
return nil
}
func unmarshalExprIndex(expr *Expr, v map[string]interface{}) error {
if x, ok := v["index"]; ok {
if n, ok := x.(json.Number); ok {
i, err := n.Int64()
if err == nil {
expr.Index = int(i)
return nil
}
}
}
return fmt.Errorf("ast: unable to unmarshal index field with type: %T (expected integer)", v["index"])
}
func unmarshalTerm(m map[string]interface{}) (*Term, error) {
v, err := unmarshalValue(m)
if err != nil {
return nil, err
}
return &Term{Value: v}, nil
}
func unmarshalTermSlice(s []interface{}) ([]*Term, error) {
buf := []*Term{}
for _, x := range s {
if m, ok := x.(map[string]interface{}); ok {
if t, err := unmarshalTerm(m); err == nil {
buf = append(buf, t)
continue
} else {
return nil, err
}
}
return nil, fmt.Errorf("ast: unable to unmarshal term")
}
return buf, nil
}
func unmarshalTermSliceValue(d map[string]interface{}) ([]*Term, error) {
if s, ok := d["value"].([]interface{}); ok {
return unmarshalTermSlice(s)
}
return nil, fmt.Errorf(`ast: unable to unmarshal term (expected {"value": [...], "type": ...} where type is one of: ref, array, or set)`)
}
func unmarshalWith(i interface{}) (*With, error) {
if m, ok := i.(map[string]interface{}); ok {
tgt, _ := m["target"].(map[string]interface{})
target, err := unmarshalTerm(tgt)
if err == nil {
val, _ := m["value"].(map[string]interface{})
value, err := unmarshalTerm(val)
if err == nil {
return &With{
Target: target,
Value: value,
}, nil
}
return nil, err
}
return nil, err
}
return nil, fmt.Errorf(`ast: unable to unmarshal with modifier (expected {"target": {...}, "value": {...}})`)
}
func unmarshalValue(d map[string]interface{}) (Value, error) {
v := d["value"]
switch d["type"] {
case "null":
return Null{}, nil
case "boolean":
if b, ok := v.(bool); ok {
return Boolean(b), nil
}
case "number":
if n, ok := v.(json.Number); ok {
return Number(n), nil
}
case "string":
if s, ok := v.(string); ok {
return String(s), nil
}
case "var":
if s, ok := v.(string); ok {
return Var(s), nil
}
case "ref":
if s, err := unmarshalTermSliceValue(d); err == nil {
return Ref(s), nil
}
case "array":
if s, err := unmarshalTermSliceValue(d); err == nil {
return Array(s), nil
}
case "set":
if s, err := unmarshalTermSliceValue(d); err == nil {
set := NewSet()
for _, x := range s {
set.Add(x)
}
return set, nil
}
case "object":
if s, ok := v.([]interface{}); ok {
buf := NewObject()
for _, x := range s {
if i, ok := x.([]interface{}); ok && len(i) == 2 {
p, err := unmarshalTermSlice(i)
if err == nil {
buf.Insert(p[0], p[1])
continue
}
}
goto unmarshal_error
}
return buf, nil
}
case "arraycomprehension", "setcomprehension":
if m, ok := v.(map[string]interface{}); ok {
t, ok := m["term"].(map[string]interface{})
if !ok {
goto unmarshal_error
}
term, err := unmarshalTerm(t)
if err != nil {
goto unmarshal_error
}
b, ok := m["body"].([]interface{})
if !ok {
goto unmarshal_error
}
body, err := unmarshalBody(b)
if err != nil {
goto unmarshal_error
}
if d["type"] == "arraycomprehension" {
return &ArrayComprehension{Term: term, Body: body}, nil
}
return &SetComprehension{Term: term, Body: body}, nil
}
case "objectcomprehension":
if m, ok := v.(map[string]interface{}); ok {
k, ok := m["key"].(map[string]interface{})
if !ok {
goto unmarshal_error
}
key, err := unmarshalTerm(k)
if err != nil {
goto unmarshal_error
}
v, ok := m["value"].(map[string]interface{})
if !ok {
goto unmarshal_error
}
value, err := unmarshalTerm(v)
if err != nil {
goto unmarshal_error
}
b, ok := m["body"].([]interface{})
if !ok {
goto unmarshal_error
}
body, err := unmarshalBody(b)
if err != nil {
goto unmarshal_error
}
return &ObjectComprehension{Key: key, Value: value, Body: body}, nil
}
case "call":
if s, err := unmarshalTermSliceValue(d); err == nil {
return Call(s), nil
}
}
unmarshal_error:
return nil, fmt.Errorf("ast: unable to unmarshal term")
}