140 lines
3.5 KiB
Go
140 lines
3.5 KiB
Go
// Copyright 2020 The OPA Authors. All rights reserved.
|
|
// Use of this source code is governed by an Apache2
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package topdown
|
|
|
|
import (
|
|
"github.com/open-policy-agent/opa/ast"
|
|
"github.com/open-policy-agent/opa/topdown/builtins"
|
|
"github.com/open-policy-agent/opa/types"
|
|
)
|
|
|
|
func builtinObjectUnion(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
|
objA, err := builtins.ObjectOperand(operands[0].Value, 1)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
objB, err := builtins.ObjectOperand(operands[1].Value, 2)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
r := mergeWithOverwrite(objA, objB)
|
|
|
|
return iter(ast.NewTerm(r))
|
|
}
|
|
|
|
func builtinObjectRemove(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
|
// Expect an object and an array/set/object of keys
|
|
obj, err := builtins.ObjectOperand(operands[0].Value, 1)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Build a set of keys to remove
|
|
keysToRemove, err := getObjectKeysParam(operands[1].Value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
r := ast.NewObject()
|
|
obj.Foreach(func(key *ast.Term, value *ast.Term) {
|
|
if !keysToRemove.Contains(key) {
|
|
r.Insert(key, value)
|
|
}
|
|
})
|
|
|
|
return iter(ast.NewTerm(r))
|
|
}
|
|
|
|
func builtinObjectFilter(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
|
// Expect an object and an array/set/object of keys
|
|
obj, err := builtins.ObjectOperand(operands[0].Value, 1)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Build a new object from the supplied filter keys
|
|
keys, err := getObjectKeysParam(operands[1].Value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
filterObj := ast.NewObject()
|
|
keys.Foreach(func(key *ast.Term) {
|
|
filterObj.Insert(key, ast.NullTerm())
|
|
})
|
|
|
|
// Actually do the filtering
|
|
r, err := obj.Filter(filterObj)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return iter(ast.NewTerm(r))
|
|
}
|
|
|
|
func builtinObjectGet(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
|
object, err := builtins.ObjectOperand(operands[0].Value, 1)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if ret := object.Get(operands[1]); ret != nil {
|
|
return iter(ret)
|
|
}
|
|
|
|
return iter(operands[2])
|
|
}
|
|
|
|
// getObjectKeysParam returns a set of key values
|
|
// from a supplied ast array, object, set value
|
|
func getObjectKeysParam(arrayOrSet ast.Value) (ast.Set, error) {
|
|
keys := ast.NewSet()
|
|
|
|
switch v := arrayOrSet.(type) {
|
|
case ast.Array:
|
|
for _, f := range v {
|
|
keys.Add(f)
|
|
}
|
|
case ast.Set:
|
|
_ = v.Iter(func(f *ast.Term) error {
|
|
keys.Add(f)
|
|
return nil
|
|
})
|
|
case ast.Object:
|
|
_ = v.Iter(func(k *ast.Term, _ *ast.Term) error {
|
|
keys.Add(k)
|
|
return nil
|
|
})
|
|
default:
|
|
return nil, builtins.NewOperandTypeErr(2, arrayOrSet, ast.TypeName(types.Object{}), ast.TypeName(types.S), ast.TypeName(types.Array{}))
|
|
}
|
|
|
|
return keys, nil
|
|
}
|
|
|
|
func mergeWithOverwrite(objA, objB ast.Object) ast.Object {
|
|
merged, _ := objA.MergeWith(objB, func(v1, v2 *ast.Term) (*ast.Term, bool) {
|
|
originalValueObj, ok2 := v1.Value.(ast.Object)
|
|
updateValueObj, ok1 := v2.Value.(ast.Object)
|
|
if !ok1 || !ok2 {
|
|
// If we can't merge, stick with the right-hand value
|
|
return v2, false
|
|
}
|
|
|
|
// Recursively update the existing value
|
|
merged := mergeWithOverwrite(originalValueObj, updateValueObj)
|
|
return ast.NewTerm(merged), false
|
|
})
|
|
return merged
|
|
}
|
|
|
|
func init() {
|
|
RegisterBuiltinFunc(ast.ObjectUnion.Name, builtinObjectUnion)
|
|
RegisterBuiltinFunc(ast.ObjectRemove.Name, builtinObjectRemove)
|
|
RegisterBuiltinFunc(ast.ObjectFilter.Name, builtinObjectFilter)
|
|
RegisterBuiltinFunc(ast.ObjectGet.Name, builtinObjectGet)
|
|
}
|