Upgrade dependent version: github.com/open-policy-agent/opa (#5315)

Upgrade dependent version: github.com/open-policy-agent/opa v0.18.0 -> v0.45.0

Signed-off-by: hongzhouzi <hongzhouzi@kubesphere.io>

Signed-off-by: hongzhouzi <hongzhouzi@kubesphere.io>
This commit is contained in:
hongzhouzi
2022-10-31 10:58:55 +08:00
committed by GitHub
parent 668fca1773
commit ef03b1e3df
363 changed files with 277341 additions and 13544 deletions

View File

@@ -13,30 +13,31 @@ import (
func builtinCount(a ast.Value) (ast.Value, error) {
switch a := a.(type) {
case ast.Array:
return ast.IntNumberTerm(len(a)).Value, nil
case *ast.Array:
return ast.IntNumberTerm(a.Len()).Value, nil
case ast.Object:
return ast.IntNumberTerm(a.Len()).Value, nil
case ast.Set:
return ast.IntNumberTerm(a.Len()).Value, nil
case ast.String:
return ast.IntNumberTerm(len(a)).Value, nil
return ast.IntNumberTerm(len([]rune(a))).Value, nil
}
return nil, builtins.NewOperandTypeErr(1, a, "array", "object", "set")
return nil, builtins.NewOperandTypeErr(1, a, "array", "object", "set", "string")
}
func builtinSum(a ast.Value) (ast.Value, error) {
switch a := a.(type) {
case ast.Array:
case *ast.Array:
sum := big.NewFloat(0)
for _, x := range a {
err := a.Iter(func(x *ast.Term) error {
n, ok := x.Value.(ast.Number)
if !ok {
return nil, builtins.NewOperandElementErr(1, a, x.Value, "number")
return builtins.NewOperandElementErr(1, a, x.Value, "number")
}
sum = new(big.Float).Add(sum, builtins.NumberToFloat(n))
}
return builtins.FloatToNumber(sum), nil
return nil
})
return builtins.FloatToNumber(sum), err
case ast.Set:
sum := big.NewFloat(0)
err := a.Iter(func(x *ast.Term) error {
@@ -54,16 +55,17 @@ func builtinSum(a ast.Value) (ast.Value, error) {
func builtinProduct(a ast.Value) (ast.Value, error) {
switch a := a.(type) {
case ast.Array:
case *ast.Array:
product := big.NewFloat(1)
for _, x := range a {
err := a.Iter(func(x *ast.Term) error {
n, ok := x.Value.(ast.Number)
if !ok {
return nil, builtins.NewOperandElementErr(1, a, x.Value, "number")
return builtins.NewOperandElementErr(1, a, x.Value, "number")
}
product = new(big.Float).Mul(product, builtins.NumberToFloat(n))
}
return builtins.FloatToNumber(product), nil
return nil
})
return builtins.FloatToNumber(product), err
case ast.Set:
product := big.NewFloat(1)
err := a.Iter(func(x *ast.Term) error {
@@ -81,16 +83,16 @@ func builtinProduct(a ast.Value) (ast.Value, error) {
func builtinMax(a ast.Value) (ast.Value, error) {
switch a := a.(type) {
case ast.Array:
if len(a) == 0 {
case *ast.Array:
if a.Len() == 0 {
return nil, BuiltinEmpty{}
}
var max = ast.Value(ast.Null{})
for i := range a {
if ast.Compare(max, a[i].Value) <= 0 {
max = a[i].Value
a.Foreach(func(x *ast.Term) {
if ast.Compare(max, x.Value) <= 0 {
max = x.Value
}
}
})
return max, nil
case ast.Set:
if a.Len() == 0 {
@@ -110,16 +112,16 @@ func builtinMax(a ast.Value) (ast.Value, error) {
func builtinMin(a ast.Value) (ast.Value, error) {
switch a := a.(type) {
case ast.Array:
if len(a) == 0 {
case *ast.Array:
if a.Len() == 0 {
return nil, BuiltinEmpty{}
}
min := a[0].Value
for i := range a {
if ast.Compare(min, a[i].Value) >= 0 {
min = a[i].Value
min := a.Elem(0).Value
a.Foreach(func(x *ast.Term) {
if ast.Compare(min, x.Value) >= 0 {
min = x.Value
}
}
})
return min, nil
case ast.Set:
if a.Len() == 0 {
@@ -146,7 +148,7 @@ func builtinMin(a ast.Value) (ast.Value, error) {
func builtinSort(a ast.Value) (ast.Value, error) {
switch a := a.(type) {
case ast.Array:
case *ast.Array:
return a.Sorted(), nil
case ast.Set:
return a.Sorted(), nil
@@ -159,20 +161,24 @@ func builtinAll(a ast.Value) (ast.Value, error) {
case ast.Set:
res := true
match := ast.BooleanTerm(true)
val.Foreach(func(term *ast.Term) {
if !term.Equal(match) {
val.Until(func(term *ast.Term) bool {
if !match.Equal(term) {
res = false
return true
}
return false
})
return ast.Boolean(res), nil
case ast.Array:
case *ast.Array:
res := true
match := ast.BooleanTerm(true)
for _, term := range val {
if !term.Equal(match) {
val.Until(func(term *ast.Term) bool {
if !match.Equal(term) {
res = false
return true
}
}
return false
})
return ast.Boolean(res), nil
default:
return nil, builtins.NewOperandTypeErr(1, a, "array", "set")
@@ -182,28 +188,64 @@ func builtinAll(a ast.Value) (ast.Value, error) {
func builtinAny(a ast.Value) (ast.Value, error) {
switch val := a.(type) {
case ast.Set:
res := false
match := ast.BooleanTerm(true)
val.Foreach(func(term *ast.Term) {
if term.Equal(match) {
res = true
}
})
res := val.Len() > 0 && val.Contains(ast.BooleanTerm(true))
return ast.Boolean(res), nil
case ast.Array:
case *ast.Array:
res := false
match := ast.BooleanTerm(true)
for _, term := range val {
if term.Equal(match) {
val.Until(func(term *ast.Term) bool {
if match.Equal(term) {
res = true
return true
}
}
return false
})
return ast.Boolean(res), nil
default:
return nil, builtins.NewOperandTypeErr(1, a, "array", "set")
}
}
func builtinMember(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
containee := args[0]
switch c := args[1].Value.(type) {
case ast.Set:
return iter(ast.BooleanTerm(c.Contains(containee)))
case *ast.Array:
ret := false
c.Until(func(v *ast.Term) bool {
if v.Value.Compare(containee.Value) == 0 {
ret = true
}
return ret
})
return iter(ast.BooleanTerm(ret))
case ast.Object:
ret := false
c.Until(func(_, v *ast.Term) bool {
if v.Value.Compare(containee.Value) == 0 {
ret = true
}
return ret
})
return iter(ast.BooleanTerm(ret))
}
return iter(ast.BooleanTerm(false))
}
func builtinMemberWithKey(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
key, val := args[0], args[1]
switch c := args[2].Value.(type) {
case interface{ Get(*ast.Term) *ast.Term }:
ret := false
if act := c.Get(key); act != nil {
ret = act.Value.Compare(val.Value) == 0
}
return iter(ast.BooleanTerm(ret))
}
return iter(ast.BooleanTerm(false))
}
func init() {
RegisterFunctionalBuiltin1(ast.Count.Name, builtinCount)
RegisterFunctionalBuiltin1(ast.Sum.Name, builtinSum)
@@ -213,4 +255,6 @@ func init() {
RegisterFunctionalBuiltin1(ast.Sort.Name, builtinSort)
RegisterFunctionalBuiltin1(ast.Any.Name, builtinAny)
RegisterFunctionalBuiltin1(ast.All.Name, builtinAll)
RegisterBuiltinFunc(ast.Member.Name, builtinMember)
RegisterBuiltinFunc(ast.MemberWithKey.Name, builtinMemberWithKey)
}

View File

@@ -32,6 +32,28 @@ func arithRound(a *big.Float) (*big.Float, error) {
return new(big.Float).SetInt(i), nil
}
func arithCeil(a *big.Float) (*big.Float, error) {
i, _ := a.Int(nil)
f := new(big.Float).SetInt(i)
if f.Signbit() || a.Cmp(f) == 0 {
return f, nil
}
return new(big.Float).Add(f, big.NewFloat(1.0)), nil
}
func arithFloor(a *big.Float) (*big.Float, error) {
i, _ := a.Int(nil)
f := new(big.Float).SetInt(i)
if !f.Signbit() || a.Cmp(f) == 0 {
return f, nil
}
return new(big.Float).Sub(f, big.NewFloat(1.0)), nil
}
func arithPlus(a, b *big.Float) (*big.Float, error) {
return new(big.Float).Add(a, b), nil
}
@@ -115,7 +137,11 @@ func builtinMinus(a, b ast.Value) (ast.Value, error) {
return nil, builtins.NewOperandTypeErr(1, a, "number", "set")
}
return nil, builtins.NewOperandTypeErr(2, b, "number", "set")
if ok2 {
return nil, builtins.NewOperandTypeErr(2, b, "set")
}
return nil, builtins.NewOperandTypeErr(2, b, "number")
}
func builtinRem(a, b ast.Value) (ast.Value, error) {
@@ -148,6 +174,8 @@ func builtinRem(a, b ast.Value) (ast.Value, error) {
func init() {
RegisterFunctionalBuiltin1(ast.Abs.Name, builtinArithArity1(arithAbs))
RegisterFunctionalBuiltin1(ast.Round.Name, builtinArithArity1(arithRound))
RegisterFunctionalBuiltin1(ast.Ceil.Name, builtinArithArity1(arithCeil))
RegisterFunctionalBuiltin1(ast.Floor.Name, builtinArithArity1(arithFloor))
RegisterFunctionalBuiltin2(ast.Plus.Name, builtinArithArity2(arithPlus))
RegisterFunctionalBuiltin2(ast.Minus.Name, builtinMinus)
RegisterFunctionalBuiltin2(ast.Multiply.Name, builtinArithArity2(arithMultiply))

View File

@@ -20,17 +20,20 @@ func builtinArrayConcat(a, b ast.Value) (ast.Value, error) {
return nil, err
}
arrC := make(ast.Array, 0, len(arrA)+len(arrB))
arrC := make([]*ast.Term, arrA.Len()+arrB.Len())
for _, elemA := range arrA {
arrC = append(arrC, elemA)
}
i := 0
arrA.Foreach(func(elemA *ast.Term) {
arrC[i] = elemA
i++
})
for _, elemB := range arrB {
arrC = append(arrC, elemB)
}
arrB.Foreach(func(elemB *ast.Term) {
arrC[i] = elemB
i++
})
return arrC, nil
return ast.NewArray(arrC...), nil
}
func builtinArraySlice(a, i, j ast.Value) (ast.Value, error) {
@@ -49,27 +52,43 @@ func builtinArraySlice(a, i, j ast.Value) (ast.Value, error) {
return nil, err
}
// Return empty array if bounds cannot be clamped sensibly.
if (startIndex >= stopIndex) || (startIndex <= 0 && stopIndex <= 0) {
return arr[0:0], nil
// Clamp stopIndex to avoid out-of-range errors. If negative, clamp to zero.
// Otherwise, clamp to length of array.
if stopIndex < 0 {
stopIndex = 0
} else if stopIndex > arr.Len() {
stopIndex = arr.Len()
}
// Clamp bounds to avoid out-of-range errors.
// Clamp startIndex to avoid out-of-range errors. If negative, clamp to zero.
// Otherwise, clamp to stopIndex to avoid to avoid cases like arr[1:0].
if startIndex < 0 {
startIndex = 0
} else if startIndex > stopIndex {
startIndex = stopIndex
}
if stopIndex > len(arr) {
stopIndex = len(arr)
return arr.Slice(startIndex, stopIndex), nil
}
func builtinArrayReverse(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
arr, err := builtins.ArrayOperand(operands[0].Value, 1)
if err != nil {
return err
}
arrb := arr[startIndex:stopIndex]
length := arr.Len()
reversedArr := make([]*ast.Term, length)
return arrb, nil
for index := 0; index < length; index++ {
reversedArr[index] = arr.Elem(length - index - 1)
}
return iter(ast.ArrayTerm(reversedArr...))
}
func init() {
RegisterFunctionalBuiltin2(ast.ArrayConcat.Name, builtinArrayConcat)
RegisterFunctionalBuiltin3(ast.ArraySlice.Name, builtinArraySlice)
RegisterBuiltinFunc(ast.ArrayReverse.Name, builtinArrayReverse)
}

View File

@@ -12,9 +12,8 @@ import (
)
type undo struct {
k *ast.Term
u *bindings
next *undo
k *ast.Term
u *bindings
}
func (u *undo) Undo() {
@@ -27,7 +26,6 @@ func (u *undo) Undo() {
return
}
u.u.delete(u.k)
u.next.Undo()
}
type bindings struct {
@@ -88,13 +86,16 @@ func (u *bindings) plugNamespaced(a *ast.Term, caller *bindings) *ast.Term {
return next.plugNamespaced(b, caller)
}
return u.namespaceVar(b, caller)
case ast.Array:
cpy := *a
arr := make(ast.Array, len(v))
for i := 0; i < len(arr); i++ {
arr[i] = u.plugNamespaced(v[i], caller)
case *ast.Array:
if a.IsGround() {
return a
}
cpy.Value = arr
cpy := *a
arr := make([]*ast.Term, v.Len())
for i := 0; i < len(arr); i++ {
arr[i] = u.plugNamespaced(v.Elem(i), caller)
}
cpy.Value = ast.NewArray(arr...)
return &cpy
case ast.Object:
if a.IsGround() {
@@ -106,6 +107,9 @@ func (u *bindings) plugNamespaced(a *ast.Term, caller *bindings) *ast.Term {
})
return &cpy
case ast.Set:
if a.IsGround() {
return a
}
cpy := *a
cpy.Value, _ = v.Map(func(x *ast.Term) (*ast.Term, error) {
return u.plugNamespaced(x, caller), nil
@@ -123,19 +127,20 @@ func (u *bindings) plugNamespaced(a *ast.Term, caller *bindings) *ast.Term {
return a
}
func (u *bindings) bind(a *ast.Term, b *ast.Term, other *bindings) *undo {
func (u *bindings) bind(a *ast.Term, b *ast.Term, other *bindings, und *undo) {
u.values.Put(a, value{
u: other,
v: b,
})
return &undo{a, u, nil}
und.k = a
und.u = u
}
func (u *bindings) apply(a *ast.Term) (*ast.Term, *bindings) {
// Early exit for non-var terms. Only vars are bound in the binding list,
// so the lookup below will always fail for non-var terms. In some cases,
// the lookup may be expensive as it has to hash the term (which for large
// inputs can be costly.)
// inputs can be costly).
_, ok := a.Value.(ast.Var)
if !ok {
return a, u
@@ -242,13 +247,16 @@ func (vis namespacingVisitor) namespaceTerm(a *ast.Term) *ast.Term {
switch v := a.Value.(type) {
case ast.Var:
return vis.b.namespaceVar(a, vis.caller)
case ast.Array:
cpy := *a
arr := make(ast.Array, len(v))
for i := 0; i < len(arr); i++ {
arr[i] = vis.namespaceTerm(v[i])
case *ast.Array:
if a.IsGround() {
return a
}
cpy.Value = arr
cpy := *a
arr := make([]*ast.Term, v.Len())
for i := 0; i < len(arr); i++ {
arr[i] = vis.namespaceTerm(v.Elem(i))
}
cpy.Value = ast.NewArray(arr...)
return &cpy
case ast.Object:
if a.IsGround() {
@@ -260,6 +268,9 @@ func (vis namespacingVisitor) namespaceTerm(a *ast.Term) *ast.Term {
})
return &cpy
case ast.Set:
if a.IsGround() {
return a
}
cpy := *a
cpy.Value, _ = v.Map(func(x *ast.Term) (*ast.Term, error) {
return vis.namespaceTerm(x), nil
@@ -279,7 +290,9 @@ func (vis namespacingVisitor) namespaceTerm(a *ast.Term) *ast.Term {
const maxLinearScan = 16
// bindingsArrayHashMap uses an array with linear scan instead of a hash map for smaller # of entries. Hash maps start to show off their performance advantage only after 16 keys.
// bindingsArrayHashMap uses an array with linear scan instead
// of a hash map for smaller # of entries. Hash maps start to
// show off their performance advantage only after 16 keys.
type bindingsArrayHashmap struct {
n int // Entries in the array.
a *[maxLinearScan]bindingArrayKeyValue

View File

@@ -6,10 +6,17 @@ package topdown
import (
"context"
"encoding/binary"
"fmt"
"io"
"math/rand"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/metrics"
"github.com/open-policy-agent/opa/topdown/builtins"
"github.com/open-policy-agent/opa/topdown/cache"
"github.com/open-policy-agent/opa/topdown/print"
"github.com/open-policy-agent/opa/tracing"
)
type (
@@ -28,14 +35,25 @@ type (
// BuiltinContext contains context from the evaluator that may be used by
// built-in functions.
BuiltinContext struct {
Context context.Context // request context that was passed when query started
Cancel Cancel // atomic value that signals evaluation to halt
Runtime *ast.Term // runtime information on the OPA instance
Cache builtins.Cache // built-in function state cache
Location *ast.Location // location of built-in call
Tracers []Tracer // tracer objects for trace() built-in function
QueryID uint64 // identifies query being evaluated
ParentID uint64 // identifies parent of query being evaluated
Context context.Context // request context that was passed when query started
Metrics metrics.Metrics // metrics registry for recording built-in specific metrics
Seed io.Reader // randomization source
Time *ast.Term // wall clock time
Cancel Cancel // atomic value that signals evaluation to halt
Runtime *ast.Term // runtime information on the OPA instance
Cache builtins.Cache // built-in function state cache
InterQueryBuiltinCache cache.InterQueryCache // cross-query built-in function state cache
NDBuiltinCache builtins.NDBCache // cache for non-deterministic built-in state
Location *ast.Location // location of built-in call
Tracers []Tracer // Deprecated: Use QueryTracers instead
QueryTracers []QueryTracer // tracer objects for trace() built-in function
TraceEnabled bool // indicates whether tracing is enabled for the evaluation
QueryID uint64 // identifies query being evaluated
ParentID uint64 // identifies parent of query being evaluated
PrintHook print.Hook // provides callback function to use for printing
DistributedTracingOpts tracing.Options // options to be used by distributed tracing.
rand *rand.Rand // randomization source for non-security-sensitive operations
Capabilities *ast.Capabilities
}
// BuiltinFunc defines an interface for implementing built-in functions.
@@ -46,6 +64,25 @@ type (
BuiltinFunc func(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error
)
// Rand returns a random number generator based on the Seed for this built-in
// context. The random number will be re-used across multiple calls to this
// function. If a random number generator cannot be created, an error is
// returned.
func (bctx *BuiltinContext) Rand() (*rand.Rand, error) {
if bctx.rand != nil {
return bctx.rand, nil
}
seed, err := readInt64(bctx.Seed)
if err != nil {
return nil, err
}
bctx.rand = rand.New(rand.NewSource(seed))
return bctx.rand, nil
}
// RegisterBuiltinFunc adds a new built-in function to the evaluation engine.
func RegisterBuiltinFunc(name string, f BuiltinFunc) {
builtinFunctions[name] = builtinErrorWrapper(name, f)
@@ -142,7 +179,7 @@ func handleBuiltinErr(name string, loc *ast.Location, err error) error {
switch err := err.(type) {
case BuiltinEmpty:
return nil
case *Error:
case *Error, Halt:
return err
case builtins.ErrOperand:
return &Error{
@@ -158,3 +195,12 @@ func handleBuiltinErr(name string, loc *ast.Location, err error) error {
}
}
}
func readInt64(r io.Reader) (int64, error) {
bs := make([]byte, 8)
n, err := io.ReadFull(r, bs)
if n != len(bs) || err != nil {
return 0, err
}
return int64(binary.BigEndian.Uint64(bs)), nil
}

View File

@@ -6,11 +6,13 @@
package builtins
import (
"encoding/json"
"fmt"
"math/big"
"strings"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/util"
)
// Cache defines the built-in cache used by the top-down evaluation. The keys
@@ -28,6 +30,85 @@ func (c Cache) Get(k interface{}) (interface{}, bool) {
return v, ok
}
// We use an ast.Object for the cached keys/values because a naive
// map[ast.Value]ast.Value will not correctly detect value equality of
// the member keys.
type NDBCache map[string]ast.Object
func (c NDBCache) AsValue() ast.Value {
out := ast.NewObject()
for bname, obj := range c {
out.Insert(ast.StringTerm(bname), ast.NewTerm(obj))
}
return out
}
// Put updates the cache for the named built-in.
// Automatically creates the 2-level hierarchy as needed.
func (c NDBCache) Put(name string, k, v ast.Value) {
if _, ok := c[name]; !ok {
c[name] = ast.NewObject()
}
c[name].Insert(ast.NewTerm(k), ast.NewTerm(v))
}
// Get returns the cached value for k for the named builtin.
func (c NDBCache) Get(name string, k ast.Value) (ast.Value, bool) {
if m, ok := c[name]; ok {
v := m.Get(ast.NewTerm(k))
if v != nil {
return v.Value, true
}
return nil, false
}
return nil, false
}
// Convenience functions for serializing the data structure.
func (c NDBCache) MarshalJSON() ([]byte, error) {
v, err := ast.JSON(c.AsValue())
if err != nil {
return nil, err
}
return json.Marshal(v)
}
func (c *NDBCache) UnmarshalJSON(data []byte) error {
out := map[string]ast.Object{}
var incoming interface{}
// Note: We use util.Unmarshal instead of json.Unmarshal to get
// correct deserialization of number types.
err := util.Unmarshal(data, &incoming)
if err != nil {
return err
}
// Convert interface types back into ast.Value types.
nestedObject, err := ast.InterfaceToValue(incoming)
if err != nil {
return err
}
// Reconstruct NDBCache from nested ast.Object structure.
if source, ok := nestedObject.(ast.Object); ok {
err = source.Iter(func(k, v *ast.Term) error {
if obj, ok := v.Value.(ast.Object); ok {
out[string(k.Value.(ast.String))] = obj
return nil
}
return fmt.Errorf("expected Object, got other Value type in conversion")
})
if err != nil {
return err
}
}
*c = out
return nil
}
// ErrOperand represents an invalid operand has been passed to a built-in
// function. Built-ins should return ErrOperand to indicate a type error has
// occurred.
@@ -149,10 +230,10 @@ func ObjectOperand(x ast.Value, pos int) (ast.Object, error) {
// ArrayOperand converts x to an array. If the cast fails, a descriptive
// error is returned.
func ArrayOperand(x ast.Value, pos int) (ast.Array, error) {
a, ok := x.(ast.Array)
func ArrayOperand(x ast.Value, pos int) (*ast.Array, error) {
a, ok := x.(*ast.Array)
if !ok {
return nil, NewOperandTypeErr(pos, x, "array")
return ast.NewArray(), NewOperandTypeErr(pos, x, "array")
}
return a, nil
}
@@ -168,7 +249,7 @@ func NumberToFloat(n ast.Number) *big.Float {
// FloatToNumber converts f to a number.
func FloatToNumber(f *big.Float) ast.Number {
return ast.Number(f.String())
return ast.Number(f.Text('g', -1))
}
// NumberToInt converts n to a big int.
@@ -189,23 +270,30 @@ func IntToNumber(i *big.Int) ast.Number {
// StringSliceOperand converts x to a []string. If the cast fails, a descriptive error is
// returned.
func StringSliceOperand(x ast.Value, pos int) ([]string, error) {
a, err := ArrayOperand(x, pos)
if err != nil {
func StringSliceOperand(a ast.Value, pos int) ([]string, error) {
type iterable interface {
Iter(func(*ast.Term) error) error
Len() int
}
strs, ok := a.(iterable)
if !ok {
return nil, NewOperandTypeErr(pos, a, "array", "set")
}
var outStrs = make([]string, 0, strs.Len())
if err := strs.Iter(func(x *ast.Term) error {
s, ok := x.Value.(ast.String)
if !ok {
return NewOperandElementErr(pos, a, x.Value, "string")
}
outStrs = append(outStrs, string(s))
return nil
}); err != nil {
return nil, err
}
var f = make([]string, len(a))
for k, b := range a {
c, ok := b.Value.(ast.String)
if !ok {
return nil, NewOperandElementErr(pos, x, b.Value, "[]string")
}
f[k] = string(c)
}
return f, nil
return outStrs, nil
}
// RuneSliceOperand converts x to a []rune. If the cast fails, a descriptive error is
@@ -216,8 +304,9 @@ func RuneSliceOperand(x ast.Value, pos int) ([]rune, error) {
return nil, err
}
var f = make([]rune, len(a))
for k, b := range a {
var f = make([]rune, a.Len())
for k := 0; k < a.Len(); k++ {
b := a.Elem(k)
c, ok := b.Value.(ast.String)
if !ok {
return nil, NewOperandElementErr(pos, x, b.Value, "string")

View File

@@ -164,3 +164,128 @@ func (s *refStack) Prefixed(ref ast.Ref) bool {
}
return false
}
type comprehensionCache struct {
stack []map[*ast.Term]*comprehensionCacheElem
}
type comprehensionCacheElem struct {
value *ast.Term
children *util.HashMap
}
func newComprehensionCache() *comprehensionCache {
cache := &comprehensionCache{}
cache.Push()
return cache
}
func (c *comprehensionCache) Push() {
c.stack = append(c.stack, map[*ast.Term]*comprehensionCacheElem{})
}
func (c *comprehensionCache) Pop() {
c.stack = c.stack[:len(c.stack)-1]
}
func (c *comprehensionCache) Elem(t *ast.Term) (*comprehensionCacheElem, bool) {
elem, ok := c.stack[len(c.stack)-1][t]
return elem, ok
}
func (c *comprehensionCache) Set(t *ast.Term, elem *comprehensionCacheElem) {
c.stack[len(c.stack)-1][t] = elem
}
func newComprehensionCacheElem() *comprehensionCacheElem {
return &comprehensionCacheElem{children: newComprehensionCacheHashMap()}
}
func (c *comprehensionCacheElem) Get(key []*ast.Term) *ast.Term {
node := c
for i := 0; i < len(key); i++ {
x, ok := node.children.Get(key[i])
if !ok {
return nil
}
node = x.(*comprehensionCacheElem)
}
return node.value
}
func (c *comprehensionCacheElem) Put(key []*ast.Term, value *ast.Term) {
node := c
for i := 0; i < len(key); i++ {
x, ok := node.children.Get(key[i])
if ok {
node = x.(*comprehensionCacheElem)
} else {
next := newComprehensionCacheElem()
node.children.Put(key[i], next)
node = next
}
}
node.value = value
}
func newComprehensionCacheHashMap() *util.HashMap {
return util.NewHashMap(func(a, b util.T) bool {
return a.(*ast.Term).Equal(b.(*ast.Term))
}, func(x util.T) int {
return x.(*ast.Term).Hash()
})
}
type functionMocksStack struct {
stack []*functionMocksElem
}
type functionMocksElem []frame
type frame map[string]*ast.Term
func newFunctionMocksStack() *functionMocksStack {
stack := &functionMocksStack{}
stack.Push()
return stack
}
func newFunctionMocksElem() *functionMocksElem {
return &functionMocksElem{}
}
func (s *functionMocksStack) Push() {
s.stack = append(s.stack, newFunctionMocksElem())
}
func (s *functionMocksStack) Pop() {
s.stack = s.stack[:len(s.stack)-1]
}
func (s *functionMocksStack) PopPairs() {
current := s.stack[len(s.stack)-1]
*current = (*current)[:len(*current)-1]
}
func (s *functionMocksStack) PutPairs(mocks [][2]*ast.Term) {
el := frame{}
for i := range mocks {
el[mocks[i][0].Value.String()] = mocks[i][1]
}
s.Put(el)
}
func (s *functionMocksStack) Put(el frame) {
current := s.stack[len(s.stack)-1]
*current = append(*current, el)
}
func (s *functionMocksStack) Get(f ast.Ref) (*ast.Term, bool) {
current := *s.stack[len(s.stack)-1]
for i := len(current) - 1; i >= 0; i-- {
if r, ok := current[i][f.String()]; ok {
return r, true
}
}
return nil, false
}

View File

@@ -0,0 +1,167 @@
// Copyright 2020 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
// Package cache defines the inter-query cache interface that can cache data across queries
package cache
import (
"container/list"
"github.com/open-policy-agent/opa/ast"
"sync"
"github.com/open-policy-agent/opa/util"
)
const (
defaultMaxSizeBytes = int64(0) // unlimited
)
// Config represents the configuration of the inter-query cache.
type Config struct {
InterQueryBuiltinCache InterQueryBuiltinCacheConfig `json:"inter_query_builtin_cache"`
}
// InterQueryBuiltinCacheConfig represents the configuration of the inter-query cache that built-in functions can utilize.
type InterQueryBuiltinCacheConfig struct {
MaxSizeBytes *int64 `json:"max_size_bytes,omitempty"`
}
// ParseCachingConfig returns the config for the inter-query cache.
func ParseCachingConfig(raw []byte) (*Config, error) {
if raw == nil {
maxSize := new(int64)
*maxSize = defaultMaxSizeBytes
return &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize}}, nil
}
var config Config
if err := util.Unmarshal(raw, &config); err == nil {
if err = config.validateAndInjectDefaults(); err != nil {
return nil, err
}
} else {
return nil, err
}
return &config, nil
}
func (c *Config) validateAndInjectDefaults() error {
if c.InterQueryBuiltinCache.MaxSizeBytes == nil {
maxSize := new(int64)
*maxSize = defaultMaxSizeBytes
c.InterQueryBuiltinCache.MaxSizeBytes = maxSize
}
return nil
}
// InterQueryCacheValue defines the interface for the data that the inter-query cache holds.
type InterQueryCacheValue interface {
SizeInBytes() int64
}
// InterQueryCache defines the interface for the inter-query cache.
type InterQueryCache interface {
Get(key ast.Value) (value InterQueryCacheValue, found bool)
Insert(key ast.Value, value InterQueryCacheValue) int
Delete(key ast.Value)
UpdateConfig(config *Config)
}
// NewInterQueryCache returns a new inter-query cache.
func NewInterQueryCache(config *Config) InterQueryCache {
return &cache{
items: map[string]InterQueryCacheValue{},
usage: 0,
config: config,
l: list.New(),
}
}
type cache struct {
items map[string]InterQueryCacheValue
usage int64
config *Config
l *list.List
mtx sync.Mutex
}
// Insert inserts a key k into the cache with value v.
func (c *cache) Insert(k ast.Value, v InterQueryCacheValue) (dropped int) {
c.mtx.Lock()
defer c.mtx.Unlock()
return c.unsafeInsert(k, v)
}
// Get returns the value in the cache for k.
func (c *cache) Get(k ast.Value) (InterQueryCacheValue, bool) {
c.mtx.Lock()
defer c.mtx.Unlock()
return c.unsafeGet(k)
}
// Delete deletes the value in the cache for k.
func (c *cache) Delete(k ast.Value) {
c.mtx.Lock()
defer c.mtx.Unlock()
c.unsafeDelete(k)
}
func (c *cache) UpdateConfig(config *Config) {
if config == nil {
return
}
c.mtx.Lock()
defer c.mtx.Unlock()
c.config = config
}
func (c *cache) unsafeInsert(k ast.Value, v InterQueryCacheValue) (dropped int) {
size := v.SizeInBytes()
limit := c.maxSizeBytes()
if limit > 0 {
if size > limit {
dropped++
return dropped
}
for key := c.l.Front(); key != nil && (c.usage+size > limit); key = c.l.Front() {
dropKey := key.Value.(ast.Value)
c.unsafeDelete(dropKey)
c.l.Remove(key)
dropped++
}
}
c.items[k.String()] = v
c.l.PushBack(k)
c.usage += size
return dropped
}
func (c *cache) unsafeGet(k ast.Value) (InterQueryCacheValue, bool) {
value, ok := c.items[k.String()]
return value, ok
}
func (c *cache) unsafeDelete(k ast.Value) {
value, ok := c.unsafeGet(k)
if !ok {
return
}
c.usage -= int64(value.SizeInBytes())
delete(c.items, k.String())
}
func (c *cache) maxSizeBytes() int64 {
if c.config == nil {
return defaultMaxSizeBytes
}
return *c.config.InterQueryBuiltinCache.MaxSizeBytes
}

View File

@@ -35,16 +35,16 @@ func builtinToNumber(a ast.Value) (ast.Value, error) {
// Deprecated in v0.13.0.
func builtinToArray(a ast.Value) (ast.Value, error) {
switch val := a.(type) {
case ast.Array:
case *ast.Array:
return val, nil
case ast.Set:
arr := make(ast.Array, val.Len())
arr := make([]*ast.Term, val.Len())
i := 0
val.Foreach(func(term *ast.Term) {
arr[i] = term
i++
})
return arr, nil
return ast.NewArray(arr...), nil
default:
return nil, builtins.NewOperandTypeErr(1, a, "array", "set")
}
@@ -53,8 +53,12 @@ func builtinToArray(a ast.Value) (ast.Value, error) {
// Deprecated in v0.13.0.
func builtinToSet(a ast.Value) (ast.Value, error) {
switch val := a.(type) {
case ast.Array:
return ast.NewSet(val...), nil
case *ast.Array:
s := ast.NewSet()
val.Foreach(func(v *ast.Term) {
s.Add(v)
})
return s, nil
case ast.Set:
return val, nil
default:

View File

@@ -1,11 +1,15 @@
package topdown
import (
"bytes"
"errors"
"fmt"
"math/big"
"net"
"sort"
"github.com/open-policy-agent/opa/ast"
cidrMerge "github.com/open-policy-agent/opa/internal/cidr/merge"
"github.com/open-policy-agent/opa/topdown/builtins"
)
@@ -69,7 +73,7 @@ func builtinNetCIDRIntersects(a, b ast.Value) (ast.Value, error) {
}
// If either net contains the others starting IP they are overlapping
cidrsOverlap := (cidrnetA.Contains(cidrnetB.IP) || cidrnetB.Contains(cidrnetA.IP))
cidrsOverlap := cidrnetA.Contains(cidrnetB.IP) || cidrnetB.Contains(cidrnetA.IP)
return ast.Boolean(cidrsOverlap), nil
}
@@ -97,7 +101,8 @@ func builtinNetCIDRContains(a, b ast.Value) (ast.Value, error) {
return nil, fmt.Errorf("not a valid textual representation of an IP address or CIDR: %s", string(bStr))
}
// We can determine if cidr A contains cidr B iff A contains the starting address of B and the last address in B.
// We can determine if cidr A contains cidr B if and only if A contains
// the starting address of B and the last address in B.
cidrContained := false
if cidrnetA.Contains(cidrnetB.IP) {
// Only spend time calculating the last IP if the starting IP is already verified to be in cidr A
@@ -111,6 +116,75 @@ func builtinNetCIDRContains(a, b ast.Value) (ast.Value, error) {
return ast.Boolean(cidrContained), nil
}
var errNetCIDRContainsMatchElementType = errors.New("element must be string or non-empty array")
func getCIDRMatchTerm(a *ast.Term) (*ast.Term, error) {
switch v := a.Value.(type) {
case ast.String:
return a, nil
case *ast.Array:
if v.Len() == 0 {
return nil, errNetCIDRContainsMatchElementType
}
return v.Elem(0), nil
default:
return nil, errNetCIDRContainsMatchElementType
}
}
func evalNetCIDRContainsMatchesOperand(operand int, a *ast.Term, iter func(cidr, index *ast.Term) error) error {
switch v := a.Value.(type) {
case ast.String:
return iter(a, a)
case *ast.Array:
for i := 0; i < v.Len(); i++ {
cidr, err := getCIDRMatchTerm(v.Elem(i))
if err != nil {
return fmt.Errorf("operand %v: %v", operand, err)
}
if err := iter(cidr, ast.IntNumberTerm(i)); err != nil {
return err
}
}
return nil
case ast.Set:
return v.Iter(func(x *ast.Term) error {
cidr, err := getCIDRMatchTerm(x)
if err != nil {
return fmt.Errorf("operand %v: %v", operand, err)
}
return iter(cidr, x)
})
case ast.Object:
return v.Iter(func(k, v *ast.Term) error {
cidr, err := getCIDRMatchTerm(v)
if err != nil {
return fmt.Errorf("operand %v: %v", operand, err)
}
return iter(cidr, k)
})
}
return nil
}
func builtinNetCIDRContainsMatches(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
result := ast.NewSet()
err := evalNetCIDRContainsMatchesOperand(1, args[0], func(cidr1 *ast.Term, index1 *ast.Term) error {
return evalNetCIDRContainsMatchesOperand(2, args[1], func(cidr2 *ast.Term, index2 *ast.Term) error {
if v, err := builtinNetCIDRContains(cidr1.Value, cidr2.Value); err != nil {
return err
} else if vb, ok := v.(ast.Boolean); ok && bool(vb) {
result.Add(ast.ArrayTerm(index1, index2))
}
return nil
})
})
if err == nil {
return iter(ast.NewTerm(result))
}
return err
}
func builtinNetCIDRExpand(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
s, err := builtins.StringOperand(operands[0].Value, 1)
@@ -128,9 +202,11 @@ func builtinNetCIDRExpand(bctx BuiltinContext, operands []*ast.Term, iter func(*
for ip := ip.Mask(ipNet.Mask); ipNet.Contains(ip); incIP(ip) {
if bctx.Cancel != nil && bctx.Cancel.Cancelled() {
return &Error{
Code: CancelErr,
Message: "net.cidr_expand: timed out before generating all IP addresses",
return Halt{
Err: &Error{
Code: CancelErr,
Message: "net.cidr_expand: timed out before generating all IP addresses",
},
}
}
@@ -140,6 +216,176 @@ func builtinNetCIDRExpand(bctx BuiltinContext, operands []*ast.Term, iter func(*
return iter(ast.NewTerm(result))
}
type cidrBlockRange struct {
First *net.IP
Last *net.IP
Network *net.IPNet
}
type cidrBlockRanges []*cidrBlockRange
// Implement Sort interface
func (c cidrBlockRanges) Len() int {
return len(c)
}
func (c cidrBlockRanges) Swap(i, j int) {
c[i], c[j] = c[j], c[i]
}
func (c cidrBlockRanges) Less(i, j int) bool {
// Compare last IP.
cmp := bytes.Compare(*c[i].Last, *c[j].Last)
if cmp < 0 {
return true
} else if cmp > 0 {
return false
}
// Then compare first IP.
cmp = bytes.Compare(*c[i].First, *c[i].First)
if cmp < 0 {
return true
} else if cmp > 0 {
return false
}
// Ranges are Equal.
return false
}
// builtinNetCIDRMerge merges the provided list of IP addresses and subnets into the smallest possible list of CIDRs.
// It merges adjacent subnets where possible, those contained within others and also removes any duplicates.
// Original Algorithm: https://github.com/netaddr/netaddr.
func builtinNetCIDRMerge(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
networks := []*net.IPNet{}
switch v := operands[0].Value.(type) {
case *ast.Array:
for i := 0; i < v.Len(); i++ {
network, err := generateIPNet(v.Elem(i))
if err != nil {
return err
}
networks = append(networks, network)
}
case ast.Set:
err := v.Iter(func(x *ast.Term) error {
network, err := generateIPNet(x)
if err != nil {
return err
}
networks = append(networks, network)
return nil
})
if err != nil {
return err
}
default:
return errors.New("operand must be an array")
}
merged := evalNetCIDRMerge(networks)
result := ast.NewSet()
for _, network := range merged {
result.Add(ast.StringTerm(network.String()))
}
return iter(ast.NewTerm(result))
}
func evalNetCIDRMerge(networks []*net.IPNet) []*net.IPNet {
if len(networks) == 0 {
return nil
}
ranges := make(cidrBlockRanges, 0, len(networks))
// For each CIDR, create an IP range. Sort them and merge when possible.
for _, network := range networks {
firstIP, lastIP := cidrMerge.GetAddressRange(*network)
ranges = append(ranges, &cidrBlockRange{
First: &firstIP,
Last: &lastIP,
Network: network,
})
}
// merge CIDRs.
merged := mergeCIDRs(ranges)
// convert ranges into an equivalent list of net.IPNet.
result := []*net.IPNet{}
for _, r := range merged {
// Not merged with any other CIDR.
if r.Network != nil {
result = append(result, r.Network)
} else {
// Find new network that represents the merged range.
rangeCIDRs := cidrMerge.RangeToCIDRs(*r.First, *r.Last)
result = append(result, rangeCIDRs...)
}
}
return result
}
func generateIPNet(term *ast.Term) (*net.IPNet, error) {
e, ok := term.Value.(ast.String)
if !ok {
return nil, errors.New("element must be string")
}
// try to parse element as an IP first, fall back to CIDR
ip := net.ParseIP(string(e))
if ip == nil {
_, network, err := net.ParseCIDR(string(e))
return network, err
}
if ip.To4() != nil {
return &net.IPNet{
IP: ip,
Mask: ip.DefaultMask(),
}, nil
}
return nil, errors.New("IPv6 invalid: needs prefix length")
}
func mergeCIDRs(ranges cidrBlockRanges) cidrBlockRanges {
sort.Sort(ranges)
// Merge adjacent CIDRs if possible.
for i := len(ranges) - 1; i > 0; i-- {
previousIP := cidrMerge.GetPreviousIP(*ranges[i].First)
// If the previous IP of the current network overlaps
// with the last IP of the previous network in the
// list, then merge the two ranges together.
if bytes.Compare(previousIP, *ranges[i-1].Last) <= 0 {
var firstIP *net.IP
if bytes.Compare(*ranges[i-1].First, *ranges[i].First) < 0 {
firstIP = ranges[i-1].First
} else {
firstIP = ranges[i].First
}
lastIPRange := make(net.IP, len(*ranges[i].Last))
copy(lastIPRange, *ranges[i].Last)
firstIPRange := make(net.IP, len(*firstIP))
copy(firstIPRange, *firstIP)
ranges[i-1] = &cidrBlockRange{First: &firstIPRange, Last: &lastIPRange, Network: nil}
// Delete ranges[i] since merged with the previous.
ranges = append(ranges[:i], ranges[i+1:]...)
}
}
return ranges
}
func incIP(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
@@ -153,5 +399,7 @@ func init() {
RegisterFunctionalBuiltin2(ast.NetCIDROverlap.Name, builtinNetCIDRContains)
RegisterFunctionalBuiltin2(ast.NetCIDRIntersects.Name, builtinNetCIDRIntersects)
RegisterFunctionalBuiltin2(ast.NetCIDRContains.Name, builtinNetCIDRContains)
RegisterBuiltinFunc(ast.NetCIDRContainsMatches.Name, builtinNetCIDRContainsMatches)
RegisterBuiltinFunc(ast.NetCIDRExpand.Name, builtinNetCIDRExpand)
RegisterBuiltinFunc(ast.NetCIDRMerge.Name, builtinNetCIDRMerge)
}

View File

@@ -5,6 +5,7 @@
package copypropagation
import (
"fmt"
"sort"
"github.com/open-policy-agent/opa/ast"
@@ -30,6 +31,19 @@ type CopyPropagator struct {
livevars ast.VarSet // vars that must be preserved in the resulting query
sorted []ast.Var // sorted copy of vars to ensure deterministic result
ensureNonEmptyBody bool
compiler *ast.Compiler
localvargen *localVarGenerator
}
type localVarGenerator struct {
next int
}
func (l *localVarGenerator) Generate() ast.Var {
result := ast.Var(fmt.Sprintf("__localcp%d__", l.next))
l.next++
return result
}
// New returns a new CopyPropagator that optimizes queries while preserving vars
@@ -45,7 +59,7 @@ func New(livevars ast.VarSet) *CopyPropagator {
return sorted[i].Compare(sorted[j]) < 0
})
return &CopyPropagator{livevars: livevars, sorted: sorted}
return &CopyPropagator{livevars: livevars, sorted: sorted, localvargen: &localVarGenerator{}}
}
// WithEnsureNonEmptyBody configures p to ensure that results are always non-empty.
@@ -54,8 +68,17 @@ func (p *CopyPropagator) WithEnsureNonEmptyBody(yes bool) *CopyPropagator {
return p
}
// WithCompiler configures the compiler to read from while processing the query. This
// should be the same compiler used to compile the original policy.
func (p *CopyPropagator) WithCompiler(c *ast.Compiler) *CopyPropagator {
p.compiler = c
return p
}
// Apply executes the copy propagation optimization and returns a new query.
func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) {
func (p *CopyPropagator) Apply(query ast.Body) ast.Body {
result := ast.NewBody()
uf, ok := makeDisjointSets(p.livevars, query)
if !ok {
@@ -63,14 +86,15 @@ func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) {
}
// Compute set of vars that appear in the head of refs in the query. If a var
// is dereferenced, we cannot plug it with a constant value so the constant on
// the union-find root must be unset (e.g., [1][0] is not legal.)
// is dereferenced, we can plug it with a constant value, but it is not always
// optimal to do so.
// TODO: Improve the algorithm for when we should plug constants/calls/etc
headvars := ast.NewVarSet()
ast.WalkRefs(query, func(x ast.Ref) bool {
if v, ok := x[0].Value.(ast.Var); ok {
if root, ok := uf.Find(v); ok {
root.constant = nil
headvars.Add(root.key)
headvars.Add(root.key.(ast.Var))
} else {
headvars.Add(v)
}
@@ -78,21 +102,21 @@ func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) {
return false
})
bindings := map[ast.Var]*binding{}
removedEqs := ast.NewValueMap()
for _, expr := range query {
pctx := &plugContext{
bindings: bindings,
uf: uf,
negated: expr.Negated,
headvars: headvars,
removedEqs: removedEqs,
uf: uf,
negated: expr.Negated,
headvars: headvars,
}
if expr, keep := p.plugBindings(pctx, expr); keep {
if p.updateBindings(pctx, expr) {
result.Append(expr)
}
expr = p.plugBindings(pctx, expr)
if p.updateBindings(pctx, expr) {
result.Append(expr)
}
}
@@ -104,7 +128,7 @@ func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) {
// from being added to the result. For example:
//
// - Given the following result: <empty>
// - Given the following bindings: x/input.x and y/input
// - Given the following removed equalities: "x = input.x" and "y = input"
// - Given the following liveset: {x}
//
// If this step were to run AFTER the following step, the output would be:
@@ -116,8 +140,8 @@ func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) {
if root, ok := uf.Find(v); ok {
if root.constant != nil {
result.Append(ast.Equality.Expr(ast.NewTerm(v), root.constant))
} else if b, ok := bindings[root.key]; ok {
result.Append(ast.Equality.Expr(ast.NewTerm(v), ast.NewTerm(b.v)))
} else if b := removedEqs.Get(root.key); b != nil {
result.Append(ast.Equality.Expr(ast.NewTerm(v), ast.NewTerm(b)))
} else if root.key != v {
result.Append(ast.Equality.Expr(ast.NewTerm(v), ast.NewTerm(root.key)))
}
@@ -125,17 +149,46 @@ func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) {
}
// Run post-processing step on query to ensure that all killed exprs are
// accounted for. If an expr is killed but the binding is never used, the query
// must still include the expr. For example, given the query 'input.x = a' and
// an empty livevar set, the result must include the ref input.x otherwise the
// query could be satisfied without input.x being defined. When exprs are
// killed we initialize the binding counter to zero and then increment it each
// time the binding is substituted. if the binding was never substituted it
// means the binding value must be added back into the query.
for _, b := range sortbindings(bindings) {
if !b.containedIn(result) {
result.Append(ast.Equality.Expr(ast.NewTerm(b.k), ast.NewTerm(b.v)))
// accounted for. There are several cases we look for:
//
// * If an expr is killed but the binding is never used, the query
// must still include the expr. For example, given the query 'input.x = a' and
// an empty livevar set, the result must include the ref input.x otherwise the
// query could be satisfied without input.x being defined.
//
// * If an expr is killed that provided safety to vars which are not
// otherwise being made safe by the current result.
//
// For any of these cases we re-add the removed equality expression
// to the current result.
// Invariant: Live vars are bound (above) and reserved vars are implicitly ground.
safe := ast.ReservedVars.Copy()
safe.Update(p.livevars)
safe.Update(ast.OutputVarsFromBody(p.compiler, result, safe))
unsafe := result.Vars(ast.SafetyCheckVisitorParams).Diff(safe)
for _, b := range sortbindings(removedEqs) {
removedEq := ast.Equality.Expr(ast.NewTerm(b.k), ast.NewTerm(b.v))
providesSafety := false
outputVars := ast.OutputVarsFromExpr(p.compiler, removedEq, safe)
diff := unsafe.Diff(outputVars)
if len(diff) < len(unsafe) {
unsafe = diff
providesSafety = true
}
if providesSafety || !containedIn(b.v, result) {
result.Append(removedEq)
safe.Update(outputVars)
}
}
if len(unsafe) > 0 {
// NOTE(tsandall): This should be impossible but if it does occur, throw
// away the result rather than generating unsafe output.
return query
}
if p.ensureNonEmptyBody && len(result) == 0 {
@@ -147,19 +200,7 @@ func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) {
// plugBindings applies the binding list and union-find to x. This process
// removes as many variables as possible.
func (p *CopyPropagator) plugBindings(pctx *plugContext, expr *ast.Expr) (*ast.Expr, bool) {
// Kill single term expressions that are in the binding list. They will be
// re-added during post-processing if needed.
if term, ok := expr.Terms.(*ast.Term); ok {
if v, ok := term.Value.(ast.Var); ok {
if root, ok := pctx.uf.Find(v); ok {
if _, ok := pctx.bindings[root.key]; ok {
return nil, false
}
}
}
}
func (p *CopyPropagator) plugBindings(pctx *plugContext, expr *ast.Expr) *ast.Expr {
xform := bindingPlugTransform{
pctx: pctx,
@@ -175,7 +216,7 @@ func (p *CopyPropagator) plugBindings(pctx *plugContext, expr *ast.Expr) (*ast.E
if expr, ok := x.(*ast.Expr); !ok || err != nil {
panic("unreachable")
} else {
return expr, true
return expr
}
}
@@ -194,31 +235,40 @@ func (t bindingPlugTransform) Transform(x interface{}) (interface{}, error) {
}
}
func (t bindingPlugTransform) plugBindingsVar(pctx *plugContext, v ast.Var) (result ast.Value) {
func (t bindingPlugTransform) plugBindingsVar(pctx *plugContext, v ast.Var) ast.Value {
result = v
var result ast.Value = v
// Apply union-find to remove redundant variables from input.
if root, ok := pctx.uf.Find(v); ok {
root, ok := pctx.uf.Find(v)
if ok {
result = root.Value()
}
// Apply binding list to substitute remaining vars.
if v, ok := result.(ast.Var); ok {
if b, ok := pctx.bindings[v]; ok {
if !pctx.negated || b.v.IsGround() {
result = b.v
}
}
v, ok = result.(ast.Var)
if !ok {
return result
}
b := pctx.removedEqs.Get(v)
if b == nil {
return result
}
if pctx.negated && !b.IsGround() {
return result
}
return result
if r, ok := b.(ast.Ref); ok && r.OutputVars().Contains(v) {
return result
}
return b
}
func (t bindingPlugTransform) plugBindingsRef(pctx *plugContext, v ast.Ref) ast.Ref {
// Apply union-find to remove redundant variables from input.
if root, ok := pctx.uf.Find(v[0].Value.(ast.Var)); ok {
if root, ok := pctx.uf.Find(v[0].Value); ok {
v[0].Value = root.Value()
}
@@ -226,11 +276,16 @@ func (t bindingPlugTransform) plugBindingsRef(pctx *plugContext, v ast.Ref) ast.
// Refs require special handling. If the head of the ref was killed, then
// the rest of the ref must be concatenated with the new base.
//
// Invariant: ref heads can only be replaced by refs (not calls).
if b, ok := pctx.bindings[v[0].Value.(ast.Var)]; ok {
if !pctx.negated || b.v.IsGround() {
result = b.v.(ast.Ref).Concat(v[1:])
if b := pctx.removedEqs.Get(v[0].Value); b != nil {
if !pctx.negated || b.IsGround() {
var base ast.Ref
switch x := b.(type) {
case ast.Ref:
base = x
default:
base = ast.Ref{ast.NewTerm(x)}
}
result = base.Concat(v[1:])
}
}
@@ -240,32 +295,54 @@ func (t bindingPlugTransform) plugBindingsRef(pctx *plugContext, v ast.Ref) ast.
// updateBindings returns false if the expression can be killed. If the
// expression is killed, the binding list is updated to map a var to value.
func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool {
if pctx.negated || len(expr.With) > 0 {
switch {
case pctx.negated || len(expr.With) > 0:
return true
}
if expr.IsEquality() {
case expr.IsEquality():
a, b := expr.Operand(0), expr.Operand(1)
if a.Equal(b) {
if p.livevarRef(a) {
pctx.removedEqs.Put(p.localvargen.Generate(), a.Value)
}
return false
}
k, v, keep := p.updateBindingsEq(a, b)
if !keep {
if v != nil {
pctx.bindings[k] = newbinding(k, v)
pctx.removedEqs.Put(k, v)
}
return false
}
} else if expr.IsCall() {
case expr.IsCall():
terms := expr.Terms.([]*ast.Term)
output := terms[len(terms)-1]
if k, ok := output.Value.(ast.Var); ok && !p.livevars.Contains(k) && !pctx.headvars.Contains(k) {
pctx.bindings[k] = newbinding(k, ast.CallTerm(terms[:len(terms)-1]...).Value)
return false
if p.compiler.GetArity(expr.Operator()) == len(terms)-2 { // with captured output
output := terms[len(terms)-1]
if k, ok := output.Value.(ast.Var); ok && !p.livevars.Contains(k) && !pctx.headvars.Contains(k) {
pctx.removedEqs.Put(k, ast.CallTerm(terms[:len(terms)-1]...).Value)
return false
}
}
}
return !isNoop(expr)
}
func (p *CopyPropagator) livevarRef(a *ast.Term) bool {
ref, ok := a.Value.(ast.Ref)
if !ok {
return false
}
for _, v := range p.sorted {
if ref[0].Value.Compare(v) == 0 {
return true
}
}
return false
}
func (p *CopyPropagator) updateBindingsEq(a, b *ast.Term) (ast.Var, ast.Value, bool) {
k, v, keep := p.updateBindingsEqAsymmetric(a, b)
if !keep {
@@ -289,26 +366,21 @@ func (p *CopyPropagator) updateBindingsEqAsymmetric(a, b *ast.Term) (ast.Var, as
}
type plugContext struct {
bindings map[ast.Var]*binding
uf *unionFind
headvars ast.VarSet
negated bool
removedEqs *ast.ValueMap
uf *unionFind
headvars ast.VarSet
negated bool
}
type binding struct {
k ast.Var
v ast.Value
k, v ast.Value
}
func newbinding(k ast.Var, v ast.Value) *binding {
return &binding{k: k, v: v}
}
func (b *binding) containedIn(query ast.Body) bool {
func containedIn(value ast.Value, x interface{}) bool {
var stop bool
switch v := b.v.(type) {
switch v := value.(type) {
case ast.Ref:
ast.WalkRefs(query, func(other ast.Ref) bool {
ast.WalkRefs(x, func(other ast.Ref) bool {
if stop || other.HasPrefix(v) {
stop = true
return stop
@@ -316,7 +388,7 @@ func (b *binding) containedIn(query ast.Body) bool {
return false
})
default:
ast.WalkTerms(query, func(other *ast.Term) bool {
ast.WalkTerms(x, func(other *ast.Term) bool {
if stop || other.Value.Compare(v) == 0 {
stop = true
return stop
@@ -327,23 +399,18 @@ func (b *binding) containedIn(query ast.Body) bool {
return stop
}
func sortbindings(bindings map[ast.Var]*binding) []*binding {
sorted := make([]*binding, 0, len(bindings))
for _, b := range bindings {
sorted = append(sorted, b)
}
func sortbindings(bindings *ast.ValueMap) []*binding {
sorted := make([]*binding, 0, bindings.Len())
bindings.Iter(func(k ast.Value, v ast.Value) bool {
sorted = append(sorted, &binding{k, v})
return false
})
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].k.Compare(sorted[j].k) < 0
return sorted[i].k.Compare(sorted[j].k) > 0
})
return sorted
}
type unionFind struct {
roots map[ast.Var]*unionFindRoot
parents map[ast.Var]ast.Var
rank rankFunc
}
// makeDisjointSets builds the union-find structure for the query. The structure
// is built by processing all of the equality exprs in the query. Sets represent
// vars that must be equal to each other. In addition to vars, each set can have
@@ -352,7 +419,7 @@ type unionFind struct {
// false.
func makeDisjointSets(livevars ast.VarSet, query ast.Body) (*unionFind, bool) {
uf := newUnionFind(func(r1, r2 *unionFindRoot) (*unionFindRoot, *unionFindRoot) {
if livevars.Contains(r1.key) {
if v, ok := r1.key.(ast.Var); ok && livevars.Contains(v) {
return r1, r2
}
return r2, r1
@@ -362,17 +429,21 @@ func makeDisjointSets(livevars ast.VarSet, query ast.Body) (*unionFind, bool) {
a, b := expr.Operand(0), expr.Operand(1)
varA, ok1 := a.Value.(ast.Var)
varB, ok2 := b.Value.(ast.Var)
if ok1 && ok2 {
switch {
case ok1 && ok2:
if _, ok := uf.Merge(varA, varB); !ok {
return nil, false
}
} else if ok1 && ast.IsConstant(b.Value) {
case ok1 && ast.IsConstant(b.Value):
root := uf.MakeSet(varA)
if root.constant != nil && !root.constant.Equal(b) {
return nil, false
}
root.constant = b
} else if ok2 && ast.IsConstant(a.Value) {
case ok2 && ast.IsConstant(a.Value):
root := uf.MakeSet(varB)
if root.constant != nil && !root.constant.Equal(a) {
return nil, false
@@ -385,89 +456,9 @@ func makeDisjointSets(livevars ast.VarSet, query ast.Body) (*unionFind, bool) {
return uf, true
}
type rankFunc func(*unionFindRoot, *unionFindRoot) (*unionFindRoot, *unionFindRoot)
func newUnionFind(rank rankFunc) *unionFind {
return &unionFind{
roots: map[ast.Var]*unionFindRoot{},
parents: map[ast.Var]ast.Var{},
rank: rank,
}
}
func (uf *unionFind) MakeSet(v ast.Var) *unionFindRoot {
root, ok := uf.Find(v)
if ok {
return root
}
root = newUnionFindRoot(v)
uf.parents[v] = v
uf.roots[v] = root
return uf.roots[v]
}
func (uf *unionFind) Find(v ast.Var) (*unionFindRoot, bool) {
parent, ok := uf.parents[v]
if !ok {
return nil, false
}
if parent == v {
return uf.roots[v], true
}
return uf.Find(parent)
}
func (uf *unionFind) Merge(a, b ast.Var) (*unionFindRoot, bool) {
r1 := uf.MakeSet(a)
r2 := uf.MakeSet(b)
if r1 != r2 {
r1, r2 = uf.rank(r1, r2)
uf.parents[r2.key] = r1.key
delete(uf.roots, r2.key)
// Sets can have at most one constant value associated with them. When
// unioning, we must preserve this invariant. If a set has two constants,
// there will be no way to prove the query.
if r1.constant != nil && r2.constant != nil && !r1.constant.Equal(r2.constant) {
return nil, false
} else if r1.constant == nil {
r1.constant = r2.constant
}
}
return r1, true
}
type unionFindRoot struct {
key ast.Var
constant *ast.Term
}
func newUnionFindRoot(key ast.Var) *unionFindRoot {
return &unionFindRoot{
key: key,
}
}
func (r *unionFindRoot) Value() ast.Value {
if r.constant != nil {
return r.constant.Value
}
return r.key
}
func isNoop(expr *ast.Expr) bool {
if !expr.IsCall() {
if !expr.IsCall() && !expr.IsEvery() {
term := expr.Terms.(*ast.Term)
if !ast.IsConstant(term.Value) {
return false

View File

@@ -0,0 +1,135 @@
// Copyright 2020 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package copypropagation
import (
"fmt"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/util"
)
type rankFunc func(*unionFindRoot, *unionFindRoot) (*unionFindRoot, *unionFindRoot)
type unionFind struct {
roots *util.HashMap
parents *ast.ValueMap
rank rankFunc
}
func newUnionFind(rank rankFunc) *unionFind {
return &unionFind{
roots: util.NewHashMap(func(a util.T, b util.T) bool {
return a.(ast.Value).Compare(b.(ast.Value)) == 0
}, func(v util.T) int {
return v.(ast.Value).Hash()
}),
parents: ast.NewValueMap(),
rank: rank,
}
}
func (uf *unionFind) MakeSet(v ast.Value) *unionFindRoot {
root, ok := uf.Find(v)
if ok {
return root
}
root = newUnionFindRoot(v)
uf.parents.Put(v, v)
uf.roots.Put(v, root)
return root
}
func (uf *unionFind) Find(v ast.Value) (*unionFindRoot, bool) {
parent := uf.parents.Get(v)
if parent == nil {
return nil, false
}
if parent.Compare(v) == 0 {
r, ok := uf.roots.Get(v)
return r.(*unionFindRoot), ok
}
return uf.Find(parent)
}
func (uf *unionFind) Merge(a, b ast.Value) (*unionFindRoot, bool) {
r1 := uf.MakeSet(a)
r2 := uf.MakeSet(b)
if r1 != r2 {
r1, r2 = uf.rank(r1, r2)
uf.parents.Put(r2.key, r1.key)
uf.roots.Delete(r2.key)
// Sets can have at most one constant value associated with them. When
// unioning, we must preserve this invariant. If a set has two constants,
// there will be no way to prove the query.
if r1.constant != nil && r2.constant != nil && !r1.constant.Equal(r2.constant) {
return nil, false
} else if r1.constant == nil {
r1.constant = r2.constant
}
}
return r1, true
}
func (uf *unionFind) String() string {
o := struct {
Roots map[string]interface{}
Parents map[string]ast.Value
}{
map[string]interface{}{},
map[string]ast.Value{},
}
uf.roots.Iter(func(k util.T, v util.T) bool {
o.Roots[k.(ast.Value).String()] = struct {
Constant *ast.Term
Key ast.Value
}{
v.(*unionFindRoot).constant,
v.(*unionFindRoot).key,
}
return true
})
uf.parents.Iter(func(k ast.Value, v ast.Value) bool {
o.Parents[k.String()] = v
return true
})
return string(util.MustMarshalJSON(o))
}
type unionFindRoot struct {
key ast.Value
constant *ast.Term
}
func newUnionFindRoot(key ast.Value) *unionFindRoot {
return &unionFindRoot{
key: key,
}
}
func (r *unionFindRoot) Value() ast.Value {
if r.constant != nil {
return r.constant.Value
}
return r.key
}
func (r *unionFindRoot) String() string {
return fmt.Sprintf("{key: %s, constant: %s", r.key, r.constant)
}

View File

@@ -5,45 +5,174 @@
package topdown
import (
"bytes"
"crypto/hmac"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"hash"
"io/ioutil"
"os"
"strings"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/internal/jwx/jwk"
"github.com/open-policy-agent/opa/topdown/builtins"
"github.com/open-policy-agent/opa/util"
)
const (
// blockTypeCertificate indicates this PEM block contains the signed certificate.
// Exported for tests.
blockTypeCertificate = "CERTIFICATE"
// blockTypeCertificateRequest indicates this PEM block contains a certificate
// request. Exported for tests.
blockTypeCertificateRequest = "CERTIFICATE REQUEST"
// blockTypeRSAPrivateKey indicates this PEM block contains a RSA private key.
// Exported for tests.
blockTypeRSAPrivateKey = "RSA PRIVATE KEY"
// blockTypeRSAPrivateKey indicates this PEM block contains a RSA private key.
// Exported for tests.
blockTypePrivateKey = "PRIVATE KEY"
)
func builtinCryptoX509ParseCertificates(a ast.Value) (ast.Value, error) {
str, err := builtinBase64Decode(a)
input, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
certs, err := x509.ParseCertificates([]byte(str.(ast.String)))
certs, err := getX509CertsFromString(string(input))
if err != nil {
return nil, err
}
bs, err := json.Marshal(certs)
return ast.InterfaceToValue(certs)
}
func builtinCryptoX509ParseAndVerifyCertificates(
_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
a := args[0].Value
input, err := builtins.StringOperand(a, 1)
if err != nil {
return err
}
invalid := ast.ArrayTerm(
ast.BooleanTerm(false),
ast.NewTerm(ast.NewArray()),
)
certs, err := getX509CertsFromString(string(input))
if err != nil {
return iter(invalid)
}
verified, err := verifyX509CertificateChain(certs)
if err != nil {
return iter(invalid)
}
value, err := ast.InterfaceToValue(verified)
if err != nil {
return err
}
valid := ast.ArrayTerm(
ast.BooleanTerm(true),
ast.NewTerm(value),
)
return iter(valid)
}
func builtinCryptoX509ParseCertificateRequest(a ast.Value) (ast.Value, error) {
input, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
// data to be passed to x509.ParseCertificateRequest
bytes := []byte(input)
// if the input is not a PEM string, attempt to decode b64
if str := string(input); !strings.HasPrefix(str, "-----BEGIN CERTIFICATE REQUEST-----") {
bytes, err = base64.StdEncoding.DecodeString(str)
if err != nil {
return nil, err
}
}
p, _ := pem.Decode(bytes)
if p != nil && p.Type != blockTypeCertificateRequest {
return nil, fmt.Errorf("invalid PEM-encoded certificate signing request")
}
if p != nil {
bytes = p.Bytes
}
csr, err := x509.ParseCertificateRequest(bytes)
if err != nil {
return nil, err
}
bs, err := json.Marshal(csr)
if err != nil {
return nil, err
}
var x interface{}
if err := util.UnmarshalJSON(bs, &x); err != nil {
return nil, err
}
return ast.InterfaceToValue(x)
}
func builtinCryptoX509ParseRSAPrivateKey(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
a := args[0].Value
input, err := builtins.StringOperand(a, 1)
if err != nil {
return err
}
// get the raw private key
rawKey, err := getRSAPrivateKeyFromString(string(input))
if err != nil {
return err
}
rsaPrivateKey, err := jwk.New(rawKey)
if err != nil {
return err
}
jsonKey, err := json.Marshal(rsaPrivateKey)
if err != nil {
return err
}
var x interface{}
if err := util.UnmarshalJSON(jsonKey, &x); err != nil {
return err
}
value, err := ast.InterfaceToValue(x)
if err != nil {
return err
}
return iter(ast.NewTerm(value))
}
func hashHelper(a ast.Value, h func(ast.String) string) (ast.Value, error) {
s, err := builtins.StringOperand(a, 1)
if err != nil {
@@ -64,47 +193,209 @@ func builtinCryptoSha256(a ast.Value) (ast.Value, error) {
return hashHelper(a, func(s ast.String) string { return fmt.Sprintf("%x", sha256.Sum256([]byte(s))) })
}
func hmacHelper(args []*ast.Term, iter func(*ast.Term) error, h func() hash.Hash) error {
a1 := args[0].Value
message, err := builtins.StringOperand(a1, 1)
if err != nil {
return err
}
a2 := args[1].Value
key, err := builtins.StringOperand(a2, 2)
if err != nil {
return err
}
mac := hmac.New(h, []byte(key))
mac.Write([]byte(message))
messageDigest := mac.Sum(nil)
return iter(ast.StringTerm(fmt.Sprintf("%x", messageDigest)))
}
func builtinCryptoHmacMd5(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
return hmacHelper(args, iter, md5.New)
}
func builtinCryptoHmacSha1(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
return hmacHelper(args, iter, sha1.New)
}
func builtinCryptoHmacSha256(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
return hmacHelper(args, iter, sha256.New)
}
func builtinCryptoHmacSha512(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
return hmacHelper(args, iter, sha512.New)
}
func init() {
RegisterFunctionalBuiltin1(ast.CryptoX509ParseCertificates.Name, builtinCryptoX509ParseCertificates)
RegisterBuiltinFunc(ast.CryptoX509ParseAndVerifyCertificates.Name, builtinCryptoX509ParseAndVerifyCertificates)
RegisterFunctionalBuiltin1(ast.CryptoMd5.Name, builtinCryptoMd5)
RegisterFunctionalBuiltin1(ast.CryptoSha1.Name, builtinCryptoSha1)
RegisterFunctionalBuiltin1(ast.CryptoSha256.Name, builtinCryptoSha256)
RegisterFunctionalBuiltin1(ast.CryptoX509ParseCertificateRequest.Name, builtinCryptoX509ParseCertificateRequest)
RegisterBuiltinFunc(ast.CryptoX509ParseRSAPrivateKey.Name, builtinCryptoX509ParseRSAPrivateKey)
RegisterBuiltinFunc(ast.CryptoHmacMd5.Name, builtinCryptoHmacMd5)
RegisterBuiltinFunc(ast.CryptoHmacSha1.Name, builtinCryptoHmacSha1)
RegisterBuiltinFunc(ast.CryptoHmacSha256.Name, builtinCryptoHmacSha256)
RegisterBuiltinFunc(ast.CryptoHmacSha512.Name, builtinCryptoHmacSha512)
}
// createRootCAs creates a new Cert Pool from scratch or adds to a copy of System Certs
func createRootCAs(tlsCACertFile string, tlsCACertEnvVar []byte, tlsUseSystemCerts bool) (*x509.CertPool, error) {
var newRootCAs *x509.CertPool
if tlsUseSystemCerts {
systemCertPool, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
newRootCAs = systemCertPool
} else {
newRootCAs = x509.NewCertPool()
func verifyX509CertificateChain(certs []*x509.Certificate) ([]*x509.Certificate, error) {
if len(certs) < 2 {
return nil, builtins.NewOperandErr(1, "must supply at least two certificates to be able to verify")
}
if len(tlsCACertFile) > 0 {
// Append our cert to the system pool
caCert, err := readCertFromFile(tlsCACertFile)
if err != nil {
return nil, err
}
if ok := newRootCAs.AppendCertsFromPEM(caCert); !ok {
return nil, fmt.Errorf("could not append CA cert from %q", tlsCACertFile)
}
// first cert is the root
roots := x509.NewCertPool()
roots.AddCert(certs[0])
// all other certs except the last are intermediates
intermediates := x509.NewCertPool()
for i := 1; i < len(certs)-1; i++ {
intermediates.AddCert(certs[i])
}
if len(tlsCACertEnvVar) > 0 {
// Append our cert to the system pool
if ok := newRootCAs.AppendCertsFromPEM(tlsCACertEnvVar); !ok {
return nil, fmt.Errorf("error appending cert from env var %q into system certs", tlsCACertEnvVar)
}
// last cert is the leaf
leaf := certs[len(certs)-1]
// verify the cert chain back to the root
verifyOpts := x509.VerifyOptions{
Roots: roots,
Intermediates: intermediates,
}
chains, err := leaf.Verify(verifyOpts)
if err != nil {
return nil, err
}
return newRootCAs, nil
return chains[0], nil
}
func getX509CertsFromString(certs string) ([]*x509.Certificate, error) {
// if the input is PEM handle that
if strings.HasPrefix(certs, "-----BEGIN") {
return getX509CertsFromPem([]byte(certs))
}
// assume input is base64 if not PEM
b64, err := base64.StdEncoding.DecodeString(certs)
if err != nil {
return nil, err
}
// handle if the decoded base64 contains PEM rather than the expected DER
if bytes.HasPrefix(b64, []byte("-----BEGIN")) {
return getX509CertsFromPem(b64)
}
// otherwise assume the contents are DER
return x509.ParseCertificates(b64)
}
func getX509CertsFromPem(pemBlocks []byte) ([]*x509.Certificate, error) {
var decodedCerts []byte
for len(pemBlocks) > 0 {
p, r := pem.Decode(pemBlocks)
if p != nil && p.Type != blockTypeCertificate {
return nil, fmt.Errorf("PEM block type is '%s', expected %s", p.Type, blockTypeCertificate)
}
if p == nil {
break
}
pemBlocks = r
decodedCerts = append(decodedCerts, p.Bytes...)
}
return x509.ParseCertificates(decodedCerts)
}
func getRSAPrivateKeyFromString(key string) (interface{}, error) {
// if the input is PEM handle that
if strings.HasPrefix(key, "-----BEGIN") {
return getRSAPrivateKeyFromPEM([]byte(key))
}
// assume input is base64 if not PEM
b64, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return nil, err
}
return getRSAPrivateKeyFromPEM(b64)
}
func getRSAPrivateKeyFromPEM(pemBlocks []byte) (interface{}, error) {
// decode the pem into the Block struct
p, _ := pem.Decode(pemBlocks)
if p == nil {
return nil, fmt.Errorf("failed to parse PEM block containing the key")
}
// if the key is in PKCS1 format
if p.Type == blockTypeRSAPrivateKey {
return x509.ParsePKCS1PrivateKey(p.Bytes)
}
// if the key is in PKCS8 format
if p.Type == blockTypePrivateKey {
return x509.ParsePKCS8PrivateKey(p.Bytes)
}
// unsupported key format
return nil, fmt.Errorf("PEM block type is '%s', expected %s or %s", p.Type, blockTypeRSAPrivateKey,
blockTypePrivateKey)
}
// addCACertsFromFile adds CA certificates from filePath into the given pool.
// If pool is nil, it creates a new x509.CertPool. pool is returned.
func addCACertsFromFile(pool *x509.CertPool, filePath string) (*x509.CertPool, error) {
if pool == nil {
pool = x509.NewCertPool()
}
caCert, err := readCertFromFile(filePath)
if err != nil {
return nil, err
}
if ok := pool.AppendCertsFromPEM(caCert); !ok {
return nil, fmt.Errorf("could not append CA certificates from %q", filePath)
}
return pool, nil
}
// addCACertsFromBytes adds CA certificates from pemBytes into the given pool.
// If pool is nil, it creates a new x509.CertPool. pool is returned.
func addCACertsFromBytes(pool *x509.CertPool, pemBytes []byte) (*x509.CertPool, error) {
if pool == nil {
pool = x509.NewCertPool()
}
if ok := pool.AppendCertsFromPEM(pemBytes); !ok {
return nil, fmt.Errorf("could not append certificates")
}
return pool, nil
}
// addCACertsFromBytes adds CA certificates from the environment variable named
// by envName into the given pool. If pool is nil, it creates a new x509.CertPool.
// pool is returned.
func addCACertsFromEnv(pool *x509.CertPool, envName string) (*x509.CertPool, error) {
pool, err := addCACertsFromBytes(pool, []byte(os.Getenv(envName)))
if err != nil {
return nil, fmt.Errorf("could not add CA certificates from envvar %q: %w", envName, err)
}
return pool, err
}
// ReadCertFromFile reads a cert from file

View File

@@ -7,6 +7,7 @@ package topdown
import (
"bytes"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"net/url"
@@ -50,6 +51,16 @@ func builtinJSONUnmarshal(a ast.Value) (ast.Value, error) {
return ast.InterfaceToValue(x)
}
func builtinJSONIsValid(a ast.Value) (ast.Value, error) {
str, err := builtins.StringOperand(a, 1)
if err != nil {
return ast.Boolean(false), nil
}
return ast.Boolean(json.Valid([]byte(str))), nil
}
func builtinBase64Encode(a ast.Value) (ast.Value, error) {
str, err := builtins.StringOperand(a, 1)
if err != nil {
@@ -69,6 +80,16 @@ func builtinBase64Decode(a ast.Value) (ast.Value, error) {
return ast.String(result), err
}
func builtinBase64IsValid(a ast.Value) (ast.Value, error) {
str, err := builtins.StringOperand(a, 1)
if err != nil {
return ast.Boolean(false), nil
}
_, err = base64.StdEncoding.DecodeString(string(str))
return ast.Boolean(err == nil), nil
}
func builtinBase64UrlEncode(a ast.Value) (ast.Value, error) {
str, err := builtins.StringOperand(a, 1)
if err != nil {
@@ -78,6 +99,14 @@ func builtinBase64UrlEncode(a ast.Value) (ast.Value, error) {
return ast.String(base64.URLEncoding.EncodeToString([]byte(str))), nil
}
func builtinBase64UrlEncodeNoPad(a ast.Value) (ast.Value, error) {
str, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
return ast.String(base64.RawURLEncoding.EncodeToString([]byte(str))), nil
}
func builtinBase64UrlDecode(a ast.Value) (ast.Value, error) {
str, err := builtins.StringOperand(a, 1)
if err != nil {
@@ -158,6 +187,29 @@ func builtinURLQueryEncodeObject(a ast.Value) (ast.Value, error) {
return ast.String(query.Encode()), nil
}
func builtinURLQueryDecodeObject(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
query, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return err
}
queryParams, err := url.ParseQuery(string(query))
if err != nil {
return err
}
queryObject := ast.NewObject()
for k, v := range queryParams {
paramsArray := make([]*ast.Term, len(v))
for i, param := range v {
paramsArray[i] = ast.StringTerm(param)
}
queryObject.Insert(ast.StringTerm(k), ast.ArrayTerm(paramsArray...))
}
return iter(ast.NewTerm(queryObject))
}
func builtinYAMLMarshal(a ast.Value) (ast.Value, error) {
asJSON, err := ast.JSON(a)
@@ -202,16 +254,54 @@ func builtinYAMLUnmarshal(a ast.Value) (ast.Value, error) {
return ast.InterfaceToValue(val)
}
func builtinYAMLIsValid(a ast.Value) (ast.Value, error) {
str, err := builtins.StringOperand(a, 1)
if err != nil {
return ast.Boolean(false), nil
}
var x interface{}
err = ghodss.Unmarshal([]byte(str), &x)
return ast.Boolean(err == nil), nil
}
func builtinHexEncode(a ast.Value) (ast.Value, error) {
str, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
return ast.String(hex.EncodeToString([]byte(str))), nil
}
func builtinHexDecode(a ast.Value) (ast.Value, error) {
str, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
val, err := hex.DecodeString(string(str))
if err != nil {
return nil, err
}
return ast.String(val), nil
}
func init() {
RegisterFunctionalBuiltin1(ast.JSONMarshal.Name, builtinJSONMarshal)
RegisterFunctionalBuiltin1(ast.JSONUnmarshal.Name, builtinJSONUnmarshal)
RegisterFunctionalBuiltin1(ast.JSONIsValid.Name, builtinJSONIsValid)
RegisterFunctionalBuiltin1(ast.Base64Encode.Name, builtinBase64Encode)
RegisterFunctionalBuiltin1(ast.Base64Decode.Name, builtinBase64Decode)
RegisterFunctionalBuiltin1(ast.Base64IsValid.Name, builtinBase64IsValid)
RegisterFunctionalBuiltin1(ast.Base64UrlEncode.Name, builtinBase64UrlEncode)
RegisterFunctionalBuiltin1(ast.Base64UrlEncodeNoPad.Name, builtinBase64UrlEncodeNoPad)
RegisterFunctionalBuiltin1(ast.Base64UrlDecode.Name, builtinBase64UrlDecode)
RegisterFunctionalBuiltin1(ast.URLQueryDecode.Name, builtinURLQueryDecode)
RegisterFunctionalBuiltin1(ast.URLQueryEncode.Name, builtinURLQueryEncode)
RegisterFunctionalBuiltin1(ast.URLQueryEncodeObject.Name, builtinURLQueryEncodeObject)
RegisterBuiltinFunc(ast.URLQueryDecodeObject.Name, builtinURLQueryDecodeObject)
RegisterFunctionalBuiltin1(ast.YAMLMarshal.Name, builtinYAMLMarshal)
RegisterFunctionalBuiltin1(ast.YAMLUnmarshal.Name, builtinYAMLUnmarshal)
RegisterFunctionalBuiltin1(ast.YAMLIsValid.Name, builtinYAMLIsValid)
RegisterFunctionalBuiltin1(ast.HexEncode.Name, builtinHexEncode)
RegisterFunctionalBuiltin1(ast.HexDecode.Name, builtinHexDecode)
}

View File

@@ -5,11 +5,24 @@
package topdown
import (
"errors"
"fmt"
"github.com/open-policy-agent/opa/ast"
)
// Halt is a special error type that built-in function implementations return to indicate
// that policy evaluation should stop immediately.
type Halt struct {
Err error
}
func (h Halt) Error() string {
return h.Err.Error()
}
func (h Halt) Unwrap() error { return h.Err }
// Error is the error type returned by the Eval and Query functions when
// an evaluation error occurs.
type Error struct {
@@ -47,20 +60,27 @@ const (
// IsError returns true if the err is an Error.
func IsError(err error) bool {
_, ok := err.(*Error)
return ok
var e *Error
return errors.As(err, &e)
}
// IsCancel returns true if err was caused by cancellation.
func IsCancel(err error) bool {
if e, ok := err.(*Error); ok {
return e.Code == CancelErr
return errors.Is(err, &Error{Code: CancelErr})
}
// Is allows matching topdown errors using errors.Is (see IsCancel).
func (e *Error) Is(target error) bool {
var t *Error
if errors.As(target, &t) {
return (t.Code == "" || e.Code == t.Code) &&
(t.Message == "" || e.Message == t.Message) &&
(t.Location == nil || t.Location.Compare(e.Location) == 0)
}
return false
}
func (e *Error) Error() string {
msg := fmt.Sprintf("%v: %v", e.Code, e.Message)
if e.Location != nil {
@@ -94,14 +114,6 @@ func objectDocKeyConflictErr(loc *ast.Location) error {
}
}
func documentConflictErr(loc *ast.Location) error {
return &Error{
Code: ConflictErr,
Location: loc,
Message: "base and virtual document keys must be disjoint",
}
}
func unsupportedBuiltinErr(loc *ast.Location) error {
return &Error{
Code: InternalErr,
@@ -117,3 +129,11 @@ func mergeConflictErr(loc *ast.Location) error {
Message: "real and replacement data could not be merged",
}
}
func internalErr(loc *ast.Location, msg string) error {
return &Error{
Code: InternalErr,
Location: loc,
Message: msg,
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,7 @@
package topdown
import (
"fmt"
"strings"
"sync"
"github.com/gobwas/glob"
@@ -18,22 +18,34 @@ func builtinGlobMatch(a, b, c ast.Value) (ast.Value, error) {
if err != nil {
return nil, err
}
var delimiters []rune
switch b.(type) {
case ast.Null:
delimiters = []rune{}
case *ast.Array:
delimiters, err = builtins.RuneSliceOperand(b, 2)
if err != nil {
return nil, err
}
delimiters, err := builtins.RuneSliceOperand(b, 2)
if err != nil {
return nil, err
if len(delimiters) == 0 {
delimiters = []rune{'.'}
}
default:
return nil, builtins.NewOperandTypeErr(2, b, "array", "null")
}
if len(delimiters) == 0 {
delimiters = []rune{'.'}
}
match, err := builtins.StringOperand(c, 3)
if err != nil {
return nil, err
}
id := fmt.Sprintf("%s-%v", pattern, delimiters)
builder := strings.Builder{}
builder.WriteString(string(pattern))
builder.WriteRune('-')
for _, v := range delimiters {
builder.WriteRune(v)
}
id := builder.String()
globCacheLock.Lock()
defer globCacheLock.Unlock()
@@ -46,7 +58,8 @@ func builtinGlobMatch(a, b, c ast.Value) (ast.Value, error) {
globCache[id] = p
}
return ast.Boolean(p.Match(string(match))), nil
m := p.Match(string(match))
return ast.Boolean(m), nil
}
func builtinGlobQuoteMeta(a ast.Value) (ast.Value, error) {

View File

@@ -0,0 +1,462 @@
// Copyright 2022 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"encoding/json"
"fmt"
"strings"
gqlast "github.com/open-policy-agent/opa/internal/gqlparser/ast"
gqlparser "github.com/open-policy-agent/opa/internal/gqlparser/parser"
gqlvalidator "github.com/open-policy-agent/opa/internal/gqlparser/validator"
// Side-effecting import. Triggers GraphQL library's validation rule init() functions.
_ "github.com/open-policy-agent/opa/internal/gqlparser/validator/rules"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
// Parses a GraphQL schema, and returns the GraphQL AST for the schema.
func parseSchema(schema string) (*gqlast.SchemaDocument, error) {
// NOTE(philipc): We don't include the "built-in schema defs" from the
// underlying graphql parsing library here, because those definitions
// generate enormous AST blobs. In the future, if there is demand for
// a "full-spec" version of schema ASTs, we may need to provide a
// version of this function that includes the built-in schema
// definitions.
schemaAST, err := gqlparser.ParseSchema(&gqlast.Source{Input: schema})
if err != nil {
errorParts := strings.SplitN(err.Error(), ":", 4)
msg := strings.TrimLeft(errorParts[3], " ")
return nil, fmt.Errorf("%s in GraphQL string at location %s:%s", msg, errorParts[1], errorParts[2])
}
return schemaAST, nil
}
// Parses a GraphQL query, and returns the GraphQL AST for the query.
func parseQuery(query string) (*gqlast.QueryDocument, error) {
queryAST, err := gqlparser.ParseQuery(&gqlast.Source{Input: query})
if err != nil {
errorParts := strings.SplitN(err.Error(), ":", 4)
msg := strings.TrimLeft(errorParts[3], " ")
return nil, fmt.Errorf("%s in GraphQL string at location %s:%s", msg, errorParts[1], errorParts[2])
}
return queryAST, nil
}
// Validates a GraphQL query against a schema, and returns an error.
// In this case, we get a wrappered error list type, and pluck out
// just the first error message in the list.
func validateQuery(schema *gqlast.Schema, query *gqlast.QueryDocument) error {
// Validate the query against the schema, erroring if there's an issue.
err := gqlvalidator.Validate(schema, query)
if err != nil {
// We use strings.TrimSuffix to remove the '.' characters that the library
// authors include on most of their validation errors. This should be safe,
// since variable names in their error messages are usually quoted, and
// this affects only the last character(s) in the string.
// NOTE(philipc): We know the error location will be in the query string,
// because schema validation always happens before this function is called.
errorParts := strings.SplitN(err.Error(), ":", 4)
msg := strings.TrimSuffix(strings.TrimLeft(errorParts[3], " "), ".\n")
return fmt.Errorf("%s in GraphQL query string at location %s:%s", msg, errorParts[1], errorParts[2])
}
return nil
}
func getBuiltinSchema() *gqlast.SchemaDocument {
schema, err := gqlparser.ParseSchema(gqlvalidator.Prelude)
if err != nil {
panic(fmt.Errorf("Error in gqlparser Prelude (should be impossible): %w", err))
}
return schema
}
// NOTE(philipc): This function expects *validated* schema documents, and will break
// if it is fed arbitrary structures.
func mergeSchemaDocuments(docA *gqlast.SchemaDocument, docB *gqlast.SchemaDocument) *gqlast.SchemaDocument {
ast := &gqlast.SchemaDocument{}
ast.Merge(docA)
ast.Merge(docB)
return ast
}
// Converts a SchemaDocument into a gqlast.Schema object that can be used for validation.
// It merges in the builtin schema typedefs exactly as gqltop.LoadSchema did internally.
func convertSchema(schemaDoc *gqlast.SchemaDocument) (*gqlast.Schema, error) {
// Merge builtin schema + schema we were provided.
builtinsSchemaDoc := getBuiltinSchema()
mergedSchemaDoc := mergeSchemaDocuments(builtinsSchemaDoc, schemaDoc)
schema, err := gqlvalidator.ValidateSchemaDocument(mergedSchemaDoc)
if err != nil {
return nil, fmt.Errorf("Error in gqlparser SchemaDocument to Schema conversion: %w", err)
}
return schema, nil
}
// Converts an ast.Object into a gqlast.QueryDocument object.
func objectToQueryDocument(value ast.Object) (*gqlast.QueryDocument, error) {
// Convert ast.Term to interface{} for JSON encoding below.
asJSON, err := ast.JSON(value)
if err != nil {
return nil, err
}
// Marshal to JSON.
bs, err := json.Marshal(asJSON)
if err != nil {
return nil, err
}
// Unmarshal from JSON -> gqlast.QueryDocument.
var result gqlast.QueryDocument
err = json.Unmarshal(bs, &result)
if err != nil {
return nil, err
}
return &result, nil
}
// Converts an ast.Object into a gqlast.SchemaDocument object.
func objectToSchemaDocument(value ast.Object) (*gqlast.SchemaDocument, error) {
// Convert ast.Term to interface{} for JSON encoding below.
asJSON, err := ast.JSON(value)
if err != nil {
return nil, err
}
// Marshal to JSON.
bs, err := json.Marshal(asJSON)
if err != nil {
return nil, err
}
// Unmarshal from JSON -> gqlast.SchemaDocument.
var result gqlast.SchemaDocument
err = json.Unmarshal(bs, &result)
if err != nil {
return nil, err
}
return &result, nil
}
// Recursively traverses an AST that has been run through InterfaceToValue,
// and prunes away the fields with null or empty values, and all `Position`
// structs.
// NOTE(philipc): We currently prune away null values to reduce the level
// of clutter in the returned AST objects. In the future, if there is demand
// for ASTs that have a more regular/fixed structure, we may need to provide
// a "raw" version of the AST, where we still prune away the `Position`
// structs, but leave in the null fields.
func pruneIrrelevantGraphQLASTNodes(value ast.Value) ast.Value {
// We iterate over the Value we've been provided, and recurse down
// in the case of complex types, such as Arrays/Objects.
// We are guaranteed to only have to deal with standard JSON types,
// so this is much less ugly than what we'd need for supporting every
// extant ast type!
switch x := value.(type) {
case *ast.Array:
result := ast.NewArray()
// Iterate over the array's elements, and do the following:
// - Drop any Nulls
// - Drop any any empty object/array value (after running the pruner)
for i := 0; i < x.Len(); i++ {
vTerm := x.Elem(i)
switch v := vTerm.Value.(type) {
case ast.Null:
continue
case *ast.Array:
// Safe, because we knew the type before going to prune it.
va := pruneIrrelevantGraphQLASTNodes(v).(*ast.Array)
if va.Len() > 0 {
result = result.Append(ast.NewTerm(va))
}
case ast.Object:
// Safe, because we knew the type before going to prune it.
vo := pruneIrrelevantGraphQLASTNodes(v).(ast.Object)
if len(vo.Keys()) > 0 {
result = result.Append(ast.NewTerm(vo))
}
default:
result = result.Append(vTerm)
}
}
return result
case ast.Object:
result := ast.NewObject()
// Iterate over our object's keys, and do the following:
// - Drop "Position".
// - Drop any key with a Null value.
// - Drop any key with an empty object/array value (after running the pruner)
keys := x.Keys()
for _, k := range keys {
// We drop the "Position" objects because we don't need the
// source-backref/location info they provide for policy rules.
// Note that keys are ast.Strings.
if ast.String("Position").Equal(k.Value) {
continue
}
vTerm := x.Get(k)
switch v := vTerm.Value.(type) {
case ast.Null:
continue
case *ast.Array:
// Safe, because we knew the type before going to prune it.
va := pruneIrrelevantGraphQLASTNodes(v).(*ast.Array)
if va.Len() > 0 {
result.Insert(k, ast.NewTerm(va))
}
case ast.Object:
// Safe, because we knew the type before going to prune it.
vo := pruneIrrelevantGraphQLASTNodes(v).(ast.Object)
if len(vo.Keys()) > 0 {
result.Insert(k, ast.NewTerm(vo))
}
default:
result.Insert(k, vTerm)
}
}
return result
default:
return x
}
}
// Reports errors from parsing/validation.
func builtinGraphQLParse(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
var queryDoc *gqlast.QueryDocument
var schemaDoc *gqlast.SchemaDocument
var err error
// Parse/translate query if it's a string/object.
switch x := operands[0].Value.(type) {
case ast.String:
queryDoc, err = parseQuery(string(x))
case ast.Object:
queryDoc, err = objectToQueryDocument(x)
default:
// Error if wrong type.
return builtins.NewOperandTypeErr(0, x, "string", "object")
}
if err != nil {
return err
}
// Parse/translate schema if it's a string/object.
switch x := operands[1].Value.(type) {
case ast.String:
schemaDoc, err = parseSchema(string(x))
case ast.Object:
schemaDoc, err = objectToSchemaDocument(x)
default:
// Error if wrong type.
return builtins.NewOperandTypeErr(1, x, "string", "object")
}
if err != nil {
return err
}
// Transform the ASTs into Objects.
queryASTValue, err := ast.InterfaceToValue(queryDoc)
if err != nil {
return err
}
schemaASTValue, err := ast.InterfaceToValue(schemaDoc)
if err != nil {
return err
}
// Validate the query against the schema, erroring if there's an issue.
schema, err := convertSchema(schemaDoc)
if err != nil {
return err
}
if err := validateQuery(schema, queryDoc); err != nil {
return err
}
// Recursively remove irrelevant AST structures.
queryResult := pruneIrrelevantGraphQLASTNodes(queryASTValue.(ast.Object))
querySchema := pruneIrrelevantGraphQLASTNodes(schemaASTValue.(ast.Object))
// Construct return value.
verified := ast.ArrayTerm(
ast.NewTerm(queryResult),
ast.NewTerm(querySchema),
)
return iter(verified)
}
// Returns default value when errors occur.
func builtinGraphQLParseAndVerify(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
var queryDoc *gqlast.QueryDocument
var schemaDoc *gqlast.SchemaDocument
var err error
unverified := ast.ArrayTerm(
ast.BooleanTerm(false),
ast.NewTerm(ast.NewObject()),
ast.NewTerm(ast.NewObject()),
)
// Parse/translate query if it's a string/object.
switch x := operands[0].Value.(type) {
case ast.String:
queryDoc, err = parseQuery(string(x))
case ast.Object:
queryDoc, err = objectToQueryDocument(x)
default:
// Error if wrong type.
return iter(unverified)
}
if err != nil {
return iter(unverified)
}
// Parse/translate schema if it's a string/object.
switch x := operands[1].Value.(type) {
case ast.String:
schemaDoc, err = parseSchema(string(x))
case ast.Object:
schemaDoc, err = objectToSchemaDocument(x)
default:
// Error if wrong type.
return iter(unverified)
}
if err != nil {
return iter(unverified)
}
// Transform the ASTs into Objects.
queryASTValue, err := ast.InterfaceToValue(queryDoc)
if err != nil {
return iter(unverified)
}
schemaASTValue, err := ast.InterfaceToValue(schemaDoc)
if err != nil {
return iter(unverified)
}
// Validate the query against the schema, erroring if there's an issue.
schema, err := convertSchema(schemaDoc)
if err != nil {
return iter(unverified)
}
if err := validateQuery(schema, queryDoc); err != nil {
return iter(unverified)
}
// Recursively remove irrelevant AST structures.
queryResult := pruneIrrelevantGraphQLASTNodes(queryASTValue.(ast.Object))
querySchema := pruneIrrelevantGraphQLASTNodes(schemaASTValue.(ast.Object))
// Construct return value.
verified := ast.ArrayTerm(
ast.BooleanTerm(true),
ast.NewTerm(queryResult),
ast.NewTerm(querySchema),
)
return iter(verified)
}
func builtinGraphQLParseQuery(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
raw, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return err
}
// Get the highly-nested AST struct, along with any errors generated.
query, err := parseQuery(string(raw))
if err != nil {
return err
}
// Transform the AST into an Object.
value, err := ast.InterfaceToValue(query)
if err != nil {
return err
}
// Recursively remove irrelevant AST structures.
result := pruneIrrelevantGraphQLASTNodes(value.(ast.Object))
return iter(ast.NewTerm(result))
}
func builtinGraphQLParseSchema(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
raw, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return err
}
// Get the highly-nested AST struct, along with any errors generated.
schema, err := parseSchema(string(raw))
if err != nil {
return err
}
// Transform the AST into an Object.
value, err := ast.InterfaceToValue(schema)
if err != nil {
return err
}
// Recursively remove irrelevant AST structures.
result := pruneIrrelevantGraphQLASTNodes(value.(ast.Object))
return iter(ast.NewTerm(result))
}
func builtinGraphQLIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
var queryDoc *gqlast.QueryDocument
var schemaDoc *gqlast.SchemaDocument
var err error
switch x := operands[0].Value.(type) {
case ast.String:
queryDoc, err = parseQuery(string(x))
case ast.Object:
queryDoc, err = objectToQueryDocument(x)
default:
// Error if wrong type.
return iter(ast.BooleanTerm(false))
}
if err != nil {
return iter(ast.BooleanTerm(false))
}
switch x := operands[1].Value.(type) {
case ast.String:
schemaDoc, err = parseSchema(string(x))
case ast.Object:
schemaDoc, err = objectToSchemaDocument(x)
default:
// Error if wrong type.
return iter(ast.BooleanTerm(false))
}
if err != nil {
return iter(ast.BooleanTerm(false))
}
// Validate the query against the schema, erroring if there's an issue.
schema, err := convertSchema(schemaDoc)
if err != nil {
return iter(ast.BooleanTerm(false))
}
if err := validateQuery(schema, queryDoc); err != nil {
return iter(ast.BooleanTerm(false))
}
// If we got this far, the GraphQL query passed validation.
return iter(ast.BooleanTerm(true))
}
func init() {
RegisterBuiltinFunc(ast.GraphQLParse.Name, builtinGraphQLParse)
RegisterBuiltinFunc(ast.GraphQLParseAndVerify.Name, builtinGraphQLParseAndVerify)
RegisterBuiltinFunc(ast.GraphQLParseQuery.Name, builtinGraphQLParseQuery)
RegisterBuiltinFunc(ast.GraphQLParseSchema.Name, builtinGraphQLParseSchema)
RegisterBuiltinFunc(ast.GraphQLIsValid.Name, builtinGraphQLIsValid)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,8 @@
//go:build !go1.18 || !darwin
// +build !go1.18 !darwin
package topdown
func fixupDarwinGo118(x string, _ string) string {
return x
}

View File

@@ -0,0 +1,13 @@
//go:build go1.18
// +build go1.18
package topdown
func fixupDarwinGo118(x, y string) string {
switch x {
case "x509: certificate signed by unknown authority":
return y
default:
return x
}
}

View File

@@ -10,18 +10,14 @@ import (
"github.com/open-policy-agent/opa/ast"
)
var errConflictingDoc = fmt.Errorf("conflicting documents")
var errBadPath = fmt.Errorf("bad document path")
func mergeTermWithValues(exist *ast.Term, pairs [][2]*ast.Term) (*ast.Term, error) {
var result *ast.Term
var init bool
if exist != nil {
result = exist.Copy()
}
for _, pair := range pairs {
for i, pair := range pairs {
if err := ast.IsValidImportPath(pair[0].Value); err != nil {
return nil, errBadPath
@@ -29,33 +25,55 @@ func mergeTermWithValues(exist *ast.Term, pairs [][2]*ast.Term) (*ast.Term, erro
target := pair[0].Value.(ast.Ref)
if len(target) == 1 {
result = pair[1]
} else if result == nil {
result = ast.NewTerm(makeTree(target[1:], pair[1]))
} else {
node := result
done := false
for i := 1; i < len(target)-1 && !done; i++ {
if child := node.Get(target[i]); child == nil {
obj, ok := node.Value.(ast.Object)
if !ok {
return nil, errConflictingDoc
}
obj.Insert(target[i], ast.NewTerm(makeTree(target[i+1:], pair[1])))
done = true
} else {
node = child
}
}
if !done {
obj, ok := node.Value.(ast.Object)
if !ok {
return nil, errConflictingDoc
}
obj.Insert(target[len(target)-1], pair[1])
// Copy the value if subsequent pairs in the slice would modify it.
for j := i + 1; j < len(pairs); j++ {
other := pairs[j][0].Value.(ast.Ref)
if len(other) > len(target) && other.HasPrefix(target) {
pair[1] = pair[1].Copy()
break
}
}
if len(target) == 1 {
result = pair[1]
init = true
} else {
if !init {
result = exist.Copy()
init = true
}
if result == nil {
result = ast.NewTerm(makeTree(target[1:], pair[1]))
} else {
node := result
done := false
for i := 1; i < len(target)-1 && !done; i++ {
obj, ok := node.Value.(ast.Object)
if !ok {
result = ast.NewTerm(makeTree(target[i:], pair[1]))
done = true
continue
}
if child := obj.Get(target[i]); !isObject(child) {
obj.Insert(target[i], ast.NewTerm(makeTree(target[i+1:], pair[1])))
done = true
} else { // child is object
node = child
}
}
if !done {
if obj, ok := node.Value.(ast.Object); ok {
obj.Insert(target[len(target)-1], pair[1])
} else {
result = ast.NewTerm(makeTree(target[len(target)-1:], pair[1]))
}
}
}
}
}
if !init {
result = exist
}
return result, nil
@@ -72,3 +90,11 @@ func makeTree(k ast.Ref, v *ast.Term) ast.Object {
obj = ast.NewObject(ast.Item(k[0], v))
return obj
}
func isObject(x *ast.Term) bool {
if x == nil {
return false
}
_, ok := x.Value.(ast.Object)
return ok
}

View File

@@ -7,18 +7,22 @@ package topdown
import "github.com/open-policy-agent/opa/metrics"
const (
evalOpPlug = "eval_op_plug"
evalOpResolve = "eval_op_resolve"
evalOpRuleIndex = "eval_op_rule_index"
evalOpBuiltinCall = "eval_op_builtin_call"
evalOpVirtualCacheHit = "eval_op_virtual_cache_hit"
evalOpVirtualCacheMiss = "eval_op_virtual_cache_miss"
evalOpBaseCacheHit = "eval_op_base_cache_hit"
evalOpBaseCacheMiss = "eval_op_base_cache_miss"
partialOpSaveUnify = "partial_op_save_unify"
partialOpSaveSetContains = "partial_op_save_set_contains"
partialOpSaveSetContainsRec = "partial_op_save_set_contains_rec"
partialOpCopyPropagation = "partial_op_copy_propagation"
evalOpPlug = "eval_op_plug"
evalOpResolve = "eval_op_resolve"
evalOpRuleIndex = "eval_op_rule_index"
evalOpBuiltinCall = "eval_op_builtin_call"
evalOpVirtualCacheHit = "eval_op_virtual_cache_hit"
evalOpVirtualCacheMiss = "eval_op_virtual_cache_miss"
evalOpBaseCacheHit = "eval_op_base_cache_hit"
evalOpBaseCacheMiss = "eval_op_base_cache_miss"
evalOpComprehensionCacheSkip = "eval_op_comprehension_cache_skip"
evalOpComprehensionCacheBuild = "eval_op_comprehension_cache_build"
evalOpComprehensionCacheHit = "eval_op_comprehension_cache_hit"
evalOpComprehensionCacheMiss = "eval_op_comprehension_cache_miss"
partialOpSaveUnify = "partial_op_save_unify"
partialOpSaveSetContains = "partial_op_save_set_contains"
partialOpSaveSetContainsRec = "partial_op_save_set_contains_rec"
partialOpCopyPropagation = "partial_op_copy_propagation"
)
// Instrumentation implements helper functions to instrument query evaluation

View File

@@ -1,21 +0,0 @@
The MIT License (MIT)
Copyright (c) 2015 lestrrat
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,113 +0,0 @@
// Package buffer provides a very thin wrapper around []byte buffer called
// `Buffer`, to provide functionalities that are often used within the jwx
// related packages
package buffer
import (
"encoding/base64"
"encoding/binary"
"encoding/json"
"github.com/pkg/errors"
)
// Buffer wraps `[]byte` and provides functions that are often used in
// the jwx related packages. One notable difference is that while
// encoding/json marshalls `[]byte` using base64.StdEncoding, this
// module uses base64.RawURLEncoding as mandated by the spec
type Buffer []byte
// FromUint creates a `Buffer` from an unsigned int
func FromUint(v uint64) Buffer {
data := make([]byte, 8)
binary.BigEndian.PutUint64(data, v)
i := 0
for ; i < len(data); i++ {
if data[i] != 0x0 {
break
}
}
return Buffer(data[i:])
}
// FromBase64 constructs a new Buffer from a base64 encoded data
func FromBase64(v []byte) (Buffer, error) {
b := Buffer{}
if err := b.Base64Decode(v); err != nil {
return Buffer(nil), errors.Wrap(err, "failed to decode from base64")
}
return b, nil
}
// FromNData constructs a new Buffer from a "n:data" format
// (I made that name up)
func FromNData(v []byte) (Buffer, error) {
size := binary.BigEndian.Uint32(v)
buf := make([]byte, int(size))
copy(buf, v[4:4+size])
return Buffer(buf), nil
}
// Bytes returns the raw bytes that comprises the Buffer
func (b Buffer) Bytes() []byte {
return []byte(b)
}
// NData returns Datalen || Data, where Datalen is a 32 bit counter for
// the length of the following data, and Data is the octets that comprise
// the buffer data
func (b Buffer) NData() []byte {
buf := make([]byte, 4+b.Len())
binary.BigEndian.PutUint32(buf, uint32(b.Len()))
copy(buf[4:], b.Bytes())
return buf
}
// Len returns the number of bytes that the Buffer holds
func (b Buffer) Len() int {
return len(b)
}
// Base64Encode encodes the contents of the Buffer using base64.RawURLEncoding
func (b Buffer) Base64Encode() ([]byte, error) {
enc := base64.RawURLEncoding
out := make([]byte, enc.EncodedLen(len(b)))
enc.Encode(out, b)
return out, nil
}
// Base64Decode decodes the contents of the Buffer using base64.RawURLEncoding
func (b *Buffer) Base64Decode(v []byte) error {
enc := base64.RawURLEncoding
out := make([]byte, enc.DecodedLen(len(v)))
n, err := enc.Decode(out, v)
if err != nil {
return errors.Wrap(err, "failed to decode from base64")
}
out = out[:n]
*b = Buffer(out)
return nil
}
// MarshalJSON marshals the buffer into JSON format after encoding the buffer
// with base64.RawURLEncoding
func (b Buffer) MarshalJSON() ([]byte, error) {
v, err := b.Base64Encode()
if err != nil {
return nil, errors.Wrap(err, "failed to encode to base64")
}
return json.Marshal(string(v))
}
// UnmarshalJSON unmarshals from a JSON string into a Buffer, after decoding it
// with base64.RawURLEncoding
func (b *Buffer) UnmarshalJSON(data []byte) error {
var x string
if err := json.Unmarshal(data, &x); err != nil {
return errors.Wrap(err, "failed to unmarshal JSON")
}
return b.Base64Decode([]byte(x))
}

View File

@@ -1,11 +0,0 @@
package jwa
// EllipticCurveAlgorithm represents the algorithms used for EC keys
type EllipticCurveAlgorithm string
// Supported values for EllipticCurveAlgorithm
const (
P256 EllipticCurveAlgorithm = "P-256"
P384 EllipticCurveAlgorithm = "P-384"
P521 EllipticCurveAlgorithm = "P-521"
)

View File

@@ -1,67 +0,0 @@
package jwa
import (
"strconv"
"github.com/pkg/errors"
)
// KeyType represents the key type ("kty") that are supported
type KeyType string
var keyTypeAlg = map[string]struct{}{"EC": {}, "oct": {}, "RSA": {}}
// Supported values for KeyType
const (
EC KeyType = "EC" // Elliptic Curve
InvalidKeyType KeyType = "" // Invalid KeyType
OctetSeq KeyType = "oct" // Octet sequence (used to represent symmetric keys)
RSA KeyType = "RSA" // RSA
)
// Accept is used when conversion from values given by
// outside sources (such as JSON payloads) is required
func (keyType *KeyType) Accept(value interface{}) error {
var tmp KeyType
switch x := value.(type) {
case string:
tmp = KeyType(x)
case KeyType:
tmp = x
default:
return errors.Errorf(`invalid type for jwa.KeyType: %T`, value)
}
_, ok := keyTypeAlg[tmp.String()]
if !ok {
return errors.Errorf("Unknown Key Type algorithm")
}
*keyType = tmp
return nil
}
// String returns the string representation of a KeyType
func (keyType KeyType) String() string {
return string(keyType)
}
// UnmarshalJSON unmarshals and checks data as KeyType Algorithm
func (keyType *KeyType) UnmarshalJSON(data []byte) error {
var quote byte = '"'
var quoted string
if data[0] == quote {
var err error
quoted, err = strconv.Unquote(string(data))
if err != nil {
return errors.Wrap(err, "Failed to process signature algorithm")
}
} else {
quoted = string(data)
}
_, ok := keyTypeAlg[quoted]
if !ok {
return errors.Errorf("Unknown signature algorithm")
}
*keyType = KeyType(quoted)
return nil
}

View File

@@ -1,29 +0,0 @@
package jwa
import (
"crypto/elliptic"
"github.com/open-policy-agent/opa/topdown/internal/jwx/buffer"
)
// EllipticCurve provides a indirect type to standard elliptic curve such that we can
// use it for unmarshal
type EllipticCurve struct {
elliptic.Curve
}
// AlgorithmParameters provides a single structure suitable to unmarshaling any JWK
type AlgorithmParameters struct {
N buffer.Buffer `json:"n,omitempty"`
E buffer.Buffer `json:"e,omitempty"`
D buffer.Buffer `json:"d,omitempty"`
P buffer.Buffer `json:"p,omitempty"`
Q buffer.Buffer `json:"q,omitempty"`
Dp buffer.Buffer `json:"dp,omitempty"`
Dq buffer.Buffer `json:"dq,omitempty"`
Qi buffer.Buffer `json:"qi,omitempty"`
Crv EllipticCurveAlgorithm `json:"crv,omitempty"`
X buffer.Buffer `json:"x,omitempty"`
Y buffer.Buffer `json:"y,omitempty"`
K buffer.Buffer `json:"k,omitempty"`
}

View File

@@ -1,76 +0,0 @@
package jwa
import (
"strconv"
"github.com/pkg/errors"
)
// SignatureAlgorithm represents the various signature algorithms as described in https://tools.ietf.org/html/rfc7518#section-3.1
type SignatureAlgorithm string
var signatureAlg = map[string]struct{}{"ES256": {}, "ES384": {}, "ES512": {}, "HS256": {}, "HS384": {}, "HS512": {}, "PS256": {}, "PS384": {}, "PS512": {}, "RS256": {}, "RS384": {}, "RS512": {}, "none": {}}
// Supported values for SignatureAlgorithm
const (
ES256 SignatureAlgorithm = "ES256" // ECDSA using P-256 and SHA-256
ES384 SignatureAlgorithm = "ES384" // ECDSA using P-384 and SHA-384
ES512 SignatureAlgorithm = "ES512" // ECDSA using P-521 and SHA-512
HS256 SignatureAlgorithm = "HS256" // HMAC using SHA-256
HS384 SignatureAlgorithm = "HS384" // HMAC using SHA-384
HS512 SignatureAlgorithm = "HS512" // HMAC using SHA-512
NoSignature SignatureAlgorithm = "none"
PS256 SignatureAlgorithm = "PS256" // RSASSA-PSS using SHA256 and MGF1-SHA256
PS384 SignatureAlgorithm = "PS384" // RSASSA-PSS using SHA384 and MGF1-SHA384
PS512 SignatureAlgorithm = "PS512" // RSASSA-PSS using SHA512 and MGF1-SHA512
RS256 SignatureAlgorithm = "RS256" // RSASSA-PKCS-v1.5 using SHA-256
RS384 SignatureAlgorithm = "RS384" // RSASSA-PKCS-v1.5 using SHA-384
RS512 SignatureAlgorithm = "RS512" // RSASSA-PKCS-v1.5 using SHA-512
NoValue SignatureAlgorithm = "" // No value is different from none
)
// Accept is used when conversion from values given by
// outside sources (such as JSON payloads) is required
func (signature *SignatureAlgorithm) Accept(value interface{}) error {
var tmp SignatureAlgorithm
switch x := value.(type) {
case string:
tmp = SignatureAlgorithm(x)
case SignatureAlgorithm:
tmp = x
default:
return errors.Errorf(`invalid type for jwa.SignatureAlgorithm: %T`, value)
}
_, ok := signatureAlg[tmp.String()]
if !ok {
return errors.Errorf("Unknown signature algorithm")
}
*signature = tmp
return nil
}
// String returns the string representation of a SignatureAlgorithm
func (signature SignatureAlgorithm) String() string {
return string(signature)
}
// UnmarshalJSON unmarshals and checks data as Signature Algorithm
func (signature *SignatureAlgorithm) UnmarshalJSON(data []byte) error {
var quote byte = '"'
var quoted string
if data[0] == quote {
var err error
quoted, err = strconv.Unquote(string(data))
if err != nil {
return errors.Wrap(err, "Failed to process signature algorithm")
}
} else {
quoted = string(data)
}
_, ok := signatureAlg[quoted]
if !ok {
return errors.Errorf("Unknown signature algorithm")
}
*signature = SignatureAlgorithm(quoted)
return nil
}

View File

@@ -1,120 +0,0 @@
package jwk
import (
"crypto/ecdsa"
"crypto/elliptic"
"math/big"
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
func newECDSAPublicKey(key *ecdsa.PublicKey) (*ECDSAPublicKey, error) {
var hdr StandardHeaders
err := hdr.Set(KeyTypeKey, jwa.EC)
if err != nil {
return nil, errors.Wrapf(err, "Failed to set Key Type")
}
return &ECDSAPublicKey{
StandardHeaders: &hdr,
key: key,
}, nil
}
func newECDSAPrivateKey(key *ecdsa.PrivateKey) (*ECDSAPrivateKey, error) {
var hdr StandardHeaders
err := hdr.Set(KeyTypeKey, jwa.EC)
if err != nil {
return nil, errors.Wrapf(err, "Failed to set Key Type")
}
return &ECDSAPrivateKey{
StandardHeaders: &hdr,
key: key,
}, nil
}
// Materialize returns the EC-DSA public key represented by this JWK
func (k ECDSAPublicKey) Materialize() (interface{}, error) {
return k.key, nil
}
// Materialize returns the EC-DSA private key represented by this JWK
func (k ECDSAPrivateKey) Materialize() (interface{}, error) {
return k.key, nil
}
// GenerateKey creates a ECDSAPublicKey from JWK format
func (k *ECDSAPublicKey) GenerateKey(keyJSON *RawKeyJSON) error {
var x, y big.Int
if keyJSON.X == nil || keyJSON.Y == nil || keyJSON.Crv == "" {
return errors.Errorf("Missing mandatory key parameters X, Y or Crv")
}
x.SetBytes(keyJSON.X.Bytes())
y.SetBytes(keyJSON.Y.Bytes())
var curve elliptic.Curve
switch keyJSON.Crv {
case jwa.P256:
curve = elliptic.P256()
case jwa.P384:
curve = elliptic.P384()
case jwa.P521:
curve = elliptic.P521()
default:
return errors.Errorf(`invalid curve name %s`, keyJSON.Crv)
}
*k = ECDSAPublicKey{
StandardHeaders: &keyJSON.StandardHeaders,
key: &ecdsa.PublicKey{
Curve: curve,
X: &x,
Y: &y,
},
}
return nil
}
// GenerateKey creates a ECDSAPrivateKey from JWK format
func (k *ECDSAPrivateKey) GenerateKey(keyJSON *RawKeyJSON) error {
if keyJSON.D == nil {
return errors.Errorf("Missing mandatory key parameter D")
}
eCDSAPublicKey := &ECDSAPublicKey{}
err := eCDSAPublicKey.GenerateKey(keyJSON)
if err != nil {
return errors.Wrap(err, `failed to generate public key`)
}
dBytes := keyJSON.D.Bytes()
// The length of this octet string MUST be ceiling(log-base-2(n)/8)
// octets (where n is the order of the curve). This is because the private
// key d must be in the interval [1, n-1] so the bitlength of d should be
// no larger than the bitlength of n-1. The easiest way to find the octet
// length is to take bitlength(n-1), add 7 to force a carry, and shift this
// bit sequence right by 3, which is essentially dividing by 8 and adding
// 1 if there is any remainder. Thus, the private key value d should be
// output to (bitlength(n-1)+7)>>3 octets.
n := eCDSAPublicKey.key.Params().N
octetLength := (new(big.Int).Sub(n, big.NewInt(1)).BitLen() + 7) >> 3
if octetLength-len(dBytes) != 0 {
return errors.Errorf("Failed to generate private key. Incorrect D value")
}
privateKey := &ecdsa.PrivateKey{
PublicKey: *eCDSAPublicKey.key,
D: (&big.Int{}).SetBytes(keyJSON.D.Bytes()),
}
k.key = privateKey
k.StandardHeaders = &keyJSON.StandardHeaders
return nil
}

View File

@@ -1,178 +0,0 @@
package jwk
import (
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
// Convenience constants for common JWK parameters
const (
AlgorithmKey = "alg"
KeyIDKey = "kid"
KeyOpsKey = "key_ops"
KeyTypeKey = "kty"
KeyUsageKey = "use"
PrivateParamsKey = "privateParams"
)
// Headers provides a common interface to all future possible headers
type Headers interface {
Get(string) (interface{}, bool)
Set(string, interface{}) error
Walk(func(string, interface{}) error) error
GetAlgorithm() jwa.SignatureAlgorithm
GetKeyID() string
GetKeyOps() KeyOperationList
GetKeyType() jwa.KeyType
GetKeyUsage() string
GetPrivateParams() map[string]interface{}
}
// StandardHeaders stores the common JWK parameters
type StandardHeaders struct {
Algorithm *jwa.SignatureAlgorithm `json:"alg,omitempty"` // https://tools.ietf.org/html/rfc7517#section-4.4
KeyID string `json:"kid,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.4
KeyOps KeyOperationList `json:"key_ops,omitempty"` // https://tools.ietf.org/html/rfc7517#section-4.3
KeyType jwa.KeyType `json:"kty,omitempty"` // https://tools.ietf.org/html/rfc7517#section-4.1
KeyUsage string `json:"use,omitempty"` // https://tools.ietf.org/html/rfc7517#section-4.2
PrivateParams map[string]interface{} `json:"privateParams,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.4
}
// GetAlgorithm is a convenience function to retrieve the corresponding value stored in the StandardHeaders
func (h *StandardHeaders) GetAlgorithm() jwa.SignatureAlgorithm {
if v := h.Algorithm; v != nil {
return *v
}
return jwa.NoValue
}
// GetKeyID is a convenience function to retrieve the corresponding value stored in the StandardHeaders
func (h *StandardHeaders) GetKeyID() string {
return h.KeyID
}
// GetKeyOps is a convenience function to retrieve the corresponding value stored in the StandardHeaders
func (h *StandardHeaders) GetKeyOps() KeyOperationList {
return h.KeyOps
}
// GetKeyType is a convenience function to retrieve the corresponding value stored in the StandardHeaders
func (h *StandardHeaders) GetKeyType() jwa.KeyType {
return h.KeyType
}
// GetKeyUsage is a convenience function to retrieve the corresponding value stored in the StandardHeaders
func (h *StandardHeaders) GetKeyUsage() string {
return h.KeyUsage
}
// GetPrivateParams is a convenience function to retrieve the corresponding value stored in the StandardHeaders
func (h *StandardHeaders) GetPrivateParams() map[string]interface{} {
return h.PrivateParams
}
// Get is a general getter function for JWK StandardHeaders structure
func (h *StandardHeaders) Get(name string) (interface{}, bool) {
switch name {
case AlgorithmKey:
alg := h.GetAlgorithm()
if alg != jwa.NoValue {
return alg, true
}
return nil, false
case KeyIDKey:
v := h.KeyID
if v == "" {
return nil, false
}
return v, true
case KeyOpsKey:
v := h.KeyOps
if v == nil {
return nil, false
}
return v, true
case KeyTypeKey:
v := h.KeyType
if v == jwa.InvalidKeyType {
return nil, false
}
return v, true
case KeyUsageKey:
v := h.KeyUsage
if v == "" {
return nil, false
}
return v, true
case PrivateParamsKey:
v := h.PrivateParams
if len(v) == 0 {
return nil, false
}
return v, true
default:
return nil, false
}
}
// Set is a general getter function for JWK StandardHeaders structure
func (h *StandardHeaders) Set(name string, value interface{}) error {
switch name {
case AlgorithmKey:
var acceptor jwa.SignatureAlgorithm
if err := acceptor.Accept(value); err != nil {
return errors.Wrapf(err, `invalid value for %s key`, AlgorithmKey)
}
h.Algorithm = &acceptor
return nil
case KeyIDKey:
if v, ok := value.(string); ok {
h.KeyID = v
return nil
}
return errors.Errorf("invalid value for %s key: %T", KeyIDKey, value)
case KeyOpsKey:
if err := h.KeyOps.Accept(value); err != nil {
return errors.Wrapf(err, "invalid value for %s key", KeyOpsKey)
}
return nil
case KeyTypeKey:
if err := h.KeyType.Accept(value); err != nil {
return errors.Wrapf(err, "invalid value for %s key", KeyTypeKey)
}
return nil
case KeyUsageKey:
if v, ok := value.(string); ok {
h.KeyUsage = v
return nil
}
return errors.Errorf("invalid value for %s key: %T", KeyUsageKey, value)
case PrivateParamsKey:
if v, ok := value.(map[string]interface{}); ok {
h.PrivateParams = v
return nil
}
return errors.Errorf("invalid value for %s key: %T", PrivateParamsKey, value)
default:
return errors.Errorf(`invalid key: %s`, name)
}
}
// Walk iterates over all JWK standard headers fields while applying a function to its value.
func (h StandardHeaders) Walk(f func(string, interface{}) error) error {
for _, key := range []string{AlgorithmKey, KeyIDKey, KeyOpsKey, KeyTypeKey, KeyUsageKey, PrivateParamsKey} {
if v, ok := h.Get(key); ok {
if err := f(key, v); err != nil {
return errors.Wrapf(err, `walk function returned error for %s`, key)
}
}
}
for k, v := range h.PrivateParams {
if err := f(k, v); err != nil {
return errors.Wrapf(err, `walk function returned error for %s`, k)
}
}
return nil
}

View File

@@ -1,70 +0,0 @@
package jwk
import (
"crypto/ecdsa"
"crypto/rsa"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
// Set is a convenience struct to allow generating and parsing
// JWK sets as opposed to single JWKs
type Set struct {
Keys []Key `json:"keys"`
}
// Key defines the minimal interface for each of the
// key types. Their use and implementation differ significantly
// between each key types, so you should use type assertions
// to perform more specific tasks with each key
type Key interface {
Headers
// Materialize creates the corresponding key. For example,
// RSA types would create *rsa.PublicKey or *rsa.PrivateKey,
// EC types would create *ecdsa.PublicKey or *ecdsa.PrivateKey,
// and OctetSeq types create a []byte key.
Materialize() (interface{}, error)
GenerateKey(*RawKeyJSON) error
}
// RawKeyJSON is generic type that represents any kind JWK
type RawKeyJSON struct {
StandardHeaders
jwa.AlgorithmParameters
}
// RawKeySetJSON is generic type that represents a JWK Set
type RawKeySetJSON struct {
Keys []RawKeyJSON `json:"keys"`
}
// RSAPublicKey is a type of JWK generated from RSA public keys
type RSAPublicKey struct {
*StandardHeaders
key *rsa.PublicKey
}
// RSAPrivateKey is a type of JWK generated from RSA private keys
type RSAPrivateKey struct {
*StandardHeaders
key *rsa.PrivateKey
}
// SymmetricKey is a type of JWK generated from symmetric keys
type SymmetricKey struct {
*StandardHeaders
key []byte
}
// ECDSAPublicKey is a type of JWK generated from ECDSA public keys
type ECDSAPublicKey struct {
*StandardHeaders
key *ecdsa.PublicKey
}
// ECDSAPrivateKey is a type of JWK generated from ECDH-ES private keys
type ECDSAPrivateKey struct {
*StandardHeaders
key *ecdsa.PrivateKey
}

View File

@@ -1,150 +0,0 @@
// Package jwk implements JWK as described in https://tools.ietf.org/html/rfc7517
package jwk
import (
"crypto/ecdsa"
"crypto/rsa"
"encoding/json"
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
// GetPublicKey returns the public key based on the private key type.
// For rsa key types *rsa.PublicKey is returned; for ecdsa key types *ecdsa.PublicKey;
// for byte slice (raw) keys, the key itself is returned. If the corresponding
// public key cannot be deduced, an error is returned
func GetPublicKey(key interface{}) (interface{}, error) {
if key == nil {
return nil, errors.New(`jwk.New requires a non-nil key`)
}
switch v := key.(type) {
// Mental note: although Public() is defined in both types,
// you can not coalesce the clauses for rsa.PrivateKey and
// ecdsa.PrivateKey, as then `v` becomes interface{}
// b/c the compiler cannot deduce the exact type.
case *rsa.PrivateKey:
return v.Public(), nil
case *ecdsa.PrivateKey:
return v.Public(), nil
case []byte:
return v, nil
default:
return nil, errors.Errorf(`invalid key type %T`, key)
}
}
// GetKeyTypeFromKey creates a jwk.Key from the given key.
func GetKeyTypeFromKey(key interface{}) jwa.KeyType {
switch key.(type) {
case *rsa.PrivateKey, *rsa.PublicKey:
return jwa.RSA
case *ecdsa.PrivateKey, *ecdsa.PublicKey:
return jwa.EC
case []byte:
return jwa.OctetSeq
default:
return jwa.InvalidKeyType
}
}
// New creates a jwk.Key from the given key.
func New(key interface{}) (Key, error) {
if key == nil {
return nil, errors.New(`jwk.New requires a non-nil key`)
}
switch v := key.(type) {
case *rsa.PrivateKey:
return newRSAPrivateKey(v)
case *rsa.PublicKey:
return newRSAPublicKey(v)
case *ecdsa.PrivateKey:
return newECDSAPrivateKey(v)
case *ecdsa.PublicKey:
return newECDSAPublicKey(v)
case []byte:
return newSymmetricKey(v)
default:
return nil, errors.Errorf(`invalid key type %T`, key)
}
}
func parse(jwkSrc string) (*Set, error) {
var jwkKeySet Set
var jwkKey Key
rawKeySetJSON := &RawKeySetJSON{}
err := json.Unmarshal([]byte(jwkSrc), rawKeySetJSON)
if err != nil {
return nil, errors.Wrap(err, "Failed to unmarshal JWK Set")
}
if len(rawKeySetJSON.Keys) == 0 {
// It might be a single key
rawKeyJSON := &RawKeyJSON{}
err := json.Unmarshal([]byte(jwkSrc), rawKeyJSON)
if err != nil {
return nil, errors.Wrap(err, "Failed to unmarshal JWK")
}
jwkKey, err = rawKeyJSON.GenerateKey()
if err != nil {
return nil, errors.Wrap(err, "Failed to generate key")
}
// Add to set
jwkKeySet.Keys = append(jwkKeySet.Keys, jwkKey)
} else {
for i := range rawKeySetJSON.Keys {
rawKeyJSON := rawKeySetJSON.Keys[i]
jwkKey, err = rawKeyJSON.GenerateKey()
if err != nil {
return nil, errors.Wrap(err, "Failed to generate key: %s")
}
jwkKeySet.Keys = append(jwkKeySet.Keys, jwkKey)
}
}
return &jwkKeySet, nil
}
// ParseBytes parses JWK from the incoming byte buffer.
func ParseBytes(buf []byte) (*Set, error) {
return parse(string(buf[:]))
}
// ParseString parses JWK from the incoming string.
func ParseString(s string) (*Set, error) {
return parse(s)
}
// GenerateKey creates an internal representation of a key from a raw JWK JSON
func (r *RawKeyJSON) GenerateKey() (Key, error) {
var key Key
switch r.KeyType {
case jwa.RSA:
if r.D != nil {
key = &RSAPrivateKey{}
} else {
key = &RSAPublicKey{}
}
case jwa.EC:
if r.D != nil {
key = &ECDSAPrivateKey{}
} else {
key = &ECDSAPublicKey{}
}
case jwa.OctetSeq:
key = &SymmetricKey{}
default:
return nil, errors.Errorf(`Unrecognized key type`)
}
err := key.GenerateKey(r)
if err != nil {
return nil, errors.Wrap(err, "Failed to generate key from JWK")
}
return key, nil
}

View File

@@ -1,68 +0,0 @@
package jwk
import (
"encoding/json"
"fmt"
"github.com/pkg/errors"
)
// KeyUsageType is used to denote what this key should be used for
type KeyUsageType string
const (
// ForSignature is the value used in the headers to indicate that
// this key should be used for signatures
ForSignature KeyUsageType = "sig"
// ForEncryption is the value used in the headers to indicate that
// this key should be used for encryptiong
ForEncryption KeyUsageType = "enc"
)
// KeyOperation is used to denote the allowed operations for a Key
type KeyOperation string
// KeyOperationList represents an slice of KeyOperation
type KeyOperationList []KeyOperation
var keyOps = map[string]struct{}{"sign": {}, "verify": {}, "encrypt": {}, "decrypt": {}, "wrapKey": {}, "unwrapKey": {}, "deriveKey": {}, "deriveBits": {}}
// KeyOperation constants
const (
KeyOpSign KeyOperation = "sign" // (compute digital signature or MAC)
KeyOpVerify = "verify" // (verify digital signature or MAC)
KeyOpEncrypt = "encrypt" // (encrypt content)
KeyOpDecrypt = "decrypt" // (decrypt content and validate decryption, if applicable)
KeyOpWrapKey = "wrapKey" // (encrypt key)
KeyOpUnwrapKey = "unwrapKey" // (decrypt key and validate decryption, if applicable)
KeyOpDeriveKey = "deriveKey" // (derive key)
KeyOpDeriveBits = "deriveBits" // (derive bits not to be used as a key)
)
// Accept determines if Key Operation is valid
func (keyOperationList *KeyOperationList) Accept(v interface{}) error {
switch x := v.(type) {
case KeyOperationList:
*keyOperationList = x
return nil
default:
return errors.Errorf(`invalid value %T`, v)
}
}
// UnmarshalJSON unmarshals and checks data as KeyType Algorithm
func (keyOperationList *KeyOperationList) UnmarshalJSON(data []byte) error {
var tempKeyOperationList []string
err := json.Unmarshal(data, &tempKeyOperationList)
if err != nil {
return fmt.Errorf("invalid key operation")
}
for _, value := range tempKeyOperationList {
_, ok := keyOps[value]
if !ok {
return fmt.Errorf("unknown key operation")
}
*keyOperationList = append(*keyOperationList, KeyOperation(value))
}
return nil
}

View File

@@ -1,103 +0,0 @@
package jwk
import (
"crypto/rsa"
"math/big"
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
func newRSAPublicKey(key *rsa.PublicKey) (*RSAPublicKey, error) {
var hdr StandardHeaders
err := hdr.Set(KeyTypeKey, jwa.RSA)
if err != nil {
return nil, errors.Wrapf(err, "Failed to set Key Type")
}
return &RSAPublicKey{
StandardHeaders: &hdr,
key: key,
}, nil
}
func newRSAPrivateKey(key *rsa.PrivateKey) (*RSAPrivateKey, error) {
var hdr StandardHeaders
err := hdr.Set(KeyTypeKey, jwa.RSA)
if err != nil {
return nil, errors.Wrapf(err, "Failed to set Key Type")
}
return &RSAPrivateKey{
StandardHeaders: &hdr,
key: key,
}, nil
}
// Materialize returns the standard RSA Public Key representation stored in the internal representation
func (k *RSAPublicKey) Materialize() (interface{}, error) {
if k.key == nil {
return nil, errors.New(`key has no rsa.PublicKey associated with it`)
}
return k.key, nil
}
// Materialize returns the standard RSA Private Key representation stored in the internal representation
func (k *RSAPrivateKey) Materialize() (interface{}, error) {
if k.key == nil {
return nil, errors.New(`key has no rsa.PrivateKey associated with it`)
}
return k.key, nil
}
// GenerateKey creates a RSAPublicKey from a RawKeyJSON
func (k *RSAPublicKey) GenerateKey(keyJSON *RawKeyJSON) error {
if keyJSON.N == nil || keyJSON.E == nil {
return errors.Errorf("Missing mandatory key parameters N or E")
}
rsaPublicKey := &rsa.PublicKey{
N: (&big.Int{}).SetBytes(keyJSON.N.Bytes()),
E: int((&big.Int{}).SetBytes(keyJSON.E.Bytes()).Int64()),
}
k.key = rsaPublicKey
k.StandardHeaders = &keyJSON.StandardHeaders
return nil
}
// GenerateKey creates a RSAPublicKey from a RawKeyJSON
func (k *RSAPrivateKey) GenerateKey(keyJSON *RawKeyJSON) error {
rsaPublicKey := &RSAPublicKey{}
err := rsaPublicKey.GenerateKey(keyJSON)
if err != nil {
return errors.Wrap(err, "failed to generate public key")
}
if keyJSON.D == nil || keyJSON.P == nil || keyJSON.Q == nil {
return errors.Errorf("Missing mandatory key parameters D, P or Q")
}
privateKey := &rsa.PrivateKey{
PublicKey: *rsaPublicKey.key,
D: (&big.Int{}).SetBytes(keyJSON.D.Bytes()),
Primes: []*big.Int{
(&big.Int{}).SetBytes(keyJSON.P.Bytes()),
(&big.Int{}).SetBytes(keyJSON.Q.Bytes()),
},
}
if keyJSON.Dp.Len() > 0 {
privateKey.Precomputed.Dp = (&big.Int{}).SetBytes(keyJSON.Dp.Bytes())
}
if keyJSON.Dq.Len() > 0 {
privateKey.Precomputed.Dq = (&big.Int{}).SetBytes(keyJSON.Dq.Bytes())
}
if keyJSON.Qi.Len() > 0 {
privateKey.Precomputed.Qinv = (&big.Int{}).SetBytes(keyJSON.Qi.Bytes())
}
k.key = privateKey
k.StandardHeaders = &keyJSON.StandardHeaders
return nil
}

View File

@@ -1,41 +0,0 @@
package jwk
import (
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
func newSymmetricKey(key []byte) (*SymmetricKey, error) {
var hdr StandardHeaders
err := hdr.Set(KeyTypeKey, jwa.OctetSeq)
if err != nil {
return nil, errors.Wrapf(err, "Failed to set Key Type")
}
return &SymmetricKey{
StandardHeaders: &hdr,
key: key,
}, nil
}
// Materialize returns the octets for this symmetric key.
// Since this is a symmetric key, this just calls Octets
func (s SymmetricKey) Materialize() (interface{}, error) {
return s.Octets(), nil
}
// Octets returns the octets in the key
func (s SymmetricKey) Octets() []byte {
return s.key
}
// GenerateKey creates a Symmetric key from a RawKeyJSON
func (s *SymmetricKey) GenerateKey(keyJSON *RawKeyJSON) error {
*s = SymmetricKey{
StandardHeaders: &keyJSON.StandardHeaders,
key: keyJSON.K,
}
return nil
}

View File

@@ -1,154 +0,0 @@
package jws
import (
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
// Constants for JWS Common parameters
const (
AlgorithmKey = "alg"
ContentTypeKey = "cty"
CriticalKey = "crit"
JWKKey = "jwk"
JWKSetURLKey = "jku"
KeyIDKey = "kid"
PrivateParamsKey = "privateParams"
TypeKey = "typ"
)
// Headers provides a common interface for common header parameters
type Headers interface {
Get(string) (interface{}, bool)
Set(string, interface{}) error
GetAlgorithm() jwa.SignatureAlgorithm
}
// StandardHeaders contains JWS common parameters.
type StandardHeaders struct {
Algorithm jwa.SignatureAlgorithm `json:"alg,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.1
ContentType string `json:"cty,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.10
Critical []string `json:"crit,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.11
JWK string `json:"jwk,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.3
JWKSetURL string `json:"jku,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.2
KeyID string `json:"kid,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.4
PrivateParams map[string]interface{} `json:"privateParams,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.9
Type string `json:"typ,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.9
}
// GetAlgorithm returns algorithm
func (h *StandardHeaders) GetAlgorithm() jwa.SignatureAlgorithm {
return h.Algorithm
}
// Get is a general getter function for StandardHeaders structure
func (h *StandardHeaders) Get(name string) (interface{}, bool) {
switch name {
case AlgorithmKey:
v := h.Algorithm
if v == "" {
return nil, false
}
return v, true
case ContentTypeKey:
v := h.ContentType
if v == "" {
return nil, false
}
return v, true
case CriticalKey:
v := h.Critical
if len(v) == 0 {
return nil, false
}
return v, true
case JWKKey:
v := h.JWK
if v == "" {
return nil, false
}
return v, true
case JWKSetURLKey:
v := h.JWKSetURL
if v == "" {
return nil, false
}
return v, true
case KeyIDKey:
v := h.KeyID
if v == "" {
return nil, false
}
return v, true
case PrivateParamsKey:
v := h.PrivateParams
if len(v) == 0 {
return nil, false
}
return v, true
case TypeKey:
v := h.Type
if v == "" {
return nil, false
}
return v, true
default:
return nil, false
}
}
// Set is a general setter function for StandardHeaders structure
func (h *StandardHeaders) Set(name string, value interface{}) error {
switch name {
case AlgorithmKey:
if err := h.Algorithm.Accept(value); err != nil {
return errors.Wrapf(err, `invalid value for %s key`, AlgorithmKey)
}
return nil
case ContentTypeKey:
if v, ok := value.(string); ok {
h.ContentType = v
return nil
}
return errors.Errorf(`invalid value for %s key: %T`, ContentTypeKey, value)
case CriticalKey:
if v, ok := value.([]string); ok {
h.Critical = v
return nil
}
return errors.Errorf(`invalid value for %s key: %T`, CriticalKey, value)
case JWKKey:
if v, ok := value.(string); ok {
h.JWK = v
return nil
}
return errors.Errorf(`invalid value for %s key: %T`, JWKKey, value)
case JWKSetURLKey:
if v, ok := value.(string); ok {
h.JWKSetURL = v
return nil
}
return errors.Errorf(`invalid value for %s key: %T`, JWKSetURLKey, value)
case KeyIDKey:
if v, ok := value.(string); ok {
h.KeyID = v
return nil
}
return errors.Errorf(`invalid value for %s key: %T`, KeyIDKey, value)
case PrivateParamsKey:
if v, ok := value.(map[string]interface{}); ok {
h.PrivateParams = v
return nil
}
return errors.Errorf(`invalid value for %s key: %T`, PrivateParamsKey, value)
case TypeKey:
if v, ok := value.(string); ok {
h.Type = v
return nil
}
return errors.Errorf(`invalid value for %s key: %T`, TypeKey, value)
default:
return errors.Errorf(`invalid key: %s`, name)
}
}

View File

@@ -1,22 +0,0 @@
package jws
// Message represents a full JWS encoded message. Flattened serialization
// is not supported as a struct, but rather it's represented as a
// Message struct with only one `Signature` element.
//
// Do not expect to use the Message object to verify or construct a
// signed payloads with. You should only use this when you want to actually
// want to programmatically view the contents for the full JWS Payload.
//
// To sign and verify, use the appropriate `SignWithOption()` nad `Verify()` functions
type Message struct {
Payload []byte `json:"payload"`
Signatures []*Signature `json:"signatures,omitempty"`
}
// Signature represents the headers and signature of a JWS message
type Signature struct {
Headers Headers `json:"header,omitempty"` // Unprotected Headers
Protected Headers `json:"Protected,omitempty"` // Protected Headers
Signature []byte `json:"signature,omitempty"` // GetSignature
}

View File

@@ -1,210 +0,0 @@
// Package jws implements the digital Signature on JSON based data
// structures as described in https://tools.ietf.org/html/rfc7515
//
// If you do not care about the details, the only things that you
// would need to use are the following functions:
//
// jws.SignWithOption(Payload, algorithm, key)
// jws.Verify(encodedjws, algorithm, key)
//
// To sign, simply use `jws.SignWithOption`. `Payload` is a []byte buffer that
// contains whatever data you want to sign. `alg` is one of the
// jwa.SignatureAlgorithm constants from package jwa. For RSA and
// ECDSA family of algorithms, you will need to prepare a private key.
// For HMAC family, you just need a []byte value. The `jws.SignWithOption`
// function will return the encoded JWS message on success.
//
// To verify, use `jws.Verify`. It will parse the `encodedjws` buffer
// and verify the result using `algorithm` and `key`. Upon successful
// verification, the original Payload is returned, so you can work on it.
package jws
import (
"bytes"
"encoding/base64"
"encoding/json"
"strings"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwk"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify"
"github.com/pkg/errors"
)
// SignLiteral generates a Signature for the given Payload and Headers, and serializes
// it in compact serialization format. In this format you may NOT use
// multiple signers.
//
func SignLiteral(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, hdrBuf []byte) ([]byte, error) {
encodedHdr := base64.RawURLEncoding.EncodeToString(hdrBuf)
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
signingInput := strings.Join(
[]string{
encodedHdr,
encodedPayload,
}, ".",
)
signer, err := sign.New(alg)
if err != nil {
return nil, errors.Wrap(err, `failed to create signer`)
}
signature, err := signer.Sign([]byte(signingInput), key)
if err != nil {
return nil, errors.Wrap(err, `failed to sign Payload`)
}
encodedSignature := base64.RawURLEncoding.EncodeToString(signature)
compactSerialization := strings.Join(
[]string{
signingInput,
encodedSignature,
}, ".",
)
return []byte(compactSerialization), nil
}
// SignWithOption generates a Signature for the given Payload, and serializes
// it in compact serialization format. In this format you may NOT use
// multiple signers.
//
// If you would like to pass custom Headers, use the WithHeaders option.
func SignWithOption(payload []byte, alg jwa.SignatureAlgorithm, key interface{}) ([]byte, error) {
var headers Headers = &StandardHeaders{}
err := headers.Set(AlgorithmKey, alg)
if err != nil {
return nil, errors.Wrap(err, "Failed to set alg value")
}
hdrBuf, err := json.Marshal(headers)
if err != nil {
return nil, errors.Wrap(err, `failed to marshal Headers`)
}
return SignLiteral(payload, alg, key, hdrBuf)
}
// Verify checks if the given JWS message is verifiable using `alg` and `key`.
// If the verification is successful, `err` is nil, and the content of the
// Payload that was signed is returned. If you need more fine-grained
// control of the verification process, manually call `Parse`, generate a
// verifier, and call `Verify` on the parsed JWS message object.
func Verify(buf []byte, alg jwa.SignatureAlgorithm, key interface{}) (ret []byte, err error) {
verifier, err := verify.New(alg)
if err != nil {
return nil, errors.Wrap(err, "failed to create verifier")
}
buf = bytes.TrimSpace(buf)
if len(buf) == 0 {
return nil, errors.New(`attempt to verify empty buffer`)
}
parts, err := SplitCompact(string(buf[:]))
if err != nil {
return nil, errors.Wrap(err, `failed extract from compact serialization format`)
}
signingInput := strings.Join(
[]string{
parts[0],
parts[1],
}, ".",
)
decodedSignature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, errors.Wrap(err, "Failed to decode signature")
}
if err := verifier.Verify([]byte(signingInput), decodedSignature, key); err != nil {
return nil, errors.Wrap(err, "Failed to verify message")
}
if decodedPayload, err := base64.RawURLEncoding.DecodeString(parts[1]); err == nil {
return decodedPayload, nil
}
return nil, errors.Wrap(err, "Failed to decode Payload")
}
// VerifyWithJWK verifies the JWS message using the specified JWK
func VerifyWithJWK(buf []byte, key jwk.Key) (payload []byte, err error) {
keyVal, err := key.Materialize()
if err != nil {
return nil, errors.Wrap(err, "Failed to materialize key")
}
return Verify(buf, key.GetAlgorithm(), keyVal)
}
// VerifyWithJWKSet verifies the JWS message using JWK key set.
// By default it will only pick up keys that have the "use" key
// set to either "sig" or "enc", but you can override it by
// providing a keyaccept function.
func VerifyWithJWKSet(buf []byte, keyset *jwk.Set) (payload []byte, err error) {
for _, key := range keyset.Keys {
payload, err := VerifyWithJWK(buf, key)
if err == nil {
return payload, nil
}
}
return nil, errors.New("failed to verify with any of the keys")
}
// ParseByte parses a JWS value serialized via compact serialization and provided as []byte.
func ParseByte(jwsCompact []byte) (m *Message, err error) {
return parseCompact(string(jwsCompact[:]))
}
// ParseString parses a JWS value serialized via compact serialization and provided as string.
func ParseString(s string) (*Message, error) {
return parseCompact(s)
}
// SplitCompact splits a JWT and returns its three parts
// separately: Protected Headers, Payload and Signature.
func SplitCompact(jwsCompact string) ([]string, error) {
parts := strings.Split(jwsCompact, ".")
if len(parts) < 3 {
return nil, errors.New("Failed to split compact serialization")
}
return parts, nil
}
// parseCompact parses a JWS value serialized via compact serialization.
func parseCompact(str string) (m *Message, err error) {
var decodedHeader, decodedPayload, decodedSignature []byte
parts, err := SplitCompact(str)
if err != nil {
return nil, errors.Wrap(err, `invalid compact serialization format`)
}
if decodedHeader, err = base64.RawURLEncoding.DecodeString(parts[0]); err != nil {
return nil, errors.Wrap(err, `failed to decode Headers`)
}
var hdr StandardHeaders
if err := json.Unmarshal(decodedHeader, &hdr); err != nil {
return nil, errors.Wrap(err, `failed to parse JOSE Headers`)
}
if decodedPayload, err = base64.RawURLEncoding.DecodeString(parts[1]); err != nil {
return nil, errors.Wrap(err, `failed to decode Payload`)
}
if len(parts) > 2 {
if decodedSignature, err = base64.RawURLEncoding.DecodeString(parts[2]); err != nil {
return nil, errors.Wrap(err, `failed to decode Signature`)
}
}
var msg Message
msg.Payload = decodedPayload
msg.Signatures = append(msg.Signatures, &Signature{
Protected: &hdr,
Signature: decodedSignature,
})
return &msg, nil
}

View File

@@ -1,26 +0,0 @@
package jws
// PublicHeaders returns the public headers in a JWS
func (s Signature) PublicHeaders() Headers {
return s.Headers
}
// ProtectedHeaders returns the protected headers in a JWS
func (s Signature) ProtectedHeaders() Headers {
return s.Protected
}
// GetSignature returns the signature in a JWS
func (s Signature) GetSignature() []byte {
return s.Signature
}
// GetPayload returns the payload in a JWS
func (m Message) GetPayload() []byte {
return m.Payload
}
// GetSignatures returns the all signatures in a JWS
func (m Message) GetSignatures() []*Signature {
return m.Signatures
}

View File

@@ -1,84 +0,0 @@
package sign
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
"github.com/pkg/errors"
)
var ecdsaSignFuncs = map[jwa.SignatureAlgorithm]ecdsaSignFunc{}
func init() {
algs := map[jwa.SignatureAlgorithm]crypto.Hash{
jwa.ES256: crypto.SHA256,
jwa.ES384: crypto.SHA384,
jwa.ES512: crypto.SHA512,
}
for alg, h := range algs {
ecdsaSignFuncs[alg] = makeECDSASignFunc(h)
}
}
func makeECDSASignFunc(hash crypto.Hash) ecdsaSignFunc {
return ecdsaSignFunc(func(payload []byte, key *ecdsa.PrivateKey) ([]byte, error) {
curveBits := key.Curve.Params().BitSize
keyBytes := curveBits / 8
// Curve bits do not need to be a multiple of 8.
if curveBits%8 > 0 {
keyBytes++
}
h := hash.New()
h.Write(payload)
r, s, err := ecdsa.Sign(rand.Reader, key, h.Sum(nil))
if err != nil {
return nil, errors.Wrap(err, "failed to sign payload using ecdsa")
}
rBytes := r.Bytes()
rBytesPadded := make([]byte, keyBytes)
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
sBytes := s.Bytes()
sBytesPadded := make([]byte, keyBytes)
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
out := append(rBytesPadded, sBytesPadded...)
return out, nil
})
}
func newECDSA(alg jwa.SignatureAlgorithm) (*ECDSASigner, error) {
signfn, ok := ecdsaSignFuncs[alg]
if !ok {
return nil, errors.Errorf(`unsupported algorithm while trying to create ECDSA signer: %s`, alg)
}
return &ECDSASigner{
alg: alg,
sign: signfn,
}, nil
}
// Algorithm returns the signer algorithm
func (s ECDSASigner) Algorithm() jwa.SignatureAlgorithm {
return s.alg
}
// Sign signs payload with a ECDSA private key
func (s ECDSASigner) Sign(payload []byte, key interface{}) ([]byte, error) {
if key == nil {
return nil, errors.New(`missing private key while signing payload`)
}
privateKey, ok := key.(*ecdsa.PrivateKey)
if !ok {
return nil, errors.Errorf(`invalid key type %T. *ecdsa.PrivateKey is required`, key)
}
return s.sign(payload, privateKey)
}

View File

@@ -1,66 +0,0 @@
package sign
import (
"crypto/hmac"
"crypto/sha256"
"crypto/sha512"
"hash"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
"github.com/pkg/errors"
)
var hmacSignFuncs = map[jwa.SignatureAlgorithm]hmacSignFunc{}
func init() {
algs := map[jwa.SignatureAlgorithm]func() hash.Hash{
jwa.HS256: sha256.New,
jwa.HS384: sha512.New384,
jwa.HS512: sha512.New,
}
for alg, h := range algs {
hmacSignFuncs[alg] = makeHMACSignFunc(h)
}
}
func newHMAC(alg jwa.SignatureAlgorithm) (*HMACSigner, error) {
signer, ok := hmacSignFuncs[alg]
if !ok {
return nil, errors.Errorf(`unsupported algorithm while trying to create HMAC signer: %s`, alg)
}
return &HMACSigner{
alg: alg,
sign: signer,
}, nil
}
func makeHMACSignFunc(hfunc func() hash.Hash) hmacSignFunc {
return hmacSignFunc(func(payload []byte, key []byte) ([]byte, error) {
h := hmac.New(hfunc, key)
h.Write(payload)
return h.Sum(nil), nil
})
}
// Algorithm returns the signer algorithm
func (s HMACSigner) Algorithm() jwa.SignatureAlgorithm {
return s.alg
}
// Sign signs payload with a Symmetric key
func (s HMACSigner) Sign(payload []byte, key interface{}) ([]byte, error) {
hmackey, ok := key.([]byte)
if !ok {
return nil, errors.Errorf(`invalid key type %T. []byte is required`, key)
}
if len(hmackey) == 0 {
return nil, errors.New(`missing key while signing payload`)
}
return s.sign(payload, hmackey)
}

View File

@@ -1,45 +0,0 @@
package sign
import (
"crypto/ecdsa"
"crypto/rsa"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
// Signer provides a common interface for supported alg signing methods
type Signer interface {
// Sign creates a signature for the given `payload`.
// `key` is the key used for signing the payload, and is usually
// the private key type associated with the signature method. For example,
// for `jwa.RSXXX` and `jwa.PSXXX` types, you need to pass the
// `*"crypto/rsa".PrivateKey` type.
// Check the documentation for each signer for details
Sign(payload []byte, key interface{}) ([]byte, error)
Algorithm() jwa.SignatureAlgorithm
}
type rsaSignFunc func([]byte, *rsa.PrivateKey) ([]byte, error)
// RSASigner uses crypto/rsa to sign the payloads.
type RSASigner struct {
alg jwa.SignatureAlgorithm
sign rsaSignFunc
}
type ecdsaSignFunc func([]byte, *ecdsa.PrivateKey) ([]byte, error)
// ECDSASigner uses crypto/ecdsa to sign the payloads.
type ECDSASigner struct {
alg jwa.SignatureAlgorithm
sign ecdsaSignFunc
}
type hmacSignFunc func([]byte, []byte) ([]byte, error)
// HMACSigner uses crypto/hmac to sign the payloads.
type HMACSigner struct {
alg jwa.SignatureAlgorithm
sign hmacSignFunc
}

View File

@@ -1,97 +0,0 @@
package sign
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
"github.com/pkg/errors"
)
var rsaSignFuncs = map[jwa.SignatureAlgorithm]rsaSignFunc{}
func init() {
algs := map[jwa.SignatureAlgorithm]struct {
Hash crypto.Hash
SignFunc func(crypto.Hash) rsaSignFunc
}{
jwa.RS256: {
Hash: crypto.SHA256,
SignFunc: makeSignPKCS1v15,
},
jwa.RS384: {
Hash: crypto.SHA384,
SignFunc: makeSignPKCS1v15,
},
jwa.RS512: {
Hash: crypto.SHA512,
SignFunc: makeSignPKCS1v15,
},
jwa.PS256: {
Hash: crypto.SHA256,
SignFunc: makeSignPSS,
},
jwa.PS384: {
Hash: crypto.SHA384,
SignFunc: makeSignPSS,
},
jwa.PS512: {
Hash: crypto.SHA512,
SignFunc: makeSignPSS,
},
}
for alg, item := range algs {
rsaSignFuncs[alg] = item.SignFunc(item.Hash)
}
}
func makeSignPKCS1v15(hash crypto.Hash) rsaSignFunc {
return rsaSignFunc(func(payload []byte, key *rsa.PrivateKey) ([]byte, error) {
h := hash.New()
h.Write(payload)
return rsa.SignPKCS1v15(rand.Reader, key, hash, h.Sum(nil))
})
}
func makeSignPSS(hash crypto.Hash) rsaSignFunc {
return rsaSignFunc(func(payload []byte, key *rsa.PrivateKey) ([]byte, error) {
h := hash.New()
h.Write(payload)
return rsa.SignPSS(rand.Reader, key, hash, h.Sum(nil), &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
})
})
}
func newRSA(alg jwa.SignatureAlgorithm) (*RSASigner, error) {
signfn, ok := rsaSignFuncs[alg]
if !ok {
return nil, errors.Errorf(`unsupported algorithm while trying to create RSA signer: %s`, alg)
}
return &RSASigner{
alg: alg,
sign: signfn,
}, nil
}
// Algorithm returns the signer algorithm
func (s RSASigner) Algorithm() jwa.SignatureAlgorithm {
return s.alg
}
// Sign creates a signature using crypto/rsa. key must be a non-nil instance of
// `*"crypto/rsa".PrivateKey`.
func (s RSASigner) Sign(payload []byte, key interface{}) ([]byte, error) {
if key == nil {
return nil, errors.New(`missing private key while signing payload`)
}
rsakey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, errors.Errorf(`invalid key type %T. *rsa.PrivateKey is required`, key)
}
return s.sign(payload, rsakey)
}

View File

@@ -1,21 +0,0 @@
package sign
import (
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
// New creates a signer that signs payloads using the given signature algorithm.
func New(alg jwa.SignatureAlgorithm) (Signer, error) {
switch alg {
case jwa.RS256, jwa.RS384, jwa.RS512, jwa.PS256, jwa.PS384, jwa.PS512:
return newRSA(alg)
case jwa.ES256, jwa.ES384, jwa.ES512:
return newECDSA(alg)
case jwa.HS256, jwa.HS384, jwa.HS512:
return newHMAC(alg)
default:
return nil, errors.Errorf(`unsupported signature algorithm %s`, alg)
}
}

View File

@@ -1,67 +0,0 @@
package verify
import (
"crypto"
"crypto/ecdsa"
"math/big"
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
var ecdsaVerifyFuncs = map[jwa.SignatureAlgorithm]ecdsaVerifyFunc{}
func init() {
algs := map[jwa.SignatureAlgorithm]crypto.Hash{
jwa.ES256: crypto.SHA256,
jwa.ES384: crypto.SHA384,
jwa.ES512: crypto.SHA512,
}
for alg, h := range algs {
ecdsaVerifyFuncs[alg] = makeECDSAVerifyFunc(h)
}
}
func makeECDSAVerifyFunc(hash crypto.Hash) ecdsaVerifyFunc {
return ecdsaVerifyFunc(func(payload []byte, signature []byte, key *ecdsa.PublicKey) error {
r, s := &big.Int{}, &big.Int{}
n := len(signature) / 2
r.SetBytes(signature[:n])
s.SetBytes(signature[n:])
h := hash.New()
h.Write(payload)
if !ecdsa.Verify(key, h.Sum(nil), r, s) {
return errors.New(`failed to verify signature using ecdsa`)
}
return nil
})
}
func newECDSA(alg jwa.SignatureAlgorithm) (*ECDSAVerifier, error) {
verifyfn, ok := ecdsaVerifyFuncs[alg]
if !ok {
return nil, errors.Errorf(`unsupported algorithm while trying to create ECDSA verifier: %s`, alg)
}
return &ECDSAVerifier{
verify: verifyfn,
}, nil
}
// Verify checks whether the signature for a given input and key is correct
func (v ECDSAVerifier) Verify(payload []byte, signature []byte, key interface{}) error {
if key == nil {
return errors.New(`missing public key while verifying payload`)
}
ecdsakey, ok := key.(*ecdsa.PublicKey)
if !ok {
return errors.Errorf(`invalid key type %T. *ecdsa.PublicKey is required`, key)
}
return v.verify(payload, signature, ecdsakey)
}

View File

@@ -1,33 +0,0 @@
package verify
import (
"crypto/hmac"
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign"
)
func newHMAC(alg jwa.SignatureAlgorithm) (*HMACVerifier, error) {
s, err := sign.New(alg)
if err != nil {
return nil, errors.Wrap(err, `failed to generate HMAC signer`)
}
return &HMACVerifier{signer: s}, nil
}
// Verify checks whether the signature for a given input and key is correct
func (v HMACVerifier) Verify(signingInput, signature []byte, key interface{}) (err error) {
expected, err := v.signer.Sign(signingInput, key)
if err != nil {
return errors.Wrap(err, `failed to generated signature`)
}
if !hmac.Equal(signature, expected) {
return errors.New(`failed to match hmac signature`)
}
return nil
}

View File

@@ -1,39 +0,0 @@
package verify
import (
"crypto/ecdsa"
"crypto/rsa"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign"
)
// Verifier provides a common interface for supported alg verification methods
type Verifier interface {
// Verify checks whether the payload and signature are valid for
// the given key.
// `key` is the key used for verifying the payload, and is usually
// the public key associated with the signature method. For example,
// for `jwa.RSXXX` and `jwa.PSXXX` types, you need to pass the
// `*"crypto/rsa".PublicKey` type.
// Check the documentation for each verifier for details
Verify(payload []byte, signature []byte, key interface{}) error
}
type rsaVerifyFunc func([]byte, []byte, *rsa.PublicKey) error
// RSAVerifier implements the Verifier interface
type RSAVerifier struct {
verify rsaVerifyFunc
}
type ecdsaVerifyFunc func([]byte, []byte, *ecdsa.PublicKey) error
// ECDSAVerifier implements the Verifier interface
type ECDSAVerifier struct {
verify ecdsaVerifyFunc
}
// HMACVerifier implements the Verifier interface
type HMACVerifier struct {
signer sign.Signer
}

View File

@@ -1,88 +0,0 @@
package verify
import (
"crypto"
"crypto/rsa"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
"github.com/pkg/errors"
)
var rsaVerifyFuncs = map[jwa.SignatureAlgorithm]rsaVerifyFunc{}
func init() {
algs := map[jwa.SignatureAlgorithm]struct {
Hash crypto.Hash
VerifyFunc func(crypto.Hash) rsaVerifyFunc
}{
jwa.RS256: {
Hash: crypto.SHA256,
VerifyFunc: makeVerifyPKCS1v15,
},
jwa.RS384: {
Hash: crypto.SHA384,
VerifyFunc: makeVerifyPKCS1v15,
},
jwa.RS512: {
Hash: crypto.SHA512,
VerifyFunc: makeVerifyPKCS1v15,
},
jwa.PS256: {
Hash: crypto.SHA256,
VerifyFunc: makeVerifyPSS,
},
jwa.PS384: {
Hash: crypto.SHA384,
VerifyFunc: makeVerifyPSS,
},
jwa.PS512: {
Hash: crypto.SHA512,
VerifyFunc: makeVerifyPSS,
},
}
for alg, item := range algs {
rsaVerifyFuncs[alg] = item.VerifyFunc(item.Hash)
}
}
func makeVerifyPKCS1v15(hash crypto.Hash) rsaVerifyFunc {
return rsaVerifyFunc(func(payload, signature []byte, key *rsa.PublicKey) error {
h := hash.New()
h.Write(payload)
return rsa.VerifyPKCS1v15(key, hash, h.Sum(nil), signature)
})
}
func makeVerifyPSS(hash crypto.Hash) rsaVerifyFunc {
return rsaVerifyFunc(func(payload, signature []byte, key *rsa.PublicKey) error {
h := hash.New()
h.Write(payload)
return rsa.VerifyPSS(key, hash, h.Sum(nil), signature, nil)
})
}
func newRSA(alg jwa.SignatureAlgorithm) (*RSAVerifier, error) {
verifyfn, ok := rsaVerifyFuncs[alg]
if !ok {
return nil, errors.Errorf(`unsupported algorithm while trying to create RSA verifier: %s`, alg)
}
return &RSAVerifier{
verify: verifyfn,
}, nil
}
// Verify checks if a JWS is valid.
func (v RSAVerifier) Verify(payload, signature []byte, key interface{}) error {
if key == nil {
return errors.New(`missing public key while verifying payload`)
}
rsaKey, ok := key.(*rsa.PublicKey)
if !ok {
return errors.Errorf(`invalid key type %T. *rsa.PublicKey is required`, key)
}
return v.verify(payload, signature, rsaKey)
}

View File

@@ -1,22 +0,0 @@
package verify
import (
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
// New creates a new JWS verifier using the specified algorithm
// and the public key
func New(alg jwa.SignatureAlgorithm) (Verifier, error) {
switch alg {
case jwa.RS256, jwa.RS384, jwa.RS512, jwa.PS256, jwa.PS384, jwa.PS512:
return newRSA(alg)
case jwa.ES256, jwa.ES384, jwa.ES512:
return newECDSA(alg)
case jwa.HS256, jwa.HS384, jwa.HS512:
return newHMAC(alg)
default:
return nil, errors.Errorf(`unsupported signature algorithm: %s`, alg)
}
}

View File

@@ -93,11 +93,12 @@ func jsonRemove(a *ast.Term, b *ast.Term) (*ast.Term, error) {
return nil, err
}
return ast.NewTerm(newSet), nil
case ast.Array:
case *ast.Array:
// When indexes are removed we shift left to close empty spots in the array
// as per the JSON patch spec.
var newArray ast.Array
for i, v := range aValue {
newArray := ast.NewArray()
for i := 0; i < aValue.Len(); i++ {
v := aValue.Elem(i)
// recurse and add the diff of sub objects as needed
// Note: Keys in b will be strings for the index, eg path /a/1/b => {"a": {"1": {"b": null}}}
diffValue, err := jsonRemove(v, bObj.Get(ast.StringTerm(strconv.Itoa(i))))
@@ -105,7 +106,7 @@ func jsonRemove(a *ast.Term, b *ast.Term) (*ast.Term, error) {
return nil, err
}
if diffValue != nil {
newArray = append(newArray, diffValue)
newArray = newArray.Append(diffValue)
}
}
return ast.NewTerm(newArray), nil
@@ -142,9 +143,9 @@ func getJSONPaths(operand ast.Value) ([]ast.Ref, error) {
var paths []ast.Ref
switch v := operand.(type) {
case ast.Array:
for _, f := range v {
filter, err := parsePath(f)
case *ast.Array:
for i := 0; i < v.Len(); i++ {
filter, err := parsePath(v.Elem(i))
if err != nil {
return nil, err
}
@@ -175,15 +176,18 @@ func parsePath(path *ast.Term) (ast.Ref, error) {
var pathSegments ast.Ref
switch p := path.Value.(type) {
case ast.String:
parts := strings.Split(strings.Trim(string(p), "/"), "/")
if p == "" {
return ast.Ref{}, nil
}
parts := strings.Split(strings.TrimLeft(string(p), "/"), "/")
for _, part := range parts {
part = strings.ReplaceAll(strings.ReplaceAll(part, "~1", "/"), "~0", "~")
pathSegments = append(pathSegments, ast.StringTerm(part))
}
case ast.Array:
for _, term := range p {
case *ast.Array:
p.Foreach(func(term *ast.Term) {
pathSegments = append(pathSegments, term)
}
})
default:
return nil, builtins.NewOperandErr(2, "must be one of {set, array} containing string paths or array of path segments but got %v", ast.TypeName(p))
}
@@ -199,6 +203,12 @@ func pathsToObject(paths []ast.Ref) ast.Object {
node := root
var done bool
// If the path is an empty JSON path, skip all further processing.
if len(path) == 0 {
done = true
}
// Otherwise, we should have 1+ path segments to work with.
for i := 0; i < len(path)-1 && !done; i++ {
k := path[i]
@@ -229,7 +239,391 @@ func pathsToObject(paths []ast.Ref) ast.Object {
return root
}
// toIndex tries to convert path elements (that may be strings) into indices into
// an array.
func toIndex(arr *ast.Array, term *ast.Term) (int, error) {
i := 0
var ok bool
switch v := term.Value.(type) {
case ast.Number:
if i, ok = v.Int(); !ok {
return 0, fmt.Errorf("Invalid number type for indexing")
}
case ast.String:
if v == "-" {
return arr.Len(), nil
}
num := ast.Number(v)
if i, ok = num.Int(); !ok {
return 0, fmt.Errorf("Invalid string for indexing")
}
if v != "0" && strings.HasPrefix(string(v), "0") {
return 0, fmt.Errorf("Leading zeros are not allowed in JSON paths")
}
default:
return 0, fmt.Errorf("Invalid type for indexing")
}
return i, nil
}
// patchWorkerris a worker that modifies a direct child of a term located
// at the given key. It returns the new term, and optionally a result that
// is passed back to the caller.
type patchWorker = func(parent, key *ast.Term) (updated, result *ast.Term)
func jsonPatchTraverse(
target *ast.Term,
path ast.Ref,
worker patchWorker,
) (*ast.Term, *ast.Term) {
if len(path) < 1 {
return nil, nil
}
key := path[0]
if len(path) == 1 {
return worker(target, key)
}
success := false
var updated, result *ast.Term
switch parent := target.Value.(type) {
case ast.Object:
obj := ast.NewObject()
parent.Foreach(func(k, v *ast.Term) {
if k.Equal(key) {
if v, result = jsonPatchTraverse(v, path[1:], worker); v != nil {
obj.Insert(k, v)
success = true
}
} else {
obj.Insert(k, v)
}
})
updated = ast.NewTerm(obj)
case *ast.Array:
idx, err := toIndex(parent, key)
if err != nil {
return nil, nil
}
arr := ast.NewArray()
for i := 0; i < parent.Len(); i++ {
v := parent.Elem(i)
if idx == i {
if v, result = jsonPatchTraverse(v, path[1:], worker); v != nil {
arr = arr.Append(v)
success = true
}
} else {
arr = arr.Append(v)
}
}
updated = ast.NewTerm(arr)
case ast.Set:
set := ast.NewSet()
parent.Foreach(func(k *ast.Term) {
if k.Equal(key) {
if k, result = jsonPatchTraverse(k, path[1:], worker); k != nil {
set.Add(k)
success = true
}
} else {
set.Add(k)
}
})
updated = ast.NewTerm(set)
}
if success {
return updated, result
}
return nil, nil
}
// jsonPatchGet goes one step further than jsonPatchTraverse and returns the
// term at the location specified by the path. It is used in functions
// where we want to read a value but not manipulate its parent: for example
// jsonPatchTest and jsonPatchCopy.
//
// Because it uses jsonPatchTraverse, it makes shallow copies of the objects
// along the path. We could possibly add a signaling mechanism that we didn't
// make any changes to avoid this.
func jsonPatchGet(target *ast.Term, path ast.Ref) *ast.Term {
// Special case: get entire document.
if len(path) == 0 {
return target
}
_, result := jsonPatchTraverse(target, path, func(parent, key *ast.Term) (*ast.Term, *ast.Term) {
switch v := parent.Value.(type) {
case ast.Object:
return parent, v.Get(key)
case *ast.Array:
i, err := toIndex(v, key)
if err == nil {
return parent, v.Elem(i)
}
case ast.Set:
if v.Contains(key) {
return parent, key
}
}
return nil, nil
})
return result
}
func jsonPatchAdd(target *ast.Term, path ast.Ref, value *ast.Term) *ast.Term {
// Special case: replacing root document.
if len(path) == 0 {
return value
}
target, _ = jsonPatchTraverse(target, path, func(parent *ast.Term, key *ast.Term) (*ast.Term, *ast.Term) {
switch original := parent.Value.(type) {
case ast.Object:
obj := ast.NewObject()
original.Foreach(func(k, v *ast.Term) {
obj.Insert(k, v)
})
obj.Insert(key, value)
return ast.NewTerm(obj), nil
case *ast.Array:
idx, err := toIndex(original, key)
if err != nil || idx < 0 || idx > original.Len() {
return nil, nil
}
arr := ast.NewArray()
for i := 0; i < idx; i++ {
arr = arr.Append(original.Elem(i))
}
arr = arr.Append(value)
for i := idx; i < original.Len(); i++ {
arr = arr.Append(original.Elem(i))
}
return ast.NewTerm(arr), nil
case ast.Set:
if !key.Equal(value) {
return nil, nil
}
set := ast.NewSet()
original.Foreach(func(k *ast.Term) {
set.Add(k)
})
set.Add(key)
return ast.NewTerm(set), nil
}
return nil, nil
})
return target
}
func jsonPatchRemove(target *ast.Term, path ast.Ref) (*ast.Term, *ast.Term) {
// Special case: replacing root document.
if len(path) == 0 {
return nil, nil
}
target, removed := jsonPatchTraverse(target, path, func(parent *ast.Term, key *ast.Term) (*ast.Term, *ast.Term) {
var removed *ast.Term
switch original := parent.Value.(type) {
case ast.Object:
obj := ast.NewObject()
original.Foreach(func(k, v *ast.Term) {
if k.Equal(key) {
removed = v
} else {
obj.Insert(k, v)
}
})
return ast.NewTerm(obj), removed
case *ast.Array:
idx, err := toIndex(original, key)
if err != nil || idx < 0 || idx >= original.Len() {
return nil, nil
}
arr := ast.NewArray()
for i := 0; i < idx; i++ {
arr = arr.Append(original.Elem(i))
}
removed = original.Elem(idx)
for i := idx + 1; i < original.Len(); i++ {
arr = arr.Append(original.Elem(i))
}
return ast.NewTerm(arr), removed
case ast.Set:
set := ast.NewSet()
original.Foreach(func(k *ast.Term) {
if k.Equal(key) {
removed = k
} else {
set.Add(k)
}
})
return ast.NewTerm(set), removed
}
return nil, nil
})
if target != nil && removed != nil {
return target, removed
}
return nil, nil
}
func jsonPatchReplace(target *ast.Term, path ast.Ref, value *ast.Term) *ast.Term {
// Special case: replacing the whole document.
if len(path) == 0 {
return value
}
// Replace is specified as `remove` followed by `add`.
if target, _ = jsonPatchRemove(target, path); target == nil {
return nil
}
return jsonPatchAdd(target, path, value)
}
func jsonPatchMove(target *ast.Term, path ast.Ref, from ast.Ref) *ast.Term {
// Move is specified as `remove` followed by `add`.
target, removed := jsonPatchRemove(target, from)
if target == nil || removed == nil {
return nil
}
return jsonPatchAdd(target, path, removed)
}
func jsonPatchCopy(target *ast.Term, path ast.Ref, from ast.Ref) *ast.Term {
value := jsonPatchGet(target, from)
if value == nil {
return nil
}
return jsonPatchAdd(target, path, value)
}
func jsonPatchTest(target *ast.Term, path ast.Ref, value *ast.Term) *ast.Term {
actual := jsonPatchGet(target, path)
if actual == nil {
return nil
}
if actual.Equal(value) {
return target
}
return nil
}
func builtinJSONPatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
// JSON patch supports arrays, objects as well as values as the target.
target := ast.NewTerm(operands[0].Value)
// Expect an array of operations.
operations, err := builtins.ArrayOperand(operands[1].Value, 2)
if err != nil {
return err
}
// Apply operations one by one.
for i := 0; i < operations.Len(); i++ {
if object, ok := operations.Elem(i).Value.(ast.Object); ok {
getAttribute := func(attr string) (*ast.Term, error) {
if term := object.Get(ast.StringTerm(attr)); term != nil {
return term, nil
}
return nil, builtins.NewOperandErr(2, fmt.Sprintf("patch is missing '%s' attribute", attr))
}
getPathAttribute := func(attr string) (ast.Ref, error) {
term, err := getAttribute(attr)
if err != nil {
return ast.Ref{}, err
}
path, err := parsePath(term)
if err != nil {
return ast.Ref{}, err
}
return path, nil
}
// Parse operation.
opTerm, err := getAttribute("op")
if err != nil {
return err
}
op, ok := opTerm.Value.(ast.String)
if !ok {
return builtins.NewOperandErr(2, "patch attribute 'op' must be a string")
}
// Parse path.
path, err := getPathAttribute("path")
if err != nil {
return err
}
switch op {
case "add":
value, err := getAttribute("value")
if err != nil {
return err
}
target = jsonPatchAdd(target, path, value)
case "remove":
target, _ = jsonPatchRemove(target, path)
case "replace":
value, err := getAttribute("value")
if err != nil {
return err
}
target = jsonPatchReplace(target, path, value)
case "move":
from, err := getPathAttribute("from")
if err != nil {
return err
}
target = jsonPatchMove(target, path, from)
case "copy":
from, err := getPathAttribute("from")
if err != nil {
return err
}
target = jsonPatchCopy(target, path, from)
case "test":
value, err := getAttribute("value")
if err != nil {
return err
}
target = jsonPatchTest(target, path, value)
default:
return builtins.NewOperandErr(2, "must be an array of JSON-Patch objects")
}
} else {
return builtins.NewOperandErr(2, "must be an array of JSON-Patch objects")
}
// JSON patches should work atomically; and if one of them fails,
// we should not try to continue.
if target == nil {
return nil
}
}
return iter(target)
}
func init() {
RegisterBuiltinFunc(ast.JSONFilter.Name, builtinJSONFilter)
RegisterBuiltinFunc(ast.JSONRemove.Name, builtinJSONRemove)
RegisterBuiltinFunc(ast.JSONPatch.Name, builtinJSONPatch)
}

64
vendor/github.com/open-policy-agent/opa/topdown/net.go generated vendored Normal file
View File

@@ -0,0 +1,64 @@
// Copyright 2021 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"net"
"strings"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
type lookupIPAddrCacheKey string
// resolv is the same as net.DefaultResolver -- this is for mocking it out in tests
var resolv = &net.Resolver{}
func builtinLookupIPAddr(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
a, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return err
}
name := string(a)
err = verifyHost(bctx, name)
if err != nil {
return err
}
key := lookupIPAddrCacheKey(name)
if val, ok := bctx.Cache.Get(key); ok {
return iter(val.(*ast.Term))
}
addrs, err := resolv.LookupIPAddr(bctx.Context, name)
if err != nil {
// NOTE(sr): We can't do better than this right now, see https://github.com/golang/go/issues/36208
if strings.Contains(err.Error(), "operation was canceled") || strings.Contains(err.Error(), "i/o timeout") {
return Halt{
Err: &Error{
Code: CancelErr,
Message: ast.NetLookupIPAddr.Name + ": " + err.Error(),
Location: bctx.Location,
},
}
}
return err
}
ret := ast.NewSet()
for _, a := range addrs {
ret.Add(ast.StringTerm(a.String()))
}
t := ast.NewTerm(ret)
bctx.Cache.Put(key, t)
return iter(t)
}
func init() {
RegisterBuiltinFunc(ast.NetLookupIPAddr.Name, builtinLookupIPAddr)
}

View File

@@ -0,0 +1,99 @@
// Copyright 2020 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"fmt"
"math/big"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
type randIntCachingKey string
var one = big.NewInt(1)
func builtinNumbersRange(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
x, err := builtins.BigIntOperand(operands[0].Value, 1)
if err != nil {
return err
}
y, err := builtins.BigIntOperand(operands[1].Value, 2)
if err != nil {
return err
}
result := ast.NewArray()
cmp := x.Cmp(y)
haltErr := Halt{
Err: &Error{
Code: CancelErr,
Message: "numbers.range: timed out before generating all numbers in range",
},
}
if cmp <= 0 {
for i := new(big.Int).Set(x); i.Cmp(y) <= 0; i = i.Add(i, one) {
if bctx.Cancel != nil && bctx.Cancel.Cancelled() {
return haltErr
}
result = result.Append(ast.NewTerm(builtins.IntToNumber(i)))
}
} else {
for i := new(big.Int).Set(x); i.Cmp(y) >= 0; i = i.Sub(i, one) {
if bctx.Cancel != nil && bctx.Cancel.Cancelled() {
return haltErr
}
result = result.Append(ast.NewTerm(builtins.IntToNumber(i)))
}
}
return iter(ast.NewTerm(result))
}
func builtinRandIntn(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
strOp, err := builtins.StringOperand(args[0].Value, 1)
if err != nil {
return err
}
n, err := builtins.IntOperand(args[1].Value, 2)
if err != nil {
return err
}
if n == 0 {
return iter(ast.IntNumberTerm(0))
}
if n < 0 {
n = -n
}
var key = randIntCachingKey(fmt.Sprintf("%s-%d", strOp, n))
if val, ok := bctx.Cache.Get(key); ok {
return iter(val.(*ast.Term))
}
r, err := bctx.Rand()
if err != nil {
return err
}
result := ast.IntNumberTerm(r.Intn(n))
bctx.Cache.Put(key, result)
return iter(result)
}
func init() {
RegisterBuiltinFunc(ast.NumbersRange.Name, builtinNumbersRange)
RegisterBuiltinFunc(ast.RandIntn.Name, builtinRandIntn)
}

View File

@@ -6,8 +6,8 @@ package topdown
import (
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/internal/ref"
"github.com/open-policy-agent/opa/topdown/builtins"
"github.com/open-policy-agent/opa/types"
)
func builtinObjectUnion(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
@@ -26,6 +26,38 @@ func builtinObjectUnion(_ BuiltinContext, operands []*ast.Term, iter func(*ast.T
return iter(ast.NewTerm(r))
}
func builtinObjectUnionN(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
arr, err := builtins.ArrayOperand(operands[0].Value, 1)
if err != nil {
return err
}
// Because we need merge-with-overwrite behavior, we can iterate
// back-to-front, and get a mostly correct set of key assignments that
// give us the "last assignment wins, with merges" behavior we want.
// However, if a non-object overwrites an object value anywhere in the
// chain of assignments for a key, we have to "freeze" that key to
// prevent accidentally picking up nested objects that could merge with
// it from earlier in the input array.
// Example:
// Input: [{"a": {"b": 2}}, {"a": 4}, {"a": {"c": 3}}]
// Want Output: {"a": {"c": 3}}
result := ast.NewObject()
frozenKeys := map[*ast.Term]struct{}{}
for i := arr.Len() - 1; i >= 0; i-- {
o, ok := arr.Elem(i).Value.(ast.Object)
if !ok {
return builtins.NewOperandElementErr(1, arr, arr.Elem(i).Value, "object")
}
mergewithOverwriteInPlace(result, o, frozenKeys)
if err != nil {
return err
}
}
return iter(ast.NewTerm(result))
}
func builtinObjectRemove(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
// Expect an object and an array/set/object of keys
obj, err := builtins.ObjectOperand(operands[0].Value, 1)
@@ -81,11 +113,29 @@ func builtinObjectGet(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter
return err
}
if ret := object.Get(operands[1]); ret != nil {
return iter(ret)
// if the get key is not an array, attempt to get the top level key for the operand value in the object
path, err := builtins.ArrayOperand(operands[1].Value, 2)
if err != nil {
if ret := object.Get(operands[1]); ret != nil {
return iter(ret)
}
return iter(operands[2])
}
return iter(operands[2])
// if the path is empty, then we skip selecting nested keys and return the whole object
if path.Len() == 0 {
return iter(operands[0])
}
// build an ast.Ref from the array and see if it matches within the object
pathRef := ref.ArrayPath(path)
value, err := object.Find(pathRef)
if err != nil {
return iter(operands[2])
}
return iter(ast.NewTerm(value))
}
// getObjectKeysParam returns a set of key values
@@ -94,10 +144,11 @@ func getObjectKeysParam(arrayOrSet ast.Value) (ast.Set, error) {
keys := ast.NewSet()
switch v := arrayOrSet.(type) {
case ast.Array:
for _, f := range v {
case *ast.Array:
_ = v.Iter(func(f *ast.Term) error {
keys.Add(f)
}
return nil
})
case ast.Set:
_ = v.Iter(func(f *ast.Term) error {
keys.Add(f)
@@ -109,7 +160,7 @@ func getObjectKeysParam(arrayOrSet ast.Value) (ast.Set, error) {
return nil
})
default:
return nil, builtins.NewOperandTypeErr(2, arrayOrSet, ast.TypeName(types.Object{}), ast.TypeName(types.S), ast.TypeName(types.Array{}))
return nil, builtins.NewOperandTypeErr(2, arrayOrSet, "object", "set", "array")
}
return keys, nil
@@ -131,8 +182,35 @@ func mergeWithOverwrite(objA, objB ast.Object) ast.Object {
return merged
}
// Modifies obj with any new keys from other, and recursively
// merges any keys where the values are both objects.
func mergewithOverwriteInPlace(obj, other ast.Object, frozenKeys map[*ast.Term]struct{}) {
other.Foreach(func(k, v *ast.Term) {
v2 := obj.Get(k)
// The key didn't exist in other, keep the original value.
if v2 == nil {
obj.Insert(k, v)
return
}
// The key exists in both. Merge or reject change.
updateValueObj, ok2 := v.Value.(ast.Object)
originalValueObj, ok1 := v2.Value.(ast.Object)
// Both are objects? Merge recursively.
if ok1 && ok2 {
// Check to make sure that this key isn't frozen before merging.
if _, ok := frozenKeys[v2]; !ok {
mergewithOverwriteInPlace(originalValueObj, updateValueObj, frozenKeys)
}
} else {
// Else, original value wins. Freeze the key.
frozenKeys[v2] = struct{}{}
}
})
}
func init() {
RegisterBuiltinFunc(ast.ObjectUnion.Name, builtinObjectUnion)
RegisterBuiltinFunc(ast.ObjectUnionN.Name, builtinObjectUnionN)
RegisterBuiltinFunc(ast.ObjectRemove.Name, builtinObjectRemove)
RegisterBuiltinFunc(ast.ObjectFilter.Name, builtinObjectFilter)
RegisterBuiltinFunc(ast.ObjectGet.Name, builtinObjectGet)

View File

@@ -7,6 +7,7 @@ package topdown
import (
"bytes"
"encoding/json"
"fmt"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
@@ -42,6 +43,17 @@ func builtinRegoParseModule(a, b ast.Value) (ast.Value, error) {
return term.Value, nil
}
func registerRegoMetadataBuiltinFunction(builtin *ast.Builtin) {
f := func(BuiltinContext, []*ast.Term, func(*ast.Term) error) error {
// The compiler should replace all usage of this function, so the only way to get here is within a query;
// which cannot define rules.
return fmt.Errorf("the %s function must only be called within the scope of a rule", builtin.Name)
}
RegisterBuiltinFunc(builtin.Name, f)
}
func init() {
RegisterFunctionalBuiltin2(ast.RegoParseModule.Name, builtinRegoParseModule)
registerRegoMetadataBuiltinFunction(ast.RegoMetadataChain)
registerRegoMetadataBuiltinFunction(ast.RegoMetadataRule)
}

View File

@@ -7,147 +7,137 @@ package topdown
import (
"fmt"
"math/big"
"strconv"
"strings"
"unicode"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
const (
none int64 = 1
kb = 1000
ki = 1024
mb = kb * 1000
mi = ki * 1024
gb = mb * 1000
gi = mi * 1024
tb = gb * 1000
ti = gi * 1024
none uint64 = 1 << (10 * iota)
ki
mi
gi
ti
pi
ei
kb uint64 = 1000
mb = kb * 1000
gb = mb * 1000
tb = gb * 1000
pb = tb * 1000
eb = pb * 1000
)
// The rune values for 0..9 as well as the period symbol (for parsing floats)
var numRunes = []rune("0123456789.")
func parseNumBytesError(msg string) error {
return fmt.Errorf("%s error: %s", ast.UnitsParseBytes.Name, msg)
return fmt.Errorf("%s: %s", ast.UnitsParseBytes.Name, msg)
}
func errUnitNotRecognized(unit string) error {
func errBytesUnitNotRecognized(unit string) error {
return parseNumBytesError(fmt.Sprintf("byte unit %s not recognized", unit))
}
var (
errNoAmount = parseNumBytesError("no byte amount provided")
errIntConv = parseNumBytesError("could not parse byte amount to integer")
errIncludesSpaces = parseNumBytesError("spaces not allowed in resource strings")
errBytesValueNoAmount = parseNumBytesError("no byte amount provided")
errBytesValueNumConv = parseNumBytesError("could not parse byte amount to a number")
errBytesValueIncludesSpaces = parseNumBytesError("spaces not allowed in resource strings")
)
func builtinNumBytes(a ast.Value) (ast.Value, error) {
var m int64
func builtinNumBytes(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
var m big.Float
raw, err := builtins.StringOperand(a, 1)
raw, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return nil, err
return err
}
s := formatString(raw)
if strings.Contains(s, " ") {
return nil, errIncludesSpaces
return errBytesValueIncludesSpaces
}
numStr, unitStr := extractNumAndUnit(s)
if numStr == "" {
return nil, errNoAmount
num, unit := extractNumAndUnit(s)
if num == "" {
return errBytesValueNoAmount
}
switch unitStr {
switch unit {
case "":
m = none
case "kb":
m = kb
case "kib":
m = ki
case "mb":
m = mb
case "mib":
m = mi
case "gb":
m = gb
case "gib":
m = gi
case "tb":
m = tb
case "tib":
m = ti
m.SetUint64(none)
case "kb", "k":
m.SetUint64(kb)
case "kib", "ki":
m.SetUint64(ki)
case "mb", "m":
m.SetUint64(mb)
case "mib", "mi":
m.SetUint64(mi)
case "gb", "g":
m.SetUint64(gb)
case "gib", "gi":
m.SetUint64(gi)
case "tb", "t":
m.SetUint64(tb)
case "tib", "ti":
m.SetUint64(ti)
case "pb", "p":
m.SetUint64(pb)
case "pib", "pi":
m.SetUint64(pi)
case "eb", "e":
m.SetUint64(eb)
case "eib", "ei":
m.SetUint64(ei)
default:
return nil, errUnitNotRecognized(unitStr)
return errBytesUnitNotRecognized(unit)
}
num, err := strconv.ParseInt(numStr, 10, 64)
if err != nil {
return nil, errIntConv
numFloat, ok := new(big.Float).SetString(num)
if !ok {
return errBytesValueNumConv
}
total := num * m
return builtins.IntToNumber(big.NewInt(total)), nil
var total big.Int
numFloat.Mul(numFloat, &m).Int(&total)
return iter(ast.NewTerm(builtins.IntToNumber(&total)))
}
// Makes the string lower case and removes spaces and quotation marks
// Makes the string lower case and removes quotation marks
func formatString(s ast.String) string {
str := string(s)
lower := strings.ToLower(str)
return strings.Replace(lower, "\"", "", -1)
}
// Splits the string into a number string à la "10" or "10.2" and a unit string à la "gb" or "MiB" or "foo". Either
// can be an empty string (error handling is provided elsewhere).
// Splits the string into a number string à la "10" or "10.2" and a unit
// string à la "gb" or "MiB" or "foo". Either can be an empty string
// (error handling is provided elsewhere).
func extractNumAndUnit(s string) (string, string) {
isNum := func(r rune) (isNum bool) {
for _, nr := range numRunes {
if nr == r {
return true
}
isNum := func(r rune) bool {
return unicode.IsDigit(r) || r == '.'
}
firstNonNumIdx := -1
for idx, r := range s {
if !isNum(r) {
firstNonNumIdx = idx
break
}
return false
}
// Returns the index of the first rune that's not a number (or 0 if there are only numbers)
getFirstNonNumIdx := func(s string) int {
for idx, r := range s {
if !isNum(r) {
return idx
}
}
return 0
}
firstRuneIsNum := func(s string) bool {
return isNum(rune(s[0]))
}
firstNonNumIdx := getFirstNonNumIdx(s)
// The string contains only a number
numOnly := firstNonNumIdx == 0 && firstRuneIsNum(s)
// The string contains only a unit
unitOnly := firstNonNumIdx == 0 && !firstRuneIsNum(s)
if numOnly {
if firstNonNumIdx == -1 { // only digits and '.'
return s, ""
} else if unitOnly {
return "", s
} else {
return s[0:firstNonNumIdx], s[firstNonNumIdx:]
}
if firstNonNumIdx == 0 { // only units (starts with non-digit)
return "", s
}
return s[0:firstNonNumIdx], s[firstNonNumIdx:]
}
func init() {
RegisterFunctionalBuiltin1(ast.UnitsParseBytes.Name, builtinNumBytes)
RegisterBuiltinFunc(ast.UnitsParseBytes.Name, builtinNumBytes)
}

View File

@@ -0,0 +1,125 @@
// Copyright 2022 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"encoding/json"
"fmt"
"math/big"
"strings"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
// Binary Si unit constants are borrowed from topdown/parse_bytes
const siMilli = 0.001
const (
siK uint64 = 1000
siM = siK * 1000
siG = siM * 1000
siT = siG * 1000
siP = siT * 1000
siE = siP * 1000
)
func parseUnitsError(msg string) error {
return fmt.Errorf("%s: %s", ast.UnitsParse.Name, msg)
}
func errUnitNotRecognized(unit string) error {
return parseUnitsError(fmt.Sprintf("unit %s not recognized", unit))
}
var (
errNoAmount = parseUnitsError("no amount provided")
errNumConv = parseUnitsError("could not parse amount to a number")
errIncludesSpaces = parseUnitsError("spaces not allowed in resource strings")
)
// Accepts both normal SI and binary SI units.
func builtinUnits(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
var x big.Rat
raw, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return err
}
// We remove escaped quotes from strings here to retain parity with units.parse_bytes.
s := string(raw)
s = strings.Replace(s, "\"", "", -1)
if strings.Contains(s, " ") {
return errIncludesSpaces
}
num, unit := extractNumAndUnit(s)
if num == "" {
return errNoAmount
}
// Unlike in units.parse_bytes, we only lowercase after the first letter,
// so that we can distinguish between 'm' and 'M'.
if len(unit) > 1 {
lower := strings.ToLower(unit[1:])
unit = unit[:1] + lower
}
switch unit {
case "m":
x.SetFloat64(siMilli)
case "":
x.SetUint64(none)
case "k", "K":
x.SetUint64(siK)
case "ki", "Ki":
x.SetUint64(ki)
case "M":
x.SetUint64(siM)
case "mi", "Mi":
x.SetUint64(mi)
case "g", "G":
x.SetUint64(siG)
case "gi", "Gi":
x.SetUint64(gi)
case "t", "T":
x.SetUint64(siT)
case "ti", "Ti":
x.SetUint64(ti)
case "p", "P":
x.SetUint64(siP)
case "pi", "Pi":
x.SetUint64(pi)
case "e", "E":
x.SetUint64(siE)
case "ei", "Ei":
x.SetUint64(ei)
default:
return errUnitNotRecognized(unit)
}
numRat, ok := new(big.Rat).SetString(num)
if !ok {
return errNumConv
}
numRat.Mul(numRat, &x)
// Cleaner printout when we have a pure integer value.
if numRat.IsInt() {
return iter(ast.NumberTerm(json.Number(numRat.Num().String())))
}
// When using just big.Float, we had floating-point precision
// issues because quantities like 0.001 are not exactly representable.
// Rationals (such as big.Rat) do not suffer this problem, but are
// more expensive to compute with in general.
return iter(ast.NumberTerm(json.Number(numRat.FloatString(10))))
}
func init() {
RegisterBuiltinFunc(ast.UnitsParse.Name, builtinUnits)
}

View File

@@ -0,0 +1,86 @@
// Copyright 2021 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"fmt"
"io"
"strings"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
"github.com/open-policy-agent/opa/topdown/print"
)
func NewPrintHook(w io.Writer) print.Hook {
return printHook{w: w}
}
type printHook struct {
w io.Writer
}
func (h printHook) Print(_ print.Context, msg string) error {
_, err := fmt.Fprintln(h.w, msg)
return err
}
func builtinPrint(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
if bctx.PrintHook == nil {
return iter(nil)
}
arr, err := builtins.ArrayOperand(operands[0].Value, 1)
if err != nil {
return err
}
buf := make([]string, arr.Len())
err = builtinPrintCrossProductOperands(bctx, buf, arr, 0, func(buf []string) error {
pctx := print.Context{
Context: bctx.Context,
Location: bctx.Location,
}
return bctx.PrintHook.Print(pctx, strings.Join(buf, " "))
})
if err != nil {
return err
}
return iter(nil)
}
func builtinPrintCrossProductOperands(bctx BuiltinContext, buf []string, operands *ast.Array, i int, f func([]string) error) error {
if i >= operands.Len() {
return f(buf)
}
xs, ok := operands.Elem(i).Value.(ast.Set)
if !ok {
return Halt{Err: internalErr(bctx.Location, fmt.Sprintf("illegal argument type: %v", ast.TypeName(operands.Elem(i).Value)))}
}
if xs.Len() == 0 {
buf[i] = "<undefined>"
return builtinPrintCrossProductOperands(bctx, buf, operands, i+1, f)
}
return xs.Iter(func(x *ast.Term) error {
switch v := x.Value.(type) {
case ast.String:
buf[i] = string(v)
default:
buf[i] = v.String()
}
return builtinPrintCrossProductOperands(bctx, buf, operands, i+1, f)
})
}
func init() {
RegisterBuiltinFunc(ast.InternalPrint.Name, builtinPrint)
}

View File

@@ -0,0 +1,21 @@
package print
import (
"context"
"github.com/open-policy-agent/opa/ast"
)
// Context provides the Hook implementation context about the print() call.
type Context struct {
Context context.Context // request context passed when query executed
Location *ast.Location // location of print call
}
// Hook defines the interface that callers can implement to receive print
// statement outputs. If the hook returns an error, it will be surfaced if
// strict builtin error checking is enabled (otherwise, it will not halt
// execution.)
type Hook interface {
Print(Context, string) error
}

View File

@@ -2,13 +2,20 @@ package topdown
import (
"context"
"crypto/rand"
"io"
"sort"
"time"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/metrics"
"github.com/open-policy-agent/opa/resolver"
"github.com/open-policy-agent/opa/storage"
"github.com/open-policy-agent/opa/topdown/builtins"
"github.com/open-policy-agent/opa/topdown/cache"
"github.com/open-policy-agent/opa/topdown/copypropagation"
"github.com/open-policy-agent/opa/topdown/print"
"github.com/open-policy-agent/opa/tracing"
)
// QueryResultSet represents a collection of results returned by a query.
@@ -20,23 +27,35 @@ type QueryResult map[ast.Var]*ast.Term
// Query provides a configurable interface for performing query evaluation.
type Query struct {
cancel Cancel
query ast.Body
queryCompiler ast.QueryCompiler
compiler *ast.Compiler
store storage.Store
txn storage.Transaction
input *ast.Term
tracers []Tracer
unknowns []*ast.Term
partialNamespace string
metrics metrics.Metrics
instr *Instrumentation
disableInlining []ast.Ref
genvarprefix string
runtime *ast.Term
builtins map[string]*Builtin
indexing bool
seed io.Reader
time time.Time
cancel Cancel
query ast.Body
queryCompiler ast.QueryCompiler
compiler *ast.Compiler
store storage.Store
txn storage.Transaction
input *ast.Term
external *resolverTrie
tracers []QueryTracer
plugTraceVars bool
unknowns []*ast.Term
partialNamespace string
skipSaveNamespace bool
metrics metrics.Metrics
instr *Instrumentation
disableInlining []ast.Ref
shallowInlining bool
genvarprefix string
runtime *ast.Term
builtins map[string]*Builtin
indexing bool
earlyExit bool
interQueryBuiltinCache cache.InterQueryCache
ndBuiltinCache builtins.NDBCache
strictBuiltinErrors bool
printHook print.Hook
tracingOpts tracing.Options
}
// Builtin represents a built-in function that queries can call.
@@ -51,6 +70,8 @@ func NewQuery(query ast.Body) *Query {
query: query,
genvarprefix: ast.WildcardPrefix,
indexing: true,
earlyExit: true,
external: newResolverTrie(),
}
}
@@ -94,8 +115,31 @@ func (q *Query) WithInput(input *ast.Term) *Query {
}
// WithTracer adds a query tracer to use during evaluation. This is optional.
// Deprecated: Use WithQueryTracer instead.
func (q *Query) WithTracer(tracer Tracer) *Query {
qt, ok := tracer.(QueryTracer)
if !ok {
qt = WrapLegacyTracer(tracer)
}
return q.WithQueryTracer(qt)
}
// WithQueryTracer adds a query tracer to use during evaluation. This is optional.
// Disabled QueryTracers will be ignored.
func (q *Query) WithQueryTracer(tracer QueryTracer) *Query {
if !tracer.Enabled() {
return q
}
q.tracers = append(q.tracers, tracer)
// If *any* of the tracers require local variable metadata we need to
// enabled plugging local trace variables.
conf := tracer.Config()
if conf.PlugLocalVars {
q.plugTraceVars = true
}
return q
}
@@ -128,6 +172,13 @@ func (q *Query) WithPartialNamespace(ns string) *Query {
return q
}
// WithSkipPartialNamespace disables namespacing of saved support rules that are generated
// from the original policy (rules which are completely synthetic are still namespaced.)
func (q *Query) WithSkipPartialNamespace(yes bool) *Query {
q.skipSaveNamespace = yes
return q
}
// WithDisableInlining adds a set of paths to the query that should be excluded from
// inlining. Inlining during partial evaluation can be expensive in some cases
// (e.g., when a cross-product is computed.) Disabling inlining avoids expensive
@@ -137,6 +188,14 @@ func (q *Query) WithDisableInlining(paths []ast.Ref) *Query {
return q
}
// WithShallowInlining disables aggressive inlining performed during partial evaluation.
// When shallow inlining is enabled rules that depend (transitively) on unknowns are not inlined.
// Only rules/values that are completely known will be inlined.
func (q *Query) WithShallowInlining(yes bool) *Query {
q.shallowInlining = yes
return q
}
// WithRuntime sets the runtime data to execute the query with. The runtime data
// can be returned by the `opa.runtime` built-in function.
func (q *Query) WithRuntime(runtime *ast.Term) *Query {
@@ -158,6 +217,61 @@ func (q *Query) WithIndexing(enabled bool) *Query {
return q
}
// WithEarlyExit will enable or disable using 'early exit' for the evaluation
// of the query. The default is enabled.
func (q *Query) WithEarlyExit(enabled bool) *Query {
q.earlyExit = enabled
return q
}
// WithSeed sets a reader that will seed randomization required by built-in functions.
// If a seed is not provided crypto/rand.Reader is used.
func (q *Query) WithSeed(r io.Reader) *Query {
q.seed = r
return q
}
// WithTime sets the time that will be returned by the time.now_ns() built-in function.
func (q *Query) WithTime(x time.Time) *Query {
q.time = x
return q
}
// WithInterQueryBuiltinCache sets the inter-query cache that built-in functions can utilize.
func (q *Query) WithInterQueryBuiltinCache(c cache.InterQueryCache) *Query {
q.interQueryBuiltinCache = c
return q
}
// WithNDBuiltinCache sets the non-deterministic builtin cache.
func (q *Query) WithNDBuiltinCache(c builtins.NDBCache) *Query {
q.ndBuiltinCache = c
return q
}
// WithStrictBuiltinErrors tells the evaluator to treat all built-in function errors as fatal errors.
func (q *Query) WithStrictBuiltinErrors(yes bool) *Query {
q.strictBuiltinErrors = yes
return q
}
// WithResolver configures an external resolver to use for the given ref.
func (q *Query) WithResolver(ref ast.Ref, r resolver.Resolver) *Query {
q.external.Put(ref, r)
return q
}
func (q *Query) WithPrintHook(h print.Hook) *Query {
q.printHook = h
return q
}
// WithDistributedTracingOpts sets the options to be used by distributed tracing.
func (q *Query) WithDistributedTracingOpts(tr tracing.Options) *Query {
q.tracingOpts = tr
return q
}
// PartialRun executes partial evaluation on the query with respect to unknown
// values. Partial evaluation attempts to evaluate as much of the query as
// possible without requiring values for the unknowns set on the query. The
@@ -169,45 +283,79 @@ func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support []
if q.partialNamespace == "" {
q.partialNamespace = "partial" // lazily initialize partial namespace
}
if q.seed == nil {
q.seed = rand.Reader
}
if !q.time.IsZero() {
q.time = time.Now()
}
if q.metrics == nil {
q.metrics = metrics.New()
}
f := &queryIDFactory{}
b := newBindings(0, q.instr)
e := &eval{
ctx: ctx,
cancel: q.cancel,
query: q.query,
queryCompiler: q.queryCompiler,
queryIDFact: f,
queryID: f.Next(),
bindings: b,
compiler: q.compiler,
store: q.store,
baseCache: newBaseCache(),
targetStack: newRefStack(),
txn: q.txn,
input: q.input,
tracers: q.tracers,
instr: q.instr,
builtins: q.builtins,
builtinCache: builtins.Cache{},
virtualCache: newVirtualCache(),
saveSet: newSaveSet(q.unknowns, b, q.instr),
saveStack: newSaveStack(),
saveSupport: newSaveSupport(),
saveNamespace: ast.StringTerm(q.partialNamespace),
ctx: ctx,
metrics: q.metrics,
seed: q.seed,
time: ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())),
cancel: q.cancel,
query: q.query,
queryCompiler: q.queryCompiler,
queryIDFact: f,
queryID: f.Next(),
bindings: b,
compiler: q.compiler,
store: q.store,
baseCache: newBaseCache(),
targetStack: newRefStack(),
txn: q.txn,
input: q.input,
external: q.external,
tracers: q.tracers,
traceEnabled: len(q.tracers) > 0,
plugTraceVars: q.plugTraceVars,
instr: q.instr,
builtins: q.builtins,
builtinCache: builtins.Cache{},
functionMocks: newFunctionMocksStack(),
interQueryBuiltinCache: q.interQueryBuiltinCache,
ndBuiltinCache: q.ndBuiltinCache,
virtualCache: newVirtualCache(),
comprehensionCache: newComprehensionCache(),
saveSet: newSaveSet(q.unknowns, b, q.instr),
saveStack: newSaveStack(),
saveSupport: newSaveSupport(),
saveNamespace: ast.StringTerm(q.partialNamespace),
skipSaveNamespace: q.skipSaveNamespace,
inliningControl: &inliningControl{
shallow: q.shallowInlining,
},
genvarprefix: q.genvarprefix,
runtime: q.runtime,
indexing: q.indexing,
earlyExit: q.earlyExit,
builtinErrors: &builtinErrors{},
printHook: q.printHook,
}
if len(q.disableInlining) > 0 {
e.disableInlining = [][]ast.Ref{q.disableInlining}
e.inliningControl.PushDisable(q.disableInlining, false)
}
e.caller = e
q.startTimer(metrics.RegoPartialEval)
defer q.stopTimer(metrics.RegoPartialEval)
q.metrics.Timer(metrics.RegoPartialEval).Start()
defer q.metrics.Timer(metrics.RegoPartialEval).Stop()
livevars := ast.NewVarSet()
for _, t := range q.unknowns {
switch v := t.Value.(type) {
case ast.Var:
livevars.Add(v)
case ast.Ref:
livevars.Add(v[0].Value.(ast.Var))
}
}
ast.WalkVars(q.query, func(x ast.Var) bool {
if !x.IsGenerated() {
@@ -216,7 +364,7 @@ func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support []
return false
})
p := copypropagation.New(livevars)
p := copypropagation.New(livevars).WithCompiler(q.compiler)
err = e.Run(func(e *eval) error {
@@ -230,10 +378,10 @@ func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support []
// Include bindings as exprs so that when caller evals the result, they
// can obtain values for the vars in their query.
bindingExprs := []*ast.Expr{}
e.bindings.Iter(e.bindings, func(a, b *ast.Term) error {
_ = e.bindings.Iter(e.bindings, func(a, b *ast.Term) error {
bindingExprs = append(bindingExprs, ast.Equality.Expr(a, b))
return nil
})
}) // cannot return error
// Sort binding expressions so that results are deterministic.
sort.Slice(bindingExprs, func(i, j int) bool {
@@ -244,12 +392,32 @@ func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support []
body.Append(bindingExprs[i])
}
partials = append(partials, applyCopyPropagation(p, e.instr, body))
// Skip this rule body if it fails to type-check.
// Type-checking failure means the rule body will never succeed.
if !e.compiler.PassesTypeCheck(body) {
return nil
}
if !q.shallowInlining {
body = applyCopyPropagation(p, e.instr, body)
}
partials = append(partials, body)
return nil
})
support = e.saveSupport.List()
if q.strictBuiltinErrors && len(e.builtinErrors.errs) > 0 {
err = e.builtinErrors.errs[0]
}
for i := range support {
sort.Slice(support[i].Rules, func(j, k int) bool {
return support[i].Rules[j].Compare(support[i].Rules[k]) < 0
})
}
return partials, support, err
}
@@ -266,52 +434,68 @@ func (q *Query) Run(ctx context.Context) (QueryResultSet, error) {
// Iter executes the query and invokes the iter function with query results
// produced by evaluating the query.
func (q *Query) Iter(ctx context.Context, iter func(QueryResult) error) error {
if q.seed == nil {
q.seed = rand.Reader
}
if q.time.IsZero() {
q.time = time.Now()
}
if q.metrics == nil {
q.metrics = metrics.New()
}
f := &queryIDFactory{}
e := &eval{
ctx: ctx,
cancel: q.cancel,
query: q.query,
queryCompiler: q.queryCompiler,
queryIDFact: f,
queryID: f.Next(),
bindings: newBindings(0, q.instr),
compiler: q.compiler,
store: q.store,
baseCache: newBaseCache(),
targetStack: newRefStack(),
txn: q.txn,
input: q.input,
tracers: q.tracers,
instr: q.instr,
builtins: q.builtins,
builtinCache: builtins.Cache{},
virtualCache: newVirtualCache(),
genvarprefix: q.genvarprefix,
runtime: q.runtime,
indexing: q.indexing,
ctx: ctx,
metrics: q.metrics,
seed: q.seed,
time: ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())),
cancel: q.cancel,
query: q.query,
queryCompiler: q.queryCompiler,
queryIDFact: f,
queryID: f.Next(),
bindings: newBindings(0, q.instr),
compiler: q.compiler,
store: q.store,
baseCache: newBaseCache(),
targetStack: newRefStack(),
txn: q.txn,
input: q.input,
external: q.external,
tracers: q.tracers,
traceEnabled: len(q.tracers) > 0,
plugTraceVars: q.plugTraceVars,
instr: q.instr,
builtins: q.builtins,
builtinCache: builtins.Cache{},
functionMocks: newFunctionMocksStack(),
interQueryBuiltinCache: q.interQueryBuiltinCache,
ndBuiltinCache: q.ndBuiltinCache,
virtualCache: newVirtualCache(),
comprehensionCache: newComprehensionCache(),
genvarprefix: q.genvarprefix,
runtime: q.runtime,
indexing: q.indexing,
earlyExit: q.earlyExit,
builtinErrors: &builtinErrors{},
printHook: q.printHook,
tracingOpts: q.tracingOpts,
}
e.caller = e
q.startTimer(metrics.RegoQueryEval)
q.metrics.Timer(metrics.RegoQueryEval).Start()
err := e.Run(func(e *eval) error {
qr := QueryResult{}
e.bindings.Iter(nil, func(k, v *ast.Term) error {
_ = e.bindings.Iter(nil, func(k, v *ast.Term) error {
qr[k.Value.(ast.Var)] = v
return nil
})
}) // cannot return error
return iter(qr)
})
q.stopTimer(metrics.RegoQueryEval)
if q.strictBuiltinErrors && err == nil && len(e.builtinErrors.errs) > 0 {
err = e.builtinErrors.errs[0]
}
q.metrics.Timer(metrics.RegoQueryEval).Stop()
return err
}
func (q *Query) startTimer(name string) {
if q.metrics != nil {
q.metrics.Timer(name).Start()
}
}
func (q *Query) stopTimer(name string) {
if q.metrics != nil {
q.metrics.Timer(name).Stop()
}
}

View File

@@ -0,0 +1,142 @@
// Copyright 2020 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
// Helper: sets of vertices can be represented as Arrays or Sets.
func foreachVertex(collection *ast.Term, f func(*ast.Term)) {
switch v := collection.Value.(type) {
case ast.Set:
v.Foreach(f)
case *ast.Array:
v.Foreach(f)
}
}
// numberOfEdges returns the number of elements of an array or a set (of edges)
func numberOfEdges(collection *ast.Term) int {
switch v := collection.Value.(type) {
case ast.Set:
return v.Len()
case *ast.Array:
return v.Len()
}
return 0
}
func builtinReachable(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
// Error on wrong types for args.
graph, err := builtins.ObjectOperand(args[0].Value, 1)
if err != nil {
return err
}
var queue []*ast.Term
switch initial := args[1].Value.(type) {
case *ast.Array, ast.Set:
foreachVertex(ast.NewTerm(initial), func(t *ast.Term) {
queue = append(queue, t)
})
default:
return builtins.NewOperandTypeErr(2, initial, "{array, set}")
}
// This is the set of nodes we have reached.
reached := ast.NewSet()
// Keep going as long as we have nodes in the queue.
for len(queue) > 0 {
// Get the edges for this node. If the node was not in the graph,
// `edges` will be `nil` and we can ignore it.
node := queue[0]
if edges := graph.Get(node); edges != nil {
// Add all the newly discovered neighbors.
foreachVertex(edges, func(neighbor *ast.Term) {
if !reached.Contains(neighbor) {
queue = append(queue, neighbor)
}
})
// Mark the node as reached.
reached.Add(node)
}
queue = queue[1:]
}
return iter(ast.NewTerm(reached))
}
// pathBuilder is called recursively to build an array of paths that are reachable from the root
func pathBuilder(graph ast.Object, root *ast.Term, path []*ast.Term, paths []*ast.Term, reached ast.Set) []*ast.Term {
if edges := graph.Get(root); edges != nil {
path = append(path, root)
if numberOfEdges(edges) >= 1 {
foreachVertex(edges, func(neighbor *ast.Term) {
if reached.Contains(neighbor) {
// If we've already reached this node, return current path (avoid infinite recursion)
paths = append(paths, path...)
} else {
reached.Add(root)
paths = pathBuilder(graph, neighbor, path, paths, reached)
}
})
} else {
paths = append(paths, path...)
}
} else {
// Node is nonexistent (not in graph). Commit the current path (without adding this root)
paths = append(paths, path...)
}
return paths
}
func builtinReachablePaths(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
// Error on wrong types for args.
graph, err := builtins.ObjectOperand(args[0].Value, 1)
if err != nil {
return err
}
// This is a queue that holds all nodes we still need to visit. It is
// initialised to the initial set of nodes we start out with.
var queue []*ast.Term
switch initial := args[1].Value.(type) {
case *ast.Array, ast.Set:
foreachVertex(ast.NewTerm(initial), func(t *ast.Term) {
queue = append(queue, t)
})
default:
return builtins.NewOperandTypeErr(2, initial, "{array, set}")
}
results := ast.NewSet()
for _, node := range queue {
// Find reachable paths from edges in root node in queue and append arrays to the results set
if edges := graph.Get(node); edges != nil {
if numberOfEdges(edges) >= 1 {
foreachVertex(edges, func(neighbor *ast.Term) {
paths := pathBuilder(graph, neighbor, []*ast.Term{node}, []*ast.Term{}, ast.NewSet(node))
results.Add(ast.ArrayTerm(paths...))
})
} else {
results.Add(ast.ArrayTerm(node))
}
}
}
return iter(ast.NewTerm(results))
}
func init() {
RegisterBuiltinFunc(ast.ReachableBuiltin.Name, builtinReachable)
RegisterBuiltinFunc(ast.ReachablePathsBuiltin.Name, builtinReachablePaths)
}

View File

@@ -9,7 +9,7 @@ import (
"regexp"
"sync"
"github.com/yashtewari/glob-intersection"
gintersect "github.com/yashtewari/glob-intersection"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
@@ -18,6 +18,21 @@ import (
var regexpCacheLock = sync.Mutex{}
var regexpCache map[string]*regexp.Regexp
func builtinRegexIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
s, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return iter(ast.BooleanTerm(false))
}
_, err = regexp.Compile(string(s))
if err != nil {
return iter(ast.BooleanTerm(false))
}
return iter(ast.BooleanTerm(true))
}
func builtinRegexMatch(a, b ast.Value) (ast.Value, error) {
s1, err := builtins.StringOperand(a, 1)
if err != nil {
@@ -79,11 +94,11 @@ func builtinRegexSplit(a, b ast.Value) (ast.Value, error) {
}
elems := re.Split(string(s2), -1)
arr := make(ast.Array, len(elems))
for i := range arr {
arr := make([]*ast.Term, len(elems))
for i := range elems {
arr[i] = ast.StringTerm(elems[i])
}
return arr, nil
return ast.NewArray(arr...), nil
}
func getRegexp(pat string) (*regexp.Regexp, error) {
@@ -151,11 +166,11 @@ func builtinRegexFind(a, b, c ast.Value) (ast.Value, error) {
}
elems := re.FindAllString(string(s2), n)
arr := make(ast.Array, len(elems))
for i := range arr {
arr := make([]*ast.Term, len(elems))
for i := range elems {
arr[i] = ast.StringTerm(elems[i])
}
return arr, nil
return ast.NewArray(arr...), nil
}
func builtinRegexFindAllStringSubmatch(a, b, c ast.Value) (ast.Value, error) {
@@ -178,24 +193,53 @@ func builtinRegexFindAllStringSubmatch(a, b, c ast.Value) (ast.Value, error) {
}
matches := re.FindAllStringSubmatch(string(s2), n)
outer := make(ast.Array, len(matches))
for i := range outer {
inner := make(ast.Array, len(matches[i]))
for j := range inner {
outer := make([]*ast.Term, len(matches))
for i := range matches {
inner := make([]*ast.Term, len(matches[i]))
for j := range matches[i] {
inner[j] = ast.StringTerm(matches[i][j])
}
outer[i] = ast.ArrayTerm(inner...)
outer[i] = ast.NewTerm(ast.NewArray(inner...))
}
return outer, nil
return ast.NewArray(outer...), nil
}
func builtinRegexReplace(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
base, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return err
}
pattern, err := builtins.StringOperand(operands[1].Value, 2)
if err != nil {
return err
}
value, err := builtins.StringOperand(operands[2].Value, 3)
if err != nil {
return err
}
re, err := getRegexp(string(pattern))
if err != nil {
return err
}
res := re.ReplaceAllString(string(base), string(value))
return iter(ast.StringTerm(res))
}
func init() {
regexpCache = map[string]*regexp.Regexp{}
RegisterBuiltinFunc(ast.RegexIsValid.Name, builtinRegexIsValid)
RegisterFunctionalBuiltin2(ast.RegexMatch.Name, builtinRegexMatch)
RegisterFunctionalBuiltin2(ast.RegexMatchDeprecated.Name, builtinRegexMatch)
RegisterFunctionalBuiltin2(ast.RegexSplit.Name, builtinRegexSplit)
RegisterFunctionalBuiltin2(ast.GlobsMatch.Name, builtinGlobsMatch)
RegisterFunctionalBuiltin4(ast.RegexTemplateMatch.Name, builtinRegexMatchTemplate)
RegisterFunctionalBuiltin3(ast.RegexFind.Name, builtinRegexFind)
RegisterFunctionalBuiltin3(ast.RegexFindAllStringSubmatch.Name, builtinRegexFindAllStringSubmatch)
RegisterBuiltinFunc(ast.RegexReplace.Name, builtinRegexReplace)
}

View File

@@ -0,0 +1,107 @@
// Copyright 2020 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/metrics"
"github.com/open-policy-agent/opa/resolver"
)
type resolverTrie struct {
r resolver.Resolver
children map[ast.Value]*resolverTrie
}
func newResolverTrie() *resolverTrie {
return &resolverTrie{children: map[ast.Value]*resolverTrie{}}
}
func (t *resolverTrie) Put(ref ast.Ref, r resolver.Resolver) {
node := t
for _, t := range ref {
child, ok := node.children[t.Value]
if !ok {
child = &resolverTrie{children: map[ast.Value]*resolverTrie{}}
node.children[t.Value] = child
}
node = child
}
node.r = r
}
func (t *resolverTrie) Resolve(e *eval, ref ast.Ref) (ast.Value, error) {
e.metrics.Timer(metrics.RegoExternalResolve).Start()
defer e.metrics.Timer(metrics.RegoExternalResolve).Stop()
node := t
for i, t := range ref {
child, ok := node.children[t.Value]
if !ok {
return nil, nil
}
node = child
if node.r != nil {
in := resolver.Input{
Ref: ref[:i+1],
Input: e.input,
Metrics: e.metrics,
}
e.traceWasm(e.query[e.index], &in.Ref)
if e.data != nil {
return nil, errInScopeWithStmt
}
result, err := node.r.Eval(e.ctx, in)
if err != nil {
return nil, err
}
if result.Value == nil {
return nil, nil
}
val, err := result.Value.Find(ref[i+1:])
if err != nil {
return nil, nil
}
return val, nil
}
}
return node.mktree(e, resolver.Input{
Ref: ref,
Input: e.input,
Metrics: e.metrics,
})
}
func (t *resolverTrie) mktree(e *eval, in resolver.Input) (ast.Value, error) {
if t.r != nil {
e.traceWasm(e.query[e.index], &in.Ref)
if e.data != nil {
return nil, errInScopeWithStmt
}
result, err := t.r.Eval(e.ctx, in)
if err != nil {
return nil, err
}
if result.Value == nil {
return nil, nil
}
return result.Value, nil
}
obj := ast.NewObject()
for k, child := range t.children {
v, err := child.mktree(e, resolver.Input{Ref: append(in.Ref, ast.NewTerm(k)), Input: in.Input, Metrics: in.Metrics})
if err != nil {
return nil, err
}
if v != nil {
obj.Insert(ast.NewTerm(k), ast.NewTerm(v))
}
}
return obj, nil
}
var errInScopeWithStmt = &Error{
Code: InternalErr,
Message: "wasm cannot be executed when 'with' statements are in-scope",
}

View File

@@ -4,7 +4,11 @@
package topdown
import "github.com/open-policy-agent/opa/ast"
import (
"fmt"
"github.com/open-policy-agent/opa/ast"
)
func builtinOPARuntime(bctx BuiltinContext, _ []*ast.Term, iter func(*ast.Term) error) error {
@@ -12,9 +16,112 @@ func builtinOPARuntime(bctx BuiltinContext, _ []*ast.Term, iter func(*ast.Term)
return iter(ast.ObjectTerm())
}
if bctx.Runtime.Get(ast.StringTerm("config")) != nil {
iface, err := ast.ValueToInterface(bctx.Runtime.Value, illegalResolver{})
if err != nil {
return err
}
if object, ok := iface.(map[string]interface{}); ok {
if cfgRaw, ok := object["config"]; ok {
if config, ok := cfgRaw.(map[string]interface{}); ok {
configPurged, err := activeConfig(config)
if err != nil {
return err
}
object["config"] = configPurged
value, err := ast.InterfaceToValue(object)
if err != nil {
return err
}
return iter(ast.NewTerm(value))
}
}
}
}
return iter(bctx.Runtime)
}
func init() {
RegisterBuiltinFunc(ast.OPARuntime.Name, builtinOPARuntime)
}
func activeConfig(config map[string]interface{}) (interface{}, error) {
if config["services"] != nil {
err := removeServiceCredentials(config["services"])
if err != nil {
return nil, err
}
}
if config["keys"] != nil {
err := removeCryptoKeys(config["keys"])
if err != nil {
return nil, err
}
}
return config, nil
}
func removeServiceCredentials(x interface{}) error {
switch x := x.(type) {
case []interface{}:
for _, v := range x {
err := removeKey(v, "credentials")
if err != nil {
return err
}
}
case map[string]interface{}:
for _, v := range x {
err := removeKey(v, "credentials")
if err != nil {
return err
}
}
default:
return fmt.Errorf("illegal service config type: %T", x)
}
return nil
}
func removeCryptoKeys(x interface{}) error {
switch x := x.(type) {
case map[string]interface{}:
for _, v := range x {
err := removeKey(v, "key", "private_key")
if err != nil {
return err
}
}
default:
return fmt.Errorf("illegal keys config type: %T", x)
}
return nil
}
func removeKey(x interface{}, keys ...string) error {
val, ok := x.(map[string]interface{})
if !ok {
return fmt.Errorf("type assertion error")
}
for _, key := range keys {
delete(val, key)
}
return nil
}
type illegalResolver struct{}
func (illegalResolver) Resolve(ref ast.Ref) (interface{}, error) {
return nil, fmt.Errorf("illegal value: %v", ref)
}

View File

@@ -57,7 +57,7 @@ func (ss *saveSet) contains(t *ast.Term, b *bindings) bool {
return false
}
// ContainsRecursive retruns true if the term t is or contains a term that is
// ContainsRecursive returns true if the term t is or contains a term that is
// contained in the save set. This function will close over the binding list
// when it encounters vars.
func (ss *saveSet) ContainsRecursive(t *ast.Term, b *bindings) bool {
@@ -279,7 +279,7 @@ func newSaveSupport() *saveSupport {
}
func (s *saveSupport) List() []*ast.Module {
result := []*ast.Module{}
result := make([]*ast.Module, 0, len(s.modules))
for _, module := range s.modules {
result = append(result, module)
}
@@ -321,7 +321,7 @@ func (s *saveSupport) Insert(path ast.Ref, rule *ast.Rule) {
// being saved. This check allows the evaluator to evaluate statements
// completely during partial evaluation as long as they do not depend on any
// kind of unknown value or statements that would generate saves.
func saveRequired(c *ast.Compiler, ss *saveSet, b *bindings, x interface{}, rec bool) bool {
func saveRequired(c *ast.Compiler, ic *inliningControl, icIgnoreInternal bool, ss *saveSet, b *bindings, x interface{}, rec bool) bool {
var found bool
@@ -344,9 +344,11 @@ func saveRequired(c *ast.Compiler, ss *saveSet, b *bindings, x interface{}, rec
case ast.Ref:
if ss.Contains(node, b) {
found = true
} else if ic.Disabled(v.ConstantPrefix(), icIgnoreInternal) {
found = true
} else {
for _, rule := range c.GetRulesDynamic(v) {
if saveRequired(c, ss, b, rule, true) {
for _, rule := range c.GetRulesDynamicWithOpts(v, ast.RulesOptions{IncludeHiddenModules: false}) {
if saveRequired(c, ic, icIgnoreInternal, ss, b, rule, true) {
found = true
break
}
@@ -373,10 +375,57 @@ func ignoreExprDuringPartial(expr *ast.Expr) bool {
}
func ignoreDuringPartial(bi *ast.Builtin) bool {
// Note(philipc): We keep this legacy check around to avoid breaking
// existing library users.
//nolint:staticcheck // We specifically ignore our own linter warning here.
for _, ignore := range ast.IgnoreDuringPartialEval {
if bi == ignore {
return true
}
}
// Otherwise, ensure all non-deterministic builtins are thrown out.
return bi.Nondeterministic
}
type inliningControl struct {
shallow bool
disable []disableInliningFrame
}
type disableInliningFrame struct {
internal bool
refs []ast.Ref
}
func (i *inliningControl) PushDisable(refs []ast.Ref, internal bool) {
if i == nil {
return
}
i.disable = append(i.disable, disableInliningFrame{
internal: internal,
refs: refs,
})
}
func (i *inliningControl) PopDisable() {
if i == nil {
return
}
i.disable = i.disable[:len(i.disable)-1]
}
func (i *inliningControl) Disabled(ref ast.Ref, ignoreInternal bool) bool {
if i == nil {
return false
}
for _, frame := range i.disable {
if !frame.internal || !ignoreInternal {
for _, other := range frame.refs {
if other.HasPrefix(ref) || ref.HasPrefix(other) {
return true
}
}
}
}
return false
}

View File

@@ -0,0 +1,59 @@
// Copyright 2020 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"fmt"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/internal/semver"
"github.com/open-policy-agent/opa/topdown/builtins"
)
func builtinSemVerCompare(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
versionStringA, err := builtins.StringOperand(args[0].Value, 1)
if err != nil {
return err
}
versionStringB, err := builtins.StringOperand(args[1].Value, 2)
if err != nil {
return err
}
versionA, err := semver.NewVersion(string(versionStringA))
if err != nil {
return fmt.Errorf("operand 1: string %s is not a valid SemVer", versionStringA)
}
versionB, err := semver.NewVersion(string(versionStringB))
if err != nil {
return fmt.Errorf("operand 2: string %s is not a valid SemVer", versionStringB)
}
result := versionA.Compare(*versionB)
return iter(ast.IntNumberTerm(result))
}
func builtinSemVerIsValid(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
versionString, err := builtins.StringOperand(args[0].Value, 1)
if err != nil {
return iter(ast.BooleanTerm(false))
}
result := true
_, err = semver.NewVersion(string(versionString))
if err != nil {
result = false
}
return iter(ast.BooleanTerm(result))
}
func init() {
RegisterBuiltinFunc(ast.SemVerCompare.Name, builtinSemVerCompare)
RegisterBuiltinFunc(ast.SemVerIsValid.Name, builtinSemVerIsValid)
}

View File

@@ -58,22 +58,26 @@ func builtinSetIntersection(a ast.Value) (ast.Value, error) {
// builtinSetUnion returns the union of the given input sets
func builtinSetUnion(a ast.Value) (ast.Value, error) {
// The set union logic here is duplicated and manually inlined on
// purpose. By lifting this logic up a level, and not doing pairwise
// set unions, we avoid a number of heap allocations. This improves
// performance dramatically over the naive approach.
result := ast.NewSet()
inputSet, err := builtins.SetOperand(a, 1)
if err != nil {
return nil, err
}
result := ast.NewSet()
err = inputSet.Iter(func(x *ast.Term) error {
n, err := builtins.SetOperand(x.Value, 1)
item, err := builtins.SetOperand(x.Value, 1)
if err != nil {
return err
}
result = result.Union(n)
item.Foreach(result.Add)
return nil
})
return result, err
}

View File

@@ -5,14 +5,113 @@
package topdown
import (
"errors"
"fmt"
"math/big"
"sort"
"strings"
"github.com/tchap/go-patricia/v2/patricia"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
func builtinAnyPrefixMatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
a, b := operands[0].Value, operands[1].Value
var strs []string
switch a := a.(type) {
case ast.String:
strs = []string{string(a)}
case *ast.Array, ast.Set:
var err error
strs, err = builtins.StringSliceOperand(a, 1)
if err != nil {
return err
}
default:
return builtins.NewOperandTypeErr(1, a, "string", "set", "array")
}
var prefixes []string
switch b := b.(type) {
case ast.String:
prefixes = []string{string(b)}
case *ast.Array, ast.Set:
var err error
prefixes, err = builtins.StringSliceOperand(b, 2)
if err != nil {
return err
}
default:
return builtins.NewOperandTypeErr(2, b, "string", "set", "array")
}
return iter(ast.BooleanTerm(anyStartsWithAny(strs, prefixes)))
}
func builtinAnySuffixMatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
a, b := operands[0].Value, operands[1].Value
var strsReversed []string
switch a := a.(type) {
case ast.String:
strsReversed = []string{reverseString(string(a))}
case *ast.Array, ast.Set:
strs, err := builtins.StringSliceOperand(a, 1)
if err != nil {
return err
}
strsReversed = make([]string, len(strs))
for i := range strs {
strsReversed[i] = reverseString(strs[i])
}
default:
return builtins.NewOperandTypeErr(1, a, "string", "set", "array")
}
var suffixesReversed []string
switch b := b.(type) {
case ast.String:
suffixesReversed = []string{reverseString(string(b))}
case *ast.Array, ast.Set:
suffixes, err := builtins.StringSliceOperand(b, 2)
if err != nil {
return err
}
suffixesReversed = make([]string, len(suffixes))
for i := range suffixes {
suffixesReversed[i] = reverseString(suffixes[i])
}
default:
return builtins.NewOperandTypeErr(2, b, "string", "set", "array")
}
return iter(ast.BooleanTerm(anyStartsWithAny(strsReversed, suffixesReversed)))
}
func anyStartsWithAny(strs []string, prefixes []string) bool {
if len(strs) == 0 || len(prefixes) == 0 {
return false
}
if len(strs) == 1 && len(prefixes) == 1 {
return strings.HasPrefix(strs[0], prefixes[0])
}
trie := patricia.NewTrie()
for i := 0; i < len(strs); i++ {
trie.Insert([]byte(strs[i]), true)
}
for i := 0; i < len(prefixes); i++ {
if trie.MatchSubtree([]byte(prefixes[i])) {
return true
}
}
return false
}
func builtinFormatInt(a, b ast.Value) (ast.Value, error) {
input, err := builtins.NumberOperand(a, 1)
@@ -55,13 +154,17 @@ func builtinConcat(a, b ast.Value) (ast.Value, error) {
strs := []string{}
switch b := b.(type) {
case ast.Array:
for i := range b {
s, ok := b[i].Value.(ast.String)
case *ast.Array:
err := b.Iter(func(x *ast.Term) error {
s, ok := x.Value.(ast.String)
if !ok {
return nil, builtins.NewOperandElementErr(2, b, b[i].Value, "string")
return builtins.NewOperandElementErr(2, b, x.Value, "string")
}
strs = append(strs, string(s))
return nil
})
if err != nil {
return nil, err
}
case ast.Set:
err := b.Iter(func(x *ast.Term) error {
@@ -82,6 +185,18 @@ func builtinConcat(a, b ast.Value) (ast.Value, error) {
return ast.String(strings.Join(strs, string(join))), nil
}
func runesEqual(a, b []rune) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
func builtinIndexOf(a, b ast.Value) (ast.Value, error) {
base, err := builtins.StringOperand(a, 1)
if err != nil {
@@ -92,9 +207,57 @@ func builtinIndexOf(a, b ast.Value) (ast.Value, error) {
if err != nil {
return nil, err
}
if len(string(search)) == 0 {
return nil, fmt.Errorf("empty search character")
}
index := strings.Index(string(base), string(search))
return ast.IntNumberTerm(index).Value, nil
baseRunes := []rune(string(base))
searchRunes := []rune(string(search))
searchLen := len(searchRunes)
for i, r := range baseRunes {
if len(baseRunes) >= i+searchLen {
if r == searchRunes[0] && runesEqual(baseRunes[i:i+searchLen], searchRunes) {
return ast.IntNumberTerm(i).Value, nil
}
} else {
break
}
}
return ast.IntNumberTerm(-1).Value, nil
}
func builtinIndexOfN(a, b ast.Value) (ast.Value, error) {
base, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
search, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
if len(string(search)) == 0 {
return nil, fmt.Errorf("empty search character")
}
baseRunes := []rune(string(base))
searchRunes := []rune(string(search))
searchLen := len(searchRunes)
var arr []*ast.Term
for i, r := range baseRunes {
if len(baseRunes) >= i+searchLen {
if r == searchRunes[0] && runesEqual(baseRunes[i:i+searchLen], searchRunes) {
arr = append(arr, ast.IntNumberTerm(i))
}
} else {
break
}
}
return ast.NewArray(arr...), nil
}
func builtinSubstring(a, b, c ast.Value) (ast.Value, error) {
@@ -103,11 +266,12 @@ func builtinSubstring(a, b, c ast.Value) (ast.Value, error) {
if err != nil {
return nil, err
}
runes := []rune(base)
startIndex, err := builtins.IntOperand(b, 2)
if err != nil {
return nil, err
} else if startIndex >= len(base) {
} else if startIndex >= len(runes) {
return ast.String(""), nil
} else if startIndex < 0 {
return nil, fmt.Errorf("negative offset")
@@ -120,13 +284,13 @@ func builtinSubstring(a, b, c ast.Value) (ast.Value, error) {
var s ast.String
if length < 0 {
s = ast.String(base[startIndex:])
s = ast.String(runes[startIndex:])
} else {
upto := startIndex + length
if len(base) < upto {
upto = len(base)
if len(runes) < upto {
upto = len(runes)
}
s = ast.String(base[startIndex:upto])
s = ast.String(runes[startIndex:upto])
}
return s, nil
@@ -202,11 +366,11 @@ func builtinSplit(a, b ast.Value) (ast.Value, error) {
return nil, err
}
elems := strings.Split(string(s), string(d))
arr := make(ast.Array, len(elems))
for i := range arr {
arr := make([]*ast.Term, len(elems))
for i := range elems {
arr[i] = ast.StringTerm(elems[i])
}
return arr, nil
return ast.NewArray(arr...), nil
}
func builtinReplace(a, b, c ast.Value) (ast.Value, error) {
@@ -229,27 +393,33 @@ func builtinReplace(a, b, c ast.Value) (ast.Value, error) {
}
func builtinReplaceN(a, b ast.Value) (ast.Value, error) {
asJSON, err := ast.JSON(a)
patterns, err := builtins.ObjectOperand(a, 1)
if err != nil {
return nil, err
}
oldnewObj, ok := asJSON.(map[string]interface{})
if !ok {
return nil, builtins.NewOperandTypeErr(1, a, "object")
}
keys := patterns.Keys()
sort.Slice(keys, func(i, j int) bool { return ast.Compare(keys[i].Value, keys[j].Value) < 0 })
s, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
var oldnewArr []string
for k, v := range oldnewObj {
strVal, ok := v.(string)
oldnewArr := make([]string, 0, len(keys)*2)
for _, k := range keys {
keyVal, ok := k.Value.(ast.String)
if !ok {
return nil, errors.New("non-string value found in pattern object")
return nil, builtins.NewOperandErr(1, "non-string key found in pattern object")
}
oldnewArr = append(oldnewArr, k, strVal)
val := patterns.Get(k) // cannot be nil
strVal, ok := val.Value.(ast.String)
if !ok {
return nil, builtins.NewOperandErr(1, "non-string value found in pattern object")
}
oldnewArr = append(oldnewArr, string(keyVal), string(strVal))
}
if err != nil {
return nil, err
}
r := strings.NewReplacer(oldnewArr...)
@@ -343,18 +513,20 @@ func builtinSprintf(a, b ast.Value) (ast.Value, error) {
return nil, err
}
astArr, ok := b.(ast.Array)
astArr, ok := b.(*ast.Array)
if !ok {
return nil, builtins.NewOperandTypeErr(2, b, "array")
}
args := make([]interface{}, len(astArr))
args := make([]interface{}, astArr.Len())
for i := range astArr {
switch v := astArr[i].Value.(type) {
for i := range args {
switch v := astArr.Elem(i).Value.(type) {
case ast.Number:
if n, ok := v.Int(); ok {
args[i] = n
} else if b, ok := new(big.Int).SetString(v.String(), 10); ok {
args[i] = b
} else if f, ok := v.Float64(); ok {
args[i] = f
} else {
@@ -363,17 +535,39 @@ func builtinSprintf(a, b ast.Value) (ast.Value, error) {
case ast.String:
args[i] = string(v)
default:
args[i] = astArr[i].String()
args[i] = astArr.Elem(i).String()
}
}
return ast.String(fmt.Sprintf(string(s), args...)), nil
}
func builtinReverse(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
s, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return err
}
return iter(ast.StringTerm(reverseString(string(s))))
}
func reverseString(str string) string {
sRunes := []rune(string(str))
length := len(sRunes)
reversedRunes := make([]rune, length)
for index, r := range sRunes {
reversedRunes[length-index-1] = r
}
return string(reversedRunes)
}
func init() {
RegisterFunctionalBuiltin2(ast.FormatInt.Name, builtinFormatInt)
RegisterFunctionalBuiltin2(ast.Concat.Name, builtinConcat)
RegisterFunctionalBuiltin2(ast.IndexOf.Name, builtinIndexOf)
RegisterFunctionalBuiltin2(ast.IndexOfN.Name, builtinIndexOfN)
RegisterFunctionalBuiltin3(ast.Substring.Name, builtinSubstring)
RegisterFunctionalBuiltin2(ast.Contains.Name, builtinContains)
RegisterFunctionalBuiltin2(ast.StartsWith.Name, builtinStartsWith)
@@ -390,4 +584,7 @@ func init() {
RegisterFunctionalBuiltin2(ast.TrimSuffix.Name, builtinTrimSuffix)
RegisterFunctionalBuiltin1(ast.TrimSpace.Name, builtinTrimSpace)
RegisterFunctionalBuiltin2(ast.Sprintf.Name, builtinSprintf)
RegisterBuiltinFunc(ast.AnyPrefixMatch.Name, builtinAnyPrefixMatch)
RegisterBuiltinFunc(ast.AnySuffixMatch.Name, builtinAnySuffixMatch)
RegisterBuiltinFunc(ast.StringReverse.Name, builtinReverse)
}

View File

@@ -0,0 +1,263 @@
// Copyright 2022 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
func bothObjects(t1, t2 *ast.Term) (bool, ast.Object, ast.Object) {
if (t1 == nil) || (t2 == nil) {
return false, nil, nil
}
obj1, ok := t1.Value.(ast.Object)
if !ok {
return false, nil, nil
}
obj2, ok := t2.Value.(ast.Object)
if !ok {
return false, nil, nil
}
return true, obj1, obj2
}
func bothSets(t1, t2 *ast.Term) (bool, ast.Set, ast.Set) {
if (t1 == nil) || (t2 == nil) {
return false, nil, nil
}
set1, ok := t1.Value.(ast.Set)
if !ok {
return false, nil, nil
}
set2, ok := t2.Value.(ast.Set)
if !ok {
return false, nil, nil
}
return true, set1, set2
}
func bothArrays(t1, t2 *ast.Term) (bool, *ast.Array, *ast.Array) {
if (t1 == nil) || (t2 == nil) {
return false, nil, nil
}
array1, ok := t1.Value.(*ast.Array)
if !ok {
return false, nil, nil
}
array2, ok := t2.Value.(*ast.Array)
if !ok {
return false, nil, nil
}
return true, array1, array2
}
func arraySet(t1, t2 *ast.Term) (bool, *ast.Array, ast.Set) {
if (t1 == nil) || (t2 == nil) {
return false, nil, nil
}
array, ok := t1.Value.(*ast.Array)
if !ok {
return false, nil, nil
}
set, ok := t2.Value.(ast.Set)
if !ok {
return false, nil, nil
}
return true, array, set
}
// objectSubset implements the subset operation on a pair of objects.
//
// This function will try to recursively apply the subset operation where it
// can, such as if both super and sub have an object or set as the value
// associated with a key.
func objectSubset(super ast.Object, sub ast.Object) bool {
var superTerm *ast.Term
isSubset := true
sub.Until(func(key, subTerm *ast.Term) bool {
// This really wants to be a for loop, hence the somewhat
// weird internal structure. However, using Until() in this
// was is a performance optimization, as it avoids performing
// any key hashing on the sub-object.
superTerm = super.Get(key)
// subTerm is can't be nil because we got it from Until(), so
// we only need to verify that super is non-nil.
if superTerm == nil {
isSubset = false
return true // break, not a subset
}
if subTerm.Equal(superTerm) {
return false // continue
}
// If both of the terms are objects then we want to apply
// the subset operation recursively, otherwise we just compare
// them normally. If only one term is an object, then we
// do a normal comparison which will come up false.
if ok, superObj, subObj := bothObjects(superTerm, subTerm); ok {
if !objectSubset(superObj, subObj) {
isSubset = false
return true // break, not a subset
}
return false // continue
}
if ok, superSet, subSet := bothSets(superTerm, subTerm); ok {
if !setSubset(superSet, subSet) {
isSubset = false
return true // break, not a subset
}
return false // continue
}
if ok, superArray, subArray := bothArrays(superTerm, subTerm); ok {
if !arraySubset(superArray, subArray) {
isSubset = false
return true // break, not a subset
}
return false // continue
}
// We have already checked for exact equality, as well as for
// all of the types of nested subsets we care about, so if we
// get here it means this isn't a subset.
isSubset = false
return true // break, not a subset
})
return isSubset
}
// setSubset implements the subset operation on sets.
//
// Unlike in the object case, this is not recursive, we just compare values
// using ast.Set.Contains() because we have no well defined way to "match up"
// objects that are in different sets.
func setSubset(super ast.Set, sub ast.Set) bool {
isSubset := true
sub.Until(func(t *ast.Term) bool {
if !super.Contains(t) {
isSubset = false
return true
}
return false
})
return isSubset
}
// arraySubset implements the subset operation on arrays.
//
// This is defined to mean that the entire "sub" array must appear in
// the "super" array. For the same rationale as setSubset(), we do not attempt
// to recurse into values.
func arraySubset(super, sub *ast.Array) bool {
// Notice that this is essentially string search. The naive approach
// used here is O(n^2). This should probably be rewritten later to use
// Boyer-Moore or something.
if sub.Len() > super.Len() {
return false
}
if sub.Equal(super) {
return true
}
superCursor := 0
subCursor := 0
for {
if subCursor == sub.Len() {
return true
}
if superCursor == super.Len() {
return false
}
subElem := sub.Elem(subCursor)
superElem := sub.Elem(superCursor + subCursor)
if superElem == nil {
return false
}
if superElem.Value.Compare(subElem.Value) == 0 {
subCursor++
} else {
superCursor++
subCursor = 0
}
}
}
// arraySetSubset implements the subset operation on array and set.
//
// This is defined to mean that the entire "sub" set must appear in
// the "super" array with no consideration of ordering.
// For the same rationale as setSubset(), we do not attempt
// to recurse into values.
func arraySetSubset(super *ast.Array, sub ast.Set) bool {
unmatched := sub.Len()
return super.Until(func(t *ast.Term) bool {
if sub.Contains(t) {
unmatched--
}
if unmatched == 0 {
return true
}
return false
})
}
func builtinObjectSubset(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
superTerm := operands[0]
subTerm := operands[1]
if ok, superObj, subObj := bothObjects(superTerm, subTerm); ok {
// Both operands are objects.
return iter(ast.BooleanTerm(objectSubset(superObj, subObj)))
}
if ok, superSet, subSet := bothSets(superTerm, subTerm); ok {
// Both operands are sets.
return iter(ast.BooleanTerm(setSubset(superSet, subSet)))
}
if ok, superArray, subArray := bothArrays(superTerm, subTerm); ok {
// Both operands are sets.
return iter(ast.BooleanTerm(arraySubset(superArray, subArray)))
}
if ok, superArray, subSet := arraySet(superTerm, subTerm); ok {
// Super operand is array and sub operand is set
return iter(ast.BooleanTerm(arraySetSubset(superArray, subSet)))
}
return builtins.ErrOperand("both arguments object.subset must be of the same type or array and set")
}
func init() {
RegisterBuiltinFunc(ast.ObjectSubset.Name, builtinObjectSubset)
}

View File

@@ -7,6 +7,7 @@ package topdown
import (
"encoding/json"
"fmt"
"math"
"math/big"
"strconv"
"sync"
@@ -16,61 +17,61 @@ import (
"github.com/open-policy-agent/opa/topdown/builtins"
)
type nowKeyID string
var nowKey = nowKeyID("time.now_ns")
var tzCache map[string]*time.Location
var tzCacheMutex *sync.Mutex
func builtinTimeNowNanos(bctx BuiltinContext, _ []*ast.Term, iter func(*ast.Term) error) error {
// 1677-09-21T00:12:43.145224192-00:00
var minDateAllowedForNsConversion = time.Unix(0, math.MinInt64)
exist, ok := bctx.Cache.Get(nowKey)
var now *ast.Term
// 2262-04-11T23:47:16.854775807-00:00
var maxDateAllowedForNsConversion = time.Unix(0, math.MaxInt64)
if !ok {
curr := time.Now()
now = ast.NewTerm(ast.Number(int64ToJSONNumber(curr.UnixNano())))
bctx.Cache.Put(nowKey, now)
} else {
now = exist.(*ast.Term)
func toSafeUnixNano(t time.Time, iter func(*ast.Term) error) error {
if t.Before(minDateAllowedForNsConversion) || t.After(maxDateAllowedForNsConversion) {
return fmt.Errorf("time outside of valid range")
}
return iter(now)
return iter(ast.NewTerm(ast.Number(int64ToJSONNumber(t.UnixNano()))))
}
func builtinTimeParseNanos(a, b ast.Value) (ast.Value, error) {
func builtinTimeNowNanos(bctx BuiltinContext, _ []*ast.Term, iter func(*ast.Term) error) error {
return iter(bctx.Time)
}
func builtinTimeParseNanos(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
a := operands[0].Value
format, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
return err
}
b := operands[1].Value
value, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
return err
}
result, err := time.Parse(string(format), string(value))
if err != nil {
return nil, err
return err
}
return ast.Number(int64ToJSONNumber(result.UnixNano())), nil
return toSafeUnixNano(result, iter)
}
func builtinTimeParseRFC3339Nanos(a ast.Value) (ast.Value, error) {
func builtinTimeParseRFC3339Nanos(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
a := operands[0].Value
value, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
return err
}
result, err := time.Parse(time.RFC3339, string(value))
if err != nil {
return nil, err
return err
}
return ast.Number(int64ToJSONNumber(result.UnixNano())), nil
return toSafeUnixNano(result, iter)
}
func builtinParseDurationNanos(a ast.Value) (ast.Value, error) {
@@ -91,7 +92,7 @@ func builtinDate(a ast.Value) (ast.Value, error) {
return nil, err
}
year, month, day := t.Date()
result := ast.Array{ast.IntNumberTerm(year), ast.IntNumberTerm(int(month)), ast.IntNumberTerm(day)}
result := ast.NewArray(ast.IntNumberTerm(year), ast.IntNumberTerm(int(month)), ast.IntNumberTerm(day))
return result, nil
}
@@ -101,7 +102,7 @@ func builtinClock(a ast.Value) (ast.Value, error) {
return nil, err
}
hour, minute, second := t.Clock()
result := ast.Array{ast.IntNumberTerm(hour), ast.IntNumberTerm(minute), ast.IntNumberTerm(second)}
result := ast.NewArray(ast.IntNumberTerm(hour), ast.IntNumberTerm(minute), ast.IntNumberTerm(second))
return result, nil
}
@@ -114,24 +115,115 @@ func builtinWeekday(a ast.Value) (ast.Value, error) {
return ast.String(weekday), nil
}
func builtinAddDate(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
t, err := tzTime(operands[0].Value)
if err != nil {
return err
}
years, err := builtins.IntOperand(operands[1].Value, 2)
if err != nil {
return err
}
months, err := builtins.IntOperand(operands[2].Value, 3)
if err != nil {
return err
}
days, err := builtins.IntOperand(operands[3].Value, 4)
if err != nil {
return err
}
result := t.AddDate(years, months, days)
return toSafeUnixNano(result, iter)
}
func builtinDiff(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
t1, err := tzTime(operands[0].Value)
if err != nil {
return err
}
t2, err := tzTime(operands[1].Value)
if err != nil {
return err
}
// The following implementation of this function is taken
// from https://github.com/icza/gox licensed under Apache 2.0.
// The only modification made is to variable names.
//
// For details, see https://stackoverflow.com/a/36531443/1705598
//
// Copyright 2021 icza
// BEGIN REDISTRIBUTION FROM APACHE 2.0 LICENSED PROJECT
if t1.Location() != t2.Location() {
t2 = t2.In(t1.Location())
}
if t1.After(t2) {
t1, t2 = t2, t1
}
y1, M1, d1 := t1.Date()
y2, M2, d2 := t2.Date()
h1, m1, s1 := t1.Clock()
h2, m2, s2 := t2.Clock()
year := y2 - y1
month := int(M2 - M1)
day := d2 - d1
hour := h2 - h1
min := m2 - m1
sec := s2 - s1
// Normalize negative values
if sec < 0 {
sec += 60
min--
}
if min < 0 {
min += 60
hour--
}
if hour < 0 {
hour += 24
day--
}
if day < 0 {
// Days in month:
t := time.Date(y1, M1, 32, 0, 0, 0, 0, time.UTC)
day += 32 - t.Day()
month--
}
if month < 0 {
month += 12
year--
}
// END REDISTRIBUTION FROM APACHE 2.0 LICENSED PROJECT
return iter(ast.ArrayTerm(ast.IntNumberTerm(year), ast.IntNumberTerm(month), ast.IntNumberTerm(day),
ast.IntNumberTerm(hour), ast.IntNumberTerm(min), ast.IntNumberTerm(sec)))
}
func tzTime(a ast.Value) (t time.Time, err error) {
var nVal ast.Value
loc := time.UTC
switch va := a.(type) {
case ast.Array:
if len(va) == 0 {
case *ast.Array:
if va.Len() == 0 {
return time.Time{}, builtins.NewOperandTypeErr(1, a, "either number (ns) or [number (ns), string (tz)]")
}
nVal, err = builtins.NumberOperand(va[0].Value, 1)
nVal, err = builtins.NumberOperand(va.Elem(0).Value, 1)
if err != nil {
return time.Time{}, err
}
if len(va) > 1 {
tzVal, err := builtins.StringOperand(va[1].Value, 1)
if va.Len() > 1 {
tzVal, err := builtins.StringOperand(va.Elem(1).Value, 1)
if err != nil {
return time.Time{}, err
}
@@ -192,12 +284,14 @@ func int64ToJSONNumber(i int64) json.Number {
func init() {
RegisterBuiltinFunc(ast.NowNanos.Name, builtinTimeNowNanos)
RegisterFunctionalBuiltin1(ast.ParseRFC3339Nanos.Name, builtinTimeParseRFC3339Nanos)
RegisterFunctionalBuiltin2(ast.ParseNanos.Name, builtinTimeParseNanos)
RegisterBuiltinFunc(ast.ParseRFC3339Nanos.Name, builtinTimeParseRFC3339Nanos)
RegisterBuiltinFunc(ast.ParseNanos.Name, builtinTimeParseNanos)
RegisterFunctionalBuiltin1(ast.ParseDurationNanos.Name, builtinParseDurationNanos)
RegisterFunctionalBuiltin1(ast.Date.Name, builtinDate)
RegisterFunctionalBuiltin1(ast.Clock.Name, builtinClock)
RegisterFunctionalBuiltin1(ast.Weekday.Name, builtinWeekday)
RegisterBuiltinFunc(ast.AddDate.Name, builtinAddDate)
RegisterBuiltinFunc(ast.Diff.Name, builtinDiff)
tzCacheMutex = &sync.Mutex{}
tzCache = make(map[string]*time.Location)
}

File diff suppressed because it is too large Load Diff

View File

@@ -9,10 +9,18 @@ import (
"io"
"strings"
iStrs "github.com/open-policy-agent/opa/internal/strings"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
const (
minLocationWidth = 5 // len("query")
maxIdealLocationWidth = 64
locationPadding = 4
)
// Op defines the types of tracing events.
type Op string
@@ -36,12 +44,20 @@ const (
// FailOp is emitted when an expression evaluates to false.
FailOp Op = "Fail"
// DuplicateOp is emitted when a query has produced a duplicate value. The search
// will stop at the point where the duplicate was emitted and backtrack.
DuplicateOp Op = "Duplicate"
// NoteOp is emitted when an expression invokes a tracing built-in function.
NoteOp Op = "Note"
// IndexOp is emitted during an expression evaluation to represent lookup
// matches.
IndexOp Op = "Index"
// WasmOp is emitted when resolving a ref using an external
// Resolver.
WasmOp Op = "Wasm"
)
// VarMetadata provides some user facing information about
@@ -58,9 +74,13 @@ type Event struct {
Location *ast.Location // The location of the Node this event relates to.
QueryID uint64 // Identifies the query this event belongs to.
ParentID uint64 // Identifies the parent query this event belongs to.
Locals *ast.ValueMap // Contains local variable bindings from the query context.
LocalMetadata map[ast.Var]VarMetadata // Contains metadata for the local variable bindings.
Locals *ast.ValueMap // Contains local variable bindings from the query context. Nil if variables were not included in the trace event.
LocalMetadata map[ast.Var]VarMetadata // Contains metadata for the local variable bindings. Nil if variables were not included in the trace event.
Message string // Contains message for Note events.
Ref *ast.Ref // Identifies the subject ref for the event. Only applies to Index and Wasm operations.
input *ast.Term
bindings *bindings
}
// HasRule returns true if the Event contains an ast.Rule.
@@ -102,6 +122,17 @@ func (evt *Event) String() string {
return fmt.Sprintf("%v %v %v (qid=%v, pqid=%v)", evt.Op, evt.Node, evt.Locals, evt.QueryID, evt.ParentID)
}
// Input returns the input object as it was at the event.
func (evt *Event) Input() *ast.Term {
return evt.input
}
// Plug plugs event bindings into the provided ast.Term. Because bindings are mutable, this only makes sense to do when
// the event is emitted rather than on recorded trace events as the bindings are going to be different by then.
func (evt *Event) Plug(term *ast.Term) *ast.Term {
return evt.bindings.Plug(term)
}
func (evt *Event) equalNodes(other *Event) bool {
switch a := evt.Node.(type) {
case ast.Body:
@@ -123,13 +154,53 @@ func (evt *Event) equalNodes(other *Event) bool {
}
// Tracer defines the interface for tracing in the top-down evaluation engine.
// Deprecated: Use QueryTracer instead.
type Tracer interface {
Enabled() bool
Trace(*Event)
}
// BufferTracer implements the Tracer interface by simply buffering all events
// received.
// QueryTracer defines the interface for tracing in the top-down evaluation engine.
// The implementation can provide additional configuration to modify the tracing
// behavior for query evaluations.
type QueryTracer interface {
Enabled() bool
TraceEvent(Event)
Config() TraceConfig
}
// TraceConfig defines some common configuration for Tracer implementations
type TraceConfig struct {
PlugLocalVars bool // Indicate whether to plug local variable bindings before calling into the tracer.
}
// legacyTracer Implements the QueryTracer interface by wrapping an older Tracer instance.
type legacyTracer struct {
t Tracer
}
func (l *legacyTracer) Enabled() bool {
return l.t.Enabled()
}
func (l *legacyTracer) Config() TraceConfig {
return TraceConfig{
PlugLocalVars: true, // For backwards compatibility old tracers will plug local variables
}
}
func (l *legacyTracer) TraceEvent(evt Event) {
l.t.Trace(&evt)
}
// WrapLegacyTracer will create a new QueryTracer which wraps an
// older Tracer instance.
func WrapLegacyTracer(tracer Tracer) QueryTracer {
return &legacyTracer{t: tracer}
}
// BufferTracer implements the Tracer and QueryTracer interface by
// simply buffering all events received.
type BufferTracer []*Event
// NewBufferTracer returns a new BufferTracer.
@@ -139,17 +210,25 @@ func NewBufferTracer() *BufferTracer {
// Enabled always returns true if the BufferTracer is instantiated.
func (b *BufferTracer) Enabled() bool {
if b == nil {
return false
}
return true
return b != nil
}
// Trace adds the event to the buffer.
// Deprecated: Use TraceEvent instead.
func (b *BufferTracer) Trace(evt *Event) {
*b = append(*b, evt)
}
// TraceEvent adds the event to the buffer.
func (b *BufferTracer) TraceEvent(evt Event) {
*b = append(*b, &evt)
}
// Config returns the Tracers standard configuration
func (b *BufferTracer) Config() TraceConfig {
return TraceConfig{PlugLocalVars: true}
}
// PrettyTrace pretty prints the trace to the writer.
func PrettyTrace(w io.Writer, trace []*Event) {
depths := depths{}
@@ -162,10 +241,16 @@ func PrettyTrace(w io.Writer, trace []*Event) {
// PrettyTraceWithLocation prints the trace to the writer and includes location information
func PrettyTraceWithLocation(w io.Writer, trace []*Event) {
depths := depths{}
filePathAliases, longest := getShortenedFileNames(trace)
// Always include some padding between the trace and location
locationWidth := longest + locationPadding
for _, event := range trace {
depth := depths.GetOrSet(event.QueryID, event.ParentID)
location := formatLocation(event)
fmt.Fprintln(w, fmt.Sprintf("%v %v", location, formatEvent(event, depth)))
location := formatLocation(event, filePathAliases)
fmt.Fprintf(w, "%-*s %s\n", locationWidth, location, formatEvent(event, depth))
}
}
@@ -173,25 +258,34 @@ func formatEvent(event *Event, depth int) string {
padding := formatEventPadding(event, depth)
if event.Op == NoteOp {
return fmt.Sprintf("%v%v %q", padding, event.Op, event.Message)
} else if event.Message != "" {
return fmt.Sprintf("%v%v %v %v", padding, event.Op, event.Node, event.Message)
} else {
switch node := event.Node.(type) {
case *ast.Rule:
return fmt.Sprintf("%v%v %v", padding, event.Op, node.Path())
default:
return fmt.Sprintf("%v%v %v", padding, event.Op, rewrite(event).Node)
}
}
var details interface{}
if node, ok := event.Node.(*ast.Rule); ok {
details = node.Path()
} else if event.Ref != nil {
details = event.Ref
} else {
details = rewrite(event).Node
}
template := "%v%v %v"
opts := []interface{}{padding, event.Op, details}
if event.Message != "" {
template += " %v"
opts = append(opts, event.Message)
}
return fmt.Sprintf(template, opts...)
}
func formatEventPadding(event *Event, depth int) string {
spaces := formatEventSpaces(event, depth)
padding := ""
if spaces > 1 {
padding += strings.Repeat("| ", spaces-1)
return strings.Repeat("| ", spaces-1)
}
return padding
return ""
}
func formatEventSpaces(event *Event, depth int) int {
@@ -206,21 +300,67 @@ func formatEventSpaces(event *Event, depth int) int {
return depth + 1
}
func formatLocation(event *Event) string {
if event.Op == NoteOp {
return fmt.Sprintf("%-19v", "note")
// getShortenedFileNames will return a map of file paths to shortened aliases
// that were found in the trace. It also returns the longest location expected
func getShortenedFileNames(trace []*Event) (map[string]string, int) {
// Get a deduplicated list of all file paths
// and the longest file path size
fpAliases := map[string]string{}
var canShorten []string
longestLocation := 0
for _, event := range trace {
if event.Location != nil {
if event.Location.File != "" {
// length of "<name>:<row>"
curLen := len(event.Location.File) + numDigits10(event.Location.Row) + 1
if curLen > longestLocation {
longestLocation = curLen
}
if _, ok := fpAliases[event.Location.File]; ok {
continue
}
canShorten = append(canShorten, event.Location.File)
// Default to just alias their full path
fpAliases[event.Location.File] = event.Location.File
} else {
// length of "<min width>:<row>"
curLen := minLocationWidth + numDigits10(event.Location.Row) + 1
if curLen > longestLocation {
longestLocation = curLen
}
}
}
}
if len(canShorten) > 0 && longestLocation > maxIdealLocationWidth {
fpAliases, longestLocation = iStrs.TruncateFilePaths(maxIdealLocationWidth, longestLocation, canShorten...)
}
return fpAliases, longestLocation
}
func numDigits10(n int) int {
if n < 10 {
return 1
}
return numDigits10(n/10) + 1
}
func formatLocation(event *Event, fileAliases map[string]string) string {
location := event.Location
if location == nil {
return fmt.Sprintf("%-19v", "")
return ""
}
if location.File == "" {
return fmt.Sprintf("%-19v", fmt.Sprintf("%.15v:%v", "query", location.Row))
return fmt.Sprintf("query:%v", location.Row)
}
return fmt.Sprintf("%-19v", fmt.Sprintf("%.15v:%v", location.File, location.Row))
return fmt.Sprintf("%v:%v", fileAliases[location.File], location.Row)
}
// depths is a helper for computing the depth of an event. Events within the
@@ -245,33 +385,25 @@ func builtinTrace(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) er
return handleBuiltinErr(ast.Trace.Name, bctx.Location, err)
}
if !traceIsEnabled(bctx.Tracers) {
if !bctx.TraceEnabled {
return iter(ast.BooleanTerm(true))
}
evt := &Event{
evt := Event{
Op: NoteOp,
Location: bctx.Location,
QueryID: bctx.QueryID,
ParentID: bctx.ParentID,
Message: string(str),
}
for i := range bctx.Tracers {
bctx.Tracers[i].Trace(evt)
for i := range bctx.QueryTracers {
bctx.QueryTracers[i].TraceEvent(evt)
}
return iter(ast.BooleanTerm(true))
}
func traceIsEnabled(tracers []Tracer) bool {
for i := range tracers {
if tracers[i].Enabled() {
return true
}
}
return false
}
func rewrite(event *Event) *Event {
cpy := *event
@@ -280,14 +412,25 @@ func rewrite(event *Event) *Event {
switch v := event.Node.(type) {
case *ast.Expr:
node = v.Copy()
expr := v.Copy()
// Hide generated local vars in 'key' position that have not been
// rewritten.
if ev, ok := v.Terms.(*ast.Every); ok {
if kv, ok := ev.Key.Value.(ast.Var); ok {
if rw, ok := cpy.LocalMetadata[kv]; !ok || rw.Name.IsGenerated() {
expr.Terms.(*ast.Every).Key = nil
}
}
}
node = expr
case ast.Body:
node = v.Copy()
case *ast.Rule:
node = v.Copy()
}
ast.TransformVars(node, func(v ast.Var) (ast.Value, error) {
_, _ = ast.TransformVars(node, func(v ast.Var) (ast.Value, error) {
if meta, ok := cpy.LocalMetadata[v]; ok {
return meta.Name, nil
}

View File

@@ -1,4 +1,4 @@
// Copyright 2018 The OPA Authors. All rights reserved.
// Copyright 2022 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
@@ -13,7 +13,7 @@ func builtinIsNumber(a ast.Value) (ast.Value, error) {
case ast.Number:
return ast.Boolean(true), nil
default:
return nil, BuiltinEmpty{}
return ast.Boolean(false), nil
}
}
@@ -22,7 +22,7 @@ func builtinIsString(a ast.Value) (ast.Value, error) {
case ast.String:
return ast.Boolean(true), nil
default:
return nil, BuiltinEmpty{}
return ast.Boolean(false), nil
}
}
@@ -31,16 +31,16 @@ func builtinIsBoolean(a ast.Value) (ast.Value, error) {
case ast.Boolean:
return ast.Boolean(true), nil
default:
return nil, BuiltinEmpty{}
return ast.Boolean(false), nil
}
}
func builtinIsArray(a ast.Value) (ast.Value, error) {
switch a.(type) {
case ast.Array:
case *ast.Array:
return ast.Boolean(true), nil
default:
return nil, BuiltinEmpty{}
return ast.Boolean(false), nil
}
}
@@ -49,7 +49,7 @@ func builtinIsSet(a ast.Value) (ast.Value, error) {
case ast.Set:
return ast.Boolean(true), nil
default:
return nil, BuiltinEmpty{}
return ast.Boolean(false), nil
}
}
@@ -58,7 +58,7 @@ func builtinIsObject(a ast.Value) (ast.Value, error) {
case ast.Object:
return ast.Boolean(true), nil
default:
return nil, BuiltinEmpty{}
return ast.Boolean(false), nil
}
}
@@ -67,7 +67,7 @@ func builtinIsNull(a ast.Value) (ast.Value, error) {
case ast.Null:
return ast.Boolean(true), nil
default:
return nil, BuiltinEmpty{}
return ast.Boolean(false), nil
}
}

View File

@@ -20,7 +20,7 @@ func builtinTypeName(a ast.Value) (ast.Value, error) {
return ast.String("number"), nil
case ast.String:
return ast.String("string"), nil
case ast.Array:
case *ast.Array:
return ast.String("array"), nil
case ast.Object:
return ast.String("object"), nil

View File

@@ -0,0 +1,36 @@
// Copyright 2020 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/internal/uuid"
)
type uuidCachingKey string
func builtinUUIDRFC4122(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
var key = uuidCachingKey(args[0].Value.String())
val, ok := bctx.Cache.Get(key)
if ok {
return iter(val.(*ast.Term))
}
s, err := uuid.New(bctx.Seed)
if err != nil {
return err
}
result := ast.NewTerm(ast.String(s))
bctx.Cache.Put(key, result)
return iter(result)
}
func init() {
RegisterBuiltinFunc(ast.UUIDRFC4122.Name, builtinUUIDRFC4122)
}

View File

@@ -8,57 +8,61 @@ import (
"github.com/open-policy-agent/opa/ast"
)
func evalWalk(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
func evalWalk(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
input := args[0]
filter := getOutputPath(args)
var path ast.Array
return walk(filter, path, input, iter)
return walk(filter, nil, input, iter)
}
func walk(filter, path ast.Array, input *ast.Term, iter func(*ast.Term) error) error {
func walk(filter, path *ast.Array, input *ast.Term, iter func(*ast.Term) error) error {
if len(filter) == 0 {
if err := iter(ast.ArrayTerm(ast.NewTerm(path), input)); err != nil {
if filter == nil || filter.Len() == 0 {
if path == nil {
path = ast.NewArray()
}
if err := iter(ast.ArrayTerm(ast.NewTerm(path.Copy()), input)); err != nil {
return err
}
}
if len(filter) > 0 {
key := filter[0]
filter = filter[1:]
if filter != nil && filter.Len() > 0 {
key := filter.Elem(0)
filter = filter.Slice(1, -1)
if key.IsGround() {
if term := input.Get(key); term != nil {
return walk(filter, append(path, key), term, iter)
path = pathAppend(path, key)
return walk(filter, path, term, iter)
}
return nil
}
}
switch v := input.Value.(type) {
case ast.Array:
for i := range v {
path = append(path, ast.IntNumberTerm(i))
if err := walk(filter, path, v[i], iter); err != nil {
case *ast.Array:
for i := 0; i < v.Len(); i++ {
path = pathAppend(path, ast.IntNumberTerm(i))
if err := walk(filter, path, v.Elem(i), iter); err != nil {
return err
}
path = path[:len(path)-1]
path = path.Slice(0, path.Len()-1)
}
case ast.Object:
return v.Iter(func(k, v *ast.Term) error {
path = append(path, k)
path = pathAppend(path, k)
if err := walk(filter, path, v, iter); err != nil {
return err
}
path = path[:len(path)-1]
path = path.Slice(0, path.Len()-1)
return nil
})
case ast.Set:
return v.Iter(func(elem *ast.Term) error {
path = append(path, elem)
path = pathAppend(path, elem)
if err := walk(filter, path, elem, iter); err != nil {
return err
}
path = path[:len(path)-1]
path = path.Slice(0, path.Len()-1)
return nil
})
}
@@ -66,11 +70,19 @@ func walk(filter, path ast.Array, input *ast.Term, iter func(*ast.Term) error) e
return nil
}
func getOutputPath(args []*ast.Term) ast.Array {
func pathAppend(path *ast.Array, key *ast.Term) *ast.Array {
if path == nil {
return ast.NewArray(key)
}
return path.Append(key)
}
func getOutputPath(args []*ast.Term) *ast.Array {
if len(args) == 2 {
if arr, ok := args[1].Value.(ast.Array); ok {
if len(arr) == 2 {
if path, ok := arr[0].Value.(ast.Array); ok {
if arr, ok := args[1].Value.(*ast.Array); ok {
if arr.Len() == 2 {
if path, ok := arr.Elem(0).Value.(*ast.Array); ok {
return path
}
}