Signed-off-by: hongming <talonwan@yunify.com>
This commit is contained in:
hongming
2020-03-19 22:44:05 +08:00
parent 23f6be88c6
commit 9769357005
332 changed files with 69808 additions and 4129 deletions

View File

@@ -0,0 +1,216 @@
// Copyright 2016 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"math/big"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
func builtinCount(a ast.Value) (ast.Value, error) {
switch a := a.(type) {
case ast.Array:
return ast.IntNumberTerm(len(a)).Value, nil
case ast.Object:
return ast.IntNumberTerm(a.Len()).Value, nil
case ast.Set:
return ast.IntNumberTerm(a.Len()).Value, nil
case ast.String:
return ast.IntNumberTerm(len(a)).Value, nil
}
return nil, builtins.NewOperandTypeErr(1, a, "array", "object", "set")
}
func builtinSum(a ast.Value) (ast.Value, error) {
switch a := a.(type) {
case ast.Array:
sum := big.NewFloat(0)
for _, x := range a {
n, ok := x.Value.(ast.Number)
if !ok {
return nil, builtins.NewOperandElementErr(1, a, x.Value, "number")
}
sum = new(big.Float).Add(sum, builtins.NumberToFloat(n))
}
return builtins.FloatToNumber(sum), nil
case ast.Set:
sum := big.NewFloat(0)
err := a.Iter(func(x *ast.Term) error {
n, ok := x.Value.(ast.Number)
if !ok {
return builtins.NewOperandElementErr(1, a, x.Value, "number")
}
sum = new(big.Float).Add(sum, builtins.NumberToFloat(n))
return nil
})
return builtins.FloatToNumber(sum), err
}
return nil, builtins.NewOperandTypeErr(1, a, "set", "array")
}
func builtinProduct(a ast.Value) (ast.Value, error) {
switch a := a.(type) {
case ast.Array:
product := big.NewFloat(1)
for _, x := range a {
n, ok := x.Value.(ast.Number)
if !ok {
return nil, builtins.NewOperandElementErr(1, a, x.Value, "number")
}
product = new(big.Float).Mul(product, builtins.NumberToFloat(n))
}
return builtins.FloatToNumber(product), nil
case ast.Set:
product := big.NewFloat(1)
err := a.Iter(func(x *ast.Term) error {
n, ok := x.Value.(ast.Number)
if !ok {
return builtins.NewOperandElementErr(1, a, x.Value, "number")
}
product = new(big.Float).Mul(product, builtins.NumberToFloat(n))
return nil
})
return builtins.FloatToNumber(product), err
}
return nil, builtins.NewOperandTypeErr(1, a, "set", "array")
}
func builtinMax(a ast.Value) (ast.Value, error) {
switch a := a.(type) {
case ast.Array:
if len(a) == 0 {
return nil, BuiltinEmpty{}
}
var max = ast.Value(ast.Null{})
for i := range a {
if ast.Compare(max, a[i].Value) <= 0 {
max = a[i].Value
}
}
return max, nil
case ast.Set:
if a.Len() == 0 {
return nil, BuiltinEmpty{}
}
max, err := a.Reduce(ast.NullTerm(), func(max *ast.Term, elem *ast.Term) (*ast.Term, error) {
if ast.Compare(max, elem) <= 0 {
return elem, nil
}
return max, nil
})
return max.Value, err
}
return nil, builtins.NewOperandTypeErr(1, a, "set", "array")
}
func builtinMin(a ast.Value) (ast.Value, error) {
switch a := a.(type) {
case ast.Array:
if len(a) == 0 {
return nil, BuiltinEmpty{}
}
min := a[0].Value
for i := range a {
if ast.Compare(min, a[i].Value) >= 0 {
min = a[i].Value
}
}
return min, nil
case ast.Set:
if a.Len() == 0 {
return nil, BuiltinEmpty{}
}
min, err := a.Reduce(ast.NullTerm(), func(min *ast.Term, elem *ast.Term) (*ast.Term, error) {
// The null term is considered to be less than any other term,
// so in order for min of a set to make sense, we need to check
// for it.
if min.Value.Compare(ast.Null{}) == 0 {
return elem, nil
}
if ast.Compare(min, elem) >= 0 {
return elem, nil
}
return min, nil
})
return min.Value, err
}
return nil, builtins.NewOperandTypeErr(1, a, "set", "array")
}
func builtinSort(a ast.Value) (ast.Value, error) {
switch a := a.(type) {
case ast.Array:
return a.Sorted(), nil
case ast.Set:
return a.Sorted(), nil
}
return nil, builtins.NewOperandTypeErr(1, a, "set", "array")
}
func builtinAll(a ast.Value) (ast.Value, error) {
switch val := a.(type) {
case ast.Set:
res := true
match := ast.BooleanTerm(true)
val.Foreach(func(term *ast.Term) {
if !term.Equal(match) {
res = false
}
})
return ast.Boolean(res), nil
case ast.Array:
res := true
match := ast.BooleanTerm(true)
for _, term := range val {
if !term.Equal(match) {
res = false
}
}
return ast.Boolean(res), nil
default:
return nil, builtins.NewOperandTypeErr(1, a, "array", "set")
}
}
func builtinAny(a ast.Value) (ast.Value, error) {
switch val := a.(type) {
case ast.Set:
res := false
match := ast.BooleanTerm(true)
val.Foreach(func(term *ast.Term) {
if term.Equal(match) {
res = true
}
})
return ast.Boolean(res), nil
case ast.Array:
res := false
match := ast.BooleanTerm(true)
for _, term := range val {
if term.Equal(match) {
res = true
}
}
return ast.Boolean(res), nil
default:
return nil, builtins.NewOperandTypeErr(1, a, "array", "set")
}
}
func init() {
RegisterFunctionalBuiltin1(ast.Count.Name, builtinCount)
RegisterFunctionalBuiltin1(ast.Sum.Name, builtinSum)
RegisterFunctionalBuiltin1(ast.Product.Name, builtinProduct)
RegisterFunctionalBuiltin1(ast.Max.Name, builtinMax)
RegisterFunctionalBuiltin1(ast.Min.Name, builtinMin)
RegisterFunctionalBuiltin1(ast.Sort.Name, builtinSort)
RegisterFunctionalBuiltin1(ast.Any.Name, builtinAny)
RegisterFunctionalBuiltin1(ast.All.Name, builtinAll)
}

View File

@@ -0,0 +1,156 @@
// Copyright 2016 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"math/big"
"fmt"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
type arithArity1 func(a *big.Float) (*big.Float, error)
type arithArity2 func(a, b *big.Float) (*big.Float, error)
func arithAbs(a *big.Float) (*big.Float, error) {
return a.Abs(a), nil
}
var halfAwayFromZero = big.NewFloat(0.5)
func arithRound(a *big.Float) (*big.Float, error) {
var i *big.Int
if a.Signbit() {
i, _ = new(big.Float).Sub(a, halfAwayFromZero).Int(nil)
} else {
i, _ = new(big.Float).Add(a, halfAwayFromZero).Int(nil)
}
return new(big.Float).SetInt(i), nil
}
func arithPlus(a, b *big.Float) (*big.Float, error) {
return new(big.Float).Add(a, b), nil
}
func arithMinus(a, b *big.Float) (*big.Float, error) {
return new(big.Float).Sub(a, b), nil
}
func arithMultiply(a, b *big.Float) (*big.Float, error) {
return new(big.Float).Mul(a, b), nil
}
func arithDivide(a, b *big.Float) (*big.Float, error) {
i, acc := b.Int64()
if acc == big.Exact && i == 0 {
return nil, fmt.Errorf("divide by zero")
}
return new(big.Float).Quo(a, b), nil
}
func arithRem(a, b *big.Int) (*big.Int, error) {
if b.Int64() == 0 {
return nil, fmt.Errorf("modulo by zero")
}
return new(big.Int).Rem(a, b), nil
}
func builtinArithArity1(fn arithArity1) FunctionalBuiltin1 {
return func(a ast.Value) (ast.Value, error) {
n, err := builtins.NumberOperand(a, 1)
if err != nil {
return nil, err
}
f, err := fn(builtins.NumberToFloat(n))
if err != nil {
return nil, err
}
return builtins.FloatToNumber(f), nil
}
}
func builtinArithArity2(fn arithArity2) FunctionalBuiltin2 {
return func(a, b ast.Value) (ast.Value, error) {
n1, err := builtins.NumberOperand(a, 1)
if err != nil {
return nil, err
}
n2, err := builtins.NumberOperand(b, 2)
if err != nil {
return nil, err
}
f, err := fn(builtins.NumberToFloat(n1), builtins.NumberToFloat(n2))
if err != nil {
return nil, err
}
return builtins.FloatToNumber(f), nil
}
}
func builtinMinus(a, b ast.Value) (ast.Value, error) {
n1, ok1 := a.(ast.Number)
n2, ok2 := b.(ast.Number)
if ok1 && ok2 {
f, err := arithMinus(builtins.NumberToFloat(n1), builtins.NumberToFloat(n2))
if err != nil {
return nil, err
}
return builtins.FloatToNumber(f), nil
}
s1, ok3 := a.(ast.Set)
s2, ok4 := b.(ast.Set)
if ok3 && ok4 {
return s1.Diff(s2), nil
}
if !ok1 && !ok3 {
return nil, builtins.NewOperandTypeErr(1, a, "number", "set")
}
return nil, builtins.NewOperandTypeErr(2, b, "number", "set")
}
func builtinRem(a, b ast.Value) (ast.Value, error) {
n1, ok1 := a.(ast.Number)
n2, ok2 := b.(ast.Number)
if ok1 && ok2 {
op1, err1 := builtins.NumberToInt(n1)
op2, err2 := builtins.NumberToInt(n2)
if err1 != nil || err2 != nil {
return nil, fmt.Errorf("modulo on floating-point number")
}
i, err := arithRem(op1, op2)
if err != nil {
return nil, err
}
return builtins.IntToNumber(i), nil
}
if !ok1 {
return nil, builtins.NewOperandTypeErr(1, a, "number")
}
return nil, builtins.NewOperandTypeErr(2, b, "number")
}
func init() {
RegisterFunctionalBuiltin1(ast.Abs.Name, builtinArithArity1(arithAbs))
RegisterFunctionalBuiltin1(ast.Round.Name, builtinArithArity1(arithRound))
RegisterFunctionalBuiltin2(ast.Plus.Name, builtinArithArity2(arithPlus))
RegisterFunctionalBuiltin2(ast.Minus.Name, builtinMinus)
RegisterFunctionalBuiltin2(ast.Multiply.Name, builtinArithArity2(arithMultiply))
RegisterFunctionalBuiltin2(ast.Divide.Name, builtinArithArity2(arithDivide))
RegisterFunctionalBuiltin2(ast.Rem.Name, builtinRem)
}

View File

@@ -0,0 +1,75 @@
// Copyright 2018 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
func builtinArrayConcat(a, b ast.Value) (ast.Value, error) {
arrA, err := builtins.ArrayOperand(a, 1)
if err != nil {
return nil, err
}
arrB, err := builtins.ArrayOperand(b, 2)
if err != nil {
return nil, err
}
arrC := make(ast.Array, 0, len(arrA)+len(arrB))
for _, elemA := range arrA {
arrC = append(arrC, elemA)
}
for _, elemB := range arrB {
arrC = append(arrC, elemB)
}
return arrC, nil
}
func builtinArraySlice(a, i, j ast.Value) (ast.Value, error) {
arr, err := builtins.ArrayOperand(a, 1)
if err != nil {
return nil, err
}
startIndex, err := builtins.IntOperand(i, 2)
if err != nil {
return nil, err
}
stopIndex, err := builtins.IntOperand(j, 3)
if err != nil {
return nil, err
}
// Return empty array if bounds cannot be clamped sensibly.
if (startIndex >= stopIndex) || (startIndex <= 0 && stopIndex <= 0) {
return arr[0:0], nil
}
// Clamp bounds to avoid out-of-range errors.
if startIndex < 0 {
startIndex = 0
}
if stopIndex > len(arr) {
stopIndex = len(arr)
}
arrb := arr[startIndex:stopIndex]
return arrb, nil
}
func init() {
RegisterFunctionalBuiltin2(ast.ArrayConcat.Name, builtinArrayConcat)
RegisterFunctionalBuiltin3(ast.ArraySlice.Name, builtinArraySlice)
}

View File

@@ -0,0 +1,45 @@
// Copyright 2017 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
func builtinBinaryAnd(a ast.Value, b ast.Value) (ast.Value, error) {
s1, err := builtins.SetOperand(a, 1)
if err != nil {
return nil, err
}
s2, err := builtins.SetOperand(b, 2)
if err != nil {
return nil, err
}
return s1.Intersect(s2), nil
}
func builtinBinaryOr(a ast.Value, b ast.Value) (ast.Value, error) {
s1, err := builtins.SetOperand(a, 1)
if err != nil {
return nil, err
}
s2, err := builtins.SetOperand(b, 2)
if err != nil {
return nil, err
}
return s1.Union(s2), nil
}
func init() {
RegisterFunctionalBuiltin2(ast.And.Name, builtinBinaryAnd)
RegisterFunctionalBuiltin2(ast.Or.Name, builtinBinaryOr)
}

View File

@@ -0,0 +1,387 @@
// Copyright 2017 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"fmt"
"strings"
"github.com/open-policy-agent/opa/ast"
)
type undo struct {
k *ast.Term
u *bindings
next *undo
}
func (u *undo) Undo() {
if u == nil {
// Allow call on zero value of Undo for ease-of-use.
return
}
if u.u == nil {
// Call on empty unifier undos a no-op unify operation.
return
}
u.u.delete(u.k)
u.next.Undo()
}
type bindings struct {
id uint64
values bindingsArrayHashmap
instr *Instrumentation
}
func newBindings(id uint64, instr *Instrumentation) *bindings {
values := newBindingsArrayHashmap()
return &bindings{id, values, instr}
}
func (u *bindings) Iter(caller *bindings, iter func(*ast.Term, *ast.Term) error) error {
var err error
u.values.Iter(func(k *ast.Term, v value) bool {
if err != nil {
return true
}
err = iter(k, u.PlugNamespaced(k, caller))
return false
})
return err
}
func (u *bindings) Namespace(x ast.Node, caller *bindings) {
vis := namespacingVisitor{
b: u,
caller: caller,
}
ast.NewGenericVisitor(vis.Visit).Walk(x)
}
func (u *bindings) Plug(a *ast.Term) *ast.Term {
return u.PlugNamespaced(a, nil)
}
func (u *bindings) PlugNamespaced(a *ast.Term, caller *bindings) *ast.Term {
if u != nil {
u.instr.startTimer(evalOpPlug)
t := u.plugNamespaced(a, caller)
u.instr.stopTimer(evalOpPlug)
return t
}
return u.plugNamespaced(a, caller)
}
func (u *bindings) plugNamespaced(a *ast.Term, caller *bindings) *ast.Term {
switch v := a.Value.(type) {
case ast.Var:
b, next := u.apply(a)
if a != b || u != next {
return next.plugNamespaced(b, caller)
}
return u.namespaceVar(b, caller)
case ast.Array:
cpy := *a
arr := make(ast.Array, len(v))
for i := 0; i < len(arr); i++ {
arr[i] = u.plugNamespaced(v[i], caller)
}
cpy.Value = arr
return &cpy
case ast.Object:
if a.IsGround() {
return a
}
cpy := *a
cpy.Value, _ = v.Map(func(k, v *ast.Term) (*ast.Term, *ast.Term, error) {
return u.plugNamespaced(k, caller), u.plugNamespaced(v, caller), nil
})
return &cpy
case ast.Set:
cpy := *a
cpy.Value, _ = v.Map(func(x *ast.Term) (*ast.Term, error) {
return u.plugNamespaced(x, caller), nil
})
return &cpy
case ast.Ref:
cpy := *a
ref := make(ast.Ref, len(v))
for i := 0; i < len(ref); i++ {
ref[i] = u.plugNamespaced(v[i], caller)
}
cpy.Value = ref
return &cpy
}
return a
}
func (u *bindings) bind(a *ast.Term, b *ast.Term, other *bindings) *undo {
u.values.Put(a, value{
u: other,
v: b,
})
return &undo{a, u, nil}
}
func (u *bindings) apply(a *ast.Term) (*ast.Term, *bindings) {
// Early exit for non-var terms. Only vars are bound in the binding list,
// so the lookup below will always fail for non-var terms. In some cases,
// the lookup may be expensive as it has to hash the term (which for large
// inputs can be costly.)
_, ok := a.Value.(ast.Var)
if !ok {
return a, u
}
val, ok := u.get(a)
if !ok {
return a, u
}
return val.u.apply(val.v)
}
func (u *bindings) delete(v *ast.Term) {
u.values.Delete(v)
}
func (u *bindings) get(v *ast.Term) (value, bool) {
if u == nil {
return value{}, false
}
return u.values.Get(v)
}
func (u *bindings) String() string {
if u == nil {
return "()"
}
var buf []string
u.values.Iter(func(a *ast.Term, b value) bool {
buf = append(buf, fmt.Sprintf("%v: %v", a, b))
return false
})
return fmt.Sprintf("({%v}, %v)", strings.Join(buf, ", "), u.id)
}
func (u *bindings) namespaceVar(v *ast.Term, caller *bindings) *ast.Term {
name, ok := v.Value.(ast.Var)
if !ok {
panic("illegal value")
}
if caller != nil && caller != u {
// Root documents (i.e., data, input) should never be namespaced because they
// are globally unique.
if !ast.RootDocumentNames.Contains(v) {
return ast.NewTerm(ast.Var(string(name) + fmt.Sprint(u.id)))
}
}
return v
}
type value struct {
u *bindings
v *ast.Term
}
func (v value) String() string {
return fmt.Sprintf("(%v, %d)", v.v, v.u.id)
}
func (v value) equal(other *value) bool {
if v.u == other.u {
return v.v.Equal(other.v)
}
return false
}
type namespacingVisitor struct {
b *bindings
caller *bindings
}
func (vis namespacingVisitor) Visit(x interface{}) bool {
switch x := x.(type) {
case *ast.ArrayComprehension:
x.Term = vis.namespaceTerm(x.Term)
ast.NewGenericVisitor(vis.Visit).Walk(x.Body)
return true
case *ast.SetComprehension:
x.Term = vis.namespaceTerm(x.Term)
ast.NewGenericVisitor(vis.Visit).Walk(x.Body)
return true
case *ast.ObjectComprehension:
x.Key = vis.namespaceTerm(x.Key)
x.Value = vis.namespaceTerm(x.Value)
ast.NewGenericVisitor(vis.Visit).Walk(x.Body)
return true
case *ast.Expr:
switch terms := x.Terms.(type) {
case []*ast.Term:
for i := 1; i < len(terms); i++ {
terms[i] = vis.namespaceTerm(terms[i])
}
case *ast.Term:
x.Terms = vis.namespaceTerm(terms)
}
for _, w := range x.With {
w.Target = vis.namespaceTerm(w.Target)
w.Value = vis.namespaceTerm(w.Value)
}
}
return false
}
func (vis namespacingVisitor) namespaceTerm(a *ast.Term) *ast.Term {
switch v := a.Value.(type) {
case ast.Var:
return vis.b.namespaceVar(a, vis.caller)
case ast.Array:
cpy := *a
arr := make(ast.Array, len(v))
for i := 0; i < len(arr); i++ {
arr[i] = vis.namespaceTerm(v[i])
}
cpy.Value = arr
return &cpy
case ast.Object:
if a.IsGround() {
return a
}
cpy := *a
cpy.Value, _ = v.Map(func(k, v *ast.Term) (*ast.Term, *ast.Term, error) {
return vis.namespaceTerm(k), vis.namespaceTerm(v), nil
})
return &cpy
case ast.Set:
cpy := *a
cpy.Value, _ = v.Map(func(x *ast.Term) (*ast.Term, error) {
return vis.namespaceTerm(x), nil
})
return &cpy
case ast.Ref:
cpy := *a
ref := make(ast.Ref, len(v))
for i := 0; i < len(ref); i++ {
ref[i] = vis.namespaceTerm(v[i])
}
cpy.Value = ref
return &cpy
}
return a
}
const maxLinearScan = 16
// bindingsArrayHashMap uses an array with linear scan instead of a hash map for smaller # of entries. Hash maps start to show off their performance advantage only after 16 keys.
type bindingsArrayHashmap struct {
n int // Entries in the array.
a *[maxLinearScan]bindingArrayKeyValue
m map[ast.Var]bindingArrayKeyValue
}
type bindingArrayKeyValue struct {
key *ast.Term
value value
}
func newBindingsArrayHashmap() bindingsArrayHashmap {
return bindingsArrayHashmap{}
}
func (b *bindingsArrayHashmap) Put(key *ast.Term, value value) {
if b.m == nil {
if b.a == nil {
b.a = new([maxLinearScan]bindingArrayKeyValue)
} else if i := b.find(key); i >= 0 {
(*b.a)[i].value = value
return
}
if b.n < maxLinearScan {
(*b.a)[b.n] = bindingArrayKeyValue{key, value}
b.n++
return
}
// Array is full, revert to using the hash map instead.
b.m = make(map[ast.Var]bindingArrayKeyValue, maxLinearScan+1)
for _, kv := range *b.a {
b.m[kv.key.Value.(ast.Var)] = bindingArrayKeyValue{kv.key, kv.value}
}
b.m[key.Value.(ast.Var)] = bindingArrayKeyValue{key, value}
b.n = 0
return
}
b.m[key.Value.(ast.Var)] = bindingArrayKeyValue{key, value}
}
func (b *bindingsArrayHashmap) Get(key *ast.Term) (value, bool) {
if b.m == nil {
if i := b.find(key); i >= 0 {
return (*b.a)[i].value, true
}
return value{}, false
}
v, ok := b.m[key.Value.(ast.Var)]
if ok {
return v.value, true
}
return value{}, false
}
func (b *bindingsArrayHashmap) Delete(key *ast.Term) {
if b.m == nil {
if i := b.find(key); i >= 0 {
n := b.n - 1
if i < n {
(*b.a)[i] = (*b.a)[n]
}
b.n = n
}
return
}
delete(b.m, key.Value.(ast.Var))
}
func (b *bindingsArrayHashmap) Iter(f func(k *ast.Term, v value) bool) {
if b.m == nil {
for i := 0; i < b.n; i++ {
if f((*b.a)[i].key, (*b.a)[i].value) {
return
}
}
return
}
for _, v := range b.m {
if f(v.key, v.value) {
return
}
}
}
func (b *bindingsArrayHashmap) find(key *ast.Term) int {
v := key.Value.(ast.Var)
for i := 0; i < b.n; i++ {
if (*b.a)[i].key.Value.(ast.Var) == v {
return i
}
}
return -1
}

View File

@@ -0,0 +1,88 @@
// Copyright 2020 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"math/big"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
type bitsArity1 func(a *big.Int) (*big.Int, error)
type bitsArity2 func(a, b *big.Int) (*big.Int, error)
func bitsOr(a, b *big.Int) (*big.Int, error) {
return new(big.Int).Or(a, b), nil
}
func bitsAnd(a, b *big.Int) (*big.Int, error) {
return new(big.Int).And(a, b), nil
}
func bitsNegate(a *big.Int) (*big.Int, error) {
return new(big.Int).Not(a), nil
}
func bitsXOr(a, b *big.Int) (*big.Int, error) {
return new(big.Int).Xor(a, b), nil
}
func bitsShiftLeft(a, b *big.Int) (*big.Int, error) {
if b.Sign() == -1 {
return nil, builtins.NewOperandErr(2, "must be an unsigned integer number but got a negative integer")
}
shift := uint(b.Uint64())
return new(big.Int).Lsh(a, shift), nil
}
func bitsShiftRight(a, b *big.Int) (*big.Int, error) {
if b.Sign() == -1 {
return nil, builtins.NewOperandErr(2, "must be an unsigned integer number but got a negative integer")
}
shift := uint(b.Uint64())
return new(big.Int).Rsh(a, shift), nil
}
func builtinBitsArity1(fn bitsArity1) BuiltinFunc {
return func(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
i, err := builtins.BigIntOperand(operands[0].Value, 1)
if err != nil {
return err
}
iOut, err := fn(i)
if err != nil {
return err
}
return iter(ast.NewTerm(builtins.IntToNumber(iOut)))
}
}
func builtinBitsArity2(fn bitsArity2) BuiltinFunc {
return func(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
i1, err := builtins.BigIntOperand(operands[0].Value, 1)
if err != nil {
return err
}
i2, err := builtins.BigIntOperand(operands[1].Value, 2)
if err != nil {
return err
}
iOut, err := fn(i1, i2)
if err != nil {
return err
}
return iter(ast.NewTerm(builtins.IntToNumber(iOut)))
}
}
func init() {
RegisterBuiltinFunc(ast.BitsOr.Name, builtinBitsArity2(bitsOr))
RegisterBuiltinFunc(ast.BitsAnd.Name, builtinBitsArity2(bitsAnd))
RegisterBuiltinFunc(ast.BitsNegate.Name, builtinBitsArity1(bitsNegate))
RegisterBuiltinFunc(ast.BitsXOr.Name, builtinBitsArity2(bitsXOr))
RegisterBuiltinFunc(ast.BitsShiftLeft.Name, builtinBitsArity2(bitsShiftLeft))
RegisterBuiltinFunc(ast.BitsShiftRight.Name, builtinBitsArity2(bitsShiftRight))
}

View File

@@ -0,0 +1,160 @@
// Copyright 2016 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"context"
"fmt"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
type (
// FunctionalBuiltin1 is deprecated. Use BuiltinFunc instead.
FunctionalBuiltin1 func(op1 ast.Value) (output ast.Value, err error)
// FunctionalBuiltin2 is deprecated. Use BuiltinFunc instead.
FunctionalBuiltin2 func(op1, op2 ast.Value) (output ast.Value, err error)
// FunctionalBuiltin3 is deprecated. Use BuiltinFunc instead.
FunctionalBuiltin3 func(op1, op2, op3 ast.Value) (output ast.Value, err error)
// FunctionalBuiltin4 is deprecated. Use BuiltinFunc instead.
FunctionalBuiltin4 func(op1, op2, op3, op4 ast.Value) (output ast.Value, err error)
// BuiltinContext contains context from the evaluator that may be used by
// built-in functions.
BuiltinContext struct {
Context context.Context // request context that was passed when query started
Cancel Cancel // atomic value that signals evaluation to halt
Runtime *ast.Term // runtime information on the OPA instance
Cache builtins.Cache // built-in function state cache
Location *ast.Location // location of built-in call
Tracers []Tracer // tracer objects for trace() built-in function
QueryID uint64 // identifies query being evaluated
ParentID uint64 // identifies parent of query being evaluated
}
// BuiltinFunc defines an interface for implementing built-in functions.
// The built-in function is called with the plugged operands from the call
// (including the output operands.) The implementation should evaluate the
// operands and invoke the iterator for each successful/defined output
// value.
BuiltinFunc func(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error
)
// RegisterBuiltinFunc adds a new built-in function to the evaluation engine.
func RegisterBuiltinFunc(name string, f BuiltinFunc) {
builtinFunctions[name] = builtinErrorWrapper(name, f)
}
// RegisterFunctionalBuiltin1 is deprecated use RegisterBuiltinFunc instead.
func RegisterFunctionalBuiltin1(name string, fun FunctionalBuiltin1) {
builtinFunctions[name] = functionalWrapper1(name, fun)
}
// RegisterFunctionalBuiltin2 is deprecated use RegisterBuiltinFunc instead.
func RegisterFunctionalBuiltin2(name string, fun FunctionalBuiltin2) {
builtinFunctions[name] = functionalWrapper2(name, fun)
}
// RegisterFunctionalBuiltin3 is deprecated use RegisterBuiltinFunc instead.
func RegisterFunctionalBuiltin3(name string, fun FunctionalBuiltin3) {
builtinFunctions[name] = functionalWrapper3(name, fun)
}
// RegisterFunctionalBuiltin4 is deprecated use RegisterBuiltinFunc instead.
func RegisterFunctionalBuiltin4(name string, fun FunctionalBuiltin4) {
builtinFunctions[name] = functionalWrapper4(name, fun)
}
// GetBuiltin returns a built-in function implementation, nil if no built-in found.
func GetBuiltin(name string) BuiltinFunc {
return builtinFunctions[name]
}
// BuiltinEmpty is deprecated.
type BuiltinEmpty struct{}
func (BuiltinEmpty) Error() string {
return "<empty>"
}
var builtinFunctions = map[string]BuiltinFunc{}
func builtinErrorWrapper(name string, fn BuiltinFunc) BuiltinFunc {
return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
err := fn(bctx, args, iter)
if err == nil {
return nil
}
return handleBuiltinErr(name, bctx.Location, err)
}
}
func functionalWrapper1(name string, fn FunctionalBuiltin1) BuiltinFunc {
return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
result, err := fn(args[0].Value)
if err == nil {
return iter(ast.NewTerm(result))
}
return handleBuiltinErr(name, bctx.Location, err)
}
}
func functionalWrapper2(name string, fn FunctionalBuiltin2) BuiltinFunc {
return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
result, err := fn(args[0].Value, args[1].Value)
if err == nil {
return iter(ast.NewTerm(result))
}
return handleBuiltinErr(name, bctx.Location, err)
}
}
func functionalWrapper3(name string, fn FunctionalBuiltin3) BuiltinFunc {
return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
result, err := fn(args[0].Value, args[1].Value, args[2].Value)
if err == nil {
return iter(ast.NewTerm(result))
}
return handleBuiltinErr(name, bctx.Location, err)
}
}
func functionalWrapper4(name string, fn FunctionalBuiltin4) BuiltinFunc {
return func(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
result, err := fn(args[0].Value, args[1].Value, args[2].Value, args[3].Value)
if err == nil {
return iter(ast.NewTerm(result))
}
if _, empty := err.(BuiltinEmpty); empty {
return nil
}
return handleBuiltinErr(name, bctx.Location, err)
}
}
func handleBuiltinErr(name string, loc *ast.Location, err error) error {
switch err := err.(type) {
case BuiltinEmpty:
return nil
case *Error:
return err
case builtins.ErrOperand:
return &Error{
Code: TypeErr,
Message: fmt.Sprintf("%v: %v", string(name), err.Error()),
Location: loc,
}
default:
return &Error{
Code: BuiltinErr,
Message: fmt.Sprintf("%v: %v", string(name), err.Error()),
Location: loc,
}
}
}

View File

@@ -0,0 +1,235 @@
// Copyright 2016 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
// Package builtins contains utilities for implementing built-in functions.
package builtins
import (
"fmt"
"math/big"
"strings"
"github.com/open-policy-agent/opa/ast"
)
// Cache defines the built-in cache used by the top-down evaluation. The keys
// must be comparable and should not be of type string.
type Cache map[interface{}]interface{}
// Put updates the cache for the named built-in.
func (c Cache) Put(k, v interface{}) {
c[k] = v
}
// Get returns the cached value for k.
func (c Cache) Get(k interface{}) (interface{}, bool) {
v, ok := c[k]
return v, ok
}
// ErrOperand represents an invalid operand has been passed to a built-in
// function. Built-ins should return ErrOperand to indicate a type error has
// occurred.
type ErrOperand string
func (err ErrOperand) Error() string {
return string(err)
}
// NewOperandErr returns a generic operand error.
func NewOperandErr(pos int, f string, a ...interface{}) error {
f = fmt.Sprintf("operand %v ", pos) + f
return ErrOperand(fmt.Sprintf(f, a...))
}
// NewOperandTypeErr returns an operand error indicating the operand's type was wrong.
func NewOperandTypeErr(pos int, got ast.Value, expected ...string) error {
if len(expected) == 1 {
return NewOperandErr(pos, "must be %v but got %v", expected[0], ast.TypeName(got))
}
return NewOperandErr(pos, "must be one of {%v} but got %v", strings.Join(expected, ", "), ast.TypeName(got))
}
// NewOperandElementErr returns an operand error indicating an element in the
// composite operand was wrong.
func NewOperandElementErr(pos int, composite ast.Value, got ast.Value, expected ...string) error {
tpe := ast.TypeName(composite)
if len(expected) == 1 {
return NewOperandErr(pos, "must be %v of %vs but got %v containing %v", tpe, expected[0], tpe, ast.TypeName(got))
}
return NewOperandErr(pos, "must be %v of (any of) {%v} but got %v containing %v", tpe, strings.Join(expected, ", "), tpe, ast.TypeName(got))
}
// NewOperandEnumErr returns an operand error indicating a value was wrong.
func NewOperandEnumErr(pos int, expected ...string) error {
if len(expected) == 1 {
return NewOperandErr(pos, "must be %v", expected[0])
}
return NewOperandErr(pos, "must be one of {%v}", strings.Join(expected, ", "))
}
// IntOperand converts x to an int. If the cast fails, a descriptive error is
// returned.
func IntOperand(x ast.Value, pos int) (int, error) {
n, ok := x.(ast.Number)
if !ok {
return 0, NewOperandTypeErr(pos, x, "number")
}
i, ok := n.Int()
if !ok {
return 0, NewOperandErr(pos, "must be integer number but got floating-point number")
}
return i, nil
}
// BigIntOperand converts x to a big int. If the cast fails, a descriptive error
// is returned.
func BigIntOperand(x ast.Value, pos int) (*big.Int, error) {
n, err := NumberOperand(x, 1)
if err != nil {
return nil, NewOperandTypeErr(pos, x, "integer")
}
bi, err := NumberToInt(n)
if err != nil {
return nil, NewOperandErr(pos, "must be integer number but got floating-point number")
}
return bi, nil
}
// NumberOperand converts x to a number. If the cast fails, a descriptive error is
// returned.
func NumberOperand(x ast.Value, pos int) (ast.Number, error) {
n, ok := x.(ast.Number)
if !ok {
return ast.Number(""), NewOperandTypeErr(pos, x, "number")
}
return n, nil
}
// SetOperand converts x to a set. If the cast fails, a descriptive error is
// returned.
func SetOperand(x ast.Value, pos int) (ast.Set, error) {
s, ok := x.(ast.Set)
if !ok {
return nil, NewOperandTypeErr(pos, x, "set")
}
return s, nil
}
// StringOperand converts x to a string. If the cast fails, a descriptive error is
// returned.
func StringOperand(x ast.Value, pos int) (ast.String, error) {
s, ok := x.(ast.String)
if !ok {
return ast.String(""), NewOperandTypeErr(pos, x, "string")
}
return s, nil
}
// ObjectOperand converts x to an object. If the cast fails, a descriptive
// error is returned.
func ObjectOperand(x ast.Value, pos int) (ast.Object, error) {
o, ok := x.(ast.Object)
if !ok {
return nil, NewOperandTypeErr(pos, x, "object")
}
return o, nil
}
// ArrayOperand converts x to an array. If the cast fails, a descriptive
// error is returned.
func ArrayOperand(x ast.Value, pos int) (ast.Array, error) {
a, ok := x.(ast.Array)
if !ok {
return nil, NewOperandTypeErr(pos, x, "array")
}
return a, nil
}
// NumberToFloat converts n to a big float.
func NumberToFloat(n ast.Number) *big.Float {
r, ok := new(big.Float).SetString(string(n))
if !ok {
panic("illegal value")
}
return r
}
// FloatToNumber converts f to a number.
func FloatToNumber(f *big.Float) ast.Number {
return ast.Number(f.String())
}
// NumberToInt converts n to a big int.
// If n cannot be converted to an big int, an error is returned.
func NumberToInt(n ast.Number) (*big.Int, error) {
f := NumberToFloat(n)
r, accuracy := f.Int(nil)
if accuracy != big.Exact {
return nil, fmt.Errorf("illegal value")
}
return r, nil
}
// IntToNumber converts i to a number.
func IntToNumber(i *big.Int) ast.Number {
return ast.Number(i.String())
}
// StringSliceOperand converts x to a []string. If the cast fails, a descriptive error is
// returned.
func StringSliceOperand(x ast.Value, pos int) ([]string, error) {
a, err := ArrayOperand(x, pos)
if err != nil {
return nil, err
}
var f = make([]string, len(a))
for k, b := range a {
c, ok := b.Value.(ast.String)
if !ok {
return nil, NewOperandElementErr(pos, x, b.Value, "[]string")
}
f[k] = string(c)
}
return f, nil
}
// RuneSliceOperand converts x to a []rune. If the cast fails, a descriptive error is
// returned.
func RuneSliceOperand(x ast.Value, pos int) ([]rune, error) {
a, err := ArrayOperand(x, pos)
if err != nil {
return nil, err
}
var f = make([]rune, len(a))
for k, b := range a {
c, ok := b.Value.(ast.String)
if !ok {
return nil, NewOperandElementErr(pos, x, b.Value, "string")
}
d := []rune(string(c))
if len(d) != 1 {
return nil, NewOperandElementErr(pos, x, b.Value, "rune")
}
f[k] = d[0]
}
return f, nil
}

View File

@@ -0,0 +1,166 @@
// Copyright 2017 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/util"
)
type virtualCache struct {
stack []*virtualCacheElem
}
type virtualCacheElem struct {
value *ast.Term
children *util.HashMap
}
func newVirtualCache() *virtualCache {
cache := &virtualCache{}
cache.Push()
return cache
}
func (c *virtualCache) Push() {
c.stack = append(c.stack, newVirtualCacheElem())
}
func (c *virtualCache) Pop() {
c.stack = c.stack[:len(c.stack)-1]
}
func (c *virtualCache) Get(ref ast.Ref) *ast.Term {
node := c.stack[len(c.stack)-1]
for i := 0; i < len(ref); i++ {
x, ok := node.children.Get(ref[i])
if !ok {
return nil
}
node = x.(*virtualCacheElem)
}
return node.value
}
func (c *virtualCache) Put(ref ast.Ref, value *ast.Term) {
node := c.stack[len(c.stack)-1]
for i := 0; i < len(ref); i++ {
x, ok := node.children.Get(ref[i])
if ok {
node = x.(*virtualCacheElem)
} else {
next := newVirtualCacheElem()
node.children.Put(ref[i], next)
node = next
}
}
node.value = value
}
func newVirtualCacheElem() *virtualCacheElem {
return &virtualCacheElem{children: newVirtualCacheHashMap()}
}
func newVirtualCacheHashMap() *util.HashMap {
return util.NewHashMap(func(a, b util.T) bool {
return a.(*ast.Term).Equal(b.(*ast.Term))
}, func(x util.T) int {
return x.(*ast.Term).Hash()
})
}
// baseCache implements a trie structure to cache base documents read out of
// storage. Values inserted into the cache may contain other values that were
// previously inserted. In this case, the previous values are erased from the
// structure.
type baseCache struct {
root *baseCacheElem
}
func newBaseCache() *baseCache {
return &baseCache{
root: newBaseCacheElem(),
}
}
func (c *baseCache) Get(ref ast.Ref) ast.Value {
node := c.root
for i := 0; i < len(ref); i++ {
node = node.children[ref[i].Value]
if node == nil {
return nil
} else if node.value != nil {
result, err := node.value.Find(ref[i+1:])
if err != nil {
return nil
}
return result
}
}
return nil
}
func (c *baseCache) Put(ref ast.Ref, value ast.Value) {
node := c.root
for i := 0; i < len(ref); i++ {
if child, ok := node.children[ref[i].Value]; ok {
node = child
} else {
child := newBaseCacheElem()
node.children[ref[i].Value] = child
node = child
}
}
node.set(value)
}
type baseCacheElem struct {
value ast.Value
children map[ast.Value]*baseCacheElem
}
func newBaseCacheElem() *baseCacheElem {
return &baseCacheElem{
children: map[ast.Value]*baseCacheElem{},
}
}
func (e *baseCacheElem) set(value ast.Value) {
e.value = value
e.children = map[ast.Value]*baseCacheElem{}
}
type refStack struct {
sl []refStackElem
}
type refStackElem struct {
refs []ast.Ref
}
func newRefStack() *refStack {
return &refStack{}
}
func (s *refStack) Push(refs []ast.Ref) {
s.sl = append(s.sl, refStackElem{refs: refs})
}
func (s *refStack) Pop() {
s.sl = s.sl[:len(s.sl)-1]
}
func (s *refStack) Prefixed(ref ast.Ref) bool {
if s != nil {
for i := len(s.sl) - 1; i >= 0; i-- {
for j := range s.sl[i].refs {
if ref.HasPrefix(s.sl[i].refs[j]) {
return true
}
}
}
}
return false
}

View File

@@ -0,0 +1,33 @@
// Copyright 2017 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"sync/atomic"
)
// Cancel defines the interface for cancelling topdown queries. Cancel
// operations are thread-safe and idempotent.
type Cancel interface {
Cancel()
Cancelled() bool
}
type cancel struct {
flag int32
}
// NewCancel returns a new Cancel object.
func NewCancel() Cancel {
return &cancel{}
}
func (c *cancel) Cancel() {
atomic.StoreInt32(&c.flag, 1)
}
func (c *cancel) Cancelled() bool {
return atomic.LoadInt32(&c.flag) != 0
}

View File

@@ -0,0 +1,113 @@
// Copyright 2018 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"strconv"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
func builtinToNumber(a ast.Value) (ast.Value, error) {
switch a := a.(type) {
case ast.Null:
return ast.Number("0"), nil
case ast.Boolean:
if a {
return ast.Number("1"), nil
}
return ast.Number("0"), nil
case ast.Number:
return a, nil
case ast.String:
_, err := strconv.ParseFloat(string(a), 64)
if err != nil {
return nil, err
}
return ast.Number(a), nil
}
return nil, builtins.NewOperandTypeErr(1, a, "null", "boolean", "number", "string")
}
// Deprecated in v0.13.0.
func builtinToArray(a ast.Value) (ast.Value, error) {
switch val := a.(type) {
case ast.Array:
return val, nil
case ast.Set:
arr := make(ast.Array, val.Len())
i := 0
val.Foreach(func(term *ast.Term) {
arr[i] = term
i++
})
return arr, nil
default:
return nil, builtins.NewOperandTypeErr(1, a, "array", "set")
}
}
// Deprecated in v0.13.0.
func builtinToSet(a ast.Value) (ast.Value, error) {
switch val := a.(type) {
case ast.Array:
return ast.NewSet(val...), nil
case ast.Set:
return val, nil
default:
return nil, builtins.NewOperandTypeErr(1, a, "array", "set")
}
}
// Deprecated in v0.13.0.
func builtinToString(a ast.Value) (ast.Value, error) {
switch val := a.(type) {
case ast.String:
return val, nil
default:
return nil, builtins.NewOperandTypeErr(1, a, "string")
}
}
// Deprecated in v0.13.0.
func builtinToBoolean(a ast.Value) (ast.Value, error) {
switch val := a.(type) {
case ast.Boolean:
return val, nil
default:
return nil, builtins.NewOperandTypeErr(1, a, "boolean")
}
}
// Deprecated in v0.13.0.
func builtinToNull(a ast.Value) (ast.Value, error) {
switch val := a.(type) {
case ast.Null:
return val, nil
default:
return nil, builtins.NewOperandTypeErr(1, a, "null")
}
}
// Deprecated in v0.13.0.
func builtinToObject(a ast.Value) (ast.Value, error) {
switch val := a.(type) {
case ast.Object:
return val, nil
default:
return nil, builtins.NewOperandTypeErr(1, a, "object")
}
}
func init() {
RegisterFunctionalBuiltin1(ast.ToNumber.Name, builtinToNumber)
RegisterFunctionalBuiltin1(ast.CastArray.Name, builtinToArray)
RegisterFunctionalBuiltin1(ast.CastSet.Name, builtinToSet)
RegisterFunctionalBuiltin1(ast.CastString.Name, builtinToString)
RegisterFunctionalBuiltin1(ast.CastBoolean.Name, builtinToBoolean)
RegisterFunctionalBuiltin1(ast.CastNull.Name, builtinToNull)
RegisterFunctionalBuiltin1(ast.CastObject.Name, builtinToObject)
}

157
vendor/github.com/open-policy-agent/opa/topdown/cidr.go generated vendored Normal file
View File

@@ -0,0 +1,157 @@
package topdown
import (
"fmt"
"math/big"
"net"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
func getNetFromOperand(v ast.Value) (*net.IPNet, error) {
subnetStringA, err := builtins.StringOperand(v, 1)
if err != nil {
return nil, err
}
_, cidrnet, err := net.ParseCIDR(string(subnetStringA))
if err != nil {
return nil, err
}
return cidrnet, nil
}
func getLastIP(cidr *net.IPNet) (net.IP, error) {
prefixLen, bits := cidr.Mask.Size()
if prefixLen == 0 && bits == 0 {
// non-standard mask, see https://golang.org/pkg/net/#IPMask.Size
return nil, fmt.Errorf("CIDR mask is in non-standard format")
}
var lastIP []byte
if prefixLen == bits {
// Special case for single ip address ranges ex: 192.168.1.1/32
// We can just use the starting IP as the last IP
lastIP = cidr.IP
} else {
// Use big.Int's so we can handle ipv6 addresses
firstIPInt := new(big.Int)
firstIPInt.SetBytes(cidr.IP)
hostLen := uint(bits) - uint(prefixLen)
lastIPInt := big.NewInt(1)
lastIPInt.Lsh(lastIPInt, hostLen)
lastIPInt.Sub(lastIPInt, big.NewInt(1))
lastIPInt.Or(lastIPInt, firstIPInt)
ipBytes := lastIPInt.Bytes()
lastIP = make([]byte, bits/8)
// Pack our IP bytes into the end of the return array,
// since big.Int.Bytes() removes front zero padding.
for i := 1; i <= len(lastIPInt.Bytes()); i++ {
lastIP[len(lastIP)-i] = ipBytes[len(ipBytes)-i]
}
}
return lastIP, nil
}
func builtinNetCIDRIntersects(a, b ast.Value) (ast.Value, error) {
cidrnetA, err := getNetFromOperand(a)
if err != nil {
return nil, err
}
cidrnetB, err := getNetFromOperand(b)
if err != nil {
return nil, err
}
// If either net contains the others starting IP they are overlapping
cidrsOverlap := (cidrnetA.Contains(cidrnetB.IP) || cidrnetB.Contains(cidrnetA.IP))
return ast.Boolean(cidrsOverlap), nil
}
func builtinNetCIDRContains(a, b ast.Value) (ast.Value, error) {
cidrnetA, err := getNetFromOperand(a)
if err != nil {
return nil, err
}
// b could be either an IP addressor CIDR string, try to parse it as an IP first, fall back to CIDR
bStr, err := builtins.StringOperand(b, 1)
if err != nil {
return nil, err
}
ip := net.ParseIP(string(bStr))
if ip != nil {
return ast.Boolean(cidrnetA.Contains(ip)), nil
}
// It wasn't an IP, try and parse it as a CIDR
cidrnetB, err := getNetFromOperand(b)
if err != nil {
return nil, fmt.Errorf("not a valid textual representation of an IP address or CIDR: %s", string(bStr))
}
// We can determine if cidr A contains cidr B iff A contains the starting address of B and the last address in B.
cidrContained := false
if cidrnetA.Contains(cidrnetB.IP) {
// Only spend time calculating the last IP if the starting IP is already verified to be in cidr A
lastIP, err := getLastIP(cidrnetB)
if err != nil {
return nil, err
}
cidrContained = cidrnetA.Contains(lastIP)
}
return ast.Boolean(cidrContained), nil
}
func builtinNetCIDRExpand(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
s, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return err
}
ip, ipNet, err := net.ParseCIDR(string(s))
if err != nil {
return err
}
result := ast.NewSet()
for ip := ip.Mask(ipNet.Mask); ipNet.Contains(ip); incIP(ip) {
if bctx.Cancel != nil && bctx.Cancel.Cancelled() {
return &Error{
Code: CancelErr,
Message: "net.cidr_expand: timed out before generating all IP addresses",
}
}
result.Add(ast.StringTerm(ip.String()))
}
return iter(ast.NewTerm(result))
}
func incIP(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
}
}
}
func init() {
RegisterFunctionalBuiltin2(ast.NetCIDROverlap.Name, builtinNetCIDRContains)
RegisterFunctionalBuiltin2(ast.NetCIDRIntersects.Name, builtinNetCIDRIntersects)
RegisterFunctionalBuiltin2(ast.NetCIDRContains.Name, builtinNetCIDRContains)
RegisterBuiltinFunc(ast.NetCIDRExpand.Name, builtinNetCIDRExpand)
}

View File

@@ -0,0 +1,48 @@
// Copyright 2016 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import "github.com/open-policy-agent/opa/ast"
type compareFunc func(a, b ast.Value) bool
func compareGreaterThan(a, b ast.Value) bool {
return ast.Compare(a, b) > 0
}
func compareGreaterThanEq(a, b ast.Value) bool {
return ast.Compare(a, b) >= 0
}
func compareLessThan(a, b ast.Value) bool {
return ast.Compare(a, b) < 0
}
func compareLessThanEq(a, b ast.Value) bool {
return ast.Compare(a, b) <= 0
}
func compareNotEq(a, b ast.Value) bool {
return ast.Compare(a, b) != 0
}
func compareEq(a, b ast.Value) bool {
return ast.Compare(a, b) == 0
}
func builtinCompare(cmp compareFunc) FunctionalBuiltin2 {
return func(a, b ast.Value) (ast.Value, error) {
return ast.Boolean(cmp(a, b)), nil
}
}
func init() {
RegisterFunctionalBuiltin2(ast.GreaterThan.Name, builtinCompare(compareGreaterThan))
RegisterFunctionalBuiltin2(ast.GreaterThanEq.Name, builtinCompare(compareGreaterThanEq))
RegisterFunctionalBuiltin2(ast.LessThan.Name, builtinCompare(compareLessThan))
RegisterFunctionalBuiltin2(ast.LessThanEq.Name, builtinCompare(compareLessThanEq))
RegisterFunctionalBuiltin2(ast.NotEqual.Name, builtinCompare(compareNotEq))
RegisterFunctionalBuiltin2(ast.Equal.Name, builtinCompare(compareEq))
}

View File

@@ -0,0 +1,484 @@
// Copyright 2018 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package copypropagation
import (
"sort"
"github.com/open-policy-agent/opa/ast"
)
// CopyPropagator implements a simple copy propagation optimization to remove
// intermediate variables in partial evaluation results.
//
// For example, given the query: input.x > 1 where 'input' is unknown, the
// compiled query would become input.x = a; a > 1 which would remain in the
// partial evaluation result. The CopyPropagator will remove the variable
// assignment so that partial evaluation simply outputs input.x > 1.
//
// In many cases, copy propagation can remove all variables from the result of
// partial evaluation which simplifies evaluation for non-OPA consumers.
//
// In some cases, copy propagation cannot remove all variables. If the output of
// a built-in call is subsequently used as a ref head, the output variable must
// be kept. For example. sort(input, x); x[0] == 1. In this case, copy
// propagation cannot replace x[0] == 1 with sort(input, x)[0] == 1 as this is
// not legal.
type CopyPropagator struct {
livevars ast.VarSet // vars that must be preserved in the resulting query
sorted []ast.Var // sorted copy of vars to ensure deterministic result
ensureNonEmptyBody bool
}
// New returns a new CopyPropagator that optimizes queries while preserving vars
// in the livevars set.
func New(livevars ast.VarSet) *CopyPropagator {
sorted := make([]ast.Var, 0, len(livevars))
for v := range livevars {
sorted = append(sorted, v)
}
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Compare(sorted[j]) < 0
})
return &CopyPropagator{livevars: livevars, sorted: sorted}
}
// WithEnsureNonEmptyBody configures p to ensure that results are always non-empty.
func (p *CopyPropagator) WithEnsureNonEmptyBody(yes bool) *CopyPropagator {
p.ensureNonEmptyBody = yes
return p
}
// Apply executes the copy propagation optimization and returns a new query.
func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) {
uf, ok := makeDisjointSets(p.livevars, query)
if !ok {
return query
}
// Compute set of vars that appear in the head of refs in the query. If a var
// is dereferenced, we cannot plug it with a constant value so the constant on
// the union-find root must be unset (e.g., [1][0] is not legal.)
headvars := ast.NewVarSet()
ast.WalkRefs(query, func(x ast.Ref) bool {
if v, ok := x[0].Value.(ast.Var); ok {
if root, ok := uf.Find(v); ok {
root.constant = nil
headvars.Add(root.key)
} else {
headvars.Add(v)
}
}
return false
})
bindings := map[ast.Var]*binding{}
for _, expr := range query {
pctx := &plugContext{
bindings: bindings,
uf: uf,
negated: expr.Negated,
headvars: headvars,
}
if expr, keep := p.plugBindings(pctx, expr); keep {
if p.updateBindings(pctx, expr) {
result.Append(expr)
}
}
}
// Run post-processing step on the query to ensure that all live vars are bound
// in the result. The plugging that happens above substitutes all vars in the
// same set with the root.
//
// This step should run before the next step to prevent unnecessary bindings
// from being added to the result. For example:
//
// - Given the following result: <empty>
// - Given the following bindings: x/input.x and y/input
// - Given the following liveset: {x}
//
// If this step were to run AFTER the following step, the output would be:
//
// x = input.x; y = input
//
// Even though y = input is not required.
for _, v := range p.sorted {
if root, ok := uf.Find(v); ok {
if root.constant != nil {
result.Append(ast.Equality.Expr(ast.NewTerm(v), root.constant))
} else if b, ok := bindings[root.key]; ok {
result.Append(ast.Equality.Expr(ast.NewTerm(v), ast.NewTerm(b.v)))
} else if root.key != v {
result.Append(ast.Equality.Expr(ast.NewTerm(v), ast.NewTerm(root.key)))
}
}
}
// Run post-processing step on query to ensure that all killed exprs are
// accounted for. If an expr is killed but the binding is never used, the query
// must still include the expr. For example, given the query 'input.x = a' and
// an empty livevar set, the result must include the ref input.x otherwise the
// query could be satisfied without input.x being defined. When exprs are
// killed we initialize the binding counter to zero and then increment it each
// time the binding is substituted. if the binding was never substituted it
// means the binding value must be added back into the query.
for _, b := range sortbindings(bindings) {
if !b.containedIn(result) {
result.Append(ast.Equality.Expr(ast.NewTerm(b.k), ast.NewTerm(b.v)))
}
}
if p.ensureNonEmptyBody && len(result) == 0 {
result = append(result, ast.NewExpr(ast.BooleanTerm(true)))
}
return result
}
// plugBindings applies the binding list and union-find to x. This process
// removes as many variables as possible.
func (p *CopyPropagator) plugBindings(pctx *plugContext, expr *ast.Expr) (*ast.Expr, bool) {
// Kill single term expressions that are in the binding list. They will be
// re-added during post-processing if needed.
if term, ok := expr.Terms.(*ast.Term); ok {
if v, ok := term.Value.(ast.Var); ok {
if root, ok := pctx.uf.Find(v); ok {
if _, ok := pctx.bindings[root.key]; ok {
return nil, false
}
}
}
}
xform := bindingPlugTransform{
pctx: pctx,
}
// Deep copy the expression as it may be mutated during the transform and
// the caller running copy propagation may have references to the
// expression. Note, the transform does not contain any error paths and
// should never return a non-expression value for the root so consider
// errors unreachable.
x, err := ast.Transform(xform, expr.Copy())
if expr, ok := x.(*ast.Expr); !ok || err != nil {
panic("unreachable")
} else {
return expr, true
}
}
type bindingPlugTransform struct {
pctx *plugContext
}
func (t bindingPlugTransform) Transform(x interface{}) (interface{}, error) {
switch x := x.(type) {
case ast.Var:
return t.plugBindingsVar(t.pctx, x), nil
case ast.Ref:
return t.plugBindingsRef(t.pctx, x), nil
default:
return x, nil
}
}
func (t bindingPlugTransform) plugBindingsVar(pctx *plugContext, v ast.Var) (result ast.Value) {
result = v
// Apply union-find to remove redundant variables from input.
if root, ok := pctx.uf.Find(v); ok {
result = root.Value()
}
// Apply binding list to substitute remaining vars.
if v, ok := result.(ast.Var); ok {
if b, ok := pctx.bindings[v]; ok {
if !pctx.negated || b.v.IsGround() {
result = b.v
}
}
}
return result
}
func (t bindingPlugTransform) plugBindingsRef(pctx *plugContext, v ast.Ref) ast.Ref {
// Apply union-find to remove redundant variables from input.
if root, ok := pctx.uf.Find(v[0].Value.(ast.Var)); ok {
v[0].Value = root.Value()
}
result := v
// Refs require special handling. If the head of the ref was killed, then
// the rest of the ref must be concatenated with the new base.
//
// Invariant: ref heads can only be replaced by refs (not calls).
if b, ok := pctx.bindings[v[0].Value.(ast.Var)]; ok {
if !pctx.negated || b.v.IsGround() {
result = b.v.(ast.Ref).Concat(v[1:])
}
}
return result
}
// updateBindings returns false if the expression can be killed. If the
// expression is killed, the binding list is updated to map a var to value.
func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool {
if pctx.negated || len(expr.With) > 0 {
return true
}
if expr.IsEquality() {
a, b := expr.Operand(0), expr.Operand(1)
if a.Equal(b) {
return false
}
k, v, keep := p.updateBindingsEq(a, b)
if !keep {
if v != nil {
pctx.bindings[k] = newbinding(k, v)
}
return false
}
} else if expr.IsCall() {
terms := expr.Terms.([]*ast.Term)
output := terms[len(terms)-1]
if k, ok := output.Value.(ast.Var); ok && !p.livevars.Contains(k) && !pctx.headvars.Contains(k) {
pctx.bindings[k] = newbinding(k, ast.CallTerm(terms[:len(terms)-1]...).Value)
return false
}
}
return !isNoop(expr)
}
func (p *CopyPropagator) updateBindingsEq(a, b *ast.Term) (ast.Var, ast.Value, bool) {
k, v, keep := p.updateBindingsEqAsymmetric(a, b)
if !keep {
return k, v, keep
}
return p.updateBindingsEqAsymmetric(b, a)
}
func (p *CopyPropagator) updateBindingsEqAsymmetric(a, b *ast.Term) (ast.Var, ast.Value, bool) {
k, ok := a.Value.(ast.Var)
if !ok || p.livevars.Contains(k) {
return "", nil, true
}
switch b.Value.(type) {
case ast.Ref, ast.Call:
return k, b.Value, false
}
return "", nil, true
}
type plugContext struct {
bindings map[ast.Var]*binding
uf *unionFind
headvars ast.VarSet
negated bool
}
type binding struct {
k ast.Var
v ast.Value
}
func newbinding(k ast.Var, v ast.Value) *binding {
return &binding{k: k, v: v}
}
func (b *binding) containedIn(query ast.Body) bool {
var stop bool
switch v := b.v.(type) {
case ast.Ref:
ast.WalkRefs(query, func(other ast.Ref) bool {
if stop || other.HasPrefix(v) {
stop = true
return stop
}
return false
})
default:
ast.WalkTerms(query, func(other *ast.Term) bool {
if stop || other.Value.Compare(v) == 0 {
stop = true
return stop
}
return false
})
}
return stop
}
func sortbindings(bindings map[ast.Var]*binding) []*binding {
sorted := make([]*binding, 0, len(bindings))
for _, b := range bindings {
sorted = append(sorted, b)
}
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].k.Compare(sorted[j].k) < 0
})
return sorted
}
type unionFind struct {
roots map[ast.Var]*unionFindRoot
parents map[ast.Var]ast.Var
rank rankFunc
}
// makeDisjointSets builds the union-find structure for the query. The structure
// is built by processing all of the equality exprs in the query. Sets represent
// vars that must be equal to each other. In addition to vars, each set can have
// at most one constant. If the query contains expressions that cannot be
// satisfied (e.g., because a set has multiple constants) this function returns
// false.
func makeDisjointSets(livevars ast.VarSet, query ast.Body) (*unionFind, bool) {
uf := newUnionFind(func(r1, r2 *unionFindRoot) (*unionFindRoot, *unionFindRoot) {
if livevars.Contains(r1.key) {
return r1, r2
}
return r2, r1
})
for _, expr := range query {
if expr.IsEquality() && !expr.Negated && len(expr.With) == 0 {
a, b := expr.Operand(0), expr.Operand(1)
varA, ok1 := a.Value.(ast.Var)
varB, ok2 := b.Value.(ast.Var)
if ok1 && ok2 {
if _, ok := uf.Merge(varA, varB); !ok {
return nil, false
}
} else if ok1 && ast.IsConstant(b.Value) {
root := uf.MakeSet(varA)
if root.constant != nil && !root.constant.Equal(b) {
return nil, false
}
root.constant = b
} else if ok2 && ast.IsConstant(a.Value) {
root := uf.MakeSet(varB)
if root.constant != nil && !root.constant.Equal(a) {
return nil, false
}
root.constant = a
}
}
}
return uf, true
}
type rankFunc func(*unionFindRoot, *unionFindRoot) (*unionFindRoot, *unionFindRoot)
func newUnionFind(rank rankFunc) *unionFind {
return &unionFind{
roots: map[ast.Var]*unionFindRoot{},
parents: map[ast.Var]ast.Var{},
rank: rank,
}
}
func (uf *unionFind) MakeSet(v ast.Var) *unionFindRoot {
root, ok := uf.Find(v)
if ok {
return root
}
root = newUnionFindRoot(v)
uf.parents[v] = v
uf.roots[v] = root
return uf.roots[v]
}
func (uf *unionFind) Find(v ast.Var) (*unionFindRoot, bool) {
parent, ok := uf.parents[v]
if !ok {
return nil, false
}
if parent == v {
return uf.roots[v], true
}
return uf.Find(parent)
}
func (uf *unionFind) Merge(a, b ast.Var) (*unionFindRoot, bool) {
r1 := uf.MakeSet(a)
r2 := uf.MakeSet(b)
if r1 != r2 {
r1, r2 = uf.rank(r1, r2)
uf.parents[r2.key] = r1.key
delete(uf.roots, r2.key)
// Sets can have at most one constant value associated with them. When
// unioning, we must preserve this invariant. If a set has two constants,
// there will be no way to prove the query.
if r1.constant != nil && r2.constant != nil && !r1.constant.Equal(r2.constant) {
return nil, false
} else if r1.constant == nil {
r1.constant = r2.constant
}
}
return r1, true
}
type unionFindRoot struct {
key ast.Var
constant *ast.Term
}
func newUnionFindRoot(key ast.Var) *unionFindRoot {
return &unionFindRoot{
key: key,
}
}
func (r *unionFindRoot) Value() ast.Value {
if r.constant != nil {
return r.constant.Value
}
return r.key
}
func isNoop(expr *ast.Expr) bool {
if !expr.IsCall() {
term := expr.Terms.(*ast.Term)
if !ast.IsConstant(term.Value) {
return false
}
return !ast.Boolean(false).Equal(term.Value)
}
// A==A can be ignored
if expr.Operator().Equal(ast.Equal.Ref()) {
return expr.Operand(0).Equal(expr.Operand(1))
}
return false
}

View File

@@ -0,0 +1,128 @@
// Copyright 2018 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/x509"
"encoding/json"
"fmt"
"io/ioutil"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
"github.com/open-policy-agent/opa/util"
)
func builtinCryptoX509ParseCertificates(a ast.Value) (ast.Value, error) {
str, err := builtinBase64Decode(a)
if err != nil {
return nil, err
}
certs, err := x509.ParseCertificates([]byte(str.(ast.String)))
if err != nil {
return nil, err
}
bs, err := json.Marshal(certs)
if err != nil {
return nil, err
}
var x interface{}
if err := util.UnmarshalJSON(bs, &x); err != nil {
return nil, err
}
return ast.InterfaceToValue(x)
}
func hashHelper(a ast.Value, h func(ast.String) string) (ast.Value, error) {
s, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
return ast.String(h(s)), nil
}
func builtinCryptoMd5(a ast.Value) (ast.Value, error) {
return hashHelper(a, func(s ast.String) string { return fmt.Sprintf("%x", md5.Sum([]byte(s))) })
}
func builtinCryptoSha1(a ast.Value) (ast.Value, error) {
return hashHelper(a, func(s ast.String) string { return fmt.Sprintf("%x", sha1.Sum([]byte(s))) })
}
func builtinCryptoSha256(a ast.Value) (ast.Value, error) {
return hashHelper(a, func(s ast.String) string { return fmt.Sprintf("%x", sha256.Sum256([]byte(s))) })
}
func init() {
RegisterFunctionalBuiltin1(ast.CryptoX509ParseCertificates.Name, builtinCryptoX509ParseCertificates)
RegisterFunctionalBuiltin1(ast.CryptoMd5.Name, builtinCryptoMd5)
RegisterFunctionalBuiltin1(ast.CryptoSha1.Name, builtinCryptoSha1)
RegisterFunctionalBuiltin1(ast.CryptoSha256.Name, builtinCryptoSha256)
}
// createRootCAs creates a new Cert Pool from scratch or adds to a copy of System Certs
func createRootCAs(tlsCACertFile string, tlsCACertEnvVar []byte, tlsUseSystemCerts bool) (*x509.CertPool, error) {
var newRootCAs *x509.CertPool
if tlsUseSystemCerts {
systemCertPool, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
newRootCAs = systemCertPool
} else {
newRootCAs = x509.NewCertPool()
}
if len(tlsCACertFile) > 0 {
// Append our cert to the system pool
caCert, err := readCertFromFile(tlsCACertFile)
if err != nil {
return nil, err
}
if ok := newRootCAs.AppendCertsFromPEM(caCert); !ok {
return nil, fmt.Errorf("could not append CA cert from %q", tlsCACertFile)
}
}
if len(tlsCACertEnvVar) > 0 {
// Append our cert to the system pool
if ok := newRootCAs.AppendCertsFromPEM(tlsCACertEnvVar); !ok {
return nil, fmt.Errorf("error appending cert from env var %q into system certs", tlsCACertEnvVar)
}
}
return newRootCAs, nil
}
// ReadCertFromFile reads a cert from file
func readCertFromFile(localCertFile string) ([]byte, error) {
// Read in the cert file
certPEM, err := ioutil.ReadFile(localCertFile)
if err != nil {
return nil, err
}
return certPEM, nil
}
// ReadKeyFromFile reads a key from file
func readKeyFromFile(localKeyFile string) ([]byte, error) {
// Read in the cert file
key, err := ioutil.ReadFile(localKeyFile)
if err != nil {
return nil, err
}
return key, nil
}

10
vendor/github.com/open-policy-agent/opa/topdown/doc.go generated vendored Normal file
View File

@@ -0,0 +1,10 @@
// Copyright 2016 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
// Package topdown provides low-level query evaluation support.
//
// The topdown implementation is a modified version of the standard top-down
// evaluation algorithm used in Datalog. References and comprehensions are
// evaluated eagerly while all other terms are evaluated lazily.
package topdown

View File

@@ -0,0 +1,217 @@
// Copyright 2017 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"net/url"
"strings"
ghodss "github.com/ghodss/yaml"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
"github.com/open-policy-agent/opa/util"
)
func builtinJSONMarshal(a ast.Value) (ast.Value, error) {
asJSON, err := ast.JSON(a)
if err != nil {
return nil, err
}
bs, err := json.Marshal(asJSON)
if err != nil {
return nil, err
}
return ast.String(string(bs)), nil
}
func builtinJSONUnmarshal(a ast.Value) (ast.Value, error) {
str, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
var x interface{}
if err := util.UnmarshalJSON([]byte(str), &x); err != nil {
return nil, err
}
return ast.InterfaceToValue(x)
}
func builtinBase64Encode(a ast.Value) (ast.Value, error) {
str, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
return ast.String(base64.StdEncoding.EncodeToString([]byte(str))), nil
}
func builtinBase64Decode(a ast.Value) (ast.Value, error) {
str, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
result, err := base64.StdEncoding.DecodeString(string(str))
return ast.String(result), err
}
func builtinBase64UrlEncode(a ast.Value) (ast.Value, error) {
str, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
return ast.String(base64.URLEncoding.EncodeToString([]byte(str))), nil
}
func builtinBase64UrlDecode(a ast.Value) (ast.Value, error) {
str, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
s := string(str)
// Some base64url encoders omit the padding at the end, so this case
// corrects such representations using the method given in RFC 7515
// Appendix C: https://tools.ietf.org/html/rfc7515#appendix-C
if !strings.HasSuffix(s, "=") {
switch len(s) % 4 {
case 0:
case 2:
s += "=="
case 3:
s += "="
default:
return nil, fmt.Errorf("illegal base64url string: %s", s)
}
}
result, err := base64.URLEncoding.DecodeString(s)
return ast.String(result), err
}
func builtinURLQueryEncode(a ast.Value) (ast.Value, error) {
str, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
return ast.String(url.QueryEscape(string(str))), nil
}
func builtinURLQueryDecode(a ast.Value) (ast.Value, error) {
str, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
s, err := url.QueryUnescape(string(str))
if err != nil {
return nil, err
}
return ast.String(s), nil
}
var encodeObjectErr = builtins.NewOperandErr(1, "values must be string, array[string], or set[string]")
func builtinURLQueryEncodeObject(a ast.Value) (ast.Value, error) {
asJSON, err := ast.JSON(a)
if err != nil {
return nil, err
}
inputs, ok := asJSON.(map[string]interface{})
if !ok {
return nil, builtins.NewOperandTypeErr(1, a, "object")
}
query := url.Values{}
for k, v := range inputs {
switch vv := v.(type) {
case string:
query.Set(k, vv)
case []interface{}:
for _, val := range vv {
strVal, ok := val.(string)
if !ok {
return nil, encodeObjectErr
}
query.Add(k, strVal)
}
default:
return nil, encodeObjectErr
}
}
return ast.String(query.Encode()), nil
}
func builtinYAMLMarshal(a ast.Value) (ast.Value, error) {
asJSON, err := ast.JSON(a)
if err != nil {
return nil, err
}
var buf bytes.Buffer
encoder := json.NewEncoder(&buf)
if err := encoder.Encode(asJSON); err != nil {
return nil, err
}
bs, err := ghodss.JSONToYAML(buf.Bytes())
if err != nil {
return nil, err
}
return ast.String(string(bs)), nil
}
func builtinYAMLUnmarshal(a ast.Value) (ast.Value, error) {
str, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
bs, err := ghodss.YAMLToJSON([]byte(str))
if err != nil {
return nil, err
}
buf := bytes.NewBuffer(bs)
decoder := util.NewJSONDecoder(buf)
var val interface{}
err = decoder.Decode(&val)
if err != nil {
return nil, err
}
return ast.InterfaceToValue(val)
}
func init() {
RegisterFunctionalBuiltin1(ast.JSONMarshal.Name, builtinJSONMarshal)
RegisterFunctionalBuiltin1(ast.JSONUnmarshal.Name, builtinJSONUnmarshal)
RegisterFunctionalBuiltin1(ast.Base64Encode.Name, builtinBase64Encode)
RegisterFunctionalBuiltin1(ast.Base64Decode.Name, builtinBase64Decode)
RegisterFunctionalBuiltin1(ast.Base64UrlEncode.Name, builtinBase64UrlEncode)
RegisterFunctionalBuiltin1(ast.Base64UrlDecode.Name, builtinBase64UrlDecode)
RegisterFunctionalBuiltin1(ast.URLQueryDecode.Name, builtinURLQueryDecode)
RegisterFunctionalBuiltin1(ast.URLQueryEncode.Name, builtinURLQueryEncode)
RegisterFunctionalBuiltin1(ast.URLQueryEncodeObject.Name, builtinURLQueryEncodeObject)
RegisterFunctionalBuiltin1(ast.YAMLMarshal.Name, builtinYAMLMarshal)
RegisterFunctionalBuiltin1(ast.YAMLUnmarshal.Name, builtinYAMLUnmarshal)
}

View File

@@ -0,0 +1,119 @@
// Copyright 2017 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"fmt"
"github.com/open-policy-agent/opa/ast"
)
// Error is the error type returned by the Eval and Query functions when
// an evaluation error occurs.
type Error struct {
Code string `json:"code"`
Message string `json:"message"`
Location *ast.Location `json:"location,omitempty"`
}
const (
// InternalErr represents an unknown evaluation error.
InternalErr string = "eval_internal_error"
// CancelErr indicates the evaluation process was cancelled.
CancelErr string = "eval_cancel_error"
// ConflictErr indicates a conflict was encountered during evaluation. For
// instance, a conflict occurs if a rule produces multiple, differing values
// for the same key in an object. Conflict errors indicate the policy does
// not account for the data loaded into the policy engine.
ConflictErr string = "eval_conflict_error"
// TypeErr indicates evaluation stopped because an expression was applied to
// a value of an inappropriate type.
TypeErr string = "eval_type_error"
// BuiltinErr indicates a built-in function received a semantically invalid
// input or encountered some kind of runtime error, e.g., connection
// timeout, connection refused, etc.
BuiltinErr string = "eval_builtin_error"
// WithMergeErr indicates that the real and replacement data could not be merged.
WithMergeErr string = "eval_with_merge_error"
)
// IsError returns true if the err is an Error.
func IsError(err error) bool {
_, ok := err.(*Error)
return ok
}
// IsCancel returns true if err was caused by cancellation.
func IsCancel(err error) bool {
if e, ok := err.(*Error); ok {
return e.Code == CancelErr
}
return false
}
func (e *Error) Error() string {
msg := fmt.Sprintf("%v: %v", e.Code, e.Message)
if e.Location != nil {
msg = e.Location.String() + ": " + msg
}
return msg
}
func functionConflictErr(loc *ast.Location) error {
return &Error{
Code: ConflictErr,
Location: loc,
Message: "functions must not produce multiple outputs for same inputs",
}
}
func completeDocConflictErr(loc *ast.Location) error {
return &Error{
Code: ConflictErr,
Location: loc,
Message: "complete rules must not produce multiple outputs",
}
}
func objectDocKeyConflictErr(loc *ast.Location) error {
return &Error{
Code: ConflictErr,
Location: loc,
Message: "object keys must be unique",
}
}
func documentConflictErr(loc *ast.Location) error {
return &Error{
Code: ConflictErr,
Location: loc,
Message: "base and virtual document keys must be disjoint",
}
}
func unsupportedBuiltinErr(loc *ast.Location) error {
return &Error{
Code: InternalErr,
Location: loc,
Message: "unsupported built-in",
}
}
func mergeConflictErr(loc *ast.Location) error {
return &Error{
Code: WithMergeErr,
Location: loc,
Message: "real and replacement data could not be merged",
}
}

2382
vendor/github.com/open-policy-agent/opa/topdown/eval.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,65 @@
package topdown
import (
"fmt"
"sync"
"github.com/gobwas/glob"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
var globCacheLock = sync.Mutex{}
var globCache map[string]glob.Glob
func builtinGlobMatch(a, b, c ast.Value) (ast.Value, error) {
pattern, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
delimiters, err := builtins.RuneSliceOperand(b, 2)
if err != nil {
return nil, err
}
if len(delimiters) == 0 {
delimiters = []rune{'.'}
}
match, err := builtins.StringOperand(c, 3)
if err != nil {
return nil, err
}
id := fmt.Sprintf("%s-%v", pattern, delimiters)
globCacheLock.Lock()
defer globCacheLock.Unlock()
p, ok := globCache[id]
if !ok {
var err error
if p, err = glob.Compile(string(pattern), delimiters...); err != nil {
return nil, err
}
globCache[id] = p
}
return ast.Boolean(p.Match(string(match))), nil
}
func builtinGlobQuoteMeta(a ast.Value) (ast.Value, error) {
pattern, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
return ast.String(glob.QuoteMeta(string(pattern))), nil
}
func init() {
globCache = map[string]glob.Glob{}
RegisterFunctionalBuiltin3(ast.GlobMatch.Name, builtinGlobMatch)
RegisterFunctionalBuiltin1(ast.GlobQuoteMeta.Name, builtinGlobQuoteMeta)
}

466
vendor/github.com/open-policy-agent/opa/topdown/http.go generated vendored Normal file
View File

@@ -0,0 +1,466 @@
// Copyright 2018 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/url"
"strconv"
"github.com/open-policy-agent/opa/internal/version"
"net/http"
"os"
"strings"
"time"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
const defaultHTTPRequestTimeoutEnv = "HTTP_SEND_TIMEOUT"
var defaultHTTPRequestTimeout = time.Second * 5
var allowedKeyNames = [...]string{
"method",
"url",
"body",
"enable_redirect",
"force_json_decode",
"headers",
"raw_body",
"tls_use_system_certs",
"tls_ca_cert_file",
"tls_ca_cert_env_variable",
"tls_client_cert_env_variable",
"tls_client_key_env_variable",
"tls_client_cert_file",
"tls_client_key_file",
"tls_insecure_skip_verify",
"timeout",
}
var allowedKeys = ast.NewSet()
var requiredKeys = ast.NewSet(ast.StringTerm("method"), ast.StringTerm("url"))
type httpSendKey string
// httpSendBuiltinCacheKey is the key in the builtin context cache that
// points to the http.send() specific cache resides at.
const httpSendBuiltinCacheKey httpSendKey = "HTTP_SEND_CACHE_KEY"
func builtinHTTPSend(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
req, err := validateHTTPRequestOperand(args[0], 1)
if err != nil {
return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err)
}
// check if cache already has a response for this query
resp := checkHTTPSendCache(bctx, req)
if resp == nil {
var err error
resp, err = executeHTTPRequest(bctx, req)
if err != nil {
return handleHTTPSendErr(bctx, err)
}
// add result to cache
insertIntoHTTPSendCache(bctx, req, resp)
}
return iter(ast.NewTerm(resp))
}
func init() {
createAllowedKeys()
initDefaults()
RegisterBuiltinFunc(ast.HTTPSend.Name, builtinHTTPSend)
}
func handleHTTPSendErr(bctx BuiltinContext, err error) error {
// Return HTTP client timeout errors in a generic error message to avoid confusion about what happened.
// Do not do this if the builtin context was cancelled and is what caused the request to stop.
if urlErr, ok := err.(*url.Error); ok && urlErr.Timeout() && bctx.Context.Err() == nil {
err = fmt.Errorf("%s %s: request timed out", urlErr.Op, urlErr.URL)
}
return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err)
}
func initDefaults() {
timeoutDuration := os.Getenv(defaultHTTPRequestTimeoutEnv)
if timeoutDuration != "" {
var err error
defaultHTTPRequestTimeout, err = time.ParseDuration(timeoutDuration)
if err != nil {
// If it is set to something not valid don't let the process continue in a state
// that will almost definitely give unexpected results by having it set at 0
// which means no timeout..
// This environment variable isn't considered part of the public API.
// TODO(patrick-east): Remove the environment variable
panic(fmt.Sprintf("invalid value for HTTP_SEND_TIMEOUT: %s", err))
}
}
}
func validateHTTPRequestOperand(term *ast.Term, pos int) (ast.Object, error) {
obj, err := builtins.ObjectOperand(term.Value, pos)
if err != nil {
return nil, err
}
requestKeys := ast.NewSet(obj.Keys()...)
invalidKeys := requestKeys.Diff(allowedKeys)
if invalidKeys.Len() != 0 {
return nil, builtins.NewOperandErr(pos, "invalid request parameters(s): %v", invalidKeys)
}
missingKeys := requiredKeys.Diff(requestKeys)
if missingKeys.Len() != 0 {
return nil, builtins.NewOperandErr(pos, "missing required request parameters(s): %v", missingKeys)
}
return obj, nil
}
// Adds custom headers to a new HTTP request.
func addHeaders(req *http.Request, headers map[string]interface{}) (bool, error) {
for k, v := range headers {
// Type assertion
header, ok := v.(string)
if !ok {
return false, fmt.Errorf("invalid type for headers value %q", v)
}
// If the Host header is given, bump that up to
// the request. Otherwise, just collect it in the
// headers.
k := http.CanonicalHeaderKey(k)
switch k {
case "Host":
req.Host = header
default:
req.Header.Add(k, header)
}
}
return true, nil
}
func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error) {
var url string
var method string
var tlsCaCertEnvVar []byte
var tlsCaCertFile string
var tlsClientKeyEnvVar []byte
var tlsClientCertEnvVar []byte
var tlsClientCertFile string
var tlsClientKeyFile string
var body *bytes.Buffer
var rawBody *bytes.Buffer
var enableRedirect bool
var forceJSONDecode bool
var tlsUseSystemCerts bool
var tlsConfig tls.Config
var clientCerts []tls.Certificate
var customHeaders map[string]interface{}
var tlsInsecureSkipVerify bool
var timeout = defaultHTTPRequestTimeout
for _, val := range obj.Keys() {
key, err := ast.JSON(val.Value)
if err != nil {
return nil, err
}
key = key.(string)
switch key {
case "method":
method = obj.Get(val).String()
method = strings.ToUpper(strings.Trim(method, "\""))
case "url":
url = obj.Get(val).String()
url = strings.Trim(url, "\"")
case "enable_redirect":
enableRedirect, err = strconv.ParseBool(obj.Get(val).String())
if err != nil {
return nil, err
}
case "force_json_decode":
forceJSONDecode, err = strconv.ParseBool(obj.Get(val).String())
if err != nil {
return nil, err
}
case "body":
bodyVal := obj.Get(val).Value
bodyValInterface, err := ast.JSON(bodyVal)
if err != nil {
return nil, err
}
bodyValBytes, err := json.Marshal(bodyValInterface)
if err != nil {
return nil, err
}
body = bytes.NewBuffer(bodyValBytes)
case "raw_body":
s, ok := obj.Get(val).Value.(ast.String)
if !ok {
return nil, fmt.Errorf("raw_body must be a string")
}
rawBody = bytes.NewBuffer([]byte(s))
case "tls_use_system_certs":
tlsUseSystemCerts, err = strconv.ParseBool(obj.Get(val).String())
if err != nil {
return nil, err
}
case "tls_ca_cert_file":
tlsCaCertFile = obj.Get(val).String()
tlsCaCertFile = strings.Trim(tlsCaCertFile, "\"")
case "tls_ca_cert_env_variable":
caCertEnv := obj.Get(val).String()
caCertEnv = strings.Trim(caCertEnv, "\"")
tlsCaCertEnvVar = []byte(os.Getenv(caCertEnv))
case "tls_client_cert_env_variable":
clientCertEnv := obj.Get(val).String()
clientCertEnv = strings.Trim(clientCertEnv, "\"")
tlsClientCertEnvVar = []byte(os.Getenv(clientCertEnv))
case "tls_client_key_env_variable":
clientKeyEnv := obj.Get(val).String()
clientKeyEnv = strings.Trim(clientKeyEnv, "\"")
tlsClientKeyEnvVar = []byte(os.Getenv(clientKeyEnv))
case "tls_client_cert_file":
tlsClientCertFile = obj.Get(val).String()
tlsClientCertFile = strings.Trim(tlsClientCertFile, "\"")
case "tls_client_key_file":
tlsClientKeyFile = obj.Get(val).String()
tlsClientKeyFile = strings.Trim(tlsClientKeyFile, "\"")
case "headers":
headersVal := obj.Get(val).Value
headersValInterface, err := ast.JSON(headersVal)
if err != nil {
return nil, err
}
var ok bool
customHeaders, ok = headersValInterface.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid type for headers key")
}
case "tls_insecure_skip_verify":
tlsInsecureSkipVerify, err = strconv.ParseBool(obj.Get(val).String())
if err != nil {
return nil, err
}
case "timeout":
timeout, err = parseTimeout(obj.Get(val).Value)
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("invalid parameter %q", key)
}
}
client := &http.Client{
Timeout: timeout,
}
if tlsInsecureSkipVerify {
client.Transport = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: tlsInsecureSkipVerify},
}
}
if tlsClientCertFile != "" && tlsClientKeyFile != "" {
clientCertFromFile, err := tls.LoadX509KeyPair(tlsClientCertFile, tlsClientKeyFile)
if err != nil {
return nil, err
}
clientCerts = append(clientCerts, clientCertFromFile)
}
if len(tlsClientCertEnvVar) > 0 && len(tlsClientKeyEnvVar) > 0 {
clientCertFromEnv, err := tls.X509KeyPair(tlsClientCertEnvVar, tlsClientKeyEnvVar)
if err != nil {
return nil, err
}
clientCerts = append(clientCerts, clientCertFromEnv)
}
isTLS := false
if len(clientCerts) > 0 {
isTLS = true
tlsConfig.Certificates = append(tlsConfig.Certificates, clientCerts...)
}
if tlsUseSystemCerts || len(tlsCaCertFile) > 0 || len(tlsCaCertEnvVar) > 0 {
isTLS = true
connRootCAs, err := createRootCAs(tlsCaCertFile, tlsCaCertEnvVar, tlsUseSystemCerts)
if err != nil {
return nil, err
}
tlsConfig.RootCAs = connRootCAs
}
if isTLS {
client.Transport = &http.Transport{
TLSClientConfig: &tlsConfig,
}
}
// check if redirects are enabled
if !enableRedirect {
client.CheckRedirect = func(*http.Request, []*http.Request) error {
return http.ErrUseLastResponse
}
}
if rawBody != nil {
body = rawBody
} else if body == nil {
body = bytes.NewBufferString("")
}
// create the http request, use the builtin context's context to ensure
// the request is cancelled if evaluation is cancelled.
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, err
}
req = req.WithContext(bctx.Context)
// Add custom headers
if len(customHeaders) != 0 {
if ok, err := addHeaders(req, customHeaders); !ok {
return nil, err
}
// Don't overwrite or append to one that was set in the custom headers
if _, hasUA := customHeaders["User-Agent"]; !hasUA {
req.Header.Add("User-Agent", version.UserAgent)
}
}
// execute the http request
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// format the http result
var resultBody interface{}
var resultRawBody []byte
var buf bytes.Buffer
tee := io.TeeReader(resp.Body, &buf)
resultRawBody, err = ioutil.ReadAll(tee)
if err != nil {
return nil, err
}
// If the response body cannot be JSON decoded,
// an error will not be returned. Instead the "body" field
// in the result will be null.
if isContentTypeJSON(resp.Header) || forceJSONDecode {
json.NewDecoder(&buf).Decode(&resultBody)
}
result := make(map[string]interface{})
result["status"] = resp.Status
result["status_code"] = resp.StatusCode
result["body"] = resultBody
result["raw_body"] = string(resultRawBody)
resultObj, err := ast.InterfaceToValue(result)
if err != nil {
return nil, err
}
return resultObj, nil
}
func isContentTypeJSON(header http.Header) bool {
return strings.Contains(header.Get("Content-Type"), "application/json")
}
// In the BuiltinContext cache we only store a single entry that points to
// our ValueMap which is the "real" http.send() cache.
func getHTTPSendCache(bctx BuiltinContext) *ast.ValueMap {
raw, ok := bctx.Cache.Get(httpSendBuiltinCacheKey)
if !ok {
// Initialize if it isn't there
cache := ast.NewValueMap()
bctx.Cache.Put(httpSendBuiltinCacheKey, cache)
return cache
}
cache, ok := raw.(*ast.ValueMap)
if !ok {
return nil
}
return cache
}
// checkHTTPSendCache checks for the given key's value in the cache
func checkHTTPSendCache(bctx BuiltinContext, key ast.Object) ast.Value {
requestCache := getHTTPSendCache(bctx)
if requestCache == nil {
return nil
}
return requestCache.Get(key)
}
func insertIntoHTTPSendCache(bctx BuiltinContext, key ast.Object, value ast.Value) {
requestCache := getHTTPSendCache(bctx)
if requestCache == nil {
// Should never happen.. if it does just skip caching the value
return
}
requestCache.Put(key, value)
}
func createAllowedKeys() {
for _, element := range allowedKeyNames {
allowedKeys.Add(ast.StringTerm(element))
}
}
func parseTimeout(timeoutVal ast.Value) (time.Duration, error) {
var timeout time.Duration
switch t := timeoutVal.(type) {
case ast.Number:
timeoutInt, ok := t.Int64()
if !ok {
return timeout, fmt.Errorf("invalid timeout number value %v, must be int64", timeoutVal)
}
return time.Duration(timeoutInt), nil
case ast.String:
// Support strings without a unit, treat them the same as just a number value (ns)
var err error
timeoutInt, err := strconv.ParseInt(string(t), 10, 64)
if err == nil {
return time.Duration(timeoutInt), nil
}
// Try parsing it as a duration (requires a supported units suffix)
timeout, err = time.ParseDuration(string(t))
if err != nil {
return timeout, fmt.Errorf("invalid timeout value %v: %s", timeoutVal, err)
}
return timeout, nil
default:
return timeout, builtins.NewOperandErr(1, "'timeout' must be one of {string, number} but got %s", ast.TypeName(t))
}
}

View File

@@ -0,0 +1,74 @@
// Copyright 2016 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"fmt"
"github.com/open-policy-agent/opa/ast"
)
var errConflictingDoc = fmt.Errorf("conflicting documents")
var errBadPath = fmt.Errorf("bad document path")
func mergeTermWithValues(exist *ast.Term, pairs [][2]*ast.Term) (*ast.Term, error) {
var result *ast.Term
if exist != nil {
result = exist.Copy()
}
for _, pair := range pairs {
if err := ast.IsValidImportPath(pair[0].Value); err != nil {
return nil, errBadPath
}
target := pair[0].Value.(ast.Ref)
if len(target) == 1 {
result = pair[1]
} else if result == nil {
result = ast.NewTerm(makeTree(target[1:], pair[1]))
} else {
node := result
done := false
for i := 1; i < len(target)-1 && !done; i++ {
if child := node.Get(target[i]); child == nil {
obj, ok := node.Value.(ast.Object)
if !ok {
return nil, errConflictingDoc
}
obj.Insert(target[i], ast.NewTerm(makeTree(target[i+1:], pair[1])))
done = true
} else {
node = child
}
}
if !done {
obj, ok := node.Value.(ast.Object)
if !ok {
return nil, errConflictingDoc
}
obj.Insert(target[len(target)-1], pair[1])
}
}
}
return result, nil
}
// makeTree returns an object that represents a document where the value v is
// the leaf and elements in k represent intermediate objects.
func makeTree(k ast.Ref, v *ast.Term) ast.Object {
var obj ast.Object
for i := len(k) - 1; i >= 1; i-- {
obj = ast.NewObject(ast.Item(k[i], v))
v = &ast.Term{Value: obj}
}
obj = ast.NewObject(ast.Item(k[0], v))
return obj
}

View File

@@ -0,0 +1,59 @@
// Copyright 2018 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import "github.com/open-policy-agent/opa/metrics"
const (
evalOpPlug = "eval_op_plug"
evalOpResolve = "eval_op_resolve"
evalOpRuleIndex = "eval_op_rule_index"
evalOpBuiltinCall = "eval_op_builtin_call"
evalOpVirtualCacheHit = "eval_op_virtual_cache_hit"
evalOpVirtualCacheMiss = "eval_op_virtual_cache_miss"
evalOpBaseCacheHit = "eval_op_base_cache_hit"
evalOpBaseCacheMiss = "eval_op_base_cache_miss"
partialOpSaveUnify = "partial_op_save_unify"
partialOpSaveSetContains = "partial_op_save_set_contains"
partialOpSaveSetContainsRec = "partial_op_save_set_contains_rec"
partialOpCopyPropagation = "partial_op_copy_propagation"
)
// Instrumentation implements helper functions to instrument query evaluation
// to diagnose performance issues. Instrumentation may be expensive in some
// cases, so it is disabled by default.
type Instrumentation struct {
m metrics.Metrics
}
// NewInstrumentation returns a new Instrumentation object. Performance
// diagnostics recorded on this Instrumentation object will stored in m.
func NewInstrumentation(m metrics.Metrics) *Instrumentation {
return &Instrumentation{
m: m,
}
}
func (instr *Instrumentation) startTimer(name string) {
if instr == nil {
return
}
instr.m.Timer(name).Start()
}
func (instr *Instrumentation) stopTimer(name string) {
if instr == nil {
return
}
delta := instr.m.Timer(name).Stop()
instr.m.Histogram(name).Update(delta)
}
func (instr *Instrumentation) counterIncr(name string) {
if instr == nil {
return
}
instr.m.Counter(name).Incr()
}

View File

@@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2015 lestrrat
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,113 @@
// Package buffer provides a very thin wrapper around []byte buffer called
// `Buffer`, to provide functionalities that are often used within the jwx
// related packages
package buffer
import (
"encoding/base64"
"encoding/binary"
"encoding/json"
"github.com/pkg/errors"
)
// Buffer wraps `[]byte` and provides functions that are often used in
// the jwx related packages. One notable difference is that while
// encoding/json marshalls `[]byte` using base64.StdEncoding, this
// module uses base64.RawURLEncoding as mandated by the spec
type Buffer []byte
// FromUint creates a `Buffer` from an unsigned int
func FromUint(v uint64) Buffer {
data := make([]byte, 8)
binary.BigEndian.PutUint64(data, v)
i := 0
for ; i < len(data); i++ {
if data[i] != 0x0 {
break
}
}
return Buffer(data[i:])
}
// FromBase64 constructs a new Buffer from a base64 encoded data
func FromBase64(v []byte) (Buffer, error) {
b := Buffer{}
if err := b.Base64Decode(v); err != nil {
return Buffer(nil), errors.Wrap(err, "failed to decode from base64")
}
return b, nil
}
// FromNData constructs a new Buffer from a "n:data" format
// (I made that name up)
func FromNData(v []byte) (Buffer, error) {
size := binary.BigEndian.Uint32(v)
buf := make([]byte, int(size))
copy(buf, v[4:4+size])
return Buffer(buf), nil
}
// Bytes returns the raw bytes that comprises the Buffer
func (b Buffer) Bytes() []byte {
return []byte(b)
}
// NData returns Datalen || Data, where Datalen is a 32 bit counter for
// the length of the following data, and Data is the octets that comprise
// the buffer data
func (b Buffer) NData() []byte {
buf := make([]byte, 4+b.Len())
binary.BigEndian.PutUint32(buf, uint32(b.Len()))
copy(buf[4:], b.Bytes())
return buf
}
// Len returns the number of bytes that the Buffer holds
func (b Buffer) Len() int {
return len(b)
}
// Base64Encode encodes the contents of the Buffer using base64.RawURLEncoding
func (b Buffer) Base64Encode() ([]byte, error) {
enc := base64.RawURLEncoding
out := make([]byte, enc.EncodedLen(len(b)))
enc.Encode(out, b)
return out, nil
}
// Base64Decode decodes the contents of the Buffer using base64.RawURLEncoding
func (b *Buffer) Base64Decode(v []byte) error {
enc := base64.RawURLEncoding
out := make([]byte, enc.DecodedLen(len(v)))
n, err := enc.Decode(out, v)
if err != nil {
return errors.Wrap(err, "failed to decode from base64")
}
out = out[:n]
*b = Buffer(out)
return nil
}
// MarshalJSON marshals the buffer into JSON format after encoding the buffer
// with base64.RawURLEncoding
func (b Buffer) MarshalJSON() ([]byte, error) {
v, err := b.Base64Encode()
if err != nil {
return nil, errors.Wrap(err, "failed to encode to base64")
}
return json.Marshal(string(v))
}
// UnmarshalJSON unmarshals from a JSON string into a Buffer, after decoding it
// with base64.RawURLEncoding
func (b *Buffer) UnmarshalJSON(data []byte) error {
var x string
if err := json.Unmarshal(data, &x); err != nil {
return errors.Wrap(err, "failed to unmarshal JSON")
}
return b.Base64Decode([]byte(x))
}

View File

@@ -0,0 +1,11 @@
package jwa
// EllipticCurveAlgorithm represents the algorithms used for EC keys
type EllipticCurveAlgorithm string
// Supported values for EllipticCurveAlgorithm
const (
P256 EllipticCurveAlgorithm = "P-256"
P384 EllipticCurveAlgorithm = "P-384"
P521 EllipticCurveAlgorithm = "P-521"
)

View File

@@ -0,0 +1,67 @@
package jwa
import (
"strconv"
"github.com/pkg/errors"
)
// KeyType represents the key type ("kty") that are supported
type KeyType string
var keyTypeAlg = map[string]struct{}{"EC": {}, "oct": {}, "RSA": {}}
// Supported values for KeyType
const (
EC KeyType = "EC" // Elliptic Curve
InvalidKeyType KeyType = "" // Invalid KeyType
OctetSeq KeyType = "oct" // Octet sequence (used to represent symmetric keys)
RSA KeyType = "RSA" // RSA
)
// Accept is used when conversion from values given by
// outside sources (such as JSON payloads) is required
func (keyType *KeyType) Accept(value interface{}) error {
var tmp KeyType
switch x := value.(type) {
case string:
tmp = KeyType(x)
case KeyType:
tmp = x
default:
return errors.Errorf(`invalid type for jwa.KeyType: %T`, value)
}
_, ok := keyTypeAlg[tmp.String()]
if !ok {
return errors.Errorf("Unknown Key Type algorithm")
}
*keyType = tmp
return nil
}
// String returns the string representation of a KeyType
func (keyType KeyType) String() string {
return string(keyType)
}
// UnmarshalJSON unmarshals and checks data as KeyType Algorithm
func (keyType *KeyType) UnmarshalJSON(data []byte) error {
var quote byte = '"'
var quoted string
if data[0] == quote {
var err error
quoted, err = strconv.Unquote(string(data))
if err != nil {
return errors.Wrap(err, "Failed to process signature algorithm")
}
} else {
quoted = string(data)
}
_, ok := keyTypeAlg[quoted]
if !ok {
return errors.Errorf("Unknown signature algorithm")
}
*keyType = KeyType(quoted)
return nil
}

View File

@@ -0,0 +1,29 @@
package jwa
import (
"crypto/elliptic"
"github.com/open-policy-agent/opa/topdown/internal/jwx/buffer"
)
// EllipticCurve provides a indirect type to standard elliptic curve such that we can
// use it for unmarshal
type EllipticCurve struct {
elliptic.Curve
}
// AlgorithmParameters provides a single structure suitable to unmarshaling any JWK
type AlgorithmParameters struct {
N buffer.Buffer `json:"n,omitempty"`
E buffer.Buffer `json:"e,omitempty"`
D buffer.Buffer `json:"d,omitempty"`
P buffer.Buffer `json:"p,omitempty"`
Q buffer.Buffer `json:"q,omitempty"`
Dp buffer.Buffer `json:"dp,omitempty"`
Dq buffer.Buffer `json:"dq,omitempty"`
Qi buffer.Buffer `json:"qi,omitempty"`
Crv EllipticCurveAlgorithm `json:"crv,omitempty"`
X buffer.Buffer `json:"x,omitempty"`
Y buffer.Buffer `json:"y,omitempty"`
K buffer.Buffer `json:"k,omitempty"`
}

View File

@@ -0,0 +1,76 @@
package jwa
import (
"strconv"
"github.com/pkg/errors"
)
// SignatureAlgorithm represents the various signature algorithms as described in https://tools.ietf.org/html/rfc7518#section-3.1
type SignatureAlgorithm string
var signatureAlg = map[string]struct{}{"ES256": {}, "ES384": {}, "ES512": {}, "HS256": {}, "HS384": {}, "HS512": {}, "PS256": {}, "PS384": {}, "PS512": {}, "RS256": {}, "RS384": {}, "RS512": {}, "none": {}}
// Supported values for SignatureAlgorithm
const (
ES256 SignatureAlgorithm = "ES256" // ECDSA using P-256 and SHA-256
ES384 SignatureAlgorithm = "ES384" // ECDSA using P-384 and SHA-384
ES512 SignatureAlgorithm = "ES512" // ECDSA using P-521 and SHA-512
HS256 SignatureAlgorithm = "HS256" // HMAC using SHA-256
HS384 SignatureAlgorithm = "HS384" // HMAC using SHA-384
HS512 SignatureAlgorithm = "HS512" // HMAC using SHA-512
NoSignature SignatureAlgorithm = "none"
PS256 SignatureAlgorithm = "PS256" // RSASSA-PSS using SHA256 and MGF1-SHA256
PS384 SignatureAlgorithm = "PS384" // RSASSA-PSS using SHA384 and MGF1-SHA384
PS512 SignatureAlgorithm = "PS512" // RSASSA-PSS using SHA512 and MGF1-SHA512
RS256 SignatureAlgorithm = "RS256" // RSASSA-PKCS-v1.5 using SHA-256
RS384 SignatureAlgorithm = "RS384" // RSASSA-PKCS-v1.5 using SHA-384
RS512 SignatureAlgorithm = "RS512" // RSASSA-PKCS-v1.5 using SHA-512
NoValue SignatureAlgorithm = "" // No value is different from none
)
// Accept is used when conversion from values given by
// outside sources (such as JSON payloads) is required
func (signature *SignatureAlgorithm) Accept(value interface{}) error {
var tmp SignatureAlgorithm
switch x := value.(type) {
case string:
tmp = SignatureAlgorithm(x)
case SignatureAlgorithm:
tmp = x
default:
return errors.Errorf(`invalid type for jwa.SignatureAlgorithm: %T`, value)
}
_, ok := signatureAlg[tmp.String()]
if !ok {
return errors.Errorf("Unknown signature algorithm")
}
*signature = tmp
return nil
}
// String returns the string representation of a SignatureAlgorithm
func (signature SignatureAlgorithm) String() string {
return string(signature)
}
// UnmarshalJSON unmarshals and checks data as Signature Algorithm
func (signature *SignatureAlgorithm) UnmarshalJSON(data []byte) error {
var quote byte = '"'
var quoted string
if data[0] == quote {
var err error
quoted, err = strconv.Unquote(string(data))
if err != nil {
return errors.Wrap(err, "Failed to process signature algorithm")
}
} else {
quoted = string(data)
}
_, ok := signatureAlg[quoted]
if !ok {
return errors.Errorf("Unknown signature algorithm")
}
*signature = SignatureAlgorithm(quoted)
return nil
}

View File

@@ -0,0 +1,120 @@
package jwk
import (
"crypto/ecdsa"
"crypto/elliptic"
"math/big"
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
func newECDSAPublicKey(key *ecdsa.PublicKey) (*ECDSAPublicKey, error) {
var hdr StandardHeaders
err := hdr.Set(KeyTypeKey, jwa.EC)
if err != nil {
return nil, errors.Wrapf(err, "Failed to set Key Type")
}
return &ECDSAPublicKey{
StandardHeaders: &hdr,
key: key,
}, nil
}
func newECDSAPrivateKey(key *ecdsa.PrivateKey) (*ECDSAPrivateKey, error) {
var hdr StandardHeaders
err := hdr.Set(KeyTypeKey, jwa.EC)
if err != nil {
return nil, errors.Wrapf(err, "Failed to set Key Type")
}
return &ECDSAPrivateKey{
StandardHeaders: &hdr,
key: key,
}, nil
}
// Materialize returns the EC-DSA public key represented by this JWK
func (k ECDSAPublicKey) Materialize() (interface{}, error) {
return k.key, nil
}
// Materialize returns the EC-DSA private key represented by this JWK
func (k ECDSAPrivateKey) Materialize() (interface{}, error) {
return k.key, nil
}
// GenerateKey creates a ECDSAPublicKey from JWK format
func (k *ECDSAPublicKey) GenerateKey(keyJSON *RawKeyJSON) error {
var x, y big.Int
if keyJSON.X == nil || keyJSON.Y == nil || keyJSON.Crv == "" {
return errors.Errorf("Missing mandatory key parameters X, Y or Crv")
}
x.SetBytes(keyJSON.X.Bytes())
y.SetBytes(keyJSON.Y.Bytes())
var curve elliptic.Curve
switch keyJSON.Crv {
case jwa.P256:
curve = elliptic.P256()
case jwa.P384:
curve = elliptic.P384()
case jwa.P521:
curve = elliptic.P521()
default:
return errors.Errorf(`invalid curve name %s`, keyJSON.Crv)
}
*k = ECDSAPublicKey{
StandardHeaders: &keyJSON.StandardHeaders,
key: &ecdsa.PublicKey{
Curve: curve,
X: &x,
Y: &y,
},
}
return nil
}
// GenerateKey creates a ECDSAPrivateKey from JWK format
func (k *ECDSAPrivateKey) GenerateKey(keyJSON *RawKeyJSON) error {
if keyJSON.D == nil {
return errors.Errorf("Missing mandatory key parameter D")
}
eCDSAPublicKey := &ECDSAPublicKey{}
err := eCDSAPublicKey.GenerateKey(keyJSON)
if err != nil {
return errors.Wrap(err, `failed to generate public key`)
}
dBytes := keyJSON.D.Bytes()
// The length of this octet string MUST be ceiling(log-base-2(n)/8)
// octets (where n is the order of the curve). This is because the private
// key d must be in the interval [1, n-1] so the bitlength of d should be
// no larger than the bitlength of n-1. The easiest way to find the octet
// length is to take bitlength(n-1), add 7 to force a carry, and shift this
// bit sequence right by 3, which is essentially dividing by 8 and adding
// 1 if there is any remainder. Thus, the private key value d should be
// output to (bitlength(n-1)+7)>>3 octets.
n := eCDSAPublicKey.key.Params().N
octetLength := (new(big.Int).Sub(n, big.NewInt(1)).BitLen() + 7) >> 3
if octetLength-len(dBytes) != 0 {
return errors.Errorf("Failed to generate private key. Incorrect D value")
}
privateKey := &ecdsa.PrivateKey{
PublicKey: *eCDSAPublicKey.key,
D: (&big.Int{}).SetBytes(keyJSON.D.Bytes()),
}
k.key = privateKey
k.StandardHeaders = &keyJSON.StandardHeaders
return nil
}

View File

@@ -0,0 +1,178 @@
package jwk
import (
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
// Convenience constants for common JWK parameters
const (
AlgorithmKey = "alg"
KeyIDKey = "kid"
KeyOpsKey = "key_ops"
KeyTypeKey = "kty"
KeyUsageKey = "use"
PrivateParamsKey = "privateParams"
)
// Headers provides a common interface to all future possible headers
type Headers interface {
Get(string) (interface{}, bool)
Set(string, interface{}) error
Walk(func(string, interface{}) error) error
GetAlgorithm() jwa.SignatureAlgorithm
GetKeyID() string
GetKeyOps() KeyOperationList
GetKeyType() jwa.KeyType
GetKeyUsage() string
GetPrivateParams() map[string]interface{}
}
// StandardHeaders stores the common JWK parameters
type StandardHeaders struct {
Algorithm *jwa.SignatureAlgorithm `json:"alg,omitempty"` // https://tools.ietf.org/html/rfc7517#section-4.4
KeyID string `json:"kid,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.4
KeyOps KeyOperationList `json:"key_ops,omitempty"` // https://tools.ietf.org/html/rfc7517#section-4.3
KeyType jwa.KeyType `json:"kty,omitempty"` // https://tools.ietf.org/html/rfc7517#section-4.1
KeyUsage string `json:"use,omitempty"` // https://tools.ietf.org/html/rfc7517#section-4.2
PrivateParams map[string]interface{} `json:"privateParams,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.4
}
// GetAlgorithm is a convenience function to retrieve the corresponding value stored in the StandardHeaders
func (h *StandardHeaders) GetAlgorithm() jwa.SignatureAlgorithm {
if v := h.Algorithm; v != nil {
return *v
}
return jwa.NoValue
}
// GetKeyID is a convenience function to retrieve the corresponding value stored in the StandardHeaders
func (h *StandardHeaders) GetKeyID() string {
return h.KeyID
}
// GetKeyOps is a convenience function to retrieve the corresponding value stored in the StandardHeaders
func (h *StandardHeaders) GetKeyOps() KeyOperationList {
return h.KeyOps
}
// GetKeyType is a convenience function to retrieve the corresponding value stored in the StandardHeaders
func (h *StandardHeaders) GetKeyType() jwa.KeyType {
return h.KeyType
}
// GetKeyUsage is a convenience function to retrieve the corresponding value stored in the StandardHeaders
func (h *StandardHeaders) GetKeyUsage() string {
return h.KeyUsage
}
// GetPrivateParams is a convenience function to retrieve the corresponding value stored in the StandardHeaders
func (h *StandardHeaders) GetPrivateParams() map[string]interface{} {
return h.PrivateParams
}
// Get is a general getter function for JWK StandardHeaders structure
func (h *StandardHeaders) Get(name string) (interface{}, bool) {
switch name {
case AlgorithmKey:
alg := h.GetAlgorithm()
if alg != jwa.NoValue {
return alg, true
}
return nil, false
case KeyIDKey:
v := h.KeyID
if v == "" {
return nil, false
}
return v, true
case KeyOpsKey:
v := h.KeyOps
if v == nil {
return nil, false
}
return v, true
case KeyTypeKey:
v := h.KeyType
if v == jwa.InvalidKeyType {
return nil, false
}
return v, true
case KeyUsageKey:
v := h.KeyUsage
if v == "" {
return nil, false
}
return v, true
case PrivateParamsKey:
v := h.PrivateParams
if len(v) == 0 {
return nil, false
}
return v, true
default:
return nil, false
}
}
// Set is a general getter function for JWK StandardHeaders structure
func (h *StandardHeaders) Set(name string, value interface{}) error {
switch name {
case AlgorithmKey:
var acceptor jwa.SignatureAlgorithm
if err := acceptor.Accept(value); err != nil {
return errors.Wrapf(err, `invalid value for %s key`, AlgorithmKey)
}
h.Algorithm = &acceptor
return nil
case KeyIDKey:
if v, ok := value.(string); ok {
h.KeyID = v
return nil
}
return errors.Errorf("invalid value for %s key: %T", KeyIDKey, value)
case KeyOpsKey:
if err := h.KeyOps.Accept(value); err != nil {
return errors.Wrapf(err, "invalid value for %s key", KeyOpsKey)
}
return nil
case KeyTypeKey:
if err := h.KeyType.Accept(value); err != nil {
return errors.Wrapf(err, "invalid value for %s key", KeyTypeKey)
}
return nil
case KeyUsageKey:
if v, ok := value.(string); ok {
h.KeyUsage = v
return nil
}
return errors.Errorf("invalid value for %s key: %T", KeyUsageKey, value)
case PrivateParamsKey:
if v, ok := value.(map[string]interface{}); ok {
h.PrivateParams = v
return nil
}
return errors.Errorf("invalid value for %s key: %T", PrivateParamsKey, value)
default:
return errors.Errorf(`invalid key: %s`, name)
}
}
// Walk iterates over all JWK standard headers fields while applying a function to its value.
func (h StandardHeaders) Walk(f func(string, interface{}) error) error {
for _, key := range []string{AlgorithmKey, KeyIDKey, KeyOpsKey, KeyTypeKey, KeyUsageKey, PrivateParamsKey} {
if v, ok := h.Get(key); ok {
if err := f(key, v); err != nil {
return errors.Wrapf(err, `walk function returned error for %s`, key)
}
}
}
for k, v := range h.PrivateParams {
if err := f(k, v); err != nil {
return errors.Wrapf(err, `walk function returned error for %s`, k)
}
}
return nil
}

View File

@@ -0,0 +1,70 @@
package jwk
import (
"crypto/ecdsa"
"crypto/rsa"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
// Set is a convenience struct to allow generating and parsing
// JWK sets as opposed to single JWKs
type Set struct {
Keys []Key `json:"keys"`
}
// Key defines the minimal interface for each of the
// key types. Their use and implementation differ significantly
// between each key types, so you should use type assertions
// to perform more specific tasks with each key
type Key interface {
Headers
// Materialize creates the corresponding key. For example,
// RSA types would create *rsa.PublicKey or *rsa.PrivateKey,
// EC types would create *ecdsa.PublicKey or *ecdsa.PrivateKey,
// and OctetSeq types create a []byte key.
Materialize() (interface{}, error)
GenerateKey(*RawKeyJSON) error
}
// RawKeyJSON is generic type that represents any kind JWK
type RawKeyJSON struct {
StandardHeaders
jwa.AlgorithmParameters
}
// RawKeySetJSON is generic type that represents a JWK Set
type RawKeySetJSON struct {
Keys []RawKeyJSON `json:"keys"`
}
// RSAPublicKey is a type of JWK generated from RSA public keys
type RSAPublicKey struct {
*StandardHeaders
key *rsa.PublicKey
}
// RSAPrivateKey is a type of JWK generated from RSA private keys
type RSAPrivateKey struct {
*StandardHeaders
key *rsa.PrivateKey
}
// SymmetricKey is a type of JWK generated from symmetric keys
type SymmetricKey struct {
*StandardHeaders
key []byte
}
// ECDSAPublicKey is a type of JWK generated from ECDSA public keys
type ECDSAPublicKey struct {
*StandardHeaders
key *ecdsa.PublicKey
}
// ECDSAPrivateKey is a type of JWK generated from ECDH-ES private keys
type ECDSAPrivateKey struct {
*StandardHeaders
key *ecdsa.PrivateKey
}

View File

@@ -0,0 +1,150 @@
// Package jwk implements JWK as described in https://tools.ietf.org/html/rfc7517
package jwk
import (
"crypto/ecdsa"
"crypto/rsa"
"encoding/json"
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
// GetPublicKey returns the public key based on the private key type.
// For rsa key types *rsa.PublicKey is returned; for ecdsa key types *ecdsa.PublicKey;
// for byte slice (raw) keys, the key itself is returned. If the corresponding
// public key cannot be deduced, an error is returned
func GetPublicKey(key interface{}) (interface{}, error) {
if key == nil {
return nil, errors.New(`jwk.New requires a non-nil key`)
}
switch v := key.(type) {
// Mental note: although Public() is defined in both types,
// you can not coalesce the clauses for rsa.PrivateKey and
// ecdsa.PrivateKey, as then `v` becomes interface{}
// b/c the compiler cannot deduce the exact type.
case *rsa.PrivateKey:
return v.Public(), nil
case *ecdsa.PrivateKey:
return v.Public(), nil
case []byte:
return v, nil
default:
return nil, errors.Errorf(`invalid key type %T`, key)
}
}
// GetKeyTypeFromKey creates a jwk.Key from the given key.
func GetKeyTypeFromKey(key interface{}) jwa.KeyType {
switch key.(type) {
case *rsa.PrivateKey, *rsa.PublicKey:
return jwa.RSA
case *ecdsa.PrivateKey, *ecdsa.PublicKey:
return jwa.EC
case []byte:
return jwa.OctetSeq
default:
return jwa.InvalidKeyType
}
}
// New creates a jwk.Key from the given key.
func New(key interface{}) (Key, error) {
if key == nil {
return nil, errors.New(`jwk.New requires a non-nil key`)
}
switch v := key.(type) {
case *rsa.PrivateKey:
return newRSAPrivateKey(v)
case *rsa.PublicKey:
return newRSAPublicKey(v)
case *ecdsa.PrivateKey:
return newECDSAPrivateKey(v)
case *ecdsa.PublicKey:
return newECDSAPublicKey(v)
case []byte:
return newSymmetricKey(v)
default:
return nil, errors.Errorf(`invalid key type %T`, key)
}
}
func parse(jwkSrc string) (*Set, error) {
var jwkKeySet Set
var jwkKey Key
rawKeySetJSON := &RawKeySetJSON{}
err := json.Unmarshal([]byte(jwkSrc), rawKeySetJSON)
if err != nil {
return nil, errors.Wrap(err, "Failed to unmarshal JWK Set")
}
if len(rawKeySetJSON.Keys) == 0 {
// It might be a single key
rawKeyJSON := &RawKeyJSON{}
err := json.Unmarshal([]byte(jwkSrc), rawKeyJSON)
if err != nil {
return nil, errors.Wrap(err, "Failed to unmarshal JWK")
}
jwkKey, err = rawKeyJSON.GenerateKey()
if err != nil {
return nil, errors.Wrap(err, "Failed to generate key")
}
// Add to set
jwkKeySet.Keys = append(jwkKeySet.Keys, jwkKey)
} else {
for i := range rawKeySetJSON.Keys {
rawKeyJSON := rawKeySetJSON.Keys[i]
jwkKey, err = rawKeyJSON.GenerateKey()
if err != nil {
return nil, errors.Wrap(err, "Failed to generate key: %s")
}
jwkKeySet.Keys = append(jwkKeySet.Keys, jwkKey)
}
}
return &jwkKeySet, nil
}
// ParseBytes parses JWK from the incoming byte buffer.
func ParseBytes(buf []byte) (*Set, error) {
return parse(string(buf[:]))
}
// ParseString parses JWK from the incoming string.
func ParseString(s string) (*Set, error) {
return parse(s)
}
// GenerateKey creates an internal representation of a key from a raw JWK JSON
func (r *RawKeyJSON) GenerateKey() (Key, error) {
var key Key
switch r.KeyType {
case jwa.RSA:
if r.D != nil {
key = &RSAPrivateKey{}
} else {
key = &RSAPublicKey{}
}
case jwa.EC:
if r.D != nil {
key = &ECDSAPrivateKey{}
} else {
key = &ECDSAPublicKey{}
}
case jwa.OctetSeq:
key = &SymmetricKey{}
default:
return nil, errors.Errorf(`Unrecognized key type`)
}
err := key.GenerateKey(r)
if err != nil {
return nil, errors.Wrap(err, "Failed to generate key from JWK")
}
return key, nil
}

View File

@@ -0,0 +1,68 @@
package jwk
import (
"encoding/json"
"fmt"
"github.com/pkg/errors"
)
// KeyUsageType is used to denote what this key should be used for
type KeyUsageType string
const (
// ForSignature is the value used in the headers to indicate that
// this key should be used for signatures
ForSignature KeyUsageType = "sig"
// ForEncryption is the value used in the headers to indicate that
// this key should be used for encryptiong
ForEncryption KeyUsageType = "enc"
)
// KeyOperation is used to denote the allowed operations for a Key
type KeyOperation string
// KeyOperationList represents an slice of KeyOperation
type KeyOperationList []KeyOperation
var keyOps = map[string]struct{}{"sign": {}, "verify": {}, "encrypt": {}, "decrypt": {}, "wrapKey": {}, "unwrapKey": {}, "deriveKey": {}, "deriveBits": {}}
// KeyOperation constants
const (
KeyOpSign KeyOperation = "sign" // (compute digital signature or MAC)
KeyOpVerify = "verify" // (verify digital signature or MAC)
KeyOpEncrypt = "encrypt" // (encrypt content)
KeyOpDecrypt = "decrypt" // (decrypt content and validate decryption, if applicable)
KeyOpWrapKey = "wrapKey" // (encrypt key)
KeyOpUnwrapKey = "unwrapKey" // (decrypt key and validate decryption, if applicable)
KeyOpDeriveKey = "deriveKey" // (derive key)
KeyOpDeriveBits = "deriveBits" // (derive bits not to be used as a key)
)
// Accept determines if Key Operation is valid
func (keyOperationList *KeyOperationList) Accept(v interface{}) error {
switch x := v.(type) {
case KeyOperationList:
*keyOperationList = x
return nil
default:
return errors.Errorf(`invalid value %T`, v)
}
}
// UnmarshalJSON unmarshals and checks data as KeyType Algorithm
func (keyOperationList *KeyOperationList) UnmarshalJSON(data []byte) error {
var tempKeyOperationList []string
err := json.Unmarshal(data, &tempKeyOperationList)
if err != nil {
return fmt.Errorf("invalid key operation")
}
for _, value := range tempKeyOperationList {
_, ok := keyOps[value]
if !ok {
return fmt.Errorf("unknown key operation")
}
*keyOperationList = append(*keyOperationList, KeyOperation(value))
}
return nil
}

View File

@@ -0,0 +1,103 @@
package jwk
import (
"crypto/rsa"
"math/big"
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
func newRSAPublicKey(key *rsa.PublicKey) (*RSAPublicKey, error) {
var hdr StandardHeaders
err := hdr.Set(KeyTypeKey, jwa.RSA)
if err != nil {
return nil, errors.Wrapf(err, "Failed to set Key Type")
}
return &RSAPublicKey{
StandardHeaders: &hdr,
key: key,
}, nil
}
func newRSAPrivateKey(key *rsa.PrivateKey) (*RSAPrivateKey, error) {
var hdr StandardHeaders
err := hdr.Set(KeyTypeKey, jwa.RSA)
if err != nil {
return nil, errors.Wrapf(err, "Failed to set Key Type")
}
return &RSAPrivateKey{
StandardHeaders: &hdr,
key: key,
}, nil
}
// Materialize returns the standard RSA Public Key representation stored in the internal representation
func (k *RSAPublicKey) Materialize() (interface{}, error) {
if k.key == nil {
return nil, errors.New(`key has no rsa.PublicKey associated with it`)
}
return k.key, nil
}
// Materialize returns the standard RSA Private Key representation stored in the internal representation
func (k *RSAPrivateKey) Materialize() (interface{}, error) {
if k.key == nil {
return nil, errors.New(`key has no rsa.PrivateKey associated with it`)
}
return k.key, nil
}
// GenerateKey creates a RSAPublicKey from a RawKeyJSON
func (k *RSAPublicKey) GenerateKey(keyJSON *RawKeyJSON) error {
if keyJSON.N == nil || keyJSON.E == nil {
return errors.Errorf("Missing mandatory key parameters N or E")
}
rsaPublicKey := &rsa.PublicKey{
N: (&big.Int{}).SetBytes(keyJSON.N.Bytes()),
E: int((&big.Int{}).SetBytes(keyJSON.E.Bytes()).Int64()),
}
k.key = rsaPublicKey
k.StandardHeaders = &keyJSON.StandardHeaders
return nil
}
// GenerateKey creates a RSAPublicKey from a RawKeyJSON
func (k *RSAPrivateKey) GenerateKey(keyJSON *RawKeyJSON) error {
rsaPublicKey := &RSAPublicKey{}
err := rsaPublicKey.GenerateKey(keyJSON)
if err != nil {
return errors.Wrap(err, "failed to generate public key")
}
if keyJSON.D == nil || keyJSON.P == nil || keyJSON.Q == nil {
return errors.Errorf("Missing mandatory key parameters D, P or Q")
}
privateKey := &rsa.PrivateKey{
PublicKey: *rsaPublicKey.key,
D: (&big.Int{}).SetBytes(keyJSON.D.Bytes()),
Primes: []*big.Int{
(&big.Int{}).SetBytes(keyJSON.P.Bytes()),
(&big.Int{}).SetBytes(keyJSON.Q.Bytes()),
},
}
if keyJSON.Dp.Len() > 0 {
privateKey.Precomputed.Dp = (&big.Int{}).SetBytes(keyJSON.Dp.Bytes())
}
if keyJSON.Dq.Len() > 0 {
privateKey.Precomputed.Dq = (&big.Int{}).SetBytes(keyJSON.Dq.Bytes())
}
if keyJSON.Qi.Len() > 0 {
privateKey.Precomputed.Qinv = (&big.Int{}).SetBytes(keyJSON.Qi.Bytes())
}
k.key = privateKey
k.StandardHeaders = &keyJSON.StandardHeaders
return nil
}

View File

@@ -0,0 +1,41 @@
package jwk
import (
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
func newSymmetricKey(key []byte) (*SymmetricKey, error) {
var hdr StandardHeaders
err := hdr.Set(KeyTypeKey, jwa.OctetSeq)
if err != nil {
return nil, errors.Wrapf(err, "Failed to set Key Type")
}
return &SymmetricKey{
StandardHeaders: &hdr,
key: key,
}, nil
}
// Materialize returns the octets for this symmetric key.
// Since this is a symmetric key, this just calls Octets
func (s SymmetricKey) Materialize() (interface{}, error) {
return s.Octets(), nil
}
// Octets returns the octets in the key
func (s SymmetricKey) Octets() []byte {
return s.key
}
// GenerateKey creates a Symmetric key from a RawKeyJSON
func (s *SymmetricKey) GenerateKey(keyJSON *RawKeyJSON) error {
*s = SymmetricKey{
StandardHeaders: &keyJSON.StandardHeaders,
key: keyJSON.K,
}
return nil
}

View File

@@ -0,0 +1,154 @@
package jws
import (
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
// Constants for JWS Common parameters
const (
AlgorithmKey = "alg"
ContentTypeKey = "cty"
CriticalKey = "crit"
JWKKey = "jwk"
JWKSetURLKey = "jku"
KeyIDKey = "kid"
PrivateParamsKey = "privateParams"
TypeKey = "typ"
)
// Headers provides a common interface for common header parameters
type Headers interface {
Get(string) (interface{}, bool)
Set(string, interface{}) error
GetAlgorithm() jwa.SignatureAlgorithm
}
// StandardHeaders contains JWS common parameters.
type StandardHeaders struct {
Algorithm jwa.SignatureAlgorithm `json:"alg,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.1
ContentType string `json:"cty,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.10
Critical []string `json:"crit,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.11
JWK string `json:"jwk,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.3
JWKSetURL string `json:"jku,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.2
KeyID string `json:"kid,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.4
PrivateParams map[string]interface{} `json:"privateParams,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.9
Type string `json:"typ,omitempty"` // https://tools.ietf.org/html/rfc7515#section-4.1.9
}
// GetAlgorithm returns algorithm
func (h *StandardHeaders) GetAlgorithm() jwa.SignatureAlgorithm {
return h.Algorithm
}
// Get is a general getter function for StandardHeaders structure
func (h *StandardHeaders) Get(name string) (interface{}, bool) {
switch name {
case AlgorithmKey:
v := h.Algorithm
if v == "" {
return nil, false
}
return v, true
case ContentTypeKey:
v := h.ContentType
if v == "" {
return nil, false
}
return v, true
case CriticalKey:
v := h.Critical
if len(v) == 0 {
return nil, false
}
return v, true
case JWKKey:
v := h.JWK
if v == "" {
return nil, false
}
return v, true
case JWKSetURLKey:
v := h.JWKSetURL
if v == "" {
return nil, false
}
return v, true
case KeyIDKey:
v := h.KeyID
if v == "" {
return nil, false
}
return v, true
case PrivateParamsKey:
v := h.PrivateParams
if len(v) == 0 {
return nil, false
}
return v, true
case TypeKey:
v := h.Type
if v == "" {
return nil, false
}
return v, true
default:
return nil, false
}
}
// Set is a general setter function for StandardHeaders structure
func (h *StandardHeaders) Set(name string, value interface{}) error {
switch name {
case AlgorithmKey:
if err := h.Algorithm.Accept(value); err != nil {
return errors.Wrapf(err, `invalid value for %s key`, AlgorithmKey)
}
return nil
case ContentTypeKey:
if v, ok := value.(string); ok {
h.ContentType = v
return nil
}
return errors.Errorf(`invalid value for %s key: %T`, ContentTypeKey, value)
case CriticalKey:
if v, ok := value.([]string); ok {
h.Critical = v
return nil
}
return errors.Errorf(`invalid value for %s key: %T`, CriticalKey, value)
case JWKKey:
if v, ok := value.(string); ok {
h.JWK = v
return nil
}
return errors.Errorf(`invalid value for %s key: %T`, JWKKey, value)
case JWKSetURLKey:
if v, ok := value.(string); ok {
h.JWKSetURL = v
return nil
}
return errors.Errorf(`invalid value for %s key: %T`, JWKSetURLKey, value)
case KeyIDKey:
if v, ok := value.(string); ok {
h.KeyID = v
return nil
}
return errors.Errorf(`invalid value for %s key: %T`, KeyIDKey, value)
case PrivateParamsKey:
if v, ok := value.(map[string]interface{}); ok {
h.PrivateParams = v
return nil
}
return errors.Errorf(`invalid value for %s key: %T`, PrivateParamsKey, value)
case TypeKey:
if v, ok := value.(string); ok {
h.Type = v
return nil
}
return errors.Errorf(`invalid value for %s key: %T`, TypeKey, value)
default:
return errors.Errorf(`invalid key: %s`, name)
}
}

View File

@@ -0,0 +1,22 @@
package jws
// Message represents a full JWS encoded message. Flattened serialization
// is not supported as a struct, but rather it's represented as a
// Message struct with only one `Signature` element.
//
// Do not expect to use the Message object to verify or construct a
// signed payloads with. You should only use this when you want to actually
// want to programmatically view the contents for the full JWS Payload.
//
// To sign and verify, use the appropriate `SignWithOption()` nad `Verify()` functions
type Message struct {
Payload []byte `json:"payload"`
Signatures []*Signature `json:"signatures,omitempty"`
}
// Signature represents the headers and signature of a JWS message
type Signature struct {
Headers Headers `json:"header,omitempty"` // Unprotected Headers
Protected Headers `json:"Protected,omitempty"` // Protected Headers
Signature []byte `json:"signature,omitempty"` // GetSignature
}

View File

@@ -0,0 +1,210 @@
// Package jws implements the digital Signature on JSON based data
// structures as described in https://tools.ietf.org/html/rfc7515
//
// If you do not care about the details, the only things that you
// would need to use are the following functions:
//
// jws.SignWithOption(Payload, algorithm, key)
// jws.Verify(encodedjws, algorithm, key)
//
// To sign, simply use `jws.SignWithOption`. `Payload` is a []byte buffer that
// contains whatever data you want to sign. `alg` is one of the
// jwa.SignatureAlgorithm constants from package jwa. For RSA and
// ECDSA family of algorithms, you will need to prepare a private key.
// For HMAC family, you just need a []byte value. The `jws.SignWithOption`
// function will return the encoded JWS message on success.
//
// To verify, use `jws.Verify`. It will parse the `encodedjws` buffer
// and verify the result using `algorithm` and `key`. Upon successful
// verification, the original Payload is returned, so you can work on it.
package jws
import (
"bytes"
"encoding/base64"
"encoding/json"
"strings"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwk"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws/verify"
"github.com/pkg/errors"
)
// SignLiteral generates a Signature for the given Payload and Headers, and serializes
// it in compact serialization format. In this format you may NOT use
// multiple signers.
//
func SignLiteral(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, hdrBuf []byte) ([]byte, error) {
encodedHdr := base64.RawURLEncoding.EncodeToString(hdrBuf)
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
signingInput := strings.Join(
[]string{
encodedHdr,
encodedPayload,
}, ".",
)
signer, err := sign.New(alg)
if err != nil {
return nil, errors.Wrap(err, `failed to create signer`)
}
signature, err := signer.Sign([]byte(signingInput), key)
if err != nil {
return nil, errors.Wrap(err, `failed to sign Payload`)
}
encodedSignature := base64.RawURLEncoding.EncodeToString(signature)
compactSerialization := strings.Join(
[]string{
signingInput,
encodedSignature,
}, ".",
)
return []byte(compactSerialization), nil
}
// SignWithOption generates a Signature for the given Payload, and serializes
// it in compact serialization format. In this format you may NOT use
// multiple signers.
//
// If you would like to pass custom Headers, use the WithHeaders option.
func SignWithOption(payload []byte, alg jwa.SignatureAlgorithm, key interface{}) ([]byte, error) {
var headers Headers = &StandardHeaders{}
err := headers.Set(AlgorithmKey, alg)
if err != nil {
return nil, errors.Wrap(err, "Failed to set alg value")
}
hdrBuf, err := json.Marshal(headers)
if err != nil {
return nil, errors.Wrap(err, `failed to marshal Headers`)
}
return SignLiteral(payload, alg, key, hdrBuf)
}
// Verify checks if the given JWS message is verifiable using `alg` and `key`.
// If the verification is successful, `err` is nil, and the content of the
// Payload that was signed is returned. If you need more fine-grained
// control of the verification process, manually call `Parse`, generate a
// verifier, and call `Verify` on the parsed JWS message object.
func Verify(buf []byte, alg jwa.SignatureAlgorithm, key interface{}) (ret []byte, err error) {
verifier, err := verify.New(alg)
if err != nil {
return nil, errors.Wrap(err, "failed to create verifier")
}
buf = bytes.TrimSpace(buf)
if len(buf) == 0 {
return nil, errors.New(`attempt to verify empty buffer`)
}
parts, err := SplitCompact(string(buf[:]))
if err != nil {
return nil, errors.Wrap(err, `failed extract from compact serialization format`)
}
signingInput := strings.Join(
[]string{
parts[0],
parts[1],
}, ".",
)
decodedSignature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, errors.Wrap(err, "Failed to decode signature")
}
if err := verifier.Verify([]byte(signingInput), decodedSignature, key); err != nil {
return nil, errors.Wrap(err, "Failed to verify message")
}
if decodedPayload, err := base64.RawURLEncoding.DecodeString(parts[1]); err == nil {
return decodedPayload, nil
}
return nil, errors.Wrap(err, "Failed to decode Payload")
}
// VerifyWithJWK verifies the JWS message using the specified JWK
func VerifyWithJWK(buf []byte, key jwk.Key) (payload []byte, err error) {
keyVal, err := key.Materialize()
if err != nil {
return nil, errors.Wrap(err, "Failed to materialize key")
}
return Verify(buf, key.GetAlgorithm(), keyVal)
}
// VerifyWithJWKSet verifies the JWS message using JWK key set.
// By default it will only pick up keys that have the "use" key
// set to either "sig" or "enc", but you can override it by
// providing a keyaccept function.
func VerifyWithJWKSet(buf []byte, keyset *jwk.Set) (payload []byte, err error) {
for _, key := range keyset.Keys {
payload, err := VerifyWithJWK(buf, key)
if err == nil {
return payload, nil
}
}
return nil, errors.New("failed to verify with any of the keys")
}
// ParseByte parses a JWS value serialized via compact serialization and provided as []byte.
func ParseByte(jwsCompact []byte) (m *Message, err error) {
return parseCompact(string(jwsCompact[:]))
}
// ParseString parses a JWS value serialized via compact serialization and provided as string.
func ParseString(s string) (*Message, error) {
return parseCompact(s)
}
// SplitCompact splits a JWT and returns its three parts
// separately: Protected Headers, Payload and Signature.
func SplitCompact(jwsCompact string) ([]string, error) {
parts := strings.Split(jwsCompact, ".")
if len(parts) < 3 {
return nil, errors.New("Failed to split compact serialization")
}
return parts, nil
}
// parseCompact parses a JWS value serialized via compact serialization.
func parseCompact(str string) (m *Message, err error) {
var decodedHeader, decodedPayload, decodedSignature []byte
parts, err := SplitCompact(str)
if err != nil {
return nil, errors.Wrap(err, `invalid compact serialization format`)
}
if decodedHeader, err = base64.RawURLEncoding.DecodeString(parts[0]); err != nil {
return nil, errors.Wrap(err, `failed to decode Headers`)
}
var hdr StandardHeaders
if err := json.Unmarshal(decodedHeader, &hdr); err != nil {
return nil, errors.Wrap(err, `failed to parse JOSE Headers`)
}
if decodedPayload, err = base64.RawURLEncoding.DecodeString(parts[1]); err != nil {
return nil, errors.Wrap(err, `failed to decode Payload`)
}
if len(parts) > 2 {
if decodedSignature, err = base64.RawURLEncoding.DecodeString(parts[2]); err != nil {
return nil, errors.Wrap(err, `failed to decode Signature`)
}
}
var msg Message
msg.Payload = decodedPayload
msg.Signatures = append(msg.Signatures, &Signature{
Protected: &hdr,
Signature: decodedSignature,
})
return &msg, nil
}

View File

@@ -0,0 +1,26 @@
package jws
// PublicHeaders returns the public headers in a JWS
func (s Signature) PublicHeaders() Headers {
return s.Headers
}
// ProtectedHeaders returns the protected headers in a JWS
func (s Signature) ProtectedHeaders() Headers {
return s.Protected
}
// GetSignature returns the signature in a JWS
func (s Signature) GetSignature() []byte {
return s.Signature
}
// GetPayload returns the payload in a JWS
func (m Message) GetPayload() []byte {
return m.Payload
}
// GetSignatures returns the all signatures in a JWS
func (m Message) GetSignatures() []*Signature {
return m.Signatures
}

View File

@@ -0,0 +1,84 @@
package sign
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
"github.com/pkg/errors"
)
var ecdsaSignFuncs = map[jwa.SignatureAlgorithm]ecdsaSignFunc{}
func init() {
algs := map[jwa.SignatureAlgorithm]crypto.Hash{
jwa.ES256: crypto.SHA256,
jwa.ES384: crypto.SHA384,
jwa.ES512: crypto.SHA512,
}
for alg, h := range algs {
ecdsaSignFuncs[alg] = makeECDSASignFunc(h)
}
}
func makeECDSASignFunc(hash crypto.Hash) ecdsaSignFunc {
return ecdsaSignFunc(func(payload []byte, key *ecdsa.PrivateKey) ([]byte, error) {
curveBits := key.Curve.Params().BitSize
keyBytes := curveBits / 8
// Curve bits do not need to be a multiple of 8.
if curveBits%8 > 0 {
keyBytes++
}
h := hash.New()
h.Write(payload)
r, s, err := ecdsa.Sign(rand.Reader, key, h.Sum(nil))
if err != nil {
return nil, errors.Wrap(err, "failed to sign payload using ecdsa")
}
rBytes := r.Bytes()
rBytesPadded := make([]byte, keyBytes)
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
sBytes := s.Bytes()
sBytesPadded := make([]byte, keyBytes)
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
out := append(rBytesPadded, sBytesPadded...)
return out, nil
})
}
func newECDSA(alg jwa.SignatureAlgorithm) (*ECDSASigner, error) {
signfn, ok := ecdsaSignFuncs[alg]
if !ok {
return nil, errors.Errorf(`unsupported algorithm while trying to create ECDSA signer: %s`, alg)
}
return &ECDSASigner{
alg: alg,
sign: signfn,
}, nil
}
// Algorithm returns the signer algorithm
func (s ECDSASigner) Algorithm() jwa.SignatureAlgorithm {
return s.alg
}
// Sign signs payload with a ECDSA private key
func (s ECDSASigner) Sign(payload []byte, key interface{}) ([]byte, error) {
if key == nil {
return nil, errors.New(`missing private key while signing payload`)
}
privateKey, ok := key.(*ecdsa.PrivateKey)
if !ok {
return nil, errors.Errorf(`invalid key type %T. *ecdsa.PrivateKey is required`, key)
}
return s.sign(payload, privateKey)
}

View File

@@ -0,0 +1,66 @@
package sign
import (
"crypto/hmac"
"crypto/sha256"
"crypto/sha512"
"hash"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
"github.com/pkg/errors"
)
var hmacSignFuncs = map[jwa.SignatureAlgorithm]hmacSignFunc{}
func init() {
algs := map[jwa.SignatureAlgorithm]func() hash.Hash{
jwa.HS256: sha256.New,
jwa.HS384: sha512.New384,
jwa.HS512: sha512.New,
}
for alg, h := range algs {
hmacSignFuncs[alg] = makeHMACSignFunc(h)
}
}
func newHMAC(alg jwa.SignatureAlgorithm) (*HMACSigner, error) {
signer, ok := hmacSignFuncs[alg]
if !ok {
return nil, errors.Errorf(`unsupported algorithm while trying to create HMAC signer: %s`, alg)
}
return &HMACSigner{
alg: alg,
sign: signer,
}, nil
}
func makeHMACSignFunc(hfunc func() hash.Hash) hmacSignFunc {
return hmacSignFunc(func(payload []byte, key []byte) ([]byte, error) {
h := hmac.New(hfunc, key)
h.Write(payload)
return h.Sum(nil), nil
})
}
// Algorithm returns the signer algorithm
func (s HMACSigner) Algorithm() jwa.SignatureAlgorithm {
return s.alg
}
// Sign signs payload with a Symmetric key
func (s HMACSigner) Sign(payload []byte, key interface{}) ([]byte, error) {
hmackey, ok := key.([]byte)
if !ok {
return nil, errors.Errorf(`invalid key type %T. []byte is required`, key)
}
if len(hmackey) == 0 {
return nil, errors.New(`missing key while signing payload`)
}
return s.sign(payload, hmackey)
}

View File

@@ -0,0 +1,45 @@
package sign
import (
"crypto/ecdsa"
"crypto/rsa"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
// Signer provides a common interface for supported alg signing methods
type Signer interface {
// Sign creates a signature for the given `payload`.
// `key` is the key used for signing the payload, and is usually
// the private key type associated with the signature method. For example,
// for `jwa.RSXXX` and `jwa.PSXXX` types, you need to pass the
// `*"crypto/rsa".PrivateKey` type.
// Check the documentation for each signer for details
Sign(payload []byte, key interface{}) ([]byte, error)
Algorithm() jwa.SignatureAlgorithm
}
type rsaSignFunc func([]byte, *rsa.PrivateKey) ([]byte, error)
// RSASigner uses crypto/rsa to sign the payloads.
type RSASigner struct {
alg jwa.SignatureAlgorithm
sign rsaSignFunc
}
type ecdsaSignFunc func([]byte, *ecdsa.PrivateKey) ([]byte, error)
// ECDSASigner uses crypto/ecdsa to sign the payloads.
type ECDSASigner struct {
alg jwa.SignatureAlgorithm
sign ecdsaSignFunc
}
type hmacSignFunc func([]byte, []byte) ([]byte, error)
// HMACSigner uses crypto/hmac to sign the payloads.
type HMACSigner struct {
alg jwa.SignatureAlgorithm
sign hmacSignFunc
}

View File

@@ -0,0 +1,97 @@
package sign
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
"github.com/pkg/errors"
)
var rsaSignFuncs = map[jwa.SignatureAlgorithm]rsaSignFunc{}
func init() {
algs := map[jwa.SignatureAlgorithm]struct {
Hash crypto.Hash
SignFunc func(crypto.Hash) rsaSignFunc
}{
jwa.RS256: {
Hash: crypto.SHA256,
SignFunc: makeSignPKCS1v15,
},
jwa.RS384: {
Hash: crypto.SHA384,
SignFunc: makeSignPKCS1v15,
},
jwa.RS512: {
Hash: crypto.SHA512,
SignFunc: makeSignPKCS1v15,
},
jwa.PS256: {
Hash: crypto.SHA256,
SignFunc: makeSignPSS,
},
jwa.PS384: {
Hash: crypto.SHA384,
SignFunc: makeSignPSS,
},
jwa.PS512: {
Hash: crypto.SHA512,
SignFunc: makeSignPSS,
},
}
for alg, item := range algs {
rsaSignFuncs[alg] = item.SignFunc(item.Hash)
}
}
func makeSignPKCS1v15(hash crypto.Hash) rsaSignFunc {
return rsaSignFunc(func(payload []byte, key *rsa.PrivateKey) ([]byte, error) {
h := hash.New()
h.Write(payload)
return rsa.SignPKCS1v15(rand.Reader, key, hash, h.Sum(nil))
})
}
func makeSignPSS(hash crypto.Hash) rsaSignFunc {
return rsaSignFunc(func(payload []byte, key *rsa.PrivateKey) ([]byte, error) {
h := hash.New()
h.Write(payload)
return rsa.SignPSS(rand.Reader, key, hash, h.Sum(nil), &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
})
})
}
func newRSA(alg jwa.SignatureAlgorithm) (*RSASigner, error) {
signfn, ok := rsaSignFuncs[alg]
if !ok {
return nil, errors.Errorf(`unsupported algorithm while trying to create RSA signer: %s`, alg)
}
return &RSASigner{
alg: alg,
sign: signfn,
}, nil
}
// Algorithm returns the signer algorithm
func (s RSASigner) Algorithm() jwa.SignatureAlgorithm {
return s.alg
}
// Sign creates a signature using crypto/rsa. key must be a non-nil instance of
// `*"crypto/rsa".PrivateKey`.
func (s RSASigner) Sign(payload []byte, key interface{}) ([]byte, error) {
if key == nil {
return nil, errors.New(`missing private key while signing payload`)
}
rsakey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, errors.Errorf(`invalid key type %T. *rsa.PrivateKey is required`, key)
}
return s.sign(payload, rsakey)
}

View File

@@ -0,0 +1,21 @@
package sign
import (
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
// New creates a signer that signs payloads using the given signature algorithm.
func New(alg jwa.SignatureAlgorithm) (Signer, error) {
switch alg {
case jwa.RS256, jwa.RS384, jwa.RS512, jwa.PS256, jwa.PS384, jwa.PS512:
return newRSA(alg)
case jwa.ES256, jwa.ES384, jwa.ES512:
return newECDSA(alg)
case jwa.HS256, jwa.HS384, jwa.HS512:
return newHMAC(alg)
default:
return nil, errors.Errorf(`unsupported signature algorithm %s`, alg)
}
}

View File

@@ -0,0 +1,67 @@
package verify
import (
"crypto"
"crypto/ecdsa"
"math/big"
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
var ecdsaVerifyFuncs = map[jwa.SignatureAlgorithm]ecdsaVerifyFunc{}
func init() {
algs := map[jwa.SignatureAlgorithm]crypto.Hash{
jwa.ES256: crypto.SHA256,
jwa.ES384: crypto.SHA384,
jwa.ES512: crypto.SHA512,
}
for alg, h := range algs {
ecdsaVerifyFuncs[alg] = makeECDSAVerifyFunc(h)
}
}
func makeECDSAVerifyFunc(hash crypto.Hash) ecdsaVerifyFunc {
return ecdsaVerifyFunc(func(payload []byte, signature []byte, key *ecdsa.PublicKey) error {
r, s := &big.Int{}, &big.Int{}
n := len(signature) / 2
r.SetBytes(signature[:n])
s.SetBytes(signature[n:])
h := hash.New()
h.Write(payload)
if !ecdsa.Verify(key, h.Sum(nil), r, s) {
return errors.New(`failed to verify signature using ecdsa`)
}
return nil
})
}
func newECDSA(alg jwa.SignatureAlgorithm) (*ECDSAVerifier, error) {
verifyfn, ok := ecdsaVerifyFuncs[alg]
if !ok {
return nil, errors.Errorf(`unsupported algorithm while trying to create ECDSA verifier: %s`, alg)
}
return &ECDSAVerifier{
verify: verifyfn,
}, nil
}
// Verify checks whether the signature for a given input and key is correct
func (v ECDSAVerifier) Verify(payload []byte, signature []byte, key interface{}) error {
if key == nil {
return errors.New(`missing public key while verifying payload`)
}
ecdsakey, ok := key.(*ecdsa.PublicKey)
if !ok {
return errors.Errorf(`invalid key type %T. *ecdsa.PublicKey is required`, key)
}
return v.verify(payload, signature, ecdsakey)
}

View File

@@ -0,0 +1,33 @@
package verify
import (
"crypto/hmac"
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign"
)
func newHMAC(alg jwa.SignatureAlgorithm) (*HMACVerifier, error) {
s, err := sign.New(alg)
if err != nil {
return nil, errors.Wrap(err, `failed to generate HMAC signer`)
}
return &HMACVerifier{signer: s}, nil
}
// Verify checks whether the signature for a given input and key is correct
func (v HMACVerifier) Verify(signingInput, signature []byte, key interface{}) (err error) {
expected, err := v.signer.Sign(signingInput, key)
if err != nil {
return errors.Wrap(err, `failed to generated signature`)
}
if !hmac.Equal(signature, expected) {
return errors.New(`failed to match hmac signature`)
}
return nil
}

View File

@@ -0,0 +1,39 @@
package verify
import (
"crypto/ecdsa"
"crypto/rsa"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws/sign"
)
// Verifier provides a common interface for supported alg verification methods
type Verifier interface {
// Verify checks whether the payload and signature are valid for
// the given key.
// `key` is the key used for verifying the payload, and is usually
// the public key associated with the signature method. For example,
// for `jwa.RSXXX` and `jwa.PSXXX` types, you need to pass the
// `*"crypto/rsa".PublicKey` type.
// Check the documentation for each verifier for details
Verify(payload []byte, signature []byte, key interface{}) error
}
type rsaVerifyFunc func([]byte, []byte, *rsa.PublicKey) error
// RSAVerifier implements the Verifier interface
type RSAVerifier struct {
verify rsaVerifyFunc
}
type ecdsaVerifyFunc func([]byte, []byte, *ecdsa.PublicKey) error
// ECDSAVerifier implements the Verifier interface
type ECDSAVerifier struct {
verify ecdsaVerifyFunc
}
// HMACVerifier implements the Verifier interface
type HMACVerifier struct {
signer sign.Signer
}

View File

@@ -0,0 +1,88 @@
package verify
import (
"crypto"
"crypto/rsa"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
"github.com/pkg/errors"
)
var rsaVerifyFuncs = map[jwa.SignatureAlgorithm]rsaVerifyFunc{}
func init() {
algs := map[jwa.SignatureAlgorithm]struct {
Hash crypto.Hash
VerifyFunc func(crypto.Hash) rsaVerifyFunc
}{
jwa.RS256: {
Hash: crypto.SHA256,
VerifyFunc: makeVerifyPKCS1v15,
},
jwa.RS384: {
Hash: crypto.SHA384,
VerifyFunc: makeVerifyPKCS1v15,
},
jwa.RS512: {
Hash: crypto.SHA512,
VerifyFunc: makeVerifyPKCS1v15,
},
jwa.PS256: {
Hash: crypto.SHA256,
VerifyFunc: makeVerifyPSS,
},
jwa.PS384: {
Hash: crypto.SHA384,
VerifyFunc: makeVerifyPSS,
},
jwa.PS512: {
Hash: crypto.SHA512,
VerifyFunc: makeVerifyPSS,
},
}
for alg, item := range algs {
rsaVerifyFuncs[alg] = item.VerifyFunc(item.Hash)
}
}
func makeVerifyPKCS1v15(hash crypto.Hash) rsaVerifyFunc {
return rsaVerifyFunc(func(payload, signature []byte, key *rsa.PublicKey) error {
h := hash.New()
h.Write(payload)
return rsa.VerifyPKCS1v15(key, hash, h.Sum(nil), signature)
})
}
func makeVerifyPSS(hash crypto.Hash) rsaVerifyFunc {
return rsaVerifyFunc(func(payload, signature []byte, key *rsa.PublicKey) error {
h := hash.New()
h.Write(payload)
return rsa.VerifyPSS(key, hash, h.Sum(nil), signature, nil)
})
}
func newRSA(alg jwa.SignatureAlgorithm) (*RSAVerifier, error) {
verifyfn, ok := rsaVerifyFuncs[alg]
if !ok {
return nil, errors.Errorf(`unsupported algorithm while trying to create RSA verifier: %s`, alg)
}
return &RSAVerifier{
verify: verifyfn,
}, nil
}
// Verify checks if a JWS is valid.
func (v RSAVerifier) Verify(payload, signature []byte, key interface{}) error {
if key == nil {
return errors.New(`missing public key while verifying payload`)
}
rsaKey, ok := key.(*rsa.PublicKey)
if !ok {
return errors.Errorf(`invalid key type %T. *rsa.PublicKey is required`, key)
}
return v.verify(payload, signature, rsaKey)
}

View File

@@ -0,0 +1,22 @@
package verify
import (
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwa"
)
// New creates a new JWS verifier using the specified algorithm
// and the public key
func New(alg jwa.SignatureAlgorithm) (Verifier, error) {
switch alg {
case jwa.RS256, jwa.RS384, jwa.RS512, jwa.PS256, jwa.PS384, jwa.PS512:
return newRSA(alg)
case jwa.ES256, jwa.ES384, jwa.ES512:
return newECDSA(alg)
case jwa.HS256, jwa.HS384, jwa.HS512:
return newHMAC(alg)
default:
return nil, errors.Errorf(`unsupported signature algorithm: %s`, alg)
}
}

235
vendor/github.com/open-policy-agent/opa/topdown/json.go generated vendored Normal file
View File

@@ -0,0 +1,235 @@
// Copyright 2019 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"fmt"
"strconv"
"strings"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
func builtinJSONRemove(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
// Expect an object and a string or array/set of strings
_, err := builtins.ObjectOperand(operands[0].Value, 1)
if err != nil {
return err
}
// Build a list of json pointers to remove
paths, err := getJSONPaths(operands[1].Value)
if err != nil {
return err
}
newObj, err := jsonRemove(operands[0], ast.NewTerm(pathsToObject(paths)))
if err != nil {
return err
}
if newObj == nil {
return nil
}
return iter(newObj)
}
// jsonRemove returns a new term that is the result of walking
// through a and omitting removing any values that are in b but
// have ast.Null values (ie leaf nodes for b).
func jsonRemove(a *ast.Term, b *ast.Term) (*ast.Term, error) {
if b == nil {
// The paths diverged, return a
return a, nil
}
var bObj ast.Object
switch bValue := b.Value.(type) {
case ast.Object:
bObj = bValue
case ast.Null:
// Means we hit a leaf node on "b", dont add the value for a
return nil, nil
default:
// The paths diverged, return a
return a, nil
}
switch aValue := a.Value.(type) {
case ast.String, ast.Number, ast.Boolean, ast.Null:
return a, nil
case ast.Object:
newObj := ast.NewObject()
err := aValue.Iter(func(k *ast.Term, v *ast.Term) error {
// recurse and add the diff of sub objects as needed
diffValue, err := jsonRemove(v, bObj.Get(k))
if err != nil || diffValue == nil {
return err
}
newObj.Insert(k, diffValue)
return nil
})
if err != nil {
return nil, err
}
return ast.NewTerm(newObj), nil
case ast.Set:
newSet := ast.NewSet()
err := aValue.Iter(func(v *ast.Term) error {
// recurse and add the diff of sub objects as needed
diffValue, err := jsonRemove(v, bObj.Get(v))
if err != nil || diffValue == nil {
return err
}
newSet.Add(diffValue)
return nil
})
if err != nil {
return nil, err
}
return ast.NewTerm(newSet), nil
case ast.Array:
// When indexes are removed we shift left to close empty spots in the array
// as per the JSON patch spec.
var newArray ast.Array
for i, v := range aValue {
// recurse and add the diff of sub objects as needed
// Note: Keys in b will be strings for the index, eg path /a/1/b => {"a": {"1": {"b": null}}}
diffValue, err := jsonRemove(v, bObj.Get(ast.StringTerm(strconv.Itoa(i))))
if err != nil {
return nil, err
}
if diffValue != nil {
newArray = append(newArray, diffValue)
}
}
return ast.NewTerm(newArray), nil
default:
return nil, fmt.Errorf("invalid value type %T", a)
}
}
func builtinJSONFilter(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
// Ensure we have the right parameters, expect an object and a string or array/set of strings
obj, err := builtins.ObjectOperand(operands[0].Value, 1)
if err != nil {
return err
}
// Build a list of filter strings
filters, err := getJSONPaths(operands[1].Value)
if err != nil {
return err
}
// Actually do the filtering
filterObj := pathsToObject(filters)
r, err := obj.Filter(filterObj)
if err != nil {
return err
}
return iter(ast.NewTerm(r))
}
func getJSONPaths(operand ast.Value) ([]ast.Ref, error) {
var paths []ast.Ref
switch v := operand.(type) {
case ast.Array:
for _, f := range v {
filter, err := parsePath(f)
if err != nil {
return nil, err
}
paths = append(paths, filter)
}
case ast.Set:
err := v.Iter(func(f *ast.Term) error {
filter, err := parsePath(f)
if err != nil {
return err
}
paths = append(paths, filter)
return nil
})
if err != nil {
return nil, err
}
default:
return nil, builtins.NewOperandTypeErr(2, v, "set", "array")
}
return paths, nil
}
func parsePath(path *ast.Term) (ast.Ref, error) {
// paths can either be a `/` separated json path or
// an array or set of values
var pathSegments ast.Ref
switch p := path.Value.(type) {
case ast.String:
parts := strings.Split(strings.Trim(string(p), "/"), "/")
for _, part := range parts {
part = strings.ReplaceAll(strings.ReplaceAll(part, "~1", "/"), "~0", "~")
pathSegments = append(pathSegments, ast.StringTerm(part))
}
case ast.Array:
for _, term := range p {
pathSegments = append(pathSegments, term)
}
default:
return nil, builtins.NewOperandErr(2, "must be one of {set, array} containing string paths or array of path segments but got %v", ast.TypeName(p))
}
return pathSegments, nil
}
func pathsToObject(paths []ast.Ref) ast.Object {
root := ast.NewObject()
for _, path := range paths {
node := root
var done bool
for i := 0; i < len(path)-1 && !done; i++ {
k := path[i]
child := node.Get(k)
if child == nil {
obj := ast.NewObject()
node.Insert(k, ast.NewTerm(obj))
node = obj
continue
}
switch v := child.Value.(type) {
case ast.Null:
done = true
case ast.Object:
node = v
default:
panic("unreachable")
}
}
if !done {
node.Insert(path[len(path)-1], ast.NullTerm())
}
}
return root
}
func init() {
RegisterBuiltinFunc(ast.JSONFilter.Name, builtinJSONFilter)
RegisterBuiltinFunc(ast.JSONRemove.Name, builtinJSONRemove)
}

View File

@@ -0,0 +1,139 @@
// Copyright 2020 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
"github.com/open-policy-agent/opa/types"
)
func builtinObjectUnion(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
objA, err := builtins.ObjectOperand(operands[0].Value, 1)
if err != nil {
return err
}
objB, err := builtins.ObjectOperand(operands[1].Value, 2)
if err != nil {
return err
}
r := mergeWithOverwrite(objA, objB)
return iter(ast.NewTerm(r))
}
func builtinObjectRemove(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
// Expect an object and an array/set/object of keys
obj, err := builtins.ObjectOperand(operands[0].Value, 1)
if err != nil {
return err
}
// Build a set of keys to remove
keysToRemove, err := getObjectKeysParam(operands[1].Value)
if err != nil {
return err
}
r := ast.NewObject()
obj.Foreach(func(key *ast.Term, value *ast.Term) {
if !keysToRemove.Contains(key) {
r.Insert(key, value)
}
})
return iter(ast.NewTerm(r))
}
func builtinObjectFilter(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
// Expect an object and an array/set/object of keys
obj, err := builtins.ObjectOperand(operands[0].Value, 1)
if err != nil {
return err
}
// Build a new object from the supplied filter keys
keys, err := getObjectKeysParam(operands[1].Value)
if err != nil {
return err
}
filterObj := ast.NewObject()
keys.Foreach(func(key *ast.Term) {
filterObj.Insert(key, ast.NullTerm())
})
// Actually do the filtering
r, err := obj.Filter(filterObj)
if err != nil {
return err
}
return iter(ast.NewTerm(r))
}
func builtinObjectGet(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
object, err := builtins.ObjectOperand(operands[0].Value, 1)
if err != nil {
return err
}
if ret := object.Get(operands[1]); ret != nil {
return iter(ret)
}
return iter(operands[2])
}
// getObjectKeysParam returns a set of key values
// from a supplied ast array, object, set value
func getObjectKeysParam(arrayOrSet ast.Value) (ast.Set, error) {
keys := ast.NewSet()
switch v := arrayOrSet.(type) {
case ast.Array:
for _, f := range v {
keys.Add(f)
}
case ast.Set:
_ = v.Iter(func(f *ast.Term) error {
keys.Add(f)
return nil
})
case ast.Object:
_ = v.Iter(func(k *ast.Term, _ *ast.Term) error {
keys.Add(k)
return nil
})
default:
return nil, builtins.NewOperandTypeErr(2, arrayOrSet, ast.TypeName(types.Object{}), ast.TypeName(types.S), ast.TypeName(types.Array{}))
}
return keys, nil
}
func mergeWithOverwrite(objA, objB ast.Object) ast.Object {
merged, _ := objA.MergeWith(objB, func(v1, v2 *ast.Term) (*ast.Term, bool) {
originalValueObj, ok2 := v1.Value.(ast.Object)
updateValueObj, ok1 := v2.Value.(ast.Object)
if !ok1 || !ok2 {
// If we can't merge, stick with the right-hand value
return v2, false
}
// Recursively update the existing value
merged := mergeWithOverwrite(originalValueObj, updateValueObj)
return ast.NewTerm(merged), false
})
return merged
}
func init() {
RegisterBuiltinFunc(ast.ObjectUnion.Name, builtinObjectUnion)
RegisterBuiltinFunc(ast.ObjectRemove.Name, builtinObjectRemove)
RegisterBuiltinFunc(ast.ObjectFilter.Name, builtinObjectFilter)
RegisterBuiltinFunc(ast.ObjectGet.Name, builtinObjectGet)
}

View File

@@ -0,0 +1,47 @@
// Copyright 2018 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"bytes"
"encoding/json"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
func builtinRegoParseModule(a, b ast.Value) (ast.Value, error) {
filename, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
input, err := builtins.StringOperand(b, 1)
if err != nil {
return nil, err
}
module, err := ast.ParseModule(string(filename), string(input))
if err != nil {
return nil, err
}
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(module); err != nil {
return nil, err
}
term, err := ast.ParseTerm(buf.String())
if err != nil {
return nil, err
}
return term.Value, nil
}
func init() {
RegisterFunctionalBuiltin2(ast.RegoParseModule.Name, builtinRegoParseModule)
}

View File

@@ -0,0 +1,153 @@
// Copyright 2016 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"fmt"
"math/big"
"strconv"
"strings"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
const (
none int64 = 1
kb = 1000
ki = 1024
mb = kb * 1000
mi = ki * 1024
gb = mb * 1000
gi = mi * 1024
tb = gb * 1000
ti = gi * 1024
)
// The rune values for 0..9 as well as the period symbol (for parsing floats)
var numRunes = []rune("0123456789.")
func parseNumBytesError(msg string) error {
return fmt.Errorf("%s error: %s", ast.UnitsParseBytes.Name, msg)
}
func errUnitNotRecognized(unit string) error {
return parseNumBytesError(fmt.Sprintf("byte unit %s not recognized", unit))
}
var (
errNoAmount = parseNumBytesError("no byte amount provided")
errIntConv = parseNumBytesError("could not parse byte amount to integer")
errIncludesSpaces = parseNumBytesError("spaces not allowed in resource strings")
)
func builtinNumBytes(a ast.Value) (ast.Value, error) {
var m int64
raw, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
s := formatString(raw)
if strings.Contains(s, " ") {
return nil, errIncludesSpaces
}
numStr, unitStr := extractNumAndUnit(s)
if numStr == "" {
return nil, errNoAmount
}
switch unitStr {
case "":
m = none
case "kb":
m = kb
case "kib":
m = ki
case "mb":
m = mb
case "mib":
m = mi
case "gb":
m = gb
case "gib":
m = gi
case "tb":
m = tb
case "tib":
m = ti
default:
return nil, errUnitNotRecognized(unitStr)
}
num, err := strconv.ParseInt(numStr, 10, 64)
if err != nil {
return nil, errIntConv
}
total := num * m
return builtins.IntToNumber(big.NewInt(total)), nil
}
// Makes the string lower case and removes spaces and quotation marks
func formatString(s ast.String) string {
str := string(s)
lower := strings.ToLower(str)
return strings.Replace(lower, "\"", "", -1)
}
// Splits the string into a number string à la "10" or "10.2" and a unit string à la "gb" or "MiB" or "foo". Either
// can be an empty string (error handling is provided elsewhere).
func extractNumAndUnit(s string) (string, string) {
isNum := func(r rune) (isNum bool) {
for _, nr := range numRunes {
if nr == r {
return true
}
}
return false
}
// Returns the index of the first rune that's not a number (or 0 if there are only numbers)
getFirstNonNumIdx := func(s string) int {
for idx, r := range s {
if !isNum(r) {
return idx
}
}
return 0
}
firstRuneIsNum := func(s string) bool {
return isNum(rune(s[0]))
}
firstNonNumIdx := getFirstNonNumIdx(s)
// The string contains only a number
numOnly := firstNonNumIdx == 0 && firstRuneIsNum(s)
// The string contains only a unit
unitOnly := firstNonNumIdx == 0 && !firstRuneIsNum(s)
if numOnly {
return s, ""
} else if unitOnly {
return "", s
} else {
return s[0:firstNonNumIdx], s[firstNonNumIdx:]
}
}
func init() {
RegisterFunctionalBuiltin1(ast.UnitsParseBytes.Name, builtinNumBytes)
}

View File

@@ -0,0 +1,317 @@
package topdown
import (
"context"
"sort"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/metrics"
"github.com/open-policy-agent/opa/storage"
"github.com/open-policy-agent/opa/topdown/builtins"
"github.com/open-policy-agent/opa/topdown/copypropagation"
)
// QueryResultSet represents a collection of results returned by a query.
type QueryResultSet []QueryResult
// QueryResult represents a single result returned by a query. The result
// contains bindings for all variables that appear in the query.
type QueryResult map[ast.Var]*ast.Term
// Query provides a configurable interface for performing query evaluation.
type Query struct {
cancel Cancel
query ast.Body
queryCompiler ast.QueryCompiler
compiler *ast.Compiler
store storage.Store
txn storage.Transaction
input *ast.Term
tracers []Tracer
unknowns []*ast.Term
partialNamespace string
metrics metrics.Metrics
instr *Instrumentation
disableInlining []ast.Ref
genvarprefix string
runtime *ast.Term
builtins map[string]*Builtin
indexing bool
}
// Builtin represents a built-in function that queries can call.
type Builtin struct {
Decl *ast.Builtin
Func BuiltinFunc
}
// NewQuery returns a new Query object that can be run.
func NewQuery(query ast.Body) *Query {
return &Query{
query: query,
genvarprefix: ast.WildcardPrefix,
indexing: true,
}
}
// WithQueryCompiler sets the queryCompiler used for the query.
func (q *Query) WithQueryCompiler(queryCompiler ast.QueryCompiler) *Query {
q.queryCompiler = queryCompiler
return q
}
// WithCompiler sets the compiler to use for the query.
func (q *Query) WithCompiler(compiler *ast.Compiler) *Query {
q.compiler = compiler
return q
}
// WithStore sets the store to use for the query.
func (q *Query) WithStore(store storage.Store) *Query {
q.store = store
return q
}
// WithTransaction sets the transaction to use for the query. All queries
// should be performed over a consistent snapshot of the storage layer.
func (q *Query) WithTransaction(txn storage.Transaction) *Query {
q.txn = txn
return q
}
// WithCancel sets the cancellation object to use for the query. Set this if
// you need to abort queries based on a deadline. This is optional.
func (q *Query) WithCancel(cancel Cancel) *Query {
q.cancel = cancel
return q
}
// WithInput sets the input object to use for the query. References rooted at
// input will be evaluated against this value. This is optional.
func (q *Query) WithInput(input *ast.Term) *Query {
q.input = input
return q
}
// WithTracer adds a query tracer to use during evaluation. This is optional.
func (q *Query) WithTracer(tracer Tracer) *Query {
q.tracers = append(q.tracers, tracer)
return q
}
// WithMetrics sets the metrics collection to add evaluation metrics to. This
// is optional.
func (q *Query) WithMetrics(m metrics.Metrics) *Query {
q.metrics = m
return q
}
// WithInstrumentation sets the instrumentation configuration to enable on the
// evaluation process. By default, instrumentation is turned off.
func (q *Query) WithInstrumentation(instr *Instrumentation) *Query {
q.instr = instr
return q
}
// WithUnknowns sets the initial set of variables or references to treat as
// unknown during query evaluation. This is required for partial evaluation.
func (q *Query) WithUnknowns(terms []*ast.Term) *Query {
q.unknowns = terms
return q
}
// WithPartialNamespace sets the namespace to use for supporting rules
// generated as part of the partial evaluation process. The ns value must be a
// valid package path component.
func (q *Query) WithPartialNamespace(ns string) *Query {
q.partialNamespace = ns
return q
}
// WithDisableInlining adds a set of paths to the query that should be excluded from
// inlining. Inlining during partial evaluation can be expensive in some cases
// (e.g., when a cross-product is computed.) Disabling inlining avoids expensive
// computation at the cost of generating support rules.
func (q *Query) WithDisableInlining(paths []ast.Ref) *Query {
q.disableInlining = paths
return q
}
// WithRuntime sets the runtime data to execute the query with. The runtime data
// can be returned by the `opa.runtime` built-in function.
func (q *Query) WithRuntime(runtime *ast.Term) *Query {
q.runtime = runtime
return q
}
// WithBuiltins adds a set of built-in functions that can be called by the
// query.
func (q *Query) WithBuiltins(builtins map[string]*Builtin) *Query {
q.builtins = builtins
return q
}
// WithIndexing will enable or disable using rule indexing for the evaluation
// of the query. The default is enabled.
func (q *Query) WithIndexing(enabled bool) *Query {
q.indexing = enabled
return q
}
// PartialRun executes partial evaluation on the query with respect to unknown
// values. Partial evaluation attempts to evaluate as much of the query as
// possible without requiring values for the unknowns set on the query. The
// result of partial evaluation is a new set of queries that can be evaluated
// once the unknown value is known. In addition to new queries, partial
// evaluation may produce additional support modules that should be used in
// conjunction with the partially evaluated queries.
func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support []*ast.Module, err error) {
if q.partialNamespace == "" {
q.partialNamespace = "partial" // lazily initialize partial namespace
}
f := &queryIDFactory{}
b := newBindings(0, q.instr)
e := &eval{
ctx: ctx,
cancel: q.cancel,
query: q.query,
queryCompiler: q.queryCompiler,
queryIDFact: f,
queryID: f.Next(),
bindings: b,
compiler: q.compiler,
store: q.store,
baseCache: newBaseCache(),
targetStack: newRefStack(),
txn: q.txn,
input: q.input,
tracers: q.tracers,
instr: q.instr,
builtins: q.builtins,
builtinCache: builtins.Cache{},
virtualCache: newVirtualCache(),
saveSet: newSaveSet(q.unknowns, b, q.instr),
saveStack: newSaveStack(),
saveSupport: newSaveSupport(),
saveNamespace: ast.StringTerm(q.partialNamespace),
genvarprefix: q.genvarprefix,
runtime: q.runtime,
indexing: q.indexing,
}
if len(q.disableInlining) > 0 {
e.disableInlining = [][]ast.Ref{q.disableInlining}
}
e.caller = e
q.startTimer(metrics.RegoPartialEval)
defer q.stopTimer(metrics.RegoPartialEval)
livevars := ast.NewVarSet()
ast.WalkVars(q.query, func(x ast.Var) bool {
if !x.IsGenerated() {
livevars.Add(x)
}
return false
})
p := copypropagation.New(livevars)
err = e.Run(func(e *eval) error {
// Build output from saved expressions.
body := ast.NewBody()
for _, elem := range e.saveStack.Stack[len(e.saveStack.Stack)-1] {
body.Append(elem.Plug(e.bindings))
}
// Include bindings as exprs so that when caller evals the result, they
// can obtain values for the vars in their query.
bindingExprs := []*ast.Expr{}
e.bindings.Iter(e.bindings, func(a, b *ast.Term) error {
bindingExprs = append(bindingExprs, ast.Equality.Expr(a, b))
return nil
})
// Sort binding expressions so that results are deterministic.
sort.Slice(bindingExprs, func(i, j int) bool {
return bindingExprs[i].Compare(bindingExprs[j]) < 0
})
for i := range bindingExprs {
body.Append(bindingExprs[i])
}
partials = append(partials, applyCopyPropagation(p, e.instr, body))
return nil
})
support = e.saveSupport.List()
return partials, support, err
}
// Run is a wrapper around Iter that accumulates query results and returns them
// in one shot.
func (q *Query) Run(ctx context.Context) (QueryResultSet, error) {
qrs := QueryResultSet{}
return qrs, q.Iter(ctx, func(qr QueryResult) error {
qrs = append(qrs, qr)
return nil
})
}
// Iter executes the query and invokes the iter function with query results
// produced by evaluating the query.
func (q *Query) Iter(ctx context.Context, iter func(QueryResult) error) error {
f := &queryIDFactory{}
e := &eval{
ctx: ctx,
cancel: q.cancel,
query: q.query,
queryCompiler: q.queryCompiler,
queryIDFact: f,
queryID: f.Next(),
bindings: newBindings(0, q.instr),
compiler: q.compiler,
store: q.store,
baseCache: newBaseCache(),
targetStack: newRefStack(),
txn: q.txn,
input: q.input,
tracers: q.tracers,
instr: q.instr,
builtins: q.builtins,
builtinCache: builtins.Cache{},
virtualCache: newVirtualCache(),
genvarprefix: q.genvarprefix,
runtime: q.runtime,
indexing: q.indexing,
}
e.caller = e
q.startTimer(metrics.RegoQueryEval)
err := e.Run(func(e *eval) error {
qr := QueryResult{}
e.bindings.Iter(nil, func(k, v *ast.Term) error {
qr[k.Value.(ast.Var)] = v
return nil
})
return iter(qr)
})
q.stopTimer(metrics.RegoQueryEval)
return err
}
func (q *Query) startTimer(name string) {
if q.metrics != nil {
q.metrics.Timer(name).Start()
}
}
func (q *Query) stopTimer(name string) {
if q.metrics != nil {
q.metrics.Timer(name).Stop()
}
}

View File

@@ -0,0 +1,201 @@
// Copyright 2016 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"fmt"
"regexp"
"sync"
"github.com/yashtewari/glob-intersection"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
var regexpCacheLock = sync.Mutex{}
var regexpCache map[string]*regexp.Regexp
func builtinRegexMatch(a, b ast.Value) (ast.Value, error) {
s1, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
s2, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
re, err := getRegexp(string(s1))
if err != nil {
return nil, err
}
return ast.Boolean(re.Match([]byte(s2))), nil
}
func builtinRegexMatchTemplate(a, b, c, d ast.Value) (ast.Value, error) {
pattern, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
match, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
start, err := builtins.StringOperand(c, 3)
if err != nil {
return nil, err
}
end, err := builtins.StringOperand(d, 4)
if err != nil {
return nil, err
}
if len(start) != 1 {
return nil, fmt.Errorf("start delimiter has to be exactly one character long but is %d long", len(start))
}
if len(end) != 1 {
return nil, fmt.Errorf("end delimiter has to be exactly one character long but is %d long", len(start))
}
re, err := getRegexpTemplate(string(pattern), string(start)[0], string(end)[0])
if err != nil {
return nil, err
}
return ast.Boolean(re.MatchString(string(match))), nil
}
func builtinRegexSplit(a, b ast.Value) (ast.Value, error) {
s1, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
s2, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
re, err := getRegexp(string(s1))
if err != nil {
return nil, err
}
elems := re.Split(string(s2), -1)
arr := make(ast.Array, len(elems))
for i := range arr {
arr[i] = ast.StringTerm(elems[i])
}
return arr, nil
}
func getRegexp(pat string) (*regexp.Regexp, error) {
regexpCacheLock.Lock()
defer regexpCacheLock.Unlock()
re, ok := regexpCache[pat]
if !ok {
var err error
re, err = regexp.Compile(string(pat))
if err != nil {
return nil, err
}
regexpCache[pat] = re
}
return re, nil
}
func getRegexpTemplate(pat string, delimStart, delimEnd byte) (*regexp.Regexp, error) {
regexpCacheLock.Lock()
defer regexpCacheLock.Unlock()
re, ok := regexpCache[pat]
if !ok {
var err error
re, err = compileRegexTemplate(string(pat), delimStart, delimEnd)
if err != nil {
return nil, err
}
regexpCache[pat] = re
}
return re, nil
}
func builtinGlobsMatch(a, b ast.Value) (ast.Value, error) {
s1, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
s2, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
ne, err := gintersect.NonEmpty(string(s1), string(s2))
if err != nil {
return nil, err
}
return ast.Boolean(ne), nil
}
func builtinRegexFind(a, b, c ast.Value) (ast.Value, error) {
s1, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
s2, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
n, err := builtins.IntOperand(c, 3)
if err != nil {
return nil, err
}
re, err := getRegexp(string(s1))
if err != nil {
return nil, err
}
elems := re.FindAllString(string(s2), n)
arr := make(ast.Array, len(elems))
for i := range arr {
arr[i] = ast.StringTerm(elems[i])
}
return arr, nil
}
func builtinRegexFindAllStringSubmatch(a, b, c ast.Value) (ast.Value, error) {
s1, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
s2, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
n, err := builtins.IntOperand(c, 3)
if err != nil {
return nil, err
}
re, err := getRegexp(string(s1))
if err != nil {
return nil, err
}
matches := re.FindAllStringSubmatch(string(s2), n)
outer := make(ast.Array, len(matches))
for i := range outer {
inner := make(ast.Array, len(matches[i]))
for j := range inner {
inner[j] = ast.StringTerm(matches[i][j])
}
outer[i] = ast.ArrayTerm(inner...)
}
return outer, nil
}
func init() {
regexpCache = map[string]*regexp.Regexp{}
RegisterFunctionalBuiltin2(ast.RegexMatch.Name, builtinRegexMatch)
RegisterFunctionalBuiltin2(ast.RegexSplit.Name, builtinRegexSplit)
RegisterFunctionalBuiltin2(ast.GlobsMatch.Name, builtinGlobsMatch)
RegisterFunctionalBuiltin4(ast.RegexTemplateMatch.Name, builtinRegexMatchTemplate)
RegisterFunctionalBuiltin3(ast.RegexFind.Name, builtinRegexFind)
RegisterFunctionalBuiltin3(ast.RegexFindAllStringSubmatch.Name, builtinRegexFindAllStringSubmatch)
}

View File

@@ -0,0 +1,122 @@
package topdown
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license as follows:
// Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// This file was forked from https://github.com/gorilla/mux/commit/eac83ba2c004bb75
import (
"bytes"
"fmt"
"regexp"
)
// delimiterIndices returns the first level delimiter indices from a string.
// It returns an error in case of unbalanced delimiters.
func delimiterIndices(s string, delimiterStart, delimiterEnd byte) ([]int, error) {
var level, idx int
idxs := make([]int, 0)
for i := 0; i < len(s); i++ {
switch s[i] {
case delimiterStart:
if level++; level == 1 {
idx = i
}
case delimiterEnd:
if level--; level == 0 {
idxs = append(idxs, idx, i+1)
} else if level < 0 {
return nil, fmt.Errorf(`unbalanced braces in %q`, s)
}
}
}
if level != 0 {
return nil, fmt.Errorf(`unbalanced braces in %q`, s)
}
return idxs, nil
}
// compileRegexTemplate parses a template and returns a Regexp.
//
// You can define your own delimiters. It is e.g. common to use curly braces {} but I recommend using characters
// which have no special meaning in Regex, e.g.: <, >
//
// reg, err := compiler.CompileRegex("foo:bar.baz:<[0-9]{2,10}>", '<', '>')
// // if err != nil ...
// reg.MatchString("foo:bar.baz:123")
func compileRegexTemplate(tpl string, delimiterStart, delimiterEnd byte) (*regexp.Regexp, error) {
// Check if it is well-formed.
idxs, errBraces := delimiterIndices(tpl, delimiterStart, delimiterEnd)
if errBraces != nil {
return nil, errBraces
}
varsR := make([]*regexp.Regexp, len(idxs)/2)
pattern := bytes.NewBufferString("")
// WriteByte's error value is always nil for bytes.Buffer, no need to check it.
pattern.WriteByte('^')
var end int
var err error
for i := 0; i < len(idxs); i += 2 {
// Set all values we are interested in.
raw := tpl[end:idxs[i]]
end = idxs[i+1]
patt := tpl[idxs[i]+1 : end-1]
// Build the regexp pattern.
varIdx := i / 2
fmt.Fprintf(pattern, "%s(%s)", regexp.QuoteMeta(raw), patt)
varsR[varIdx], err = regexp.Compile(fmt.Sprintf("^%s$", patt))
if err != nil {
return nil, err
}
}
// Add the remaining.
raw := tpl[end:]
// WriteString's error value is always nil for bytes.Buffer, no need to check it.
pattern.WriteString(regexp.QuoteMeta(raw))
// WriteByte's error value is always nil for bytes.Buffer, no need to check it.
pattern.WriteByte('$')
// Compile full regexp.
reg, errCompile := regexp.Compile(pattern.String())
if errCompile != nil {
return nil, errCompile
}
return reg, nil
}

View File

@@ -0,0 +1,20 @@
// Copyright 2018 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import "github.com/open-policy-agent/opa/ast"
func builtinOPARuntime(bctx BuiltinContext, _ []*ast.Term, iter func(*ast.Term) error) error {
if bctx.Runtime == nil {
return iter(ast.ObjectTerm())
}
return iter(bctx.Runtime)
}
func init() {
RegisterBuiltinFunc(ast.OPARuntime.Name, builtinOPARuntime)
}

382
vendor/github.com/open-policy-agent/opa/topdown/save.go generated vendored Normal file
View File

@@ -0,0 +1,382 @@
package topdown
import (
"container/list"
"fmt"
"strings"
"github.com/open-policy-agent/opa/ast"
)
// saveSet contains a stack of terms that are considered 'unknown' during
// partial evaluation. Only var and ref terms (rooted at one of the root
// documents) can be added to the save set. Vars added to the save set are
// namespaced by the binding list they are added with. This means the save set
// can be shared across queries.
type saveSet struct {
instr *Instrumentation
l *list.List
}
func newSaveSet(ts []*ast.Term, b *bindings, instr *Instrumentation) *saveSet {
ss := &saveSet{
l: list.New(),
instr: instr,
}
ss.Push(ts, b)
return ss
}
func (ss *saveSet) Push(ts []*ast.Term, b *bindings) {
ss.l.PushBack(newSaveSetElem(ts, b))
}
func (ss *saveSet) Pop() {
ss.l.Remove(ss.l.Back())
}
// Contains returns true if the term t is contained in the save set. Non-var and
// non-ref terms are never contained. Ref terms are contained if they share a
// prefix with a ref that was added (in either direction).
func (ss *saveSet) Contains(t *ast.Term, b *bindings) bool {
if ss != nil {
ss.instr.startTimer(partialOpSaveSetContains)
ret := ss.contains(t, b)
ss.instr.stopTimer(partialOpSaveSetContains)
return ret
}
return false
}
func (ss *saveSet) contains(t *ast.Term, b *bindings) bool {
for el := ss.l.Back(); el != nil; el = el.Prev() {
if el.Value.(*saveSetElem).Contains(t, b) {
return true
}
}
return false
}
// ContainsRecursive retruns true if the term t is or contains a term that is
// contained in the save set. This function will close over the binding list
// when it encounters vars.
func (ss *saveSet) ContainsRecursive(t *ast.Term, b *bindings) bool {
if ss != nil {
ss.instr.startTimer(partialOpSaveSetContainsRec)
ret := ss.containsrec(t, b)
ss.instr.stopTimer(partialOpSaveSetContainsRec)
return ret
}
return false
}
func (ss *saveSet) containsrec(t *ast.Term, b *bindings) bool {
var found bool
ast.WalkTerms(t, func(x *ast.Term) bool {
if _, ok := x.Value.(ast.Var); ok {
x1, b1 := b.apply(x)
if x1 != x || b1 != b {
if ss.containsrec(x1, b1) {
found = true
}
} else if ss.contains(x1, b1) {
found = true
}
}
return found
})
return found
}
func (ss *saveSet) Vars(caller *bindings) ast.VarSet {
result := ast.NewVarSet()
for x := ss.l.Front(); x != nil; x = x.Next() {
elem := x.Value.(*saveSetElem)
for _, v := range elem.vars {
if v, ok := elem.b.PlugNamespaced(v, caller).Value.(ast.Var); ok {
result.Add(v)
}
}
}
return result
}
func (ss *saveSet) String() string {
var buf []string
for x := ss.l.Front(); x != nil; x = x.Next() {
buf = append(buf, x.Value.(*saveSetElem).String())
}
return "(" + strings.Join(buf, " ") + ")"
}
type saveSetElem struct {
refs []ast.Ref
vars []*ast.Term
b *bindings
}
func newSaveSetElem(ts []*ast.Term, b *bindings) *saveSetElem {
var refs []ast.Ref
var vars []*ast.Term
for _, t := range ts {
switch v := t.Value.(type) {
case ast.Var:
vars = append(vars, t)
case ast.Ref:
refs = append(refs, v)
default:
panic("illegal value")
}
}
return &saveSetElem{
b: b,
vars: vars,
refs: refs,
}
}
func (sse *saveSetElem) Contains(t *ast.Term, b *bindings) bool {
switch other := t.Value.(type) {
case ast.Var:
return sse.containsVar(t, b)
case ast.Ref:
for _, ref := range sse.refs {
if ref.HasPrefix(other) || other.HasPrefix(ref) {
return true
}
}
return sse.containsVar(other[0], b)
}
return false
}
func (sse *saveSetElem) String() string {
return fmt.Sprintf("(refs: %v, vars: %v, b: %v)", sse.refs, sse.vars, sse.b)
}
func (sse *saveSetElem) containsVar(t *ast.Term, b *bindings) bool {
if b == sse.b {
for _, v := range sse.vars {
if v.Equal(t) {
return true
}
}
}
return false
}
// saveStack contains a stack of queries that represent the result of partial
// evaluation. When partial evaluation completes, the top of the stack
// represents a complete, partially evaluated query that can be saved and
// evaluated later.
//
// The result is stored in a stack so that partial evaluation of a query can be
// paused and then resumed in cases where different queries make up the result
// of partial evaluation, such as when a rule with a default clause is
// partially evaluated. In this case, the partially evaluated rule will be
// output in the support module.
type saveStack struct {
Stack []saveStackQuery
}
func newSaveStack() *saveStack {
return &saveStack{
Stack: []saveStackQuery{
{},
},
}
}
func (s *saveStack) PushQuery(query saveStackQuery) {
s.Stack = append(s.Stack, query)
}
func (s *saveStack) PopQuery() saveStackQuery {
last := s.Stack[len(s.Stack)-1]
s.Stack = s.Stack[:len(s.Stack)-1]
return last
}
func (s *saveStack) Peek() saveStackQuery {
return s.Stack[len(s.Stack)-1]
}
func (s *saveStack) Push(expr *ast.Expr, b1 *bindings, b2 *bindings) {
idx := len(s.Stack) - 1
s.Stack[idx] = append(s.Stack[idx], saveStackElem{expr, b1, b2})
}
func (s *saveStack) Pop() {
idx := len(s.Stack) - 1
query := s.Stack[idx]
s.Stack[idx] = query[:len(query)-1]
}
type saveStackQuery []saveStackElem
func (s saveStackQuery) Plug(b *bindings) ast.Body {
if len(s) == 0 {
return ast.NewBody(ast.NewExpr(ast.BooleanTerm(true)))
}
result := make(ast.Body, len(s))
for i := range s {
expr := s[i].Plug(b)
result.Set(expr, i)
}
return result
}
type saveStackElem struct {
Expr *ast.Expr
B1 *bindings
B2 *bindings
}
func (e saveStackElem) Plug(caller *bindings) *ast.Expr {
if e.B1 == nil && e.B2 == nil {
return e.Expr
}
expr := e.Expr.Copy()
switch terms := expr.Terms.(type) {
case []*ast.Term:
if expr.IsEquality() {
terms[1] = e.B1.PlugNamespaced(terms[1], caller)
terms[2] = e.B2.PlugNamespaced(terms[2], caller)
} else {
for i := 1; i < len(terms); i++ {
terms[i] = e.B1.PlugNamespaced(terms[i], caller)
}
}
case *ast.Term:
expr.Terms = e.B1.PlugNamespaced(terms, caller)
}
for i := range expr.With {
expr.With[i].Value = e.B1.PlugNamespaced(expr.With[i].Value, caller)
}
return expr
}
// saveSupport contains additional partially evaluated policies that are part
// of the output of partial evaluation.
//
// The support structure is accumulated as partial evaluation runs and then
// considered complete once partial evaluation finishes (but not before). This
// differs from partially evaluated queries which are considered complete as
// soon as each one finishes.
type saveSupport struct {
modules map[string]*ast.Module
}
func newSaveSupport() *saveSupport {
return &saveSupport{
modules: map[string]*ast.Module{},
}
}
func (s *saveSupport) List() []*ast.Module {
result := []*ast.Module{}
for _, module := range s.modules {
result = append(result, module)
}
return result
}
func (s *saveSupport) Exists(path ast.Ref) bool {
k := path[:len(path)-1].String()
module, ok := s.modules[k]
if !ok {
return false
}
name := ast.Var(path[len(path)-1].Value.(ast.String))
for _, rule := range module.Rules {
if rule.Head.Name.Equal(name) {
return true
}
}
return false
}
func (s *saveSupport) Insert(path ast.Ref, rule *ast.Rule) {
pkg := path[:len(path)-1]
k := pkg.String()
module, ok := s.modules[k]
if !ok {
module = &ast.Module{
Package: &ast.Package{
Path: pkg,
},
}
s.modules[k] = module
}
rule.Module = module
module.Rules = append(module.Rules, rule)
}
// saveRequired returns true if the statement x will result in some expressions
// being saved. This check allows the evaluator to evaluate statements
// completely during partial evaluation as long as they do not depend on any
// kind of unknown value or statements that would generate saves.
func saveRequired(c *ast.Compiler, ss *saveSet, b *bindings, x interface{}, rec bool) bool {
var found bool
vis := ast.NewGenericVisitor(func(node interface{}) bool {
if found {
return found
}
switch node := node.(type) {
case *ast.Expr:
found = len(node.With) > 0 || ignoreExprDuringPartial(node)
case *ast.Term:
switch v := node.Value.(type) {
case ast.Var:
// Variables only need to be tested in the node from call site
// because once traversal recurses into a rule existing unknown
// variables are out-of-scope.
if !rec && ss.ContainsRecursive(node, b) {
found = true
}
case ast.Ref:
if ss.Contains(node, b) {
found = true
} else {
for _, rule := range c.GetRulesDynamic(v) {
if saveRequired(c, ss, b, rule, true) {
found = true
break
}
}
}
}
}
return found
})
vis.Walk(x)
return found
}
func ignoreExprDuringPartial(expr *ast.Expr) bool {
if !expr.IsCall() {
return false
}
bi, ok := ast.BuiltinMap[expr.Operator().String()]
return ok && ignoreDuringPartial(bi)
}
func ignoreDuringPartial(bi *ast.Builtin) bool {
for _, ignore := range ast.IgnoreDuringPartialEval {
if bi == ignore {
return true
}
}
return false
}

View File

@@ -0,0 +1,84 @@
// Copyright 2016 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
// Deprecated in v0.4.2 in favour of minus/infix "-" operation.
func builtinSetDiff(a, b ast.Value) (ast.Value, error) {
s1, err := builtins.SetOperand(a, 1)
if err != nil {
return nil, err
}
s2, err := builtins.SetOperand(b, 2)
if err != nil {
return nil, err
}
return s1.Diff(s2), nil
}
// builtinSetIntersection returns the intersection of the given input sets
func builtinSetIntersection(a ast.Value) (ast.Value, error) {
inputSet, err := builtins.SetOperand(a, 1)
if err != nil {
return nil, err
}
// empty input set
if inputSet.Len() == 0 {
return ast.NewSet(), nil
}
var result ast.Set
err = inputSet.Iter(func(x *ast.Term) error {
n, err := builtins.SetOperand(x.Value, 1)
if err != nil {
return err
}
if result == nil {
result = n
} else {
result = result.Intersect(n)
}
return nil
})
return result, err
}
// builtinSetUnion returns the union of the given input sets
func builtinSetUnion(a ast.Value) (ast.Value, error) {
inputSet, err := builtins.SetOperand(a, 1)
if err != nil {
return nil, err
}
result := ast.NewSet()
err = inputSet.Iter(func(x *ast.Term) error {
n, err := builtins.SetOperand(x.Value, 1)
if err != nil {
return err
}
result = result.Union(n)
return nil
})
return result, err
}
func init() {
RegisterFunctionalBuiltin2(ast.SetDiff.Name, builtinSetDiff)
RegisterFunctionalBuiltin1(ast.Intersection.Name, builtinSetIntersection)
RegisterFunctionalBuiltin1(ast.Union.Name, builtinSetUnion)
}

View File

@@ -0,0 +1,393 @@
// Copyright 2016 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"errors"
"fmt"
"strings"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
func builtinFormatInt(a, b ast.Value) (ast.Value, error) {
input, err := builtins.NumberOperand(a, 1)
if err != nil {
return nil, err
}
base, err := builtins.NumberOperand(b, 2)
if err != nil {
return nil, err
}
var format string
switch base {
case ast.Number("2"):
format = "%b"
case ast.Number("8"):
format = "%o"
case ast.Number("10"):
format = "%d"
case ast.Number("16"):
format = "%x"
default:
return nil, builtins.NewOperandEnumErr(2, "2", "8", "10", "16")
}
f := builtins.NumberToFloat(input)
i, _ := f.Int(nil)
return ast.String(fmt.Sprintf(format, i)), nil
}
func builtinConcat(a, b ast.Value) (ast.Value, error) {
join, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
strs := []string{}
switch b := b.(type) {
case ast.Array:
for i := range b {
s, ok := b[i].Value.(ast.String)
if !ok {
return nil, builtins.NewOperandElementErr(2, b, b[i].Value, "string")
}
strs = append(strs, string(s))
}
case ast.Set:
err := b.Iter(func(x *ast.Term) error {
s, ok := x.Value.(ast.String)
if !ok {
return builtins.NewOperandElementErr(2, b, x.Value, "string")
}
strs = append(strs, string(s))
return nil
})
if err != nil {
return nil, err
}
default:
return nil, builtins.NewOperandTypeErr(2, b, "set", "array")
}
return ast.String(strings.Join(strs, string(join))), nil
}
func builtinIndexOf(a, b ast.Value) (ast.Value, error) {
base, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
search, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
index := strings.Index(string(base), string(search))
return ast.IntNumberTerm(index).Value, nil
}
func builtinSubstring(a, b, c ast.Value) (ast.Value, error) {
base, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
startIndex, err := builtins.IntOperand(b, 2)
if err != nil {
return nil, err
} else if startIndex >= len(base) {
return ast.String(""), nil
} else if startIndex < 0 {
return nil, fmt.Errorf("negative offset")
}
length, err := builtins.IntOperand(c, 3)
if err != nil {
return nil, err
}
var s ast.String
if length < 0 {
s = ast.String(base[startIndex:])
} else {
upto := startIndex + length
if len(base) < upto {
upto = len(base)
}
s = ast.String(base[startIndex:upto])
}
return s, nil
}
func builtinContains(a, b ast.Value) (ast.Value, error) {
s, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
substr, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
return ast.Boolean(strings.Contains(string(s), string(substr))), nil
}
func builtinStartsWith(a, b ast.Value) (ast.Value, error) {
s, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
prefix, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
return ast.Boolean(strings.HasPrefix(string(s), string(prefix))), nil
}
func builtinEndsWith(a, b ast.Value) (ast.Value, error) {
s, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
suffix, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
return ast.Boolean(strings.HasSuffix(string(s), string(suffix))), nil
}
func builtinLower(a ast.Value) (ast.Value, error) {
s, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
return ast.String(strings.ToLower(string(s))), nil
}
func builtinUpper(a ast.Value) (ast.Value, error) {
s, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
return ast.String(strings.ToUpper(string(s))), nil
}
func builtinSplit(a, b ast.Value) (ast.Value, error) {
s, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
d, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
elems := strings.Split(string(s), string(d))
arr := make(ast.Array, len(elems))
for i := range arr {
arr[i] = ast.StringTerm(elems[i])
}
return arr, nil
}
func builtinReplace(a, b, c ast.Value) (ast.Value, error) {
s, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
old, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
new, err := builtins.StringOperand(c, 3)
if err != nil {
return nil, err
}
return ast.String(strings.Replace(string(s), string(old), string(new), -1)), nil
}
func builtinReplaceN(a, b ast.Value) (ast.Value, error) {
asJSON, err := ast.JSON(a)
if err != nil {
return nil, err
}
oldnewObj, ok := asJSON.(map[string]interface{})
if !ok {
return nil, builtins.NewOperandTypeErr(1, a, "object")
}
s, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
var oldnewArr []string
for k, v := range oldnewObj {
strVal, ok := v.(string)
if !ok {
return nil, errors.New("non-string value found in pattern object")
}
oldnewArr = append(oldnewArr, k, strVal)
}
r := strings.NewReplacer(oldnewArr...)
replaced := r.Replace(string(s))
return ast.String(replaced), nil
}
func builtinTrim(a, b ast.Value) (ast.Value, error) {
s, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
c, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
return ast.String(strings.Trim(string(s), string(c))), nil
}
func builtinTrimLeft(a, b ast.Value) (ast.Value, error) {
s, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
c, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
return ast.String(strings.TrimLeft(string(s), string(c))), nil
}
func builtinTrimPrefix(a, b ast.Value) (ast.Value, error) {
s, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
pre, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
return ast.String(strings.TrimPrefix(string(s), string(pre))), nil
}
func builtinTrimRight(a, b ast.Value) (ast.Value, error) {
s, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
c, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
return ast.String(strings.TrimRight(string(s), string(c))), nil
}
func builtinTrimSuffix(a, b ast.Value) (ast.Value, error) {
s, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
suf, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
return ast.String(strings.TrimSuffix(string(s), string(suf))), nil
}
func builtinTrimSpace(a ast.Value) (ast.Value, error) {
s, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
return ast.String(strings.TrimSpace(string(s))), nil
}
func builtinSprintf(a, b ast.Value) (ast.Value, error) {
s, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
astArr, ok := b.(ast.Array)
if !ok {
return nil, builtins.NewOperandTypeErr(2, b, "array")
}
args := make([]interface{}, len(astArr))
for i := range astArr {
switch v := astArr[i].Value.(type) {
case ast.Number:
if n, ok := v.Int(); ok {
args[i] = n
} else if f, ok := v.Float64(); ok {
args[i] = f
} else {
args[i] = v.String()
}
case ast.String:
args[i] = string(v)
default:
args[i] = astArr[i].String()
}
}
return ast.String(fmt.Sprintf(string(s), args...)), nil
}
func init() {
RegisterFunctionalBuiltin2(ast.FormatInt.Name, builtinFormatInt)
RegisterFunctionalBuiltin2(ast.Concat.Name, builtinConcat)
RegisterFunctionalBuiltin2(ast.IndexOf.Name, builtinIndexOf)
RegisterFunctionalBuiltin3(ast.Substring.Name, builtinSubstring)
RegisterFunctionalBuiltin2(ast.Contains.Name, builtinContains)
RegisterFunctionalBuiltin2(ast.StartsWith.Name, builtinStartsWith)
RegisterFunctionalBuiltin2(ast.EndsWith.Name, builtinEndsWith)
RegisterFunctionalBuiltin1(ast.Upper.Name, builtinUpper)
RegisterFunctionalBuiltin1(ast.Lower.Name, builtinLower)
RegisterFunctionalBuiltin2(ast.Split.Name, builtinSplit)
RegisterFunctionalBuiltin3(ast.Replace.Name, builtinReplace)
RegisterFunctionalBuiltin2(ast.ReplaceN.Name, builtinReplaceN)
RegisterFunctionalBuiltin2(ast.Trim.Name, builtinTrim)
RegisterFunctionalBuiltin2(ast.TrimLeft.Name, builtinTrimLeft)
RegisterFunctionalBuiltin2(ast.TrimPrefix.Name, builtinTrimPrefix)
RegisterFunctionalBuiltin2(ast.TrimRight.Name, builtinTrimRight)
RegisterFunctionalBuiltin2(ast.TrimSuffix.Name, builtinTrimSuffix)
RegisterFunctionalBuiltin1(ast.TrimSpace.Name, builtinTrimSpace)
RegisterFunctionalBuiltin2(ast.Sprintf.Name, builtinSprintf)
}

203
vendor/github.com/open-policy-agent/opa/topdown/time.go generated vendored Normal file
View File

@@ -0,0 +1,203 @@
// Copyright 2017 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"encoding/json"
"fmt"
"math/big"
"strconv"
"sync"
"time"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
type nowKeyID string
var nowKey = nowKeyID("time.now_ns")
var tzCache map[string]*time.Location
var tzCacheMutex *sync.Mutex
func builtinTimeNowNanos(bctx BuiltinContext, _ []*ast.Term, iter func(*ast.Term) error) error {
exist, ok := bctx.Cache.Get(nowKey)
var now *ast.Term
if !ok {
curr := time.Now()
now = ast.NewTerm(ast.Number(int64ToJSONNumber(curr.UnixNano())))
bctx.Cache.Put(nowKey, now)
} else {
now = exist.(*ast.Term)
}
return iter(now)
}
func builtinTimeParseNanos(a, b ast.Value) (ast.Value, error) {
format, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
value, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
result, err := time.Parse(string(format), string(value))
if err != nil {
return nil, err
}
return ast.Number(int64ToJSONNumber(result.UnixNano())), nil
}
func builtinTimeParseRFC3339Nanos(a ast.Value) (ast.Value, error) {
value, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
result, err := time.Parse(time.RFC3339, string(value))
if err != nil {
return nil, err
}
return ast.Number(int64ToJSONNumber(result.UnixNano())), nil
}
func builtinParseDurationNanos(a ast.Value) (ast.Value, error) {
duration, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
value, err := time.ParseDuration(string(duration))
if err != nil {
return nil, err
}
return ast.Number(int64ToJSONNumber(int64(value))), nil
}
func builtinDate(a ast.Value) (ast.Value, error) {
t, err := tzTime(a)
if err != nil {
return nil, err
}
year, month, day := t.Date()
result := ast.Array{ast.IntNumberTerm(year), ast.IntNumberTerm(int(month)), ast.IntNumberTerm(day)}
return result, nil
}
func builtinClock(a ast.Value) (ast.Value, error) {
t, err := tzTime(a)
if err != nil {
return nil, err
}
hour, minute, second := t.Clock()
result := ast.Array{ast.IntNumberTerm(hour), ast.IntNumberTerm(minute), ast.IntNumberTerm(second)}
return result, nil
}
func builtinWeekday(a ast.Value) (ast.Value, error) {
t, err := tzTime(a)
if err != nil {
return nil, err
}
weekday := t.Weekday().String()
return ast.String(weekday), nil
}
func tzTime(a ast.Value) (t time.Time, err error) {
var nVal ast.Value
loc := time.UTC
switch va := a.(type) {
case ast.Array:
if len(va) == 0 {
return time.Time{}, builtins.NewOperandTypeErr(1, a, "either number (ns) or [number (ns), string (tz)]")
}
nVal, err = builtins.NumberOperand(va[0].Value, 1)
if err != nil {
return time.Time{}, err
}
if len(va) > 1 {
tzVal, err := builtins.StringOperand(va[1].Value, 1)
if err != nil {
return time.Time{}, err
}
tzName := string(tzVal)
switch tzName {
case "", "UTC":
// loc is already UTC
case "Local":
loc = time.Local
default:
var ok bool
tzCacheMutex.Lock()
loc, ok = tzCache[tzName]
if !ok {
loc, err = time.LoadLocation(tzName)
if err != nil {
tzCacheMutex.Unlock()
return time.Time{}, err
}
tzCache[tzName] = loc
}
tzCacheMutex.Unlock()
}
}
case ast.Number:
nVal = a
default:
return time.Time{}, builtins.NewOperandTypeErr(1, a, "either number (ns) or [number (ns), string (tz)]")
}
value, err := builtins.NumberOperand(nVal, 1)
if err != nil {
return time.Time{}, err
}
f := builtins.NumberToFloat(value)
i64, acc := f.Int64()
if acc != big.Exact {
return time.Time{}, fmt.Errorf("timestamp too big")
}
t = time.Unix(0, i64).In(loc)
return t, nil
}
func int64ToJSONNumber(i int64) json.Number {
return json.Number(strconv.FormatInt(i, 10))
}
func init() {
RegisterBuiltinFunc(ast.NowNanos.Name, builtinTimeNowNanos)
RegisterFunctionalBuiltin1(ast.ParseRFC3339Nanos.Name, builtinTimeParseRFC3339Nanos)
RegisterFunctionalBuiltin2(ast.ParseNanos.Name, builtinTimeParseNanos)
RegisterFunctionalBuiltin1(ast.ParseDurationNanos.Name, builtinParseDurationNanos)
RegisterFunctionalBuiltin1(ast.Date.Name, builtinDate)
RegisterFunctionalBuiltin1(ast.Clock.Name, builtinClock)
RegisterFunctionalBuiltin1(ast.Weekday.Name, builtinWeekday)
tzCacheMutex = &sync.Mutex{}
tzCache = make(map[string]*time.Location)
}

View File

@@ -0,0 +1,967 @@
// Copyright 2017 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"crypto"
"crypto/ecdsa"
"crypto/hmac"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"strconv"
"strings"
"time"
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jwk"
"github.com/open-policy-agent/opa/topdown/internal/jwx/jws"
)
var (
jwtEncKey = ast.StringTerm("enc")
jwtCtyKey = ast.StringTerm("cty")
jwtAlgKey = ast.StringTerm("alg")
jwtIssKey = ast.StringTerm("iss")
jwtExpKey = ast.StringTerm("exp")
jwtNbfKey = ast.StringTerm("nbf")
jwtAudKey = ast.StringTerm("aud")
)
// JSONWebToken represent the 3 parts (header, payload & signature) of
// a JWT in Base64.
type JSONWebToken struct {
header string
payload string
signature string
decodedHeader ast.Object
}
// decodeHeader populates the decodedHeader field.
func (token *JSONWebToken) decodeHeader() (err error) {
var h ast.Value
if h, err = builtinBase64UrlDecode(ast.String(token.header)); err != nil {
return fmt.Errorf("JWT header had invalid encoding: %v", err)
}
if token.decodedHeader, err = validateJWTHeader(string(h.(ast.String))); err != nil {
return err
}
return
}
// Implements JWT decoding/validation based on RFC 7519 Section 7.2:
// https://tools.ietf.org/html/rfc7519#section-7.2
// It does no data validation, it merely checks that the given string
// represents a structurally valid JWT. It supports JWTs using JWS compact
// serialization.
func builtinJWTDecode(a ast.Value) (ast.Value, error) {
token, err := decodeJWT(a)
if err != nil {
return nil, err
}
if err = token.decodeHeader(); err != nil {
return nil, err
}
p, err := builtinBase64UrlDecode(ast.String(token.payload))
if err != nil {
return nil, fmt.Errorf("JWT payload had invalid encoding: %v", err)
}
if cty := token.decodedHeader.Get(jwtCtyKey); cty != nil {
ctyVal := string(cty.Value.(ast.String))
// It is possible for the contents of a token to be another
// token as a result of nested signing or encryption. To handle
// the case where we are given a token such as this, we check
// the content type and recurse on the payload if the content
// is "JWT".
// When the payload is itself another encoded JWT, then its
// contents are quoted (behavior of https://jwt.io/). To fix
// this, remove leading and trailing quotes.
if ctyVal == "JWT" {
p, err = builtinTrim(p, ast.String(`"'`))
if err != nil {
panic("not reached")
}
return builtinJWTDecode(p)
}
}
payload, err := extractJSONObject(string(p.(ast.String)))
if err != nil {
return nil, err
}
s, err := builtinBase64UrlDecode(ast.String(token.signature))
if err != nil {
return nil, fmt.Errorf("JWT signature had invalid encoding: %v", err)
}
sign := hex.EncodeToString([]byte(s.(ast.String)))
arr := make(ast.Array, 3)
arr[0] = ast.NewTerm(token.decodedHeader)
arr[1] = ast.NewTerm(payload)
arr[2] = ast.StringTerm(sign)
return arr, nil
}
// Implements RS256 JWT signature verification
func builtinJWTVerifyRS256(a ast.Value, b ast.Value) (ast.Value, error) {
return builtinJWTVerifyRSA(a, b, func(publicKey *rsa.PublicKey, digest []byte, signature []byte) error {
return rsa.VerifyPKCS1v15(
publicKey,
crypto.SHA256,
digest,
signature)
})
}
// Implements PS256 JWT signature verification
func builtinJWTVerifyPS256(a ast.Value, b ast.Value) (ast.Value, error) {
return builtinJWTVerifyRSA(a, b, func(publicKey *rsa.PublicKey, digest []byte, signature []byte) error {
return rsa.VerifyPSS(
publicKey,
crypto.SHA256,
digest,
signature,
nil)
})
}
// Implements RSA JWT signature verification.
func builtinJWTVerifyRSA(a ast.Value, b ast.Value, verify func(publicKey *rsa.PublicKey, digest []byte, signature []byte) error) (ast.Value, error) {
return builtinJWTVerify(a, b, func(publicKey interface{}, digest []byte, signature []byte) error {
publicKeyRsa, ok := publicKey.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("incorrect public key type")
}
return verify(publicKeyRsa, digest, signature)
})
}
// Implements ES256 JWT signature verification.
func builtinJWTVerifyES256(a ast.Value, b ast.Value) (ast.Value, error) {
return builtinJWTVerify(a, b, func(publicKey interface{}, digest []byte, signature []byte) error {
publicKeyEcdsa, ok := publicKey.(*ecdsa.PublicKey)
if !ok {
return fmt.Errorf("incorrect public key type")
}
r, s := &big.Int{}, &big.Int{}
n := len(signature) / 2
r.SetBytes(signature[:n])
s.SetBytes(signature[n:])
if ecdsa.Verify(publicKeyEcdsa, digest, r, s) {
return nil
}
return fmt.Errorf("ECDSA signature verification error")
})
}
// getKeyFromCertOrJWK returns the public key found in a X.509 certificate or JWK key(s).
// A valid PEM block is never valid JSON (and vice versa), hence can try parsing both.
func getKeyFromCertOrJWK(certificate string) ([]interface{}, error) {
if block, rest := pem.Decode([]byte(certificate)); block != nil {
if len(rest) > 0 {
return nil, fmt.Errorf("extra data after a PEM certificate block")
}
if block.Type == "CERTIFICATE" {
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, errors.Wrap(err, "failed to parse a PEM certificate")
}
return []interface{}{cert.PublicKey}, nil
}
if block.Type == "PUBLIC KEY" {
key, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, errors.Wrap(err, "failed to parse a PEM public key")
}
return []interface{}{key}, nil
}
return nil, fmt.Errorf("failed to extract a Key from the PEM certificate")
}
jwks, err := jwk.ParseString(certificate)
if err != nil {
return nil, errors.Wrap(err, "failed to parse a JWK key (set)")
}
var keys []interface{}
for _, k := range jwks.Keys {
key, err := k.Materialize()
if err != nil {
return nil, err
}
keys = append(keys, key)
}
return keys, nil
}
// Implements JWT signature verification.
func builtinJWTVerify(a ast.Value, b ast.Value, verify func(publicKey interface{}, digest []byte, signature []byte) error) (ast.Value, error) {
token, err := decodeJWT(a)
if err != nil {
return nil, err
}
s, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
keys, err := getKeyFromCertOrJWK(string(s))
if err != nil {
return nil, err
}
signature, err := token.decodeSignature()
if err != nil {
return nil, err
}
// Validate the JWT signature
for _, key := range keys {
err = verify(key,
getInputSHA([]byte(token.header+"."+token.payload)),
[]byte(signature))
if err == nil {
return ast.Boolean(true), nil
}
}
// None of the keys worked, return false
return ast.Boolean(false), nil
}
// Implements HS256 (secret) JWT signature verification
func builtinJWTVerifyHS256(a ast.Value, b ast.Value) (ast.Value, error) {
// Decode the JSON Web Token
token, err := decodeJWT(a)
if err != nil {
return nil, err
}
// Process Secret input
astSecret, err := builtins.StringOperand(b, 2)
if err != nil {
return nil, err
}
secret := string(astSecret)
mac := hmac.New(sha256.New, []byte(secret))
_, err = mac.Write([]byte(token.header + "." + token.payload))
if err != nil {
return nil, err
}
signature, err := token.decodeSignature()
if err != nil {
return nil, err
}
return ast.Boolean(hmac.Equal([]byte(signature), mac.Sum(nil))), nil
}
// -- Full JWT verification and decoding --
// Verification constraints. See tokens_test.go for unit tests.
// tokenConstraints holds decoded JWT verification constraints.
type tokenConstraints struct {
// The set of asymmetric keys we can verify with.
keys []interface{}
// The single symmetric key we will verify with.
secret string
// The algorithm that must be used to verify.
// If "", any algorithm is acceptable.
alg string
// The required issuer.
// If "", any issuer is acceptable.
iss string
// The required audience.
// If "", no audience is acceptable.
aud string
// The time to validate against, or -1 if no constraint set.
// (If unset, the current time will be used.)
time int64
}
// tokenConstraintHandler is the handler type for JWT verification constraints.
type tokenConstraintHandler func(value ast.Value, parameters *tokenConstraints) (err error)
// tokenConstraintTypes maps known JWT verification constraints to handlers.
var tokenConstraintTypes = map[string]tokenConstraintHandler{
"cert": tokenConstraintCert,
"secret": func(value ast.Value, constraints *tokenConstraints) (err error) {
return tokenConstraintString("secret", value, &constraints.secret)
},
"alg": func(value ast.Value, constraints *tokenConstraints) (err error) {
return tokenConstraintString("alg", value, &constraints.alg)
},
"iss": func(value ast.Value, constraints *tokenConstraints) (err error) {
return tokenConstraintString("iss", value, &constraints.iss)
},
"aud": func(value ast.Value, constraints *tokenConstraints) (err error) {
return tokenConstraintString("aud", value, &constraints.aud)
},
"time": tokenConstraintTime,
}
// tokenConstraintCert handles the `cert` constraint.
func tokenConstraintCert(value ast.Value, constraints *tokenConstraints) (err error) {
var s ast.String
var ok bool
if s, ok = value.(ast.String); !ok {
return fmt.Errorf("cert constraint: must be a string")
}
constraints.keys, err = getKeyFromCertOrJWK(string(s))
return
}
// tokenConstraintTime handles the `time` constraint.
func tokenConstraintTime(value ast.Value, constraints *tokenConstraints) (err error) {
var time ast.Number
var ok bool
if time, ok = value.(ast.Number); !ok {
err = fmt.Errorf("token time constraint: must be a number")
return
}
var timeFloat float64
if timeFloat, err = strconv.ParseFloat(string(time), 64); err != nil {
err = fmt.Errorf("token time constraint: %v", err)
return
}
if timeFloat < 0 {
err = fmt.Errorf("token time constraint: must not be negative")
return
}
constraints.time = int64(timeFloat)
return
}
// tokenConstraintString handles string constraints.
func tokenConstraintString(name string, value ast.Value, where *string) (err error) {
var av ast.String
var ok bool
if av, ok = value.(ast.String); !ok {
err = fmt.Errorf("%s constraint: must be a string", name)
return
}
*where = string(av)
return
}
// parseTokenConstraints parses the constraints argument.
func parseTokenConstraints(a ast.Value) (constraints tokenConstraints, err error) {
constraints.time = -1
var o ast.Object
var ok bool
if o, ok = a.(ast.Object); !ok {
err = fmt.Errorf("token constraints must be object")
return
}
if err = o.Iter(func(k *ast.Term, v *ast.Term) (err error) {
var handler tokenConstraintHandler
var ok bool
name := string(k.Value.(ast.String))
if handler, ok = tokenConstraintTypes[name]; ok {
if err = handler(v.Value, &constraints); err != nil {
return
}
} else {
// Anything unknown is rejected.
err = fmt.Errorf("unknown token validation constraint: %s", name)
return
}
return
}); err != nil {
return
}
return
}
// validate validates the constraints argument.
func (constraints *tokenConstraints) validate() (err error) {
keys := 0
if constraints.keys != nil {
keys++
}
if constraints.secret != "" {
keys++
}
if keys > 1 {
err = fmt.Errorf("duplicate key constraints")
return
}
if keys < 1 {
err = fmt.Errorf("no key constraint")
return
}
return
}
// verify verifies a JWT using the constraints and the algorithm from the header
func (constraints *tokenConstraints) verify(kid, alg, header, payload, signature string) error {
// Construct the payload
plaintext := []byte(header)
plaintext = append(plaintext, []byte(".")...)
plaintext = append(plaintext, payload...)
// Look up the algorithm
var ok bool
var a tokenAlgorithm
a, ok = tokenAlgorithms[alg]
if !ok {
return fmt.Errorf("unknown JWS algorithm: %s", alg)
}
// If we're configured with asymmetric key(s) then only trust that
if constraints.keys != nil {
verified := false
for _, key := range constraints.keys {
err := a.verify(key, a.hash, plaintext, []byte(signature))
if err == nil {
verified = true
break
}
}
if !verified {
return errSignatureNotVerified
}
return nil
}
if constraints.secret != "" {
return a.verify([]byte(constraints.secret), a.hash, plaintext, []byte(signature))
}
// (*tokenConstraints)validate() should prevent this happening
return errors.New("unexpectedly found no keys to trust")
}
// validAudience checks the audience of the JWT.
// It returns true if it meets the constraints and false otherwise.
func (constraints *tokenConstraints) validAudience(aud ast.Value) (valid bool) {
var ok bool
var s ast.String
if s, ok = aud.(ast.String); ok {
return string(s) == constraints.aud
}
var a ast.Array
if a, ok = aud.(ast.Array); ok {
for _, t := range a {
if s, ok = t.Value.(ast.String); ok {
if string(s) == constraints.aud {
return true
}
} else {
// Ill-formed aud claim
return false
}
}
}
return false
}
// JWT algorithms
type tokenVerifyFunction func(key interface{}, hash crypto.Hash, payload []byte, signature []byte) (err error)
type tokenVerifyAsymmetricFunction func(key interface{}, hash crypto.Hash, digest []byte, signature []byte) (err error)
// jwtAlgorithm describes a JWS 'alg' value
type tokenAlgorithm struct {
hash crypto.Hash
verify tokenVerifyFunction
}
// tokenAlgorithms is the known JWT algorithms
var tokenAlgorithms = map[string]tokenAlgorithm{
"RS256": {crypto.SHA256, verifyAsymmetric(verifyRSAPKCS)},
"RS384": {crypto.SHA384, verifyAsymmetric(verifyRSAPKCS)},
"RS512": {crypto.SHA512, verifyAsymmetric(verifyRSAPKCS)},
"PS256": {crypto.SHA256, verifyAsymmetric(verifyRSAPSS)},
"PS384": {crypto.SHA384, verifyAsymmetric(verifyRSAPSS)},
"PS512": {crypto.SHA512, verifyAsymmetric(verifyRSAPSS)},
"ES256": {crypto.SHA256, verifyAsymmetric(verifyECDSA)},
"ES384": {crypto.SHA384, verifyAsymmetric(verifyECDSA)},
"ES512": {crypto.SHA512, verifyAsymmetric(verifyECDSA)},
"HS256": {crypto.SHA256, verifyHMAC},
"HS384": {crypto.SHA384, verifyHMAC},
"HS512": {crypto.SHA512, verifyHMAC},
}
// errSignatureNotVerified is returned when a signature cannot be verified.
var errSignatureNotVerified = errors.New("signature not verified")
func verifyHMAC(key interface{}, hash crypto.Hash, payload []byte, signature []byte) (err error) {
macKey, ok := key.([]byte)
if !ok {
return fmt.Errorf("incorrect symmetric key type")
}
mac := hmac.New(hash.New, macKey)
if _, err = mac.Write([]byte(payload)); err != nil {
return
}
if !hmac.Equal(signature, mac.Sum([]byte{})) {
err = errSignatureNotVerified
}
return
}
func verifyAsymmetric(verify tokenVerifyAsymmetricFunction) tokenVerifyFunction {
return func(key interface{}, hash crypto.Hash, payload []byte, signature []byte) (err error) {
h := hash.New()
h.Write(payload)
return verify(key, hash, h.Sum([]byte{}), signature)
}
}
func verifyRSAPKCS(key interface{}, hash crypto.Hash, digest []byte, signature []byte) (err error) {
publicKeyRsa, ok := key.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("incorrect public key type")
}
if err = rsa.VerifyPKCS1v15(publicKeyRsa, hash, digest, signature); err != nil {
err = errSignatureNotVerified
}
return
}
func verifyRSAPSS(key interface{}, hash crypto.Hash, digest []byte, signature []byte) (err error) {
publicKeyRsa, ok := key.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("incorrect public key type")
}
if err = rsa.VerifyPSS(publicKeyRsa, hash, digest, signature, nil); err != nil {
err = errSignatureNotVerified
}
return
}
func verifyECDSA(key interface{}, hash crypto.Hash, digest []byte, signature []byte) (err error) {
publicKeyEcdsa, ok := key.(*ecdsa.PublicKey)
if !ok {
return fmt.Errorf("incorrect public key type")
}
r, s := &big.Int{}, &big.Int{}
n := len(signature) / 2
r.SetBytes(signature[:n])
s.SetBytes(signature[n:])
if ecdsa.Verify(publicKeyEcdsa, digest, r, s) {
return nil
}
return errSignatureNotVerified
}
// JWT header parsing and parameters. See tokens_test.go for unit tests.
// tokenHeaderType represents a recognized JWT header field
// tokenHeader is a parsed JWT header
type tokenHeader struct {
alg string
kid string
typ string
cty string
crit map[string]bool
unknown []string
}
// tokenHeaderHandler handles a JWT header parameters
type tokenHeaderHandler func(header *tokenHeader, value ast.Value) (err error)
// tokenHeaderTypes maps known JWT header parameters to handlers
var tokenHeaderTypes = map[string]tokenHeaderHandler{
"alg": func(header *tokenHeader, value ast.Value) (err error) {
return tokenHeaderString("alg", &header.alg, value)
},
"kid": func(header *tokenHeader, value ast.Value) (err error) {
return tokenHeaderString("kid", &header.kid, value)
},
"typ": func(header *tokenHeader, value ast.Value) (err error) {
return tokenHeaderString("typ", &header.typ, value)
},
"cty": func(header *tokenHeader, value ast.Value) (err error) {
return tokenHeaderString("cty", &header.cty, value)
},
"crit": tokenHeaderCrit,
}
// tokenHeaderCrit handles the 'crit' header parameter
func tokenHeaderCrit(header *tokenHeader, value ast.Value) (err error) {
var ok bool
var v ast.Array
if v, ok = value.(ast.Array); !ok {
err = fmt.Errorf("crit: must be a list")
return
}
header.crit = map[string]bool{}
for _, t := range v {
var tv ast.String
if tv, ok = t.Value.(ast.String); !ok {
err = fmt.Errorf("crit: must be a list of strings")
return
}
header.crit[string(tv)] = true
}
if len(header.crit) == 0 {
err = fmt.Errorf("crit: must be a nonempty list") // 'MUST NOT' use the empty list
return
}
return
}
// tokenHeaderString handles string-format JWT header parameters
func tokenHeaderString(name string, where *string, value ast.Value) (err error) {
var ok bool
var v ast.String
if v, ok = value.(ast.String); !ok {
err = fmt.Errorf("%s: must be a string", name)
return
}
*where = string(v)
return
}
// parseTokenHeader parses the JWT header.
func parseTokenHeader(token *JSONWebToken) (header tokenHeader, err error) {
header.unknown = []string{}
if err = token.decodedHeader.Iter(func(k *ast.Term, v *ast.Term) (err error) {
ks := string(k.Value.(ast.String))
var ok bool
var handler tokenHeaderHandler
if handler, ok = tokenHeaderTypes[ks]; ok {
if err = handler(&header, v.Value); err != nil {
return
}
} else {
header.unknown = append(header.unknown, ks)
}
return
}); err != nil {
return
}
return
}
// validTokenHeader returns true if the JOSE header is valid, otherwise false.
func (header *tokenHeader) valid() bool {
// RFC7515 s4.1.1 alg MUST be present
if header.alg == "" {
return false
}
// RFC7515 4.1.11 JWS is invalid if there is a critical parameter that we did not recognize
for _, u := range header.unknown {
if header.crit[u] {
return false
}
}
return true
}
func commonBuiltinJWTEncodeSign(inputHeaders, jwsPayload, jwkSrc string) (v ast.Value, err error) {
keys, err := jwk.ParseString(jwkSrc)
if err != nil {
return nil, err
}
key, err := keys.Keys[0].Materialize()
if err != nil {
return nil, err
}
if jwk.GetKeyTypeFromKey(key) != keys.Keys[0].GetKeyType() {
return nil, fmt.Errorf("JWK derived key type and keyType parameter do not match")
}
standardHeaders := &jws.StandardHeaders{}
jwsHeaders := []byte(inputHeaders)
err = json.Unmarshal(jwsHeaders, standardHeaders)
if err != nil {
return nil, err
}
alg := standardHeaders.GetAlgorithm()
if (standardHeaders.Type == "" || standardHeaders.Type == "JWT") && !json.Valid([]byte(jwsPayload)) {
return nil, fmt.Errorf("type is JWT but payload is not JSON")
}
// process payload and sign
var jwsCompact []byte
jwsCompact, err = jws.SignLiteral([]byte(jwsPayload), alg, key, jwsHeaders)
if err != nil {
return nil, err
}
return ast.String(jwsCompact[:]), nil
}
func builtinJWTEncodeSign(a ast.Value, b ast.Value, c ast.Value) (v ast.Value, err error) {
jwkSrc := c.String()
inputHeaders := a.String()
jwsPayload := b.String()
return commonBuiltinJWTEncodeSign(inputHeaders, jwsPayload, jwkSrc)
}
func builtinJWTEncodeSignRaw(a ast.Value, b ast.Value, c ast.Value) (v ast.Value, err error) {
jwkSrc, err := builtins.StringOperand(c, 1)
if err != nil {
return nil, err
}
inputHeaders, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
jwsPayload, err := builtins.StringOperand(b, 1)
if err != nil {
return nil, err
}
return commonBuiltinJWTEncodeSign(string(inputHeaders), string(jwsPayload), string(jwkSrc))
}
// Implements full JWT decoding, validation and verification.
func builtinJWTDecodeVerify(a ast.Value, b ast.Value) (v ast.Value, err error) {
// io.jwt.decode_verify(string, constraints, [valid, header, payload])
//
// If valid is true then the signature verifies and all constraints are met.
// If valid is false then either the signature did not verify or some constrain
// was not met.
//
// Decoding errors etc are returned as errors.
arr := make(ast.Array, 3)
arr[0] = ast.BooleanTerm(false) // by default, not verified
arr[1] = ast.NewTerm(ast.NewObject())
arr[2] = ast.NewTerm(ast.NewObject())
var constraints tokenConstraints
if constraints, err = parseTokenConstraints(b); err != nil {
return
}
if err = constraints.validate(); err != nil {
return
}
var token *JSONWebToken
var p ast.Value
for {
// RFC7519 7.2 #1-2 split into parts
if token, err = decodeJWT(a); err != nil {
return
}
// RFC7519 7.2 #3, #4, #6
if err = token.decodeHeader(); err != nil {
return
}
// RFC7159 7.2 #5 (and RFC7159 5.2 #5) validate header fields
var header tokenHeader
if header, err = parseTokenHeader(token); err != nil {
return
}
if !header.valid() {
return arr, nil
}
// Check constraints that impact signature verification.
if constraints.alg != "" && constraints.alg != header.alg {
return arr, nil
}
// RFC7159 7.2 #7 verify the signature
var signature string
if signature, err = token.decodeSignature(); err != nil {
return
}
if err = constraints.verify(header.kid, header.alg, token.header, token.payload, signature); err != nil {
if err == errSignatureNotVerified {
return arr, nil
}
return
}
// RFC7159 7.2 #9-10 decode the payload
if p, err = builtinBase64UrlDecode(ast.String(token.payload)); err != nil {
return nil, fmt.Errorf("JWT payload had invalid encoding: %v", err)
}
// RFC7159 7.2 #8 and 5.2 cty
if strings.ToUpper(header.cty) == "JWT" {
// Nested JWT, go round again
a = p
continue
} else {
// Non-nested JWT (or we've reached the bottom of the nesting).
break
}
}
var payload ast.Object
if payload, err = extractJSONObject(string(p.(ast.String))); err != nil {
return
}
// Check registered claim names against constraints or environment
// RFC7159 4.1.1 iss
if constraints.iss != "" {
if iss := payload.Get(jwtIssKey); iss != nil {
issVal := string(iss.Value.(ast.String))
if constraints.iss != issVal {
return arr, nil
}
}
}
// RFC7159 4.1.3 aud
if aud := payload.Get(jwtAudKey); aud != nil {
if !constraints.validAudience(aud.Value) {
return arr, nil
}
} else {
if constraints.aud != "" {
return arr, nil
}
}
// RFC7159 4.1.4 exp
if exp := payload.Get(jwtExpKey); exp != nil {
var expVal int64
if expVal, err = strconv.ParseInt(string(exp.Value.(ast.Number)), 10, 64); err != nil {
err = fmt.Errorf("parsing 'exp' JWT claim: %v", err)
return
}
if constraints.time < 0 {
constraints.time = time.Now().UnixNano()
}
// constraints.time is in nanoseconds but expVal is in seconds
if constraints.time/1000000000 >= expVal {
return arr, nil
}
}
// RFC7159 4.1.5 nbf
if nbf := payload.Get(jwtNbfKey); nbf != nil {
var nbfVal int64
if nbfVal, err = strconv.ParseInt(string(nbf.Value.(ast.Number)), 10, 64); err != nil {
err = fmt.Errorf("parsing 'nbf' JWT claim: %v", err)
return
}
if constraints.time < 0 {
constraints.time = time.Now().UnixNano()
}
// constraints.time is in nanoseconds but nbfVal is in seconds
if constraints.time/1000000000 < nbfVal {
return arr, nil
}
}
// Format the result
arr[0] = ast.BooleanTerm(true)
arr[1] = ast.NewTerm(token.decodedHeader)
arr[2] = ast.NewTerm(payload)
return arr, nil
}
// -- Utilities --
func decodeJWT(a ast.Value) (*JSONWebToken, error) {
// Parse the JSON Web Token
astEncode, err := builtins.StringOperand(a, 1)
if err != nil {
return nil, err
}
encoding := string(astEncode)
if !strings.Contains(encoding, ".") {
return nil, errors.New("encoded JWT had no period separators")
}
parts := strings.Split(encoding, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("encoded JWT must have 3 sections, found %d", len(parts))
}
return &JSONWebToken{header: parts[0], payload: parts[1], signature: parts[2]}, nil
}
func (token *JSONWebToken) decodeSignature() (string, error) {
decodedSignature, err := builtinBase64UrlDecode(ast.String(token.signature))
if err != nil {
return "", err
}
signatureAst, err := builtins.StringOperand(decodedSignature, 1)
if err != nil {
return "", err
}
return string(signatureAst), err
}
// Extract, validate and return the JWT header as an ast.Object.
func validateJWTHeader(h string) (ast.Object, error) {
header, err := extractJSONObject(h)
if err != nil {
return nil, fmt.Errorf("bad JWT header: %v", err)
}
// There are two kinds of JWT tokens, a JSON Web Signature (JWS) and
// a JSON Web Encryption (JWE). The latter is very involved, and we
// won't support it for now.
// This code checks which kind of JWT we are dealing with according to
// RFC 7516 Section 9: https://tools.ietf.org/html/rfc7516#section-9
if header.Get(jwtEncKey) != nil {
return nil, errors.New("JWT is a JWE object, which is not supported")
}
return header, nil
}
func extractJSONObject(s string) (ast.Object, error) {
// XXX: This code relies on undocumented behavior of Go's
// json.Unmarshal using the last occurrence of duplicate keys in a JSON
// Object. If duplicate keys are present in a JWT, the last must be
// used or the token rejected. Since detecting duplicates is tantamount
// to parsing it ourselves, we're relying on the Go implementation
// using the last occurring instance of the key, which is the behavior
// as of Go 1.8.1.
v, err := builtinJSONUnmarshal(ast.String(s))
if err != nil {
return nil, fmt.Errorf("invalid JSON: %v", err)
}
o, ok := v.(ast.Object)
if !ok {
return nil, errors.New("decoded JSON type was not an Object")
}
return o, nil
}
// getInputSha returns the SHA256 checksum of the input
func getInputSHA(input []byte) (hash []byte) {
hasher := sha256.New()
hasher.Write(input)
return hasher.Sum(nil)
}
func init() {
RegisterFunctionalBuiltin1(ast.JWTDecode.Name, builtinJWTDecode)
RegisterFunctionalBuiltin2(ast.JWTVerifyRS256.Name, builtinJWTVerifyRS256)
RegisterFunctionalBuiltin2(ast.JWTVerifyPS256.Name, builtinJWTVerifyPS256)
RegisterFunctionalBuiltin2(ast.JWTVerifyES256.Name, builtinJWTVerifyES256)
RegisterFunctionalBuiltin2(ast.JWTVerifyHS256.Name, builtinJWTVerifyHS256)
RegisterFunctionalBuiltin2(ast.JWTDecodeVerify.Name, builtinJWTDecodeVerify)
RegisterFunctionalBuiltin3(ast.JWTEncodeSignRaw.Name, builtinJWTEncodeSignRaw)
RegisterFunctionalBuiltin3(ast.JWTEncodeSign.Name, builtinJWTEncodeSign)
}

View File

@@ -0,0 +1,304 @@
// Copyright 2016 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"fmt"
"io"
"strings"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
// Op defines the types of tracing events.
type Op string
const (
// EnterOp is emitted when a new query is about to be evaluated.
EnterOp Op = "Enter"
// ExitOp is emitted when a query has evaluated to true.
ExitOp Op = "Exit"
// EvalOp is emitted when an expression is about to be evaluated.
EvalOp Op = "Eval"
// RedoOp is emitted when an expression, rule, or query is being re-evaluated.
RedoOp Op = "Redo"
// SaveOp is emitted when an expression is saved instead of evaluated
// during partial evaluation.
SaveOp Op = "Save"
// FailOp is emitted when an expression evaluates to false.
FailOp Op = "Fail"
// NoteOp is emitted when an expression invokes a tracing built-in function.
NoteOp Op = "Note"
// IndexOp is emitted during an expression evaluation to represent lookup
// matches.
IndexOp Op = "Index"
)
// VarMetadata provides some user facing information about
// a variable in some policy.
type VarMetadata struct {
Name ast.Var `json:"name"`
Location *ast.Location `json:"location"`
}
// Event contains state associated with a tracing event.
type Event struct {
Op Op // Identifies type of event.
Node ast.Node // Contains AST node relevant to the event.
Location *ast.Location // The location of the Node this event relates to.
QueryID uint64 // Identifies the query this event belongs to.
ParentID uint64 // Identifies the parent query this event belongs to.
Locals *ast.ValueMap // Contains local variable bindings from the query context.
LocalMetadata map[ast.Var]VarMetadata // Contains metadata for the local variable bindings.
Message string // Contains message for Note events.
}
// HasRule returns true if the Event contains an ast.Rule.
func (evt *Event) HasRule() bool {
_, ok := evt.Node.(*ast.Rule)
return ok
}
// HasBody returns true if the Event contains an ast.Body.
func (evt *Event) HasBody() bool {
_, ok := evt.Node.(ast.Body)
return ok
}
// HasExpr returns true if the Event contains an ast.Expr.
func (evt *Event) HasExpr() bool {
_, ok := evt.Node.(*ast.Expr)
return ok
}
// Equal returns true if this event is equal to the other event.
func (evt *Event) Equal(other *Event) bool {
if evt.Op != other.Op {
return false
}
if evt.QueryID != other.QueryID {
return false
}
if evt.ParentID != other.ParentID {
return false
}
if !evt.equalNodes(other) {
return false
}
return evt.Locals.Equal(other.Locals)
}
func (evt *Event) String() string {
return fmt.Sprintf("%v %v %v (qid=%v, pqid=%v)", evt.Op, evt.Node, evt.Locals, evt.QueryID, evt.ParentID)
}
func (evt *Event) equalNodes(other *Event) bool {
switch a := evt.Node.(type) {
case ast.Body:
if b, ok := other.Node.(ast.Body); ok {
return a.Equal(b)
}
case *ast.Rule:
if b, ok := other.Node.(*ast.Rule); ok {
return a.Equal(b)
}
case *ast.Expr:
if b, ok := other.Node.(*ast.Expr); ok {
return a.Equal(b)
}
case nil:
return other.Node == nil
}
return false
}
// Tracer defines the interface for tracing in the top-down evaluation engine.
type Tracer interface {
Enabled() bool
Trace(*Event)
}
// BufferTracer implements the Tracer interface by simply buffering all events
// received.
type BufferTracer []*Event
// NewBufferTracer returns a new BufferTracer.
func NewBufferTracer() *BufferTracer {
return &BufferTracer{}
}
// Enabled always returns true if the BufferTracer is instantiated.
func (b *BufferTracer) Enabled() bool {
if b == nil {
return false
}
return true
}
// Trace adds the event to the buffer.
func (b *BufferTracer) Trace(evt *Event) {
*b = append(*b, evt)
}
// PrettyTrace pretty prints the trace to the writer.
func PrettyTrace(w io.Writer, trace []*Event) {
depths := depths{}
for _, event := range trace {
depth := depths.GetOrSet(event.QueryID, event.ParentID)
fmt.Fprintln(w, formatEvent(event, depth))
}
}
// PrettyTraceWithLocation prints the trace to the writer and includes location information
func PrettyTraceWithLocation(w io.Writer, trace []*Event) {
depths := depths{}
for _, event := range trace {
depth := depths.GetOrSet(event.QueryID, event.ParentID)
location := formatLocation(event)
fmt.Fprintln(w, fmt.Sprintf("%v %v", location, formatEvent(event, depth)))
}
}
func formatEvent(event *Event, depth int) string {
padding := formatEventPadding(event, depth)
if event.Op == NoteOp {
return fmt.Sprintf("%v%v %q", padding, event.Op, event.Message)
} else if event.Message != "" {
return fmt.Sprintf("%v%v %v %v", padding, event.Op, event.Node, event.Message)
} else {
switch node := event.Node.(type) {
case *ast.Rule:
return fmt.Sprintf("%v%v %v", padding, event.Op, node.Path())
default:
return fmt.Sprintf("%v%v %v", padding, event.Op, rewrite(event).Node)
}
}
}
func formatEventPadding(event *Event, depth int) string {
spaces := formatEventSpaces(event, depth)
padding := ""
if spaces > 1 {
padding += strings.Repeat("| ", spaces-1)
}
return padding
}
func formatEventSpaces(event *Event, depth int) int {
switch event.Op {
case EnterOp:
return depth
case RedoOp:
if _, ok := event.Node.(*ast.Expr); !ok {
return depth
}
}
return depth + 1
}
func formatLocation(event *Event) string {
if event.Op == NoteOp {
return fmt.Sprintf("%-19v", "note")
}
location := event.Location
if location == nil {
return fmt.Sprintf("%-19v", "")
}
if location.File == "" {
return fmt.Sprintf("%-19v", fmt.Sprintf("%.15v:%v", "query", location.Row))
}
return fmt.Sprintf("%-19v", fmt.Sprintf("%.15v:%v", location.File, location.Row))
}
// depths is a helper for computing the depth of an event. Events within the
// same query all have the same depth. The depth of query is
// depth(parent(query))+1.
type depths map[uint64]int
func (ds depths) GetOrSet(qid uint64, pqid uint64) int {
depth := ds[qid]
if depth == 0 {
depth = ds[pqid]
depth++
ds[qid] = depth
}
return depth
}
func builtinTrace(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
str, err := builtins.StringOperand(args[0].Value, 1)
if err != nil {
return handleBuiltinErr(ast.Trace.Name, bctx.Location, err)
}
if !traceIsEnabled(bctx.Tracers) {
return iter(ast.BooleanTerm(true))
}
evt := &Event{
Op: NoteOp,
QueryID: bctx.QueryID,
ParentID: bctx.ParentID,
Message: string(str),
}
for i := range bctx.Tracers {
bctx.Tracers[i].Trace(evt)
}
return iter(ast.BooleanTerm(true))
}
func traceIsEnabled(tracers []Tracer) bool {
for i := range tracers {
if tracers[i].Enabled() {
return true
}
}
return false
}
func rewrite(event *Event) *Event {
cpy := *event
var node ast.Node
switch v := event.Node.(type) {
case *ast.Expr:
node = v.Copy()
case ast.Body:
node = v.Copy()
case *ast.Rule:
node = v.Copy()
}
ast.TransformVars(node, func(v ast.Var) (ast.Value, error) {
if meta, ok := cpy.LocalMetadata[v]; ok {
return meta.Name, nil
}
return v, nil
})
cpy.Node = node
return &cpy
}
func init() {
RegisterBuiltinFunc(ast.Trace.Name, builtinTrace)
}

View File

@@ -0,0 +1,82 @@
// Copyright 2018 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"github.com/open-policy-agent/opa/ast"
)
func builtinIsNumber(a ast.Value) (ast.Value, error) {
switch a.(type) {
case ast.Number:
return ast.Boolean(true), nil
default:
return nil, BuiltinEmpty{}
}
}
func builtinIsString(a ast.Value) (ast.Value, error) {
switch a.(type) {
case ast.String:
return ast.Boolean(true), nil
default:
return nil, BuiltinEmpty{}
}
}
func builtinIsBoolean(a ast.Value) (ast.Value, error) {
switch a.(type) {
case ast.Boolean:
return ast.Boolean(true), nil
default:
return nil, BuiltinEmpty{}
}
}
func builtinIsArray(a ast.Value) (ast.Value, error) {
switch a.(type) {
case ast.Array:
return ast.Boolean(true), nil
default:
return nil, BuiltinEmpty{}
}
}
func builtinIsSet(a ast.Value) (ast.Value, error) {
switch a.(type) {
case ast.Set:
return ast.Boolean(true), nil
default:
return nil, BuiltinEmpty{}
}
}
func builtinIsObject(a ast.Value) (ast.Value, error) {
switch a.(type) {
case ast.Object:
return ast.Boolean(true), nil
default:
return nil, BuiltinEmpty{}
}
}
func builtinIsNull(a ast.Value) (ast.Value, error) {
switch a.(type) {
case ast.Null:
return ast.Boolean(true), nil
default:
return nil, BuiltinEmpty{}
}
}
func init() {
RegisterFunctionalBuiltin1(ast.IsNumber.Name, builtinIsNumber)
RegisterFunctionalBuiltin1(ast.IsString.Name, builtinIsString)
RegisterFunctionalBuiltin1(ast.IsBoolean.Name, builtinIsBoolean)
RegisterFunctionalBuiltin1(ast.IsArray.Name, builtinIsArray)
RegisterFunctionalBuiltin1(ast.IsSet.Name, builtinIsSet)
RegisterFunctionalBuiltin1(ast.IsObject.Name, builtinIsObject)
RegisterFunctionalBuiltin1(ast.IsNull.Name, builtinIsNull)
}

View File

@@ -0,0 +1,36 @@
// Copyright 2018 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"fmt"
"github.com/open-policy-agent/opa/ast"
)
func builtinTypeName(a ast.Value) (ast.Value, error) {
switch a.(type) {
case ast.Null:
return ast.String("null"), nil
case ast.Boolean:
return ast.String("boolean"), nil
case ast.Number:
return ast.String("number"), nil
case ast.String:
return ast.String("string"), nil
case ast.Array:
return ast.String("array"), nil
case ast.Object:
return ast.String("object"), nil
case ast.Set:
return ast.String("set"), nil
}
return nil, fmt.Errorf("illegal value")
}
func init() {
RegisterFunctionalBuiltin1(ast.TypeNameBuiltin.Name, builtinTypeName)
}

View File

@@ -0,0 +1,84 @@
// Copyright 2017 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"github.com/open-policy-agent/opa/ast"
)
func evalWalk(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
input := args[0]
filter := getOutputPath(args)
var path ast.Array
return walk(filter, path, input, iter)
}
func walk(filter, path ast.Array, input *ast.Term, iter func(*ast.Term) error) error {
if len(filter) == 0 {
if err := iter(ast.ArrayTerm(ast.NewTerm(path), input)); err != nil {
return err
}
}
if len(filter) > 0 {
key := filter[0]
filter = filter[1:]
if key.IsGround() {
if term := input.Get(key); term != nil {
return walk(filter, append(path, key), term, iter)
}
return nil
}
}
switch v := input.Value.(type) {
case ast.Array:
for i := range v {
path = append(path, ast.IntNumberTerm(i))
if err := walk(filter, path, v[i], iter); err != nil {
return err
}
path = path[:len(path)-1]
}
case ast.Object:
return v.Iter(func(k, v *ast.Term) error {
path = append(path, k)
if err := walk(filter, path, v, iter); err != nil {
return err
}
path = path[:len(path)-1]
return nil
})
case ast.Set:
return v.Iter(func(elem *ast.Term) error {
path = append(path, elem)
if err := walk(filter, path, elem, iter); err != nil {
return err
}
path = path[:len(path)-1]
return nil
})
}
return nil
}
func getOutputPath(args []*ast.Term) ast.Array {
if len(args) == 2 {
if arr, ok := args[1].Value.(ast.Array); ok {
if len(arr) == 2 {
if path, ok := arr[0].Value.(ast.Array); ok {
return path
}
}
}
}
return nil
}
func init() {
RegisterBuiltinFunc(ast.WalkBuiltin.Name, evalWalk)
}