// Copyright 2016 The OPA Authors. All rights reserved. // Use of this source code is governed by an Apache2 // license that can be found in the LICENSE file. package topdown import ( "math/big" "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/topdown/builtins" ) func builtinCount(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch a := operands[0].Value.(type) { case *ast.Array: return iter(ast.IntNumberTerm(a.Len())) case ast.Object: return iter(ast.IntNumberTerm(a.Len())) case ast.Set: return iter(ast.IntNumberTerm(a.Len())) case ast.String: return iter(ast.IntNumberTerm(len([]rune(a)))) } return builtins.NewOperandTypeErr(1, operands[0].Value, "array", "object", "set", "string") } func builtinSum(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch a := operands[0].Value.(type) { case *ast.Array: sum := big.NewFloat(0) err := a.Iter(func(x *ast.Term) error { n, ok := x.Value.(ast.Number) if !ok { return builtins.NewOperandElementErr(1, a, x.Value, "number") } sum = new(big.Float).Add(sum, builtins.NumberToFloat(n)) return nil }) if err != nil { return err } return iter(ast.NewTerm(builtins.FloatToNumber(sum))) case ast.Set: sum := big.NewFloat(0) err := a.Iter(func(x *ast.Term) error { n, ok := x.Value.(ast.Number) if !ok { return builtins.NewOperandElementErr(1, a, x.Value, "number") } sum = new(big.Float).Add(sum, builtins.NumberToFloat(n)) return nil }) if err != nil { return err } return iter(ast.NewTerm(builtins.FloatToNumber(sum))) } return builtins.NewOperandTypeErr(1, operands[0].Value, "set", "array") } func builtinProduct(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch a := operands[0].Value.(type) { case *ast.Array: product := big.NewFloat(1) err := a.Iter(func(x *ast.Term) error { n, ok := x.Value.(ast.Number) if !ok { return builtins.NewOperandElementErr(1, a, x.Value, "number") } product = new(big.Float).Mul(product, builtins.NumberToFloat(n)) return nil }) if err != nil { return err } return iter(ast.NewTerm(builtins.FloatToNumber(product))) case ast.Set: product := big.NewFloat(1) err := a.Iter(func(x *ast.Term) error { n, ok := x.Value.(ast.Number) if !ok { return builtins.NewOperandElementErr(1, a, x.Value, "number") } product = new(big.Float).Mul(product, builtins.NumberToFloat(n)) return nil }) if err != nil { return err } return iter(ast.NewTerm(builtins.FloatToNumber(product))) } return builtins.NewOperandTypeErr(1, operands[0].Value, "set", "array") } func builtinMax(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch a := operands[0].Value.(type) { case *ast.Array: if a.Len() == 0 { return nil } var max = ast.Value(ast.Null{}) a.Foreach(func(x *ast.Term) { if ast.Compare(max, x.Value) <= 0 { max = x.Value } }) return iter(ast.NewTerm(max)) case ast.Set: if a.Len() == 0 { return nil } max, err := a.Reduce(ast.NullTerm(), func(max *ast.Term, elem *ast.Term) (*ast.Term, error) { if ast.Compare(max, elem) <= 0 { return elem, nil } return max, nil }) if err != nil { return err } return iter(max) } return builtins.NewOperandTypeErr(1, operands[0].Value, "set", "array") } func builtinMin(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch a := operands[0].Value.(type) { case *ast.Array: if a.Len() == 0 { return nil } min := a.Elem(0).Value a.Foreach(func(x *ast.Term) { if ast.Compare(min, x.Value) >= 0 { min = x.Value } }) return iter(ast.NewTerm(min)) case ast.Set: if a.Len() == 0 { return nil } min, err := a.Reduce(ast.NullTerm(), func(min *ast.Term, elem *ast.Term) (*ast.Term, error) { // The null term is considered to be less than any other term, // so in order for min of a set to make sense, we need to check // for it. if min.Value.Compare(ast.Null{}) == 0 { return elem, nil } if ast.Compare(min, elem) >= 0 { return elem, nil } return min, nil }) if err != nil { return err } return iter(min) } return builtins.NewOperandTypeErr(1, operands[0].Value, "set", "array") } func builtinSort(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch a := operands[0].Value.(type) { case *ast.Array: return iter(ast.NewTerm(a.Sorted())) case ast.Set: return iter(ast.NewTerm(a.Sorted())) } return builtins.NewOperandTypeErr(1, operands[0].Value, "set", "array") } func builtinAll(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch val := operands[0].Value.(type) { case ast.Set: res := true match := ast.BooleanTerm(true) val.Until(func(term *ast.Term) bool { if !match.Equal(term) { res = false return true } return false }) return iter(ast.BooleanTerm(res)) case *ast.Array: res := true match := ast.BooleanTerm(true) val.Until(func(term *ast.Term) bool { if !match.Equal(term) { res = false return true } return false }) return iter(ast.BooleanTerm(res)) default: return builtins.NewOperandTypeErr(1, operands[0].Value, "array", "set") } } func builtinAny(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { switch val := operands[0].Value.(type) { case ast.Set: res := val.Len() > 0 && val.Contains(ast.BooleanTerm(true)) return iter(ast.BooleanTerm(res)) case *ast.Array: res := false match := ast.BooleanTerm(true) val.Until(func(term *ast.Term) bool { if match.Equal(term) { res = true return true } return false }) return iter(ast.BooleanTerm(res)) default: return builtins.NewOperandTypeErr(1, operands[0].Value, "array", "set") } } func builtinMember(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { containee := operands[0] switch c := operands[1].Value.(type) { case ast.Set: return iter(ast.BooleanTerm(c.Contains(containee))) case *ast.Array: ret := false c.Until(func(v *ast.Term) bool { if v.Value.Compare(containee.Value) == 0 { ret = true } return ret }) return iter(ast.BooleanTerm(ret)) case ast.Object: ret := false c.Until(func(_, v *ast.Term) bool { if v.Value.Compare(containee.Value) == 0 { ret = true } return ret }) return iter(ast.BooleanTerm(ret)) } return iter(ast.BooleanTerm(false)) } func builtinMemberWithKey(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { key, val := operands[0], operands[1] switch c := operands[2].Value.(type) { case interface{ Get(*ast.Term) *ast.Term }: ret := false if act := c.Get(key); act != nil { ret = act.Value.Compare(val.Value) == 0 } return iter(ast.BooleanTerm(ret)) } return iter(ast.BooleanTerm(false)) } func init() { RegisterBuiltinFunc(ast.Count.Name, builtinCount) RegisterBuiltinFunc(ast.Sum.Name, builtinSum) RegisterBuiltinFunc(ast.Product.Name, builtinProduct) RegisterBuiltinFunc(ast.Max.Name, builtinMax) RegisterBuiltinFunc(ast.Min.Name, builtinMin) RegisterBuiltinFunc(ast.Sort.Name, builtinSort) RegisterBuiltinFunc(ast.Any.Name, builtinAny) RegisterBuiltinFunc(ast.All.Name, builtinAll) RegisterBuiltinFunc(ast.Member.Name, builtinMember) RegisterBuiltinFunc(ast.MemberWithKey.Name, builtinMemberWithKey) }