update dependencies (#6267)

Signed-off-by: hongming <coder.scala@gmail.com>
This commit is contained in:
hongming
2024-11-06 10:27:06 +08:00
committed by GitHub
parent faf255a084
commit cfebd96a1f
4263 changed files with 341374 additions and 132036 deletions

87
vendor/k8s.io/apiserver/pkg/cel/cidr.go generated vendored Normal file
View File

@@ -0,0 +1,87 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cel
import (
"fmt"
"math"
"net/netip"
"reflect"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// CIDR provides a CEL representation of an network address.
type CIDR struct {
netip.Prefix
}
var (
CIDRType = cel.OpaqueType("net.CIDR")
)
// ConvertToNative implements ref.Val.ConvertToNative.
func (d CIDR) ConvertToNative(typeDesc reflect.Type) (any, error) {
if reflect.TypeOf(d.Prefix).AssignableTo(typeDesc) {
return d.Prefix, nil
}
if reflect.TypeOf("").AssignableTo(typeDesc) {
return d.Prefix.String(), nil
}
return nil, fmt.Errorf("type conversion error from 'CIDR' to '%v'", typeDesc)
}
// ConvertToType implements ref.Val.ConvertToType.
func (d CIDR) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case CIDRType:
return d
case types.TypeType:
return CIDRType
case types.StringType:
return types.String(d.Prefix.String())
}
return types.NewErr("type conversion error from '%s' to '%s'", CIDRType, typeVal)
}
// Equal implements ref.Val.Equal.
func (d CIDR) Equal(other ref.Val) ref.Val {
otherD, ok := other.(CIDR)
if !ok {
return types.ValOrErr(other, "no such overload")
}
return types.Bool(d.Prefix == otherD.Prefix)
}
// Type implements ref.Val.Type.
func (d CIDR) Type() ref.Type {
return CIDRType
}
// Value implements ref.Val.Value.
func (d CIDR) Value() any {
return d.Prefix
}
// Size returns the size of the CIDR prefix address in bytes.
// Used in the size estimation of the runtime cost.
func (d CIDR) Size() ref.Val {
return types.Int(int(math.Ceil(float64(d.Prefix.Bits()) / 8)))
}

View File

@@ -20,6 +20,7 @@ import (
"fmt"
"strconv"
"sync"
"sync/atomic"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
@@ -30,23 +31,34 @@ import (
"k8s.io/apimachinery/pkg/util/version"
celconfig "k8s.io/apiserver/pkg/apis/cel"
"k8s.io/apiserver/pkg/cel/library"
genericfeatures "k8s.io/apiserver/pkg/features"
utilfeature "k8s.io/apiserver/pkg/util/feature"
utilversion "k8s.io/apiserver/pkg/util/version"
)
// DefaultCompatibilityVersion returns a default compatibility version for use with EnvSet
// that guarantees compatibility with CEL features/libraries/parameters understood by
// an n-1 version
// the api server min compatibility version
//
// This default will be set to no more than n-1 the current Kubernetes major.minor version.
// This default will be set to no more than the current Kubernetes major.minor version.
//
// Note that a default version number less than n-1 indicates a wider range of version
// compatibility than strictly required for rollback. A wide range of compatibility is
// desirable because it means that CEL expressions are portable across a wider range
// of Kubernetes versions.
// Note that a default version number less than n-1 the current Kubernetes major.minor version
// indicates a wider range of version compatibility than strictly required for rollback.
// A wide range of compatibility is desirable because it means that CEL expressions are portable
// across a wider range of Kubernetes versions.
// A default version number equal to the current Kubernetes major.minor version
// indicates fast forward CEL features that can be used when rollback is no longer needed.
func DefaultCompatibilityVersion() *version.Version {
return version.MajorMinor(1, 28)
effectiveVer := utilversion.DefaultComponentGlobalsRegistry.EffectiveVersionFor(utilversion.DefaultKubeComponent)
if effectiveVer == nil {
effectiveVer = utilversion.DefaultKubeEffectiveVersion()
}
return effectiveVer.MinCompatibilityVersion()
}
var baseOpts = []VersionedOptions{
var baseOpts = append(baseOptsWithoutStrictCost, StrictCostOpt)
var baseOptsWithoutStrictCost = []VersionedOptions{
{
// CEL epoch was actually 1.23, but we artificially set it to 1.0 because these
// options should always be present.
@@ -123,6 +135,60 @@ var baseOpts = []VersionedOptions{
ext.Sets(),
},
},
{
IntroducedVersion: version.MajorMinor(1, 30),
EnvOptions: []cel.EnvOption{
library.IP(),
library.CIDR(),
},
},
// Format Library
{
IntroducedVersion: version.MajorMinor(1, 31),
EnvOptions: []cel.EnvOption{
library.Format(),
},
},
// Authz selectors
{
IntroducedVersion: version.MajorMinor(1, 31),
FeatureEnabled: func() bool {
enabled := utilfeature.DefaultFeatureGate.Enabled(genericfeatures.AuthorizeWithSelectors)
authzSelectorsLibraryInit.Do(func() {
// Record the first time feature enablement was checked for this library.
// This is checked from integration tests to ensure no cached cel envs
// are constructed before feature enablement is effectively set.
authzSelectorsLibraryEnabled.Store(enabled)
// Uncomment to debug where the first initialization is coming from if needed.
// debug.PrintStack()
})
return enabled
},
EnvOptions: []cel.EnvOption{
library.AuthzSelectors(),
},
},
}
var (
authzSelectorsLibraryInit sync.Once
authzSelectorsLibraryEnabled atomic.Value
)
// AuthzSelectorsLibraryEnabled returns whether the AuthzSelectors library was enabled when it was constructed.
// If it has not been contructed yet, this returns `false, false`.
// This is solely for the benefit of the integration tests making sure feature gates get correctly parsed before AuthzSelector ever has to check for enablement.
func AuthzSelectorsLibraryEnabled() (enabled, constructed bool) {
enabled, constructed = authzSelectorsLibraryEnabled.Load().(bool)
return
}
var StrictCostOpt = VersionedOptions{
// This is to configure the cost calculation for extended libraries
IntroducedVersion: version.MajorMinor(1, 0),
ProgramOptions: []cel.ProgramOption{
cel.CostTracking(&library.CostEstimator{}),
},
}
// MustBaseEnvSet returns the common CEL base environments for Kubernetes for Version, or panics
@@ -134,7 +200,8 @@ var baseOpts = []VersionedOptions{
// The returned environment contains no CEL variable definitions or custom type declarations and
// should be extended to construct environments with the appropriate variable definitions,
// type declarations and any other needed configuration.
func MustBaseEnvSet(ver *version.Version) *EnvSet {
// strictCost is used to determine whether to enforce strict cost calculation for CEL expressions.
func MustBaseEnvSet(ver *version.Version, strictCost bool) *EnvSet {
if ver == nil {
panic("version must be non-nil")
}
@@ -142,19 +209,33 @@ func MustBaseEnvSet(ver *version.Version) *EnvSet {
panic(fmt.Sprintf("version must contain an major and minor component, but got: %s", ver.String()))
}
key := strconv.FormatUint(uint64(ver.Major()), 10) + "." + strconv.FormatUint(uint64(ver.Minor()), 10)
if entry, ok := baseEnvs.Load(key); ok {
return entry.(*EnvSet)
var entry interface{}
if strictCost {
if entry, ok := baseEnvs.Load(key); ok {
return entry.(*EnvSet)
}
entry, _, _ = baseEnvsSingleflight.Do(key, func() (interface{}, error) {
entry := mustNewEnvSet(ver, baseOpts)
baseEnvs.Store(key, entry)
return entry, nil
})
} else {
if entry, ok := baseEnvsWithOption.Load(key); ok {
return entry.(*EnvSet)
}
entry, _, _ = baseEnvsWithOptionSingleflight.Do(key, func() (interface{}, error) {
entry := mustNewEnvSet(ver, baseOptsWithoutStrictCost)
baseEnvsWithOption.Store(key, entry)
return entry, nil
})
}
entry, _, _ := baseEnvsSingleflight.Do(key, func() (interface{}, error) {
entry := mustNewEnvSet(ver, baseOpts)
baseEnvs.Store(key, entry)
return entry, nil
})
return entry.(*EnvSet)
}
var (
baseEnvs = sync.Map{}
baseEnvsSingleflight = &singleflight.Group{}
baseEnvs = sync.Map{}
baseEnvsWithOption = sync.Map{}
baseEnvsSingleflight = &singleflight.Group{}
baseEnvsWithOptionSingleflight = &singleflight.Group{}
)

View File

@@ -175,7 +175,15 @@ type VersionedOptions struct {
//
// Optional.
RemovedVersion *version.Version
// FeatureEnabled returns true if these options are enabled by feature gates,
// and returns false if these options are not enabled due to feature gates.
//
// This takes priority over IntroducedVersion / RemovedVersion for the NewExpressions environment.
//
// The StoredExpressions environment ignores this function.
//
// Optional.
FeatureEnabled func() bool
// EnvOptions provides CEL EnvOptions. This may be used to add a cel.Variable, a
// cel.Library, or to enable other CEL EnvOptions such as language settings.
//
@@ -210,7 +218,7 @@ type VersionedOptions struct {
// making multiple calls to Extend.
func (e *EnvSet) Extend(options ...VersionedOptions) (*EnvSet, error) {
if len(options) > 0 {
newExprOpts, err := e.filterAndBuildOpts(e.newExpressions, e.compatibilityVersion, options)
newExprOpts, err := e.filterAndBuildOpts(e.newExpressions, e.compatibilityVersion, true, options)
if err != nil {
return nil, err
}
@@ -218,7 +226,7 @@ func (e *EnvSet) Extend(options ...VersionedOptions) (*EnvSet, error) {
if err != nil {
return nil, err
}
storedExprOpt, err := e.filterAndBuildOpts(e.storedExpressions, version.MajorMinor(math.MaxUint, math.MaxUint), options)
storedExprOpt, err := e.filterAndBuildOpts(e.storedExpressions, version.MajorMinor(math.MaxUint, math.MaxUint), false, options)
if err != nil {
return nil, err
}
@@ -231,13 +239,26 @@ func (e *EnvSet) Extend(options ...VersionedOptions) (*EnvSet, error) {
return e, nil
}
func (e *EnvSet) filterAndBuildOpts(base *cel.Env, compatVer *version.Version, opts []VersionedOptions) (cel.EnvOption, error) {
func (e *EnvSet) filterAndBuildOpts(base *cel.Env, compatVer *version.Version, honorFeatureGateEnablement bool, opts []VersionedOptions) (cel.EnvOption, error) {
var envOpts []cel.EnvOption
var progOpts []cel.ProgramOption
var declTypes []*apiservercel.DeclType
for _, opt := range opts {
var allowedByFeatureGate, allowedByVersion bool
if opt.FeatureEnabled != nil && honorFeatureGateEnablement {
// Feature-gate-enabled libraries must follow compatible default feature enablement.
// Enabling alpha features in their first release enables libraries the previous API server is unaware of.
allowedByFeatureGate = opt.FeatureEnabled()
if !allowedByFeatureGate {
continue
}
}
if compatVer.AtLeast(opt.IntroducedVersion) && (opt.RemovedVersion == nil || compatVer.LessThan(opt.RemovedVersion)) {
allowedByVersion = true
}
if allowedByFeatureGate || allowedByVersion {
envOpts = append(envOpts, opt.EnvOptions...)
progOpts = append(progOpts, opt.ProgramOptions...)
declTypes = append(declTypes, opt.DeclTypes...)
@@ -246,7 +267,10 @@ func (e *EnvSet) filterAndBuildOpts(base *cel.Env, compatVer *version.Version, o
if len(declTypes) > 0 {
provider := apiservercel.NewDeclTypeProvider(declTypes...)
providerOpts, err := provider.EnvOptions(base.TypeProvider())
if compatVer.AtLeast(version.MajorMinor(1, 31)) {
provider.SetRecognizeKeywordAsFieldName(true)
}
providerOpts, err := provider.EnvOptions(base.CELTypeProvider())
if err != nil {
return nil, err
}

View File

@@ -16,11 +16,46 @@ limitations under the License.
package cel
import (
"fmt"
"github.com/google/cel-go/cel"
)
// ErrInternal the basic error that occurs when the expression fails to evaluate
// due to internal reasons. Any Error that has the Type of
// ErrorInternal is considered equal to ErrInternal
var ErrInternal = fmt.Errorf("internal")
// ErrInvalid is the basic error that occurs when the expression fails to
// evaluate but not due to internal reasons. Any Error that has the Type of
// ErrorInvalid is considered equal to ErrInvalid.
var ErrInvalid = fmt.Errorf("invalid")
// ErrRequired is the basic error that occurs when the expression is required
// but absent.
// Any Error that has the Type of ErrorRequired is considered equal
// to ErrRequired.
var ErrRequired = fmt.Errorf("required")
// ErrCompilation is the basic error that occurs when the expression fails to
// compile. Any CompilationError wraps ErrCompilation.
// ErrCompilation wraps ErrInvalid
var ErrCompilation = fmt.Errorf("%w: compilation error", ErrInvalid)
// ErrOutOfBudget is the basic error that occurs when the expression fails due to
// exceeding budget.
var ErrOutOfBudget = fmt.Errorf("out of budget")
// Error is an implementation of the 'error' interface, which represents a
// XValidation error.
type Error struct {
Type ErrorType
Detail string
// Cause is an optional wrapped errors that can be useful to
// programmatically retrieve detailed errors.
Cause error
}
var _ error = &Error{}
@@ -30,7 +65,24 @@ func (v *Error) Error() string {
return v.Detail
}
// ErrorType is a machine readable value providing more detail about why
func (v *Error) Is(err error) bool {
switch v.Type {
case ErrorTypeRequired:
return err == ErrRequired
case ErrorTypeInvalid:
return err == ErrInvalid
case ErrorTypeInternal:
return err == ErrInternal
}
return false
}
// Unwrap returns the wrapped Cause.
func (v *Error) Unwrap() error {
return v.Cause
}
// ErrorType is a machine-readable value providing more detail about why
// a XValidation is invalid.
type ErrorType string
@@ -45,3 +97,28 @@ const (
// to user input. See InternalError().
ErrorTypeInternal ErrorType = "InternalError"
)
// CompilationError indicates an error during expression compilation.
// It wraps ErrCompilation.
type CompilationError struct {
err *Error
Issues *cel.Issues
}
// NewCompilationError wraps a cel.Issues to indicate a compilation failure.
func NewCompilationError(issues *cel.Issues) *CompilationError {
return &CompilationError{
Issues: issues,
err: &Error{
Type: ErrorTypeInvalid,
Detail: fmt.Sprintf("compilation error: %s", issues),
}}
}
func (e *CompilationError) Error() string {
return e.err.Error()
}
func (e *CompilationError) Unwrap() []error {
return []error{e.err, ErrCompilation}
}

73
vendor/k8s.io/apiserver/pkg/cel/format.go generated vendored Normal file
View File

@@ -0,0 +1,73 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cel
import (
"fmt"
"reflect"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
var (
FormatObject = decls.NewObjectType("kubernetes.NamedFormat")
FormatType = cel.ObjectType("kubernetes.NamedFormat")
)
// Format provdes a CEL representation of kubernetes format
type Format struct {
Name string
ValidateFunc func(string) []string
// Size of the regex string or estimated equivalent regex string used
// for cost estimation
MaxRegexSize int
}
func (d *Format) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
return nil, fmt.Errorf("type conversion error from 'Format' to '%v'", typeDesc)
}
func (d *Format) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case FormatType:
return d
case types.TypeType:
return FormatType
default:
return types.NewErr("type conversion error from '%s' to '%s'", FormatType, typeVal)
}
}
func (d *Format) Equal(other ref.Val) ref.Val {
otherDur, ok := other.(*Format)
if !ok {
return types.MaybeNoSuchOverloadErr(other)
}
return types.Bool(d.Name == otherDur.Name)
}
func (d *Format) Type() ref.Type {
return FormatType
}
func (d *Format) Value() interface{} {
return d
}

86
vendor/k8s.io/apiserver/pkg/cel/ip.go generated vendored Normal file
View File

@@ -0,0 +1,86 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cel
import (
"fmt"
"math"
"net/netip"
"reflect"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// IP provides a CEL representation of an IP address.
type IP struct {
netip.Addr
}
var (
IPType = cel.OpaqueType("net.IP")
)
// ConvertToNative implements ref.Val.ConvertToNative.
func (d IP) ConvertToNative(typeDesc reflect.Type) (any, error) {
if reflect.TypeOf(d.Addr).AssignableTo(typeDesc) {
return d.Addr, nil
}
if reflect.TypeOf("").AssignableTo(typeDesc) {
return d.Addr.String(), nil
}
return nil, fmt.Errorf("type conversion error from 'IP' to '%v'", typeDesc)
}
// ConvertToType implements ref.Val.ConvertToType.
func (d IP) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case IPType:
return d
case types.TypeType:
return IPType
case types.StringType:
return types.String(d.Addr.String())
}
return types.NewErr("type conversion error from '%s' to '%s'", IPType, typeVal)
}
// Equal implements ref.Val.Equal.
func (d IP) Equal(other ref.Val) ref.Val {
otherD, ok := other.(IP)
if !ok {
return types.ValOrErr(other, "no such overload")
}
return types.Bool(d.Addr == otherD.Addr)
}
// Type implements ref.Val.Type.
func (d IP) Type() ref.Type {
return IPType
}
// Value implements ref.Val.Value.
func (d IP) Value() any {
return d.Addr
}
// Size returns the size of the IP address in bytes.
// Used in the size estimation of the runtime cost.
func (d IP) Size() ref.Val {
return types.Int(int(math.Ceil(float64(d.Addr.BitLen()) / 8)))
}

View File

@@ -22,6 +22,11 @@ import (
"reflect"
"strings"
"k8s.io/apimachinery/pkg/fields"
"k8s.io/apimachinery/pkg/labels"
genericfeatures "k8s.io/apiserver/pkg/features"
utilfeature "k8s.io/apiserver/pkg/util/feature"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
@@ -194,6 +199,30 @@ import (
// Examples:
//
// authorizer.group('').resource('pods').namespace('default').check('create').error()
//
// fieldSelector
//
// Takes a string field selector, parses it to field selector requirements, and includes it in the authorization check.
// If the field selector does not parse successfully, no field selector requirements are included in the authorization check.
// Added in Kubernetes 1.31+, Authz library version 1.
//
// <ResourceCheck>.fieldSelector(<string>) <ResourceCheck>
//
// Examples:
//
// authorizer.group('').resource('pods').fieldSelector('spec.nodeName=mynode').check('list').allowed()
//
// labelSelector (added in v1, Kubernetes 1.31+)
//
// Takes a string label selector, parses it to label selector requirements, and includes it in the authorization check.
// If the label selector does not parse successfully, no label selector requirements are included in the authorization check.
// Added in Kubernetes 1.31+, Authz library version 1.
//
// <ResourceCheck>.labelSelector(<string>) <ResourceCheck>
//
// Examples:
//
// authorizer.group('').resource('pods').labelSelector('app=example').check('list').allowed()
func Authz() cel.EnvOption {
return cel.Lib(authzLib)
}
@@ -259,6 +288,66 @@ func (*authz) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
// AuthzSelectors provides a CEL function library extension for adding fieldSelector and
// labelSelector filters to authorization checks. This requires the Authz library.
// See documentation of the Authz library for use and availability of the authorizer variable.
//
// fieldSelector
//
// Takes a string field selector, parses it to field selector requirements, and includes it in the authorization check.
// If the field selector does not parse successfully, no field selector requirements are included in the authorization check.
// Added in Kubernetes 1.31+.
//
// <ResourceCheck>.fieldSelector(<string>) <ResourceCheck>
//
// Examples:
//
// authorizer.group('').resource('pods').fieldSelector('spec.nodeName=mynode').check('list').allowed()
//
// labelSelector
//
// Takes a string label selector, parses it to label selector requirements, and includes it in the authorization check.
// If the label selector does not parse successfully, no label selector requirements are included in the authorization check.
// Added in Kubernetes 1.31+.
//
// <ResourceCheck>.labelSelector(<string>) <ResourceCheck>
//
// Examples:
//
// authorizer.group('').resource('pods').labelSelector('app=example').check('list').allowed()
func AuthzSelectors() cel.EnvOption {
return cel.Lib(authzSelectorsLib)
}
var authzSelectorsLib = &authzSelectors{}
type authzSelectors struct{}
func (*authzSelectors) LibraryName() string {
return "k8s.authzSelectors"
}
var authzSelectorsLibraryDecls = map[string][]cel.FunctionOpt{
"fieldSelector": {
cel.MemberOverload("authorizer_fieldselector", []*cel.Type{ResourceCheckType, cel.StringType}, ResourceCheckType,
cel.BinaryBinding(resourceCheckFieldSelector))},
"labelSelector": {
cel.MemberOverload("authorizer_labelselector", []*cel.Type{ResourceCheckType, cel.StringType}, ResourceCheckType,
cel.BinaryBinding(resourceCheckLabelSelector))},
}
func (*authzSelectors) CompileOptions() []cel.EnvOption {
options := make([]cel.EnvOption, 0, len(authzSelectorsLibraryDecls))
for name, overloads := range authzSelectorsLibraryDecls {
options = append(options, cel.Function(name, overloads...))
}
return options
}
func (*authzSelectors) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func authorizerPath(arg1, arg2 ref.Val) ref.Val {
authz, ok := arg1.(authorizerVal)
if !ok {
@@ -354,6 +443,38 @@ func resourceCheckSubresource(arg1, arg2 ref.Val) ref.Val {
return result
}
func resourceCheckFieldSelector(arg1, arg2 ref.Val) ref.Val {
resourceCheck, ok := arg1.(resourceCheckVal)
if !ok {
return types.MaybeNoSuchOverloadErr(arg1)
}
fieldSelector, ok := arg2.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg1)
}
result := resourceCheck
result.fieldSelector = fieldSelector
return result
}
func resourceCheckLabelSelector(arg1, arg2 ref.Val) ref.Val {
resourceCheck, ok := arg1.(resourceCheckVal)
if !ok {
return types.MaybeNoSuchOverloadErr(arg1)
}
labelSelector, ok := arg2.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg1)
}
result := resourceCheck
result.labelSelector = labelSelector
return result
}
func resourceCheckNamespace(arg1, arg2 ref.Val) ref.Val {
resourceCheck, ok := arg1.(resourceCheckVal)
if !ok {
@@ -544,11 +665,13 @@ func (g groupCheckVal) resourceCheck(resource string) resourceCheckVal {
type resourceCheckVal struct {
receiverOnlyObjectVal
groupCheck groupCheckVal
resource string
subresource string
namespace string
name string
groupCheck groupCheckVal
resource string
subresource string
namespace string
name string
fieldSelector string
labelSelector string
}
func (a resourceCheckVal) Authorize(ctx context.Context, verb string) ref.Val {
@@ -563,6 +686,26 @@ func (a resourceCheckVal) Authorize(ctx context.Context, verb string) ref.Val {
Verb: verb,
User: a.groupCheck.authorizer.userInfo,
}
if utilfeature.DefaultFeatureGate.Enabled(genericfeatures.AuthorizeWithSelectors) {
if len(a.fieldSelector) > 0 {
selector, err := fields.ParseSelector(a.fieldSelector)
if err != nil {
attr.FieldSelectorRequirements, attr.FieldSelectorParsingErr = nil, err
} else {
attr.FieldSelectorRequirements, attr.FieldSelectorParsingErr = selector.Requirements(), nil
}
}
if len(a.labelSelector) > 0 {
requirements, err := labels.ParseToRequirements(a.labelSelector)
if err != nil {
attr.LabelSelectorRequirements, attr.LabelSelectorParsingErr = nil, err
} else {
attr.LabelSelectorRequirements, attr.LabelSelectorParsingErr = requirements, nil
}
}
}
decision, reason, err := a.groupCheck.authorizer.authAuthorizer.Authorize(ctx, attr)
return newDecision(decision, err, reason)
}

287
vendor/k8s.io/apiserver/pkg/cel/library/cidr.go generated vendored Normal file
View File

@@ -0,0 +1,287 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"fmt"
"net/netip"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
apiservercel "k8s.io/apiserver/pkg/cel"
)
// CIDR provides a CEL function library extension of CIDR notation parsing functions.
//
// cidr
//
// Converts a string in CIDR notation to a network address representation or results in an error if the string is not a valid CIDR notation.
// The CIDR must be an IPv4 or IPv6 subnet address with a mask.
// Leading zeros in IPv4 address octets are not allowed.
// IPv4-mapped IPv6 addresses (e.g. ::ffff:1.2.3.4/24) are not allowed.
//
// cidr(<string>) <CIDR>
//
// Examples:
//
// cidr('192.168.0.0/16') // returns an IPv4 address with a CIDR mask
// cidr('::1/128') // returns an IPv6 address with a CIDR mask
// cidr('192.168.0.0/33') // error
// cidr('::1/129') // error
// cidr('192.168.0.1/16') // error, because there are non-0 bits after the prefix
//
// isCIDR
//
// Returns true if a string is a valid CIDR notation respresentation of a subnet with mask.
// The CIDR must be an IPv4 or IPv6 subnet address with a mask.
// Leading zeros in IPv4 address octets are not allowed.
// IPv4-mapped IPv6 addresses (e.g. ::ffff:1.2.3.4/24) are not allowed.
//
// isCIDR(<string>) <bool>
//
// Examples:
//
// isCIDR('192.168.0.0/16') // returns true
// isCIDR('::1/128') // returns true
// isCIDR('192.168.0.0/33') // returns false
// isCIDR('::1/129') // returns false
//
// containsIP / containerCIDR / ip / masked / prefixLength
//
// - containsIP: Returns true if a the CIDR contains the given IP address.
// The IP address must be an IPv4 or IPv6 address.
// May take either a string or IP address as an argument.
//
// - containsCIDR: Returns true if a the CIDR contains the given CIDR.
// The CIDR must be an IPv4 or IPv6 subnet address with a mask.
// May take either a string or CIDR as an argument.
//
// - ip: Returns the IP address representation of the CIDR.
//
// - masked: Returns the CIDR representation of the network address with a masked prefix.
// This can be used to return the canonical form of the CIDR network.
//
// - prefixLength: Returns the prefix length of the CIDR in bits.
// This is the number of bits in the mask.
//
// Examples:
//
// cidr('192.168.0.0/24').containsIP(ip('192.168.0.1')) // returns true
// cidr('192.168.0.0/24').containsIP(ip('192.168.1.1')) // returns false
// cidr('192.168.0.0/24').containsIP('192.168.0.1') // returns true
// cidr('192.168.0.0/24').containsIP('192.168.1.1') // returns false
// cidr('192.168.0.0/16').containsCIDR(cidr('192.168.10.0/24')) // returns true
// cidr('192.168.1.0/24').containsCIDR(cidr('192.168.2.0/24')) // returns false
// cidr('192.168.0.0/16').containsCIDR('192.168.10.0/24') // returns true
// cidr('192.168.1.0/24').containsCIDR('192.168.2.0/24') // returns false
// cidr('192.168.0.1/24').ip() // returns ipAddr('192.168.0.1')
// cidr('192.168.0.1/24').ip().family() // returns '4'
// cidr('::1/128').ip() // returns ipAddr('::1')
// cidr('::1/128').ip().family() // returns '6'
// cidr('192.168.0.0/24').masked() // returns cidr('192.168.0.0/24')
// cidr('192.168.0.1/24').masked() // returns cidr('192.168.0.0/24')
// cidr('192.168.0.0/24') == cidr('192.168.0.0/24').masked() // returns true, CIDR was already in canonical format
// cidr('192.168.0.1/24') == cidr('192.168.0.1/24').masked() // returns false, CIDR was not in canonical format
// cidr('192.168.0.0/16').prefixLength() // returns 16
// cidr('::1/128').prefixLength() // returns 128
func CIDR() cel.EnvOption {
return cel.Lib(cidrsLib)
}
var cidrsLib = &cidrs{}
type cidrs struct{}
func (*cidrs) LibraryName() string {
return "net.cidr"
}
var cidrLibraryDecls = map[string][]cel.FunctionOpt{
"cidr": {
cel.Overload("string_to_cidr", []*cel.Type{cel.StringType}, apiservercel.CIDRType,
cel.UnaryBinding(stringToCIDR)),
},
"containsIP": {
cel.MemberOverload("cidr_contains_ip_string", []*cel.Type{apiservercel.CIDRType, cel.StringType}, cel.BoolType,
cel.BinaryBinding(cidrContainsIPString)),
cel.MemberOverload("cidr_contains_ip_ip", []*cel.Type{apiservercel.CIDRType, apiservercel.IPType}, cel.BoolType,
cel.BinaryBinding(cidrContainsIP)),
},
"containsCIDR": {
cel.MemberOverload("cidr_contains_cidr_string", []*cel.Type{apiservercel.CIDRType, cel.StringType}, cel.BoolType,
cel.BinaryBinding(cidrContainsCIDRString)),
cel.MemberOverload("cidr_contains_cidr", []*cel.Type{apiservercel.CIDRType, apiservercel.CIDRType}, cel.BoolType,
cel.BinaryBinding(cidrContainsCIDR)),
},
"ip": {
cel.MemberOverload("cidr_ip", []*cel.Type{apiservercel.CIDRType}, apiservercel.IPType,
cel.UnaryBinding(cidrToIP)),
},
"prefixLength": {
cel.MemberOverload("cidr_prefix_length", []*cel.Type{apiservercel.CIDRType}, cel.IntType,
cel.UnaryBinding(prefixLength)),
},
"masked": {
cel.MemberOverload("cidr_masked", []*cel.Type{apiservercel.CIDRType}, apiservercel.CIDRType,
cel.UnaryBinding(masked)),
},
"isCIDR": {
cel.Overload("is_cidr", []*cel.Type{cel.StringType}, cel.BoolType,
cel.UnaryBinding(isCIDR)),
},
"string": {
cel.Overload("cidr_to_string", []*cel.Type{apiservercel.CIDRType}, cel.StringType,
cel.UnaryBinding(cidrToString)),
},
}
func (*cidrs) CompileOptions() []cel.EnvOption {
options := []cel.EnvOption{cel.Types(apiservercel.CIDRType),
cel.Variable(apiservercel.CIDRType.TypeName(), types.NewTypeTypeWithParam(apiservercel.CIDRType)),
}
for name, overloads := range cidrLibraryDecls {
options = append(options, cel.Function(name, overloads...))
}
return options
}
func (*cidrs) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func stringToCIDR(arg ref.Val) ref.Val {
s, ok := arg.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
net, err := parseCIDR(s)
if err != nil {
return types.NewErr("network address parse error during conversion from string: %v", err)
}
return apiservercel.CIDR{
Prefix: net,
}
}
func cidrToString(arg ref.Val) ref.Val {
cidr, ok := arg.(apiservercel.CIDR)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.String(cidr.Prefix.String())
}
func cidrContainsIPString(arg ref.Val, other ref.Val) ref.Val {
return cidrContainsIP(arg, stringToIP(other))
}
func cidrContainsCIDRString(arg ref.Val, other ref.Val) ref.Val {
return cidrContainsCIDR(arg, stringToCIDR(other))
}
func cidrContainsIP(arg ref.Val, other ref.Val) ref.Val {
cidr, ok := arg.(apiservercel.CIDR)
if !ok {
return types.MaybeNoSuchOverloadErr(other)
}
ip, ok := other.(apiservercel.IP)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.Bool(cidr.Contains(ip.Addr))
}
func cidrContainsCIDR(arg ref.Val, other ref.Val) ref.Val {
cidr, ok := arg.(apiservercel.CIDR)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
containsCIDR, ok := other.(apiservercel.CIDR)
if !ok {
return types.MaybeNoSuchOverloadErr(other)
}
equalMasked := cidr.Prefix.Masked() == netip.PrefixFrom(containsCIDR.Prefix.Addr(), cidr.Prefix.Bits())
return types.Bool(equalMasked && cidr.Prefix.Bits() <= containsCIDR.Prefix.Bits())
}
func prefixLength(arg ref.Val) ref.Val {
cidr, ok := arg.(apiservercel.CIDR)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.Int(cidr.Prefix.Bits())
}
func isCIDR(arg ref.Val) ref.Val {
s, ok := arg.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
_, err := parseCIDR(s)
return types.Bool(err == nil)
}
func cidrToIP(arg ref.Val) ref.Val {
cidr, ok := arg.(apiservercel.CIDR)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return apiservercel.IP{
Addr: cidr.Prefix.Addr(),
}
}
func masked(arg ref.Val) ref.Val {
cidr, ok := arg.(apiservercel.CIDR)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
maskedCIDR := cidr.Prefix.Masked()
return apiservercel.CIDR{
Prefix: maskedCIDR,
}
}
// parseCIDR parses a string into an CIDR.
// We use this function to parse CIDR notation in the CEL library
// so that we can share the common logic of rejecting strings
// that IPv4-mapped IPv6 addresses or contain non-zero bits after the mask.
func parseCIDR(raw string) (netip.Prefix, error) {
net, err := netip.ParsePrefix(raw)
if err != nil {
return netip.Prefix{}, fmt.Errorf("network address parse error during conversion from string: %v", err)
}
if net.Addr().Is4In6() {
return netip.Prefix{}, fmt.Errorf("IPv4-mapped IPv6 address %q is not allowed", raw)
}
return net, nil
}

View File

@@ -17,16 +17,37 @@ limitations under the License.
package library
import (
"fmt"
"math"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"k8s.io/apiserver/pkg/cel"
)
// panicOnUnknown makes cost estimate functions panic on unrecognized functions.
// This is only set to true for unit tests.
var panicOnUnknown = false
// builtInFunctions is a list of functions used in cost tests that are not handled by CostEstimator.
var knownUnhandledFunctions = map[string]bool{
"uint": true,
"duration": true,
"bytes": true,
"timestamp": true,
"value": true,
"_==_": true,
"_&&_": true,
"_>_": true,
"!_": true,
"strings.quote": true,
}
// CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator.
type CostEstimator struct {
// SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation
@@ -34,6 +55,25 @@ type CostEstimator struct {
SizeEstimator checker.CostEstimator
}
const (
// shortest repeatable selector requirement that allocates a values slice is 2 characters: k,
selectorLengthToRequirementCount = float64(.5)
// the expensive parts to represent each requirement are a struct and a values slice
costPerRequirement = float64(common.ListCreateBaseCost + common.StructCreateBaseCost)
)
// a selector consists of a list of requirements held in a slice
var baseSelectorCost = checker.CostEstimate{Min: common.ListCreateBaseCost, Max: common.ListCreateBaseCost}
func selectorCostEstimate(selectorLength checker.SizeEstimate) checker.CostEstimate {
parseCost := selectorLength.MultiplyByCostFactor(common.StringTraversalCostFactor)
requirementCount := selectorLength.MultiplyByCostFactor(selectorLengthToRequirementCount)
requirementCost := requirementCount.MultiplyByCostFactor(costPerRequirement)
return baseSelectorCost.Add(parseCost).Add(requirementCost)
}
func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, result ref.Val) *uint64 {
switch function {
case "check":
@@ -45,6 +85,13 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re
// All authorization builder and accessor functions have a nominal cost
cost := uint64(1)
return &cost
case "fieldSelector", "labelSelector":
// field and label selector parse is a string parse into a structured set of requirements
if len(args) >= 2 {
selectorLength := actualSize(args[1])
cost := selectorCostEstimate(checker.SizeEstimate{Min: selectorLength, Max: selectorLength})
return &cost.Max
}
case "isSorted", "sum", "max", "min", "indexOf", "lastIndexOf":
var cost uint64
if len(args) > 0 {
@@ -79,6 +126,118 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re
cost := strCost * regexCost
return &cost
}
case "cidr", "isIP", "isCIDR":
// IP and CIDR parsing is a string traversal.
if len(args) >= 1 {
cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor))
return &cost
}
case "ip":
// IP and CIDR parsing is a string traversal.
if len(args) >= 1 {
if overloadId == "cidr_ip" {
// The IP member of the CIDR object is just accessing a field.
// Nominal cost.
cost := uint64(1)
return &cost
}
cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor))
return &cost
}
case "ip.isCanonical":
if len(args) >= 1 {
// We have to parse the string and then compare the parsed string to the original string.
// So we double the cost of parsing the string.
cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * common.StringTraversalCostFactor))
return &cost
}
case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast", "isGlobalUnicast":
// IP and CIDR accessors are nominal cost.
cost := uint64(1)
return &cost
case "containsIP":
if len(args) >= 2 {
cidrSize := actualSize(args[0])
otherSize := actualSize(args[1])
// This is the base cost of comparing two byte lists.
// We will compare only up to the length of the CIDR prefix in bytes, so use the cidrSize twice.
cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * common.StringTraversalCostFactor))
if overloadId == "cidr_contains_ip_string" {
// If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again.
cost += uint64(math.Ceil(float64(otherSize) * common.StringTraversalCostFactor))
}
return &cost
}
case "containsCIDR":
if len(args) >= 2 {
cidrSize := actualSize(args[0])
otherSize := actualSize(args[1])
// This is the base cost of comparing two byte lists.
// We will compare only up to the length of the CIDR prefix in bytes, so use the cidrSize twice.
cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * common.StringTraversalCostFactor))
// As we are comparing if a CIDR is within another CIDR, we first mask the base CIDR and
// also compare the CIDR bits.
// This has an additional cost of the length of the IP being traversed again, plus 1.
cost += uint64(math.Ceil(float64(cidrSize)*common.StringTraversalCostFactor)) + 1
if overloadId == "cidr_contains_cidr_string" {
// If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again.
cost += uint64(math.Ceil(float64(otherSize) * common.StringTraversalCostFactor))
}
return &cost
}
case "quantity", "isQuantity":
if len(args) >= 1 {
cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor))
return &cost
}
case "validate":
if len(args) >= 2 {
format, isFormat := args[0].Value().(*cel.Format)
if isFormat {
strSize := actualSize(args[1])
// Dont have access to underlying regex, estimate a long regexp
regexSize := format.MaxRegexSize
// Copied from CEL implementation for regex cost
//
// https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL
// Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0
// in case where string is empty but regex is still expensive.
strCost := uint64(math.Ceil((1.0 + float64(strSize)) * common.StringTraversalCostFactor))
// We don't know how many expressions are in the regex, just the string length (a huge
// improvement here would be to somehow get a count the number of expressions in the regex or
// how many states are in the regex state machine and use that to measure regex cost).
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
// in length.
regexCost := uint64(math.Ceil(float64(regexSize) * common.RegexStringLengthCostFactor))
cost := strCost * regexCost
return &cost
}
}
case "format.named":
// Simply dictionary lookup
cost := uint64(1)
return &cost
case "sign", "asInteger", "isInteger", "asApproximateFloat", "isGreaterThan", "isLessThan", "compareTo", "add", "sub":
cost := uint64(1)
return &cost
case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery":
// url accessors
cost := uint64(1)
return &cost
}
if panicOnUnknown && !knownUnhandledFunctions[function] {
panic(fmt.Errorf("CallCost: unhandled function %q or args %v", function, args))
}
return nil
}
@@ -94,6 +253,11 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
case "serviceAccount", "path", "group", "resource", "subresource", "namespace", "name", "allowed", "reason", "error", "errored":
// All authorization builder and accessor functions have a nominal cost
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
case "fieldSelector", "labelSelector":
// field and label selector parse is a string parse into a structured set of requirements
if len(args) == 1 {
return &checker.CallEstimate{CostEstimate: selectorCostEstimate(l.sizeEstimate(args[0]))}
}
case "isSorted", "sum", "max", "min", "indexOf", "lastIndexOf":
if target != nil {
// Charge 1 cost for comparing each element in the list
@@ -179,8 +343,10 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
// Worst case size is where is that a separator of "" is used, and each char is returned as a list element.
max := sz.Max
if len(args) > 1 {
if c := args[1].Expr().GetConstExpr(); c != nil {
max = uint64(c.GetInt64Value())
if v := args[1].Expr().AsLiteral(); v != nil {
if i, ok := v.Value().(int64); ok {
max = uint64(i)
}
}
}
// Cost is the traversal plus the construction of the result.
@@ -225,6 +391,93 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
// worst case size of result is that every char is returned as separate find result.
return &checker.CallEstimate{CostEstimate: strCost.Multiply(regexCost), ResultSize: &checker.SizeEstimate{Min: 0, Max: sz.Max}}
}
case "cidr", "isIP", "isCIDR":
if target != nil {
sz := l.sizeEstimate(args[0])
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)}
}
case "ip":
if target != nil && len(args) >= 1 {
if overloadId == "cidr_ip" {
// The IP member of the CIDR object is just accessing a field.
// Nominal cost.
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
}
sz := l.sizeEstimate(args[0])
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)}
} else if target != nil {
// The IP member of a CIDR is a just accessing a field, nominal cost.
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
}
case "ip.isCanonical":
if target != nil && len(args) >= 1 {
sz := l.sizeEstimate(args[0])
// We have to parse the string and then compare the parsed string to the original string.
// So we double the cost of parsing the string.
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor)}
}
case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast", "isGlobalUnicast":
// IP and CIDR accessors are nominal cost.
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
case "containsIP":
if target != nil && len(args) >= 1 {
// The base cost of the function is the cost of comparing two byte lists.
// The byte lists will be either ipv4 or ipv6 so will have a length of 4, or 16 bytes.
sz := checker.SizeEstimate{Min: 4, Max: 16}
// We have to compare the two strings to determine if the CIDR/IP is in the other CIDR.
ipCompCost := sz.Add(sz).MultiplyByCostFactor(common.StringTraversalCostFactor)
if overloadId == "cidr_contains_ip_string" {
// If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again.
ipCompCost = ipCompCost.Add(checker.CostEstimate(l.sizeEstimate(args[0])).MultiplyByCostFactor(common.StringTraversalCostFactor))
}
return &checker.CallEstimate{CostEstimate: ipCompCost}
}
case "containsCIDR":
if target != nil && len(args) >= 1 {
// The base cost of the function is the cost of comparing two byte lists.
// The byte lists will be either ipv4 or ipv6 so will have a length of 4, or 16 bytes.
sz := checker.SizeEstimate{Min: 4, Max: 16}
// We have to compare the two strings to determine if the CIDR/IP is in the other CIDR.
ipCompCost := sz.Add(sz).MultiplyByCostFactor(common.StringTraversalCostFactor)
// As we are comparing if a CIDR is within another CIDR, we first mask the base CIDR and
// also compare the CIDR bits.
// This has an additional cost of the length of the IP being traversed again, plus 1.
ipCompCost = ipCompCost.Add(sz.MultiplyByCostFactor(common.StringTraversalCostFactor))
ipCompCost = ipCompCost.Add(checker.CostEstimate{Min: 1, Max: 1})
if overloadId == "cidr_contains_cidr_string" {
// If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again.
ipCompCost = ipCompCost.Add(checker.CostEstimate(l.sizeEstimate(args[0])).MultiplyByCostFactor(common.StringTraversalCostFactor))
}
return &checker.CallEstimate{CostEstimate: ipCompCost}
}
case "quantity", "isQuantity":
if target != nil {
sz := l.sizeEstimate(args[0])
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)}
}
case "validate":
if target != nil {
sz := l.sizeEstimate(args[0])
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).MultiplyByCostFactor(cel.MaxNameFormatRegexSize * common.RegexStringLengthCostFactor)}
}
case "format.named":
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
case "sign", "asInteger", "isInteger", "asApproximateFloat", "isGreaterThan", "isLessThan", "compareTo", "add", "sub":
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery":
// url accessors
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
}
if panicOnUnknown && !knownUnhandledFunctions[function] {
panic(fmt.Errorf("EstimateCallCost: unhandled function %q, target %v, args %v", function, target, args))
}
return nil
}
@@ -233,6 +486,10 @@ func actualSize(value ref.Val) uint64 {
if sz, ok := value.(traits.Sizer); ok {
return uint64(sz.Size().(types.Int))
}
if panicOnUnknown {
// debug.PrintStack()
panic(fmt.Errorf("actualSize: non-sizer type %T", value))
}
return 1
}
@@ -275,7 +532,7 @@ func (l *CostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstim
type itemsNode struct {
path []string
t *types.Type
expr *exprpb.Expr
expr ast.Expr
}
func (i *itemsNode) Path() []string {
@@ -286,7 +543,7 @@ func (i *itemsNode) Type() *types.Type {
return i.t
}
func (i *itemsNode) Expr() *exprpb.Expr {
func (i *itemsNode) Expr() ast.Expr {
return i.expr
}
@@ -294,6 +551,8 @@ func (i *itemsNode) ComputedSize() *checker.SizeEstimate {
return nil
}
var _ checker.AstNode = (*itemsNode)(nil)
// traversalCost computes the cost of traversing a ref.Val as a data tree.
func traversalCost(v ref.Val) uint64 {
// TODO: This could potentially be optimized by sampling maps and lists instead of traversing.

270
vendor/k8s.io/apiserver/pkg/cel/library/format.go generated vendored Normal file
View File

@@ -0,0 +1,270 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"fmt"
"net/url"
"github.com/asaskevich/govalidator"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation"
"k8s.io/apimachinery/pkg/util/validation"
apiservercel "k8s.io/apiserver/pkg/cel"
"k8s.io/kube-openapi/pkg/validation/strfmt"
)
// Format provides a CEL library exposing common named Kubernetes string
// validations. Can be used in CRD ValidationRules messageExpression.
//
// Example:
//
// rule: format.dns1123label.validate(object.metadata.name).hasValue()
// messageExpression: format.dns1123label.validate(object.metadata.name).value().join("\n")
//
// format.named(name: string) -> ?Format
//
// Returns the Format with the given name, if it exists. Otherwise, optional.none
// Allowed names are:
// - `dns1123Label`
// - `dns1123Subdomain`
// - `dns1035Label`
// - `qualifiedName`
// - `dns1123LabelPrefix`
// - `dns1123SubdomainPrefix`
// - `dns1035LabelPrefix`
// - `labelValue`
// - `uri`
// - `uuid`
// - `byte`
// - `date`
// - `datetime`
//
// format.<formatName>() -> Format
//
// Convenience functions for all the named formats are also available
//
// Examples:
// format.dns1123Label().validate("my-label-name")
// format.dns1123Subdomain().validate("apiextensions.k8s.io")
// format.dns1035Label().validate("my-label-name")
// format.qualifiedName().validate("apiextensions.k8s.io/v1beta1")
// format.dns1123LabelPrefix().validate("my-label-prefix-")
// format.dns1123SubdomainPrefix().validate("mysubdomain.prefix.-")
// format.dns1035LabelPrefix().validate("my-label-prefix-")
// format.uri().validate("http://example.com")
// Uses same pattern as isURL, but returns an error
// format.uuid().validate("123e4567-e89b-12d3-a456-426614174000")
// format.byte().validate("aGVsbG8=")
// format.date().validate("2021-01-01")
// format.datetime().validate("2021-01-01T00:00:00Z")
//
// <Format>.validate(str: string) -> ?list<string>
//
// Validates the given string against the given format. Returns optional.none
// if the string is valid, otherwise a list of validation error strings.
func Format() cel.EnvOption {
return cel.Lib(formatLib)
}
var formatLib = &format{}
type format struct{}
func (*format) LibraryName() string {
return "format"
}
func ZeroArgumentFunctionBinding(binding func() ref.Val) decls.OverloadOpt {
return func(o *decls.OverloadDecl) (*decls.OverloadDecl, error) {
wrapped, err := decls.FunctionBinding(func(values ...ref.Val) ref.Val { return binding() })(o)
if err != nil {
return nil, err
}
if len(wrapped.ArgTypes()) != 0 {
return nil, fmt.Errorf("function binding must have 0 arguments")
}
return o, nil
}
}
func (*format) CompileOptions() []cel.EnvOption {
options := make([]cel.EnvOption, 0, len(formatLibraryDecls))
for name, overloads := range formatLibraryDecls {
options = append(options, cel.Function(name, overloads...))
}
for name, constantValue := range ConstantFormats {
prefixedName := "format." + name
options = append(options, cel.Function(prefixedName, cel.Overload(prefixedName, []*cel.Type{}, apiservercel.FormatType, ZeroArgumentFunctionBinding(func() ref.Val {
return constantValue
}))))
}
return options
}
func (*format) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
var ConstantFormats map[string]*apiservercel.Format = map[string]*apiservercel.Format{
"dns1123Label": {
Name: "DNS1123Label",
ValidateFunc: func(s string) []string { return apimachineryvalidation.NameIsDNSLabel(s, false) },
MaxRegexSize: 30,
},
"dns1123Subdomain": {
Name: "DNS1123Subdomain",
ValidateFunc: func(s string) []string { return apimachineryvalidation.NameIsDNSSubdomain(s, false) },
MaxRegexSize: 60,
},
"dns1035Label": {
Name: "DNS1035Label",
ValidateFunc: func(s string) []string { return apimachineryvalidation.NameIsDNS1035Label(s, false) },
MaxRegexSize: 30,
},
"qualifiedName": {
Name: "QualifiedName",
ValidateFunc: validation.IsQualifiedName,
MaxRegexSize: 60, // uses subdomain regex
},
"dns1123LabelPrefix": {
Name: "DNS1123LabelPrefix",
ValidateFunc: func(s string) []string { return apimachineryvalidation.NameIsDNSLabel(s, true) },
MaxRegexSize: 30,
},
"dns1123SubdomainPrefix": {
Name: "DNS1123SubdomainPrefix",
ValidateFunc: func(s string) []string { return apimachineryvalidation.NameIsDNSSubdomain(s, true) },
MaxRegexSize: 60,
},
"dns1035LabelPrefix": {
Name: "DNS1035LabelPrefix",
ValidateFunc: func(s string) []string { return apimachineryvalidation.NameIsDNS1035Label(s, true) },
MaxRegexSize: 30,
},
"labelValue": {
Name: "LabelValue",
ValidateFunc: validation.IsValidLabelValue,
MaxRegexSize: 40,
},
// CRD formats
// Implementations sourced from strfmt, which kube-openapi uses as its
// format library. There are other CRD formats supported, but they are
// covered by other portions of the CEL library (like IP/CIDR), or their
// use is discouraged (like bsonobjectid, email, etc)
"uri": {
Name: "URI",
ValidateFunc: func(s string) []string {
// Directly call ParseRequestURI since we can get a better error message
_, err := url.ParseRequestURI(s)
if err != nil {
return []string{err.Error()}
}
return nil
},
// Use govalidator url regex to estimate, since ParseRequestURI
// doesnt use regex
MaxRegexSize: len(govalidator.URL),
},
"uuid": {
Name: "uuid",
ValidateFunc: func(s string) []string {
if !strfmt.Default.Validates("uuid", s) {
return []string{"does not match the UUID format"}
}
return nil
},
MaxRegexSize: len(strfmt.UUIDPattern),
},
"byte": {
Name: "byte",
ValidateFunc: func(s string) []string {
if !strfmt.Default.Validates("byte", s) {
return []string{"invalid base64"}
}
return nil
},
MaxRegexSize: len(govalidator.Base64),
},
"date": {
Name: "date",
ValidateFunc: func(s string) []string {
if !strfmt.Default.Validates("date", s) {
return []string{"invalid date"}
}
return nil
},
// Estimated regex size for RFC3339FullDate which is
// a date format. Assume a date-time pattern is longer
// so use that to conservatively estimate this
MaxRegexSize: len(strfmt.DateTimePattern),
},
"datetime": {
Name: "datetime",
ValidateFunc: func(s string) []string {
if !strfmt.Default.Validates("datetime", s) {
return []string{"invalid datetime"}
}
return nil
},
MaxRegexSize: len(strfmt.DateTimePattern),
},
}
var formatLibraryDecls = map[string][]cel.FunctionOpt{
"validate": {
cel.MemberOverload("format-validate", []*cel.Type{apiservercel.FormatType, cel.StringType}, cel.OptionalType(cel.ListType(cel.StringType)), cel.BinaryBinding(formatValidate)),
},
"format.named": {
cel.Overload("format-named", []*cel.Type{cel.StringType}, cel.OptionalType(apiservercel.FormatType), cel.UnaryBinding(func(name ref.Val) ref.Val {
nameString, ok := name.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(name)
}
f, ok := ConstantFormats[nameString]
if !ok {
return types.OptionalNone
}
return types.OptionalOf(f)
})),
},
}
func formatValidate(arg1, arg2 ref.Val) ref.Val {
f, ok := arg1.Value().(*apiservercel.Format)
if !ok {
return types.MaybeNoSuchOverloadErr(arg1)
}
str, ok := arg2.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg2)
}
res := f.ValidateFunc(str)
if len(res) == 0 {
return types.OptionalNone
}
return types.OptionalOf(types.NewStringList(types.DefaultTypeAdapter, res))
}

329
vendor/k8s.io/apiserver/pkg/cel/library/ip.go generated vendored Normal file
View File

@@ -0,0 +1,329 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"fmt"
"net/netip"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
apiservercel "k8s.io/apiserver/pkg/cel"
)
// IP provides a CEL function library extension of IP address parsing functions.
//
// ip
//
// Converts a string to an IP address or results in an error if the string is not a valid IP address.
// The IP address must be an IPv4 or IPv6 address.
// IPv4-mapped IPv6 addresses (e.g. ::ffff:1.2.3.4) are not allowed.
// IP addresses with zones (e.g. fe80::1%eth0) are not allowed.
// Leading zeros in IPv4 address octets are not allowed.
//
// ip(<string>) <IPAddr>
//
// Examples:
//
// ip('127.0.0.1') // returns an IPv4 address
// ip('::1') // returns an IPv6 address
// ip('127.0.0.256') // error
// ip(':::1') // error
//
// isIP
//
// Returns true if a string is a valid IP address.
// The IP address must be an IPv4 or IPv6 address.
// IPv4-mapped IPv6 addresses (e.g. ::ffff:1.2.3.4) are not allowed.
// IP addresses with zones (e.g. fe80::1%eth0) are not allowed.
// Leading zeros in IPv4 address octets are not allowed.
//
// isIP(<string>) <bool>
//
// Examples:
//
// isIP('127.0.0.1') // returns true
// isIP('::1') // returns true
// isIP('127.0.0.256') // returns false
// isIP(':::1') // returns false
//
// ip.isCanonical
//
// Returns true if the IP address is in its canonical form.
// There is exactly one canonical form for every IP address, so fields containing
// IPs in canonical form can just be treated as strings when checking for equality or uniqueness.
//
// ip.isCanonical(<string>) <bool>
//
// Examples:
//
// ip.isCanonical('127.0.0.1') // returns true; all valid IPv4 addresses are canonical
// ip.isCanonical('2001:db8::abcd') // returns true
// ip.isCanonical('2001:DB8::ABCD') // returns false
// ip.isCanonical('2001:db8::0:0:0:abcd') // returns false
//
// family / isUnspecified / isLoopback / isLinkLocalMulticast / isLinkLocalUnicast / isGlobalUnicast
//
// - family: returns the IP addresses' family (IPv4 or IPv6) as an integer, either '4' or '6'.
//
// - isUnspecified: returns true if the IP address is the unspecified address.
// Either the IPv4 address "0.0.0.0" or the IPv6 address "::".
//
// - isLoopback: returns true if the IP address is the loopback address.
// Either an IPv4 address with a value of 127.x.x.x or an IPv6 address with a value of ::1.
//
// - isLinkLocalMulticast: returns true if the IP address is a link-local multicast address.
// Either an IPv4 address with a value of 224.0.0.x or an IPv6 address in the network ff00::/8.
//
// - isLinkLocalUnicast: returns true if the IP address is a link-local unicast address.
// Either an IPv4 address with a value of 169.254.x.x or an IPv6 address in the network fe80::/10.
//
// - isGlobalUnicast: returns true if the IP address is a global unicast address.
// Either an IPv4 address that is not zero or 255.255.255.255 or an IPv6 address that is not a link-local unicast, loopback or multicast address.
//
// Examples:
//
// ip('127.0.0.1').family() // returns '4”
// ip('::1').family() // returns '6'
// ip('127.0.0.1').family() == 4 // returns true
// ip('::1').family() == 6 // returns true
// ip('0.0.0.0').isUnspecified() // returns true
// ip('127.0.0.1').isUnspecified() // returns false
// ip('::').isUnspecified() // returns true
// ip('::1').isUnspecified() // returns false
// ip('127.0.0.1').isLoopback() // returns true
// ip('192.168.0.1').isLoopback() // returns false
// ip('::1').isLoopback() // returns true
// ip('2001:db8::abcd').isLoopback() // returns false
// ip('224.0.0.1').isLinkLocalMulticast() // returns true
// ip('224.0.1.1').isLinkLocalMulticast() // returns false
// ip('ff02::1').isLinkLocalMulticast() // returns true
// ip('fd00::1').isLinkLocalMulticast() // returns false
// ip('169.254.169.254').isLinkLocalUnicast() // returns true
// ip('192.168.0.1').isLinkLocalUnicast() // returns false
// ip('fe80::1').isLinkLocalUnicast() // returns true
// ip('fd80::1').isLinkLocalUnicast() // returns false
// ip('192.168.0.1').isGlobalUnicast() // returns true
// ip('255.255.255.255').isGlobalUnicast() // returns false
// ip('2001:db8::abcd').isGlobalUnicast() // returns true
// ip('ff00::1').isGlobalUnicast() // returns false
func IP() cel.EnvOption {
return cel.Lib(ipLib)
}
var ipLib = &ip{}
type ip struct{}
func (*ip) LibraryName() string {
return "net.ip"
}
var ipLibraryDecls = map[string][]cel.FunctionOpt{
"ip": {
cel.Overload("string_to_ip", []*cel.Type{cel.StringType}, apiservercel.IPType,
cel.UnaryBinding(stringToIP)),
},
"family": {
cel.MemberOverload("ip_family", []*cel.Type{apiservercel.IPType}, cel.IntType,
cel.UnaryBinding(family)),
},
"ip.isCanonical": {
cel.Overload("ip_is_canonical", []*cel.Type{cel.StringType}, cel.BoolType,
cel.UnaryBinding(ipIsCanonical)),
},
"isUnspecified": {
cel.MemberOverload("ip_is_unspecified", []*cel.Type{apiservercel.IPType}, cel.BoolType,
cel.UnaryBinding(isUnspecified)),
},
"isLoopback": {
cel.MemberOverload("ip_is_loopback", []*cel.Type{apiservercel.IPType}, cel.BoolType,
cel.UnaryBinding(isLoopback)),
},
"isLinkLocalMulticast": {
cel.MemberOverload("ip_is_link_local_multicast", []*cel.Type{apiservercel.IPType}, cel.BoolType,
cel.UnaryBinding(isLinkLocalMulticast)),
},
"isLinkLocalUnicast": {
cel.MemberOverload("ip_is_link_local_unicast", []*cel.Type{apiservercel.IPType}, cel.BoolType,
cel.UnaryBinding(isLinkLocalUnicast)),
},
"isGlobalUnicast": {
cel.MemberOverload("ip_is_global_unicast", []*cel.Type{apiservercel.IPType}, cel.BoolType,
cel.UnaryBinding(isGlobalUnicast)),
},
"isIP": {
cel.Overload("is_ip", []*cel.Type{cel.StringType}, cel.BoolType,
cel.UnaryBinding(isIP)),
},
"string": {
cel.Overload("ip_to_string", []*cel.Type{apiservercel.IPType}, cel.StringType,
cel.UnaryBinding(ipToString)),
},
}
func (*ip) CompileOptions() []cel.EnvOption {
options := []cel.EnvOption{cel.Types(apiservercel.IPType),
cel.Variable(apiservercel.IPType.TypeName(), types.NewTypeTypeWithParam(apiservercel.IPType)),
}
for name, overloads := range ipLibraryDecls {
options = append(options, cel.Function(name, overloads...))
}
return options
}
func (*ip) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func stringToIP(arg ref.Val) ref.Val {
s, ok := arg.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
addr, err := parseIPAddr(s)
if err != nil {
// Don't add context, we control the error message already.
return types.NewErr("%v", err)
}
return apiservercel.IP{
Addr: addr,
}
}
func ipToString(arg ref.Val) ref.Val {
ip, ok := arg.(apiservercel.IP)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.String(ip.Addr.String())
}
func family(arg ref.Val) ref.Val {
ip, ok := arg.(apiservercel.IP)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
switch {
case ip.Addr.Is4():
return types.Int(4)
case ip.Addr.Is6():
return types.Int(6)
default:
return types.NewErr("IP address %q is not an IPv4 or IPv6 address", ip.Addr.String())
}
}
func ipIsCanonical(arg ref.Val) ref.Val {
s, ok := arg.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
addr, err := parseIPAddr(s)
if err != nil {
// Don't add context, we control the error message already.
return types.NewErr("%v", err)
}
// Addr.String() always returns the canonical form of the IP address.
// Therefore comparing this with the original string representation
// will tell us if the IP address is in its canonical form.
return types.Bool(addr.String() == s)
}
func isIP(arg ref.Val) ref.Val {
s, ok := arg.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
_, err := parseIPAddr(s)
return types.Bool(err == nil)
}
func isUnspecified(arg ref.Val) ref.Val {
ip, ok := arg.(apiservercel.IP)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.Bool(ip.Addr.IsUnspecified())
}
func isLoopback(arg ref.Val) ref.Val {
ip, ok := arg.(apiservercel.IP)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.Bool(ip.Addr.IsLoopback())
}
func isLinkLocalMulticast(arg ref.Val) ref.Val {
ip, ok := arg.(apiservercel.IP)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.Bool(ip.Addr.IsLinkLocalMulticast())
}
func isLinkLocalUnicast(arg ref.Val) ref.Val {
ip, ok := arg.(apiservercel.IP)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.Bool(ip.Addr.IsLinkLocalUnicast())
}
func isGlobalUnicast(arg ref.Val) ref.Val {
ip, ok := arg.(apiservercel.IP)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.Bool(ip.Addr.IsGlobalUnicast())
}
// parseIPAddr parses a string into an IP address.
// We use this function to parse IP addresses in the CEL library
// so that we can share the common logic of rejecting IP addresses
// that contain zones or are IPv4-mapped IPv6 addresses.
func parseIPAddr(raw string) (netip.Addr, error) {
addr, err := netip.ParseAddr(raw)
if err != nil {
return netip.Addr{}, fmt.Errorf("IP Address %q parse error during conversion from string: %v", raw, err)
}
if addr.Zone() != "" {
return netip.Addr{}, fmt.Errorf("IP address %q with zone value is not allowed", raw)
}
if addr.Is4In6() {
return netip.Addr{}, fmt.Errorf("IPv4-mapped IPv6 address %q is not allowed", raw)
}
return addr, nil
}

View File

@@ -47,4 +47,6 @@ const (
MinBoolSize = 4
// MinNumberSize is the length of literal 0
MinNumberSize = 1
MaxNameFormatRegexSize = 128
)

View File

@@ -44,14 +44,14 @@ func newCelMetrics() *CelMetrics {
Subsystem: subsystem,
Name: "compilation_duration_seconds",
Help: "CEL compilation time in seconds.",
StabilityLevel: metrics.ALPHA,
StabilityLevel: metrics.BETA,
}),
evaluationTime: metrics.NewHistogram(&metrics.HistogramOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "evaluation_duration_seconds",
Help: "CEL evaluation time in seconds.",
StabilityLevel: metrics.ALPHA,
StabilityLevel: metrics.BETA,
}),
}

View File

@@ -50,7 +50,7 @@ func (d Quantity) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (d Quantity) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case typeValue:
case quantityTypeValue:
return d
case types.TypeType:
return quantityTypeValue

View File

@@ -27,7 +27,7 @@ import (
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"google.golang.org/protobuf/proto"
"k8s.io/apimachinery/pkg/api/resource"
)
const (
@@ -348,9 +348,14 @@ func NewDeclTypeProvider(rootTypes ...*DeclType) *DeclTypeProvider {
// DeclTypeProvider extends the CEL ref.TypeProvider interface and provides an Open API Schema-based
// type-system.
type DeclTypeProvider struct {
registeredTypes map[string]*DeclType
typeProvider ref.TypeProvider
typeAdapter ref.TypeAdapter
registeredTypes map[string]*DeclType
typeProvider types.Provider
typeAdapter types.Adapter
recognizeKeywordAsFieldName bool
}
func (rt *DeclTypeProvider) SetRecognizeKeywordAsFieldName(recognize bool) {
rt.recognizeKeywordAsFieldName = recognize
}
func (rt *DeclTypeProvider) EnumValue(enumName string) ref.Val {
@@ -365,7 +370,7 @@ func (rt *DeclTypeProvider) FindIdent(identName string) (ref.Val, bool) {
// as well as a custom ref.TypeProvider.
//
// If the DeclTypeProvider value is nil, an empty []cel.EnvOption set is returned.
func (rt *DeclTypeProvider) EnvOptions(tp ref.TypeProvider) ([]cel.EnvOption, error) {
func (rt *DeclTypeProvider) EnvOptions(tp types.Provider) ([]cel.EnvOption, error) {
if rt == nil {
return []cel.EnvOption{}, nil
}
@@ -381,54 +386,52 @@ func (rt *DeclTypeProvider) EnvOptions(tp ref.TypeProvider) ([]cel.EnvOption, er
// WithTypeProvider returns a new DeclTypeProvider that sets the given TypeProvider
// If the original DeclTypeProvider is nil, the returned DeclTypeProvider is still nil.
func (rt *DeclTypeProvider) WithTypeProvider(tp ref.TypeProvider) (*DeclTypeProvider, error) {
func (rt *DeclTypeProvider) WithTypeProvider(tp types.Provider) (*DeclTypeProvider, error) {
if rt == nil {
return nil, nil
}
var ta ref.TypeAdapter = types.DefaultTypeAdapter
tpa, ok := tp.(ref.TypeAdapter)
var ta types.Adapter = types.DefaultTypeAdapter
tpa, ok := tp.(types.Adapter)
if ok {
ta = tpa
}
rtWithTypes := &DeclTypeProvider{
typeProvider: tp,
typeAdapter: ta,
registeredTypes: rt.registeredTypes,
typeProvider: tp,
typeAdapter: ta,
registeredTypes: rt.registeredTypes,
recognizeKeywordAsFieldName: rt.recognizeKeywordAsFieldName,
}
for name, declType := range rt.registeredTypes {
tpType, found := tp.FindType(name)
expT, err := declType.ExprType()
if err != nil {
return nil, fmt.Errorf("fail to get cel type: %s", err)
}
if found && !proto.Equal(tpType, expT) {
tpType, found := tp.FindStructType(name)
// cast celType to types.type
expT := declType.CelType()
if found && !expT.IsExactType(tpType) {
return nil, fmt.Errorf(
"type %s definition differs between CEL environment and type provider", name)
}
}
return rtWithTypes, nil
}
// FindType attempts to resolve the typeName provided from the rule's rule-schema, or if not
// FindStructType attempts to resolve the typeName provided from the rule's rule-schema, or if not
// from the embedded ref.TypeProvider.
//
// FindType overrides the default type-finding behavior of the embedded TypeProvider.
// FindStructType overrides the default type-finding behavior of the embedded TypeProvider.
//
// Note, when the type name is based on the Open API Schema, the name will reflect the object path
// where the type definition appears.
func (rt *DeclTypeProvider) FindType(typeName string) (*exprpb.Type, bool) {
func (rt *DeclTypeProvider) FindStructType(typeName string) (*types.Type, bool) {
if rt == nil {
return nil, false
}
declType, found := rt.findDeclType(typeName)
if found {
expT, err := declType.ExprType()
if err != nil {
return expT, false
}
expT := declType.CelType()
return expT, found
}
return rt.typeProvider.FindType(typeName)
return rt.typeProvider.FindStructType(typeName)
}
// FindDeclType returns the CPT type description which can be mapped to a CEL type.
@@ -439,37 +442,41 @@ func (rt *DeclTypeProvider) FindDeclType(typeName string) (*DeclType, bool) {
return rt.findDeclType(typeName)
}
// FindFieldType returns a field type given a type name and field name, if found.
// FindStructFieldNames returns the field names associated with the type, if the type
// is found.
func (rt *DeclTypeProvider) FindStructFieldNames(typeName string) ([]string, bool) {
return []string{}, false
}
// FindStructFieldType returns a field type given a type name and field name, if found.
//
// Note, the type name for an Open API Schema type is likely to be its qualified object path.
// If, in the future an object instance rather than a type name were provided, the field
// resolution might more accurately reflect the expected type model. However, in this case
// concessions were made to align with the existing CEL interfaces.
func (rt *DeclTypeProvider) FindFieldType(typeName, fieldName string) (*ref.FieldType, bool) {
func (rt *DeclTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) {
st, found := rt.findDeclType(typeName)
if !found {
return rt.typeProvider.FindFieldType(typeName, fieldName)
return rt.typeProvider.FindStructFieldType(typeName, fieldName)
}
f, found := st.Fields[fieldName]
if rt.recognizeKeywordAsFieldName && !found && celReservedSymbols.Has(fieldName) {
f, found = st.Fields["__"+fieldName+"__"]
}
if found {
ft := f.Type
expT, err := ft.ExprType()
if err != nil {
return nil, false
}
return &ref.FieldType{
expT := ft.CelType()
return &types.FieldType{
Type: expT,
}, true
}
// This could be a dynamic map.
if st.IsMap() {
et := st.ElemType
expT, err := et.ExprType()
if err != nil {
return nil, false
}
return &ref.FieldType{
expT := et.CelType()
return &types.FieldType{
Type: expT,
}, true
}
@@ -576,6 +583,10 @@ var (
// labeled as Timestamp will necessarily have the same MinSerializedSize.
TimestampType = NewSimpleTypeWithMinSize("timestamp", cel.TimestampType, types.Timestamp{Time: time.Time{}}, JSONDateSize)
// QuantityDeclType wraps a [QuantityType] and makes it usable with functions that expect
// a [DeclType].
QuantityDeclType = NewSimpleTypeWithMinSize("quantity", QuantityType, Quantity{Quantity: resource.NewQuantity(0, resource.DecimalSI)}, 8)
// UintType is equivalent to the CEL 'uint' type.
UintType = NewSimpleTypeWithMinSize("uint", cel.UintType, types.Uint(0), 1)