// Copyright 2019 The OPA Authors. All rights reserved. // Use of this source code is governed by an Apache2 // license that can be found in the LICENSE file. package topdown import ( "fmt" "strconv" "strings" "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/topdown/builtins" ) func builtinJSONRemove(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { // Expect an object and a string or array/set of strings _, err := builtins.ObjectOperand(operands[0].Value, 1) if err != nil { return err } // Build a list of json pointers to remove paths, err := getJSONPaths(operands[1].Value) if err != nil { return err } newObj, err := jsonRemove(operands[0], ast.NewTerm(pathsToObject(paths))) if err != nil { return err } if newObj == nil { return nil } return iter(newObj) } // jsonRemove returns a new term that is the result of walking // through a and omitting removing any values that are in b but // have ast.Null values (ie leaf nodes for b). func jsonRemove(a *ast.Term, b *ast.Term) (*ast.Term, error) { if b == nil { // The paths diverged, return a return a, nil } var bObj ast.Object switch bValue := b.Value.(type) { case ast.Object: bObj = bValue case ast.Null: // Means we hit a leaf node on "b", dont add the value for a return nil, nil default: // The paths diverged, return a return a, nil } switch aValue := a.Value.(type) { case ast.String, ast.Number, ast.Boolean, ast.Null: return a, nil case ast.Object: newObj := ast.NewObject() err := aValue.Iter(func(k *ast.Term, v *ast.Term) error { // recurse and add the diff of sub objects as needed diffValue, err := jsonRemove(v, bObj.Get(k)) if err != nil || diffValue == nil { return err } newObj.Insert(k, diffValue) return nil }) if err != nil { return nil, err } return ast.NewTerm(newObj), nil case ast.Set: newSet := ast.NewSet() err := aValue.Iter(func(v *ast.Term) error { // recurse and add the diff of sub objects as needed diffValue, err := jsonRemove(v, bObj.Get(v)) if err != nil || diffValue == nil { return err } newSet.Add(diffValue) return nil }) if err != nil { return nil, err } return ast.NewTerm(newSet), nil case ast.Array: // When indexes are removed we shift left to close empty spots in the array // as per the JSON patch spec. var newArray ast.Array for i, v := range aValue { // recurse and add the diff of sub objects as needed // Note: Keys in b will be strings for the index, eg path /a/1/b => {"a": {"1": {"b": null}}} diffValue, err := jsonRemove(v, bObj.Get(ast.StringTerm(strconv.Itoa(i)))) if err != nil { return nil, err } if diffValue != nil { newArray = append(newArray, diffValue) } } return ast.NewTerm(newArray), nil default: return nil, fmt.Errorf("invalid value type %T", a) } } func builtinJSONFilter(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { // Ensure we have the right parameters, expect an object and a string or array/set of strings obj, err := builtins.ObjectOperand(operands[0].Value, 1) if err != nil { return err } // Build a list of filter strings filters, err := getJSONPaths(operands[1].Value) if err != nil { return err } // Actually do the filtering filterObj := pathsToObject(filters) r, err := obj.Filter(filterObj) if err != nil { return err } return iter(ast.NewTerm(r)) } func getJSONPaths(operand ast.Value) ([]ast.Ref, error) { var paths []ast.Ref switch v := operand.(type) { case ast.Array: for _, f := range v { filter, err := parsePath(f) if err != nil { return nil, err } paths = append(paths, filter) } case ast.Set: err := v.Iter(func(f *ast.Term) error { filter, err := parsePath(f) if err != nil { return err } paths = append(paths, filter) return nil }) if err != nil { return nil, err } default: return nil, builtins.NewOperandTypeErr(2, v, "set", "array") } return paths, nil } func parsePath(path *ast.Term) (ast.Ref, error) { // paths can either be a `/` separated json path or // an array or set of values var pathSegments ast.Ref switch p := path.Value.(type) { case ast.String: parts := strings.Split(strings.Trim(string(p), "/"), "/") for _, part := range parts { part = strings.ReplaceAll(strings.ReplaceAll(part, "~1", "/"), "~0", "~") pathSegments = append(pathSegments, ast.StringTerm(part)) } case ast.Array: for _, term := range p { pathSegments = append(pathSegments, term) } default: return nil, builtins.NewOperandErr(2, "must be one of {set, array} containing string paths or array of path segments but got %v", ast.TypeName(p)) } return pathSegments, nil } func pathsToObject(paths []ast.Ref) ast.Object { root := ast.NewObject() for _, path := range paths { node := root var done bool for i := 0; i < len(path)-1 && !done; i++ { k := path[i] child := node.Get(k) if child == nil { obj := ast.NewObject() node.Insert(k, ast.NewTerm(obj)) node = obj continue } switch v := child.Value.(type) { case ast.Null: done = true case ast.Object: node = v default: panic("unreachable") } } if !done { node.Insert(path[len(path)-1], ast.NullTerm()) } } return root } func init() { RegisterBuiltinFunc(ast.JSONFilter.Name, builtinJSONFilter) RegisterBuiltinFunc(ast.JSONRemove.Name, builtinJSONRemove) }