Upgrade dependent version: github.com/open-policy-agent/opa (#5315)
Upgrade dependent version: github.com/open-policy-agent/opa v0.18.0 -> v0.45.0 Signed-off-by: hongzhouzi <hongzhouzi@kubesphere.io> Signed-off-by: hongzhouzi <hongzhouzi@kubesphere.io>
This commit is contained in:
134
vendor/github.com/open-policy-agent/opa/topdown/aggregates.go
generated
vendored
134
vendor/github.com/open-policy-agent/opa/topdown/aggregates.go
generated
vendored
@@ -13,30 +13,31 @@ import (
|
||||
|
||||
func builtinCount(a ast.Value) (ast.Value, error) {
|
||||
switch a := a.(type) {
|
||||
case ast.Array:
|
||||
return ast.IntNumberTerm(len(a)).Value, nil
|
||||
case *ast.Array:
|
||||
return ast.IntNumberTerm(a.Len()).Value, nil
|
||||
case ast.Object:
|
||||
return ast.IntNumberTerm(a.Len()).Value, nil
|
||||
case ast.Set:
|
||||
return ast.IntNumberTerm(a.Len()).Value, nil
|
||||
case ast.String:
|
||||
return ast.IntNumberTerm(len(a)).Value, nil
|
||||
return ast.IntNumberTerm(len([]rune(a))).Value, nil
|
||||
}
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "array", "object", "set")
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "array", "object", "set", "string")
|
||||
}
|
||||
|
||||
func builtinSum(a ast.Value) (ast.Value, error) {
|
||||
switch a := a.(type) {
|
||||
case ast.Array:
|
||||
case *ast.Array:
|
||||
sum := big.NewFloat(0)
|
||||
for _, x := range a {
|
||||
err := a.Iter(func(x *ast.Term) error {
|
||||
n, ok := x.Value.(ast.Number)
|
||||
if !ok {
|
||||
return nil, builtins.NewOperandElementErr(1, a, x.Value, "number")
|
||||
return builtins.NewOperandElementErr(1, a, x.Value, "number")
|
||||
}
|
||||
sum = new(big.Float).Add(sum, builtins.NumberToFloat(n))
|
||||
}
|
||||
return builtins.FloatToNumber(sum), nil
|
||||
return nil
|
||||
})
|
||||
return builtins.FloatToNumber(sum), err
|
||||
case ast.Set:
|
||||
sum := big.NewFloat(0)
|
||||
err := a.Iter(func(x *ast.Term) error {
|
||||
@@ -54,16 +55,17 @@ func builtinSum(a ast.Value) (ast.Value, error) {
|
||||
|
||||
func builtinProduct(a ast.Value) (ast.Value, error) {
|
||||
switch a := a.(type) {
|
||||
case ast.Array:
|
||||
case *ast.Array:
|
||||
product := big.NewFloat(1)
|
||||
for _, x := range a {
|
||||
err := a.Iter(func(x *ast.Term) error {
|
||||
n, ok := x.Value.(ast.Number)
|
||||
if !ok {
|
||||
return nil, builtins.NewOperandElementErr(1, a, x.Value, "number")
|
||||
return builtins.NewOperandElementErr(1, a, x.Value, "number")
|
||||
}
|
||||
product = new(big.Float).Mul(product, builtins.NumberToFloat(n))
|
||||
}
|
||||
return builtins.FloatToNumber(product), nil
|
||||
return nil
|
||||
})
|
||||
return builtins.FloatToNumber(product), err
|
||||
case ast.Set:
|
||||
product := big.NewFloat(1)
|
||||
err := a.Iter(func(x *ast.Term) error {
|
||||
@@ -81,16 +83,16 @@ func builtinProduct(a ast.Value) (ast.Value, error) {
|
||||
|
||||
func builtinMax(a ast.Value) (ast.Value, error) {
|
||||
switch a := a.(type) {
|
||||
case ast.Array:
|
||||
if len(a) == 0 {
|
||||
case *ast.Array:
|
||||
if a.Len() == 0 {
|
||||
return nil, BuiltinEmpty{}
|
||||
}
|
||||
var max = ast.Value(ast.Null{})
|
||||
for i := range a {
|
||||
if ast.Compare(max, a[i].Value) <= 0 {
|
||||
max = a[i].Value
|
||||
a.Foreach(func(x *ast.Term) {
|
||||
if ast.Compare(max, x.Value) <= 0 {
|
||||
max = x.Value
|
||||
}
|
||||
}
|
||||
})
|
||||
return max, nil
|
||||
case ast.Set:
|
||||
if a.Len() == 0 {
|
||||
@@ -110,16 +112,16 @@ func builtinMax(a ast.Value) (ast.Value, error) {
|
||||
|
||||
func builtinMin(a ast.Value) (ast.Value, error) {
|
||||
switch a := a.(type) {
|
||||
case ast.Array:
|
||||
if len(a) == 0 {
|
||||
case *ast.Array:
|
||||
if a.Len() == 0 {
|
||||
return nil, BuiltinEmpty{}
|
||||
}
|
||||
min := a[0].Value
|
||||
for i := range a {
|
||||
if ast.Compare(min, a[i].Value) >= 0 {
|
||||
min = a[i].Value
|
||||
min := a.Elem(0).Value
|
||||
a.Foreach(func(x *ast.Term) {
|
||||
if ast.Compare(min, x.Value) >= 0 {
|
||||
min = x.Value
|
||||
}
|
||||
}
|
||||
})
|
||||
return min, nil
|
||||
case ast.Set:
|
||||
if a.Len() == 0 {
|
||||
@@ -146,7 +148,7 @@ func builtinMin(a ast.Value) (ast.Value, error) {
|
||||
|
||||
func builtinSort(a ast.Value) (ast.Value, error) {
|
||||
switch a := a.(type) {
|
||||
case ast.Array:
|
||||
case *ast.Array:
|
||||
return a.Sorted(), nil
|
||||
case ast.Set:
|
||||
return a.Sorted(), nil
|
||||
@@ -159,20 +161,24 @@ func builtinAll(a ast.Value) (ast.Value, error) {
|
||||
case ast.Set:
|
||||
res := true
|
||||
match := ast.BooleanTerm(true)
|
||||
val.Foreach(func(term *ast.Term) {
|
||||
if !term.Equal(match) {
|
||||
val.Until(func(term *ast.Term) bool {
|
||||
if !match.Equal(term) {
|
||||
res = false
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
return ast.Boolean(res), nil
|
||||
case ast.Array:
|
||||
case *ast.Array:
|
||||
res := true
|
||||
match := ast.BooleanTerm(true)
|
||||
for _, term := range val {
|
||||
if !term.Equal(match) {
|
||||
val.Until(func(term *ast.Term) bool {
|
||||
if !match.Equal(term) {
|
||||
res = false
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
return ast.Boolean(res), nil
|
||||
default:
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "array", "set")
|
||||
@@ -182,28 +188,64 @@ func builtinAll(a ast.Value) (ast.Value, error) {
|
||||
func builtinAny(a ast.Value) (ast.Value, error) {
|
||||
switch val := a.(type) {
|
||||
case ast.Set:
|
||||
res := false
|
||||
match := ast.BooleanTerm(true)
|
||||
val.Foreach(func(term *ast.Term) {
|
||||
if term.Equal(match) {
|
||||
res = true
|
||||
}
|
||||
})
|
||||
res := val.Len() > 0 && val.Contains(ast.BooleanTerm(true))
|
||||
return ast.Boolean(res), nil
|
||||
case ast.Array:
|
||||
case *ast.Array:
|
||||
res := false
|
||||
match := ast.BooleanTerm(true)
|
||||
for _, term := range val {
|
||||
if term.Equal(match) {
|
||||
val.Until(func(term *ast.Term) bool {
|
||||
if match.Equal(term) {
|
||||
res = true
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
return ast.Boolean(res), nil
|
||||
default:
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "array", "set")
|
||||
}
|
||||
}
|
||||
|
||||
func builtinMember(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
containee := args[0]
|
||||
switch c := args[1].Value.(type) {
|
||||
case ast.Set:
|
||||
return iter(ast.BooleanTerm(c.Contains(containee)))
|
||||
case *ast.Array:
|
||||
ret := false
|
||||
c.Until(func(v *ast.Term) bool {
|
||||
if v.Value.Compare(containee.Value) == 0 {
|
||||
ret = true
|
||||
}
|
||||
return ret
|
||||
})
|
||||
return iter(ast.BooleanTerm(ret))
|
||||
case ast.Object:
|
||||
ret := false
|
||||
c.Until(func(_, v *ast.Term) bool {
|
||||
if v.Value.Compare(containee.Value) == 0 {
|
||||
ret = true
|
||||
}
|
||||
return ret
|
||||
})
|
||||
return iter(ast.BooleanTerm(ret))
|
||||
}
|
||||
return iter(ast.BooleanTerm(false))
|
||||
}
|
||||
|
||||
func builtinMemberWithKey(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
key, val := args[0], args[1]
|
||||
switch c := args[2].Value.(type) {
|
||||
case interface{ Get(*ast.Term) *ast.Term }:
|
||||
ret := false
|
||||
if act := c.Get(key); act != nil {
|
||||
ret = act.Value.Compare(val.Value) == 0
|
||||
}
|
||||
return iter(ast.BooleanTerm(ret))
|
||||
}
|
||||
return iter(ast.BooleanTerm(false))
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin1(ast.Count.Name, builtinCount)
|
||||
RegisterFunctionalBuiltin1(ast.Sum.Name, builtinSum)
|
||||
@@ -213,4 +255,6 @@ func init() {
|
||||
RegisterFunctionalBuiltin1(ast.Sort.Name, builtinSort)
|
||||
RegisterFunctionalBuiltin1(ast.Any.Name, builtinAny)
|
||||
RegisterFunctionalBuiltin1(ast.All.Name, builtinAll)
|
||||
RegisterBuiltinFunc(ast.Member.Name, builtinMember)
|
||||
RegisterBuiltinFunc(ast.MemberWithKey.Name, builtinMemberWithKey)
|
||||
}
|
||||
|
||||
30
vendor/github.com/open-policy-agent/opa/topdown/arithmetic.go
generated
vendored
30
vendor/github.com/open-policy-agent/opa/topdown/arithmetic.go
generated
vendored
@@ -32,6 +32,28 @@ func arithRound(a *big.Float) (*big.Float, error) {
|
||||
return new(big.Float).SetInt(i), nil
|
||||
}
|
||||
|
||||
func arithCeil(a *big.Float) (*big.Float, error) {
|
||||
i, _ := a.Int(nil)
|
||||
f := new(big.Float).SetInt(i)
|
||||
|
||||
if f.Signbit() || a.Cmp(f) == 0 {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
return new(big.Float).Add(f, big.NewFloat(1.0)), nil
|
||||
}
|
||||
|
||||
func arithFloor(a *big.Float) (*big.Float, error) {
|
||||
i, _ := a.Int(nil)
|
||||
f := new(big.Float).SetInt(i)
|
||||
|
||||
if !f.Signbit() || a.Cmp(f) == 0 {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
return new(big.Float).Sub(f, big.NewFloat(1.0)), nil
|
||||
}
|
||||
|
||||
func arithPlus(a, b *big.Float) (*big.Float, error) {
|
||||
return new(big.Float).Add(a, b), nil
|
||||
}
|
||||
@@ -115,7 +137,11 @@ func builtinMinus(a, b ast.Value) (ast.Value, error) {
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "number", "set")
|
||||
}
|
||||
|
||||
return nil, builtins.NewOperandTypeErr(2, b, "number", "set")
|
||||
if ok2 {
|
||||
return nil, builtins.NewOperandTypeErr(2, b, "set")
|
||||
}
|
||||
|
||||
return nil, builtins.NewOperandTypeErr(2, b, "number")
|
||||
}
|
||||
|
||||
func builtinRem(a, b ast.Value) (ast.Value, error) {
|
||||
@@ -148,6 +174,8 @@ func builtinRem(a, b ast.Value) (ast.Value, error) {
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin1(ast.Abs.Name, builtinArithArity1(arithAbs))
|
||||
RegisterFunctionalBuiltin1(ast.Round.Name, builtinArithArity1(arithRound))
|
||||
RegisterFunctionalBuiltin1(ast.Ceil.Name, builtinArithArity1(arithCeil))
|
||||
RegisterFunctionalBuiltin1(ast.Floor.Name, builtinArithArity1(arithFloor))
|
||||
RegisterFunctionalBuiltin2(ast.Plus.Name, builtinArithArity2(arithPlus))
|
||||
RegisterFunctionalBuiltin2(ast.Minus.Name, builtinMinus)
|
||||
RegisterFunctionalBuiltin2(ast.Multiply.Name, builtinArithArity2(arithMultiply))
|
||||
|
||||
51
vendor/github.com/open-policy-agent/opa/topdown/array.go
generated
vendored
51
vendor/github.com/open-policy-agent/opa/topdown/array.go
generated
vendored
@@ -20,17 +20,20 @@ func builtinArrayConcat(a, b ast.Value) (ast.Value, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
arrC := make(ast.Array, 0, len(arrA)+len(arrB))
|
||||
arrC := make([]*ast.Term, arrA.Len()+arrB.Len())
|
||||
|
||||
for _, elemA := range arrA {
|
||||
arrC = append(arrC, elemA)
|
||||
}
|
||||
i := 0
|
||||
arrA.Foreach(func(elemA *ast.Term) {
|
||||
arrC[i] = elemA
|
||||
i++
|
||||
})
|
||||
|
||||
for _, elemB := range arrB {
|
||||
arrC = append(arrC, elemB)
|
||||
}
|
||||
arrB.Foreach(func(elemB *ast.Term) {
|
||||
arrC[i] = elemB
|
||||
i++
|
||||
})
|
||||
|
||||
return arrC, nil
|
||||
return ast.NewArray(arrC...), nil
|
||||
}
|
||||
|
||||
func builtinArraySlice(a, i, j ast.Value) (ast.Value, error) {
|
||||
@@ -49,27 +52,43 @@ func builtinArraySlice(a, i, j ast.Value) (ast.Value, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return empty array if bounds cannot be clamped sensibly.
|
||||
if (startIndex >= stopIndex) || (startIndex <= 0 && stopIndex <= 0) {
|
||||
return arr[0:0], nil
|
||||
// Clamp stopIndex to avoid out-of-range errors. If negative, clamp to zero.
|
||||
// Otherwise, clamp to length of array.
|
||||
if stopIndex < 0 {
|
||||
stopIndex = 0
|
||||
} else if stopIndex > arr.Len() {
|
||||
stopIndex = arr.Len()
|
||||
}
|
||||
|
||||
// Clamp bounds to avoid out-of-range errors.
|
||||
// Clamp startIndex to avoid out-of-range errors. If negative, clamp to zero.
|
||||
// Otherwise, clamp to stopIndex to avoid to avoid cases like arr[1:0].
|
||||
if startIndex < 0 {
|
||||
startIndex = 0
|
||||
} else if startIndex > stopIndex {
|
||||
startIndex = stopIndex
|
||||
}
|
||||
|
||||
if stopIndex > len(arr) {
|
||||
stopIndex = len(arr)
|
||||
return arr.Slice(startIndex, stopIndex), nil
|
||||
}
|
||||
|
||||
func builtinArrayReverse(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
arr, err := builtins.ArrayOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
arrb := arr[startIndex:stopIndex]
|
||||
length := arr.Len()
|
||||
reversedArr := make([]*ast.Term, length)
|
||||
|
||||
return arrb, nil
|
||||
for index := 0; index < length; index++ {
|
||||
reversedArr[index] = arr.Elem(length - index - 1)
|
||||
}
|
||||
|
||||
return iter(ast.ArrayTerm(reversedArr...))
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin2(ast.ArrayConcat.Name, builtinArrayConcat)
|
||||
RegisterFunctionalBuiltin3(ast.ArraySlice.Name, builtinArraySlice)
|
||||
RegisterBuiltinFunc(ast.ArrayReverse.Name, builtinArrayReverse)
|
||||
}
|
||||
|
||||
53
vendor/github.com/open-policy-agent/opa/topdown/bindings.go
generated
vendored
53
vendor/github.com/open-policy-agent/opa/topdown/bindings.go
generated
vendored
@@ -12,9 +12,8 @@ import (
|
||||
)
|
||||
|
||||
type undo struct {
|
||||
k *ast.Term
|
||||
u *bindings
|
||||
next *undo
|
||||
k *ast.Term
|
||||
u *bindings
|
||||
}
|
||||
|
||||
func (u *undo) Undo() {
|
||||
@@ -27,7 +26,6 @@ func (u *undo) Undo() {
|
||||
return
|
||||
}
|
||||
u.u.delete(u.k)
|
||||
u.next.Undo()
|
||||
}
|
||||
|
||||
type bindings struct {
|
||||
@@ -88,13 +86,16 @@ func (u *bindings) plugNamespaced(a *ast.Term, caller *bindings) *ast.Term {
|
||||
return next.plugNamespaced(b, caller)
|
||||
}
|
||||
return u.namespaceVar(b, caller)
|
||||
case ast.Array:
|
||||
cpy := *a
|
||||
arr := make(ast.Array, len(v))
|
||||
for i := 0; i < len(arr); i++ {
|
||||
arr[i] = u.plugNamespaced(v[i], caller)
|
||||
case *ast.Array:
|
||||
if a.IsGround() {
|
||||
return a
|
||||
}
|
||||
cpy.Value = arr
|
||||
cpy := *a
|
||||
arr := make([]*ast.Term, v.Len())
|
||||
for i := 0; i < len(arr); i++ {
|
||||
arr[i] = u.plugNamespaced(v.Elem(i), caller)
|
||||
}
|
||||
cpy.Value = ast.NewArray(arr...)
|
||||
return &cpy
|
||||
case ast.Object:
|
||||
if a.IsGround() {
|
||||
@@ -106,6 +107,9 @@ func (u *bindings) plugNamespaced(a *ast.Term, caller *bindings) *ast.Term {
|
||||
})
|
||||
return &cpy
|
||||
case ast.Set:
|
||||
if a.IsGround() {
|
||||
return a
|
||||
}
|
||||
cpy := *a
|
||||
cpy.Value, _ = v.Map(func(x *ast.Term) (*ast.Term, error) {
|
||||
return u.plugNamespaced(x, caller), nil
|
||||
@@ -123,19 +127,20 @@ func (u *bindings) plugNamespaced(a *ast.Term, caller *bindings) *ast.Term {
|
||||
return a
|
||||
}
|
||||
|
||||
func (u *bindings) bind(a *ast.Term, b *ast.Term, other *bindings) *undo {
|
||||
func (u *bindings) bind(a *ast.Term, b *ast.Term, other *bindings, und *undo) {
|
||||
u.values.Put(a, value{
|
||||
u: other,
|
||||
v: b,
|
||||
})
|
||||
return &undo{a, u, nil}
|
||||
und.k = a
|
||||
und.u = u
|
||||
}
|
||||
|
||||
func (u *bindings) apply(a *ast.Term) (*ast.Term, *bindings) {
|
||||
// Early exit for non-var terms. Only vars are bound in the binding list,
|
||||
// so the lookup below will always fail for non-var terms. In some cases,
|
||||
// the lookup may be expensive as it has to hash the term (which for large
|
||||
// inputs can be costly.)
|
||||
// inputs can be costly).
|
||||
_, ok := a.Value.(ast.Var)
|
||||
if !ok {
|
||||
return a, u
|
||||
@@ -242,13 +247,16 @@ func (vis namespacingVisitor) namespaceTerm(a *ast.Term) *ast.Term {
|
||||
switch v := a.Value.(type) {
|
||||
case ast.Var:
|
||||
return vis.b.namespaceVar(a, vis.caller)
|
||||
case ast.Array:
|
||||
cpy := *a
|
||||
arr := make(ast.Array, len(v))
|
||||
for i := 0; i < len(arr); i++ {
|
||||
arr[i] = vis.namespaceTerm(v[i])
|
||||
case *ast.Array:
|
||||
if a.IsGround() {
|
||||
return a
|
||||
}
|
||||
cpy.Value = arr
|
||||
cpy := *a
|
||||
arr := make([]*ast.Term, v.Len())
|
||||
for i := 0; i < len(arr); i++ {
|
||||
arr[i] = vis.namespaceTerm(v.Elem(i))
|
||||
}
|
||||
cpy.Value = ast.NewArray(arr...)
|
||||
return &cpy
|
||||
case ast.Object:
|
||||
if a.IsGround() {
|
||||
@@ -260,6 +268,9 @@ func (vis namespacingVisitor) namespaceTerm(a *ast.Term) *ast.Term {
|
||||
})
|
||||
return &cpy
|
||||
case ast.Set:
|
||||
if a.IsGround() {
|
||||
return a
|
||||
}
|
||||
cpy := *a
|
||||
cpy.Value, _ = v.Map(func(x *ast.Term) (*ast.Term, error) {
|
||||
return vis.namespaceTerm(x), nil
|
||||
@@ -279,7 +290,9 @@ func (vis namespacingVisitor) namespaceTerm(a *ast.Term) *ast.Term {
|
||||
|
||||
const maxLinearScan = 16
|
||||
|
||||
// bindingsArrayHashMap uses an array with linear scan instead of a hash map for smaller # of entries. Hash maps start to show off their performance advantage only after 16 keys.
|
||||
// bindingsArrayHashMap uses an array with linear scan instead
|
||||
// of a hash map for smaller # of entries. Hash maps start to
|
||||
// show off their performance advantage only after 16 keys.
|
||||
type bindingsArrayHashmap struct {
|
||||
n int // Entries in the array.
|
||||
a *[maxLinearScan]bindingArrayKeyValue
|
||||
|
||||
64
vendor/github.com/open-policy-agent/opa/topdown/builtins.go
generated
vendored
64
vendor/github.com/open-policy-agent/opa/topdown/builtins.go
generated
vendored
@@ -6,10 +6,17 @@ package topdown
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/metrics"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
"github.com/open-policy-agent/opa/topdown/cache"
|
||||
"github.com/open-policy-agent/opa/topdown/print"
|
||||
"github.com/open-policy-agent/opa/tracing"
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -28,14 +35,25 @@ type (
|
||||
// BuiltinContext contains context from the evaluator that may be used by
|
||||
// built-in functions.
|
||||
BuiltinContext struct {
|
||||
Context context.Context // request context that was passed when query started
|
||||
Cancel Cancel // atomic value that signals evaluation to halt
|
||||
Runtime *ast.Term // runtime information on the OPA instance
|
||||
Cache builtins.Cache // built-in function state cache
|
||||
Location *ast.Location // location of built-in call
|
||||
Tracers []Tracer // tracer objects for trace() built-in function
|
||||
QueryID uint64 // identifies query being evaluated
|
||||
ParentID uint64 // identifies parent of query being evaluated
|
||||
Context context.Context // request context that was passed when query started
|
||||
Metrics metrics.Metrics // metrics registry for recording built-in specific metrics
|
||||
Seed io.Reader // randomization source
|
||||
Time *ast.Term // wall clock time
|
||||
Cancel Cancel // atomic value that signals evaluation to halt
|
||||
Runtime *ast.Term // runtime information on the OPA instance
|
||||
Cache builtins.Cache // built-in function state cache
|
||||
InterQueryBuiltinCache cache.InterQueryCache // cross-query built-in function state cache
|
||||
NDBuiltinCache builtins.NDBCache // cache for non-deterministic built-in state
|
||||
Location *ast.Location // location of built-in call
|
||||
Tracers []Tracer // Deprecated: Use QueryTracers instead
|
||||
QueryTracers []QueryTracer // tracer objects for trace() built-in function
|
||||
TraceEnabled bool // indicates whether tracing is enabled for the evaluation
|
||||
QueryID uint64 // identifies query being evaluated
|
||||
ParentID uint64 // identifies parent of query being evaluated
|
||||
PrintHook print.Hook // provides callback function to use for printing
|
||||
DistributedTracingOpts tracing.Options // options to be used by distributed tracing.
|
||||
rand *rand.Rand // randomization source for non-security-sensitive operations
|
||||
Capabilities *ast.Capabilities
|
||||
}
|
||||
|
||||
// BuiltinFunc defines an interface for implementing built-in functions.
|
||||
@@ -46,6 +64,25 @@ type (
|
||||
BuiltinFunc func(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error
|
||||
)
|
||||
|
||||
// Rand returns a random number generator based on the Seed for this built-in
|
||||
// context. The random number will be re-used across multiple calls to this
|
||||
// function. If a random number generator cannot be created, an error is
|
||||
// returned.
|
||||
func (bctx *BuiltinContext) Rand() (*rand.Rand, error) {
|
||||
|
||||
if bctx.rand != nil {
|
||||
return bctx.rand, nil
|
||||
}
|
||||
|
||||
seed, err := readInt64(bctx.Seed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bctx.rand = rand.New(rand.NewSource(seed))
|
||||
return bctx.rand, nil
|
||||
}
|
||||
|
||||
// RegisterBuiltinFunc adds a new built-in function to the evaluation engine.
|
||||
func RegisterBuiltinFunc(name string, f BuiltinFunc) {
|
||||
builtinFunctions[name] = builtinErrorWrapper(name, f)
|
||||
@@ -142,7 +179,7 @@ func handleBuiltinErr(name string, loc *ast.Location, err error) error {
|
||||
switch err := err.(type) {
|
||||
case BuiltinEmpty:
|
||||
return nil
|
||||
case *Error:
|
||||
case *Error, Halt:
|
||||
return err
|
||||
case builtins.ErrOperand:
|
||||
return &Error{
|
||||
@@ -158,3 +195,12 @@ func handleBuiltinErr(name string, loc *ast.Location, err error) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readInt64(r io.Reader) (int64, error) {
|
||||
bs := make([]byte, 8)
|
||||
n, err := io.ReadFull(r, bs)
|
||||
if n != len(bs) || err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(binary.BigEndian.Uint64(bs)), nil
|
||||
}
|
||||
|
||||
129
vendor/github.com/open-policy-agent/opa/topdown/builtins/builtins.go
generated
vendored
129
vendor/github.com/open-policy-agent/opa/topdown/builtins/builtins.go
generated
vendored
@@ -6,11 +6,13 @@
|
||||
package builtins
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/util"
|
||||
)
|
||||
|
||||
// Cache defines the built-in cache used by the top-down evaluation. The keys
|
||||
@@ -28,6 +30,85 @@ func (c Cache) Get(k interface{}) (interface{}, bool) {
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// We use an ast.Object for the cached keys/values because a naive
|
||||
// map[ast.Value]ast.Value will not correctly detect value equality of
|
||||
// the member keys.
|
||||
type NDBCache map[string]ast.Object
|
||||
|
||||
func (c NDBCache) AsValue() ast.Value {
|
||||
out := ast.NewObject()
|
||||
for bname, obj := range c {
|
||||
out.Insert(ast.StringTerm(bname), ast.NewTerm(obj))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Put updates the cache for the named built-in.
|
||||
// Automatically creates the 2-level hierarchy as needed.
|
||||
func (c NDBCache) Put(name string, k, v ast.Value) {
|
||||
if _, ok := c[name]; !ok {
|
||||
c[name] = ast.NewObject()
|
||||
}
|
||||
c[name].Insert(ast.NewTerm(k), ast.NewTerm(v))
|
||||
}
|
||||
|
||||
// Get returns the cached value for k for the named builtin.
|
||||
func (c NDBCache) Get(name string, k ast.Value) (ast.Value, bool) {
|
||||
if m, ok := c[name]; ok {
|
||||
v := m.Get(ast.NewTerm(k))
|
||||
if v != nil {
|
||||
return v.Value, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Convenience functions for serializing the data structure.
|
||||
func (c NDBCache) MarshalJSON() ([]byte, error) {
|
||||
v, err := ast.JSON(c.AsValue())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
func (c *NDBCache) UnmarshalJSON(data []byte) error {
|
||||
out := map[string]ast.Object{}
|
||||
var incoming interface{}
|
||||
|
||||
// Note: We use util.Unmarshal instead of json.Unmarshal to get
|
||||
// correct deserialization of number types.
|
||||
err := util.Unmarshal(data, &incoming)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Convert interface types back into ast.Value types.
|
||||
nestedObject, err := ast.InterfaceToValue(incoming)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Reconstruct NDBCache from nested ast.Object structure.
|
||||
if source, ok := nestedObject.(ast.Object); ok {
|
||||
err = source.Iter(func(k, v *ast.Term) error {
|
||||
if obj, ok := v.Value.(ast.Object); ok {
|
||||
out[string(k.Value.(ast.String))] = obj
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("expected Object, got other Value type in conversion")
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
*c = out
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ErrOperand represents an invalid operand has been passed to a built-in
|
||||
// function. Built-ins should return ErrOperand to indicate a type error has
|
||||
// occurred.
|
||||
@@ -149,10 +230,10 @@ func ObjectOperand(x ast.Value, pos int) (ast.Object, error) {
|
||||
|
||||
// ArrayOperand converts x to an array. If the cast fails, a descriptive
|
||||
// error is returned.
|
||||
func ArrayOperand(x ast.Value, pos int) (ast.Array, error) {
|
||||
a, ok := x.(ast.Array)
|
||||
func ArrayOperand(x ast.Value, pos int) (*ast.Array, error) {
|
||||
a, ok := x.(*ast.Array)
|
||||
if !ok {
|
||||
return nil, NewOperandTypeErr(pos, x, "array")
|
||||
return ast.NewArray(), NewOperandTypeErr(pos, x, "array")
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
@@ -168,7 +249,7 @@ func NumberToFloat(n ast.Number) *big.Float {
|
||||
|
||||
// FloatToNumber converts f to a number.
|
||||
func FloatToNumber(f *big.Float) ast.Number {
|
||||
return ast.Number(f.String())
|
||||
return ast.Number(f.Text('g', -1))
|
||||
}
|
||||
|
||||
// NumberToInt converts n to a big int.
|
||||
@@ -189,23 +270,30 @@ func IntToNumber(i *big.Int) ast.Number {
|
||||
|
||||
// StringSliceOperand converts x to a []string. If the cast fails, a descriptive error is
|
||||
// returned.
|
||||
func StringSliceOperand(x ast.Value, pos int) ([]string, error) {
|
||||
a, err := ArrayOperand(x, pos)
|
||||
if err != nil {
|
||||
func StringSliceOperand(a ast.Value, pos int) ([]string, error) {
|
||||
type iterable interface {
|
||||
Iter(func(*ast.Term) error) error
|
||||
Len() int
|
||||
}
|
||||
|
||||
strs, ok := a.(iterable)
|
||||
if !ok {
|
||||
return nil, NewOperandTypeErr(pos, a, "array", "set")
|
||||
}
|
||||
|
||||
var outStrs = make([]string, 0, strs.Len())
|
||||
if err := strs.Iter(func(x *ast.Term) error {
|
||||
s, ok := x.Value.(ast.String)
|
||||
if !ok {
|
||||
return NewOperandElementErr(pos, a, x.Value, "string")
|
||||
}
|
||||
outStrs = append(outStrs, string(s))
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var f = make([]string, len(a))
|
||||
for k, b := range a {
|
||||
c, ok := b.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, NewOperandElementErr(pos, x, b.Value, "[]string")
|
||||
}
|
||||
|
||||
f[k] = string(c)
|
||||
}
|
||||
|
||||
return f, nil
|
||||
return outStrs, nil
|
||||
}
|
||||
|
||||
// RuneSliceOperand converts x to a []rune. If the cast fails, a descriptive error is
|
||||
@@ -216,8 +304,9 @@ func RuneSliceOperand(x ast.Value, pos int) ([]rune, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var f = make([]rune, len(a))
|
||||
for k, b := range a {
|
||||
var f = make([]rune, a.Len())
|
||||
for k := 0; k < a.Len(); k++ {
|
||||
b := a.Elem(k)
|
||||
c, ok := b.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, NewOperandElementErr(pos, x, b.Value, "string")
|
||||
|
||||
125
vendor/github.com/open-policy-agent/opa/topdown/cache.go
generated
vendored
125
vendor/github.com/open-policy-agent/opa/topdown/cache.go
generated
vendored
@@ -164,3 +164,128 @@ func (s *refStack) Prefixed(ref ast.Ref) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type comprehensionCache struct {
|
||||
stack []map[*ast.Term]*comprehensionCacheElem
|
||||
}
|
||||
|
||||
type comprehensionCacheElem struct {
|
||||
value *ast.Term
|
||||
children *util.HashMap
|
||||
}
|
||||
|
||||
func newComprehensionCache() *comprehensionCache {
|
||||
cache := &comprehensionCache{}
|
||||
cache.Push()
|
||||
return cache
|
||||
}
|
||||
|
||||
func (c *comprehensionCache) Push() {
|
||||
c.stack = append(c.stack, map[*ast.Term]*comprehensionCacheElem{})
|
||||
}
|
||||
|
||||
func (c *comprehensionCache) Pop() {
|
||||
c.stack = c.stack[:len(c.stack)-1]
|
||||
}
|
||||
|
||||
func (c *comprehensionCache) Elem(t *ast.Term) (*comprehensionCacheElem, bool) {
|
||||
elem, ok := c.stack[len(c.stack)-1][t]
|
||||
return elem, ok
|
||||
}
|
||||
|
||||
func (c *comprehensionCache) Set(t *ast.Term, elem *comprehensionCacheElem) {
|
||||
c.stack[len(c.stack)-1][t] = elem
|
||||
}
|
||||
|
||||
func newComprehensionCacheElem() *comprehensionCacheElem {
|
||||
return &comprehensionCacheElem{children: newComprehensionCacheHashMap()}
|
||||
}
|
||||
|
||||
func (c *comprehensionCacheElem) Get(key []*ast.Term) *ast.Term {
|
||||
node := c
|
||||
for i := 0; i < len(key); i++ {
|
||||
x, ok := node.children.Get(key[i])
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
node = x.(*comprehensionCacheElem)
|
||||
}
|
||||
return node.value
|
||||
}
|
||||
|
||||
func (c *comprehensionCacheElem) Put(key []*ast.Term, value *ast.Term) {
|
||||
node := c
|
||||
for i := 0; i < len(key); i++ {
|
||||
x, ok := node.children.Get(key[i])
|
||||
if ok {
|
||||
node = x.(*comprehensionCacheElem)
|
||||
} else {
|
||||
next := newComprehensionCacheElem()
|
||||
node.children.Put(key[i], next)
|
||||
node = next
|
||||
}
|
||||
}
|
||||
node.value = value
|
||||
}
|
||||
|
||||
func newComprehensionCacheHashMap() *util.HashMap {
|
||||
return util.NewHashMap(func(a, b util.T) bool {
|
||||
return a.(*ast.Term).Equal(b.(*ast.Term))
|
||||
}, func(x util.T) int {
|
||||
return x.(*ast.Term).Hash()
|
||||
})
|
||||
}
|
||||
|
||||
type functionMocksStack struct {
|
||||
stack []*functionMocksElem
|
||||
}
|
||||
|
||||
type functionMocksElem []frame
|
||||
|
||||
type frame map[string]*ast.Term
|
||||
|
||||
func newFunctionMocksStack() *functionMocksStack {
|
||||
stack := &functionMocksStack{}
|
||||
stack.Push()
|
||||
return stack
|
||||
}
|
||||
|
||||
func newFunctionMocksElem() *functionMocksElem {
|
||||
return &functionMocksElem{}
|
||||
}
|
||||
|
||||
func (s *functionMocksStack) Push() {
|
||||
s.stack = append(s.stack, newFunctionMocksElem())
|
||||
}
|
||||
|
||||
func (s *functionMocksStack) Pop() {
|
||||
s.stack = s.stack[:len(s.stack)-1]
|
||||
}
|
||||
|
||||
func (s *functionMocksStack) PopPairs() {
|
||||
current := s.stack[len(s.stack)-1]
|
||||
*current = (*current)[:len(*current)-1]
|
||||
}
|
||||
|
||||
func (s *functionMocksStack) PutPairs(mocks [][2]*ast.Term) {
|
||||
el := frame{}
|
||||
for i := range mocks {
|
||||
el[mocks[i][0].Value.String()] = mocks[i][1]
|
||||
}
|
||||
s.Put(el)
|
||||
}
|
||||
|
||||
func (s *functionMocksStack) Put(el frame) {
|
||||
current := s.stack[len(s.stack)-1]
|
||||
*current = append(*current, el)
|
||||
}
|
||||
|
||||
func (s *functionMocksStack) Get(f ast.Ref) (*ast.Term, bool) {
|
||||
current := *s.stack[len(s.stack)-1]
|
||||
for i := len(current) - 1; i >= 0; i-- {
|
||||
if r, ok := current[i][f.String()]; ok {
|
||||
return r, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
167
vendor/github.com/open-policy-agent/opa/topdown/cache/cache.go
generated
vendored
Normal file
167
vendor/github.com/open-policy-agent/opa/topdown/cache/cache.go
generated
vendored
Normal file
@@ -0,0 +1,167 @@
|
||||
// Copyright 2020 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package cache defines the inter-query cache interface that can cache data across queries
|
||||
package cache
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
|
||||
"sync"
|
||||
|
||||
"github.com/open-policy-agent/opa/util"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMaxSizeBytes = int64(0) // unlimited
|
||||
)
|
||||
|
||||
// Config represents the configuration of the inter-query cache.
|
||||
type Config struct {
|
||||
InterQueryBuiltinCache InterQueryBuiltinCacheConfig `json:"inter_query_builtin_cache"`
|
||||
}
|
||||
|
||||
// InterQueryBuiltinCacheConfig represents the configuration of the inter-query cache that built-in functions can utilize.
|
||||
type InterQueryBuiltinCacheConfig struct {
|
||||
MaxSizeBytes *int64 `json:"max_size_bytes,omitempty"`
|
||||
}
|
||||
|
||||
// ParseCachingConfig returns the config for the inter-query cache.
|
||||
func ParseCachingConfig(raw []byte) (*Config, error) {
|
||||
if raw == nil {
|
||||
maxSize := new(int64)
|
||||
*maxSize = defaultMaxSizeBytes
|
||||
return &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize}}, nil
|
||||
}
|
||||
|
||||
var config Config
|
||||
|
||||
if err := util.Unmarshal(raw, &config); err == nil {
|
||||
if err = config.validateAndInjectDefaults(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func (c *Config) validateAndInjectDefaults() error {
|
||||
if c.InterQueryBuiltinCache.MaxSizeBytes == nil {
|
||||
maxSize := new(int64)
|
||||
*maxSize = defaultMaxSizeBytes
|
||||
c.InterQueryBuiltinCache.MaxSizeBytes = maxSize
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// InterQueryCacheValue defines the interface for the data that the inter-query cache holds.
|
||||
type InterQueryCacheValue interface {
|
||||
SizeInBytes() int64
|
||||
}
|
||||
|
||||
// InterQueryCache defines the interface for the inter-query cache.
|
||||
type InterQueryCache interface {
|
||||
Get(key ast.Value) (value InterQueryCacheValue, found bool)
|
||||
Insert(key ast.Value, value InterQueryCacheValue) int
|
||||
Delete(key ast.Value)
|
||||
UpdateConfig(config *Config)
|
||||
}
|
||||
|
||||
// NewInterQueryCache returns a new inter-query cache.
|
||||
func NewInterQueryCache(config *Config) InterQueryCache {
|
||||
return &cache{
|
||||
items: map[string]InterQueryCacheValue{},
|
||||
usage: 0,
|
||||
config: config,
|
||||
l: list.New(),
|
||||
}
|
||||
}
|
||||
|
||||
type cache struct {
|
||||
items map[string]InterQueryCacheValue
|
||||
usage int64
|
||||
config *Config
|
||||
l *list.List
|
||||
mtx sync.Mutex
|
||||
}
|
||||
|
||||
// Insert inserts a key k into the cache with value v.
|
||||
func (c *cache) Insert(k ast.Value, v InterQueryCacheValue) (dropped int) {
|
||||
c.mtx.Lock()
|
||||
defer c.mtx.Unlock()
|
||||
return c.unsafeInsert(k, v)
|
||||
}
|
||||
|
||||
// Get returns the value in the cache for k.
|
||||
func (c *cache) Get(k ast.Value) (InterQueryCacheValue, bool) {
|
||||
c.mtx.Lock()
|
||||
defer c.mtx.Unlock()
|
||||
return c.unsafeGet(k)
|
||||
}
|
||||
|
||||
// Delete deletes the value in the cache for k.
|
||||
func (c *cache) Delete(k ast.Value) {
|
||||
c.mtx.Lock()
|
||||
defer c.mtx.Unlock()
|
||||
c.unsafeDelete(k)
|
||||
}
|
||||
|
||||
func (c *cache) UpdateConfig(config *Config) {
|
||||
if config == nil {
|
||||
return
|
||||
}
|
||||
c.mtx.Lock()
|
||||
defer c.mtx.Unlock()
|
||||
c.config = config
|
||||
}
|
||||
|
||||
func (c *cache) unsafeInsert(k ast.Value, v InterQueryCacheValue) (dropped int) {
|
||||
size := v.SizeInBytes()
|
||||
limit := c.maxSizeBytes()
|
||||
|
||||
if limit > 0 {
|
||||
if size > limit {
|
||||
dropped++
|
||||
return dropped
|
||||
}
|
||||
|
||||
for key := c.l.Front(); key != nil && (c.usage+size > limit); key = c.l.Front() {
|
||||
dropKey := key.Value.(ast.Value)
|
||||
c.unsafeDelete(dropKey)
|
||||
c.l.Remove(key)
|
||||
dropped++
|
||||
}
|
||||
}
|
||||
|
||||
c.items[k.String()] = v
|
||||
c.l.PushBack(k)
|
||||
c.usage += size
|
||||
return dropped
|
||||
}
|
||||
|
||||
func (c *cache) unsafeGet(k ast.Value) (InterQueryCacheValue, bool) {
|
||||
value, ok := c.items[k.String()]
|
||||
return value, ok
|
||||
}
|
||||
|
||||
func (c *cache) unsafeDelete(k ast.Value) {
|
||||
value, ok := c.unsafeGet(k)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
c.usage -= int64(value.SizeInBytes())
|
||||
delete(c.items, k.String())
|
||||
}
|
||||
|
||||
func (c *cache) maxSizeBytes() int64 {
|
||||
if c.config == nil {
|
||||
return defaultMaxSizeBytes
|
||||
}
|
||||
return *c.config.InterQueryBuiltinCache.MaxSizeBytes
|
||||
}
|
||||
14
vendor/github.com/open-policy-agent/opa/topdown/casts.go
generated
vendored
14
vendor/github.com/open-policy-agent/opa/topdown/casts.go
generated
vendored
@@ -35,16 +35,16 @@ func builtinToNumber(a ast.Value) (ast.Value, error) {
|
||||
// Deprecated in v0.13.0.
|
||||
func builtinToArray(a ast.Value) (ast.Value, error) {
|
||||
switch val := a.(type) {
|
||||
case ast.Array:
|
||||
case *ast.Array:
|
||||
return val, nil
|
||||
case ast.Set:
|
||||
arr := make(ast.Array, val.Len())
|
||||
arr := make([]*ast.Term, val.Len())
|
||||
i := 0
|
||||
val.Foreach(func(term *ast.Term) {
|
||||
arr[i] = term
|
||||
i++
|
||||
})
|
||||
return arr, nil
|
||||
return ast.NewArray(arr...), nil
|
||||
default:
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "array", "set")
|
||||
}
|
||||
@@ -53,8 +53,12 @@ func builtinToArray(a ast.Value) (ast.Value, error) {
|
||||
// Deprecated in v0.13.0.
|
||||
func builtinToSet(a ast.Value) (ast.Value, error) {
|
||||
switch val := a.(type) {
|
||||
case ast.Array:
|
||||
return ast.NewSet(val...), nil
|
||||
case *ast.Array:
|
||||
s := ast.NewSet()
|
||||
val.Foreach(func(v *ast.Term) {
|
||||
s.Add(v)
|
||||
})
|
||||
return s, nil
|
||||
case ast.Set:
|
||||
return val, nil
|
||||
default:
|
||||
|
||||
258
vendor/github.com/open-policy-agent/opa/topdown/cidr.go
generated
vendored
258
vendor/github.com/open-policy-agent/opa/topdown/cidr.go
generated
vendored
@@ -1,11 +1,15 @@
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"sort"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
cidrMerge "github.com/open-policy-agent/opa/internal/cidr/merge"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
@@ -69,7 +73,7 @@ func builtinNetCIDRIntersects(a, b ast.Value) (ast.Value, error) {
|
||||
}
|
||||
|
||||
// If either net contains the others starting IP they are overlapping
|
||||
cidrsOverlap := (cidrnetA.Contains(cidrnetB.IP) || cidrnetB.Contains(cidrnetA.IP))
|
||||
cidrsOverlap := cidrnetA.Contains(cidrnetB.IP) || cidrnetB.Contains(cidrnetA.IP)
|
||||
|
||||
return ast.Boolean(cidrsOverlap), nil
|
||||
}
|
||||
@@ -97,7 +101,8 @@ func builtinNetCIDRContains(a, b ast.Value) (ast.Value, error) {
|
||||
return nil, fmt.Errorf("not a valid textual representation of an IP address or CIDR: %s", string(bStr))
|
||||
}
|
||||
|
||||
// We can determine if cidr A contains cidr B iff A contains the starting address of B and the last address in B.
|
||||
// We can determine if cidr A contains cidr B if and only if A contains
|
||||
// the starting address of B and the last address in B.
|
||||
cidrContained := false
|
||||
if cidrnetA.Contains(cidrnetB.IP) {
|
||||
// Only spend time calculating the last IP if the starting IP is already verified to be in cidr A
|
||||
@@ -111,6 +116,75 @@ func builtinNetCIDRContains(a, b ast.Value) (ast.Value, error) {
|
||||
return ast.Boolean(cidrContained), nil
|
||||
}
|
||||
|
||||
var errNetCIDRContainsMatchElementType = errors.New("element must be string or non-empty array")
|
||||
|
||||
func getCIDRMatchTerm(a *ast.Term) (*ast.Term, error) {
|
||||
switch v := a.Value.(type) {
|
||||
case ast.String:
|
||||
return a, nil
|
||||
case *ast.Array:
|
||||
if v.Len() == 0 {
|
||||
return nil, errNetCIDRContainsMatchElementType
|
||||
}
|
||||
return v.Elem(0), nil
|
||||
default:
|
||||
return nil, errNetCIDRContainsMatchElementType
|
||||
}
|
||||
}
|
||||
|
||||
func evalNetCIDRContainsMatchesOperand(operand int, a *ast.Term, iter func(cidr, index *ast.Term) error) error {
|
||||
switch v := a.Value.(type) {
|
||||
case ast.String:
|
||||
return iter(a, a)
|
||||
case *ast.Array:
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
cidr, err := getCIDRMatchTerm(v.Elem(i))
|
||||
if err != nil {
|
||||
return fmt.Errorf("operand %v: %v", operand, err)
|
||||
}
|
||||
if err := iter(cidr, ast.IntNumberTerm(i)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case ast.Set:
|
||||
return v.Iter(func(x *ast.Term) error {
|
||||
cidr, err := getCIDRMatchTerm(x)
|
||||
if err != nil {
|
||||
return fmt.Errorf("operand %v: %v", operand, err)
|
||||
}
|
||||
return iter(cidr, x)
|
||||
})
|
||||
case ast.Object:
|
||||
return v.Iter(func(k, v *ast.Term) error {
|
||||
cidr, err := getCIDRMatchTerm(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("operand %v: %v", operand, err)
|
||||
}
|
||||
return iter(cidr, k)
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func builtinNetCIDRContainsMatches(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
result := ast.NewSet()
|
||||
err := evalNetCIDRContainsMatchesOperand(1, args[0], func(cidr1 *ast.Term, index1 *ast.Term) error {
|
||||
return evalNetCIDRContainsMatchesOperand(2, args[1], func(cidr2 *ast.Term, index2 *ast.Term) error {
|
||||
if v, err := builtinNetCIDRContains(cidr1.Value, cidr2.Value); err != nil {
|
||||
return err
|
||||
} else if vb, ok := v.(ast.Boolean); ok && bool(vb) {
|
||||
result.Add(ast.ArrayTerm(index1, index2))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
if err == nil {
|
||||
return iter(ast.NewTerm(result))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func builtinNetCIDRExpand(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
s, err := builtins.StringOperand(operands[0].Value, 1)
|
||||
@@ -128,9 +202,11 @@ func builtinNetCIDRExpand(bctx BuiltinContext, operands []*ast.Term, iter func(*
|
||||
for ip := ip.Mask(ipNet.Mask); ipNet.Contains(ip); incIP(ip) {
|
||||
|
||||
if bctx.Cancel != nil && bctx.Cancel.Cancelled() {
|
||||
return &Error{
|
||||
Code: CancelErr,
|
||||
Message: "net.cidr_expand: timed out before generating all IP addresses",
|
||||
return Halt{
|
||||
Err: &Error{
|
||||
Code: CancelErr,
|
||||
Message: "net.cidr_expand: timed out before generating all IP addresses",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,6 +216,176 @@ func builtinNetCIDRExpand(bctx BuiltinContext, operands []*ast.Term, iter func(*
|
||||
return iter(ast.NewTerm(result))
|
||||
}
|
||||
|
||||
type cidrBlockRange struct {
|
||||
First *net.IP
|
||||
Last *net.IP
|
||||
Network *net.IPNet
|
||||
}
|
||||
|
||||
type cidrBlockRanges []*cidrBlockRange
|
||||
|
||||
// Implement Sort interface
|
||||
func (c cidrBlockRanges) Len() int {
|
||||
return len(c)
|
||||
}
|
||||
|
||||
func (c cidrBlockRanges) Swap(i, j int) {
|
||||
c[i], c[j] = c[j], c[i]
|
||||
}
|
||||
|
||||
func (c cidrBlockRanges) Less(i, j int) bool {
|
||||
// Compare last IP.
|
||||
cmp := bytes.Compare(*c[i].Last, *c[j].Last)
|
||||
if cmp < 0 {
|
||||
return true
|
||||
} else if cmp > 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Then compare first IP.
|
||||
cmp = bytes.Compare(*c[i].First, *c[i].First)
|
||||
if cmp < 0 {
|
||||
return true
|
||||
} else if cmp > 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Ranges are Equal.
|
||||
return false
|
||||
}
|
||||
|
||||
// builtinNetCIDRMerge merges the provided list of IP addresses and subnets into the smallest possible list of CIDRs.
|
||||
// It merges adjacent subnets where possible, those contained within others and also removes any duplicates.
|
||||
// Original Algorithm: https://github.com/netaddr/netaddr.
|
||||
func builtinNetCIDRMerge(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
networks := []*net.IPNet{}
|
||||
|
||||
switch v := operands[0].Value.(type) {
|
||||
case *ast.Array:
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
network, err := generateIPNet(v.Elem(i))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
networks = append(networks, network)
|
||||
}
|
||||
case ast.Set:
|
||||
err := v.Iter(func(x *ast.Term) error {
|
||||
network, err := generateIPNet(x)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
networks = append(networks, network)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return errors.New("operand must be an array")
|
||||
}
|
||||
|
||||
merged := evalNetCIDRMerge(networks)
|
||||
|
||||
result := ast.NewSet()
|
||||
for _, network := range merged {
|
||||
result.Add(ast.StringTerm(network.String()))
|
||||
}
|
||||
|
||||
return iter(ast.NewTerm(result))
|
||||
}
|
||||
|
||||
func evalNetCIDRMerge(networks []*net.IPNet) []*net.IPNet {
|
||||
if len(networks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ranges := make(cidrBlockRanges, 0, len(networks))
|
||||
|
||||
// For each CIDR, create an IP range. Sort them and merge when possible.
|
||||
for _, network := range networks {
|
||||
firstIP, lastIP := cidrMerge.GetAddressRange(*network)
|
||||
ranges = append(ranges, &cidrBlockRange{
|
||||
First: &firstIP,
|
||||
Last: &lastIP,
|
||||
Network: network,
|
||||
})
|
||||
}
|
||||
|
||||
// merge CIDRs.
|
||||
merged := mergeCIDRs(ranges)
|
||||
|
||||
// convert ranges into an equivalent list of net.IPNet.
|
||||
result := []*net.IPNet{}
|
||||
|
||||
for _, r := range merged {
|
||||
// Not merged with any other CIDR.
|
||||
if r.Network != nil {
|
||||
result = append(result, r.Network)
|
||||
} else {
|
||||
// Find new network that represents the merged range.
|
||||
rangeCIDRs := cidrMerge.RangeToCIDRs(*r.First, *r.Last)
|
||||
result = append(result, rangeCIDRs...)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func generateIPNet(term *ast.Term) (*net.IPNet, error) {
|
||||
e, ok := term.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, errors.New("element must be string")
|
||||
}
|
||||
|
||||
// try to parse element as an IP first, fall back to CIDR
|
||||
ip := net.ParseIP(string(e))
|
||||
if ip == nil {
|
||||
_, network, err := net.ParseCIDR(string(e))
|
||||
return network, err
|
||||
}
|
||||
|
||||
if ip.To4() != nil {
|
||||
return &net.IPNet{
|
||||
IP: ip,
|
||||
Mask: ip.DefaultMask(),
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("IPv6 invalid: needs prefix length")
|
||||
}
|
||||
|
||||
func mergeCIDRs(ranges cidrBlockRanges) cidrBlockRanges {
|
||||
sort.Sort(ranges)
|
||||
|
||||
// Merge adjacent CIDRs if possible.
|
||||
for i := len(ranges) - 1; i > 0; i-- {
|
||||
previousIP := cidrMerge.GetPreviousIP(*ranges[i].First)
|
||||
|
||||
// If the previous IP of the current network overlaps
|
||||
// with the last IP of the previous network in the
|
||||
// list, then merge the two ranges together.
|
||||
if bytes.Compare(previousIP, *ranges[i-1].Last) <= 0 {
|
||||
var firstIP *net.IP
|
||||
if bytes.Compare(*ranges[i-1].First, *ranges[i].First) < 0 {
|
||||
firstIP = ranges[i-1].First
|
||||
} else {
|
||||
firstIP = ranges[i].First
|
||||
}
|
||||
|
||||
lastIPRange := make(net.IP, len(*ranges[i].Last))
|
||||
copy(lastIPRange, *ranges[i].Last)
|
||||
|
||||
firstIPRange := make(net.IP, len(*firstIP))
|
||||
copy(firstIPRange, *firstIP)
|
||||
|
||||
ranges[i-1] = &cidrBlockRange{First: &firstIPRange, Last: &lastIPRange, Network: nil}
|
||||
|
||||
// Delete ranges[i] since merged with the previous.
|
||||
ranges = append(ranges[:i], ranges[i+1:]...)
|
||||
}
|
||||
}
|
||||
return ranges
|
||||
}
|
||||
|
||||
func incIP(ip net.IP) {
|
||||
for j := len(ip) - 1; j >= 0; j-- {
|
||||
ip[j]++
|
||||
@@ -153,5 +399,7 @@ func init() {
|
||||
RegisterFunctionalBuiltin2(ast.NetCIDROverlap.Name, builtinNetCIDRContains)
|
||||
RegisterFunctionalBuiltin2(ast.NetCIDRIntersects.Name, builtinNetCIDRIntersects)
|
||||
RegisterFunctionalBuiltin2(ast.NetCIDRContains.Name, builtinNetCIDRContains)
|
||||
RegisterBuiltinFunc(ast.NetCIDRContainsMatches.Name, builtinNetCIDRContainsMatches)
|
||||
RegisterBuiltinFunc(ast.NetCIDRExpand.Name, builtinNetCIDRExpand)
|
||||
RegisterBuiltinFunc(ast.NetCIDRMerge.Name, builtinNetCIDRMerge)
|
||||
}
|
||||
|
||||
345
vendor/github.com/open-policy-agent/opa/topdown/copypropagation/copypropagation.go
generated
vendored
345
vendor/github.com/open-policy-agent/opa/topdown/copypropagation/copypropagation.go
generated
vendored
@@ -5,6 +5,7 @@
|
||||
package copypropagation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
@@ -30,6 +31,19 @@ type CopyPropagator struct {
|
||||
livevars ast.VarSet // vars that must be preserved in the resulting query
|
||||
sorted []ast.Var // sorted copy of vars to ensure deterministic result
|
||||
ensureNonEmptyBody bool
|
||||
compiler *ast.Compiler
|
||||
localvargen *localVarGenerator
|
||||
}
|
||||
|
||||
type localVarGenerator struct {
|
||||
next int
|
||||
}
|
||||
|
||||
func (l *localVarGenerator) Generate() ast.Var {
|
||||
result := ast.Var(fmt.Sprintf("__localcp%d__", l.next))
|
||||
l.next++
|
||||
return result
|
||||
|
||||
}
|
||||
|
||||
// New returns a new CopyPropagator that optimizes queries while preserving vars
|
||||
@@ -45,7 +59,7 @@ func New(livevars ast.VarSet) *CopyPropagator {
|
||||
return sorted[i].Compare(sorted[j]) < 0
|
||||
})
|
||||
|
||||
return &CopyPropagator{livevars: livevars, sorted: sorted}
|
||||
return &CopyPropagator{livevars: livevars, sorted: sorted, localvargen: &localVarGenerator{}}
|
||||
}
|
||||
|
||||
// WithEnsureNonEmptyBody configures p to ensure that results are always non-empty.
|
||||
@@ -54,8 +68,17 @@ func (p *CopyPropagator) WithEnsureNonEmptyBody(yes bool) *CopyPropagator {
|
||||
return p
|
||||
}
|
||||
|
||||
// WithCompiler configures the compiler to read from while processing the query. This
|
||||
// should be the same compiler used to compile the original policy.
|
||||
func (p *CopyPropagator) WithCompiler(c *ast.Compiler) *CopyPropagator {
|
||||
p.compiler = c
|
||||
return p
|
||||
}
|
||||
|
||||
// Apply executes the copy propagation optimization and returns a new query.
|
||||
func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) {
|
||||
func (p *CopyPropagator) Apply(query ast.Body) ast.Body {
|
||||
|
||||
result := ast.NewBody()
|
||||
|
||||
uf, ok := makeDisjointSets(p.livevars, query)
|
||||
if !ok {
|
||||
@@ -63,14 +86,15 @@ func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) {
|
||||
}
|
||||
|
||||
// Compute set of vars that appear in the head of refs in the query. If a var
|
||||
// is dereferenced, we cannot plug it with a constant value so the constant on
|
||||
// the union-find root must be unset (e.g., [1][0] is not legal.)
|
||||
// is dereferenced, we can plug it with a constant value, but it is not always
|
||||
// optimal to do so.
|
||||
// TODO: Improve the algorithm for when we should plug constants/calls/etc
|
||||
headvars := ast.NewVarSet()
|
||||
ast.WalkRefs(query, func(x ast.Ref) bool {
|
||||
if v, ok := x[0].Value.(ast.Var); ok {
|
||||
if root, ok := uf.Find(v); ok {
|
||||
root.constant = nil
|
||||
headvars.Add(root.key)
|
||||
headvars.Add(root.key.(ast.Var))
|
||||
} else {
|
||||
headvars.Add(v)
|
||||
}
|
||||
@@ -78,21 +102,21 @@ func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) {
|
||||
return false
|
||||
})
|
||||
|
||||
bindings := map[ast.Var]*binding{}
|
||||
removedEqs := ast.NewValueMap()
|
||||
|
||||
for _, expr := range query {
|
||||
|
||||
pctx := &plugContext{
|
||||
bindings: bindings,
|
||||
uf: uf,
|
||||
negated: expr.Negated,
|
||||
headvars: headvars,
|
||||
removedEqs: removedEqs,
|
||||
uf: uf,
|
||||
negated: expr.Negated,
|
||||
headvars: headvars,
|
||||
}
|
||||
|
||||
if expr, keep := p.plugBindings(pctx, expr); keep {
|
||||
if p.updateBindings(pctx, expr) {
|
||||
result.Append(expr)
|
||||
}
|
||||
expr = p.plugBindings(pctx, expr)
|
||||
|
||||
if p.updateBindings(pctx, expr) {
|
||||
result.Append(expr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,7 +128,7 @@ func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) {
|
||||
// from being added to the result. For example:
|
||||
//
|
||||
// - Given the following result: <empty>
|
||||
// - Given the following bindings: x/input.x and y/input
|
||||
// - Given the following removed equalities: "x = input.x" and "y = input"
|
||||
// - Given the following liveset: {x}
|
||||
//
|
||||
// If this step were to run AFTER the following step, the output would be:
|
||||
@@ -116,8 +140,8 @@ func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) {
|
||||
if root, ok := uf.Find(v); ok {
|
||||
if root.constant != nil {
|
||||
result.Append(ast.Equality.Expr(ast.NewTerm(v), root.constant))
|
||||
} else if b, ok := bindings[root.key]; ok {
|
||||
result.Append(ast.Equality.Expr(ast.NewTerm(v), ast.NewTerm(b.v)))
|
||||
} else if b := removedEqs.Get(root.key); b != nil {
|
||||
result.Append(ast.Equality.Expr(ast.NewTerm(v), ast.NewTerm(b)))
|
||||
} else if root.key != v {
|
||||
result.Append(ast.Equality.Expr(ast.NewTerm(v), ast.NewTerm(root.key)))
|
||||
}
|
||||
@@ -125,17 +149,46 @@ func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) {
|
||||
}
|
||||
|
||||
// Run post-processing step on query to ensure that all killed exprs are
|
||||
// accounted for. If an expr is killed but the binding is never used, the query
|
||||
// must still include the expr. For example, given the query 'input.x = a' and
|
||||
// an empty livevar set, the result must include the ref input.x otherwise the
|
||||
// query could be satisfied without input.x being defined. When exprs are
|
||||
// killed we initialize the binding counter to zero and then increment it each
|
||||
// time the binding is substituted. if the binding was never substituted it
|
||||
// means the binding value must be added back into the query.
|
||||
for _, b := range sortbindings(bindings) {
|
||||
if !b.containedIn(result) {
|
||||
result.Append(ast.Equality.Expr(ast.NewTerm(b.k), ast.NewTerm(b.v)))
|
||||
// accounted for. There are several cases we look for:
|
||||
//
|
||||
// * If an expr is killed but the binding is never used, the query
|
||||
// must still include the expr. For example, given the query 'input.x = a' and
|
||||
// an empty livevar set, the result must include the ref input.x otherwise the
|
||||
// query could be satisfied without input.x being defined.
|
||||
//
|
||||
// * If an expr is killed that provided safety to vars which are not
|
||||
// otherwise being made safe by the current result.
|
||||
//
|
||||
// For any of these cases we re-add the removed equality expression
|
||||
// to the current result.
|
||||
|
||||
// Invariant: Live vars are bound (above) and reserved vars are implicitly ground.
|
||||
safe := ast.ReservedVars.Copy()
|
||||
safe.Update(p.livevars)
|
||||
safe.Update(ast.OutputVarsFromBody(p.compiler, result, safe))
|
||||
unsafe := result.Vars(ast.SafetyCheckVisitorParams).Diff(safe)
|
||||
|
||||
for _, b := range sortbindings(removedEqs) {
|
||||
removedEq := ast.Equality.Expr(ast.NewTerm(b.k), ast.NewTerm(b.v))
|
||||
|
||||
providesSafety := false
|
||||
outputVars := ast.OutputVarsFromExpr(p.compiler, removedEq, safe)
|
||||
diff := unsafe.Diff(outputVars)
|
||||
if len(diff) < len(unsafe) {
|
||||
unsafe = diff
|
||||
providesSafety = true
|
||||
}
|
||||
|
||||
if providesSafety || !containedIn(b.v, result) {
|
||||
result.Append(removedEq)
|
||||
safe.Update(outputVars)
|
||||
}
|
||||
}
|
||||
|
||||
if len(unsafe) > 0 {
|
||||
// NOTE(tsandall): This should be impossible but if it does occur, throw
|
||||
// away the result rather than generating unsafe output.
|
||||
return query
|
||||
}
|
||||
|
||||
if p.ensureNonEmptyBody && len(result) == 0 {
|
||||
@@ -147,19 +200,7 @@ func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) {
|
||||
|
||||
// plugBindings applies the binding list and union-find to x. This process
|
||||
// removes as many variables as possible.
|
||||
func (p *CopyPropagator) plugBindings(pctx *plugContext, expr *ast.Expr) (*ast.Expr, bool) {
|
||||
|
||||
// Kill single term expressions that are in the binding list. They will be
|
||||
// re-added during post-processing if needed.
|
||||
if term, ok := expr.Terms.(*ast.Term); ok {
|
||||
if v, ok := term.Value.(ast.Var); ok {
|
||||
if root, ok := pctx.uf.Find(v); ok {
|
||||
if _, ok := pctx.bindings[root.key]; ok {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
func (p *CopyPropagator) plugBindings(pctx *plugContext, expr *ast.Expr) *ast.Expr {
|
||||
|
||||
xform := bindingPlugTransform{
|
||||
pctx: pctx,
|
||||
@@ -175,7 +216,7 @@ func (p *CopyPropagator) plugBindings(pctx *plugContext, expr *ast.Expr) (*ast.E
|
||||
if expr, ok := x.(*ast.Expr); !ok || err != nil {
|
||||
panic("unreachable")
|
||||
} else {
|
||||
return expr, true
|
||||
return expr
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,31 +235,40 @@ func (t bindingPlugTransform) Transform(x interface{}) (interface{}, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (t bindingPlugTransform) plugBindingsVar(pctx *plugContext, v ast.Var) (result ast.Value) {
|
||||
func (t bindingPlugTransform) plugBindingsVar(pctx *plugContext, v ast.Var) ast.Value {
|
||||
|
||||
result = v
|
||||
var result ast.Value = v
|
||||
|
||||
// Apply union-find to remove redundant variables from input.
|
||||
if root, ok := pctx.uf.Find(v); ok {
|
||||
root, ok := pctx.uf.Find(v)
|
||||
if ok {
|
||||
result = root.Value()
|
||||
}
|
||||
|
||||
// Apply binding list to substitute remaining vars.
|
||||
if v, ok := result.(ast.Var); ok {
|
||||
if b, ok := pctx.bindings[v]; ok {
|
||||
if !pctx.negated || b.v.IsGround() {
|
||||
result = b.v
|
||||
}
|
||||
}
|
||||
v, ok = result.(ast.Var)
|
||||
if !ok {
|
||||
return result
|
||||
}
|
||||
b := pctx.removedEqs.Get(v)
|
||||
if b == nil {
|
||||
return result
|
||||
}
|
||||
if pctx.negated && !b.IsGround() {
|
||||
return result
|
||||
}
|
||||
|
||||
return result
|
||||
if r, ok := b.(ast.Ref); ok && r.OutputVars().Contains(v) {
|
||||
return result
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (t bindingPlugTransform) plugBindingsRef(pctx *plugContext, v ast.Ref) ast.Ref {
|
||||
|
||||
// Apply union-find to remove redundant variables from input.
|
||||
if root, ok := pctx.uf.Find(v[0].Value.(ast.Var)); ok {
|
||||
if root, ok := pctx.uf.Find(v[0].Value); ok {
|
||||
v[0].Value = root.Value()
|
||||
}
|
||||
|
||||
@@ -226,11 +276,16 @@ func (t bindingPlugTransform) plugBindingsRef(pctx *plugContext, v ast.Ref) ast.
|
||||
|
||||
// Refs require special handling. If the head of the ref was killed, then
|
||||
// the rest of the ref must be concatenated with the new base.
|
||||
//
|
||||
// Invariant: ref heads can only be replaced by refs (not calls).
|
||||
if b, ok := pctx.bindings[v[0].Value.(ast.Var)]; ok {
|
||||
if !pctx.negated || b.v.IsGround() {
|
||||
result = b.v.(ast.Ref).Concat(v[1:])
|
||||
if b := pctx.removedEqs.Get(v[0].Value); b != nil {
|
||||
if !pctx.negated || b.IsGround() {
|
||||
var base ast.Ref
|
||||
switch x := b.(type) {
|
||||
case ast.Ref:
|
||||
base = x
|
||||
default:
|
||||
base = ast.Ref{ast.NewTerm(x)}
|
||||
}
|
||||
result = base.Concat(v[1:])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,32 +295,54 @@ func (t bindingPlugTransform) plugBindingsRef(pctx *plugContext, v ast.Ref) ast.
|
||||
// updateBindings returns false if the expression can be killed. If the
|
||||
// expression is killed, the binding list is updated to map a var to value.
|
||||
func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool {
|
||||
if pctx.negated || len(expr.With) > 0 {
|
||||
switch {
|
||||
case pctx.negated || len(expr.With) > 0:
|
||||
return true
|
||||
}
|
||||
if expr.IsEquality() {
|
||||
|
||||
case expr.IsEquality():
|
||||
a, b := expr.Operand(0), expr.Operand(1)
|
||||
if a.Equal(b) {
|
||||
if p.livevarRef(a) {
|
||||
pctx.removedEqs.Put(p.localvargen.Generate(), a.Value)
|
||||
}
|
||||
return false
|
||||
}
|
||||
k, v, keep := p.updateBindingsEq(a, b)
|
||||
if !keep {
|
||||
if v != nil {
|
||||
pctx.bindings[k] = newbinding(k, v)
|
||||
pctx.removedEqs.Put(k, v)
|
||||
}
|
||||
return false
|
||||
}
|
||||
} else if expr.IsCall() {
|
||||
|
||||
case expr.IsCall():
|
||||
terms := expr.Terms.([]*ast.Term)
|
||||
output := terms[len(terms)-1]
|
||||
if k, ok := output.Value.(ast.Var); ok && !p.livevars.Contains(k) && !pctx.headvars.Contains(k) {
|
||||
pctx.bindings[k] = newbinding(k, ast.CallTerm(terms[:len(terms)-1]...).Value)
|
||||
return false
|
||||
if p.compiler.GetArity(expr.Operator()) == len(terms)-2 { // with captured output
|
||||
output := terms[len(terms)-1]
|
||||
if k, ok := output.Value.(ast.Var); ok && !p.livevars.Contains(k) && !pctx.headvars.Contains(k) {
|
||||
pctx.removedEqs.Put(k, ast.CallTerm(terms[:len(terms)-1]...).Value)
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return !isNoop(expr)
|
||||
}
|
||||
|
||||
func (p *CopyPropagator) livevarRef(a *ast.Term) bool {
|
||||
ref, ok := a.Value.(ast.Ref)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, v := range p.sorted {
|
||||
if ref[0].Value.Compare(v) == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *CopyPropagator) updateBindingsEq(a, b *ast.Term) (ast.Var, ast.Value, bool) {
|
||||
k, v, keep := p.updateBindingsEqAsymmetric(a, b)
|
||||
if !keep {
|
||||
@@ -289,26 +366,21 @@ func (p *CopyPropagator) updateBindingsEqAsymmetric(a, b *ast.Term) (ast.Var, as
|
||||
}
|
||||
|
||||
type plugContext struct {
|
||||
bindings map[ast.Var]*binding
|
||||
uf *unionFind
|
||||
headvars ast.VarSet
|
||||
negated bool
|
||||
removedEqs *ast.ValueMap
|
||||
uf *unionFind
|
||||
headvars ast.VarSet
|
||||
negated bool
|
||||
}
|
||||
|
||||
type binding struct {
|
||||
k ast.Var
|
||||
v ast.Value
|
||||
k, v ast.Value
|
||||
}
|
||||
|
||||
func newbinding(k ast.Var, v ast.Value) *binding {
|
||||
return &binding{k: k, v: v}
|
||||
}
|
||||
|
||||
func (b *binding) containedIn(query ast.Body) bool {
|
||||
func containedIn(value ast.Value, x interface{}) bool {
|
||||
var stop bool
|
||||
switch v := b.v.(type) {
|
||||
switch v := value.(type) {
|
||||
case ast.Ref:
|
||||
ast.WalkRefs(query, func(other ast.Ref) bool {
|
||||
ast.WalkRefs(x, func(other ast.Ref) bool {
|
||||
if stop || other.HasPrefix(v) {
|
||||
stop = true
|
||||
return stop
|
||||
@@ -316,7 +388,7 @@ func (b *binding) containedIn(query ast.Body) bool {
|
||||
return false
|
||||
})
|
||||
default:
|
||||
ast.WalkTerms(query, func(other *ast.Term) bool {
|
||||
ast.WalkTerms(x, func(other *ast.Term) bool {
|
||||
if stop || other.Value.Compare(v) == 0 {
|
||||
stop = true
|
||||
return stop
|
||||
@@ -327,23 +399,18 @@ func (b *binding) containedIn(query ast.Body) bool {
|
||||
return stop
|
||||
}
|
||||
|
||||
func sortbindings(bindings map[ast.Var]*binding) []*binding {
|
||||
sorted := make([]*binding, 0, len(bindings))
|
||||
for _, b := range bindings {
|
||||
sorted = append(sorted, b)
|
||||
}
|
||||
func sortbindings(bindings *ast.ValueMap) []*binding {
|
||||
sorted := make([]*binding, 0, bindings.Len())
|
||||
bindings.Iter(func(k ast.Value, v ast.Value) bool {
|
||||
sorted = append(sorted, &binding{k, v})
|
||||
return false
|
||||
})
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].k.Compare(sorted[j].k) < 0
|
||||
return sorted[i].k.Compare(sorted[j].k) > 0
|
||||
})
|
||||
return sorted
|
||||
}
|
||||
|
||||
type unionFind struct {
|
||||
roots map[ast.Var]*unionFindRoot
|
||||
parents map[ast.Var]ast.Var
|
||||
rank rankFunc
|
||||
}
|
||||
|
||||
// makeDisjointSets builds the union-find structure for the query. The structure
|
||||
// is built by processing all of the equality exprs in the query. Sets represent
|
||||
// vars that must be equal to each other. In addition to vars, each set can have
|
||||
@@ -352,7 +419,7 @@ type unionFind struct {
|
||||
// false.
|
||||
func makeDisjointSets(livevars ast.VarSet, query ast.Body) (*unionFind, bool) {
|
||||
uf := newUnionFind(func(r1, r2 *unionFindRoot) (*unionFindRoot, *unionFindRoot) {
|
||||
if livevars.Contains(r1.key) {
|
||||
if v, ok := r1.key.(ast.Var); ok && livevars.Contains(v) {
|
||||
return r1, r2
|
||||
}
|
||||
return r2, r1
|
||||
@@ -362,17 +429,21 @@ func makeDisjointSets(livevars ast.VarSet, query ast.Body) (*unionFind, bool) {
|
||||
a, b := expr.Operand(0), expr.Operand(1)
|
||||
varA, ok1 := a.Value.(ast.Var)
|
||||
varB, ok2 := b.Value.(ast.Var)
|
||||
if ok1 && ok2 {
|
||||
|
||||
switch {
|
||||
case ok1 && ok2:
|
||||
if _, ok := uf.Merge(varA, varB); !ok {
|
||||
return nil, false
|
||||
}
|
||||
} else if ok1 && ast.IsConstant(b.Value) {
|
||||
|
||||
case ok1 && ast.IsConstant(b.Value):
|
||||
root := uf.MakeSet(varA)
|
||||
if root.constant != nil && !root.constant.Equal(b) {
|
||||
return nil, false
|
||||
}
|
||||
root.constant = b
|
||||
} else if ok2 && ast.IsConstant(a.Value) {
|
||||
|
||||
case ok2 && ast.IsConstant(a.Value):
|
||||
root := uf.MakeSet(varB)
|
||||
if root.constant != nil && !root.constant.Equal(a) {
|
||||
return nil, false
|
||||
@@ -385,89 +456,9 @@ func makeDisjointSets(livevars ast.VarSet, query ast.Body) (*unionFind, bool) {
|
||||
return uf, true
|
||||
}
|
||||
|
||||
type rankFunc func(*unionFindRoot, *unionFindRoot) (*unionFindRoot, *unionFindRoot)
|
||||
|
||||
func newUnionFind(rank rankFunc) *unionFind {
|
||||
return &unionFind{
|
||||
roots: map[ast.Var]*unionFindRoot{},
|
||||
parents: map[ast.Var]ast.Var{},
|
||||
rank: rank,
|
||||
}
|
||||
}
|
||||
|
||||
func (uf *unionFind) MakeSet(v ast.Var) *unionFindRoot {
|
||||
|
||||
root, ok := uf.Find(v)
|
||||
if ok {
|
||||
return root
|
||||
}
|
||||
|
||||
root = newUnionFindRoot(v)
|
||||
uf.parents[v] = v
|
||||
uf.roots[v] = root
|
||||
return uf.roots[v]
|
||||
}
|
||||
|
||||
func (uf *unionFind) Find(v ast.Var) (*unionFindRoot, bool) {
|
||||
|
||||
parent, ok := uf.parents[v]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if parent == v {
|
||||
return uf.roots[v], true
|
||||
}
|
||||
|
||||
return uf.Find(parent)
|
||||
}
|
||||
|
||||
func (uf *unionFind) Merge(a, b ast.Var) (*unionFindRoot, bool) {
|
||||
|
||||
r1 := uf.MakeSet(a)
|
||||
r2 := uf.MakeSet(b)
|
||||
|
||||
if r1 != r2 {
|
||||
|
||||
r1, r2 = uf.rank(r1, r2)
|
||||
|
||||
uf.parents[r2.key] = r1.key
|
||||
delete(uf.roots, r2.key)
|
||||
|
||||
// Sets can have at most one constant value associated with them. When
|
||||
// unioning, we must preserve this invariant. If a set has two constants,
|
||||
// there will be no way to prove the query.
|
||||
if r1.constant != nil && r2.constant != nil && !r1.constant.Equal(r2.constant) {
|
||||
return nil, false
|
||||
} else if r1.constant == nil {
|
||||
r1.constant = r2.constant
|
||||
}
|
||||
}
|
||||
|
||||
return r1, true
|
||||
}
|
||||
|
||||
type unionFindRoot struct {
|
||||
key ast.Var
|
||||
constant *ast.Term
|
||||
}
|
||||
|
||||
func newUnionFindRoot(key ast.Var) *unionFindRoot {
|
||||
return &unionFindRoot{
|
||||
key: key,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *unionFindRoot) Value() ast.Value {
|
||||
if r.constant != nil {
|
||||
return r.constant.Value
|
||||
}
|
||||
return r.key
|
||||
}
|
||||
|
||||
func isNoop(expr *ast.Expr) bool {
|
||||
|
||||
if !expr.IsCall() {
|
||||
if !expr.IsCall() && !expr.IsEvery() {
|
||||
term := expr.Terms.(*ast.Term)
|
||||
if !ast.IsConstant(term.Value) {
|
||||
return false
|
||||
|
||||
135
vendor/github.com/open-policy-agent/opa/topdown/copypropagation/unionfind.go
generated
vendored
Normal file
135
vendor/github.com/open-policy-agent/opa/topdown/copypropagation/unionfind.go
generated
vendored
Normal file
@@ -0,0 +1,135 @@
|
||||
// Copyright 2020 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package copypropagation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/util"
|
||||
)
|
||||
|
||||
type rankFunc func(*unionFindRoot, *unionFindRoot) (*unionFindRoot, *unionFindRoot)
|
||||
|
||||
type unionFind struct {
|
||||
roots *util.HashMap
|
||||
parents *ast.ValueMap
|
||||
rank rankFunc
|
||||
}
|
||||
|
||||
func newUnionFind(rank rankFunc) *unionFind {
|
||||
return &unionFind{
|
||||
roots: util.NewHashMap(func(a util.T, b util.T) bool {
|
||||
return a.(ast.Value).Compare(b.(ast.Value)) == 0
|
||||
}, func(v util.T) int {
|
||||
return v.(ast.Value).Hash()
|
||||
}),
|
||||
parents: ast.NewValueMap(),
|
||||
rank: rank,
|
||||
}
|
||||
}
|
||||
|
||||
func (uf *unionFind) MakeSet(v ast.Value) *unionFindRoot {
|
||||
|
||||
root, ok := uf.Find(v)
|
||||
if ok {
|
||||
return root
|
||||
}
|
||||
|
||||
root = newUnionFindRoot(v)
|
||||
uf.parents.Put(v, v)
|
||||
uf.roots.Put(v, root)
|
||||
return root
|
||||
}
|
||||
|
||||
func (uf *unionFind) Find(v ast.Value) (*unionFindRoot, bool) {
|
||||
|
||||
parent := uf.parents.Get(v)
|
||||
if parent == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if parent.Compare(v) == 0 {
|
||||
r, ok := uf.roots.Get(v)
|
||||
return r.(*unionFindRoot), ok
|
||||
}
|
||||
|
||||
return uf.Find(parent)
|
||||
}
|
||||
|
||||
func (uf *unionFind) Merge(a, b ast.Value) (*unionFindRoot, bool) {
|
||||
|
||||
r1 := uf.MakeSet(a)
|
||||
r2 := uf.MakeSet(b)
|
||||
|
||||
if r1 != r2 {
|
||||
|
||||
r1, r2 = uf.rank(r1, r2)
|
||||
|
||||
uf.parents.Put(r2.key, r1.key)
|
||||
uf.roots.Delete(r2.key)
|
||||
|
||||
// Sets can have at most one constant value associated with them. When
|
||||
// unioning, we must preserve this invariant. If a set has two constants,
|
||||
// there will be no way to prove the query.
|
||||
if r1.constant != nil && r2.constant != nil && !r1.constant.Equal(r2.constant) {
|
||||
return nil, false
|
||||
} else if r1.constant == nil {
|
||||
r1.constant = r2.constant
|
||||
}
|
||||
}
|
||||
|
||||
return r1, true
|
||||
}
|
||||
|
||||
func (uf *unionFind) String() string {
|
||||
o := struct {
|
||||
Roots map[string]interface{}
|
||||
Parents map[string]ast.Value
|
||||
}{
|
||||
map[string]interface{}{},
|
||||
map[string]ast.Value{},
|
||||
}
|
||||
|
||||
uf.roots.Iter(func(k util.T, v util.T) bool {
|
||||
o.Roots[k.(ast.Value).String()] = struct {
|
||||
Constant *ast.Term
|
||||
Key ast.Value
|
||||
}{
|
||||
v.(*unionFindRoot).constant,
|
||||
v.(*unionFindRoot).key,
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
uf.parents.Iter(func(k ast.Value, v ast.Value) bool {
|
||||
o.Parents[k.String()] = v
|
||||
return true
|
||||
})
|
||||
|
||||
return string(util.MustMarshalJSON(o))
|
||||
}
|
||||
|
||||
type unionFindRoot struct {
|
||||
key ast.Value
|
||||
constant *ast.Term
|
||||
}
|
||||
|
||||
func newUnionFindRoot(key ast.Value) *unionFindRoot {
|
||||
return &unionFindRoot{
|
||||
key: key,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *unionFindRoot) Value() ast.Value {
|
||||
if r.constant != nil {
|
||||
return r.constant.Value
|
||||
}
|
||||
return r.key
|
||||
}
|
||||
|
||||
func (r *unionFindRoot) String() string {
|
||||
return fmt.Sprintf("{key: %s, constant: %s", r.key, r.constant)
|
||||
}
|
||||
359
vendor/github.com/open-policy-agent/opa/topdown/crypto.go
generated
vendored
359
vendor/github.com/open-policy-agent/opa/topdown/crypto.go
generated
vendored
@@ -5,45 +5,174 @@
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/internal/jwx/jwk"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
"github.com/open-policy-agent/opa/util"
|
||||
)
|
||||
|
||||
const (
|
||||
// blockTypeCertificate indicates this PEM block contains the signed certificate.
|
||||
// Exported for tests.
|
||||
blockTypeCertificate = "CERTIFICATE"
|
||||
// blockTypeCertificateRequest indicates this PEM block contains a certificate
|
||||
// request. Exported for tests.
|
||||
blockTypeCertificateRequest = "CERTIFICATE REQUEST"
|
||||
// blockTypeRSAPrivateKey indicates this PEM block contains a RSA private key.
|
||||
// Exported for tests.
|
||||
blockTypeRSAPrivateKey = "RSA PRIVATE KEY"
|
||||
// blockTypeRSAPrivateKey indicates this PEM block contains a RSA private key.
|
||||
// Exported for tests.
|
||||
blockTypePrivateKey = "PRIVATE KEY"
|
||||
)
|
||||
|
||||
func builtinCryptoX509ParseCertificates(a ast.Value) (ast.Value, error) {
|
||||
|
||||
str, err := builtinBase64Decode(a)
|
||||
input, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
certs, err := x509.ParseCertificates([]byte(str.(ast.String)))
|
||||
certs, err := getX509CertsFromString(string(input))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bs, err := json.Marshal(certs)
|
||||
return ast.InterfaceToValue(certs)
|
||||
}
|
||||
|
||||
func builtinCryptoX509ParseAndVerifyCertificates(
|
||||
_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
a := args[0].Value
|
||||
input, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
invalid := ast.ArrayTerm(
|
||||
ast.BooleanTerm(false),
|
||||
ast.NewTerm(ast.NewArray()),
|
||||
)
|
||||
|
||||
certs, err := getX509CertsFromString(string(input))
|
||||
if err != nil {
|
||||
return iter(invalid)
|
||||
}
|
||||
|
||||
verified, err := verifyX509CertificateChain(certs)
|
||||
if err != nil {
|
||||
return iter(invalid)
|
||||
}
|
||||
|
||||
value, err := ast.InterfaceToValue(verified)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
valid := ast.ArrayTerm(
|
||||
ast.BooleanTerm(true),
|
||||
ast.NewTerm(value),
|
||||
)
|
||||
|
||||
return iter(valid)
|
||||
}
|
||||
|
||||
func builtinCryptoX509ParseCertificateRequest(a ast.Value) (ast.Value, error) {
|
||||
|
||||
input, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// data to be passed to x509.ParseCertificateRequest
|
||||
bytes := []byte(input)
|
||||
|
||||
// if the input is not a PEM string, attempt to decode b64
|
||||
if str := string(input); !strings.HasPrefix(str, "-----BEGIN CERTIFICATE REQUEST-----") {
|
||||
bytes, err = base64.StdEncoding.DecodeString(str)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
p, _ := pem.Decode(bytes)
|
||||
if p != nil && p.Type != blockTypeCertificateRequest {
|
||||
return nil, fmt.Errorf("invalid PEM-encoded certificate signing request")
|
||||
}
|
||||
if p != nil {
|
||||
bytes = p.Bytes
|
||||
}
|
||||
|
||||
csr, err := x509.ParseCertificateRequest(bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bs, err := json.Marshal(csr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var x interface{}
|
||||
|
||||
if err := util.UnmarshalJSON(bs, &x); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.InterfaceToValue(x)
|
||||
}
|
||||
|
||||
func builtinCryptoX509ParseRSAPrivateKey(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
a := args[0].Value
|
||||
input, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// get the raw private key
|
||||
rawKey, err := getRSAPrivateKeyFromString(string(input))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rsaPrivateKey, err := jwk.New(rawKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
jsonKey, err := json.Marshal(rsaPrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var x interface{}
|
||||
if err := util.UnmarshalJSON(jsonKey, &x); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
value, err := ast.InterfaceToValue(x)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return iter(ast.NewTerm(value))
|
||||
}
|
||||
|
||||
func hashHelper(a ast.Value, h func(ast.String) string) (ast.Value, error) {
|
||||
s, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
@@ -64,47 +193,209 @@ func builtinCryptoSha256(a ast.Value) (ast.Value, error) {
|
||||
return hashHelper(a, func(s ast.String) string { return fmt.Sprintf("%x", sha256.Sum256([]byte(s))) })
|
||||
}
|
||||
|
||||
func hmacHelper(args []*ast.Term, iter func(*ast.Term) error, h func() hash.Hash) error {
|
||||
a1 := args[0].Value
|
||||
message, err := builtins.StringOperand(a1, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a2 := args[1].Value
|
||||
key, err := builtins.StringOperand(a2, 2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mac := hmac.New(h, []byte(key))
|
||||
mac.Write([]byte(message))
|
||||
messageDigest := mac.Sum(nil)
|
||||
|
||||
return iter(ast.StringTerm(fmt.Sprintf("%x", messageDigest)))
|
||||
}
|
||||
|
||||
func builtinCryptoHmacMd5(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
return hmacHelper(args, iter, md5.New)
|
||||
}
|
||||
|
||||
func builtinCryptoHmacSha1(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
return hmacHelper(args, iter, sha1.New)
|
||||
}
|
||||
|
||||
func builtinCryptoHmacSha256(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
return hmacHelper(args, iter, sha256.New)
|
||||
}
|
||||
|
||||
func builtinCryptoHmacSha512(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
return hmacHelper(args, iter, sha512.New)
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin1(ast.CryptoX509ParseCertificates.Name, builtinCryptoX509ParseCertificates)
|
||||
RegisterBuiltinFunc(ast.CryptoX509ParseAndVerifyCertificates.Name, builtinCryptoX509ParseAndVerifyCertificates)
|
||||
RegisterFunctionalBuiltin1(ast.CryptoMd5.Name, builtinCryptoMd5)
|
||||
RegisterFunctionalBuiltin1(ast.CryptoSha1.Name, builtinCryptoSha1)
|
||||
RegisterFunctionalBuiltin1(ast.CryptoSha256.Name, builtinCryptoSha256)
|
||||
RegisterFunctionalBuiltin1(ast.CryptoX509ParseCertificateRequest.Name, builtinCryptoX509ParseCertificateRequest)
|
||||
RegisterBuiltinFunc(ast.CryptoX509ParseRSAPrivateKey.Name, builtinCryptoX509ParseRSAPrivateKey)
|
||||
RegisterBuiltinFunc(ast.CryptoHmacMd5.Name, builtinCryptoHmacMd5)
|
||||
RegisterBuiltinFunc(ast.CryptoHmacSha1.Name, builtinCryptoHmacSha1)
|
||||
RegisterBuiltinFunc(ast.CryptoHmacSha256.Name, builtinCryptoHmacSha256)
|
||||
RegisterBuiltinFunc(ast.CryptoHmacSha512.Name, builtinCryptoHmacSha512)
|
||||
}
|
||||
|
||||
// createRootCAs creates a new Cert Pool from scratch or adds to a copy of System Certs
|
||||
func createRootCAs(tlsCACertFile string, tlsCACertEnvVar []byte, tlsUseSystemCerts bool) (*x509.CertPool, error) {
|
||||
|
||||
var newRootCAs *x509.CertPool
|
||||
|
||||
if tlsUseSystemCerts {
|
||||
systemCertPool, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newRootCAs = systemCertPool
|
||||
} else {
|
||||
newRootCAs = x509.NewCertPool()
|
||||
func verifyX509CertificateChain(certs []*x509.Certificate) ([]*x509.Certificate, error) {
|
||||
if len(certs) < 2 {
|
||||
return nil, builtins.NewOperandErr(1, "must supply at least two certificates to be able to verify")
|
||||
}
|
||||
|
||||
if len(tlsCACertFile) > 0 {
|
||||
// Append our cert to the system pool
|
||||
caCert, err := readCertFromFile(tlsCACertFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok := newRootCAs.AppendCertsFromPEM(caCert); !ok {
|
||||
return nil, fmt.Errorf("could not append CA cert from %q", tlsCACertFile)
|
||||
}
|
||||
// first cert is the root
|
||||
roots := x509.NewCertPool()
|
||||
roots.AddCert(certs[0])
|
||||
|
||||
// all other certs except the last are intermediates
|
||||
intermediates := x509.NewCertPool()
|
||||
for i := 1; i < len(certs)-1; i++ {
|
||||
intermediates.AddCert(certs[i])
|
||||
}
|
||||
|
||||
if len(tlsCACertEnvVar) > 0 {
|
||||
// Append our cert to the system pool
|
||||
if ok := newRootCAs.AppendCertsFromPEM(tlsCACertEnvVar); !ok {
|
||||
return nil, fmt.Errorf("error appending cert from env var %q into system certs", tlsCACertEnvVar)
|
||||
}
|
||||
// last cert is the leaf
|
||||
leaf := certs[len(certs)-1]
|
||||
|
||||
// verify the cert chain back to the root
|
||||
verifyOpts := x509.VerifyOptions{
|
||||
Roots: roots,
|
||||
Intermediates: intermediates,
|
||||
}
|
||||
chains, err := leaf.Verify(verifyOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newRootCAs, nil
|
||||
return chains[0], nil
|
||||
}
|
||||
|
||||
func getX509CertsFromString(certs string) ([]*x509.Certificate, error) {
|
||||
// if the input is PEM handle that
|
||||
if strings.HasPrefix(certs, "-----BEGIN") {
|
||||
return getX509CertsFromPem([]byte(certs))
|
||||
}
|
||||
|
||||
// assume input is base64 if not PEM
|
||||
b64, err := base64.StdEncoding.DecodeString(certs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// handle if the decoded base64 contains PEM rather than the expected DER
|
||||
if bytes.HasPrefix(b64, []byte("-----BEGIN")) {
|
||||
return getX509CertsFromPem(b64)
|
||||
}
|
||||
|
||||
// otherwise assume the contents are DER
|
||||
return x509.ParseCertificates(b64)
|
||||
}
|
||||
|
||||
func getX509CertsFromPem(pemBlocks []byte) ([]*x509.Certificate, error) {
|
||||
var decodedCerts []byte
|
||||
for len(pemBlocks) > 0 {
|
||||
p, r := pem.Decode(pemBlocks)
|
||||
if p != nil && p.Type != blockTypeCertificate {
|
||||
return nil, fmt.Errorf("PEM block type is '%s', expected %s", p.Type, blockTypeCertificate)
|
||||
}
|
||||
|
||||
if p == nil {
|
||||
break
|
||||
}
|
||||
|
||||
pemBlocks = r
|
||||
decodedCerts = append(decodedCerts, p.Bytes...)
|
||||
}
|
||||
|
||||
return x509.ParseCertificates(decodedCerts)
|
||||
}
|
||||
|
||||
func getRSAPrivateKeyFromString(key string) (interface{}, error) {
|
||||
// if the input is PEM handle that
|
||||
if strings.HasPrefix(key, "-----BEGIN") {
|
||||
return getRSAPrivateKeyFromPEM([]byte(key))
|
||||
}
|
||||
|
||||
// assume input is base64 if not PEM
|
||||
b64, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return getRSAPrivateKeyFromPEM(b64)
|
||||
}
|
||||
|
||||
func getRSAPrivateKeyFromPEM(pemBlocks []byte) (interface{}, error) {
|
||||
|
||||
// decode the pem into the Block struct
|
||||
p, _ := pem.Decode(pemBlocks)
|
||||
if p == nil {
|
||||
return nil, fmt.Errorf("failed to parse PEM block containing the key")
|
||||
}
|
||||
|
||||
// if the key is in PKCS1 format
|
||||
if p.Type == blockTypeRSAPrivateKey {
|
||||
return x509.ParsePKCS1PrivateKey(p.Bytes)
|
||||
}
|
||||
|
||||
// if the key is in PKCS8 format
|
||||
if p.Type == blockTypePrivateKey {
|
||||
return x509.ParsePKCS8PrivateKey(p.Bytes)
|
||||
}
|
||||
|
||||
// unsupported key format
|
||||
return nil, fmt.Errorf("PEM block type is '%s', expected %s or %s", p.Type, blockTypeRSAPrivateKey,
|
||||
blockTypePrivateKey)
|
||||
|
||||
}
|
||||
|
||||
// addCACertsFromFile adds CA certificates from filePath into the given pool.
|
||||
// If pool is nil, it creates a new x509.CertPool. pool is returned.
|
||||
func addCACertsFromFile(pool *x509.CertPool, filePath string) (*x509.CertPool, error) {
|
||||
if pool == nil {
|
||||
pool = x509.NewCertPool()
|
||||
}
|
||||
|
||||
caCert, err := readCertFromFile(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ok := pool.AppendCertsFromPEM(caCert); !ok {
|
||||
return nil, fmt.Errorf("could not append CA certificates from %q", filePath)
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// addCACertsFromBytes adds CA certificates from pemBytes into the given pool.
|
||||
// If pool is nil, it creates a new x509.CertPool. pool is returned.
|
||||
func addCACertsFromBytes(pool *x509.CertPool, pemBytes []byte) (*x509.CertPool, error) {
|
||||
if pool == nil {
|
||||
pool = x509.NewCertPool()
|
||||
}
|
||||
|
||||
if ok := pool.AppendCertsFromPEM(pemBytes); !ok {
|
||||
return nil, fmt.Errorf("could not append certificates")
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// addCACertsFromBytes adds CA certificates from the environment variable named
|
||||
// by envName into the given pool. If pool is nil, it creates a new x509.CertPool.
|
||||
// pool is returned.
|
||||
func addCACertsFromEnv(pool *x509.CertPool, envName string) (*x509.CertPool, error) {
|
||||
pool, err := addCACertsFromBytes(pool, []byte(os.Getenv(envName)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not add CA certificates from envvar %q: %w", envName, err)
|
||||
}
|
||||
|
||||
return pool, err
|
||||
}
|
||||
|
||||
// ReadCertFromFile reads a cert from file
|
||||
|
||||
90
vendor/github.com/open-policy-agent/opa/topdown/encoding.go
generated
vendored
90
vendor/github.com/open-policy-agent/opa/topdown/encoding.go
generated
vendored
@@ -7,6 +7,7 @@ package topdown
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
@@ -50,6 +51,16 @@ func builtinJSONUnmarshal(a ast.Value) (ast.Value, error) {
|
||||
return ast.InterfaceToValue(x)
|
||||
}
|
||||
|
||||
func builtinJSONIsValid(a ast.Value) (ast.Value, error) {
|
||||
|
||||
str, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return ast.Boolean(false), nil
|
||||
}
|
||||
|
||||
return ast.Boolean(json.Valid([]byte(str))), nil
|
||||
}
|
||||
|
||||
func builtinBase64Encode(a ast.Value) (ast.Value, error) {
|
||||
str, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
@@ -69,6 +80,16 @@ func builtinBase64Decode(a ast.Value) (ast.Value, error) {
|
||||
return ast.String(result), err
|
||||
}
|
||||
|
||||
func builtinBase64IsValid(a ast.Value) (ast.Value, error) {
|
||||
str, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return ast.Boolean(false), nil
|
||||
}
|
||||
|
||||
_, err = base64.StdEncoding.DecodeString(string(str))
|
||||
return ast.Boolean(err == nil), nil
|
||||
}
|
||||
|
||||
func builtinBase64UrlEncode(a ast.Value) (ast.Value, error) {
|
||||
str, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
@@ -78,6 +99,14 @@ func builtinBase64UrlEncode(a ast.Value) (ast.Value, error) {
|
||||
return ast.String(base64.URLEncoding.EncodeToString([]byte(str))), nil
|
||||
}
|
||||
|
||||
func builtinBase64UrlEncodeNoPad(a ast.Value) (ast.Value, error) {
|
||||
str, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ast.String(base64.RawURLEncoding.EncodeToString([]byte(str))), nil
|
||||
}
|
||||
|
||||
func builtinBase64UrlDecode(a ast.Value) (ast.Value, error) {
|
||||
str, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
@@ -158,6 +187,29 @@ func builtinURLQueryEncodeObject(a ast.Value) (ast.Value, error) {
|
||||
return ast.String(query.Encode()), nil
|
||||
}
|
||||
|
||||
func builtinURLQueryDecodeObject(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
query, err := builtins.StringOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
queryParams, err := url.ParseQuery(string(query))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
queryObject := ast.NewObject()
|
||||
for k, v := range queryParams {
|
||||
paramsArray := make([]*ast.Term, len(v))
|
||||
for i, param := range v {
|
||||
paramsArray[i] = ast.StringTerm(param)
|
||||
}
|
||||
queryObject.Insert(ast.StringTerm(k), ast.ArrayTerm(paramsArray...))
|
||||
}
|
||||
|
||||
return iter(ast.NewTerm(queryObject))
|
||||
}
|
||||
|
||||
func builtinYAMLMarshal(a ast.Value) (ast.Value, error) {
|
||||
|
||||
asJSON, err := ast.JSON(a)
|
||||
@@ -202,16 +254,54 @@ func builtinYAMLUnmarshal(a ast.Value) (ast.Value, error) {
|
||||
return ast.InterfaceToValue(val)
|
||||
}
|
||||
|
||||
func builtinYAMLIsValid(a ast.Value) (ast.Value, error) {
|
||||
str, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return ast.Boolean(false), nil
|
||||
}
|
||||
|
||||
var x interface{}
|
||||
err = ghodss.Unmarshal([]byte(str), &x)
|
||||
return ast.Boolean(err == nil), nil
|
||||
}
|
||||
|
||||
func builtinHexEncode(a ast.Value) (ast.Value, error) {
|
||||
str, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ast.String(hex.EncodeToString([]byte(str))), nil
|
||||
}
|
||||
|
||||
func builtinHexDecode(a ast.Value) (ast.Value, error) {
|
||||
str, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
val, err := hex.DecodeString(string(str))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ast.String(val), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin1(ast.JSONMarshal.Name, builtinJSONMarshal)
|
||||
RegisterFunctionalBuiltin1(ast.JSONUnmarshal.Name, builtinJSONUnmarshal)
|
||||
RegisterFunctionalBuiltin1(ast.JSONIsValid.Name, builtinJSONIsValid)
|
||||
RegisterFunctionalBuiltin1(ast.Base64Encode.Name, builtinBase64Encode)
|
||||
RegisterFunctionalBuiltin1(ast.Base64Decode.Name, builtinBase64Decode)
|
||||
RegisterFunctionalBuiltin1(ast.Base64IsValid.Name, builtinBase64IsValid)
|
||||
RegisterFunctionalBuiltin1(ast.Base64UrlEncode.Name, builtinBase64UrlEncode)
|
||||
RegisterFunctionalBuiltin1(ast.Base64UrlEncodeNoPad.Name, builtinBase64UrlEncodeNoPad)
|
||||
RegisterFunctionalBuiltin1(ast.Base64UrlDecode.Name, builtinBase64UrlDecode)
|
||||
RegisterFunctionalBuiltin1(ast.URLQueryDecode.Name, builtinURLQueryDecode)
|
||||
RegisterFunctionalBuiltin1(ast.URLQueryEncode.Name, builtinURLQueryEncode)
|
||||
RegisterFunctionalBuiltin1(ast.URLQueryEncodeObject.Name, builtinURLQueryEncodeObject)
|
||||
RegisterBuiltinFunc(ast.URLQueryDecodeObject.Name, builtinURLQueryDecodeObject)
|
||||
RegisterFunctionalBuiltin1(ast.YAMLMarshal.Name, builtinYAMLMarshal)
|
||||
RegisterFunctionalBuiltin1(ast.YAMLUnmarshal.Name, builtinYAMLUnmarshal)
|
||||
RegisterFunctionalBuiltin1(ast.YAMLIsValid.Name, builtinYAMLIsValid)
|
||||
RegisterFunctionalBuiltin1(ast.HexEncode.Name, builtinHexEncode)
|
||||
RegisterFunctionalBuiltin1(ast.HexDecode.Name, builtinHexDecode)
|
||||
}
|
||||
|
||||
46
vendor/github.com/open-policy-agent/opa/topdown/errors.go
generated
vendored
46
vendor/github.com/open-policy-agent/opa/topdown/errors.go
generated
vendored
@@ -5,11 +5,24 @@
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
)
|
||||
|
||||
// Halt is a special error type that built-in function implementations return to indicate
|
||||
// that policy evaluation should stop immediately.
|
||||
type Halt struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (h Halt) Error() string {
|
||||
return h.Err.Error()
|
||||
}
|
||||
|
||||
func (h Halt) Unwrap() error { return h.Err }
|
||||
|
||||
// Error is the error type returned by the Eval and Query functions when
|
||||
// an evaluation error occurs.
|
||||
type Error struct {
|
||||
@@ -47,20 +60,27 @@ const (
|
||||
|
||||
// IsError returns true if the err is an Error.
|
||||
func IsError(err error) bool {
|
||||
_, ok := err.(*Error)
|
||||
return ok
|
||||
var e *Error
|
||||
return errors.As(err, &e)
|
||||
}
|
||||
|
||||
// IsCancel returns true if err was caused by cancellation.
|
||||
func IsCancel(err error) bool {
|
||||
if e, ok := err.(*Error); ok {
|
||||
return e.Code == CancelErr
|
||||
return errors.Is(err, &Error{Code: CancelErr})
|
||||
}
|
||||
|
||||
// Is allows matching topdown errors using errors.Is (see IsCancel).
|
||||
func (e *Error) Is(target error) bool {
|
||||
var t *Error
|
||||
if errors.As(target, &t) {
|
||||
return (t.Code == "" || e.Code == t.Code) &&
|
||||
(t.Message == "" || e.Message == t.Message) &&
|
||||
(t.Location == nil || t.Location.Compare(e.Location) == 0)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
|
||||
msg := fmt.Sprintf("%v: %v", e.Code, e.Message)
|
||||
|
||||
if e.Location != nil {
|
||||
@@ -94,14 +114,6 @@ func objectDocKeyConflictErr(loc *ast.Location) error {
|
||||
}
|
||||
}
|
||||
|
||||
func documentConflictErr(loc *ast.Location) error {
|
||||
return &Error{
|
||||
Code: ConflictErr,
|
||||
Location: loc,
|
||||
Message: "base and virtual document keys must be disjoint",
|
||||
}
|
||||
}
|
||||
|
||||
func unsupportedBuiltinErr(loc *ast.Location) error {
|
||||
return &Error{
|
||||
Code: InternalErr,
|
||||
@@ -117,3 +129,11 @@ func mergeConflictErr(loc *ast.Location) error {
|
||||
Message: "real and replacement data could not be merged",
|
||||
}
|
||||
}
|
||||
|
||||
func internalErr(loc *ast.Location, msg string) error {
|
||||
return &Error{
|
||||
Code: InternalErr,
|
||||
Location: loc,
|
||||
Message: msg,
|
||||
}
|
||||
}
|
||||
|
||||
1939
vendor/github.com/open-policy-agent/opa/topdown/eval.go
generated
vendored
1939
vendor/github.com/open-policy-agent/opa/topdown/eval.go
generated
vendored
File diff suppressed because it is too large
Load Diff
35
vendor/github.com/open-policy-agent/opa/topdown/glob.go
generated
vendored
35
vendor/github.com/open-policy-agent/opa/topdown/glob.go
generated
vendored
@@ -1,7 +1,7 @@
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gobwas/glob"
|
||||
@@ -18,22 +18,34 @@ func builtinGlobMatch(a, b, c ast.Value) (ast.Value, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var delimiters []rune
|
||||
switch b.(type) {
|
||||
case ast.Null:
|
||||
delimiters = []rune{}
|
||||
case *ast.Array:
|
||||
delimiters, err = builtins.RuneSliceOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
delimiters, err := builtins.RuneSliceOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if len(delimiters) == 0 {
|
||||
delimiters = []rune{'.'}
|
||||
}
|
||||
default:
|
||||
return nil, builtins.NewOperandTypeErr(2, b, "array", "null")
|
||||
}
|
||||
|
||||
if len(delimiters) == 0 {
|
||||
delimiters = []rune{'.'}
|
||||
}
|
||||
|
||||
match, err := builtins.StringOperand(c, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id := fmt.Sprintf("%s-%v", pattern, delimiters)
|
||||
builder := strings.Builder{}
|
||||
builder.WriteString(string(pattern))
|
||||
builder.WriteRune('-')
|
||||
for _, v := range delimiters {
|
||||
builder.WriteRune(v)
|
||||
}
|
||||
id := builder.String()
|
||||
|
||||
globCacheLock.Lock()
|
||||
defer globCacheLock.Unlock()
|
||||
@@ -46,7 +58,8 @@ func builtinGlobMatch(a, b, c ast.Value) (ast.Value, error) {
|
||||
globCache[id] = p
|
||||
}
|
||||
|
||||
return ast.Boolean(p.Match(string(match))), nil
|
||||
m := p.Match(string(match))
|
||||
return ast.Boolean(m), nil
|
||||
}
|
||||
|
||||
func builtinGlobQuoteMeta(a ast.Value) (ast.Value, error) {
|
||||
|
||||
462
vendor/github.com/open-policy-agent/opa/topdown/graphql.go
generated
vendored
Normal file
462
vendor/github.com/open-policy-agent/opa/topdown/graphql.go
generated
vendored
Normal file
@@ -0,0 +1,462 @@
|
||||
// Copyright 2022 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
gqlast "github.com/open-policy-agent/opa/internal/gqlparser/ast"
|
||||
gqlparser "github.com/open-policy-agent/opa/internal/gqlparser/parser"
|
||||
gqlvalidator "github.com/open-policy-agent/opa/internal/gqlparser/validator"
|
||||
|
||||
// Side-effecting import. Triggers GraphQL library's validation rule init() functions.
|
||||
_ "github.com/open-policy-agent/opa/internal/gqlparser/validator/rules"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
// Parses a GraphQL schema, and returns the GraphQL AST for the schema.
|
||||
func parseSchema(schema string) (*gqlast.SchemaDocument, error) {
|
||||
// NOTE(philipc): We don't include the "built-in schema defs" from the
|
||||
// underlying graphql parsing library here, because those definitions
|
||||
// generate enormous AST blobs. In the future, if there is demand for
|
||||
// a "full-spec" version of schema ASTs, we may need to provide a
|
||||
// version of this function that includes the built-in schema
|
||||
// definitions.
|
||||
schemaAST, err := gqlparser.ParseSchema(&gqlast.Source{Input: schema})
|
||||
if err != nil {
|
||||
errorParts := strings.SplitN(err.Error(), ":", 4)
|
||||
msg := strings.TrimLeft(errorParts[3], " ")
|
||||
return nil, fmt.Errorf("%s in GraphQL string at location %s:%s", msg, errorParts[1], errorParts[2])
|
||||
}
|
||||
return schemaAST, nil
|
||||
}
|
||||
|
||||
// Parses a GraphQL query, and returns the GraphQL AST for the query.
|
||||
func parseQuery(query string) (*gqlast.QueryDocument, error) {
|
||||
queryAST, err := gqlparser.ParseQuery(&gqlast.Source{Input: query})
|
||||
if err != nil {
|
||||
errorParts := strings.SplitN(err.Error(), ":", 4)
|
||||
msg := strings.TrimLeft(errorParts[3], " ")
|
||||
return nil, fmt.Errorf("%s in GraphQL string at location %s:%s", msg, errorParts[1], errorParts[2])
|
||||
}
|
||||
return queryAST, nil
|
||||
}
|
||||
|
||||
// Validates a GraphQL query against a schema, and returns an error.
|
||||
// In this case, we get a wrappered error list type, and pluck out
|
||||
// just the first error message in the list.
|
||||
func validateQuery(schema *gqlast.Schema, query *gqlast.QueryDocument) error {
|
||||
// Validate the query against the schema, erroring if there's an issue.
|
||||
err := gqlvalidator.Validate(schema, query)
|
||||
if err != nil {
|
||||
// We use strings.TrimSuffix to remove the '.' characters that the library
|
||||
// authors include on most of their validation errors. This should be safe,
|
||||
// since variable names in their error messages are usually quoted, and
|
||||
// this affects only the last character(s) in the string.
|
||||
// NOTE(philipc): We know the error location will be in the query string,
|
||||
// because schema validation always happens before this function is called.
|
||||
errorParts := strings.SplitN(err.Error(), ":", 4)
|
||||
msg := strings.TrimSuffix(strings.TrimLeft(errorParts[3], " "), ".\n")
|
||||
return fmt.Errorf("%s in GraphQL query string at location %s:%s", msg, errorParts[1], errorParts[2])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getBuiltinSchema() *gqlast.SchemaDocument {
|
||||
schema, err := gqlparser.ParseSchema(gqlvalidator.Prelude)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("Error in gqlparser Prelude (should be impossible): %w", err))
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
// NOTE(philipc): This function expects *validated* schema documents, and will break
|
||||
// if it is fed arbitrary structures.
|
||||
func mergeSchemaDocuments(docA *gqlast.SchemaDocument, docB *gqlast.SchemaDocument) *gqlast.SchemaDocument {
|
||||
ast := &gqlast.SchemaDocument{}
|
||||
ast.Merge(docA)
|
||||
ast.Merge(docB)
|
||||
return ast
|
||||
}
|
||||
|
||||
// Converts a SchemaDocument into a gqlast.Schema object that can be used for validation.
|
||||
// It merges in the builtin schema typedefs exactly as gqltop.LoadSchema did internally.
|
||||
func convertSchema(schemaDoc *gqlast.SchemaDocument) (*gqlast.Schema, error) {
|
||||
// Merge builtin schema + schema we were provided.
|
||||
builtinsSchemaDoc := getBuiltinSchema()
|
||||
mergedSchemaDoc := mergeSchemaDocuments(builtinsSchemaDoc, schemaDoc)
|
||||
schema, err := gqlvalidator.ValidateSchemaDocument(mergedSchemaDoc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error in gqlparser SchemaDocument to Schema conversion: %w", err)
|
||||
}
|
||||
return schema, nil
|
||||
}
|
||||
|
||||
// Converts an ast.Object into a gqlast.QueryDocument object.
|
||||
func objectToQueryDocument(value ast.Object) (*gqlast.QueryDocument, error) {
|
||||
// Convert ast.Term to interface{} for JSON encoding below.
|
||||
asJSON, err := ast.JSON(value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Marshal to JSON.
|
||||
bs, err := json.Marshal(asJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Unmarshal from JSON -> gqlast.QueryDocument.
|
||||
var result gqlast.QueryDocument
|
||||
err = json.Unmarshal(bs, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Converts an ast.Object into a gqlast.SchemaDocument object.
|
||||
func objectToSchemaDocument(value ast.Object) (*gqlast.SchemaDocument, error) {
|
||||
// Convert ast.Term to interface{} for JSON encoding below.
|
||||
asJSON, err := ast.JSON(value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Marshal to JSON.
|
||||
bs, err := json.Marshal(asJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Unmarshal from JSON -> gqlast.SchemaDocument.
|
||||
var result gqlast.SchemaDocument
|
||||
err = json.Unmarshal(bs, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Recursively traverses an AST that has been run through InterfaceToValue,
|
||||
// and prunes away the fields with null or empty values, and all `Position`
|
||||
// structs.
|
||||
// NOTE(philipc): We currently prune away null values to reduce the level
|
||||
// of clutter in the returned AST objects. In the future, if there is demand
|
||||
// for ASTs that have a more regular/fixed structure, we may need to provide
|
||||
// a "raw" version of the AST, where we still prune away the `Position`
|
||||
// structs, but leave in the null fields.
|
||||
func pruneIrrelevantGraphQLASTNodes(value ast.Value) ast.Value {
|
||||
// We iterate over the Value we've been provided, and recurse down
|
||||
// in the case of complex types, such as Arrays/Objects.
|
||||
// We are guaranteed to only have to deal with standard JSON types,
|
||||
// so this is much less ugly than what we'd need for supporting every
|
||||
// extant ast type!
|
||||
switch x := value.(type) {
|
||||
case *ast.Array:
|
||||
result := ast.NewArray()
|
||||
// Iterate over the array's elements, and do the following:
|
||||
// - Drop any Nulls
|
||||
// - Drop any any empty object/array value (after running the pruner)
|
||||
for i := 0; i < x.Len(); i++ {
|
||||
vTerm := x.Elem(i)
|
||||
switch v := vTerm.Value.(type) {
|
||||
case ast.Null:
|
||||
continue
|
||||
case *ast.Array:
|
||||
// Safe, because we knew the type before going to prune it.
|
||||
va := pruneIrrelevantGraphQLASTNodes(v).(*ast.Array)
|
||||
if va.Len() > 0 {
|
||||
result = result.Append(ast.NewTerm(va))
|
||||
}
|
||||
case ast.Object:
|
||||
// Safe, because we knew the type before going to prune it.
|
||||
vo := pruneIrrelevantGraphQLASTNodes(v).(ast.Object)
|
||||
if len(vo.Keys()) > 0 {
|
||||
result = result.Append(ast.NewTerm(vo))
|
||||
}
|
||||
default:
|
||||
result = result.Append(vTerm)
|
||||
}
|
||||
}
|
||||
return result
|
||||
case ast.Object:
|
||||
result := ast.NewObject()
|
||||
// Iterate over our object's keys, and do the following:
|
||||
// - Drop "Position".
|
||||
// - Drop any key with a Null value.
|
||||
// - Drop any key with an empty object/array value (after running the pruner)
|
||||
keys := x.Keys()
|
||||
for _, k := range keys {
|
||||
// We drop the "Position" objects because we don't need the
|
||||
// source-backref/location info they provide for policy rules.
|
||||
// Note that keys are ast.Strings.
|
||||
if ast.String("Position").Equal(k.Value) {
|
||||
continue
|
||||
}
|
||||
vTerm := x.Get(k)
|
||||
switch v := vTerm.Value.(type) {
|
||||
case ast.Null:
|
||||
continue
|
||||
case *ast.Array:
|
||||
// Safe, because we knew the type before going to prune it.
|
||||
va := pruneIrrelevantGraphQLASTNodes(v).(*ast.Array)
|
||||
if va.Len() > 0 {
|
||||
result.Insert(k, ast.NewTerm(va))
|
||||
}
|
||||
case ast.Object:
|
||||
// Safe, because we knew the type before going to prune it.
|
||||
vo := pruneIrrelevantGraphQLASTNodes(v).(ast.Object)
|
||||
if len(vo.Keys()) > 0 {
|
||||
result.Insert(k, ast.NewTerm(vo))
|
||||
}
|
||||
default:
|
||||
result.Insert(k, vTerm)
|
||||
}
|
||||
}
|
||||
return result
|
||||
default:
|
||||
return x
|
||||
}
|
||||
}
|
||||
|
||||
// Reports errors from parsing/validation.
|
||||
func builtinGraphQLParse(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
var queryDoc *gqlast.QueryDocument
|
||||
var schemaDoc *gqlast.SchemaDocument
|
||||
var err error
|
||||
|
||||
// Parse/translate query if it's a string/object.
|
||||
switch x := operands[0].Value.(type) {
|
||||
case ast.String:
|
||||
queryDoc, err = parseQuery(string(x))
|
||||
case ast.Object:
|
||||
queryDoc, err = objectToQueryDocument(x)
|
||||
default:
|
||||
// Error if wrong type.
|
||||
return builtins.NewOperandTypeErr(0, x, "string", "object")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse/translate schema if it's a string/object.
|
||||
switch x := operands[1].Value.(type) {
|
||||
case ast.String:
|
||||
schemaDoc, err = parseSchema(string(x))
|
||||
case ast.Object:
|
||||
schemaDoc, err = objectToSchemaDocument(x)
|
||||
default:
|
||||
// Error if wrong type.
|
||||
return builtins.NewOperandTypeErr(1, x, "string", "object")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Transform the ASTs into Objects.
|
||||
queryASTValue, err := ast.InterfaceToValue(queryDoc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
schemaASTValue, err := ast.InterfaceToValue(schemaDoc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate the query against the schema, erroring if there's an issue.
|
||||
schema, err := convertSchema(schemaDoc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateQuery(schema, queryDoc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Recursively remove irrelevant AST structures.
|
||||
queryResult := pruneIrrelevantGraphQLASTNodes(queryASTValue.(ast.Object))
|
||||
querySchema := pruneIrrelevantGraphQLASTNodes(schemaASTValue.(ast.Object))
|
||||
|
||||
// Construct return value.
|
||||
verified := ast.ArrayTerm(
|
||||
ast.NewTerm(queryResult),
|
||||
ast.NewTerm(querySchema),
|
||||
)
|
||||
|
||||
return iter(verified)
|
||||
}
|
||||
|
||||
// Returns default value when errors occur.
|
||||
func builtinGraphQLParseAndVerify(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
var queryDoc *gqlast.QueryDocument
|
||||
var schemaDoc *gqlast.SchemaDocument
|
||||
var err error
|
||||
|
||||
unverified := ast.ArrayTerm(
|
||||
ast.BooleanTerm(false),
|
||||
ast.NewTerm(ast.NewObject()),
|
||||
ast.NewTerm(ast.NewObject()),
|
||||
)
|
||||
|
||||
// Parse/translate query if it's a string/object.
|
||||
switch x := operands[0].Value.(type) {
|
||||
case ast.String:
|
||||
queryDoc, err = parseQuery(string(x))
|
||||
case ast.Object:
|
||||
queryDoc, err = objectToQueryDocument(x)
|
||||
default:
|
||||
// Error if wrong type.
|
||||
return iter(unverified)
|
||||
}
|
||||
if err != nil {
|
||||
return iter(unverified)
|
||||
}
|
||||
|
||||
// Parse/translate schema if it's a string/object.
|
||||
switch x := operands[1].Value.(type) {
|
||||
case ast.String:
|
||||
schemaDoc, err = parseSchema(string(x))
|
||||
case ast.Object:
|
||||
schemaDoc, err = objectToSchemaDocument(x)
|
||||
default:
|
||||
// Error if wrong type.
|
||||
return iter(unverified)
|
||||
}
|
||||
if err != nil {
|
||||
return iter(unverified)
|
||||
}
|
||||
|
||||
// Transform the ASTs into Objects.
|
||||
queryASTValue, err := ast.InterfaceToValue(queryDoc)
|
||||
if err != nil {
|
||||
return iter(unverified)
|
||||
}
|
||||
schemaASTValue, err := ast.InterfaceToValue(schemaDoc)
|
||||
if err != nil {
|
||||
return iter(unverified)
|
||||
}
|
||||
|
||||
// Validate the query against the schema, erroring if there's an issue.
|
||||
schema, err := convertSchema(schemaDoc)
|
||||
if err != nil {
|
||||
return iter(unverified)
|
||||
}
|
||||
if err := validateQuery(schema, queryDoc); err != nil {
|
||||
return iter(unverified)
|
||||
}
|
||||
|
||||
// Recursively remove irrelevant AST structures.
|
||||
queryResult := pruneIrrelevantGraphQLASTNodes(queryASTValue.(ast.Object))
|
||||
querySchema := pruneIrrelevantGraphQLASTNodes(schemaASTValue.(ast.Object))
|
||||
|
||||
// Construct return value.
|
||||
verified := ast.ArrayTerm(
|
||||
ast.BooleanTerm(true),
|
||||
ast.NewTerm(queryResult),
|
||||
ast.NewTerm(querySchema),
|
||||
)
|
||||
|
||||
return iter(verified)
|
||||
}
|
||||
|
||||
func builtinGraphQLParseQuery(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
raw, err := builtins.StringOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the highly-nested AST struct, along with any errors generated.
|
||||
query, err := parseQuery(string(raw))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Transform the AST into an Object.
|
||||
value, err := ast.InterfaceToValue(query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Recursively remove irrelevant AST structures.
|
||||
result := pruneIrrelevantGraphQLASTNodes(value.(ast.Object))
|
||||
|
||||
return iter(ast.NewTerm(result))
|
||||
}
|
||||
|
||||
func builtinGraphQLParseSchema(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
raw, err := builtins.StringOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the highly-nested AST struct, along with any errors generated.
|
||||
schema, err := parseSchema(string(raw))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Transform the AST into an Object.
|
||||
value, err := ast.InterfaceToValue(schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Recursively remove irrelevant AST structures.
|
||||
result := pruneIrrelevantGraphQLASTNodes(value.(ast.Object))
|
||||
|
||||
return iter(ast.NewTerm(result))
|
||||
}
|
||||
|
||||
func builtinGraphQLIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
var queryDoc *gqlast.QueryDocument
|
||||
var schemaDoc *gqlast.SchemaDocument
|
||||
var err error
|
||||
|
||||
switch x := operands[0].Value.(type) {
|
||||
case ast.String:
|
||||
queryDoc, err = parseQuery(string(x))
|
||||
case ast.Object:
|
||||
queryDoc, err = objectToQueryDocument(x)
|
||||
default:
|
||||
// Error if wrong type.
|
||||
return iter(ast.BooleanTerm(false))
|
||||
}
|
||||
if err != nil {
|
||||
return iter(ast.BooleanTerm(false))
|
||||
}
|
||||
|
||||
switch x := operands[1].Value.(type) {
|
||||
case ast.String:
|
||||
schemaDoc, err = parseSchema(string(x))
|
||||
case ast.Object:
|
||||
schemaDoc, err = objectToSchemaDocument(x)
|
||||
default:
|
||||
// Error if wrong type.
|
||||
return iter(ast.BooleanTerm(false))
|
||||
}
|
||||
if err != nil {
|
||||
return iter(ast.BooleanTerm(false))
|
||||
}
|
||||
|
||||
// Validate the query against the schema, erroring if there's an issue.
|
||||
schema, err := convertSchema(schemaDoc)
|
||||
if err != nil {
|
||||
return iter(ast.BooleanTerm(false))
|
||||
}
|
||||
if err := validateQuery(schema, queryDoc); err != nil {
|
||||
return iter(ast.BooleanTerm(false))
|
||||
}
|
||||
|
||||
// If we got this far, the GraphQL query passed validation.
|
||||
return iter(ast.BooleanTerm(true))
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.GraphQLParse.Name, builtinGraphQLParse)
|
||||
RegisterBuiltinFunc(ast.GraphQLParseAndVerify.Name, builtinGraphQLParseAndVerify)
|
||||
RegisterBuiltinFunc(ast.GraphQLParseQuery.Name, builtinGraphQLParseQuery)
|
||||
RegisterBuiltinFunc(ast.GraphQLParseSchema.Name, builtinGraphQLParseSchema)
|
||||
RegisterBuiltinFunc(ast.GraphQLIsValid.Name, builtinGraphQLIsValid)
|
||||
}
|
||||
1231
vendor/github.com/open-policy-agent/opa/topdown/http.go
generated
vendored
1231
vendor/github.com/open-policy-agent/opa/topdown/http.go
generated
vendored
File diff suppressed because it is too large
Load Diff
8
vendor/github.com/open-policy-agent/opa/topdown/http_fixup.go
generated
vendored
Normal file
8
vendor/github.com/open-policy-agent/opa/topdown/http_fixup.go
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build !go1.18 || !darwin
|
||||
// +build !go1.18 !darwin
|
||||
|
||||
package topdown
|
||||
|
||||
func fixupDarwinGo118(x string, _ string) string {
|
||||
return x
|
||||
}
|
||||
13
vendor/github.com/open-policy-agent/opa/topdown/http_fixup_darwin.go
generated
vendored
Normal file
13
vendor/github.com/open-policy-agent/opa/topdown/http_fixup_darwin.go
generated
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build go1.18
|
||||
// +build go1.18
|
||||
|
||||
package topdown
|
||||
|
||||
func fixupDarwinGo118(x, y string) string {
|
||||
switch x {
|
||||
case "x509: certificate signed by unknown authority":
|
||||
return y
|
||||
default:
|
||||
return x
|
||||
}
|
||||
}
|
||||
88
vendor/github.com/open-policy-agent/opa/topdown/input.go
generated
vendored
88
vendor/github.com/open-policy-agent/opa/topdown/input.go
generated
vendored
@@ -10,18 +10,14 @@ import (
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
)
|
||||
|
||||
var errConflictingDoc = fmt.Errorf("conflicting documents")
|
||||
var errBadPath = fmt.Errorf("bad document path")
|
||||
|
||||
func mergeTermWithValues(exist *ast.Term, pairs [][2]*ast.Term) (*ast.Term, error) {
|
||||
|
||||
var result *ast.Term
|
||||
var init bool
|
||||
|
||||
if exist != nil {
|
||||
result = exist.Copy()
|
||||
}
|
||||
|
||||
for _, pair := range pairs {
|
||||
for i, pair := range pairs {
|
||||
|
||||
if err := ast.IsValidImportPath(pair[0].Value); err != nil {
|
||||
return nil, errBadPath
|
||||
@@ -29,33 +25,55 @@ func mergeTermWithValues(exist *ast.Term, pairs [][2]*ast.Term) (*ast.Term, erro
|
||||
|
||||
target := pair[0].Value.(ast.Ref)
|
||||
|
||||
if len(target) == 1 {
|
||||
result = pair[1]
|
||||
} else if result == nil {
|
||||
result = ast.NewTerm(makeTree(target[1:], pair[1]))
|
||||
} else {
|
||||
node := result
|
||||
done := false
|
||||
for i := 1; i < len(target)-1 && !done; i++ {
|
||||
if child := node.Get(target[i]); child == nil {
|
||||
obj, ok := node.Value.(ast.Object)
|
||||
if !ok {
|
||||
return nil, errConflictingDoc
|
||||
}
|
||||
obj.Insert(target[i], ast.NewTerm(makeTree(target[i+1:], pair[1])))
|
||||
done = true
|
||||
} else {
|
||||
node = child
|
||||
}
|
||||
}
|
||||
if !done {
|
||||
obj, ok := node.Value.(ast.Object)
|
||||
if !ok {
|
||||
return nil, errConflictingDoc
|
||||
}
|
||||
obj.Insert(target[len(target)-1], pair[1])
|
||||
// Copy the value if subsequent pairs in the slice would modify it.
|
||||
for j := i + 1; j < len(pairs); j++ {
|
||||
other := pairs[j][0].Value.(ast.Ref)
|
||||
if len(other) > len(target) && other.HasPrefix(target) {
|
||||
pair[1] = pair[1].Copy()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(target) == 1 {
|
||||
result = pair[1]
|
||||
init = true
|
||||
} else {
|
||||
if !init {
|
||||
result = exist.Copy()
|
||||
init = true
|
||||
}
|
||||
if result == nil {
|
||||
result = ast.NewTerm(makeTree(target[1:], pair[1]))
|
||||
} else {
|
||||
node := result
|
||||
done := false
|
||||
for i := 1; i < len(target)-1 && !done; i++ {
|
||||
obj, ok := node.Value.(ast.Object)
|
||||
if !ok {
|
||||
result = ast.NewTerm(makeTree(target[i:], pair[1]))
|
||||
done = true
|
||||
continue
|
||||
}
|
||||
if child := obj.Get(target[i]); !isObject(child) {
|
||||
obj.Insert(target[i], ast.NewTerm(makeTree(target[i+1:], pair[1])))
|
||||
done = true
|
||||
} else { // child is object
|
||||
node = child
|
||||
}
|
||||
}
|
||||
if !done {
|
||||
if obj, ok := node.Value.(ast.Object); ok {
|
||||
obj.Insert(target[len(target)-1], pair[1])
|
||||
} else {
|
||||
result = ast.NewTerm(makeTree(target[len(target)-1:], pair[1]))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !init {
|
||||
result = exist
|
||||
}
|
||||
|
||||
return result, nil
|
||||
@@ -72,3 +90,11 @@ func makeTree(k ast.Ref, v *ast.Term) ast.Object {
|
||||
obj = ast.NewObject(ast.Item(k[0], v))
|
||||
return obj
|
||||
}
|
||||
|
||||
func isObject(x *ast.Term) bool {
|
||||
if x == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := x.Value.(ast.Object)
|
||||
return ok
|
||||
}
|
||||
|
||||
28
vendor/github.com/open-policy-agent/opa/topdown/instrumentation.go
generated
vendored
28
vendor/github.com/open-policy-agent/opa/topdown/instrumentation.go
generated
vendored
@@ -7,18 +7,22 @@ package topdown
|
||||
import "github.com/open-policy-agent/opa/metrics"
|
||||
|
||||
const (
|
||||
evalOpPlug = "eval_op_plug"
|
||||
evalOpResolve = "eval_op_resolve"
|
||||
evalOpRuleIndex = "eval_op_rule_index"
|
||||
evalOpBuiltinCall = "eval_op_builtin_call"
|
||||
evalOpVirtualCacheHit = "eval_op_virtual_cache_hit"
|
||||
evalOpVirtualCacheMiss = "eval_op_virtual_cache_miss"
|
||||
evalOpBaseCacheHit = "eval_op_base_cache_hit"
|
||||
evalOpBaseCacheMiss = "eval_op_base_cache_miss"
|
||||
partialOpSaveUnify = "partial_op_save_unify"
|
||||
partialOpSaveSetContains = "partial_op_save_set_contains"
|
||||
partialOpSaveSetContainsRec = "partial_op_save_set_contains_rec"
|
||||
partialOpCopyPropagation = "partial_op_copy_propagation"
|
||||
evalOpPlug = "eval_op_plug"
|
||||
evalOpResolve = "eval_op_resolve"
|
||||
evalOpRuleIndex = "eval_op_rule_index"
|
||||
evalOpBuiltinCall = "eval_op_builtin_call"
|
||||
evalOpVirtualCacheHit = "eval_op_virtual_cache_hit"
|
||||
evalOpVirtualCacheMiss = "eval_op_virtual_cache_miss"
|
||||
evalOpBaseCacheHit = "eval_op_base_cache_hit"
|
||||
evalOpBaseCacheMiss = "eval_op_base_cache_miss"
|
||||
evalOpComprehensionCacheSkip = "eval_op_comprehension_cache_skip"
|
||||
evalOpComprehensionCacheBuild = "eval_op_comprehension_cache_build"
|
||||
evalOpComprehensionCacheHit = "eval_op_comprehension_cache_hit"
|
||||
evalOpComprehensionCacheMiss = "eval_op_comprehension_cache_miss"
|
||||
partialOpSaveUnify = "partial_op_save_unify"
|
||||
partialOpSaveSetContains = "partial_op_save_set_contains"
|
||||
partialOpSaveSetContainsRec = "partial_op_save_set_contains_rec"
|
||||
partialOpCopyPropagation = "partial_op_copy_propagation"
|
||||
)
|
||||
|
||||
// Instrumentation implements helper functions to instrument query evaluation
|
||||
|
||||
21
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/LICENSE
generated
vendored
21
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/LICENSE
generated
vendored
@@ -1,21 +0,0 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015 lestrrat
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
113
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/buffer/buffer.go
generated
vendored
113
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/buffer/buffer.go
generated
vendored
@@ -1,113 +0,0 @@
|
||||
// Package buffer provides a very thin wrapper around []byte buffer called
|
||||
// `Buffer`, to provide functionalities that are often used within the jwx
|
||||
// related packages
|
||||
package buffer
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Buffer wraps `[]byte` and provides functions that are often used in
|
||||
// the jwx related packages. One notable difference is that while
|
||||
// encoding/json marshalls `[]byte` using base64.StdEncoding, this
|
||||
// module uses base64.RawURLEncoding as mandated by the spec
|
||||
type Buffer []byte
|
||||
|
||||
// FromUint creates a `Buffer` from an unsigned int
|
||||
func FromUint(v uint64) Buffer {
|
||||
data := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(data, v)
|
||||
|
||||
i := 0
|
||||
for ; i < len(data); i++ {
|
||||
if data[i] != 0x0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return Buffer(data[i:])
|
||||
}
|
||||
|
||||
// FromBase64 constructs a new Buffer from a base64 encoded data
|
||||
func FromBase64(v []byte) (Buffer, error) {
|
||||
b := Buffer{}
|
||||
if err := b.Base64Decode(v); err != nil {
|
||||
return Buffer(nil), errors.Wrap(err, "failed to decode from base64")
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// FromNData constructs a new Buffer from a "n:data" format
|
||||
// (I made that name up)
|
||||
func FromNData(v []byte) (Buffer, error) {
|
||||
size := binary.BigEndian.Uint32(v)
|
||||
buf := make([]byte, int(size))
|
||||
copy(buf, v[4:4+size])
|
||||
return Buffer(buf), nil
|
||||
}
|
||||
|
||||
// Bytes returns the raw bytes that comprises the Buffer
|
||||
func (b Buffer) Bytes() []byte {
|
||||
return []byte(b)
|
||||
}
|
||||
|
||||
// NData returns Datalen || Data, where Datalen is a 32 bit counter for
|
||||
// the length of the following data, and Data is the octets that comprise
|
||||
// the buffer data
|
||||
func (b Buffer) NData() []byte {
|
||||
buf := make([]byte, 4+b.Len())
|
||||
binary.BigEndian.PutUint32(buf, uint32(b.Len()))
|
||||
|
||||
copy(buf[4:], b.Bytes())
|
||||
return buf
|
||||
}
|
||||
|
||||
// Len returns the number of bytes that the Buffer holds
|
||||
func (b Buffer) Len() int {
|
||||
return len(b)
|
||||
}
|
||||
|
||||
// Base64Encode encodes the contents of the Buffer using base64.RawURLEncoding
|
||||
func (b Buffer) Base64Encode() ([]byte, error) {
|
||||
enc := base64.RawURLEncoding
|
||||
out := make([]byte, enc.EncodedLen(len(b)))
|
||||
enc.Encode(out, b)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Base64Decode decodes the contents of the Buffer using base64.RawURLEncoding
|
||||
func (b *Buffer) Base64Decode(v []byte) error {
|
||||
enc := base64.RawURLEncoding
|
||||
out := make([]byte, enc.DecodedLen(len(v)))
|
||||
n, err := enc.Decode(out, v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to decode from base64")
|
||||
}
|
||||
out = out[:n]
|
||||
*b = Buffer(out)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON marshals the buffer into JSON format after encoding the buffer
|
||||
// with base64.RawURLEncoding
|
||||
func (b Buffer) MarshalJSON() ([]byte, error) {
|
||||
v, err := b.Base64Encode()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to encode to base64")
|
||||
}
|
||||
return json.Marshal(string(v))
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals from a JSON string into a Buffer, after decoding it
|
||||
// with base64.RawURLEncoding
|
||||
func (b *Buffer) UnmarshalJSON(data []byte) error {
|
||||
var x string
|
||||
if err := json.Unmarshal(data, &x); err != nil {
|
||||
return errors.Wrap(err, "failed to unmarshal JSON")
|
||||
}
|
||||
return b.Base64Decode([]byte(x))
|
||||
}
|
||||
11
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwa/elliptic.go
generated
vendored
11
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwa/elliptic.go
generated
vendored
@@ -1,11 +0,0 @@
|
||||
package jwa
|
||||
|
||||
// EllipticCurveAlgorithm represents the algorithms used for EC keys
|
||||
type EllipticCurveAlgorithm string
|
||||
|
||||
// Supported values for EllipticCurveAlgorithm
|
||||
const (
|
||||
P256 EllipticCurveAlgorithm = "P-256"
|
||||
P384 EllipticCurveAlgorithm = "P-384"
|
||||
P521 EllipticCurveAlgorithm = "P-521"
|
||||
)
|
||||
67
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwa/key_type.go
generated
vendored
67
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwa/key_type.go
generated
vendored
@@ -1,67 +0,0 @@
|
||||
package jwa
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// KeyType represents the key type ("kty") that are supported
|
||||
type KeyType string
|
||||
|
||||
var keyTypeAlg = map[string]struct{}{"EC": {}, "oct": {}, "RSA": {}}
|
||||
|
||||
// Supported values for KeyType
|
||||
const (
|
||||
EC KeyType = "EC" // Elliptic Curve
|
||||
InvalidKeyType KeyType = "" // Invalid KeyType
|
||||
OctetSeq KeyType = "oct" // Octet sequence (used to represent symmetric keys)
|
||||
RSA KeyType = "RSA" // RSA
|
||||
)
|
||||
|
||||
// Accept is used when conversion from values given by
|
||||
// outside sources (such as JSON payloads) is required
|
||||
func (keyType *KeyType) Accept(value interface{}) error {
|
||||
var tmp KeyType
|
||||
switch x := value.(type) {
|
||||
case string:
|
||||
tmp = KeyType(x)
|
||||
case KeyType:
|
||||
tmp = x
|
||||
default:
|
||||
return errors.Errorf(`invalid type for jwa.KeyType: %T`, value)
|
||||
}
|
||||
_, ok := keyTypeAlg[tmp.String()]
|
||||
if !ok {
|
||||
return errors.Errorf("Unknown Key Type algorithm")
|
||||
}
|
||||
|
||||
*keyType = tmp
|
||||
return nil
|
||||
}
|
||||
|
||||
// String returns the string representation of a KeyType
|
||||
func (keyType KeyType) String() string {
|
||||
return string(keyType)
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals and checks data as KeyType Algorithm
|
||||
func (keyType *KeyType) UnmarshalJSON(data []byte) error {
|
||||
var quote byte = '"'
|
||||
var quoted string
|
||||
if data[0] == quote {
|
||||
var err error
|
||||
quoted, err = strconv.Unquote(string(data))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to process signature algorithm")
|
||||
}
|
||||
} else {
|
||||
quoted = string(data)
|
||||
}
|
||||
_, ok := keyTypeAlg[quoted]
|
||||
if !ok {
|
||||
return errors.Errorf("Unknown signature algorithm")
|
||||
}
|
||||
*keyType = KeyType(quoted)
|
||||
return nil
|
||||
}
|
||||
29
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwa/parameters.go
generated
vendored
29
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwa/parameters.go
generated
vendored
@@ -1,29 +0,0 @@
|
||||
package jwa
|
||||
|
||||
import (
|
||||
"crypto/elliptic"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/buffer"
|
||||
)
|
||||
|
||||
// EllipticCurve provides a indirect type to standard elliptic curve such that we can
|
||||
// use it for unmarshal
|
||||
type EllipticCurve struct {
|
||||
elliptic.Curve
|
||||
}
|
||||
|
||||
// AlgorithmParameters provides a single structure suitable to unmarshaling any JWK
|
||||
type AlgorithmParameters struct {
|
||||
N buffer.Buffer `json:"n,omitempty"`
|
||||
E buffer.Buffer `json:"e,omitempty"`
|
||||
D buffer.Buffer `json:"d,omitempty"`
|
||||
P buffer.Buffer `json:"p,omitempty"`
|
||||
Q buffer.Buffer `json:"q,omitempty"`
|
||||
Dp buffer.Buffer `json:"dp,omitempty"`
|
||||
Dq buffer.Buffer `json:"dq,omitempty"`
|
||||
Qi buffer.Buffer `json:"qi,omitempty"`
|
||||
Crv EllipticCurveAlgorithm `json:"crv,omitempty"`
|
||||
X buffer.Buffer `json:"x,omitempty"`
|
||||
Y buffer.Buffer `json:"y,omitempty"`
|
||||
K buffer.Buffer `json:"k,omitempty"`
|
||||
}
|
||||
76
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwa/signature.go
generated
vendored
76
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwa/signature.go
generated
vendored
@@ -1,76 +0,0 @@
|
||||
package jwa
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// SignatureAlgorithm represents the various signature algorithms as described in https://tools.ietf.org/html/rfc7518#section-3.1
|
||||
type SignatureAlgorithm string
|
||||
|
||||
var signatureAlg = map[string]struct{}{"ES256": {}, "ES384": {}, "ES512": {}, "HS256": {}, "HS384": {}, "HS512": {}, "PS256": {}, "PS384": {}, "PS512": {}, "RS256": {}, "RS384": {}, "RS512": {}, "none": {}}
|
||||
|
||||
// Supported values for SignatureAlgorithm
|
||||
const (
|
||||
ES256 SignatureAlgorithm = "ES256" // ECDSA using P-256 and SHA-256
|
||||
ES384 SignatureAlgorithm = "ES384" // ECDSA using P-384 and SHA-384
|
||||
ES512 SignatureAlgorithm = "ES512" // ECDSA using P-521 and SHA-512
|
||||
HS256 SignatureAlgorithm = "HS256" // HMAC using SHA-256
|
||||
HS384 SignatureAlgorithm = "HS384" // HMAC using SHA-384
|
||||
HS512 SignatureAlgorithm = "HS512" // HMAC using SHA-512
|
||||
NoSignature SignatureAlgorithm = "none"
|
||||
PS256 SignatureAlgorithm = "PS256" // RSASSA-PSS using SHA256 and MGF1-SHA256
|
||||
PS384 SignatureAlgorithm = "PS384" // RSASSA-PSS using SHA384 and MGF1-SHA384
|
||||
PS512 SignatureAlgorithm = "PS512" // RSASSA-PSS using SHA512 and MGF1-SHA512
|
||||
RS256 SignatureAlgorithm = "RS256" // RSASSA-PKCS-v1.5 using SHA-256
|
||||
RS384 SignatureAlgorithm = "RS384" // RSASSA-PKCS-v1.5 using SHA-384
|
||||
RS512 SignatureAlgorithm = "RS512" // RSASSA-PKCS-v1.5 using SHA-512
|
||||
NoValue SignatureAlgorithm = "" // No value is different from none
|
||||
)
|
||||
|
||||
// Accept is used when conversion from values given by
|
||||
// outside sources (such as JSON payloads) is required
|
||||
func (signature *SignatureAlgorithm) Accept(value interface{}) error {
|
||||
var tmp SignatureAlgorithm
|
||||
switch x := value.(type) {
|
||||
case string:
|
||||
tmp = SignatureAlgorithm(x)
|
||||
case SignatureAlgorithm:
|
||||
tmp = x
|
||||
default:
|
||||
return errors.Errorf(`invalid type for jwa.SignatureAlgorithm: %T`, value)
|
||||
}
|
||||
_, ok := signatureAlg[tmp.String()]
|
||||
if !ok {
|
||||
return errors.Errorf("Unknown signature algorithm")
|
||||
}
|
||||
*signature = tmp
|
||||
return nil
|
||||
}
|
||||
|
||||
// String returns the string representation of a SignatureAlgorithm
|
||||
func (signature SignatureAlgorithm) String() string {
|
||||
return string(signature)
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals and checks data as Signature Algorithm
|
||||
func (signature *SignatureAlgorithm) UnmarshalJSON(data []byte) error {
|
||||
var quote byte = '"'
|
||||
var quoted string
|
||||
if data[0] == quote {
|
||||
var err error
|
||||
quoted, err = strconv.Unquote(string(data))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to process signature algorithm")
|
||||
}
|
||||
} else {
|
||||
quoted = string(data)
|
||||
}
|
||||
_, ok := signatureAlg[quoted]
|
||||
if !ok {
|
||||
return errors.Errorf("Unknown signature algorithm")
|
||||
}
|
||||
*signature = SignatureAlgorithm(quoted)
|
||||
return nil
|
||||
}
|
||||
120
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/ecdsa.go
generated
vendored
120
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/ecdsa.go
generated
vendored
@@ -1,120 +0,0 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"math/big"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
func newECDSAPublicKey(key *ecdsa.PublicKey) (*ECDSAPublicKey, error) {
|
||||
|
||||
var hdr StandardHeaders
|
||||
err := hdr.Set(KeyTypeKey, jwa.EC)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "Failed to set Key Type")
|
||||
}
|
||||
|
||||
return &ECDSAPublicKey{
|
||||
StandardHeaders: &hdr,
|
||||
key: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newECDSAPrivateKey(key *ecdsa.PrivateKey) (*ECDSAPrivateKey, error) {
|
||||
|
||||
var hdr StandardHeaders
|
||||
err := hdr.Set(KeyTypeKey, jwa.EC)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "Failed to set Key Type")
|
||||
}
|
||||
|
||||
return &ECDSAPrivateKey{
|
||||
StandardHeaders: &hdr,
|
||||
key: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Materialize returns the EC-DSA public key represented by this JWK
|
||||
func (k ECDSAPublicKey) Materialize() (interface{}, error) {
|
||||
return k.key, nil
|
||||
}
|
||||
|
||||
// Materialize returns the EC-DSA private key represented by this JWK
|
||||
func (k ECDSAPrivateKey) Materialize() (interface{}, error) {
|
||||
return k.key, nil
|
||||
}
|
||||
|
||||
// GenerateKey creates a ECDSAPublicKey from JWK format
|
||||
func (k *ECDSAPublicKey) GenerateKey(keyJSON *RawKeyJSON) error {
|
||||
|
||||
var x, y big.Int
|
||||
|
||||
if keyJSON.X == nil || keyJSON.Y == nil || keyJSON.Crv == "" {
|
||||
return errors.Errorf("Missing mandatory key parameters X, Y or Crv")
|
||||
}
|
||||
|
||||
x.SetBytes(keyJSON.X.Bytes())
|
||||
y.SetBytes(keyJSON.Y.Bytes())
|
||||
|
||||
var curve elliptic.Curve
|
||||
switch keyJSON.Crv {
|
||||
case jwa.P256:
|
||||
curve = elliptic.P256()
|
||||
case jwa.P384:
|
||||
curve = elliptic.P384()
|
||||
case jwa.P521:
|
||||
curve = elliptic.P521()
|
||||
default:
|
||||
return errors.Errorf(`invalid curve name %s`, keyJSON.Crv)
|
||||
}
|
||||
|
||||
*k = ECDSAPublicKey{
|
||||
StandardHeaders: &keyJSON.StandardHeaders,
|
||||
key: &ecdsa.PublicKey{
|
||||
Curve: curve,
|
||||
X: &x,
|
||||
Y: &y,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateKey creates a ECDSAPrivateKey from JWK format
|
||||
func (k *ECDSAPrivateKey) GenerateKey(keyJSON *RawKeyJSON) error {
|
||||
|
||||
if keyJSON.D == nil {
|
||||
return errors.Errorf("Missing mandatory key parameter D")
|
||||
}
|
||||
eCDSAPublicKey := &ECDSAPublicKey{}
|
||||
err := eCDSAPublicKey.GenerateKey(keyJSON)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, `failed to generate public key`)
|
||||
}
|
||||
dBytes := keyJSON.D.Bytes()
|
||||
// The length of this octet string MUST be ceiling(log-base-2(n)/8)
|
||||
// octets (where n is the order of the curve). This is because the private
|
||||
// key d must be in the interval [1, n-1] so the bitlength of d should be
|
||||
// no larger than the bitlength of n-1. The easiest way to find the octet
|
||||
// length is to take bitlength(n-1), add 7 to force a carry, and shift this
|
||||
// bit sequence right by 3, which is essentially dividing by 8 and adding
|
||||
// 1 if there is any remainder. Thus, the private key value d should be
|
||||
// output to (bitlength(n-1)+7)>>3 octets.
|
||||
n := eCDSAPublicKey.key.Params().N
|
||||
octetLength := (new(big.Int).Sub(n, big.NewInt(1)).BitLen() + 7) >> 3
|
||||
if octetLength-len(dBytes) != 0 {
|
||||
return errors.Errorf("Failed to generate private key. Incorrect D value")
|
||||
}
|
||||
privateKey := &ecdsa.PrivateKey{
|
||||
PublicKey: *eCDSAPublicKey.key,
|
||||
D: (&big.Int{}).SetBytes(keyJSON.D.Bytes()),
|
||||
}
|
||||
|
||||
k.key = privateKey
|
||||
k.StandardHeaders = &keyJSON.StandardHeaders
|
||||
|
||||
return nil
|
||||
}
|
||||
178
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/headers.go
generated
vendored
178
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/headers.go
generated
vendored
@@ -1,178 +0,0 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
// Convenience constants for common JWK parameters
|
||||
const (
|
||||
AlgorithmKey = "alg"
|
||||
KeyIDKey = "kid"
|
||||
KeyOpsKey = "key_ops"
|
||||
KeyTypeKey = "kty"
|
||||
KeyUsageKey = "use"
|
||||
PrivateParamsKey = "privateParams"
|
||||
)
|
||||
|
||||
// Headers provides a common interface to all future possible headers
|
||||
type Headers interface {
|
||||
Get(string) (interface{}, bool)
|
||||
Set(string, interface{}) error
|
||||
Walk(func(string, interface{}) error) error
|
||||
GetAlgorithm() jwa.SignatureAlgorithm
|
||||
GetKeyID() string
|
||||
GetKeyOps() KeyOperationList
|
||||
GetKeyType() jwa.KeyType
|
||||
GetKeyUsage() string
|
||||
GetPrivateParams() map[string]interface{}
|
||||
}
|
||||
|
||||
// StandardHeaders stores the common JWK parameters
|
||||
type StandardHeaders struct {
|
||||
Algorithm *jwa.SignatureAlgorithm `json:"alg,omitempty"` // https://tools.ietf.org/html/rfc7517#section-4.4
|
||||
KeyID string `json:"kid,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.4
|
||||
KeyOps KeyOperationList `json:"key_ops,omitempty"` // https://tools.ietf.org/html/rfc7517#section-4.3
|
||||
KeyType jwa.KeyType `json:"kty,omitempty"` // https://tools.ietf.org/html/rfc7517#section-4.1
|
||||
KeyUsage string `json:"use,omitempty"` // https://tools.ietf.org/html/rfc7517#section-4.2
|
||||
PrivateParams map[string]interface{} `json:"privateParams,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.4
|
||||
}
|
||||
|
||||
// GetAlgorithm is a convenience function to retrieve the corresponding value stored in the StandardHeaders
|
||||
func (h *StandardHeaders) GetAlgorithm() jwa.SignatureAlgorithm {
|
||||
if v := h.Algorithm; v != nil {
|
||||
return *v
|
||||
}
|
||||
return jwa.NoValue
|
||||
}
|
||||
|
||||
// GetKeyID is a convenience function to retrieve the corresponding value stored in the StandardHeaders
|
||||
func (h *StandardHeaders) GetKeyID() string {
|
||||
return h.KeyID
|
||||
}
|
||||
|
||||
// GetKeyOps is a convenience function to retrieve the corresponding value stored in the StandardHeaders
|
||||
func (h *StandardHeaders) GetKeyOps() KeyOperationList {
|
||||
return h.KeyOps
|
||||
}
|
||||
|
||||
// GetKeyType is a convenience function to retrieve the corresponding value stored in the StandardHeaders
|
||||
func (h *StandardHeaders) GetKeyType() jwa.KeyType {
|
||||
return h.KeyType
|
||||
}
|
||||
|
||||
// GetKeyUsage is a convenience function to retrieve the corresponding value stored in the StandardHeaders
|
||||
func (h *StandardHeaders) GetKeyUsage() string {
|
||||
return h.KeyUsage
|
||||
}
|
||||
|
||||
// GetPrivateParams is a convenience function to retrieve the corresponding value stored in the StandardHeaders
|
||||
func (h *StandardHeaders) GetPrivateParams() map[string]interface{} {
|
||||
return h.PrivateParams
|
||||
}
|
||||
|
||||
// Get is a general getter function for JWK StandardHeaders structure
|
||||
func (h *StandardHeaders) Get(name string) (interface{}, bool) {
|
||||
switch name {
|
||||
case AlgorithmKey:
|
||||
alg := h.GetAlgorithm()
|
||||
if alg != jwa.NoValue {
|
||||
return alg, true
|
||||
}
|
||||
return nil, false
|
||||
case KeyIDKey:
|
||||
v := h.KeyID
|
||||
if v == "" {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case KeyOpsKey:
|
||||
v := h.KeyOps
|
||||
if v == nil {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case KeyTypeKey:
|
||||
v := h.KeyType
|
||||
if v == jwa.InvalidKeyType {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case KeyUsageKey:
|
||||
v := h.KeyUsage
|
||||
if v == "" {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case PrivateParamsKey:
|
||||
v := h.PrivateParams
|
||||
if len(v) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
// Set is a general getter function for JWK StandardHeaders structure
|
||||
func (h *StandardHeaders) Set(name string, value interface{}) error {
|
||||
switch name {
|
||||
case AlgorithmKey:
|
||||
var acceptor jwa.SignatureAlgorithm
|
||||
if err := acceptor.Accept(value); err != nil {
|
||||
return errors.Wrapf(err, `invalid value for %s key`, AlgorithmKey)
|
||||
}
|
||||
h.Algorithm = &acceptor
|
||||
return nil
|
||||
case KeyIDKey:
|
||||
if v, ok := value.(string); ok {
|
||||
h.KeyID = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf("invalid value for %s key: %T", KeyIDKey, value)
|
||||
case KeyOpsKey:
|
||||
if err := h.KeyOps.Accept(value); err != nil {
|
||||
return errors.Wrapf(err, "invalid value for %s key", KeyOpsKey)
|
||||
}
|
||||
return nil
|
||||
case KeyTypeKey:
|
||||
if err := h.KeyType.Accept(value); err != nil {
|
||||
return errors.Wrapf(err, "invalid value for %s key", KeyTypeKey)
|
||||
}
|
||||
return nil
|
||||
case KeyUsageKey:
|
||||
if v, ok := value.(string); ok {
|
||||
h.KeyUsage = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf("invalid value for %s key: %T", KeyUsageKey, value)
|
||||
case PrivateParamsKey:
|
||||
if v, ok := value.(map[string]interface{}); ok {
|
||||
h.PrivateParams = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf("invalid value for %s key: %T", PrivateParamsKey, value)
|
||||
default:
|
||||
return errors.Errorf(`invalid key: %s`, name)
|
||||
}
|
||||
}
|
||||
|
||||
// Walk iterates over all JWK standard headers fields while applying a function to its value.
|
||||
func (h StandardHeaders) Walk(f func(string, interface{}) error) error {
|
||||
for _, key := range []string{AlgorithmKey, KeyIDKey, KeyOpsKey, KeyTypeKey, KeyUsageKey, PrivateParamsKey} {
|
||||
if v, ok := h.Get(key); ok {
|
||||
if err := f(key, v); err != nil {
|
||||
return errors.Wrapf(err, `walk function returned error for %s`, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range h.PrivateParams {
|
||||
if err := f(k, v); err != nil {
|
||||
return errors.Wrapf(err, `walk function returned error for %s`, k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
70
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/interface.go
generated
vendored
70
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/interface.go
generated
vendored
@@ -1,70 +0,0 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
// Set is a convenience struct to allow generating and parsing
|
||||
// JWK sets as opposed to single JWKs
|
||||
type Set struct {
|
||||
Keys []Key `json:"keys"`
|
||||
}
|
||||
|
||||
// Key defines the minimal interface for each of the
|
||||
// key types. Their use and implementation differ significantly
|
||||
// between each key types, so you should use type assertions
|
||||
// to perform more specific tasks with each key
|
||||
type Key interface {
|
||||
Headers
|
||||
|
||||
// Materialize creates the corresponding key. For example,
|
||||
// RSA types would create *rsa.PublicKey or *rsa.PrivateKey,
|
||||
// EC types would create *ecdsa.PublicKey or *ecdsa.PrivateKey,
|
||||
// and OctetSeq types create a []byte key.
|
||||
Materialize() (interface{}, error)
|
||||
GenerateKey(*RawKeyJSON) error
|
||||
}
|
||||
|
||||
// RawKeyJSON is generic type that represents any kind JWK
|
||||
type RawKeyJSON struct {
|
||||
StandardHeaders
|
||||
jwa.AlgorithmParameters
|
||||
}
|
||||
|
||||
// RawKeySetJSON is generic type that represents a JWK Set
|
||||
type RawKeySetJSON struct {
|
||||
Keys []RawKeyJSON `json:"keys"`
|
||||
}
|
||||
|
||||
// RSAPublicKey is a type of JWK generated from RSA public keys
|
||||
type RSAPublicKey struct {
|
||||
*StandardHeaders
|
||||
key *rsa.PublicKey
|
||||
}
|
||||
|
||||
// RSAPrivateKey is a type of JWK generated from RSA private keys
|
||||
type RSAPrivateKey struct {
|
||||
*StandardHeaders
|
||||
key *rsa.PrivateKey
|
||||
}
|
||||
|
||||
// SymmetricKey is a type of JWK generated from symmetric keys
|
||||
type SymmetricKey struct {
|
||||
*StandardHeaders
|
||||
key []byte
|
||||
}
|
||||
|
||||
// ECDSAPublicKey is a type of JWK generated from ECDSA public keys
|
||||
type ECDSAPublicKey struct {
|
||||
*StandardHeaders
|
||||
key *ecdsa.PublicKey
|
||||
}
|
||||
|
||||
// ECDSAPrivateKey is a type of JWK generated from ECDH-ES private keys
|
||||
type ECDSAPrivateKey struct {
|
||||
*StandardHeaders
|
||||
key *ecdsa.PrivateKey
|
||||
}
|
||||
150
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/jwk.go
generated
vendored
150
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/jwk.go
generated
vendored
@@ -1,150 +0,0 @@
|
||||
// Package jwk implements JWK as described in https://tools.ietf.org/html/rfc7517
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
// GetPublicKey returns the public key based on the private key type.
|
||||
// For rsa key types *rsa.PublicKey is returned; for ecdsa key types *ecdsa.PublicKey;
|
||||
// for byte slice (raw) keys, the key itself is returned. If the corresponding
|
||||
// public key cannot be deduced, an error is returned
|
||||
func GetPublicKey(key interface{}) (interface{}, error) {
|
||||
if key == nil {
|
||||
return nil, errors.New(`jwk.New requires a non-nil key`)
|
||||
}
|
||||
|
||||
switch v := key.(type) {
|
||||
// Mental note: although Public() is defined in both types,
|
||||
// you can not coalesce the clauses for rsa.PrivateKey and
|
||||
// ecdsa.PrivateKey, as then `v` becomes interface{}
|
||||
// b/c the compiler cannot deduce the exact type.
|
||||
case *rsa.PrivateKey:
|
||||
return v.Public(), nil
|
||||
case *ecdsa.PrivateKey:
|
||||
return v.Public(), nil
|
||||
case []byte:
|
||||
return v, nil
|
||||
default:
|
||||
return nil, errors.Errorf(`invalid key type %T`, key)
|
||||
}
|
||||
}
|
||||
|
||||
// GetKeyTypeFromKey creates a jwk.Key from the given key.
|
||||
func GetKeyTypeFromKey(key interface{}) jwa.KeyType {
|
||||
|
||||
switch key.(type) {
|
||||
case *rsa.PrivateKey, *rsa.PublicKey:
|
||||
return jwa.RSA
|
||||
case *ecdsa.PrivateKey, *ecdsa.PublicKey:
|
||||
return jwa.EC
|
||||
case []byte:
|
||||
return jwa.OctetSeq
|
||||
default:
|
||||
return jwa.InvalidKeyType
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a jwk.Key from the given key.
|
||||
func New(key interface{}) (Key, error) {
|
||||
if key == nil {
|
||||
return nil, errors.New(`jwk.New requires a non-nil key`)
|
||||
}
|
||||
|
||||
switch v := key.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return newRSAPrivateKey(v)
|
||||
case *rsa.PublicKey:
|
||||
return newRSAPublicKey(v)
|
||||
case *ecdsa.PrivateKey:
|
||||
return newECDSAPrivateKey(v)
|
||||
case *ecdsa.PublicKey:
|
||||
return newECDSAPublicKey(v)
|
||||
case []byte:
|
||||
return newSymmetricKey(v)
|
||||
default:
|
||||
return nil, errors.Errorf(`invalid key type %T`, key)
|
||||
}
|
||||
}
|
||||
|
||||
func parse(jwkSrc string) (*Set, error) {
|
||||
|
||||
var jwkKeySet Set
|
||||
var jwkKey Key
|
||||
rawKeySetJSON := &RawKeySetJSON{}
|
||||
err := json.Unmarshal([]byte(jwkSrc), rawKeySetJSON)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to unmarshal JWK Set")
|
||||
}
|
||||
if len(rawKeySetJSON.Keys) == 0 {
|
||||
|
||||
// It might be a single key
|
||||
rawKeyJSON := &RawKeyJSON{}
|
||||
err := json.Unmarshal([]byte(jwkSrc), rawKeyJSON)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to unmarshal JWK")
|
||||
}
|
||||
jwkKey, err = rawKeyJSON.GenerateKey()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to generate key")
|
||||
}
|
||||
// Add to set
|
||||
jwkKeySet.Keys = append(jwkKeySet.Keys, jwkKey)
|
||||
} else {
|
||||
for i := range rawKeySetJSON.Keys {
|
||||
rawKeyJSON := rawKeySetJSON.Keys[i]
|
||||
jwkKey, err = rawKeyJSON.GenerateKey()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to generate key: %s")
|
||||
}
|
||||
jwkKeySet.Keys = append(jwkKeySet.Keys, jwkKey)
|
||||
}
|
||||
}
|
||||
return &jwkKeySet, nil
|
||||
}
|
||||
|
||||
// ParseBytes parses JWK from the incoming byte buffer.
|
||||
func ParseBytes(buf []byte) (*Set, error) {
|
||||
return parse(string(buf[:]))
|
||||
}
|
||||
|
||||
// ParseString parses JWK from the incoming string.
|
||||
func ParseString(s string) (*Set, error) {
|
||||
return parse(s)
|
||||
}
|
||||
|
||||
// GenerateKey creates an internal representation of a key from a raw JWK JSON
|
||||
func (r *RawKeyJSON) GenerateKey() (Key, error) {
|
||||
|
||||
var key Key
|
||||
|
||||
switch r.KeyType {
|
||||
case jwa.RSA:
|
||||
if r.D != nil {
|
||||
key = &RSAPrivateKey{}
|
||||
} else {
|
||||
key = &RSAPublicKey{}
|
||||
}
|
||||
case jwa.EC:
|
||||
if r.D != nil {
|
||||
key = &ECDSAPrivateKey{}
|
||||
} else {
|
||||
key = &ECDSAPublicKey{}
|
||||
}
|
||||
case jwa.OctetSeq:
|
||||
key = &SymmetricKey{}
|
||||
default:
|
||||
return nil, errors.Errorf(`Unrecognized key type`)
|
||||
}
|
||||
err := key.GenerateKey(r)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to generate key from JWK")
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
68
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/key_ops.go
generated
vendored
68
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/key_ops.go
generated
vendored
@@ -1,68 +0,0 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// KeyUsageType is used to denote what this key should be used for
|
||||
type KeyUsageType string
|
||||
|
||||
const (
|
||||
// ForSignature is the value used in the headers to indicate that
|
||||
// this key should be used for signatures
|
||||
ForSignature KeyUsageType = "sig"
|
||||
// ForEncryption is the value used in the headers to indicate that
|
||||
// this key should be used for encryptiong
|
||||
ForEncryption KeyUsageType = "enc"
|
||||
)
|
||||
|
||||
// KeyOperation is used to denote the allowed operations for a Key
|
||||
type KeyOperation string
|
||||
|
||||
// KeyOperationList represents an slice of KeyOperation
|
||||
type KeyOperationList []KeyOperation
|
||||
|
||||
var keyOps = map[string]struct{}{"sign": {}, "verify": {}, "encrypt": {}, "decrypt": {}, "wrapKey": {}, "unwrapKey": {}, "deriveKey": {}, "deriveBits": {}}
|
||||
|
||||
// KeyOperation constants
|
||||
const (
|
||||
KeyOpSign KeyOperation = "sign" // (compute digital signature or MAC)
|
||||
KeyOpVerify = "verify" // (verify digital signature or MAC)
|
||||
KeyOpEncrypt = "encrypt" // (encrypt content)
|
||||
KeyOpDecrypt = "decrypt" // (decrypt content and validate decryption, if applicable)
|
||||
KeyOpWrapKey = "wrapKey" // (encrypt key)
|
||||
KeyOpUnwrapKey = "unwrapKey" // (decrypt key and validate decryption, if applicable)
|
||||
KeyOpDeriveKey = "deriveKey" // (derive key)
|
||||
KeyOpDeriveBits = "deriveBits" // (derive bits not to be used as a key)
|
||||
)
|
||||
|
||||
// Accept determines if Key Operation is valid
|
||||
func (keyOperationList *KeyOperationList) Accept(v interface{}) error {
|
||||
switch x := v.(type) {
|
||||
case KeyOperationList:
|
||||
*keyOperationList = x
|
||||
return nil
|
||||
default:
|
||||
return errors.Errorf(`invalid value %T`, v)
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals and checks data as KeyType Algorithm
|
||||
func (keyOperationList *KeyOperationList) UnmarshalJSON(data []byte) error {
|
||||
var tempKeyOperationList []string
|
||||
err := json.Unmarshal(data, &tempKeyOperationList)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid key operation")
|
||||
}
|
||||
for _, value := range tempKeyOperationList {
|
||||
_, ok := keyOps[value]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown key operation")
|
||||
}
|
||||
*keyOperationList = append(*keyOperationList, KeyOperation(value))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
103
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/rsa.go
generated
vendored
103
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/rsa.go
generated
vendored
@@ -1,103 +0,0 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
func newRSAPublicKey(key *rsa.PublicKey) (*RSAPublicKey, error) {
|
||||
|
||||
var hdr StandardHeaders
|
||||
err := hdr.Set(KeyTypeKey, jwa.RSA)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "Failed to set Key Type")
|
||||
}
|
||||
return &RSAPublicKey{
|
||||
StandardHeaders: &hdr,
|
||||
key: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newRSAPrivateKey(key *rsa.PrivateKey) (*RSAPrivateKey, error) {
|
||||
|
||||
var hdr StandardHeaders
|
||||
err := hdr.Set(KeyTypeKey, jwa.RSA)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "Failed to set Key Type")
|
||||
}
|
||||
return &RSAPrivateKey{
|
||||
StandardHeaders: &hdr,
|
||||
key: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Materialize returns the standard RSA Public Key representation stored in the internal representation
|
||||
func (k *RSAPublicKey) Materialize() (interface{}, error) {
|
||||
if k.key == nil {
|
||||
return nil, errors.New(`key has no rsa.PublicKey associated with it`)
|
||||
}
|
||||
return k.key, nil
|
||||
}
|
||||
|
||||
// Materialize returns the standard RSA Private Key representation stored in the internal representation
|
||||
func (k *RSAPrivateKey) Materialize() (interface{}, error) {
|
||||
if k.key == nil {
|
||||
return nil, errors.New(`key has no rsa.PrivateKey associated with it`)
|
||||
}
|
||||
return k.key, nil
|
||||
}
|
||||
|
||||
// GenerateKey creates a RSAPublicKey from a RawKeyJSON
|
||||
func (k *RSAPublicKey) GenerateKey(keyJSON *RawKeyJSON) error {
|
||||
|
||||
if keyJSON.N == nil || keyJSON.E == nil {
|
||||
return errors.Errorf("Missing mandatory key parameters N or E")
|
||||
}
|
||||
rsaPublicKey := &rsa.PublicKey{
|
||||
N: (&big.Int{}).SetBytes(keyJSON.N.Bytes()),
|
||||
E: int((&big.Int{}).SetBytes(keyJSON.E.Bytes()).Int64()),
|
||||
}
|
||||
k.key = rsaPublicKey
|
||||
k.StandardHeaders = &keyJSON.StandardHeaders
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateKey creates a RSAPublicKey from a RawKeyJSON
|
||||
func (k *RSAPrivateKey) GenerateKey(keyJSON *RawKeyJSON) error {
|
||||
|
||||
rsaPublicKey := &RSAPublicKey{}
|
||||
err := rsaPublicKey.GenerateKey(keyJSON)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to generate public key")
|
||||
}
|
||||
|
||||
if keyJSON.D == nil || keyJSON.P == nil || keyJSON.Q == nil {
|
||||
return errors.Errorf("Missing mandatory key parameters D, P or Q")
|
||||
}
|
||||
privateKey := &rsa.PrivateKey{
|
||||
PublicKey: *rsaPublicKey.key,
|
||||
D: (&big.Int{}).SetBytes(keyJSON.D.Bytes()),
|
||||
Primes: []*big.Int{
|
||||
(&big.Int{}).SetBytes(keyJSON.P.Bytes()),
|
||||
(&big.Int{}).SetBytes(keyJSON.Q.Bytes()),
|
||||
},
|
||||
}
|
||||
|
||||
if keyJSON.Dp.Len() > 0 {
|
||||
privateKey.Precomputed.Dp = (&big.Int{}).SetBytes(keyJSON.Dp.Bytes())
|
||||
}
|
||||
if keyJSON.Dq.Len() > 0 {
|
||||
privateKey.Precomputed.Dq = (&big.Int{}).SetBytes(keyJSON.Dq.Bytes())
|
||||
}
|
||||
if keyJSON.Qi.Len() > 0 {
|
||||
privateKey.Precomputed.Qinv = (&big.Int{}).SetBytes(keyJSON.Qi.Bytes())
|
||||
}
|
||||
|
||||
k.key = privateKey
|
||||
k.StandardHeaders = &keyJSON.StandardHeaders
|
||||
return nil
|
||||
}
|
||||
41
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/symmetric.go
generated
vendored
41
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/symmetric.go
generated
vendored
@@ -1,41 +0,0 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
func newSymmetricKey(key []byte) (*SymmetricKey, error) {
|
||||
var hdr StandardHeaders
|
||||
|
||||
err := hdr.Set(KeyTypeKey, jwa.OctetSeq)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "Failed to set Key Type")
|
||||
}
|
||||
return &SymmetricKey{
|
||||
StandardHeaders: &hdr,
|
||||
key: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Materialize returns the octets for this symmetric key.
|
||||
// Since this is a symmetric key, this just calls Octets
|
||||
func (s SymmetricKey) Materialize() (interface{}, error) {
|
||||
return s.Octets(), nil
|
||||
}
|
||||
|
||||
// Octets returns the octets in the key
|
||||
func (s SymmetricKey) Octets() []byte {
|
||||
return s.key
|
||||
}
|
||||
|
||||
// GenerateKey creates a Symmetric key from a RawKeyJSON
|
||||
func (s *SymmetricKey) GenerateKey(keyJSON *RawKeyJSON) error {
|
||||
|
||||
*s = SymmetricKey{
|
||||
StandardHeaders: &keyJSON.StandardHeaders,
|
||||
key: keyJSON.K,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
154
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/headers.go
generated
vendored
154
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/headers.go
generated
vendored
@@ -1,154 +0,0 @@
|
||||
package jws
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
// Constants for JWS Common parameters
|
||||
const (
|
||||
AlgorithmKey = "alg"
|
||||
ContentTypeKey = "cty"
|
||||
CriticalKey = "crit"
|
||||
JWKKey = "jwk"
|
||||
JWKSetURLKey = "jku"
|
||||
KeyIDKey = "kid"
|
||||
PrivateParamsKey = "privateParams"
|
||||
TypeKey = "typ"
|
||||
)
|
||||
|
||||
// Headers provides a common interface for common header parameters
|
||||
type Headers interface {
|
||||
Get(string) (interface{}, bool)
|
||||
Set(string, interface{}) error
|
||||
GetAlgorithm() jwa.SignatureAlgorithm
|
||||
}
|
||||
|
||||
// StandardHeaders contains JWS common parameters.
|
||||
type StandardHeaders struct {
|
||||
Algorithm jwa.SignatureAlgorithm `json:"alg,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.1
|
||||
ContentType string `json:"cty,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.10
|
||||
Critical []string `json:"crit,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.11
|
||||
JWK string `json:"jwk,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.3
|
||||
JWKSetURL string `json:"jku,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.2
|
||||
KeyID string `json:"kid,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.4
|
||||
PrivateParams map[string]interface{} `json:"privateParams,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.9
|
||||
Type string `json:"typ,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.9
|
||||
}
|
||||
|
||||
// GetAlgorithm returns algorithm
|
||||
func (h *StandardHeaders) GetAlgorithm() jwa.SignatureAlgorithm {
|
||||
return h.Algorithm
|
||||
}
|
||||
|
||||
// Get is a general getter function for StandardHeaders structure
|
||||
func (h *StandardHeaders) Get(name string) (interface{}, bool) {
|
||||
switch name {
|
||||
case AlgorithmKey:
|
||||
v := h.Algorithm
|
||||
if v == "" {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case ContentTypeKey:
|
||||
v := h.ContentType
|
||||
if v == "" {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case CriticalKey:
|
||||
v := h.Critical
|
||||
if len(v) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case JWKKey:
|
||||
v := h.JWK
|
||||
if v == "" {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case JWKSetURLKey:
|
||||
v := h.JWKSetURL
|
||||
if v == "" {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case KeyIDKey:
|
||||
v := h.KeyID
|
||||
if v == "" {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case PrivateParamsKey:
|
||||
v := h.PrivateParams
|
||||
if len(v) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case TypeKey:
|
||||
v := h.Type
|
||||
if v == "" {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
// Set is a general setter function for StandardHeaders structure
|
||||
func (h *StandardHeaders) Set(name string, value interface{}) error {
|
||||
switch name {
|
||||
case AlgorithmKey:
|
||||
if err := h.Algorithm.Accept(value); err != nil {
|
||||
return errors.Wrapf(err, `invalid value for %s key`, AlgorithmKey)
|
||||
}
|
||||
return nil
|
||||
case ContentTypeKey:
|
||||
if v, ok := value.(string); ok {
|
||||
h.ContentType = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf(`invalid value for %s key: %T`, ContentTypeKey, value)
|
||||
case CriticalKey:
|
||||
if v, ok := value.([]string); ok {
|
||||
h.Critical = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf(`invalid value for %s key: %T`, CriticalKey, value)
|
||||
case JWKKey:
|
||||
if v, ok := value.(string); ok {
|
||||
h.JWK = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf(`invalid value for %s key: %T`, JWKKey, value)
|
||||
case JWKSetURLKey:
|
||||
if v, ok := value.(string); ok {
|
||||
h.JWKSetURL = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf(`invalid value for %s key: %T`, JWKSetURLKey, value)
|
||||
case KeyIDKey:
|
||||
if v, ok := value.(string); ok {
|
||||
h.KeyID = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf(`invalid value for %s key: %T`, KeyIDKey, value)
|
||||
case PrivateParamsKey:
|
||||
if v, ok := value.(map[string]interface{}); ok {
|
||||
h.PrivateParams = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf(`invalid value for %s key: %T`, PrivateParamsKey, value)
|
||||
case TypeKey:
|
||||
if v, ok := value.(string); ok {
|
||||
h.Type = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf(`invalid value for %s key: %T`, TypeKey, value)
|
||||
default:
|
||||
return errors.Errorf(`invalid key: %s`, name)
|
||||
}
|
||||
}
|
||||
22
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/interface.go
generated
vendored
22
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/interface.go
generated
vendored
@@ -1,22 +0,0 @@
|
||||
package jws
|
||||
|
||||
// Message represents a full JWS encoded message. Flattened serialization
|
||||
// is not supported as a struct, but rather it's represented as a
|
||||
// Message struct with only one `Signature` element.
|
||||
//
|
||||
// Do not expect to use the Message object to verify or construct a
|
||||
// signed payloads with. You should only use this when you want to actually
|
||||
// want to programmatically view the contents for the full JWS Payload.
|
||||
//
|
||||
// To sign and verify, use the appropriate `SignWithOption()` nad `Verify()` functions
|
||||
type Message struct {
|
||||
Payload []byte `json:"payload"`
|
||||
Signatures []*Signature `json:"signatures,omitempty"`
|
||||
}
|
||||
|
||||
// Signature represents the headers and signature of a JWS message
|
||||
type Signature struct {
|
||||
Headers Headers `json:"header,omitempty"` // Unprotected Headers
|
||||
Protected Headers `json:"Protected,omitempty"` // Protected Headers
|
||||
Signature []byte `json:"signature,omitempty"` // GetSignature
|
||||
}
|
||||
210
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/jws.go
generated
vendored
210
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/jws.go
generated
vendored
@@ -1,210 +0,0 @@
|
||||
// Package jws implements the digital Signature on JSON based data
|
||||
// structures as described in https://tools.ietf.org/html/rfc7515
|
||||
//
|
||||
// If you do not care about the details, the only things that you
|
||||
// would need to use are the following functions:
|
||||
//
|
||||
// jws.SignWithOption(Payload, algorithm, key)
|
||||
// jws.Verify(encodedjws, algorithm, key)
|
||||
//
|
||||
// To sign, simply use `jws.SignWithOption`. `Payload` is a []byte buffer that
|
||||
// contains whatever data you want to sign. `alg` is one of the
|
||||
// jwa.SignatureAlgorithm constants from package jwa. For RSA and
|
||||
// ECDSA family of algorithms, you will need to prepare a private key.
|
||||
// For HMAC family, you just need a []byte value. The `jws.SignWithOption`
|
||||
// function will return the encoded JWS message on success.
|
||||
//
|
||||
// To verify, use `jws.Verify`. It will parse the `encodedjws` buffer
|
||||
// and verify the result using `algorithm` and `key`. Upon successful
|
||||
// verification, the original Payload is returned, so you can work on it.
|
||||
package jws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwk"
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign"
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// SignLiteral generates a Signature for the given Payload and Headers, and serializes
|
||||
// it in compact serialization format. In this format you may NOT use
|
||||
// multiple signers.
|
||||
//
|
||||
func SignLiteral(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, hdrBuf []byte) ([]byte, error) {
|
||||
encodedHdr := base64.RawURLEncoding.EncodeToString(hdrBuf)
|
||||
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
|
||||
signingInput := strings.Join(
|
||||
[]string{
|
||||
encodedHdr,
|
||||
encodedPayload,
|
||||
}, ".",
|
||||
)
|
||||
signer, err := sign.New(alg)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, `failed to create signer`)
|
||||
}
|
||||
signature, err := signer.Sign([]byte(signingInput), key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, `failed to sign Payload`)
|
||||
}
|
||||
encodedSignature := base64.RawURLEncoding.EncodeToString(signature)
|
||||
compactSerialization := strings.Join(
|
||||
[]string{
|
||||
signingInput,
|
||||
encodedSignature,
|
||||
}, ".",
|
||||
)
|
||||
return []byte(compactSerialization), nil
|
||||
}
|
||||
|
||||
// SignWithOption generates a Signature for the given Payload, and serializes
|
||||
// it in compact serialization format. In this format you may NOT use
|
||||
// multiple signers.
|
||||
//
|
||||
// If you would like to pass custom Headers, use the WithHeaders option.
|
||||
func SignWithOption(payload []byte, alg jwa.SignatureAlgorithm, key interface{}) ([]byte, error) {
|
||||
var headers Headers = &StandardHeaders{}
|
||||
|
||||
err := headers.Set(AlgorithmKey, alg)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to set alg value")
|
||||
}
|
||||
|
||||
hdrBuf, err := json.Marshal(headers)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, `failed to marshal Headers`)
|
||||
}
|
||||
return SignLiteral(payload, alg, key, hdrBuf)
|
||||
}
|
||||
|
||||
// Verify checks if the given JWS message is verifiable using `alg` and `key`.
|
||||
// If the verification is successful, `err` is nil, and the content of the
|
||||
// Payload that was signed is returned. If you need more fine-grained
|
||||
// control of the verification process, manually call `Parse`, generate a
|
||||
// verifier, and call `Verify` on the parsed JWS message object.
|
||||
func Verify(buf []byte, alg jwa.SignatureAlgorithm, key interface{}) (ret []byte, err error) {
|
||||
|
||||
verifier, err := verify.New(alg)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create verifier")
|
||||
}
|
||||
|
||||
buf = bytes.TrimSpace(buf)
|
||||
if len(buf) == 0 {
|
||||
return nil, errors.New(`attempt to verify empty buffer`)
|
||||
}
|
||||
|
||||
parts, err := SplitCompact(string(buf[:]))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, `failed extract from compact serialization format`)
|
||||
}
|
||||
|
||||
signingInput := strings.Join(
|
||||
[]string{
|
||||
parts[0],
|
||||
parts[1],
|
||||
}, ".",
|
||||
)
|
||||
|
||||
decodedSignature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to decode signature")
|
||||
}
|
||||
if err := verifier.Verify([]byte(signingInput), decodedSignature, key); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to verify message")
|
||||
}
|
||||
|
||||
if decodedPayload, err := base64.RawURLEncoding.DecodeString(parts[1]); err == nil {
|
||||
return decodedPayload, nil
|
||||
}
|
||||
return nil, errors.Wrap(err, "Failed to decode Payload")
|
||||
}
|
||||
|
||||
// VerifyWithJWK verifies the JWS message using the specified JWK
|
||||
func VerifyWithJWK(buf []byte, key jwk.Key) (payload []byte, err error) {
|
||||
|
||||
keyVal, err := key.Materialize()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to materialize key")
|
||||
}
|
||||
return Verify(buf, key.GetAlgorithm(), keyVal)
|
||||
}
|
||||
|
||||
// VerifyWithJWKSet verifies the JWS message using JWK key set.
|
||||
// By default it will only pick up keys that have the "use" key
|
||||
// set to either "sig" or "enc", but you can override it by
|
||||
// providing a keyaccept function.
|
||||
func VerifyWithJWKSet(buf []byte, keyset *jwk.Set) (payload []byte, err error) {
|
||||
|
||||
for _, key := range keyset.Keys {
|
||||
payload, err := VerifyWithJWK(buf, key)
|
||||
if err == nil {
|
||||
return payload, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("failed to verify with any of the keys")
|
||||
}
|
||||
|
||||
// ParseByte parses a JWS value serialized via compact serialization and provided as []byte.
|
||||
func ParseByte(jwsCompact []byte) (m *Message, err error) {
|
||||
return parseCompact(string(jwsCompact[:]))
|
||||
}
|
||||
|
||||
// ParseString parses a JWS value serialized via compact serialization and provided as string.
|
||||
func ParseString(s string) (*Message, error) {
|
||||
return parseCompact(s)
|
||||
}
|
||||
|
||||
// SplitCompact splits a JWT and returns its three parts
|
||||
// separately: Protected Headers, Payload and Signature.
|
||||
func SplitCompact(jwsCompact string) ([]string, error) {
|
||||
|
||||
parts := strings.Split(jwsCompact, ".")
|
||||
if len(parts) < 3 {
|
||||
return nil, errors.New("Failed to split compact serialization")
|
||||
}
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
// parseCompact parses a JWS value serialized via compact serialization.
|
||||
func parseCompact(str string) (m *Message, err error) {
|
||||
|
||||
var decodedHeader, decodedPayload, decodedSignature []byte
|
||||
parts, err := SplitCompact(str)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, `invalid compact serialization format`)
|
||||
}
|
||||
|
||||
if decodedHeader, err = base64.RawURLEncoding.DecodeString(parts[0]); err != nil {
|
||||
return nil, errors.Wrap(err, `failed to decode Headers`)
|
||||
}
|
||||
var hdr StandardHeaders
|
||||
if err := json.Unmarshal(decodedHeader, &hdr); err != nil {
|
||||
return nil, errors.Wrap(err, `failed to parse JOSE Headers`)
|
||||
}
|
||||
|
||||
if decodedPayload, err = base64.RawURLEncoding.DecodeString(parts[1]); err != nil {
|
||||
return nil, errors.Wrap(err, `failed to decode Payload`)
|
||||
}
|
||||
|
||||
if len(parts) > 2 {
|
||||
if decodedSignature, err = base64.RawURLEncoding.DecodeString(parts[2]); err != nil {
|
||||
return nil, errors.Wrap(err, `failed to decode Signature`)
|
||||
}
|
||||
}
|
||||
|
||||
var msg Message
|
||||
msg.Payload = decodedPayload
|
||||
msg.Signatures = append(msg.Signatures, &Signature{
|
||||
Protected: &hdr,
|
||||
Signature: decodedSignature,
|
||||
})
|
||||
return &msg, nil
|
||||
}
|
||||
26
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/message.go
generated
vendored
26
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/message.go
generated
vendored
@@ -1,26 +0,0 @@
|
||||
package jws
|
||||
|
||||
// PublicHeaders returns the public headers in a JWS
|
||||
func (s Signature) PublicHeaders() Headers {
|
||||
return s.Headers
|
||||
}
|
||||
|
||||
// ProtectedHeaders returns the protected headers in a JWS
|
||||
func (s Signature) ProtectedHeaders() Headers {
|
||||
return s.Protected
|
||||
}
|
||||
|
||||
// GetSignature returns the signature in a JWS
|
||||
func (s Signature) GetSignature() []byte {
|
||||
return s.Signature
|
||||
}
|
||||
|
||||
// GetPayload returns the payload in a JWS
|
||||
func (m Message) GetPayload() []byte {
|
||||
return m.Payload
|
||||
}
|
||||
|
||||
// GetSignatures returns the all signatures in a JWS
|
||||
func (m Message) GetSignatures() []*Signature {
|
||||
return m.Signatures
|
||||
}
|
||||
84
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/ecdsa.go
generated
vendored
84
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/ecdsa.go
generated
vendored
@@ -1,84 +0,0 @@
|
||||
package sign
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var ecdsaSignFuncs = map[jwa.SignatureAlgorithm]ecdsaSignFunc{}
|
||||
|
||||
func init() {
|
||||
algs := map[jwa.SignatureAlgorithm]crypto.Hash{
|
||||
jwa.ES256: crypto.SHA256,
|
||||
jwa.ES384: crypto.SHA384,
|
||||
jwa.ES512: crypto.SHA512,
|
||||
}
|
||||
|
||||
for alg, h := range algs {
|
||||
ecdsaSignFuncs[alg] = makeECDSASignFunc(h)
|
||||
}
|
||||
}
|
||||
|
||||
func makeECDSASignFunc(hash crypto.Hash) ecdsaSignFunc {
|
||||
return ecdsaSignFunc(func(payload []byte, key *ecdsa.PrivateKey) ([]byte, error) {
|
||||
curveBits := key.Curve.Params().BitSize
|
||||
keyBytes := curveBits / 8
|
||||
// Curve bits do not need to be a multiple of 8.
|
||||
if curveBits%8 > 0 {
|
||||
keyBytes++
|
||||
}
|
||||
h := hash.New()
|
||||
h.Write(payload)
|
||||
r, s, err := ecdsa.Sign(rand.Reader, key, h.Sum(nil))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to sign payload using ecdsa")
|
||||
}
|
||||
|
||||
rBytes := r.Bytes()
|
||||
rBytesPadded := make([]byte, keyBytes)
|
||||
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
|
||||
|
||||
sBytes := s.Bytes()
|
||||
sBytesPadded := make([]byte, keyBytes)
|
||||
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
|
||||
|
||||
out := append(rBytesPadded, sBytesPadded...)
|
||||
return out, nil
|
||||
})
|
||||
}
|
||||
|
||||
func newECDSA(alg jwa.SignatureAlgorithm) (*ECDSASigner, error) {
|
||||
signfn, ok := ecdsaSignFuncs[alg]
|
||||
if !ok {
|
||||
return nil, errors.Errorf(`unsupported algorithm while trying to create ECDSA signer: %s`, alg)
|
||||
}
|
||||
|
||||
return &ECDSASigner{
|
||||
alg: alg,
|
||||
sign: signfn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Algorithm returns the signer algorithm
|
||||
func (s ECDSASigner) Algorithm() jwa.SignatureAlgorithm {
|
||||
return s.alg
|
||||
}
|
||||
|
||||
// Sign signs payload with a ECDSA private key
|
||||
func (s ECDSASigner) Sign(payload []byte, key interface{}) ([]byte, error) {
|
||||
if key == nil {
|
||||
return nil, errors.New(`missing private key while signing payload`)
|
||||
}
|
||||
|
||||
privateKey, ok := key.(*ecdsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, errors.Errorf(`invalid key type %T. *ecdsa.PrivateKey is required`, key)
|
||||
}
|
||||
|
||||
return s.sign(payload, privateKey)
|
||||
}
|
||||
66
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/hmac.go
generated
vendored
66
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/hmac.go
generated
vendored
@@ -1,66 +0,0 @@
|
||||
package sign
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"hash"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var hmacSignFuncs = map[jwa.SignatureAlgorithm]hmacSignFunc{}
|
||||
|
||||
func init() {
|
||||
algs := map[jwa.SignatureAlgorithm]func() hash.Hash{
|
||||
jwa.HS256: sha256.New,
|
||||
jwa.HS384: sha512.New384,
|
||||
jwa.HS512: sha512.New,
|
||||
}
|
||||
|
||||
for alg, h := range algs {
|
||||
hmacSignFuncs[alg] = makeHMACSignFunc(h)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func newHMAC(alg jwa.SignatureAlgorithm) (*HMACSigner, error) {
|
||||
signer, ok := hmacSignFuncs[alg]
|
||||
if !ok {
|
||||
return nil, errors.Errorf(`unsupported algorithm while trying to create HMAC signer: %s`, alg)
|
||||
}
|
||||
|
||||
return &HMACSigner{
|
||||
alg: alg,
|
||||
sign: signer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func makeHMACSignFunc(hfunc func() hash.Hash) hmacSignFunc {
|
||||
return hmacSignFunc(func(payload []byte, key []byte) ([]byte, error) {
|
||||
h := hmac.New(hfunc, key)
|
||||
h.Write(payload)
|
||||
return h.Sum(nil), nil
|
||||
})
|
||||
}
|
||||
|
||||
// Algorithm returns the signer algorithm
|
||||
func (s HMACSigner) Algorithm() jwa.SignatureAlgorithm {
|
||||
return s.alg
|
||||
}
|
||||
|
||||
// Sign signs payload with a Symmetric key
|
||||
func (s HMACSigner) Sign(payload []byte, key interface{}) ([]byte, error) {
|
||||
hmackey, ok := key.([]byte)
|
||||
if !ok {
|
||||
return nil, errors.Errorf(`invalid key type %T. []byte is required`, key)
|
||||
}
|
||||
|
||||
if len(hmackey) == 0 {
|
||||
return nil, errors.New(`missing key while signing payload`)
|
||||
}
|
||||
|
||||
return s.sign(payload, hmackey)
|
||||
}
|
||||
45
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/interface.go
generated
vendored
45
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/interface.go
generated
vendored
@@ -1,45 +0,0 @@
|
||||
package sign
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
// Signer provides a common interface for supported alg signing methods
|
||||
type Signer interface {
|
||||
// Sign creates a signature for the given `payload`.
|
||||
// `key` is the key used for signing the payload, and is usually
|
||||
// the private key type associated with the signature method. For example,
|
||||
// for `jwa.RSXXX` and `jwa.PSXXX` types, you need to pass the
|
||||
// `*"crypto/rsa".PrivateKey` type.
|
||||
// Check the documentation for each signer for details
|
||||
Sign(payload []byte, key interface{}) ([]byte, error)
|
||||
|
||||
Algorithm() jwa.SignatureAlgorithm
|
||||
}
|
||||
|
||||
type rsaSignFunc func([]byte, *rsa.PrivateKey) ([]byte, error)
|
||||
|
||||
// RSASigner uses crypto/rsa to sign the payloads.
|
||||
type RSASigner struct {
|
||||
alg jwa.SignatureAlgorithm
|
||||
sign rsaSignFunc
|
||||
}
|
||||
|
||||
type ecdsaSignFunc func([]byte, *ecdsa.PrivateKey) ([]byte, error)
|
||||
|
||||
// ECDSASigner uses crypto/ecdsa to sign the payloads.
|
||||
type ECDSASigner struct {
|
||||
alg jwa.SignatureAlgorithm
|
||||
sign ecdsaSignFunc
|
||||
}
|
||||
|
||||
type hmacSignFunc func([]byte, []byte) ([]byte, error)
|
||||
|
||||
// HMACSigner uses crypto/hmac to sign the payloads.
|
||||
type HMACSigner struct {
|
||||
alg jwa.SignatureAlgorithm
|
||||
sign hmacSignFunc
|
||||
}
|
||||
97
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/rsa.go
generated
vendored
97
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/rsa.go
generated
vendored
@@ -1,97 +0,0 @@
|
||||
package sign
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var rsaSignFuncs = map[jwa.SignatureAlgorithm]rsaSignFunc{}
|
||||
|
||||
func init() {
|
||||
algs := map[jwa.SignatureAlgorithm]struct {
|
||||
Hash crypto.Hash
|
||||
SignFunc func(crypto.Hash) rsaSignFunc
|
||||
}{
|
||||
jwa.RS256: {
|
||||
Hash: crypto.SHA256,
|
||||
SignFunc: makeSignPKCS1v15,
|
||||
},
|
||||
jwa.RS384: {
|
||||
Hash: crypto.SHA384,
|
||||
SignFunc: makeSignPKCS1v15,
|
||||
},
|
||||
jwa.RS512: {
|
||||
Hash: crypto.SHA512,
|
||||
SignFunc: makeSignPKCS1v15,
|
||||
},
|
||||
jwa.PS256: {
|
||||
Hash: crypto.SHA256,
|
||||
SignFunc: makeSignPSS,
|
||||
},
|
||||
jwa.PS384: {
|
||||
Hash: crypto.SHA384,
|
||||
SignFunc: makeSignPSS,
|
||||
},
|
||||
jwa.PS512: {
|
||||
Hash: crypto.SHA512,
|
||||
SignFunc: makeSignPSS,
|
||||
},
|
||||
}
|
||||
|
||||
for alg, item := range algs {
|
||||
rsaSignFuncs[alg] = item.SignFunc(item.Hash)
|
||||
}
|
||||
}
|
||||
|
||||
func makeSignPKCS1v15(hash crypto.Hash) rsaSignFunc {
|
||||
return rsaSignFunc(func(payload []byte, key *rsa.PrivateKey) ([]byte, error) {
|
||||
h := hash.New()
|
||||
h.Write(payload)
|
||||
return rsa.SignPKCS1v15(rand.Reader, key, hash, h.Sum(nil))
|
||||
})
|
||||
}
|
||||
|
||||
func makeSignPSS(hash crypto.Hash) rsaSignFunc {
|
||||
return rsaSignFunc(func(payload []byte, key *rsa.PrivateKey) ([]byte, error) {
|
||||
h := hash.New()
|
||||
h.Write(payload)
|
||||
return rsa.SignPSS(rand.Reader, key, hash, h.Sum(nil), &rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func newRSA(alg jwa.SignatureAlgorithm) (*RSASigner, error) {
|
||||
signfn, ok := rsaSignFuncs[alg]
|
||||
if !ok {
|
||||
return nil, errors.Errorf(`unsupported algorithm while trying to create RSA signer: %s`, alg)
|
||||
}
|
||||
return &RSASigner{
|
||||
alg: alg,
|
||||
sign: signfn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Algorithm returns the signer algorithm
|
||||
func (s RSASigner) Algorithm() jwa.SignatureAlgorithm {
|
||||
return s.alg
|
||||
}
|
||||
|
||||
// Sign creates a signature using crypto/rsa. key must be a non-nil instance of
|
||||
// `*"crypto/rsa".PrivateKey`.
|
||||
func (s RSASigner) Sign(payload []byte, key interface{}) ([]byte, error) {
|
||||
if key == nil {
|
||||
return nil, errors.New(`missing private key while signing payload`)
|
||||
}
|
||||
rsakey, ok := key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, errors.Errorf(`invalid key type %T. *rsa.PrivateKey is required`, key)
|
||||
}
|
||||
|
||||
return s.sign(payload, rsakey)
|
||||
}
|
||||
21
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/sign.go
generated
vendored
21
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/sign.go
generated
vendored
@@ -1,21 +0,0 @@
|
||||
package sign
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
// New creates a signer that signs payloads using the given signature algorithm.
|
||||
func New(alg jwa.SignatureAlgorithm) (Signer, error) {
|
||||
switch alg {
|
||||
case jwa.RS256, jwa.RS384, jwa.RS512, jwa.PS256, jwa.PS384, jwa.PS512:
|
||||
return newRSA(alg)
|
||||
case jwa.ES256, jwa.ES384, jwa.ES512:
|
||||
return newECDSA(alg)
|
||||
case jwa.HS256, jwa.HS384, jwa.HS512:
|
||||
return newHMAC(alg)
|
||||
default:
|
||||
return nil, errors.Errorf(`unsupported signature algorithm %s`, alg)
|
||||
}
|
||||
}
|
||||
67
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/ecdsa.go
generated
vendored
67
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/ecdsa.go
generated
vendored
@@ -1,67 +0,0 @@
|
||||
package verify
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"math/big"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
var ecdsaVerifyFuncs = map[jwa.SignatureAlgorithm]ecdsaVerifyFunc{}
|
||||
|
||||
func init() {
|
||||
algs := map[jwa.SignatureAlgorithm]crypto.Hash{
|
||||
jwa.ES256: crypto.SHA256,
|
||||
jwa.ES384: crypto.SHA384,
|
||||
jwa.ES512: crypto.SHA512,
|
||||
}
|
||||
|
||||
for alg, h := range algs {
|
||||
ecdsaVerifyFuncs[alg] = makeECDSAVerifyFunc(h)
|
||||
}
|
||||
}
|
||||
|
||||
func makeECDSAVerifyFunc(hash crypto.Hash) ecdsaVerifyFunc {
|
||||
return ecdsaVerifyFunc(func(payload []byte, signature []byte, key *ecdsa.PublicKey) error {
|
||||
|
||||
r, s := &big.Int{}, &big.Int{}
|
||||
n := len(signature) / 2
|
||||
r.SetBytes(signature[:n])
|
||||
s.SetBytes(signature[n:])
|
||||
|
||||
h := hash.New()
|
||||
h.Write(payload)
|
||||
|
||||
if !ecdsa.Verify(key, h.Sum(nil), r, s) {
|
||||
return errors.New(`failed to verify signature using ecdsa`)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func newECDSA(alg jwa.SignatureAlgorithm) (*ECDSAVerifier, error) {
|
||||
verifyfn, ok := ecdsaVerifyFuncs[alg]
|
||||
if !ok {
|
||||
return nil, errors.Errorf(`unsupported algorithm while trying to create ECDSA verifier: %s`, alg)
|
||||
}
|
||||
|
||||
return &ECDSAVerifier{
|
||||
verify: verifyfn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Verify checks whether the signature for a given input and key is correct
|
||||
func (v ECDSAVerifier) Verify(payload []byte, signature []byte, key interface{}) error {
|
||||
if key == nil {
|
||||
return errors.New(`missing public key while verifying payload`)
|
||||
}
|
||||
ecdsakey, ok := key.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return errors.Errorf(`invalid key type %T. *ecdsa.PublicKey is required`, key)
|
||||
}
|
||||
|
||||
return v.verify(payload, signature, ecdsakey)
|
||||
}
|
||||
33
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/hmac.go
generated
vendored
33
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/hmac.go
generated
vendored
@@ -1,33 +0,0 @@
|
||||
package verify
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign"
|
||||
)
|
||||
|
||||
func newHMAC(alg jwa.SignatureAlgorithm) (*HMACVerifier, error) {
|
||||
|
||||
s, err := sign.New(alg)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, `failed to generate HMAC signer`)
|
||||
}
|
||||
return &HMACVerifier{signer: s}, nil
|
||||
}
|
||||
|
||||
// Verify checks whether the signature for a given input and key is correct
|
||||
func (v HMACVerifier) Verify(signingInput, signature []byte, key interface{}) (err error) {
|
||||
|
||||
expected, err := v.signer.Sign(signingInput, key)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, `failed to generated signature`)
|
||||
}
|
||||
|
||||
if !hmac.Equal(signature, expected) {
|
||||
return errors.New(`failed to match hmac signature`)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
39
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/interface.go
generated
vendored
39
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/interface.go
generated
vendored
@@ -1,39 +0,0 @@
|
||||
package verify
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign"
|
||||
)
|
||||
|
||||
// Verifier provides a common interface for supported alg verification methods
|
||||
type Verifier interface {
|
||||
// Verify checks whether the payload and signature are valid for
|
||||
// the given key.
|
||||
// `key` is the key used for verifying the payload, and is usually
|
||||
// the public key associated with the signature method. For example,
|
||||
// for `jwa.RSXXX` and `jwa.PSXXX` types, you need to pass the
|
||||
// `*"crypto/rsa".PublicKey` type.
|
||||
// Check the documentation for each verifier for details
|
||||
Verify(payload []byte, signature []byte, key interface{}) error
|
||||
}
|
||||
|
||||
type rsaVerifyFunc func([]byte, []byte, *rsa.PublicKey) error
|
||||
|
||||
// RSAVerifier implements the Verifier interface
|
||||
type RSAVerifier struct {
|
||||
verify rsaVerifyFunc
|
||||
}
|
||||
|
||||
type ecdsaVerifyFunc func([]byte, []byte, *ecdsa.PublicKey) error
|
||||
|
||||
// ECDSAVerifier implements the Verifier interface
|
||||
type ECDSAVerifier struct {
|
||||
verify ecdsaVerifyFunc
|
||||
}
|
||||
|
||||
// HMACVerifier implements the Verifier interface
|
||||
type HMACVerifier struct {
|
||||
signer sign.Signer
|
||||
}
|
||||
88
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/rsa.go
generated
vendored
88
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/rsa.go
generated
vendored
@@ -1,88 +0,0 @@
|
||||
package verify
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rsa"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var rsaVerifyFuncs = map[jwa.SignatureAlgorithm]rsaVerifyFunc{}
|
||||
|
||||
func init() {
|
||||
algs := map[jwa.SignatureAlgorithm]struct {
|
||||
Hash crypto.Hash
|
||||
VerifyFunc func(crypto.Hash) rsaVerifyFunc
|
||||
}{
|
||||
jwa.RS256: {
|
||||
Hash: crypto.SHA256,
|
||||
VerifyFunc: makeVerifyPKCS1v15,
|
||||
},
|
||||
jwa.RS384: {
|
||||
Hash: crypto.SHA384,
|
||||
VerifyFunc: makeVerifyPKCS1v15,
|
||||
},
|
||||
jwa.RS512: {
|
||||
Hash: crypto.SHA512,
|
||||
VerifyFunc: makeVerifyPKCS1v15,
|
||||
},
|
||||
jwa.PS256: {
|
||||
Hash: crypto.SHA256,
|
||||
VerifyFunc: makeVerifyPSS,
|
||||
},
|
||||
jwa.PS384: {
|
||||
Hash: crypto.SHA384,
|
||||
VerifyFunc: makeVerifyPSS,
|
||||
},
|
||||
jwa.PS512: {
|
||||
Hash: crypto.SHA512,
|
||||
VerifyFunc: makeVerifyPSS,
|
||||
},
|
||||
}
|
||||
|
||||
for alg, item := range algs {
|
||||
rsaVerifyFuncs[alg] = item.VerifyFunc(item.Hash)
|
||||
}
|
||||
}
|
||||
|
||||
func makeVerifyPKCS1v15(hash crypto.Hash) rsaVerifyFunc {
|
||||
return rsaVerifyFunc(func(payload, signature []byte, key *rsa.PublicKey) error {
|
||||
h := hash.New()
|
||||
h.Write(payload)
|
||||
return rsa.VerifyPKCS1v15(key, hash, h.Sum(nil), signature)
|
||||
})
|
||||
}
|
||||
|
||||
func makeVerifyPSS(hash crypto.Hash) rsaVerifyFunc {
|
||||
return rsaVerifyFunc(func(payload, signature []byte, key *rsa.PublicKey) error {
|
||||
h := hash.New()
|
||||
h.Write(payload)
|
||||
return rsa.VerifyPSS(key, hash, h.Sum(nil), signature, nil)
|
||||
})
|
||||
}
|
||||
|
||||
func newRSA(alg jwa.SignatureAlgorithm) (*RSAVerifier, error) {
|
||||
verifyfn, ok := rsaVerifyFuncs[alg]
|
||||
if !ok {
|
||||
return nil, errors.Errorf(`unsupported algorithm while trying to create RSA verifier: %s`, alg)
|
||||
}
|
||||
|
||||
return &RSAVerifier{
|
||||
verify: verifyfn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Verify checks if a JWS is valid.
|
||||
func (v RSAVerifier) Verify(payload, signature []byte, key interface{}) error {
|
||||
if key == nil {
|
||||
return errors.New(`missing public key while verifying payload`)
|
||||
}
|
||||
rsaKey, ok := key.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return errors.Errorf(`invalid key type %T. *rsa.PublicKey is required`, key)
|
||||
}
|
||||
|
||||
return v.verify(payload, signature, rsaKey)
|
||||
}
|
||||
22
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/verify.go
generated
vendored
22
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/verify.go
generated
vendored
@@ -1,22 +0,0 @@
|
||||
package verify
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
// New creates a new JWS verifier using the specified algorithm
|
||||
// and the public key
|
||||
func New(alg jwa.SignatureAlgorithm) (Verifier, error) {
|
||||
switch alg {
|
||||
case jwa.RS256, jwa.RS384, jwa.RS512, jwa.PS256, jwa.PS384, jwa.PS512:
|
||||
return newRSA(alg)
|
||||
case jwa.ES256, jwa.ES384, jwa.ES512:
|
||||
return newECDSA(alg)
|
||||
case jwa.HS256, jwa.HS384, jwa.HS512:
|
||||
return newHMAC(alg)
|
||||
default:
|
||||
return nil, errors.Errorf(`unsupported signature algorithm: %s`, alg)
|
||||
}
|
||||
}
|
||||
416
vendor/github.com/open-policy-agent/opa/topdown/json.go
generated
vendored
416
vendor/github.com/open-policy-agent/opa/topdown/json.go
generated
vendored
@@ -93,11 +93,12 @@ func jsonRemove(a *ast.Term, b *ast.Term) (*ast.Term, error) {
|
||||
return nil, err
|
||||
}
|
||||
return ast.NewTerm(newSet), nil
|
||||
case ast.Array:
|
||||
case *ast.Array:
|
||||
// When indexes are removed we shift left to close empty spots in the array
|
||||
// as per the JSON patch spec.
|
||||
var newArray ast.Array
|
||||
for i, v := range aValue {
|
||||
newArray := ast.NewArray()
|
||||
for i := 0; i < aValue.Len(); i++ {
|
||||
v := aValue.Elem(i)
|
||||
// recurse and add the diff of sub objects as needed
|
||||
// Note: Keys in b will be strings for the index, eg path /a/1/b => {"a": {"1": {"b": null}}}
|
||||
diffValue, err := jsonRemove(v, bObj.Get(ast.StringTerm(strconv.Itoa(i))))
|
||||
@@ -105,7 +106,7 @@ func jsonRemove(a *ast.Term, b *ast.Term) (*ast.Term, error) {
|
||||
return nil, err
|
||||
}
|
||||
if diffValue != nil {
|
||||
newArray = append(newArray, diffValue)
|
||||
newArray = newArray.Append(diffValue)
|
||||
}
|
||||
}
|
||||
return ast.NewTerm(newArray), nil
|
||||
@@ -142,9 +143,9 @@ func getJSONPaths(operand ast.Value) ([]ast.Ref, error) {
|
||||
var paths []ast.Ref
|
||||
|
||||
switch v := operand.(type) {
|
||||
case ast.Array:
|
||||
for _, f := range v {
|
||||
filter, err := parsePath(f)
|
||||
case *ast.Array:
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
filter, err := parsePath(v.Elem(i))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -175,15 +176,18 @@ func parsePath(path *ast.Term) (ast.Ref, error) {
|
||||
var pathSegments ast.Ref
|
||||
switch p := path.Value.(type) {
|
||||
case ast.String:
|
||||
parts := strings.Split(strings.Trim(string(p), "/"), "/")
|
||||
if p == "" {
|
||||
return ast.Ref{}, nil
|
||||
}
|
||||
parts := strings.Split(strings.TrimLeft(string(p), "/"), "/")
|
||||
for _, part := range parts {
|
||||
part = strings.ReplaceAll(strings.ReplaceAll(part, "~1", "/"), "~0", "~")
|
||||
pathSegments = append(pathSegments, ast.StringTerm(part))
|
||||
}
|
||||
case ast.Array:
|
||||
for _, term := range p {
|
||||
case *ast.Array:
|
||||
p.Foreach(func(term *ast.Term) {
|
||||
pathSegments = append(pathSegments, term)
|
||||
}
|
||||
})
|
||||
default:
|
||||
return nil, builtins.NewOperandErr(2, "must be one of {set, array} containing string paths or array of path segments but got %v", ast.TypeName(p))
|
||||
}
|
||||
@@ -199,6 +203,12 @@ func pathsToObject(paths []ast.Ref) ast.Object {
|
||||
node := root
|
||||
var done bool
|
||||
|
||||
// If the path is an empty JSON path, skip all further processing.
|
||||
if len(path) == 0 {
|
||||
done = true
|
||||
}
|
||||
|
||||
// Otherwise, we should have 1+ path segments to work with.
|
||||
for i := 0; i < len(path)-1 && !done; i++ {
|
||||
|
||||
k := path[i]
|
||||
@@ -229,7 +239,391 @@ func pathsToObject(paths []ast.Ref) ast.Object {
|
||||
return root
|
||||
}
|
||||
|
||||
// toIndex tries to convert path elements (that may be strings) into indices into
|
||||
// an array.
|
||||
func toIndex(arr *ast.Array, term *ast.Term) (int, error) {
|
||||
i := 0
|
||||
var ok bool
|
||||
switch v := term.Value.(type) {
|
||||
case ast.Number:
|
||||
if i, ok = v.Int(); !ok {
|
||||
return 0, fmt.Errorf("Invalid number type for indexing")
|
||||
}
|
||||
case ast.String:
|
||||
if v == "-" {
|
||||
return arr.Len(), nil
|
||||
}
|
||||
num := ast.Number(v)
|
||||
if i, ok = num.Int(); !ok {
|
||||
return 0, fmt.Errorf("Invalid string for indexing")
|
||||
}
|
||||
if v != "0" && strings.HasPrefix(string(v), "0") {
|
||||
return 0, fmt.Errorf("Leading zeros are not allowed in JSON paths")
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("Invalid type for indexing")
|
||||
}
|
||||
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// patchWorkerris a worker that modifies a direct child of a term located
|
||||
// at the given key. It returns the new term, and optionally a result that
|
||||
// is passed back to the caller.
|
||||
type patchWorker = func(parent, key *ast.Term) (updated, result *ast.Term)
|
||||
|
||||
func jsonPatchTraverse(
|
||||
target *ast.Term,
|
||||
path ast.Ref,
|
||||
worker patchWorker,
|
||||
) (*ast.Term, *ast.Term) {
|
||||
if len(path) < 1 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
key := path[0]
|
||||
if len(path) == 1 {
|
||||
return worker(target, key)
|
||||
}
|
||||
|
||||
success := false
|
||||
var updated, result *ast.Term
|
||||
switch parent := target.Value.(type) {
|
||||
case ast.Object:
|
||||
obj := ast.NewObject()
|
||||
parent.Foreach(func(k, v *ast.Term) {
|
||||
if k.Equal(key) {
|
||||
if v, result = jsonPatchTraverse(v, path[1:], worker); v != nil {
|
||||
obj.Insert(k, v)
|
||||
success = true
|
||||
}
|
||||
} else {
|
||||
obj.Insert(k, v)
|
||||
}
|
||||
})
|
||||
updated = ast.NewTerm(obj)
|
||||
|
||||
case *ast.Array:
|
||||
idx, err := toIndex(parent, key)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
arr := ast.NewArray()
|
||||
for i := 0; i < parent.Len(); i++ {
|
||||
v := parent.Elem(i)
|
||||
if idx == i {
|
||||
if v, result = jsonPatchTraverse(v, path[1:], worker); v != nil {
|
||||
arr = arr.Append(v)
|
||||
success = true
|
||||
}
|
||||
} else {
|
||||
arr = arr.Append(v)
|
||||
}
|
||||
}
|
||||
updated = ast.NewTerm(arr)
|
||||
|
||||
case ast.Set:
|
||||
set := ast.NewSet()
|
||||
parent.Foreach(func(k *ast.Term) {
|
||||
if k.Equal(key) {
|
||||
if k, result = jsonPatchTraverse(k, path[1:], worker); k != nil {
|
||||
set.Add(k)
|
||||
success = true
|
||||
}
|
||||
} else {
|
||||
set.Add(k)
|
||||
}
|
||||
})
|
||||
updated = ast.NewTerm(set)
|
||||
}
|
||||
|
||||
if success {
|
||||
return updated, result
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// jsonPatchGet goes one step further than jsonPatchTraverse and returns the
|
||||
// term at the location specified by the path. It is used in functions
|
||||
// where we want to read a value but not manipulate its parent: for example
|
||||
// jsonPatchTest and jsonPatchCopy.
|
||||
//
|
||||
// Because it uses jsonPatchTraverse, it makes shallow copies of the objects
|
||||
// along the path. We could possibly add a signaling mechanism that we didn't
|
||||
// make any changes to avoid this.
|
||||
func jsonPatchGet(target *ast.Term, path ast.Ref) *ast.Term {
|
||||
// Special case: get entire document.
|
||||
if len(path) == 0 {
|
||||
return target
|
||||
}
|
||||
|
||||
_, result := jsonPatchTraverse(target, path, func(parent, key *ast.Term) (*ast.Term, *ast.Term) {
|
||||
switch v := parent.Value.(type) {
|
||||
case ast.Object:
|
||||
return parent, v.Get(key)
|
||||
case *ast.Array:
|
||||
i, err := toIndex(v, key)
|
||||
if err == nil {
|
||||
return parent, v.Elem(i)
|
||||
}
|
||||
case ast.Set:
|
||||
if v.Contains(key) {
|
||||
return parent, key
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
func jsonPatchAdd(target *ast.Term, path ast.Ref, value *ast.Term) *ast.Term {
|
||||
// Special case: replacing root document.
|
||||
if len(path) == 0 {
|
||||
return value
|
||||
}
|
||||
|
||||
target, _ = jsonPatchTraverse(target, path, func(parent *ast.Term, key *ast.Term) (*ast.Term, *ast.Term) {
|
||||
switch original := parent.Value.(type) {
|
||||
case ast.Object:
|
||||
obj := ast.NewObject()
|
||||
original.Foreach(func(k, v *ast.Term) {
|
||||
obj.Insert(k, v)
|
||||
})
|
||||
obj.Insert(key, value)
|
||||
return ast.NewTerm(obj), nil
|
||||
case *ast.Array:
|
||||
idx, err := toIndex(original, key)
|
||||
if err != nil || idx < 0 || idx > original.Len() {
|
||||
return nil, nil
|
||||
}
|
||||
arr := ast.NewArray()
|
||||
for i := 0; i < idx; i++ {
|
||||
arr = arr.Append(original.Elem(i))
|
||||
}
|
||||
arr = arr.Append(value)
|
||||
for i := idx; i < original.Len(); i++ {
|
||||
arr = arr.Append(original.Elem(i))
|
||||
}
|
||||
return ast.NewTerm(arr), nil
|
||||
case ast.Set:
|
||||
if !key.Equal(value) {
|
||||
return nil, nil
|
||||
}
|
||||
set := ast.NewSet()
|
||||
original.Foreach(func(k *ast.Term) {
|
||||
set.Add(k)
|
||||
})
|
||||
set.Add(key)
|
||||
return ast.NewTerm(set), nil
|
||||
}
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
return target
|
||||
}
|
||||
|
||||
func jsonPatchRemove(target *ast.Term, path ast.Ref) (*ast.Term, *ast.Term) {
|
||||
// Special case: replacing root document.
|
||||
if len(path) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
target, removed := jsonPatchTraverse(target, path, func(parent *ast.Term, key *ast.Term) (*ast.Term, *ast.Term) {
|
||||
var removed *ast.Term
|
||||
switch original := parent.Value.(type) {
|
||||
case ast.Object:
|
||||
obj := ast.NewObject()
|
||||
original.Foreach(func(k, v *ast.Term) {
|
||||
if k.Equal(key) {
|
||||
removed = v
|
||||
} else {
|
||||
obj.Insert(k, v)
|
||||
}
|
||||
})
|
||||
return ast.NewTerm(obj), removed
|
||||
case *ast.Array:
|
||||
idx, err := toIndex(original, key)
|
||||
if err != nil || idx < 0 || idx >= original.Len() {
|
||||
return nil, nil
|
||||
}
|
||||
arr := ast.NewArray()
|
||||
for i := 0; i < idx; i++ {
|
||||
arr = arr.Append(original.Elem(i))
|
||||
}
|
||||
removed = original.Elem(idx)
|
||||
for i := idx + 1; i < original.Len(); i++ {
|
||||
arr = arr.Append(original.Elem(i))
|
||||
}
|
||||
return ast.NewTerm(arr), removed
|
||||
case ast.Set:
|
||||
set := ast.NewSet()
|
||||
original.Foreach(func(k *ast.Term) {
|
||||
if k.Equal(key) {
|
||||
removed = k
|
||||
} else {
|
||||
set.Add(k)
|
||||
}
|
||||
})
|
||||
return ast.NewTerm(set), removed
|
||||
}
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
if target != nil && removed != nil {
|
||||
return target, removed
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func jsonPatchReplace(target *ast.Term, path ast.Ref, value *ast.Term) *ast.Term {
|
||||
// Special case: replacing the whole document.
|
||||
if len(path) == 0 {
|
||||
return value
|
||||
}
|
||||
|
||||
// Replace is specified as `remove` followed by `add`.
|
||||
if target, _ = jsonPatchRemove(target, path); target == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return jsonPatchAdd(target, path, value)
|
||||
}
|
||||
|
||||
func jsonPatchMove(target *ast.Term, path ast.Ref, from ast.Ref) *ast.Term {
|
||||
// Move is specified as `remove` followed by `add`.
|
||||
target, removed := jsonPatchRemove(target, from)
|
||||
if target == nil || removed == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return jsonPatchAdd(target, path, removed)
|
||||
}
|
||||
|
||||
func jsonPatchCopy(target *ast.Term, path ast.Ref, from ast.Ref) *ast.Term {
|
||||
value := jsonPatchGet(target, from)
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return jsonPatchAdd(target, path, value)
|
||||
}
|
||||
|
||||
func jsonPatchTest(target *ast.Term, path ast.Ref, value *ast.Term) *ast.Term {
|
||||
actual := jsonPatchGet(target, path)
|
||||
if actual == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if actual.Equal(value) {
|
||||
return target
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func builtinJSONPatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
// JSON patch supports arrays, objects as well as values as the target.
|
||||
target := ast.NewTerm(operands[0].Value)
|
||||
|
||||
// Expect an array of operations.
|
||||
operations, err := builtins.ArrayOperand(operands[1].Value, 2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Apply operations one by one.
|
||||
for i := 0; i < operations.Len(); i++ {
|
||||
if object, ok := operations.Elem(i).Value.(ast.Object); ok {
|
||||
getAttribute := func(attr string) (*ast.Term, error) {
|
||||
if term := object.Get(ast.StringTerm(attr)); term != nil {
|
||||
return term, nil
|
||||
}
|
||||
|
||||
return nil, builtins.NewOperandErr(2, fmt.Sprintf("patch is missing '%s' attribute", attr))
|
||||
}
|
||||
|
||||
getPathAttribute := func(attr string) (ast.Ref, error) {
|
||||
term, err := getAttribute(attr)
|
||||
if err != nil {
|
||||
return ast.Ref{}, err
|
||||
}
|
||||
path, err := parsePath(term)
|
||||
if err != nil {
|
||||
return ast.Ref{}, err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// Parse operation.
|
||||
opTerm, err := getAttribute("op")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
op, ok := opTerm.Value.(ast.String)
|
||||
if !ok {
|
||||
return builtins.NewOperandErr(2, "patch attribute 'op' must be a string")
|
||||
}
|
||||
|
||||
// Parse path.
|
||||
path, err := getPathAttribute("path")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch op {
|
||||
case "add":
|
||||
value, err := getAttribute("value")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
target = jsonPatchAdd(target, path, value)
|
||||
case "remove":
|
||||
target, _ = jsonPatchRemove(target, path)
|
||||
case "replace":
|
||||
value, err := getAttribute("value")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
target = jsonPatchReplace(target, path, value)
|
||||
case "move":
|
||||
from, err := getPathAttribute("from")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
target = jsonPatchMove(target, path, from)
|
||||
case "copy":
|
||||
from, err := getPathAttribute("from")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
target = jsonPatchCopy(target, path, from)
|
||||
case "test":
|
||||
value, err := getAttribute("value")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
target = jsonPatchTest(target, path, value)
|
||||
default:
|
||||
return builtins.NewOperandErr(2, "must be an array of JSON-Patch objects")
|
||||
}
|
||||
} else {
|
||||
return builtins.NewOperandErr(2, "must be an array of JSON-Patch objects")
|
||||
}
|
||||
|
||||
// JSON patches should work atomically; and if one of them fails,
|
||||
// we should not try to continue.
|
||||
if target == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return iter(target)
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.JSONFilter.Name, builtinJSONFilter)
|
||||
RegisterBuiltinFunc(ast.JSONRemove.Name, builtinJSONRemove)
|
||||
RegisterBuiltinFunc(ast.JSONPatch.Name, builtinJSONPatch)
|
||||
}
|
||||
|
||||
64
vendor/github.com/open-policy-agent/opa/topdown/net.go
generated
vendored
Normal file
64
vendor/github.com/open-policy-agent/opa/topdown/net.go
generated
vendored
Normal file
@@ -0,0 +1,64 @@
|
||||
// Copyright 2021 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
type lookupIPAddrCacheKey string
|
||||
|
||||
// resolv is the same as net.DefaultResolver -- this is for mocking it out in tests
|
||||
var resolv = &net.Resolver{}
|
||||
|
||||
func builtinLookupIPAddr(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
a, err := builtins.StringOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
name := string(a)
|
||||
|
||||
err = verifyHost(bctx, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key := lookupIPAddrCacheKey(name)
|
||||
if val, ok := bctx.Cache.Get(key); ok {
|
||||
return iter(val.(*ast.Term))
|
||||
}
|
||||
|
||||
addrs, err := resolv.LookupIPAddr(bctx.Context, name)
|
||||
if err != nil {
|
||||
// NOTE(sr): We can't do better than this right now, see https://github.com/golang/go/issues/36208
|
||||
if strings.Contains(err.Error(), "operation was canceled") || strings.Contains(err.Error(), "i/o timeout") {
|
||||
return Halt{
|
||||
Err: &Error{
|
||||
Code: CancelErr,
|
||||
Message: ast.NetLookupIPAddr.Name + ": " + err.Error(),
|
||||
Location: bctx.Location,
|
||||
},
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
ret := ast.NewSet()
|
||||
for _, a := range addrs {
|
||||
ret.Add(ast.StringTerm(a.String()))
|
||||
|
||||
}
|
||||
t := ast.NewTerm(ret)
|
||||
bctx.Cache.Put(key, t)
|
||||
return iter(t)
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.NetLookupIPAddr.Name, builtinLookupIPAddr)
|
||||
}
|
||||
99
vendor/github.com/open-policy-agent/opa/topdown/numbers.go
generated
vendored
Normal file
99
vendor/github.com/open-policy-agent/opa/topdown/numbers.go
generated
vendored
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright 2020 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
type randIntCachingKey string
|
||||
|
||||
var one = big.NewInt(1)
|
||||
|
||||
func builtinNumbersRange(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
x, err := builtins.BigIntOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
y, err := builtins.BigIntOperand(operands[1].Value, 2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result := ast.NewArray()
|
||||
cmp := x.Cmp(y)
|
||||
haltErr := Halt{
|
||||
Err: &Error{
|
||||
Code: CancelErr,
|
||||
Message: "numbers.range: timed out before generating all numbers in range",
|
||||
},
|
||||
}
|
||||
|
||||
if cmp <= 0 {
|
||||
for i := new(big.Int).Set(x); i.Cmp(y) <= 0; i = i.Add(i, one) {
|
||||
if bctx.Cancel != nil && bctx.Cancel.Cancelled() {
|
||||
return haltErr
|
||||
}
|
||||
result = result.Append(ast.NewTerm(builtins.IntToNumber(i)))
|
||||
}
|
||||
} else {
|
||||
for i := new(big.Int).Set(x); i.Cmp(y) >= 0; i = i.Sub(i, one) {
|
||||
if bctx.Cancel != nil && bctx.Cancel.Cancelled() {
|
||||
return haltErr
|
||||
}
|
||||
result = result.Append(ast.NewTerm(builtins.IntToNumber(i)))
|
||||
}
|
||||
}
|
||||
|
||||
return iter(ast.NewTerm(result))
|
||||
}
|
||||
|
||||
func builtinRandIntn(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
strOp, err := builtins.StringOperand(args[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
n, err := builtins.IntOperand(args[1].Value, 2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
return iter(ast.IntNumberTerm(0))
|
||||
}
|
||||
|
||||
if n < 0 {
|
||||
n = -n
|
||||
}
|
||||
|
||||
var key = randIntCachingKey(fmt.Sprintf("%s-%d", strOp, n))
|
||||
|
||||
if val, ok := bctx.Cache.Get(key); ok {
|
||||
return iter(val.(*ast.Term))
|
||||
}
|
||||
|
||||
r, err := bctx.Rand()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result := ast.IntNumberTerm(r.Intn(n))
|
||||
bctx.Cache.Put(key, result)
|
||||
|
||||
return iter(result)
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.NumbersRange.Name, builtinNumbersRange)
|
||||
RegisterBuiltinFunc(ast.RandIntn.Name, builtinRandIntn)
|
||||
}
|
||||
94
vendor/github.com/open-policy-agent/opa/topdown/object.go
generated
vendored
94
vendor/github.com/open-policy-agent/opa/topdown/object.go
generated
vendored
@@ -6,8 +6,8 @@ package topdown
|
||||
|
||||
import (
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/internal/ref"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
"github.com/open-policy-agent/opa/types"
|
||||
)
|
||||
|
||||
func builtinObjectUnion(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
@@ -26,6 +26,38 @@ func builtinObjectUnion(_ BuiltinContext, operands []*ast.Term, iter func(*ast.T
|
||||
return iter(ast.NewTerm(r))
|
||||
}
|
||||
|
||||
func builtinObjectUnionN(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
arr, err := builtins.ArrayOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Because we need merge-with-overwrite behavior, we can iterate
|
||||
// back-to-front, and get a mostly correct set of key assignments that
|
||||
// give us the "last assignment wins, with merges" behavior we want.
|
||||
// However, if a non-object overwrites an object value anywhere in the
|
||||
// chain of assignments for a key, we have to "freeze" that key to
|
||||
// prevent accidentally picking up nested objects that could merge with
|
||||
// it from earlier in the input array.
|
||||
// Example:
|
||||
// Input: [{"a": {"b": 2}}, {"a": 4}, {"a": {"c": 3}}]
|
||||
// Want Output: {"a": {"c": 3}}
|
||||
result := ast.NewObject()
|
||||
frozenKeys := map[*ast.Term]struct{}{}
|
||||
for i := arr.Len() - 1; i >= 0; i-- {
|
||||
o, ok := arr.Elem(i).Value.(ast.Object)
|
||||
if !ok {
|
||||
return builtins.NewOperandElementErr(1, arr, arr.Elem(i).Value, "object")
|
||||
}
|
||||
mergewithOverwriteInPlace(result, o, frozenKeys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return iter(ast.NewTerm(result))
|
||||
}
|
||||
|
||||
func builtinObjectRemove(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
// Expect an object and an array/set/object of keys
|
||||
obj, err := builtins.ObjectOperand(operands[0].Value, 1)
|
||||
@@ -81,11 +113,29 @@ func builtinObjectGet(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter
|
||||
return err
|
||||
}
|
||||
|
||||
if ret := object.Get(operands[1]); ret != nil {
|
||||
return iter(ret)
|
||||
// if the get key is not an array, attempt to get the top level key for the operand value in the object
|
||||
path, err := builtins.ArrayOperand(operands[1].Value, 2)
|
||||
if err != nil {
|
||||
if ret := object.Get(operands[1]); ret != nil {
|
||||
return iter(ret)
|
||||
}
|
||||
|
||||
return iter(operands[2])
|
||||
}
|
||||
|
||||
return iter(operands[2])
|
||||
// if the path is empty, then we skip selecting nested keys and return the whole object
|
||||
if path.Len() == 0 {
|
||||
return iter(operands[0])
|
||||
}
|
||||
|
||||
// build an ast.Ref from the array and see if it matches within the object
|
||||
pathRef := ref.ArrayPath(path)
|
||||
value, err := object.Find(pathRef)
|
||||
if err != nil {
|
||||
return iter(operands[2])
|
||||
}
|
||||
|
||||
return iter(ast.NewTerm(value))
|
||||
}
|
||||
|
||||
// getObjectKeysParam returns a set of key values
|
||||
@@ -94,10 +144,11 @@ func getObjectKeysParam(arrayOrSet ast.Value) (ast.Set, error) {
|
||||
keys := ast.NewSet()
|
||||
|
||||
switch v := arrayOrSet.(type) {
|
||||
case ast.Array:
|
||||
for _, f := range v {
|
||||
case *ast.Array:
|
||||
_ = v.Iter(func(f *ast.Term) error {
|
||||
keys.Add(f)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
case ast.Set:
|
||||
_ = v.Iter(func(f *ast.Term) error {
|
||||
keys.Add(f)
|
||||
@@ -109,7 +160,7 @@ func getObjectKeysParam(arrayOrSet ast.Value) (ast.Set, error) {
|
||||
return nil
|
||||
})
|
||||
default:
|
||||
return nil, builtins.NewOperandTypeErr(2, arrayOrSet, ast.TypeName(types.Object{}), ast.TypeName(types.S), ast.TypeName(types.Array{}))
|
||||
return nil, builtins.NewOperandTypeErr(2, arrayOrSet, "object", "set", "array")
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
@@ -131,8 +182,35 @@ func mergeWithOverwrite(objA, objB ast.Object) ast.Object {
|
||||
return merged
|
||||
}
|
||||
|
||||
// Modifies obj with any new keys from other, and recursively
|
||||
// merges any keys where the values are both objects.
|
||||
func mergewithOverwriteInPlace(obj, other ast.Object, frozenKeys map[*ast.Term]struct{}) {
|
||||
other.Foreach(func(k, v *ast.Term) {
|
||||
v2 := obj.Get(k)
|
||||
// The key didn't exist in other, keep the original value.
|
||||
if v2 == nil {
|
||||
obj.Insert(k, v)
|
||||
return
|
||||
}
|
||||
// The key exists in both. Merge or reject change.
|
||||
updateValueObj, ok2 := v.Value.(ast.Object)
|
||||
originalValueObj, ok1 := v2.Value.(ast.Object)
|
||||
// Both are objects? Merge recursively.
|
||||
if ok1 && ok2 {
|
||||
// Check to make sure that this key isn't frozen before merging.
|
||||
if _, ok := frozenKeys[v2]; !ok {
|
||||
mergewithOverwriteInPlace(originalValueObj, updateValueObj, frozenKeys)
|
||||
}
|
||||
} else {
|
||||
// Else, original value wins. Freeze the key.
|
||||
frozenKeys[v2] = struct{}{}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.ObjectUnion.Name, builtinObjectUnion)
|
||||
RegisterBuiltinFunc(ast.ObjectUnionN.Name, builtinObjectUnionN)
|
||||
RegisterBuiltinFunc(ast.ObjectRemove.Name, builtinObjectRemove)
|
||||
RegisterBuiltinFunc(ast.ObjectFilter.Name, builtinObjectFilter)
|
||||
RegisterBuiltinFunc(ast.ObjectGet.Name, builtinObjectGet)
|
||||
|
||||
12
vendor/github.com/open-policy-agent/opa/topdown/parse.go
generated
vendored
12
vendor/github.com/open-policy-agent/opa/topdown/parse.go
generated
vendored
@@ -7,6 +7,7 @@ package topdown
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
@@ -42,6 +43,17 @@ func builtinRegoParseModule(a, b ast.Value) (ast.Value, error) {
|
||||
return term.Value, nil
|
||||
}
|
||||
|
||||
func registerRegoMetadataBuiltinFunction(builtin *ast.Builtin) {
|
||||
f := func(BuiltinContext, []*ast.Term, func(*ast.Term) error) error {
|
||||
// The compiler should replace all usage of this function, so the only way to get here is within a query;
|
||||
// which cannot define rules.
|
||||
return fmt.Errorf("the %s function must only be called within the scope of a rule", builtin.Name)
|
||||
}
|
||||
RegisterBuiltinFunc(builtin.Name, f)
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin2(ast.RegoParseModule.Name, builtinRegoParseModule)
|
||||
registerRegoMetadataBuiltinFunction(ast.RegoMetadataChain)
|
||||
registerRegoMetadataBuiltinFunction(ast.RegoMetadataRule)
|
||||
}
|
||||
|
||||
172
vendor/github.com/open-policy-agent/opa/topdown/parse_bytes.go
generated
vendored
172
vendor/github.com/open-policy-agent/opa/topdown/parse_bytes.go
generated
vendored
@@ -7,147 +7,137 @@ package topdown
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
const (
|
||||
none int64 = 1
|
||||
kb = 1000
|
||||
ki = 1024
|
||||
mb = kb * 1000
|
||||
mi = ki * 1024
|
||||
gb = mb * 1000
|
||||
gi = mi * 1024
|
||||
tb = gb * 1000
|
||||
ti = gi * 1024
|
||||
none uint64 = 1 << (10 * iota)
|
||||
ki
|
||||
mi
|
||||
gi
|
||||
ti
|
||||
pi
|
||||
ei
|
||||
|
||||
kb uint64 = 1000
|
||||
mb = kb * 1000
|
||||
gb = mb * 1000
|
||||
tb = gb * 1000
|
||||
pb = tb * 1000
|
||||
eb = pb * 1000
|
||||
)
|
||||
|
||||
// The rune values for 0..9 as well as the period symbol (for parsing floats)
|
||||
var numRunes = []rune("0123456789.")
|
||||
|
||||
func parseNumBytesError(msg string) error {
|
||||
return fmt.Errorf("%s error: %s", ast.UnitsParseBytes.Name, msg)
|
||||
return fmt.Errorf("%s: %s", ast.UnitsParseBytes.Name, msg)
|
||||
}
|
||||
|
||||
func errUnitNotRecognized(unit string) error {
|
||||
func errBytesUnitNotRecognized(unit string) error {
|
||||
return parseNumBytesError(fmt.Sprintf("byte unit %s not recognized", unit))
|
||||
}
|
||||
|
||||
var (
|
||||
errNoAmount = parseNumBytesError("no byte amount provided")
|
||||
errIntConv = parseNumBytesError("could not parse byte amount to integer")
|
||||
errIncludesSpaces = parseNumBytesError("spaces not allowed in resource strings")
|
||||
errBytesValueNoAmount = parseNumBytesError("no byte amount provided")
|
||||
errBytesValueNumConv = parseNumBytesError("could not parse byte amount to a number")
|
||||
errBytesValueIncludesSpaces = parseNumBytesError("spaces not allowed in resource strings")
|
||||
)
|
||||
|
||||
func builtinNumBytes(a ast.Value) (ast.Value, error) {
|
||||
var m int64
|
||||
func builtinNumBytes(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
var m big.Float
|
||||
|
||||
raw, err := builtins.StringOperand(a, 1)
|
||||
raw, err := builtins.StringOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
s := formatString(raw)
|
||||
|
||||
if strings.Contains(s, " ") {
|
||||
return nil, errIncludesSpaces
|
||||
return errBytesValueIncludesSpaces
|
||||
}
|
||||
|
||||
numStr, unitStr := extractNumAndUnit(s)
|
||||
|
||||
if numStr == "" {
|
||||
return nil, errNoAmount
|
||||
num, unit := extractNumAndUnit(s)
|
||||
if num == "" {
|
||||
return errBytesValueNoAmount
|
||||
}
|
||||
|
||||
switch unitStr {
|
||||
switch unit {
|
||||
case "":
|
||||
m = none
|
||||
case "kb":
|
||||
m = kb
|
||||
case "kib":
|
||||
m = ki
|
||||
case "mb":
|
||||
m = mb
|
||||
case "mib":
|
||||
m = mi
|
||||
case "gb":
|
||||
m = gb
|
||||
case "gib":
|
||||
m = gi
|
||||
case "tb":
|
||||
m = tb
|
||||
case "tib":
|
||||
m = ti
|
||||
m.SetUint64(none)
|
||||
case "kb", "k":
|
||||
m.SetUint64(kb)
|
||||
case "kib", "ki":
|
||||
m.SetUint64(ki)
|
||||
case "mb", "m":
|
||||
m.SetUint64(mb)
|
||||
case "mib", "mi":
|
||||
m.SetUint64(mi)
|
||||
case "gb", "g":
|
||||
m.SetUint64(gb)
|
||||
case "gib", "gi":
|
||||
m.SetUint64(gi)
|
||||
case "tb", "t":
|
||||
m.SetUint64(tb)
|
||||
case "tib", "ti":
|
||||
m.SetUint64(ti)
|
||||
case "pb", "p":
|
||||
m.SetUint64(pb)
|
||||
case "pib", "pi":
|
||||
m.SetUint64(pi)
|
||||
case "eb", "e":
|
||||
m.SetUint64(eb)
|
||||
case "eib", "ei":
|
||||
m.SetUint64(ei)
|
||||
default:
|
||||
return nil, errUnitNotRecognized(unitStr)
|
||||
return errBytesUnitNotRecognized(unit)
|
||||
}
|
||||
|
||||
num, err := strconv.ParseInt(numStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, errIntConv
|
||||
numFloat, ok := new(big.Float).SetString(num)
|
||||
if !ok {
|
||||
return errBytesValueNumConv
|
||||
}
|
||||
|
||||
total := num * m
|
||||
|
||||
return builtins.IntToNumber(big.NewInt(total)), nil
|
||||
var total big.Int
|
||||
numFloat.Mul(numFloat, &m).Int(&total)
|
||||
return iter(ast.NewTerm(builtins.IntToNumber(&total)))
|
||||
}
|
||||
|
||||
// Makes the string lower case and removes spaces and quotation marks
|
||||
// Makes the string lower case and removes quotation marks
|
||||
func formatString(s ast.String) string {
|
||||
str := string(s)
|
||||
lower := strings.ToLower(str)
|
||||
return strings.Replace(lower, "\"", "", -1)
|
||||
}
|
||||
|
||||
// Splits the string into a number string à la "10" or "10.2" and a unit string à la "gb" or "MiB" or "foo". Either
|
||||
// can be an empty string (error handling is provided elsewhere).
|
||||
// Splits the string into a number string à la "10" or "10.2" and a unit
|
||||
// string à la "gb" or "MiB" or "foo". Either can be an empty string
|
||||
// (error handling is provided elsewhere).
|
||||
func extractNumAndUnit(s string) (string, string) {
|
||||
isNum := func(r rune) (isNum bool) {
|
||||
for _, nr := range numRunes {
|
||||
if nr == r {
|
||||
return true
|
||||
}
|
||||
isNum := func(r rune) bool {
|
||||
return unicode.IsDigit(r) || r == '.'
|
||||
}
|
||||
|
||||
firstNonNumIdx := -1
|
||||
for idx, r := range s {
|
||||
if !isNum(r) {
|
||||
firstNonNumIdx = idx
|
||||
break
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Returns the index of the first rune that's not a number (or 0 if there are only numbers)
|
||||
getFirstNonNumIdx := func(s string) int {
|
||||
for idx, r := range s {
|
||||
if !isNum(r) {
|
||||
return idx
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
firstRuneIsNum := func(s string) bool {
|
||||
return isNum(rune(s[0]))
|
||||
}
|
||||
|
||||
firstNonNumIdx := getFirstNonNumIdx(s)
|
||||
|
||||
// The string contains only a number
|
||||
numOnly := firstNonNumIdx == 0 && firstRuneIsNum(s)
|
||||
|
||||
// The string contains only a unit
|
||||
unitOnly := firstNonNumIdx == 0 && !firstRuneIsNum(s)
|
||||
|
||||
if numOnly {
|
||||
if firstNonNumIdx == -1 { // only digits and '.'
|
||||
return s, ""
|
||||
} else if unitOnly {
|
||||
return "", s
|
||||
} else {
|
||||
return s[0:firstNonNumIdx], s[firstNonNumIdx:]
|
||||
}
|
||||
if firstNonNumIdx == 0 { // only units (starts with non-digit)
|
||||
return "", s
|
||||
}
|
||||
|
||||
return s[0:firstNonNumIdx], s[firstNonNumIdx:]
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin1(ast.UnitsParseBytes.Name, builtinNumBytes)
|
||||
RegisterBuiltinFunc(ast.UnitsParseBytes.Name, builtinNumBytes)
|
||||
}
|
||||
|
||||
125
vendor/github.com/open-policy-agent/opa/topdown/parse_units.go
generated
vendored
Normal file
125
vendor/github.com/open-policy-agent/opa/topdown/parse_units.go
generated
vendored
Normal file
@@ -0,0 +1,125 @@
|
||||
// Copyright 2022 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
// Binary Si unit constants are borrowed from topdown/parse_bytes
|
||||
const siMilli = 0.001
|
||||
const (
|
||||
siK uint64 = 1000
|
||||
siM = siK * 1000
|
||||
siG = siM * 1000
|
||||
siT = siG * 1000
|
||||
siP = siT * 1000
|
||||
siE = siP * 1000
|
||||
)
|
||||
|
||||
func parseUnitsError(msg string) error {
|
||||
return fmt.Errorf("%s: %s", ast.UnitsParse.Name, msg)
|
||||
}
|
||||
|
||||
func errUnitNotRecognized(unit string) error {
|
||||
return parseUnitsError(fmt.Sprintf("unit %s not recognized", unit))
|
||||
}
|
||||
|
||||
var (
|
||||
errNoAmount = parseUnitsError("no amount provided")
|
||||
errNumConv = parseUnitsError("could not parse amount to a number")
|
||||
errIncludesSpaces = parseUnitsError("spaces not allowed in resource strings")
|
||||
)
|
||||
|
||||
// Accepts both normal SI and binary SI units.
|
||||
func builtinUnits(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
var x big.Rat
|
||||
|
||||
raw, err := builtins.StringOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We remove escaped quotes from strings here to retain parity with units.parse_bytes.
|
||||
s := string(raw)
|
||||
s = strings.Replace(s, "\"", "", -1)
|
||||
|
||||
if strings.Contains(s, " ") {
|
||||
return errIncludesSpaces
|
||||
}
|
||||
|
||||
num, unit := extractNumAndUnit(s)
|
||||
if num == "" {
|
||||
return errNoAmount
|
||||
}
|
||||
|
||||
// Unlike in units.parse_bytes, we only lowercase after the first letter,
|
||||
// so that we can distinguish between 'm' and 'M'.
|
||||
if len(unit) > 1 {
|
||||
lower := strings.ToLower(unit[1:])
|
||||
unit = unit[:1] + lower
|
||||
}
|
||||
|
||||
switch unit {
|
||||
case "m":
|
||||
x.SetFloat64(siMilli)
|
||||
case "":
|
||||
x.SetUint64(none)
|
||||
case "k", "K":
|
||||
x.SetUint64(siK)
|
||||
case "ki", "Ki":
|
||||
x.SetUint64(ki)
|
||||
case "M":
|
||||
x.SetUint64(siM)
|
||||
case "mi", "Mi":
|
||||
x.SetUint64(mi)
|
||||
case "g", "G":
|
||||
x.SetUint64(siG)
|
||||
case "gi", "Gi":
|
||||
x.SetUint64(gi)
|
||||
case "t", "T":
|
||||
x.SetUint64(siT)
|
||||
case "ti", "Ti":
|
||||
x.SetUint64(ti)
|
||||
case "p", "P":
|
||||
x.SetUint64(siP)
|
||||
case "pi", "Pi":
|
||||
x.SetUint64(pi)
|
||||
case "e", "E":
|
||||
x.SetUint64(siE)
|
||||
case "ei", "Ei":
|
||||
x.SetUint64(ei)
|
||||
default:
|
||||
return errUnitNotRecognized(unit)
|
||||
}
|
||||
|
||||
numRat, ok := new(big.Rat).SetString(num)
|
||||
if !ok {
|
||||
return errNumConv
|
||||
}
|
||||
|
||||
numRat.Mul(numRat, &x)
|
||||
|
||||
// Cleaner printout when we have a pure integer value.
|
||||
if numRat.IsInt() {
|
||||
return iter(ast.NumberTerm(json.Number(numRat.Num().String())))
|
||||
}
|
||||
|
||||
// When using just big.Float, we had floating-point precision
|
||||
// issues because quantities like 0.001 are not exactly representable.
|
||||
// Rationals (such as big.Rat) do not suffer this problem, but are
|
||||
// more expensive to compute with in general.
|
||||
return iter(ast.NumberTerm(json.Number(numRat.FloatString(10))))
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.UnitsParse.Name, builtinUnits)
|
||||
}
|
||||
86
vendor/github.com/open-policy-agent/opa/topdown/print.go
generated
vendored
Normal file
86
vendor/github.com/open-policy-agent/opa/topdown/print.go
generated
vendored
Normal file
@@ -0,0 +1,86 @@
|
||||
// Copyright 2021 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
"github.com/open-policy-agent/opa/topdown/print"
|
||||
)
|
||||
|
||||
func NewPrintHook(w io.Writer) print.Hook {
|
||||
return printHook{w: w}
|
||||
}
|
||||
|
||||
type printHook struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (h printHook) Print(_ print.Context, msg string) error {
|
||||
_, err := fmt.Fprintln(h.w, msg)
|
||||
return err
|
||||
}
|
||||
|
||||
func builtinPrint(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
if bctx.PrintHook == nil {
|
||||
return iter(nil)
|
||||
}
|
||||
|
||||
arr, err := builtins.ArrayOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := make([]string, arr.Len())
|
||||
|
||||
err = builtinPrintCrossProductOperands(bctx, buf, arr, 0, func(buf []string) error {
|
||||
pctx := print.Context{
|
||||
Context: bctx.Context,
|
||||
Location: bctx.Location,
|
||||
}
|
||||
return bctx.PrintHook.Print(pctx, strings.Join(buf, " "))
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return iter(nil)
|
||||
}
|
||||
|
||||
func builtinPrintCrossProductOperands(bctx BuiltinContext, buf []string, operands *ast.Array, i int, f func([]string) error) error {
|
||||
|
||||
if i >= operands.Len() {
|
||||
return f(buf)
|
||||
}
|
||||
|
||||
xs, ok := operands.Elem(i).Value.(ast.Set)
|
||||
if !ok {
|
||||
return Halt{Err: internalErr(bctx.Location, fmt.Sprintf("illegal argument type: %v", ast.TypeName(operands.Elem(i).Value)))}
|
||||
}
|
||||
|
||||
if xs.Len() == 0 {
|
||||
buf[i] = "<undefined>"
|
||||
return builtinPrintCrossProductOperands(bctx, buf, operands, i+1, f)
|
||||
}
|
||||
|
||||
return xs.Iter(func(x *ast.Term) error {
|
||||
switch v := x.Value.(type) {
|
||||
case ast.String:
|
||||
buf[i] = string(v)
|
||||
default:
|
||||
buf[i] = v.String()
|
||||
}
|
||||
return builtinPrintCrossProductOperands(bctx, buf, operands, i+1, f)
|
||||
})
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.InternalPrint.Name, builtinPrint)
|
||||
}
|
||||
21
vendor/github.com/open-policy-agent/opa/topdown/print/print.go
generated
vendored
Normal file
21
vendor/github.com/open-policy-agent/opa/topdown/print/print.go
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
package print
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
)
|
||||
|
||||
// Context provides the Hook implementation context about the print() call.
|
||||
type Context struct {
|
||||
Context context.Context // request context passed when query executed
|
||||
Location *ast.Location // location of print call
|
||||
}
|
||||
|
||||
// Hook defines the interface that callers can implement to receive print
|
||||
// statement outputs. If the hook returns an error, it will be surfaced if
|
||||
// strict builtin error checking is enabled (otherwise, it will not halt
|
||||
// execution.)
|
||||
type Hook interface {
|
||||
Print(Context, string) error
|
||||
}
|
||||
350
vendor/github.com/open-policy-agent/opa/topdown/query.go
generated
vendored
350
vendor/github.com/open-policy-agent/opa/topdown/query.go
generated
vendored
@@ -2,13 +2,20 @@ package topdown
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/metrics"
|
||||
"github.com/open-policy-agent/opa/resolver"
|
||||
"github.com/open-policy-agent/opa/storage"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
"github.com/open-policy-agent/opa/topdown/cache"
|
||||
"github.com/open-policy-agent/opa/topdown/copypropagation"
|
||||
"github.com/open-policy-agent/opa/topdown/print"
|
||||
"github.com/open-policy-agent/opa/tracing"
|
||||
)
|
||||
|
||||
// QueryResultSet represents a collection of results returned by a query.
|
||||
@@ -20,23 +27,35 @@ type QueryResult map[ast.Var]*ast.Term
|
||||
|
||||
// Query provides a configurable interface for performing query evaluation.
|
||||
type Query struct {
|
||||
cancel Cancel
|
||||
query ast.Body
|
||||
queryCompiler ast.QueryCompiler
|
||||
compiler *ast.Compiler
|
||||
store storage.Store
|
||||
txn storage.Transaction
|
||||
input *ast.Term
|
||||
tracers []Tracer
|
||||
unknowns []*ast.Term
|
||||
partialNamespace string
|
||||
metrics metrics.Metrics
|
||||
instr *Instrumentation
|
||||
disableInlining []ast.Ref
|
||||
genvarprefix string
|
||||
runtime *ast.Term
|
||||
builtins map[string]*Builtin
|
||||
indexing bool
|
||||
seed io.Reader
|
||||
time time.Time
|
||||
cancel Cancel
|
||||
query ast.Body
|
||||
queryCompiler ast.QueryCompiler
|
||||
compiler *ast.Compiler
|
||||
store storage.Store
|
||||
txn storage.Transaction
|
||||
input *ast.Term
|
||||
external *resolverTrie
|
||||
tracers []QueryTracer
|
||||
plugTraceVars bool
|
||||
unknowns []*ast.Term
|
||||
partialNamespace string
|
||||
skipSaveNamespace bool
|
||||
metrics metrics.Metrics
|
||||
instr *Instrumentation
|
||||
disableInlining []ast.Ref
|
||||
shallowInlining bool
|
||||
genvarprefix string
|
||||
runtime *ast.Term
|
||||
builtins map[string]*Builtin
|
||||
indexing bool
|
||||
earlyExit bool
|
||||
interQueryBuiltinCache cache.InterQueryCache
|
||||
ndBuiltinCache builtins.NDBCache
|
||||
strictBuiltinErrors bool
|
||||
printHook print.Hook
|
||||
tracingOpts tracing.Options
|
||||
}
|
||||
|
||||
// Builtin represents a built-in function that queries can call.
|
||||
@@ -51,6 +70,8 @@ func NewQuery(query ast.Body) *Query {
|
||||
query: query,
|
||||
genvarprefix: ast.WildcardPrefix,
|
||||
indexing: true,
|
||||
earlyExit: true,
|
||||
external: newResolverTrie(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,8 +115,31 @@ func (q *Query) WithInput(input *ast.Term) *Query {
|
||||
}
|
||||
|
||||
// WithTracer adds a query tracer to use during evaluation. This is optional.
|
||||
// Deprecated: Use WithQueryTracer instead.
|
||||
func (q *Query) WithTracer(tracer Tracer) *Query {
|
||||
qt, ok := tracer.(QueryTracer)
|
||||
if !ok {
|
||||
qt = WrapLegacyTracer(tracer)
|
||||
}
|
||||
return q.WithQueryTracer(qt)
|
||||
}
|
||||
|
||||
// WithQueryTracer adds a query tracer to use during evaluation. This is optional.
|
||||
// Disabled QueryTracers will be ignored.
|
||||
func (q *Query) WithQueryTracer(tracer QueryTracer) *Query {
|
||||
if !tracer.Enabled() {
|
||||
return q
|
||||
}
|
||||
|
||||
q.tracers = append(q.tracers, tracer)
|
||||
|
||||
// If *any* of the tracers require local variable metadata we need to
|
||||
// enabled plugging local trace variables.
|
||||
conf := tracer.Config()
|
||||
if conf.PlugLocalVars {
|
||||
q.plugTraceVars = true
|
||||
}
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
@@ -128,6 +172,13 @@ func (q *Query) WithPartialNamespace(ns string) *Query {
|
||||
return q
|
||||
}
|
||||
|
||||
// WithSkipPartialNamespace disables namespacing of saved support rules that are generated
|
||||
// from the original policy (rules which are completely synthetic are still namespaced.)
|
||||
func (q *Query) WithSkipPartialNamespace(yes bool) *Query {
|
||||
q.skipSaveNamespace = yes
|
||||
return q
|
||||
}
|
||||
|
||||
// WithDisableInlining adds a set of paths to the query that should be excluded from
|
||||
// inlining. Inlining during partial evaluation can be expensive in some cases
|
||||
// (e.g., when a cross-product is computed.) Disabling inlining avoids expensive
|
||||
@@ -137,6 +188,14 @@ func (q *Query) WithDisableInlining(paths []ast.Ref) *Query {
|
||||
return q
|
||||
}
|
||||
|
||||
// WithShallowInlining disables aggressive inlining performed during partial evaluation.
|
||||
// When shallow inlining is enabled rules that depend (transitively) on unknowns are not inlined.
|
||||
// Only rules/values that are completely known will be inlined.
|
||||
func (q *Query) WithShallowInlining(yes bool) *Query {
|
||||
q.shallowInlining = yes
|
||||
return q
|
||||
}
|
||||
|
||||
// WithRuntime sets the runtime data to execute the query with. The runtime data
|
||||
// can be returned by the `opa.runtime` built-in function.
|
||||
func (q *Query) WithRuntime(runtime *ast.Term) *Query {
|
||||
@@ -158,6 +217,61 @@ func (q *Query) WithIndexing(enabled bool) *Query {
|
||||
return q
|
||||
}
|
||||
|
||||
// WithEarlyExit will enable or disable using 'early exit' for the evaluation
|
||||
// of the query. The default is enabled.
|
||||
func (q *Query) WithEarlyExit(enabled bool) *Query {
|
||||
q.earlyExit = enabled
|
||||
return q
|
||||
}
|
||||
|
||||
// WithSeed sets a reader that will seed randomization required by built-in functions.
|
||||
// If a seed is not provided crypto/rand.Reader is used.
|
||||
func (q *Query) WithSeed(r io.Reader) *Query {
|
||||
q.seed = r
|
||||
return q
|
||||
}
|
||||
|
||||
// WithTime sets the time that will be returned by the time.now_ns() built-in function.
|
||||
func (q *Query) WithTime(x time.Time) *Query {
|
||||
q.time = x
|
||||
return q
|
||||
}
|
||||
|
||||
// WithInterQueryBuiltinCache sets the inter-query cache that built-in functions can utilize.
|
||||
func (q *Query) WithInterQueryBuiltinCache(c cache.InterQueryCache) *Query {
|
||||
q.interQueryBuiltinCache = c
|
||||
return q
|
||||
}
|
||||
|
||||
// WithNDBuiltinCache sets the non-deterministic builtin cache.
|
||||
func (q *Query) WithNDBuiltinCache(c builtins.NDBCache) *Query {
|
||||
q.ndBuiltinCache = c
|
||||
return q
|
||||
}
|
||||
|
||||
// WithStrictBuiltinErrors tells the evaluator to treat all built-in function errors as fatal errors.
|
||||
func (q *Query) WithStrictBuiltinErrors(yes bool) *Query {
|
||||
q.strictBuiltinErrors = yes
|
||||
return q
|
||||
}
|
||||
|
||||
// WithResolver configures an external resolver to use for the given ref.
|
||||
func (q *Query) WithResolver(ref ast.Ref, r resolver.Resolver) *Query {
|
||||
q.external.Put(ref, r)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *Query) WithPrintHook(h print.Hook) *Query {
|
||||
q.printHook = h
|
||||
return q
|
||||
}
|
||||
|
||||
// WithDistributedTracingOpts sets the options to be used by distributed tracing.
|
||||
func (q *Query) WithDistributedTracingOpts(tr tracing.Options) *Query {
|
||||
q.tracingOpts = tr
|
||||
return q
|
||||
}
|
||||
|
||||
// PartialRun executes partial evaluation on the query with respect to unknown
|
||||
// values. Partial evaluation attempts to evaluate as much of the query as
|
||||
// possible without requiring values for the unknowns set on the query. The
|
||||
@@ -169,45 +283,79 @@ func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support []
|
||||
if q.partialNamespace == "" {
|
||||
q.partialNamespace = "partial" // lazily initialize partial namespace
|
||||
}
|
||||
if q.seed == nil {
|
||||
q.seed = rand.Reader
|
||||
}
|
||||
if !q.time.IsZero() {
|
||||
q.time = time.Now()
|
||||
}
|
||||
if q.metrics == nil {
|
||||
q.metrics = metrics.New()
|
||||
}
|
||||
f := &queryIDFactory{}
|
||||
b := newBindings(0, q.instr)
|
||||
e := &eval{
|
||||
ctx: ctx,
|
||||
cancel: q.cancel,
|
||||
query: q.query,
|
||||
queryCompiler: q.queryCompiler,
|
||||
queryIDFact: f,
|
||||
queryID: f.Next(),
|
||||
bindings: b,
|
||||
compiler: q.compiler,
|
||||
store: q.store,
|
||||
baseCache: newBaseCache(),
|
||||
targetStack: newRefStack(),
|
||||
txn: q.txn,
|
||||
input: q.input,
|
||||
tracers: q.tracers,
|
||||
instr: q.instr,
|
||||
builtins: q.builtins,
|
||||
builtinCache: builtins.Cache{},
|
||||
virtualCache: newVirtualCache(),
|
||||
saveSet: newSaveSet(q.unknowns, b, q.instr),
|
||||
saveStack: newSaveStack(),
|
||||
saveSupport: newSaveSupport(),
|
||||
saveNamespace: ast.StringTerm(q.partialNamespace),
|
||||
ctx: ctx,
|
||||
metrics: q.metrics,
|
||||
seed: q.seed,
|
||||
time: ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())),
|
||||
cancel: q.cancel,
|
||||
query: q.query,
|
||||
queryCompiler: q.queryCompiler,
|
||||
queryIDFact: f,
|
||||
queryID: f.Next(),
|
||||
bindings: b,
|
||||
compiler: q.compiler,
|
||||
store: q.store,
|
||||
baseCache: newBaseCache(),
|
||||
targetStack: newRefStack(),
|
||||
txn: q.txn,
|
||||
input: q.input,
|
||||
external: q.external,
|
||||
tracers: q.tracers,
|
||||
traceEnabled: len(q.tracers) > 0,
|
||||
plugTraceVars: q.plugTraceVars,
|
||||
instr: q.instr,
|
||||
builtins: q.builtins,
|
||||
builtinCache: builtins.Cache{},
|
||||
functionMocks: newFunctionMocksStack(),
|
||||
interQueryBuiltinCache: q.interQueryBuiltinCache,
|
||||
ndBuiltinCache: q.ndBuiltinCache,
|
||||
virtualCache: newVirtualCache(),
|
||||
comprehensionCache: newComprehensionCache(),
|
||||
saveSet: newSaveSet(q.unknowns, b, q.instr),
|
||||
saveStack: newSaveStack(),
|
||||
saveSupport: newSaveSupport(),
|
||||
saveNamespace: ast.StringTerm(q.partialNamespace),
|
||||
skipSaveNamespace: q.skipSaveNamespace,
|
||||
inliningControl: &inliningControl{
|
||||
shallow: q.shallowInlining,
|
||||
},
|
||||
genvarprefix: q.genvarprefix,
|
||||
runtime: q.runtime,
|
||||
indexing: q.indexing,
|
||||
earlyExit: q.earlyExit,
|
||||
builtinErrors: &builtinErrors{},
|
||||
printHook: q.printHook,
|
||||
}
|
||||
|
||||
if len(q.disableInlining) > 0 {
|
||||
e.disableInlining = [][]ast.Ref{q.disableInlining}
|
||||
e.inliningControl.PushDisable(q.disableInlining, false)
|
||||
}
|
||||
|
||||
e.caller = e
|
||||
q.startTimer(metrics.RegoPartialEval)
|
||||
defer q.stopTimer(metrics.RegoPartialEval)
|
||||
q.metrics.Timer(metrics.RegoPartialEval).Start()
|
||||
defer q.metrics.Timer(metrics.RegoPartialEval).Stop()
|
||||
|
||||
livevars := ast.NewVarSet()
|
||||
for _, t := range q.unknowns {
|
||||
switch v := t.Value.(type) {
|
||||
case ast.Var:
|
||||
livevars.Add(v)
|
||||
case ast.Ref:
|
||||
livevars.Add(v[0].Value.(ast.Var))
|
||||
}
|
||||
}
|
||||
|
||||
ast.WalkVars(q.query, func(x ast.Var) bool {
|
||||
if !x.IsGenerated() {
|
||||
@@ -216,7 +364,7 @@ func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support []
|
||||
return false
|
||||
})
|
||||
|
||||
p := copypropagation.New(livevars)
|
||||
p := copypropagation.New(livevars).WithCompiler(q.compiler)
|
||||
|
||||
err = e.Run(func(e *eval) error {
|
||||
|
||||
@@ -230,10 +378,10 @@ func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support []
|
||||
// Include bindings as exprs so that when caller evals the result, they
|
||||
// can obtain values for the vars in their query.
|
||||
bindingExprs := []*ast.Expr{}
|
||||
e.bindings.Iter(e.bindings, func(a, b *ast.Term) error {
|
||||
_ = e.bindings.Iter(e.bindings, func(a, b *ast.Term) error {
|
||||
bindingExprs = append(bindingExprs, ast.Equality.Expr(a, b))
|
||||
return nil
|
||||
})
|
||||
}) // cannot return error
|
||||
|
||||
// Sort binding expressions so that results are deterministic.
|
||||
sort.Slice(bindingExprs, func(i, j int) bool {
|
||||
@@ -244,12 +392,32 @@ func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support []
|
||||
body.Append(bindingExprs[i])
|
||||
}
|
||||
|
||||
partials = append(partials, applyCopyPropagation(p, e.instr, body))
|
||||
// Skip this rule body if it fails to type-check.
|
||||
// Type-checking failure means the rule body will never succeed.
|
||||
if !e.compiler.PassesTypeCheck(body) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !q.shallowInlining {
|
||||
body = applyCopyPropagation(p, e.instr, body)
|
||||
}
|
||||
|
||||
partials = append(partials, body)
|
||||
return nil
|
||||
})
|
||||
|
||||
support = e.saveSupport.List()
|
||||
|
||||
if q.strictBuiltinErrors && len(e.builtinErrors.errs) > 0 {
|
||||
err = e.builtinErrors.errs[0]
|
||||
}
|
||||
|
||||
for i := range support {
|
||||
sort.Slice(support[i].Rules, func(j, k int) bool {
|
||||
return support[i].Rules[j].Compare(support[i].Rules[k]) < 0
|
||||
})
|
||||
}
|
||||
|
||||
return partials, support, err
|
||||
}
|
||||
|
||||
@@ -266,52 +434,68 @@ func (q *Query) Run(ctx context.Context) (QueryResultSet, error) {
|
||||
// Iter executes the query and invokes the iter function with query results
|
||||
// produced by evaluating the query.
|
||||
func (q *Query) Iter(ctx context.Context, iter func(QueryResult) error) error {
|
||||
if q.seed == nil {
|
||||
q.seed = rand.Reader
|
||||
}
|
||||
if q.time.IsZero() {
|
||||
q.time = time.Now()
|
||||
}
|
||||
if q.metrics == nil {
|
||||
q.metrics = metrics.New()
|
||||
}
|
||||
f := &queryIDFactory{}
|
||||
e := &eval{
|
||||
ctx: ctx,
|
||||
cancel: q.cancel,
|
||||
query: q.query,
|
||||
queryCompiler: q.queryCompiler,
|
||||
queryIDFact: f,
|
||||
queryID: f.Next(),
|
||||
bindings: newBindings(0, q.instr),
|
||||
compiler: q.compiler,
|
||||
store: q.store,
|
||||
baseCache: newBaseCache(),
|
||||
targetStack: newRefStack(),
|
||||
txn: q.txn,
|
||||
input: q.input,
|
||||
tracers: q.tracers,
|
||||
instr: q.instr,
|
||||
builtins: q.builtins,
|
||||
builtinCache: builtins.Cache{},
|
||||
virtualCache: newVirtualCache(),
|
||||
genvarprefix: q.genvarprefix,
|
||||
runtime: q.runtime,
|
||||
indexing: q.indexing,
|
||||
ctx: ctx,
|
||||
metrics: q.metrics,
|
||||
seed: q.seed,
|
||||
time: ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())),
|
||||
cancel: q.cancel,
|
||||
query: q.query,
|
||||
queryCompiler: q.queryCompiler,
|
||||
queryIDFact: f,
|
||||
queryID: f.Next(),
|
||||
bindings: newBindings(0, q.instr),
|
||||
compiler: q.compiler,
|
||||
store: q.store,
|
||||
baseCache: newBaseCache(),
|
||||
targetStack: newRefStack(),
|
||||
txn: q.txn,
|
||||
input: q.input,
|
||||
external: q.external,
|
||||
tracers: q.tracers,
|
||||
traceEnabled: len(q.tracers) > 0,
|
||||
plugTraceVars: q.plugTraceVars,
|
||||
instr: q.instr,
|
||||
builtins: q.builtins,
|
||||
builtinCache: builtins.Cache{},
|
||||
functionMocks: newFunctionMocksStack(),
|
||||
interQueryBuiltinCache: q.interQueryBuiltinCache,
|
||||
ndBuiltinCache: q.ndBuiltinCache,
|
||||
virtualCache: newVirtualCache(),
|
||||
comprehensionCache: newComprehensionCache(),
|
||||
genvarprefix: q.genvarprefix,
|
||||
runtime: q.runtime,
|
||||
indexing: q.indexing,
|
||||
earlyExit: q.earlyExit,
|
||||
builtinErrors: &builtinErrors{},
|
||||
printHook: q.printHook,
|
||||
tracingOpts: q.tracingOpts,
|
||||
}
|
||||
e.caller = e
|
||||
q.startTimer(metrics.RegoQueryEval)
|
||||
q.metrics.Timer(metrics.RegoQueryEval).Start()
|
||||
err := e.Run(func(e *eval) error {
|
||||
qr := QueryResult{}
|
||||
e.bindings.Iter(nil, func(k, v *ast.Term) error {
|
||||
_ = e.bindings.Iter(nil, func(k, v *ast.Term) error {
|
||||
qr[k.Value.(ast.Var)] = v
|
||||
return nil
|
||||
})
|
||||
}) // cannot return error
|
||||
return iter(qr)
|
||||
})
|
||||
q.stopTimer(metrics.RegoQueryEval)
|
||||
|
||||
if q.strictBuiltinErrors && err == nil && len(e.builtinErrors.errs) > 0 {
|
||||
err = e.builtinErrors.errs[0]
|
||||
}
|
||||
|
||||
q.metrics.Timer(metrics.RegoQueryEval).Stop()
|
||||
return err
|
||||
}
|
||||
|
||||
func (q *Query) startTimer(name string) {
|
||||
if q.metrics != nil {
|
||||
q.metrics.Timer(name).Start()
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Query) stopTimer(name string) {
|
||||
if q.metrics != nil {
|
||||
q.metrics.Timer(name).Stop()
|
||||
}
|
||||
}
|
||||
|
||||
142
vendor/github.com/open-policy-agent/opa/topdown/reachable.go
generated
vendored
Normal file
142
vendor/github.com/open-policy-agent/opa/topdown/reachable.go
generated
vendored
Normal file
@@ -0,0 +1,142 @@
|
||||
// Copyright 2020 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
// Helper: sets of vertices can be represented as Arrays or Sets.
|
||||
func foreachVertex(collection *ast.Term, f func(*ast.Term)) {
|
||||
switch v := collection.Value.(type) {
|
||||
case ast.Set:
|
||||
v.Foreach(f)
|
||||
case *ast.Array:
|
||||
v.Foreach(f)
|
||||
}
|
||||
}
|
||||
|
||||
// numberOfEdges returns the number of elements of an array or a set (of edges)
|
||||
func numberOfEdges(collection *ast.Term) int {
|
||||
switch v := collection.Value.(type) {
|
||||
case ast.Set:
|
||||
return v.Len()
|
||||
case *ast.Array:
|
||||
return v.Len()
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func builtinReachable(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
// Error on wrong types for args.
|
||||
graph, err := builtins.ObjectOperand(args[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var queue []*ast.Term
|
||||
switch initial := args[1].Value.(type) {
|
||||
case *ast.Array, ast.Set:
|
||||
foreachVertex(ast.NewTerm(initial), func(t *ast.Term) {
|
||||
queue = append(queue, t)
|
||||
})
|
||||
default:
|
||||
return builtins.NewOperandTypeErr(2, initial, "{array, set}")
|
||||
}
|
||||
|
||||
// This is the set of nodes we have reached.
|
||||
reached := ast.NewSet()
|
||||
|
||||
// Keep going as long as we have nodes in the queue.
|
||||
for len(queue) > 0 {
|
||||
// Get the edges for this node. If the node was not in the graph,
|
||||
// `edges` will be `nil` and we can ignore it.
|
||||
node := queue[0]
|
||||
if edges := graph.Get(node); edges != nil {
|
||||
// Add all the newly discovered neighbors.
|
||||
foreachVertex(edges, func(neighbor *ast.Term) {
|
||||
if !reached.Contains(neighbor) {
|
||||
queue = append(queue, neighbor)
|
||||
}
|
||||
})
|
||||
// Mark the node as reached.
|
||||
reached.Add(node)
|
||||
}
|
||||
queue = queue[1:]
|
||||
}
|
||||
|
||||
return iter(ast.NewTerm(reached))
|
||||
}
|
||||
|
||||
// pathBuilder is called recursively to build an array of paths that are reachable from the root
|
||||
func pathBuilder(graph ast.Object, root *ast.Term, path []*ast.Term, paths []*ast.Term, reached ast.Set) []*ast.Term {
|
||||
if edges := graph.Get(root); edges != nil {
|
||||
path = append(path, root)
|
||||
|
||||
if numberOfEdges(edges) >= 1 {
|
||||
foreachVertex(edges, func(neighbor *ast.Term) {
|
||||
if reached.Contains(neighbor) {
|
||||
// If we've already reached this node, return current path (avoid infinite recursion)
|
||||
paths = append(paths, path...)
|
||||
} else {
|
||||
reached.Add(root)
|
||||
paths = pathBuilder(graph, neighbor, path, paths, reached)
|
||||
}
|
||||
})
|
||||
} else {
|
||||
paths = append(paths, path...)
|
||||
}
|
||||
} else {
|
||||
// Node is nonexistent (not in graph). Commit the current path (without adding this root)
|
||||
paths = append(paths, path...)
|
||||
}
|
||||
|
||||
return paths
|
||||
}
|
||||
|
||||
func builtinReachablePaths(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
// Error on wrong types for args.
|
||||
graph, err := builtins.ObjectOperand(args[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// This is a queue that holds all nodes we still need to visit. It is
|
||||
// initialised to the initial set of nodes we start out with.
|
||||
var queue []*ast.Term
|
||||
switch initial := args[1].Value.(type) {
|
||||
case *ast.Array, ast.Set:
|
||||
foreachVertex(ast.NewTerm(initial), func(t *ast.Term) {
|
||||
queue = append(queue, t)
|
||||
})
|
||||
default:
|
||||
return builtins.NewOperandTypeErr(2, initial, "{array, set}")
|
||||
}
|
||||
|
||||
results := ast.NewSet()
|
||||
|
||||
for _, node := range queue {
|
||||
// Find reachable paths from edges in root node in queue and append arrays to the results set
|
||||
if edges := graph.Get(node); edges != nil {
|
||||
if numberOfEdges(edges) >= 1 {
|
||||
foreachVertex(edges, func(neighbor *ast.Term) {
|
||||
paths := pathBuilder(graph, neighbor, []*ast.Term{node}, []*ast.Term{}, ast.NewSet(node))
|
||||
results.Add(ast.ArrayTerm(paths...))
|
||||
})
|
||||
} else {
|
||||
results.Add(ast.ArrayTerm(node))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return iter(ast.NewTerm(results))
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.ReachableBuiltin.Name, builtinReachable)
|
||||
RegisterBuiltinFunc(ast.ReachablePathsBuiltin.Name, builtinReachablePaths)
|
||||
}
|
||||
70
vendor/github.com/open-policy-agent/opa/topdown/regex.go
generated
vendored
70
vendor/github.com/open-policy-agent/opa/topdown/regex.go
generated
vendored
@@ -9,7 +9,7 @@ import (
|
||||
"regexp"
|
||||
"sync"
|
||||
|
||||
"github.com/yashtewari/glob-intersection"
|
||||
gintersect "github.com/yashtewari/glob-intersection"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
@@ -18,6 +18,21 @@ import (
|
||||
var regexpCacheLock = sync.Mutex{}
|
||||
var regexpCache map[string]*regexp.Regexp
|
||||
|
||||
func builtinRegexIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
s, err := builtins.StringOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return iter(ast.BooleanTerm(false))
|
||||
}
|
||||
|
||||
_, err = regexp.Compile(string(s))
|
||||
if err != nil {
|
||||
return iter(ast.BooleanTerm(false))
|
||||
}
|
||||
|
||||
return iter(ast.BooleanTerm(true))
|
||||
}
|
||||
|
||||
func builtinRegexMatch(a, b ast.Value) (ast.Value, error) {
|
||||
s1, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
@@ -79,11 +94,11 @@ func builtinRegexSplit(a, b ast.Value) (ast.Value, error) {
|
||||
}
|
||||
|
||||
elems := re.Split(string(s2), -1)
|
||||
arr := make(ast.Array, len(elems))
|
||||
for i := range arr {
|
||||
arr := make([]*ast.Term, len(elems))
|
||||
for i := range elems {
|
||||
arr[i] = ast.StringTerm(elems[i])
|
||||
}
|
||||
return arr, nil
|
||||
return ast.NewArray(arr...), nil
|
||||
}
|
||||
|
||||
func getRegexp(pat string) (*regexp.Regexp, error) {
|
||||
@@ -151,11 +166,11 @@ func builtinRegexFind(a, b, c ast.Value) (ast.Value, error) {
|
||||
}
|
||||
|
||||
elems := re.FindAllString(string(s2), n)
|
||||
arr := make(ast.Array, len(elems))
|
||||
for i := range arr {
|
||||
arr := make([]*ast.Term, len(elems))
|
||||
for i := range elems {
|
||||
arr[i] = ast.StringTerm(elems[i])
|
||||
}
|
||||
return arr, nil
|
||||
return ast.NewArray(arr...), nil
|
||||
}
|
||||
|
||||
func builtinRegexFindAllStringSubmatch(a, b, c ast.Value) (ast.Value, error) {
|
||||
@@ -178,24 +193,53 @@ func builtinRegexFindAllStringSubmatch(a, b, c ast.Value) (ast.Value, error) {
|
||||
}
|
||||
matches := re.FindAllStringSubmatch(string(s2), n)
|
||||
|
||||
outer := make(ast.Array, len(matches))
|
||||
for i := range outer {
|
||||
inner := make(ast.Array, len(matches[i]))
|
||||
for j := range inner {
|
||||
outer := make([]*ast.Term, len(matches))
|
||||
for i := range matches {
|
||||
inner := make([]*ast.Term, len(matches[i]))
|
||||
for j := range matches[i] {
|
||||
inner[j] = ast.StringTerm(matches[i][j])
|
||||
}
|
||||
outer[i] = ast.ArrayTerm(inner...)
|
||||
outer[i] = ast.NewTerm(ast.NewArray(inner...))
|
||||
}
|
||||
|
||||
return outer, nil
|
||||
return ast.NewArray(outer...), nil
|
||||
}
|
||||
|
||||
func builtinRegexReplace(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
base, err := builtins.StringOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pattern, err := builtins.StringOperand(operands[1].Value, 2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
value, err := builtins.StringOperand(operands[2].Value, 3)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
re, err := getRegexp(string(pattern))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res := re.ReplaceAllString(string(base), string(value))
|
||||
|
||||
return iter(ast.StringTerm(res))
|
||||
}
|
||||
|
||||
func init() {
|
||||
regexpCache = map[string]*regexp.Regexp{}
|
||||
RegisterBuiltinFunc(ast.RegexIsValid.Name, builtinRegexIsValid)
|
||||
RegisterFunctionalBuiltin2(ast.RegexMatch.Name, builtinRegexMatch)
|
||||
RegisterFunctionalBuiltin2(ast.RegexMatchDeprecated.Name, builtinRegexMatch)
|
||||
RegisterFunctionalBuiltin2(ast.RegexSplit.Name, builtinRegexSplit)
|
||||
RegisterFunctionalBuiltin2(ast.GlobsMatch.Name, builtinGlobsMatch)
|
||||
RegisterFunctionalBuiltin4(ast.RegexTemplateMatch.Name, builtinRegexMatchTemplate)
|
||||
RegisterFunctionalBuiltin3(ast.RegexFind.Name, builtinRegexFind)
|
||||
RegisterFunctionalBuiltin3(ast.RegexFindAllStringSubmatch.Name, builtinRegexFindAllStringSubmatch)
|
||||
RegisterBuiltinFunc(ast.RegexReplace.Name, builtinRegexReplace)
|
||||
}
|
||||
|
||||
107
vendor/github.com/open-policy-agent/opa/topdown/resolver.go
generated
vendored
Normal file
107
vendor/github.com/open-policy-agent/opa/topdown/resolver.go
generated
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
// Copyright 2020 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/metrics"
|
||||
"github.com/open-policy-agent/opa/resolver"
|
||||
)
|
||||
|
||||
type resolverTrie struct {
|
||||
r resolver.Resolver
|
||||
children map[ast.Value]*resolverTrie
|
||||
}
|
||||
|
||||
func newResolverTrie() *resolverTrie {
|
||||
return &resolverTrie{children: map[ast.Value]*resolverTrie{}}
|
||||
}
|
||||
|
||||
func (t *resolverTrie) Put(ref ast.Ref, r resolver.Resolver) {
|
||||
node := t
|
||||
for _, t := range ref {
|
||||
child, ok := node.children[t.Value]
|
||||
if !ok {
|
||||
child = &resolverTrie{children: map[ast.Value]*resolverTrie{}}
|
||||
node.children[t.Value] = child
|
||||
}
|
||||
node = child
|
||||
}
|
||||
node.r = r
|
||||
}
|
||||
|
||||
func (t *resolverTrie) Resolve(e *eval, ref ast.Ref) (ast.Value, error) {
|
||||
e.metrics.Timer(metrics.RegoExternalResolve).Start()
|
||||
defer e.metrics.Timer(metrics.RegoExternalResolve).Stop()
|
||||
node := t
|
||||
for i, t := range ref {
|
||||
child, ok := node.children[t.Value]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
node = child
|
||||
if node.r != nil {
|
||||
in := resolver.Input{
|
||||
Ref: ref[:i+1],
|
||||
Input: e.input,
|
||||
Metrics: e.metrics,
|
||||
}
|
||||
e.traceWasm(e.query[e.index], &in.Ref)
|
||||
if e.data != nil {
|
||||
return nil, errInScopeWithStmt
|
||||
}
|
||||
result, err := node.r.Eval(e.ctx, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.Value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
val, err := result.Value.Find(ref[i+1:])
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
}
|
||||
return node.mktree(e, resolver.Input{
|
||||
Ref: ref,
|
||||
Input: e.input,
|
||||
Metrics: e.metrics,
|
||||
})
|
||||
}
|
||||
|
||||
func (t *resolverTrie) mktree(e *eval, in resolver.Input) (ast.Value, error) {
|
||||
if t.r != nil {
|
||||
e.traceWasm(e.query[e.index], &in.Ref)
|
||||
if e.data != nil {
|
||||
return nil, errInScopeWithStmt
|
||||
}
|
||||
result, err := t.r.Eval(e.ctx, in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.Value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return result.Value, nil
|
||||
}
|
||||
obj := ast.NewObject()
|
||||
for k, child := range t.children {
|
||||
v, err := child.mktree(e, resolver.Input{Ref: append(in.Ref, ast.NewTerm(k)), Input: in.Input, Metrics: in.Metrics})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v != nil {
|
||||
obj.Insert(ast.NewTerm(k), ast.NewTerm(v))
|
||||
}
|
||||
}
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
var errInScopeWithStmt = &Error{
|
||||
Code: InternalErr,
|
||||
Message: "wasm cannot be executed when 'with' statements are in-scope",
|
||||
}
|
||||
109
vendor/github.com/open-policy-agent/opa/topdown/runtime.go
generated
vendored
109
vendor/github.com/open-policy-agent/opa/topdown/runtime.go
generated
vendored
@@ -4,7 +4,11 @@
|
||||
|
||||
package topdown
|
||||
|
||||
import "github.com/open-policy-agent/opa/ast"
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
)
|
||||
|
||||
func builtinOPARuntime(bctx BuiltinContext, _ []*ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
@@ -12,9 +16,112 @@ func builtinOPARuntime(bctx BuiltinContext, _ []*ast.Term, iter func(*ast.Term)
|
||||
return iter(ast.ObjectTerm())
|
||||
}
|
||||
|
||||
if bctx.Runtime.Get(ast.StringTerm("config")) != nil {
|
||||
iface, err := ast.ValueToInterface(bctx.Runtime.Value, illegalResolver{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if object, ok := iface.(map[string]interface{}); ok {
|
||||
if cfgRaw, ok := object["config"]; ok {
|
||||
if config, ok := cfgRaw.(map[string]interface{}); ok {
|
||||
configPurged, err := activeConfig(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
object["config"] = configPurged
|
||||
value, err := ast.InterfaceToValue(object)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return iter(ast.NewTerm(value))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return iter(bctx.Runtime)
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.OPARuntime.Name, builtinOPARuntime)
|
||||
}
|
||||
|
||||
func activeConfig(config map[string]interface{}) (interface{}, error) {
|
||||
|
||||
if config["services"] != nil {
|
||||
err := removeServiceCredentials(config["services"])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if config["keys"] != nil {
|
||||
err := removeCryptoKeys(config["keys"])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func removeServiceCredentials(x interface{}) error {
|
||||
|
||||
switch x := x.(type) {
|
||||
case []interface{}:
|
||||
for _, v := range x {
|
||||
err := removeKey(v, "credentials")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
case map[string]interface{}:
|
||||
for _, v := range x {
|
||||
err := removeKey(v, "credentials")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("illegal service config type: %T", x)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeCryptoKeys(x interface{}) error {
|
||||
|
||||
switch x := x.(type) {
|
||||
case map[string]interface{}:
|
||||
for _, v := range x {
|
||||
err := removeKey(v, "key", "private_key")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("illegal keys config type: %T", x)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeKey(x interface{}, keys ...string) error {
|
||||
val, ok := x.(map[string]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("type assertion error")
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
delete(val, key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type illegalResolver struct{}
|
||||
|
||||
func (illegalResolver) Resolve(ref ast.Ref) (interface{}, error) {
|
||||
return nil, fmt.Errorf("illegal value: %v", ref)
|
||||
}
|
||||
|
||||
59
vendor/github.com/open-policy-agent/opa/topdown/save.go
generated
vendored
59
vendor/github.com/open-policy-agent/opa/topdown/save.go
generated
vendored
@@ -57,7 +57,7 @@ func (ss *saveSet) contains(t *ast.Term, b *bindings) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ContainsRecursive retruns true if the term t is or contains a term that is
|
||||
// ContainsRecursive returns true if the term t is or contains a term that is
|
||||
// contained in the save set. This function will close over the binding list
|
||||
// when it encounters vars.
|
||||
func (ss *saveSet) ContainsRecursive(t *ast.Term, b *bindings) bool {
|
||||
@@ -279,7 +279,7 @@ func newSaveSupport() *saveSupport {
|
||||
}
|
||||
|
||||
func (s *saveSupport) List() []*ast.Module {
|
||||
result := []*ast.Module{}
|
||||
result := make([]*ast.Module, 0, len(s.modules))
|
||||
for _, module := range s.modules {
|
||||
result = append(result, module)
|
||||
}
|
||||
@@ -321,7 +321,7 @@ func (s *saveSupport) Insert(path ast.Ref, rule *ast.Rule) {
|
||||
// being saved. This check allows the evaluator to evaluate statements
|
||||
// completely during partial evaluation as long as they do not depend on any
|
||||
// kind of unknown value or statements that would generate saves.
|
||||
func saveRequired(c *ast.Compiler, ss *saveSet, b *bindings, x interface{}, rec bool) bool {
|
||||
func saveRequired(c *ast.Compiler, ic *inliningControl, icIgnoreInternal bool, ss *saveSet, b *bindings, x interface{}, rec bool) bool {
|
||||
|
||||
var found bool
|
||||
|
||||
@@ -344,9 +344,11 @@ func saveRequired(c *ast.Compiler, ss *saveSet, b *bindings, x interface{}, rec
|
||||
case ast.Ref:
|
||||
if ss.Contains(node, b) {
|
||||
found = true
|
||||
} else if ic.Disabled(v.ConstantPrefix(), icIgnoreInternal) {
|
||||
found = true
|
||||
} else {
|
||||
for _, rule := range c.GetRulesDynamic(v) {
|
||||
if saveRequired(c, ss, b, rule, true) {
|
||||
for _, rule := range c.GetRulesDynamicWithOpts(v, ast.RulesOptions{IncludeHiddenModules: false}) {
|
||||
if saveRequired(c, ic, icIgnoreInternal, ss, b, rule, true) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
@@ -373,10 +375,57 @@ func ignoreExprDuringPartial(expr *ast.Expr) bool {
|
||||
}
|
||||
|
||||
func ignoreDuringPartial(bi *ast.Builtin) bool {
|
||||
// Note(philipc): We keep this legacy check around to avoid breaking
|
||||
// existing library users.
|
||||
//nolint:staticcheck // We specifically ignore our own linter warning here.
|
||||
for _, ignore := range ast.IgnoreDuringPartialEval {
|
||||
if bi == ignore {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Otherwise, ensure all non-deterministic builtins are thrown out.
|
||||
return bi.Nondeterministic
|
||||
}
|
||||
|
||||
type inliningControl struct {
|
||||
shallow bool
|
||||
disable []disableInliningFrame
|
||||
}
|
||||
|
||||
type disableInliningFrame struct {
|
||||
internal bool
|
||||
refs []ast.Ref
|
||||
}
|
||||
|
||||
func (i *inliningControl) PushDisable(refs []ast.Ref, internal bool) {
|
||||
if i == nil {
|
||||
return
|
||||
}
|
||||
i.disable = append(i.disable, disableInliningFrame{
|
||||
internal: internal,
|
||||
refs: refs,
|
||||
})
|
||||
}
|
||||
|
||||
func (i *inliningControl) PopDisable() {
|
||||
if i == nil {
|
||||
return
|
||||
}
|
||||
i.disable = i.disable[:len(i.disable)-1]
|
||||
}
|
||||
|
||||
func (i *inliningControl) Disabled(ref ast.Ref, ignoreInternal bool) bool {
|
||||
if i == nil {
|
||||
return false
|
||||
}
|
||||
for _, frame := range i.disable {
|
||||
if !frame.internal || !ignoreInternal {
|
||||
for _, other := range frame.refs {
|
||||
if other.HasPrefix(ref) || ref.HasPrefix(other) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
59
vendor/github.com/open-policy-agent/opa/topdown/semver.go
generated
vendored
Normal file
59
vendor/github.com/open-policy-agent/opa/topdown/semver.go
generated
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright 2020 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/internal/semver"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
func builtinSemVerCompare(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
versionStringA, err := builtins.StringOperand(args[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
versionStringB, err := builtins.StringOperand(args[1].Value, 2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
versionA, err := semver.NewVersion(string(versionStringA))
|
||||
if err != nil {
|
||||
return fmt.Errorf("operand 1: string %s is not a valid SemVer", versionStringA)
|
||||
}
|
||||
versionB, err := semver.NewVersion(string(versionStringB))
|
||||
if err != nil {
|
||||
return fmt.Errorf("operand 2: string %s is not a valid SemVer", versionStringB)
|
||||
}
|
||||
|
||||
result := versionA.Compare(*versionB)
|
||||
|
||||
return iter(ast.IntNumberTerm(result))
|
||||
}
|
||||
|
||||
func builtinSemVerIsValid(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
versionString, err := builtins.StringOperand(args[0].Value, 1)
|
||||
if err != nil {
|
||||
return iter(ast.BooleanTerm(false))
|
||||
}
|
||||
|
||||
result := true
|
||||
|
||||
_, err = semver.NewVersion(string(versionString))
|
||||
if err != nil {
|
||||
result = false
|
||||
}
|
||||
|
||||
return iter(ast.BooleanTerm(result))
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.SemVerCompare.Name, builtinSemVerCompare)
|
||||
RegisterBuiltinFunc(ast.SemVerIsValid.Name, builtinSemVerIsValid)
|
||||
}
|
||||
12
vendor/github.com/open-policy-agent/opa/topdown/sets.go
generated
vendored
12
vendor/github.com/open-policy-agent/opa/topdown/sets.go
generated
vendored
@@ -58,22 +58,26 @@ func builtinSetIntersection(a ast.Value) (ast.Value, error) {
|
||||
|
||||
// builtinSetUnion returns the union of the given input sets
|
||||
func builtinSetUnion(a ast.Value) (ast.Value, error) {
|
||||
// The set union logic here is duplicated and manually inlined on
|
||||
// purpose. By lifting this logic up a level, and not doing pairwise
|
||||
// set unions, we avoid a number of heap allocations. This improves
|
||||
// performance dramatically over the naive approach.
|
||||
result := ast.NewSet()
|
||||
|
||||
inputSet, err := builtins.SetOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := ast.NewSet()
|
||||
|
||||
err = inputSet.Iter(func(x *ast.Term) error {
|
||||
n, err := builtins.SetOperand(x.Value, 1)
|
||||
item, err := builtins.SetOperand(x.Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result = result.Union(n)
|
||||
item.Foreach(result.Add)
|
||||
return nil
|
||||
})
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
|
||||
257
vendor/github.com/open-policy-agent/opa/topdown/strings.go
generated
vendored
257
vendor/github.com/open-policy-agent/opa/topdown/strings.go
generated
vendored
@@ -5,14 +5,113 @@
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/tchap/go-patricia/v2/patricia"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
func builtinAnyPrefixMatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
a, b := operands[0].Value, operands[1].Value
|
||||
|
||||
var strs []string
|
||||
switch a := a.(type) {
|
||||
case ast.String:
|
||||
strs = []string{string(a)}
|
||||
case *ast.Array, ast.Set:
|
||||
var err error
|
||||
strs, err = builtins.StringSliceOperand(a, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return builtins.NewOperandTypeErr(1, a, "string", "set", "array")
|
||||
}
|
||||
|
||||
var prefixes []string
|
||||
switch b := b.(type) {
|
||||
case ast.String:
|
||||
prefixes = []string{string(b)}
|
||||
case *ast.Array, ast.Set:
|
||||
var err error
|
||||
prefixes, err = builtins.StringSliceOperand(b, 2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return builtins.NewOperandTypeErr(2, b, "string", "set", "array")
|
||||
}
|
||||
|
||||
return iter(ast.BooleanTerm(anyStartsWithAny(strs, prefixes)))
|
||||
}
|
||||
|
||||
func builtinAnySuffixMatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
a, b := operands[0].Value, operands[1].Value
|
||||
|
||||
var strsReversed []string
|
||||
switch a := a.(type) {
|
||||
case ast.String:
|
||||
strsReversed = []string{reverseString(string(a))}
|
||||
case *ast.Array, ast.Set:
|
||||
strs, err := builtins.StringSliceOperand(a, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
strsReversed = make([]string, len(strs))
|
||||
for i := range strs {
|
||||
strsReversed[i] = reverseString(strs[i])
|
||||
}
|
||||
default:
|
||||
return builtins.NewOperandTypeErr(1, a, "string", "set", "array")
|
||||
}
|
||||
|
||||
var suffixesReversed []string
|
||||
switch b := b.(type) {
|
||||
case ast.String:
|
||||
suffixesReversed = []string{reverseString(string(b))}
|
||||
case *ast.Array, ast.Set:
|
||||
suffixes, err := builtins.StringSliceOperand(b, 2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
suffixesReversed = make([]string, len(suffixes))
|
||||
for i := range suffixes {
|
||||
suffixesReversed[i] = reverseString(suffixes[i])
|
||||
}
|
||||
default:
|
||||
return builtins.NewOperandTypeErr(2, b, "string", "set", "array")
|
||||
}
|
||||
|
||||
return iter(ast.BooleanTerm(anyStartsWithAny(strsReversed, suffixesReversed)))
|
||||
}
|
||||
|
||||
func anyStartsWithAny(strs []string, prefixes []string) bool {
|
||||
if len(strs) == 0 || len(prefixes) == 0 {
|
||||
return false
|
||||
}
|
||||
if len(strs) == 1 && len(prefixes) == 1 {
|
||||
return strings.HasPrefix(strs[0], prefixes[0])
|
||||
}
|
||||
|
||||
trie := patricia.NewTrie()
|
||||
for i := 0; i < len(strs); i++ {
|
||||
trie.Insert([]byte(strs[i]), true)
|
||||
}
|
||||
|
||||
for i := 0; i < len(prefixes); i++ {
|
||||
if trie.MatchSubtree([]byte(prefixes[i])) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func builtinFormatInt(a, b ast.Value) (ast.Value, error) {
|
||||
|
||||
input, err := builtins.NumberOperand(a, 1)
|
||||
@@ -55,13 +154,17 @@ func builtinConcat(a, b ast.Value) (ast.Value, error) {
|
||||
strs := []string{}
|
||||
|
||||
switch b := b.(type) {
|
||||
case ast.Array:
|
||||
for i := range b {
|
||||
s, ok := b[i].Value.(ast.String)
|
||||
case *ast.Array:
|
||||
err := b.Iter(func(x *ast.Term) error {
|
||||
s, ok := x.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, builtins.NewOperandElementErr(2, b, b[i].Value, "string")
|
||||
return builtins.NewOperandElementErr(2, b, x.Value, "string")
|
||||
}
|
||||
strs = append(strs, string(s))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case ast.Set:
|
||||
err := b.Iter(func(x *ast.Term) error {
|
||||
@@ -82,6 +185,18 @@ func builtinConcat(a, b ast.Value) (ast.Value, error) {
|
||||
return ast.String(strings.Join(strs, string(join))), nil
|
||||
}
|
||||
|
||||
func runesEqual(a, b []rune) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a {
|
||||
if v != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func builtinIndexOf(a, b ast.Value) (ast.Value, error) {
|
||||
base, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
@@ -92,9 +207,57 @@ func builtinIndexOf(a, b ast.Value) (ast.Value, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(string(search)) == 0 {
|
||||
return nil, fmt.Errorf("empty search character")
|
||||
}
|
||||
|
||||
index := strings.Index(string(base), string(search))
|
||||
return ast.IntNumberTerm(index).Value, nil
|
||||
baseRunes := []rune(string(base))
|
||||
searchRunes := []rune(string(search))
|
||||
searchLen := len(searchRunes)
|
||||
|
||||
for i, r := range baseRunes {
|
||||
if len(baseRunes) >= i+searchLen {
|
||||
if r == searchRunes[0] && runesEqual(baseRunes[i:i+searchLen], searchRunes) {
|
||||
return ast.IntNumberTerm(i).Value, nil
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return ast.IntNumberTerm(-1).Value, nil
|
||||
}
|
||||
|
||||
func builtinIndexOfN(a, b ast.Value) (ast.Value, error) {
|
||||
base, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
search, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(string(search)) == 0 {
|
||||
return nil, fmt.Errorf("empty search character")
|
||||
}
|
||||
|
||||
baseRunes := []rune(string(base))
|
||||
searchRunes := []rune(string(search))
|
||||
searchLen := len(searchRunes)
|
||||
|
||||
var arr []*ast.Term
|
||||
for i, r := range baseRunes {
|
||||
if len(baseRunes) >= i+searchLen {
|
||||
if r == searchRunes[0] && runesEqual(baseRunes[i:i+searchLen], searchRunes) {
|
||||
arr = append(arr, ast.IntNumberTerm(i))
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return ast.NewArray(arr...), nil
|
||||
}
|
||||
|
||||
func builtinSubstring(a, b, c ast.Value) (ast.Value, error) {
|
||||
@@ -103,11 +266,12 @@ func builtinSubstring(a, b, c ast.Value) (ast.Value, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
runes := []rune(base)
|
||||
|
||||
startIndex, err := builtins.IntOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if startIndex >= len(base) {
|
||||
} else if startIndex >= len(runes) {
|
||||
return ast.String(""), nil
|
||||
} else if startIndex < 0 {
|
||||
return nil, fmt.Errorf("negative offset")
|
||||
@@ -120,13 +284,13 @@ func builtinSubstring(a, b, c ast.Value) (ast.Value, error) {
|
||||
|
||||
var s ast.String
|
||||
if length < 0 {
|
||||
s = ast.String(base[startIndex:])
|
||||
s = ast.String(runes[startIndex:])
|
||||
} else {
|
||||
upto := startIndex + length
|
||||
if len(base) < upto {
|
||||
upto = len(base)
|
||||
if len(runes) < upto {
|
||||
upto = len(runes)
|
||||
}
|
||||
s = ast.String(base[startIndex:upto])
|
||||
s = ast.String(runes[startIndex:upto])
|
||||
}
|
||||
|
||||
return s, nil
|
||||
@@ -202,11 +366,11 @@ func builtinSplit(a, b ast.Value) (ast.Value, error) {
|
||||
return nil, err
|
||||
}
|
||||
elems := strings.Split(string(s), string(d))
|
||||
arr := make(ast.Array, len(elems))
|
||||
for i := range arr {
|
||||
arr := make([]*ast.Term, len(elems))
|
||||
for i := range elems {
|
||||
arr[i] = ast.StringTerm(elems[i])
|
||||
}
|
||||
return arr, nil
|
||||
return ast.NewArray(arr...), nil
|
||||
}
|
||||
|
||||
func builtinReplace(a, b, c ast.Value) (ast.Value, error) {
|
||||
@@ -229,27 +393,33 @@ func builtinReplace(a, b, c ast.Value) (ast.Value, error) {
|
||||
}
|
||||
|
||||
func builtinReplaceN(a, b ast.Value) (ast.Value, error) {
|
||||
asJSON, err := ast.JSON(a)
|
||||
patterns, err := builtins.ObjectOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
oldnewObj, ok := asJSON.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "object")
|
||||
}
|
||||
keys := patterns.Keys()
|
||||
sort.Slice(keys, func(i, j int) bool { return ast.Compare(keys[i].Value, keys[j].Value) < 0 })
|
||||
|
||||
s, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var oldnewArr []string
|
||||
for k, v := range oldnewObj {
|
||||
strVal, ok := v.(string)
|
||||
oldnewArr := make([]string, 0, len(keys)*2)
|
||||
for _, k := range keys {
|
||||
keyVal, ok := k.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, errors.New("non-string value found in pattern object")
|
||||
return nil, builtins.NewOperandErr(1, "non-string key found in pattern object")
|
||||
}
|
||||
oldnewArr = append(oldnewArr, k, strVal)
|
||||
val := patterns.Get(k) // cannot be nil
|
||||
strVal, ok := val.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, builtins.NewOperandErr(1, "non-string value found in pattern object")
|
||||
}
|
||||
oldnewArr = append(oldnewArr, string(keyVal), string(strVal))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := strings.NewReplacer(oldnewArr...)
|
||||
@@ -343,18 +513,20 @@ func builtinSprintf(a, b ast.Value) (ast.Value, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
astArr, ok := b.(ast.Array)
|
||||
astArr, ok := b.(*ast.Array)
|
||||
if !ok {
|
||||
return nil, builtins.NewOperandTypeErr(2, b, "array")
|
||||
}
|
||||
|
||||
args := make([]interface{}, len(astArr))
|
||||
args := make([]interface{}, astArr.Len())
|
||||
|
||||
for i := range astArr {
|
||||
switch v := astArr[i].Value.(type) {
|
||||
for i := range args {
|
||||
switch v := astArr.Elem(i).Value.(type) {
|
||||
case ast.Number:
|
||||
if n, ok := v.Int(); ok {
|
||||
args[i] = n
|
||||
} else if b, ok := new(big.Int).SetString(v.String(), 10); ok {
|
||||
args[i] = b
|
||||
} else if f, ok := v.Float64(); ok {
|
||||
args[i] = f
|
||||
} else {
|
||||
@@ -363,17 +535,39 @@ func builtinSprintf(a, b ast.Value) (ast.Value, error) {
|
||||
case ast.String:
|
||||
args[i] = string(v)
|
||||
default:
|
||||
args[i] = astArr[i].String()
|
||||
args[i] = astArr.Elem(i).String()
|
||||
}
|
||||
}
|
||||
|
||||
return ast.String(fmt.Sprintf(string(s), args...)), nil
|
||||
}
|
||||
|
||||
func builtinReverse(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
s, err := builtins.StringOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return iter(ast.StringTerm(reverseString(string(s))))
|
||||
}
|
||||
|
||||
func reverseString(str string) string {
|
||||
sRunes := []rune(string(str))
|
||||
length := len(sRunes)
|
||||
reversedRunes := make([]rune, length)
|
||||
|
||||
for index, r := range sRunes {
|
||||
reversedRunes[length-index-1] = r
|
||||
}
|
||||
|
||||
return string(reversedRunes)
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin2(ast.FormatInt.Name, builtinFormatInt)
|
||||
RegisterFunctionalBuiltin2(ast.Concat.Name, builtinConcat)
|
||||
RegisterFunctionalBuiltin2(ast.IndexOf.Name, builtinIndexOf)
|
||||
RegisterFunctionalBuiltin2(ast.IndexOfN.Name, builtinIndexOfN)
|
||||
RegisterFunctionalBuiltin3(ast.Substring.Name, builtinSubstring)
|
||||
RegisterFunctionalBuiltin2(ast.Contains.Name, builtinContains)
|
||||
RegisterFunctionalBuiltin2(ast.StartsWith.Name, builtinStartsWith)
|
||||
@@ -390,4 +584,7 @@ func init() {
|
||||
RegisterFunctionalBuiltin2(ast.TrimSuffix.Name, builtinTrimSuffix)
|
||||
RegisterFunctionalBuiltin1(ast.TrimSpace.Name, builtinTrimSpace)
|
||||
RegisterFunctionalBuiltin2(ast.Sprintf.Name, builtinSprintf)
|
||||
RegisterBuiltinFunc(ast.AnyPrefixMatch.Name, builtinAnyPrefixMatch)
|
||||
RegisterBuiltinFunc(ast.AnySuffixMatch.Name, builtinAnySuffixMatch)
|
||||
RegisterBuiltinFunc(ast.StringReverse.Name, builtinReverse)
|
||||
}
|
||||
|
||||
263
vendor/github.com/open-policy-agent/opa/topdown/subset.go
generated
vendored
Normal file
263
vendor/github.com/open-policy-agent/opa/topdown/subset.go
generated
vendored
Normal file
@@ -0,0 +1,263 @@
|
||||
// Copyright 2022 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
func bothObjects(t1, t2 *ast.Term) (bool, ast.Object, ast.Object) {
|
||||
if (t1 == nil) || (t2 == nil) {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
obj1, ok := t1.Value.(ast.Object)
|
||||
if !ok {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
obj2, ok := t2.Value.(ast.Object)
|
||||
if !ok {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
return true, obj1, obj2
|
||||
}
|
||||
|
||||
func bothSets(t1, t2 *ast.Term) (bool, ast.Set, ast.Set) {
|
||||
if (t1 == nil) || (t2 == nil) {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
set1, ok := t1.Value.(ast.Set)
|
||||
if !ok {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
set2, ok := t2.Value.(ast.Set)
|
||||
if !ok {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
return true, set1, set2
|
||||
}
|
||||
|
||||
func bothArrays(t1, t2 *ast.Term) (bool, *ast.Array, *ast.Array) {
|
||||
if (t1 == nil) || (t2 == nil) {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
array1, ok := t1.Value.(*ast.Array)
|
||||
if !ok {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
array2, ok := t2.Value.(*ast.Array)
|
||||
if !ok {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
return true, array1, array2
|
||||
}
|
||||
|
||||
func arraySet(t1, t2 *ast.Term) (bool, *ast.Array, ast.Set) {
|
||||
if (t1 == nil) || (t2 == nil) {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
array, ok := t1.Value.(*ast.Array)
|
||||
if !ok {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
set, ok := t2.Value.(ast.Set)
|
||||
if !ok {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
return true, array, set
|
||||
}
|
||||
|
||||
// objectSubset implements the subset operation on a pair of objects.
|
||||
//
|
||||
// This function will try to recursively apply the subset operation where it
|
||||
// can, such as if both super and sub have an object or set as the value
|
||||
// associated with a key.
|
||||
func objectSubset(super ast.Object, sub ast.Object) bool {
|
||||
var superTerm *ast.Term
|
||||
isSubset := true
|
||||
|
||||
sub.Until(func(key, subTerm *ast.Term) bool {
|
||||
// This really wants to be a for loop, hence the somewhat
|
||||
// weird internal structure. However, using Until() in this
|
||||
// was is a performance optimization, as it avoids performing
|
||||
// any key hashing on the sub-object.
|
||||
|
||||
superTerm = super.Get(key)
|
||||
|
||||
// subTerm is can't be nil because we got it from Until(), so
|
||||
// we only need to verify that super is non-nil.
|
||||
if superTerm == nil {
|
||||
isSubset = false
|
||||
return true // break, not a subset
|
||||
}
|
||||
|
||||
if subTerm.Equal(superTerm) {
|
||||
return false // continue
|
||||
}
|
||||
|
||||
// If both of the terms are objects then we want to apply
|
||||
// the subset operation recursively, otherwise we just compare
|
||||
// them normally. If only one term is an object, then we
|
||||
// do a normal comparison which will come up false.
|
||||
if ok, superObj, subObj := bothObjects(superTerm, subTerm); ok {
|
||||
if !objectSubset(superObj, subObj) {
|
||||
isSubset = false
|
||||
return true // break, not a subset
|
||||
}
|
||||
|
||||
return false // continue
|
||||
}
|
||||
|
||||
if ok, superSet, subSet := bothSets(superTerm, subTerm); ok {
|
||||
if !setSubset(superSet, subSet) {
|
||||
isSubset = false
|
||||
return true // break, not a subset
|
||||
}
|
||||
|
||||
return false // continue
|
||||
}
|
||||
|
||||
if ok, superArray, subArray := bothArrays(superTerm, subTerm); ok {
|
||||
if !arraySubset(superArray, subArray) {
|
||||
isSubset = false
|
||||
return true // break, not a subset
|
||||
}
|
||||
|
||||
return false // continue
|
||||
}
|
||||
|
||||
// We have already checked for exact equality, as well as for
|
||||
// all of the types of nested subsets we care about, so if we
|
||||
// get here it means this isn't a subset.
|
||||
isSubset = false
|
||||
return true // break, not a subset
|
||||
})
|
||||
|
||||
return isSubset
|
||||
}
|
||||
|
||||
// setSubset implements the subset operation on sets.
|
||||
//
|
||||
// Unlike in the object case, this is not recursive, we just compare values
|
||||
// using ast.Set.Contains() because we have no well defined way to "match up"
|
||||
// objects that are in different sets.
|
||||
func setSubset(super ast.Set, sub ast.Set) bool {
|
||||
isSubset := true
|
||||
sub.Until(func(t *ast.Term) bool {
|
||||
if !super.Contains(t) {
|
||||
isSubset = false
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
return isSubset
|
||||
}
|
||||
|
||||
// arraySubset implements the subset operation on arrays.
|
||||
//
|
||||
// This is defined to mean that the entire "sub" array must appear in
|
||||
// the "super" array. For the same rationale as setSubset(), we do not attempt
|
||||
// to recurse into values.
|
||||
func arraySubset(super, sub *ast.Array) bool {
|
||||
// Notice that this is essentially string search. The naive approach
|
||||
// used here is O(n^2). This should probably be rewritten later to use
|
||||
// Boyer-Moore or something.
|
||||
|
||||
if sub.Len() > super.Len() {
|
||||
return false
|
||||
}
|
||||
|
||||
if sub.Equal(super) {
|
||||
return true
|
||||
}
|
||||
|
||||
superCursor := 0
|
||||
subCursor := 0
|
||||
for {
|
||||
if subCursor == sub.Len() {
|
||||
return true
|
||||
}
|
||||
|
||||
if superCursor == super.Len() {
|
||||
return false
|
||||
}
|
||||
|
||||
subElem := sub.Elem(subCursor)
|
||||
superElem := sub.Elem(superCursor + subCursor)
|
||||
if superElem == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if superElem.Value.Compare(subElem.Value) == 0 {
|
||||
subCursor++
|
||||
} else {
|
||||
superCursor++
|
||||
subCursor = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// arraySetSubset implements the subset operation on array and set.
|
||||
//
|
||||
// This is defined to mean that the entire "sub" set must appear in
|
||||
// the "super" array with no consideration of ordering.
|
||||
// For the same rationale as setSubset(), we do not attempt
|
||||
// to recurse into values.
|
||||
func arraySetSubset(super *ast.Array, sub ast.Set) bool {
|
||||
unmatched := sub.Len()
|
||||
return super.Until(func(t *ast.Term) bool {
|
||||
if sub.Contains(t) {
|
||||
unmatched--
|
||||
}
|
||||
if unmatched == 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
func builtinObjectSubset(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
superTerm := operands[0]
|
||||
subTerm := operands[1]
|
||||
|
||||
if ok, superObj, subObj := bothObjects(superTerm, subTerm); ok {
|
||||
// Both operands are objects.
|
||||
return iter(ast.BooleanTerm(objectSubset(superObj, subObj)))
|
||||
}
|
||||
|
||||
if ok, superSet, subSet := bothSets(superTerm, subTerm); ok {
|
||||
// Both operands are sets.
|
||||
return iter(ast.BooleanTerm(setSubset(superSet, subSet)))
|
||||
}
|
||||
|
||||
if ok, superArray, subArray := bothArrays(superTerm, subTerm); ok {
|
||||
// Both operands are sets.
|
||||
return iter(ast.BooleanTerm(arraySubset(superArray, subArray)))
|
||||
}
|
||||
|
||||
if ok, superArray, subSet := arraySet(superTerm, subTerm); ok {
|
||||
// Super operand is array and sub operand is set
|
||||
return iter(ast.BooleanTerm(arraySetSubset(superArray, subSet)))
|
||||
}
|
||||
|
||||
return builtins.ErrOperand("both arguments object.subset must be of the same type or array and set")
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.ObjectSubset.Name, builtinObjectSubset)
|
||||
}
|
||||
160
vendor/github.com/open-policy-agent/opa/topdown/time.go
generated
vendored
160
vendor/github.com/open-policy-agent/opa/topdown/time.go
generated
vendored
@@ -7,6 +7,7 @@ package topdown
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
"strconv"
|
||||
"sync"
|
||||
@@ -16,61 +17,61 @@ import (
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
type nowKeyID string
|
||||
|
||||
var nowKey = nowKeyID("time.now_ns")
|
||||
var tzCache map[string]*time.Location
|
||||
var tzCacheMutex *sync.Mutex
|
||||
|
||||
func builtinTimeNowNanos(bctx BuiltinContext, _ []*ast.Term, iter func(*ast.Term) error) error {
|
||||
// 1677-09-21T00:12:43.145224192-00:00
|
||||
var minDateAllowedForNsConversion = time.Unix(0, math.MinInt64)
|
||||
|
||||
exist, ok := bctx.Cache.Get(nowKey)
|
||||
var now *ast.Term
|
||||
// 2262-04-11T23:47:16.854775807-00:00
|
||||
var maxDateAllowedForNsConversion = time.Unix(0, math.MaxInt64)
|
||||
|
||||
if !ok {
|
||||
curr := time.Now()
|
||||
now = ast.NewTerm(ast.Number(int64ToJSONNumber(curr.UnixNano())))
|
||||
bctx.Cache.Put(nowKey, now)
|
||||
} else {
|
||||
now = exist.(*ast.Term)
|
||||
func toSafeUnixNano(t time.Time, iter func(*ast.Term) error) error {
|
||||
if t.Before(minDateAllowedForNsConversion) || t.After(maxDateAllowedForNsConversion) {
|
||||
return fmt.Errorf("time outside of valid range")
|
||||
}
|
||||
|
||||
return iter(now)
|
||||
return iter(ast.NewTerm(ast.Number(int64ToJSONNumber(t.UnixNano()))))
|
||||
}
|
||||
|
||||
func builtinTimeParseNanos(a, b ast.Value) (ast.Value, error) {
|
||||
func builtinTimeNowNanos(bctx BuiltinContext, _ []*ast.Term, iter func(*ast.Term) error) error {
|
||||
return iter(bctx.Time)
|
||||
}
|
||||
|
||||
func builtinTimeParseNanos(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
a := operands[0].Value
|
||||
format, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
b := operands[1].Value
|
||||
value, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := time.Parse(string(format), string(value))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
return ast.Number(int64ToJSONNumber(result.UnixNano())), nil
|
||||
return toSafeUnixNano(result, iter)
|
||||
}
|
||||
|
||||
func builtinTimeParseRFC3339Nanos(a ast.Value) (ast.Value, error) {
|
||||
|
||||
func builtinTimeParseRFC3339Nanos(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
a := operands[0].Value
|
||||
value, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := time.Parse(time.RFC3339, string(value))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
return ast.Number(int64ToJSONNumber(result.UnixNano())), nil
|
||||
return toSafeUnixNano(result, iter)
|
||||
}
|
||||
func builtinParseDurationNanos(a ast.Value) (ast.Value, error) {
|
||||
|
||||
@@ -91,7 +92,7 @@ func builtinDate(a ast.Value) (ast.Value, error) {
|
||||
return nil, err
|
||||
}
|
||||
year, month, day := t.Date()
|
||||
result := ast.Array{ast.IntNumberTerm(year), ast.IntNumberTerm(int(month)), ast.IntNumberTerm(day)}
|
||||
result := ast.NewArray(ast.IntNumberTerm(year), ast.IntNumberTerm(int(month)), ast.IntNumberTerm(day))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -101,7 +102,7 @@ func builtinClock(a ast.Value) (ast.Value, error) {
|
||||
return nil, err
|
||||
}
|
||||
hour, minute, second := t.Clock()
|
||||
result := ast.Array{ast.IntNumberTerm(hour), ast.IntNumberTerm(minute), ast.IntNumberTerm(second)}
|
||||
result := ast.NewArray(ast.IntNumberTerm(hour), ast.IntNumberTerm(minute), ast.IntNumberTerm(second))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -114,24 +115,115 @@ func builtinWeekday(a ast.Value) (ast.Value, error) {
|
||||
return ast.String(weekday), nil
|
||||
}
|
||||
|
||||
func builtinAddDate(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
t, err := tzTime(operands[0].Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
years, err := builtins.IntOperand(operands[1].Value, 2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
months, err := builtins.IntOperand(operands[2].Value, 3)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
days, err := builtins.IntOperand(operands[3].Value, 4)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result := t.AddDate(years, months, days)
|
||||
|
||||
return toSafeUnixNano(result, iter)
|
||||
}
|
||||
|
||||
func builtinDiff(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
t1, err := tzTime(operands[0].Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t2, err := tzTime(operands[1].Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// The following implementation of this function is taken
|
||||
// from https://github.com/icza/gox licensed under Apache 2.0.
|
||||
// The only modification made is to variable names.
|
||||
//
|
||||
// For details, see https://stackoverflow.com/a/36531443/1705598
|
||||
//
|
||||
// Copyright 2021 icza
|
||||
// BEGIN REDISTRIBUTION FROM APACHE 2.0 LICENSED PROJECT
|
||||
if t1.Location() != t2.Location() {
|
||||
t2 = t2.In(t1.Location())
|
||||
}
|
||||
if t1.After(t2) {
|
||||
t1, t2 = t2, t1
|
||||
}
|
||||
y1, M1, d1 := t1.Date()
|
||||
y2, M2, d2 := t2.Date()
|
||||
|
||||
h1, m1, s1 := t1.Clock()
|
||||
h2, m2, s2 := t2.Clock()
|
||||
|
||||
year := y2 - y1
|
||||
month := int(M2 - M1)
|
||||
day := d2 - d1
|
||||
hour := h2 - h1
|
||||
min := m2 - m1
|
||||
sec := s2 - s1
|
||||
|
||||
// Normalize negative values
|
||||
if sec < 0 {
|
||||
sec += 60
|
||||
min--
|
||||
}
|
||||
if min < 0 {
|
||||
min += 60
|
||||
hour--
|
||||
}
|
||||
if hour < 0 {
|
||||
hour += 24
|
||||
day--
|
||||
}
|
||||
if day < 0 {
|
||||
// Days in month:
|
||||
t := time.Date(y1, M1, 32, 0, 0, 0, 0, time.UTC)
|
||||
day += 32 - t.Day()
|
||||
month--
|
||||
}
|
||||
if month < 0 {
|
||||
month += 12
|
||||
year--
|
||||
}
|
||||
// END REDISTRIBUTION FROM APACHE 2.0 LICENSED PROJECT
|
||||
|
||||
return iter(ast.ArrayTerm(ast.IntNumberTerm(year), ast.IntNumberTerm(month), ast.IntNumberTerm(day),
|
||||
ast.IntNumberTerm(hour), ast.IntNumberTerm(min), ast.IntNumberTerm(sec)))
|
||||
}
|
||||
|
||||
func tzTime(a ast.Value) (t time.Time, err error) {
|
||||
var nVal ast.Value
|
||||
loc := time.UTC
|
||||
|
||||
switch va := a.(type) {
|
||||
case ast.Array:
|
||||
|
||||
if len(va) == 0 {
|
||||
case *ast.Array:
|
||||
if va.Len() == 0 {
|
||||
return time.Time{}, builtins.NewOperandTypeErr(1, a, "either number (ns) or [number (ns), string (tz)]")
|
||||
}
|
||||
|
||||
nVal, err = builtins.NumberOperand(va[0].Value, 1)
|
||||
nVal, err = builtins.NumberOperand(va.Elem(0).Value, 1)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
if len(va) > 1 {
|
||||
tzVal, err := builtins.StringOperand(va[1].Value, 1)
|
||||
if va.Len() > 1 {
|
||||
tzVal, err := builtins.StringOperand(va.Elem(1).Value, 1)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
@@ -192,12 +284,14 @@ func int64ToJSONNumber(i int64) json.Number {
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.NowNanos.Name, builtinTimeNowNanos)
|
||||
RegisterFunctionalBuiltin1(ast.ParseRFC3339Nanos.Name, builtinTimeParseRFC3339Nanos)
|
||||
RegisterFunctionalBuiltin2(ast.ParseNanos.Name, builtinTimeParseNanos)
|
||||
RegisterBuiltinFunc(ast.ParseRFC3339Nanos.Name, builtinTimeParseRFC3339Nanos)
|
||||
RegisterBuiltinFunc(ast.ParseNanos.Name, builtinTimeParseNanos)
|
||||
RegisterFunctionalBuiltin1(ast.ParseDurationNanos.Name, builtinParseDurationNanos)
|
||||
RegisterFunctionalBuiltin1(ast.Date.Name, builtinDate)
|
||||
RegisterFunctionalBuiltin1(ast.Clock.Name, builtinClock)
|
||||
RegisterFunctionalBuiltin1(ast.Weekday.Name, builtinWeekday)
|
||||
RegisterBuiltinFunc(ast.AddDate.Name, builtinAddDate)
|
||||
RegisterBuiltinFunc(ast.Diff.Name, builtinDiff)
|
||||
tzCacheMutex = &sync.Mutex{}
|
||||
tzCache = make(map[string]*time.Location)
|
||||
}
|
||||
|
||||
866
vendor/github.com/open-policy-agent/opa/topdown/tokens.go
generated
vendored
866
vendor/github.com/open-policy-agent/opa/topdown/tokens.go
generated
vendored
File diff suppressed because it is too large
Load Diff
229
vendor/github.com/open-policy-agent/opa/topdown/trace.go
generated
vendored
229
vendor/github.com/open-policy-agent/opa/topdown/trace.go
generated
vendored
@@ -9,10 +9,18 @@ import (
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
iStrs "github.com/open-policy-agent/opa/internal/strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
const (
|
||||
minLocationWidth = 5 // len("query")
|
||||
maxIdealLocationWidth = 64
|
||||
locationPadding = 4
|
||||
)
|
||||
|
||||
// Op defines the types of tracing events.
|
||||
type Op string
|
||||
|
||||
@@ -36,12 +44,20 @@ const (
|
||||
// FailOp is emitted when an expression evaluates to false.
|
||||
FailOp Op = "Fail"
|
||||
|
||||
// DuplicateOp is emitted when a query has produced a duplicate value. The search
|
||||
// will stop at the point where the duplicate was emitted and backtrack.
|
||||
DuplicateOp Op = "Duplicate"
|
||||
|
||||
// NoteOp is emitted when an expression invokes a tracing built-in function.
|
||||
NoteOp Op = "Note"
|
||||
|
||||
// IndexOp is emitted during an expression evaluation to represent lookup
|
||||
// matches.
|
||||
IndexOp Op = "Index"
|
||||
|
||||
// WasmOp is emitted when resolving a ref using an external
|
||||
// Resolver.
|
||||
WasmOp Op = "Wasm"
|
||||
)
|
||||
|
||||
// VarMetadata provides some user facing information about
|
||||
@@ -58,9 +74,13 @@ type Event struct {
|
||||
Location *ast.Location // The location of the Node this event relates to.
|
||||
QueryID uint64 // Identifies the query this event belongs to.
|
||||
ParentID uint64 // Identifies the parent query this event belongs to.
|
||||
Locals *ast.ValueMap // Contains local variable bindings from the query context.
|
||||
LocalMetadata map[ast.Var]VarMetadata // Contains metadata for the local variable bindings.
|
||||
Locals *ast.ValueMap // Contains local variable bindings from the query context. Nil if variables were not included in the trace event.
|
||||
LocalMetadata map[ast.Var]VarMetadata // Contains metadata for the local variable bindings. Nil if variables were not included in the trace event.
|
||||
Message string // Contains message for Note events.
|
||||
Ref *ast.Ref // Identifies the subject ref for the event. Only applies to Index and Wasm operations.
|
||||
|
||||
input *ast.Term
|
||||
bindings *bindings
|
||||
}
|
||||
|
||||
// HasRule returns true if the Event contains an ast.Rule.
|
||||
@@ -102,6 +122,17 @@ func (evt *Event) String() string {
|
||||
return fmt.Sprintf("%v %v %v (qid=%v, pqid=%v)", evt.Op, evt.Node, evt.Locals, evt.QueryID, evt.ParentID)
|
||||
}
|
||||
|
||||
// Input returns the input object as it was at the event.
|
||||
func (evt *Event) Input() *ast.Term {
|
||||
return evt.input
|
||||
}
|
||||
|
||||
// Plug plugs event bindings into the provided ast.Term. Because bindings are mutable, this only makes sense to do when
|
||||
// the event is emitted rather than on recorded trace events as the bindings are going to be different by then.
|
||||
func (evt *Event) Plug(term *ast.Term) *ast.Term {
|
||||
return evt.bindings.Plug(term)
|
||||
}
|
||||
|
||||
func (evt *Event) equalNodes(other *Event) bool {
|
||||
switch a := evt.Node.(type) {
|
||||
case ast.Body:
|
||||
@@ -123,13 +154,53 @@ func (evt *Event) equalNodes(other *Event) bool {
|
||||
}
|
||||
|
||||
// Tracer defines the interface for tracing in the top-down evaluation engine.
|
||||
// Deprecated: Use QueryTracer instead.
|
||||
type Tracer interface {
|
||||
Enabled() bool
|
||||
Trace(*Event)
|
||||
}
|
||||
|
||||
// BufferTracer implements the Tracer interface by simply buffering all events
|
||||
// received.
|
||||
// QueryTracer defines the interface for tracing in the top-down evaluation engine.
|
||||
// The implementation can provide additional configuration to modify the tracing
|
||||
// behavior for query evaluations.
|
||||
type QueryTracer interface {
|
||||
Enabled() bool
|
||||
TraceEvent(Event)
|
||||
Config() TraceConfig
|
||||
}
|
||||
|
||||
// TraceConfig defines some common configuration for Tracer implementations
|
||||
type TraceConfig struct {
|
||||
PlugLocalVars bool // Indicate whether to plug local variable bindings before calling into the tracer.
|
||||
}
|
||||
|
||||
// legacyTracer Implements the QueryTracer interface by wrapping an older Tracer instance.
|
||||
type legacyTracer struct {
|
||||
t Tracer
|
||||
}
|
||||
|
||||
func (l *legacyTracer) Enabled() bool {
|
||||
return l.t.Enabled()
|
||||
}
|
||||
|
||||
func (l *legacyTracer) Config() TraceConfig {
|
||||
return TraceConfig{
|
||||
PlugLocalVars: true, // For backwards compatibility old tracers will plug local variables
|
||||
}
|
||||
}
|
||||
|
||||
func (l *legacyTracer) TraceEvent(evt Event) {
|
||||
l.t.Trace(&evt)
|
||||
}
|
||||
|
||||
// WrapLegacyTracer will create a new QueryTracer which wraps an
|
||||
// older Tracer instance.
|
||||
func WrapLegacyTracer(tracer Tracer) QueryTracer {
|
||||
return &legacyTracer{t: tracer}
|
||||
}
|
||||
|
||||
// BufferTracer implements the Tracer and QueryTracer interface by
|
||||
// simply buffering all events received.
|
||||
type BufferTracer []*Event
|
||||
|
||||
// NewBufferTracer returns a new BufferTracer.
|
||||
@@ -139,17 +210,25 @@ func NewBufferTracer() *BufferTracer {
|
||||
|
||||
// Enabled always returns true if the BufferTracer is instantiated.
|
||||
func (b *BufferTracer) Enabled() bool {
|
||||
if b == nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
return b != nil
|
||||
}
|
||||
|
||||
// Trace adds the event to the buffer.
|
||||
// Deprecated: Use TraceEvent instead.
|
||||
func (b *BufferTracer) Trace(evt *Event) {
|
||||
*b = append(*b, evt)
|
||||
}
|
||||
|
||||
// TraceEvent adds the event to the buffer.
|
||||
func (b *BufferTracer) TraceEvent(evt Event) {
|
||||
*b = append(*b, &evt)
|
||||
}
|
||||
|
||||
// Config returns the Tracers standard configuration
|
||||
func (b *BufferTracer) Config() TraceConfig {
|
||||
return TraceConfig{PlugLocalVars: true}
|
||||
}
|
||||
|
||||
// PrettyTrace pretty prints the trace to the writer.
|
||||
func PrettyTrace(w io.Writer, trace []*Event) {
|
||||
depths := depths{}
|
||||
@@ -162,10 +241,16 @@ func PrettyTrace(w io.Writer, trace []*Event) {
|
||||
// PrettyTraceWithLocation prints the trace to the writer and includes location information
|
||||
func PrettyTraceWithLocation(w io.Writer, trace []*Event) {
|
||||
depths := depths{}
|
||||
|
||||
filePathAliases, longest := getShortenedFileNames(trace)
|
||||
|
||||
// Always include some padding between the trace and location
|
||||
locationWidth := longest + locationPadding
|
||||
|
||||
for _, event := range trace {
|
||||
depth := depths.GetOrSet(event.QueryID, event.ParentID)
|
||||
location := formatLocation(event)
|
||||
fmt.Fprintln(w, fmt.Sprintf("%v %v", location, formatEvent(event, depth)))
|
||||
location := formatLocation(event, filePathAliases)
|
||||
fmt.Fprintf(w, "%-*s %s\n", locationWidth, location, formatEvent(event, depth))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,25 +258,34 @@ func formatEvent(event *Event, depth int) string {
|
||||
padding := formatEventPadding(event, depth)
|
||||
if event.Op == NoteOp {
|
||||
return fmt.Sprintf("%v%v %q", padding, event.Op, event.Message)
|
||||
} else if event.Message != "" {
|
||||
return fmt.Sprintf("%v%v %v %v", padding, event.Op, event.Node, event.Message)
|
||||
} else {
|
||||
switch node := event.Node.(type) {
|
||||
case *ast.Rule:
|
||||
return fmt.Sprintf("%v%v %v", padding, event.Op, node.Path())
|
||||
default:
|
||||
return fmt.Sprintf("%v%v %v", padding, event.Op, rewrite(event).Node)
|
||||
}
|
||||
}
|
||||
|
||||
var details interface{}
|
||||
if node, ok := event.Node.(*ast.Rule); ok {
|
||||
details = node.Path()
|
||||
} else if event.Ref != nil {
|
||||
details = event.Ref
|
||||
} else {
|
||||
details = rewrite(event).Node
|
||||
}
|
||||
|
||||
template := "%v%v %v"
|
||||
opts := []interface{}{padding, event.Op, details}
|
||||
|
||||
if event.Message != "" {
|
||||
template += " %v"
|
||||
opts = append(opts, event.Message)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(template, opts...)
|
||||
}
|
||||
|
||||
func formatEventPadding(event *Event, depth int) string {
|
||||
spaces := formatEventSpaces(event, depth)
|
||||
padding := ""
|
||||
if spaces > 1 {
|
||||
padding += strings.Repeat("| ", spaces-1)
|
||||
return strings.Repeat("| ", spaces-1)
|
||||
}
|
||||
return padding
|
||||
return ""
|
||||
}
|
||||
|
||||
func formatEventSpaces(event *Event, depth int) int {
|
||||
@@ -206,21 +300,67 @@ func formatEventSpaces(event *Event, depth int) int {
|
||||
return depth + 1
|
||||
}
|
||||
|
||||
func formatLocation(event *Event) string {
|
||||
if event.Op == NoteOp {
|
||||
return fmt.Sprintf("%-19v", "note")
|
||||
// getShortenedFileNames will return a map of file paths to shortened aliases
|
||||
// that were found in the trace. It also returns the longest location expected
|
||||
func getShortenedFileNames(trace []*Event) (map[string]string, int) {
|
||||
// Get a deduplicated list of all file paths
|
||||
// and the longest file path size
|
||||
fpAliases := map[string]string{}
|
||||
var canShorten []string
|
||||
longestLocation := 0
|
||||
for _, event := range trace {
|
||||
if event.Location != nil {
|
||||
if event.Location.File != "" {
|
||||
// length of "<name>:<row>"
|
||||
curLen := len(event.Location.File) + numDigits10(event.Location.Row) + 1
|
||||
if curLen > longestLocation {
|
||||
longestLocation = curLen
|
||||
}
|
||||
|
||||
if _, ok := fpAliases[event.Location.File]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
canShorten = append(canShorten, event.Location.File)
|
||||
|
||||
// Default to just alias their full path
|
||||
fpAliases[event.Location.File] = event.Location.File
|
||||
} else {
|
||||
// length of "<min width>:<row>"
|
||||
curLen := minLocationWidth + numDigits10(event.Location.Row) + 1
|
||||
if curLen > longestLocation {
|
||||
longestLocation = curLen
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(canShorten) > 0 && longestLocation > maxIdealLocationWidth {
|
||||
fpAliases, longestLocation = iStrs.TruncateFilePaths(maxIdealLocationWidth, longestLocation, canShorten...)
|
||||
}
|
||||
|
||||
return fpAliases, longestLocation
|
||||
}
|
||||
|
||||
func numDigits10(n int) int {
|
||||
if n < 10 {
|
||||
return 1
|
||||
}
|
||||
return numDigits10(n/10) + 1
|
||||
}
|
||||
|
||||
func formatLocation(event *Event, fileAliases map[string]string) string {
|
||||
|
||||
location := event.Location
|
||||
if location == nil {
|
||||
return fmt.Sprintf("%-19v", "")
|
||||
return ""
|
||||
}
|
||||
|
||||
if location.File == "" {
|
||||
return fmt.Sprintf("%-19v", fmt.Sprintf("%.15v:%v", "query", location.Row))
|
||||
return fmt.Sprintf("query:%v", location.Row)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%-19v", fmt.Sprintf("%.15v:%v", location.File, location.Row))
|
||||
return fmt.Sprintf("%v:%v", fileAliases[location.File], location.Row)
|
||||
}
|
||||
|
||||
// depths is a helper for computing the depth of an event. Events within the
|
||||
@@ -245,33 +385,25 @@ func builtinTrace(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) er
|
||||
return handleBuiltinErr(ast.Trace.Name, bctx.Location, err)
|
||||
}
|
||||
|
||||
if !traceIsEnabled(bctx.Tracers) {
|
||||
if !bctx.TraceEnabled {
|
||||
return iter(ast.BooleanTerm(true))
|
||||
}
|
||||
|
||||
evt := &Event{
|
||||
evt := Event{
|
||||
Op: NoteOp,
|
||||
Location: bctx.Location,
|
||||
QueryID: bctx.QueryID,
|
||||
ParentID: bctx.ParentID,
|
||||
Message: string(str),
|
||||
}
|
||||
|
||||
for i := range bctx.Tracers {
|
||||
bctx.Tracers[i].Trace(evt)
|
||||
for i := range bctx.QueryTracers {
|
||||
bctx.QueryTracers[i].TraceEvent(evt)
|
||||
}
|
||||
|
||||
return iter(ast.BooleanTerm(true))
|
||||
}
|
||||
|
||||
func traceIsEnabled(tracers []Tracer) bool {
|
||||
for i := range tracers {
|
||||
if tracers[i].Enabled() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func rewrite(event *Event) *Event {
|
||||
|
||||
cpy := *event
|
||||
@@ -280,14 +412,25 @@ func rewrite(event *Event) *Event {
|
||||
|
||||
switch v := event.Node.(type) {
|
||||
case *ast.Expr:
|
||||
node = v.Copy()
|
||||
expr := v.Copy()
|
||||
|
||||
// Hide generated local vars in 'key' position that have not been
|
||||
// rewritten.
|
||||
if ev, ok := v.Terms.(*ast.Every); ok {
|
||||
if kv, ok := ev.Key.Value.(ast.Var); ok {
|
||||
if rw, ok := cpy.LocalMetadata[kv]; !ok || rw.Name.IsGenerated() {
|
||||
expr.Terms.(*ast.Every).Key = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
node = expr
|
||||
case ast.Body:
|
||||
node = v.Copy()
|
||||
case *ast.Rule:
|
||||
node = v.Copy()
|
||||
}
|
||||
|
||||
ast.TransformVars(node, func(v ast.Var) (ast.Value, error) {
|
||||
_, _ = ast.TransformVars(node, func(v ast.Var) (ast.Value, error) {
|
||||
if meta, ok := cpy.LocalMetadata[v]; ok {
|
||||
return meta.Name, nil
|
||||
}
|
||||
|
||||
18
vendor/github.com/open-policy-agent/opa/topdown/type.go
generated
vendored
18
vendor/github.com/open-policy-agent/opa/topdown/type.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright 2018 The OPA Authors. All rights reserved.
|
||||
// Copyright 2022 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
@@ -13,7 +13,7 @@ func builtinIsNumber(a ast.Value) (ast.Value, error) {
|
||||
case ast.Number:
|
||||
return ast.Boolean(true), nil
|
||||
default:
|
||||
return nil, BuiltinEmpty{}
|
||||
return ast.Boolean(false), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ func builtinIsString(a ast.Value) (ast.Value, error) {
|
||||
case ast.String:
|
||||
return ast.Boolean(true), nil
|
||||
default:
|
||||
return nil, BuiltinEmpty{}
|
||||
return ast.Boolean(false), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,16 +31,16 @@ func builtinIsBoolean(a ast.Value) (ast.Value, error) {
|
||||
case ast.Boolean:
|
||||
return ast.Boolean(true), nil
|
||||
default:
|
||||
return nil, BuiltinEmpty{}
|
||||
return ast.Boolean(false), nil
|
||||
}
|
||||
}
|
||||
|
||||
func builtinIsArray(a ast.Value) (ast.Value, error) {
|
||||
switch a.(type) {
|
||||
case ast.Array:
|
||||
case *ast.Array:
|
||||
return ast.Boolean(true), nil
|
||||
default:
|
||||
return nil, BuiltinEmpty{}
|
||||
return ast.Boolean(false), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ func builtinIsSet(a ast.Value) (ast.Value, error) {
|
||||
case ast.Set:
|
||||
return ast.Boolean(true), nil
|
||||
default:
|
||||
return nil, BuiltinEmpty{}
|
||||
return ast.Boolean(false), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func builtinIsObject(a ast.Value) (ast.Value, error) {
|
||||
case ast.Object:
|
||||
return ast.Boolean(true), nil
|
||||
default:
|
||||
return nil, BuiltinEmpty{}
|
||||
return ast.Boolean(false), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ func builtinIsNull(a ast.Value) (ast.Value, error) {
|
||||
case ast.Null:
|
||||
return ast.Boolean(true), nil
|
||||
default:
|
||||
return nil, BuiltinEmpty{}
|
||||
return ast.Boolean(false), nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
2
vendor/github.com/open-policy-agent/opa/topdown/type_name.go
generated
vendored
2
vendor/github.com/open-policy-agent/opa/topdown/type_name.go
generated
vendored
@@ -20,7 +20,7 @@ func builtinTypeName(a ast.Value) (ast.Value, error) {
|
||||
return ast.String("number"), nil
|
||||
case ast.String:
|
||||
return ast.String("string"), nil
|
||||
case ast.Array:
|
||||
case *ast.Array:
|
||||
return ast.String("array"), nil
|
||||
case ast.Object:
|
||||
return ast.String("object"), nil
|
||||
|
||||
36
vendor/github.com/open-policy-agent/opa/topdown/uuid.go
generated
vendored
Normal file
36
vendor/github.com/open-policy-agent/opa/topdown/uuid.go
generated
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright 2020 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/internal/uuid"
|
||||
)
|
||||
|
||||
type uuidCachingKey string
|
||||
|
||||
func builtinUUIDRFC4122(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
var key = uuidCachingKey(args[0].Value.String())
|
||||
|
||||
val, ok := bctx.Cache.Get(key)
|
||||
if ok {
|
||||
return iter(val.(*ast.Term))
|
||||
}
|
||||
|
||||
s, err := uuid.New(bctx.Seed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result := ast.NewTerm(ast.String(s))
|
||||
bctx.Cache.Put(key, result)
|
||||
|
||||
return iter(result)
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.UUIDRFC4122.Name, builtinUUIDRFC4122)
|
||||
}
|
||||
58
vendor/github.com/open-policy-agent/opa/topdown/walk.go
generated
vendored
58
vendor/github.com/open-policy-agent/opa/topdown/walk.go
generated
vendored
@@ -8,57 +8,61 @@ import (
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
)
|
||||
|
||||
func evalWalk(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
func evalWalk(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
input := args[0]
|
||||
filter := getOutputPath(args)
|
||||
var path ast.Array
|
||||
return walk(filter, path, input, iter)
|
||||
return walk(filter, nil, input, iter)
|
||||
}
|
||||
|
||||
func walk(filter, path ast.Array, input *ast.Term, iter func(*ast.Term) error) error {
|
||||
func walk(filter, path *ast.Array, input *ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
if len(filter) == 0 {
|
||||
if err := iter(ast.ArrayTerm(ast.NewTerm(path), input)); err != nil {
|
||||
if filter == nil || filter.Len() == 0 {
|
||||
if path == nil {
|
||||
path = ast.NewArray()
|
||||
}
|
||||
|
||||
if err := iter(ast.ArrayTerm(ast.NewTerm(path.Copy()), input)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(filter) > 0 {
|
||||
key := filter[0]
|
||||
filter = filter[1:]
|
||||
if filter != nil && filter.Len() > 0 {
|
||||
key := filter.Elem(0)
|
||||
filter = filter.Slice(1, -1)
|
||||
if key.IsGround() {
|
||||
if term := input.Get(key); term != nil {
|
||||
return walk(filter, append(path, key), term, iter)
|
||||
path = pathAppend(path, key)
|
||||
return walk(filter, path, term, iter)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
switch v := input.Value.(type) {
|
||||
case ast.Array:
|
||||
for i := range v {
|
||||
path = append(path, ast.IntNumberTerm(i))
|
||||
if err := walk(filter, path, v[i], iter); err != nil {
|
||||
case *ast.Array:
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
path = pathAppend(path, ast.IntNumberTerm(i))
|
||||
if err := walk(filter, path, v.Elem(i), iter); err != nil {
|
||||
return err
|
||||
}
|
||||
path = path[:len(path)-1]
|
||||
path = path.Slice(0, path.Len()-1)
|
||||
}
|
||||
case ast.Object:
|
||||
return v.Iter(func(k, v *ast.Term) error {
|
||||
path = append(path, k)
|
||||
path = pathAppend(path, k)
|
||||
if err := walk(filter, path, v, iter); err != nil {
|
||||
return err
|
||||
}
|
||||
path = path[:len(path)-1]
|
||||
path = path.Slice(0, path.Len()-1)
|
||||
return nil
|
||||
})
|
||||
case ast.Set:
|
||||
return v.Iter(func(elem *ast.Term) error {
|
||||
path = append(path, elem)
|
||||
path = pathAppend(path, elem)
|
||||
if err := walk(filter, path, elem, iter); err != nil {
|
||||
return err
|
||||
}
|
||||
path = path[:len(path)-1]
|
||||
path = path.Slice(0, path.Len()-1)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -66,11 +70,19 @@ func walk(filter, path ast.Array, input *ast.Term, iter func(*ast.Term) error) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func getOutputPath(args []*ast.Term) ast.Array {
|
||||
func pathAppend(path *ast.Array, key *ast.Term) *ast.Array {
|
||||
if path == nil {
|
||||
return ast.NewArray(key)
|
||||
}
|
||||
|
||||
return path.Append(key)
|
||||
}
|
||||
|
||||
func getOutputPath(args []*ast.Term) *ast.Array {
|
||||
if len(args) == 2 {
|
||||
if arr, ok := args[1].Value.(ast.Array); ok {
|
||||
if len(arr) == 2 {
|
||||
if path, ok := arr[0].Value.(ast.Array); ok {
|
||||
if arr, ok := args[1].Value.(*ast.Array); ok {
|
||||
if arr.Len() == 2 {
|
||||
if path, ok := arr.Elem(0).Value.(*ast.Array); ok {
|
||||
return path
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user