feat: kubesphere 4.0 (#6115)

* feat: kubesphere 4.0

Signed-off-by: ci-bot <ci-bot@kubesphere.io>

* feat: kubesphere 4.0

Signed-off-by: ci-bot <ci-bot@kubesphere.io>

---------

Signed-off-by: ci-bot <ci-bot@kubesphere.io>
Co-authored-by: ks-ci-bot <ks-ci-bot@example.com>
Co-authored-by: joyceliu <joyceliu@yunify.com>
This commit is contained in:
KubeSphere CI Bot
2024-09-06 11:05:52 +08:00
committed by GitHub
parent b5015ec7b9
commit 447a51f08b
8557 changed files with 546695 additions and 1146174 deletions

View File

@@ -5,9 +5,8 @@
package topdown
import (
"math/big"
"fmt"
"math/big"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
@@ -54,6 +53,54 @@ func arithFloor(a *big.Float) (*big.Float, error) {
return new(big.Float).Sub(f, big.NewFloat(1.0)), nil
}
func builtinPlus(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
n1, err := builtins.NumberOperand(operands[0].Value, 1)
if err != nil {
return err
}
n2, err := builtins.NumberOperand(operands[1].Value, 2)
if err != nil {
return err
}
x, ok1 := n1.Int()
y, ok2 := n2.Int()
if ok1 && ok2 && inSmallIntRange(x) && inSmallIntRange(y) {
return iter(ast.IntNumberTerm(x + y))
}
f, err := arithPlus(builtins.NumberToFloat(n1), builtins.NumberToFloat(n2))
if err != nil {
return err
}
return iter(ast.NewTerm(builtins.FloatToNumber(f)))
}
func builtinMultiply(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
n1, err := builtins.NumberOperand(operands[0].Value, 1)
if err != nil {
return err
}
n2, err := builtins.NumberOperand(operands[1].Value, 2)
if err != nil {
return err
}
x, ok1 := n1.Int()
y, ok2 := n2.Int()
if ok1 && ok2 && inSmallIntRange(x) && inSmallIntRange(y) {
return iter(ast.IntNumberTerm(x * y))
}
f, err := arithMultiply(builtins.NumberToFloat(n1), builtins.NumberToFloat(n2))
if err != nil {
return err
}
return iter(ast.NewTerm(builtins.FloatToNumber(f)))
}
func arithPlus(a, b *big.Float) (*big.Float, error) {
return new(big.Float).Add(a, b), nil
}
@@ -119,6 +166,14 @@ func builtinMinus(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) e
n2, ok2 := operands[1].Value.(ast.Number)
if ok1 && ok2 {
x, okx := n1.Int()
y, oky := n2.Int()
if okx && oky && inSmallIntRange(x) && inSmallIntRange(y) {
return iter(ast.IntNumberTerm(x - y))
}
f, err := arithMinus(builtins.NumberToFloat(n1), builtins.NumberToFloat(n2))
if err != nil {
return err
@@ -150,6 +205,17 @@ func builtinRem(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) err
if ok1 && ok2 {
x, okx := n1.Int()
y, oky := n2.Int()
if okx && oky && inSmallIntRange(x) && inSmallIntRange(y) {
if y == 0 {
return fmt.Errorf("modulo by zero")
}
return iter(ast.IntNumberTerm(x % y))
}
op1, err1 := builtins.NumberToInt(n1)
op2, err2 := builtins.NumberToInt(n2)
@@ -171,14 +237,18 @@ func builtinRem(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) err
return builtins.NewOperandTypeErr(2, operands[1].Value, "number")
}
func inSmallIntRange(num int) bool {
return -1000 < num && num < 1000
}
func init() {
RegisterBuiltinFunc(ast.Abs.Name, builtinArithArity1(arithAbs))
RegisterBuiltinFunc(ast.Round.Name, builtinArithArity1(arithRound))
RegisterBuiltinFunc(ast.Ceil.Name, builtinArithArity1(arithCeil))
RegisterBuiltinFunc(ast.Floor.Name, builtinArithArity1(arithFloor))
RegisterBuiltinFunc(ast.Plus.Name, builtinArithArity2(arithPlus))
RegisterBuiltinFunc(ast.Plus.Name, builtinPlus)
RegisterBuiltinFunc(ast.Minus.Name, builtinMinus)
RegisterBuiltinFunc(ast.Multiply.Name, builtinArithArity2(arithMultiply))
RegisterBuiltinFunc(ast.Multiply.Name, builtinMultiply)
RegisterBuiltinFunc(ast.Divide.Name, builtinArithArity2(arithDivide))
RegisterBuiltinFunc(ast.Rem.Name, builtinRem)
}

View File

@@ -182,17 +182,19 @@ func handleBuiltinErr(name string, loc *ast.Location, err error) error {
case *Error, Halt:
return err
case builtins.ErrOperand:
return &Error{
e := &Error{
Code: TypeErr,
Message: fmt.Sprintf("%v: %v", name, err.Error()),
Location: loc,
}
return e.Wrap(err)
default:
return &Error{
e := &Error{
Code: BuiltinErr,
Message: fmt.Sprintf("%v: %v", name, err.Error()),
Location: loc,
}
return e.Wrap(err)
}
}

View File

@@ -249,7 +249,11 @@ func NumberToFloat(n ast.Number) *big.Float {
// FloatToNumber converts f to a number.
func FloatToNumber(f *big.Float) ast.Number {
return ast.Number(f.Text('g', -1))
var format byte = 'g'
if f.IsInt() {
format = 'f'
}
return ast.Number(f.Text(format, -1))
}
// NumberToInt converts n to a big int.

View File

@@ -7,16 +7,20 @@ package cache
import (
"container/list"
"context"
"fmt"
"math"
"sync"
"time"
"github.com/open-policy-agent/opa/ast"
"sync"
"github.com/open-policy-agent/opa/util"
)
const (
defaultMaxSizeBytes = int64(0) // unlimited
defaultMaxSizeBytes = int64(0) // unlimited
defaultForcedEvictionThresholdPercentage = int64(100) // trigger at max_size_bytes
defaultStaleEntryEvictionPeriodSeconds = int64(0) // never
)
// Config represents the configuration of the inter-query cache.
@@ -25,8 +29,13 @@ type Config struct {
}
// InterQueryBuiltinCacheConfig represents the configuration of the inter-query cache that built-in functions can utilize.
// MaxSizeBytes - max capacity of cache in bytes
// ForcedEvictionThresholdPercentage - capacity usage in percentage after which forced FIFO eviction starts
// StaleEntryEvictionPeriodSeconds - time period between end of previous and start of new stale entry eviction routine
type InterQueryBuiltinCacheConfig struct {
MaxSizeBytes *int64 `json:"max_size_bytes,omitempty"`
MaxSizeBytes *int64 `json:"max_size_bytes,omitempty"`
ForcedEvictionThresholdPercentage *int64 `json:"forced_eviction_threshold_percentage,omitempty"`
StaleEntryEvictionPeriodSeconds *int64 `json:"stale_entry_eviction_period_seconds,omitempty"`
}
// ParseCachingConfig returns the config for the inter-query cache.
@@ -34,7 +43,11 @@ func ParseCachingConfig(raw []byte) (*Config, error) {
if raw == nil {
maxSize := new(int64)
*maxSize = defaultMaxSizeBytes
return &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize}}, nil
threshold := new(int64)
*threshold = defaultForcedEvictionThresholdPercentage
period := new(int64)
*period = defaultStaleEntryEvictionPeriodSeconds
return &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize, ForcedEvictionThresholdPercentage: threshold, StaleEntryEvictionPeriodSeconds: period}}, nil
}
var config Config
@@ -56,34 +69,88 @@ func (c *Config) validateAndInjectDefaults() error {
*maxSize = defaultMaxSizeBytes
c.InterQueryBuiltinCache.MaxSizeBytes = maxSize
}
if c.InterQueryBuiltinCache.ForcedEvictionThresholdPercentage == nil {
threshold := new(int64)
*threshold = defaultForcedEvictionThresholdPercentage
c.InterQueryBuiltinCache.ForcedEvictionThresholdPercentage = threshold
} else {
threshold := *c.InterQueryBuiltinCache.ForcedEvictionThresholdPercentage
if threshold < 0 || threshold > 100 {
return fmt.Errorf("invalid forced_eviction_threshold_percentage %v", threshold)
}
}
if c.InterQueryBuiltinCache.StaleEntryEvictionPeriodSeconds == nil {
period := new(int64)
*period = defaultStaleEntryEvictionPeriodSeconds
c.InterQueryBuiltinCache.StaleEntryEvictionPeriodSeconds = period
} else {
period := *c.InterQueryBuiltinCache.StaleEntryEvictionPeriodSeconds
if period < 0 {
return fmt.Errorf("invalid stale_entry_eviction_period_seconds %v", period)
}
}
return nil
}
// InterQueryCacheValue defines the interface for the data that the inter-query cache holds.
type InterQueryCacheValue interface {
SizeInBytes() int64
Clone() (InterQueryCacheValue, error)
}
// InterQueryCache defines the interface for the inter-query cache.
type InterQueryCache interface {
Get(key ast.Value) (value InterQueryCacheValue, found bool)
Insert(key ast.Value, value InterQueryCacheValue) int
InsertWithExpiry(key ast.Value, value InterQueryCacheValue, expiresAt time.Time) int
Delete(key ast.Value)
UpdateConfig(config *Config)
Clone(value InterQueryCacheValue) (InterQueryCacheValue, error)
}
// NewInterQueryCache returns a new inter-query cache.
// The cache uses a FIFO eviction policy when it reaches the forced eviction threshold.
// Parameters:
//
// config - to configure the InterQueryCache
func NewInterQueryCache(config *Config) InterQueryCache {
return &cache{
items: map[string]cacheItem{},
usage: 0,
config: config,
l: list.New(),
return newCache(config)
}
// NewInterQueryCacheWithContext returns a new inter-query cache with context.
// The cache uses a combination of FIFO eviction policy when it reaches the forced eviction threshold
// and a periodic cleanup routine to remove stale entries that exceed their expiration time, if specified.
// If configured with a zero stale_entry_eviction_period_seconds value, the stale entry cleanup routine is disabled.
//
// Parameters:
//
// ctx - used to control lifecycle of the stale entry cleanup routine
// config - to configure the InterQueryCache
func NewInterQueryCacheWithContext(ctx context.Context, config *Config) InterQueryCache {
iqCache := newCache(config)
if iqCache.staleEntryEvictionTimePeriodSeconds() > 0 {
cleanupTicker := time.NewTicker(time.Duration(iqCache.staleEntryEvictionTimePeriodSeconds()) * time.Second)
go func() {
for {
select {
case <-cleanupTicker.C:
cleanupTicker.Stop()
iqCache.cleanStaleValues()
cleanupTicker = time.NewTicker(time.Duration(iqCache.staleEntryEvictionTimePeriodSeconds()) * time.Second)
case <-ctx.Done():
cleanupTicker.Stop()
return
}
}
}()
}
return iqCache
}
type cacheItem struct {
value InterQueryCacheValue
expiresAt time.Time
keyElement *list.Element
}
@@ -95,11 +162,26 @@ type cache struct {
mtx sync.Mutex
}
// Insert inserts a key k into the cache with value v.
func (c *cache) Insert(k ast.Value, v InterQueryCacheValue) (dropped int) {
func newCache(config *Config) *cache {
return &cache{
items: map[string]cacheItem{},
usage: 0,
config: config,
l: list.New(),
}
}
// InsertWithExpiry inserts a key k into the cache with value v with an expiration time expiresAt.
// A zero time value for expiresAt indicates no expiry
func (c *cache) InsertWithExpiry(k ast.Value, v InterQueryCacheValue, expiresAt time.Time) (dropped int) {
c.mtx.Lock()
defer c.mtx.Unlock()
return c.unsafeInsert(k, v)
return c.unsafeInsert(k, v, expiresAt)
}
// Insert inserts a key k into the cache with value v with no expiration time.
func (c *cache) Insert(k ast.Value, v InterQueryCacheValue) (dropped int) {
return c.InsertWithExpiry(k, v, time.Time{})
}
// Get returns the value in the cache for k.
@@ -130,10 +212,15 @@ func (c *cache) UpdateConfig(config *Config) {
c.config = config
}
func (c *cache) unsafeInsert(k ast.Value, v InterQueryCacheValue) (dropped int) {
size := v.SizeInBytes()
limit := c.maxSizeBytes()
func (c *cache) Clone(value InterQueryCacheValue) (InterQueryCacheValue, error) {
c.mtx.Lock()
defer c.mtx.Unlock()
return c.unsafeClone(value)
}
func (c *cache) unsafeInsert(k ast.Value, v InterQueryCacheValue, expiresAt time.Time) (dropped int) {
size := v.SizeInBytes()
limit := int64(math.Ceil(float64(c.forcedEvictionThresholdPercentage())/100.0) * (float64(c.maxSizeBytes())))
if limit > 0 {
if size > limit {
dropped++
@@ -152,6 +239,7 @@ func (c *cache) unsafeInsert(k ast.Value, v InterQueryCacheValue) (dropped int)
c.items[k.String()] = cacheItem{
value: v,
expiresAt: expiresAt,
keyElement: c.l.PushBack(k),
}
c.usage += size
@@ -174,9 +262,42 @@ func (c *cache) unsafeDelete(k ast.Value) {
c.l.Remove(cacheItem.keyElement)
}
func (c *cache) unsafeClone(value InterQueryCacheValue) (InterQueryCacheValue, error) {
return value.Clone()
}
func (c *cache) maxSizeBytes() int64 {
if c.config == nil {
return defaultMaxSizeBytes
}
return *c.config.InterQueryBuiltinCache.MaxSizeBytes
}
func (c *cache) forcedEvictionThresholdPercentage() int64 {
if c.config == nil {
return defaultForcedEvictionThresholdPercentage
}
return *c.config.InterQueryBuiltinCache.ForcedEvictionThresholdPercentage
}
func (c *cache) staleEntryEvictionTimePeriodSeconds() int64 {
if c.config == nil {
return defaultStaleEntryEvictionPeriodSeconds
}
return *c.config.InterQueryBuiltinCache.StaleEntryEvictionPeriodSeconds
}
func (c *cache) cleanStaleValues() (dropped int) {
c.mtx.Lock()
defer c.mtx.Unlock()
for key := c.l.Front(); key != nil; {
nextKey := key.Next()
// if expiresAt is zero, the item doesn't have an expiry
if ea := c.items[(key.Value.(ast.Value)).String()].expiresAt; !ea.IsZero() && ea.Before(time.Now()) {
c.unsafeDelete(key.Value.(ast.Value))
dropped++
}
key = nextKey
}
return dropped
}

View File

@@ -6,11 +6,13 @@ package topdown
import (
"bytes"
"crypto"
"crypto/hmac"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
@@ -20,8 +22,9 @@ import (
"os"
"strings"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/internal/jwx/jwk"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
"github.com/open-policy-agent/opa/util"
)
@@ -38,7 +41,8 @@ const (
blockTypeRSAPrivateKey = "RSA PRIVATE KEY"
// blockTypeRSAPrivateKey indicates this PEM block contains a RSA private key.
// Exported for tests.
blockTypePrivateKey = "PRIVATE KEY"
blockTypePrivateKey = "PRIVATE KEY"
blockTypeEcPrivateKey = "EC PRIVATE KEY"
)
func builtinCryptoX509ParseCertificates(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
@@ -52,7 +56,7 @@ func builtinCryptoX509ParseCertificates(_ BuiltinContext, operands []*ast.Term,
return err
}
v, err := ast.InterfaceToValue(certs)
v, err := ast.InterfaceToValue(extendCertificates(certs))
if err != nil {
return err
}
@@ -60,6 +64,28 @@ func builtinCryptoX509ParseCertificates(_ BuiltinContext, operands []*ast.Term,
return iter(ast.NewTerm(v))
}
// extendedCert is a wrapper around x509.Certificate that adds additional fields for JSON serialization.
type extendedCert struct {
x509.Certificate
URIStrings []string
}
func extendCertificates(certs []*x509.Certificate) []extendedCert {
// add a field to certs containing the URIs as strings
processedCerts := make([]extendedCert, len(certs))
for i, cert := range certs {
processedCerts[i].Certificate = *cert
if cert.URIs != nil {
processedCerts[i].URIStrings = make([]string, len(cert.URIs))
for j, uri := range cert.URIs {
processedCerts[i].URIStrings[j] = uri.String()
}
}
}
return processedCerts
}
func builtinCryptoX509ParseAndVerifyCertificates(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
a := operands[0].Value
@@ -83,7 +109,7 @@ func builtinCryptoX509ParseAndVerifyCertificates(_ BuiltinContext, operands []*a
return iter(invalid)
}
value, err := ast.InterfaceToValue(verified)
value, err := ast.InterfaceToValue(extendCertificates(verified))
if err != nil {
return err
}
@@ -96,6 +122,29 @@ func builtinCryptoX509ParseAndVerifyCertificates(_ BuiltinContext, operands []*a
return iter(valid)
}
func builtinCryptoX509ParseKeyPair(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
certificate, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return err
}
key, err := builtins.StringOperand(operands[1].Value, 1)
if err != nil {
return err
}
certs, err := getTLSx509KeyPairFromString([]byte(certificate), []byte(key))
if err != nil {
return err
}
v, err := ast.InterfaceToValue(certs)
if err != nil {
return err
}
return iter(ast.NewTerm(v))
}
func builtinCryptoX509ParseCertificateRequest(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
input, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
@@ -144,7 +193,8 @@ func builtinCryptoX509ParseCertificateRequest(_ BuiltinContext, operands []*ast.
return iter(ast.NewTerm(v))
}
func builtinCryptoX509ParseRSAPrivateKey(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
func builtinCryptoJWKFromPrivateKey(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
var x interface{}
a := operands[0].Value
input, err := builtins.StringOperand(a, 1)
@@ -153,23 +203,83 @@ func builtinCryptoX509ParseRSAPrivateKey(_ BuiltinContext, operands []*ast.Term,
}
// get the raw private key
rawKey, err := getRSAPrivateKeyFromString(string(input))
pemDataString := string(input)
if pemDataString == "" {
return fmt.Errorf("input PEM data was empty")
}
// This built in must be supplied a valid PEM or base64 encoded string.
// If the input is not a PEM string, attempt to decode b64.
// If the base64 decode fails - this is an error
if !strings.HasPrefix(pemDataString, "-----BEGIN") {
bs, err := base64.StdEncoding.DecodeString(pemDataString)
if err != nil {
return err
}
pemDataString = string(bs)
}
rawKeys, err := getPrivateKeysFromPEMData(pemDataString)
if err != nil {
return err
}
rsaPrivateKey, err := jwk.New(rawKey)
if len(rawKeys) == 0 {
return iter(ast.NullTerm())
}
key, err := jwk.New(rawKeys[0])
if err != nil {
return err
}
jsonKey, err := json.Marshal(rsaPrivateKey)
jsonKey, err := json.Marshal(key)
if err != nil {
return err
}
if err := util.UnmarshalJSON(jsonKey, &x); err != nil {
return err
}
value, err := ast.InterfaceToValue(x)
if err != nil {
return err
}
return iter(ast.NewTerm(value))
}
func builtinCryptoParsePrivateKeys(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
a := operands[0].Value
input, err := builtins.StringOperand(a, 1)
if err != nil {
return err
}
if string(input) == "" {
return iter(ast.NullTerm())
}
// get the raw private key
rawKeys, err := getPrivateKeysFromPEMData(string(input))
if err != nil {
return err
}
if len(rawKeys) == 0 {
return iter(ast.NewTerm(ast.NewArray()))
}
bs, err := json.Marshal(rawKeys)
if err != nil {
return err
}
var x interface{}
if err := util.UnmarshalJSON(jsonKey, &x); err != nil {
if err := util.UnmarshalJSON(bs, &x); err != nil {
return err
}
@@ -249,6 +359,24 @@ func builtinCryptoHmacSha512(_ BuiltinContext, operands []*ast.Term, iter func(*
return hmacHelper(operands, iter, sha512.New)
}
func builtinCryptoHmacEqual(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
a1 := operands[0].Value
mac1, err := builtins.StringOperand(a1, 1)
if err != nil {
return err
}
a2 := operands[1].Value
mac2, err := builtins.StringOperand(a2, 2)
if err != nil {
return err
}
res := hmac.Equal([]byte(mac1), []byte(mac2))
return iter(ast.BooleanTerm(res))
}
func init() {
RegisterBuiltinFunc(ast.CryptoX509ParseCertificates.Name, builtinCryptoX509ParseCertificates)
RegisterBuiltinFunc(ast.CryptoX509ParseAndVerifyCertificates.Name, builtinCryptoX509ParseAndVerifyCertificates)
@@ -256,11 +384,14 @@ func init() {
RegisterBuiltinFunc(ast.CryptoSha1.Name, builtinCryptoSha1)
RegisterBuiltinFunc(ast.CryptoSha256.Name, builtinCryptoSha256)
RegisterBuiltinFunc(ast.CryptoX509ParseCertificateRequest.Name, builtinCryptoX509ParseCertificateRequest)
RegisterBuiltinFunc(ast.CryptoX509ParseRSAPrivateKey.Name, builtinCryptoX509ParseRSAPrivateKey)
RegisterBuiltinFunc(ast.CryptoX509ParseRSAPrivateKey.Name, builtinCryptoJWKFromPrivateKey)
RegisterBuiltinFunc(ast.CryptoParsePrivateKeys.Name, builtinCryptoParsePrivateKeys)
RegisterBuiltinFunc(ast.CryptoX509ParseKeyPair.Name, builtinCryptoX509ParseKeyPair)
RegisterBuiltinFunc(ast.CryptoHmacMd5.Name, builtinCryptoHmacMd5)
RegisterBuiltinFunc(ast.CryptoHmacSha1.Name, builtinCryptoHmacSha1)
RegisterBuiltinFunc(ast.CryptoHmacSha256.Name, builtinCryptoHmacSha256)
RegisterBuiltinFunc(ast.CryptoHmacSha512.Name, builtinCryptoHmacSha512)
RegisterBuiltinFunc(ast.CryptoHmacEqual.Name, builtinCryptoHmacEqual)
}
func verifyX509CertificateChain(certs []*x509.Certificate) ([]*x509.Certificate, error) {
@@ -334,43 +465,56 @@ func getX509CertsFromPem(pemBlocks []byte) ([]*x509.Certificate, error) {
return x509.ParseCertificates(decodedCerts)
}
func getRSAPrivateKeyFromString(key string) (interface{}, error) {
// if the input is PEM handle that
if strings.HasPrefix(key, "-----BEGIN") {
return getRSAPrivateKeyFromPEM([]byte(key))
func getPrivateKeysFromPEMData(pemData string) ([]crypto.PrivateKey, error) {
pemBlockString := pemData
var validPrivateKeys []crypto.PrivateKey
// if the input is base64, decode it
bs, err := base64.StdEncoding.DecodeString(pemBlockString)
if err == nil {
pemBlockString = string(bs)
}
bs = []byte(pemBlockString)
// assume input is base64 if not PEM
b64, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return nil, err
for len(bs) > 0 {
inputLen := len(bs)
var block *pem.Block
block, bs = pem.Decode(bs)
if block == nil && len(bs) == 0 {
break
}
// should only happen if end of input is not a valid PEM block. See TestParseRSAPrivateKeyVariedPemInput.
if inputLen == len(bs) {
break
}
if block == nil {
continue
}
switch block.Type {
case blockTypeRSAPrivateKey:
parsedKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
validPrivateKeys = append(validPrivateKeys, parsedKey)
case blockTypePrivateKey:
parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
validPrivateKeys = append(validPrivateKeys, parsedKey)
case blockTypeEcPrivateKey:
parsedKey, err := x509.ParseECPrivateKey(block.Bytes)
if err != nil {
return nil, err
}
validPrivateKeys = append(validPrivateKeys, parsedKey)
}
}
return getRSAPrivateKeyFromPEM(b64)
}
func getRSAPrivateKeyFromPEM(pemBlocks []byte) (interface{}, error) {
// decode the pem into the Block struct
p, _ := pem.Decode(pemBlocks)
if p == nil {
return nil, fmt.Errorf("failed to parse PEM block containing the key")
}
// if the key is in PKCS1 format
if p.Type == blockTypeRSAPrivateKey {
return x509.ParsePKCS1PrivateKey(p.Bytes)
}
// if the key is in PKCS8 format
if p.Type == blockTypePrivateKey {
return x509.ParsePKCS8PrivateKey(p.Bytes)
}
// unsupported key format
return nil, fmt.Errorf("PEM block type is '%s', expected %s or %s", p.Type, blockTypeRSAPrivateKey,
blockTypePrivateKey)
return validPrivateKeys, nil
}
// addCACertsFromFile adds CA certificates from filePath into the given pool.
@@ -406,7 +550,7 @@ func addCACertsFromBytes(pool *x509.CertPool, pemBytes []byte) (*x509.CertPool,
return pool, nil
}
// addCACertsFromBytes adds CA certificates from the environment variable named
// addCACertsFromEnv adds CA certificates from the environment variable named
// by envName into the given pool. If pool is nil, it creates a new x509.CertPool.
// pool is returned.
func addCACertsFromEnv(pool *x509.CertPool, envName string) (*x509.CertPool, error) {
@@ -428,6 +572,60 @@ func readCertFromFile(localCertFile string) ([]byte, error) {
return certPEM, nil
}
func getTLSx509KeyPairFromString(certPemBlock []byte, keyPemBlock []byte) (*tls.Certificate, error) {
if !strings.HasPrefix(string(certPemBlock), "-----BEGIN") {
s, err := base64.StdEncoding.DecodeString(string(certPemBlock))
if err != nil {
return nil, err
}
certPemBlock = s
}
if !strings.HasPrefix(string(keyPemBlock), "-----BEGIN") {
s, err := base64.StdEncoding.DecodeString(string(keyPemBlock))
if err != nil {
return nil, err
}
keyPemBlock = s
}
// we assume it a DER certificate and try to convert it to a PEM.
if !bytes.HasPrefix(certPemBlock, []byte("-----BEGIN")) {
pemBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: certPemBlock,
}
var buf bytes.Buffer
if err := pem.Encode(&buf, pemBlock); err != nil {
return nil, err
}
certPemBlock = buf.Bytes()
}
// we assume it a DER key and try to convert it to a PEM.
if !bytes.HasPrefix(keyPemBlock, []byte("-----BEGIN")) {
pemBlock := &pem.Block{
Type: "PRIVATE KEY",
Bytes: keyPemBlock,
}
var buf bytes.Buffer
if err := pem.Encode(&buf, pemBlock); err != nil {
return nil, err
}
keyPemBlock = buf.Bytes()
}
cert, err := tls.X509KeyPair(certPemBlock, keyPemBlock)
if err != nil {
return nil, err
}
return &cert, nil
}
// ReadKeyFromFile reads a key from file
func readKeyFromFile(localKeyFile string) ([]byte, error) {
// Read in the cert file

View File

@@ -13,7 +13,7 @@ import (
"net/url"
"strings"
ghodss "github.com/ghodss/yaml"
"sigs.k8s.io/yaml"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
@@ -232,7 +232,7 @@ func builtinYAMLMarshal(_ BuiltinContext, operands []*ast.Term, iter func(*ast.T
return err
}
bs, err := ghodss.JSONToYAML(buf.Bytes())
bs, err := yaml.JSONToYAML(buf.Bytes())
if err != nil {
return err
}
@@ -247,7 +247,7 @@ func builtinYAMLUnmarshal(_ BuiltinContext, operands []*ast.Term, iter func(*ast
return err
}
bs, err := ghodss.YAMLToJSON([]byte(str))
bs, err := yaml.YAMLToJSON([]byte(str))
if err != nil {
return err
}
@@ -273,7 +273,7 @@ func builtinYAMLIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.T
}
var x interface{}
err = ghodss.Unmarshal([]byte(str), &x)
err = yaml.Unmarshal([]byte(str), &x)
return iter(ast.BooleanTerm(err == nil))
}

View File

@@ -29,6 +29,7 @@ type Error struct {
Code string `json:"code"`
Message string `json:"message"`
Location *ast.Location `json:"location,omitempty"`
err error `json:"-"`
}
const (
@@ -90,6 +91,15 @@ func (e *Error) Error() string {
return msg
}
func (e *Error) Wrap(err error) *Error {
e.err = err
return e
}
func (e *Error) Unwrap() error {
return e.err
}
func functionConflictErr(loc *ast.Location) error {
return &Error{
Code: ConflictErr,

View File

@@ -23,6 +23,8 @@ type evalIterator func(*eval) error
type unifyIterator func() error
type unifyRefIterator func(pos int) error
type queryIDFactory struct {
curr uint64
}
@@ -1795,13 +1797,16 @@ type evalFunc struct {
func (e evalFunc) eval(iter unifyIterator) error {
// default functions aren't supported:
// https://github.com/open-policy-agent/opa/issues/2445
if len(e.ir.Rules) == 0 {
if e.ir.Empty() {
return nil
}
argCount := len(e.ir.Rules[0].Head.Args)
var argCount int
if len(e.ir.Rules) > 0 {
argCount = len(e.ir.Rules[0].Head.Args)
} else if e.ir.Default != nil {
argCount = len(e.ir.Default.Head.Args)
}
if len(e.ir.Else) > 0 && e.e.unknown(e.e.query[e.e.index], e.e.bindings) {
// Partial evaluation of ordered rules is not supported currently. Save the
@@ -1820,6 +1825,7 @@ func (e evalFunc) eval(iter unifyIterator) error {
return e.partialEvalSupport(argCount, iter)
}
}
return suppressEarlyExit(e.evalValue(iter, argCount, e.ir.EarlyExit))
}
@@ -1859,6 +1865,11 @@ func (e evalFunc) evalValue(iter unifyIterator, argCount int, findOne bool) erro
}
}
if e.ir.Default != nil && prev == nil {
_, err := e.evalOneRule(iter, e.ir.Default, cacheKey, prev, findOne)
return err
}
return nil
}
@@ -2269,6 +2280,14 @@ func (e evalVirtual) eval(iter unifyIterator) error {
switch ir.Kind {
case ast.MultiValue:
var empty *ast.Term
if ir.OnlyGroundRefs {
// rule ref contains no vars, so we're building a set
empty = ast.SetTerm()
} else {
// rule ref contains vars, so we're building an object containing a set leaf
empty = ast.ObjectTerm()
}
eval := evalVirtualPartial{
e: e.e,
ref: e.ref,
@@ -2278,12 +2297,10 @@ func (e evalVirtual) eval(iter unifyIterator) error {
bindings: e.bindings,
rterm: e.rterm,
rbindings: e.rbindings,
empty: ast.SetTerm(),
empty: empty,
}
return eval.eval(iter)
case ast.SingleValue:
// NOTE(sr): If we allow vars in others than the last position of a ref, we need
// to start reworking things here
if ir.OnlyGroundRefs {
eval := evalVirtualComplete{
e: e.e,
@@ -2350,9 +2367,31 @@ func (e evalVirtualPartial) eval(iter unifyIterator) error {
return e.evalEachRule(iter, unknown)
}
// returns the maximum length a ref can be without being longer than the longest rule ref in rules.
func maxRefLength(rules []*ast.Rule, ceil int) int {
var l int
for _, r := range rules {
rl := len(r.Ref())
if r.Head.RuleKind() == ast.MultiValue {
rl = rl + 1
}
if rl >= ceil {
return ceil
} else if rl > l {
l = rl
}
}
return l
}
func (e evalVirtualPartial) evalEachRule(iter unifyIterator, unknown bool) error {
if e.e.unknown(e.ref[e.pos+1], e.bindings) {
if e.ir.Empty() {
return nil
}
m := maxRefLength(e.ir.Rules, len(e.ref))
if e.e.unknown(e.ref[e.pos+1:m], e.bindings) {
for _, rule := range e.ir.Rules {
if err := e.evalOneRulePostUnify(iter, rule); err != nil {
return err
@@ -2378,12 +2417,25 @@ func (e evalVirtualPartial) evalEachRule(iter unifyIterator, unknown bool) error
}
result := e.empty
var visitedRefs []ast.Ref
for _, rule := range e.ir.Rules {
if err := e.evalOneRulePreUnify(iter, rule, hint, result, unknown); err != nil {
result, err = e.evalOneRulePreUnify(iter, rule, result, unknown, &visitedRefs)
if err != nil {
return err
}
}
if hint.key != nil {
if v, err := result.Value.Find(hint.key[e.pos+1:]); err == nil && v != nil {
e.e.virtualCache.Put(hint.key, ast.NewTerm(v))
}
}
if !unknown {
return e.evalTerm(iter, e.pos+1, result, e.bindings)
}
return nil
}
@@ -2413,13 +2465,15 @@ func (e evalVirtualPartial) evalAllRules(iter unifyIterator, rules []*ast.Rule)
func (e evalVirtualPartial) evalAllRulesNoCache(rules []*ast.Rule) (*ast.Term, error) {
result := e.empty
var visitedRefs []ast.Ref
for _, rule := range rules {
child := e.e.child(rule.Body)
child.traceEnter(rule)
err := child.eval(func(*eval) error {
child.traceExit(rule)
var err error
result, _, err = e.reduce(rule.Head, child.bindings, result)
result, _, err = e.reduce(rule, child.bindings, result, &visitedRefs)
if err != nil {
return err
}
@@ -2436,9 +2490,18 @@ func (e evalVirtualPartial) evalAllRulesNoCache(rules []*ast.Rule) (*ast.Term, e
return result, nil
}
func (e evalVirtualPartial) evalOneRulePreUnify(iter unifyIterator, rule *ast.Rule, hint evalVirtualPartialCacheHint, result *ast.Term, unknown bool) error {
func wrapInObjects(leaf *ast.Term, ref ast.Ref) *ast.Term {
// We build the nested objects leaf-to-root to preserve ground:ness
if len(ref) == 0 {
return leaf
}
key := ref[0]
val := wrapInObjects(leaf, ref[1:])
return ast.ObjectTerm(ast.Item(key, val))
}
func (e evalVirtualPartial) evalOneRulePreUnify(iter unifyIterator, rule *ast.Rule, result *ast.Term, unknown bool, visitedRefs *[]ast.Ref) (*ast.Term, error) {
key := e.ref[e.pos+1]
child := e.e.child(rule.Body)
child.traceEnter(rule)
@@ -2448,63 +2511,89 @@ func (e evalVirtualPartial) evalOneRulePreUnify(iter unifyIterator, rule *ast.Ru
if headKey == nil {
headKey = rule.Head.Reference[len(rule.Head.Reference)-1]
}
err := child.biunify(headKey, key, child.bindings, e.bindings, func() error {
// Walk the dynamic portion of rule ref and key to unify vars
err := child.biunifyRuleHead(e.pos+1, e.ref, rule, e.bindings, child.bindings, func(pos int) error {
defined = true
return child.eval(func(child *eval) error {
child.traceExit(rule)
term := rule.Head.Value
if term == nil {
term = headKey
}
if hint.key != nil {
result := child.bindings.Plug(term)
e.e.virtualCache.Put(hint.key, result)
}
if unknown {
term, termbindings := child.bindings.apply(term)
// NOTE(tsandall): if the rule set depends on any unknowns then do
// not perform the duplicate check because evaluation of the ruleset
// may not produce a definitive result. This is a bit strict--we
// could improve by skipping only when saves occur.
if !unknown {
var dup bool
var err error
result, dup, err = e.reduce(rule.Head, child.bindings, result)
if rule.Head.RuleKind() == ast.MultiValue {
term = ast.SetTerm(term)
}
objRef := rule.Ref()[e.pos+1:]
term = wrapInObjects(term, objRef)
err := e.evalTerm(iter, e.pos+1, term, termbindings)
if err != nil {
return err
} else if dup {
}
} else {
var dup bool
var err error
result, dup, err = e.reduce(rule, child.bindings, result, visitedRefs)
if err != nil {
return err
} else if !unknown && dup {
child.traceDuplicate(rule)
return nil
}
}
child.traceExit(rule)
term, termbindings := child.bindings.apply(term)
err := e.evalTerm(iter, e.pos+2, term, termbindings)
if err != nil {
return err
}
child.traceRedo(rule)
return nil
})
})
if err != nil {
return err
return nil, err
}
// TODO(tsandall): why are we tracing here? this looks wrong.
if !defined {
child.traceFail(rule)
}
return nil
return result, nil
}
func (e *eval) biunifyRuleHead(pos int, ref ast.Ref, rule *ast.Rule, refBindings, ruleBindings *bindings, iter unifyRefIterator) error {
return e.biunifyDynamicRef(pos, ref, rule.Ref(), refBindings, ruleBindings, func(pos int) error {
// FIXME: Is there a simpler, more robust way of figuring out that we should biunify the rule key?
if rule.Head.RuleKind() == ast.MultiValue && pos < len(ref) && len(rule.Ref()) <= len(ref) {
headKey := rule.Head.Key
if headKey == nil {
headKey = rule.Head.Reference[len(rule.Head.Reference)-1]
}
return e.biunify(ref[pos], headKey, refBindings, ruleBindings, func() error {
return iter(pos + 1)
})
}
return iter(pos)
})
}
func (e *eval) biunifyDynamicRef(pos int, a, b ast.Ref, b1, b2 *bindings, iter unifyRefIterator) error {
if pos >= len(a) || pos >= len(b) {
return iter(pos)
}
return e.biunify(a[pos], b[pos], b1, b2, func() error {
return e.biunifyDynamicRef(pos+1, a, b, b1, b2, iter)
})
}
func (e evalVirtualPartial) evalOneRulePostUnify(iter unifyIterator, rule *ast.Rule) error {
key := e.ref[e.pos+1]
child := e.e.child(rule.Body)
child.traceEnter(rule)
@@ -2512,7 +2601,7 @@ func (e evalVirtualPartial) evalOneRulePostUnify(iter unifyIterator, rule *ast.R
err := child.eval(func(child *eval) error {
defined = true
return e.e.biunify(rule.Head.Key, key, child.bindings, e.bindings, func() error {
return e.e.biunifyRuleHead(e.pos+1, e.ref, rule, e.bindings, child.bindings, func(pos int) error {
return e.evalOneRuleContinue(iter, rule, child)
})
})
@@ -2538,7 +2627,15 @@ func (e evalVirtualPartial) evalOneRuleContinue(iter unifyIterator, rule *ast.Ru
}
term, termbindings := child.bindings.apply(term)
err := e.evalTerm(iter, e.pos+2, term, termbindings)
if rule.Head.RuleKind() == ast.MultiValue {
term = ast.SetTerm(term)
}
objRef := rule.Ref()[e.pos+1:]
term = wrapInObjects(term, objRef)
err := e.evalTerm(iter, e.pos+1, term, termbindings)
if err != nil {
return err
}
@@ -2597,17 +2694,32 @@ func (e evalVirtualPartial) partialEvalSupportRule(rule *ast.Rule, path ast.Ref)
// Skip this rule body if it fails to type-check.
// Type-checking failure means the rule body will never succeed.
if e.e.compiler.PassesTypeCheck(plugged) {
var key, value *ast.Term
if rule.Head.Key != nil {
key = child.bindings.PlugNamespaced(rule.Head.Key, e.e.caller.bindings)
}
var value *ast.Term
if rule.Head.Value != nil {
value = child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings)
}
head := ast.NewHead(rule.Head.Name, key, value)
ref := e.e.namespaceRef(rule.Ref())
for i := 1; i < len(ref); i++ {
ref[i] = child.bindings.plugNamespaced(ref[i], e.e.caller.bindings)
}
pkg, ruleRef := splitPackageAndRule(ref)
head := ast.RefHead(ruleRef, value)
// key is also part of ref in single-value rules, and can be dropped
if rule.Head.Key != nil && rule.Head.RuleKind() == ast.MultiValue {
head.Key = child.bindings.PlugNamespaced(rule.Head.Key, e.e.caller.bindings)
}
if rule.Head.RuleKind() == ast.SingleValue && len(ruleRef) == 2 {
head.Key = ruleRef[len(ruleRef)-1]
}
if head.Name.Equal(ast.Var("")) && (len(ruleRef) == 1 || (len(ruleRef) == 2 && rule.Head.RuleKind() == ast.SingleValue)) {
head.Name = ruleRef[0].Value.(ast.Var)
}
if !e.e.inliningControl.shallow {
cp := copypropagation.New(head.Vars()).
@@ -2616,7 +2728,7 @@ func (e evalVirtualPartial) partialEvalSupportRule(rule *ast.Rule, path ast.Ref)
plugged = applyCopyPropagation(cp, e.e.instr, plugged)
}
e.e.saveSupport.Insert(path, &ast.Rule{
e.e.saveSupport.InsertByPkg(pkg, &ast.Rule{
Head: head,
Body: plugged,
Default: rule.Default,
@@ -2649,6 +2761,7 @@ func (e evalVirtualPartial) evalCache(iter unifyIterator) (evalVirtualPartialCac
var hint evalVirtualPartialCacheHint
if e.e.unknown(e.ref[:e.pos+1], e.bindings) {
// FIXME: Return empty hint if unknowns in any e.ref elem overlapping with applicable rule refs?
return hint, nil
}
@@ -2660,17 +2773,29 @@ func (e evalVirtualPartial) evalCache(iter unifyIterator) (evalVirtualPartialCac
plugged := e.bindings.Plug(e.ref[e.pos+1])
if plugged.IsGround() {
hint.key = append(e.plugged[:e.pos+1], plugged)
if _, ok := plugged.Value.(ast.Var); ok {
hint.full = true
hint.key = e.plugged[:e.pos+1]
e.e.instr.counterIncr(evalOpVirtualCacheMiss)
return hint, nil
}
m := maxRefLength(e.ir.Rules, len(e.ref))
for i := e.pos + 1; i < m; i++ {
plugged = e.bindings.Plug(e.ref[i])
if !plugged.IsGround() {
break
}
hint.key = append(e.plugged[:i], plugged)
if cached, _ := e.e.virtualCache.Get(hint.key); cached != nil {
e.e.instr.counterIncr(evalOpVirtualCacheHit)
hint.hit = true
return hint, e.evalTerm(iter, e.pos+2, cached, e.bindings)
return hint, e.evalTerm(iter, i+1, cached, e.bindings)
}
} else if _, ok := plugged.Value.(ast.Var); ok {
hint.full = true
hint.key = e.plugged[:e.pos+1]
}
e.e.instr.counterIncr(evalOpVirtualCacheMiss)
@@ -2678,26 +2803,99 @@ func (e evalVirtualPartial) evalCache(iter unifyIterator) (evalVirtualPartialCac
return hint, nil
}
func (e evalVirtualPartial) reduce(head *ast.Head, b *bindings, result *ast.Term) (*ast.Term, bool, error) {
func getNestedObject(ref ast.Ref, rootObj *ast.Object, b *bindings, l *ast.Location) (*ast.Object, error) {
current := rootObj
for _, term := range ref {
key := b.Plug(term)
if child := (*current).Get(key); child != nil {
if val, ok := child.Value.(ast.Object); ok {
current = &val
} else {
return nil, objectDocKeyConflictErr(l)
}
} else {
child := ast.NewObject()
(*current).Insert(key, ast.NewTerm(child))
current = &child
}
}
return current, nil
}
func hasCollisions(path ast.Ref, visitedRefs *[]ast.Ref, b *bindings) bool {
collisionPathTerm := b.Plug(ast.NewTerm(path))
collisionPath := collisionPathTerm.Value.(ast.Ref)
for _, c := range *visitedRefs {
if collisionPath.HasPrefix(c) && !collisionPath.Equal(c) {
return true
}
}
*visitedRefs = append(*visitedRefs, collisionPath)
return false
}
func (e evalVirtualPartial) reduce(rule *ast.Rule, b *bindings, result *ast.Term, visitedRefs *[]ast.Ref) (*ast.Term, bool, error) {
var exists bool
head := rule.Head
switch v := result.Value.(type) {
case ast.Set: // MultiValue
case ast.Set:
key := b.Plug(head.Key)
exists = v.Contains(key)
v.Add(key)
case ast.Object: // SingleValue
key := head.Reference[len(head.Reference)-1] // NOTE(sr): multiple vars in ref heads need to deal with this better
key = b.Plug(key)
value := b.Plug(head.Value)
if curr := v.Get(key); curr != nil {
if !curr.Equal(value) {
return nil, false, objectDocKeyConflictErr(head.Location)
case ast.Object:
// data.p.q[r].s.t := 42 {...}
// |----|-|
// ^ ^
// | leafKey
// objPath
fullPath := rule.Ref()
collisionPath := fullPath[e.pos+1:]
if hasCollisions(collisionPath, visitedRefs, b) {
return nil, false, objectDocKeyConflictErr(head.Location)
}
objPath := fullPath[e.pos+1 : len(fullPath)-1] // the portion of the ref that generates nested objects
leafKey := b.Plug(fullPath[len(fullPath)-1]) // the portion of the ref that is the deepest nested key for the value
leafObj, err := getNestedObject(objPath, &v, b, head.Location)
if err != nil {
return nil, false, err
}
if kind := head.RuleKind(); kind == ast.SingleValue {
// We're inserting into an object
val := b.Plug(head.Value) // head.Value instance is shared between rule enumerations;but this is ok, as we don't allow rules to modify each others values.
if curr := (*leafObj).Get(leafKey); curr != nil {
if !curr.Equal(val) {
return nil, false, objectDocKeyConflictErr(head.Location)
}
exists = true
} else {
(*leafObj).Insert(leafKey, val)
}
exists = true
} else {
v.Insert(key, value)
// We're inserting into a set
var set *ast.Set
if leaf := (*leafObj).Get(leafKey); leaf != nil {
if s, ok := leaf.Value.(ast.Set); ok {
set = &s
} else {
return nil, false, objectDocKeyConflictErr(head.Location)
}
} else {
s := ast.NewSet()
(*leafObj).Insert(leafKey, ast.NewTerm(s))
set = &s
}
key := b.Plug(head.Key)
exists = (*set).Contains(key)
(*set).Add(key)
}
}
@@ -2916,15 +3114,8 @@ func (e evalVirtualComplete) partialEvalSupportRule(rule *ast.Rule, path ast.Ref
// Skip this rule body if it fails to type-check.
// Type-checking failure means the rule body will never succeed.
if e.e.compiler.PassesTypeCheck(plugged) {
var name ast.Var
switch ref := rule.Head.Ref().GroundPrefix(); len(ref) {
case 1:
name = ref[0].Value.(ast.Var)
default:
s := ref[len(ref)-1].Value.(ast.String)
name = ast.Var(s)
}
head := ast.NewHead(name, nil, child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings))
pkg, ruleRef := splitPackageAndRule(path)
head := ast.RefHead(ruleRef, child.bindings.PlugNamespaced(rule.Head.Value, e.e.caller.bindings))
if !e.e.inliningControl.shallow {
cp := copypropagation.New(head.Vars()).
@@ -2933,7 +3124,7 @@ func (e evalVirtualComplete) partialEvalSupportRule(rule *ast.Rule, path ast.Ref
plugged = applyCopyPropagation(cp, e.e.instr, plugged)
}
e.e.saveSupport.Insert(path, &ast.Rule{
e.e.saveSupport.InsertByPkg(pkg, &ast.Rule{
Head: head,
Body: plugged,
Default: rule.Default,

View File

@@ -67,6 +67,7 @@ var allowedKeyNames = [...]string{
"force_cache_duration_seconds",
"raise_error",
"caching_mode",
"max_retry_attempts",
}
// ref: https://www.rfc-editor.org/rfc/rfc7231#section-6.1
@@ -104,6 +105,12 @@ const (
// HTTPSendNetworkErr represents a network error.
HTTPSendNetworkErr string = "eval_http_send_network_error"
// minRetryDelay is amount of time to backoff after the first failure.
minRetryDelay = time.Millisecond * 100
// maxRetryDelay is the upper bound of backoff delay.
maxRetryDelay = time.Second * 60
)
func builtinHTTPSend(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
@@ -161,6 +168,7 @@ func getHTTPResponse(bctx BuiltinContext, req ast.Object) (*ast.Term, error) {
if resp == nil {
httpResp, err := reqExecutor.ExecuteHTTPRequest()
if err != nil {
reqExecutor.InsertErrorIntoCache(err)
return nil, err
}
defer util.Close(httpResp)
@@ -460,7 +468,7 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt
case "cache", "caching_mode",
"force_cache", "force_cache_duration_seconds",
"force_json_decode", "force_yaml_decode",
"raise_error": // no-op
"raise_error", "max_retry_attempts": // no-op
default:
return nil, nil, fmt.Errorf("invalid parameter %q", key)
}
@@ -646,8 +654,39 @@ func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *htt
return req, client, nil
}
func executeHTTPRequest(req *http.Request, client *http.Client) (*http.Response, error) {
return client.Do(req)
func executeHTTPRequest(req *http.Request, client *http.Client, inputReqObj ast.Object) (*http.Response, error) {
var err error
var retry int
retry, err = getNumberValFromReqObj(inputReqObj, ast.StringTerm("max_retry_attempts"))
if err != nil {
return nil, err
}
for i := 0; true; i++ {
var resp *http.Response
resp, err = client.Do(req)
if err == nil {
return resp, nil
}
// final attempt
if i == retry {
break
}
if err == context.Canceled {
return nil, err
}
select {
case <-time.After(util.DefaultBackoff(float64(minRetryDelay), float64(maxRetryDelay), i)):
case <-req.Context().Done():
return nil, context.Canceled
}
}
return nil, err
}
func isContentType(header http.Header, typ ...string) bool {
@@ -659,51 +698,121 @@ func isContentType(header http.Header, typ ...string) bool {
return false
}
type httpSendCacheEntry struct {
response *ast.Value
error error
}
// The httpSendCache is used for intra-query caching of http.send results.
type httpSendCache struct {
entries *util.HashMap
}
func newHTTPSendCache() *httpSendCache {
return &httpSendCache{
entries: util.NewHashMap(valueEq, valueHash),
}
}
func valueHash(v util.T) int {
return v.(ast.Value).Hash()
}
func valueEq(a, b util.T) bool {
av := a.(ast.Value)
bv := b.(ast.Value)
return av.Compare(bv) == 0
}
func (cache *httpSendCache) get(k ast.Value) *httpSendCacheEntry {
if v, ok := cache.entries.Get(k); ok {
v := v.(httpSendCacheEntry)
return &v
}
return nil
}
func (cache *httpSendCache) putResponse(k ast.Value, v *ast.Value) {
cache.entries.Put(k, httpSendCacheEntry{response: v})
}
func (cache *httpSendCache) putError(k ast.Value, v error) {
cache.entries.Put(k, httpSendCacheEntry{error: v})
}
// In the BuiltinContext cache we only store a single entry that points to
// our ValueMap which is the "real" http.send() cache.
func getHTTPSendCache(bctx BuiltinContext) *ast.ValueMap {
func getHTTPSendCache(bctx BuiltinContext) *httpSendCache {
raw, ok := bctx.Cache.Get(httpSendBuiltinCacheKey)
if !ok {
// Initialize if it isn't there
cache := ast.NewValueMap()
bctx.Cache.Put(httpSendBuiltinCacheKey, cache)
return cache
c := newHTTPSendCache()
bctx.Cache.Put(httpSendBuiltinCacheKey, c)
return c
}
cache, ok := raw.(*ast.ValueMap)
c, ok := raw.(*httpSendCache)
if !ok {
return nil
}
return cache
return c
}
// checkHTTPSendCache checks for the given key's value in the cache
func checkHTTPSendCache(bctx BuiltinContext, key ast.Object) ast.Value {
func checkHTTPSendCache(bctx BuiltinContext, key ast.Object) (ast.Value, error) {
requestCache := getHTTPSendCache(bctx)
if requestCache == nil {
return nil
return nil, nil
}
return requestCache.Get(key)
v := requestCache.get(key)
if v != nil {
if v.error != nil {
return nil, v.error
}
if v.response != nil {
return *v.response, nil
}
// This should never happen
}
return nil, nil
}
func insertIntoHTTPSendCache(bctx BuiltinContext, key ast.Object, value ast.Value) {
requestCache := getHTTPSendCache(bctx)
if requestCache == nil {
// Should never happen.. if it does just skip caching the value
// FIXME: return error instead, to prevent inconsistencies?
return
}
requestCache.Put(key, value)
requestCache.putResponse(key, &value)
}
func insertErrorIntoHTTPSendCache(bctx BuiltinContext, key ast.Object, err error) {
requestCache := getHTTPSendCache(bctx)
if requestCache == nil {
// Should never happen.. if it does just skip caching the value
// FIXME: return error instead, to prevent inconsistencies?
return
}
requestCache.putError(key, err)
}
// checkHTTPSendInterQueryCache checks for the given key's value in the inter-query cache
func (c *interQueryCache) checkHTTPSendInterQueryCache() (ast.Value, error) {
requestCache := c.bctx.InterQueryBuiltinCache
value, found := requestCache.Get(c.key)
cachedValue, found := requestCache.Get(c.key)
if !found {
return nil, nil
}
value, cerr := requestCache.Clone(cachedValue)
if cerr != nil {
return nil, handleHTTPSendErr(c.bctx, cerr)
}
c.bctx.Metrics.Counter(httpSendInterQueryCacheHits).Incr()
var cachedRespData *interQueryCacheData
@@ -730,15 +839,12 @@ func (c *interQueryCache) checkHTTPSendInterQueryCache() (ast.Value, error) {
return nil, handleHTTPSendErr(c.bctx, err)
}
headers, err := parseResponseHeaders(cachedRespData.Headers)
if err != nil {
return nil, err
}
headers := parseResponseHeaders(cachedRespData.Headers)
// check with the server if the stale response is still up-to-date.
// If server returns a new response (ie. status_code=200), update the cache with the new response
// If server returns an unmodified response (ie. status_code=304), update the headers for the existing response
result, modified, err := revalidateCachedResponse(c.httpReq, c.httpClient, headers)
result, modified, err := revalidateCachedResponse(c.httpReq, c.httpClient, c.key, headers)
requestCache.Delete(c.key)
if err != nil || result == nil {
return nil, err
@@ -755,11 +861,16 @@ func (c *interQueryCache) checkHTTPSendInterQueryCache() (ast.Value, error) {
}
}
expiresAt, err := expiryFromHeaders(result.Header)
if err != nil {
return nil, err
if forceCaching(c.forceCacheParams) {
createdAt := getCurrentTime(c.bctx)
cachedRespData.ExpiresAt = createdAt.Add(time.Second * time.Duration(c.forceCacheParams.forceCacheDurationSeconds))
} else {
expiresAt, err := expiryFromHeaders(result.Header)
if err != nil {
return nil, err
}
cachedRespData.ExpiresAt = expiresAt
}
cachedRespData.ExpiresAt = expiresAt
cachingMode, err := getCachingMode(c.key)
if err != nil {
@@ -777,7 +888,7 @@ func (c *interQueryCache) checkHTTPSendInterQueryCache() (ast.Value, error) {
pcv = cachedRespData
}
c.bctx.InterQueryBuiltinCache.Insert(c.key, pcv)
c.bctx.InterQueryBuiltinCache.InsertWithExpiry(c.key, pcv, cachedRespData.ExpiresAt)
return cachedRespData.formatToAST(c.forceJSONDecode, c.forceYAMLDecode)
}
@@ -813,18 +924,19 @@ func insertIntoHTTPSendInterQueryCache(bctx BuiltinContext, key ast.Value, resp
}
var pcv cache.InterQueryCacheValue
var pcvData *interQueryCacheData
if cachingMode == defaultCachingMode {
pcv, err = newInterQueryCacheValue(bctx, resp, respBody, cacheParams)
pcv, pcvData, err = newInterQueryCacheValue(bctx, resp, respBody, cacheParams)
} else {
pcv, err = newInterQueryCacheData(bctx, resp, respBody, cacheParams)
pcvData, err = newInterQueryCacheData(bctx, resp, respBody, cacheParams)
pcv = pcvData
}
if err != nil {
return err
}
requestCache.Insert(key, pcv)
requestCache.InsertWithExpiry(key, pcv, pcvData.ExpiresAt)
return nil
}
@@ -879,6 +991,23 @@ func getBoolValFromReqObj(req ast.Object, key *ast.Term) (bool, error) {
return bool(b), nil
}
func getNumberValFromReqObj(req ast.Object, key *ast.Term) (int, error) {
term := req.Get(key)
if term == nil {
return 0, nil
}
if t, ok := term.Value.(ast.Number); ok {
num, ok := t.Int()
if !ok || num < 0 {
return 0, fmt.Errorf("invalid value %v for field %v", t.String(), key.String())
}
return num, nil
}
return 0, fmt.Errorf("invalid value %v for field %v", term.String(), key.String())
}
func getCachingMode(req ast.Object) (cachingMode, error) {
key := ast.StringTerm("caching_mode")
var s ast.String
@@ -902,17 +1031,23 @@ type interQueryCacheValue struct {
Data []byte
}
func newInterQueryCacheValue(bctx BuiltinContext, resp *http.Response, respBody []byte, cacheParams *forceCacheParams) (*interQueryCacheValue, error) {
func newInterQueryCacheValue(bctx BuiltinContext, resp *http.Response, respBody []byte, cacheParams *forceCacheParams) (*interQueryCacheValue, *interQueryCacheData, error) {
data, err := newInterQueryCacheData(bctx, resp, respBody, cacheParams)
if err != nil {
return nil, err
return nil, nil, err
}
b, err := json.Marshal(data)
if err != nil {
return nil, err
return nil, nil, err
}
return &interQueryCacheValue{Data: b}, nil
return &interQueryCacheValue{Data: b}, data, nil
}
func (cb interQueryCacheValue) Clone() (cache.InterQueryCacheValue, error) {
dup := make([]byte, len(cb.Data))
copy(dup, cb.Data)
return &interQueryCacheValue{Data: dup}, nil
}
func (cb interQueryCacheValue) SizeInBytes() int64 {
@@ -998,44 +1133,38 @@ func (c *interQueryCacheData) SizeInBytes() int64 {
return 0
}
func (c *interQueryCacheData) Clone() (cache.InterQueryCacheValue, error) {
dup := make([]byte, len(c.RespBody))
copy(dup, c.RespBody)
return &interQueryCacheData{
ExpiresAt: c.ExpiresAt,
RespBody: dup,
Status: c.Status,
StatusCode: c.StatusCode,
Headers: c.Headers.Clone()}, nil
}
type responseHeaders struct {
date time.Time // origination date and time of response
cacheControl map[string]string // response cache-control header
maxAge deltaSeconds // max-age cache control directive
expires time.Time // date/time after which the response is considered stale
etag string // identifier for a specific version of the response
lastModified string // date and time response was last modified as per origin server
etag string // identifier for a specific version of the response
lastModified string // date and time response was last modified as per origin server
}
// deltaSeconds specifies a non-negative integer, representing
// time in seconds: http://tools.ietf.org/html/rfc7234#section-1.2.1
type deltaSeconds int32
func parseResponseHeaders(headers http.Header) (*responseHeaders, error) {
var err error
func parseResponseHeaders(headers http.Header) *responseHeaders {
result := responseHeaders{}
result.date, err = getResponseHeaderDate(headers)
if err != nil {
return nil, err
}
result.cacheControl = parseCacheControlHeader(headers)
result.maxAge, err = parseMaxAgeCacheDirective(result.cacheControl)
if err != nil {
return nil, err
}
result.expires = getResponseHeaderExpires(headers)
result.etag = headers.Get("etag")
result.lastModified = headers.Get("last-modified")
return &result, nil
return &result
}
func revalidateCachedResponse(req *http.Request, client *http.Client, headers *responseHeaders) (*http.Response, bool, error) {
func revalidateCachedResponse(req *http.Request, client *http.Client, inputReqObj ast.Object, headers *responseHeaders) (*http.Response, bool, error) {
etag := headers.etag
lastModified := headers.lastModified
@@ -1053,7 +1182,7 @@ func revalidateCachedResponse(req *http.Request, client *http.Client, headers *r
cloneReq.Header.Set("if-modified-since", lastModified)
}
response, err := client.Do(cloneReq)
response, err := executeHTTPRequest(cloneReq, client, inputReqObj)
if err != nil {
return nil, false, err
}
@@ -1233,6 +1362,7 @@ func getResponseHeaders(headers http.Header) map[string]interface{} {
type httpRequestExecutor interface {
CheckCache() (ast.Value, error)
InsertIntoCache(value *http.Response) (ast.Value, error)
InsertErrorIntoCache(err error)
ExecuteHTTPRequest() (*http.Response, error)
}
@@ -1268,6 +1398,15 @@ func newInterQueryCache(bctx BuiltinContext, key ast.Object, forceCacheParams *f
func (c *interQueryCache) CheckCache() (ast.Value, error) {
var err error
// Checking the intra-query cache first ensures consistency of errors and HTTP responses within a query.
resp, err := checkHTTPSendCache(c.bctx, c.key)
if err != nil {
return nil, err
}
if resp != nil {
return resp, nil
}
c.forceJSONDecode, err = getBoolValFromReqObj(c.key, ast.StringTerm("force_json_decode"))
if err != nil {
return nil, handleHTTPSendErr(c.bctx, err)
@@ -1277,14 +1416,14 @@ func (c *interQueryCache) CheckCache() (ast.Value, error) {
return nil, handleHTTPSendErr(c.bctx, err)
}
resp, err := c.checkHTTPSendInterQueryCache()
// fallback to the http send cache if response not found in the inter-query cache or inter-query cache look-up results
// in an error
if resp == nil || err != nil {
return checkHTTPSendCache(c.bctx, c.key), nil
resp, err = c.checkHTTPSendInterQueryCache()
// Always insert the result of the inter-query cache into the intra-query cache, to maintain consistency within the same query.
if err != nil {
insertErrorIntoHTTPSendCache(c.bctx, c.key, err)
}
if resp != nil {
insertIntoHTTPSendCache(c.bctx, c.key, resp)
}
return resp, err
}
@@ -1295,14 +1434,19 @@ func (c *interQueryCache) InsertIntoCache(value *http.Response) (ast.Value, erro
return nil, handleHTTPSendErr(c.bctx, err)
}
// fallback to the http send cache if error encountered while inserting response in inter-query cache
err = insertIntoHTTPSendInterQueryCache(c.bctx, c.key, value, respBody, c.forceCacheParams)
if err != nil {
insertIntoHTTPSendCache(c.bctx, c.key, result)
}
// Always insert into the intra-query cache, to maintain consistency within the same query.
insertIntoHTTPSendCache(c.bctx, c.key, result)
// We ignore errors when populating the inter-query cache, because we've already populated the intra-cache,
// and query consistency is our primary concern.
_ = insertIntoHTTPSendInterQueryCache(c.bctx, c.key, value, respBody, c.forceCacheParams)
return result, nil
}
func (c *interQueryCache) InsertErrorIntoCache(err error) {
insertErrorIntoHTTPSendCache(c.bctx, c.key, err)
}
// ExecuteHTTPRequest executes a HTTP request
func (c *interQueryCache) ExecuteHTTPRequest() (*http.Response, error) {
var err error
@@ -1311,7 +1455,7 @@ func (c *interQueryCache) ExecuteHTTPRequest() (*http.Response, error) {
return nil, handleHTTPSendErr(c.bctx, err)
}
return executeHTTPRequest(c.httpReq, c.httpClient)
return executeHTTPRequest(c.httpReq, c.httpClient, c.key)
}
type intraQueryCache struct {
@@ -1325,7 +1469,7 @@ func newIntraQueryCache(bctx BuiltinContext, key ast.Object) (*intraQueryCache,
// CheckCache checks the cache for the value of the key set on this object
func (c *intraQueryCache) CheckCache() (ast.Value, error) {
return checkHTTPSendCache(c.bctx, c.key), nil
return checkHTTPSendCache(c.bctx, c.key)
}
// InsertIntoCache inserts the key set on this object into the cache with the given value
@@ -1351,13 +1495,17 @@ func (c *intraQueryCache) InsertIntoCache(value *http.Response) (ast.Value, erro
return result, nil
}
func (c *intraQueryCache) InsertErrorIntoCache(err error) {
insertErrorIntoHTTPSendCache(c.bctx, c.key, err)
}
// ExecuteHTTPRequest executes a HTTP request
func (c *intraQueryCache) ExecuteHTTPRequest() (*http.Response, error) {
httpReq, httpClient, err := createHTTPRequest(c.bctx, c.key)
if err != nil {
return nil, handleHTTPSendErr(c.bctx, err)
}
return executeHTTPRequest(httpReq, httpClient)
return executeHTTPRequest(httpReq, httpClient, c.key)
}
func useInterQueryCache(req ast.Object) (bool, *forceCacheParams, error) {

View File

@@ -11,10 +11,11 @@ import (
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
"github.com/open-policy-agent/opa/internal/edittree"
)
func builtinJSONRemove(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
// Expect an object and a string or array/set of strings
_, err := builtins.ObjectOperand(operands[0].Value, 1)
if err != nil {
@@ -116,7 +117,6 @@ func jsonRemove(a *ast.Term, b *ast.Term) (*ast.Term, error) {
}
func builtinJSONFilter(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
// Ensure we have the right parameters, expect an object and a string or array/set of strings
obj, err := builtins.ObjectOperand(operands[0].Value, 1)
if err != nil {
@@ -196,7 +196,6 @@ func parsePath(path *ast.Term) (ast.Ref, error) {
}
func pathsToObject(paths []ast.Ref) ast.Object {
root := ast.NewObject()
for _, path := range paths {
@@ -239,288 +238,147 @@ func pathsToObject(paths []ast.Ref) ast.Object {
return root
}
// toIndex tries to convert path elements (that may be strings) into indices into
// an array.
func toIndex(arr *ast.Array, term *ast.Term) (int, error) {
i := 0
type jsonPatch struct {
op string
path *ast.Term
from *ast.Term
value *ast.Term
}
func getPatch(o ast.Object) (jsonPatch, error) {
validOps := map[string]struct{}{"add": {}, "remove": {}, "replace": {}, "move": {}, "copy": {}, "test": {}}
var out jsonPatch
var ok bool
switch v := term.Value.(type) {
case ast.Number:
if i, ok = v.Int(); !ok {
return 0, fmt.Errorf("Invalid number type for indexing")
getAttribute := func(attr string) (*ast.Term, error) {
if term := o.Get(ast.StringTerm(attr)); term != nil {
return term, nil
}
case ast.String:
if v == "-" {
return arr.Len(), nil
}
num := ast.Number(v)
if i, ok = num.Int(); !ok {
return 0, fmt.Errorf("Invalid string for indexing")
}
if v != "0" && strings.HasPrefix(string(v), "0") {
return 0, fmt.Errorf("Leading zeros are not allowed in JSON paths")
}
default:
return 0, fmt.Errorf("Invalid type for indexing")
return nil, fmt.Errorf("missing '%s' attribute", attr)
}
return i, nil
}
// patchWorkerris a worker that modifies a direct child of a term located
// at the given key. It returns the new term, and optionally a result that
// is passed back to the caller.
type patchWorker = func(parent, key *ast.Term) (updated, result *ast.Term)
func jsonPatchTraverse(
target *ast.Term,
path ast.Ref,
worker patchWorker,
) (*ast.Term, *ast.Term) {
if len(path) < 1 {
return nil, nil
opTerm, err := getAttribute("op")
if err != nil {
return out, err
}
op, ok := opTerm.Value.(ast.String)
if !ok {
return out, fmt.Errorf("attribute 'op' must be a string")
}
out.op = string(op)
if _, found := validOps[out.op]; !found {
out.op = ""
return out, fmt.Errorf("unrecognized op '%s'", string(op))
}
key := path[0]
if len(path) == 1 {
return worker(target, key)
pathTerm, err := getAttribute("path")
if err != nil {
return out, err
}
out.path = pathTerm
success := false
var updated, result *ast.Term
switch parent := target.Value.(type) {
case ast.Object:
obj := ast.NewObject()
parent.Foreach(func(k, v *ast.Term) {
if k.Equal(key) {
if v, result = jsonPatchTraverse(v, path[1:], worker); v != nil {
obj.Insert(k, v)
success = true
}
} else {
obj.Insert(k, v)
}
})
updated = ast.NewTerm(obj)
case *ast.Array:
idx, err := toIndex(parent, key)
// Only fetch the "from" parameter for move/copy ops.
switch out.op {
case "move", "copy":
fromTerm, err := getAttribute("from")
if err != nil {
return nil, nil
return out, err
}
arr := ast.NewArray()
for i := 0; i < parent.Len(); i++ {
v := parent.Elem(i)
if idx == i {
if v, result = jsonPatchTraverse(v, path[1:], worker); v != nil {
arr = arr.Append(v)
success = true
}
} else {
arr = arr.Append(v)
out.from = fromTerm
}
// Only fetch the "value" parameter for add/replace/test ops.
switch out.op {
case "add", "replace", "test":
valueTerm, err := getAttribute("value")
if err != nil {
return out, err
}
out.value = valueTerm
}
return out, nil
}
func applyPatches(source *ast.Term, operations *ast.Array) (*ast.Term, error) {
et := edittree.NewEditTree(source)
for i := 0; i < operations.Len(); i++ {
object, ok := operations.Elem(i).Value.(ast.Object)
if !ok {
return nil, fmt.Errorf("must be an array of JSON-Patch objects, but at least one element is not an object")
}
patch, err := getPatch(object)
if err != nil {
return nil, err
}
path, err := parsePath(patch.path)
if err != nil {
return nil, err
}
switch patch.op {
case "add":
_, err = et.InsertAtPath(path, patch.value)
if err != nil {
return nil, err
}
case "remove":
_, err = et.DeleteAtPath(path)
if err != nil {
return nil, err
}
case "replace":
_, err = et.DeleteAtPath(path)
if err != nil {
return nil, err
}
_, err = et.InsertAtPath(path, patch.value)
if err != nil {
return nil, err
}
case "move":
from, err := parsePath(patch.from)
if err != nil {
return nil, err
}
chunk, err := et.RenderAtPath(from)
if err != nil {
return nil, err
}
_, err = et.DeleteAtPath(from)
if err != nil {
return nil, err
}
_, err = et.InsertAtPath(path, chunk)
if err != nil {
return nil, err
}
case "copy":
from, err := parsePath(patch.from)
if err != nil {
return nil, err
}
chunk, err := et.RenderAtPath(from)
if err != nil {
return nil, err
}
_, err = et.InsertAtPath(path, chunk)
if err != nil {
return nil, err
}
case "test":
chunk, err := et.RenderAtPath(path)
if err != nil {
return nil, err
}
if !chunk.Equal(patch.value) {
return nil, fmt.Errorf("value from EditTree != patch value.\n\nExpected: %v\n\nFound: %v", patch.value, chunk)
}
}
updated = ast.NewTerm(arr)
case ast.Set:
set := ast.NewSet()
parent.Foreach(func(k *ast.Term) {
if k.Equal(key) {
if k, result = jsonPatchTraverse(k, path[1:], worker); k != nil {
set.Add(k)
success = true
}
} else {
set.Add(k)
}
})
updated = ast.NewTerm(set)
}
if success {
return updated, result
}
return nil, nil
}
// jsonPatchGet goes one step further than jsonPatchTraverse and returns the
// term at the location specified by the path. It is used in functions
// where we want to read a value but not manipulate its parent: for example
// jsonPatchTest and jsonPatchCopy.
//
// Because it uses jsonPatchTraverse, it makes shallow copies of the objects
// along the path. We could possibly add a signaling mechanism that we didn't
// make any changes to avoid this.
func jsonPatchGet(target *ast.Term, path ast.Ref) *ast.Term {
// Special case: get entire document.
if len(path) == 0 {
return target
}
_, result := jsonPatchTraverse(target, path, func(parent, key *ast.Term) (*ast.Term, *ast.Term) {
switch v := parent.Value.(type) {
case ast.Object:
return parent, v.Get(key)
case *ast.Array:
i, err := toIndex(v, key)
if err == nil {
return parent, v.Elem(i)
}
case ast.Set:
if v.Contains(key) {
return parent, key
}
}
return nil, nil
})
return result
}
func jsonPatchAdd(target *ast.Term, path ast.Ref, value *ast.Term) *ast.Term {
// Special case: replacing root document.
if len(path) == 0 {
return value
}
target, _ = jsonPatchTraverse(target, path, func(parent *ast.Term, key *ast.Term) (*ast.Term, *ast.Term) {
switch original := parent.Value.(type) {
case ast.Object:
obj := ast.NewObject()
original.Foreach(func(k, v *ast.Term) {
obj.Insert(k, v)
})
obj.Insert(key, value)
return ast.NewTerm(obj), nil
case *ast.Array:
idx, err := toIndex(original, key)
if err != nil || idx < 0 || idx > original.Len() {
return nil, nil
}
arr := ast.NewArray()
for i := 0; i < idx; i++ {
arr = arr.Append(original.Elem(i))
}
arr = arr.Append(value)
for i := idx; i < original.Len(); i++ {
arr = arr.Append(original.Elem(i))
}
return ast.NewTerm(arr), nil
case ast.Set:
if !key.Equal(value) {
return nil, nil
}
set := ast.NewSet()
original.Foreach(func(k *ast.Term) {
set.Add(k)
})
set.Add(key)
return ast.NewTerm(set), nil
}
return nil, nil
})
return target
}
func jsonPatchRemove(target *ast.Term, path ast.Ref) (*ast.Term, *ast.Term) {
// Special case: replacing root document.
if len(path) == 0 {
return nil, nil
}
target, removed := jsonPatchTraverse(target, path, func(parent *ast.Term, key *ast.Term) (*ast.Term, *ast.Term) {
var removed *ast.Term
switch original := parent.Value.(type) {
case ast.Object:
obj := ast.NewObject()
original.Foreach(func(k, v *ast.Term) {
if k.Equal(key) {
removed = v
} else {
obj.Insert(k, v)
}
})
return ast.NewTerm(obj), removed
case *ast.Array:
idx, err := toIndex(original, key)
if err != nil || idx < 0 || idx >= original.Len() {
return nil, nil
}
arr := ast.NewArray()
for i := 0; i < idx; i++ {
arr = arr.Append(original.Elem(i))
}
removed = original.Elem(idx)
for i := idx + 1; i < original.Len(); i++ {
arr = arr.Append(original.Elem(i))
}
return ast.NewTerm(arr), removed
case ast.Set:
set := ast.NewSet()
original.Foreach(func(k *ast.Term) {
if k.Equal(key) {
removed = k
} else {
set.Add(k)
}
})
return ast.NewTerm(set), removed
}
return nil, nil
})
if target != nil && removed != nil {
return target, removed
}
return nil, nil
}
func jsonPatchReplace(target *ast.Term, path ast.Ref, value *ast.Term) *ast.Term {
// Special case: replacing the whole document.
if len(path) == 0 {
return value
}
// Replace is specified as `remove` followed by `add`.
if target, _ = jsonPatchRemove(target, path); target == nil {
return nil
}
return jsonPatchAdd(target, path, value)
}
func jsonPatchMove(target *ast.Term, path ast.Ref, from ast.Ref) *ast.Term {
// Move is specified as `remove` followed by `add`.
target, removed := jsonPatchRemove(target, from)
if target == nil || removed == nil {
return nil
}
return jsonPatchAdd(target, path, removed)
}
func jsonPatchCopy(target *ast.Term, path ast.Ref, from ast.Ref) *ast.Term {
value := jsonPatchGet(target, from)
if value == nil {
return nil
}
return jsonPatchAdd(target, path, value)
}
func jsonPatchTest(target *ast.Term, path ast.Ref, value *ast.Term) *ast.Term {
actual := jsonPatchGet(target, path)
if actual == nil {
return nil
}
if actual.Equal(value) {
return target
}
return nil
final := et.Render()
// TODO: Nil check here?
return final, nil
}
func builtinJSONPatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
@@ -533,93 +391,11 @@ func builtinJSONPatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter
return err
}
// Apply operations one by one.
for i := 0; i < operations.Len(); i++ {
if object, ok := operations.Elem(i).Value.(ast.Object); ok {
getAttribute := func(attr string) (*ast.Term, error) {
if term := object.Get(ast.StringTerm(attr)); term != nil {
return term, nil
}
return nil, builtins.NewOperandErr(2, fmt.Sprintf("patch is missing '%s' attribute", attr))
}
getPathAttribute := func(attr string) (ast.Ref, error) {
term, err := getAttribute(attr)
if err != nil {
return ast.Ref{}, err
}
path, err := parsePath(term)
if err != nil {
return ast.Ref{}, err
}
return path, nil
}
// Parse operation.
opTerm, err := getAttribute("op")
if err != nil {
return err
}
op, ok := opTerm.Value.(ast.String)
if !ok {
return builtins.NewOperandErr(2, "patch attribute 'op' must be a string")
}
// Parse path.
path, err := getPathAttribute("path")
if err != nil {
return err
}
switch op {
case "add":
value, err := getAttribute("value")
if err != nil {
return err
}
target = jsonPatchAdd(target, path, value)
case "remove":
target, _ = jsonPatchRemove(target, path)
case "replace":
value, err := getAttribute("value")
if err != nil {
return err
}
target = jsonPatchReplace(target, path, value)
case "move":
from, err := getPathAttribute("from")
if err != nil {
return err
}
target = jsonPatchMove(target, path, from)
case "copy":
from, err := getPathAttribute("from")
if err != nil {
return err
}
target = jsonPatchCopy(target, path, from)
case "test":
value, err := getAttribute("value")
if err != nil {
return err
}
target = jsonPatchTest(target, path, value)
default:
return builtins.NewOperandErr(2, "must be an array of JSON-Patch objects")
}
} else {
return builtins.NewOperandErr(2, "must be an array of JSON-Patch objects")
}
// JSON patches should work atomically; and if one of them fails,
// we should not try to continue.
if target == nil {
return nil
}
patched, err := applyPatches(target, operations)
if err != nil {
return nil
}
return iter(target)
return iter(patched)
}
func init() {

View File

@@ -0,0 +1,109 @@
// Copyright 2022 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"encoding/json"
"errors"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/internal/gojsonschema"
)
// astValueToJSONSchemaLoader converts a value to JSON Loader.
// Value can be ast.String or ast.Object.
func astValueToJSONSchemaLoader(value ast.Value) (gojsonschema.JSONLoader, error) {
var loader gojsonschema.JSONLoader
var err error
// ast.Value type selector.
switch x := value.(type) {
case ast.String:
// In case of string pass it as is as a raw JSON string.
// Make pre-check that it's a valid JSON at all because gojsonschema won't do that.
if !json.Valid([]byte(x)) {
return nil, errors.New("invalid JSON string")
}
loader = gojsonschema.NewStringLoader(string(x))
case ast.Object:
// In case of object serialize it to JSON representation.
var data interface{}
data, err = ast.JSON(value)
if err != nil {
return nil, err
}
loader = gojsonschema.NewGoLoader(data)
default:
// Any other cases will produce an error.
return nil, errors.New("wrong type, expected string or object")
}
return loader, nil
}
func newResultTerm(valid bool, data *ast.Term) *ast.Term {
return ast.ArrayTerm(ast.BooleanTerm(valid), data)
}
// builtinJSONSchemaVerify accepts 1 argument which can be string or object and checks if it is valid JSON schema.
// Returns array [false, <string>] with error string at index 1, or [true, ""] with empty string at index 1 otherwise.
func builtinJSONSchemaVerify(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
// Take first argument and make JSON Loader from it.
loader, err := astValueToJSONSchemaLoader(operands[0].Value)
if err != nil {
return iter(newResultTerm(false, ast.StringTerm("jsonschema: "+err.Error())))
}
// Check that schema is correct and parses without errors.
if _, err = gojsonschema.NewSchema(loader); err != nil {
return iter(newResultTerm(false, ast.StringTerm("jsonschema: "+err.Error())))
}
return iter(newResultTerm(true, ast.NullTerm()))
}
// builtinJSONMatchSchema accepts 2 arguments both can be string or object and verifies if the document matches the JSON schema.
// Returns an array where first element is a boolean indicating a successful match, and the second is an array of errors that is empty on success and populated on failure.
// In case of internal error returns empty array.
func builtinJSONMatchSchema(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
// Take first argument and make JSON Loader from it.
// This is a JSON document made from Rego JSON string or object.
documentLoader, err := astValueToJSONSchemaLoader(operands[0].Value)
if err != nil {
return err
}
// Take second argument and make JSON Loader from it.
// This is a JSON schema made from Rego JSON string or object.
schemaLoader, err := astValueToJSONSchemaLoader(operands[1].Value)
if err != nil {
return err
}
// Use schema to validate document.
result, err := gojsonschema.Validate(schemaLoader, documentLoader)
if err != nil {
return err
}
// In case of validation errors produce Rego array of objects to describe the errors.
arr := ast.NewArray()
for _, re := range result.Errors() {
o := ast.NewObject(
[...]*ast.Term{ast.StringTerm("error"), ast.StringTerm(re.String())},
[...]*ast.Term{ast.StringTerm("type"), ast.StringTerm(re.Type())},
[...]*ast.Term{ast.StringTerm("field"), ast.StringTerm(re.Field())},
[...]*ast.Term{ast.StringTerm("desc"), ast.StringTerm(re.Description())},
)
arr = arr.Append(ast.NewTerm(o))
}
return iter(newResultTerm(result.Valid(), ast.NewTerm(arr)))
}
func init() {
RegisterBuiltinFunc(ast.JSONSchemaVerify.Name, builtinJSONSchemaVerify)
RegisterBuiltinFunc(ast.JSONMatchSchema.Name, builtinJSONMatchSchema)
}

View File

@@ -28,32 +28,71 @@ func builtinNumbersRange(bctx BuiltinContext, operands []*ast.Term, iter func(*a
return err
}
result := ast.NewArray()
ast, err := generateRange(bctx, x, y, one, "numbers.range")
if err != nil {
return err
}
return iter(ast)
}
func builtinNumbersRangeStep(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
x, err := builtins.BigIntOperand(operands[0].Value, 1)
if err != nil {
return err
}
y, err := builtins.BigIntOperand(operands[1].Value, 2)
if err != nil {
return err
}
step, err := builtins.BigIntOperand(operands[2].Value, 3)
if err != nil {
return err
}
if step.Cmp(big.NewInt(0)) <= 0 {
return fmt.Errorf("numbers.range_step: step must be a positive number above zero")
}
ast, err := generateRange(bctx, x, y, step, "numbers.range_step")
if err != nil {
return err
}
return iter(ast)
}
func generateRange(bctx BuiltinContext, x *big.Int, y *big.Int, step *big.Int, funcName string) (*ast.Term, error) {
cmp := x.Cmp(y)
comp := func(i *big.Int, y *big.Int) bool { return i.Cmp(y) <= 0 }
iter := func(i *big.Int) *big.Int { return i.Add(i, step) }
if cmp > 0 {
comp = func(i *big.Int, y *big.Int) bool { return i.Cmp(y) >= 0 }
iter = func(i *big.Int) *big.Int { return i.Sub(i, step) }
}
result := ast.NewArray()
haltErr := Halt{
Err: &Error{
Code: CancelErr,
Message: "numbers.range: timed out before generating all numbers in range",
Message: fmt.Sprintf("%s: timed out before generating all numbers in range", funcName),
},
}
if cmp <= 0 {
for i := new(big.Int).Set(x); i.Cmp(y) <= 0; i = i.Add(i, one) {
if bctx.Cancel != nil && bctx.Cancel.Cancelled() {
return haltErr
}
result = result.Append(ast.NewTerm(builtins.IntToNumber(i)))
}
} else {
for i := new(big.Int).Set(x); i.Cmp(y) >= 0; i = i.Sub(i, one) {
if bctx.Cancel != nil && bctx.Cancel.Cancelled() {
return haltErr
}
result = result.Append(ast.NewTerm(builtins.IntToNumber(i)))
for i := new(big.Int).Set(x); comp(i, y); i = iter(i) {
if bctx.Cancel != nil && bctx.Cancel.Cancelled() {
return nil, haltErr
}
result = result.Append(ast.NewTerm(builtins.IntToNumber(i)))
}
return iter(ast.NewTerm(result))
return ast.NewTerm(result), nil
}
func builtinRandIntn(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
@@ -95,5 +134,6 @@ func builtinRandIntn(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.T
func init() {
RegisterBuiltinFunc(ast.NumbersRange.Name, builtinNumbersRange)
RegisterBuiltinFunc(ast.NumbersRangeStep.Name, builtinNumbersRangeStep)
RegisterBuiltinFunc(ast.RandIntn.Name, builtinRandIntn)
}

View File

@@ -200,7 +200,16 @@ func mergewithOverwriteInPlace(obj, other ast.Object, frozenKeys map[*ast.Term]s
v2 := obj.Get(k)
// The key didn't exist in other, keep the original value.
if v2 == nil {
obj.Insert(k, v)
nestedObj, ok := v.Value.(ast.Object)
if !ok {
// v is not an object
obj.Insert(k, v)
} else {
// Copy the nested object so the original object would not be modified
nestedObjCopy := nestedObj.Copy()
obj.Insert(k, ast.NewTerm(nestedObjCopy))
}
return
}
// The key exists in both. Merge or reject change.

View File

@@ -53,7 +53,7 @@ func getReqBodyBytes(body, rawBody *ast.Term) ([]byte, error) {
}
func objectToMap(o ast.Object) map[string][]string {
var out map[string][]string
out := make(map[string][]string, o.Len())
o.Foreach(func(k, v *ast.Term) {
ks := stringFromTerm(k)
vs := stringFromTerm(v)
@@ -172,10 +172,21 @@ func builtinAWSSigV4SignReq(ctx BuiltinContext, operands []*ast.Term, iter func(
}
// Sign the request object's headers, and reconstruct the headers map.
authHeader, signedHeadersMap := aws.SignV4(objectToMap(headers), method, theURL, body, service, awsCreds, signingTimestamp)
headersMap := objectToMap(headers)
authHeader, awsHeadersMap := aws.SignV4(headersMap, method, theURL, body, service, awsCreds, signingTimestamp)
signedHeadersObj := ast.NewObject()
// Restore original headers
for k, v := range headersMap {
// objectToMap doesn't support arrays
if len(v) == 1 {
signedHeadersObj.Insert(ast.StringTerm(k), ast.StringTerm(v[0]))
}
}
// Set authorization header
signedHeadersObj.Insert(ast.StringTerm("Authorization"), ast.StringTerm(authHeader))
for k, v := range signedHeadersMap {
// set aws signature headers
for k, v := range awsHeadersMap {
signedHeadersObj.Insert(ast.StringTerm(k), ast.StringTerm(v))
}

View File

@@ -471,6 +471,14 @@ func (q *Query) Run(ctx context.Context) (QueryResultSet, error) {
// Iter executes the query and invokes the iter function with query results
// produced by evaluating the query.
func (q *Query) Iter(ctx context.Context, iter func(QueryResult) error) error {
// Query evaluation must not be allowed if the compiler has errors and is in an undefined, possibly inconsistent state
if q.compiler != nil && len(q.compiler.Errors) > 0 {
return &Error{
Code: InternalErr,
Message: "compiler has errors",
}
}
if q.seed == nil {
q.seed = rand.Reader
}

View File

@@ -46,7 +46,7 @@ func builtinRegexMatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te
if err != nil {
return err
}
return iter(ast.BooleanTerm(re.Match([]byte(s2))))
return iter(ast.BooleanTerm(re.MatchString(string(s2))))
}
func builtinRegexMatchTemplate(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {

View File

@@ -287,22 +287,37 @@ func (s *saveSupport) List() []*ast.Module {
}
func (s *saveSupport) Exists(path ast.Ref) bool {
k := path[:len(path)-1].String()
module, ok := s.modules[k]
pkg, ruleRef := splitPackageAndRule(path)
module, ok := s.modules[pkg.String()]
if !ok {
return false
}
name := ast.Var(path[len(path)-1].Value.(ast.String))
if len(ruleRef) == 1 {
name := ruleRef[0].Value.(ast.Var)
for _, rule := range module.Rules {
if rule.Head.Name.Equal(name) {
return true
}
}
return false
}
for _, rule := range module.Rules {
if rule.Head.Name.Equal(name) {
if rule.Head.Ref().HasPrefix(ruleRef) {
return true
}
}
return false
}
func (s *saveSupport) Insert(path ast.Ref, rule *ast.Rule) {
pkg := path[:len(path)-1]
pkg, _ := splitPackageAndRule(path)
s.InsertByPkg(pkg, rule)
}
func (s *saveSupport) InsertByPkg(pkg ast.Ref, rule *ast.Rule) {
k := pkg.String()
module, ok := s.modules[k]
if !ok {
@@ -317,6 +332,25 @@ func (s *saveSupport) Insert(path ast.Ref, rule *ast.Rule) {
module.Rules = append(module.Rules, rule)
}
func splitPackageAndRule(path ast.Ref) (ast.Ref, ast.Ref) {
p := path.Copy()
ruleRefStart := 2 // path always contains at least 3 terms (data. + one term in package + rule name)
for i := ruleRefStart; i < len(p.StringPrefix()); i++ {
t := p[i]
if str, ok := t.Value.(ast.String); ok && ast.IsVarCompatibleString(string(str)) {
ruleRefStart = i
} else {
break
}
}
pkg := p[:ruleRefStart]
rule := p[ruleRefStart:]
rule[0].Value = ast.Var(rule[0].Value.(ast.String))
return pkg, rule
}
// saveRequired returns true if the statement x will result in some expressions
// being saved. This check allows the evaluator to evaluate statements
// completely during partial evaluation as long as they do not depend on any

View File

@@ -193,12 +193,12 @@ func arraySubset(super, sub *ast.Array) bool {
return true
}
if superCursor == super.Len() {
if superCursor+subCursor == super.Len() {
return false
}
subElem := sub.Elem(subCursor)
superElem := sub.Elem(superCursor + subCursor)
superElem := super.Elem(superCursor + subCursor)
if superElem == nil {
return false
}

View File

@@ -0,0 +1,45 @@
package topdown
import (
"bytes"
"text/template"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
)
func renderTemplate(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
preContentTerm, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return err
}
templateVariablesTerm, err := builtins.ObjectOperand(operands[1].Value, 2)
if err != nil {
return err
}
var templateVariables map[string]interface{}
if err := ast.As(templateVariablesTerm, &templateVariables); err != nil {
return err
}
tmpl, err := template.New("template").Parse(string(preContentTerm))
if err != nil {
return err
}
// Do not attempt to render if template variable keys are missing
tmpl.Option("missingkey=error")
var buf bytes.Buffer
if err := tmpl.Execute(&buf, templateVariables); err != nil {
return err
}
return iter(ast.StringTerm(buf.String()))
}
func init() {
RegisterBuiltinFunc(ast.RenderTemplate.Name, renderTemplate)
}

View File

@@ -12,6 +12,7 @@ import (
"strconv"
"sync"
"time"
_ "time/tzdata" // this is needed to have LoadLocation when no filesystem tzdata is available
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/topdown/builtins"
@@ -49,7 +50,13 @@ func builtinTimeParseNanos(_ BuiltinContext, operands []*ast.Term, iter func(*as
return err
}
result, err := time.Parse(string(format), string(value))
formatStr := string(format)
// look for the formatStr in our acceptedTimeFormats and
// use the constant instead if it matches
if f, ok := acceptedTimeFormats[formatStr]; ok {
formatStr = f
}
result, err := time.Parse(formatStr, string(value))
if err != nil {
return err
}
@@ -82,6 +89,20 @@ func builtinParseDurationNanos(_ BuiltinContext, operands []*ast.Term, iter func
return iter(ast.NumberTerm(int64ToJSONNumber(int64(value))))
}
// Represent exposed constants for formatting from the stdlib time pkg
var acceptedTimeFormats = map[string]string{
"ANSIC": time.ANSIC,
"UnixDate": time.UnixDate,
"RubyDate": time.RubyDate,
"RFC822": time.RFC822,
"RFC822Z": time.RFC822Z,
"RFC850": time.RFC850,
"RFC1123": time.RFC1123,
"RFC1123Z": time.RFC1123Z,
"RFC3339": time.RFC3339,
"RFC3339Nano": time.RFC3339Nano,
}
func builtinFormat(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
t, layout, err := tzTime(operands[0].Value)
if err != nil {
@@ -90,7 +111,12 @@ func builtinFormat(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term)
// Using RFC3339Nano time formatting as default
if layout == "" {
layout = time.RFC3339Nano
} else if layoutStr, ok := acceptedTimeFormats[layout]; ok {
// if we can find a constant specified, use the constant
layout = layoutStr
}
// otherwise try to treat the fmt string as a datetime fmt string
timestamp := t.Format(layout)
return iter(ast.StringTerm(timestamp))
}

View File

@@ -1065,6 +1065,8 @@ func builtinJWTDecodeVerify(bctx BuiltinContext, operands []*ast.Term, iter func
if constraints.iss != issVal {
return iter(unverified)
}
} else {
return iter(unverified)
}
}
// RFC7159 4.1.3 aud

View File

@@ -7,6 +7,7 @@ package topdown
import (
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/internal/uuid"
"github.com/open-policy-agent/opa/topdown/builtins"
)
type uuidCachingKey string
@@ -31,6 +32,25 @@ func builtinUUIDRFC4122(bctx BuiltinContext, operands []*ast.Term, iter func(*as
return iter(result)
}
func builtinUUIDParse(_ BuiltinContext, operands []*ast.Term, iter func(term *ast.Term) error) error {
str, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return err
}
parsed, err := uuid.Parse(string(str))
if err != nil {
return nil
}
val, err := ast.InterfaceToValue(parsed)
if err != nil {
return err
}
return iter(ast.NewTerm(val))
}
func init() {
RegisterBuiltinFunc(ast.UUIDRFC4122.Name, builtinUUIDRFC4122)
RegisterBuiltinFunc(ast.UUIDParse.Name, builtinUUIDParse)
}

View File

@@ -10,6 +10,15 @@ import (
func evalWalk(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
input := operands[0]
if pathIsWildcard(operands) {
// When the path assignment is a wildcard: walk(input, [_, value])
// we may skip the path construction entirely, and simply return
// same pointer in each iteration. This is a much more efficient
// path when only the values are needed.
return walkNoPath(input, iter)
}
filter := getOutputPath(operands)
return walk(filter, nil, input, iter)
}
@@ -70,6 +79,33 @@ func walk(filter, path *ast.Array, input *ast.Term, iter func(*ast.Term) error)
return nil
}
var emptyArr = ast.ArrayTerm()
func walkNoPath(input *ast.Term, iter func(*ast.Term) error) error {
if err := iter(ast.ArrayTerm(emptyArr, input)); err != nil {
return err
}
switch v := input.Value.(type) {
case ast.Object:
return v.Iter(func(_, v *ast.Term) error {
return walkNoPath(v, iter)
})
case *ast.Array:
for i := 0; i < v.Len(); i++ {
if err := walkNoPath(v.Elem(i), iter); err != nil {
return err
}
}
case ast.Set:
return v.Iter(func(elem *ast.Term) error {
return walkNoPath(elem, iter)
})
}
return nil
}
func pathAppend(path *ast.Array, key *ast.Term) *ast.Array {
if path == nil {
return ast.NewArray(key)
@@ -80,17 +116,26 @@ func pathAppend(path *ast.Array, key *ast.Term) *ast.Array {
func getOutputPath(operands []*ast.Term) *ast.Array {
if len(operands) == 2 {
if arr, ok := operands[1].Value.(*ast.Array); ok {
if arr.Len() == 2 {
if path, ok := arr.Elem(0).Value.(*ast.Array); ok {
return path
}
if arr, ok := operands[1].Value.(*ast.Array); ok && arr.Len() == 2 {
if path, ok := arr.Elem(0).Value.(*ast.Array); ok {
return path
}
}
}
return nil
}
func pathIsWildcard(operands []*ast.Term) bool {
if len(operands) == 2 {
if arr, ok := operands[1].Value.(*ast.Array); ok && arr.Len() == 2 {
if v, ok := arr.Elem(0).Value.(ast.Var); ok {
return v.IsWildcard()
}
}
}
return false
}
func init() {
RegisterBuiltinFunc(ast.WalkBuiltin.Name, evalWalk)
}