216
vendor/github.com/open-policy-agent/opa/topdown/aggregates.go
generated
vendored
Normal file
216
vendor/github.com/open-policy-agent/opa/topdown/aggregates.go
generated
vendored
Normal file
@@ -0,0 +1,216 @@
|
||||
// Copyright 2016 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
func builtinCount(a ast.Value) (ast.Value, error) {
|
||||
switch a := a.(type) {
|
||||
case ast.Array:
|
||||
return ast.IntNumberTerm(len(a)).Value, nil
|
||||
case ast.Object:
|
||||
return ast.IntNumberTerm(a.Len()).Value, nil
|
||||
case ast.Set:
|
||||
return ast.IntNumberTerm(a.Len()).Value, nil
|
||||
case ast.String:
|
||||
return ast.IntNumberTerm(len(a)).Value, nil
|
||||
}
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "array", "object", "set")
|
||||
}
|
||||
|
||||
func builtinSum(a ast.Value) (ast.Value, error) {
|
||||
switch a := a.(type) {
|
||||
case ast.Array:
|
||||
sum := big.NewFloat(0)
|
||||
for _, x := range a {
|
||||
n, ok := x.Value.(ast.Number)
|
||||
if !ok {
|
||||
return nil, builtins.NewOperandElementErr(1, a, x.Value, "number")
|
||||
}
|
||||
sum = new(big.Float).Add(sum, builtins.NumberToFloat(n))
|
||||
}
|
||||
return builtins.FloatToNumber(sum), nil
|
||||
case ast.Set:
|
||||
sum := big.NewFloat(0)
|
||||
err := a.Iter(func(x *ast.Term) error {
|
||||
n, ok := x.Value.(ast.Number)
|
||||
if !ok {
|
||||
return builtins.NewOperandElementErr(1, a, x.Value, "number")
|
||||
}
|
||||
sum = new(big.Float).Add(sum, builtins.NumberToFloat(n))
|
||||
return nil
|
||||
})
|
||||
return builtins.FloatToNumber(sum), err
|
||||
}
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "set", "array")
|
||||
}
|
||||
|
||||
func builtinProduct(a ast.Value) (ast.Value, error) {
|
||||
switch a := a.(type) {
|
||||
case ast.Array:
|
||||
product := big.NewFloat(1)
|
||||
for _, x := range a {
|
||||
n, ok := x.Value.(ast.Number)
|
||||
if !ok {
|
||||
return nil, builtins.NewOperandElementErr(1, a, x.Value, "number")
|
||||
}
|
||||
product = new(big.Float).Mul(product, builtins.NumberToFloat(n))
|
||||
}
|
||||
return builtins.FloatToNumber(product), nil
|
||||
case ast.Set:
|
||||
product := big.NewFloat(1)
|
||||
err := a.Iter(func(x *ast.Term) error {
|
||||
n, ok := x.Value.(ast.Number)
|
||||
if !ok {
|
||||
return builtins.NewOperandElementErr(1, a, x.Value, "number")
|
||||
}
|
||||
product = new(big.Float).Mul(product, builtins.NumberToFloat(n))
|
||||
return nil
|
||||
})
|
||||
return builtins.FloatToNumber(product), err
|
||||
}
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "set", "array")
|
||||
}
|
||||
|
||||
func builtinMax(a ast.Value) (ast.Value, error) {
|
||||
switch a := a.(type) {
|
||||
case ast.Array:
|
||||
if len(a) == 0 {
|
||||
return nil, BuiltinEmpty{}
|
||||
}
|
||||
var max = ast.Value(ast.Null{})
|
||||
for i := range a {
|
||||
if ast.Compare(max, a[i].Value) <= 0 {
|
||||
max = a[i].Value
|
||||
}
|
||||
}
|
||||
return max, nil
|
||||
case ast.Set:
|
||||
if a.Len() == 0 {
|
||||
return nil, BuiltinEmpty{}
|
||||
}
|
||||
max, err := a.Reduce(ast.NullTerm(), func(max *ast.Term, elem *ast.Term) (*ast.Term, error) {
|
||||
if ast.Compare(max, elem) <= 0 {
|
||||
return elem, nil
|
||||
}
|
||||
return max, nil
|
||||
})
|
||||
return max.Value, err
|
||||
}
|
||||
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "set", "array")
|
||||
}
|
||||
|
||||
func builtinMin(a ast.Value) (ast.Value, error) {
|
||||
switch a := a.(type) {
|
||||
case ast.Array:
|
||||
if len(a) == 0 {
|
||||
return nil, BuiltinEmpty{}
|
||||
}
|
||||
min := a[0].Value
|
||||
for i := range a {
|
||||
if ast.Compare(min, a[i].Value) >= 0 {
|
||||
min = a[i].Value
|
||||
}
|
||||
}
|
||||
return min, nil
|
||||
case ast.Set:
|
||||
if a.Len() == 0 {
|
||||
return nil, BuiltinEmpty{}
|
||||
}
|
||||
min, err := a.Reduce(ast.NullTerm(), func(min *ast.Term, elem *ast.Term) (*ast.Term, error) {
|
||||
// The null term is considered to be less than any other term,
|
||||
// so in order for min of a set to make sense, we need to check
|
||||
// for it.
|
||||
if min.Value.Compare(ast.Null{}) == 0 {
|
||||
return elem, nil
|
||||
}
|
||||
|
||||
if ast.Compare(min, elem) >= 0 {
|
||||
return elem, nil
|
||||
}
|
||||
return min, nil
|
||||
})
|
||||
return min.Value, err
|
||||
}
|
||||
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "set", "array")
|
||||
}
|
||||
|
||||
func builtinSort(a ast.Value) (ast.Value, error) {
|
||||
switch a := a.(type) {
|
||||
case ast.Array:
|
||||
return a.Sorted(), nil
|
||||
case ast.Set:
|
||||
return a.Sorted(), nil
|
||||
}
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "set", "array")
|
||||
}
|
||||
|
||||
func builtinAll(a ast.Value) (ast.Value, error) {
|
||||
switch val := a.(type) {
|
||||
case ast.Set:
|
||||
res := true
|
||||
match := ast.BooleanTerm(true)
|
||||
val.Foreach(func(term *ast.Term) {
|
||||
if !term.Equal(match) {
|
||||
res = false
|
||||
}
|
||||
})
|
||||
return ast.Boolean(res), nil
|
||||
case ast.Array:
|
||||
res := true
|
||||
match := ast.BooleanTerm(true)
|
||||
for _, term := range val {
|
||||
if !term.Equal(match) {
|
||||
res = false
|
||||
}
|
||||
}
|
||||
return ast.Boolean(res), nil
|
||||
default:
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "array", "set")
|
||||
}
|
||||
}
|
||||
|
||||
func builtinAny(a ast.Value) (ast.Value, error) {
|
||||
switch val := a.(type) {
|
||||
case ast.Set:
|
||||
res := false
|
||||
match := ast.BooleanTerm(true)
|
||||
val.Foreach(func(term *ast.Term) {
|
||||
if term.Equal(match) {
|
||||
res = true
|
||||
}
|
||||
})
|
||||
return ast.Boolean(res), nil
|
||||
case ast.Array:
|
||||
res := false
|
||||
match := ast.BooleanTerm(true)
|
||||
for _, term := range val {
|
||||
if term.Equal(match) {
|
||||
res = true
|
||||
}
|
||||
}
|
||||
return ast.Boolean(res), nil
|
||||
default:
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "array", "set")
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin1(ast.Count.Name, builtinCount)
|
||||
RegisterFunctionalBuiltin1(ast.Sum.Name, builtinSum)
|
||||
RegisterFunctionalBuiltin1(ast.Product.Name, builtinProduct)
|
||||
RegisterFunctionalBuiltin1(ast.Max.Name, builtinMax)
|
||||
RegisterFunctionalBuiltin1(ast.Min.Name, builtinMin)
|
||||
RegisterFunctionalBuiltin1(ast.Sort.Name, builtinSort)
|
||||
RegisterFunctionalBuiltin1(ast.Any.Name, builtinAny)
|
||||
RegisterFunctionalBuiltin1(ast.All.Name, builtinAll)
|
||||
}
|
||||
156
vendor/github.com/open-policy-agent/opa/topdown/arithmetic.go
generated
vendored
Normal file
156
vendor/github.com/open-policy-agent/opa/topdown/arithmetic.go
generated
vendored
Normal file
@@ -0,0 +1,156 @@
|
||||
// Copyright 2016 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
|
||||
"fmt"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
type arithArity1 func(a *big.Float) (*big.Float, error)
|
||||
type arithArity2 func(a, b *big.Float) (*big.Float, error)
|
||||
|
||||
func arithAbs(a *big.Float) (*big.Float, error) {
|
||||
return a.Abs(a), nil
|
||||
}
|
||||
|
||||
var halfAwayFromZero = big.NewFloat(0.5)
|
||||
|
||||
func arithRound(a *big.Float) (*big.Float, error) {
|
||||
var i *big.Int
|
||||
if a.Signbit() {
|
||||
i, _ = new(big.Float).Sub(a, halfAwayFromZero).Int(nil)
|
||||
} else {
|
||||
i, _ = new(big.Float).Add(a, halfAwayFromZero).Int(nil)
|
||||
}
|
||||
return new(big.Float).SetInt(i), nil
|
||||
}
|
||||
|
||||
func arithPlus(a, b *big.Float) (*big.Float, error) {
|
||||
return new(big.Float).Add(a, b), nil
|
||||
}
|
||||
|
||||
func arithMinus(a, b *big.Float) (*big.Float, error) {
|
||||
return new(big.Float).Sub(a, b), nil
|
||||
}
|
||||
|
||||
func arithMultiply(a, b *big.Float) (*big.Float, error) {
|
||||
return new(big.Float).Mul(a, b), nil
|
||||
}
|
||||
|
||||
func arithDivide(a, b *big.Float) (*big.Float, error) {
|
||||
i, acc := b.Int64()
|
||||
if acc == big.Exact && i == 0 {
|
||||
return nil, fmt.Errorf("divide by zero")
|
||||
}
|
||||
return new(big.Float).Quo(a, b), nil
|
||||
}
|
||||
|
||||
func arithRem(a, b *big.Int) (*big.Int, error) {
|
||||
if b.Int64() == 0 {
|
||||
return nil, fmt.Errorf("modulo by zero")
|
||||
}
|
||||
return new(big.Int).Rem(a, b), nil
|
||||
}
|
||||
|
||||
func builtinArithArity1(fn arithArity1) FunctionalBuiltin1 {
|
||||
return func(a ast.Value) (ast.Value, error) {
|
||||
n, err := builtins.NumberOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f, err := fn(builtins.NumberToFloat(n))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return builtins.FloatToNumber(f), nil
|
||||
}
|
||||
}
|
||||
|
||||
func builtinArithArity2(fn arithArity2) FunctionalBuiltin2 {
|
||||
return func(a, b ast.Value) (ast.Value, error) {
|
||||
n1, err := builtins.NumberOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n2, err := builtins.NumberOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f, err := fn(builtins.NumberToFloat(n1), builtins.NumberToFloat(n2))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return builtins.FloatToNumber(f), nil
|
||||
}
|
||||
}
|
||||
|
||||
func builtinMinus(a, b ast.Value) (ast.Value, error) {
|
||||
|
||||
n1, ok1 := a.(ast.Number)
|
||||
n2, ok2 := b.(ast.Number)
|
||||
|
||||
if ok1 && ok2 {
|
||||
f, err := arithMinus(builtins.NumberToFloat(n1), builtins.NumberToFloat(n2))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return builtins.FloatToNumber(f), nil
|
||||
}
|
||||
|
||||
s1, ok3 := a.(ast.Set)
|
||||
s2, ok4 := b.(ast.Set)
|
||||
|
||||
if ok3 && ok4 {
|
||||
return s1.Diff(s2), nil
|
||||
}
|
||||
|
||||
if !ok1 && !ok3 {
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "number", "set")
|
||||
}
|
||||
|
||||
return nil, builtins.NewOperandTypeErr(2, b, "number", "set")
|
||||
}
|
||||
|
||||
func builtinRem(a, b ast.Value) (ast.Value, error) {
|
||||
n1, ok1 := a.(ast.Number)
|
||||
n2, ok2 := b.(ast.Number)
|
||||
|
||||
if ok1 && ok2 {
|
||||
|
||||
op1, err1 := builtins.NumberToInt(n1)
|
||||
op2, err2 := builtins.NumberToInt(n2)
|
||||
|
||||
if err1 != nil || err2 != nil {
|
||||
return nil, fmt.Errorf("modulo on floating-point number")
|
||||
}
|
||||
|
||||
i, err := arithRem(op1, op2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return builtins.IntToNumber(i), nil
|
||||
}
|
||||
|
||||
if !ok1 {
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "number")
|
||||
}
|
||||
|
||||
return nil, builtins.NewOperandTypeErr(2, b, "number")
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin1(ast.Abs.Name, builtinArithArity1(arithAbs))
|
||||
RegisterFunctionalBuiltin1(ast.Round.Name, builtinArithArity1(arithRound))
|
||||
RegisterFunctionalBuiltin2(ast.Plus.Name, builtinArithArity2(arithPlus))
|
||||
RegisterFunctionalBuiltin2(ast.Minus.Name, builtinMinus)
|
||||
RegisterFunctionalBuiltin2(ast.Multiply.Name, builtinArithArity2(arithMultiply))
|
||||
RegisterFunctionalBuiltin2(ast.Divide.Name, builtinArithArity2(arithDivide))
|
||||
RegisterFunctionalBuiltin2(ast.Rem.Name, builtinRem)
|
||||
}
|
||||
75
vendor/github.com/open-policy-agent/opa/topdown/array.go
generated
vendored
Normal file
75
vendor/github.com/open-policy-agent/opa/topdown/array.go
generated
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
// Copyright 2018 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
func builtinArrayConcat(a, b ast.Value) (ast.Value, error) {
|
||||
arrA, err := builtins.ArrayOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
arrB, err := builtins.ArrayOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
arrC := make(ast.Array, 0, len(arrA)+len(arrB))
|
||||
|
||||
for _, elemA := range arrA {
|
||||
arrC = append(arrC, elemA)
|
||||
}
|
||||
|
||||
for _, elemB := range arrB {
|
||||
arrC = append(arrC, elemB)
|
||||
}
|
||||
|
||||
return arrC, nil
|
||||
}
|
||||
|
||||
func builtinArraySlice(a, i, j ast.Value) (ast.Value, error) {
|
||||
arr, err := builtins.ArrayOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startIndex, err := builtins.IntOperand(i, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stopIndex, err := builtins.IntOperand(j, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return empty array if bounds cannot be clamped sensibly.
|
||||
if (startIndex >= stopIndex) || (startIndex <= 0 && stopIndex <= 0) {
|
||||
return arr[0:0], nil
|
||||
}
|
||||
|
||||
// Clamp bounds to avoid out-of-range errors.
|
||||
if startIndex < 0 {
|
||||
startIndex = 0
|
||||
}
|
||||
|
||||
if stopIndex > len(arr) {
|
||||
stopIndex = len(arr)
|
||||
}
|
||||
|
||||
arrb := arr[startIndex:stopIndex]
|
||||
|
||||
return arrb, nil
|
||||
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin2(ast.ArrayConcat.Name, builtinArrayConcat)
|
||||
RegisterFunctionalBuiltin3(ast.ArraySlice.Name, builtinArraySlice)
|
||||
}
|
||||
45
vendor/github.com/open-policy-agent/opa/topdown/binary.go
generated
vendored
Normal file
45
vendor/github.com/open-policy-agent/opa/topdown/binary.go
generated
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright 2017 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
func builtinBinaryAnd(a ast.Value, b ast.Value) (ast.Value, error) {
|
||||
|
||||
s1, err := builtins.SetOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s2, err := builtins.SetOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s1.Intersect(s2), nil
|
||||
}
|
||||
|
||||
func builtinBinaryOr(a ast.Value, b ast.Value) (ast.Value, error) {
|
||||
|
||||
s1, err := builtins.SetOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s2, err := builtins.SetOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s1.Union(s2), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin2(ast.And.Name, builtinBinaryAnd)
|
||||
RegisterFunctionalBuiltin2(ast.Or.Name, builtinBinaryOr)
|
||||
}
|
||||
387
vendor/github.com/open-policy-agent/opa/topdown/bindings.go
generated
vendored
Normal file
387
vendor/github.com/open-policy-agent/opa/topdown/bindings.go
generated
vendored
Normal file
@@ -0,0 +1,387 @@
|
||||
// Copyright 2017 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
)
|
||||
|
||||
type undo struct {
|
||||
k *ast.Term
|
||||
u *bindings
|
||||
next *undo
|
||||
}
|
||||
|
||||
func (u *undo) Undo() {
|
||||
if u == nil {
|
||||
// Allow call on zero value of Undo for ease-of-use.
|
||||
return
|
||||
}
|
||||
if u.u == nil {
|
||||
// Call on empty unifier undos a no-op unify operation.
|
||||
return
|
||||
}
|
||||
u.u.delete(u.k)
|
||||
u.next.Undo()
|
||||
}
|
||||
|
||||
type bindings struct {
|
||||
id uint64
|
||||
values bindingsArrayHashmap
|
||||
instr *Instrumentation
|
||||
}
|
||||
|
||||
func newBindings(id uint64, instr *Instrumentation) *bindings {
|
||||
values := newBindingsArrayHashmap()
|
||||
return &bindings{id, values, instr}
|
||||
}
|
||||
|
||||
func (u *bindings) Iter(caller *bindings, iter func(*ast.Term, *ast.Term) error) error {
|
||||
|
||||
var err error
|
||||
|
||||
u.values.Iter(func(k *ast.Term, v value) bool {
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
err = iter(k, u.PlugNamespaced(k, caller))
|
||||
|
||||
return false
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *bindings) Namespace(x ast.Node, caller *bindings) {
|
||||
vis := namespacingVisitor{
|
||||
b: u,
|
||||
caller: caller,
|
||||
}
|
||||
ast.NewGenericVisitor(vis.Visit).Walk(x)
|
||||
}
|
||||
|
||||
func (u *bindings) Plug(a *ast.Term) *ast.Term {
|
||||
return u.PlugNamespaced(a, nil)
|
||||
}
|
||||
|
||||
func (u *bindings) PlugNamespaced(a *ast.Term, caller *bindings) *ast.Term {
|
||||
if u != nil {
|
||||
u.instr.startTimer(evalOpPlug)
|
||||
t := u.plugNamespaced(a, caller)
|
||||
u.instr.stopTimer(evalOpPlug)
|
||||
return t
|
||||
}
|
||||
|
||||
return u.plugNamespaced(a, caller)
|
||||
}
|
||||
|
||||
func (u *bindings) plugNamespaced(a *ast.Term, caller *bindings) *ast.Term {
|
||||
switch v := a.Value.(type) {
|
||||
case ast.Var:
|
||||
b, next := u.apply(a)
|
||||
if a != b || u != next {
|
||||
return next.plugNamespaced(b, caller)
|
||||
}
|
||||
return u.namespaceVar(b, caller)
|
||||
case ast.Array:
|
||||
cpy := *a
|
||||
arr := make(ast.Array, len(v))
|
||||
for i := 0; i < len(arr); i++ {
|
||||
arr[i] = u.plugNamespaced(v[i], caller)
|
||||
}
|
||||
cpy.Value = arr
|
||||
return &cpy
|
||||
case ast.Object:
|
||||
if a.IsGround() {
|
||||
return a
|
||||
}
|
||||
cpy := *a
|
||||
cpy.Value, _ = v.Map(func(k, v *ast.Term) (*ast.Term, *ast.Term, error) {
|
||||
return u.plugNamespaced(k, caller), u.plugNamespaced(v, caller), nil
|
||||
})
|
||||
return &cpy
|
||||
case ast.Set:
|
||||
cpy := *a
|
||||
cpy.Value, _ = v.Map(func(x *ast.Term) (*ast.Term, error) {
|
||||
return u.plugNamespaced(x, caller), nil
|
||||
})
|
||||
return &cpy
|
||||
case ast.Ref:
|
||||
cpy := *a
|
||||
ref := make(ast.Ref, len(v))
|
||||
for i := 0; i < len(ref); i++ {
|
||||
ref[i] = u.plugNamespaced(v[i], caller)
|
||||
}
|
||||
cpy.Value = ref
|
||||
return &cpy
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func (u *bindings) bind(a *ast.Term, b *ast.Term, other *bindings) *undo {
|
||||
u.values.Put(a, value{
|
||||
u: other,
|
||||
v: b,
|
||||
})
|
||||
return &undo{a, u, nil}
|
||||
}
|
||||
|
||||
func (u *bindings) apply(a *ast.Term) (*ast.Term, *bindings) {
|
||||
// Early exit for non-var terms. Only vars are bound in the binding list,
|
||||
// so the lookup below will always fail for non-var terms. In some cases,
|
||||
// the lookup may be expensive as it has to hash the term (which for large
|
||||
// inputs can be costly.)
|
||||
_, ok := a.Value.(ast.Var)
|
||||
if !ok {
|
||||
return a, u
|
||||
}
|
||||
val, ok := u.get(a)
|
||||
if !ok {
|
||||
return a, u
|
||||
}
|
||||
return val.u.apply(val.v)
|
||||
}
|
||||
|
||||
func (u *bindings) delete(v *ast.Term) {
|
||||
u.values.Delete(v)
|
||||
}
|
||||
|
||||
func (u *bindings) get(v *ast.Term) (value, bool) {
|
||||
if u == nil {
|
||||
return value{}, false
|
||||
}
|
||||
return u.values.Get(v)
|
||||
}
|
||||
|
||||
func (u *bindings) String() string {
|
||||
if u == nil {
|
||||
return "()"
|
||||
}
|
||||
var buf []string
|
||||
u.values.Iter(func(a *ast.Term, b value) bool {
|
||||
buf = append(buf, fmt.Sprintf("%v: %v", a, b))
|
||||
return false
|
||||
})
|
||||
return fmt.Sprintf("({%v}, %v)", strings.Join(buf, ", "), u.id)
|
||||
}
|
||||
|
||||
func (u *bindings) namespaceVar(v *ast.Term, caller *bindings) *ast.Term {
|
||||
name, ok := v.Value.(ast.Var)
|
||||
if !ok {
|
||||
panic("illegal value")
|
||||
}
|
||||
if caller != nil && caller != u {
|
||||
// Root documents (i.e., data, input) should never be namespaced because they
|
||||
// are globally unique.
|
||||
if !ast.RootDocumentNames.Contains(v) {
|
||||
return ast.NewTerm(ast.Var(string(name) + fmt.Sprint(u.id)))
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
type value struct {
|
||||
u *bindings
|
||||
v *ast.Term
|
||||
}
|
||||
|
||||
func (v value) String() string {
|
||||
return fmt.Sprintf("(%v, %d)", v.v, v.u.id)
|
||||
}
|
||||
|
||||
func (v value) equal(other *value) bool {
|
||||
if v.u == other.u {
|
||||
return v.v.Equal(other.v)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type namespacingVisitor struct {
|
||||
b *bindings
|
||||
caller *bindings
|
||||
}
|
||||
|
||||
func (vis namespacingVisitor) Visit(x interface{}) bool {
|
||||
switch x := x.(type) {
|
||||
case *ast.ArrayComprehension:
|
||||
x.Term = vis.namespaceTerm(x.Term)
|
||||
ast.NewGenericVisitor(vis.Visit).Walk(x.Body)
|
||||
return true
|
||||
case *ast.SetComprehension:
|
||||
x.Term = vis.namespaceTerm(x.Term)
|
||||
ast.NewGenericVisitor(vis.Visit).Walk(x.Body)
|
||||
return true
|
||||
case *ast.ObjectComprehension:
|
||||
x.Key = vis.namespaceTerm(x.Key)
|
||||
x.Value = vis.namespaceTerm(x.Value)
|
||||
ast.NewGenericVisitor(vis.Visit).Walk(x.Body)
|
||||
return true
|
||||
case *ast.Expr:
|
||||
switch terms := x.Terms.(type) {
|
||||
case []*ast.Term:
|
||||
for i := 1; i < len(terms); i++ {
|
||||
terms[i] = vis.namespaceTerm(terms[i])
|
||||
}
|
||||
case *ast.Term:
|
||||
x.Terms = vis.namespaceTerm(terms)
|
||||
}
|
||||
for _, w := range x.With {
|
||||
w.Target = vis.namespaceTerm(w.Target)
|
||||
w.Value = vis.namespaceTerm(w.Value)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (vis namespacingVisitor) namespaceTerm(a *ast.Term) *ast.Term {
|
||||
switch v := a.Value.(type) {
|
||||
case ast.Var:
|
||||
return vis.b.namespaceVar(a, vis.caller)
|
||||
case ast.Array:
|
||||
cpy := *a
|
||||
arr := make(ast.Array, len(v))
|
||||
for i := 0; i < len(arr); i++ {
|
||||
arr[i] = vis.namespaceTerm(v[i])
|
||||
}
|
||||
cpy.Value = arr
|
||||
return &cpy
|
||||
case ast.Object:
|
||||
if a.IsGround() {
|
||||
return a
|
||||
}
|
||||
cpy := *a
|
||||
cpy.Value, _ = v.Map(func(k, v *ast.Term) (*ast.Term, *ast.Term, error) {
|
||||
return vis.namespaceTerm(k), vis.namespaceTerm(v), nil
|
||||
})
|
||||
return &cpy
|
||||
case ast.Set:
|
||||
cpy := *a
|
||||
cpy.Value, _ = v.Map(func(x *ast.Term) (*ast.Term, error) {
|
||||
return vis.namespaceTerm(x), nil
|
||||
})
|
||||
return &cpy
|
||||
case ast.Ref:
|
||||
cpy := *a
|
||||
ref := make(ast.Ref, len(v))
|
||||
for i := 0; i < len(ref); i++ {
|
||||
ref[i] = vis.namespaceTerm(v[i])
|
||||
}
|
||||
cpy.Value = ref
|
||||
return &cpy
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
const maxLinearScan = 16
|
||||
|
||||
// bindingsArrayHashMap uses an array with linear scan instead of a hash map for smaller # of entries. Hash maps start to show off their performance advantage only after 16 keys.
|
||||
type bindingsArrayHashmap struct {
|
||||
n int // Entries in the array.
|
||||
a *[maxLinearScan]bindingArrayKeyValue
|
||||
m map[ast.Var]bindingArrayKeyValue
|
||||
}
|
||||
|
||||
type bindingArrayKeyValue struct {
|
||||
key *ast.Term
|
||||
value value
|
||||
}
|
||||
|
||||
func newBindingsArrayHashmap() bindingsArrayHashmap {
|
||||
return bindingsArrayHashmap{}
|
||||
}
|
||||
|
||||
func (b *bindingsArrayHashmap) Put(key *ast.Term, value value) {
|
||||
if b.m == nil {
|
||||
if b.a == nil {
|
||||
b.a = new([maxLinearScan]bindingArrayKeyValue)
|
||||
} else if i := b.find(key); i >= 0 {
|
||||
(*b.a)[i].value = value
|
||||
return
|
||||
}
|
||||
|
||||
if b.n < maxLinearScan {
|
||||
(*b.a)[b.n] = bindingArrayKeyValue{key, value}
|
||||
b.n++
|
||||
return
|
||||
}
|
||||
|
||||
// Array is full, revert to using the hash map instead.
|
||||
|
||||
b.m = make(map[ast.Var]bindingArrayKeyValue, maxLinearScan+1)
|
||||
for _, kv := range *b.a {
|
||||
b.m[kv.key.Value.(ast.Var)] = bindingArrayKeyValue{kv.key, kv.value}
|
||||
}
|
||||
b.m[key.Value.(ast.Var)] = bindingArrayKeyValue{key, value}
|
||||
|
||||
b.n = 0
|
||||
return
|
||||
}
|
||||
|
||||
b.m[key.Value.(ast.Var)] = bindingArrayKeyValue{key, value}
|
||||
}
|
||||
|
||||
func (b *bindingsArrayHashmap) Get(key *ast.Term) (value, bool) {
|
||||
if b.m == nil {
|
||||
if i := b.find(key); i >= 0 {
|
||||
return (*b.a)[i].value, true
|
||||
}
|
||||
|
||||
return value{}, false
|
||||
}
|
||||
|
||||
v, ok := b.m[key.Value.(ast.Var)]
|
||||
if ok {
|
||||
return v.value, true
|
||||
}
|
||||
|
||||
return value{}, false
|
||||
}
|
||||
|
||||
func (b *bindingsArrayHashmap) Delete(key *ast.Term) {
|
||||
if b.m == nil {
|
||||
if i := b.find(key); i >= 0 {
|
||||
n := b.n - 1
|
||||
if i < n {
|
||||
(*b.a)[i] = (*b.a)[n]
|
||||
}
|
||||
|
||||
b.n = n
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
delete(b.m, key.Value.(ast.Var))
|
||||
}
|
||||
|
||||
func (b *bindingsArrayHashmap) Iter(f func(k *ast.Term, v value) bool) {
|
||||
if b.m == nil {
|
||||
for i := 0; i < b.n; i++ {
|
||||
if f((*b.a)[i].key, (*b.a)[i].value) {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, v := range b.m {
|
||||
if f(v.key, v.value) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bindingsArrayHashmap) find(key *ast.Term) int {
|
||||
v := key.Value.(ast.Var)
|
||||
for i := 0; i < b.n; i++ {
|
||||
if (*b.a)[i].key.Value.(ast.Var) == v {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
88
vendor/github.com/open-policy-agent/opa/topdown/bits.go
generated
vendored
Normal file
88
vendor/github.com/open-policy-agent/opa/topdown/bits.go
generated
vendored
Normal file
@@ -0,0 +1,88 @@
|
||||
// Copyright 2020 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
type bitsArity1 func(a *big.Int) (*big.Int, error)
|
||||
type bitsArity2 func(a, b *big.Int) (*big.Int, error)
|
||||
|
||||
func bitsOr(a, b *big.Int) (*big.Int, error) {
|
||||
return new(big.Int).Or(a, b), nil
|
||||
}
|
||||
|
||||
func bitsAnd(a, b *big.Int) (*big.Int, error) {
|
||||
return new(big.Int).And(a, b), nil
|
||||
}
|
||||
|
||||
func bitsNegate(a *big.Int) (*big.Int, error) {
|
||||
return new(big.Int).Not(a), nil
|
||||
}
|
||||
|
||||
func bitsXOr(a, b *big.Int) (*big.Int, error) {
|
||||
return new(big.Int).Xor(a, b), nil
|
||||
}
|
||||
|
||||
func bitsShiftLeft(a, b *big.Int) (*big.Int, error) {
|
||||
if b.Sign() == -1 {
|
||||
return nil, builtins.NewOperandErr(2, "must be an unsigned integer number but got a negative integer")
|
||||
}
|
||||
shift := uint(b.Uint64())
|
||||
return new(big.Int).Lsh(a, shift), nil
|
||||
}
|
||||
|
||||
func bitsShiftRight(a, b *big.Int) (*big.Int, error) {
|
||||
if b.Sign() == -1 {
|
||||
return nil, builtins.NewOperandErr(2, "must be an unsigned integer number but got a negative integer")
|
||||
}
|
||||
shift := uint(b.Uint64())
|
||||
return new(big.Int).Rsh(a, shift), nil
|
||||
}
|
||||
|
||||
func builtinBitsArity1(fn bitsArity1) BuiltinFunc {
|
||||
return func(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
i, err := builtins.BigIntOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
iOut, err := fn(i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return iter(ast.NewTerm(builtins.IntToNumber(iOut)))
|
||||
}
|
||||
}
|
||||
|
||||
func builtinBitsArity2(fn bitsArity2) BuiltinFunc {
|
||||
return func(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
i1, err := builtins.BigIntOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i2, err := builtins.BigIntOperand(operands[1].Value, 2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
iOut, err := fn(i1, i2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return iter(ast.NewTerm(builtins.IntToNumber(iOut)))
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.BitsOr.Name, builtinBitsArity2(bitsOr))
|
||||
RegisterBuiltinFunc(ast.BitsAnd.Name, builtinBitsArity2(bitsAnd))
|
||||
RegisterBuiltinFunc(ast.BitsNegate.Name, builtinBitsArity1(bitsNegate))
|
||||
RegisterBuiltinFunc(ast.BitsXOr.Name, builtinBitsArity2(bitsXOr))
|
||||
RegisterBuiltinFunc(ast.BitsShiftLeft.Name, builtinBitsArity2(bitsShiftLeft))
|
||||
RegisterBuiltinFunc(ast.BitsShiftRight.Name, builtinBitsArity2(bitsShiftRight))
|
||||
}
|
||||
160
vendor/github.com/open-policy-agent/opa/topdown/builtins.go
generated
vendored
Normal file
160
vendor/github.com/open-policy-agent/opa/topdown/builtins.go
generated
vendored
Normal file
@@ -0,0 +1,160 @@
|
||||
// Copyright 2016 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
type (
|
||||
// FunctionalBuiltin1 is deprecated. Use BuiltinFunc instead.
|
||||
FunctionalBuiltin1 func(op1 ast.Value) (output ast.Value, err error)
|
||||
|
||||
// FunctionalBuiltin2 is deprecated. Use BuiltinFunc instead.
|
||||
FunctionalBuiltin2 func(op1, op2 ast.Value) (output ast.Value, err error)
|
||||
|
||||
// FunctionalBuiltin3 is deprecated. Use BuiltinFunc instead.
|
||||
FunctionalBuiltin3 func(op1, op2, op3 ast.Value) (output ast.Value, err error)
|
||||
|
||||
// FunctionalBuiltin4 is deprecated. Use BuiltinFunc instead.
|
||||
FunctionalBuiltin4 func(op1, op2, op3, op4 ast.Value) (output ast.Value, err error)
|
||||
|
||||
// BuiltinContext contains context from the evaluator that may be used by
|
||||
// built-in functions.
|
||||
BuiltinContext struct {
|
||||
Context context.Context // request context that was passed when query started
|
||||
Cancel Cancel // atomic value that signals evaluation to halt
|
||||
Runtime *ast.Term // runtime information on the OPA instance
|
||||
Cache builtins.Cache // built-in function state cache
|
||||
Location *ast.Location // location of built-in call
|
||||
Tracers []Tracer // tracer objects for trace() built-in function
|
||||
QueryID uint64 // identifies query being evaluated
|
||||
ParentID uint64 // identifies parent of query being evaluated
|
||||
}
|
||||
|
||||
// BuiltinFunc defines an interface for implementing built-in functions.
|
||||
// The built-in function is called with the plugged operands from the call
|
||||
// (including the output operands.) The implementation should evaluate the
|
||||
// operands and invoke the iterator for each successful/defined output
|
||||
// value.
|
||||
BuiltinFunc func(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error
|
||||
)
|
||||
|
||||
// RegisterBuiltinFunc adds a new built-in function to the evaluation engine.
|
||||
func RegisterBuiltinFunc(name string, f BuiltinFunc) {
|
||||
builtinFunctions[name] = builtinErrorWrapper(name, f)
|
||||
}
|
||||
|
||||
// RegisterFunctionalBuiltin1 is deprecated use RegisterBuiltinFunc instead.
|
||||
func RegisterFunctionalBuiltin1(name string, fun FunctionalBuiltin1) {
|
||||
builtinFunctions[name] = functionalWrapper1(name, fun)
|
||||
}
|
||||
|
||||
// RegisterFunctionalBuiltin2 is deprecated use RegisterBuiltinFunc instead.
|
||||
func RegisterFunctionalBuiltin2(name string, fun FunctionalBuiltin2) {
|
||||
builtinFunctions[name] = functionalWrapper2(name, fun)
|
||||
}
|
||||
|
||||
// RegisterFunctionalBuiltin3 is deprecated use RegisterBuiltinFunc instead.
|
||||
func RegisterFunctionalBuiltin3(name string, fun FunctionalBuiltin3) {
|
||||
builtinFunctions[name] = functionalWrapper3(name, fun)
|
||||
}
|
||||
|
||||
// RegisterFunctionalBuiltin4 is deprecated use RegisterBuiltinFunc instead.
|
||||
func RegisterFunctionalBuiltin4(name string, fun FunctionalBuiltin4) {
|
||||
builtinFunctions[name] = functionalWrapper4(name, fun)
|
||||
}
|
||||
|
||||
// GetBuiltin returns a built-in function implementation, nil if no built-in found.
|
||||
func GetBuiltin(name string) BuiltinFunc {
|
||||
return builtinFunctions[name]
|
||||
}
|
||||
|
||||
// BuiltinEmpty is deprecated.
|
||||
type BuiltinEmpty struct{}
|
||||
|
||||
func (BuiltinEmpty) Error() string {
|
||||
return "<empty>"
|
||||
}
|
||||
|
||||
var builtinFunctions = map[string]BuiltinFunc{}
|
||||
|
||||
func builtinErrorWrapper(name string, fn BuiltinFunc) BuiltinFunc {
|
||||
return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
err := fn(bctx, args, iter)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return handleBuiltinErr(name, bctx.Location, err)
|
||||
}
|
||||
}
|
||||
|
||||
func functionalWrapper1(name string, fn FunctionalBuiltin1) BuiltinFunc {
|
||||
return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
result, err := fn(args[0].Value)
|
||||
if err == nil {
|
||||
return iter(ast.NewTerm(result))
|
||||
}
|
||||
return handleBuiltinErr(name, bctx.Location, err)
|
||||
}
|
||||
}
|
||||
|
||||
func functionalWrapper2(name string, fn FunctionalBuiltin2) BuiltinFunc {
|
||||
return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
result, err := fn(args[0].Value, args[1].Value)
|
||||
if err == nil {
|
||||
return iter(ast.NewTerm(result))
|
||||
}
|
||||
return handleBuiltinErr(name, bctx.Location, err)
|
||||
}
|
||||
}
|
||||
|
||||
func functionalWrapper3(name string, fn FunctionalBuiltin3) BuiltinFunc {
|
||||
return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
result, err := fn(args[0].Value, args[1].Value, args[2].Value)
|
||||
if err == nil {
|
||||
return iter(ast.NewTerm(result))
|
||||
}
|
||||
return handleBuiltinErr(name, bctx.Location, err)
|
||||
}
|
||||
}
|
||||
|
||||
func functionalWrapper4(name string, fn FunctionalBuiltin4) BuiltinFunc {
|
||||
return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
result, err := fn(args[0].Value, args[1].Value, args[2].Value, args[3].Value)
|
||||
if err == nil {
|
||||
return iter(ast.NewTerm(result))
|
||||
}
|
||||
if _, empty := err.(BuiltinEmpty); empty {
|
||||
return nil
|
||||
}
|
||||
return handleBuiltinErr(name, bctx.Location, err)
|
||||
}
|
||||
}
|
||||
|
||||
func handleBuiltinErr(name string, loc *ast.Location, err error) error {
|
||||
switch err := err.(type) {
|
||||
case BuiltinEmpty:
|
||||
return nil
|
||||
case *Error:
|
||||
return err
|
||||
case builtins.ErrOperand:
|
||||
return &Error{
|
||||
Code: TypeErr,
|
||||
Message: fmt.Sprintf("%v: %v", string(name), err.Error()),
|
||||
Location: loc,
|
||||
}
|
||||
default:
|
||||
return &Error{
|
||||
Code: BuiltinErr,
|
||||
Message: fmt.Sprintf("%v: %v", string(name), err.Error()),
|
||||
Location: loc,
|
||||
}
|
||||
}
|
||||
}
|
||||
235
vendor/github.com/open-policy-agent/opa/topdown/builtins/builtins.go
generated
vendored
Normal file
235
vendor/github.com/open-policy-agent/opa/topdown/builtins/builtins.go
generated
vendored
Normal file
@@ -0,0 +1,235 @@
|
||||
// Copyright 2016 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package builtins contains utilities for implementing built-in functions.
|
||||
package builtins
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
)
|
||||
|
||||
// Cache defines the built-in cache used by the top-down evaluation. The keys
|
||||
// must be comparable and should not be of type string.
|
||||
type Cache map[interface{}]interface{}
|
||||
|
||||
// Put updates the cache for the named built-in.
|
||||
func (c Cache) Put(k, v interface{}) {
|
||||
c[k] = v
|
||||
}
|
||||
|
||||
// Get returns the cached value for k.
|
||||
func (c Cache) Get(k interface{}) (interface{}, bool) {
|
||||
v, ok := c[k]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// ErrOperand represents an invalid operand has been passed to a built-in
|
||||
// function. Built-ins should return ErrOperand to indicate a type error has
|
||||
// occurred.
|
||||
type ErrOperand string
|
||||
|
||||
func (err ErrOperand) Error() string {
|
||||
return string(err)
|
||||
}
|
||||
|
||||
// NewOperandErr returns a generic operand error.
|
||||
func NewOperandErr(pos int, f string, a ...interface{}) error {
|
||||
f = fmt.Sprintf("operand %v ", pos) + f
|
||||
return ErrOperand(fmt.Sprintf(f, a...))
|
||||
}
|
||||
|
||||
// NewOperandTypeErr returns an operand error indicating the operand's type was wrong.
|
||||
func NewOperandTypeErr(pos int, got ast.Value, expected ...string) error {
|
||||
|
||||
if len(expected) == 1 {
|
||||
return NewOperandErr(pos, "must be %v but got %v", expected[0], ast.TypeName(got))
|
||||
}
|
||||
|
||||
return NewOperandErr(pos, "must be one of {%v} but got %v", strings.Join(expected, ", "), ast.TypeName(got))
|
||||
}
|
||||
|
||||
// NewOperandElementErr returns an operand error indicating an element in the
|
||||
// composite operand was wrong.
|
||||
func NewOperandElementErr(pos int, composite ast.Value, got ast.Value, expected ...string) error {
|
||||
|
||||
tpe := ast.TypeName(composite)
|
||||
|
||||
if len(expected) == 1 {
|
||||
return NewOperandErr(pos, "must be %v of %vs but got %v containing %v", tpe, expected[0], tpe, ast.TypeName(got))
|
||||
}
|
||||
|
||||
return NewOperandErr(pos, "must be %v of (any of) {%v} but got %v containing %v", tpe, strings.Join(expected, ", "), tpe, ast.TypeName(got))
|
||||
}
|
||||
|
||||
// NewOperandEnumErr returns an operand error indicating a value was wrong.
|
||||
func NewOperandEnumErr(pos int, expected ...string) error {
|
||||
|
||||
if len(expected) == 1 {
|
||||
return NewOperandErr(pos, "must be %v", expected[0])
|
||||
}
|
||||
|
||||
return NewOperandErr(pos, "must be one of {%v}", strings.Join(expected, ", "))
|
||||
}
|
||||
|
||||
// IntOperand converts x to an int. If the cast fails, a descriptive error is
|
||||
// returned.
|
||||
func IntOperand(x ast.Value, pos int) (int, error) {
|
||||
n, ok := x.(ast.Number)
|
||||
if !ok {
|
||||
return 0, NewOperandTypeErr(pos, x, "number")
|
||||
}
|
||||
|
||||
i, ok := n.Int()
|
||||
if !ok {
|
||||
return 0, NewOperandErr(pos, "must be integer number but got floating-point number")
|
||||
}
|
||||
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// BigIntOperand converts x to a big int. If the cast fails, a descriptive error
|
||||
// is returned.
|
||||
func BigIntOperand(x ast.Value, pos int) (*big.Int, error) {
|
||||
n, err := NumberOperand(x, 1)
|
||||
if err != nil {
|
||||
return nil, NewOperandTypeErr(pos, x, "integer")
|
||||
}
|
||||
bi, err := NumberToInt(n)
|
||||
if err != nil {
|
||||
return nil, NewOperandErr(pos, "must be integer number but got floating-point number")
|
||||
}
|
||||
|
||||
return bi, nil
|
||||
}
|
||||
|
||||
// NumberOperand converts x to a number. If the cast fails, a descriptive error is
|
||||
// returned.
|
||||
func NumberOperand(x ast.Value, pos int) (ast.Number, error) {
|
||||
n, ok := x.(ast.Number)
|
||||
if !ok {
|
||||
return ast.Number(""), NewOperandTypeErr(pos, x, "number")
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// SetOperand converts x to a set. If the cast fails, a descriptive error is
|
||||
// returned.
|
||||
func SetOperand(x ast.Value, pos int) (ast.Set, error) {
|
||||
s, ok := x.(ast.Set)
|
||||
if !ok {
|
||||
return nil, NewOperandTypeErr(pos, x, "set")
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// StringOperand converts x to a string. If the cast fails, a descriptive error is
|
||||
// returned.
|
||||
func StringOperand(x ast.Value, pos int) (ast.String, error) {
|
||||
s, ok := x.(ast.String)
|
||||
if !ok {
|
||||
return ast.String(""), NewOperandTypeErr(pos, x, "string")
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// ObjectOperand converts x to an object. If the cast fails, a descriptive
|
||||
// error is returned.
|
||||
func ObjectOperand(x ast.Value, pos int) (ast.Object, error) {
|
||||
o, ok := x.(ast.Object)
|
||||
if !ok {
|
||||
return nil, NewOperandTypeErr(pos, x, "object")
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// ArrayOperand converts x to an array. If the cast fails, a descriptive
|
||||
// error is returned.
|
||||
func ArrayOperand(x ast.Value, pos int) (ast.Array, error) {
|
||||
a, ok := x.(ast.Array)
|
||||
if !ok {
|
||||
return nil, NewOperandTypeErr(pos, x, "array")
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// NumberToFloat converts n to a big float.
|
||||
func NumberToFloat(n ast.Number) *big.Float {
|
||||
r, ok := new(big.Float).SetString(string(n))
|
||||
if !ok {
|
||||
panic("illegal value")
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// FloatToNumber converts f to a number.
|
||||
func FloatToNumber(f *big.Float) ast.Number {
|
||||
return ast.Number(f.String())
|
||||
}
|
||||
|
||||
// NumberToInt converts n to a big int.
|
||||
// If n cannot be converted to an big int, an error is returned.
|
||||
func NumberToInt(n ast.Number) (*big.Int, error) {
|
||||
f := NumberToFloat(n)
|
||||
r, accuracy := f.Int(nil)
|
||||
if accuracy != big.Exact {
|
||||
return nil, fmt.Errorf("illegal value")
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// IntToNumber converts i to a number.
|
||||
func IntToNumber(i *big.Int) ast.Number {
|
||||
return ast.Number(i.String())
|
||||
}
|
||||
|
||||
// StringSliceOperand converts x to a []string. If the cast fails, a descriptive error is
|
||||
// returned.
|
||||
func StringSliceOperand(x ast.Value, pos int) ([]string, error) {
|
||||
a, err := ArrayOperand(x, pos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var f = make([]string, len(a))
|
||||
for k, b := range a {
|
||||
c, ok := b.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, NewOperandElementErr(pos, x, b.Value, "[]string")
|
||||
}
|
||||
|
||||
f[k] = string(c)
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// RuneSliceOperand converts x to a []rune. If the cast fails, a descriptive error is
|
||||
// returned.
|
||||
func RuneSliceOperand(x ast.Value, pos int) ([]rune, error) {
|
||||
a, err := ArrayOperand(x, pos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var f = make([]rune, len(a))
|
||||
for k, b := range a {
|
||||
c, ok := b.Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, NewOperandElementErr(pos, x, b.Value, "string")
|
||||
}
|
||||
|
||||
d := []rune(string(c))
|
||||
if len(d) != 1 {
|
||||
return nil, NewOperandElementErr(pos, x, b.Value, "rune")
|
||||
}
|
||||
|
||||
f[k] = d[0]
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
166
vendor/github.com/open-policy-agent/opa/topdown/cache.go
generated
vendored
Normal file
166
vendor/github.com/open-policy-agent/opa/topdown/cache.go
generated
vendored
Normal file
@@ -0,0 +1,166 @@
|
||||
// Copyright 2017 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/util"
|
||||
)
|
||||
|
||||
type virtualCache struct {
|
||||
stack []*virtualCacheElem
|
||||
}
|
||||
|
||||
type virtualCacheElem struct {
|
||||
value *ast.Term
|
||||
children *util.HashMap
|
||||
}
|
||||
|
||||
func newVirtualCache() *virtualCache {
|
||||
cache := &virtualCache{}
|
||||
cache.Push()
|
||||
return cache
|
||||
}
|
||||
|
||||
func (c *virtualCache) Push() {
|
||||
c.stack = append(c.stack, newVirtualCacheElem())
|
||||
}
|
||||
|
||||
func (c *virtualCache) Pop() {
|
||||
c.stack = c.stack[:len(c.stack)-1]
|
||||
}
|
||||
|
||||
func (c *virtualCache) Get(ref ast.Ref) *ast.Term {
|
||||
node := c.stack[len(c.stack)-1]
|
||||
for i := 0; i < len(ref); i++ {
|
||||
x, ok := node.children.Get(ref[i])
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
node = x.(*virtualCacheElem)
|
||||
}
|
||||
return node.value
|
||||
}
|
||||
|
||||
func (c *virtualCache) Put(ref ast.Ref, value *ast.Term) {
|
||||
node := c.stack[len(c.stack)-1]
|
||||
for i := 0; i < len(ref); i++ {
|
||||
x, ok := node.children.Get(ref[i])
|
||||
if ok {
|
||||
node = x.(*virtualCacheElem)
|
||||
} else {
|
||||
next := newVirtualCacheElem()
|
||||
node.children.Put(ref[i], next)
|
||||
node = next
|
||||
}
|
||||
}
|
||||
node.value = value
|
||||
}
|
||||
|
||||
func newVirtualCacheElem() *virtualCacheElem {
|
||||
return &virtualCacheElem{children: newVirtualCacheHashMap()}
|
||||
}
|
||||
|
||||
func newVirtualCacheHashMap() *util.HashMap {
|
||||
return util.NewHashMap(func(a, b util.T) bool {
|
||||
return a.(*ast.Term).Equal(b.(*ast.Term))
|
||||
}, func(x util.T) int {
|
||||
return x.(*ast.Term).Hash()
|
||||
})
|
||||
}
|
||||
|
||||
// baseCache implements a trie structure to cache base documents read out of
|
||||
// storage. Values inserted into the cache may contain other values that were
|
||||
// previously inserted. In this case, the previous values are erased from the
|
||||
// structure.
|
||||
type baseCache struct {
|
||||
root *baseCacheElem
|
||||
}
|
||||
|
||||
func newBaseCache() *baseCache {
|
||||
return &baseCache{
|
||||
root: newBaseCacheElem(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *baseCache) Get(ref ast.Ref) ast.Value {
|
||||
node := c.root
|
||||
for i := 0; i < len(ref); i++ {
|
||||
node = node.children[ref[i].Value]
|
||||
if node == nil {
|
||||
return nil
|
||||
} else if node.value != nil {
|
||||
result, err := node.value.Find(ref[i+1:])
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *baseCache) Put(ref ast.Ref, value ast.Value) {
|
||||
node := c.root
|
||||
for i := 0; i < len(ref); i++ {
|
||||
if child, ok := node.children[ref[i].Value]; ok {
|
||||
node = child
|
||||
} else {
|
||||
child := newBaseCacheElem()
|
||||
node.children[ref[i].Value] = child
|
||||
node = child
|
||||
}
|
||||
}
|
||||
node.set(value)
|
||||
}
|
||||
|
||||
type baseCacheElem struct {
|
||||
value ast.Value
|
||||
children map[ast.Value]*baseCacheElem
|
||||
}
|
||||
|
||||
func newBaseCacheElem() *baseCacheElem {
|
||||
return &baseCacheElem{
|
||||
children: map[ast.Value]*baseCacheElem{},
|
||||
}
|
||||
}
|
||||
|
||||
func (e *baseCacheElem) set(value ast.Value) {
|
||||
e.value = value
|
||||
e.children = map[ast.Value]*baseCacheElem{}
|
||||
}
|
||||
|
||||
type refStack struct {
|
||||
sl []refStackElem
|
||||
}
|
||||
|
||||
type refStackElem struct {
|
||||
refs []ast.Ref
|
||||
}
|
||||
|
||||
func newRefStack() *refStack {
|
||||
return &refStack{}
|
||||
}
|
||||
|
||||
func (s *refStack) Push(refs []ast.Ref) {
|
||||
s.sl = append(s.sl, refStackElem{refs: refs})
|
||||
}
|
||||
|
||||
func (s *refStack) Pop() {
|
||||
s.sl = s.sl[:len(s.sl)-1]
|
||||
}
|
||||
|
||||
func (s *refStack) Prefixed(ref ast.Ref) bool {
|
||||
if s != nil {
|
||||
for i := len(s.sl) - 1; i >= 0; i-- {
|
||||
for j := range s.sl[i].refs {
|
||||
if ref.HasPrefix(s.sl[i].refs[j]) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
33
vendor/github.com/open-policy-agent/opa/topdown/cancel.go
generated
vendored
Normal file
33
vendor/github.com/open-policy-agent/opa/topdown/cancel.go
generated
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
// Copyright 2017 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Cancel defines the interface for cancelling topdown queries. Cancel
|
||||
// operations are thread-safe and idempotent.
|
||||
type Cancel interface {
|
||||
Cancel()
|
||||
Cancelled() bool
|
||||
}
|
||||
|
||||
type cancel struct {
|
||||
flag int32
|
||||
}
|
||||
|
||||
// NewCancel returns a new Cancel object.
|
||||
func NewCancel() Cancel {
|
||||
return &cancel{}
|
||||
}
|
||||
|
||||
func (c *cancel) Cancel() {
|
||||
atomic.StoreInt32(&c.flag, 1)
|
||||
}
|
||||
|
||||
func (c *cancel) Cancelled() bool {
|
||||
return atomic.LoadInt32(&c.flag) != 0
|
||||
}
|
||||
113
vendor/github.com/open-policy-agent/opa/topdown/casts.go
generated
vendored
Normal file
113
vendor/github.com/open-policy-agent/opa/topdown/casts.go
generated
vendored
Normal file
@@ -0,0 +1,113 @@
|
||||
// Copyright 2018 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
func builtinToNumber(a ast.Value) (ast.Value, error) {
|
||||
switch a := a.(type) {
|
||||
case ast.Null:
|
||||
return ast.Number("0"), nil
|
||||
case ast.Boolean:
|
||||
if a {
|
||||
return ast.Number("1"), nil
|
||||
}
|
||||
return ast.Number("0"), nil
|
||||
case ast.Number:
|
||||
return a, nil
|
||||
case ast.String:
|
||||
_, err := strconv.ParseFloat(string(a), 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ast.Number(a), nil
|
||||
}
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "null", "boolean", "number", "string")
|
||||
}
|
||||
|
||||
// Deprecated in v0.13.0.
|
||||
func builtinToArray(a ast.Value) (ast.Value, error) {
|
||||
switch val := a.(type) {
|
||||
case ast.Array:
|
||||
return val, nil
|
||||
case ast.Set:
|
||||
arr := make(ast.Array, val.Len())
|
||||
i := 0
|
||||
val.Foreach(func(term *ast.Term) {
|
||||
arr[i] = term
|
||||
i++
|
||||
})
|
||||
return arr, nil
|
||||
default:
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "array", "set")
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated in v0.13.0.
|
||||
func builtinToSet(a ast.Value) (ast.Value, error) {
|
||||
switch val := a.(type) {
|
||||
case ast.Array:
|
||||
return ast.NewSet(val...), nil
|
||||
case ast.Set:
|
||||
return val, nil
|
||||
default:
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "array", "set")
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated in v0.13.0.
|
||||
func builtinToString(a ast.Value) (ast.Value, error) {
|
||||
switch val := a.(type) {
|
||||
case ast.String:
|
||||
return val, nil
|
||||
default:
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "string")
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated in v0.13.0.
|
||||
func builtinToBoolean(a ast.Value) (ast.Value, error) {
|
||||
switch val := a.(type) {
|
||||
case ast.Boolean:
|
||||
return val, nil
|
||||
default:
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "boolean")
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated in v0.13.0.
|
||||
func builtinToNull(a ast.Value) (ast.Value, error) {
|
||||
switch val := a.(type) {
|
||||
case ast.Null:
|
||||
return val, nil
|
||||
default:
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "null")
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated in v0.13.0.
|
||||
func builtinToObject(a ast.Value) (ast.Value, error) {
|
||||
switch val := a.(type) {
|
||||
case ast.Object:
|
||||
return val, nil
|
||||
default:
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "object")
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin1(ast.ToNumber.Name, builtinToNumber)
|
||||
RegisterFunctionalBuiltin1(ast.CastArray.Name, builtinToArray)
|
||||
RegisterFunctionalBuiltin1(ast.CastSet.Name, builtinToSet)
|
||||
RegisterFunctionalBuiltin1(ast.CastString.Name, builtinToString)
|
||||
RegisterFunctionalBuiltin1(ast.CastBoolean.Name, builtinToBoolean)
|
||||
RegisterFunctionalBuiltin1(ast.CastNull.Name, builtinToNull)
|
||||
RegisterFunctionalBuiltin1(ast.CastObject.Name, builtinToObject)
|
||||
}
|
||||
157
vendor/github.com/open-policy-agent/opa/topdown/cidr.go
generated
vendored
Normal file
157
vendor/github.com/open-policy-agent/opa/topdown/cidr.go
generated
vendored
Normal file
@@ -0,0 +1,157 @@
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
func getNetFromOperand(v ast.Value) (*net.IPNet, error) {
|
||||
subnetStringA, err := builtins.StringOperand(v, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, cidrnet, err := net.ParseCIDR(string(subnetStringA))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cidrnet, nil
|
||||
}
|
||||
|
||||
func getLastIP(cidr *net.IPNet) (net.IP, error) {
|
||||
prefixLen, bits := cidr.Mask.Size()
|
||||
if prefixLen == 0 && bits == 0 {
|
||||
// non-standard mask, see https://golang.org/pkg/net/#IPMask.Size
|
||||
return nil, fmt.Errorf("CIDR mask is in non-standard format")
|
||||
}
|
||||
var lastIP []byte
|
||||
if prefixLen == bits {
|
||||
// Special case for single ip address ranges ex: 192.168.1.1/32
|
||||
// We can just use the starting IP as the last IP
|
||||
lastIP = cidr.IP
|
||||
} else {
|
||||
// Use big.Int's so we can handle ipv6 addresses
|
||||
firstIPInt := new(big.Int)
|
||||
firstIPInt.SetBytes(cidr.IP)
|
||||
hostLen := uint(bits) - uint(prefixLen)
|
||||
lastIPInt := big.NewInt(1)
|
||||
lastIPInt.Lsh(lastIPInt, hostLen)
|
||||
lastIPInt.Sub(lastIPInt, big.NewInt(1))
|
||||
lastIPInt.Or(lastIPInt, firstIPInt)
|
||||
|
||||
ipBytes := lastIPInt.Bytes()
|
||||
lastIP = make([]byte, bits/8)
|
||||
|
||||
// Pack our IP bytes into the end of the return array,
|
||||
// since big.Int.Bytes() removes front zero padding.
|
||||
for i := 1; i <= len(lastIPInt.Bytes()); i++ {
|
||||
lastIP[len(lastIP)-i] = ipBytes[len(ipBytes)-i]
|
||||
}
|
||||
}
|
||||
|
||||
return lastIP, nil
|
||||
}
|
||||
|
||||
func builtinNetCIDRIntersects(a, b ast.Value) (ast.Value, error) {
|
||||
cidrnetA, err := getNetFromOperand(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cidrnetB, err := getNetFromOperand(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If either net contains the others starting IP they are overlapping
|
||||
cidrsOverlap := (cidrnetA.Contains(cidrnetB.IP) || cidrnetB.Contains(cidrnetA.IP))
|
||||
|
||||
return ast.Boolean(cidrsOverlap), nil
|
||||
}
|
||||
|
||||
func builtinNetCIDRContains(a, b ast.Value) (ast.Value, error) {
|
||||
cidrnetA, err := getNetFromOperand(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// b could be either an IP addressor CIDR string, try to parse it as an IP first, fall back to CIDR
|
||||
bStr, err := builtins.StringOperand(b, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ip := net.ParseIP(string(bStr))
|
||||
if ip != nil {
|
||||
return ast.Boolean(cidrnetA.Contains(ip)), nil
|
||||
}
|
||||
|
||||
// It wasn't an IP, try and parse it as a CIDR
|
||||
cidrnetB, err := getNetFromOperand(b)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("not a valid textual representation of an IP address or CIDR: %s", string(bStr))
|
||||
}
|
||||
|
||||
// We can determine if cidr A contains cidr B iff A contains the starting address of B and the last address in B.
|
||||
cidrContained := false
|
||||
if cidrnetA.Contains(cidrnetB.IP) {
|
||||
// Only spend time calculating the last IP if the starting IP is already verified to be in cidr A
|
||||
lastIP, err := getLastIP(cidrnetB)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cidrContained = cidrnetA.Contains(lastIP)
|
||||
}
|
||||
|
||||
return ast.Boolean(cidrContained), nil
|
||||
}
|
||||
|
||||
func builtinNetCIDRExpand(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
s, err := builtins.StringOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ip, ipNet, err := net.ParseCIDR(string(s))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result := ast.NewSet()
|
||||
|
||||
for ip := ip.Mask(ipNet.Mask); ipNet.Contains(ip); incIP(ip) {
|
||||
|
||||
if bctx.Cancel != nil && bctx.Cancel.Cancelled() {
|
||||
return &Error{
|
||||
Code: CancelErr,
|
||||
Message: "net.cidr_expand: timed out before generating all IP addresses",
|
||||
}
|
||||
}
|
||||
|
||||
result.Add(ast.StringTerm(ip.String()))
|
||||
}
|
||||
|
||||
return iter(ast.NewTerm(result))
|
||||
}
|
||||
|
||||
func incIP(ip net.IP) {
|
||||
for j := len(ip) - 1; j >= 0; j-- {
|
||||
ip[j]++
|
||||
if ip[j] > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin2(ast.NetCIDROverlap.Name, builtinNetCIDRContains)
|
||||
RegisterFunctionalBuiltin2(ast.NetCIDRIntersects.Name, builtinNetCIDRIntersects)
|
||||
RegisterFunctionalBuiltin2(ast.NetCIDRContains.Name, builtinNetCIDRContains)
|
||||
RegisterBuiltinFunc(ast.NetCIDRExpand.Name, builtinNetCIDRExpand)
|
||||
}
|
||||
48
vendor/github.com/open-policy-agent/opa/topdown/comparison.go
generated
vendored
Normal file
48
vendor/github.com/open-policy-agent/opa/topdown/comparison.go
generated
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright 2016 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import "github.com/open-policy-agent/opa/ast"
|
||||
|
||||
type compareFunc func(a, b ast.Value) bool
|
||||
|
||||
func compareGreaterThan(a, b ast.Value) bool {
|
||||
return ast.Compare(a, b) > 0
|
||||
}
|
||||
|
||||
func compareGreaterThanEq(a, b ast.Value) bool {
|
||||
return ast.Compare(a, b) >= 0
|
||||
}
|
||||
|
||||
func compareLessThan(a, b ast.Value) bool {
|
||||
return ast.Compare(a, b) < 0
|
||||
}
|
||||
|
||||
func compareLessThanEq(a, b ast.Value) bool {
|
||||
return ast.Compare(a, b) <= 0
|
||||
}
|
||||
|
||||
func compareNotEq(a, b ast.Value) bool {
|
||||
return ast.Compare(a, b) != 0
|
||||
}
|
||||
|
||||
func compareEq(a, b ast.Value) bool {
|
||||
return ast.Compare(a, b) == 0
|
||||
}
|
||||
|
||||
func builtinCompare(cmp compareFunc) FunctionalBuiltin2 {
|
||||
return func(a, b ast.Value) (ast.Value, error) {
|
||||
return ast.Boolean(cmp(a, b)), nil
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin2(ast.GreaterThan.Name, builtinCompare(compareGreaterThan))
|
||||
RegisterFunctionalBuiltin2(ast.GreaterThanEq.Name, builtinCompare(compareGreaterThanEq))
|
||||
RegisterFunctionalBuiltin2(ast.LessThan.Name, builtinCompare(compareLessThan))
|
||||
RegisterFunctionalBuiltin2(ast.LessThanEq.Name, builtinCompare(compareLessThanEq))
|
||||
RegisterFunctionalBuiltin2(ast.NotEqual.Name, builtinCompare(compareNotEq))
|
||||
RegisterFunctionalBuiltin2(ast.Equal.Name, builtinCompare(compareEq))
|
||||
}
|
||||
484
vendor/github.com/open-policy-agent/opa/topdown/copypropagation/copypropagation.go
generated
vendored
Normal file
484
vendor/github.com/open-policy-agent/opa/topdown/copypropagation/copypropagation.go
generated
vendored
Normal file
@@ -0,0 +1,484 @@
|
||||
// Copyright 2018 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package copypropagation
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
)
|
||||
|
||||
// CopyPropagator implements a simple copy propagation optimization to remove
|
||||
// intermediate variables in partial evaluation results.
|
||||
//
|
||||
// For example, given the query: input.x > 1 where 'input' is unknown, the
|
||||
// compiled query would become input.x = a; a > 1 which would remain in the
|
||||
// partial evaluation result. The CopyPropagator will remove the variable
|
||||
// assignment so that partial evaluation simply outputs input.x > 1.
|
||||
//
|
||||
// In many cases, copy propagation can remove all variables from the result of
|
||||
// partial evaluation which simplifies evaluation for non-OPA consumers.
|
||||
//
|
||||
// In some cases, copy propagation cannot remove all variables. If the output of
|
||||
// a built-in call is subsequently used as a ref head, the output variable must
|
||||
// be kept. For example. sort(input, x); x[0] == 1. In this case, copy
|
||||
// propagation cannot replace x[0] == 1 with sort(input, x)[0] == 1 as this is
|
||||
// not legal.
|
||||
type CopyPropagator struct {
|
||||
livevars ast.VarSet // vars that must be preserved in the resulting query
|
||||
sorted []ast.Var // sorted copy of vars to ensure deterministic result
|
||||
ensureNonEmptyBody bool
|
||||
}
|
||||
|
||||
// New returns a new CopyPropagator that optimizes queries while preserving vars
|
||||
// in the livevars set.
|
||||
func New(livevars ast.VarSet) *CopyPropagator {
|
||||
|
||||
sorted := make([]ast.Var, 0, len(livevars))
|
||||
for v := range livevars {
|
||||
sorted = append(sorted, v)
|
||||
}
|
||||
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].Compare(sorted[j]) < 0
|
||||
})
|
||||
|
||||
return &CopyPropagator{livevars: livevars, sorted: sorted}
|
||||
}
|
||||
|
||||
// WithEnsureNonEmptyBody configures p to ensure that results are always non-empty.
|
||||
func (p *CopyPropagator) WithEnsureNonEmptyBody(yes bool) *CopyPropagator {
|
||||
p.ensureNonEmptyBody = yes
|
||||
return p
|
||||
}
|
||||
|
||||
// Apply executes the copy propagation optimization and returns a new query.
|
||||
func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) {
|
||||
|
||||
uf, ok := makeDisjointSets(p.livevars, query)
|
||||
if !ok {
|
||||
return query
|
||||
}
|
||||
|
||||
// Compute set of vars that appear in the head of refs in the query. If a var
|
||||
// is dereferenced, we cannot plug it with a constant value so the constant on
|
||||
// the union-find root must be unset (e.g., [1][0] is not legal.)
|
||||
headvars := ast.NewVarSet()
|
||||
ast.WalkRefs(query, func(x ast.Ref) bool {
|
||||
if v, ok := x[0].Value.(ast.Var); ok {
|
||||
if root, ok := uf.Find(v); ok {
|
||||
root.constant = nil
|
||||
headvars.Add(root.key)
|
||||
} else {
|
||||
headvars.Add(v)
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
bindings := map[ast.Var]*binding{}
|
||||
|
||||
for _, expr := range query {
|
||||
|
||||
pctx := &plugContext{
|
||||
bindings: bindings,
|
||||
uf: uf,
|
||||
negated: expr.Negated,
|
||||
headvars: headvars,
|
||||
}
|
||||
|
||||
if expr, keep := p.plugBindings(pctx, expr); keep {
|
||||
if p.updateBindings(pctx, expr) {
|
||||
result.Append(expr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run post-processing step on the query to ensure that all live vars are bound
|
||||
// in the result. The plugging that happens above substitutes all vars in the
|
||||
// same set with the root.
|
||||
//
|
||||
// This step should run before the next step to prevent unnecessary bindings
|
||||
// from being added to the result. For example:
|
||||
//
|
||||
// - Given the following result: <empty>
|
||||
// - Given the following bindings: x/input.x and y/input
|
||||
// - Given the following liveset: {x}
|
||||
//
|
||||
// If this step were to run AFTER the following step, the output would be:
|
||||
//
|
||||
// x = input.x; y = input
|
||||
//
|
||||
// Even though y = input is not required.
|
||||
for _, v := range p.sorted {
|
||||
if root, ok := uf.Find(v); ok {
|
||||
if root.constant != nil {
|
||||
result.Append(ast.Equality.Expr(ast.NewTerm(v), root.constant))
|
||||
} else if b, ok := bindings[root.key]; ok {
|
||||
result.Append(ast.Equality.Expr(ast.NewTerm(v), ast.NewTerm(b.v)))
|
||||
} else if root.key != v {
|
||||
result.Append(ast.Equality.Expr(ast.NewTerm(v), ast.NewTerm(root.key)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run post-processing step on query to ensure that all killed exprs are
|
||||
// accounted for. If an expr is killed but the binding is never used, the query
|
||||
// must still include the expr. For example, given the query 'input.x = a' and
|
||||
// an empty livevar set, the result must include the ref input.x otherwise the
|
||||
// query could be satisfied without input.x being defined. When exprs are
|
||||
// killed we initialize the binding counter to zero and then increment it each
|
||||
// time the binding is substituted. if the binding was never substituted it
|
||||
// means the binding value must be added back into the query.
|
||||
for _, b := range sortbindings(bindings) {
|
||||
if !b.containedIn(result) {
|
||||
result.Append(ast.Equality.Expr(ast.NewTerm(b.k), ast.NewTerm(b.v)))
|
||||
}
|
||||
}
|
||||
|
||||
if p.ensureNonEmptyBody && len(result) == 0 {
|
||||
result = append(result, ast.NewExpr(ast.BooleanTerm(true)))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// plugBindings applies the binding list and union-find to x. This process
|
||||
// removes as many variables as possible.
|
||||
func (p *CopyPropagator) plugBindings(pctx *plugContext, expr *ast.Expr) (*ast.Expr, bool) {
|
||||
|
||||
// Kill single term expressions that are in the binding list. They will be
|
||||
// re-added during post-processing if needed.
|
||||
if term, ok := expr.Terms.(*ast.Term); ok {
|
||||
if v, ok := term.Value.(ast.Var); ok {
|
||||
if root, ok := pctx.uf.Find(v); ok {
|
||||
if _, ok := pctx.bindings[root.key]; ok {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
xform := bindingPlugTransform{
|
||||
pctx: pctx,
|
||||
}
|
||||
|
||||
// Deep copy the expression as it may be mutated during the transform and
|
||||
// the caller running copy propagation may have references to the
|
||||
// expression. Note, the transform does not contain any error paths and
|
||||
// should never return a non-expression value for the root so consider
|
||||
// errors unreachable.
|
||||
x, err := ast.Transform(xform, expr.Copy())
|
||||
|
||||
if expr, ok := x.(*ast.Expr); !ok || err != nil {
|
||||
panic("unreachable")
|
||||
} else {
|
||||
return expr, true
|
||||
}
|
||||
}
|
||||
|
||||
type bindingPlugTransform struct {
|
||||
pctx *plugContext
|
||||
}
|
||||
|
||||
func (t bindingPlugTransform) Transform(x interface{}) (interface{}, error) {
|
||||
switch x := x.(type) {
|
||||
case ast.Var:
|
||||
return t.plugBindingsVar(t.pctx, x), nil
|
||||
case ast.Ref:
|
||||
return t.plugBindingsRef(t.pctx, x), nil
|
||||
default:
|
||||
return x, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t bindingPlugTransform) plugBindingsVar(pctx *plugContext, v ast.Var) (result ast.Value) {
|
||||
|
||||
result = v
|
||||
|
||||
// Apply union-find to remove redundant variables from input.
|
||||
if root, ok := pctx.uf.Find(v); ok {
|
||||
result = root.Value()
|
||||
}
|
||||
|
||||
// Apply binding list to substitute remaining vars.
|
||||
if v, ok := result.(ast.Var); ok {
|
||||
if b, ok := pctx.bindings[v]; ok {
|
||||
if !pctx.negated || b.v.IsGround() {
|
||||
result = b.v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (t bindingPlugTransform) plugBindingsRef(pctx *plugContext, v ast.Ref) ast.Ref {
|
||||
|
||||
// Apply union-find to remove redundant variables from input.
|
||||
if root, ok := pctx.uf.Find(v[0].Value.(ast.Var)); ok {
|
||||
v[0].Value = root.Value()
|
||||
}
|
||||
|
||||
result := v
|
||||
|
||||
// Refs require special handling. If the head of the ref was killed, then
|
||||
// the rest of the ref must be concatenated with the new base.
|
||||
//
|
||||
// Invariant: ref heads can only be replaced by refs (not calls).
|
||||
if b, ok := pctx.bindings[v[0].Value.(ast.Var)]; ok {
|
||||
if !pctx.negated || b.v.IsGround() {
|
||||
result = b.v.(ast.Ref).Concat(v[1:])
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// updateBindings returns false if the expression can be killed. If the
|
||||
// expression is killed, the binding list is updated to map a var to value.
|
||||
func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool {
|
||||
if pctx.negated || len(expr.With) > 0 {
|
||||
return true
|
||||
}
|
||||
if expr.IsEquality() {
|
||||
a, b := expr.Operand(0), expr.Operand(1)
|
||||
if a.Equal(b) {
|
||||
return false
|
||||
}
|
||||
k, v, keep := p.updateBindingsEq(a, b)
|
||||
if !keep {
|
||||
if v != nil {
|
||||
pctx.bindings[k] = newbinding(k, v)
|
||||
}
|
||||
return false
|
||||
}
|
||||
} else if expr.IsCall() {
|
||||
terms := expr.Terms.([]*ast.Term)
|
||||
output := terms[len(terms)-1]
|
||||
if k, ok := output.Value.(ast.Var); ok && !p.livevars.Contains(k) && !pctx.headvars.Contains(k) {
|
||||
pctx.bindings[k] = newbinding(k, ast.CallTerm(terms[:len(terms)-1]...).Value)
|
||||
return false
|
||||
}
|
||||
}
|
||||
return !isNoop(expr)
|
||||
}
|
||||
|
||||
func (p *CopyPropagator) updateBindingsEq(a, b *ast.Term) (ast.Var, ast.Value, bool) {
|
||||
k, v, keep := p.updateBindingsEqAsymmetric(a, b)
|
||||
if !keep {
|
||||
return k, v, keep
|
||||
}
|
||||
return p.updateBindingsEqAsymmetric(b, a)
|
||||
}
|
||||
|
||||
func (p *CopyPropagator) updateBindingsEqAsymmetric(a, b *ast.Term) (ast.Var, ast.Value, bool) {
|
||||
k, ok := a.Value.(ast.Var)
|
||||
if !ok || p.livevars.Contains(k) {
|
||||
return "", nil, true
|
||||
}
|
||||
|
||||
switch b.Value.(type) {
|
||||
case ast.Ref, ast.Call:
|
||||
return k, b.Value, false
|
||||
}
|
||||
|
||||
return "", nil, true
|
||||
}
|
||||
|
||||
type plugContext struct {
|
||||
bindings map[ast.Var]*binding
|
||||
uf *unionFind
|
||||
headvars ast.VarSet
|
||||
negated bool
|
||||
}
|
||||
|
||||
type binding struct {
|
||||
k ast.Var
|
||||
v ast.Value
|
||||
}
|
||||
|
||||
func newbinding(k ast.Var, v ast.Value) *binding {
|
||||
return &binding{k: k, v: v}
|
||||
}
|
||||
|
||||
func (b *binding) containedIn(query ast.Body) bool {
|
||||
var stop bool
|
||||
switch v := b.v.(type) {
|
||||
case ast.Ref:
|
||||
ast.WalkRefs(query, func(other ast.Ref) bool {
|
||||
if stop || other.HasPrefix(v) {
|
||||
stop = true
|
||||
return stop
|
||||
}
|
||||
return false
|
||||
})
|
||||
default:
|
||||
ast.WalkTerms(query, func(other *ast.Term) bool {
|
||||
if stop || other.Value.Compare(v) == 0 {
|
||||
stop = true
|
||||
return stop
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
return stop
|
||||
}
|
||||
|
||||
func sortbindings(bindings map[ast.Var]*binding) []*binding {
|
||||
sorted := make([]*binding, 0, len(bindings))
|
||||
for _, b := range bindings {
|
||||
sorted = append(sorted, b)
|
||||
}
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].k.Compare(sorted[j].k) < 0
|
||||
})
|
||||
return sorted
|
||||
}
|
||||
|
||||
type unionFind struct {
|
||||
roots map[ast.Var]*unionFindRoot
|
||||
parents map[ast.Var]ast.Var
|
||||
rank rankFunc
|
||||
}
|
||||
|
||||
// makeDisjointSets builds the union-find structure for the query. The structure
|
||||
// is built by processing all of the equality exprs in the query. Sets represent
|
||||
// vars that must be equal to each other. In addition to vars, each set can have
|
||||
// at most one constant. If the query contains expressions that cannot be
|
||||
// satisfied (e.g., because a set has multiple constants) this function returns
|
||||
// false.
|
||||
func makeDisjointSets(livevars ast.VarSet, query ast.Body) (*unionFind, bool) {
|
||||
uf := newUnionFind(func(r1, r2 *unionFindRoot) (*unionFindRoot, *unionFindRoot) {
|
||||
if livevars.Contains(r1.key) {
|
||||
return r1, r2
|
||||
}
|
||||
return r2, r1
|
||||
})
|
||||
for _, expr := range query {
|
||||
if expr.IsEquality() && !expr.Negated && len(expr.With) == 0 {
|
||||
a, b := expr.Operand(0), expr.Operand(1)
|
||||
varA, ok1 := a.Value.(ast.Var)
|
||||
varB, ok2 := b.Value.(ast.Var)
|
||||
if ok1 && ok2 {
|
||||
if _, ok := uf.Merge(varA, varB); !ok {
|
||||
return nil, false
|
||||
}
|
||||
} else if ok1 && ast.IsConstant(b.Value) {
|
||||
root := uf.MakeSet(varA)
|
||||
if root.constant != nil && !root.constant.Equal(b) {
|
||||
return nil, false
|
||||
}
|
||||
root.constant = b
|
||||
} else if ok2 && ast.IsConstant(a.Value) {
|
||||
root := uf.MakeSet(varB)
|
||||
if root.constant != nil && !root.constant.Equal(a) {
|
||||
return nil, false
|
||||
}
|
||||
root.constant = a
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return uf, true
|
||||
}
|
||||
|
||||
type rankFunc func(*unionFindRoot, *unionFindRoot) (*unionFindRoot, *unionFindRoot)
|
||||
|
||||
func newUnionFind(rank rankFunc) *unionFind {
|
||||
return &unionFind{
|
||||
roots: map[ast.Var]*unionFindRoot{},
|
||||
parents: map[ast.Var]ast.Var{},
|
||||
rank: rank,
|
||||
}
|
||||
}
|
||||
|
||||
func (uf *unionFind) MakeSet(v ast.Var) *unionFindRoot {
|
||||
|
||||
root, ok := uf.Find(v)
|
||||
if ok {
|
||||
return root
|
||||
}
|
||||
|
||||
root = newUnionFindRoot(v)
|
||||
uf.parents[v] = v
|
||||
uf.roots[v] = root
|
||||
return uf.roots[v]
|
||||
}
|
||||
|
||||
func (uf *unionFind) Find(v ast.Var) (*unionFindRoot, bool) {
|
||||
|
||||
parent, ok := uf.parents[v]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if parent == v {
|
||||
return uf.roots[v], true
|
||||
}
|
||||
|
||||
return uf.Find(parent)
|
||||
}
|
||||
|
||||
func (uf *unionFind) Merge(a, b ast.Var) (*unionFindRoot, bool) {
|
||||
|
||||
r1 := uf.MakeSet(a)
|
||||
r2 := uf.MakeSet(b)
|
||||
|
||||
if r1 != r2 {
|
||||
|
||||
r1, r2 = uf.rank(r1, r2)
|
||||
|
||||
uf.parents[r2.key] = r1.key
|
||||
delete(uf.roots, r2.key)
|
||||
|
||||
// Sets can have at most one constant value associated with them. When
|
||||
// unioning, we must preserve this invariant. If a set has two constants,
|
||||
// there will be no way to prove the query.
|
||||
if r1.constant != nil && r2.constant != nil && !r1.constant.Equal(r2.constant) {
|
||||
return nil, false
|
||||
} else if r1.constant == nil {
|
||||
r1.constant = r2.constant
|
||||
}
|
||||
}
|
||||
|
||||
return r1, true
|
||||
}
|
||||
|
||||
type unionFindRoot struct {
|
||||
key ast.Var
|
||||
constant *ast.Term
|
||||
}
|
||||
|
||||
func newUnionFindRoot(key ast.Var) *unionFindRoot {
|
||||
return &unionFindRoot{
|
||||
key: key,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *unionFindRoot) Value() ast.Value {
|
||||
if r.constant != nil {
|
||||
return r.constant.Value
|
||||
}
|
||||
return r.key
|
||||
}
|
||||
|
||||
func isNoop(expr *ast.Expr) bool {
|
||||
|
||||
if !expr.IsCall() {
|
||||
term := expr.Terms.(*ast.Term)
|
||||
if !ast.IsConstant(term.Value) {
|
||||
return false
|
||||
}
|
||||
return !ast.Boolean(false).Equal(term.Value)
|
||||
}
|
||||
|
||||
// A==A can be ignored
|
||||
if expr.Operator().Equal(ast.Equal.Ref()) {
|
||||
return expr.Operand(0).Equal(expr.Operand(1))
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
128
vendor/github.com/open-policy-agent/opa/topdown/crypto.go
generated
vendored
Normal file
128
vendor/github.com/open-policy-agent/opa/topdown/crypto.go
generated
vendored
Normal file
@@ -0,0 +1,128 @@
|
||||
// Copyright 2018 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
"github.com/open-policy-agent/opa/util"
|
||||
)
|
||||
|
||||
func builtinCryptoX509ParseCertificates(a ast.Value) (ast.Value, error) {
|
||||
|
||||
str, err := builtinBase64Decode(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
certs, err := x509.ParseCertificates([]byte(str.(ast.String)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bs, err := json.Marshal(certs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var x interface{}
|
||||
|
||||
if err := util.UnmarshalJSON(bs, &x); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.InterfaceToValue(x)
|
||||
}
|
||||
|
||||
func hashHelper(a ast.Value, h func(ast.String) string) (ast.Value, error) {
|
||||
s, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ast.String(h(s)), nil
|
||||
}
|
||||
|
||||
func builtinCryptoMd5(a ast.Value) (ast.Value, error) {
|
||||
return hashHelper(a, func(s ast.String) string { return fmt.Sprintf("%x", md5.Sum([]byte(s))) })
|
||||
}
|
||||
|
||||
func builtinCryptoSha1(a ast.Value) (ast.Value, error) {
|
||||
return hashHelper(a, func(s ast.String) string { return fmt.Sprintf("%x", sha1.Sum([]byte(s))) })
|
||||
}
|
||||
|
||||
func builtinCryptoSha256(a ast.Value) (ast.Value, error) {
|
||||
return hashHelper(a, func(s ast.String) string { return fmt.Sprintf("%x", sha256.Sum256([]byte(s))) })
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin1(ast.CryptoX509ParseCertificates.Name, builtinCryptoX509ParseCertificates)
|
||||
RegisterFunctionalBuiltin1(ast.CryptoMd5.Name, builtinCryptoMd5)
|
||||
RegisterFunctionalBuiltin1(ast.CryptoSha1.Name, builtinCryptoSha1)
|
||||
RegisterFunctionalBuiltin1(ast.CryptoSha256.Name, builtinCryptoSha256)
|
||||
}
|
||||
|
||||
// createRootCAs creates a new Cert Pool from scratch or adds to a copy of System Certs
|
||||
func createRootCAs(tlsCACertFile string, tlsCACertEnvVar []byte, tlsUseSystemCerts bool) (*x509.CertPool, error) {
|
||||
|
||||
var newRootCAs *x509.CertPool
|
||||
|
||||
if tlsUseSystemCerts {
|
||||
systemCertPool, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newRootCAs = systemCertPool
|
||||
} else {
|
||||
newRootCAs = x509.NewCertPool()
|
||||
}
|
||||
|
||||
if len(tlsCACertFile) > 0 {
|
||||
// Append our cert to the system pool
|
||||
caCert, err := readCertFromFile(tlsCACertFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok := newRootCAs.AppendCertsFromPEM(caCert); !ok {
|
||||
return nil, fmt.Errorf("could not append CA cert from %q", tlsCACertFile)
|
||||
}
|
||||
}
|
||||
|
||||
if len(tlsCACertEnvVar) > 0 {
|
||||
// Append our cert to the system pool
|
||||
if ok := newRootCAs.AppendCertsFromPEM(tlsCACertEnvVar); !ok {
|
||||
return nil, fmt.Errorf("error appending cert from env var %q into system certs", tlsCACertEnvVar)
|
||||
}
|
||||
}
|
||||
|
||||
return newRootCAs, nil
|
||||
}
|
||||
|
||||
// ReadCertFromFile reads a cert from file
|
||||
func readCertFromFile(localCertFile string) ([]byte, error) {
|
||||
// Read in the cert file
|
||||
certPEM, err := ioutil.ReadFile(localCertFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return certPEM, nil
|
||||
}
|
||||
|
||||
// ReadKeyFromFile reads a key from file
|
||||
func readKeyFromFile(localKeyFile string) ([]byte, error) {
|
||||
// Read in the cert file
|
||||
key, err := ioutil.ReadFile(localKeyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
10
vendor/github.com/open-policy-agent/opa/topdown/doc.go
generated
vendored
Normal file
10
vendor/github.com/open-policy-agent/opa/topdown/doc.go
generated
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright 2016 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package topdown provides low-level query evaluation support.
|
||||
//
|
||||
// The topdown implementation is a modified version of the standard top-down
|
||||
// evaluation algorithm used in Datalog. References and comprehensions are
|
||||
// evaluated eagerly while all other terms are evaluated lazily.
|
||||
package topdown
|
||||
217
vendor/github.com/open-policy-agent/opa/topdown/encoding.go
generated
vendored
Normal file
217
vendor/github.com/open-policy-agent/opa/topdown/encoding.go
generated
vendored
Normal file
@@ -0,0 +1,217 @@
|
||||
// Copyright 2017 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
ghodss "github.com/ghodss/yaml"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
"github.com/open-policy-agent/opa/util"
|
||||
)
|
||||
|
||||
func builtinJSONMarshal(a ast.Value) (ast.Value, error) {
|
||||
|
||||
asJSON, err := ast.JSON(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bs, err := json.Marshal(asJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.String(string(bs)), nil
|
||||
}
|
||||
|
||||
func builtinJSONUnmarshal(a ast.Value) (ast.Value, error) {
|
||||
|
||||
str, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var x interface{}
|
||||
|
||||
if err := util.UnmarshalJSON([]byte(str), &x); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.InterfaceToValue(x)
|
||||
}
|
||||
|
||||
func builtinBase64Encode(a ast.Value) (ast.Value, error) {
|
||||
str, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.String(base64.StdEncoding.EncodeToString([]byte(str))), nil
|
||||
}
|
||||
|
||||
func builtinBase64Decode(a ast.Value) (ast.Value, error) {
|
||||
str, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := base64.StdEncoding.DecodeString(string(str))
|
||||
return ast.String(result), err
|
||||
}
|
||||
|
||||
func builtinBase64UrlEncode(a ast.Value) (ast.Value, error) {
|
||||
str, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.String(base64.URLEncoding.EncodeToString([]byte(str))), nil
|
||||
}
|
||||
|
||||
func builtinBase64UrlDecode(a ast.Value) (ast.Value, error) {
|
||||
str, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := string(str)
|
||||
|
||||
// Some base64url encoders omit the padding at the end, so this case
|
||||
// corrects such representations using the method given in RFC 7515
|
||||
// Appendix C: https://tools.ietf.org/html/rfc7515#appendix-C
|
||||
if !strings.HasSuffix(s, "=") {
|
||||
switch len(s) % 4 {
|
||||
case 0:
|
||||
case 2:
|
||||
s += "=="
|
||||
case 3:
|
||||
s += "="
|
||||
default:
|
||||
return nil, fmt.Errorf("illegal base64url string: %s", s)
|
||||
}
|
||||
}
|
||||
result, err := base64.URLEncoding.DecodeString(s)
|
||||
return ast.String(result), err
|
||||
}
|
||||
|
||||
func builtinURLQueryEncode(a ast.Value) (ast.Value, error) {
|
||||
str, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ast.String(url.QueryEscape(string(str))), nil
|
||||
}
|
||||
|
||||
func builtinURLQueryDecode(a ast.Value) (ast.Value, error) {
|
||||
str, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s, err := url.QueryUnescape(string(str))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ast.String(s), nil
|
||||
}
|
||||
|
||||
var encodeObjectErr = builtins.NewOperandErr(1, "values must be string, array[string], or set[string]")
|
||||
|
||||
func builtinURLQueryEncodeObject(a ast.Value) (ast.Value, error) {
|
||||
asJSON, err := ast.JSON(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputs, ok := asJSON.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "object")
|
||||
}
|
||||
|
||||
query := url.Values{}
|
||||
|
||||
for k, v := range inputs {
|
||||
switch vv := v.(type) {
|
||||
case string:
|
||||
query.Set(k, vv)
|
||||
case []interface{}:
|
||||
for _, val := range vv {
|
||||
strVal, ok := val.(string)
|
||||
if !ok {
|
||||
return nil, encodeObjectErr
|
||||
}
|
||||
query.Add(k, strVal)
|
||||
}
|
||||
default:
|
||||
return nil, encodeObjectErr
|
||||
}
|
||||
}
|
||||
|
||||
return ast.String(query.Encode()), nil
|
||||
}
|
||||
|
||||
func builtinYAMLMarshal(a ast.Value) (ast.Value, error) {
|
||||
|
||||
asJSON, err := ast.JSON(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
encoder := json.NewEncoder(&buf)
|
||||
if err := encoder.Encode(asJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bs, err := ghodss.JSONToYAML(buf.Bytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.String(string(bs)), nil
|
||||
}
|
||||
|
||||
func builtinYAMLUnmarshal(a ast.Value) (ast.Value, error) {
|
||||
|
||||
str, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bs, err := ghodss.YAMLToJSON([]byte(str))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf := bytes.NewBuffer(bs)
|
||||
decoder := util.NewJSONDecoder(buf)
|
||||
var val interface{}
|
||||
err = decoder.Decode(&val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.InterfaceToValue(val)
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin1(ast.JSONMarshal.Name, builtinJSONMarshal)
|
||||
RegisterFunctionalBuiltin1(ast.JSONUnmarshal.Name, builtinJSONUnmarshal)
|
||||
RegisterFunctionalBuiltin1(ast.Base64Encode.Name, builtinBase64Encode)
|
||||
RegisterFunctionalBuiltin1(ast.Base64Decode.Name, builtinBase64Decode)
|
||||
RegisterFunctionalBuiltin1(ast.Base64UrlEncode.Name, builtinBase64UrlEncode)
|
||||
RegisterFunctionalBuiltin1(ast.Base64UrlDecode.Name, builtinBase64UrlDecode)
|
||||
RegisterFunctionalBuiltin1(ast.URLQueryDecode.Name, builtinURLQueryDecode)
|
||||
RegisterFunctionalBuiltin1(ast.URLQueryEncode.Name, builtinURLQueryEncode)
|
||||
RegisterFunctionalBuiltin1(ast.URLQueryEncodeObject.Name, builtinURLQueryEncodeObject)
|
||||
RegisterFunctionalBuiltin1(ast.YAMLMarshal.Name, builtinYAMLMarshal)
|
||||
RegisterFunctionalBuiltin1(ast.YAMLUnmarshal.Name, builtinYAMLUnmarshal)
|
||||
}
|
||||
119
vendor/github.com/open-policy-agent/opa/topdown/errors.go
generated
vendored
Normal file
119
vendor/github.com/open-policy-agent/opa/topdown/errors.go
generated
vendored
Normal file
@@ -0,0 +1,119 @@
|
||||
// Copyright 2017 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
)
|
||||
|
||||
// Error is the error type returned by the Eval and Query functions when
|
||||
// an evaluation error occurs.
|
||||
type Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Location *ast.Location `json:"location,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
// InternalErr represents an unknown evaluation error.
|
||||
InternalErr string = "eval_internal_error"
|
||||
|
||||
// CancelErr indicates the evaluation process was cancelled.
|
||||
CancelErr string = "eval_cancel_error"
|
||||
|
||||
// ConflictErr indicates a conflict was encountered during evaluation. For
|
||||
// instance, a conflict occurs if a rule produces multiple, differing values
|
||||
// for the same key in an object. Conflict errors indicate the policy does
|
||||
// not account for the data loaded into the policy engine.
|
||||
ConflictErr string = "eval_conflict_error"
|
||||
|
||||
// TypeErr indicates evaluation stopped because an expression was applied to
|
||||
// a value of an inappropriate type.
|
||||
TypeErr string = "eval_type_error"
|
||||
|
||||
// BuiltinErr indicates a built-in function received a semantically invalid
|
||||
// input or encountered some kind of runtime error, e.g., connection
|
||||
// timeout, connection refused, etc.
|
||||
BuiltinErr string = "eval_builtin_error"
|
||||
|
||||
// WithMergeErr indicates that the real and replacement data could not be merged.
|
||||
WithMergeErr string = "eval_with_merge_error"
|
||||
)
|
||||
|
||||
// IsError returns true if the err is an Error.
|
||||
func IsError(err error) bool {
|
||||
_, ok := err.(*Error)
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsCancel returns true if err was caused by cancellation.
|
||||
func IsCancel(err error) bool {
|
||||
if e, ok := err.(*Error); ok {
|
||||
return e.Code == CancelErr
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
|
||||
msg := fmt.Sprintf("%v: %v", e.Code, e.Message)
|
||||
|
||||
if e.Location != nil {
|
||||
msg = e.Location.String() + ": " + msg
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
func functionConflictErr(loc *ast.Location) error {
|
||||
return &Error{
|
||||
Code: ConflictErr,
|
||||
Location: loc,
|
||||
Message: "functions must not produce multiple outputs for same inputs",
|
||||
}
|
||||
}
|
||||
|
||||
func completeDocConflictErr(loc *ast.Location) error {
|
||||
return &Error{
|
||||
Code: ConflictErr,
|
||||
Location: loc,
|
||||
Message: "complete rules must not produce multiple outputs",
|
||||
}
|
||||
}
|
||||
|
||||
func objectDocKeyConflictErr(loc *ast.Location) error {
|
||||
return &Error{
|
||||
Code: ConflictErr,
|
||||
Location: loc,
|
||||
Message: "object keys must be unique",
|
||||
}
|
||||
}
|
||||
|
||||
func documentConflictErr(loc *ast.Location) error {
|
||||
return &Error{
|
||||
Code: ConflictErr,
|
||||
Location: loc,
|
||||
Message: "base and virtual document keys must be disjoint",
|
||||
}
|
||||
}
|
||||
|
||||
func unsupportedBuiltinErr(loc *ast.Location) error {
|
||||
return &Error{
|
||||
Code: InternalErr,
|
||||
Location: loc,
|
||||
Message: "unsupported built-in",
|
||||
}
|
||||
}
|
||||
|
||||
func mergeConflictErr(loc *ast.Location) error {
|
||||
return &Error{
|
||||
Code: WithMergeErr,
|
||||
Location: loc,
|
||||
Message: "real and replacement data could not be merged",
|
||||
}
|
||||
}
|
||||
2382
vendor/github.com/open-policy-agent/opa/topdown/eval.go
generated
vendored
Normal file
2382
vendor/github.com/open-policy-agent/opa/topdown/eval.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
65
vendor/github.com/open-policy-agent/opa/topdown/glob.go
generated
vendored
Normal file
65
vendor/github.com/open-policy-agent/opa/topdown/glob.go
generated
vendored
Normal file
@@ -0,0 +1,65 @@
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/gobwas/glob"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
var globCacheLock = sync.Mutex{}
|
||||
var globCache map[string]glob.Glob
|
||||
|
||||
func builtinGlobMatch(a, b, c ast.Value) (ast.Value, error) {
|
||||
pattern, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
delimiters, err := builtins.RuneSliceOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(delimiters) == 0 {
|
||||
delimiters = []rune{'.'}
|
||||
}
|
||||
|
||||
match, err := builtins.StringOperand(c, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id := fmt.Sprintf("%s-%v", pattern, delimiters)
|
||||
|
||||
globCacheLock.Lock()
|
||||
defer globCacheLock.Unlock()
|
||||
p, ok := globCache[id]
|
||||
if !ok {
|
||||
var err error
|
||||
if p, err = glob.Compile(string(pattern), delimiters...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
globCache[id] = p
|
||||
}
|
||||
|
||||
return ast.Boolean(p.Match(string(match))), nil
|
||||
}
|
||||
|
||||
func builtinGlobQuoteMeta(a ast.Value) (ast.Value, error) {
|
||||
pattern, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.String(glob.QuoteMeta(string(pattern))), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
globCache = map[string]glob.Glob{}
|
||||
RegisterFunctionalBuiltin3(ast.GlobMatch.Name, builtinGlobMatch)
|
||||
RegisterFunctionalBuiltin1(ast.GlobQuoteMeta.Name, builtinGlobQuoteMeta)
|
||||
}
|
||||
466
vendor/github.com/open-policy-agent/opa/topdown/http.go
generated
vendored
Normal file
466
vendor/github.com/open-policy-agent/opa/topdown/http.go
generated
vendored
Normal file
@@ -0,0 +1,466 @@
|
||||
// Copyright 2018 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/open-policy-agent/opa/internal/version"
|
||||
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
const defaultHTTPRequestTimeoutEnv = "HTTP_SEND_TIMEOUT"
|
||||
|
||||
var defaultHTTPRequestTimeout = time.Second * 5
|
||||
|
||||
var allowedKeyNames = [...]string{
|
||||
"method",
|
||||
"url",
|
||||
"body",
|
||||
"enable_redirect",
|
||||
"force_json_decode",
|
||||
"headers",
|
||||
"raw_body",
|
||||
"tls_use_system_certs",
|
||||
"tls_ca_cert_file",
|
||||
"tls_ca_cert_env_variable",
|
||||
"tls_client_cert_env_variable",
|
||||
"tls_client_key_env_variable",
|
||||
"tls_client_cert_file",
|
||||
"tls_client_key_file",
|
||||
"tls_insecure_skip_verify",
|
||||
"timeout",
|
||||
}
|
||||
var allowedKeys = ast.NewSet()
|
||||
|
||||
var requiredKeys = ast.NewSet(ast.StringTerm("method"), ast.StringTerm("url"))
|
||||
|
||||
type httpSendKey string
|
||||
|
||||
// httpSendBuiltinCacheKey is the key in the builtin context cache that
|
||||
// points to the http.send() specific cache resides at.
|
||||
const httpSendBuiltinCacheKey httpSendKey = "HTTP_SEND_CACHE_KEY"
|
||||
|
||||
func builtinHTTPSend(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
req, err := validateHTTPRequestOperand(args[0], 1)
|
||||
if err != nil {
|
||||
return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err)
|
||||
}
|
||||
|
||||
// check if cache already has a response for this query
|
||||
resp := checkHTTPSendCache(bctx, req)
|
||||
if resp == nil {
|
||||
var err error
|
||||
resp, err = executeHTTPRequest(bctx, req)
|
||||
if err != nil {
|
||||
return handleHTTPSendErr(bctx, err)
|
||||
}
|
||||
|
||||
// add result to cache
|
||||
insertIntoHTTPSendCache(bctx, req, resp)
|
||||
}
|
||||
|
||||
return iter(ast.NewTerm(resp))
|
||||
}
|
||||
|
||||
func init() {
|
||||
createAllowedKeys()
|
||||
initDefaults()
|
||||
RegisterBuiltinFunc(ast.HTTPSend.Name, builtinHTTPSend)
|
||||
}
|
||||
|
||||
func handleHTTPSendErr(bctx BuiltinContext, err error) error {
|
||||
// Return HTTP client timeout errors in a generic error message to avoid confusion about what happened.
|
||||
// Do not do this if the builtin context was cancelled and is what caused the request to stop.
|
||||
if urlErr, ok := err.(*url.Error); ok && urlErr.Timeout() && bctx.Context.Err() == nil {
|
||||
err = fmt.Errorf("%s %s: request timed out", urlErr.Op, urlErr.URL)
|
||||
}
|
||||
return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err)
|
||||
}
|
||||
|
||||
func initDefaults() {
|
||||
timeoutDuration := os.Getenv(defaultHTTPRequestTimeoutEnv)
|
||||
if timeoutDuration != "" {
|
||||
var err error
|
||||
defaultHTTPRequestTimeout, err = time.ParseDuration(timeoutDuration)
|
||||
if err != nil {
|
||||
// If it is set to something not valid don't let the process continue in a state
|
||||
// that will almost definitely give unexpected results by having it set at 0
|
||||
// which means no timeout..
|
||||
// This environment variable isn't considered part of the public API.
|
||||
// TODO(patrick-east): Remove the environment variable
|
||||
panic(fmt.Sprintf("invalid value for HTTP_SEND_TIMEOUT: %s", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func validateHTTPRequestOperand(term *ast.Term, pos int) (ast.Object, error) {
|
||||
|
||||
obj, err := builtins.ObjectOperand(term.Value, pos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requestKeys := ast.NewSet(obj.Keys()...)
|
||||
|
||||
invalidKeys := requestKeys.Diff(allowedKeys)
|
||||
if invalidKeys.Len() != 0 {
|
||||
return nil, builtins.NewOperandErr(pos, "invalid request parameters(s): %v", invalidKeys)
|
||||
}
|
||||
|
||||
missingKeys := requiredKeys.Diff(requestKeys)
|
||||
if missingKeys.Len() != 0 {
|
||||
return nil, builtins.NewOperandErr(pos, "missing required request parameters(s): %v", missingKeys)
|
||||
}
|
||||
|
||||
return obj, nil
|
||||
|
||||
}
|
||||
|
||||
// Adds custom headers to a new HTTP request.
|
||||
func addHeaders(req *http.Request, headers map[string]interface{}) (bool, error) {
|
||||
for k, v := range headers {
|
||||
// Type assertion
|
||||
header, ok := v.(string)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("invalid type for headers value %q", v)
|
||||
}
|
||||
|
||||
// If the Host header is given, bump that up to
|
||||
// the request. Otherwise, just collect it in the
|
||||
// headers.
|
||||
k := http.CanonicalHeaderKey(k)
|
||||
switch k {
|
||||
case "Host":
|
||||
req.Host = header
|
||||
default:
|
||||
req.Header.Add(k, header)
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error) {
|
||||
var url string
|
||||
var method string
|
||||
var tlsCaCertEnvVar []byte
|
||||
var tlsCaCertFile string
|
||||
var tlsClientKeyEnvVar []byte
|
||||
var tlsClientCertEnvVar []byte
|
||||
var tlsClientCertFile string
|
||||
var tlsClientKeyFile string
|
||||
var body *bytes.Buffer
|
||||
var rawBody *bytes.Buffer
|
||||
var enableRedirect bool
|
||||
var forceJSONDecode bool
|
||||
var tlsUseSystemCerts bool
|
||||
var tlsConfig tls.Config
|
||||
var clientCerts []tls.Certificate
|
||||
var customHeaders map[string]interface{}
|
||||
var tlsInsecureSkipVerify bool
|
||||
var timeout = defaultHTTPRequestTimeout
|
||||
|
||||
for _, val := range obj.Keys() {
|
||||
key, err := ast.JSON(val.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key = key.(string)
|
||||
|
||||
switch key {
|
||||
case "method":
|
||||
method = obj.Get(val).String()
|
||||
method = strings.ToUpper(strings.Trim(method, "\""))
|
||||
case "url":
|
||||
url = obj.Get(val).String()
|
||||
url = strings.Trim(url, "\"")
|
||||
case "enable_redirect":
|
||||
enableRedirect, err = strconv.ParseBool(obj.Get(val).String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "force_json_decode":
|
||||
forceJSONDecode, err = strconv.ParseBool(obj.Get(val).String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "body":
|
||||
bodyVal := obj.Get(val).Value
|
||||
bodyValInterface, err := ast.JSON(bodyVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bodyValBytes, err := json.Marshal(bodyValInterface)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body = bytes.NewBuffer(bodyValBytes)
|
||||
case "raw_body":
|
||||
s, ok := obj.Get(val).Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("raw_body must be a string")
|
||||
}
|
||||
rawBody = bytes.NewBuffer([]byte(s))
|
||||
case "tls_use_system_certs":
|
||||
tlsUseSystemCerts, err = strconv.ParseBool(obj.Get(val).String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "tls_ca_cert_file":
|
||||
tlsCaCertFile = obj.Get(val).String()
|
||||
tlsCaCertFile = strings.Trim(tlsCaCertFile, "\"")
|
||||
case "tls_ca_cert_env_variable":
|
||||
caCertEnv := obj.Get(val).String()
|
||||
caCertEnv = strings.Trim(caCertEnv, "\"")
|
||||
tlsCaCertEnvVar = []byte(os.Getenv(caCertEnv))
|
||||
case "tls_client_cert_env_variable":
|
||||
clientCertEnv := obj.Get(val).String()
|
||||
clientCertEnv = strings.Trim(clientCertEnv, "\"")
|
||||
tlsClientCertEnvVar = []byte(os.Getenv(clientCertEnv))
|
||||
case "tls_client_key_env_variable":
|
||||
clientKeyEnv := obj.Get(val).String()
|
||||
clientKeyEnv = strings.Trim(clientKeyEnv, "\"")
|
||||
tlsClientKeyEnvVar = []byte(os.Getenv(clientKeyEnv))
|
||||
case "tls_client_cert_file":
|
||||
tlsClientCertFile = obj.Get(val).String()
|
||||
tlsClientCertFile = strings.Trim(tlsClientCertFile, "\"")
|
||||
case "tls_client_key_file":
|
||||
tlsClientKeyFile = obj.Get(val).String()
|
||||
tlsClientKeyFile = strings.Trim(tlsClientKeyFile, "\"")
|
||||
case "headers":
|
||||
headersVal := obj.Get(val).Value
|
||||
headersValInterface, err := ast.JSON(headersVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ok bool
|
||||
customHeaders, ok = headersValInterface.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid type for headers key")
|
||||
}
|
||||
case "tls_insecure_skip_verify":
|
||||
tlsInsecureSkipVerify, err = strconv.ParseBool(obj.Get(val).String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "timeout":
|
||||
timeout, err = parseTimeout(obj.Get(val).Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid parameter %q", key)
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
if tlsInsecureSkipVerify {
|
||||
client.Transport = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: tlsInsecureSkipVerify},
|
||||
}
|
||||
}
|
||||
if tlsClientCertFile != "" && tlsClientKeyFile != "" {
|
||||
clientCertFromFile, err := tls.LoadX509KeyPair(tlsClientCertFile, tlsClientKeyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientCerts = append(clientCerts, clientCertFromFile)
|
||||
}
|
||||
|
||||
if len(tlsClientCertEnvVar) > 0 && len(tlsClientKeyEnvVar) > 0 {
|
||||
clientCertFromEnv, err := tls.X509KeyPair(tlsClientCertEnvVar, tlsClientKeyEnvVar)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientCerts = append(clientCerts, clientCertFromEnv)
|
||||
}
|
||||
|
||||
isTLS := false
|
||||
if len(clientCerts) > 0 {
|
||||
isTLS = true
|
||||
tlsConfig.Certificates = append(tlsConfig.Certificates, clientCerts...)
|
||||
}
|
||||
|
||||
if tlsUseSystemCerts || len(tlsCaCertFile) > 0 || len(tlsCaCertEnvVar) > 0 {
|
||||
isTLS = true
|
||||
connRootCAs, err := createRootCAs(tlsCaCertFile, tlsCaCertEnvVar, tlsUseSystemCerts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig.RootCAs = connRootCAs
|
||||
}
|
||||
|
||||
if isTLS {
|
||||
client.Transport = &http.Transport{
|
||||
TLSClientConfig: &tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// check if redirects are enabled
|
||||
if !enableRedirect {
|
||||
client.CheckRedirect = func(*http.Request, []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
}
|
||||
|
||||
if rawBody != nil {
|
||||
body = rawBody
|
||||
} else if body == nil {
|
||||
body = bytes.NewBufferString("")
|
||||
}
|
||||
|
||||
// create the http request, use the builtin context's context to ensure
|
||||
// the request is cancelled if evaluation is cancelled.
|
||||
req, err := http.NewRequest(method, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req = req.WithContext(bctx.Context)
|
||||
|
||||
// Add custom headers
|
||||
if len(customHeaders) != 0 {
|
||||
if ok, err := addHeaders(req, customHeaders); !ok {
|
||||
return nil, err
|
||||
}
|
||||
// Don't overwrite or append to one that was set in the custom headers
|
||||
if _, hasUA := customHeaders["User-Agent"]; !hasUA {
|
||||
req.Header.Add("User-Agent", version.UserAgent)
|
||||
}
|
||||
}
|
||||
|
||||
// execute the http request
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
// format the http result
|
||||
var resultBody interface{}
|
||||
var resultRawBody []byte
|
||||
|
||||
var buf bytes.Buffer
|
||||
tee := io.TeeReader(resp.Body, &buf)
|
||||
resultRawBody, err = ioutil.ReadAll(tee)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the response body cannot be JSON decoded,
|
||||
// an error will not be returned. Instead the "body" field
|
||||
// in the result will be null.
|
||||
if isContentTypeJSON(resp.Header) || forceJSONDecode {
|
||||
json.NewDecoder(&buf).Decode(&resultBody)
|
||||
}
|
||||
|
||||
result := make(map[string]interface{})
|
||||
result["status"] = resp.Status
|
||||
result["status_code"] = resp.StatusCode
|
||||
result["body"] = resultBody
|
||||
result["raw_body"] = string(resultRawBody)
|
||||
|
||||
resultObj, err := ast.InterfaceToValue(result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resultObj, nil
|
||||
}
|
||||
|
||||
func isContentTypeJSON(header http.Header) bool {
|
||||
return strings.Contains(header.Get("Content-Type"), "application/json")
|
||||
}
|
||||
|
||||
// In the BuiltinContext cache we only store a single entry that points to
|
||||
// our ValueMap which is the "real" http.send() cache.
|
||||
func getHTTPSendCache(bctx BuiltinContext) *ast.ValueMap {
|
||||
raw, ok := bctx.Cache.Get(httpSendBuiltinCacheKey)
|
||||
if !ok {
|
||||
// Initialize if it isn't there
|
||||
cache := ast.NewValueMap()
|
||||
bctx.Cache.Put(httpSendBuiltinCacheKey, cache)
|
||||
return cache
|
||||
}
|
||||
|
||||
cache, ok := raw.(*ast.ValueMap)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return cache
|
||||
}
|
||||
|
||||
// checkHTTPSendCache checks for the given key's value in the cache
|
||||
func checkHTTPSendCache(bctx BuiltinContext, key ast.Object) ast.Value {
|
||||
requestCache := getHTTPSendCache(bctx)
|
||||
if requestCache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return requestCache.Get(key)
|
||||
}
|
||||
|
||||
func insertIntoHTTPSendCache(bctx BuiltinContext, key ast.Object, value ast.Value) {
|
||||
requestCache := getHTTPSendCache(bctx)
|
||||
if requestCache == nil {
|
||||
// Should never happen.. if it does just skip caching the value
|
||||
return
|
||||
}
|
||||
requestCache.Put(key, value)
|
||||
}
|
||||
|
||||
func createAllowedKeys() {
|
||||
for _, element := range allowedKeyNames {
|
||||
allowedKeys.Add(ast.StringTerm(element))
|
||||
}
|
||||
}
|
||||
|
||||
func parseTimeout(timeoutVal ast.Value) (time.Duration, error) {
|
||||
var timeout time.Duration
|
||||
switch t := timeoutVal.(type) {
|
||||
case ast.Number:
|
||||
timeoutInt, ok := t.Int64()
|
||||
if !ok {
|
||||
return timeout, fmt.Errorf("invalid timeout number value %v, must be int64", timeoutVal)
|
||||
}
|
||||
return time.Duration(timeoutInt), nil
|
||||
case ast.String:
|
||||
// Support strings without a unit, treat them the same as just a number value (ns)
|
||||
var err error
|
||||
timeoutInt, err := strconv.ParseInt(string(t), 10, 64)
|
||||
if err == nil {
|
||||
return time.Duration(timeoutInt), nil
|
||||
}
|
||||
|
||||
// Try parsing it as a duration (requires a supported units suffix)
|
||||
timeout, err = time.ParseDuration(string(t))
|
||||
if err != nil {
|
||||
return timeout, fmt.Errorf("invalid timeout value %v: %s", timeoutVal, err)
|
||||
}
|
||||
return timeout, nil
|
||||
default:
|
||||
return timeout, builtins.NewOperandErr(1, "'timeout' must be one of {string, number} but got %s", ast.TypeName(t))
|
||||
}
|
||||
}
|
||||
74
vendor/github.com/open-policy-agent/opa/topdown/input.go
generated
vendored
Normal file
74
vendor/github.com/open-policy-agent/opa/topdown/input.go
generated
vendored
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright 2016 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
)
|
||||
|
||||
var errConflictingDoc = fmt.Errorf("conflicting documents")
|
||||
var errBadPath = fmt.Errorf("bad document path")
|
||||
|
||||
func mergeTermWithValues(exist *ast.Term, pairs [][2]*ast.Term) (*ast.Term, error) {
|
||||
|
||||
var result *ast.Term
|
||||
|
||||
if exist != nil {
|
||||
result = exist.Copy()
|
||||
}
|
||||
|
||||
for _, pair := range pairs {
|
||||
|
||||
if err := ast.IsValidImportPath(pair[0].Value); err != nil {
|
||||
return nil, errBadPath
|
||||
}
|
||||
|
||||
target := pair[0].Value.(ast.Ref)
|
||||
|
||||
if len(target) == 1 {
|
||||
result = pair[1]
|
||||
} else if result == nil {
|
||||
result = ast.NewTerm(makeTree(target[1:], pair[1]))
|
||||
} else {
|
||||
node := result
|
||||
done := false
|
||||
for i := 1; i < len(target)-1 && !done; i++ {
|
||||
if child := node.Get(target[i]); child == nil {
|
||||
obj, ok := node.Value.(ast.Object)
|
||||
if !ok {
|
||||
return nil, errConflictingDoc
|
||||
}
|
||||
obj.Insert(target[i], ast.NewTerm(makeTree(target[i+1:], pair[1])))
|
||||
done = true
|
||||
} else {
|
||||
node = child
|
||||
}
|
||||
}
|
||||
if !done {
|
||||
obj, ok := node.Value.(ast.Object)
|
||||
if !ok {
|
||||
return nil, errConflictingDoc
|
||||
}
|
||||
obj.Insert(target[len(target)-1], pair[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// makeTree returns an object that represents a document where the value v is
|
||||
// the leaf and elements in k represent intermediate objects.
|
||||
func makeTree(k ast.Ref, v *ast.Term) ast.Object {
|
||||
var obj ast.Object
|
||||
for i := len(k) - 1; i >= 1; i-- {
|
||||
obj = ast.NewObject(ast.Item(k[i], v))
|
||||
v = &ast.Term{Value: obj}
|
||||
}
|
||||
obj = ast.NewObject(ast.Item(k[0], v))
|
||||
return obj
|
||||
}
|
||||
59
vendor/github.com/open-policy-agent/opa/topdown/instrumentation.go
generated
vendored
Normal file
59
vendor/github.com/open-policy-agent/opa/topdown/instrumentation.go
generated
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright 2018 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import "github.com/open-policy-agent/opa/metrics"
|
||||
|
||||
const (
|
||||
evalOpPlug = "eval_op_plug"
|
||||
evalOpResolve = "eval_op_resolve"
|
||||
evalOpRuleIndex = "eval_op_rule_index"
|
||||
evalOpBuiltinCall = "eval_op_builtin_call"
|
||||
evalOpVirtualCacheHit = "eval_op_virtual_cache_hit"
|
||||
evalOpVirtualCacheMiss = "eval_op_virtual_cache_miss"
|
||||
evalOpBaseCacheHit = "eval_op_base_cache_hit"
|
||||
evalOpBaseCacheMiss = "eval_op_base_cache_miss"
|
||||
partialOpSaveUnify = "partial_op_save_unify"
|
||||
partialOpSaveSetContains = "partial_op_save_set_contains"
|
||||
partialOpSaveSetContainsRec = "partial_op_save_set_contains_rec"
|
||||
partialOpCopyPropagation = "partial_op_copy_propagation"
|
||||
)
|
||||
|
||||
// Instrumentation implements helper functions to instrument query evaluation
|
||||
// to diagnose performance issues. Instrumentation may be expensive in some
|
||||
// cases, so it is disabled by default.
|
||||
type Instrumentation struct {
|
||||
m metrics.Metrics
|
||||
}
|
||||
|
||||
// NewInstrumentation returns a new Instrumentation object. Performance
|
||||
// diagnostics recorded on this Instrumentation object will stored in m.
|
||||
func NewInstrumentation(m metrics.Metrics) *Instrumentation {
|
||||
return &Instrumentation{
|
||||
m: m,
|
||||
}
|
||||
}
|
||||
|
||||
func (instr *Instrumentation) startTimer(name string) {
|
||||
if instr == nil {
|
||||
return
|
||||
}
|
||||
instr.m.Timer(name).Start()
|
||||
}
|
||||
|
||||
func (instr *Instrumentation) stopTimer(name string) {
|
||||
if instr == nil {
|
||||
return
|
||||
}
|
||||
delta := instr.m.Timer(name).Stop()
|
||||
instr.m.Histogram(name).Update(delta)
|
||||
}
|
||||
|
||||
func (instr *Instrumentation) counterIncr(name string) {
|
||||
if instr == nil {
|
||||
return
|
||||
}
|
||||
instr.m.Counter(name).Incr()
|
||||
}
|
||||
21
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/LICENSE
generated
vendored
Normal file
21
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015 lestrrat
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
113
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/buffer/buffer.go
generated
vendored
Normal file
113
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/buffer/buffer.go
generated
vendored
Normal file
@@ -0,0 +1,113 @@
|
||||
// Package buffer provides a very thin wrapper around []byte buffer called
|
||||
// `Buffer`, to provide functionalities that are often used within the jwx
|
||||
// related packages
|
||||
package buffer
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Buffer wraps `[]byte` and provides functions that are often used in
|
||||
// the jwx related packages. One notable difference is that while
|
||||
// encoding/json marshalls `[]byte` using base64.StdEncoding, this
|
||||
// module uses base64.RawURLEncoding as mandated by the spec
|
||||
type Buffer []byte
|
||||
|
||||
// FromUint creates a `Buffer` from an unsigned int
|
||||
func FromUint(v uint64) Buffer {
|
||||
data := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(data, v)
|
||||
|
||||
i := 0
|
||||
for ; i < len(data); i++ {
|
||||
if data[i] != 0x0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return Buffer(data[i:])
|
||||
}
|
||||
|
||||
// FromBase64 constructs a new Buffer from a base64 encoded data
|
||||
func FromBase64(v []byte) (Buffer, error) {
|
||||
b := Buffer{}
|
||||
if err := b.Base64Decode(v); err != nil {
|
||||
return Buffer(nil), errors.Wrap(err, "failed to decode from base64")
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// FromNData constructs a new Buffer from a "n:data" format
|
||||
// (I made that name up)
|
||||
func FromNData(v []byte) (Buffer, error) {
|
||||
size := binary.BigEndian.Uint32(v)
|
||||
buf := make([]byte, int(size))
|
||||
copy(buf, v[4:4+size])
|
||||
return Buffer(buf), nil
|
||||
}
|
||||
|
||||
// Bytes returns the raw bytes that comprises the Buffer
|
||||
func (b Buffer) Bytes() []byte {
|
||||
return []byte(b)
|
||||
}
|
||||
|
||||
// NData returns Datalen || Data, where Datalen is a 32 bit counter for
|
||||
// the length of the following data, and Data is the octets that comprise
|
||||
// the buffer data
|
||||
func (b Buffer) NData() []byte {
|
||||
buf := make([]byte, 4+b.Len())
|
||||
binary.BigEndian.PutUint32(buf, uint32(b.Len()))
|
||||
|
||||
copy(buf[4:], b.Bytes())
|
||||
return buf
|
||||
}
|
||||
|
||||
// Len returns the number of bytes that the Buffer holds
|
||||
func (b Buffer) Len() int {
|
||||
return len(b)
|
||||
}
|
||||
|
||||
// Base64Encode encodes the contents of the Buffer using base64.RawURLEncoding
|
||||
func (b Buffer) Base64Encode() ([]byte, error) {
|
||||
enc := base64.RawURLEncoding
|
||||
out := make([]byte, enc.EncodedLen(len(b)))
|
||||
enc.Encode(out, b)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Base64Decode decodes the contents of the Buffer using base64.RawURLEncoding
|
||||
func (b *Buffer) Base64Decode(v []byte) error {
|
||||
enc := base64.RawURLEncoding
|
||||
out := make([]byte, enc.DecodedLen(len(v)))
|
||||
n, err := enc.Decode(out, v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to decode from base64")
|
||||
}
|
||||
out = out[:n]
|
||||
*b = Buffer(out)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON marshals the buffer into JSON format after encoding the buffer
|
||||
// with base64.RawURLEncoding
|
||||
func (b Buffer) MarshalJSON() ([]byte, error) {
|
||||
v, err := b.Base64Encode()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to encode to base64")
|
||||
}
|
||||
return json.Marshal(string(v))
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals from a JSON string into a Buffer, after decoding it
|
||||
// with base64.RawURLEncoding
|
||||
func (b *Buffer) UnmarshalJSON(data []byte) error {
|
||||
var x string
|
||||
if err := json.Unmarshal(data, &x); err != nil {
|
||||
return errors.Wrap(err, "failed to unmarshal JSON")
|
||||
}
|
||||
return b.Base64Decode([]byte(x))
|
||||
}
|
||||
11
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwa/elliptic.go
generated
vendored
Normal file
11
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwa/elliptic.go
generated
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
package jwa
|
||||
|
||||
// EllipticCurveAlgorithm represents the algorithms used for EC keys
|
||||
type EllipticCurveAlgorithm string
|
||||
|
||||
// Supported values for EllipticCurveAlgorithm
|
||||
const (
|
||||
P256 EllipticCurveAlgorithm = "P-256"
|
||||
P384 EllipticCurveAlgorithm = "P-384"
|
||||
P521 EllipticCurveAlgorithm = "P-521"
|
||||
)
|
||||
67
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwa/key_type.go
generated
vendored
Normal file
67
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwa/key_type.go
generated
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
package jwa
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// KeyType represents the key type ("kty") that are supported
|
||||
type KeyType string
|
||||
|
||||
var keyTypeAlg = map[string]struct{}{"EC": {}, "oct": {}, "RSA": {}}
|
||||
|
||||
// Supported values for KeyType
|
||||
const (
|
||||
EC KeyType = "EC" // Elliptic Curve
|
||||
InvalidKeyType KeyType = "" // Invalid KeyType
|
||||
OctetSeq KeyType = "oct" // Octet sequence (used to represent symmetric keys)
|
||||
RSA KeyType = "RSA" // RSA
|
||||
)
|
||||
|
||||
// Accept is used when conversion from values given by
|
||||
// outside sources (such as JSON payloads) is required
|
||||
func (keyType *KeyType) Accept(value interface{}) error {
|
||||
var tmp KeyType
|
||||
switch x := value.(type) {
|
||||
case string:
|
||||
tmp = KeyType(x)
|
||||
case KeyType:
|
||||
tmp = x
|
||||
default:
|
||||
return errors.Errorf(`invalid type for jwa.KeyType: %T`, value)
|
||||
}
|
||||
_, ok := keyTypeAlg[tmp.String()]
|
||||
if !ok {
|
||||
return errors.Errorf("Unknown Key Type algorithm")
|
||||
}
|
||||
|
||||
*keyType = tmp
|
||||
return nil
|
||||
}
|
||||
|
||||
// String returns the string representation of a KeyType
|
||||
func (keyType KeyType) String() string {
|
||||
return string(keyType)
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals and checks data as KeyType Algorithm
|
||||
func (keyType *KeyType) UnmarshalJSON(data []byte) error {
|
||||
var quote byte = '"'
|
||||
var quoted string
|
||||
if data[0] == quote {
|
||||
var err error
|
||||
quoted, err = strconv.Unquote(string(data))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to process signature algorithm")
|
||||
}
|
||||
} else {
|
||||
quoted = string(data)
|
||||
}
|
||||
_, ok := keyTypeAlg[quoted]
|
||||
if !ok {
|
||||
return errors.Errorf("Unknown signature algorithm")
|
||||
}
|
||||
*keyType = KeyType(quoted)
|
||||
return nil
|
||||
}
|
||||
29
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwa/parameters.go
generated
vendored
Normal file
29
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwa/parameters.go
generated
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
package jwa
|
||||
|
||||
import (
|
||||
"crypto/elliptic"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/buffer"
|
||||
)
|
||||
|
||||
// EllipticCurve provides a indirect type to standard elliptic curve such that we can
|
||||
// use it for unmarshal
|
||||
type EllipticCurve struct {
|
||||
elliptic.Curve
|
||||
}
|
||||
|
||||
// AlgorithmParameters provides a single structure suitable to unmarshaling any JWK
|
||||
type AlgorithmParameters struct {
|
||||
N buffer.Buffer `json:"n,omitempty"`
|
||||
E buffer.Buffer `json:"e,omitempty"`
|
||||
D buffer.Buffer `json:"d,omitempty"`
|
||||
P buffer.Buffer `json:"p,omitempty"`
|
||||
Q buffer.Buffer `json:"q,omitempty"`
|
||||
Dp buffer.Buffer `json:"dp,omitempty"`
|
||||
Dq buffer.Buffer `json:"dq,omitempty"`
|
||||
Qi buffer.Buffer `json:"qi,omitempty"`
|
||||
Crv EllipticCurveAlgorithm `json:"crv,omitempty"`
|
||||
X buffer.Buffer `json:"x,omitempty"`
|
||||
Y buffer.Buffer `json:"y,omitempty"`
|
||||
K buffer.Buffer `json:"k,omitempty"`
|
||||
}
|
||||
76
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwa/signature.go
generated
vendored
Normal file
76
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwa/signature.go
generated
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
package jwa
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// SignatureAlgorithm represents the various signature algorithms as described in https://tools.ietf.org/html/rfc7518#section-3.1
|
||||
type SignatureAlgorithm string
|
||||
|
||||
var signatureAlg = map[string]struct{}{"ES256": {}, "ES384": {}, "ES512": {}, "HS256": {}, "HS384": {}, "HS512": {}, "PS256": {}, "PS384": {}, "PS512": {}, "RS256": {}, "RS384": {}, "RS512": {}, "none": {}}
|
||||
|
||||
// Supported values for SignatureAlgorithm
|
||||
const (
|
||||
ES256 SignatureAlgorithm = "ES256" // ECDSA using P-256 and SHA-256
|
||||
ES384 SignatureAlgorithm = "ES384" // ECDSA using P-384 and SHA-384
|
||||
ES512 SignatureAlgorithm = "ES512" // ECDSA using P-521 and SHA-512
|
||||
HS256 SignatureAlgorithm = "HS256" // HMAC using SHA-256
|
||||
HS384 SignatureAlgorithm = "HS384" // HMAC using SHA-384
|
||||
HS512 SignatureAlgorithm = "HS512" // HMAC using SHA-512
|
||||
NoSignature SignatureAlgorithm = "none"
|
||||
PS256 SignatureAlgorithm = "PS256" // RSASSA-PSS using SHA256 and MGF1-SHA256
|
||||
PS384 SignatureAlgorithm = "PS384" // RSASSA-PSS using SHA384 and MGF1-SHA384
|
||||
PS512 SignatureAlgorithm = "PS512" // RSASSA-PSS using SHA512 and MGF1-SHA512
|
||||
RS256 SignatureAlgorithm = "RS256" // RSASSA-PKCS-v1.5 using SHA-256
|
||||
RS384 SignatureAlgorithm = "RS384" // RSASSA-PKCS-v1.5 using SHA-384
|
||||
RS512 SignatureAlgorithm = "RS512" // RSASSA-PKCS-v1.5 using SHA-512
|
||||
NoValue SignatureAlgorithm = "" // No value is different from none
|
||||
)
|
||||
|
||||
// Accept is used when conversion from values given by
|
||||
// outside sources (such as JSON payloads) is required
|
||||
func (signature *SignatureAlgorithm) Accept(value interface{}) error {
|
||||
var tmp SignatureAlgorithm
|
||||
switch x := value.(type) {
|
||||
case string:
|
||||
tmp = SignatureAlgorithm(x)
|
||||
case SignatureAlgorithm:
|
||||
tmp = x
|
||||
default:
|
||||
return errors.Errorf(`invalid type for jwa.SignatureAlgorithm: %T`, value)
|
||||
}
|
||||
_, ok := signatureAlg[tmp.String()]
|
||||
if !ok {
|
||||
return errors.Errorf("Unknown signature algorithm")
|
||||
}
|
||||
*signature = tmp
|
||||
return nil
|
||||
}
|
||||
|
||||
// String returns the string representation of a SignatureAlgorithm
|
||||
func (signature SignatureAlgorithm) String() string {
|
||||
return string(signature)
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals and checks data as Signature Algorithm
|
||||
func (signature *SignatureAlgorithm) UnmarshalJSON(data []byte) error {
|
||||
var quote byte = '"'
|
||||
var quoted string
|
||||
if data[0] == quote {
|
||||
var err error
|
||||
quoted, err = strconv.Unquote(string(data))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to process signature algorithm")
|
||||
}
|
||||
} else {
|
||||
quoted = string(data)
|
||||
}
|
||||
_, ok := signatureAlg[quoted]
|
||||
if !ok {
|
||||
return errors.Errorf("Unknown signature algorithm")
|
||||
}
|
||||
*signature = SignatureAlgorithm(quoted)
|
||||
return nil
|
||||
}
|
||||
120
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/ecdsa.go
generated
vendored
Normal file
120
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/ecdsa.go
generated
vendored
Normal file
@@ -0,0 +1,120 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"math/big"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
func newECDSAPublicKey(key *ecdsa.PublicKey) (*ECDSAPublicKey, error) {
|
||||
|
||||
var hdr StandardHeaders
|
||||
err := hdr.Set(KeyTypeKey, jwa.EC)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "Failed to set Key Type")
|
||||
}
|
||||
|
||||
return &ECDSAPublicKey{
|
||||
StandardHeaders: &hdr,
|
||||
key: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newECDSAPrivateKey(key *ecdsa.PrivateKey) (*ECDSAPrivateKey, error) {
|
||||
|
||||
var hdr StandardHeaders
|
||||
err := hdr.Set(KeyTypeKey, jwa.EC)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "Failed to set Key Type")
|
||||
}
|
||||
|
||||
return &ECDSAPrivateKey{
|
||||
StandardHeaders: &hdr,
|
||||
key: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Materialize returns the EC-DSA public key represented by this JWK
|
||||
func (k ECDSAPublicKey) Materialize() (interface{}, error) {
|
||||
return k.key, nil
|
||||
}
|
||||
|
||||
// Materialize returns the EC-DSA private key represented by this JWK
|
||||
func (k ECDSAPrivateKey) Materialize() (interface{}, error) {
|
||||
return k.key, nil
|
||||
}
|
||||
|
||||
// GenerateKey creates a ECDSAPublicKey from JWK format
|
||||
func (k *ECDSAPublicKey) GenerateKey(keyJSON *RawKeyJSON) error {
|
||||
|
||||
var x, y big.Int
|
||||
|
||||
if keyJSON.X == nil || keyJSON.Y == nil || keyJSON.Crv == "" {
|
||||
return errors.Errorf("Missing mandatory key parameters X, Y or Crv")
|
||||
}
|
||||
|
||||
x.SetBytes(keyJSON.X.Bytes())
|
||||
y.SetBytes(keyJSON.Y.Bytes())
|
||||
|
||||
var curve elliptic.Curve
|
||||
switch keyJSON.Crv {
|
||||
case jwa.P256:
|
||||
curve = elliptic.P256()
|
||||
case jwa.P384:
|
||||
curve = elliptic.P384()
|
||||
case jwa.P521:
|
||||
curve = elliptic.P521()
|
||||
default:
|
||||
return errors.Errorf(`invalid curve name %s`, keyJSON.Crv)
|
||||
}
|
||||
|
||||
*k = ECDSAPublicKey{
|
||||
StandardHeaders: &keyJSON.StandardHeaders,
|
||||
key: &ecdsa.PublicKey{
|
||||
Curve: curve,
|
||||
X: &x,
|
||||
Y: &y,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateKey creates a ECDSAPrivateKey from JWK format
|
||||
func (k *ECDSAPrivateKey) GenerateKey(keyJSON *RawKeyJSON) error {
|
||||
|
||||
if keyJSON.D == nil {
|
||||
return errors.Errorf("Missing mandatory key parameter D")
|
||||
}
|
||||
eCDSAPublicKey := &ECDSAPublicKey{}
|
||||
err := eCDSAPublicKey.GenerateKey(keyJSON)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, `failed to generate public key`)
|
||||
}
|
||||
dBytes := keyJSON.D.Bytes()
|
||||
// The length of this octet string MUST be ceiling(log-base-2(n)/8)
|
||||
// octets (where n is the order of the curve). This is because the private
|
||||
// key d must be in the interval [1, n-1] so the bitlength of d should be
|
||||
// no larger than the bitlength of n-1. The easiest way to find the octet
|
||||
// length is to take bitlength(n-1), add 7 to force a carry, and shift this
|
||||
// bit sequence right by 3, which is essentially dividing by 8 and adding
|
||||
// 1 if there is any remainder. Thus, the private key value d should be
|
||||
// output to (bitlength(n-1)+7)>>3 octets.
|
||||
n := eCDSAPublicKey.key.Params().N
|
||||
octetLength := (new(big.Int).Sub(n, big.NewInt(1)).BitLen() + 7) >> 3
|
||||
if octetLength-len(dBytes) != 0 {
|
||||
return errors.Errorf("Failed to generate private key. Incorrect D value")
|
||||
}
|
||||
privateKey := &ecdsa.PrivateKey{
|
||||
PublicKey: *eCDSAPublicKey.key,
|
||||
D: (&big.Int{}).SetBytes(keyJSON.D.Bytes()),
|
||||
}
|
||||
|
||||
k.key = privateKey
|
||||
k.StandardHeaders = &keyJSON.StandardHeaders
|
||||
|
||||
return nil
|
||||
}
|
||||
178
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/headers.go
generated
vendored
Normal file
178
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/headers.go
generated
vendored
Normal file
@@ -0,0 +1,178 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
// Convenience constants for common JWK parameters
|
||||
const (
|
||||
AlgorithmKey = "alg"
|
||||
KeyIDKey = "kid"
|
||||
KeyOpsKey = "key_ops"
|
||||
KeyTypeKey = "kty"
|
||||
KeyUsageKey = "use"
|
||||
PrivateParamsKey = "privateParams"
|
||||
)
|
||||
|
||||
// Headers provides a common interface to all future possible headers
|
||||
type Headers interface {
|
||||
Get(string) (interface{}, bool)
|
||||
Set(string, interface{}) error
|
||||
Walk(func(string, interface{}) error) error
|
||||
GetAlgorithm() jwa.SignatureAlgorithm
|
||||
GetKeyID() string
|
||||
GetKeyOps() KeyOperationList
|
||||
GetKeyType() jwa.KeyType
|
||||
GetKeyUsage() string
|
||||
GetPrivateParams() map[string]interface{}
|
||||
}
|
||||
|
||||
// StandardHeaders stores the common JWK parameters
|
||||
type StandardHeaders struct {
|
||||
Algorithm *jwa.SignatureAlgorithm `json:"alg,omitempty"` // https://tools.ietf.org/html/rfc7517#section-4.4
|
||||
KeyID string `json:"kid,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.4
|
||||
KeyOps KeyOperationList `json:"key_ops,omitempty"` // https://tools.ietf.org/html/rfc7517#section-4.3
|
||||
KeyType jwa.KeyType `json:"kty,omitempty"` // https://tools.ietf.org/html/rfc7517#section-4.1
|
||||
KeyUsage string `json:"use,omitempty"` // https://tools.ietf.org/html/rfc7517#section-4.2
|
||||
PrivateParams map[string]interface{} `json:"privateParams,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.4
|
||||
}
|
||||
|
||||
// GetAlgorithm is a convenience function to retrieve the corresponding value stored in the StandardHeaders
|
||||
func (h *StandardHeaders) GetAlgorithm() jwa.SignatureAlgorithm {
|
||||
if v := h.Algorithm; v != nil {
|
||||
return *v
|
||||
}
|
||||
return jwa.NoValue
|
||||
}
|
||||
|
||||
// GetKeyID is a convenience function to retrieve the corresponding value stored in the StandardHeaders
|
||||
func (h *StandardHeaders) GetKeyID() string {
|
||||
return h.KeyID
|
||||
}
|
||||
|
||||
// GetKeyOps is a convenience function to retrieve the corresponding value stored in the StandardHeaders
|
||||
func (h *StandardHeaders) GetKeyOps() KeyOperationList {
|
||||
return h.KeyOps
|
||||
}
|
||||
|
||||
// GetKeyType is a convenience function to retrieve the corresponding value stored in the StandardHeaders
|
||||
func (h *StandardHeaders) GetKeyType() jwa.KeyType {
|
||||
return h.KeyType
|
||||
}
|
||||
|
||||
// GetKeyUsage is a convenience function to retrieve the corresponding value stored in the StandardHeaders
|
||||
func (h *StandardHeaders) GetKeyUsage() string {
|
||||
return h.KeyUsage
|
||||
}
|
||||
|
||||
// GetPrivateParams is a convenience function to retrieve the corresponding value stored in the StandardHeaders
|
||||
func (h *StandardHeaders) GetPrivateParams() map[string]interface{} {
|
||||
return h.PrivateParams
|
||||
}
|
||||
|
||||
// Get is a general getter function for JWK StandardHeaders structure
|
||||
func (h *StandardHeaders) Get(name string) (interface{}, bool) {
|
||||
switch name {
|
||||
case AlgorithmKey:
|
||||
alg := h.GetAlgorithm()
|
||||
if alg != jwa.NoValue {
|
||||
return alg, true
|
||||
}
|
||||
return nil, false
|
||||
case KeyIDKey:
|
||||
v := h.KeyID
|
||||
if v == "" {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case KeyOpsKey:
|
||||
v := h.KeyOps
|
||||
if v == nil {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case KeyTypeKey:
|
||||
v := h.KeyType
|
||||
if v == jwa.InvalidKeyType {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case KeyUsageKey:
|
||||
v := h.KeyUsage
|
||||
if v == "" {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case PrivateParamsKey:
|
||||
v := h.PrivateParams
|
||||
if len(v) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
// Set is a general getter function for JWK StandardHeaders structure
|
||||
func (h *StandardHeaders) Set(name string, value interface{}) error {
|
||||
switch name {
|
||||
case AlgorithmKey:
|
||||
var acceptor jwa.SignatureAlgorithm
|
||||
if err := acceptor.Accept(value); err != nil {
|
||||
return errors.Wrapf(err, `invalid value for %s key`, AlgorithmKey)
|
||||
}
|
||||
h.Algorithm = &acceptor
|
||||
return nil
|
||||
case KeyIDKey:
|
||||
if v, ok := value.(string); ok {
|
||||
h.KeyID = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf("invalid value for %s key: %T", KeyIDKey, value)
|
||||
case KeyOpsKey:
|
||||
if err := h.KeyOps.Accept(value); err != nil {
|
||||
return errors.Wrapf(err, "invalid value for %s key", KeyOpsKey)
|
||||
}
|
||||
return nil
|
||||
case KeyTypeKey:
|
||||
if err := h.KeyType.Accept(value); err != nil {
|
||||
return errors.Wrapf(err, "invalid value for %s key", KeyTypeKey)
|
||||
}
|
||||
return nil
|
||||
case KeyUsageKey:
|
||||
if v, ok := value.(string); ok {
|
||||
h.KeyUsage = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf("invalid value for %s key: %T", KeyUsageKey, value)
|
||||
case PrivateParamsKey:
|
||||
if v, ok := value.(map[string]interface{}); ok {
|
||||
h.PrivateParams = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf("invalid value for %s key: %T", PrivateParamsKey, value)
|
||||
default:
|
||||
return errors.Errorf(`invalid key: %s`, name)
|
||||
}
|
||||
}
|
||||
|
||||
// Walk iterates over all JWK standard headers fields while applying a function to its value.
|
||||
func (h StandardHeaders) Walk(f func(string, interface{}) error) error {
|
||||
for _, key := range []string{AlgorithmKey, KeyIDKey, KeyOpsKey, KeyTypeKey, KeyUsageKey, PrivateParamsKey} {
|
||||
if v, ok := h.Get(key); ok {
|
||||
if err := f(key, v); err != nil {
|
||||
return errors.Wrapf(err, `walk function returned error for %s`, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range h.PrivateParams {
|
||||
if err := f(k, v); err != nil {
|
||||
return errors.Wrapf(err, `walk function returned error for %s`, k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
70
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/interface.go
generated
vendored
Normal file
70
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/interface.go
generated
vendored
Normal file
@@ -0,0 +1,70 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
// Set is a convenience struct to allow generating and parsing
|
||||
// JWK sets as opposed to single JWKs
|
||||
type Set struct {
|
||||
Keys []Key `json:"keys"`
|
||||
}
|
||||
|
||||
// Key defines the minimal interface for each of the
|
||||
// key types. Their use and implementation differ significantly
|
||||
// between each key types, so you should use type assertions
|
||||
// to perform more specific tasks with each key
|
||||
type Key interface {
|
||||
Headers
|
||||
|
||||
// Materialize creates the corresponding key. For example,
|
||||
// RSA types would create *rsa.PublicKey or *rsa.PrivateKey,
|
||||
// EC types would create *ecdsa.PublicKey or *ecdsa.PrivateKey,
|
||||
// and OctetSeq types create a []byte key.
|
||||
Materialize() (interface{}, error)
|
||||
GenerateKey(*RawKeyJSON) error
|
||||
}
|
||||
|
||||
// RawKeyJSON is generic type that represents any kind JWK
|
||||
type RawKeyJSON struct {
|
||||
StandardHeaders
|
||||
jwa.AlgorithmParameters
|
||||
}
|
||||
|
||||
// RawKeySetJSON is generic type that represents a JWK Set
|
||||
type RawKeySetJSON struct {
|
||||
Keys []RawKeyJSON `json:"keys"`
|
||||
}
|
||||
|
||||
// RSAPublicKey is a type of JWK generated from RSA public keys
|
||||
type RSAPublicKey struct {
|
||||
*StandardHeaders
|
||||
key *rsa.PublicKey
|
||||
}
|
||||
|
||||
// RSAPrivateKey is a type of JWK generated from RSA private keys
|
||||
type RSAPrivateKey struct {
|
||||
*StandardHeaders
|
||||
key *rsa.PrivateKey
|
||||
}
|
||||
|
||||
// SymmetricKey is a type of JWK generated from symmetric keys
|
||||
type SymmetricKey struct {
|
||||
*StandardHeaders
|
||||
key []byte
|
||||
}
|
||||
|
||||
// ECDSAPublicKey is a type of JWK generated from ECDSA public keys
|
||||
type ECDSAPublicKey struct {
|
||||
*StandardHeaders
|
||||
key *ecdsa.PublicKey
|
||||
}
|
||||
|
||||
// ECDSAPrivateKey is a type of JWK generated from ECDH-ES private keys
|
||||
type ECDSAPrivateKey struct {
|
||||
*StandardHeaders
|
||||
key *ecdsa.PrivateKey
|
||||
}
|
||||
150
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/jwk.go
generated
vendored
Normal file
150
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/jwk.go
generated
vendored
Normal file
@@ -0,0 +1,150 @@
|
||||
// Package jwk implements JWK as described in https://tools.ietf.org/html/rfc7517
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
// GetPublicKey returns the public key based on the private key type.
|
||||
// For rsa key types *rsa.PublicKey is returned; for ecdsa key types *ecdsa.PublicKey;
|
||||
// for byte slice (raw) keys, the key itself is returned. If the corresponding
|
||||
// public key cannot be deduced, an error is returned
|
||||
func GetPublicKey(key interface{}) (interface{}, error) {
|
||||
if key == nil {
|
||||
return nil, errors.New(`jwk.New requires a non-nil key`)
|
||||
}
|
||||
|
||||
switch v := key.(type) {
|
||||
// Mental note: although Public() is defined in both types,
|
||||
// you can not coalesce the clauses for rsa.PrivateKey and
|
||||
// ecdsa.PrivateKey, as then `v` becomes interface{}
|
||||
// b/c the compiler cannot deduce the exact type.
|
||||
case *rsa.PrivateKey:
|
||||
return v.Public(), nil
|
||||
case *ecdsa.PrivateKey:
|
||||
return v.Public(), nil
|
||||
case []byte:
|
||||
return v, nil
|
||||
default:
|
||||
return nil, errors.Errorf(`invalid key type %T`, key)
|
||||
}
|
||||
}
|
||||
|
||||
// GetKeyTypeFromKey creates a jwk.Key from the given key.
|
||||
func GetKeyTypeFromKey(key interface{}) jwa.KeyType {
|
||||
|
||||
switch key.(type) {
|
||||
case *rsa.PrivateKey, *rsa.PublicKey:
|
||||
return jwa.RSA
|
||||
case *ecdsa.PrivateKey, *ecdsa.PublicKey:
|
||||
return jwa.EC
|
||||
case []byte:
|
||||
return jwa.OctetSeq
|
||||
default:
|
||||
return jwa.InvalidKeyType
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a jwk.Key from the given key.
|
||||
func New(key interface{}) (Key, error) {
|
||||
if key == nil {
|
||||
return nil, errors.New(`jwk.New requires a non-nil key`)
|
||||
}
|
||||
|
||||
switch v := key.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return newRSAPrivateKey(v)
|
||||
case *rsa.PublicKey:
|
||||
return newRSAPublicKey(v)
|
||||
case *ecdsa.PrivateKey:
|
||||
return newECDSAPrivateKey(v)
|
||||
case *ecdsa.PublicKey:
|
||||
return newECDSAPublicKey(v)
|
||||
case []byte:
|
||||
return newSymmetricKey(v)
|
||||
default:
|
||||
return nil, errors.Errorf(`invalid key type %T`, key)
|
||||
}
|
||||
}
|
||||
|
||||
func parse(jwkSrc string) (*Set, error) {
|
||||
|
||||
var jwkKeySet Set
|
||||
var jwkKey Key
|
||||
rawKeySetJSON := &RawKeySetJSON{}
|
||||
err := json.Unmarshal([]byte(jwkSrc), rawKeySetJSON)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to unmarshal JWK Set")
|
||||
}
|
||||
if len(rawKeySetJSON.Keys) == 0 {
|
||||
|
||||
// It might be a single key
|
||||
rawKeyJSON := &RawKeyJSON{}
|
||||
err := json.Unmarshal([]byte(jwkSrc), rawKeyJSON)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to unmarshal JWK")
|
||||
}
|
||||
jwkKey, err = rawKeyJSON.GenerateKey()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to generate key")
|
||||
}
|
||||
// Add to set
|
||||
jwkKeySet.Keys = append(jwkKeySet.Keys, jwkKey)
|
||||
} else {
|
||||
for i := range rawKeySetJSON.Keys {
|
||||
rawKeyJSON := rawKeySetJSON.Keys[i]
|
||||
jwkKey, err = rawKeyJSON.GenerateKey()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to generate key: %s")
|
||||
}
|
||||
jwkKeySet.Keys = append(jwkKeySet.Keys, jwkKey)
|
||||
}
|
||||
}
|
||||
return &jwkKeySet, nil
|
||||
}
|
||||
|
||||
// ParseBytes parses JWK from the incoming byte buffer.
|
||||
func ParseBytes(buf []byte) (*Set, error) {
|
||||
return parse(string(buf[:]))
|
||||
}
|
||||
|
||||
// ParseString parses JWK from the incoming string.
|
||||
func ParseString(s string) (*Set, error) {
|
||||
return parse(s)
|
||||
}
|
||||
|
||||
// GenerateKey creates an internal representation of a key from a raw JWK JSON
|
||||
func (r *RawKeyJSON) GenerateKey() (Key, error) {
|
||||
|
||||
var key Key
|
||||
|
||||
switch r.KeyType {
|
||||
case jwa.RSA:
|
||||
if r.D != nil {
|
||||
key = &RSAPrivateKey{}
|
||||
} else {
|
||||
key = &RSAPublicKey{}
|
||||
}
|
||||
case jwa.EC:
|
||||
if r.D != nil {
|
||||
key = &ECDSAPrivateKey{}
|
||||
} else {
|
||||
key = &ECDSAPublicKey{}
|
||||
}
|
||||
case jwa.OctetSeq:
|
||||
key = &SymmetricKey{}
|
||||
default:
|
||||
return nil, errors.Errorf(`Unrecognized key type`)
|
||||
}
|
||||
err := key.GenerateKey(r)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to generate key from JWK")
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
68
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/key_ops.go
generated
vendored
Normal file
68
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/key_ops.go
generated
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// KeyUsageType is used to denote what this key should be used for
|
||||
type KeyUsageType string
|
||||
|
||||
const (
|
||||
// ForSignature is the value used in the headers to indicate that
|
||||
// this key should be used for signatures
|
||||
ForSignature KeyUsageType = "sig"
|
||||
// ForEncryption is the value used in the headers to indicate that
|
||||
// this key should be used for encryptiong
|
||||
ForEncryption KeyUsageType = "enc"
|
||||
)
|
||||
|
||||
// KeyOperation is used to denote the allowed operations for a Key
|
||||
type KeyOperation string
|
||||
|
||||
// KeyOperationList represents an slice of KeyOperation
|
||||
type KeyOperationList []KeyOperation
|
||||
|
||||
var keyOps = map[string]struct{}{"sign": {}, "verify": {}, "encrypt": {}, "decrypt": {}, "wrapKey": {}, "unwrapKey": {}, "deriveKey": {}, "deriveBits": {}}
|
||||
|
||||
// KeyOperation constants
|
||||
const (
|
||||
KeyOpSign KeyOperation = "sign" // (compute digital signature or MAC)
|
||||
KeyOpVerify = "verify" // (verify digital signature or MAC)
|
||||
KeyOpEncrypt = "encrypt" // (encrypt content)
|
||||
KeyOpDecrypt = "decrypt" // (decrypt content and validate decryption, if applicable)
|
||||
KeyOpWrapKey = "wrapKey" // (encrypt key)
|
||||
KeyOpUnwrapKey = "unwrapKey" // (decrypt key and validate decryption, if applicable)
|
||||
KeyOpDeriveKey = "deriveKey" // (derive key)
|
||||
KeyOpDeriveBits = "deriveBits" // (derive bits not to be used as a key)
|
||||
)
|
||||
|
||||
// Accept determines if Key Operation is valid
|
||||
func (keyOperationList *KeyOperationList) Accept(v interface{}) error {
|
||||
switch x := v.(type) {
|
||||
case KeyOperationList:
|
||||
*keyOperationList = x
|
||||
return nil
|
||||
default:
|
||||
return errors.Errorf(`invalid value %T`, v)
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals and checks data as KeyType Algorithm
|
||||
func (keyOperationList *KeyOperationList) UnmarshalJSON(data []byte) error {
|
||||
var tempKeyOperationList []string
|
||||
err := json.Unmarshal(data, &tempKeyOperationList)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid key operation")
|
||||
}
|
||||
for _, value := range tempKeyOperationList {
|
||||
_, ok := keyOps[value]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown key operation")
|
||||
}
|
||||
*keyOperationList = append(*keyOperationList, KeyOperation(value))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
103
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/rsa.go
generated
vendored
Normal file
103
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/rsa.go
generated
vendored
Normal file
@@ -0,0 +1,103 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
func newRSAPublicKey(key *rsa.PublicKey) (*RSAPublicKey, error) {
|
||||
|
||||
var hdr StandardHeaders
|
||||
err := hdr.Set(KeyTypeKey, jwa.RSA)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "Failed to set Key Type")
|
||||
}
|
||||
return &RSAPublicKey{
|
||||
StandardHeaders: &hdr,
|
||||
key: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newRSAPrivateKey(key *rsa.PrivateKey) (*RSAPrivateKey, error) {
|
||||
|
||||
var hdr StandardHeaders
|
||||
err := hdr.Set(KeyTypeKey, jwa.RSA)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "Failed to set Key Type")
|
||||
}
|
||||
return &RSAPrivateKey{
|
||||
StandardHeaders: &hdr,
|
||||
key: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Materialize returns the standard RSA Public Key representation stored in the internal representation
|
||||
func (k *RSAPublicKey) Materialize() (interface{}, error) {
|
||||
if k.key == nil {
|
||||
return nil, errors.New(`key has no rsa.PublicKey associated with it`)
|
||||
}
|
||||
return k.key, nil
|
||||
}
|
||||
|
||||
// Materialize returns the standard RSA Private Key representation stored in the internal representation
|
||||
func (k *RSAPrivateKey) Materialize() (interface{}, error) {
|
||||
if k.key == nil {
|
||||
return nil, errors.New(`key has no rsa.PrivateKey associated with it`)
|
||||
}
|
||||
return k.key, nil
|
||||
}
|
||||
|
||||
// GenerateKey creates a RSAPublicKey from a RawKeyJSON
|
||||
func (k *RSAPublicKey) GenerateKey(keyJSON *RawKeyJSON) error {
|
||||
|
||||
if keyJSON.N == nil || keyJSON.E == nil {
|
||||
return errors.Errorf("Missing mandatory key parameters N or E")
|
||||
}
|
||||
rsaPublicKey := &rsa.PublicKey{
|
||||
N: (&big.Int{}).SetBytes(keyJSON.N.Bytes()),
|
||||
E: int((&big.Int{}).SetBytes(keyJSON.E.Bytes()).Int64()),
|
||||
}
|
||||
k.key = rsaPublicKey
|
||||
k.StandardHeaders = &keyJSON.StandardHeaders
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateKey creates a RSAPublicKey from a RawKeyJSON
|
||||
func (k *RSAPrivateKey) GenerateKey(keyJSON *RawKeyJSON) error {
|
||||
|
||||
rsaPublicKey := &RSAPublicKey{}
|
||||
err := rsaPublicKey.GenerateKey(keyJSON)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to generate public key")
|
||||
}
|
||||
|
||||
if keyJSON.D == nil || keyJSON.P == nil || keyJSON.Q == nil {
|
||||
return errors.Errorf("Missing mandatory key parameters D, P or Q")
|
||||
}
|
||||
privateKey := &rsa.PrivateKey{
|
||||
PublicKey: *rsaPublicKey.key,
|
||||
D: (&big.Int{}).SetBytes(keyJSON.D.Bytes()),
|
||||
Primes: []*big.Int{
|
||||
(&big.Int{}).SetBytes(keyJSON.P.Bytes()),
|
||||
(&big.Int{}).SetBytes(keyJSON.Q.Bytes()),
|
||||
},
|
||||
}
|
||||
|
||||
if keyJSON.Dp.Len() > 0 {
|
||||
privateKey.Precomputed.Dp = (&big.Int{}).SetBytes(keyJSON.Dp.Bytes())
|
||||
}
|
||||
if keyJSON.Dq.Len() > 0 {
|
||||
privateKey.Precomputed.Dq = (&big.Int{}).SetBytes(keyJSON.Dq.Bytes())
|
||||
}
|
||||
if keyJSON.Qi.Len() > 0 {
|
||||
privateKey.Precomputed.Qinv = (&big.Int{}).SetBytes(keyJSON.Qi.Bytes())
|
||||
}
|
||||
|
||||
k.key = privateKey
|
||||
k.StandardHeaders = &keyJSON.StandardHeaders
|
||||
return nil
|
||||
}
|
||||
41
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/symmetric.go
generated
vendored
Normal file
41
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jwk/symmetric.go
generated
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
func newSymmetricKey(key []byte) (*SymmetricKey, error) {
|
||||
var hdr StandardHeaders
|
||||
|
||||
err := hdr.Set(KeyTypeKey, jwa.OctetSeq)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "Failed to set Key Type")
|
||||
}
|
||||
return &SymmetricKey{
|
||||
StandardHeaders: &hdr,
|
||||
key: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Materialize returns the octets for this symmetric key.
|
||||
// Since this is a symmetric key, this just calls Octets
|
||||
func (s SymmetricKey) Materialize() (interface{}, error) {
|
||||
return s.Octets(), nil
|
||||
}
|
||||
|
||||
// Octets returns the octets in the key
|
||||
func (s SymmetricKey) Octets() []byte {
|
||||
return s.key
|
||||
}
|
||||
|
||||
// GenerateKey creates a Symmetric key from a RawKeyJSON
|
||||
func (s *SymmetricKey) GenerateKey(keyJSON *RawKeyJSON) error {
|
||||
|
||||
*s = SymmetricKey{
|
||||
StandardHeaders: &keyJSON.StandardHeaders,
|
||||
key: keyJSON.K,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
154
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/headers.go
generated
vendored
Normal file
154
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/headers.go
generated
vendored
Normal file
@@ -0,0 +1,154 @@
|
||||
package jws
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
// Constants for JWS Common parameters
|
||||
const (
|
||||
AlgorithmKey = "alg"
|
||||
ContentTypeKey = "cty"
|
||||
CriticalKey = "crit"
|
||||
JWKKey = "jwk"
|
||||
JWKSetURLKey = "jku"
|
||||
KeyIDKey = "kid"
|
||||
PrivateParamsKey = "privateParams"
|
||||
TypeKey = "typ"
|
||||
)
|
||||
|
||||
// Headers provides a common interface for common header parameters
|
||||
type Headers interface {
|
||||
Get(string) (interface{}, bool)
|
||||
Set(string, interface{}) error
|
||||
GetAlgorithm() jwa.SignatureAlgorithm
|
||||
}
|
||||
|
||||
// StandardHeaders contains JWS common parameters.
|
||||
type StandardHeaders struct {
|
||||
Algorithm jwa.SignatureAlgorithm `json:"alg,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.1
|
||||
ContentType string `json:"cty,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.10
|
||||
Critical []string `json:"crit,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.11
|
||||
JWK string `json:"jwk,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.3
|
||||
JWKSetURL string `json:"jku,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.2
|
||||
KeyID string `json:"kid,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.4
|
||||
PrivateParams map[string]interface{} `json:"privateParams,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.9
|
||||
Type string `json:"typ,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.9
|
||||
}
|
||||
|
||||
// GetAlgorithm returns algorithm
|
||||
func (h *StandardHeaders) GetAlgorithm() jwa.SignatureAlgorithm {
|
||||
return h.Algorithm
|
||||
}
|
||||
|
||||
// Get is a general getter function for StandardHeaders structure
|
||||
func (h *StandardHeaders) Get(name string) (interface{}, bool) {
|
||||
switch name {
|
||||
case AlgorithmKey:
|
||||
v := h.Algorithm
|
||||
if v == "" {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case ContentTypeKey:
|
||||
v := h.ContentType
|
||||
if v == "" {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case CriticalKey:
|
||||
v := h.Critical
|
||||
if len(v) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case JWKKey:
|
||||
v := h.JWK
|
||||
if v == "" {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case JWKSetURLKey:
|
||||
v := h.JWKSetURL
|
||||
if v == "" {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case KeyIDKey:
|
||||
v := h.KeyID
|
||||
if v == "" {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case PrivateParamsKey:
|
||||
v := h.PrivateParams
|
||||
if len(v) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
case TypeKey:
|
||||
v := h.Type
|
||||
if v == "" {
|
||||
return nil, false
|
||||
}
|
||||
return v, true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
// Set is a general setter function for StandardHeaders structure
|
||||
func (h *StandardHeaders) Set(name string, value interface{}) error {
|
||||
switch name {
|
||||
case AlgorithmKey:
|
||||
if err := h.Algorithm.Accept(value); err != nil {
|
||||
return errors.Wrapf(err, `invalid value for %s key`, AlgorithmKey)
|
||||
}
|
||||
return nil
|
||||
case ContentTypeKey:
|
||||
if v, ok := value.(string); ok {
|
||||
h.ContentType = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf(`invalid value for %s key: %T`, ContentTypeKey, value)
|
||||
case CriticalKey:
|
||||
if v, ok := value.([]string); ok {
|
||||
h.Critical = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf(`invalid value for %s key: %T`, CriticalKey, value)
|
||||
case JWKKey:
|
||||
if v, ok := value.(string); ok {
|
||||
h.JWK = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf(`invalid value for %s key: %T`, JWKKey, value)
|
||||
case JWKSetURLKey:
|
||||
if v, ok := value.(string); ok {
|
||||
h.JWKSetURL = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf(`invalid value for %s key: %T`, JWKSetURLKey, value)
|
||||
case KeyIDKey:
|
||||
if v, ok := value.(string); ok {
|
||||
h.KeyID = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf(`invalid value for %s key: %T`, KeyIDKey, value)
|
||||
case PrivateParamsKey:
|
||||
if v, ok := value.(map[string]interface{}); ok {
|
||||
h.PrivateParams = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf(`invalid value for %s key: %T`, PrivateParamsKey, value)
|
||||
case TypeKey:
|
||||
if v, ok := value.(string); ok {
|
||||
h.Type = v
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf(`invalid value for %s key: %T`, TypeKey, value)
|
||||
default:
|
||||
return errors.Errorf(`invalid key: %s`, name)
|
||||
}
|
||||
}
|
||||
22
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/interface.go
generated
vendored
Normal file
22
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/interface.go
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
package jws
|
||||
|
||||
// Message represents a full JWS encoded message. Flattened serialization
|
||||
// is not supported as a struct, but rather it's represented as a
|
||||
// Message struct with only one `Signature` element.
|
||||
//
|
||||
// Do not expect to use the Message object to verify or construct a
|
||||
// signed payloads with. You should only use this when you want to actually
|
||||
// want to programmatically view the contents for the full JWS Payload.
|
||||
//
|
||||
// To sign and verify, use the appropriate `SignWithOption()` nad `Verify()` functions
|
||||
type Message struct {
|
||||
Payload []byte `json:"payload"`
|
||||
Signatures []*Signature `json:"signatures,omitempty"`
|
||||
}
|
||||
|
||||
// Signature represents the headers and signature of a JWS message
|
||||
type Signature struct {
|
||||
Headers Headers `json:"header,omitempty"` // Unprotected Headers
|
||||
Protected Headers `json:"Protected,omitempty"` // Protected Headers
|
||||
Signature []byte `json:"signature,omitempty"` // GetSignature
|
||||
}
|
||||
210
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/jws.go
generated
vendored
Normal file
210
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/jws.go
generated
vendored
Normal file
@@ -0,0 +1,210 @@
|
||||
// Package jws implements the digital Signature on JSON based data
|
||||
// structures as described in https://tools.ietf.org/html/rfc7515
|
||||
//
|
||||
// If you do not care about the details, the only things that you
|
||||
// would need to use are the following functions:
|
||||
//
|
||||
// jws.SignWithOption(Payload, algorithm, key)
|
||||
// jws.Verify(encodedjws, algorithm, key)
|
||||
//
|
||||
// To sign, simply use `jws.SignWithOption`. `Payload` is a []byte buffer that
|
||||
// contains whatever data you want to sign. `alg` is one of the
|
||||
// jwa.SignatureAlgorithm constants from package jwa. For RSA and
|
||||
// ECDSA family of algorithms, you will need to prepare a private key.
|
||||
// For HMAC family, you just need a []byte value. The `jws.SignWithOption`
|
||||
// function will return the encoded JWS message on success.
|
||||
//
|
||||
// To verify, use `jws.Verify`. It will parse the `encodedjws` buffer
|
||||
// and verify the result using `algorithm` and `key`. Upon successful
|
||||
// verification, the original Payload is returned, so you can work on it.
|
||||
package jws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwk"
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign"
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// SignLiteral generates a Signature for the given Payload and Headers, and serializes
|
||||
// it in compact serialization format. In this format you may NOT use
|
||||
// multiple signers.
|
||||
//
|
||||
func SignLiteral(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, hdrBuf []byte) ([]byte, error) {
|
||||
encodedHdr := base64.RawURLEncoding.EncodeToString(hdrBuf)
|
||||
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
|
||||
signingInput := strings.Join(
|
||||
[]string{
|
||||
encodedHdr,
|
||||
encodedPayload,
|
||||
}, ".",
|
||||
)
|
||||
signer, err := sign.New(alg)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, `failed to create signer`)
|
||||
}
|
||||
signature, err := signer.Sign([]byte(signingInput), key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, `failed to sign Payload`)
|
||||
}
|
||||
encodedSignature := base64.RawURLEncoding.EncodeToString(signature)
|
||||
compactSerialization := strings.Join(
|
||||
[]string{
|
||||
signingInput,
|
||||
encodedSignature,
|
||||
}, ".",
|
||||
)
|
||||
return []byte(compactSerialization), nil
|
||||
}
|
||||
|
||||
// SignWithOption generates a Signature for the given Payload, and serializes
|
||||
// it in compact serialization format. In this format you may NOT use
|
||||
// multiple signers.
|
||||
//
|
||||
// If you would like to pass custom Headers, use the WithHeaders option.
|
||||
func SignWithOption(payload []byte, alg jwa.SignatureAlgorithm, key interface{}) ([]byte, error) {
|
||||
var headers Headers = &StandardHeaders{}
|
||||
|
||||
err := headers.Set(AlgorithmKey, alg)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to set alg value")
|
||||
}
|
||||
|
||||
hdrBuf, err := json.Marshal(headers)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, `failed to marshal Headers`)
|
||||
}
|
||||
return SignLiteral(payload, alg, key, hdrBuf)
|
||||
}
|
||||
|
||||
// Verify checks if the given JWS message is verifiable using `alg` and `key`.
|
||||
// If the verification is successful, `err` is nil, and the content of the
|
||||
// Payload that was signed is returned. If you need more fine-grained
|
||||
// control of the verification process, manually call `Parse`, generate a
|
||||
// verifier, and call `Verify` on the parsed JWS message object.
|
||||
func Verify(buf []byte, alg jwa.SignatureAlgorithm, key interface{}) (ret []byte, err error) {
|
||||
|
||||
verifier, err := verify.New(alg)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create verifier")
|
||||
}
|
||||
|
||||
buf = bytes.TrimSpace(buf)
|
||||
if len(buf) == 0 {
|
||||
return nil, errors.New(`attempt to verify empty buffer`)
|
||||
}
|
||||
|
||||
parts, err := SplitCompact(string(buf[:]))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, `failed extract from compact serialization format`)
|
||||
}
|
||||
|
||||
signingInput := strings.Join(
|
||||
[]string{
|
||||
parts[0],
|
||||
parts[1],
|
||||
}, ".",
|
||||
)
|
||||
|
||||
decodedSignature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to decode signature")
|
||||
}
|
||||
if err := verifier.Verify([]byte(signingInput), decodedSignature, key); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to verify message")
|
||||
}
|
||||
|
||||
if decodedPayload, err := base64.RawURLEncoding.DecodeString(parts[1]); err == nil {
|
||||
return decodedPayload, nil
|
||||
}
|
||||
return nil, errors.Wrap(err, "Failed to decode Payload")
|
||||
}
|
||||
|
||||
// VerifyWithJWK verifies the JWS message using the specified JWK
|
||||
func VerifyWithJWK(buf []byte, key jwk.Key) (payload []byte, err error) {
|
||||
|
||||
keyVal, err := key.Materialize()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to materialize key")
|
||||
}
|
||||
return Verify(buf, key.GetAlgorithm(), keyVal)
|
||||
}
|
||||
|
||||
// VerifyWithJWKSet verifies the JWS message using JWK key set.
|
||||
// By default it will only pick up keys that have the "use" key
|
||||
// set to either "sig" or "enc", but you can override it by
|
||||
// providing a keyaccept function.
|
||||
func VerifyWithJWKSet(buf []byte, keyset *jwk.Set) (payload []byte, err error) {
|
||||
|
||||
for _, key := range keyset.Keys {
|
||||
payload, err := VerifyWithJWK(buf, key)
|
||||
if err == nil {
|
||||
return payload, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("failed to verify with any of the keys")
|
||||
}
|
||||
|
||||
// ParseByte parses a JWS value serialized via compact serialization and provided as []byte.
|
||||
func ParseByte(jwsCompact []byte) (m *Message, err error) {
|
||||
return parseCompact(string(jwsCompact[:]))
|
||||
}
|
||||
|
||||
// ParseString parses a JWS value serialized via compact serialization and provided as string.
|
||||
func ParseString(s string) (*Message, error) {
|
||||
return parseCompact(s)
|
||||
}
|
||||
|
||||
// SplitCompact splits a JWT and returns its three parts
|
||||
// separately: Protected Headers, Payload and Signature.
|
||||
func SplitCompact(jwsCompact string) ([]string, error) {
|
||||
|
||||
parts := strings.Split(jwsCompact, ".")
|
||||
if len(parts) < 3 {
|
||||
return nil, errors.New("Failed to split compact serialization")
|
||||
}
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
// parseCompact parses a JWS value serialized via compact serialization.
|
||||
func parseCompact(str string) (m *Message, err error) {
|
||||
|
||||
var decodedHeader, decodedPayload, decodedSignature []byte
|
||||
parts, err := SplitCompact(str)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, `invalid compact serialization format`)
|
||||
}
|
||||
|
||||
if decodedHeader, err = base64.RawURLEncoding.DecodeString(parts[0]); err != nil {
|
||||
return nil, errors.Wrap(err, `failed to decode Headers`)
|
||||
}
|
||||
var hdr StandardHeaders
|
||||
if err := json.Unmarshal(decodedHeader, &hdr); err != nil {
|
||||
return nil, errors.Wrap(err, `failed to parse JOSE Headers`)
|
||||
}
|
||||
|
||||
if decodedPayload, err = base64.RawURLEncoding.DecodeString(parts[1]); err != nil {
|
||||
return nil, errors.Wrap(err, `failed to decode Payload`)
|
||||
}
|
||||
|
||||
if len(parts) > 2 {
|
||||
if decodedSignature, err = base64.RawURLEncoding.DecodeString(parts[2]); err != nil {
|
||||
return nil, errors.Wrap(err, `failed to decode Signature`)
|
||||
}
|
||||
}
|
||||
|
||||
var msg Message
|
||||
msg.Payload = decodedPayload
|
||||
msg.Signatures = append(msg.Signatures, &Signature{
|
||||
Protected: &hdr,
|
||||
Signature: decodedSignature,
|
||||
})
|
||||
return &msg, nil
|
||||
}
|
||||
26
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/message.go
generated
vendored
Normal file
26
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/message.go
generated
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
package jws
|
||||
|
||||
// PublicHeaders returns the public headers in a JWS
|
||||
func (s Signature) PublicHeaders() Headers {
|
||||
return s.Headers
|
||||
}
|
||||
|
||||
// ProtectedHeaders returns the protected headers in a JWS
|
||||
func (s Signature) ProtectedHeaders() Headers {
|
||||
return s.Protected
|
||||
}
|
||||
|
||||
// GetSignature returns the signature in a JWS
|
||||
func (s Signature) GetSignature() []byte {
|
||||
return s.Signature
|
||||
}
|
||||
|
||||
// GetPayload returns the payload in a JWS
|
||||
func (m Message) GetPayload() []byte {
|
||||
return m.Payload
|
||||
}
|
||||
|
||||
// GetSignatures returns the all signatures in a JWS
|
||||
func (m Message) GetSignatures() []*Signature {
|
||||
return m.Signatures
|
||||
}
|
||||
84
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/ecdsa.go
generated
vendored
Normal file
84
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/ecdsa.go
generated
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
package sign
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var ecdsaSignFuncs = map[jwa.SignatureAlgorithm]ecdsaSignFunc{}
|
||||
|
||||
func init() {
|
||||
algs := map[jwa.SignatureAlgorithm]crypto.Hash{
|
||||
jwa.ES256: crypto.SHA256,
|
||||
jwa.ES384: crypto.SHA384,
|
||||
jwa.ES512: crypto.SHA512,
|
||||
}
|
||||
|
||||
for alg, h := range algs {
|
||||
ecdsaSignFuncs[alg] = makeECDSASignFunc(h)
|
||||
}
|
||||
}
|
||||
|
||||
func makeECDSASignFunc(hash crypto.Hash) ecdsaSignFunc {
|
||||
return ecdsaSignFunc(func(payload []byte, key *ecdsa.PrivateKey) ([]byte, error) {
|
||||
curveBits := key.Curve.Params().BitSize
|
||||
keyBytes := curveBits / 8
|
||||
// Curve bits do not need to be a multiple of 8.
|
||||
if curveBits%8 > 0 {
|
||||
keyBytes++
|
||||
}
|
||||
h := hash.New()
|
||||
h.Write(payload)
|
||||
r, s, err := ecdsa.Sign(rand.Reader, key, h.Sum(nil))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to sign payload using ecdsa")
|
||||
}
|
||||
|
||||
rBytes := r.Bytes()
|
||||
rBytesPadded := make([]byte, keyBytes)
|
||||
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
|
||||
|
||||
sBytes := s.Bytes()
|
||||
sBytesPadded := make([]byte, keyBytes)
|
||||
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
|
||||
|
||||
out := append(rBytesPadded, sBytesPadded...)
|
||||
return out, nil
|
||||
})
|
||||
}
|
||||
|
||||
func newECDSA(alg jwa.SignatureAlgorithm) (*ECDSASigner, error) {
|
||||
signfn, ok := ecdsaSignFuncs[alg]
|
||||
if !ok {
|
||||
return nil, errors.Errorf(`unsupported algorithm while trying to create ECDSA signer: %s`, alg)
|
||||
}
|
||||
|
||||
return &ECDSASigner{
|
||||
alg: alg,
|
||||
sign: signfn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Algorithm returns the signer algorithm
|
||||
func (s ECDSASigner) Algorithm() jwa.SignatureAlgorithm {
|
||||
return s.alg
|
||||
}
|
||||
|
||||
// Sign signs payload with a ECDSA private key
|
||||
func (s ECDSASigner) Sign(payload []byte, key interface{}) ([]byte, error) {
|
||||
if key == nil {
|
||||
return nil, errors.New(`missing private key while signing payload`)
|
||||
}
|
||||
|
||||
privateKey, ok := key.(*ecdsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, errors.Errorf(`invalid key type %T. *ecdsa.PrivateKey is required`, key)
|
||||
}
|
||||
|
||||
return s.sign(payload, privateKey)
|
||||
}
|
||||
66
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/hmac.go
generated
vendored
Normal file
66
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/hmac.go
generated
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
package sign
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"hash"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var hmacSignFuncs = map[jwa.SignatureAlgorithm]hmacSignFunc{}
|
||||
|
||||
func init() {
|
||||
algs := map[jwa.SignatureAlgorithm]func() hash.Hash{
|
||||
jwa.HS256: sha256.New,
|
||||
jwa.HS384: sha512.New384,
|
||||
jwa.HS512: sha512.New,
|
||||
}
|
||||
|
||||
for alg, h := range algs {
|
||||
hmacSignFuncs[alg] = makeHMACSignFunc(h)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func newHMAC(alg jwa.SignatureAlgorithm) (*HMACSigner, error) {
|
||||
signer, ok := hmacSignFuncs[alg]
|
||||
if !ok {
|
||||
return nil, errors.Errorf(`unsupported algorithm while trying to create HMAC signer: %s`, alg)
|
||||
}
|
||||
|
||||
return &HMACSigner{
|
||||
alg: alg,
|
||||
sign: signer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func makeHMACSignFunc(hfunc func() hash.Hash) hmacSignFunc {
|
||||
return hmacSignFunc(func(payload []byte, key []byte) ([]byte, error) {
|
||||
h := hmac.New(hfunc, key)
|
||||
h.Write(payload)
|
||||
return h.Sum(nil), nil
|
||||
})
|
||||
}
|
||||
|
||||
// Algorithm returns the signer algorithm
|
||||
func (s HMACSigner) Algorithm() jwa.SignatureAlgorithm {
|
||||
return s.alg
|
||||
}
|
||||
|
||||
// Sign signs payload with a Symmetric key
|
||||
func (s HMACSigner) Sign(payload []byte, key interface{}) ([]byte, error) {
|
||||
hmackey, ok := key.([]byte)
|
||||
if !ok {
|
||||
return nil, errors.Errorf(`invalid key type %T. []byte is required`, key)
|
||||
}
|
||||
|
||||
if len(hmackey) == 0 {
|
||||
return nil, errors.New(`missing key while signing payload`)
|
||||
}
|
||||
|
||||
return s.sign(payload, hmackey)
|
||||
}
|
||||
45
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/interface.go
generated
vendored
Normal file
45
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/interface.go
generated
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
package sign
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
// Signer provides a common interface for supported alg signing methods
|
||||
type Signer interface {
|
||||
// Sign creates a signature for the given `payload`.
|
||||
// `key` is the key used for signing the payload, and is usually
|
||||
// the private key type associated with the signature method. For example,
|
||||
// for `jwa.RSXXX` and `jwa.PSXXX` types, you need to pass the
|
||||
// `*"crypto/rsa".PrivateKey` type.
|
||||
// Check the documentation for each signer for details
|
||||
Sign(payload []byte, key interface{}) ([]byte, error)
|
||||
|
||||
Algorithm() jwa.SignatureAlgorithm
|
||||
}
|
||||
|
||||
type rsaSignFunc func([]byte, *rsa.PrivateKey) ([]byte, error)
|
||||
|
||||
// RSASigner uses crypto/rsa to sign the payloads.
|
||||
type RSASigner struct {
|
||||
alg jwa.SignatureAlgorithm
|
||||
sign rsaSignFunc
|
||||
}
|
||||
|
||||
type ecdsaSignFunc func([]byte, *ecdsa.PrivateKey) ([]byte, error)
|
||||
|
||||
// ECDSASigner uses crypto/ecdsa to sign the payloads.
|
||||
type ECDSASigner struct {
|
||||
alg jwa.SignatureAlgorithm
|
||||
sign ecdsaSignFunc
|
||||
}
|
||||
|
||||
type hmacSignFunc func([]byte, []byte) ([]byte, error)
|
||||
|
||||
// HMACSigner uses crypto/hmac to sign the payloads.
|
||||
type HMACSigner struct {
|
||||
alg jwa.SignatureAlgorithm
|
||||
sign hmacSignFunc
|
||||
}
|
||||
97
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/rsa.go
generated
vendored
Normal file
97
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/rsa.go
generated
vendored
Normal file
@@ -0,0 +1,97 @@
|
||||
package sign
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var rsaSignFuncs = map[jwa.SignatureAlgorithm]rsaSignFunc{}
|
||||
|
||||
func init() {
|
||||
algs := map[jwa.SignatureAlgorithm]struct {
|
||||
Hash crypto.Hash
|
||||
SignFunc func(crypto.Hash) rsaSignFunc
|
||||
}{
|
||||
jwa.RS256: {
|
||||
Hash: crypto.SHA256,
|
||||
SignFunc: makeSignPKCS1v15,
|
||||
},
|
||||
jwa.RS384: {
|
||||
Hash: crypto.SHA384,
|
||||
SignFunc: makeSignPKCS1v15,
|
||||
},
|
||||
jwa.RS512: {
|
||||
Hash: crypto.SHA512,
|
||||
SignFunc: makeSignPKCS1v15,
|
||||
},
|
||||
jwa.PS256: {
|
||||
Hash: crypto.SHA256,
|
||||
SignFunc: makeSignPSS,
|
||||
},
|
||||
jwa.PS384: {
|
||||
Hash: crypto.SHA384,
|
||||
SignFunc: makeSignPSS,
|
||||
},
|
||||
jwa.PS512: {
|
||||
Hash: crypto.SHA512,
|
||||
SignFunc: makeSignPSS,
|
||||
},
|
||||
}
|
||||
|
||||
for alg, item := range algs {
|
||||
rsaSignFuncs[alg] = item.SignFunc(item.Hash)
|
||||
}
|
||||
}
|
||||
|
||||
func makeSignPKCS1v15(hash crypto.Hash) rsaSignFunc {
|
||||
return rsaSignFunc(func(payload []byte, key *rsa.PrivateKey) ([]byte, error) {
|
||||
h := hash.New()
|
||||
h.Write(payload)
|
||||
return rsa.SignPKCS1v15(rand.Reader, key, hash, h.Sum(nil))
|
||||
})
|
||||
}
|
||||
|
||||
func makeSignPSS(hash crypto.Hash) rsaSignFunc {
|
||||
return rsaSignFunc(func(payload []byte, key *rsa.PrivateKey) ([]byte, error) {
|
||||
h := hash.New()
|
||||
h.Write(payload)
|
||||
return rsa.SignPSS(rand.Reader, key, hash, h.Sum(nil), &rsa.PSSOptions{
|
||||
SaltLength: rsa.PSSSaltLengthAuto,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func newRSA(alg jwa.SignatureAlgorithm) (*RSASigner, error) {
|
||||
signfn, ok := rsaSignFuncs[alg]
|
||||
if !ok {
|
||||
return nil, errors.Errorf(`unsupported algorithm while trying to create RSA signer: %s`, alg)
|
||||
}
|
||||
return &RSASigner{
|
||||
alg: alg,
|
||||
sign: signfn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Algorithm returns the signer algorithm
|
||||
func (s RSASigner) Algorithm() jwa.SignatureAlgorithm {
|
||||
return s.alg
|
||||
}
|
||||
|
||||
// Sign creates a signature using crypto/rsa. key must be a non-nil instance of
|
||||
// `*"crypto/rsa".PrivateKey`.
|
||||
func (s RSASigner) Sign(payload []byte, key interface{}) ([]byte, error) {
|
||||
if key == nil {
|
||||
return nil, errors.New(`missing private key while signing payload`)
|
||||
}
|
||||
rsakey, ok := key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, errors.Errorf(`invalid key type %T. *rsa.PrivateKey is required`, key)
|
||||
}
|
||||
|
||||
return s.sign(payload, rsakey)
|
||||
}
|
||||
21
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/sign.go
generated
vendored
Normal file
21
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign/sign.go
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
package sign
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
// New creates a signer that signs payloads using the given signature algorithm.
|
||||
func New(alg jwa.SignatureAlgorithm) (Signer, error) {
|
||||
switch alg {
|
||||
case jwa.RS256, jwa.RS384, jwa.RS512, jwa.PS256, jwa.PS384, jwa.PS512:
|
||||
return newRSA(alg)
|
||||
case jwa.ES256, jwa.ES384, jwa.ES512:
|
||||
return newECDSA(alg)
|
||||
case jwa.HS256, jwa.HS384, jwa.HS512:
|
||||
return newHMAC(alg)
|
||||
default:
|
||||
return nil, errors.Errorf(`unsupported signature algorithm %s`, alg)
|
||||
}
|
||||
}
|
||||
67
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/ecdsa.go
generated
vendored
Normal file
67
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/ecdsa.go
generated
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
package verify
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"math/big"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
var ecdsaVerifyFuncs = map[jwa.SignatureAlgorithm]ecdsaVerifyFunc{}
|
||||
|
||||
func init() {
|
||||
algs := map[jwa.SignatureAlgorithm]crypto.Hash{
|
||||
jwa.ES256: crypto.SHA256,
|
||||
jwa.ES384: crypto.SHA384,
|
||||
jwa.ES512: crypto.SHA512,
|
||||
}
|
||||
|
||||
for alg, h := range algs {
|
||||
ecdsaVerifyFuncs[alg] = makeECDSAVerifyFunc(h)
|
||||
}
|
||||
}
|
||||
|
||||
func makeECDSAVerifyFunc(hash crypto.Hash) ecdsaVerifyFunc {
|
||||
return ecdsaVerifyFunc(func(payload []byte, signature []byte, key *ecdsa.PublicKey) error {
|
||||
|
||||
r, s := &big.Int{}, &big.Int{}
|
||||
n := len(signature) / 2
|
||||
r.SetBytes(signature[:n])
|
||||
s.SetBytes(signature[n:])
|
||||
|
||||
h := hash.New()
|
||||
h.Write(payload)
|
||||
|
||||
if !ecdsa.Verify(key, h.Sum(nil), r, s) {
|
||||
return errors.New(`failed to verify signature using ecdsa`)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func newECDSA(alg jwa.SignatureAlgorithm) (*ECDSAVerifier, error) {
|
||||
verifyfn, ok := ecdsaVerifyFuncs[alg]
|
||||
if !ok {
|
||||
return nil, errors.Errorf(`unsupported algorithm while trying to create ECDSA verifier: %s`, alg)
|
||||
}
|
||||
|
||||
return &ECDSAVerifier{
|
||||
verify: verifyfn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Verify checks whether the signature for a given input and key is correct
|
||||
func (v ECDSAVerifier) Verify(payload []byte, signature []byte, key interface{}) error {
|
||||
if key == nil {
|
||||
return errors.New(`missing public key while verifying payload`)
|
||||
}
|
||||
ecdsakey, ok := key.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return errors.Errorf(`invalid key type %T. *ecdsa.PublicKey is required`, key)
|
||||
}
|
||||
|
||||
return v.verify(payload, signature, ecdsakey)
|
||||
}
|
||||
33
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/hmac.go
generated
vendored
Normal file
33
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/hmac.go
generated
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
package verify
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign"
|
||||
)
|
||||
|
||||
func newHMAC(alg jwa.SignatureAlgorithm) (*HMACVerifier, error) {
|
||||
|
||||
s, err := sign.New(alg)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, `failed to generate HMAC signer`)
|
||||
}
|
||||
return &HMACVerifier{signer: s}, nil
|
||||
}
|
||||
|
||||
// Verify checks whether the signature for a given input and key is correct
|
||||
func (v HMACVerifier) Verify(signingInput, signature []byte, key interface{}) (err error) {
|
||||
|
||||
expected, err := v.signer.Sign(signingInput, key)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, `failed to generated signature`)
|
||||
}
|
||||
|
||||
if !hmac.Equal(signature, expected) {
|
||||
return errors.New(`failed to match hmac signature`)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
39
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/interface.go
generated
vendored
Normal file
39
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/interface.go
generated
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
package verify
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign"
|
||||
)
|
||||
|
||||
// Verifier provides a common interface for supported alg verification methods
|
||||
type Verifier interface {
|
||||
// Verify checks whether the payload and signature are valid for
|
||||
// the given key.
|
||||
// `key` is the key used for verifying the payload, and is usually
|
||||
// the public key associated with the signature method. For example,
|
||||
// for `jwa.RSXXX` and `jwa.PSXXX` types, you need to pass the
|
||||
// `*"crypto/rsa".PublicKey` type.
|
||||
// Check the documentation for each verifier for details
|
||||
Verify(payload []byte, signature []byte, key interface{}) error
|
||||
}
|
||||
|
||||
type rsaVerifyFunc func([]byte, []byte, *rsa.PublicKey) error
|
||||
|
||||
// RSAVerifier implements the Verifier interface
|
||||
type RSAVerifier struct {
|
||||
verify rsaVerifyFunc
|
||||
}
|
||||
|
||||
type ecdsaVerifyFunc func([]byte, []byte, *ecdsa.PublicKey) error
|
||||
|
||||
// ECDSAVerifier implements the Verifier interface
|
||||
type ECDSAVerifier struct {
|
||||
verify ecdsaVerifyFunc
|
||||
}
|
||||
|
||||
// HMACVerifier implements the Verifier interface
|
||||
type HMACVerifier struct {
|
||||
signer sign.Signer
|
||||
}
|
||||
88
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/rsa.go
generated
vendored
Normal file
88
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/rsa.go
generated
vendored
Normal file
@@ -0,0 +1,88 @@
|
||||
package verify
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rsa"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var rsaVerifyFuncs = map[jwa.SignatureAlgorithm]rsaVerifyFunc{}
|
||||
|
||||
func init() {
|
||||
algs := map[jwa.SignatureAlgorithm]struct {
|
||||
Hash crypto.Hash
|
||||
VerifyFunc func(crypto.Hash) rsaVerifyFunc
|
||||
}{
|
||||
jwa.RS256: {
|
||||
Hash: crypto.SHA256,
|
||||
VerifyFunc: makeVerifyPKCS1v15,
|
||||
},
|
||||
jwa.RS384: {
|
||||
Hash: crypto.SHA384,
|
||||
VerifyFunc: makeVerifyPKCS1v15,
|
||||
},
|
||||
jwa.RS512: {
|
||||
Hash: crypto.SHA512,
|
||||
VerifyFunc: makeVerifyPKCS1v15,
|
||||
},
|
||||
jwa.PS256: {
|
||||
Hash: crypto.SHA256,
|
||||
VerifyFunc: makeVerifyPSS,
|
||||
},
|
||||
jwa.PS384: {
|
||||
Hash: crypto.SHA384,
|
||||
VerifyFunc: makeVerifyPSS,
|
||||
},
|
||||
jwa.PS512: {
|
||||
Hash: crypto.SHA512,
|
||||
VerifyFunc: makeVerifyPSS,
|
||||
},
|
||||
}
|
||||
|
||||
for alg, item := range algs {
|
||||
rsaVerifyFuncs[alg] = item.VerifyFunc(item.Hash)
|
||||
}
|
||||
}
|
||||
|
||||
func makeVerifyPKCS1v15(hash crypto.Hash) rsaVerifyFunc {
|
||||
return rsaVerifyFunc(func(payload, signature []byte, key *rsa.PublicKey) error {
|
||||
h := hash.New()
|
||||
h.Write(payload)
|
||||
return rsa.VerifyPKCS1v15(key, hash, h.Sum(nil), signature)
|
||||
})
|
||||
}
|
||||
|
||||
func makeVerifyPSS(hash crypto.Hash) rsaVerifyFunc {
|
||||
return rsaVerifyFunc(func(payload, signature []byte, key *rsa.PublicKey) error {
|
||||
h := hash.New()
|
||||
h.Write(payload)
|
||||
return rsa.VerifyPSS(key, hash, h.Sum(nil), signature, nil)
|
||||
})
|
||||
}
|
||||
|
||||
func newRSA(alg jwa.SignatureAlgorithm) (*RSAVerifier, error) {
|
||||
verifyfn, ok := rsaVerifyFuncs[alg]
|
||||
if !ok {
|
||||
return nil, errors.Errorf(`unsupported algorithm while trying to create RSA verifier: %s`, alg)
|
||||
}
|
||||
|
||||
return &RSAVerifier{
|
||||
verify: verifyfn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Verify checks if a JWS is valid.
|
||||
func (v RSAVerifier) Verify(payload, signature []byte, key interface{}) error {
|
||||
if key == nil {
|
||||
return errors.New(`missing public key while verifying payload`)
|
||||
}
|
||||
rsaKey, ok := key.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return errors.Errorf(`invalid key type %T. *rsa.PublicKey is required`, key)
|
||||
}
|
||||
|
||||
return v.verify(payload, signature, rsaKey)
|
||||
}
|
||||
22
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/verify.go
generated
vendored
Normal file
22
vendor/github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify/verify.go
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
package verify
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
|
||||
)
|
||||
|
||||
// New creates a new JWS verifier using the specified algorithm
|
||||
// and the public key
|
||||
func New(alg jwa.SignatureAlgorithm) (Verifier, error) {
|
||||
switch alg {
|
||||
case jwa.RS256, jwa.RS384, jwa.RS512, jwa.PS256, jwa.PS384, jwa.PS512:
|
||||
return newRSA(alg)
|
||||
case jwa.ES256, jwa.ES384, jwa.ES512:
|
||||
return newECDSA(alg)
|
||||
case jwa.HS256, jwa.HS384, jwa.HS512:
|
||||
return newHMAC(alg)
|
||||
default:
|
||||
return nil, errors.Errorf(`unsupported signature algorithm: %s`, alg)
|
||||
}
|
||||
}
|
||||
235
vendor/github.com/open-policy-agent/opa/topdown/json.go
generated
vendored
Normal file
235
vendor/github.com/open-policy-agent/opa/topdown/json.go
generated
vendored
Normal file
@@ -0,0 +1,235 @@
|
||||
// Copyright 2019 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
func builtinJSONRemove(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
// Expect an object and a string or array/set of strings
|
||||
_, err := builtins.ObjectOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build a list of json pointers to remove
|
||||
paths, err := getJSONPaths(operands[1].Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newObj, err := jsonRemove(operands[0], ast.NewTerm(pathsToObject(paths)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if newObj == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return iter(newObj)
|
||||
}
|
||||
|
||||
// jsonRemove returns a new term that is the result of walking
|
||||
// through a and omitting removing any values that are in b but
|
||||
// have ast.Null values (ie leaf nodes for b).
|
||||
func jsonRemove(a *ast.Term, b *ast.Term) (*ast.Term, error) {
|
||||
if b == nil {
|
||||
// The paths diverged, return a
|
||||
return a, nil
|
||||
}
|
||||
|
||||
var bObj ast.Object
|
||||
switch bValue := b.Value.(type) {
|
||||
case ast.Object:
|
||||
bObj = bValue
|
||||
case ast.Null:
|
||||
// Means we hit a leaf node on "b", dont add the value for a
|
||||
return nil, nil
|
||||
default:
|
||||
// The paths diverged, return a
|
||||
return a, nil
|
||||
}
|
||||
|
||||
switch aValue := a.Value.(type) {
|
||||
case ast.String, ast.Number, ast.Boolean, ast.Null:
|
||||
return a, nil
|
||||
case ast.Object:
|
||||
newObj := ast.NewObject()
|
||||
err := aValue.Iter(func(k *ast.Term, v *ast.Term) error {
|
||||
// recurse and add the diff of sub objects as needed
|
||||
diffValue, err := jsonRemove(v, bObj.Get(k))
|
||||
if err != nil || diffValue == nil {
|
||||
return err
|
||||
}
|
||||
newObj.Insert(k, diffValue)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ast.NewTerm(newObj), nil
|
||||
case ast.Set:
|
||||
newSet := ast.NewSet()
|
||||
err := aValue.Iter(func(v *ast.Term) error {
|
||||
// recurse and add the diff of sub objects as needed
|
||||
diffValue, err := jsonRemove(v, bObj.Get(v))
|
||||
if err != nil || diffValue == nil {
|
||||
return err
|
||||
}
|
||||
newSet.Add(diffValue)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ast.NewTerm(newSet), nil
|
||||
case ast.Array:
|
||||
// When indexes are removed we shift left to close empty spots in the array
|
||||
// as per the JSON patch spec.
|
||||
var newArray ast.Array
|
||||
for i, v := range aValue {
|
||||
// recurse and add the diff of sub objects as needed
|
||||
// Note: Keys in b will be strings for the index, eg path /a/1/b => {"a": {"1": {"b": null}}}
|
||||
diffValue, err := jsonRemove(v, bObj.Get(ast.StringTerm(strconv.Itoa(i))))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if diffValue != nil {
|
||||
newArray = append(newArray, diffValue)
|
||||
}
|
||||
}
|
||||
return ast.NewTerm(newArray), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid value type %T", a)
|
||||
}
|
||||
}
|
||||
|
||||
func builtinJSONFilter(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
// Ensure we have the right parameters, expect an object and a string or array/set of strings
|
||||
obj, err := builtins.ObjectOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build a list of filter strings
|
||||
filters, err := getJSONPaths(operands[1].Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Actually do the filtering
|
||||
filterObj := pathsToObject(filters)
|
||||
r, err := obj.Filter(filterObj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return iter(ast.NewTerm(r))
|
||||
}
|
||||
|
||||
func getJSONPaths(operand ast.Value) ([]ast.Ref, error) {
|
||||
var paths []ast.Ref
|
||||
|
||||
switch v := operand.(type) {
|
||||
case ast.Array:
|
||||
for _, f := range v {
|
||||
filter, err := parsePath(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paths = append(paths, filter)
|
||||
}
|
||||
case ast.Set:
|
||||
err := v.Iter(func(f *ast.Term) error {
|
||||
filter, err := parsePath(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
paths = append(paths, filter)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
return nil, builtins.NewOperandTypeErr(2, v, "set", "array")
|
||||
}
|
||||
|
||||
return paths, nil
|
||||
}
|
||||
|
||||
func parsePath(path *ast.Term) (ast.Ref, error) {
|
||||
// paths can either be a `/` separated json path or
|
||||
// an array or set of values
|
||||
var pathSegments ast.Ref
|
||||
switch p := path.Value.(type) {
|
||||
case ast.String:
|
||||
parts := strings.Split(strings.Trim(string(p), "/"), "/")
|
||||
for _, part := range parts {
|
||||
part = strings.ReplaceAll(strings.ReplaceAll(part, "~1", "/"), "~0", "~")
|
||||
pathSegments = append(pathSegments, ast.StringTerm(part))
|
||||
}
|
||||
case ast.Array:
|
||||
for _, term := range p {
|
||||
pathSegments = append(pathSegments, term)
|
||||
}
|
||||
default:
|
||||
return nil, builtins.NewOperandErr(2, "must be one of {set, array} containing string paths or array of path segments but got %v", ast.TypeName(p))
|
||||
}
|
||||
|
||||
return pathSegments, nil
|
||||
}
|
||||
|
||||
func pathsToObject(paths []ast.Ref) ast.Object {
|
||||
|
||||
root := ast.NewObject()
|
||||
|
||||
for _, path := range paths {
|
||||
node := root
|
||||
var done bool
|
||||
|
||||
for i := 0; i < len(path)-1 && !done; i++ {
|
||||
|
||||
k := path[i]
|
||||
child := node.Get(k)
|
||||
|
||||
if child == nil {
|
||||
obj := ast.NewObject()
|
||||
node.Insert(k, ast.NewTerm(obj))
|
||||
node = obj
|
||||
continue
|
||||
}
|
||||
|
||||
switch v := child.Value.(type) {
|
||||
case ast.Null:
|
||||
done = true
|
||||
case ast.Object:
|
||||
node = v
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
if !done {
|
||||
node.Insert(path[len(path)-1], ast.NullTerm())
|
||||
}
|
||||
}
|
||||
|
||||
return root
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.JSONFilter.Name, builtinJSONFilter)
|
||||
RegisterBuiltinFunc(ast.JSONRemove.Name, builtinJSONRemove)
|
||||
}
|
||||
139
vendor/github.com/open-policy-agent/opa/topdown/object.go
generated
vendored
Normal file
139
vendor/github.com/open-policy-agent/opa/topdown/object.go
generated
vendored
Normal file
@@ -0,0 +1,139 @@
|
||||
// Copyright 2020 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
"github.com/open-policy-agent/opa/types"
|
||||
)
|
||||
|
||||
func builtinObjectUnion(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
objA, err := builtins.ObjectOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
objB, err := builtins.ObjectOperand(operands[1].Value, 2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r := mergeWithOverwrite(objA, objB)
|
||||
|
||||
return iter(ast.NewTerm(r))
|
||||
}
|
||||
|
||||
func builtinObjectRemove(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
// Expect an object and an array/set/object of keys
|
||||
obj, err := builtins.ObjectOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build a set of keys to remove
|
||||
keysToRemove, err := getObjectKeysParam(operands[1].Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r := ast.NewObject()
|
||||
obj.Foreach(func(key *ast.Term, value *ast.Term) {
|
||||
if !keysToRemove.Contains(key) {
|
||||
r.Insert(key, value)
|
||||
}
|
||||
})
|
||||
|
||||
return iter(ast.NewTerm(r))
|
||||
}
|
||||
|
||||
func builtinObjectFilter(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
// Expect an object and an array/set/object of keys
|
||||
obj, err := builtins.ObjectOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build a new object from the supplied filter keys
|
||||
keys, err := getObjectKeysParam(operands[1].Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filterObj := ast.NewObject()
|
||||
keys.Foreach(func(key *ast.Term) {
|
||||
filterObj.Insert(key, ast.NullTerm())
|
||||
})
|
||||
|
||||
// Actually do the filtering
|
||||
r, err := obj.Filter(filterObj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return iter(ast.NewTerm(r))
|
||||
}
|
||||
|
||||
func builtinObjectGet(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
||||
object, err := builtins.ObjectOperand(operands[0].Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ret := object.Get(operands[1]); ret != nil {
|
||||
return iter(ret)
|
||||
}
|
||||
|
||||
return iter(operands[2])
|
||||
}
|
||||
|
||||
// getObjectKeysParam returns a set of key values
|
||||
// from a supplied ast array, object, set value
|
||||
func getObjectKeysParam(arrayOrSet ast.Value) (ast.Set, error) {
|
||||
keys := ast.NewSet()
|
||||
|
||||
switch v := arrayOrSet.(type) {
|
||||
case ast.Array:
|
||||
for _, f := range v {
|
||||
keys.Add(f)
|
||||
}
|
||||
case ast.Set:
|
||||
_ = v.Iter(func(f *ast.Term) error {
|
||||
keys.Add(f)
|
||||
return nil
|
||||
})
|
||||
case ast.Object:
|
||||
_ = v.Iter(func(k *ast.Term, _ *ast.Term) error {
|
||||
keys.Add(k)
|
||||
return nil
|
||||
})
|
||||
default:
|
||||
return nil, builtins.NewOperandTypeErr(2, arrayOrSet, ast.TypeName(types.Object{}), ast.TypeName(types.S), ast.TypeName(types.Array{}))
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func mergeWithOverwrite(objA, objB ast.Object) ast.Object {
|
||||
merged, _ := objA.MergeWith(objB, func(v1, v2 *ast.Term) (*ast.Term, bool) {
|
||||
originalValueObj, ok2 := v1.Value.(ast.Object)
|
||||
updateValueObj, ok1 := v2.Value.(ast.Object)
|
||||
if !ok1 || !ok2 {
|
||||
// If we can't merge, stick with the right-hand value
|
||||
return v2, false
|
||||
}
|
||||
|
||||
// Recursively update the existing value
|
||||
merged := mergeWithOverwrite(originalValueObj, updateValueObj)
|
||||
return ast.NewTerm(merged), false
|
||||
})
|
||||
return merged
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.ObjectUnion.Name, builtinObjectUnion)
|
||||
RegisterBuiltinFunc(ast.ObjectRemove.Name, builtinObjectRemove)
|
||||
RegisterBuiltinFunc(ast.ObjectFilter.Name, builtinObjectFilter)
|
||||
RegisterBuiltinFunc(ast.ObjectGet.Name, builtinObjectGet)
|
||||
}
|
||||
47
vendor/github.com/open-policy-agent/opa/topdown/parse.go
generated
vendored
Normal file
47
vendor/github.com/open-policy-agent/opa/topdown/parse.go
generated
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright 2018 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
func builtinRegoParseModule(a, b ast.Value) (ast.Value, error) {
|
||||
|
||||
filename, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
input, err := builtins.StringOperand(b, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
module, err := ast.ParseModule(string(filename), string(input))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := json.NewEncoder(&buf).Encode(module); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
term, err := ast.ParseTerm(buf.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return term.Value, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin2(ast.RegoParseModule.Name, builtinRegoParseModule)
|
||||
}
|
||||
153
vendor/github.com/open-policy-agent/opa/topdown/parse_bytes.go
generated
vendored
Normal file
153
vendor/github.com/open-policy-agent/opa/topdown/parse_bytes.go
generated
vendored
Normal file
@@ -0,0 +1,153 @@
|
||||
// Copyright 2016 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
const (
|
||||
none int64 = 1
|
||||
kb = 1000
|
||||
ki = 1024
|
||||
mb = kb * 1000
|
||||
mi = ki * 1024
|
||||
gb = mb * 1000
|
||||
gi = mi * 1024
|
||||
tb = gb * 1000
|
||||
ti = gi * 1024
|
||||
)
|
||||
|
||||
// The rune values for 0..9 as well as the period symbol (for parsing floats)
|
||||
var numRunes = []rune("0123456789.")
|
||||
|
||||
func parseNumBytesError(msg string) error {
|
||||
return fmt.Errorf("%s error: %s", ast.UnitsParseBytes.Name, msg)
|
||||
}
|
||||
|
||||
func errUnitNotRecognized(unit string) error {
|
||||
return parseNumBytesError(fmt.Sprintf("byte unit %s not recognized", unit))
|
||||
}
|
||||
|
||||
var (
|
||||
errNoAmount = parseNumBytesError("no byte amount provided")
|
||||
errIntConv = parseNumBytesError("could not parse byte amount to integer")
|
||||
errIncludesSpaces = parseNumBytesError("spaces not allowed in resource strings")
|
||||
)
|
||||
|
||||
func builtinNumBytes(a ast.Value) (ast.Value, error) {
|
||||
var m int64
|
||||
|
||||
raw, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := formatString(raw)
|
||||
|
||||
if strings.Contains(s, " ") {
|
||||
return nil, errIncludesSpaces
|
||||
}
|
||||
|
||||
numStr, unitStr := extractNumAndUnit(s)
|
||||
|
||||
if numStr == "" {
|
||||
return nil, errNoAmount
|
||||
}
|
||||
|
||||
switch unitStr {
|
||||
case "":
|
||||
m = none
|
||||
case "kb":
|
||||
m = kb
|
||||
case "kib":
|
||||
m = ki
|
||||
case "mb":
|
||||
m = mb
|
||||
case "mib":
|
||||
m = mi
|
||||
case "gb":
|
||||
m = gb
|
||||
case "gib":
|
||||
m = gi
|
||||
case "tb":
|
||||
m = tb
|
||||
case "tib":
|
||||
m = ti
|
||||
default:
|
||||
return nil, errUnitNotRecognized(unitStr)
|
||||
}
|
||||
|
||||
num, err := strconv.ParseInt(numStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, errIntConv
|
||||
}
|
||||
|
||||
total := num * m
|
||||
|
||||
return builtins.IntToNumber(big.NewInt(total)), nil
|
||||
}
|
||||
|
||||
// Makes the string lower case and removes spaces and quotation marks
|
||||
func formatString(s ast.String) string {
|
||||
str := string(s)
|
||||
lower := strings.ToLower(str)
|
||||
return strings.Replace(lower, "\"", "", -1)
|
||||
}
|
||||
|
||||
// Splits the string into a number string à la "10" or "10.2" and a unit string à la "gb" or "MiB" or "foo". Either
|
||||
// can be an empty string (error handling is provided elsewhere).
|
||||
func extractNumAndUnit(s string) (string, string) {
|
||||
isNum := func(r rune) (isNum bool) {
|
||||
for _, nr := range numRunes {
|
||||
if nr == r {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Returns the index of the first rune that's not a number (or 0 if there are only numbers)
|
||||
getFirstNonNumIdx := func(s string) int {
|
||||
for idx, r := range s {
|
||||
if !isNum(r) {
|
||||
return idx
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
firstRuneIsNum := func(s string) bool {
|
||||
return isNum(rune(s[0]))
|
||||
}
|
||||
|
||||
firstNonNumIdx := getFirstNonNumIdx(s)
|
||||
|
||||
// The string contains only a number
|
||||
numOnly := firstNonNumIdx == 0 && firstRuneIsNum(s)
|
||||
|
||||
// The string contains only a unit
|
||||
unitOnly := firstNonNumIdx == 0 && !firstRuneIsNum(s)
|
||||
|
||||
if numOnly {
|
||||
return s, ""
|
||||
} else if unitOnly {
|
||||
return "", s
|
||||
} else {
|
||||
return s[0:firstNonNumIdx], s[firstNonNumIdx:]
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin1(ast.UnitsParseBytes.Name, builtinNumBytes)
|
||||
}
|
||||
317
vendor/github.com/open-policy-agent/opa/topdown/query.go
generated
vendored
Normal file
317
vendor/github.com/open-policy-agent/opa/topdown/query.go
generated
vendored
Normal file
@@ -0,0 +1,317 @@
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/metrics"
|
||||
"github.com/open-policy-agent/opa/storage"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
"github.com/open-policy-agent/opa/topdown/copypropagation"
|
||||
)
|
||||
|
||||
// QueryResultSet represents a collection of results returned by a query.
|
||||
type QueryResultSet []QueryResult
|
||||
|
||||
// QueryResult represents a single result returned by a query. The result
|
||||
// contains bindings for all variables that appear in the query.
|
||||
type QueryResult map[ast.Var]*ast.Term
|
||||
|
||||
// Query provides a configurable interface for performing query evaluation.
|
||||
type Query struct {
|
||||
cancel Cancel
|
||||
query ast.Body
|
||||
queryCompiler ast.QueryCompiler
|
||||
compiler *ast.Compiler
|
||||
store storage.Store
|
||||
txn storage.Transaction
|
||||
input *ast.Term
|
||||
tracers []Tracer
|
||||
unknowns []*ast.Term
|
||||
partialNamespace string
|
||||
metrics metrics.Metrics
|
||||
instr *Instrumentation
|
||||
disableInlining []ast.Ref
|
||||
genvarprefix string
|
||||
runtime *ast.Term
|
||||
builtins map[string]*Builtin
|
||||
indexing bool
|
||||
}
|
||||
|
||||
// Builtin represents a built-in function that queries can call.
|
||||
type Builtin struct {
|
||||
Decl *ast.Builtin
|
||||
Func BuiltinFunc
|
||||
}
|
||||
|
||||
// NewQuery returns a new Query object that can be run.
|
||||
func NewQuery(query ast.Body) *Query {
|
||||
return &Query{
|
||||
query: query,
|
||||
genvarprefix: ast.WildcardPrefix,
|
||||
indexing: true,
|
||||
}
|
||||
}
|
||||
|
||||
// WithQueryCompiler sets the queryCompiler used for the query.
|
||||
func (q *Query) WithQueryCompiler(queryCompiler ast.QueryCompiler) *Query {
|
||||
q.queryCompiler = queryCompiler
|
||||
return q
|
||||
}
|
||||
|
||||
// WithCompiler sets the compiler to use for the query.
|
||||
func (q *Query) WithCompiler(compiler *ast.Compiler) *Query {
|
||||
q.compiler = compiler
|
||||
return q
|
||||
}
|
||||
|
||||
// WithStore sets the store to use for the query.
|
||||
func (q *Query) WithStore(store storage.Store) *Query {
|
||||
q.store = store
|
||||
return q
|
||||
}
|
||||
|
||||
// WithTransaction sets the transaction to use for the query. All queries
|
||||
// should be performed over a consistent snapshot of the storage layer.
|
||||
func (q *Query) WithTransaction(txn storage.Transaction) *Query {
|
||||
q.txn = txn
|
||||
return q
|
||||
}
|
||||
|
||||
// WithCancel sets the cancellation object to use for the query. Set this if
|
||||
// you need to abort queries based on a deadline. This is optional.
|
||||
func (q *Query) WithCancel(cancel Cancel) *Query {
|
||||
q.cancel = cancel
|
||||
return q
|
||||
}
|
||||
|
||||
// WithInput sets the input object to use for the query. References rooted at
|
||||
// input will be evaluated against this value. This is optional.
|
||||
func (q *Query) WithInput(input *ast.Term) *Query {
|
||||
q.input = input
|
||||
return q
|
||||
}
|
||||
|
||||
// WithTracer adds a query tracer to use during evaluation. This is optional.
|
||||
func (q *Query) WithTracer(tracer Tracer) *Query {
|
||||
q.tracers = append(q.tracers, tracer)
|
||||
return q
|
||||
}
|
||||
|
||||
// WithMetrics sets the metrics collection to add evaluation metrics to. This
|
||||
// is optional.
|
||||
func (q *Query) WithMetrics(m metrics.Metrics) *Query {
|
||||
q.metrics = m
|
||||
return q
|
||||
}
|
||||
|
||||
// WithInstrumentation sets the instrumentation configuration to enable on the
|
||||
// evaluation process. By default, instrumentation is turned off.
|
||||
func (q *Query) WithInstrumentation(instr *Instrumentation) *Query {
|
||||
q.instr = instr
|
||||
return q
|
||||
}
|
||||
|
||||
// WithUnknowns sets the initial set of variables or references to treat as
|
||||
// unknown during query evaluation. This is required for partial evaluation.
|
||||
func (q *Query) WithUnknowns(terms []*ast.Term) *Query {
|
||||
q.unknowns = terms
|
||||
return q
|
||||
}
|
||||
|
||||
// WithPartialNamespace sets the namespace to use for supporting rules
|
||||
// generated as part of the partial evaluation process. The ns value must be a
|
||||
// valid package path component.
|
||||
func (q *Query) WithPartialNamespace(ns string) *Query {
|
||||
q.partialNamespace = ns
|
||||
return q
|
||||
}
|
||||
|
||||
// WithDisableInlining adds a set of paths to the query that should be excluded from
|
||||
// inlining. Inlining during partial evaluation can be expensive in some cases
|
||||
// (e.g., when a cross-product is computed.) Disabling inlining avoids expensive
|
||||
// computation at the cost of generating support rules.
|
||||
func (q *Query) WithDisableInlining(paths []ast.Ref) *Query {
|
||||
q.disableInlining = paths
|
||||
return q
|
||||
}
|
||||
|
||||
// WithRuntime sets the runtime data to execute the query with. The runtime data
|
||||
// can be returned by the `opa.runtime` built-in function.
|
||||
func (q *Query) WithRuntime(runtime *ast.Term) *Query {
|
||||
q.runtime = runtime
|
||||
return q
|
||||
}
|
||||
|
||||
// WithBuiltins adds a set of built-in functions that can be called by the
|
||||
// query.
|
||||
func (q *Query) WithBuiltins(builtins map[string]*Builtin) *Query {
|
||||
q.builtins = builtins
|
||||
return q
|
||||
}
|
||||
|
||||
// WithIndexing will enable or disable using rule indexing for the evaluation
|
||||
// of the query. The default is enabled.
|
||||
func (q *Query) WithIndexing(enabled bool) *Query {
|
||||
q.indexing = enabled
|
||||
return q
|
||||
}
|
||||
|
||||
// PartialRun executes partial evaluation on the query with respect to unknown
|
||||
// values. Partial evaluation attempts to evaluate as much of the query as
|
||||
// possible without requiring values for the unknowns set on the query. The
|
||||
// result of partial evaluation is a new set of queries that can be evaluated
|
||||
// once the unknown value is known. In addition to new queries, partial
|
||||
// evaluation may produce additional support modules that should be used in
|
||||
// conjunction with the partially evaluated queries.
|
||||
func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support []*ast.Module, err error) {
|
||||
if q.partialNamespace == "" {
|
||||
q.partialNamespace = "partial" // lazily initialize partial namespace
|
||||
}
|
||||
f := &queryIDFactory{}
|
||||
b := newBindings(0, q.instr)
|
||||
e := &eval{
|
||||
ctx: ctx,
|
||||
cancel: q.cancel,
|
||||
query: q.query,
|
||||
queryCompiler: q.queryCompiler,
|
||||
queryIDFact: f,
|
||||
queryID: f.Next(),
|
||||
bindings: b,
|
||||
compiler: q.compiler,
|
||||
store: q.store,
|
||||
baseCache: newBaseCache(),
|
||||
targetStack: newRefStack(),
|
||||
txn: q.txn,
|
||||
input: q.input,
|
||||
tracers: q.tracers,
|
||||
instr: q.instr,
|
||||
builtins: q.builtins,
|
||||
builtinCache: builtins.Cache{},
|
||||
virtualCache: newVirtualCache(),
|
||||
saveSet: newSaveSet(q.unknowns, b, q.instr),
|
||||
saveStack: newSaveStack(),
|
||||
saveSupport: newSaveSupport(),
|
||||
saveNamespace: ast.StringTerm(q.partialNamespace),
|
||||
genvarprefix: q.genvarprefix,
|
||||
runtime: q.runtime,
|
||||
indexing: q.indexing,
|
||||
}
|
||||
|
||||
if len(q.disableInlining) > 0 {
|
||||
e.disableInlining = [][]ast.Ref{q.disableInlining}
|
||||
}
|
||||
|
||||
e.caller = e
|
||||
q.startTimer(metrics.RegoPartialEval)
|
||||
defer q.stopTimer(metrics.RegoPartialEval)
|
||||
|
||||
livevars := ast.NewVarSet()
|
||||
|
||||
ast.WalkVars(q.query, func(x ast.Var) bool {
|
||||
if !x.IsGenerated() {
|
||||
livevars.Add(x)
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
p := copypropagation.New(livevars)
|
||||
|
||||
err = e.Run(func(e *eval) error {
|
||||
|
||||
// Build output from saved expressions.
|
||||
body := ast.NewBody()
|
||||
|
||||
for _, elem := range e.saveStack.Stack[len(e.saveStack.Stack)-1] {
|
||||
body.Append(elem.Plug(e.bindings))
|
||||
}
|
||||
|
||||
// Include bindings as exprs so that when caller evals the result, they
|
||||
// can obtain values for the vars in their query.
|
||||
bindingExprs := []*ast.Expr{}
|
||||
e.bindings.Iter(e.bindings, func(a, b *ast.Term) error {
|
||||
bindingExprs = append(bindingExprs, ast.Equality.Expr(a, b))
|
||||
return nil
|
||||
})
|
||||
|
||||
// Sort binding expressions so that results are deterministic.
|
||||
sort.Slice(bindingExprs, func(i, j int) bool {
|
||||
return bindingExprs[i].Compare(bindingExprs[j]) < 0
|
||||
})
|
||||
|
||||
for i := range bindingExprs {
|
||||
body.Append(bindingExprs[i])
|
||||
}
|
||||
|
||||
partials = append(partials, applyCopyPropagation(p, e.instr, body))
|
||||
return nil
|
||||
})
|
||||
|
||||
support = e.saveSupport.List()
|
||||
|
||||
return partials, support, err
|
||||
}
|
||||
|
||||
// Run is a wrapper around Iter that accumulates query results and returns them
|
||||
// in one shot.
|
||||
func (q *Query) Run(ctx context.Context) (QueryResultSet, error) {
|
||||
qrs := QueryResultSet{}
|
||||
return qrs, q.Iter(ctx, func(qr QueryResult) error {
|
||||
qrs = append(qrs, qr)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Iter executes the query and invokes the iter function with query results
|
||||
// produced by evaluating the query.
|
||||
func (q *Query) Iter(ctx context.Context, iter func(QueryResult) error) error {
|
||||
f := &queryIDFactory{}
|
||||
e := &eval{
|
||||
ctx: ctx,
|
||||
cancel: q.cancel,
|
||||
query: q.query,
|
||||
queryCompiler: q.queryCompiler,
|
||||
queryIDFact: f,
|
||||
queryID: f.Next(),
|
||||
bindings: newBindings(0, q.instr),
|
||||
compiler: q.compiler,
|
||||
store: q.store,
|
||||
baseCache: newBaseCache(),
|
||||
targetStack: newRefStack(),
|
||||
txn: q.txn,
|
||||
input: q.input,
|
||||
tracers: q.tracers,
|
||||
instr: q.instr,
|
||||
builtins: q.builtins,
|
||||
builtinCache: builtins.Cache{},
|
||||
virtualCache: newVirtualCache(),
|
||||
genvarprefix: q.genvarprefix,
|
||||
runtime: q.runtime,
|
||||
indexing: q.indexing,
|
||||
}
|
||||
e.caller = e
|
||||
q.startTimer(metrics.RegoQueryEval)
|
||||
err := e.Run(func(e *eval) error {
|
||||
qr := QueryResult{}
|
||||
e.bindings.Iter(nil, func(k, v *ast.Term) error {
|
||||
qr[k.Value.(ast.Var)] = v
|
||||
return nil
|
||||
})
|
||||
return iter(qr)
|
||||
})
|
||||
q.stopTimer(metrics.RegoQueryEval)
|
||||
return err
|
||||
}
|
||||
|
||||
func (q *Query) startTimer(name string) {
|
||||
if q.metrics != nil {
|
||||
q.metrics.Timer(name).Start()
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Query) stopTimer(name string) {
|
||||
if q.metrics != nil {
|
||||
q.metrics.Timer(name).Stop()
|
||||
}
|
||||
}
|
||||
201
vendor/github.com/open-policy-agent/opa/topdown/regex.go
generated
vendored
Normal file
201
vendor/github.com/open-policy-agent/opa/topdown/regex.go
generated
vendored
Normal file
@@ -0,0 +1,201 @@
|
||||
// Copyright 2016 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sync"
|
||||
|
||||
"github.com/yashtewari/glob-intersection"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
var regexpCacheLock = sync.Mutex{}
|
||||
var regexpCache map[string]*regexp.Regexp
|
||||
|
||||
func builtinRegexMatch(a, b ast.Value) (ast.Value, error) {
|
||||
s1, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s2, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
re, err := getRegexp(string(s1))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ast.Boolean(re.Match([]byte(s2))), nil
|
||||
}
|
||||
|
||||
func builtinRegexMatchTemplate(a, b, c, d ast.Value) (ast.Value, error) {
|
||||
pattern, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
match, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
start, err := builtins.StringOperand(c, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
end, err := builtins.StringOperand(d, 4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(start) != 1 {
|
||||
return nil, fmt.Errorf("start delimiter has to be exactly one character long but is %d long", len(start))
|
||||
}
|
||||
if len(end) != 1 {
|
||||
return nil, fmt.Errorf("end delimiter has to be exactly one character long but is %d long", len(start))
|
||||
}
|
||||
re, err := getRegexpTemplate(string(pattern), string(start)[0], string(end)[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ast.Boolean(re.MatchString(string(match))), nil
|
||||
}
|
||||
|
||||
func builtinRegexSplit(a, b ast.Value) (ast.Value, error) {
|
||||
s1, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s2, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
re, err := getRegexp(string(s1))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
elems := re.Split(string(s2), -1)
|
||||
arr := make(ast.Array, len(elems))
|
||||
for i := range arr {
|
||||
arr[i] = ast.StringTerm(elems[i])
|
||||
}
|
||||
return arr, nil
|
||||
}
|
||||
|
||||
func getRegexp(pat string) (*regexp.Regexp, error) {
|
||||
regexpCacheLock.Lock()
|
||||
defer regexpCacheLock.Unlock()
|
||||
re, ok := regexpCache[pat]
|
||||
if !ok {
|
||||
var err error
|
||||
re, err = regexp.Compile(string(pat))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
regexpCache[pat] = re
|
||||
}
|
||||
return re, nil
|
||||
}
|
||||
|
||||
func getRegexpTemplate(pat string, delimStart, delimEnd byte) (*regexp.Regexp, error) {
|
||||
regexpCacheLock.Lock()
|
||||
defer regexpCacheLock.Unlock()
|
||||
re, ok := regexpCache[pat]
|
||||
if !ok {
|
||||
var err error
|
||||
re, err = compileRegexTemplate(string(pat), delimStart, delimEnd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
regexpCache[pat] = re
|
||||
}
|
||||
return re, nil
|
||||
}
|
||||
|
||||
func builtinGlobsMatch(a, b ast.Value) (ast.Value, error) {
|
||||
s1, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s2, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ne, err := gintersect.NonEmpty(string(s1), string(s2))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ast.Boolean(ne), nil
|
||||
}
|
||||
|
||||
func builtinRegexFind(a, b, c ast.Value) (ast.Value, error) {
|
||||
s1, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s2, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n, err := builtins.IntOperand(c, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
re, err := getRegexp(string(s1))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
elems := re.FindAllString(string(s2), n)
|
||||
arr := make(ast.Array, len(elems))
|
||||
for i := range arr {
|
||||
arr[i] = ast.StringTerm(elems[i])
|
||||
}
|
||||
return arr, nil
|
||||
}
|
||||
|
||||
func builtinRegexFindAllStringSubmatch(a, b, c ast.Value) (ast.Value, error) {
|
||||
s1, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s2, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
n, err := builtins.IntOperand(c, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
re, err := getRegexp(string(s1))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
matches := re.FindAllStringSubmatch(string(s2), n)
|
||||
|
||||
outer := make(ast.Array, len(matches))
|
||||
for i := range outer {
|
||||
inner := make(ast.Array, len(matches[i]))
|
||||
for j := range inner {
|
||||
inner[j] = ast.StringTerm(matches[i][j])
|
||||
}
|
||||
outer[i] = ast.ArrayTerm(inner...)
|
||||
}
|
||||
|
||||
return outer, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
regexpCache = map[string]*regexp.Regexp{}
|
||||
RegisterFunctionalBuiltin2(ast.RegexMatch.Name, builtinRegexMatch)
|
||||
RegisterFunctionalBuiltin2(ast.RegexSplit.Name, builtinRegexSplit)
|
||||
RegisterFunctionalBuiltin2(ast.GlobsMatch.Name, builtinGlobsMatch)
|
||||
RegisterFunctionalBuiltin4(ast.RegexTemplateMatch.Name, builtinRegexMatchTemplate)
|
||||
RegisterFunctionalBuiltin3(ast.RegexFind.Name, builtinRegexFind)
|
||||
RegisterFunctionalBuiltin3(ast.RegexFindAllStringSubmatch.Name, builtinRegexFindAllStringSubmatch)
|
||||
}
|
||||
122
vendor/github.com/open-policy-agent/opa/topdown/regex_template.go
generated
vendored
Normal file
122
vendor/github.com/open-policy-agent/opa/topdown/regex_template.go
generated
vendored
Normal file
@@ -0,0 +1,122 @@
|
||||
package topdown
|
||||
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license as follows:
|
||||
|
||||
// Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
// This file was forked from https://github.com/gorilla/mux/commit/eac83ba2c004bb75
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// delimiterIndices returns the first level delimiter indices from a string.
|
||||
// It returns an error in case of unbalanced delimiters.
|
||||
func delimiterIndices(s string, delimiterStart, delimiterEnd byte) ([]int, error) {
|
||||
var level, idx int
|
||||
idxs := make([]int, 0)
|
||||
for i := 0; i < len(s); i++ {
|
||||
switch s[i] {
|
||||
case delimiterStart:
|
||||
if level++; level == 1 {
|
||||
idx = i
|
||||
}
|
||||
case delimiterEnd:
|
||||
if level--; level == 0 {
|
||||
idxs = append(idxs, idx, i+1)
|
||||
} else if level < 0 {
|
||||
return nil, fmt.Errorf(`unbalanced braces in %q`, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if level != 0 {
|
||||
return nil, fmt.Errorf(`unbalanced braces in %q`, s)
|
||||
}
|
||||
|
||||
return idxs, nil
|
||||
}
|
||||
|
||||
// compileRegexTemplate parses a template and returns a Regexp.
|
||||
//
|
||||
// You can define your own delimiters. It is e.g. common to use curly braces {} but I recommend using characters
|
||||
// which have no special meaning in Regex, e.g.: <, >
|
||||
//
|
||||
// reg, err := compiler.CompileRegex("foo:bar.baz:<[0-9]{2,10}>", '<', '>')
|
||||
// // if err != nil ...
|
||||
// reg.MatchString("foo:bar.baz:123")
|
||||
func compileRegexTemplate(tpl string, delimiterStart, delimiterEnd byte) (*regexp.Regexp, error) {
|
||||
// Check if it is well-formed.
|
||||
idxs, errBraces := delimiterIndices(tpl, delimiterStart, delimiterEnd)
|
||||
if errBraces != nil {
|
||||
return nil, errBraces
|
||||
}
|
||||
varsR := make([]*regexp.Regexp, len(idxs)/2)
|
||||
pattern := bytes.NewBufferString("")
|
||||
|
||||
// WriteByte's error value is always nil for bytes.Buffer, no need to check it.
|
||||
pattern.WriteByte('^')
|
||||
|
||||
var end int
|
||||
var err error
|
||||
for i := 0; i < len(idxs); i += 2 {
|
||||
// Set all values we are interested in.
|
||||
raw := tpl[end:idxs[i]]
|
||||
end = idxs[i+1]
|
||||
patt := tpl[idxs[i]+1 : end-1]
|
||||
// Build the regexp pattern.
|
||||
varIdx := i / 2
|
||||
fmt.Fprintf(pattern, "%s(%s)", regexp.QuoteMeta(raw), patt)
|
||||
varsR[varIdx], err = regexp.Compile(fmt.Sprintf("^%s$", patt))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Add the remaining.
|
||||
raw := tpl[end:]
|
||||
|
||||
// WriteString's error value is always nil for bytes.Buffer, no need to check it.
|
||||
pattern.WriteString(regexp.QuoteMeta(raw))
|
||||
|
||||
// WriteByte's error value is always nil for bytes.Buffer, no need to check it.
|
||||
pattern.WriteByte('$')
|
||||
|
||||
// Compile full regexp.
|
||||
reg, errCompile := regexp.Compile(pattern.String())
|
||||
if errCompile != nil {
|
||||
return nil, errCompile
|
||||
}
|
||||
|
||||
return reg, nil
|
||||
}
|
||||
20
vendor/github.com/open-policy-agent/opa/topdown/runtime.go
generated
vendored
Normal file
20
vendor/github.com/open-policy-agent/opa/topdown/runtime.go
generated
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
// Copyright 2018 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import "github.com/open-policy-agent/opa/ast"
|
||||
|
||||
func builtinOPARuntime(bctx BuiltinContext, _ []*ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
if bctx.Runtime == nil {
|
||||
return iter(ast.ObjectTerm())
|
||||
}
|
||||
|
||||
return iter(bctx.Runtime)
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.OPARuntime.Name, builtinOPARuntime)
|
||||
}
|
||||
382
vendor/github.com/open-policy-agent/opa/topdown/save.go
generated
vendored
Normal file
382
vendor/github.com/open-policy-agent/opa/topdown/save.go
generated
vendored
Normal file
@@ -0,0 +1,382 @@
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
)
|
||||
|
||||
// saveSet contains a stack of terms that are considered 'unknown' during
|
||||
// partial evaluation. Only var and ref terms (rooted at one of the root
|
||||
// documents) can be added to the save set. Vars added to the save set are
|
||||
// namespaced by the binding list they are added with. This means the save set
|
||||
// can be shared across queries.
|
||||
type saveSet struct {
|
||||
instr *Instrumentation
|
||||
l *list.List
|
||||
}
|
||||
|
||||
func newSaveSet(ts []*ast.Term, b *bindings, instr *Instrumentation) *saveSet {
|
||||
ss := &saveSet{
|
||||
l: list.New(),
|
||||
instr: instr,
|
||||
}
|
||||
ss.Push(ts, b)
|
||||
return ss
|
||||
}
|
||||
|
||||
func (ss *saveSet) Push(ts []*ast.Term, b *bindings) {
|
||||
ss.l.PushBack(newSaveSetElem(ts, b))
|
||||
}
|
||||
|
||||
func (ss *saveSet) Pop() {
|
||||
ss.l.Remove(ss.l.Back())
|
||||
}
|
||||
|
||||
// Contains returns true if the term t is contained in the save set. Non-var and
|
||||
// non-ref terms are never contained. Ref terms are contained if they share a
|
||||
// prefix with a ref that was added (in either direction).
|
||||
func (ss *saveSet) Contains(t *ast.Term, b *bindings) bool {
|
||||
if ss != nil {
|
||||
ss.instr.startTimer(partialOpSaveSetContains)
|
||||
ret := ss.contains(t, b)
|
||||
ss.instr.stopTimer(partialOpSaveSetContains)
|
||||
return ret
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ss *saveSet) contains(t *ast.Term, b *bindings) bool {
|
||||
for el := ss.l.Back(); el != nil; el = el.Prev() {
|
||||
if el.Value.(*saveSetElem).Contains(t, b) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ContainsRecursive retruns true if the term t is or contains a term that is
|
||||
// contained in the save set. This function will close over the binding list
|
||||
// when it encounters vars.
|
||||
func (ss *saveSet) ContainsRecursive(t *ast.Term, b *bindings) bool {
|
||||
if ss != nil {
|
||||
ss.instr.startTimer(partialOpSaveSetContainsRec)
|
||||
ret := ss.containsrec(t, b)
|
||||
ss.instr.stopTimer(partialOpSaveSetContainsRec)
|
||||
return ret
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ss *saveSet) containsrec(t *ast.Term, b *bindings) bool {
|
||||
var found bool
|
||||
ast.WalkTerms(t, func(x *ast.Term) bool {
|
||||
if _, ok := x.Value.(ast.Var); ok {
|
||||
x1, b1 := b.apply(x)
|
||||
if x1 != x || b1 != b {
|
||||
if ss.containsrec(x1, b1) {
|
||||
found = true
|
||||
}
|
||||
} else if ss.contains(x1, b1) {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
return found
|
||||
})
|
||||
return found
|
||||
}
|
||||
|
||||
func (ss *saveSet) Vars(caller *bindings) ast.VarSet {
|
||||
result := ast.NewVarSet()
|
||||
for x := ss.l.Front(); x != nil; x = x.Next() {
|
||||
elem := x.Value.(*saveSetElem)
|
||||
for _, v := range elem.vars {
|
||||
if v, ok := elem.b.PlugNamespaced(v, caller).Value.(ast.Var); ok {
|
||||
result.Add(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (ss *saveSet) String() string {
|
||||
var buf []string
|
||||
|
||||
for x := ss.l.Front(); x != nil; x = x.Next() {
|
||||
buf = append(buf, x.Value.(*saveSetElem).String())
|
||||
}
|
||||
|
||||
return "(" + strings.Join(buf, " ") + ")"
|
||||
}
|
||||
|
||||
type saveSetElem struct {
|
||||
refs []ast.Ref
|
||||
vars []*ast.Term
|
||||
b *bindings
|
||||
}
|
||||
|
||||
func newSaveSetElem(ts []*ast.Term, b *bindings) *saveSetElem {
|
||||
|
||||
var refs []ast.Ref
|
||||
var vars []*ast.Term
|
||||
|
||||
for _, t := range ts {
|
||||
switch v := t.Value.(type) {
|
||||
case ast.Var:
|
||||
vars = append(vars, t)
|
||||
case ast.Ref:
|
||||
refs = append(refs, v)
|
||||
default:
|
||||
panic("illegal value")
|
||||
}
|
||||
}
|
||||
|
||||
return &saveSetElem{
|
||||
b: b,
|
||||
vars: vars,
|
||||
refs: refs,
|
||||
}
|
||||
}
|
||||
|
||||
func (sse *saveSetElem) Contains(t *ast.Term, b *bindings) bool {
|
||||
switch other := t.Value.(type) {
|
||||
case ast.Var:
|
||||
return sse.containsVar(t, b)
|
||||
case ast.Ref:
|
||||
for _, ref := range sse.refs {
|
||||
if ref.HasPrefix(other) || other.HasPrefix(ref) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return sse.containsVar(other[0], b)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (sse *saveSetElem) String() string {
|
||||
return fmt.Sprintf("(refs: %v, vars: %v, b: %v)", sse.refs, sse.vars, sse.b)
|
||||
}
|
||||
|
||||
func (sse *saveSetElem) containsVar(t *ast.Term, b *bindings) bool {
|
||||
if b == sse.b {
|
||||
for _, v := range sse.vars {
|
||||
if v.Equal(t) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// saveStack contains a stack of queries that represent the result of partial
|
||||
// evaluation. When partial evaluation completes, the top of the stack
|
||||
// represents a complete, partially evaluated query that can be saved and
|
||||
// evaluated later.
|
||||
//
|
||||
// The result is stored in a stack so that partial evaluation of a query can be
|
||||
// paused and then resumed in cases where different queries make up the result
|
||||
// of partial evaluation, such as when a rule with a default clause is
|
||||
// partially evaluated. In this case, the partially evaluated rule will be
|
||||
// output in the support module.
|
||||
type saveStack struct {
|
||||
Stack []saveStackQuery
|
||||
}
|
||||
|
||||
func newSaveStack() *saveStack {
|
||||
return &saveStack{
|
||||
Stack: []saveStackQuery{
|
||||
{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *saveStack) PushQuery(query saveStackQuery) {
|
||||
s.Stack = append(s.Stack, query)
|
||||
}
|
||||
|
||||
func (s *saveStack) PopQuery() saveStackQuery {
|
||||
last := s.Stack[len(s.Stack)-1]
|
||||
s.Stack = s.Stack[:len(s.Stack)-1]
|
||||
return last
|
||||
}
|
||||
|
||||
func (s *saveStack) Peek() saveStackQuery {
|
||||
return s.Stack[len(s.Stack)-1]
|
||||
}
|
||||
|
||||
func (s *saveStack) Push(expr *ast.Expr, b1 *bindings, b2 *bindings) {
|
||||
idx := len(s.Stack) - 1
|
||||
s.Stack[idx] = append(s.Stack[idx], saveStackElem{expr, b1, b2})
|
||||
}
|
||||
|
||||
func (s *saveStack) Pop() {
|
||||
idx := len(s.Stack) - 1
|
||||
query := s.Stack[idx]
|
||||
s.Stack[idx] = query[:len(query)-1]
|
||||
}
|
||||
|
||||
type saveStackQuery []saveStackElem
|
||||
|
||||
func (s saveStackQuery) Plug(b *bindings) ast.Body {
|
||||
if len(s) == 0 {
|
||||
return ast.NewBody(ast.NewExpr(ast.BooleanTerm(true)))
|
||||
}
|
||||
result := make(ast.Body, len(s))
|
||||
for i := range s {
|
||||
expr := s[i].Plug(b)
|
||||
result.Set(expr, i)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type saveStackElem struct {
|
||||
Expr *ast.Expr
|
||||
B1 *bindings
|
||||
B2 *bindings
|
||||
}
|
||||
|
||||
func (e saveStackElem) Plug(caller *bindings) *ast.Expr {
|
||||
if e.B1 == nil && e.B2 == nil {
|
||||
return e.Expr
|
||||
}
|
||||
expr := e.Expr.Copy()
|
||||
switch terms := expr.Terms.(type) {
|
||||
case []*ast.Term:
|
||||
if expr.IsEquality() {
|
||||
terms[1] = e.B1.PlugNamespaced(terms[1], caller)
|
||||
terms[2] = e.B2.PlugNamespaced(terms[2], caller)
|
||||
} else {
|
||||
for i := 1; i < len(terms); i++ {
|
||||
terms[i] = e.B1.PlugNamespaced(terms[i], caller)
|
||||
}
|
||||
}
|
||||
case *ast.Term:
|
||||
expr.Terms = e.B1.PlugNamespaced(terms, caller)
|
||||
}
|
||||
for i := range expr.With {
|
||||
expr.With[i].Value = e.B1.PlugNamespaced(expr.With[i].Value, caller)
|
||||
}
|
||||
return expr
|
||||
}
|
||||
|
||||
// saveSupport contains additional partially evaluated policies that are part
|
||||
// of the output of partial evaluation.
|
||||
//
|
||||
// The support structure is accumulated as partial evaluation runs and then
|
||||
// considered complete once partial evaluation finishes (but not before). This
|
||||
// differs from partially evaluated queries which are considered complete as
|
||||
// soon as each one finishes.
|
||||
type saveSupport struct {
|
||||
modules map[string]*ast.Module
|
||||
}
|
||||
|
||||
func newSaveSupport() *saveSupport {
|
||||
return &saveSupport{
|
||||
modules: map[string]*ast.Module{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *saveSupport) List() []*ast.Module {
|
||||
result := []*ast.Module{}
|
||||
for _, module := range s.modules {
|
||||
result = append(result, module)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *saveSupport) Exists(path ast.Ref) bool {
|
||||
k := path[:len(path)-1].String()
|
||||
module, ok := s.modules[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
name := ast.Var(path[len(path)-1].Value.(ast.String))
|
||||
for _, rule := range module.Rules {
|
||||
if rule.Head.Name.Equal(name) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *saveSupport) Insert(path ast.Ref, rule *ast.Rule) {
|
||||
pkg := path[:len(path)-1]
|
||||
k := pkg.String()
|
||||
module, ok := s.modules[k]
|
||||
if !ok {
|
||||
module = &ast.Module{
|
||||
Package: &ast.Package{
|
||||
Path: pkg,
|
||||
},
|
||||
}
|
||||
s.modules[k] = module
|
||||
}
|
||||
rule.Module = module
|
||||
module.Rules = append(module.Rules, rule)
|
||||
}
|
||||
|
||||
// saveRequired returns true if the statement x will result in some expressions
|
||||
// being saved. This check allows the evaluator to evaluate statements
|
||||
// completely during partial evaluation as long as they do not depend on any
|
||||
// kind of unknown value or statements that would generate saves.
|
||||
func saveRequired(c *ast.Compiler, ss *saveSet, b *bindings, x interface{}, rec bool) bool {
|
||||
|
||||
var found bool
|
||||
|
||||
vis := ast.NewGenericVisitor(func(node interface{}) bool {
|
||||
if found {
|
||||
return found
|
||||
}
|
||||
switch node := node.(type) {
|
||||
case *ast.Expr:
|
||||
found = len(node.With) > 0 || ignoreExprDuringPartial(node)
|
||||
case *ast.Term:
|
||||
switch v := node.Value.(type) {
|
||||
case ast.Var:
|
||||
// Variables only need to be tested in the node from call site
|
||||
// because once traversal recurses into a rule existing unknown
|
||||
// variables are out-of-scope.
|
||||
if !rec && ss.ContainsRecursive(node, b) {
|
||||
found = true
|
||||
}
|
||||
case ast.Ref:
|
||||
if ss.Contains(node, b) {
|
||||
found = true
|
||||
} else {
|
||||
for _, rule := range c.GetRulesDynamic(v) {
|
||||
if saveRequired(c, ss, b, rule, true) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return found
|
||||
})
|
||||
|
||||
vis.Walk(x)
|
||||
|
||||
return found
|
||||
}
|
||||
|
||||
func ignoreExprDuringPartial(expr *ast.Expr) bool {
|
||||
if !expr.IsCall() {
|
||||
return false
|
||||
}
|
||||
|
||||
bi, ok := ast.BuiltinMap[expr.Operator().String()]
|
||||
|
||||
return ok && ignoreDuringPartial(bi)
|
||||
}
|
||||
|
||||
func ignoreDuringPartial(bi *ast.Builtin) bool {
|
||||
for _, ignore := range ast.IgnoreDuringPartialEval {
|
||||
if bi == ignore {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
84
vendor/github.com/open-policy-agent/opa/topdown/sets.go
generated
vendored
Normal file
84
vendor/github.com/open-policy-agent/opa/topdown/sets.go
generated
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
// Copyright 2016 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
// Deprecated in v0.4.2 in favour of minus/infix "-" operation.
|
||||
func builtinSetDiff(a, b ast.Value) (ast.Value, error) {
|
||||
|
||||
s1, err := builtins.SetOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s2, err := builtins.SetOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s1.Diff(s2), nil
|
||||
}
|
||||
|
||||
// builtinSetIntersection returns the intersection of the given input sets
|
||||
func builtinSetIntersection(a ast.Value) (ast.Value, error) {
|
||||
|
||||
inputSet, err := builtins.SetOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// empty input set
|
||||
if inputSet.Len() == 0 {
|
||||
return ast.NewSet(), nil
|
||||
}
|
||||
|
||||
var result ast.Set
|
||||
|
||||
err = inputSet.Iter(func(x *ast.Term) error {
|
||||
n, err := builtins.SetOperand(x.Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
result = n
|
||||
} else {
|
||||
result = result.Intersect(n)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
||||
// builtinSetUnion returns the union of the given input sets
|
||||
func builtinSetUnion(a ast.Value) (ast.Value, error) {
|
||||
|
||||
inputSet, err := builtins.SetOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := ast.NewSet()
|
||||
|
||||
err = inputSet.Iter(func(x *ast.Term) error {
|
||||
n, err := builtins.SetOperand(x.Value, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result = result.Union(n)
|
||||
return nil
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin2(ast.SetDiff.Name, builtinSetDiff)
|
||||
RegisterFunctionalBuiltin1(ast.Intersection.Name, builtinSetIntersection)
|
||||
RegisterFunctionalBuiltin1(ast.Union.Name, builtinSetUnion)
|
||||
}
|
||||
393
vendor/github.com/open-policy-agent/opa/topdown/strings.go
generated
vendored
Normal file
393
vendor/github.com/open-policy-agent/opa/topdown/strings.go
generated
vendored
Normal file
@@ -0,0 +1,393 @@
|
||||
// Copyright 2016 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
func builtinFormatInt(a, b ast.Value) (ast.Value, error) {
|
||||
|
||||
input, err := builtins.NumberOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
base, err := builtins.NumberOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var format string
|
||||
switch base {
|
||||
case ast.Number("2"):
|
||||
format = "%b"
|
||||
case ast.Number("8"):
|
||||
format = "%o"
|
||||
case ast.Number("10"):
|
||||
format = "%d"
|
||||
case ast.Number("16"):
|
||||
format = "%x"
|
||||
default:
|
||||
return nil, builtins.NewOperandEnumErr(2, "2", "8", "10", "16")
|
||||
}
|
||||
|
||||
f := builtins.NumberToFloat(input)
|
||||
i, _ := f.Int(nil)
|
||||
|
||||
return ast.String(fmt.Sprintf(format, i)), nil
|
||||
}
|
||||
|
||||
func builtinConcat(a, b ast.Value) (ast.Value, error) {
|
||||
|
||||
join, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
strs := []string{}
|
||||
|
||||
switch b := b.(type) {
|
||||
case ast.Array:
|
||||
for i := range b {
|
||||
s, ok := b[i].Value.(ast.String)
|
||||
if !ok {
|
||||
return nil, builtins.NewOperandElementErr(2, b, b[i].Value, "string")
|
||||
}
|
||||
strs = append(strs, string(s))
|
||||
}
|
||||
case ast.Set:
|
||||
err := b.Iter(func(x *ast.Term) error {
|
||||
s, ok := x.Value.(ast.String)
|
||||
if !ok {
|
||||
return builtins.NewOperandElementErr(2, b, x.Value, "string")
|
||||
}
|
||||
strs = append(strs, string(s))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
return nil, builtins.NewOperandTypeErr(2, b, "set", "array")
|
||||
}
|
||||
|
||||
return ast.String(strings.Join(strs, string(join))), nil
|
||||
}
|
||||
|
||||
func builtinIndexOf(a, b ast.Value) (ast.Value, error) {
|
||||
base, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
search, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
index := strings.Index(string(base), string(search))
|
||||
return ast.IntNumberTerm(index).Value, nil
|
||||
}
|
||||
|
||||
func builtinSubstring(a, b, c ast.Value) (ast.Value, error) {
|
||||
|
||||
base, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startIndex, err := builtins.IntOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if startIndex >= len(base) {
|
||||
return ast.String(""), nil
|
||||
} else if startIndex < 0 {
|
||||
return nil, fmt.Errorf("negative offset")
|
||||
}
|
||||
|
||||
length, err := builtins.IntOperand(c, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var s ast.String
|
||||
if length < 0 {
|
||||
s = ast.String(base[startIndex:])
|
||||
} else {
|
||||
upto := startIndex + length
|
||||
if len(base) < upto {
|
||||
upto = len(base)
|
||||
}
|
||||
s = ast.String(base[startIndex:upto])
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func builtinContains(a, b ast.Value) (ast.Value, error) {
|
||||
s, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
substr, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.Boolean(strings.Contains(string(s), string(substr))), nil
|
||||
}
|
||||
|
||||
func builtinStartsWith(a, b ast.Value) (ast.Value, error) {
|
||||
s, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
prefix, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.Boolean(strings.HasPrefix(string(s), string(prefix))), nil
|
||||
}
|
||||
|
||||
func builtinEndsWith(a, b ast.Value) (ast.Value, error) {
|
||||
s, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
suffix, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.Boolean(strings.HasSuffix(string(s), string(suffix))), nil
|
||||
}
|
||||
|
||||
func builtinLower(a ast.Value) (ast.Value, error) {
|
||||
s, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.String(strings.ToLower(string(s))), nil
|
||||
}
|
||||
|
||||
func builtinUpper(a ast.Value) (ast.Value, error) {
|
||||
s, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.String(strings.ToUpper(string(s))), nil
|
||||
}
|
||||
|
||||
func builtinSplit(a, b ast.Value) (ast.Value, error) {
|
||||
s, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
elems := strings.Split(string(s), string(d))
|
||||
arr := make(ast.Array, len(elems))
|
||||
for i := range arr {
|
||||
arr[i] = ast.StringTerm(elems[i])
|
||||
}
|
||||
return arr, nil
|
||||
}
|
||||
|
||||
func builtinReplace(a, b, c ast.Value) (ast.Value, error) {
|
||||
s, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
old, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
new, err := builtins.StringOperand(c, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.String(strings.Replace(string(s), string(old), string(new), -1)), nil
|
||||
}
|
||||
|
||||
func builtinReplaceN(a, b ast.Value) (ast.Value, error) {
|
||||
asJSON, err := ast.JSON(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
oldnewObj, ok := asJSON.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, builtins.NewOperandTypeErr(1, a, "object")
|
||||
}
|
||||
|
||||
s, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var oldnewArr []string
|
||||
for k, v := range oldnewObj {
|
||||
strVal, ok := v.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("non-string value found in pattern object")
|
||||
}
|
||||
oldnewArr = append(oldnewArr, k, strVal)
|
||||
}
|
||||
|
||||
r := strings.NewReplacer(oldnewArr...)
|
||||
replaced := r.Replace(string(s))
|
||||
|
||||
return ast.String(replaced), nil
|
||||
}
|
||||
|
||||
func builtinTrim(a, b ast.Value) (ast.Value, error) {
|
||||
s, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.String(strings.Trim(string(s), string(c))), nil
|
||||
}
|
||||
|
||||
func builtinTrimLeft(a, b ast.Value) (ast.Value, error) {
|
||||
s, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.String(strings.TrimLeft(string(s), string(c))), nil
|
||||
}
|
||||
|
||||
func builtinTrimPrefix(a, b ast.Value) (ast.Value, error) {
|
||||
s, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pre, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.String(strings.TrimPrefix(string(s), string(pre))), nil
|
||||
}
|
||||
|
||||
func builtinTrimRight(a, b ast.Value) (ast.Value, error) {
|
||||
s, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.String(strings.TrimRight(string(s), string(c))), nil
|
||||
}
|
||||
|
||||
func builtinTrimSuffix(a, b ast.Value) (ast.Value, error) {
|
||||
s, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
suf, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.String(strings.TrimSuffix(string(s), string(suf))), nil
|
||||
}
|
||||
|
||||
func builtinTrimSpace(a ast.Value) (ast.Value, error) {
|
||||
s, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.String(strings.TrimSpace(string(s))), nil
|
||||
}
|
||||
|
||||
func builtinSprintf(a, b ast.Value) (ast.Value, error) {
|
||||
s, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
astArr, ok := b.(ast.Array)
|
||||
if !ok {
|
||||
return nil, builtins.NewOperandTypeErr(2, b, "array")
|
||||
}
|
||||
|
||||
args := make([]interface{}, len(astArr))
|
||||
|
||||
for i := range astArr {
|
||||
switch v := astArr[i].Value.(type) {
|
||||
case ast.Number:
|
||||
if n, ok := v.Int(); ok {
|
||||
args[i] = n
|
||||
} else if f, ok := v.Float64(); ok {
|
||||
args[i] = f
|
||||
} else {
|
||||
args[i] = v.String()
|
||||
}
|
||||
case ast.String:
|
||||
args[i] = string(v)
|
||||
default:
|
||||
args[i] = astArr[i].String()
|
||||
}
|
||||
}
|
||||
|
||||
return ast.String(fmt.Sprintf(string(s), args...)), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin2(ast.FormatInt.Name, builtinFormatInt)
|
||||
RegisterFunctionalBuiltin2(ast.Concat.Name, builtinConcat)
|
||||
RegisterFunctionalBuiltin2(ast.IndexOf.Name, builtinIndexOf)
|
||||
RegisterFunctionalBuiltin3(ast.Substring.Name, builtinSubstring)
|
||||
RegisterFunctionalBuiltin2(ast.Contains.Name, builtinContains)
|
||||
RegisterFunctionalBuiltin2(ast.StartsWith.Name, builtinStartsWith)
|
||||
RegisterFunctionalBuiltin2(ast.EndsWith.Name, builtinEndsWith)
|
||||
RegisterFunctionalBuiltin1(ast.Upper.Name, builtinUpper)
|
||||
RegisterFunctionalBuiltin1(ast.Lower.Name, builtinLower)
|
||||
RegisterFunctionalBuiltin2(ast.Split.Name, builtinSplit)
|
||||
RegisterFunctionalBuiltin3(ast.Replace.Name, builtinReplace)
|
||||
RegisterFunctionalBuiltin2(ast.ReplaceN.Name, builtinReplaceN)
|
||||
RegisterFunctionalBuiltin2(ast.Trim.Name, builtinTrim)
|
||||
RegisterFunctionalBuiltin2(ast.TrimLeft.Name, builtinTrimLeft)
|
||||
RegisterFunctionalBuiltin2(ast.TrimPrefix.Name, builtinTrimPrefix)
|
||||
RegisterFunctionalBuiltin2(ast.TrimRight.Name, builtinTrimRight)
|
||||
RegisterFunctionalBuiltin2(ast.TrimSuffix.Name, builtinTrimSuffix)
|
||||
RegisterFunctionalBuiltin1(ast.TrimSpace.Name, builtinTrimSpace)
|
||||
RegisterFunctionalBuiltin2(ast.Sprintf.Name, builtinSprintf)
|
||||
}
|
||||
203
vendor/github.com/open-policy-agent/opa/topdown/time.go
generated
vendored
Normal file
203
vendor/github.com/open-policy-agent/opa/topdown/time.go
generated
vendored
Normal file
@@ -0,0 +1,203 @@
|
||||
// Copyright 2017 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
type nowKeyID string
|
||||
|
||||
var nowKey = nowKeyID("time.now_ns")
|
||||
var tzCache map[string]*time.Location
|
||||
var tzCacheMutex *sync.Mutex
|
||||
|
||||
func builtinTimeNowNanos(bctx BuiltinContext, _ []*ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
exist, ok := bctx.Cache.Get(nowKey)
|
||||
var now *ast.Term
|
||||
|
||||
if !ok {
|
||||
curr := time.Now()
|
||||
now = ast.NewTerm(ast.Number(int64ToJSONNumber(curr.UnixNano())))
|
||||
bctx.Cache.Put(nowKey, now)
|
||||
} else {
|
||||
now = exist.(*ast.Term)
|
||||
}
|
||||
|
||||
return iter(now)
|
||||
}
|
||||
|
||||
func builtinTimeParseNanos(a, b ast.Value) (ast.Value, error) {
|
||||
|
||||
format, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
value, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := time.Parse(string(format), string(value))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.Number(int64ToJSONNumber(result.UnixNano())), nil
|
||||
}
|
||||
|
||||
func builtinTimeParseRFC3339Nanos(a ast.Value) (ast.Value, error) {
|
||||
|
||||
value, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := time.Parse(time.RFC3339, string(value))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.Number(int64ToJSONNumber(result.UnixNano())), nil
|
||||
}
|
||||
func builtinParseDurationNanos(a ast.Value) (ast.Value, error) {
|
||||
|
||||
duration, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
value, err := time.ParseDuration(string(duration))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ast.Number(int64ToJSONNumber(int64(value))), nil
|
||||
}
|
||||
|
||||
func builtinDate(a ast.Value) (ast.Value, error) {
|
||||
t, err := tzTime(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
year, month, day := t.Date()
|
||||
result := ast.Array{ast.IntNumberTerm(year), ast.IntNumberTerm(int(month)), ast.IntNumberTerm(day)}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func builtinClock(a ast.Value) (ast.Value, error) {
|
||||
t, err := tzTime(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hour, minute, second := t.Clock()
|
||||
result := ast.Array{ast.IntNumberTerm(hour), ast.IntNumberTerm(minute), ast.IntNumberTerm(second)}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func builtinWeekday(a ast.Value) (ast.Value, error) {
|
||||
t, err := tzTime(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
weekday := t.Weekday().String()
|
||||
return ast.String(weekday), nil
|
||||
}
|
||||
|
||||
func tzTime(a ast.Value) (t time.Time, err error) {
|
||||
var nVal ast.Value
|
||||
loc := time.UTC
|
||||
|
||||
switch va := a.(type) {
|
||||
case ast.Array:
|
||||
|
||||
if len(va) == 0 {
|
||||
return time.Time{}, builtins.NewOperandTypeErr(1, a, "either number (ns) or [number (ns), string (tz)]")
|
||||
}
|
||||
|
||||
nVal, err = builtins.NumberOperand(va[0].Value, 1)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
if len(va) > 1 {
|
||||
tzVal, err := builtins.StringOperand(va[1].Value, 1)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
tzName := string(tzVal)
|
||||
|
||||
switch tzName {
|
||||
case "", "UTC":
|
||||
// loc is already UTC
|
||||
|
||||
case "Local":
|
||||
loc = time.Local
|
||||
|
||||
default:
|
||||
var ok bool
|
||||
|
||||
tzCacheMutex.Lock()
|
||||
loc, ok = tzCache[tzName]
|
||||
|
||||
if !ok {
|
||||
loc, err = time.LoadLocation(tzName)
|
||||
if err != nil {
|
||||
tzCacheMutex.Unlock()
|
||||
return time.Time{}, err
|
||||
}
|
||||
tzCache[tzName] = loc
|
||||
}
|
||||
tzCacheMutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
case ast.Number:
|
||||
nVal = a
|
||||
|
||||
default:
|
||||
return time.Time{}, builtins.NewOperandTypeErr(1, a, "either number (ns) or [number (ns), string (tz)]")
|
||||
}
|
||||
|
||||
value, err := builtins.NumberOperand(nVal, 1)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
f := builtins.NumberToFloat(value)
|
||||
i64, acc := f.Int64()
|
||||
if acc != big.Exact {
|
||||
return time.Time{}, fmt.Errorf("timestamp too big")
|
||||
}
|
||||
|
||||
t = time.Unix(0, i64).In(loc)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func int64ToJSONNumber(i int64) json.Number {
|
||||
return json.Number(strconv.FormatInt(i, 10))
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.NowNanos.Name, builtinTimeNowNanos)
|
||||
RegisterFunctionalBuiltin1(ast.ParseRFC3339Nanos.Name, builtinTimeParseRFC3339Nanos)
|
||||
RegisterFunctionalBuiltin2(ast.ParseNanos.Name, builtinTimeParseNanos)
|
||||
RegisterFunctionalBuiltin1(ast.ParseDurationNanos.Name, builtinParseDurationNanos)
|
||||
RegisterFunctionalBuiltin1(ast.Date.Name, builtinDate)
|
||||
RegisterFunctionalBuiltin1(ast.Clock.Name, builtinClock)
|
||||
RegisterFunctionalBuiltin1(ast.Weekday.Name, builtinWeekday)
|
||||
tzCacheMutex = &sync.Mutex{}
|
||||
tzCache = make(map[string]*time.Location)
|
||||
}
|
||||
967
vendor/github.com/open-policy-agent/opa/topdown/tokens.go
generated
vendored
Normal file
967
vendor/github.com/open-policy-agent/opa/topdown/tokens.go
generated
vendored
Normal file
@@ -0,0 +1,967 @@
|
||||
// Copyright 2017 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/hmac"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwk"
|
||||
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws"
|
||||
)
|
||||
|
||||
var (
|
||||
jwtEncKey = ast.StringTerm("enc")
|
||||
jwtCtyKey = ast.StringTerm("cty")
|
||||
jwtAlgKey = ast.StringTerm("alg")
|
||||
jwtIssKey = ast.StringTerm("iss")
|
||||
jwtExpKey = ast.StringTerm("exp")
|
||||
jwtNbfKey = ast.StringTerm("nbf")
|
||||
jwtAudKey = ast.StringTerm("aud")
|
||||
)
|
||||
|
||||
// JSONWebToken represent the 3 parts (header, payload & signature) of
|
||||
// a JWT in Base64.
|
||||
type JSONWebToken struct {
|
||||
header string
|
||||
payload string
|
||||
signature string
|
||||
decodedHeader ast.Object
|
||||
}
|
||||
|
||||
// decodeHeader populates the decodedHeader field.
|
||||
func (token *JSONWebToken) decodeHeader() (err error) {
|
||||
var h ast.Value
|
||||
if h, err = builtinBase64UrlDecode(ast.String(token.header)); err != nil {
|
||||
return fmt.Errorf("JWT header had invalid encoding: %v", err)
|
||||
}
|
||||
if token.decodedHeader, err = validateJWTHeader(string(h.(ast.String))); err != nil {
|
||||
return err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Implements JWT decoding/validation based on RFC 7519 Section 7.2:
|
||||
// https://tools.ietf.org/html/rfc7519#section-7.2
|
||||
// It does no data validation, it merely checks that the given string
|
||||
// represents a structurally valid JWT. It supports JWTs using JWS compact
|
||||
// serialization.
|
||||
func builtinJWTDecode(a ast.Value) (ast.Value, error) {
|
||||
token, err := decodeJWT(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = token.decodeHeader(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p, err := builtinBase64UrlDecode(ast.String(token.payload))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("JWT payload had invalid encoding: %v", err)
|
||||
}
|
||||
|
||||
if cty := token.decodedHeader.Get(jwtCtyKey); cty != nil {
|
||||
ctyVal := string(cty.Value.(ast.String))
|
||||
// It is possible for the contents of a token to be another
|
||||
// token as a result of nested signing or encryption. To handle
|
||||
// the case where we are given a token such as this, we check
|
||||
// the content type and recurse on the payload if the content
|
||||
// is "JWT".
|
||||
// When the payload is itself another encoded JWT, then its
|
||||
// contents are quoted (behavior of https://jwt.io/). To fix
|
||||
// this, remove leading and trailing quotes.
|
||||
if ctyVal == "JWT" {
|
||||
p, err = builtinTrim(p, ast.String(`"'`))
|
||||
if err != nil {
|
||||
panic("not reached")
|
||||
}
|
||||
return builtinJWTDecode(p)
|
||||
}
|
||||
}
|
||||
|
||||
payload, err := extractJSONObject(string(p.(ast.String)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s, err := builtinBase64UrlDecode(ast.String(token.signature))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("JWT signature had invalid encoding: %v", err)
|
||||
}
|
||||
sign := hex.EncodeToString([]byte(s.(ast.String)))
|
||||
|
||||
arr := make(ast.Array, 3)
|
||||
arr[0] = ast.NewTerm(token.decodedHeader)
|
||||
arr[1] = ast.NewTerm(payload)
|
||||
arr[2] = ast.StringTerm(sign)
|
||||
|
||||
return arr, nil
|
||||
}
|
||||
|
||||
// Implements RS256 JWT signature verification
|
||||
func builtinJWTVerifyRS256(a ast.Value, b ast.Value) (ast.Value, error) {
|
||||
return builtinJWTVerifyRSA(a, b, func(publicKey *rsa.PublicKey, digest []byte, signature []byte) error {
|
||||
return rsa.VerifyPKCS1v15(
|
||||
publicKey,
|
||||
crypto.SHA256,
|
||||
digest,
|
||||
signature)
|
||||
})
|
||||
}
|
||||
|
||||
// Implements PS256 JWT signature verification
|
||||
func builtinJWTVerifyPS256(a ast.Value, b ast.Value) (ast.Value, error) {
|
||||
return builtinJWTVerifyRSA(a, b, func(publicKey *rsa.PublicKey, digest []byte, signature []byte) error {
|
||||
return rsa.VerifyPSS(
|
||||
publicKey,
|
||||
crypto.SHA256,
|
||||
digest,
|
||||
signature,
|
||||
nil)
|
||||
})
|
||||
}
|
||||
|
||||
// Implements RSA JWT signature verification.
|
||||
func builtinJWTVerifyRSA(a ast.Value, b ast.Value, verify func(publicKey *rsa.PublicKey, digest []byte, signature []byte) error) (ast.Value, error) {
|
||||
return builtinJWTVerify(a, b, func(publicKey interface{}, digest []byte, signature []byte) error {
|
||||
publicKeyRsa, ok := publicKey.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("incorrect public key type")
|
||||
}
|
||||
return verify(publicKeyRsa, digest, signature)
|
||||
})
|
||||
}
|
||||
|
||||
// Implements ES256 JWT signature verification.
|
||||
func builtinJWTVerifyES256(a ast.Value, b ast.Value) (ast.Value, error) {
|
||||
return builtinJWTVerify(a, b, func(publicKey interface{}, digest []byte, signature []byte) error {
|
||||
publicKeyEcdsa, ok := publicKey.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("incorrect public key type")
|
||||
}
|
||||
r, s := &big.Int{}, &big.Int{}
|
||||
n := len(signature) / 2
|
||||
r.SetBytes(signature[:n])
|
||||
s.SetBytes(signature[n:])
|
||||
if ecdsa.Verify(publicKeyEcdsa, digest, r, s) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("ECDSA signature verification error")
|
||||
})
|
||||
}
|
||||
|
||||
// getKeyFromCertOrJWK returns the public key found in a X.509 certificate or JWK key(s).
|
||||
// A valid PEM block is never valid JSON (and vice versa), hence can try parsing both.
|
||||
func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
|
||||
if block, rest := pem.Decode([]byte(certificate)); block != nil {
|
||||
if len(rest) > 0 {
|
||||
return nil, fmt.Errorf("extra data after a PEM certificate block")
|
||||
}
|
||||
|
||||
if block.Type == "CERTIFICATE" {
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse a PEM certificate")
|
||||
}
|
||||
|
||||
return []interface{}{cert.PublicKey}, nil
|
||||
}
|
||||
|
||||
if block.Type == "PUBLIC KEY" {
|
||||
key, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse a PEM public key")
|
||||
}
|
||||
|
||||
return []interface{}{key}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to extract a Key from the PEM certificate")
|
||||
}
|
||||
|
||||
jwks, err := jwk.ParseString(certificate)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse a JWK key (set)")
|
||||
}
|
||||
|
||||
var keys []interface{}
|
||||
for _, k := range jwks.Keys {
|
||||
key, err := k.Materialize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// Implements JWT signature verification.
|
||||
func builtinJWTVerify(a ast.Value, b ast.Value, verify func(publicKey interface{}, digest []byte, signature []byte) error) (ast.Value, error) {
|
||||
token, err := decodeJWT(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys, err := getKeyFromCertOrJWK(string(s))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signature, err := token.decodeSignature()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate the JWT signature
|
||||
for _, key := range keys {
|
||||
err = verify(key,
|
||||
getInputSHA([]byte(token.header+"."+token.payload)),
|
||||
[]byte(signature))
|
||||
|
||||
if err == nil {
|
||||
return ast.Boolean(true), nil
|
||||
}
|
||||
}
|
||||
|
||||
// None of the keys worked, return false
|
||||
return ast.Boolean(false), nil
|
||||
}
|
||||
|
||||
// Implements HS256 (secret) JWT signature verification
|
||||
func builtinJWTVerifyHS256(a ast.Value, b ast.Value) (ast.Value, error) {
|
||||
// Decode the JSON Web Token
|
||||
token, err := decodeJWT(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process Secret input
|
||||
astSecret, err := builtins.StringOperand(b, 2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
secret := string(astSecret)
|
||||
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
_, err = mac.Write([]byte(token.header + "." + token.payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signature, err := token.decodeSignature()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ast.Boolean(hmac.Equal([]byte(signature), mac.Sum(nil))), nil
|
||||
}
|
||||
|
||||
// -- Full JWT verification and decoding --
|
||||
|
||||
// Verification constraints. See tokens_test.go for unit tests.
|
||||
|
||||
// tokenConstraints holds decoded JWT verification constraints.
|
||||
type tokenConstraints struct {
|
||||
// The set of asymmetric keys we can verify with.
|
||||
keys []interface{}
|
||||
|
||||
// The single symmetric key we will verify with.
|
||||
secret string
|
||||
|
||||
// The algorithm that must be used to verify.
|
||||
// If "", any algorithm is acceptable.
|
||||
alg string
|
||||
|
||||
// The required issuer.
|
||||
// If "", any issuer is acceptable.
|
||||
iss string
|
||||
|
||||
// The required audience.
|
||||
// If "", no audience is acceptable.
|
||||
aud string
|
||||
|
||||
// The time to validate against, or -1 if no constraint set.
|
||||
// (If unset, the current time will be used.)
|
||||
time int64
|
||||
}
|
||||
|
||||
// tokenConstraintHandler is the handler type for JWT verification constraints.
|
||||
type tokenConstraintHandler func(value ast.Value, parameters *tokenConstraints) (err error)
|
||||
|
||||
// tokenConstraintTypes maps known JWT verification constraints to handlers.
|
||||
var tokenConstraintTypes = map[string]tokenConstraintHandler{
|
||||
"cert": tokenConstraintCert,
|
||||
"secret": func(value ast.Value, constraints *tokenConstraints) (err error) {
|
||||
return tokenConstraintString("secret", value, &constraints.secret)
|
||||
},
|
||||
"alg": func(value ast.Value, constraints *tokenConstraints) (err error) {
|
||||
return tokenConstraintString("alg", value, &constraints.alg)
|
||||
},
|
||||
"iss": func(value ast.Value, constraints *tokenConstraints) (err error) {
|
||||
return tokenConstraintString("iss", value, &constraints.iss)
|
||||
},
|
||||
"aud": func(value ast.Value, constraints *tokenConstraints) (err error) {
|
||||
return tokenConstraintString("aud", value, &constraints.aud)
|
||||
},
|
||||
"time": tokenConstraintTime,
|
||||
}
|
||||
|
||||
// tokenConstraintCert handles the `cert` constraint.
|
||||
func tokenConstraintCert(value ast.Value, constraints *tokenConstraints) (err error) {
|
||||
var s ast.String
|
||||
var ok bool
|
||||
if s, ok = value.(ast.String); !ok {
|
||||
return fmt.Errorf("cert constraint: must be a string")
|
||||
}
|
||||
|
||||
constraints.keys, err = getKeyFromCertOrJWK(string(s))
|
||||
return
|
||||
}
|
||||
|
||||
// tokenConstraintTime handles the `time` constraint.
|
||||
func tokenConstraintTime(value ast.Value, constraints *tokenConstraints) (err error) {
|
||||
var time ast.Number
|
||||
var ok bool
|
||||
if time, ok = value.(ast.Number); !ok {
|
||||
err = fmt.Errorf("token time constraint: must be a number")
|
||||
return
|
||||
}
|
||||
var timeFloat float64
|
||||
if timeFloat, err = strconv.ParseFloat(string(time), 64); err != nil {
|
||||
err = fmt.Errorf("token time constraint: %v", err)
|
||||
return
|
||||
}
|
||||
if timeFloat < 0 {
|
||||
err = fmt.Errorf("token time constraint: must not be negative")
|
||||
return
|
||||
}
|
||||
constraints.time = int64(timeFloat)
|
||||
return
|
||||
}
|
||||
|
||||
// tokenConstraintString handles string constraints.
|
||||
func tokenConstraintString(name string, value ast.Value, where *string) (err error) {
|
||||
var av ast.String
|
||||
var ok bool
|
||||
if av, ok = value.(ast.String); !ok {
|
||||
err = fmt.Errorf("%s constraint: must be a string", name)
|
||||
return
|
||||
}
|
||||
*where = string(av)
|
||||
return
|
||||
}
|
||||
|
||||
// parseTokenConstraints parses the constraints argument.
|
||||
func parseTokenConstraints(a ast.Value) (constraints tokenConstraints, err error) {
|
||||
constraints.time = -1
|
||||
var o ast.Object
|
||||
var ok bool
|
||||
if o, ok = a.(ast.Object); !ok {
|
||||
err = fmt.Errorf("token constraints must be object")
|
||||
return
|
||||
}
|
||||
if err = o.Iter(func(k *ast.Term, v *ast.Term) (err error) {
|
||||
var handler tokenConstraintHandler
|
||||
var ok bool
|
||||
name := string(k.Value.(ast.String))
|
||||
if handler, ok = tokenConstraintTypes[name]; ok {
|
||||
if err = handler(v.Value, &constraints); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Anything unknown is rejected.
|
||||
err = fmt.Errorf("unknown token validation constraint: %s", name)
|
||||
return
|
||||
}
|
||||
return
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// validate validates the constraints argument.
|
||||
func (constraints *tokenConstraints) validate() (err error) {
|
||||
keys := 0
|
||||
if constraints.keys != nil {
|
||||
keys++
|
||||
}
|
||||
if constraints.secret != "" {
|
||||
keys++
|
||||
}
|
||||
if keys > 1 {
|
||||
err = fmt.Errorf("duplicate key constraints")
|
||||
return
|
||||
}
|
||||
if keys < 1 {
|
||||
err = fmt.Errorf("no key constraint")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// verify verifies a JWT using the constraints and the algorithm from the header
|
||||
func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature string) error {
|
||||
// Construct the payload
|
||||
plaintext := []byte(header)
|
||||
plaintext = append(plaintext, []byte(".")...)
|
||||
plaintext = append(plaintext, payload...)
|
||||
// Look up the algorithm
|
||||
var ok bool
|
||||
var a tokenAlgorithm
|
||||
a, ok = tokenAlgorithms[alg]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown JWS algorithm: %s", alg)
|
||||
}
|
||||
// If we're configured with asymmetric key(s) then only trust that
|
||||
if constraints.keys != nil {
|
||||
verified := false
|
||||
for _, key := range constraints.keys {
|
||||
err := a.verify(key, a.hash, plaintext, []byte(signature))
|
||||
if err == nil {
|
||||
verified = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !verified {
|
||||
return errSignatureNotVerified
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if constraints.secret != "" {
|
||||
return a.verify([]byte(constraints.secret), a.hash, plaintext, []byte(signature))
|
||||
}
|
||||
// (*tokenConstraints)validate() should prevent this happening
|
||||
return errors.New("unexpectedly found no keys to trust")
|
||||
}
|
||||
|
||||
// validAudience checks the audience of the JWT.
|
||||
// It returns true if it meets the constraints and false otherwise.
|
||||
func (constraints *tokenConstraints) validAudience(aud ast.Value) (valid bool) {
|
||||
var ok bool
|
||||
var s ast.String
|
||||
if s, ok = aud.(ast.String); ok {
|
||||
return string(s) == constraints.aud
|
||||
}
|
||||
var a ast.Array
|
||||
if a, ok = aud.(ast.Array); ok {
|
||||
for _, t := range a {
|
||||
if s, ok = t.Value.(ast.String); ok {
|
||||
if string(s) == constraints.aud {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
// Ill-formed aud claim
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// JWT algorithms
|
||||
|
||||
type tokenVerifyFunction func(key interface{}, hash crypto.Hash, payload []byte, signature []byte) (err error)
|
||||
type tokenVerifyAsymmetricFunction func(key interface{}, hash crypto.Hash, digest []byte, signature []byte) (err error)
|
||||
|
||||
// jwtAlgorithm describes a JWS 'alg' value
|
||||
type tokenAlgorithm struct {
|
||||
hash crypto.Hash
|
||||
verify tokenVerifyFunction
|
||||
}
|
||||
|
||||
// tokenAlgorithms is the known JWT algorithms
|
||||
var tokenAlgorithms = map[string]tokenAlgorithm{
|
||||
"RS256": {crypto.SHA256, verifyAsymmetric(verifyRSAPKCS)},
|
||||
"RS384": {crypto.SHA384, verifyAsymmetric(verifyRSAPKCS)},
|
||||
"RS512": {crypto.SHA512, verifyAsymmetric(verifyRSAPKCS)},
|
||||
"PS256": {crypto.SHA256, verifyAsymmetric(verifyRSAPSS)},
|
||||
"PS384": {crypto.SHA384, verifyAsymmetric(verifyRSAPSS)},
|
||||
"PS512": {crypto.SHA512, verifyAsymmetric(verifyRSAPSS)},
|
||||
"ES256": {crypto.SHA256, verifyAsymmetric(verifyECDSA)},
|
||||
"ES384": {crypto.SHA384, verifyAsymmetric(verifyECDSA)},
|
||||
"ES512": {crypto.SHA512, verifyAsymmetric(verifyECDSA)},
|
||||
"HS256": {crypto.SHA256, verifyHMAC},
|
||||
"HS384": {crypto.SHA384, verifyHMAC},
|
||||
"HS512": {crypto.SHA512, verifyHMAC},
|
||||
}
|
||||
|
||||
// errSignatureNotVerified is returned when a signature cannot be verified.
|
||||
var errSignatureNotVerified = errors.New("signature not verified")
|
||||
|
||||
func verifyHMAC(key interface{}, hash crypto.Hash, payload []byte, signature []byte) (err error) {
|
||||
macKey, ok := key.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("incorrect symmetric key type")
|
||||
}
|
||||
mac := hmac.New(hash.New, macKey)
|
||||
if _, err = mac.Write([]byte(payload)); err != nil {
|
||||
return
|
||||
}
|
||||
if !hmac.Equal(signature, mac.Sum([]byte{})) {
|
||||
err = errSignatureNotVerified
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func verifyAsymmetric(verify tokenVerifyAsymmetricFunction) tokenVerifyFunction {
|
||||
return func(key interface{}, hash crypto.Hash, payload []byte, signature []byte) (err error) {
|
||||
h := hash.New()
|
||||
h.Write(payload)
|
||||
return verify(key, hash, h.Sum([]byte{}), signature)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyRSAPKCS(key interface{}, hash crypto.Hash, digest []byte, signature []byte) (err error) {
|
||||
publicKeyRsa, ok := key.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("incorrect public key type")
|
||||
}
|
||||
if err = rsa.VerifyPKCS1v15(publicKeyRsa, hash, digest, signature); err != nil {
|
||||
err = errSignatureNotVerified
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func verifyRSAPSS(key interface{}, hash crypto.Hash, digest []byte, signature []byte) (err error) {
|
||||
publicKeyRsa, ok := key.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("incorrect public key type")
|
||||
}
|
||||
if err = rsa.VerifyPSS(publicKeyRsa, hash, digest, signature, nil); err != nil {
|
||||
err = errSignatureNotVerified
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func verifyECDSA(key interface{}, hash crypto.Hash, digest []byte, signature []byte) (err error) {
|
||||
publicKeyEcdsa, ok := key.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("incorrect public key type")
|
||||
}
|
||||
r, s := &big.Int{}, &big.Int{}
|
||||
n := len(signature) / 2
|
||||
r.SetBytes(signature[:n])
|
||||
s.SetBytes(signature[n:])
|
||||
if ecdsa.Verify(publicKeyEcdsa, digest, r, s) {
|
||||
return nil
|
||||
}
|
||||
return errSignatureNotVerified
|
||||
}
|
||||
|
||||
// JWT header parsing and parameters. See tokens_test.go for unit tests.
|
||||
|
||||
// tokenHeaderType represents a recognized JWT header field
|
||||
// tokenHeader is a parsed JWT header
|
||||
type tokenHeader struct {
|
||||
alg string
|
||||
kid string
|
||||
typ string
|
||||
cty string
|
||||
crit map[string]bool
|
||||
unknown []string
|
||||
}
|
||||
|
||||
// tokenHeaderHandler handles a JWT header parameters
|
||||
type tokenHeaderHandler func(header *tokenHeader, value ast.Value) (err error)
|
||||
|
||||
// tokenHeaderTypes maps known JWT header parameters to handlers
|
||||
var tokenHeaderTypes = map[string]tokenHeaderHandler{
|
||||
"alg": func(header *tokenHeader, value ast.Value) (err error) {
|
||||
return tokenHeaderString("alg", &header.alg, value)
|
||||
},
|
||||
"kid": func(header *tokenHeader, value ast.Value) (err error) {
|
||||
return tokenHeaderString("kid", &header.kid, value)
|
||||
},
|
||||
"typ": func(header *tokenHeader, value ast.Value) (err error) {
|
||||
return tokenHeaderString("typ", &header.typ, value)
|
||||
},
|
||||
"cty": func(header *tokenHeader, value ast.Value) (err error) {
|
||||
return tokenHeaderString("cty", &header.cty, value)
|
||||
},
|
||||
"crit": tokenHeaderCrit,
|
||||
}
|
||||
|
||||
// tokenHeaderCrit handles the 'crit' header parameter
|
||||
func tokenHeaderCrit(header *tokenHeader, value ast.Value) (err error) {
|
||||
var ok bool
|
||||
var v ast.Array
|
||||
if v, ok = value.(ast.Array); !ok {
|
||||
err = fmt.Errorf("crit: must be a list")
|
||||
return
|
||||
}
|
||||
header.crit = map[string]bool{}
|
||||
for _, t := range v {
|
||||
var tv ast.String
|
||||
if tv, ok = t.Value.(ast.String); !ok {
|
||||
err = fmt.Errorf("crit: must be a list of strings")
|
||||
return
|
||||
}
|
||||
header.crit[string(tv)] = true
|
||||
}
|
||||
if len(header.crit) == 0 {
|
||||
err = fmt.Errorf("crit: must be a nonempty list") // 'MUST NOT' use the empty list
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// tokenHeaderString handles string-format JWT header parameters
|
||||
func tokenHeaderString(name string, where *string, value ast.Value) (err error) {
|
||||
var ok bool
|
||||
var v ast.String
|
||||
if v, ok = value.(ast.String); !ok {
|
||||
err = fmt.Errorf("%s: must be a string", name)
|
||||
return
|
||||
}
|
||||
*where = string(v)
|
||||
return
|
||||
}
|
||||
|
||||
// parseTokenHeader parses the JWT header.
|
||||
func parseTokenHeader(token *JSONWebToken) (header tokenHeader, err error) {
|
||||
header.unknown = []string{}
|
||||
if err = token.decodedHeader.Iter(func(k *ast.Term, v *ast.Term) (err error) {
|
||||
ks := string(k.Value.(ast.String))
|
||||
var ok bool
|
||||
var handler tokenHeaderHandler
|
||||
if handler, ok = tokenHeaderTypes[ks]; ok {
|
||||
if err = handler(&header, v.Value); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
header.unknown = append(header.unknown, ks)
|
||||
}
|
||||
return
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// validTokenHeader returns true if the JOSE header is valid, otherwise false.
|
||||
func (header *tokenHeader) valid() bool {
|
||||
// RFC7515 s4.1.1 alg MUST be present
|
||||
if header.alg == "" {
|
||||
return false
|
||||
}
|
||||
// RFC7515 4.1.11 JWS is invalid if there is a critical parameter that we did not recognize
|
||||
for _, u := range header.unknown {
|
||||
if header.crit[u] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func commonBuiltinJWTEncodeSign(inputHeaders, jwsPayload, jwkSrc string) (v ast.Value, err error) {
|
||||
|
||||
keys, err := jwk.ParseString(jwkSrc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key, err := keys.Keys[0].Materialize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if jwk.GetKeyTypeFromKey(key) != keys.Keys[0].GetKeyType() {
|
||||
return nil, fmt.Errorf("JWK derived key type and keyType parameter do not match")
|
||||
}
|
||||
|
||||
standardHeaders := &jws.StandardHeaders{}
|
||||
jwsHeaders := []byte(inputHeaders)
|
||||
err = json.Unmarshal(jwsHeaders, standardHeaders)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
alg := standardHeaders.GetAlgorithm()
|
||||
|
||||
if (standardHeaders.Type == "" || standardHeaders.Type == "JWT") && !json.Valid([]byte(jwsPayload)) {
|
||||
return nil, fmt.Errorf("type is JWT but payload is not JSON")
|
||||
}
|
||||
|
||||
// process payload and sign
|
||||
var jwsCompact []byte
|
||||
jwsCompact, err = jws.SignLiteral([]byte(jwsPayload), alg, key, jwsHeaders)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ast.String(jwsCompact[:]), nil
|
||||
|
||||
}
|
||||
|
||||
func builtinJWTEncodeSign(a ast.Value, b ast.Value, c ast.Value) (v ast.Value, err error) {
|
||||
|
||||
jwkSrc := c.String()
|
||||
|
||||
inputHeaders := a.String()
|
||||
|
||||
jwsPayload := b.String()
|
||||
|
||||
return commonBuiltinJWTEncodeSign(inputHeaders, jwsPayload, jwkSrc)
|
||||
|
||||
}
|
||||
|
||||
func builtinJWTEncodeSignRaw(a ast.Value, b ast.Value, c ast.Value) (v ast.Value, err error) {
|
||||
|
||||
jwkSrc, err := builtins.StringOperand(c, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inputHeaders, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwsPayload, err := builtins.StringOperand(b, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return commonBuiltinJWTEncodeSign(string(inputHeaders), string(jwsPayload), string(jwkSrc))
|
||||
}
|
||||
|
||||
// Implements full JWT decoding, validation and verification.
|
||||
func builtinJWTDecodeVerify(a ast.Value, b ast.Value) (v ast.Value, err error) {
|
||||
// io.jwt.decode_verify(string, constraints, [valid, header, payload])
|
||||
//
|
||||
// If valid is true then the signature verifies and all constraints are met.
|
||||
// If valid is false then either the signature did not verify or some constrain
|
||||
// was not met.
|
||||
//
|
||||
// Decoding errors etc are returned as errors.
|
||||
arr := make(ast.Array, 3)
|
||||
arr[0] = ast.BooleanTerm(false) // by default, not verified
|
||||
arr[1] = ast.NewTerm(ast.NewObject())
|
||||
arr[2] = ast.NewTerm(ast.NewObject())
|
||||
var constraints tokenConstraints
|
||||
if constraints, err = parseTokenConstraints(b); err != nil {
|
||||
return
|
||||
}
|
||||
if err = constraints.validate(); err != nil {
|
||||
return
|
||||
}
|
||||
var token *JSONWebToken
|
||||
var p ast.Value
|
||||
for {
|
||||
// RFC7519 7.2 #1-2 split into parts
|
||||
if token, err = decodeJWT(a); err != nil {
|
||||
return
|
||||
}
|
||||
// RFC7519 7.2 #3, #4, #6
|
||||
if err = token.decodeHeader(); err != nil {
|
||||
return
|
||||
}
|
||||
// RFC7159 7.2 #5 (and RFC7159 5.2 #5) validate header fields
|
||||
var header tokenHeader
|
||||
if header, err = parseTokenHeader(token); err != nil {
|
||||
return
|
||||
}
|
||||
if !header.valid() {
|
||||
return arr, nil
|
||||
}
|
||||
// Check constraints that impact signature verification.
|
||||
if constraints.alg != "" && constraints.alg != header.alg {
|
||||
return arr, nil
|
||||
}
|
||||
// RFC7159 7.2 #7 verify the signature
|
||||
var signature string
|
||||
if signature, err = token.decodeSignature(); err != nil {
|
||||
return
|
||||
}
|
||||
if err = constraints.verify(header.kid, header.alg, token.header, token.payload, signature); err != nil {
|
||||
if err == errSignatureNotVerified {
|
||||
return arr, nil
|
||||
}
|
||||
return
|
||||
}
|
||||
// RFC7159 7.2 #9-10 decode the payload
|
||||
if p, err = builtinBase64UrlDecode(ast.String(token.payload)); err != nil {
|
||||
return nil, fmt.Errorf("JWT payload had invalid encoding: %v", err)
|
||||
}
|
||||
// RFC7159 7.2 #8 and 5.2 cty
|
||||
if strings.ToUpper(header.cty) == "JWT" {
|
||||
// Nested JWT, go round again
|
||||
a = p
|
||||
continue
|
||||
} else {
|
||||
// Non-nested JWT (or we've reached the bottom of the nesting).
|
||||
break
|
||||
}
|
||||
}
|
||||
var payload ast.Object
|
||||
if payload, err = extractJSONObject(string(p.(ast.String))); err != nil {
|
||||
return
|
||||
}
|
||||
// Check registered claim names against constraints or environment
|
||||
// RFC7159 4.1.1 iss
|
||||
if constraints.iss != "" {
|
||||
if iss := payload.Get(jwtIssKey); iss != nil {
|
||||
issVal := string(iss.Value.(ast.String))
|
||||
if constraints.iss != issVal {
|
||||
return arr, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
// RFC7159 4.1.3 aud
|
||||
if aud := payload.Get(jwtAudKey); aud != nil {
|
||||
if !constraints.validAudience(aud.Value) {
|
||||
return arr, nil
|
||||
}
|
||||
} else {
|
||||
if constraints.aud != "" {
|
||||
return arr, nil
|
||||
}
|
||||
}
|
||||
// RFC7159 4.1.4 exp
|
||||
if exp := payload.Get(jwtExpKey); exp != nil {
|
||||
var expVal int64
|
||||
if expVal, err = strconv.ParseInt(string(exp.Value.(ast.Number)), 10, 64); err != nil {
|
||||
err = fmt.Errorf("parsing 'exp' JWT claim: %v", err)
|
||||
return
|
||||
}
|
||||
if constraints.time < 0 {
|
||||
constraints.time = time.Now().UnixNano()
|
||||
}
|
||||
// constraints.time is in nanoseconds but expVal is in seconds
|
||||
if constraints.time/1000000000 >= expVal {
|
||||
return arr, nil
|
||||
}
|
||||
}
|
||||
// RFC7159 4.1.5 nbf
|
||||
if nbf := payload.Get(jwtNbfKey); nbf != nil {
|
||||
var nbfVal int64
|
||||
if nbfVal, err = strconv.ParseInt(string(nbf.Value.(ast.Number)), 10, 64); err != nil {
|
||||
err = fmt.Errorf("parsing 'nbf' JWT claim: %v", err)
|
||||
return
|
||||
}
|
||||
if constraints.time < 0 {
|
||||
constraints.time = time.Now().UnixNano()
|
||||
}
|
||||
// constraints.time is in nanoseconds but nbfVal is in seconds
|
||||
if constraints.time/1000000000 < nbfVal {
|
||||
return arr, nil
|
||||
}
|
||||
}
|
||||
// Format the result
|
||||
arr[0] = ast.BooleanTerm(true)
|
||||
arr[1] = ast.NewTerm(token.decodedHeader)
|
||||
arr[2] = ast.NewTerm(payload)
|
||||
return arr, nil
|
||||
}
|
||||
|
||||
// -- Utilities --
|
||||
|
||||
func decodeJWT(a ast.Value) (*JSONWebToken, error) {
|
||||
// Parse the JSON Web Token
|
||||
astEncode, err := builtins.StringOperand(a, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encoding := string(astEncode)
|
||||
if !strings.Contains(encoding, ".") {
|
||||
return nil, errors.New("encoded JWT had no period separators")
|
||||
}
|
||||
|
||||
parts := strings.Split(encoding, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("encoded JWT must have 3 sections, found %d", len(parts))
|
||||
}
|
||||
|
||||
return &JSONWebToken{header: parts[0], payload: parts[1], signature: parts[2]}, nil
|
||||
}
|
||||
|
||||
func (token *JSONWebToken) decodeSignature() (string, error) {
|
||||
decodedSignature, err := builtinBase64UrlDecode(ast.String(token.signature))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signatureAst, err := builtins.StringOperand(decodedSignature, 1)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(signatureAst), err
|
||||
}
|
||||
|
||||
// Extract, validate and return the JWT header as an ast.Object.
|
||||
func validateJWTHeader(h string) (ast.Object, error) {
|
||||
header, err := extractJSONObject(h)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bad JWT header: %v", err)
|
||||
}
|
||||
|
||||
// There are two kinds of JWT tokens, a JSON Web Signature (JWS) and
|
||||
// a JSON Web Encryption (JWE). The latter is very involved, and we
|
||||
// won't support it for now.
|
||||
// This code checks which kind of JWT we are dealing with according to
|
||||
// RFC 7516 Section 9: https://tools.ietf.org/html/rfc7516#section-9
|
||||
if header.Get(jwtEncKey) != nil {
|
||||
return nil, errors.New("JWT is a JWE object, which is not supported")
|
||||
}
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func extractJSONObject(s string) (ast.Object, error) {
|
||||
// XXX: This code relies on undocumented behavior of Go's
|
||||
// json.Unmarshal using the last occurrence of duplicate keys in a JSON
|
||||
// Object. If duplicate keys are present in a JWT, the last must be
|
||||
// used or the token rejected. Since detecting duplicates is tantamount
|
||||
// to parsing it ourselves, we're relying on the Go implementation
|
||||
// using the last occurring instance of the key, which is the behavior
|
||||
// as of Go 1.8.1.
|
||||
v, err := builtinJSONUnmarshal(ast.String(s))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
o, ok := v.(ast.Object)
|
||||
if !ok {
|
||||
return nil, errors.New("decoded JSON type was not an Object")
|
||||
}
|
||||
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// getInputSha returns the SHA256 checksum of the input
|
||||
func getInputSHA(input []byte) (hash []byte) {
|
||||
hasher := sha256.New()
|
||||
hasher.Write(input)
|
||||
return hasher.Sum(nil)
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin1(ast.JWTDecode.Name, builtinJWTDecode)
|
||||
RegisterFunctionalBuiltin2(ast.JWTVerifyRS256.Name, builtinJWTVerifyRS256)
|
||||
RegisterFunctionalBuiltin2(ast.JWTVerifyPS256.Name, builtinJWTVerifyPS256)
|
||||
RegisterFunctionalBuiltin2(ast.JWTVerifyES256.Name, builtinJWTVerifyES256)
|
||||
RegisterFunctionalBuiltin2(ast.JWTVerifyHS256.Name, builtinJWTVerifyHS256)
|
||||
RegisterFunctionalBuiltin2(ast.JWTDecodeVerify.Name, builtinJWTDecodeVerify)
|
||||
RegisterFunctionalBuiltin3(ast.JWTEncodeSignRaw.Name, builtinJWTEncodeSignRaw)
|
||||
RegisterFunctionalBuiltin3(ast.JWTEncodeSign.Name, builtinJWTEncodeSign)
|
||||
}
|
||||
304
vendor/github.com/open-policy-agent/opa/topdown/trace.go
generated
vendored
Normal file
304
vendor/github.com/open-policy-agent/opa/topdown/trace.go
generated
vendored
Normal file
@@ -0,0 +1,304 @@
|
||||
// Copyright 2016 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"github.com/open-policy-agent/opa/topdown/builtins"
|
||||
)
|
||||
|
||||
// Op defines the types of tracing events.
|
||||
type Op string
|
||||
|
||||
const (
|
||||
// EnterOp is emitted when a new query is about to be evaluated.
|
||||
EnterOp Op = "Enter"
|
||||
|
||||
// ExitOp is emitted when a query has evaluated to true.
|
||||
ExitOp Op = "Exit"
|
||||
|
||||
// EvalOp is emitted when an expression is about to be evaluated.
|
||||
EvalOp Op = "Eval"
|
||||
|
||||
// RedoOp is emitted when an expression, rule, or query is being re-evaluated.
|
||||
RedoOp Op = "Redo"
|
||||
|
||||
// SaveOp is emitted when an expression is saved instead of evaluated
|
||||
// during partial evaluation.
|
||||
SaveOp Op = "Save"
|
||||
|
||||
// FailOp is emitted when an expression evaluates to false.
|
||||
FailOp Op = "Fail"
|
||||
|
||||
// NoteOp is emitted when an expression invokes a tracing built-in function.
|
||||
NoteOp Op = "Note"
|
||||
|
||||
// IndexOp is emitted during an expression evaluation to represent lookup
|
||||
// matches.
|
||||
IndexOp Op = "Index"
|
||||
)
|
||||
|
||||
// VarMetadata provides some user facing information about
|
||||
// a variable in some policy.
|
||||
type VarMetadata struct {
|
||||
Name ast.Var `json:"name"`
|
||||
Location *ast.Location `json:"location"`
|
||||
}
|
||||
|
||||
// Event contains state associated with a tracing event.
|
||||
type Event struct {
|
||||
Op Op // Identifies type of event.
|
||||
Node ast.Node // Contains AST node relevant to the event.
|
||||
Location *ast.Location // The location of the Node this event relates to.
|
||||
QueryID uint64 // Identifies the query this event belongs to.
|
||||
ParentID uint64 // Identifies the parent query this event belongs to.
|
||||
Locals *ast.ValueMap // Contains local variable bindings from the query context.
|
||||
LocalMetadata map[ast.Var]VarMetadata // Contains metadata for the local variable bindings.
|
||||
Message string // Contains message for Note events.
|
||||
}
|
||||
|
||||
// HasRule returns true if the Event contains an ast.Rule.
|
||||
func (evt *Event) HasRule() bool {
|
||||
_, ok := evt.Node.(*ast.Rule)
|
||||
return ok
|
||||
}
|
||||
|
||||
// HasBody returns true if the Event contains an ast.Body.
|
||||
func (evt *Event) HasBody() bool {
|
||||
_, ok := evt.Node.(ast.Body)
|
||||
return ok
|
||||
}
|
||||
|
||||
// HasExpr returns true if the Event contains an ast.Expr.
|
||||
func (evt *Event) HasExpr() bool {
|
||||
_, ok := evt.Node.(*ast.Expr)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Equal returns true if this event is equal to the other event.
|
||||
func (evt *Event) Equal(other *Event) bool {
|
||||
if evt.Op != other.Op {
|
||||
return false
|
||||
}
|
||||
if evt.QueryID != other.QueryID {
|
||||
return false
|
||||
}
|
||||
if evt.ParentID != other.ParentID {
|
||||
return false
|
||||
}
|
||||
if !evt.equalNodes(other) {
|
||||
return false
|
||||
}
|
||||
return evt.Locals.Equal(other.Locals)
|
||||
}
|
||||
|
||||
func (evt *Event) String() string {
|
||||
return fmt.Sprintf("%v %v %v (qid=%v, pqid=%v)", evt.Op, evt.Node, evt.Locals, evt.QueryID, evt.ParentID)
|
||||
}
|
||||
|
||||
func (evt *Event) equalNodes(other *Event) bool {
|
||||
switch a := evt.Node.(type) {
|
||||
case ast.Body:
|
||||
if b, ok := other.Node.(ast.Body); ok {
|
||||
return a.Equal(b)
|
||||
}
|
||||
case *ast.Rule:
|
||||
if b, ok := other.Node.(*ast.Rule); ok {
|
||||
return a.Equal(b)
|
||||
}
|
||||
case *ast.Expr:
|
||||
if b, ok := other.Node.(*ast.Expr); ok {
|
||||
return a.Equal(b)
|
||||
}
|
||||
case nil:
|
||||
return other.Node == nil
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Tracer defines the interface for tracing in the top-down evaluation engine.
|
||||
type Tracer interface {
|
||||
Enabled() bool
|
||||
Trace(*Event)
|
||||
}
|
||||
|
||||
// BufferTracer implements the Tracer interface by simply buffering all events
|
||||
// received.
|
||||
type BufferTracer []*Event
|
||||
|
||||
// NewBufferTracer returns a new BufferTracer.
|
||||
func NewBufferTracer() *BufferTracer {
|
||||
return &BufferTracer{}
|
||||
}
|
||||
|
||||
// Enabled always returns true if the BufferTracer is instantiated.
|
||||
func (b *BufferTracer) Enabled() bool {
|
||||
if b == nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Trace adds the event to the buffer.
|
||||
func (b *BufferTracer) Trace(evt *Event) {
|
||||
*b = append(*b, evt)
|
||||
}
|
||||
|
||||
// PrettyTrace pretty prints the trace to the writer.
|
||||
func PrettyTrace(w io.Writer, trace []*Event) {
|
||||
depths := depths{}
|
||||
for _, event := range trace {
|
||||
depth := depths.GetOrSet(event.QueryID, event.ParentID)
|
||||
fmt.Fprintln(w, formatEvent(event, depth))
|
||||
}
|
||||
}
|
||||
|
||||
// PrettyTraceWithLocation prints the trace to the writer and includes location information
|
||||
func PrettyTraceWithLocation(w io.Writer, trace []*Event) {
|
||||
depths := depths{}
|
||||
for _, event := range trace {
|
||||
depth := depths.GetOrSet(event.QueryID, event.ParentID)
|
||||
location := formatLocation(event)
|
||||
fmt.Fprintln(w, fmt.Sprintf("%v %v", location, formatEvent(event, depth)))
|
||||
}
|
||||
}
|
||||
|
||||
func formatEvent(event *Event, depth int) string {
|
||||
padding := formatEventPadding(event, depth)
|
||||
if event.Op == NoteOp {
|
||||
return fmt.Sprintf("%v%v %q", padding, event.Op, event.Message)
|
||||
} else if event.Message != "" {
|
||||
return fmt.Sprintf("%v%v %v %v", padding, event.Op, event.Node, event.Message)
|
||||
} else {
|
||||
switch node := event.Node.(type) {
|
||||
case *ast.Rule:
|
||||
return fmt.Sprintf("%v%v %v", padding, event.Op, node.Path())
|
||||
default:
|
||||
return fmt.Sprintf("%v%v %v", padding, event.Op, rewrite(event).Node)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func formatEventPadding(event *Event, depth int) string {
|
||||
spaces := formatEventSpaces(event, depth)
|
||||
padding := ""
|
||||
if spaces > 1 {
|
||||
padding += strings.Repeat("| ", spaces-1)
|
||||
}
|
||||
return padding
|
||||
}
|
||||
|
||||
func formatEventSpaces(event *Event, depth int) int {
|
||||
switch event.Op {
|
||||
case EnterOp:
|
||||
return depth
|
||||
case RedoOp:
|
||||
if _, ok := event.Node.(*ast.Expr); !ok {
|
||||
return depth
|
||||
}
|
||||
}
|
||||
return depth + 1
|
||||
}
|
||||
|
||||
func formatLocation(event *Event) string {
|
||||
if event.Op == NoteOp {
|
||||
return fmt.Sprintf("%-19v", "note")
|
||||
}
|
||||
|
||||
location := event.Location
|
||||
if location == nil {
|
||||
return fmt.Sprintf("%-19v", "")
|
||||
}
|
||||
|
||||
if location.File == "" {
|
||||
return fmt.Sprintf("%-19v", fmt.Sprintf("%.15v:%v", "query", location.Row))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%-19v", fmt.Sprintf("%.15v:%v", location.File, location.Row))
|
||||
}
|
||||
|
||||
// depths is a helper for computing the depth of an event. Events within the
|
||||
// same query all have the same depth. The depth of query is
|
||||
// depth(parent(query))+1.
|
||||
type depths map[uint64]int
|
||||
|
||||
func (ds depths) GetOrSet(qid uint64, pqid uint64) int {
|
||||
depth := ds[qid]
|
||||
if depth == 0 {
|
||||
depth = ds[pqid]
|
||||
depth++
|
||||
ds[qid] = depth
|
||||
}
|
||||
return depth
|
||||
}
|
||||
|
||||
func builtinTrace(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
str, err := builtins.StringOperand(args[0].Value, 1)
|
||||
if err != nil {
|
||||
return handleBuiltinErr(ast.Trace.Name, bctx.Location, err)
|
||||
}
|
||||
|
||||
if !traceIsEnabled(bctx.Tracers) {
|
||||
return iter(ast.BooleanTerm(true))
|
||||
}
|
||||
|
||||
evt := &Event{
|
||||
Op: NoteOp,
|
||||
QueryID: bctx.QueryID,
|
||||
ParentID: bctx.ParentID,
|
||||
Message: string(str),
|
||||
}
|
||||
|
||||
for i := range bctx.Tracers {
|
||||
bctx.Tracers[i].Trace(evt)
|
||||
}
|
||||
|
||||
return iter(ast.BooleanTerm(true))
|
||||
}
|
||||
|
||||
func traceIsEnabled(tracers []Tracer) bool {
|
||||
for i := range tracers {
|
||||
if tracers[i].Enabled() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func rewrite(event *Event) *Event {
|
||||
|
||||
cpy := *event
|
||||
|
||||
var node ast.Node
|
||||
|
||||
switch v := event.Node.(type) {
|
||||
case *ast.Expr:
|
||||
node = v.Copy()
|
||||
case ast.Body:
|
||||
node = v.Copy()
|
||||
case *ast.Rule:
|
||||
node = v.Copy()
|
||||
}
|
||||
|
||||
ast.TransformVars(node, func(v ast.Var) (ast.Value, error) {
|
||||
if meta, ok := cpy.LocalMetadata[v]; ok {
|
||||
return meta.Name, nil
|
||||
}
|
||||
return v, nil
|
||||
})
|
||||
|
||||
cpy.Node = node
|
||||
|
||||
return &cpy
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.Trace.Name, builtinTrace)
|
||||
}
|
||||
82
vendor/github.com/open-policy-agent/opa/topdown/type.go
generated
vendored
Normal file
82
vendor/github.com/open-policy-agent/opa/topdown/type.go
generated
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright 2018 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
)
|
||||
|
||||
func builtinIsNumber(a ast.Value) (ast.Value, error) {
|
||||
switch a.(type) {
|
||||
case ast.Number:
|
||||
return ast.Boolean(true), nil
|
||||
default:
|
||||
return nil, BuiltinEmpty{}
|
||||
}
|
||||
}
|
||||
|
||||
func builtinIsString(a ast.Value) (ast.Value, error) {
|
||||
switch a.(type) {
|
||||
case ast.String:
|
||||
return ast.Boolean(true), nil
|
||||
default:
|
||||
return nil, BuiltinEmpty{}
|
||||
}
|
||||
}
|
||||
|
||||
func builtinIsBoolean(a ast.Value) (ast.Value, error) {
|
||||
switch a.(type) {
|
||||
case ast.Boolean:
|
||||
return ast.Boolean(true), nil
|
||||
default:
|
||||
return nil, BuiltinEmpty{}
|
||||
}
|
||||
}
|
||||
|
||||
func builtinIsArray(a ast.Value) (ast.Value, error) {
|
||||
switch a.(type) {
|
||||
case ast.Array:
|
||||
return ast.Boolean(true), nil
|
||||
default:
|
||||
return nil, BuiltinEmpty{}
|
||||
}
|
||||
}
|
||||
|
||||
func builtinIsSet(a ast.Value) (ast.Value, error) {
|
||||
switch a.(type) {
|
||||
case ast.Set:
|
||||
return ast.Boolean(true), nil
|
||||
default:
|
||||
return nil, BuiltinEmpty{}
|
||||
}
|
||||
}
|
||||
|
||||
func builtinIsObject(a ast.Value) (ast.Value, error) {
|
||||
switch a.(type) {
|
||||
case ast.Object:
|
||||
return ast.Boolean(true), nil
|
||||
default:
|
||||
return nil, BuiltinEmpty{}
|
||||
}
|
||||
}
|
||||
|
||||
func builtinIsNull(a ast.Value) (ast.Value, error) {
|
||||
switch a.(type) {
|
||||
case ast.Null:
|
||||
return ast.Boolean(true), nil
|
||||
default:
|
||||
return nil, BuiltinEmpty{}
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin1(ast.IsNumber.Name, builtinIsNumber)
|
||||
RegisterFunctionalBuiltin1(ast.IsString.Name, builtinIsString)
|
||||
RegisterFunctionalBuiltin1(ast.IsBoolean.Name, builtinIsBoolean)
|
||||
RegisterFunctionalBuiltin1(ast.IsArray.Name, builtinIsArray)
|
||||
RegisterFunctionalBuiltin1(ast.IsSet.Name, builtinIsSet)
|
||||
RegisterFunctionalBuiltin1(ast.IsObject.Name, builtinIsObject)
|
||||
RegisterFunctionalBuiltin1(ast.IsNull.Name, builtinIsNull)
|
||||
}
|
||||
36
vendor/github.com/open-policy-agent/opa/topdown/type_name.go
generated
vendored
Normal file
36
vendor/github.com/open-policy-agent/opa/topdown/type_name.go
generated
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright 2018 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
)
|
||||
|
||||
func builtinTypeName(a ast.Value) (ast.Value, error) {
|
||||
switch a.(type) {
|
||||
case ast.Null:
|
||||
return ast.String("null"), nil
|
||||
case ast.Boolean:
|
||||
return ast.String("boolean"), nil
|
||||
case ast.Number:
|
||||
return ast.String("number"), nil
|
||||
case ast.String:
|
||||
return ast.String("string"), nil
|
||||
case ast.Array:
|
||||
return ast.String("array"), nil
|
||||
case ast.Object:
|
||||
return ast.String("object"), nil
|
||||
case ast.Set:
|
||||
return ast.String("set"), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("illegal value")
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterFunctionalBuiltin1(ast.TypeNameBuiltin.Name, builtinTypeName)
|
||||
}
|
||||
84
vendor/github.com/open-policy-agent/opa/topdown/walk.go
generated
vendored
Normal file
84
vendor/github.com/open-policy-agent/opa/topdown/walk.go
generated
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
// Copyright 2017 The OPA Authors. All rights reserved.
|
||||
// Use of this source code is governed by an Apache2
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package topdown
|
||||
|
||||
import (
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
)
|
||||
|
||||
func evalWalk(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
||||
input := args[0]
|
||||
filter := getOutputPath(args)
|
||||
var path ast.Array
|
||||
return walk(filter, path, input, iter)
|
||||
}
|
||||
|
||||
func walk(filter, path ast.Array, input *ast.Term, iter func(*ast.Term) error) error {
|
||||
|
||||
if len(filter) == 0 {
|
||||
if err := iter(ast.ArrayTerm(ast.NewTerm(path), input)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(filter) > 0 {
|
||||
key := filter[0]
|
||||
filter = filter[1:]
|
||||
if key.IsGround() {
|
||||
if term := input.Get(key); term != nil {
|
||||
return walk(filter, append(path, key), term, iter)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
switch v := input.Value.(type) {
|
||||
case ast.Array:
|
||||
for i := range v {
|
||||
path = append(path, ast.IntNumberTerm(i))
|
||||
if err := walk(filter, path, v[i], iter); err != nil {
|
||||
return err
|
||||
}
|
||||
path = path[:len(path)-1]
|
||||
}
|
||||
case ast.Object:
|
||||
return v.Iter(func(k, v *ast.Term) error {
|
||||
path = append(path, k)
|
||||
if err := walk(filter, path, v, iter); err != nil {
|
||||
return err
|
||||
}
|
||||
path = path[:len(path)-1]
|
||||
return nil
|
||||
})
|
||||
case ast.Set:
|
||||
return v.Iter(func(elem *ast.Term) error {
|
||||
path = append(path, elem)
|
||||
if err := walk(filter, path, elem, iter); err != nil {
|
||||
return err
|
||||
}
|
||||
path = path[:len(path)-1]
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getOutputPath(args []*ast.Term) ast.Array {
|
||||
if len(args) == 2 {
|
||||
if arr, ok := args[1].Value.(ast.Array); ok {
|
||||
if len(arr) == 2 {
|
||||
if path, ok := arr[0].Value.(ast.Array); ok {
|
||||
return path
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterBuiltinFunc(ast.WalkBuiltin.Name, evalWalk)
|
||||
}
|
||||
Reference in New Issue
Block a user