feat: kubesphere 4.0 (#6115)

* feat: kubesphere 4.0

Signed-off-by: ci-bot <ci-bot@kubesphere.io>

* feat: kubesphere 4.0

Signed-off-by: ci-bot <ci-bot@kubesphere.io>

---------

Signed-off-by: ci-bot <ci-bot@kubesphere.io>
Co-authored-by: ks-ci-bot <ks-ci-bot@example.com>
Co-authored-by: joyceliu <joyceliu@yunify.com>
This commit is contained in:
KubeSphere CI Bot
2024-09-06 11:05:52 +08:00
committed by GitHub
parent b5015ec7b9
commit 447a51f08b
8557 changed files with 546695 additions and 1146174 deletions

View File

@@ -15,6 +15,7 @@ go_library(
"macro.go",
"options.go",
"program.go",
"validator.go",
],
importpath = "github.com/google/cel-go/cel",
visibility = ["//visibility:public"],
@@ -22,16 +23,20 @@ go_library(
"//checker:go_default_library",
"//checker/decls:go_default_library",
"//common:go_default_library",
"//common/ast:go_default_library",
"//common/containers:go_default_library",
"//common/decls:go_default_library",
"//common/functions:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/stdlib:go_default_library",
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//interpreter:go_default_library",
"//interpreter/functions:go_default_library",
"//parser:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protodesc:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
@@ -69,8 +74,10 @@ go_test(
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@io_bazel_rules_go//proto/wkt:descriptor_go_proto",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//encoding/prototext:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
"@org_golang_google_protobuf//types/known/wrapperspb:go_default_library",
],
)

File diff suppressed because it is too large Load Diff

View File

@@ -16,13 +16,14 @@ package cel
import (
"errors"
"fmt"
"sync"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/checker/decls"
chkdecls "github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common"
celast "github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
@@ -40,8 +41,8 @@ type Ast struct {
expr *exprpb.Expr
info *exprpb.SourceInfo
source Source
refMap map[int64]*exprpb.Reference
typeMap map[int64]*exprpb.Type
refMap map[int64]*celast.ReferenceInfo
typeMap map[int64]*types.Type
}
// Expr returns the proto serializable instance of the parsed/checked expression.
@@ -60,21 +61,26 @@ func (ast *Ast) SourceInfo() *exprpb.SourceInfo {
}
// ResultType returns the output type of the expression if the Ast has been type-checked, else
// returns decls.Dyn as the parse step cannot infer the type.
// returns chkdecls.Dyn as the parse step cannot infer the type.
//
// Deprecated: use OutputType
func (ast *Ast) ResultType() *exprpb.Type {
if !ast.IsChecked() {
return decls.Dyn
return chkdecls.Dyn
}
return ast.typeMap[ast.expr.GetId()]
out := ast.OutputType()
t, err := TypeToExprType(out)
if err != nil {
return chkdecls.Dyn
}
return t
}
// OutputType returns the output type of the expression if the Ast has been type-checked, else
// returns cel.DynType as the parse step cannot infer types.
func (ast *Ast) OutputType() *Type {
t, err := ExprTypeToType(ast.ResultType())
if err != nil {
t, found := ast.typeMap[ast.expr.GetId()]
if !found {
return DynType
}
return t
@@ -87,30 +93,44 @@ func (ast *Ast) Source() Source {
}
// FormatType converts a type message into a string representation.
//
// Deprecated: prefer FormatCELType
func FormatType(t *exprpb.Type) string {
return checker.FormatCheckedType(t)
}
// FormatCELType formats a cel.Type value to a string representation.
//
// The type formatting is identical to FormatType.
func FormatCELType(t *Type) string {
return checker.FormatCELType(t)
}
// Env encapsulates the context necessary to perform parsing, type checking, or generation of
// evaluable programs for different expressions.
type Env struct {
Container *containers.Container
functions map[string]*functionDecl
declarations []*exprpb.Decl
variables []*decls.VariableDecl
functions map[string]*decls.FunctionDecl
macros []parser.Macro
adapter ref.TypeAdapter
provider ref.TypeProvider
adapter types.Adapter
provider types.Provider
features map[int]bool
appliedFeatures map[int]bool
libraries map[string]bool
validators []ASTValidator
costOptions []checker.CostOption
// Internal parser representation
prsr *parser.Parser
prsr *parser.Parser
prsrOpts []parser.Option
// Internal checker representation
chk *checker.Env
chkErr error
chkOnce sync.Once
chkOpts []checker.Option
chkMutex sync.Mutex
chk *checker.Env
chkErr error
chkOnce sync.Once
chkOpts []checker.Option
// Program options tied to the environment
progOpts []ProgramOption
@@ -151,22 +171,29 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
return nil, err
}
return (&Env{
declarations: []*exprpb.Decl{},
functions: map[string]*functionDecl{},
variables: []*decls.VariableDecl{},
functions: map[string]*decls.FunctionDecl{},
macros: []parser.Macro{},
Container: containers.DefaultContainer,
adapter: registry,
provider: registry,
features: map[int]bool{},
appliedFeatures: map[int]bool{},
libraries: map[string]bool{},
validators: []ASTValidator{},
progOpts: []ProgramOption{},
costOptions: []checker.CostOption{},
}).configure(opts)
}
// Check performs type-checking on the input Ast and yields a checked Ast and/or set of Issues.
// If any `ASTValidators` are configured on the environment, they will be applied after a valid
// type-check result. If any issues are detected, the validators will provide them on the
// output Issues object.
//
// Checking has failed if the returned Issues value and its Issues.Err() value are non-nil.
// Issues should be inspected if they are non-nil, but may not represent a fatal error.
// Either checking or validation has failed if the returned Issues value and its Issues.Err()
// value are non-nil. Issues should be inspected if they are non-nil, but may not represent a
// fatal error.
//
// It is possible to have both non-nil Ast and Issues values returned from this call: however,
// the mere presence of an Ast does not imply that it is valid for use.
@@ -175,25 +202,42 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
pe, _ := AstToParsedExpr(ast)
// Construct the internal checker env, erroring if there is an issue adding the declarations.
err := e.initChecker()
chk, err := e.initChecker()
if err != nil {
errs := common.NewErrors(ast.Source())
errs.ReportError(common.NoLocation, e.chkErr.Error())
return nil, NewIssues(errs)
errs.ReportError(common.NoLocation, err.Error())
return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo())
}
res, errs := checker.Check(pe, ast.Source(), e.chk)
res, errs := checker.Check(pe, ast.Source(), chk)
if len(errs.GetErrors()) > 0 {
return nil, NewIssues(errs)
return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo())
}
// Manually create the Ast to ensure that the Ast source information (which may be more
// detailed than the information provided by Check), is returned to the caller.
return &Ast{
ast = &Ast{
source: ast.Source(),
expr: res.GetExpr(),
info: res.GetSourceInfo(),
refMap: res.GetReferenceMap(),
typeMap: res.GetTypeMap()}, nil
expr: res.Expr,
info: res.SourceInfo,
refMap: res.ReferenceMap,
typeMap: res.TypeMap}
// Generate a validator configuration from the set of configured validators.
vConfig := newValidatorConfig()
for _, v := range e.validators {
if cv, ok := v.(ASTValidatorConfigurer); ok {
cv.Configure(vConfig)
}
}
// Apply additional validators on the type-checked result.
iss := NewIssuesWithSourceInfo(errs, ast.SourceInfo())
for _, v := range e.validators {
v.Validate(e, vConfig, res, iss)
}
if iss.Err() != nil {
return nil, iss
}
return ast, nil
}
// Compile combines the Parse and Check phases CEL program compilation to produce an Ast and
@@ -236,10 +280,14 @@ func (e *Env) CompileSource(src Source) (*Ast, *Issues) {
// TypeProvider are immutable, or that their underlying implementations are based on the
// ref.TypeRegistry which provides a Copy method which will be invoked by this method.
func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
if e.chkErr != nil {
return nil, e.chkErr
chk, chkErr := e.getCheckerOrError()
if chkErr != nil {
return nil, chkErr
}
prsrOptsCopy := make([]parser.Option, len(e.prsrOpts))
copy(prsrOptsCopy, e.prsrOpts)
// The type-checker is configured with Declarations. The declarations may either be provided
// as options which have not yet been validated, or may come from a previous checker instance
// whose types have already been validated.
@@ -247,16 +295,16 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
copy(chkOptsCopy, e.chkOpts)
// Copy the declarations if needed.
decsCopy := []*exprpb.Decl{}
if e.chk != nil {
varsCopy := []*decls.VariableDecl{}
if chk != nil {
// If the type-checker has already been instantiated, then the e.declarations have been
// valdiated within the chk instance.
chkOptsCopy = append(chkOptsCopy, checker.ValidatedDeclarations(e.chk))
// validated within the chk instance.
chkOptsCopy = append(chkOptsCopy, checker.ValidatedDeclarations(chk))
} else {
// If the type-checker has not been instantiated, ensure the unvalidated declarations are
// provided to the extended Env instance.
decsCopy = make([]*exprpb.Decl, len(e.declarations))
copy(decsCopy, e.declarations)
varsCopy = make([]*decls.VariableDecl, len(e.variables))
copy(varsCopy, e.variables)
}
// Copy macros and program options
@@ -268,8 +316,8 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
// Copy the adapter / provider if they appear to be mutable.
adapter := e.adapter
provider := e.provider
adapterReg, isAdapterReg := e.adapter.(ref.TypeRegistry)
providerReg, isProviderReg := e.provider.(ref.TypeRegistry)
adapterReg, isAdapterReg := e.adapter.(*types.Registry)
providerReg, isProviderReg := e.provider.(*types.Registry)
// In most cases the provider and adapter will be a ref.TypeRegistry;
// however, in the rare cases where they are not, they are assumed to
// be immutable. Since it is possible to set the TypeProvider separately
@@ -300,23 +348,34 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
for k, v := range e.appliedFeatures {
appliedFeaturesCopy[k] = v
}
funcsCopy := make(map[string]*functionDecl, len(e.functions))
funcsCopy := make(map[string]*decls.FunctionDecl, len(e.functions))
for k, v := range e.functions {
funcsCopy[k] = v
}
libsCopy := make(map[string]bool, len(e.libraries))
for k, v := range e.libraries {
libsCopy[k] = v
}
validatorsCopy := make([]ASTValidator, len(e.validators))
copy(validatorsCopy, e.validators)
costOptsCopy := make([]checker.CostOption, len(e.costOptions))
copy(costOptsCopy, e.costOptions)
// TODO: functions copy needs to happen here.
ext := &Env{
Container: e.Container,
declarations: decsCopy,
variables: varsCopy,
functions: funcsCopy,
macros: macsCopy,
progOpts: progOptsCopy,
adapter: adapter,
features: featuresCopy,
appliedFeatures: appliedFeaturesCopy,
libraries: libsCopy,
validators: validatorsCopy,
provider: provider,
chkOpts: chkOptsCopy,
prsrOpts: prsrOptsCopy,
costOptions: costOptsCopy,
}
return ext.configure(opts)
}
@@ -328,6 +387,31 @@ func (e *Env) HasFeature(flag int) bool {
return has && enabled
}
// HasLibrary returns whether a specific SingletonLibrary has been configured in the environment.
func (e *Env) HasLibrary(libName string) bool {
configured, exists := e.libraries[libName]
return exists && configured
}
// Libraries returns a list of SingletonLibrary that have been configured in the environment.
func (e *Env) Libraries() []string {
libraries := make([]string, 0, len(e.libraries))
for libName := range e.libraries {
libraries = append(libraries, libName)
}
return libraries
}
// HasValidator returns whether a specific ASTValidator has been configured in the environment.
func (e *Env) HasValidator(name string) bool {
for _, v := range e.validators {
if v.Name() == name {
return true
}
}
return false
}
// Parse parses the input expression value `txt` to a Ast and/or a set of Issues.
//
// This form of Parse creates a Source value for the input `txt` and forwards to the
@@ -369,36 +453,64 @@ func (e *Env) Program(ast *Ast, opts ...ProgramOption) (Program, error) {
return newProgram(e, ast, optSet)
}
// CELTypeAdapter returns the `types.Adapter` configured for the environment.
func (e *Env) CELTypeAdapter() types.Adapter {
return e.adapter
}
// CELTypeProvider returns the `types.Provider` configured for the environment.
func (e *Env) CELTypeProvider() types.Provider {
return e.provider
}
// TypeAdapter returns the `ref.TypeAdapter` configured for the environment.
//
// Deprecated: use CELTypeAdapter()
func (e *Env) TypeAdapter() ref.TypeAdapter {
return e.adapter
}
// TypeProvider returns the `ref.TypeProvider` configured for the environment.
//
// Deprecated: use CELTypeProvider()
func (e *Env) TypeProvider() ref.TypeProvider {
return e.provider
if legacyProvider, ok := e.provider.(ref.TypeProvider); ok {
return legacyProvider
}
return &interopLegacyTypeProvider{Provider: e.provider}
}
// UnknownVars returns an interpreter.PartialActivation which marks all variables
// declared in the Env as unknown AttributePattern values.
// UnknownVars returns an interpreter.PartialActivation which marks all variables declared in the
// Env as unknown AttributePattern values.
//
// Note, the UnknownVars will behave the same as an interpreter.EmptyActivation
// unless the PartialAttributes option is provided as a ProgramOption.
// Note, the UnknownVars will behave the same as an interpreter.EmptyActivation unless the
// PartialAttributes option is provided as a ProgramOption.
func (e *Env) UnknownVars() interpreter.PartialActivation {
var unknownPatterns []*interpreter.AttributePattern
for _, d := range e.declarations {
switch d.GetDeclKind().(type) {
case *exprpb.Decl_Ident:
unknownPatterns = append(unknownPatterns,
interpreter.NewAttributePattern(d.GetName()))
}
}
part, _ := PartialVars(
interpreter.EmptyActivation(),
unknownPatterns...)
act := interpreter.EmptyActivation()
part, _ := PartialVars(act, e.computeUnknownVars(act)...)
return part
}
// PartialVars returns an interpreter.PartialActivation where all variables not in the input variable
// set, but which have been configured in the environment, are marked as unknown.
//
// The `vars` value may either be an interpreter.Activation or any valid input to the
// interpreter.NewActivation call.
//
// Note, this is equivalent to calling cel.PartialVars and manually configuring the set of unknown
// variables. For more advanced use cases of partial state where portions of an object graph, rather
// than top-level variables, are missing the PartialVars() method may be a more suitable choice.
//
// Note, the PartialVars will behave the same as an interpreter.EmptyActivation unless the
// PartialAttributes option is provided as a ProgramOption.
func (e *Env) PartialVars(vars any) (interpreter.PartialActivation, error) {
act, err := interpreter.NewActivation(vars)
if err != nil {
return nil, err
}
return PartialVars(act, e.computeUnknownVars(act)...)
}
// ResidualAst takes an Ast and its EvalDetails to produce a new Ast which only contains the
// attribute references which are unknown.
//
@@ -422,8 +534,8 @@ func (e *Env) UnknownVars() interpreter.PartialActivation {
// TODO: Consider adding an option to generate a Program.Residual to avoid round-tripping to an
// Ast format and then Program again.
func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
pruned := interpreter.PruneAst(a.Expr(), details.State())
expr, err := AstToString(ParsedExprToAst(&exprpb.ParsedExpr{Expr: pruned}))
pruned := interpreter.PruneAst(a.Expr(), a.SourceInfo().GetMacroCalls(), details.State())
expr, err := AstToString(ParsedExprToAst(pruned))
if err != nil {
return nil, err
}
@@ -443,12 +555,17 @@ func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
// EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and
// extension functions provided by estimator.
func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator) (checker.CostEstimate, error) {
checked, err := AstToCheckedExpr(ast)
if err != nil {
return checker.CostEstimate{}, fmt.Errorf("EsimateCost could not inspect Ast: %v", err)
func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...checker.CostOption) (checker.CostEstimate, error) {
checked := &celast.CheckedAST{
Expr: ast.Expr(),
SourceInfo: ast.SourceInfo(),
TypeMap: ast.typeMap,
ReferenceMap: ast.refMap,
}
return checker.Cost(checked, estimator), nil
extendedOpts := make([]checker.CostOption, 0, len(e.costOptions))
extendedOpts = append(extendedOpts, opts...)
extendedOpts = append(extendedOpts, e.costOptions...)
return checker.Cost(checked, estimator, extendedOpts...)
}
// configure applies a series of EnvOptions to the current environment.
@@ -464,32 +581,22 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
}
// If the default UTC timezone fix has been enabled, make sure the library is configured
if e.HasFeature(featureDefaultUTCTimeZone) {
if _, found := e.appliedFeatures[featureDefaultUTCTimeZone]; !found {
e, err = Lib(timeUTCLibrary{})(e)
if err != nil {
return nil, err
}
// record that the feature has been applied since it will generate declarations
// and functions which will be propagated on Extend() calls and which should only
// be registered once.
e.appliedFeatures[featureDefaultUTCTimeZone] = true
}
}
// Initialize all of the functions configured within the environment.
for _, fn := range e.functions {
err = fn.init()
if err != nil {
return nil, err
}
e, err = e.maybeApplyFeature(featureDefaultUTCTimeZone, Lib(timeUTCLibrary{}))
if err != nil {
return nil, err
}
// Configure the parser.
prsrOpts := []parser.Option{parser.Macros(e.macros...)}
prsrOpts := []parser.Option{}
prsrOpts = append(prsrOpts, e.prsrOpts...)
prsrOpts = append(prsrOpts, parser.Macros(e.macros...))
if e.HasFeature(featureEnableMacroCallTracking) {
prsrOpts = append(prsrOpts, parser.PopulateMacroCalls(true))
}
if e.HasFeature(featureVariadicLogicalASTs) {
prsrOpts = append(prsrOpts, parser.EnableVariadicOperatorASTs(true))
}
e.prsr, err = parser.NewParser(prsrOpts...)
if err != nil {
return nil, err
@@ -497,7 +604,7 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
// Ensure that the checker init happens eagerly rather than lazily.
if e.HasFeature(featureEagerlyValidateDeclarations) {
err := e.initChecker()
_, err := e.initChecker()
if err != nil {
return nil, err
}
@@ -506,57 +613,115 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
return e, nil
}
func (e *Env) initChecker() error {
func (e *Env) initChecker() (*checker.Env, error) {
e.chkOnce.Do(func() {
chkOpts := []checker.Option{}
chkOpts = append(chkOpts, e.chkOpts...)
chkOpts = append(chkOpts,
checker.HomogeneousAggregateLiterals(
e.HasFeature(featureDisableDynamicAggregateLiterals)),
checker.CrossTypeNumericComparisons(
e.HasFeature(featureCrossTypeNumericComparisons)))
ce, err := checker.NewEnv(e.Container, e.provider, chkOpts...)
if err != nil {
e.chkErr = err
e.setCheckerOrError(nil, err)
return
}
// Add the statically configured declarations.
err = ce.Add(e.declarations...)
err = ce.AddIdents(e.variables...)
if err != nil {
e.chkErr = err
e.setCheckerOrError(nil, err)
return
}
// Add the function declarations which are derived from the FunctionDecl instances.
for _, fn := range e.functions {
fnDecl, err := functionDeclToExprDecl(fn)
if err != nil {
e.chkErr = err
return
if fn.IsDeclarationDisabled() {
continue
}
err = ce.Add(fnDecl)
err = ce.AddFunctions(fn)
if err != nil {
e.chkErr = err
e.setCheckerOrError(nil, err)
return
}
}
// Add function declarations here separately.
e.chk = ce
e.setCheckerOrError(ce, nil)
})
return e.chkErr
return e.getCheckerOrError()
}
// setCheckerOrError sets the checker.Env or error state in a concurrency-safe manner
func (e *Env) setCheckerOrError(chk *checker.Env, chkErr error) {
e.chkMutex.Lock()
e.chk = chk
e.chkErr = chkErr
e.chkMutex.Unlock()
}
// getCheckerOrError gets the checker.Env or error state in a concurrency-safe manner
func (e *Env) getCheckerOrError() (*checker.Env, error) {
e.chkMutex.Lock()
defer e.chkMutex.Unlock()
return e.chk, e.chkErr
}
// maybeApplyFeature determines whether the feature-guarded option is enabled, and if so applies
// the feature if it has not already been enabled.
func (e *Env) maybeApplyFeature(feature int, option EnvOption) (*Env, error) {
if !e.HasFeature(feature) {
return e, nil
}
_, applied := e.appliedFeatures[feature]
if applied {
return e, nil
}
e, err := option(e)
if err != nil {
return nil, err
}
// record that the feature has been applied since it will generate declarations
// and functions which will be propagated on Extend() calls and which should only
// be registered once.
e.appliedFeatures[feature] = true
return e, nil
}
// computeUnknownVars determines a set of missing variables based on the input activation and the
// environment's configured declaration set.
func (e *Env) computeUnknownVars(vars interpreter.Activation) []*interpreter.AttributePattern {
var unknownPatterns []*interpreter.AttributePattern
for _, v := range e.variables {
varName := v.Name()
if _, found := vars.ResolveName(varName); found {
continue
}
unknownPatterns = append(unknownPatterns, interpreter.NewAttributePattern(varName))
}
return unknownPatterns
}
// Error type which references an expression id, a location within source, and a message.
type Error = common.Error
// Issues defines methods for inspecting the error details of parse and check calls.
//
// Note: in the future, non-fatal warnings and notices may be inspectable via the Issues struct.
type Issues struct {
errs *common.Errors
info *exprpb.SourceInfo
}
// NewIssues returns an Issues struct from a common.Errors object.
func NewIssues(errs *common.Errors) *Issues {
return NewIssuesWithSourceInfo(errs, nil)
}
// NewIssuesWithSourceInfo returns an Issues struct from a common.Errors object with SourceInfo metatata
// which can be used with the `ReportErrorAtID` method for additional error reports within the context
// information that's inferred from an expression id.
func NewIssuesWithSourceInfo(errs *common.Errors, info *exprpb.SourceInfo) *Issues {
return &Issues{
errs: errs,
info: info,
}
}
@@ -572,9 +737,9 @@ func (i *Issues) Err() error {
}
// Errors returns the collection of errors encountered in more granular detail.
func (i *Issues) Errors() []common.Error {
func (i *Issues) Errors() []*Error {
if i == nil {
return []common.Error{}
return []*Error{}
}
return i.errs.GetErrors()
}
@@ -598,6 +763,37 @@ func (i *Issues) String() string {
return i.errs.ToDisplayString()
}
// ReportErrorAtID reports an error message with an optional set of formatting arguments.
//
// The source metadata for the expression at `id`, if present, is attached to the error report.
// To ensure that source metadata is attached to error reports, use NewIssuesWithSourceInfo.
func (i *Issues) ReportErrorAtID(id int64, message string, args ...any) {
i.errs.ReportErrorAtID(id, locationByID(id, i.info), message, args...)
}
// locationByID returns a common.Location given an expression id.
//
// TODO: move this functionality into the native SourceInfo and an overhaul of the common.Source
// as this implementation relies on the abstractions present in the protobuf SourceInfo object,
// and is replicated in the checker.
func locationByID(id int64, sourceInfo *exprpb.SourceInfo) common.Location {
positions := sourceInfo.GetPositions()
var line = 1
if offset, found := positions[id]; found {
col := int(offset)
for _, lineOffset := range sourceInfo.GetLineOffsets() {
if lineOffset < offset {
line++
col = int(offset - lineOffset)
} else {
break
}
}
return common.NewLocation(line, col)
}
return common.NoLocation
}
// getStdEnv lazy initializes the CEL standard environment.
func getStdEnv() (*Env, error) {
stdEnvInit.Do(func() {
@@ -606,6 +802,90 @@ func getStdEnv() (*Env, error) {
return stdEnv, stdEnvErr
}
// interopCELTypeProvider layers support for the types.Provider interface on top of a ref.TypeProvider.
type interopCELTypeProvider struct {
ref.TypeProvider
}
// FindStructType returns a types.Type instance for the given fully-qualified typeName if one exists.
//
// This method proxies to the underyling ref.TypeProvider's FindType method and converts protobuf type
// into a native type representation. If the conversion fails, the type is listed as not found.
func (p *interopCELTypeProvider) FindStructType(typeName string) (*types.Type, bool) {
if et, found := p.FindType(typeName); found {
t, err := types.ExprTypeToType(et)
if err != nil {
return nil, false
}
return t, true
}
return nil, false
}
// FindStructFieldType returns a types.FieldType instance for the given fully-qualified typeName and field
// name, if one exists.
//
// This method proxies to the underyling ref.TypeProvider's FindFieldType method and converts protobuf type
// into a native type representation. If the conversion fails, the type is listed as not found.
func (p *interopCELTypeProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) {
if ft, found := p.FindFieldType(structType, fieldName); found {
t, err := types.ExprTypeToType(ft.Type)
if err != nil {
return nil, false
}
return &types.FieldType{
Type: t,
IsSet: ft.IsSet,
GetFrom: ft.GetFrom,
}, true
}
return nil, false
}
// interopLegacyTypeProvider layers support for the ref.TypeProvider interface on top of a types.Provider.
type interopLegacyTypeProvider struct {
types.Provider
}
// FindType retruns the protobuf Type representation for the input type name if one exists.
//
// This method proxies to the underlying types.Provider FindStructType method and converts the types.Type
// value to a protobuf Type representation.
//
// Failure to convert the type will result in the type not being found.
func (p *interopLegacyTypeProvider) FindType(typeName string) (*exprpb.Type, bool) {
if t, found := p.FindStructType(typeName); found {
et, err := types.TypeToExprType(t)
if err != nil {
return nil, false
}
return et, true
}
return nil, false
}
// FindFieldType returns the protobuf-based FieldType representation for the input type name and field,
// if one exists.
//
// This call proxies to the types.Provider FindStructFieldType method and converts the types.FIeldType
// value to a protobuf-based ref.FieldType representation if found.
//
// Failure to convert the FieldType will result in the field not being found.
func (p *interopLegacyTypeProvider) FindFieldType(structType, fieldName string) (*ref.FieldType, bool) {
if cft, found := p.FindStructFieldType(structType, fieldName); found {
et, err := types.TypeToExprType(cft.Type)
if err != nil {
return nil, false
}
return &ref.FieldType{
Type: et,
IsSet: cft.IsSet,
GetFrom: cft.GetFrom,
}, true
}
return nil, false
}
var (
stdEnvInit sync.Once
stdEnv *Env

View File

@@ -19,21 +19,23 @@ import (
"fmt"
"reflect"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/parser"
"google.golang.org/protobuf/proto"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
anypb "google.golang.org/protobuf/types/known/anypb"
)
// CheckedExprToAst converts a checked expression proto message to an Ast.
func CheckedExprToAst(checkedExpr *exprpb.CheckedExpr) *Ast {
return CheckedExprToAstWithSource(checkedExpr, nil)
checked, _ := CheckedExprToAstWithSource(checkedExpr, nil)
return checked
}
// CheckedExprToAstWithSource converts a checked expression proto message to an Ast,
@@ -44,29 +46,18 @@ func CheckedExprToAst(checkedExpr *exprpb.CheckedExpr) *Ast {
// through future calls.
//
// Prefer CheckedExprToAst if loading expressions from storage.
func CheckedExprToAstWithSource(checkedExpr *exprpb.CheckedExpr, src Source) *Ast {
refMap := checkedExpr.GetReferenceMap()
if refMap == nil {
refMap = map[int64]*exprpb.Reference{}
}
typeMap := checkedExpr.GetTypeMap()
if typeMap == nil {
typeMap = map[int64]*exprpb.Type{}
}
si := checkedExpr.GetSourceInfo()
if si == nil {
si = &exprpb.SourceInfo{}
}
if src == nil {
src = common.NewInfoSource(si)
func CheckedExprToAstWithSource(checkedExpr *exprpb.CheckedExpr, src Source) (*Ast, error) {
checkedAST, err := ast.CheckedExprToCheckedAST(checkedExpr)
if err != nil {
return nil, err
}
return &Ast{
expr: checkedExpr.GetExpr(),
info: si,
expr: checkedAST.Expr,
info: checkedAST.SourceInfo,
source: src,
refMap: refMap,
typeMap: typeMap,
}
refMap: checkedAST.ReferenceMap,
typeMap: checkedAST.TypeMap,
}, nil
}
// AstToCheckedExpr converts an Ast to an protobuf CheckedExpr value.
@@ -76,12 +67,13 @@ func AstToCheckedExpr(a *Ast) (*exprpb.CheckedExpr, error) {
if !a.IsChecked() {
return nil, fmt.Errorf("cannot convert unchecked ast")
}
return &exprpb.CheckedExpr{
Expr: a.Expr(),
SourceInfo: a.SourceInfo(),
cAst := &ast.CheckedAST{
Expr: a.expr,
SourceInfo: a.info,
ReferenceMap: a.refMap,
TypeMap: a.typeMap,
}, nil
}
return ast.CheckedASTToCheckedExpr(cAst)
}
// ParsedExprToAst converts a parsed expression proto message to an Ast.
@@ -202,7 +194,7 @@ func RefValueToValue(res ref.Val) (*exprpb.Value, error) {
}
var (
typeNameToTypeValue = map[string]*types.TypeValue{
typeNameToTypeValue = map[string]ref.Val{
"bool": types.BoolType,
"bytes": types.BytesType,
"double": types.DoubleType,
@@ -219,7 +211,7 @@ var (
)
// ValueToRefValue converts between exprpb.Value and ref.Val.
func ValueToRefValue(adapter ref.TypeAdapter, v *exprpb.Value) (ref.Val, error) {
func ValueToRefValue(adapter types.Adapter, v *exprpb.Value) (ref.Val, error) {
switch v.Kind.(type) {
case *exprpb.Value_NullValue:
return types.NullValue, nil

View File

@@ -15,15 +15,32 @@
package cel
import (
"math"
"strconv"
"strings"
"time"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/stdlib"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter/functions"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
const (
optMapMacro = "optMap"
optFlatMapMacro = "optFlatMap"
hasValueFunc = "hasValue"
optionalNoneFunc = "optional.none"
optionalOfFunc = "optional.of"
optionalOfNonZeroValueFunc = "optional.ofNonZeroValue"
valueFunc = "value"
unusedIterVar = "#unused"
)
// Library provides a collection of EnvOption and ProgramOption values used to configure a CEL
@@ -42,10 +59,27 @@ type Library interface {
ProgramOptions() []ProgramOption
}
// SingletonLibrary refines the Library interface to ensure that libraries in this format are only
// configured once within the environment.
type SingletonLibrary interface {
Library
// LibraryName provides a namespaced name which is used to check whether the library has already
// been configured in the environment.
LibraryName() string
}
// Lib creates an EnvOption out of a Library, allowing libraries to be provided as functional args,
// and to be linked to each other.
func Lib(l Library) EnvOption {
singleton, isSingleton := l.(SingletonLibrary)
return func(e *Env) (*Env, error) {
if isSingleton {
if e.HasLibrary(singleton.LibraryName()) {
return e, nil
}
e.libraries[singleton.LibraryName()] = true
}
var err error
for _, opt := range l.CompileOptions() {
e, err = opt(e)
@@ -67,19 +101,439 @@ func StdLib() EnvOption {
// features documented in the specification.
type stdLibrary struct{}
// EnvOptions returns options for the standard CEL function declarations and macros.
// LibraryName implements the SingletonLibrary interface method.
func (stdLibrary) LibraryName() string {
return "cel.lib.std"
}
// CompileOptions returns options for the standard CEL function declarations and macros.
func (stdLibrary) CompileOptions() []EnvOption {
return []EnvOption{
Declarations(checker.StandardDeclarations()...),
func(e *Env) (*Env, error) {
var err error
for _, fn := range stdlib.Functions() {
existing, found := e.functions[fn.Name()]
if found {
fn, err = existing.Merge(fn)
if err != nil {
return nil, err
}
}
e.functions[fn.Name()] = fn
}
return e, nil
},
func(e *Env) (*Env, error) {
e.variables = append(e.variables, stdlib.Types()...)
return e, nil
},
Macros(StandardMacros...),
}
}
// ProgramOptions returns function implementations for the standard CEL functions.
func (stdLibrary) ProgramOptions() []ProgramOption {
return []ProgramOption{
Functions(functions.StandardOverloads()...),
return []ProgramOption{}
}
// OptionalTypes enable support for optional syntax and types in CEL.
//
// The optional value type makes it possible to express whether variables have
// been provided, whether a result has been computed, and in the future whether
// an object field path, map key value, or list index has a value.
//
// # Syntax Changes
//
// OptionalTypes are unlike other CEL extensions because they modify the CEL
// syntax itself, notably through the use of a `?` preceding a field name or
// index value.
//
// ## Field Selection
//
// The optional syntax in field selection is denoted as `obj.?field`. In other
// words, if a field is set, return `optional.of(obj.field)“, else
// `optional.none()`. The optional field selection is viral in the sense that
// after the first optional selection all subsequent selections or indices
// are treated as optional, i.e. the following expressions are equivalent:
//
// obj.?field.subfield
// obj.?field.?subfield
//
// ## Indexing
//
// Similar to field selection, the optional syntax can be used in index
// expressions on maps and lists:
//
// list[?0]
// map[?key]
//
// ## Optional Field Setting
//
// When creating map or message literals, if a field may be optionally set
// based on its presence, then placing a `?` before the field name or key
// will ensure the type on the right-hand side must be optional(T) where T
// is the type of the field or key-value.
//
// The following returns a map with the key expression set only if the
// subfield is present, otherwise an empty map is created:
//
// {?key: obj.?field.subfield}
//
// ## Optional Element Setting
//
// When creating list literals, an element in the list may be optionally added
// when the element expression is preceded by a `?`:
//
// [a, ?b, ?c] // return a list with either [a], [a, b], [a, b, c], or [a, c]
//
// # Optional.Of
//
// Create an optional(T) value of a given value with type T.
//
// optional.of(10)
//
// # Optional.OfNonZeroValue
//
// Create an optional(T) value of a given value with type T if it is not a
// zero-value. A zero-value the default empty value for any given CEL type,
// including empty protobuf message types. If the value is empty, the result
// of this call will be optional.none().
//
// optional.ofNonZeroValue([1, 2, 3]) // optional(list(int))
// optional.ofNonZeroValue([]) // optional.none()
// optional.ofNonZeroValue(0) // optional.none()
// optional.ofNonZeroValue("") // optional.none()
//
// # Optional.None
//
// Create an empty optional value.
//
// # HasValue
//
// Determine whether the optional contains a value.
//
// optional.of(b'hello').hasValue() // true
// optional.ofNonZeroValue({}).hasValue() // false
//
// # Value
//
// Get the value contained by the optional. If the optional does not have a
// value, the result will be a CEL error.
//
// optional.of(b'hello').value() // b'hello'
// optional.ofNonZeroValue({}).value() // error
//
// # Or
//
// If the value on the left-hand side is optional.none(), the optional value
// on the right hand side is returned. If the value on the left-hand set is
// valued, then it is returned. This operation is short-circuiting and will
// only evaluate as many links in the `or` chain as are needed to return a
// non-empty optional value.
//
// obj.?field.or(m[?key])
// l[?index].or(obj.?field.subfield).or(obj.?other)
//
// # OrValue
//
// Either return the value contained within the optional on the left-hand side
// or return the alternative value on the right hand side.
//
// m[?key].orValue("none")
//
// # OptMap
//
// Apply a transformation to the optional's underlying value if it is not empty
// and return an optional typed result based on the transformation. The
// transformation expression type must return a type T which is wrapped into
// an optional.
//
// msg.?elements.optMap(e, e.size()).orValue(0)
//
// # OptFlatMap
//
// Introduced in version: 1
//
// Apply a transformation to the optional's underlying value if it is not empty
// and return the result. The transform expression must return an optional(T)
// rather than type T. This can be useful when dealing with zero values and
// conditionally generating an empty or non-empty result in ways which cannot
// be expressed with `optMap`.
//
// msg.?elements.optFlatMap(e, e[?0]) // return the first element if present.
func OptionalTypes(opts ...OptionalTypesOption) EnvOption {
lib := &optionalLib{version: math.MaxUint32}
for _, opt := range opts {
lib = opt(lib)
}
return Lib(lib)
}
type optionalLib struct {
version uint32
}
// OptionalTypesOption is a functional interface for configuring the strings library.
type OptionalTypesOption func(*optionalLib) *optionalLib
// OptionalTypesVersion configures the version of the optional type library.
//
// The version limits which functions are available. Only functions introduced
// below or equal to the given version included in the library. If this option
// is not set, all functions are available.
//
// See the library documentation to determine which version a function was introduced.
// If the documentation does not state which version a function was introduced, it can
// be assumed to be introduced at version 0, when the library was first created.
func OptionalTypesVersion(version uint32) OptionalTypesOption {
return func(lib *optionalLib) *optionalLib {
lib.version = version
return lib
}
}
// LibraryName implements the SingletonLibrary interface method.
func (lib *optionalLib) LibraryName() string {
return "cel.lib.optional"
}
// CompileOptions implements the Library interface method.
func (lib *optionalLib) CompileOptions() []EnvOption {
paramTypeK := TypeParamType("K")
paramTypeV := TypeParamType("V")
optionalTypeV := OptionalType(paramTypeV)
listTypeV := ListType(paramTypeV)
mapTypeKV := MapType(paramTypeK, paramTypeV)
opts := []EnvOption{
// Enable the optional syntax in the parser.
enableOptionalSyntax(),
// Introduce the optional type.
Types(types.OptionalType),
// Configure the optMap and optFlatMap macros.
Macros(NewReceiverMacro(optMapMacro, 2, optMap)),
// Global and member functions for working with optional values.
Function(optionalOfFunc,
Overload("optional_of", []*Type{paramTypeV}, optionalTypeV,
UnaryBinding(func(value ref.Val) ref.Val {
return types.OptionalOf(value)
}))),
Function(optionalOfNonZeroValueFunc,
Overload("optional_ofNonZeroValue", []*Type{paramTypeV}, optionalTypeV,
UnaryBinding(func(value ref.Val) ref.Val {
v, isZeroer := value.(traits.Zeroer)
if !isZeroer || !v.IsZeroValue() {
return types.OptionalOf(value)
}
return types.OptionalNone
}))),
Function(optionalNoneFunc,
Overload("optional_none", []*Type{}, optionalTypeV,
FunctionBinding(func(values ...ref.Val) ref.Val {
return types.OptionalNone
}))),
Function(valueFunc,
MemberOverload("optional_value", []*Type{optionalTypeV}, paramTypeV,
UnaryBinding(func(value ref.Val) ref.Val {
opt := value.(*types.Optional)
return opt.GetValue()
}))),
Function(hasValueFunc,
MemberOverload("optional_hasValue", []*Type{optionalTypeV}, BoolType,
UnaryBinding(func(value ref.Val) ref.Val {
opt := value.(*types.Optional)
return types.Bool(opt.HasValue())
}))),
// Implementation of 'or' and 'orValue' are special-cased to support short-circuiting in the
// evaluation chain.
Function("or",
MemberOverload("optional_or_optional", []*Type{optionalTypeV, optionalTypeV}, optionalTypeV)),
Function("orValue",
MemberOverload("optional_orValue_value", []*Type{optionalTypeV, paramTypeV}, paramTypeV)),
// OptSelect is handled specially by the type-checker, so the receiver's field type is used to determine the
// optput type.
Function(operators.OptSelect,
Overload("select_optional_field", []*Type{DynType, StringType}, optionalTypeV)),
// OptIndex is handled mostly like any other indexing operation on a list or map, so the type-checker can use
// these signatures to determine type-agreement without any special handling.
Function(operators.OptIndex,
Overload("list_optindex_optional_int", []*Type{listTypeV, IntType}, optionalTypeV),
Overload("optional_list_optindex_optional_int", []*Type{OptionalType(listTypeV), IntType}, optionalTypeV),
Overload("map_optindex_optional_value", []*Type{mapTypeKV, paramTypeK}, optionalTypeV),
Overload("optional_map_optindex_optional_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)),
// Index overloads to accommodate using an optional value as the operand.
Function(operators.Index,
Overload("optional_list_index_int", []*Type{OptionalType(listTypeV), IntType}, optionalTypeV),
Overload("optional_map_index_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)),
}
if lib.version >= 1 {
opts = append(opts, Macros(NewReceiverMacro(optFlatMapMacro, 2, optFlatMap)))
}
return opts
}
// ProgramOptions implements the Library interface method.
func (lib *optionalLib) ProgramOptions() []ProgramOption {
return []ProgramOption{
CustomDecorator(decorateOptionalOr),
}
}
func optMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
default:
return nil, meh.NewError(varIdent.GetId(), "optMap() variable name must be a simple identifier")
}
mapExpr := args[1]
return meh.GlobalCall(
operators.Conditional,
meh.ReceiverCall(hasValueFunc, target),
meh.GlobalCall(optionalOfFunc,
meh.Fold(
unusedIterVar,
meh.NewList(),
varName,
meh.ReceiverCall(valueFunc, target),
meh.LiteralBool(false),
meh.Ident(varName),
mapExpr,
),
),
meh.GlobalCall(optionalNoneFunc),
), nil
}
func optFlatMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
default:
return nil, meh.NewError(varIdent.GetId(), "optFlatMap() variable name must be a simple identifier")
}
mapExpr := args[1]
return meh.GlobalCall(
operators.Conditional,
meh.ReceiverCall(hasValueFunc, target),
meh.Fold(
unusedIterVar,
meh.NewList(),
varName,
meh.ReceiverCall(valueFunc, target),
meh.LiteralBool(false),
meh.Ident(varName),
mapExpr,
),
meh.GlobalCall(optionalNoneFunc),
), nil
}
func enableOptionalSyntax() EnvOption {
return func(e *Env) (*Env, error) {
e.prsrOpts = append(e.prsrOpts, parser.EnableOptionalSyntax(true))
return e, nil
}
}
func decorateOptionalOr(i interpreter.Interpretable) (interpreter.Interpretable, error) {
call, ok := i.(interpreter.InterpretableCall)
if !ok {
return i, nil
}
args := call.Args()
if len(args) != 2 {
return i, nil
}
switch call.Function() {
case "or":
if call.OverloadID() != "" && call.OverloadID() != "optional_or_optional" {
return i, nil
}
return &evalOptionalOr{
id: call.ID(),
lhs: args[0],
rhs: args[1],
}, nil
case "orValue":
if call.OverloadID() != "" && call.OverloadID() != "optional_orValue_value" {
return i, nil
}
return &evalOptionalOrValue{
id: call.ID(),
lhs: args[0],
rhs: args[1],
}, nil
default:
return i, nil
}
}
// evalOptionalOr selects between two optional values, either the first if it has a value, or
// the second optional expression is evaluated and returned.
type evalOptionalOr struct {
id int64
lhs interpreter.Interpretable
rhs interpreter.Interpretable
}
// ID implements the Interpretable interface method.
func (opt *evalOptionalOr) ID() int64 {
return opt.id
}
// Eval evaluates the left-hand side optional to determine whether it contains a value, else
// proceeds with the right-hand side evaluation.
func (opt *evalOptionalOr) Eval(ctx interpreter.Activation) ref.Val {
// short-circuit lhs.
optLHS := opt.lhs.Eval(ctx)
optVal, ok := optLHS.(*types.Optional)
if !ok {
return optLHS
}
if optVal.HasValue() {
return optVal
}
return opt.rhs.Eval(ctx)
}
// evalOptionalOrValue selects between an optional or a concrete value. If the optional has a value,
// its value is returned, otherwise the alternative value expression is evaluated and returned.
type evalOptionalOrValue struct {
id int64
lhs interpreter.Interpretable
rhs interpreter.Interpretable
}
// ID implements the Interpretable interface method.
func (opt *evalOptionalOrValue) ID() int64 {
return opt.id
}
// Eval evaluates the left-hand side optional to determine whether it contains a value, else
// proceeds with the right-hand side evaluation.
func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val {
// short-circuit lhs.
optLHS := opt.lhs.Eval(ctx)
optVal, ok := optLHS.(*types.Optional)
if !ok {
return optLHS
}
if optVal.HasValue() {
return optVal.GetValue()
}
return opt.rhs.Eval(ctx)
}
type timeUTCLibrary struct{}
@@ -100,28 +554,16 @@ var (
timeOverloadDeclarations = []EnvOption{
Function(overloads.TimeGetHours,
MemberOverload(overloads.DurationToHours, []*Type{DurationType}, IntType,
UnaryBinding(func(dur ref.Val) ref.Val {
d := dur.(types.Duration)
return types.Int(d.Hours())
}))),
UnaryBinding(types.DurationGetHours))),
Function(overloads.TimeGetMinutes,
MemberOverload(overloads.DurationToMinutes, []*Type{DurationType}, IntType,
UnaryBinding(func(dur ref.Val) ref.Val {
d := dur.(types.Duration)
return types.Int(d.Minutes())
}))),
UnaryBinding(types.DurationGetMinutes))),
Function(overloads.TimeGetSeconds,
MemberOverload(overloads.DurationToSeconds, []*Type{DurationType}, IntType,
UnaryBinding(func(dur ref.Val) ref.Val {
d := dur.(types.Duration)
return types.Int(d.Seconds())
}))),
UnaryBinding(types.DurationGetSeconds))),
Function(overloads.TimeGetMilliseconds,
MemberOverload(overloads.DurationToMilliseconds, []*Type{DurationType}, IntType,
UnaryBinding(func(dur ref.Val) ref.Val {
d := dur.(types.Duration)
return types.Int(d.Milliseconds())
}))),
UnaryBinding(types.DurationGetMilliseconds))),
Function(overloads.TimeGetFullYear,
MemberOverload(overloads.TimestampToYear, []*Type{TimestampType}, IntType,
UnaryBinding(func(ts ref.Val) ref.Val {

View File

@@ -15,8 +15,8 @@
package cel
import (
"github.com/google/cel-go/common"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
@@ -26,8 +26,11 @@ import (
// a Macro should be created per arg-count or as a var arg macro.
type Macro = parser.Macro
// MacroExpander converts a call and its associated arguments into a new CEL abstract syntax tree, or an error
// if the input arguments are not suitable for the expansion requirements for the macro in question.
// MacroExpander converts a call and its associated arguments into a new CEL abstract syntax tree.
//
// If the MacroExpander determines within the implementation that an expansion is not needed it may return
// a nil Expr value to indicate a non-match. However, if an expansion is to be performed, but the arguments
// are not well-formed, the result of the expansion will be an error.
//
// The MacroExpander accepts as arguments a MacroExprHelper as well as the arguments used in the function call
// and produces as output an Expr ast node.
@@ -59,21 +62,21 @@ func NewReceiverVarArgMacro(function string, expander MacroExpander) Macro {
}
// HasMacroExpander expands the input call arguments into a presence test, e.g. has(<operand>.field)
func HasMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func HasMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeHas(meh, target, args)
}
// ExistsMacroExpander expands the input call arguments into a comprehension that returns true if any of the
// elements in the range match the predicate expressions:
// <iterRange>.exists(<iterVar>, <predicate>)
func ExistsMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func ExistsMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeExists(meh, target, args)
}
// ExistsOneMacroExpander expands the input call arguments into a comprehension that returns true if exactly
// one of the elements in the range match the predicate expressions:
// <iterRange>.exists_one(<iterVar>, <predicate>)
func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeExistsOne(meh, target, args)
}
@@ -81,18 +84,20 @@ func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*ex
// input to produce an output list.
//
// There are two call patterns supported by map:
// <iterRange>.map(<iterVar>, <transform>)
// <iterRange>.map(<iterVar>, <predicate>, <transform>)
//
// <iterRange>.map(<iterVar>, <transform>)
// <iterRange>.map(<iterVar>, <predicate>, <transform>)
//
// In the second form only iterVar values which return true when provided to the predicate expression
// are transformed.
func MapMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func MapMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeMap(meh, target, args)
}
// FilterMacroExpander expands the input call arguments into a comprehension which produces a list which contains
// only elements which match the provided predicate expression:
// <iterRange>.filter(<iterVar>, <predicate>)
func FilterMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) {
func FilterMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) {
return parser.MakeFilter(meh, target, args)
}

View File

@@ -23,12 +23,14 @@ import (
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/dynamicpb"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/interpreter/functions"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
descpb "google.golang.org/protobuf/types/descriptorpb"
@@ -40,13 +42,6 @@ import (
const (
_ = iota
// Disallow heterogeneous aggregate (list, map) literals.
// Note, it is still possible to have heterogeneous aggregates when
// provided as variables to the expression, as well as via conversion
// of well-known dynamic types, or with unchecked expressions.
// Affects checking. Provides a subset of standard behavior.
featureDisableDynamicAggregateLiterals
// Enable the tracking of function call expressions replaced by macros.
featureEnableMacroCallTracking
@@ -61,6 +56,11 @@ const (
// on a CEL timestamp operation. This fixes the scenario where the input time
// is not already in UTC.
featureDefaultUTCTimeZone
// Enable the serialization of logical operator ASTs as variadic calls, thus
// compressing the logic graph to a single call when multiple like-operator
// expressions occur: e.g. a && b && c && d -> call(_&&_, [a, b, c, d])
featureVariadicLogicalASTs
)
// EnvOption is a functional interface for configuring the environment.
@@ -77,23 +77,26 @@ func ClearMacros() EnvOption {
}
}
// CustomTypeAdapter swaps the default ref.TypeAdapter implementation with a custom one.
// CustomTypeAdapter swaps the default types.Adapter implementation with a custom one.
//
// Note: This option must be specified before the Types and TypeDescs options when used together.
func CustomTypeAdapter(adapter ref.TypeAdapter) EnvOption {
func CustomTypeAdapter(adapter types.Adapter) EnvOption {
return func(e *Env) (*Env, error) {
e.adapter = adapter
return e, nil
}
}
// CustomTypeProvider swaps the default ref.TypeProvider implementation with a custom one.
// CustomTypeProvider replaces the types.Provider implementation with a custom one.
//
// The `provider` variable type may either be types.Provider or ref.TypeProvider (deprecated)
//
// Note: This option must be specified before the Types and TypeDescs options when used together.
func CustomTypeProvider(provider ref.TypeProvider) EnvOption {
func CustomTypeProvider(provider any) EnvOption {
return func(e *Env) (*Env, error) {
e.provider = provider
return e, nil
var err error
e.provider, err = maybeInteropProvider(provider)
return e, err
}
}
@@ -103,8 +106,28 @@ func CustomTypeProvider(provider ref.TypeProvider) EnvOption {
// for the environment. The NewEnv call builds on top of the standard CEL declarations. For a
// purely custom set of declarations use NewCustomEnv.
func Declarations(decls ...*exprpb.Decl) EnvOption {
declOpts := []EnvOption{}
var err error
var opt EnvOption
// Convert the declarations to `EnvOption` values ahead of time.
// Surface any errors in conversion when the options are applied.
for _, d := range decls {
opt, err = ExprDeclToDeclaration(d)
if err != nil {
break
}
declOpts = append(declOpts, opt)
}
return func(e *Env) (*Env, error) {
e.declarations = append(e.declarations, decls...)
if err != nil {
return nil, err
}
for _, o := range declOpts {
e, err = o(e)
if err != nil {
return nil, err
}
}
return e, nil
}
}
@@ -121,14 +144,25 @@ func EagerlyValidateDeclarations(enabled bool) EnvOption {
return features(featureEagerlyValidateDeclarations, enabled)
}
// HomogeneousAggregateLiterals option ensures that list and map literal entry types must agree
// during type-checking.
// HomogeneousAggregateLiterals disables mixed type list and map literal values.
//
// Note, it is still possible to have heterogeneous aggregates when provided as variables to the
// expression, as well as via conversion of well-known dynamic types, or with unchecked
// expressions.
func HomogeneousAggregateLiterals() EnvOption {
return features(featureDisableDynamicAggregateLiterals, true)
return ASTValidators(ValidateHomogeneousAggregateLiterals())
}
// variadicLogicalOperatorASTs flatten like-operator chained logical expressions into a single
// variadic call with N-terms. This behavior is useful when serializing to a protocol buffer as
// it will reduce the number of recursive calls needed to deserialize the AST later.
//
// For example, given the following expression the call graph will be rendered accordingly:
//
// expression: a && b && c && (d || e)
// ast: call(_&&_, [a, b, c, call(_||_, [d, e])])
func variadicLogicalOperatorASTs() EnvOption {
return features(featureVariadicLogicalASTs, true)
}
// Macros option extends the macro set configured in the environment.
@@ -163,19 +197,19 @@ func Container(name string) EnvOption {
// Abbreviations can be useful when working with variables, functions, and especially types from
// multiple namespaces:
//
// // CEL object construction
// qual.pkg.version.ObjTypeName{
// field: alt.container.ver.FieldTypeName{value: ...}
// }
// // CEL object construction
// qual.pkg.version.ObjTypeName{
// field: alt.container.ver.FieldTypeName{value: ...}
// }
//
// Only one the qualified names above may be used as the CEL container, so at least one of these
// references must be a long qualified name within an otherwise short CEL program. Using the
// following abbreviations, the program becomes much simpler:
//
// // CEL Go option
// Abbrevs("qual.pkg.version.ObjTypeName", "alt.container.ver.FieldTypeName")
// // Simplified Object construction
// ObjTypeName{field: FieldTypeName{value: ...}}
// // CEL Go option
// Abbrevs("qual.pkg.version.ObjTypeName", "alt.container.ver.FieldTypeName")
// // Simplified Object construction
// ObjTypeName{field: FieldTypeName{value: ...}}
//
// There are a few rules for the qualified names and the simple abbreviations generated from them:
// - Qualified names must be dot-delimited, e.g. `package.subpkg.name`.
@@ -188,9 +222,12 @@ func Container(name string) EnvOption {
// - Expanded abbreviations do not participate in namespace resolution.
// - Abbreviation expansion is done instead of the container search for a matching identifier.
// - Containers follow C++ namespace resolution rules with searches from the most qualified name
// to the least qualified name.
//
// to the least qualified name.
//
// - Container references within the CEL program may be relative, and are resolved to fully
// qualified names at either type-check time or program plan time, whichever comes first.
//
// qualified names at either type-check time or program plan time, whichever comes first.
//
// If there is ever a case where an identifier could be in both the container and as an
// abbreviation, the abbreviation wins as this will ensure that the meaning of a program is
@@ -216,9 +253,14 @@ func Abbrevs(qualifiedNames ...string) EnvOption {
// environment by default.
//
// Note: This option must be specified after the CustomTypeProvider option when used together.
func Types(addTypes ...interface{}) EnvOption {
func Types(addTypes ...any) EnvOption {
return func(e *Env) (*Env, error) {
reg, isReg := e.provider.(ref.TypeRegistry)
var reg ref.TypeRegistry
var isReg bool
reg, isReg = e.provider.(*types.Registry)
if !isReg {
reg, isReg = e.provider.(ref.TypeRegistry)
}
if !isReg {
return nil, fmt.Errorf("custom types not supported by provider: %T", e.provider)
}
@@ -253,7 +295,7 @@ func Types(addTypes ...interface{}) EnvOption {
//
// TypeDescs are hermetic to a single Env object, but may be copied to other Env values via
// extension or by re-using the same EnvOption with another NewEnv() call.
func TypeDescs(descs ...interface{}) EnvOption {
func TypeDescs(descs ...any) EnvOption {
return func(e *Env) (*Env, error) {
reg, isReg := e.provider.(ref.TypeRegistry)
if !isReg {
@@ -350,8 +392,8 @@ func Functions(funcs ...*functions.Overload) ProgramOption {
// variables with the same name provided to the Eval() call. If Globals is used in a Library with
// a Lib EnvOption, vars may shadow variables provided by previously added libraries.
//
// The vars value may either be an `interpreter.Activation` instance or a `map[string]interface{}`.
func Globals(vars interface{}) ProgramOption {
// The vars value may either be an `interpreter.Activation` instance or a `map[string]any`.
func Globals(vars any) ProgramOption {
return func(p *prog) (*prog, error) {
defaultVars, err := interpreter.NewActivation(vars)
if err != nil {
@@ -404,6 +446,9 @@ const (
// OptTrackCost enables the runtime cost calculation while validation and return cost within evalDetails
// cost calculation is available via func ActualCost()
OptTrackCost EvalOption = 1 << iota
// OptCheckStringFormat enables compile-time checking of string.format calls for syntax/cardinality.
OptCheckStringFormat EvalOption = 1 << iota
)
// EvalOptions sets one or more evaluation options which may affect the evaluation or Result.
@@ -425,6 +470,24 @@ func InterruptCheckFrequency(checkFrequency uint) ProgramOption {
}
}
// CostEstimatorOptions configure type-check time options for estimating expression cost.
func CostEstimatorOptions(costOpts ...checker.CostOption) EnvOption {
return func(e *Env) (*Env, error) {
e.costOptions = append(e.costOptions, costOpts...)
return e, nil
}
}
// CostTrackerOptions configures a set of options for cost-tracking.
//
// Note, CostTrackerOptions is a no-op unless CostTracking is also enabled.
func CostTrackerOptions(costOpts ...interpreter.CostTrackerOption) ProgramOption {
return func(p *prog) (*prog, error) {
p.costOptions = append(p.costOptions, costOpts...)
return p, nil
}
}
// CostTracking enables cost tracking and registers a ActualCostEstimator that can optionally provide a runtime cost estimate for any function calls.
func CostTracking(costEstimator interpreter.ActualCostEstimator) ProgramOption {
return func(p *prog) (*prog, error) {
@@ -446,25 +509,21 @@ func CostLimit(costLimit uint64) ProgramOption {
}
}
func fieldToCELType(field protoreflect.FieldDescriptor) (*exprpb.Type, error) {
func fieldToCELType(field protoreflect.FieldDescriptor) (*Type, error) {
if field.Kind() == protoreflect.MessageKind || field.Kind() == protoreflect.GroupKind {
msgName := (string)(field.Message().FullName())
wellKnownType, found := pb.CheckedWellKnowns[msgName]
if found {
return wellKnownType, nil
}
return decls.NewObjectType(msgName), nil
return ObjectType(msgName), nil
}
if primitiveType, found := pb.CheckedPrimitives[field.Kind()]; found {
if primitiveType, found := types.ProtoCELPrimitives[field.Kind()]; found {
return primitiveType, nil
}
if field.Kind() == protoreflect.EnumKind {
return decls.Int, nil
return IntType, nil
}
return nil, fmt.Errorf("field %s type %s not implemented", field.FullName(), field.Kind().String())
}
func fieldToDecl(field protoreflect.FieldDescriptor) (*exprpb.Decl, error) {
func fieldToVariable(field protoreflect.FieldDescriptor) (EnvOption, error) {
name := string(field.Name())
if field.IsMap() {
mapKey := field.MapKey()
@@ -477,20 +536,20 @@ func fieldToDecl(field protoreflect.FieldDescriptor) (*exprpb.Decl, error) {
if err != nil {
return nil, err
}
return decls.NewVar(name, decls.NewMapType(keyType, valueType)), nil
return Variable(name, MapType(keyType, valueType)), nil
}
if field.IsList() {
elemType, err := fieldToCELType(field)
if err != nil {
return nil, err
}
return decls.NewVar(name, decls.NewListType(elemType)), nil
return Variable(name, ListType(elemType)), nil
}
celType, err := fieldToCELType(field)
if err != nil {
return nil, err
}
return decls.NewVar(name, celType), nil
return Variable(name, celType), nil
}
// DeclareContextProto returns an option to extend CEL environment with declarations from the given context proto.
@@ -498,25 +557,53 @@ func fieldToDecl(field protoreflect.FieldDescriptor) (*exprpb.Decl, error) {
// https://github.com/google/cel-spec/blob/master/doc/langdef.md#evaluation-environment
func DeclareContextProto(descriptor protoreflect.MessageDescriptor) EnvOption {
return func(e *Env) (*Env, error) {
var decls []*exprpb.Decl
fields := descriptor.Fields()
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
decl, err := fieldToDecl(field)
variable, err := fieldToVariable(field)
if err != nil {
return nil, err
}
e, err = variable(e)
if err != nil {
return nil, err
}
decls = append(decls, decl)
}
var err error
e, err = Declarations(decls...)(e)
if err != nil {
return nil, err
}
return Types(dynamicpb.NewMessage(descriptor))(e)
}
}
// ContextProtoVars uses the fields of the input proto.Messages as top-level variables within an Activation.
//
// Consider using with `DeclareContextProto` to simplify variable type declarations and publishing when using
// protocol buffers.
func ContextProtoVars(ctx proto.Message) (interpreter.Activation, error) {
if ctx == nil || !ctx.ProtoReflect().IsValid() {
return interpreter.EmptyActivation(), nil
}
reg, err := types.NewRegistry(ctx)
if err != nil {
return nil, err
}
pbRef := ctx.ProtoReflect()
typeName := string(pbRef.Descriptor().FullName())
fields := pbRef.Descriptor().Fields()
vars := make(map[string]any, fields.Len())
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
sft, found := reg.FindStructFieldType(typeName, field.TextName())
if !found {
return nil, fmt.Errorf("no such field: %s", field.TextName())
}
fieldVal, err := sft.GetFrom(ctx)
if err != nil {
return nil, err
}
vars[field.TextName()] = fieldVal
}
return interpreter.NewActivation(vars)
}
// EnableMacroCallTracking ensures that call expressions which are replaced by macros
// are tracked in the `SourceInfo` of parsed and checked expressions.
func EnableMacroCallTracking() EnvOption {
@@ -541,3 +628,32 @@ func features(flag int, enabled bool) EnvOption {
return e, nil
}
}
// ParserRecursionLimit adjusts the AST depth the parser will tolerate.
// Defaults defined in the parser package.
func ParserRecursionLimit(limit int) EnvOption {
return func(e *Env) (*Env, error) {
e.prsrOpts = append(e.prsrOpts, parser.MaxRecursionDepth(limit))
return e, nil
}
}
// ParserExpressionSizeLimit adjusts the number of code points the expression parser is allowed to parse.
// Defaults defined in the parser package.
func ParserExpressionSizeLimit(limit int) EnvOption {
return func(e *Env) (*Env, error) {
e.prsrOpts = append(e.prsrOpts, parser.ExpressionSizeCodePointLimit(limit))
return e, nil
}
}
func maybeInteropProvider(provider any) (types.Provider, error) {
switch p := provider.(type) {
case types.Provider:
return p, nil
case ref.TypeProvider:
return &interopCELTypeProvider{TypeProvider: p}, nil
default:
return nil, fmt.Errorf("unsupported type provider: %T", provider)
}
}

View File

@@ -17,11 +17,9 @@ package cel
import (
"context"
"fmt"
"math"
"sync"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
celast "github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
@@ -31,7 +29,7 @@ import (
type Program interface {
// Eval returns the result of an evaluation of the Ast and environment against the input vars.
//
// The vars value may either be an `interpreter.Activation` or a `map[string]interface{}`.
// The vars value may either be an `interpreter.Activation` or a `map[string]any`.
//
// If the `OptTrackState`, `OptTrackCost` or `OptExhaustiveEval` flags are used, the `details` response will
// be non-nil. Given this caveat on `details`, the return state from evaluation will be:
@@ -43,16 +41,16 @@ type Program interface {
// An unsuccessful evaluation is typically the result of a series of incompatible `EnvOption`
// or `ProgramOption` values used in the creation of the evaluation environment or executable
// program.
Eval(interface{}) (ref.Val, *EvalDetails, error)
Eval(any) (ref.Val, *EvalDetails, error)
// ContextEval evaluates the program with a set of input variables and a context object in order
// to support cancellation and timeouts. This method must be used in conjunction with the
// InterruptCheckFrequency() option for cancellation interrupts to be impact evaluation.
//
// The vars value may either be an `interpreter.Activation` or `map[string]interface{}`.
// The vars value may either be an `interpreter.Activation` or `map[string]any`.
//
// The output contract for `ContextEval` is otherwise identical to the `Eval` method.
ContextEval(context.Context, interface{}) (ref.Val, *EvalDetails, error)
ContextEval(context.Context, any) (ref.Val, *EvalDetails, error)
}
// NoVars returns an empty Activation.
@@ -63,9 +61,12 @@ func NoVars() interpreter.Activation {
// PartialVars returns a PartialActivation which contains variables and a set of AttributePattern
// values that indicate variables or parts of variables whose value are not yet known.
//
// This method relies on manually configured sets of missing attribute patterns. For a method which
// infers the missing variables from the input and the configured environment, use Env.PartialVars().
//
// The `vars` value may either be an interpreter.Activation or any valid input to the
// interpreter.NewActivation call.
func PartialVars(vars interface{},
func PartialVars(vars any,
unknowns ...*interpreter.AttributePattern) (interpreter.PartialActivation, error) {
return interpreter.NewPartialActivation(vars, unknowns...)
}
@@ -105,7 +106,7 @@ func (ed *EvalDetails) State() interpreter.EvalState {
// ActualCost returns the tracked cost through the course of execution when `CostTracking` is enabled.
// Otherwise, returns nil if the cost was not enabled.
func (ed *EvalDetails) ActualCost() *uint64 {
if ed.costTracker == nil {
if ed == nil || ed.costTracker == nil {
return nil
}
cost := ed.costTracker.ActualCost()
@@ -129,10 +130,14 @@ type prog struct {
// Interpretable configured from an Ast and aggregate decorator set based on program options.
interpretable interpreter.Interpretable
callCostEstimator interpreter.ActualCostEstimator
costOptions []interpreter.CostTrackerOption
costLimit *uint64
}
func (p *prog) clone() *prog {
costOptsCopy := make([]interpreter.CostTrackerOption, len(p.costOptions))
copy(costOptsCopy, p.costOptions)
return &prog{
Env: p.Env,
evalOpts: p.evalOpts,
@@ -154,9 +159,10 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
// Ensure the default attribute factory is set after the adapter and provider are
// configured.
p := &prog{
Env: e,
decorators: []interpreter.InterpretableDecorator{},
dispatcher: disp,
Env: e,
decorators: []interpreter.InterpretableDecorator{},
dispatcher: disp,
costOptions: []interpreter.CostTrackerOption{},
}
// Configure the program via the ProgramOption values.
@@ -170,7 +176,7 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
// Add the function bindings created via Function() options.
for _, fn := range e.functions {
bindings, err := fn.bindings()
bindings, err := fn.Bindings()
if err != nil {
return nil, err
}
@@ -207,12 +213,46 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
if len(p.regexOptimizations) > 0 {
decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...))
}
// Enable compile-time checking of syntax/cardinality for string.format calls.
if p.evalOpts&OptCheckStringFormat == OptCheckStringFormat {
var isValidType func(id int64, validTypes ...ref.Type) (bool, error)
if ast.IsChecked() {
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
t := ast.typeMap[id]
if t.Kind() == DynKind {
return true, nil
}
for _, vt := range validTypes {
k, err := typeValueToKind(vt)
if err != nil {
return false, err
}
if t.Kind() == k {
return true, nil
}
}
return false, nil
}
} else {
// if the AST isn't type-checked, short-circuit validation
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
return true, nil
}
}
decorators = append(decorators, interpreter.InterpolateFormattedString(isValidType))
}
// Enable exhaustive eval, state tracking and cost tracking last since they require a factory.
if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 {
factory := func(state interpreter.EvalState, costTracker *interpreter.CostTracker) (Program, error) {
costTracker.Estimator = p.callCostEstimator
costTracker.Limit = p.costLimit
for _, costOpt := range p.costOptions {
err := costOpt(costTracker)
if err != nil {
return nil, err
}
}
// Limit capacity to guarantee a reallocation when calling 'append(decs, ...)' below. This
// prevents the underlying memory from being shared between factory function calls causing
// undesired mutations.
@@ -254,10 +294,11 @@ func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecor
}
// When the AST has been checked it contains metadata that can be used to speed up program execution.
var checked *exprpb.CheckedExpr
checked, err := AstToCheckedExpr(ast)
if err != nil {
return nil, err
checked := &celast.CheckedAST{
Expr: ast.Expr(),
SourceInfo: ast.SourceInfo(),
TypeMap: ast.typeMap,
ReferenceMap: ast.refMap,
}
interpretable, err := p.interpreter.NewInterpretable(checked, decs...)
if err != nil {
@@ -268,7 +309,7 @@ func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecor
}
// Eval implements the Program interface method.
func (p *prog) Eval(input interface{}) (v ref.Val, det *EvalDetails, err error) {
func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) {
// Configure error recovery for unexpected panics during evaluation. Note, the use of named
// return values makes it possible to modify the error response during the recovery
// function.
@@ -287,11 +328,11 @@ func (p *prog) Eval(input interface{}) (v ref.Val, det *EvalDetails, err error)
switch v := input.(type) {
case interpreter.Activation:
vars = v
case map[string]interface{}:
case map[string]any:
vars = activationPool.Setup(v)
defer activationPool.Put(vars)
default:
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]interface{}, got: (%T)%v", input, input)
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input)
}
if p.defaultVars != nil {
vars = interpreter.NewHierarchicalActivation(p.defaultVars, vars)
@@ -307,7 +348,7 @@ func (p *prog) Eval(input interface{}) (v ref.Val, det *EvalDetails, err error)
}
// ContextEval implements the Program interface.
func (p *prog) ContextEval(ctx context.Context, input interface{}) (ref.Val, *EvalDetails, error) {
func (p *prog) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetails, error) {
if ctx == nil {
return nil, nil, fmt.Errorf("context can not be nil")
}
@@ -318,22 +359,17 @@ func (p *prog) ContextEval(ctx context.Context, input interface{}) (ref.Val, *Ev
case interpreter.Activation:
vars = ctxActivationPool.Setup(v, ctx.Done(), p.interruptCheckFrequency)
defer ctxActivationPool.Put(vars)
case map[string]interface{}:
case map[string]any:
rawVars := activationPool.Setup(v)
defer activationPool.Put(rawVars)
vars = ctxActivationPool.Setup(rawVars, ctx.Done(), p.interruptCheckFrequency)
defer ctxActivationPool.Put(vars)
default:
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]interface{}, got: (%T)%v", input, input)
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input)
}
return p.Eval(vars)
}
// Cost implements the Coster interface method.
func (p *prog) Cost() (min, max int64) {
return estimateCost(p.interpretable)
}
// progFactory is a helper alias for marking a program creation factory function.
type progFactory func(interpreter.EvalState, *interpreter.CostTracker) (Program, error)
@@ -346,7 +382,11 @@ type progGen struct {
// the test is successful.
func newProgGen(factory progFactory) (Program, error) {
// Test the factory to make sure that configuration errors are spotted at config
_, err := factory(interpreter.NewEvalState(), &interpreter.CostTracker{})
tracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, err
}
_, err = factory(interpreter.NewEvalState(), tracker)
if err != nil {
return nil, err
}
@@ -354,12 +394,15 @@ func newProgGen(factory progFactory) (Program, error) {
}
// Eval implements the Program interface method.
func (gen *progGen) Eval(input interface{}) (ref.Val, *EvalDetails, error) {
func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
// The factory based Eval() differs from the standard evaluation model in that it generates a
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
state := interpreter.NewEvalState()
costTracker := &interpreter.CostTracker{}
costTracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, nil, err
}
det := &EvalDetails{state: state, costTracker: costTracker}
// Generate a new instance of the interpretable using the factory configured during the call to
@@ -379,7 +422,7 @@ func (gen *progGen) Eval(input interface{}) (ref.Val, *EvalDetails, error) {
}
// ContextEval implements the Program interface method.
func (gen *progGen) ContextEval(ctx context.Context, input interface{}) (ref.Val, *EvalDetails, error) {
func (gen *progGen) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetails, error) {
if ctx == nil {
return nil, nil, fmt.Errorf("context can not be nil")
}
@@ -387,7 +430,10 @@ func (gen *progGen) ContextEval(ctx context.Context, input interface{}) (ref.Val
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
state := interpreter.NewEvalState()
costTracker := &interpreter.CostTracker{}
costTracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, nil, err
}
det := &EvalDetails{state: state, costTracker: costTracker}
// Generate a new instance of the interpretable using the factory configured during the call to
@@ -406,29 +452,6 @@ func (gen *progGen) ContextEval(ctx context.Context, input interface{}) (ref.Val
return v, det, nil
}
// Cost implements the Coster interface method.
func (gen *progGen) Cost() (min, max int64) {
// Use an empty state value since no evaluation is performed.
p, err := gen.factory(emptyEvalState, nil)
if err != nil {
return 0, math.MaxInt64
}
return estimateCost(p)
}
// EstimateCost returns the heuristic cost interval for the program.
func EstimateCost(p Program) (min, max int64) {
return estimateCost(p)
}
func estimateCost(i interface{}) (min, max int64) {
c, ok := i.(interpreter.Coster)
if !ok {
return 0, math.MaxInt64
}
return c.Cost()
}
type ctxEvalActivation struct {
parent interpreter.Activation
interrupt <-chan struct{}
@@ -438,7 +461,7 @@ type ctxEvalActivation struct {
// ResolveName implements the Activation interface method, but adds a special #interrupted variable
// which is capable of testing whether a 'done' signal is provided from a context.Context channel.
func (a *ctxEvalActivation) ResolveName(name string) (interface{}, bool) {
func (a *ctxEvalActivation) ResolveName(name string) (any, bool) {
if name == "#interrupted" {
a.interruptCheckCount++
if a.interruptCheckCount%a.interruptCheckFrequency == 0 {
@@ -461,7 +484,7 @@ func (a *ctxEvalActivation) Parent() interpreter.Activation {
func newCtxEvalActivationPool() *ctxEvalActivationPool {
return &ctxEvalActivationPool{
Pool: sync.Pool{
New: func() interface{} {
New: func() any {
return &ctxEvalActivation{}
},
},
@@ -483,21 +506,21 @@ func (p *ctxEvalActivationPool) Setup(vars interpreter.Activation, done <-chan s
}
type evalActivation struct {
vars map[string]interface{}
lazyVars map[string]interface{}
vars map[string]any
lazyVars map[string]any
}
// ResolveName looks up the value of the input variable name, if found.
//
// Lazy bindings may be supplied within the map-based input in either of the following forms:
// - func() interface{}
// - func() any
// - func() ref.Val
//
// The lazy binding will only be invoked once per evaluation.
//
// Values which are not represented as ref.Val types on input may be adapted to a ref.Val using
// the ref.TypeAdapter configured in the environment.
func (a *evalActivation) ResolveName(name string) (interface{}, bool) {
// the types.Adapter configured in the environment.
func (a *evalActivation) ResolveName(name string) (any, bool) {
v, found := a.vars[name]
if !found {
return nil, false
@@ -510,7 +533,7 @@ func (a *evalActivation) ResolveName(name string) (interface{}, bool) {
lazy := obj()
a.lazyVars[name] = lazy
return lazy, true
case func() interface{}:
case func() any:
if resolved, found := a.lazyVars[name]; found {
return resolved, true
}
@@ -530,8 +553,8 @@ func (a *evalActivation) Parent() interpreter.Activation {
func newEvalActivationPool() *evalActivationPool {
return &evalActivationPool{
Pool: sync.Pool{
New: func() interface{} {
return &evalActivation{lazyVars: make(map[string]interface{})}
New: func() any {
return &evalActivation{lazyVars: make(map[string]any)}
},
},
}
@@ -542,13 +565,13 @@ type evalActivationPool struct {
}
// Setup initializes a pooled Activation object with the map input.
func (p *evalActivationPool) Setup(vars map[string]interface{}) *evalActivation {
func (p *evalActivationPool) Setup(vars map[string]any) *evalActivation {
a := p.Pool.Get().(*evalActivation)
a.vars = vars
return a
}
func (p *evalActivationPool) Put(value interface{}) {
func (p *evalActivationPool) Put(value any) {
a := value.(*evalActivation)
for k := range a.lazyVars {
delete(a.lazyVars, k)
@@ -559,7 +582,7 @@ func (p *evalActivationPool) Put(value interface{}) {
var (
emptyEvalState = interpreter.NewEvalState()
// activationPool is an internally managed pool of Activation values that wrap map[string]interface{} inputs
// activationPool is an internally managed pool of Activation values that wrap map[string]any inputs
activationPool = newEvalActivationPool()
// ctxActivationPool is an internally managed pool of Activation values that expose a special #interrupted variable

388
vendor/github.com/google/cel-go/cel/validator.go generated vendored Normal file
View File

@@ -0,0 +1,388 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cel
import (
"fmt"
"reflect"
"regexp"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/overloads"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
const (
homogeneousValidatorName = "cel.lib.std.validate.types.homogeneous"
// HomogeneousAggregateLiteralExemptFunctions is the ValidatorConfig key used to configure
// the set of function names which are exempt from homogeneous type checks. The expected type
// is a string list of function names.
//
// As an example, the `<string>.format([args])` call expects the input arguments list to be
// comprised of a variety of types which correspond to the types expected by the format control
// clauses; however, all other uses of a mixed element type list, would be unexpected.
HomogeneousAggregateLiteralExemptFunctions = homogeneousValidatorName + ".exempt"
)
// ASTValidators configures a set of ASTValidator instances into the target environment.
//
// Validators are applied in the order in which the are specified and are treated as singletons.
// The same ASTValidator with a given name will not be applied more than once.
func ASTValidators(validators ...ASTValidator) EnvOption {
return func(e *Env) (*Env, error) {
for _, v := range validators {
if !e.HasValidator(v.Name()) {
e.validators = append(e.validators, v)
}
}
return e, nil
}
}
// ASTValidator defines a singleton interface for validating a type-checked Ast against an environment.
//
// Note: the Issues argument is mutable in the sense that it is intended to collect errors which will be
// reported to the caller.
type ASTValidator interface {
// Name returns the name of the validator. Names must be unique.
Name() string
// Validate validates a given Ast within an Environment and collects a set of potential issues.
//
// The ValidatorConfig is generated from the set of ASTValidatorConfigurer instances prior to
// the invocation of the Validate call. The expectation is that the validator configuration
// is created in sequence and immutable once provided to the Validate call.
//
// See individual validators for more information on their configuration keys and configuration
// properties.
Validate(*Env, ValidatorConfig, *ast.CheckedAST, *Issues)
}
// ValidatorConfig provides an accessor method for querying validator configuration state.
type ValidatorConfig interface {
GetOrDefault(name string, value any) any
}
// MutableValidatorConfig provides mutation methods for querying and updating validator configuration
// settings.
type MutableValidatorConfig interface {
ValidatorConfig
Set(name string, value any) error
}
// ASTValidatorConfigurer indicates that this object, currently expected to be an ASTValidator,
// participates in validator configuration settings.
//
// This interface may be split from the expectation of being an ASTValidator instance in the future.
type ASTValidatorConfigurer interface {
Configure(MutableValidatorConfig) error
}
// validatorConfig implements the ValidatorConfig and MutableValidatorConfig interfaces.
type validatorConfig struct {
data map[string]any
}
// newValidatorConfig initializes the validator config with default values for core CEL validators.
func newValidatorConfig() *validatorConfig {
return &validatorConfig{
data: map[string]any{
HomogeneousAggregateLiteralExemptFunctions: []string{},
},
}
}
// GetOrDefault returns the configured value for the name, if present, else the input default value.
//
// Note, the type-agreement between the input default and configured value is not checked on read.
func (config *validatorConfig) GetOrDefault(name string, value any) any {
v, found := config.data[name]
if !found {
return value
}
return v
}
// Set configures a validator option with the given name and value.
//
// If the value had previously been set, the new value must have the same reflection type as the old one,
// or the call will error.
func (config *validatorConfig) Set(name string, value any) error {
v, found := config.data[name]
if found && reflect.TypeOf(v) != reflect.TypeOf(value) {
return fmt.Errorf("incompatible configuration type for %s, got %T, wanted %T", name, value, v)
}
config.data[name] = value
return nil
}
// ExtendedValidations collects a set of common AST validations which reduce the likelihood of runtime errors.
//
// - Validate duration and timestamp literals
// - Ensure regex strings are valid
// - Disable mixed type list and map literals
func ExtendedValidations() EnvOption {
return ASTValidators(
ValidateDurationLiterals(),
ValidateTimestampLiterals(),
ValidateRegexLiterals(),
ValidateHomogeneousAggregateLiterals(),
)
}
// ValidateDurationLiterals ensures that duration literal arguments are valid immediately after type-check.
func ValidateDurationLiterals() ASTValidator {
return newFormatValidator(overloads.TypeConvertDuration, 0, evalCall)
}
// ValidateTimestampLiterals ensures that timestamp literal arguments are valid immediately after type-check.
func ValidateTimestampLiterals() ASTValidator {
return newFormatValidator(overloads.TypeConvertTimestamp, 0, evalCall)
}
// ValidateRegexLiterals ensures that regex patterns are validated after type-check.
func ValidateRegexLiterals() ASTValidator {
return newFormatValidator(overloads.Matches, 0, compileRegex)
}
// ValidateHomogeneousAggregateLiterals checks that all list and map literals entries have the same types, i.e.
// no mixed list element types or mixed map key or map value types.
//
// Note: the string format call relies on a mixed element type list for ease of use, so this check skips all
// literals which occur within string format calls.
func ValidateHomogeneousAggregateLiterals() ASTValidator {
return homogeneousAggregateLiteralValidator{}
}
// ValidateComprehensionNestingLimit ensures that comprehension nesting does not exceed the specified limit.
//
// This validator can be useful for preventing arbitrarily nested comprehensions which can take high polynomial
// time to complete.
//
// Note, this limit does not apply to comprehensions with an empty iteration range, as these comprehensions have
// no actual looping cost. The cel.bind() utilizes the comprehension structure to perform local variable
// assignments and supplies an empty iteration range, so they won't count against the nesting limit either.
func ValidateComprehensionNestingLimit(limit int) ASTValidator {
return nestingLimitValidator{limit: limit}
}
type argChecker func(env *Env, call, arg ast.NavigableExpr) error
func newFormatValidator(funcName string, argNum int, check argChecker) formatValidator {
return formatValidator{
funcName: funcName,
check: check,
argNum: argNum,
}
}
type formatValidator struct {
funcName string
argNum int
check argChecker
}
// Name returns the unique name of this function format validator.
func (v formatValidator) Name() string {
return fmt.Sprintf("cel.lib.std.validate.functions.%s", v.funcName)
}
// Validate searches the AST for uses of a given function name with a constant argument and performs a check
// on whether the argument is a valid literal value.
func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
root := ast.NavigateCheckedAST(a)
funcCalls := ast.MatchDescendants(root, ast.FunctionMatcher(v.funcName))
for _, call := range funcCalls {
callArgs := call.AsCall().Args()
if len(callArgs) <= v.argNum {
continue
}
litArg := callArgs[v.argNum]
if litArg.Kind() != ast.LiteralKind {
continue
}
if err := v.check(e, call, litArg); err != nil {
iss.ReportErrorAtID(litArg.ID(), "invalid %s argument", v.funcName)
}
}
}
func evalCall(env *Env, call, arg ast.NavigableExpr) error {
ast := ParsedExprToAst(&exprpb.ParsedExpr{Expr: call.ToExpr()})
prg, err := env.Program(ast)
if err != nil {
return err
}
_, _, err = prg.Eval(NoVars())
return err
}
func compileRegex(_ *Env, _, arg ast.NavigableExpr) error {
pattern := arg.AsLiteral().Value().(string)
_, err := regexp.Compile(pattern)
return err
}
type homogeneousAggregateLiteralValidator struct{}
// Name returns the unique name of the homogeneous type validator.
func (homogeneousAggregateLiteralValidator) Name() string {
return homogeneousValidatorName
}
// Configure implements the ASTValidatorConfigurer interface and currently sets the list of standard
// and exempt functions from homogeneous aggregate literal checks.
//
// TODO: Move this call into the string.format() ASTValidator once ported.
func (homogeneousAggregateLiteralValidator) Configure(c MutableValidatorConfig) error {
emptyList := []string{}
exemptFunctions := c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, emptyList).([]string)
exemptFunctions = append(exemptFunctions, "format")
return c.Set(HomogeneousAggregateLiteralExemptFunctions, exemptFunctions)
}
// Validate validates that all lists and map literals have homogeneous types, i.e. don't contain dyn types.
//
// This validator makes an exception for list and map literals which occur at any level of nesting within
// string format calls.
func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
var exemptedFunctions []string
exemptedFunctions = c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, exemptedFunctions).([]string)
root := ast.NavigateCheckedAST(a)
listExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.ListKind))
for _, listExpr := range listExprs {
if inExemptFunction(listExpr, exemptedFunctions) {
continue
}
l := listExpr.AsList()
elements := l.Elements()
optIndices := l.OptionalIndices()
var elemType *Type
for i, e := range elements {
et := e.Type()
if isOptionalIndex(i, optIndices) {
et = et.Parameters()[0]
}
if elemType == nil {
elemType = et
continue
}
if !elemType.IsEquivalentType(et) {
v.typeMismatch(iss, e.ID(), elemType, et)
break
}
}
}
mapExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.MapKind))
for _, mapExpr := range mapExprs {
if inExemptFunction(mapExpr, exemptedFunctions) {
continue
}
m := mapExpr.AsMap()
entries := m.Entries()
var keyType, valType *Type
for _, e := range entries {
key, val := e.Key(), e.Value()
kt, vt := key.Type(), val.Type()
if e.IsOptional() {
vt = vt.Parameters()[0]
}
if keyType == nil && valType == nil {
keyType, valType = kt, vt
continue
}
if !keyType.IsEquivalentType(kt) {
v.typeMismatch(iss, key.ID(), keyType, kt)
}
if !valType.IsEquivalentType(vt) {
v.typeMismatch(iss, val.ID(), valType, vt)
}
}
}
}
func inExemptFunction(e ast.NavigableExpr, exemptFunctions []string) bool {
if parent, found := e.Parent(); found {
if parent.Kind() == ast.CallKind {
fnName := parent.AsCall().FunctionName()
for _, exempt := range exemptFunctions {
if exempt == fnName {
return true
}
}
}
if parent.Kind() == ast.ListKind || parent.Kind() == ast.MapKind {
return inExemptFunction(parent, exemptFunctions)
}
}
return false
}
func isOptionalIndex(i int, optIndices []int32) bool {
for _, optInd := range optIndices {
if i == int(optInd) {
return true
}
}
return false
}
func (homogeneousAggregateLiteralValidator) typeMismatch(iss *Issues, id int64, expected, actual *Type) {
iss.ReportErrorAtID(id, "expected type '%s' but found '%s'", FormatCELType(expected), FormatCELType(actual))
}
type nestingLimitValidator struct {
limit int
}
func (v nestingLimitValidator) Name() string {
return "cel.lib.std.validate.comprehension_nesting_limit"
}
func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, iss *Issues) {
root := ast.NavigateCheckedAST(a)
comprehensions := ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind))
if len(comprehensions) <= v.limit {
return
}
for _, comp := range comprehensions {
count := 0
e := comp
hasParent := true
for hasParent {
// When the expression is not a comprehension, continue to the next ancestor.
if e.Kind() != ast.ComprehensionKind {
e, hasParent = e.Parent()
continue
}
// When the comprehension has an empty range, continue to the next ancestor
// as this comprehension does not have any associated cost.
iterRange := e.AsComprehension().IterRange()
if iterRange.Kind() == ast.ListKind && iterRange.AsList().Size() == 0 {
e, hasParent = e.Parent()
continue
}
// Otherwise check the nesting limit.
count++
if count > v.limit {
iss.ReportErrorAtID(comp.ID(), "comprehension exceeds nesting limit")
break
}
e, hasParent = e.Parent()
}
}
}

View File

@@ -11,9 +11,11 @@ go_library(
"cost.go",
"env.go",
"errors.go",
"format.go",
"mapping.go",
"options.go",
"printer.go",
"scopes.go",
"standard.go",
"types.go",
],
@@ -22,15 +24,18 @@ go_library(
deps = [
"//checker/decls:go_default_library",
"//common:go_default_library",
"//common/ast:go_default_library",
"//common/containers:go_default_library",
"//common/debug:go_default_library",
"//common/decls:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/stdlib:go_default_library",
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
"//common/types/ref:go_default_library",
"//parser:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/emptypb:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
@@ -44,6 +49,7 @@ go_test(
"checker_test.go",
"cost_test.go",
"env_test.go",
"format_test.go",
],
embed = [
":go_default_library",
@@ -54,7 +60,7 @@ go_test(
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@com_github_antlr_antlr4_runtime_go_antlr//:go_default_library",
"@com_github_antlr_antlr4_runtime_go_antlr_v4//:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
)

View File

@@ -18,14 +18,13 @@ package checker
import (
"fmt"
"reflect"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/types/ref"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
@@ -36,8 +35,8 @@ type checker struct {
mappings *mapping
freeTypeVarCounter int
sourceInfo *exprpb.SourceInfo
types map[int64]*exprpb.Type
references map[int64]*exprpb.Reference
types map[int64]*types.Type
references map[int64]*ast.ReferenceInfo
}
// Check performs type checking, giving a typed AST.
@@ -46,40 +45,38 @@ type checker struct {
// descriptions of protocol buffers, and a registry for errors.
// Returns a CheckedExpr proto, which might not be usable if
// there are errors in the error registry.
func Check(parsedExpr *exprpb.ParsedExpr,
source common.Source,
env *Env) (*exprpb.CheckedExpr, *common.Errors) {
func Check(parsedExpr *exprpb.ParsedExpr, source common.Source, env *Env) (*ast.CheckedAST, *common.Errors) {
errs := common.NewErrors(source)
c := checker{
env: env,
errors: &typeErrors{common.NewErrors(source)},
errors: &typeErrors{errs: errs},
mappings: newMapping(),
freeTypeVarCounter: 0,
sourceInfo: parsedExpr.GetSourceInfo(),
types: make(map[int64]*exprpb.Type),
references: make(map[int64]*exprpb.Reference),
types: make(map[int64]*types.Type),
references: make(map[int64]*ast.ReferenceInfo),
}
c.check(parsedExpr.GetExpr())
// Walk over the final type map substituting any type parameters either by their bound value or
// by DYN.
m := make(map[int64]*exprpb.Type)
for k, v := range c.types {
m[k] = substitute(c.mappings, v, true)
m := make(map[int64]*types.Type)
for id, t := range c.types {
m[id] = substitute(c.mappings, t, true)
}
return &exprpb.CheckedExpr{
return &ast.CheckedAST{
Expr: parsedExpr.GetExpr(),
SourceInfo: parsedExpr.GetSourceInfo(),
TypeMap: m,
ReferenceMap: c.references,
}, c.errors.Errors
}, errs
}
func (c *checker) check(e *exprpb.Expr) {
if e == nil {
return
}
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
literal := e.GetConstExpr()
@@ -112,53 +109,51 @@ func (c *checker) check(e *exprpb.Expr) {
case *exprpb.Expr_ComprehensionExpr:
c.checkComprehension(e)
default:
c.errors.ReportError(
c.location(e), "Unrecognized ast type: %v", reflect.TypeOf(e))
c.errors.unexpectedASTType(e.GetId(), c.location(e), e)
}
}
func (c *checker) checkInt64Literal(e *exprpb.Expr) {
c.setType(e, decls.Int)
c.setType(e, types.IntType)
}
func (c *checker) checkUint64Literal(e *exprpb.Expr) {
c.setType(e, decls.Uint)
c.setType(e, types.UintType)
}
func (c *checker) checkStringLiteral(e *exprpb.Expr) {
c.setType(e, decls.String)
c.setType(e, types.StringType)
}
func (c *checker) checkBytesLiteral(e *exprpb.Expr) {
c.setType(e, decls.Bytes)
c.setType(e, types.BytesType)
}
func (c *checker) checkDoubleLiteral(e *exprpb.Expr) {
c.setType(e, decls.Double)
c.setType(e, types.DoubleType)
}
func (c *checker) checkBoolLiteral(e *exprpb.Expr) {
c.setType(e, decls.Bool)
c.setType(e, types.BoolType)
}
func (c *checker) checkNullLiteral(e *exprpb.Expr) {
c.setType(e, decls.Null)
c.setType(e, types.NullType)
}
func (c *checker) checkIdent(e *exprpb.Expr) {
identExpr := e.GetIdentExpr()
// Check to see if the identifier is declared.
if ident := c.env.LookupIdent(identExpr.GetName()); ident != nil {
c.setType(e, ident.GetIdent().GetType())
c.setReference(e, newIdentReference(ident.GetName(), ident.GetIdent().GetValue()))
c.setType(e, ident.Type())
c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value()))
// Overwrite the identifier with its fully qualified name.
identExpr.Name = ident.GetName()
identExpr.Name = ident.Name()
return
}
c.setType(e, decls.Error)
c.errors.undeclaredReference(
c.location(e), c.env.container.Name(), identExpr.GetName())
c.setType(e, types.ErrorType)
c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), identExpr.GetName())
}
func (c *checker) checkSelect(e *exprpb.Expr) {
@@ -173,9 +168,9 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
// Rewrite the node to be a variable reference to the resolved fully-qualified
// variable name.
c.setType(e, ident.GetIdent().Type)
c.setReference(e, newIdentReference(ident.GetName(), ident.GetIdent().Value))
identName := ident.GetName()
c.setType(e, ident.Type())
c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value()))
identName := ident.Name()
e.ExprKind = &exprpb.Expr_IdentExpr{
IdentExpr: &exprpb.Expr_Ident{
Name: identName,
@@ -185,43 +180,72 @@ func (c *checker) checkSelect(e *exprpb.Expr) {
}
}
resultType := c.checkSelectField(e, sel.GetOperand(), sel.GetField(), false)
if sel.TestOnly {
resultType = types.BoolType
}
c.setType(e, substitute(c.mappings, resultType, false))
}
func (c *checker) checkOptSelect(e *exprpb.Expr) {
// Collect metadata related to the opt select call packaged by the parser.
call := e.GetCallExpr()
operand := call.GetArgs()[0]
field := call.GetArgs()[1]
fieldName, isString := maybeUnwrapString(field)
if !isString {
c.errors.notAnOptionalFieldSelection(field.GetId(), c.location(field), field)
return
}
// Perform type-checking using the field selection logic.
resultType := c.checkSelectField(e, operand, fieldName, true)
c.setType(e, substitute(c.mappings, resultType, false))
c.setReference(e, ast.NewFunctionReference("select_optional_field"))
}
func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, optional bool) *types.Type {
// Interpret as field selection, first traversing down the operand.
c.check(sel.GetOperand())
targetType := substitute(c.mappings, c.getType(sel.GetOperand()), false)
c.check(operand)
operandType := substitute(c.mappings, c.getType(operand), false)
// If the target type is 'optional', unwrap it for the sake of this check.
targetType, isOpt := maybeUnwrapOptional(operandType)
// Assume error type by default as most types do not support field selection.
resultType := decls.Error
switch kindOf(targetType) {
case kindMap:
resultType := types.ErrorType
switch targetType.Kind() {
case types.MapKind:
// Maps yield their value type as the selection result type.
mapType := targetType.GetMapType()
resultType = mapType.GetValueType()
case kindObject:
resultType = targetType.Parameters()[1]
case types.StructKind:
// Objects yield their field type declaration as the selection result type, but only if
// the field is defined.
messageType := targetType
if fieldType, found := c.lookupFieldType(c.location(e), messageType.GetMessageType(), sel.GetField()); found {
resultType = fieldType.Type
if fieldType, found := c.lookupFieldType(e.GetId(), messageType.TypeName(), field); found {
resultType = fieldType
}
case kindTypeParam:
case types.TypeParamKind:
// Set the operand type to DYN to prevent assignment to a potentially incorrect type
// at a later point in type-checking. The isAssignable call will update the type
// substitutions for the type param under the covers.
c.isAssignable(decls.Dyn, targetType)
c.isAssignable(types.DynType, targetType)
// Also, set the result type to DYN.
resultType = decls.Dyn
resultType = types.DynType
default:
// Dynamic / error values are treated as DYN type. Errors are handled this way as well
// in order to allow forward progress on the check.
if isDynOrError(targetType) {
resultType = decls.Dyn
} else {
c.errors.typeDoesNotSupportFieldSelection(c.location(e), targetType)
if !isDynOrError(targetType) {
c.errors.typeDoesNotSupportFieldSelection(e.GetId(), c.location(e), targetType)
}
resultType = types.DynType
}
if sel.TestOnly {
resultType = decls.Bool
// If the target type was optional coming in, then the result must be optional going out.
if isOpt || optional {
return types.NewOptionalType(resultType)
}
c.setType(e, substitute(c.mappings, resultType, false))
return resultType
}
func (c *checker) checkCall(e *exprpb.Expr) {
@@ -229,29 +253,32 @@ func (c *checker) checkCall(e *exprpb.Expr) {
// please consider the impact on planner.go and consolidate implementations or mirror code
// as appropriate.
call := e.GetCallExpr()
target := call.GetTarget()
args := call.GetArgs()
fnName := call.GetFunction()
if fnName == operators.OptSelect {
c.checkOptSelect(e)
return
}
args := call.GetArgs()
// Traverse arguments.
for _, arg := range args {
c.check(arg)
}
target := call.GetTarget()
// Regular static call with simple name.
if target == nil {
// Check for the existence of the function.
fn := c.env.LookupFunction(fnName)
if fn == nil {
c.errors.undeclaredReference(
c.location(e), c.env.container.Name(), fnName)
c.setType(e, decls.Error)
c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), fnName)
c.setType(e, types.ErrorType)
return
}
// Overwrite the function name with its fully qualified resolved name.
call.Function = fn.GetName()
call.Function = fn.Name()
// Check to see whether the overload resolves.
c.resolveOverloadOrError(c.location(e), e, fn, nil, args)
c.resolveOverloadOrError(e, fn, nil, args)
return
}
@@ -269,8 +296,8 @@ func (c *checker) checkCall(e *exprpb.Expr) {
// be an inaccurate representation of the desired evaluation behavior.
// Overwrite with fully-qualified resolved function name sans receiver target.
call.Target = nil
call.Function = fn.GetName()
c.resolveOverloadOrError(c.location(e), e, fn, nil, args)
call.Function = fn.Name()
c.resolveOverloadOrError(e, fn, nil, args)
return
}
}
@@ -280,22 +307,21 @@ func (c *checker) checkCall(e *exprpb.Expr) {
fn := c.env.LookupFunction(fnName)
// Function found, attempt overload resolution.
if fn != nil {
c.resolveOverloadOrError(c.location(e), e, fn, target, args)
c.resolveOverloadOrError(e, fn, target, args)
return
}
// Function name not declared, record error.
c.errors.undeclaredReference(c.location(e), c.env.container.Name(), fnName)
c.setType(e, types.ErrorType)
c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), fnName)
}
func (c *checker) resolveOverloadOrError(
loc common.Location,
e *exprpb.Expr,
fn *exprpb.Decl, target *exprpb.Expr, args []*exprpb.Expr) {
e *exprpb.Expr, fn *decls.FunctionDecl, target *exprpb.Expr, args []*exprpb.Expr) {
// Attempt to resolve the overload.
resolution := c.resolveOverload(loc, fn, target, args)
resolution := c.resolveOverload(e, fn, target, args)
// No such overload, error noted in the resolveOverload call, type recorded here.
if resolution == nil {
c.setType(e, decls.Error)
c.setType(e, types.ErrorType)
return
}
// Overload found.
@@ -304,10 +330,9 @@ func (c *checker) resolveOverloadOrError(
}
func (c *checker) resolveOverload(
loc common.Location,
fn *exprpb.Decl, target *exprpb.Expr, args []*exprpb.Expr) *overloadResolution {
call *exprpb.Expr, fn *decls.FunctionDecl, target *exprpb.Expr, args []*exprpb.Expr) *overloadResolution {
var argTypes []*exprpb.Type
var argTypes []*types.Type
if target != nil {
argTypes = append(argTypes, c.getType(target))
}
@@ -315,52 +340,75 @@ func (c *checker) resolveOverload(
argTypes = append(argTypes, c.getType(arg))
}
var resultType *exprpb.Type
var checkedRef *exprpb.Reference
for _, overload := range fn.GetFunction().GetOverloads() {
var resultType *types.Type
var checkedRef *ast.ReferenceInfo
for _, overload := range fn.OverloadDecls() {
// Determine whether the overload is currently considered.
if c.env.isOverloadDisabled(overload.GetOverloadId()) {
if c.env.isOverloadDisabled(overload.ID()) {
continue
}
// Ensure the call style for the overload matches.
if (target == nil && overload.GetIsInstanceFunction()) ||
(target != nil && !overload.GetIsInstanceFunction()) {
if (target == nil && overload.IsMemberFunction()) ||
(target != nil && !overload.IsMemberFunction()) {
// not a compatible call style.
continue
}
overloadType := decls.NewFunctionType(overload.ResultType, overload.Params...)
if len(overload.GetTypeParams()) > 0 {
// Alternative type-checking behavior when the logical operators are compacted into
// variadic AST representations.
if fn.Name() == operators.LogicalAnd || fn.Name() == operators.LogicalOr {
checkedRef = ast.NewFunctionReference(overload.ID())
for i, argType := range argTypes {
if !c.isAssignable(argType, types.BoolType) {
c.errors.typeMismatch(
args[i].GetId(),
c.locationByID(args[i].GetId()),
types.BoolType,
argType)
resultType = types.ErrorType
}
}
if isError(resultType) {
return nil
}
return newResolution(checkedRef, types.BoolType)
}
overloadType := newFunctionType(overload.ResultType(), overload.ArgTypes()...)
typeParams := overload.TypeParams()
if len(typeParams) != 0 {
// Instantiate overload's type with fresh type variables.
substitutions := newMapping()
for _, typePar := range overload.GetTypeParams() {
substitutions.add(decls.NewTypeParamType(typePar), c.newTypeVar())
for _, typePar := range typeParams {
substitutions.add(types.NewTypeParamType(typePar), c.newTypeVar())
}
overloadType = substitute(substitutions, overloadType, false)
}
candidateArgTypes := overloadType.GetFunction().GetArgTypes()
candidateArgTypes := overloadType.Parameters()[1:]
if c.isAssignableList(argTypes, candidateArgTypes) {
if checkedRef == nil {
checkedRef = newFunctionReference(overload.GetOverloadId())
checkedRef = ast.NewFunctionReference(overload.ID())
} else {
checkedRef.OverloadId = append(checkedRef.GetOverloadId(), overload.GetOverloadId())
checkedRef.AddOverload(overload.ID())
}
// First matching overload, determines result type.
fnResultType := substitute(c.mappings, overloadType.GetFunction().GetResultType(), false)
fnResultType := substitute(c.mappings, overloadType.Parameters()[0], false)
if resultType == nil {
resultType = fnResultType
} else if !isDyn(resultType) && !proto.Equal(fnResultType, resultType) {
resultType = decls.Dyn
} else if !isDyn(resultType) && !fnResultType.IsExactType(resultType) {
resultType = types.DynType
}
}
}
if resultType == nil {
c.errors.noMatchingOverload(loc, fn.GetName(), argTypes, target != nil)
resultType = decls.Error
for i, argType := range argTypes {
argTypes[i] = substitute(c.mappings, argType, true)
}
c.errors.noMatchingOverload(call.GetId(), c.location(call), fn.Name(), argTypes, target != nil)
return nil
}
@@ -369,16 +417,29 @@ func (c *checker) resolveOverload(
func (c *checker) checkCreateList(e *exprpb.Expr) {
create := e.GetListExpr()
var elemType *exprpb.Type
for _, e := range create.GetElements() {
var elemsType *types.Type
optionalIndices := create.GetOptionalIndices()
optionals := make(map[int32]bool, len(optionalIndices))
for _, optInd := range optionalIndices {
optionals[optInd] = true
}
for i, e := range create.GetElements() {
c.check(e)
elemType = c.joinTypes(c.location(e), elemType, c.getType(e))
elemType := c.getType(e)
if optionals[int32(i)] {
var isOptional bool
elemType, isOptional = maybeUnwrapOptional(elemType)
if !isOptional && !isDyn(elemType) {
c.errors.typeMismatch(e.GetId(), c.location(e), types.NewOptionalType(elemType), elemType)
}
}
elemsType = c.joinTypes(e, elemsType, elemType)
}
if elemType == nil {
if elemsType == nil {
// If the list is empty, assign free type var to elem type.
elemType = c.newTypeVar()
elemsType = c.newTypeVar()
}
c.setType(e, decls.NewListType(elemType))
c.setType(e, types.NewListType(elemsType))
}
func (c *checker) checkCreateStruct(e *exprpb.Expr) {
@@ -392,55 +453,68 @@ func (c *checker) checkCreateStruct(e *exprpb.Expr) {
func (c *checker) checkCreateMap(e *exprpb.Expr) {
mapVal := e.GetStructExpr()
var keyType *exprpb.Type
var valueType *exprpb.Type
var mapKeyType *types.Type
var mapValueType *types.Type
for _, ent := range mapVal.GetEntries() {
key := ent.GetMapKey()
c.check(key)
keyType = c.joinTypes(c.location(key), keyType, c.getType(key))
mapKeyType = c.joinTypes(key, mapKeyType, c.getType(key))
c.check(ent.Value)
valueType = c.joinTypes(c.location(ent.Value), valueType, c.getType(ent.Value))
val := ent.GetValue()
c.check(val)
valType := c.getType(val)
if ent.GetOptionalEntry() {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(val.GetId(), c.location(val), types.NewOptionalType(valType), valType)
}
}
mapValueType = c.joinTypes(val, mapValueType, valType)
}
if keyType == nil {
if mapKeyType == nil {
// If the map is empty, assign free type variables to typeKey and value type.
keyType = c.newTypeVar()
valueType = c.newTypeVar()
mapKeyType = c.newTypeVar()
mapValueType = c.newTypeVar()
}
c.setType(e, decls.NewMapType(keyType, valueType))
c.setType(e, types.NewMapType(mapKeyType, mapValueType))
}
func (c *checker) checkCreateMessage(e *exprpb.Expr) {
msgVal := e.GetStructExpr()
// Determine the type of the message.
messageType := decls.Error
decl := c.env.LookupIdent(msgVal.GetMessageName())
if decl == nil {
resultType := types.ErrorType
ident := c.env.LookupIdent(msgVal.GetMessageName())
if ident == nil {
c.errors.undeclaredReference(
c.location(e), c.env.container.Name(), msgVal.GetMessageName())
e.GetId(), c.location(e), c.env.container.Name(), msgVal.GetMessageName())
c.setType(e, types.ErrorType)
return
}
// Ensure the type name is fully qualified in the AST.
msgVal.MessageName = decl.GetName()
c.setReference(e, newIdentReference(decl.GetName(), nil))
ident := decl.GetIdent()
identKind := kindOf(ident.GetType())
if identKind != kindError {
if identKind != kindType {
c.errors.notAType(c.location(e), ident.GetType())
typeName := ident.Name()
msgVal.MessageName = typeName
c.setReference(e, ast.NewIdentReference(ident.Name(), nil))
identKind := ident.Type().Kind()
if identKind != types.ErrorKind {
if identKind != types.TypeKind {
c.errors.notAType(e.GetId(), c.location(e), ident.Type().DeclaredTypeName())
} else {
messageType = ident.GetType().GetType()
if kindOf(messageType) != kindObject {
c.errors.notAMessageType(c.location(e), messageType)
messageType = decls.Error
resultType = ident.Type().Parameters()[0]
// Backwards compatibility test between well-known types and message types
// In this context, the type is being instantiated by its protobuf name which
// is not ideal or recommended, but some users expect this to work.
if isWellKnownType(resultType) {
typeName = getWellKnownTypeName(resultType)
} else if resultType.Kind() == types.StructKind {
typeName = resultType.DeclaredTypeName()
} else {
c.errors.notAMessageType(e.GetId(), c.location(e), resultType.DeclaredTypeName())
resultType = types.ErrorType
}
}
}
if isObjectWellKnownType(messageType) {
c.setType(e, getObjectWellKnownType(messageType))
} else {
c.setType(e, messageType)
}
c.setType(e, resultType)
// Check the field initializers.
for _, ent := range msgVal.GetEntries() {
@@ -448,16 +522,22 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) {
value := ent.GetValue()
c.check(value)
fieldType := decls.Error
if t, found := c.lookupFieldType(
c.locationByID(ent.GetId()),
messageType.GetMessageType(),
field); found {
fieldType = t.Type
fieldType := types.ErrorType
ft, found := c.lookupFieldType(ent.GetId(), typeName, field)
if found {
fieldType = ft
}
if !c.isAssignable(fieldType, c.getType(value)) {
c.errors.fieldTypeMismatch(
c.locationByID(ent.Id), field, fieldType, c.getType(value))
valType := c.getType(value)
if ent.GetOptionalEntry() {
var isOptional bool
valType, isOptional = maybeUnwrapOptional(valType)
if !isOptional && !isDyn(valType) {
c.errors.typeMismatch(value.GetId(), c.location(value), types.NewOptionalType(valType), valType)
}
}
if !c.isAssignable(fieldType, valType) {
c.errors.fieldTypeMismatch(ent.GetId(), c.locationByID(ent.GetId()), field, fieldType, valType)
}
}
}
@@ -468,36 +548,36 @@ func (c *checker) checkComprehension(e *exprpb.Expr) {
c.check(comp.GetAccuInit())
accuType := c.getType(comp.GetAccuInit())
rangeType := substitute(c.mappings, c.getType(comp.GetIterRange()), false)
var varType *exprpb.Type
var varType *types.Type
switch kindOf(rangeType) {
case kindList:
varType = rangeType.GetListType().GetElemType()
case kindMap:
switch rangeType.Kind() {
case types.ListKind:
varType = rangeType.Parameters()[0]
case types.MapKind:
// Ranges over the keys.
varType = rangeType.GetMapType().GetKeyType()
case kindDyn, kindError, kindTypeParam:
varType = rangeType.Parameters()[0]
case types.DynKind, types.ErrorKind, types.TypeParamKind:
// Set the range type to DYN to prevent assignment to a potentially incorrect type
// at a later point in type-checking. The isAssignable call will update the type
// substitutions for the type param under the covers.
c.isAssignable(decls.Dyn, rangeType)
c.isAssignable(types.DynType, rangeType)
// Set the range iteration variable to type DYN as well.
varType = decls.Dyn
varType = types.DynType
default:
c.errors.notAComprehensionRange(c.location(comp.GetIterRange()), rangeType)
varType = decls.Error
c.errors.notAComprehensionRange(comp.GetIterRange().GetId(), c.location(comp.GetIterRange()), rangeType)
varType = types.ErrorType
}
// Create a scope for the comprehension since it has a local accumulation variable.
// This scope will contain the accumulation variable used to compute the result.
c.env = c.env.enterScope()
c.env.Add(decls.NewVar(comp.GetAccuVar(), accuType))
c.env.AddIdents(decls.NewVariable(comp.GetAccuVar(), accuType))
// Create a block scope for the loop.
c.env = c.env.enterScope()
c.env.Add(decls.NewVar(comp.GetIterVar(), varType))
c.env.AddIdents(decls.NewVariable(comp.GetIterVar(), varType))
// Check the variable references in the condition and step.
c.check(comp.GetLoopCondition())
c.assertType(comp.GetLoopCondition(), decls.Bool)
c.assertType(comp.GetLoopCondition(), types.BoolType)
c.check(comp.GetLoopStep())
c.assertType(comp.GetLoopStep(), accuType)
// Exit the loop's block scope before checking the result.
@@ -509,9 +589,7 @@ func (c *checker) checkComprehension(e *exprpb.Expr) {
}
// Checks compatibility of joined types, and returns the most general common type.
func (c *checker) joinTypes(loc common.Location,
previous *exprpb.Type,
current *exprpb.Type) *exprpb.Type {
func (c *checker) joinTypes(e *exprpb.Expr, previous, current *types.Type) *types.Type {
if previous == nil {
return current
}
@@ -519,23 +597,23 @@ func (c *checker) joinTypes(loc common.Location,
return mostGeneral(previous, current)
}
if c.dynAggregateLiteralElementTypesEnabled() {
return decls.Dyn
return types.DynType
}
c.errors.typeMismatch(loc, previous, current)
return decls.Error
c.errors.typeMismatch(e.GetId(), c.location(e), previous, current)
return types.ErrorType
}
func (c *checker) dynAggregateLiteralElementTypesEnabled() bool {
return c.env.aggLitElemType == dynElementType
}
func (c *checker) newTypeVar() *exprpb.Type {
func (c *checker) newTypeVar() *types.Type {
id := c.freeTypeVarCounter
c.freeTypeVarCounter++
return decls.NewTypeParamType(fmt.Sprintf("_var%d", id))
return types.NewTypeParamType(fmt.Sprintf("_var%d", id))
}
func (c *checker) isAssignable(t1 *exprpb.Type, t2 *exprpb.Type) bool {
func (c *checker) isAssignable(t1, t2 *types.Type) bool {
subs := isAssignable(c.mappings, t1, t2)
if subs != nil {
c.mappings = subs
@@ -545,7 +623,7 @@ func (c *checker) isAssignable(t1 *exprpb.Type, t2 *exprpb.Type) bool {
return false
}
func (c *checker) isAssignableList(l1 []*exprpb.Type, l2 []*exprpb.Type) bool {
func (c *checker) isAssignableList(l1, l2 []*types.Type) bool {
subs := isAssignableList(c.mappings, l1, l2)
if subs != nil {
c.mappings = subs
@@ -555,57 +633,52 @@ func (c *checker) isAssignableList(l1 []*exprpb.Type, l2 []*exprpb.Type) bool {
return false
}
func (c *checker) lookupFieldType(l common.Location, messageType string, fieldName string) (*ref.FieldType, bool) {
if _, found := c.env.provider.FindType(messageType); !found {
// This should not happen, anyway, report an error.
c.errors.unexpectedFailedResolution(l, messageType)
return nil, false
func maybeUnwrapString(e *exprpb.Expr) (string, bool) {
switch e.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
literal := e.GetConstExpr()
switch literal.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
return literal.GetStringValue(), true
}
}
if ft, found := c.env.provider.FindFieldType(messageType, fieldName); found {
return ft, found
}
c.errors.undefinedField(l, fieldName)
return nil, false
return "", false
}
func (c *checker) setType(e *exprpb.Expr, t *exprpb.Type) {
if old, found := c.types[e.GetId()]; found && !proto.Equal(old, t) {
c.errors.ReportError(c.location(e),
"(Incompatible) Type already exists for expression: %v(%d) old:%v, new:%v", e, e.GetId(), old, t)
func (c *checker) setType(e *exprpb.Expr, t *types.Type) {
if old, found := c.types[e.GetId()]; found && !old.IsExactType(t) {
c.errors.incompatibleType(e.GetId(), c.location(e), e, old, t)
return
}
c.types[e.GetId()] = t
}
func (c *checker) getType(e *exprpb.Expr) *exprpb.Type {
func (c *checker) getType(e *exprpb.Expr) *types.Type {
return c.types[e.GetId()]
}
func (c *checker) setReference(e *exprpb.Expr, r *exprpb.Reference) {
if old, found := c.references[e.GetId()]; found && !proto.Equal(old, r) {
c.errors.ReportError(c.location(e),
"Reference already exists for expression: %v(%d) old:%v, new:%v", e, e.GetId(), old, r)
func (c *checker) setReference(e *exprpb.Expr, r *ast.ReferenceInfo) {
if old, found := c.references[e.GetId()]; found && !old.Equals(r) {
c.errors.referenceRedefinition(e.GetId(), c.location(e), e, old, r)
return
}
c.references[e.GetId()] = r
}
func (c *checker) assertType(e *exprpb.Expr, t *exprpb.Type) {
func (c *checker) assertType(e *exprpb.Expr, t *types.Type) {
if !c.isAssignable(t, c.getType(e)) {
c.errors.typeMismatch(c.location(e), t, c.getType(e))
c.errors.typeMismatch(e.GetId(), c.location(e), t, c.getType(e))
}
}
type overloadResolution struct {
Reference *exprpb.Reference
Type *exprpb.Type
Type *types.Type
Reference *ast.ReferenceInfo
}
func newResolution(checkedRef *exprpb.Reference, t *exprpb.Type) *overloadResolution {
func newResolution(r *ast.ReferenceInfo, t *types.Type) *overloadResolution {
return &overloadResolution{
Reference: checkedRef,
Reference: r,
Type: t,
}
}
@@ -632,10 +705,56 @@ func (c *checker) locationByID(id int64) common.Location {
return common.NoLocation
}
func newIdentReference(name string, value *exprpb.Constant) *exprpb.Reference {
return &exprpb.Reference{Name: name, Value: value}
func (c *checker) lookupFieldType(exprID int64, structType, fieldName string) (*types.Type, bool) {
if _, found := c.env.provider.FindStructType(structType); !found {
// This should not happen, anyway, report an error.
c.errors.unexpectedFailedResolution(exprID, c.locationByID(exprID), structType)
return nil, false
}
if ft, found := c.env.provider.FindStructFieldType(structType, fieldName); found {
return ft.Type, found
}
c.errors.undefinedField(exprID, c.locationByID(exprID), fieldName)
return nil, false
}
func newFunctionReference(overloads ...string) *exprpb.Reference {
return &exprpb.Reference{OverloadId: overloads}
func isWellKnownType(t *types.Type) bool {
switch t.Kind() {
case types.AnyKind, types.TimestampKind, types.DurationKind, types.DynKind, types.NullTypeKind:
return true
case types.BoolKind, types.BytesKind, types.DoubleKind, types.IntKind, types.StringKind, types.UintKind:
return t.IsAssignableType(types.NullType)
case types.ListKind:
return t.Parameters()[0] == types.DynType
case types.MapKind:
return t.Parameters()[0] == types.StringType && t.Parameters()[1] == types.DynType
}
return false
}
func getWellKnownTypeName(t *types.Type) string {
if name, found := wellKnownTypes[t.Kind()]; found {
return name
}
return ""
}
var (
wellKnownTypes = map[types.Kind]string{
types.AnyKind: "google.protobuf.Any",
types.BoolKind: "google.protobuf.BoolValue",
types.BytesKind: "google.protobuf.BytesValue",
types.DoubleKind: "google.protobuf.DoubleValue",
types.DurationKind: "google.protobuf.Duration",
types.DynKind: "google.protobuf.Value",
types.IntKind: "google.protobuf.Int64Value",
types.ListKind: "google.protobuf.ListValue",
types.NullTypeKind: "google.protobuf.NullValue",
types.MapKind: "google.protobuf.Struct",
types.StringKind: "google.protobuf.StringValue",
types.TimestampKind: "google.protobuf.Timestamp",
types.UintKind: "google.protobuf.UInt64Value",
}
)

View File

@@ -18,7 +18,9 @@ import (
"math"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@@ -54,7 +56,7 @@ type AstNode interface {
// The first path element is a variable. All subsequent path elements are one of: field name, '@items', '@keys', '@values'.
Path() []string
// Type returns the deduced type of the AstNode.
Type() *exprpb.Type
Type() *types.Type
// Expr returns the expression of the AstNode.
Expr() *exprpb.Expr
// ComputedSize returns a size estimate of the AstNode derived from information available in the CEL expression.
@@ -66,7 +68,7 @@ type AstNode interface {
type astNode struct {
path []string
t *exprpb.Type
t *types.Type
expr *exprpb.Expr
derivedSize *SizeEstimate
}
@@ -75,7 +77,7 @@ func (e astNode) Path() []string {
return e.path
}
func (e astNode) Type() *exprpb.Type {
func (e astNode) Type() *types.Type {
return e.t
}
@@ -92,7 +94,10 @@ func (e astNode) ComputedSize() *SizeEstimate {
case *exprpb.Expr_ConstExpr:
switch ck := ek.ConstExpr.GetConstantKind().(type) {
case *exprpb.Constant_StringValue:
v = uint64(len(ck.StringValue))
// converting to runes here is an O(n) operation, but
// this is consistent with how size is computed at runtime,
// and how the language definition defines string size
v = uint64(len([]rune(ck.StringValue)))
case *exprpb.Constant_BytesValue:
v = uint64(len(ck.BytesValue))
case *exprpb.Constant_BoolValue, *exprpb.Constant_DoubleValue, *exprpb.Constant_DurationValue,
@@ -225,7 +230,7 @@ func addUint64NoOverflow(x, y uint64) uint64 {
// multiplyUint64NoOverflow multiplies non-negative ints. If the result is exceeds math.MaxUint64, math.MaxUint64
// is returned.
func multiplyUint64NoOverflow(x, y uint64) uint64 {
if x > 0 && y > 0 && x > math.MaxUint64/y {
if y != 0 && x > math.MaxUint64/y {
return math.MaxUint64
}
return x * y
@@ -237,7 +242,11 @@ func multiplyByCostFactor(x uint64, y float64) uint64 {
if xFloat > 0 && y > 0 && xFloat > math.MaxUint64/y {
return math.MaxUint64
}
return uint64(math.Ceil(xFloat * y))
ceil := math.Ceil(xFloat * y)
if ceil >= doubleTwoTo64 {
return math.MaxUint64
}
return uint64(ceil)
}
var (
@@ -255,9 +264,12 @@ type coster struct {
// iterRanges tracks the iterRange of each iterVar.
iterRanges iterRangeScopes
// computedSizes tracks the computed sizes of call results.
computedSizes map[int64]SizeEstimate
checkedExpr *exprpb.CheckedExpr
estimator CostEstimator
computedSizes map[int64]SizeEstimate
checkedAST *ast.CheckedAST
estimator CostEstimator
overloadEstimators map[string]FunctionEstimator
// presenceTestCost will either be a zero or one based on whether has() macros count against cost computations.
presenceTestCost CostEstimate
}
// Use a stack of iterVar -> iterRange Expr Ids to handle shadowed variable names.
@@ -280,16 +292,55 @@ func (vs iterRangeScopes) peek(varName string) (int64, bool) {
return 0, false
}
// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checker *exprpb.CheckedExpr, estimator CostEstimator) CostEstimate {
c := coster{
checkedExpr: checker,
estimator: estimator,
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
// CostOption configures flags which affect cost computations.
type CostOption func(*coster) error
// PresenceTestHasCost determines whether presence testing has a cost of one or zero.
//
// Defaults to presence test has a cost of one.
func PresenceTestHasCost(hasCost bool) CostOption {
return func(c *coster) error {
if hasCost {
c.presenceTestCost = selectAndIdentCost
return nil
}
c.presenceTestCost = CostEstimate{Min: 0, Max: 0}
return nil
}
return c.cost(checker.GetExpr())
}
// FunctionEstimator provides a CallEstimate given the target and arguments for a specific function, overload pair.
type FunctionEstimator func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate
// OverloadCostEstimate binds a FunctionCoster to a specific function overload ID.
//
// When a OverloadCostEstimate is provided, it will override the cost calculation of the CostEstimator provided to
// the Cost() call.
func OverloadCostEstimate(overloadID string, functionCoster FunctionEstimator) CostOption {
return func(c *coster) error {
c.overloadEstimators[overloadID] = functionCoster
return nil
}
}
// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checker *ast.CheckedAST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) {
c := &coster{
checkedAST: checker,
estimator: estimator,
overloadEstimators: map[string]FunctionEstimator{},
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
}
for _, opt := range opts {
err := opt(c)
if err != nil {
return CostEstimate{}, err
}
}
return c.cost(checker.Expr), nil
}
func (c *coster) cost(e *exprpb.Expr) CostEstimate {
@@ -323,10 +374,10 @@ func (c *coster) costIdent(e *exprpb.Expr) CostEstimate {
// build and track the field path
if iterRange, ok := c.iterRanges.peek(identExpr.GetName()); ok {
switch c.checkedExpr.TypeMap[iterRange].GetTypeKind().(type) {
case *exprpb.Type_ListType_:
switch c.checkedAST.TypeMap[iterRange].Kind() {
case types.ListKind:
c.addPath(e, append(c.exprPath[iterRange], "@items"))
case *exprpb.Type_MapType_:
case types.MapKind:
c.addPath(e, append(c.exprPath[iterRange], "@keys"))
}
} else {
@@ -340,12 +391,18 @@ func (c *coster) costSelect(e *exprpb.Expr) CostEstimate {
sel := e.GetSelectExpr()
var sum CostEstimate
if sel.GetTestOnly() {
// recurse, but do not add any cost
// this is equivalent to how evalTestOnly increments the runtime cost counter
// but does not add any additional cost for the qualifier, except here we do
// the reverse (ident adds cost)
sum = sum.Add(c.presenceTestCost)
sum = sum.Add(c.cost(sel.GetOperand()))
return sum
}
sum = sum.Add(c.cost(sel.GetOperand()))
targetType := c.getType(sel.GetOperand())
switch kindOf(targetType) {
case kindMap, kindObject, kindTypeParam:
switch targetType.Kind() {
case types.MapKind, types.StructKind, types.TypeParamKind:
sum = sum.Add(selectAndIdentCost)
}
@@ -369,8 +426,8 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
argTypes[i] = c.newAstNode(arg)
}
ref := c.checkedExpr.ReferenceMap[e.GetId()]
if ref == nil || len(ref.GetOverloadId()) == 0 {
ref := c.checkedAST.ReferenceMap[e.GetId()]
if ref == nil || len(ref.OverloadIDs) == 0 {
return CostEstimate{}
}
var targetType AstNode
@@ -383,7 +440,7 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
// Pick a cost estimate range that covers all the overload cost estimation ranges
fnCost := CostEstimate{Min: uint64(math.MaxUint64), Max: 0}
var resultSize *SizeEstimate
for _, overload := range ref.GetOverloadId() {
for _, overload := range ref.OverloadIDs {
overloadCost := c.functionCost(call.GetFunction(), overload, &targetType, argTypes, argCosts)
fnCost = fnCost.Union(overloadCost.CostEstimate)
if overloadCost.ResultSize != nil {
@@ -496,14 +553,44 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args
}
return sum
}
if len(c.overloadEstimators) != 0 {
if estimator, found := c.overloadEstimators[overloadID]; found {
if est := estimator(c.estimator, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize}
}
}
}
if est := c.estimator.EstimateCallCost(function, overloadID, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum())}
return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize}
}
switch overloadID {
// O(n) functions
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString:
case overloads.ExtFormatString:
if target != nil {
// ResultSize not calculated because we can't bound the max size.
return CallEstimate{CostEstimate: c.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())}
}
case overloads.StringToBytes:
if len(args) == 1 {
sz := c.sizeEstimate(args[0])
// ResultSize max is when each char converts to 4 bytes.
return CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &SizeEstimate{Min: sz.Min, Max: sz.Max * 4}}
}
case overloads.BytesToString:
if len(args) == 1 {
sz := c.sizeEstimate(args[0])
// ResultSize min is when 4 bytes convert to 1 char.
return CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &SizeEstimate{Min: sz.Min / 4, Max: sz.Max}}
}
case overloads.ExtQuoteString:
if len(args) == 1 {
sz := c.sizeEstimate(args[0])
// ResultSize max is when each char is escaped. 2 quote chars always added.
return CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &SizeEstimate{Min: sz.Min + 2, Max: sz.Max*2 + 2}}
}
case overloads.StartsWithString, overloads.EndsWithString:
if len(args) == 1 {
return CallEstimate{CostEstimate: c.sizeEstimate(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())}
}
@@ -584,8 +671,8 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args
return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}.Add(argCostSum())}
}
func (c *coster) getType(e *exprpb.Expr) *exprpb.Type {
return c.checkedExpr.TypeMap[e.GetId()]
func (c *coster) getType(e *exprpb.Expr) *types.Type {
return c.checkedAST.TypeMap[e.GetId()]
}
func (c *coster) getPath(e *exprpb.Expr) []string {
@@ -606,22 +693,24 @@ func (c *coster) newAstNode(e *exprpb.Expr) *astNode {
if size, ok := c.computedSizes[e.GetId()]; ok {
derivedSize = &size
}
return &astNode{path: path, t: c.getType(e), expr: e, derivedSize: derivedSize}
return &astNode{
path: path,
t: c.getType(e),
expr: e,
derivedSize: derivedSize}
}
// isScalar returns true if the given type is known to be of a constant size at
// compile time. isScalar will return false for strings (they are variable-width)
// in addition to protobuf.Any and protobuf.Value (their size is not knowable at compile time).
func isScalar(t *exprpb.Type) bool {
switch kindOf(t) {
case kindPrimitive:
if t.GetPrimitive() != exprpb.Type_STRING && t.GetPrimitive() != exprpb.Type_BYTES {
return true
}
case kindWellKnown:
if t.GetWellKnown() == exprpb.Type_DURATION || t.GetWellKnown() == exprpb.Type_TIMESTAMP {
return true
}
func isScalar(t *types.Type) bool {
switch t.Kind() {
case types.BoolKind, types.DoubleKind, types.DurationKind, types.IntKind, types.TimestampKind, types.UintKind:
return true
}
return false
}
var (
doubleTwoTo64 = math.Ldexp(1.0, 64)
)

View File

@@ -9,11 +9,10 @@ go_library(
name = "go_default_library",
srcs = [
"decls.go",
"scopes.go",
],
importpath = "github.com/google/cel-go/checker/decls",
deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//types/known/emptypb:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
],

View File

@@ -16,9 +16,9 @@
package decls
import (
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
emptypb "google.golang.org/protobuf/types/known/emptypb"
structpb "google.golang.org/protobuf/types/known/structpb"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
var (
@@ -64,6 +64,12 @@ func NewAbstractType(name string, paramTypes ...*exprpb.Type) *exprpb.Type {
ParameterTypes: paramTypes}}}
}
// NewOptionalType constructs an abstract type indicating that the parameterized type
// may be contained within the object.
func NewOptionalType(paramType *exprpb.Type) *exprpb.Type {
return NewAbstractType("optional", paramType)
}
// NewFunctionType creates a function invocation contract, typically only used
// by type-checking steps after overload resolution.
func NewFunctionType(resultType *exprpb.Type,

View File

@@ -18,17 +18,11 @@ import (
"fmt"
"strings"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
type aggregateLiteralElementType int
@@ -76,15 +70,15 @@ var (
// which can be used to assist with type-checking.
type Env struct {
container *containers.Container
provider ref.TypeProvider
declarations *decls.Scopes
provider types.Provider
declarations *Scopes
aggLitElemType aggregateLiteralElementType
filteredOverloadIDs map[string]struct{}
}
// NewEnv returns a new *Env with the given parameters.
func NewEnv(container *containers.Container, provider ref.TypeProvider, opts ...Option) (*Env, error) {
declarations := decls.NewScopes()
func NewEnv(container *containers.Container, provider types.Provider, opts ...Option) (*Env, error) {
declarations := newScopes()
declarations.Push()
envOptions := &options{}
@@ -113,24 +107,31 @@ func NewEnv(container *containers.Container, provider ref.TypeProvider, opts ...
}, nil
}
// Add adds new Decl protos to the Env.
// Returns an error for identifier redeclarations.
func (e *Env) Add(decls ...*exprpb.Decl) error {
// AddIdents configures the checker with a list of variable declarations.
//
// If there are overlapping declarations, the method will error.
func (e *Env) AddIdents(declarations ...*decls.VariableDecl) error {
errMsgs := make([]errorMsg, 0)
for _, decl := range decls {
switch decl.DeclKind.(type) {
case *exprpb.Decl_Ident:
errMsgs = append(errMsgs, e.addIdent(sanitizeIdent(decl)))
case *exprpb.Decl_Function:
errMsgs = append(errMsgs, e.setFunction(sanitizeFunction(decl))...)
}
for _, d := range declarations {
errMsgs = append(errMsgs, e.addIdent(d))
}
return formatError(errMsgs)
}
// AddFunctions configures the checker with a list of function declarations.
//
// If there are overlapping declarations, the method will error.
func (e *Env) AddFunctions(declarations ...*decls.FunctionDecl) error {
errMsgs := make([]errorMsg, 0)
for _, d := range declarations {
errMsgs = append(errMsgs, e.setFunction(d)...)
}
return formatError(errMsgs)
}
// LookupIdent returns a Decl proto for typeName as an identifier in the Env.
// Returns nil if no such identifier is found in the Env.
func (e *Env) LookupIdent(name string) *exprpb.Decl {
func (e *Env) LookupIdent(name string) *decls.VariableDecl {
for _, candidate := range e.container.ResolveCandidateNames(name) {
if ident := e.declarations.FindIdent(candidate); ident != nil {
return ident
@@ -139,8 +140,8 @@ func (e *Env) LookupIdent(name string) *exprpb.Decl {
// Next try to import the name as a reference to a message type. If found,
// the declaration is added to the outest (global) scope of the
// environment, so next time we can access it faster.
if t, found := e.provider.FindType(candidate); found {
decl := decls.NewVar(candidate, t)
if t, found := e.provider.FindStructType(candidate); found {
decl := decls.NewVariable(candidate, t)
e.declarations.AddIdent(decl)
return decl
}
@@ -148,11 +149,7 @@ func (e *Env) LookupIdent(name string) *exprpb.Decl {
// Next try to import this as an enum value by splitting the name in a type prefix and
// the enum inside.
if enumValue := e.provider.EnumValue(candidate); enumValue.Type() != types.ErrType {
decl := decls.NewIdent(candidate,
decls.Int,
&exprpb.Constant{
ConstantKind: &exprpb.Constant_Int64Value{
Int64Value: int64(enumValue.(types.Int))}})
decl := decls.NewConstant(candidate, types.IntType, enumValue)
e.declarations.AddIdent(decl)
return decl
}
@@ -162,7 +159,7 @@ func (e *Env) LookupIdent(name string) *exprpb.Decl {
// LookupFunction returns a Decl proto for typeName as a function in env.
// Returns nil if no such function is found in env.
func (e *Env) LookupFunction(name string) *exprpb.Decl {
func (e *Env) LookupFunction(name string) *decls.FunctionDecl {
for _, candidate := range e.container.ResolveCandidateNames(name) {
if fn := e.declarations.FindFunction(candidate); fn != nil {
return fn
@@ -171,88 +168,46 @@ func (e *Env) LookupFunction(name string) *exprpb.Decl {
return nil
}
// addOverload adds overload to function declaration f.
// Returns one or more errorMsg values if the overload overlaps with an existing overload or macro.
func (e *Env) addOverload(f *exprpb.Decl, overload *exprpb.Decl_FunctionDecl_Overload) []errorMsg {
errMsgs := make([]errorMsg, 0)
function := f.GetFunction()
emptyMappings := newMapping()
overloadFunction := decls.NewFunctionType(overload.GetResultType(),
overload.GetParams()...)
overloadErased := substitute(emptyMappings, overloadFunction, true)
for _, existing := range function.GetOverloads() {
existingFunction := decls.NewFunctionType(existing.GetResultType(), existing.GetParams()...)
existingErased := substitute(emptyMappings, existingFunction, true)
overlap := isAssignable(emptyMappings, overloadErased, existingErased) != nil ||
isAssignable(emptyMappings, existingErased, overloadErased) != nil
if overlap &&
overload.GetIsInstanceFunction() == existing.GetIsInstanceFunction() {
errMsgs = append(errMsgs,
overlappingOverloadError(f.Name,
overload.GetOverloadId(), overloadFunction,
existing.GetOverloadId(), existingFunction))
}
}
for _, macro := range parser.AllMacros {
if macro.Function() == f.Name &&
macro.IsReceiverStyle() == overload.GetIsInstanceFunction() &&
macro.ArgCount() == len(overload.GetParams()) {
errMsgs = append(errMsgs, overlappingMacroError(f.Name, macro.ArgCount()))
}
}
if len(errMsgs) > 0 {
return errMsgs
}
function.Overloads = append(function.GetOverloads(), overload)
return errMsgs
}
// setFunction adds the function Decl to the Env.
// Adds a function decl if one doesn't already exist, then adds all overloads from the Decl.
// If overload overlaps with an existing overload, adds to the errors in the Env instead.
func (e *Env) setFunction(decl *exprpb.Decl) []errorMsg {
errorMsgs := make([]errorMsg, 0)
overloads := decl.GetFunction().GetOverloads()
current := e.declarations.FindFunction(decl.Name)
if current == nil {
//Add the function declaration without overloads and check the overloads below.
current = decls.NewFunction(decl.Name)
} else {
existingOverloads := map[string]*exprpb.Decl_FunctionDecl_Overload{}
for _, overload := range current.GetFunction().GetOverloads() {
existingOverloads[overload.GetOverloadId()] = overload
func (e *Env) setFunction(fn *decls.FunctionDecl) []errorMsg {
errMsgs := make([]errorMsg, 0)
current := e.declarations.FindFunction(fn.Name())
if current != nil {
var err error
current, err = current.Merge(fn)
if err != nil {
return append(errMsgs, errorMsg(err.Error()))
}
newOverloads := []*exprpb.Decl_FunctionDecl_Overload{}
for _, overload := range overloads {
existing, found := existingOverloads[overload.GetOverloadId()]
if !found || !proto.Equal(existing, overload) {
newOverloads = append(newOverloads, overload)
} else {
current = fn
}
for _, overload := range current.OverloadDecls() {
for _, macro := range parser.AllMacros {
if macro.Function() == current.Name() &&
macro.IsReceiverStyle() == overload.IsMemberFunction() &&
macro.ArgCount() == len(overload.ArgTypes()) {
errMsgs = append(errMsgs, overlappingMacroError(current.Name(), macro.ArgCount()))
}
}
overloads = newOverloads
if len(newOverloads) == 0 {
return errorMsgs
if len(errMsgs) > 0 {
return errMsgs
}
// Copy on write since we don't know where this original definition came from.
current = proto.Clone(current).(*exprpb.Decl)
}
e.declarations.SetFunction(current)
for _, overload := range overloads {
errorMsgs = append(errorMsgs, e.addOverload(current, overload)...)
}
return errorMsgs
return errMsgs
}
// addIdent adds the Decl to the declarations in the Env.
// Returns a non-empty errorMsg if the identifier is already declared in the scope.
func (e *Env) addIdent(decl *exprpb.Decl) errorMsg {
current := e.declarations.FindIdentInScope(decl.Name)
func (e *Env) addIdent(decl *decls.VariableDecl) errorMsg {
current := e.declarations.FindIdentInScope(decl.Name())
if current != nil {
if proto.Equal(current, decl) {
if current.DeclarationIsEquivalent(decl) {
return ""
}
return overlappingIdentifierError(decl.Name)
return overlappingIdentifierError(decl.Name())
}
e.declarations.AddIdent(decl)
return ""
@@ -264,86 +219,9 @@ func (e *Env) isOverloadDisabled(overloadID string) bool {
return found
}
// sanitizeFunction replaces well-known types referenced by message name with their equivalent
// CEL built-in type instances.
func sanitizeFunction(decl *exprpb.Decl) *exprpb.Decl {
fn := decl.GetFunction()
// Determine whether the declaration requires replacements from proto-based message type
// references to well-known CEL type references.
var needsSanitizing bool
for _, o := range fn.GetOverloads() {
if isObjectWellKnownType(o.GetResultType()) {
needsSanitizing = true
break
}
for _, p := range o.GetParams() {
if isObjectWellKnownType(p) {
needsSanitizing = true
break
}
}
}
// Early return if the declaration requires no modification.
if !needsSanitizing {
return decl
}
// Sanitize all of the overloads if any overload requires an update to its type references.
overloads := make([]*exprpb.Decl_FunctionDecl_Overload, len(fn.GetOverloads()))
for i, o := range fn.GetOverloads() {
rt := o.GetResultType()
if isObjectWellKnownType(rt) {
rt = getObjectWellKnownType(rt)
}
params := make([]*exprpb.Type, len(o.GetParams()))
copy(params, o.GetParams())
for j, p := range params {
if isObjectWellKnownType(p) {
params[j] = getObjectWellKnownType(p)
}
}
// If sanitized, replace the overload definition.
if o.IsInstanceFunction {
overloads[i] =
decls.NewInstanceOverload(o.GetOverloadId(), params, rt)
} else {
overloads[i] =
decls.NewOverload(o.GetOverloadId(), params, rt)
}
}
return decls.NewFunction(decl.GetName(), overloads...)
}
// sanitizeIdent replaces the identifier's well-known types referenced by message name with
// references to CEL built-in type instances.
func sanitizeIdent(decl *exprpb.Decl) *exprpb.Decl {
id := decl.GetIdent()
t := id.GetType()
if !isObjectWellKnownType(t) {
return decl
}
return decls.NewIdent(decl.GetName(), getObjectWellKnownType(t), id.GetValue())
}
// isObjectWellKnownType returns true if the input type is an OBJECT type with a message name
// that corresponds the message name of a built-in CEL type.
func isObjectWellKnownType(t *exprpb.Type) bool {
if kindOf(t) != kindObject {
return false
}
_, found := pb.CheckedWellKnowns[t.GetMessageType()]
return found
}
// getObjectWellKnownType returns the built-in CEL type declaration for input type's message name.
func getObjectWellKnownType(t *exprpb.Type) *exprpb.Type {
return pb.CheckedWellKnowns[t.GetMessageType()]
}
// validatedDeclarations returns a reference to the validated variable and function declaration scope stack.
// must be copied before use.
func (e *Env) validatedDeclarations() *decls.Scopes {
func (e *Env) validatedDeclarations() *Scopes {
return e.declarations
}
@@ -377,19 +255,6 @@ func overlappingIdentifierError(name string) errorMsg {
return errorMsg(fmt.Sprintf("overlapping identifier for name '%s'", name))
}
func overlappingOverloadError(name string,
overloadID1 string, f1 *exprpb.Type,
overloadID2 string, f2 *exprpb.Type) errorMsg {
return errorMsg(fmt.Sprintf(
"overlapping overload for name '%s' (type '%s' with overloadId: '%s' "+
"cannot be distinguished from '%s' with overloadId: '%s')",
name,
FormatCheckedType(f1),
overloadID1,
FormatCheckedType(f2),
overloadID2))
}
func overlappingMacroError(name string, argCount int) errorMsg {
return errorMsg(fmt.Sprintf(
"overlapping macro for name '%s' with %d args", name, argCount))

View File

@@ -15,82 +15,78 @@
package checker
import (
"reflect"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// typeErrors is a specialization of Errors.
type typeErrors struct {
*common.Errors
errs *common.Errors
}
func (e *typeErrors) undeclaredReference(l common.Location, container string, name string) {
e.ReportError(l, "undeclared reference to '%s' (in container '%s')", name, container)
func (e *typeErrors) fieldTypeMismatch(id int64, l common.Location, name string, field, value *types.Type) {
e.errs.ReportErrorAtID(id, l, "expected type of field '%s' is '%s' but provided type is '%s'",
name, FormatCELType(field), FormatCELType(value))
}
func (e *typeErrors) typeDoesNotSupportFieldSelection(l common.Location, t *exprpb.Type) {
e.ReportError(l, "type '%s' does not support field selection", t)
func (e *typeErrors) incompatibleType(id int64, l common.Location, ex *exprpb.Expr, prev, next *types.Type) {
e.errs.ReportErrorAtID(id, l,
"incompatible type already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next)
}
func (e *typeErrors) undefinedField(l common.Location, field string) {
e.ReportError(l, "undefined field '%s'", field)
func (e *typeErrors) noMatchingOverload(id int64, l common.Location, name string, args []*types.Type, isInstance bool) {
signature := formatFunctionDeclType(nil, args, isInstance)
e.errs.ReportErrorAtID(id, l, "found no matching overload for '%s' applied to '%s'", name, signature)
}
func (e *typeErrors) noMatchingOverload(l common.Location, name string, args []*exprpb.Type, isInstance bool) {
signature := formatFunction(nil, args, isInstance)
e.ReportError(l, "found no matching overload for '%s' applied to '%s'", name, signature)
func (e *typeErrors) notAComprehensionRange(id int64, l common.Location, t *types.Type) {
e.errs.ReportErrorAtID(id, l, "expression of type '%s' cannot be range of a comprehension (must be list, map, or dynamic)",
FormatCELType(t))
}
func (e *typeErrors) notAType(l common.Location, t *exprpb.Type) {
e.ReportError(l, "'%s(%v)' is not a type", FormatCheckedType(t), t)
func (e *typeErrors) notAnOptionalFieldSelection(id int64, l common.Location, field *exprpb.Expr) {
e.errs.ReportErrorAtID(id, l, "unsupported optional field selection: %v", field)
}
func (e *typeErrors) notAMessageType(l common.Location, t *exprpb.Type) {
e.ReportError(l, "'%s' is not a message type", FormatCheckedType(t))
func (e *typeErrors) notAType(id int64, l common.Location, typeName string) {
e.errs.ReportErrorAtID(id, l, "'%s' is not a type", typeName)
}
func (e *typeErrors) fieldTypeMismatch(l common.Location, name string, field *exprpb.Type, value *exprpb.Type) {
e.ReportError(l, "expected type of field '%s' is '%s' but provided type is '%s'",
name, FormatCheckedType(field), FormatCheckedType(value))
func (e *typeErrors) notAMessageType(id int64, l common.Location, typeName string) {
e.errs.ReportErrorAtID(id, l, "'%s' is not a message type", typeName)
}
func (e *typeErrors) unexpectedFailedResolution(l common.Location, typeName string) {
e.ReportError(l, "[internal] unexpected failed resolution of '%s'", typeName)
func (e *typeErrors) referenceRedefinition(id int64, l common.Location, ex *exprpb.Expr, prev, next *ast.ReferenceInfo) {
e.errs.ReportErrorAtID(id, l,
"reference already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next)
}
func (e *typeErrors) notAComprehensionRange(l common.Location, t *exprpb.Type) {
e.ReportError(l, "expression of type '%s' cannot be range of a comprehension (must be list, map, or dynamic)",
FormatCheckedType(t))
func (e *typeErrors) typeDoesNotSupportFieldSelection(id int64, l common.Location, t *types.Type) {
e.errs.ReportErrorAtID(id, l, "type '%s' does not support field selection", FormatCELType(t))
}
func (e *typeErrors) typeMismatch(l common.Location, expected *exprpb.Type, actual *exprpb.Type) {
e.ReportError(l, "expected type '%s' but found '%s'",
FormatCheckedType(expected), FormatCheckedType(actual))
func (e *typeErrors) typeMismatch(id int64, l common.Location, expected, actual *types.Type) {
e.errs.ReportErrorAtID(id, l, "expected type '%s' but found '%s'",
FormatCELType(expected), FormatCELType(actual))
}
func formatFunction(resultType *exprpb.Type, argTypes []*exprpb.Type, isInstance bool) string {
result := ""
if isInstance {
target := argTypes[0]
argTypes = argTypes[1:]
result += FormatCheckedType(target)
result += "."
}
result += "("
for i, arg := range argTypes {
if i > 0 {
result += ", "
}
result += FormatCheckedType(arg)
}
result += ")"
if resultType != nil {
result += " -> "
result += FormatCheckedType(resultType)
}
return result
func (e *typeErrors) undefinedField(id int64, l common.Location, field string) {
e.errs.ReportErrorAtID(id, l, "undefined field '%s'", field)
}
func (e *typeErrors) undeclaredReference(id int64, l common.Location, container string, name string) {
e.errs.ReportErrorAtID(id, l, "undeclared reference to '%s' (in container '%s')", name, container)
}
func (e *typeErrors) unexpectedFailedResolution(id int64, l common.Location, typeName string) {
e.errs.ReportErrorAtID(id, l, "unexpected failed resolution of '%s'", typeName)
}
func (e *typeErrors) unexpectedASTType(id int64, l common.Location, ex *exprpb.Expr) {
e.errs.ReportErrorAtID(id, l, "unrecognized ast type: %v", reflect.TypeOf(ex))
}

216
vendor/github.com/google/cel-go/checker/format.go generated vendored Normal file
View File

@@ -0,0 +1,216 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package checker
import (
"fmt"
"strings"
chkdecls "github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
const (
kindUnknown = iota + 1
kindError
kindFunction
kindDyn
kindPrimitive
kindWellKnown
kindWrapper
kindNull
kindAbstract
kindType
kindList
kindMap
kindObject
kindTypeParam
)
// FormatCheckedType converts a type message into a string representation.
func FormatCheckedType(t *exprpb.Type) string {
switch kindOf(t) {
case kindDyn:
return "dyn"
case kindFunction:
return formatFunctionExprType(t.GetFunction().GetResultType(),
t.GetFunction().GetArgTypes(),
false)
case kindList:
return fmt.Sprintf("list(%s)", FormatCheckedType(t.GetListType().GetElemType()))
case kindObject:
return t.GetMessageType()
case kindMap:
return fmt.Sprintf("map(%s, %s)",
FormatCheckedType(t.GetMapType().GetKeyType()),
FormatCheckedType(t.GetMapType().GetValueType()))
case kindNull:
return "null"
case kindPrimitive:
switch t.GetPrimitive() {
case exprpb.Type_UINT64:
return "uint"
case exprpb.Type_INT64:
return "int"
}
return strings.Trim(strings.ToLower(t.GetPrimitive().String()), " ")
case kindType:
if t.GetType() == nil || t.GetType().GetTypeKind() == nil {
return "type"
}
return fmt.Sprintf("type(%s)", FormatCheckedType(t.GetType()))
case kindWellKnown:
switch t.GetWellKnown() {
case exprpb.Type_ANY:
return "any"
case exprpb.Type_DURATION:
return "duration"
case exprpb.Type_TIMESTAMP:
return "timestamp"
}
case kindWrapper:
return fmt.Sprintf("wrapper(%s)",
FormatCheckedType(chkdecls.NewPrimitiveType(t.GetWrapper())))
case kindError:
return "!error!"
case kindTypeParam:
return t.GetTypeParam()
case kindAbstract:
at := t.GetAbstractType()
params := at.GetParameterTypes()
paramStrs := make([]string, len(params))
for i, p := range params {
paramStrs[i] = FormatCheckedType(p)
}
return fmt.Sprintf("%s(%s)", at.GetName(), strings.Join(paramStrs, ", "))
}
return t.String()
}
type formatter func(any) string
// FormatCELType formats a types.Type value to a string representation.
//
// The type formatting is identical to FormatCheckedType.
func FormatCELType(t any) string {
dt := t.(*types.Type)
switch dt.Kind() {
case types.AnyKind:
return "any"
case types.DurationKind:
return "duration"
case types.ErrorKind:
return "!error!"
case types.NullTypeKind:
return "null"
case types.TimestampKind:
return "timestamp"
case types.TypeParamKind:
return dt.TypeName()
case types.OpaqueKind:
if dt.TypeName() == "function" {
// There is no explicit function type in the new types representation, so information like
// whether the function is a member function is absent.
return formatFunctionDeclType(dt.Parameters()[0], dt.Parameters()[1:], false)
}
case types.UnspecifiedKind:
return ""
}
if len(dt.Parameters()) == 0 {
return dt.DeclaredTypeName()
}
paramTypeNames := make([]string, 0, len(dt.Parameters()))
for _, p := range dt.Parameters() {
paramTypeNames = append(paramTypeNames, FormatCELType(p))
}
return fmt.Sprintf("%s(%s)", dt.TypeName(), strings.Join(paramTypeNames, ", "))
}
func formatExprType(t any) string {
if t == nil {
return ""
}
return FormatCheckedType(t.(*exprpb.Type))
}
func formatFunctionExprType(resultType *exprpb.Type, argTypes []*exprpb.Type, isInstance bool) string {
return formatFunctionInternal[*exprpb.Type](resultType, argTypes, isInstance, formatExprType)
}
func formatFunctionDeclType(resultType *types.Type, argTypes []*types.Type, isInstance bool) string {
return formatFunctionInternal[*types.Type](resultType, argTypes, isInstance, FormatCELType)
}
func formatFunctionInternal[T any](resultType T, argTypes []T, isInstance bool, format formatter) string {
result := ""
if isInstance {
target := argTypes[0]
argTypes = argTypes[1:]
result += format(target)
result += "."
}
result += "("
for i, arg := range argTypes {
if i > 0 {
result += ", "
}
result += format(arg)
}
result += ")"
rt := format(resultType)
if rt != "" {
result += " -> "
result += rt
}
return result
}
// kindOf returns the kind of the type as defined in the checked.proto.
func kindOf(t *exprpb.Type) int {
if t == nil || t.TypeKind == nil {
return kindUnknown
}
switch t.GetTypeKind().(type) {
case *exprpb.Type_Error:
return kindError
case *exprpb.Type_Function:
return kindFunction
case *exprpb.Type_Dyn:
return kindDyn
case *exprpb.Type_Primitive:
return kindPrimitive
case *exprpb.Type_WellKnown:
return kindWellKnown
case *exprpb.Type_Wrapper:
return kindWrapper
case *exprpb.Type_Null:
return kindNull
case *exprpb.Type_Type:
return kindType
case *exprpb.Type_ListType_:
return kindList
case *exprpb.Type_MapType_:
return kindMap
case *exprpb.Type_MessageType:
return kindObject
case *exprpb.Type_TypeParam:
return kindTypeParam
case *exprpb.Type_AbstractType_:
return kindAbstract
}
return kindUnknown
}

View File

@@ -15,25 +15,25 @@
package checker
import (
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/types"
)
type mapping struct {
mapping map[string]*exprpb.Type
mapping map[string]*types.Type
}
func newMapping() *mapping {
return &mapping{
mapping: make(map[string]*exprpb.Type),
mapping: make(map[string]*types.Type),
}
}
func (m *mapping) add(from *exprpb.Type, to *exprpb.Type) {
m.mapping[typeKey(from)] = to
func (m *mapping) add(from, to *types.Type) {
m.mapping[FormatCELType(from)] = to
}
func (m *mapping) find(from *exprpb.Type) (*exprpb.Type, bool) {
if r, found := m.mapping[typeKey(from)]; found {
func (m *mapping) find(from *types.Type) (*types.Type, bool) {
if r, found := m.mapping[FormatCELType(from)]; found {
return r, found
}
return nil, false

View File

@@ -14,12 +14,10 @@
package checker
import "github.com/google/cel-go/checker/decls"
type options struct {
crossTypeNumericComparisons bool
homogeneousAggregateLiterals bool
validatedDeclarations *decls.Scopes
validatedDeclarations *Scopes
}
// Option is a functional option for configuring the type-checker
@@ -34,15 +32,6 @@ func CrossTypeNumericComparisons(enabled bool) Option {
}
}
// HomogeneousAggregateLiterals toggles support for constructing lists and maps whose elements all
// have the same type.
func HomogeneousAggregateLiterals(enabled bool) Option {
return func(opts *options) error {
opts.homogeneousAggregateLiterals = enabled
return nil
}
}
// ValidatedDeclarations provides a references to validated declarations which will be copied
// into new checker instances.
func ValidatedDeclarations(env *Env) Option {

View File

@@ -15,6 +15,8 @@
package checker
import (
"sort"
"github.com/google/cel-go/common/debug"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@@ -26,7 +28,7 @@ type semanticAdorner struct {
var _ debug.Adorner = &semanticAdorner{}
func (a *semanticAdorner) GetMetadata(elem interface{}) string {
func (a *semanticAdorner) GetMetadata(elem any) string {
result := ""
e, isExpr := elem.(*exprpb.Expr)
if !isExpr {
@@ -47,6 +49,7 @@ func (a *semanticAdorner) GetMetadata(elem interface{}) string {
if len(ref.GetOverloadId()) == 0 {
result += "^" + ref.Name
} else {
sort.Strings(ref.GetOverloadId())
for i, overload := range ref.GetOverloadId() {
if i == 0 {
result += "^"

View File

@@ -12,9 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package decls
package checker
import exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
import (
"github.com/google/cel-go/common/decls"
)
// Scopes represents nested Decl sets where the Scopes value contains a Groups containing all
// identifiers in scope and an optional parent representing outer scopes.
@@ -25,9 +27,9 @@ type Scopes struct {
scopes *Group
}
// NewScopes creates a new, empty Scopes.
// newScopes creates a new, empty Scopes.
// Some operations can't be safely performed until a Group is added with Push.
func NewScopes() *Scopes {
func newScopes() *Scopes {
return &Scopes{
scopes: newGroup(),
}
@@ -35,7 +37,7 @@ func NewScopes() *Scopes {
// Copy creates a copy of the current Scopes values, including a copy of its parent if non-nil.
func (s *Scopes) Copy() *Scopes {
cpy := NewScopes()
cpy := newScopes()
if s == nil {
return cpy
}
@@ -66,14 +68,14 @@ func (s *Scopes) Pop() *Scopes {
// AddIdent adds the ident Decl in the current scope.
// Note: If the name collides with an existing identifier in the scope, the Decl is overwritten.
func (s *Scopes) AddIdent(decl *exprpb.Decl) {
s.scopes.idents[decl.Name] = decl
func (s *Scopes) AddIdent(decl *decls.VariableDecl) {
s.scopes.idents[decl.Name()] = decl
}
// FindIdent finds the first ident Decl with a matching name in Scopes, or nil if one cannot be
// found.
// Note: The search is performed from innermost to outermost.
func (s *Scopes) FindIdent(name string) *exprpb.Decl {
func (s *Scopes) FindIdent(name string) *decls.VariableDecl {
if ident, found := s.scopes.idents[name]; found {
return ident
}
@@ -86,7 +88,7 @@ func (s *Scopes) FindIdent(name string) *exprpb.Decl {
// FindIdentInScope finds the first ident Decl with a matching name in the current Scopes value, or
// nil if one does not exist.
// Note: The search is only performed on the current scope and does not search outer scopes.
func (s *Scopes) FindIdentInScope(name string) *exprpb.Decl {
func (s *Scopes) FindIdentInScope(name string) *decls.VariableDecl {
if ident, found := s.scopes.idents[name]; found {
return ident
}
@@ -95,14 +97,14 @@ func (s *Scopes) FindIdentInScope(name string) *exprpb.Decl {
// SetFunction adds the function Decl to the current scope.
// Note: Any previous entry for a function in the current scope with the same name is overwritten.
func (s *Scopes) SetFunction(fn *exprpb.Decl) {
s.scopes.functions[fn.Name] = fn
func (s *Scopes) SetFunction(fn *decls.FunctionDecl) {
s.scopes.functions[fn.Name()] = fn
}
// FindFunction finds the first function Decl with a matching name in Scopes.
// The search is performed from innermost to outermost.
// Returns nil if no such function in Scopes.
func (s *Scopes) FindFunction(name string) *exprpb.Decl {
func (s *Scopes) FindFunction(name string) *decls.FunctionDecl {
if fn, found := s.scopes.functions[name]; found {
return fn
}
@@ -116,16 +118,16 @@ func (s *Scopes) FindFunction(name string) *exprpb.Decl {
// Contains separate namespaces for identifier and function Decls.
// (Should be named "Scope" perhaps?)
type Group struct {
idents map[string]*exprpb.Decl
functions map[string]*exprpb.Decl
idents map[string]*decls.VariableDecl
functions map[string]*decls.FunctionDecl
}
// copy creates a new Group instance with a shallow copy of the variables and functions.
// If callers need to mutate the exprpb.Decl definitions for a Function, they should copy-on-write.
func (g *Group) copy() *Group {
cpy := &Group{
idents: make(map[string]*exprpb.Decl, len(g.idents)),
functions: make(map[string]*exprpb.Decl, len(g.functions)),
idents: make(map[string]*decls.VariableDecl, len(g.idents)),
functions: make(map[string]*decls.FunctionDecl, len(g.functions)),
}
for n, id := range g.idents {
cpy.idents[n] = id
@@ -139,7 +141,7 @@ func (g *Group) copy() *Group {
// newGroup creates a new Group with empty maps for identifiers and functions.
func newGroup() *Group {
return &Group{
idents: make(map[string]*exprpb.Decl),
functions: make(map[string]*exprpb.Decl),
idents: make(map[string]*decls.VariableDecl),
functions: make(map[string]*decls.FunctionDecl),
}
}

View File

@@ -15,478 +15,21 @@
package checker
import (
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/stdlib"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
var (
standardDeclarations []*exprpb.Decl
)
func init() {
// Some shortcuts we use when building declarations.
paramA := decls.NewTypeParamType("A")
typeParamAList := []string{"A"}
listOfA := decls.NewListType(paramA)
paramB := decls.NewTypeParamType("B")
typeParamABList := []string{"A", "B"}
mapOfAB := decls.NewMapType(paramA, paramB)
var idents []*exprpb.Decl
for _, t := range []*exprpb.Type{
decls.Int, decls.Uint, decls.Bool,
decls.Double, decls.Bytes, decls.String} {
idents = append(idents,
decls.NewVar(FormatCheckedType(t), decls.NewTypeType(t)))
}
idents = append(idents,
decls.NewVar("list", decls.NewTypeType(listOfA)),
decls.NewVar("map", decls.NewTypeType(mapOfAB)),
decls.NewVar("null_type", decls.NewTypeType(decls.Null)),
decls.NewVar("type", decls.NewTypeType(decls.NewTypeType(nil))))
standardDeclarations = append(standardDeclarations, idents...)
standardDeclarations = append(standardDeclarations, []*exprpb.Decl{
// Booleans
decls.NewFunction(operators.Conditional,
decls.NewParameterizedOverload(overloads.Conditional,
[]*exprpb.Type{decls.Bool, paramA, paramA}, paramA,
typeParamAList)),
decls.NewFunction(operators.LogicalAnd,
decls.NewOverload(overloads.LogicalAnd,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool)),
decls.NewFunction(operators.LogicalOr,
decls.NewOverload(overloads.LogicalOr,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool)),
decls.NewFunction(operators.LogicalNot,
decls.NewOverload(overloads.LogicalNot,
[]*exprpb.Type{decls.Bool}, decls.Bool)),
decls.NewFunction(operators.NotStrictlyFalse,
decls.NewOverload(overloads.NotStrictlyFalse,
[]*exprpb.Type{decls.Bool}, decls.Bool)),
decls.NewFunction(operators.Equals,
decls.NewParameterizedOverload(overloads.Equals,
[]*exprpb.Type{paramA, paramA}, decls.Bool,
typeParamAList)),
decls.NewFunction(operators.NotEquals,
decls.NewParameterizedOverload(overloads.NotEquals,
[]*exprpb.Type{paramA, paramA}, decls.Bool,
typeParamAList)),
// Algebra.
decls.NewFunction(operators.Subtract,
decls.NewOverload(overloads.SubtractInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Int),
decls.NewOverload(overloads.SubtractUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Uint),
decls.NewOverload(overloads.SubtractDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Double),
decls.NewOverload(overloads.SubtractTimestampTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Duration),
decls.NewOverload(overloads.SubtractTimestampDuration,
[]*exprpb.Type{decls.Timestamp, decls.Duration}, decls.Timestamp),
decls.NewOverload(overloads.SubtractDurationDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Duration)),
decls.NewFunction(operators.Multiply,
decls.NewOverload(overloads.MultiplyInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Int),
decls.NewOverload(overloads.MultiplyUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Uint),
decls.NewOverload(overloads.MultiplyDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Double)),
decls.NewFunction(operators.Divide,
decls.NewOverload(overloads.DivideInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Int),
decls.NewOverload(overloads.DivideUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Uint),
decls.NewOverload(overloads.DivideDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Double)),
decls.NewFunction(operators.Modulo,
decls.NewOverload(overloads.ModuloInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Int),
decls.NewOverload(overloads.ModuloUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Uint)),
decls.NewFunction(operators.Add,
decls.NewOverload(overloads.AddInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Int),
decls.NewOverload(overloads.AddUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Uint),
decls.NewOverload(overloads.AddDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Double),
decls.NewOverload(overloads.AddString,
[]*exprpb.Type{decls.String, decls.String}, decls.String),
decls.NewOverload(overloads.AddBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bytes),
decls.NewParameterizedOverload(overloads.AddList,
[]*exprpb.Type{listOfA, listOfA}, listOfA,
typeParamAList),
decls.NewOverload(overloads.AddTimestampDuration,
[]*exprpb.Type{decls.Timestamp, decls.Duration}, decls.Timestamp),
decls.NewOverload(overloads.AddDurationTimestamp,
[]*exprpb.Type{decls.Duration, decls.Timestamp}, decls.Timestamp),
decls.NewOverload(overloads.AddDurationDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Duration)),
decls.NewFunction(operators.Negate,
decls.NewOverload(overloads.NegateInt64,
[]*exprpb.Type{decls.Int}, decls.Int),
decls.NewOverload(overloads.NegateDouble,
[]*exprpb.Type{decls.Double}, decls.Double)),
// Index.
decls.NewFunction(operators.Index,
decls.NewParameterizedOverload(overloads.IndexList,
[]*exprpb.Type{listOfA, decls.Int}, paramA,
typeParamAList),
decls.NewParameterizedOverload(overloads.IndexMap,
[]*exprpb.Type{mapOfAB, paramA}, paramB,
typeParamABList)),
// Collections.
decls.NewFunction(overloads.Size,
decls.NewInstanceOverload(overloads.SizeStringInst,
[]*exprpb.Type{decls.String}, decls.Int),
decls.NewInstanceOverload(overloads.SizeBytesInst,
[]*exprpb.Type{decls.Bytes}, decls.Int),
decls.NewParameterizedInstanceOverload(overloads.SizeListInst,
[]*exprpb.Type{listOfA}, decls.Int, typeParamAList),
decls.NewParameterizedInstanceOverload(overloads.SizeMapInst,
[]*exprpb.Type{mapOfAB}, decls.Int, typeParamABList),
decls.NewOverload(overloads.SizeString,
[]*exprpb.Type{decls.String}, decls.Int),
decls.NewOverload(overloads.SizeBytes,
[]*exprpb.Type{decls.Bytes}, decls.Int),
decls.NewParameterizedOverload(overloads.SizeList,
[]*exprpb.Type{listOfA}, decls.Int, typeParamAList),
decls.NewParameterizedOverload(overloads.SizeMap,
[]*exprpb.Type{mapOfAB}, decls.Int, typeParamABList)),
decls.NewFunction(operators.In,
decls.NewParameterizedOverload(overloads.InList,
[]*exprpb.Type{paramA, listOfA}, decls.Bool,
typeParamAList),
decls.NewParameterizedOverload(overloads.InMap,
[]*exprpb.Type{paramA, mapOfAB}, decls.Bool,
typeParamABList)),
// Deprecated 'in()' function.
decls.NewFunction(overloads.DeprecatedIn,
decls.NewParameterizedOverload(overloads.InList,
[]*exprpb.Type{paramA, listOfA}, decls.Bool,
typeParamAList),
decls.NewParameterizedOverload(overloads.InMap,
[]*exprpb.Type{paramA, mapOfAB}, decls.Bool,
typeParamABList)),
// Conversions to type.
decls.NewFunction(overloads.TypeConvertType,
decls.NewParameterizedOverload(overloads.TypeConvertType,
[]*exprpb.Type{paramA}, decls.NewTypeType(paramA), typeParamAList)),
// Conversions to int.
decls.NewFunction(overloads.TypeConvertInt,
decls.NewOverload(overloads.IntToInt, []*exprpb.Type{decls.Int}, decls.Int),
decls.NewOverload(overloads.UintToInt, []*exprpb.Type{decls.Uint}, decls.Int),
decls.NewOverload(overloads.DoubleToInt, []*exprpb.Type{decls.Double}, decls.Int),
decls.NewOverload(overloads.StringToInt, []*exprpb.Type{decls.String}, decls.Int),
decls.NewOverload(overloads.TimestampToInt, []*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewOverload(overloads.DurationToInt, []*exprpb.Type{decls.Duration}, decls.Int)),
// Conversions to uint.
decls.NewFunction(overloads.TypeConvertUint,
decls.NewOverload(overloads.UintToUint, []*exprpb.Type{decls.Uint}, decls.Uint),
decls.NewOverload(overloads.IntToUint, []*exprpb.Type{decls.Int}, decls.Uint),
decls.NewOverload(overloads.DoubleToUint, []*exprpb.Type{decls.Double}, decls.Uint),
decls.NewOverload(overloads.StringToUint, []*exprpb.Type{decls.String}, decls.Uint)),
// Conversions to double.
decls.NewFunction(overloads.TypeConvertDouble,
decls.NewOverload(overloads.DoubleToDouble, []*exprpb.Type{decls.Double}, decls.Double),
decls.NewOverload(overloads.IntToDouble, []*exprpb.Type{decls.Int}, decls.Double),
decls.NewOverload(overloads.UintToDouble, []*exprpb.Type{decls.Uint}, decls.Double),
decls.NewOverload(overloads.StringToDouble, []*exprpb.Type{decls.String}, decls.Double)),
// Conversions to bool.
decls.NewFunction(overloads.TypeConvertBool,
decls.NewOverload(overloads.BoolToBool, []*exprpb.Type{decls.Bool}, decls.Bool),
decls.NewOverload(overloads.StringToBool, []*exprpb.Type{decls.String}, decls.Bool)),
// Conversions to string.
decls.NewFunction(overloads.TypeConvertString,
decls.NewOverload(overloads.StringToString, []*exprpb.Type{decls.String}, decls.String),
decls.NewOverload(overloads.BoolToString, []*exprpb.Type{decls.Bool}, decls.String),
decls.NewOverload(overloads.IntToString, []*exprpb.Type{decls.Int}, decls.String),
decls.NewOverload(overloads.UintToString, []*exprpb.Type{decls.Uint}, decls.String),
decls.NewOverload(overloads.DoubleToString, []*exprpb.Type{decls.Double}, decls.String),
decls.NewOverload(overloads.BytesToString, []*exprpb.Type{decls.Bytes}, decls.String),
decls.NewOverload(overloads.TimestampToString, []*exprpb.Type{decls.Timestamp}, decls.String),
decls.NewOverload(overloads.DurationToString, []*exprpb.Type{decls.Duration}, decls.String)),
// Conversions to bytes.
decls.NewFunction(overloads.TypeConvertBytes,
decls.NewOverload(overloads.BytesToBytes, []*exprpb.Type{decls.Bytes}, decls.Bytes),
decls.NewOverload(overloads.StringToBytes, []*exprpb.Type{decls.String}, decls.Bytes)),
// Conversions to timestamps.
decls.NewFunction(overloads.TypeConvertTimestamp,
decls.NewOverload(overloads.TimestampToTimestamp,
[]*exprpb.Type{decls.Timestamp}, decls.Timestamp),
decls.NewOverload(overloads.StringToTimestamp,
[]*exprpb.Type{decls.String}, decls.Timestamp),
decls.NewOverload(overloads.IntToTimestamp,
[]*exprpb.Type{decls.Int}, decls.Timestamp)),
// Conversions to durations.
decls.NewFunction(overloads.TypeConvertDuration,
decls.NewOverload(overloads.DurationToDuration,
[]*exprpb.Type{decls.Duration}, decls.Duration),
decls.NewOverload(overloads.StringToDuration,
[]*exprpb.Type{decls.String}, decls.Duration),
decls.NewOverload(overloads.IntToDuration,
[]*exprpb.Type{decls.Int}, decls.Duration)),
// Conversions to Dyn.
decls.NewFunction(overloads.TypeConvertDyn,
decls.NewParameterizedOverload(overloads.ToDyn,
[]*exprpb.Type{paramA}, decls.Dyn,
typeParamAList)),
// String functions.
decls.NewFunction(overloads.Contains,
decls.NewInstanceOverload(overloads.ContainsString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)),
decls.NewFunction(overloads.EndsWith,
decls.NewInstanceOverload(overloads.EndsWithString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)),
decls.NewFunction(overloads.Matches,
decls.NewInstanceOverload(overloads.MatchesString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)),
decls.NewFunction(overloads.StartsWith,
decls.NewInstanceOverload(overloads.StartsWithString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool)),
// Date/time functions.
decls.NewFunction(overloads.TimeGetFullYear,
decls.NewInstanceOverload(overloads.TimestampToYear,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToYearWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetMonth,
decls.NewInstanceOverload(overloads.TimestampToMonth,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToMonthWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetDayOfYear,
decls.NewInstanceOverload(overloads.TimestampToDayOfYear,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToDayOfYearWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetDayOfMonth,
decls.NewInstanceOverload(overloads.TimestampToDayOfMonthZeroBased,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToDayOfMonthZeroBasedWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetDate,
decls.NewInstanceOverload(overloads.TimestampToDayOfMonthOneBased,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToDayOfMonthOneBasedWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetDayOfWeek,
decls.NewInstanceOverload(overloads.TimestampToDayOfWeek,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToDayOfWeekWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int)),
decls.NewFunction(overloads.TimeGetHours,
decls.NewInstanceOverload(overloads.TimestampToHours,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToHoursWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int),
decls.NewInstanceOverload(overloads.DurationToHours,
[]*exprpb.Type{decls.Duration}, decls.Int)),
decls.NewFunction(overloads.TimeGetMinutes,
decls.NewInstanceOverload(overloads.TimestampToMinutes,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToMinutesWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int),
decls.NewInstanceOverload(overloads.DurationToMinutes,
[]*exprpb.Type{decls.Duration}, decls.Int)),
decls.NewFunction(overloads.TimeGetSeconds,
decls.NewInstanceOverload(overloads.TimestampToSeconds,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToSecondsWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int),
decls.NewInstanceOverload(overloads.DurationToSeconds,
[]*exprpb.Type{decls.Duration}, decls.Int)),
decls.NewFunction(overloads.TimeGetMilliseconds,
decls.NewInstanceOverload(overloads.TimestampToMilliseconds,
[]*exprpb.Type{decls.Timestamp}, decls.Int),
decls.NewInstanceOverload(overloads.TimestampToMillisecondsWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int),
decls.NewInstanceOverload(overloads.DurationToMilliseconds,
[]*exprpb.Type{decls.Duration}, decls.Int)),
// Relations.
decls.NewFunction(operators.Less,
decls.NewOverload(overloads.LessBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.LessInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessInt64Double,
[]*exprpb.Type{decls.Int, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessInt64Uint64,
[]*exprpb.Type{decls.Int, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessUint64Double,
[]*exprpb.Type{decls.Uint, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessUint64Int64,
[]*exprpb.Type{decls.Uint, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessDoubleInt64,
[]*exprpb.Type{decls.Double, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessDoubleUint64,
[]*exprpb.Type{decls.Double, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.LessBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.LessTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.LessDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
decls.NewFunction(operators.LessEquals,
decls.NewOverload(overloads.LessEqualsBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.LessEqualsInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessEqualsInt64Double,
[]*exprpb.Type{decls.Int, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessEqualsInt64Uint64,
[]*exprpb.Type{decls.Int, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessEqualsUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessEqualsUint64Double,
[]*exprpb.Type{decls.Uint, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessEqualsUint64Int64,
[]*exprpb.Type{decls.Uint, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDoubleInt64,
[]*exprpb.Type{decls.Double, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDoubleUint64,
[]*exprpb.Type{decls.Double, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessEqualsString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.LessEqualsBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.LessEqualsTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
decls.NewFunction(operators.Greater,
decls.NewOverload(overloads.GreaterBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.GreaterInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterInt64Double,
[]*exprpb.Type{decls.Int, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterInt64Uint64,
[]*exprpb.Type{decls.Int, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterUint64Double,
[]*exprpb.Type{decls.Uint, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterUint64Int64,
[]*exprpb.Type{decls.Uint, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterDoubleInt64,
[]*exprpb.Type{decls.Double, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterDoubleUint64,
[]*exprpb.Type{decls.Double, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.GreaterBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.GreaterTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.GreaterDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
decls.NewFunction(operators.GreaterEquals,
decls.NewOverload(overloads.GreaterEqualsBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsInt64Double,
[]*exprpb.Type{decls.Int, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsInt64Uint64,
[]*exprpb.Type{decls.Int, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsUint64Double,
[]*exprpb.Type{decls.Uint, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsUint64Int64,
[]*exprpb.Type{decls.Uint, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDoubleInt64,
[]*exprpb.Type{decls.Double, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDoubleUint64,
[]*exprpb.Type{decls.Double, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
}...)
// StandardFunctions returns the Decls for all functions in the evaluator.
//
// Deprecated: prefer stdlib.FunctionExprDecls()
func StandardFunctions() []*exprpb.Decl {
return stdlib.FunctionExprDecls()
}
// StandardDeclarations returns the Decls for all functions and constants in the evaluator.
func StandardDeclarations() []*exprpb.Decl {
return standardDeclarations
// StandardTypes returns the set of type identifiers for standard library types.
//
// Deprecated: prefer stdlib.TypeExprDecls()
func StandardTypes() []*exprpb.Decl {
return stdlib.TypeExprDecls()
}

View File

@@ -15,119 +15,54 @@
package checker
import (
"fmt"
"strings"
"github.com/google/cel-go/checker/decls"
"google.golang.org/protobuf/proto"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/types"
)
const (
kindUnknown = iota + 1
kindError
kindFunction
kindDyn
kindPrimitive
kindWellKnown
kindWrapper
kindNull
kindAbstract
kindType
kindList
kindMap
kindObject
kindTypeParam
)
// FormatCheckedType converts a type message into a string representation.
func FormatCheckedType(t *exprpb.Type) string {
switch kindOf(t) {
case kindDyn:
return "dyn"
case kindFunction:
return formatFunction(t.GetFunction().GetResultType(),
t.GetFunction().GetArgTypes(),
false)
case kindList:
return fmt.Sprintf("list(%s)", FormatCheckedType(t.GetListType().GetElemType()))
case kindObject:
return t.GetMessageType()
case kindMap:
return fmt.Sprintf("map(%s, %s)",
FormatCheckedType(t.GetMapType().GetKeyType()),
FormatCheckedType(t.GetMapType().GetValueType()))
case kindNull:
return "null"
case kindPrimitive:
switch t.GetPrimitive() {
case exprpb.Type_UINT64:
return "uint"
case exprpb.Type_INT64:
return "int"
}
return strings.Trim(strings.ToLower(t.GetPrimitive().String()), " ")
case kindType:
if t.GetType() == nil {
return "type"
}
return fmt.Sprintf("type(%s)", FormatCheckedType(t.GetType()))
case kindWellKnown:
switch t.GetWellKnown() {
case exprpb.Type_ANY:
return "any"
case exprpb.Type_DURATION:
return "duration"
case exprpb.Type_TIMESTAMP:
return "timestamp"
}
case kindWrapper:
return fmt.Sprintf("wrapper(%s)",
FormatCheckedType(decls.NewPrimitiveType(t.GetWrapper())))
case kindError:
return "!error!"
case kindTypeParam:
return t.GetTypeParam()
}
return t.String()
}
// isDyn returns true if the input t is either type DYN or a well-known ANY message.
func isDyn(t *exprpb.Type) bool {
func isDyn(t *types.Type) bool {
// Note: object type values that are well-known and map to a DYN value in practice
// are sanitized prior to being added to the environment.
switch kindOf(t) {
case kindDyn:
switch t.Kind() {
case types.DynKind, types.AnyKind:
return true
case kindWellKnown:
return t.GetWellKnown() == exprpb.Type_ANY
default:
return false
}
}
// isDynOrError returns true if the input is either an Error, DYN, or well-known ANY message.
func isDynOrError(t *exprpb.Type) bool {
switch kindOf(t) {
case kindError:
return true
default:
return isDyn(t)
func isDynOrError(t *types.Type) bool {
return isError(t) || isDyn(t)
}
func isError(t *types.Type) bool {
return t.Kind() == types.ErrorKind
}
func isOptional(t *types.Type) bool {
if t.Kind() == types.OpaqueKind {
return t.TypeName() == "optional"
}
return false
}
func maybeUnwrapOptional(t *types.Type) (*types.Type, bool) {
if isOptional(t) {
return t.Parameters()[0], true
}
return t, false
}
// isEqualOrLessSpecific checks whether one type is equal or less specific than the other one.
// A type is less specific if it matches the other type using the DYN type.
func isEqualOrLessSpecific(t1 *exprpb.Type, t2 *exprpb.Type) bool {
kind1, kind2 := kindOf(t1), kindOf(t2)
func isEqualOrLessSpecific(t1, t2 *types.Type) bool {
kind1, kind2 := t1.Kind(), t2.Kind()
// The first type is less specific.
if isDyn(t1) || kind1 == kindTypeParam {
if isDyn(t1) || kind1 == types.TypeParamKind {
return true
}
// The first type is not less specific.
if isDyn(t2) || kind2 == kindTypeParam {
if isDyn(t2) || kind2 == types.TypeParamKind {
return false
}
// Types must be of the same kind to be equal.
@@ -138,38 +73,34 @@ func isEqualOrLessSpecific(t1 *exprpb.Type, t2 *exprpb.Type) bool {
// With limited exceptions for ANY and JSON values, the types must agree and be equivalent in
// order to return true.
switch kind1 {
case kindAbstract:
a1 := t1.GetAbstractType()
a2 := t2.GetAbstractType()
if a1.GetName() != a2.GetName() ||
len(a1.GetParameterTypes()) != len(a2.GetParameterTypes()) {
case types.OpaqueKind:
if t1.TypeName() != t2.TypeName() ||
len(t1.Parameters()) != len(t2.Parameters()) {
return false
}
for i, p1 := range a1.GetParameterTypes() {
if !isEqualOrLessSpecific(p1, a2.GetParameterTypes()[i]) {
for i, p1 := range t1.Parameters() {
if !isEqualOrLessSpecific(p1, t2.Parameters()[i]) {
return false
}
}
return true
case kindList:
return isEqualOrLessSpecific(t1.GetListType().GetElemType(), t2.GetListType().GetElemType())
case kindMap:
m1 := t1.GetMapType()
m2 := t2.GetMapType()
return isEqualOrLessSpecific(m1.GetKeyType(), m2.GetKeyType()) &&
isEqualOrLessSpecific(m1.GetValueType(), m2.GetValueType())
case kindType:
case types.ListKind:
return isEqualOrLessSpecific(t1.Parameters()[0], t2.Parameters()[0])
case types.MapKind:
return isEqualOrLessSpecific(t1.Parameters()[0], t2.Parameters()[0]) &&
isEqualOrLessSpecific(t1.Parameters()[1], t2.Parameters()[1])
case types.TypeKind:
return true
default:
return proto.Equal(t1, t2)
return t1.IsExactType(t2)
}
}
// / internalIsAssignable returns true if t1 is assignable to t2.
func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
func internalIsAssignable(m *mapping, t1, t2 *types.Type) bool {
// Process type parameters.
kind1, kind2 := kindOf(t1), kindOf(t2)
if kind2 == kindTypeParam {
kind1, kind2 := t1.Kind(), t2.Kind()
if kind2 == types.TypeParamKind {
// If t2 is a valid type substitution for t1, return true.
valid, t2HasSub := isValidTypeSubstitution(m, t1, t2)
if valid {
@@ -182,7 +113,7 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
}
// Otherwise, fall through to check whether t1 is a possible substitution for t2.
}
if kind1 == kindTypeParam {
if kind1 == types.TypeParamKind {
// Return whether t1 is a valid substitution for t2. If not, do no additional checks as the
// possible type substitutions have been searched in both directions.
valid, _ := isValidTypeSubstitution(m, t2, t1)
@@ -193,40 +124,25 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
if isDynOrError(t1) || isDynOrError(t2) {
return true
}
// Preserve the nullness checks of the legacy type-checker.
if kind1 == types.NullTypeKind {
return internalIsAssignableNull(t2)
}
if kind2 == types.NullTypeKind {
return internalIsAssignableNull(t1)
}
// Test for when the types do not need to agree, but are more specific than dyn.
switch kind1 {
case kindNull:
return internalIsAssignableNull(t2)
case kindPrimitive:
return internalIsAssignablePrimitive(t1.GetPrimitive(), t2)
case kindWrapper:
return internalIsAssignable(m, decls.NewPrimitiveType(t1.GetWrapper()), t2)
default:
if kind1 != kind2 {
return false
}
}
// Test for when the types must agree.
switch kind1 {
// ERROR, TYPE_PARAM, and DYN handled above.
case kindAbstract:
return internalIsAssignableAbstractType(m, t1.GetAbstractType(), t2.GetAbstractType())
case kindFunction:
return internalIsAssignableFunction(m, t1.GetFunction(), t2.GetFunction())
case kindList:
return internalIsAssignable(m, t1.GetListType().GetElemType(), t2.GetListType().GetElemType())
case kindMap:
return internalIsAssignableMap(m, t1.GetMapType(), t2.GetMapType())
case kindObject:
return t1.GetMessageType() == t2.GetMessageType()
case kindType:
// A type is a type is a type, any additional parameterization of the
// type cannot affect method resolution or assignability.
return true
case kindWellKnown:
return t1.GetWellKnown() == t2.GetWellKnown()
case types.BoolKind, types.BytesKind, types.DoubleKind, types.IntKind, types.StringKind, types.UintKind,
types.AnyKind, types.DurationKind, types.TimestampKind,
types.StructKind:
return t1.IsAssignableType(t2)
case types.TypeKind:
return kind2 == types.TypeKind
case types.OpaqueKind, types.ListKind, types.MapKind:
return t1.Kind() == t2.Kind() && t1.TypeName() == t2.TypeName() &&
internalIsAssignableList(m, t1.Parameters(), t2.Parameters())
default:
return false
}
@@ -236,19 +152,19 @@ func internalIsAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) bool {
// substitution for t1, and whether t2 has a type substitution in mapping m.
//
// The type t2 is a valid substitution for t1 if any of the following statements is true
// - t2 has a type substitition (t2sub) equal to t1
// - t2 has a type substitution (t2sub) equal to t1
// - t2 has a type substitution (t2sub) assignable to t1
// - t2 does not occur within t1.
func isValidTypeSubstitution(m *mapping, t1, t2 *exprpb.Type) (valid, hasSub bool) {
func isValidTypeSubstitution(m *mapping, t1, t2 *types.Type) (valid, hasSub bool) {
// Early return if the t1 and t2 are the same instance.
kind1, kind2 := kindOf(t1), kindOf(t2)
if kind1 == kind2 && (t1 == t2 || proto.Equal(t1, t2)) {
kind1, kind2 := t1.Kind(), t2.Kind()
if kind1 == kind2 && t1.IsExactType(t2) {
return true, true
}
if t2Sub, found := m.find(t2); found {
// Early return if t1 and t2Sub are the same instance as otherwise the mapping
// might mark a type as being a subtitution for itself.
if kind1 == kindOf(t2Sub) && (t1 == t2Sub || proto.Equal(t1, t2Sub)) {
if kind1 == t2Sub.Kind() && t1.IsExactType(t2Sub) {
return true, true
}
// If the types are compatible, pick the more general type and return true
@@ -270,28 +186,10 @@ func isValidTypeSubstitution(m *mapping, t1, t2 *exprpb.Type) (valid, hasSub boo
return false, false
}
// internalIsAssignableAbstractType returns true if the abstract type names agree and all type
// parameters are assignable.
func internalIsAssignableAbstractType(m *mapping, a1 *exprpb.Type_AbstractType, a2 *exprpb.Type_AbstractType) bool {
return a1.GetName() == a2.GetName() &&
internalIsAssignableList(m, a1.GetParameterTypes(), a2.GetParameterTypes())
}
// internalIsAssignableFunction returns true if the function return type and arg types are
// assignable.
func internalIsAssignableFunction(m *mapping, f1 *exprpb.Type_FunctionType, f2 *exprpb.Type_FunctionType) bool {
f1ArgTypes := flattenFunctionTypes(f1)
f2ArgTypes := flattenFunctionTypes(f2)
if internalIsAssignableList(m, f1ArgTypes, f2ArgTypes) {
return true
}
return false
}
// internalIsAssignableList returns true if the element types at each index in the list are
// assignable from l1[i] to l2[i]. The list lengths must also agree for the lists to be
// assignable.
func internalIsAssignableList(m *mapping, l1 []*exprpb.Type, l2 []*exprpb.Type) bool {
func internalIsAssignableList(m *mapping, l1, l2 []*types.Type) bool {
if len(l1) != len(l2) {
return false
}
@@ -303,41 +201,22 @@ func internalIsAssignableList(m *mapping, l1 []*exprpb.Type, l2 []*exprpb.Type)
return true
}
// internalIsAssignableMap returns true if map m1 may be assigned to map m2.
func internalIsAssignableMap(m *mapping, m1 *exprpb.Type_MapType, m2 *exprpb.Type_MapType) bool {
if internalIsAssignableList(m,
[]*exprpb.Type{m1.GetKeyType(), m1.GetValueType()},
[]*exprpb.Type{m2.GetKeyType(), m2.GetValueType()}) {
// internalIsAssignableNull returns true if the type is nullable.
func internalIsAssignableNull(t *types.Type) bool {
return isLegacyNullable(t) || t.IsAssignableType(types.NullType)
}
// isLegacyNullable preserves the null-ness compatibility of the original type-checker implementation.
func isLegacyNullable(t *types.Type) bool {
switch t.Kind() {
case types.OpaqueKind, types.StructKind, types.AnyKind, types.DurationKind, types.TimestampKind:
return true
}
return false
}
// internalIsAssignableNull returns true if the type is nullable.
func internalIsAssignableNull(t *exprpb.Type) bool {
switch kindOf(t) {
case kindAbstract, kindObject, kindNull, kindWellKnown, kindWrapper:
return true
default:
return false
}
}
// internalIsAssignablePrimitive returns true if the target type is the same or if it is a wrapper
// for the primitive type.
func internalIsAssignablePrimitive(p exprpb.Type_PrimitiveType, target *exprpb.Type) bool {
switch kindOf(target) {
case kindPrimitive:
return p == target.GetPrimitive()
case kindWrapper:
return p == target.GetWrapper()
default:
return false
}
}
// isAssignable returns an updated type substitution mapping if t1 is assignable to t2.
func isAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) *mapping {
func isAssignable(m *mapping, t1, t2 *types.Type) *mapping {
mCopy := m.copy()
if internalIsAssignable(mCopy, t1, t2) {
return mCopy
@@ -346,7 +225,7 @@ func isAssignable(m *mapping, t1 *exprpb.Type, t2 *exprpb.Type) *mapping {
}
// isAssignableList returns an updated type substitution mapping if l1 is assignable to l2.
func isAssignableList(m *mapping, l1 []*exprpb.Type, l2 []*exprpb.Type) *mapping {
func isAssignableList(m *mapping, l1, l2 []*types.Type) *mapping {
mCopy := m.copy()
if internalIsAssignableList(mCopy, l1, l2) {
return mCopy
@@ -354,44 +233,8 @@ func isAssignableList(m *mapping, l1 []*exprpb.Type, l2 []*exprpb.Type) *mapping
return nil
}
// kindOf returns the kind of the type as defined in the checked.proto.
func kindOf(t *exprpb.Type) int {
if t == nil || t.TypeKind == nil {
return kindUnknown
}
switch t.GetTypeKind().(type) {
case *exprpb.Type_Error:
return kindError
case *exprpb.Type_Function:
return kindFunction
case *exprpb.Type_Dyn:
return kindDyn
case *exprpb.Type_Primitive:
return kindPrimitive
case *exprpb.Type_WellKnown:
return kindWellKnown
case *exprpb.Type_Wrapper:
return kindWrapper
case *exprpb.Type_Null:
return kindNull
case *exprpb.Type_Type:
return kindType
case *exprpb.Type_ListType_:
return kindList
case *exprpb.Type_MapType_:
return kindMap
case *exprpb.Type_MessageType:
return kindObject
case *exprpb.Type_TypeParam:
return kindTypeParam
case *exprpb.Type_AbstractType_:
return kindAbstract
}
return kindUnknown
}
// mostGeneral returns the more general of two types which are known to unify.
func mostGeneral(t1 *exprpb.Type, t2 *exprpb.Type) *exprpb.Type {
func mostGeneral(t1, t2 *types.Type) *types.Type {
if isEqualOrLessSpecific(t1, t2) {
return t1
}
@@ -401,32 +244,25 @@ func mostGeneral(t1 *exprpb.Type, t2 *exprpb.Type) *exprpb.Type {
// notReferencedIn checks whether the type doesn't appear directly or transitively within the other
// type. This is a standard requirement for type unification, commonly referred to as the "occurs
// check".
func notReferencedIn(m *mapping, t *exprpb.Type, withinType *exprpb.Type) bool {
if proto.Equal(t, withinType) {
func notReferencedIn(m *mapping, t, withinType *types.Type) bool {
if t.IsExactType(withinType) {
return false
}
withinKind := kindOf(withinType)
withinKind := withinType.Kind()
switch withinKind {
case kindTypeParam:
case types.TypeParamKind:
wtSub, found := m.find(withinType)
if !found {
return true
}
return notReferencedIn(m, t, wtSub)
case kindAbstract:
for _, pt := range withinType.GetAbstractType().GetParameterTypes() {
case types.OpaqueKind, types.ListKind, types.MapKind:
for _, pt := range withinType.Parameters() {
if !notReferencedIn(m, t, pt) {
return false
}
}
return true
case kindList:
return notReferencedIn(m, t, withinType.GetListType().GetElemType())
case kindMap:
mt := withinType.GetMapType()
return notReferencedIn(m, t, mt.GetKeyType()) && notReferencedIn(m, t, mt.GetValueType())
case kindWrapper:
return notReferencedIn(m, t, decls.NewPrimitiveType(withinType.GetWrapper()))
default:
return true
}
@@ -434,39 +270,25 @@ func notReferencedIn(m *mapping, t *exprpb.Type, withinType *exprpb.Type) bool {
// substitute replaces all direct and indirect occurrences of bound type parameters. Unbound type
// parameters are replaced by DYN if typeParamToDyn is true.
func substitute(m *mapping, t *exprpb.Type, typeParamToDyn bool) *exprpb.Type {
func substitute(m *mapping, t *types.Type, typeParamToDyn bool) *types.Type {
if tSub, found := m.find(t); found {
return substitute(m, tSub, typeParamToDyn)
}
kind := kindOf(t)
if typeParamToDyn && kind == kindTypeParam {
return decls.Dyn
kind := t.Kind()
if typeParamToDyn && kind == types.TypeParamKind {
return types.DynType
}
switch kind {
case kindAbstract:
at := t.GetAbstractType()
params := make([]*exprpb.Type, len(at.GetParameterTypes()))
for i, p := range at.GetParameterTypes() {
params[i] = substitute(m, p, typeParamToDyn)
}
return decls.NewAbstractType(at.GetName(), params...)
case kindFunction:
fn := t.GetFunction()
rt := substitute(m, fn.ResultType, typeParamToDyn)
args := make([]*exprpb.Type, len(fn.GetArgTypes()))
for i, a := range fn.ArgTypes {
args[i] = substitute(m, a, typeParamToDyn)
}
return decls.NewFunctionType(rt, args...)
case kindList:
return decls.NewListType(substitute(m, t.GetListType().GetElemType(), typeParamToDyn))
case kindMap:
mt := t.GetMapType()
return decls.NewMapType(substitute(m, mt.GetKeyType(), typeParamToDyn),
substitute(m, mt.GetValueType(), typeParamToDyn))
case kindType:
if t.GetType() != nil {
return decls.NewTypeType(substitute(m, t.GetType(), typeParamToDyn))
case types.OpaqueKind:
return types.NewOpaqueType(t.TypeName(), substituteParams(m, t.Parameters(), typeParamToDyn)...)
case types.ListKind:
return types.NewListType(substitute(m, t.Parameters()[0], typeParamToDyn))
case types.MapKind:
return types.NewMapType(substitute(m, t.Parameters()[0], typeParamToDyn),
substitute(m, t.Parameters()[1], typeParamToDyn))
case types.TypeKind:
if len(t.Parameters()) > 0 {
return types.NewTypeTypeWithParam(substitute(m, t.Parameters()[0], typeParamToDyn))
}
return t
default:
@@ -474,21 +296,14 @@ func substitute(m *mapping, t *exprpb.Type, typeParamToDyn bool) *exprpb.Type {
}
}
func typeKey(t *exprpb.Type) string {
return FormatCheckedType(t)
func substituteParams(m *mapping, typeParams []*types.Type, typeParamToDyn bool) []*types.Type {
subParams := make([]*types.Type, len(typeParams))
for i, tp := range typeParams {
subParams[i] = substitute(m, tp, typeParamToDyn)
}
return subParams
}
// flattenFunctionTypes takes a function with arg types T1, T2, ..., TN and result type TR
// and returns a slice containing {T1, T2, ..., TN, TR}.
func flattenFunctionTypes(f *exprpb.Type_FunctionType) []*exprpb.Type {
argTypes := f.GetArgTypes()
if len(argTypes) == 0 {
return []*exprpb.Type{f.GetResultType()}
}
flattend := make([]*exprpb.Type, len(argTypes)+1, len(argTypes)+1)
for i, at := range argTypes {
flattend[i] = at
}
flattend[len(argTypes)] = f.GetResultType()
return flattend
func newFunctionType(resultType *types.Type, argTypes ...*types.Type) *types.Type {
return types.NewOpaqueType("function", append([]*types.Type{resultType}, argTypes...)...)
}

View File

@@ -17,7 +17,7 @@ go_library(
importpath = "github.com/google/cel-go/common",
deps = [
"//common/runes:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_x_text//width:go_default_library",
],
)

52
vendor/github.com/google/cel-go/common/ast/BUILD.bazel generated vendored Normal file
View File

@@ -0,0 +1,52 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(
default_visibility = [
"//cel:__subpackages__",
"//checker:__subpackages__",
"//common:__subpackages__",
"//interpreter:__subpackages__",
],
licenses = ["notice"], # Apache 2.0
)
go_library(
name = "go_default_library",
srcs = [
"ast.go",
"expr.go",
],
importpath = "github.com/google/cel-go/common/ast",
deps = [
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = [
"ast_test.go",
"expr_test.go",
],
embed = [
":go_default_library",
],
deps = [
"//checker:go_default_library",
"//checker/decls:go_default_library",
"//common:go_default_library",
"//common/containers:go_default_library",
"//common/decls:go_default_library",
"//common/overloads:go_default_library",
"//common/stdlib:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//parser:go_default_library",
"//test/proto3pb:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
)

226
vendor/github.com/google/cel-go/common/ast/ast.go generated vendored Normal file
View File

@@ -0,0 +1,226 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package ast declares data structures useful for parsed and checked abstract syntax trees
package ast
import (
"fmt"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
structpb "google.golang.org/protobuf/types/known/structpb"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// CheckedAST contains a protobuf expression and source info along with CEL-native type and reference information.
type CheckedAST struct {
Expr *exprpb.Expr
SourceInfo *exprpb.SourceInfo
TypeMap map[int64]*types.Type
ReferenceMap map[int64]*ReferenceInfo
}
// CheckedASTToCheckedExpr converts a CheckedAST to a CheckedExpr protobouf.
func CheckedASTToCheckedExpr(ast *CheckedAST) (*exprpb.CheckedExpr, error) {
refMap := make(map[int64]*exprpb.Reference, len(ast.ReferenceMap))
for id, ref := range ast.ReferenceMap {
r, err := ReferenceInfoToReferenceExpr(ref)
if err != nil {
return nil, err
}
refMap[id] = r
}
typeMap := make(map[int64]*exprpb.Type, len(ast.TypeMap))
for id, typ := range ast.TypeMap {
t, err := types.TypeToExprType(typ)
if err != nil {
return nil, err
}
typeMap[id] = t
}
return &exprpb.CheckedExpr{
Expr: ast.Expr,
SourceInfo: ast.SourceInfo,
ReferenceMap: refMap,
TypeMap: typeMap,
}, nil
}
// CheckedExprToCheckedAST converts a CheckedExpr protobuf to a CheckedAST instance.
func CheckedExprToCheckedAST(checked *exprpb.CheckedExpr) (*CheckedAST, error) {
refMap := make(map[int64]*ReferenceInfo, len(checked.GetReferenceMap()))
for id, ref := range checked.GetReferenceMap() {
r, err := ReferenceExprToReferenceInfo(ref)
if err != nil {
return nil, err
}
refMap[id] = r
}
typeMap := make(map[int64]*types.Type, len(checked.GetTypeMap()))
for id, typ := range checked.GetTypeMap() {
t, err := types.ExprTypeToType(typ)
if err != nil {
return nil, err
}
typeMap[id] = t
}
return &CheckedAST{
Expr: checked.GetExpr(),
SourceInfo: checked.GetSourceInfo(),
ReferenceMap: refMap,
TypeMap: typeMap,
}, nil
}
// ReferenceInfo contains a CEL native representation of an identifier reference which may refer to
// either a qualified identifier name, a set of overload ids, or a constant value from an enum.
type ReferenceInfo struct {
Name string
OverloadIDs []string
Value ref.Val
}
// NewIdentReference creates a ReferenceInfo instance for an identifier with an optional constant value.
func NewIdentReference(name string, value ref.Val) *ReferenceInfo {
return &ReferenceInfo{Name: name, Value: value}
}
// NewFunctionReference creates a ReferenceInfo instance for a set of function overloads.
func NewFunctionReference(overloads ...string) *ReferenceInfo {
info := &ReferenceInfo{}
for _, id := range overloads {
info.AddOverload(id)
}
return info
}
// AddOverload appends a function overload ID to the ReferenceInfo.
func (r *ReferenceInfo) AddOverload(overloadID string) {
for _, id := range r.OverloadIDs {
if id == overloadID {
return
}
}
r.OverloadIDs = append(r.OverloadIDs, overloadID)
}
// Equals returns whether two references are identical to each other.
func (r *ReferenceInfo) Equals(other *ReferenceInfo) bool {
if r.Name != other.Name {
return false
}
if len(r.OverloadIDs) != len(other.OverloadIDs) {
return false
}
if len(r.OverloadIDs) != 0 {
overloadMap := make(map[string]struct{}, len(r.OverloadIDs))
for _, id := range r.OverloadIDs {
overloadMap[id] = struct{}{}
}
for _, id := range other.OverloadIDs {
_, found := overloadMap[id]
if !found {
return false
}
}
}
if r.Value == nil && other.Value == nil {
return true
}
if r.Value == nil && other.Value != nil ||
r.Value != nil && other.Value == nil ||
r.Value.Equal(other.Value) != types.True {
return false
}
return true
}
// ReferenceInfoToReferenceExpr converts a ReferenceInfo instance to a protobuf Reference suitable for serialization.
func ReferenceInfoToReferenceExpr(info *ReferenceInfo) (*exprpb.Reference, error) {
c, err := ValToConstant(info.Value)
if err != nil {
return nil, err
}
return &exprpb.Reference{
Name: info.Name,
OverloadId: info.OverloadIDs,
Value: c,
}, nil
}
// ReferenceExprToReferenceInfo converts a protobuf Reference into a CEL-native ReferenceInfo instance.
func ReferenceExprToReferenceInfo(ref *exprpb.Reference) (*ReferenceInfo, error) {
v, err := ConstantToVal(ref.GetValue())
if err != nil {
return nil, err
}
return &ReferenceInfo{
Name: ref.GetName(),
OverloadIDs: ref.GetOverloadId(),
Value: v,
}, nil
}
// ValToConstant converts a CEL-native ref.Val to a protobuf Constant.
//
// Only simple scalar types are supported by this method.
func ValToConstant(v ref.Val) (*exprpb.Constant, error) {
if v == nil {
return nil, nil
}
switch v.Type() {
case types.BoolType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: v.Value().(bool)}}, nil
case types.BytesType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: v.Value().([]byte)}}, nil
case types.DoubleType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: v.Value().(float64)}}, nil
case types.IntType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: v.Value().(int64)}}, nil
case types.NullType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: structpb.NullValue_NULL_VALUE}}, nil
case types.StringType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: v.Value().(string)}}, nil
case types.UintType:
return &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: v.Value().(uint64)}}, nil
}
return nil, fmt.Errorf("unsupported constant kind: %v", v.Type())
}
// ConstantToVal converts a protobuf Constant to a CEL-native ref.Val.
func ConstantToVal(c *exprpb.Constant) (ref.Val, error) {
if c == nil {
return nil, nil
}
switch c.GetConstantKind().(type) {
case *exprpb.Constant_BoolValue:
return types.Bool(c.GetBoolValue()), nil
case *exprpb.Constant_BytesValue:
return types.Bytes(c.GetBytesValue()), nil
case *exprpb.Constant_DoubleValue:
return types.Double(c.GetDoubleValue()), nil
case *exprpb.Constant_Int64Value:
return types.Int(c.GetInt64Value()), nil
case *exprpb.Constant_NullValue:
return types.NullValue, nil
case *exprpb.Constant_StringValue:
return types.String(c.GetStringValue()), nil
case *exprpb.Constant_Uint64Value:
return types.Uint(c.GetUint64Value()), nil
}
return nil, fmt.Errorf("unsupported constant kind: %v", c.GetConstantKind())
}

709
vendor/github.com/google/cel-go/common/ast/expr.go generated vendored Normal file
View File

@@ -0,0 +1,709 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ast
import (
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// ExprKind represents the expression node kind.
type ExprKind int
const (
// UnspecifiedKind represents an unset expression with no specified properties.
UnspecifiedKind ExprKind = iota
// LiteralKind represents a primitive scalar literal.
LiteralKind
// IdentKind represents a simple variable, constant, or type identifier.
IdentKind
// SelectKind represents a field selection expression.
SelectKind
// CallKind represents a function call.
CallKind
// ListKind represents a list literal expression.
ListKind
// MapKind represents a map literal expression.
MapKind
// StructKind represents a struct literal expression.
StructKind
// ComprehensionKind represents a comprehension expression generated by a macro.
ComprehensionKind
)
// NavigateCheckedAST converts a CheckedAST to a NavigableExpr
func NavigateCheckedAST(ast *CheckedAST) NavigableExpr {
return newNavigableExpr(nil, ast.Expr, ast.TypeMap)
}
// ExprMatcher takes a NavigableExpr in and indicates whether the value is a match.
//
// This function type should be use with the `Match` and `MatchList` calls.
type ExprMatcher func(NavigableExpr) bool
// ConstantValueMatcher returns an ExprMatcher which will return true if the input NavigableExpr
// is comprised of all constant values, such as a simple literal or even list and map literal.
func ConstantValueMatcher() ExprMatcher {
return matchIsConstantValue
}
// KindMatcher returns an ExprMatcher which will return true if the input NavigableExpr.Kind() matches
// the specified `kind`.
func KindMatcher(kind ExprKind) ExprMatcher {
return func(e NavigableExpr) bool {
return e.Kind() == kind
}
}
// FunctionMatcher returns an ExprMatcher which will match NavigableExpr nodes of CallKind type whose
// function name is equal to `funcName`.
func FunctionMatcher(funcName string) ExprMatcher {
return func(e NavigableExpr) bool {
if e.Kind() != CallKind {
return false
}
return e.AsCall().FunctionName() == funcName
}
}
// AllMatcher returns true for all descendants of a NavigableExpr, effectively flattening them into a list.
//
// Such a result would work well with subsequent MatchList calls.
func AllMatcher() ExprMatcher {
return func(NavigableExpr) bool {
return true
}
}
// MatchDescendants takes a NavigableExpr and ExprMatcher and produces a list of NavigableExpr values of the
// descendants which match.
func MatchDescendants(expr NavigableExpr, matcher ExprMatcher) []NavigableExpr {
return matchListInternal([]NavigableExpr{expr}, matcher, true)
}
// MatchSubset applies an ExprMatcher to a list of NavigableExpr values and their descendants, producing a
// subset of NavigableExpr values which match.
func MatchSubset(exprs []NavigableExpr, matcher ExprMatcher) []NavigableExpr {
visit := make([]NavigableExpr, len(exprs))
copy(visit, exprs)
return matchListInternal(visit, matcher, false)
}
func matchListInternal(visit []NavigableExpr, matcher ExprMatcher, visitDescendants bool) []NavigableExpr {
var matched []NavigableExpr
for len(visit) != 0 {
e := visit[0]
if matcher(e) {
matched = append(matched, e)
}
if visitDescendants {
visit = append(visit[1:], e.Children()...)
} else {
visit = visit[1:]
}
}
return matched
}
func matchIsConstantValue(e NavigableExpr) bool {
if e.Kind() == LiteralKind {
return true
}
if e.Kind() == StructKind || e.Kind() == MapKind || e.Kind() == ListKind {
for _, child := range e.Children() {
if !matchIsConstantValue(child) {
return false
}
}
return true
}
return false
}
// NavigableExpr represents the base navigable expression value.
//
// Depending on the `Kind()` value, the NavigableExpr may be converted to a concrete expression types
// as indicated by the `As<Kind>` methods.
//
// NavigableExpr values and their concrete expression types should be nil-safe. Conversion of an expr
// to the wrong kind should produce a nil value.
type NavigableExpr interface {
// ID of the expression as it appears in the AST
ID() int64
// Kind of the expression node. See ExprKind for the valid enum values.
Kind() ExprKind
// Type of the expression node.
Type() *types.Type
// Parent returns the parent expression node, if one exists.
Parent() (NavigableExpr, bool)
// Children returns a list of child expression nodes.
Children() []NavigableExpr
// ToExpr adapts this NavigableExpr to a protobuf representation.
ToExpr() *exprpb.Expr
// AsCall adapts the expr into a NavigableCallExpr
//
// The Kind() must be equal to a CallKind for the conversion to be well-defined.
AsCall() NavigableCallExpr
// AsComprehension adapts the expr into a NavigableComprehensionExpr.
//
// The Kind() must be equal to a ComprehensionKind for the conversion to be well-defined.
AsComprehension() NavigableComprehensionExpr
// AsIdent adapts the expr into an identifier string.
//
// The Kind() must be equal to an IdentKind for the conversion to be well-defined.
AsIdent() string
// AsLiteral adapts the expr into a constant ref.Val.
//
// The Kind() must be equal to a LiteralKind for the conversion to be well-defined.
AsLiteral() ref.Val
// AsList adapts the expr into a NavigableListExpr.
//
// The Kind() must be equal to a ListKind for the conversion to be well-defined.
AsList() NavigableListExpr
// AsMap adapts the expr into a NavigableMapExpr.
//
// The Kind() must be equal to a MapKind for the conversion to be well-defined.
AsMap() NavigableMapExpr
// AsSelect adapts the expr into a NavigableSelectExpr.
//
// The Kind() must be equal to a SelectKind for the conversion to be well-defined.
AsSelect() NavigableSelectExpr
// AsStruct adapts the expr into a NavigableStructExpr.
//
// The Kind() must be equal to a StructKind for the conversion to be well-defined.
AsStruct() NavigableStructExpr
// marker interface method
isNavigable()
}
// NavigableCallExpr defines an interface for inspecting a function call and its arugments.
type NavigableCallExpr interface {
// FunctionName returns the name of the function.
FunctionName() string
// Target returns the target of the expression if one is present.
Target() NavigableExpr
// Args returns the list of call arguments, excluding the target.
Args() []NavigableExpr
// ReturnType returns the result type of the call.
ReturnType() *types.Type
// marker interface method
isNavigable()
}
// NavigableListExpr defines an interface for inspecting a list literal expression.
type NavigableListExpr interface {
// Elements returns the list elements as navigable expressions.
Elements() []NavigableExpr
// OptionalIndicies returns the list of optional indices in the list literal.
OptionalIndices() []int32
// Size returns the number of elements in the list.
Size() int
// marker interface method
isNavigable()
}
// NavigableSelectExpr defines an interface for inspecting a select expression.
type NavigableSelectExpr interface {
// Operand returns the selection operand expression.
Operand() NavigableExpr
// FieldName returns the field name being selected from the operand.
FieldName() string
// IsTestOnly indicates whether the select expression is a presence test generated by a macro.
IsTestOnly() bool
// marker interface method
isNavigable()
}
// NavigableMapExpr defines an interface for inspecting a map expression.
type NavigableMapExpr interface {
// Entries returns the map key value pairs as NavigableEntry values.
Entries() []NavigableEntry
// Size returns the number of entries in the map.
Size() int
// marker interface method
isNavigable()
}
// NavigableEntry defines an interface for inspecting a map entry.
type NavigableEntry interface {
// Key returns the map entry key expression.
Key() NavigableExpr
// Value returns the map entry value expression.
Value() NavigableExpr
// IsOptional returns whether the entry is optional.
IsOptional() bool
// marker interface method
isNavigable()
}
// NavigableStructExpr defines an interfaces for inspecting a struct and its field initializers.
type NavigableStructExpr interface {
// TypeName returns the struct type name.
TypeName() string
// Fields returns the set of field initializers in the struct expression as NavigableField values.
Fields() []NavigableField
// marker interface method
isNavigable()
}
// NavigableField defines an interface for inspecting a struct field initialization.
type NavigableField interface {
// FieldName returns the name of the field.
FieldName() string
// Value returns the field initialization expression.
Value() NavigableExpr
// IsOptional returns whether the field is optional.
IsOptional() bool
// marker interface method
isNavigable()
}
// NavigableComprehensionExpr defines an interface for inspecting a comprehension expression.
type NavigableComprehensionExpr interface {
// IterRange returns the iteration range expression.
IterRange() NavigableExpr
// IterVar returns the iteration variable name.
IterVar() string
// AccuVar returns the accumulation variable name.
AccuVar() string
// AccuInit returns the accumulation variable initialization expression.
AccuInit() NavigableExpr
// LoopCondition returns the loop condition expression.
LoopCondition() NavigableExpr
// LoopStep returns the loop step expression.
LoopStep() NavigableExpr
// Result returns the comprehension result expression.
Result() NavigableExpr
// marker interface method
isNavigable()
}
func newNavigableExpr(parent NavigableExpr, expr *exprpb.Expr, typeMap map[int64]*types.Type) NavigableExpr {
kind, factory := kindOf(expr)
nav := &navigableExprImpl{
parent: parent,
kind: kind,
expr: expr,
typeMap: typeMap,
createChildren: factory,
}
return nav
}
type navigableExprImpl struct {
parent NavigableExpr
kind ExprKind
expr *exprpb.Expr
typeMap map[int64]*types.Type
createChildren childFactory
}
func (nav *navigableExprImpl) ID() int64 {
return nav.ToExpr().GetId()
}
func (nav *navigableExprImpl) Kind() ExprKind {
return nav.kind
}
func (nav *navigableExprImpl) Type() *types.Type {
if t, found := nav.typeMap[nav.ID()]; found {
return t
}
return types.DynType
}
func (nav *navigableExprImpl) Parent() (NavigableExpr, bool) {
if nav.parent != nil {
return nav.parent, true
}
return nil, false
}
func (nav *navigableExprImpl) Children() []NavigableExpr {
return nav.createChildren(nav)
}
func (nav *navigableExprImpl) ToExpr() *exprpb.Expr {
return nav.expr
}
func (nav *navigableExprImpl) AsCall() NavigableCallExpr {
return navigableCallImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsComprehension() NavigableComprehensionExpr {
return navigableComprehensionImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsIdent() string {
return nav.ToExpr().GetIdentExpr().GetName()
}
func (nav *navigableExprImpl) AsLiteral() ref.Val {
if nav.Kind() != LiteralKind {
return nil
}
val, err := ConstantToVal(nav.ToExpr().GetConstExpr())
if err != nil {
panic(err)
}
return val
}
func (nav *navigableExprImpl) AsList() NavigableListExpr {
return navigableListImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsMap() NavigableMapExpr {
return navigableMapImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsSelect() NavigableSelectExpr {
return navigableSelectImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) AsStruct() NavigableStructExpr {
return navigableStructImpl{navigableExprImpl: nav}
}
func (nav *navigableExprImpl) createChild(e *exprpb.Expr) NavigableExpr {
return newNavigableExpr(nav, e, nav.typeMap)
}
func (nav *navigableExprImpl) isNavigable() {}
type navigableCallImpl struct {
*navigableExprImpl
}
func (call navigableCallImpl) FunctionName() string {
return call.ToExpr().GetCallExpr().GetFunction()
}
func (call navigableCallImpl) Target() NavigableExpr {
t := call.ToExpr().GetCallExpr().GetTarget()
if t != nil {
return call.createChild(t)
}
return nil
}
func (call navigableCallImpl) Args() []NavigableExpr {
args := call.ToExpr().GetCallExpr().GetArgs()
navArgs := make([]NavigableExpr, len(args))
for i, a := range args {
navArgs[i] = call.createChild(a)
}
return navArgs
}
func (call navigableCallImpl) ReturnType() *types.Type {
return call.Type()
}
type navigableComprehensionImpl struct {
*navigableExprImpl
}
func (comp navigableComprehensionImpl) IterRange() NavigableExpr {
return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetIterRange())
}
func (comp navigableComprehensionImpl) IterVar() string {
return comp.ToExpr().GetComprehensionExpr().GetIterVar()
}
func (comp navigableComprehensionImpl) AccuVar() string {
return comp.ToExpr().GetComprehensionExpr().GetAccuVar()
}
func (comp navigableComprehensionImpl) AccuInit() NavigableExpr {
return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetAccuInit())
}
func (comp navigableComprehensionImpl) LoopCondition() NavigableExpr {
return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetLoopCondition())
}
func (comp navigableComprehensionImpl) LoopStep() NavigableExpr {
return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetLoopStep())
}
func (comp navigableComprehensionImpl) Result() NavigableExpr {
return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetResult())
}
type navigableListImpl struct {
*navigableExprImpl
}
func (l navigableListImpl) Elements() []NavigableExpr {
return l.Children()
}
func (l navigableListImpl) OptionalIndices() []int32 {
return l.ToExpr().GetListExpr().GetOptionalIndices()
}
func (l navigableListImpl) Size() int {
return len(l.ToExpr().GetListExpr().GetElements())
}
type navigableMapImpl struct {
*navigableExprImpl
}
func (m navigableMapImpl) Entries() []NavigableEntry {
mapExpr := m.ToExpr().GetStructExpr()
entries := make([]NavigableEntry, len(mapExpr.GetEntries()))
for i, e := range mapExpr.GetEntries() {
entries[i] = navigableEntryImpl{
key: m.createChild(e.GetMapKey()),
val: m.createChild(e.GetValue()),
isOpt: e.GetOptionalEntry(),
}
}
return entries
}
func (m navigableMapImpl) Size() int {
return len(m.ToExpr().GetStructExpr().GetEntries())
}
type navigableEntryImpl struct {
key NavigableExpr
val NavigableExpr
isOpt bool
}
func (e navigableEntryImpl) Key() NavigableExpr {
return e.key
}
func (e navigableEntryImpl) Value() NavigableExpr {
return e.val
}
func (e navigableEntryImpl) IsOptional() bool {
return e.isOpt
}
func (e navigableEntryImpl) isNavigable() {}
type navigableSelectImpl struct {
*navigableExprImpl
}
func (sel navigableSelectImpl) FieldName() string {
return sel.ToExpr().GetSelectExpr().GetField()
}
func (sel navigableSelectImpl) IsTestOnly() bool {
return sel.ToExpr().GetSelectExpr().GetTestOnly()
}
func (sel navigableSelectImpl) Operand() NavigableExpr {
return sel.createChild(sel.ToExpr().GetSelectExpr().GetOperand())
}
type navigableStructImpl struct {
*navigableExprImpl
}
func (s navigableStructImpl) TypeName() string {
return s.ToExpr().GetStructExpr().GetMessageName()
}
func (s navigableStructImpl) Fields() []NavigableField {
fieldInits := s.ToExpr().GetStructExpr().GetEntries()
fields := make([]NavigableField, len(fieldInits))
for i, f := range fieldInits {
fields[i] = navigableFieldImpl{
name: f.GetFieldKey(),
val: s.createChild(f.GetValue()),
isOpt: f.GetOptionalEntry(),
}
}
return fields
}
type navigableFieldImpl struct {
name string
val NavigableExpr
isOpt bool
}
func (f navigableFieldImpl) FieldName() string {
return f.name
}
func (f navigableFieldImpl) Value() NavigableExpr {
return f.val
}
func (f navigableFieldImpl) IsOptional() bool {
return f.isOpt
}
func (f navigableFieldImpl) isNavigable() {}
func kindOf(expr *exprpb.Expr) (ExprKind, childFactory) {
switch expr.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
return LiteralKind, noopFactory
case *exprpb.Expr_IdentExpr:
return IdentKind, noopFactory
case *exprpb.Expr_SelectExpr:
return SelectKind, selectFactory
case *exprpb.Expr_CallExpr:
return CallKind, callArgFactory
case *exprpb.Expr_ListExpr:
return ListKind, listElemFactory
case *exprpb.Expr_StructExpr:
if expr.GetStructExpr().GetMessageName() != "" {
return StructKind, structEntryFactory
}
return MapKind, mapEntryFactory
case *exprpb.Expr_ComprehensionExpr:
return ComprehensionKind, comprehensionFactory
default:
return UnspecifiedKind, noopFactory
}
}
type childFactory func(*navigableExprImpl) []NavigableExpr
func noopFactory(*navigableExprImpl) []NavigableExpr {
return nil
}
func selectFactory(nav *navigableExprImpl) []NavigableExpr {
return []NavigableExpr{
nav.createChild(nav.ToExpr().GetSelectExpr().GetOperand()),
}
}
func callArgFactory(nav *navigableExprImpl) []NavigableExpr {
call := nav.ToExpr().GetCallExpr()
argCount := len(call.GetArgs())
if call.GetTarget() != nil {
argCount++
}
navExprs := make([]NavigableExpr, argCount)
i := 0
if call.GetTarget() != nil {
navExprs[i] = nav.createChild(call.GetTarget())
i++
}
for _, arg := range call.GetArgs() {
navExprs[i] = nav.createChild(arg)
i++
}
return navExprs
}
func listElemFactory(nav *navigableExprImpl) []NavigableExpr {
l := nav.ToExpr().GetListExpr()
navExprs := make([]NavigableExpr, len(l.GetElements()))
for i, e := range l.GetElements() {
navExprs[i] = nav.createChild(e)
}
return navExprs
}
func structEntryFactory(nav *navigableExprImpl) []NavigableExpr {
s := nav.ToExpr().GetStructExpr()
entries := make([]NavigableExpr, len(s.GetEntries()))
for i, e := range s.GetEntries() {
entries[i] = nav.createChild(e.GetValue())
}
return entries
}
func mapEntryFactory(nav *navigableExprImpl) []NavigableExpr {
s := nav.ToExpr().GetStructExpr()
entries := make([]NavigableExpr, len(s.GetEntries())*2)
j := 0
for _, e := range s.GetEntries() {
entries[j] = nav.createChild(e.GetMapKey())
entries[j+1] = nav.createChild(e.GetValue())
j += 2
}
return entries
}
func comprehensionFactory(nav *navigableExprImpl) []NavigableExpr {
compre := nav.ToExpr().GetComprehensionExpr()
return []NavigableExpr{
nav.createChild(compre.GetIterRange()),
nav.createChild(compre.GetAccuInit()),
nav.createChild(compre.GetLoopCondition()),
nav.createChild(compre.GetLoopStep()),
nav.createChild(compre.GetResult()),
}
}

View File

@@ -12,7 +12,7 @@ go_library(
],
importpath = "github.com/google/cel-go/common/containers",
deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
],
)
@@ -26,6 +26,6 @@ go_test(
":go_default_library",
],
deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
],
)

View File

@@ -13,6 +13,6 @@ go_library(
importpath = "github.com/google/cel-go/common/debug",
deps = [
"//common:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
],
)

View File

@@ -29,7 +29,7 @@ import (
// representation of an expression.
type Adorner interface {
// GetMetadata for the input context.
GetMetadata(ctx interface{}) string
GetMetadata(ctx any) string
}
// Writer manages writing expressions to an internal string.
@@ -46,7 +46,7 @@ type emptyDebugAdorner struct {
var emptyAdorner Adorner = &emptyDebugAdorner{}
func (a *emptyDebugAdorner) GetMetadata(e interface{}) string {
func (a *emptyDebugAdorner) GetMetadata(e any) string {
return ""
}
@@ -170,6 +170,9 @@ func (w *debugWriter) appendObject(obj *exprpb.Expr_CreateStruct) {
w.append(",")
w.appendLine()
}
if entry.GetOptionalEntry() {
w.append("?")
}
w.append(entry.GetFieldKey())
w.append(":")
w.Buffer(entry.GetValue())
@@ -191,6 +194,9 @@ func (w *debugWriter) appendMap(obj *exprpb.Expr_CreateStruct) {
w.append(",")
w.appendLine()
}
if entry.GetOptionalEntry() {
w.append("?")
}
w.Buffer(entry.GetMapKey())
w.append(":")
w.Buffer(entry.GetValue())
@@ -269,7 +275,7 @@ func (w *debugWriter) append(s string) {
w.buffer.WriteString(s)
}
func (w *debugWriter) appendFormat(f string, args ...interface{}) {
func (w *debugWriter) appendFormat(f string, args ...any) {
w.append(fmt.Sprintf(f, args...))
}
@@ -280,7 +286,7 @@ func (w *debugWriter) doIndent() {
}
}
func (w *debugWriter) adorn(e interface{}) {
func (w *debugWriter) adorn(e any) {
w.append(w.adorner.GetMetadata(e))
}

View File

@@ -0,0 +1,39 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
go_library(
name = "go_default_library",
srcs = [
"decls.go",
],
importpath = "github.com/google/cel-go/common/decls",
deps = [
"//checker/decls:go_default_library",
"//common/functions:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = [
"decls_test.go",
],
embed = [":go_default_library"],
deps = [
"//checker/decls:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
)

844
vendor/github.com/google/cel-go/common/decls/decls.go generated vendored Normal file
View File

@@ -0,0 +1,844 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package decls contains function and variable declaration structs and helper methods.
package decls
import (
"fmt"
"strings"
chkdecls "github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// NewFunction creates a new function declaration with a set of function options to configure overloads
// and function definitions (implementations).
//
// Functions are checked for name collisions and singleton redefinition.
func NewFunction(name string, opts ...FunctionOpt) (*FunctionDecl, error) {
fn := &FunctionDecl{
name: name,
overloads: map[string]*OverloadDecl{},
overloadOrdinals: []string{},
}
var err error
for _, opt := range opts {
fn, err = opt(fn)
if err != nil {
return nil, err
}
}
if len(fn.overloads) == 0 {
return nil, fmt.Errorf("function %s must have at least one overload", name)
}
return fn, nil
}
// FunctionDecl defines a function name, overload set, and optionally a singleton definition for all
// overload instances.
type FunctionDecl struct {
name string
// overloads associated with the function name.
overloads map[string]*OverloadDecl
// singleton implementation of the function for all overloads.
//
// If this option is set, an error will occur if any overloads specify a per-overload implementation
// or if another function with the same name attempts to redefine the singleton.
singleton *functions.Overload
// disableTypeGuards is a performance optimization to disable detailed runtime type checks which could
// add overhead on common operations. Setting this option true leaves error checks and argument checks
// intact.
disableTypeGuards bool
// state indicates that the binding should be provided as a declaration, as a runtime binding, or both.
state declarationState
// overloadOrdinals indicates the order in which the overload was declared.
overloadOrdinals []string
}
type declarationState int
const (
declarationStateUnset declarationState = iota
declarationDisabled
declarationEnabled
)
// Name returns the function name in human-readable terms, e.g. 'contains' of 'math.least'
func (f *FunctionDecl) Name() string {
if f == nil {
return ""
}
return f.name
}
// IsDeclarationDisabled indicates that the function implementation should be added to the dispatcher, but the
// declaration should not be exposed for use in expressions.
func (f *FunctionDecl) IsDeclarationDisabled() bool {
return f.state == declarationDisabled
}
// Merge combines an existing function declaration with another.
//
// If a function is extended, by say adding new overloads to an existing function, then it is merged with the
// prior definition of the function at which point its overloads must not collide with pre-existing overloads
// and its bindings (singleton, or per-overload) must not conflict with previous definitions either.
func (f *FunctionDecl) Merge(other *FunctionDecl) (*FunctionDecl, error) {
if f == other {
return f, nil
}
if f.Name() != other.Name() {
return nil, fmt.Errorf("cannot merge unrelated functions. %s and %s", f.Name(), other.Name())
}
merged := &FunctionDecl{
name: f.Name(),
overloads: make(map[string]*OverloadDecl, len(f.overloads)),
singleton: f.singleton,
overloadOrdinals: make([]string, len(f.overloads)),
// if one function is expecting type-guards and the other is not, then they
// must not be disabled.
disableTypeGuards: f.disableTypeGuards && other.disableTypeGuards,
// default to the current functions declaration state.
state: f.state,
}
// If the other state indicates that the declaration should be explicitly enabled or
// disabled, then update the merged state with the most recent value.
if other.state != declarationStateUnset {
merged.state = other.state
}
// baseline copy of the overloads and their ordinals
copy(merged.overloadOrdinals, f.overloadOrdinals)
for oID, o := range f.overloads {
merged.overloads[oID] = o
}
// overloads and their ordinals are added from the left
for _, oID := range other.overloadOrdinals {
o := other.overloads[oID]
err := merged.AddOverload(o)
if err != nil {
return nil, fmt.Errorf("function declaration merge failed: %v", err)
}
}
if other.singleton != nil {
if merged.singleton != nil && merged.singleton != other.singleton {
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
}
merged.singleton = other.singleton
}
return merged, nil
}
// AddOverload ensures that the new overload does not collide with an existing overload signature;
// however, if the function signatures are identical, the implementation may be rewritten as its
// difficult to compare functions by object identity.
func (f *FunctionDecl) AddOverload(overload *OverloadDecl) error {
if f == nil {
return fmt.Errorf("nil function cannot add overload: %s", overload.ID())
}
for oID, o := range f.overloads {
if oID != overload.ID() && o.SignatureOverlaps(overload) {
return fmt.Errorf("overload signature collision in function %s: %s collides with %s", f.Name(), oID, overload.ID())
}
if oID == overload.ID() {
if o.SignatureEquals(overload) && o.IsNonStrict() == overload.IsNonStrict() {
// Allow redefinition of an overload implementation so long as the signatures match.
f.overloads[oID] = overload
return nil
}
return fmt.Errorf("overload redefinition in function. %s: %s has multiple definitions", f.Name(), oID)
}
}
f.overloadOrdinals = append(f.overloadOrdinals, overload.ID())
f.overloads[overload.ID()] = overload
return nil
}
// OverloadDecls returns the overload declarations in the order in which they were declared.
func (f *FunctionDecl) OverloadDecls() []*OverloadDecl {
if f == nil {
return []*OverloadDecl{}
}
overloads := make([]*OverloadDecl, 0, len(f.overloads))
for _, oID := range f.overloadOrdinals {
overloads = append(overloads, f.overloads[oID])
}
return overloads
}
// Bindings produces a set of function bindings, if any are defined.
func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) {
if f == nil {
return []*functions.Overload{}, nil
}
overloads := []*functions.Overload{}
nonStrict := false
for _, oID := range f.overloadOrdinals {
o := f.overloads[oID]
if o.hasBinding() {
overload := &functions.Overload{
Operator: o.ID(),
Unary: o.guardedUnaryOp(f.Name(), f.disableTypeGuards),
Binary: o.guardedBinaryOp(f.Name(), f.disableTypeGuards),
Function: o.guardedFunctionOp(f.Name(), f.disableTypeGuards),
OperandTrait: o.OperandTrait(),
NonStrict: o.IsNonStrict(),
}
overloads = append(overloads, overload)
nonStrict = nonStrict || o.IsNonStrict()
}
}
if f.singleton != nil {
if len(overloads) != 0 {
return nil, fmt.Errorf("singleton function incompatible with specialized overloads: %s", f.Name())
}
overloads = []*functions.Overload{
{
Operator: f.Name(),
Unary: f.singleton.Unary,
Binary: f.singleton.Binary,
Function: f.singleton.Function,
OperandTrait: f.singleton.OperandTrait,
},
}
// fall-through to return single overload case.
}
if len(overloads) == 0 {
return overloads, nil
}
// Single overload. Replicate an entry for it using the function name as well.
if len(overloads) == 1 {
if overloads[0].Operator == f.Name() {
return overloads, nil
}
return append(overloads, &functions.Overload{
Operator: f.Name(),
Unary: overloads[0].Unary,
Binary: overloads[0].Binary,
Function: overloads[0].Function,
NonStrict: overloads[0].NonStrict,
OperandTrait: overloads[0].OperandTrait,
}), nil
}
// All of the defined overloads are wrapped into a top-level function which
// performs dynamic dispatch to the proper overload based on the argument types.
bindings := append([]*functions.Overload{}, overloads...)
funcDispatch := func(args ...ref.Val) ref.Val {
for _, oID := range f.overloadOrdinals {
o := f.overloads[oID]
// During dynamic dispatch over multiple functions, signature agreement checks
// are preserved in order to assist with the function resolution step.
switch len(args) {
case 1:
if o.unaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
return o.unaryOp(args[0])
}
case 2:
if o.binaryOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
return o.binaryOp(args[0], args[1])
}
}
if o.functionOp != nil && o.matchesRuntimeSignature( /* disableTypeGuards=*/ false, args...) {
return o.functionOp(args...)
}
// eventually this will fall through to the noSuchOverload below.
}
return MaybeNoSuchOverload(f.Name(), args...)
}
function := &functions.Overload{
Operator: f.Name(),
Function: funcDispatch,
NonStrict: nonStrict,
}
return append(bindings, function), nil
}
// MaybeNoSuchOverload determines whether to propagate an error if one is provided as an argument, or
// to return an unknown set, or to produce a new error for a missing function signature.
func MaybeNoSuchOverload(funcName string, args ...ref.Val) ref.Val {
argTypes := make([]string, len(args))
var unk *types.Unknown = nil
for i, arg := range args {
if types.IsError(arg) {
return arg
}
if types.IsUnknown(arg) {
unk = types.MergeUnknowns(arg.(*types.Unknown), unk)
}
argTypes[i] = arg.Type().TypeName()
}
if unk != nil {
return unk
}
signature := strings.Join(argTypes, ", ")
return types.NewErr("no such overload: %s(%s)", funcName, signature)
}
// FunctionOpt defines a functional option for mutating a function declaration.
type FunctionOpt func(*FunctionDecl) (*FunctionDecl, error)
// DisableTypeGuards disables automatically generated function invocation guards on direct overload calls.
// Type guards remain on during dynamic dispatch for parsed-only expressions.
func DisableTypeGuards(value bool) FunctionOpt {
return func(fn *FunctionDecl) (*FunctionDecl, error) {
fn.disableTypeGuards = value
return fn, nil
}
}
// DisableDeclaration indicates that the function declaration should be disabled, but the runtime function
// binding should be provided. Marking a function as runtime-only is a safe way to manage deprecations
// of function declarations while still preserving the runtime behavior for previously compiled expressions.
func DisableDeclaration(value bool) FunctionOpt {
return func(fn *FunctionDecl) (*FunctionDecl, error) {
if value {
fn.state = declarationDisabled
} else {
fn.state = declarationEnabled
}
return fn, nil
}
}
// SingletonUnaryBinding creates a singleton function definition to be used for all function overloads.
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
func SingletonUnaryBinding(fn functions.UnaryOp, traits ...int) FunctionOpt {
trait := 0
for _, t := range traits {
trait = trait | t
}
return func(f *FunctionDecl) (*FunctionDecl, error) {
if f.singleton != nil {
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
}
f.singleton = &functions.Overload{
Operator: f.Name(),
Unary: fn,
OperandTrait: trait,
}
return f, nil
}
}
// SingletonBinaryBinding creates a singleton function definition to be used with all function overloads.
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
func SingletonBinaryBinding(fn functions.BinaryOp, traits ...int) FunctionOpt {
trait := 0
for _, t := range traits {
trait = trait | t
}
return func(f *FunctionDecl) (*FunctionDecl, error) {
if f.singleton != nil {
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
}
f.singleton = &functions.Overload{
Operator: f.Name(),
Binary: fn,
OperandTrait: trait,
}
return f, nil
}
}
// SingletonFunctionBinding creates a singleton function definition to be used with all function overloads.
//
// Note, this approach works well if operand is expected to have a specific trait which it implements,
// e.g. traits.ContainerType. Otherwise, prefer per-overload function bindings.
func SingletonFunctionBinding(fn functions.FunctionOp, traits ...int) FunctionOpt {
trait := 0
for _, t := range traits {
trait = trait | t
}
return func(f *FunctionDecl) (*FunctionDecl, error) {
if f.singleton != nil {
return nil, fmt.Errorf("function already has a singleton binding: %s", f.Name())
}
f.singleton = &functions.Overload{
Operator: f.Name(),
Function: fn,
OperandTrait: trait,
}
return f, nil
}
}
// Overload defines a new global overload with an overload id, argument types, and result type. Through the
// use of OverloadOpt options, the overload may also be configured with a binding, an operand trait, and to
// be non-strict.
//
// Note: function bindings should be commonly configured with Overload instances whereas operand traits and
// strict-ness should be rare occurrences.
func Overload(overloadID string,
args []*types.Type, resultType *types.Type,
opts ...OverloadOpt) FunctionOpt {
return newOverload(overloadID, false, args, resultType, opts...)
}
// MemberOverload defines a new receiver-style overload (or member function) with an overload id, argument types,
// and result type. Through the use of OverloadOpt options, the overload may also be configured with a binding,
// an operand trait, and to be non-strict.
//
// Note: function bindings should be commonly configured with Overload instances whereas operand traits and
// strict-ness should be rare occurrences.
func MemberOverload(overloadID string,
args []*types.Type, resultType *types.Type,
opts ...OverloadOpt) FunctionOpt {
return newOverload(overloadID, true, args, resultType, opts...)
}
func newOverload(overloadID string,
memberFunction bool, args []*types.Type, resultType *types.Type,
opts ...OverloadOpt) FunctionOpt {
return func(f *FunctionDecl) (*FunctionDecl, error) {
overload, err := newOverloadInternal(overloadID, memberFunction, args, resultType, opts...)
if err != nil {
return nil, err
}
err = f.AddOverload(overload)
if err != nil {
return nil, err
}
return f, nil
}
}
func newOverloadInternal(overloadID string,
memberFunction bool, args []*types.Type, resultType *types.Type,
opts ...OverloadOpt) (*OverloadDecl, error) {
overload := &OverloadDecl{
id: overloadID,
argTypes: args,
resultType: resultType,
isMemberFunction: memberFunction,
}
var err error
for _, opt := range opts {
overload, err = opt(overload)
if err != nil {
return nil, err
}
}
return overload, nil
}
// OverloadDecl contains the definition of a single overload id with a specific signature, and an optional
// implementation.
type OverloadDecl struct {
id string
argTypes []*types.Type
resultType *types.Type
isMemberFunction bool
// nonStrict indicates that the function will accept error and unknown arguments as inputs.
nonStrict bool
// operandTrait indicates whether the member argument should have a specific type-trait.
//
// This is useful for creating overloads which operate on a type-interface rather than a concrete type.
operandTrait int
// Function implementation options. Optional, but encouraged.
// unaryOp is a function binding that takes a single argument.
unaryOp functions.UnaryOp
// binaryOp is a function binding that takes two arguments.
binaryOp functions.BinaryOp
// functionOp is a catch-all for zero-arity and three-plus arity functions.
functionOp functions.FunctionOp
}
// ID mirrors the overload signature and provides a unique id which may be referenced within the type-checker
// and interpreter to optimize performance.
//
// The ID format is usually one of two styles:
// global: <functionName>_<argType>_<argTypeN>
// member: <memberType>_<functionName>_<argType>_<argTypeN>
func (o *OverloadDecl) ID() string {
if o == nil {
return ""
}
return o.id
}
// ArgTypes contains the set of argument types expected by the overload.
//
// For member functions ArgTypes[0] represents the member operand type.
func (o *OverloadDecl) ArgTypes() []*types.Type {
if o == nil {
return emptyArgs
}
return o.argTypes
}
// IsMemberFunction indicates whether the overload is a member function
func (o *OverloadDecl) IsMemberFunction() bool {
if o == nil {
return false
}
return o.isMemberFunction
}
// IsNonStrict returns whether the overload accepts errors and unknown values as arguments.
func (o *OverloadDecl) IsNonStrict() bool {
if o == nil {
return false
}
return o.nonStrict
}
// OperandTrait returns the trait mask of the first operand to the overload call, e.g.
// `traits.Indexer`
func (o *OverloadDecl) OperandTrait() int {
if o == nil {
return 0
}
return o.operandTrait
}
// ResultType indicates the output type from calling the function.
func (o *OverloadDecl) ResultType() *types.Type {
if o == nil {
// *types.Type is nil-safe
return nil
}
return o.resultType
}
// TypeParams returns the type parameter names associated with the overload.
func (o *OverloadDecl) TypeParams() []string {
typeParams := map[string]struct{}{}
collectParamNames(typeParams, o.ResultType())
for _, arg := range o.ArgTypes() {
collectParamNames(typeParams, arg)
}
params := make([]string, 0, len(typeParams))
for param := range typeParams {
params = append(params, param)
}
return params
}
// SignatureEquals determines whether the incoming overload declaration signature is equal to the current signature.
//
// Result type, operand trait, and strict-ness are not considered as part of signature equality.
func (o *OverloadDecl) SignatureEquals(other *OverloadDecl) bool {
if o == other {
return true
}
if o.ID() != other.ID() || o.IsMemberFunction() != other.IsMemberFunction() || len(o.ArgTypes()) != len(other.ArgTypes()) {
return false
}
for i, at := range o.ArgTypes() {
oat := other.ArgTypes()[i]
if !at.IsEquivalentType(oat) {
return false
}
}
return o.ResultType().IsEquivalentType(other.ResultType())
}
// SignatureOverlaps indicates whether two functions have non-equal, but overloapping function signatures.
//
// For example, list(dyn) collides with list(string) since the 'dyn' type can contain a 'string' type.
func (o *OverloadDecl) SignatureOverlaps(other *OverloadDecl) bool {
if o.IsMemberFunction() != other.IsMemberFunction() || len(o.ArgTypes()) != len(other.ArgTypes()) {
return false
}
argsOverlap := true
for i, argType := range o.ArgTypes() {
otherArgType := other.ArgTypes()[i]
argsOverlap = argsOverlap &&
(argType.IsAssignableType(otherArgType) ||
otherArgType.IsAssignableType(argType))
}
return argsOverlap
}
// hasBinding indicates whether the overload already has a definition.
func (o *OverloadDecl) hasBinding() bool {
return o != nil && (o.unaryOp != nil || o.binaryOp != nil || o.functionOp != nil)
}
// guardedUnaryOp creates an invocation guard around the provided unary operator, if one is defined.
func (o *OverloadDecl) guardedUnaryOp(funcName string, disableTypeGuards bool) functions.UnaryOp {
if o.unaryOp == nil {
return nil
}
return func(arg ref.Val) ref.Val {
if !o.matchesRuntimeUnarySignature(disableTypeGuards, arg) {
return MaybeNoSuchOverload(funcName, arg)
}
return o.unaryOp(arg)
}
}
// guardedBinaryOp creates an invocation guard around the provided binary operator, if one is defined.
func (o *OverloadDecl) guardedBinaryOp(funcName string, disableTypeGuards bool) functions.BinaryOp {
if o.binaryOp == nil {
return nil
}
return func(arg1, arg2 ref.Val) ref.Val {
if !o.matchesRuntimeBinarySignature(disableTypeGuards, arg1, arg2) {
return MaybeNoSuchOverload(funcName, arg1, arg2)
}
return o.binaryOp(arg1, arg2)
}
}
// guardedFunctionOp creates an invocation guard around the provided variadic function binding, if one is provided.
func (o *OverloadDecl) guardedFunctionOp(funcName string, disableTypeGuards bool) functions.FunctionOp {
if o.functionOp == nil {
return nil
}
return func(args ...ref.Val) ref.Val {
if !o.matchesRuntimeSignature(disableTypeGuards, args...) {
return MaybeNoSuchOverload(funcName, args...)
}
return o.functionOp(args...)
}
}
// matchesRuntimeUnarySignature indicates whether the argument type is runtime assiganble to the overload's expected argument.
func (o *OverloadDecl) matchesRuntimeUnarySignature(disableTypeGuards bool, arg ref.Val) bool {
return matchRuntimeArgType(o.IsNonStrict(), disableTypeGuards, o.ArgTypes()[0], arg) &&
matchOperandTrait(o.OperandTrait(), arg)
}
// matchesRuntimeBinarySignature indicates whether the argument types are runtime assiganble to the overload's expected arguments.
func (o *OverloadDecl) matchesRuntimeBinarySignature(disableTypeGuards bool, arg1, arg2 ref.Val) bool {
return matchRuntimeArgType(o.IsNonStrict(), disableTypeGuards, o.ArgTypes()[0], arg1) &&
matchRuntimeArgType(o.IsNonStrict(), disableTypeGuards, o.ArgTypes()[1], arg2) &&
matchOperandTrait(o.OperandTrait(), arg1)
}
// matchesRuntimeSignature indicates whether the argument types are runtime assiganble to the overload's expected arguments.
func (o *OverloadDecl) matchesRuntimeSignature(disableTypeGuards bool, args ...ref.Val) bool {
if len(args) != len(o.ArgTypes()) {
return false
}
if len(args) == 0 {
return true
}
for i, arg := range args {
if !matchRuntimeArgType(o.IsNonStrict(), disableTypeGuards, o.ArgTypes()[i], arg) {
return false
}
}
return matchOperandTrait(o.OperandTrait(), args[0])
}
func matchRuntimeArgType(nonStrict, disableTypeGuards bool, argType *types.Type, arg ref.Val) bool {
if nonStrict && (disableTypeGuards || types.IsUnknownOrError(arg)) {
return true
}
if types.IsUnknownOrError(arg) {
return false
}
return disableTypeGuards || argType.IsAssignableRuntimeType(arg)
}
func matchOperandTrait(trait int, arg ref.Val) bool {
return trait == 0 || arg.Type().HasTrait(trait) || types.IsUnknownOrError(arg)
}
// OverloadOpt is a functional option for configuring a function overload.
type OverloadOpt func(*OverloadDecl) (*OverloadDecl, error)
// UnaryBinding provides the implementation of a unary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
}
if len(o.ArgTypes()) != 1 {
return nil, fmt.Errorf("unary function bound to non-unary overload: %s", o.ID())
}
o.unaryOp = binding
return o, nil
}
}
// BinaryBinding provides the implementation of a binary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
}
if len(o.ArgTypes()) != 2 {
return nil, fmt.Errorf("binary function bound to non-binary overload: %s", o.ID())
}
o.binaryOp = binding
return o, nil
}
}
// FunctionBinding provides the implementation of a variadic overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
if o.hasBinding() {
return nil, fmt.Errorf("overload already has a binding: %s", o.ID())
}
o.functionOp = binding
return o, nil
}
}
// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
//
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.
func OverloadIsNonStrict() OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
o.nonStrict = true
return o, nil
}
}
// OverloadOperandTrait configures a set of traits which the first argument to the overload must implement in order to be
// successfully invoked.
func OverloadOperandTrait(trait int) OverloadOpt {
return func(o *OverloadDecl) (*OverloadDecl, error) {
o.operandTrait = trait
return o, nil
}
}
// NewConstant creates a new constant declaration.
func NewConstant(name string, t *types.Type, v ref.Val) *VariableDecl {
return &VariableDecl{name: name, varType: t, value: v}
}
// NewVariable creates a new variable declaration.
func NewVariable(name string, t *types.Type) *VariableDecl {
return &VariableDecl{name: name, varType: t}
}
// VariableDecl defines a variable declaration which may optionally have a constant value.
type VariableDecl struct {
name string
varType *types.Type
value ref.Val
}
// Name returns the fully-qualified variable name
func (v *VariableDecl) Name() string {
if v == nil {
return ""
}
return v.name
}
// Type returns the types.Type value associated with the variable.
func (v *VariableDecl) Type() *types.Type {
if v == nil {
// types.Type is nil-safe
return nil
}
return v.varType
}
// Value returns the constant value associated with the declaration.
func (v *VariableDecl) Value() ref.Val {
if v == nil {
return nil
}
return v.value
}
// DeclarationIsEquivalent returns true if one variable declaration has the same name and same type as the input.
func (v *VariableDecl) DeclarationIsEquivalent(other *VariableDecl) bool {
if v == other {
return true
}
return v.Name() == other.Name() && v.Type().IsEquivalentType(other.Type())
}
// VariableDeclToExprDecl converts a go-native variable declaration into a protobuf-type variable declaration.
func VariableDeclToExprDecl(v *VariableDecl) (*exprpb.Decl, error) {
varType, err := types.TypeToExprType(v.Type())
if err != nil {
return nil, err
}
return chkdecls.NewVar(v.Name(), varType), nil
}
// TypeVariable creates a new type identifier for use within a types.Provider
func TypeVariable(t *types.Type) *VariableDecl {
return NewVariable(t.TypeName(), types.NewTypeTypeWithParam(t))
}
// FunctionDeclToExprDecl converts a go-native function declaration into a protobuf-typed function declaration.
func FunctionDeclToExprDecl(f *FunctionDecl) (*exprpb.Decl, error) {
overloads := make([]*exprpb.Decl_FunctionDecl_Overload, len(f.overloads))
for i, oID := range f.overloadOrdinals {
o := f.overloads[oID]
paramNames := map[string]struct{}{}
argTypes := make([]*exprpb.Type, len(o.ArgTypes()))
for j, a := range o.ArgTypes() {
collectParamNames(paramNames, a)
at, err := types.TypeToExprType(a)
if err != nil {
return nil, err
}
argTypes[j] = at
}
collectParamNames(paramNames, o.ResultType())
resultType, err := types.TypeToExprType(o.ResultType())
if err != nil {
return nil, err
}
if len(paramNames) == 0 {
if o.IsMemberFunction() {
overloads[i] = chkdecls.NewInstanceOverload(oID, argTypes, resultType)
} else {
overloads[i] = chkdecls.NewOverload(oID, argTypes, resultType)
}
} else {
params := []string{}
for pn := range paramNames {
params = append(params, pn)
}
if o.IsMemberFunction() {
overloads[i] = chkdecls.NewParameterizedInstanceOverload(oID, argTypes, resultType, params)
} else {
overloads[i] = chkdecls.NewParameterizedOverload(oID, argTypes, resultType, params)
}
}
}
return chkdecls.NewFunction(f.Name(), overloads...), nil
}
func collectParamNames(paramNames map[string]struct{}, arg *types.Type) {
if arg.Kind() == types.TypeParamKind {
paramNames[arg.TypeName()] = struct{}{}
}
for _, param := range arg.Parameters() {
collectParamNames(paramNames, param)
}
}
var (
emptyArgs = []*types.Type{}
)

View File

@@ -22,10 +22,16 @@ import (
"golang.org/x/text/width"
)
// Error type which references a location within source and a message.
// NewError creates an error associated with an expression id with the given message at the given location.
func NewError(id int64, message string, location Location) *Error {
return &Error{Message: message, Location: location, ExprID: id}
}
// Error type which references an expression id, a location within source, and a message.
type Error struct {
Location Location
Message string
ExprID int64
}
const (

View File

@@ -22,7 +22,7 @@ import (
// Errors type which contains a list of errors observed during parsing.
type Errors struct {
errors []Error
errors []*Error
source Source
numErrors int
maxErrorsToReport int
@@ -31,19 +31,25 @@ type Errors struct {
// NewErrors creates a new instance of the Errors type.
func NewErrors(source Source) *Errors {
return &Errors{
errors: []Error{},
errors: []*Error{},
source: source,
maxErrorsToReport: 100,
}
}
// ReportError records an error at a source location.
func (e *Errors) ReportError(l Location, format string, args ...interface{}) {
func (e *Errors) ReportError(l Location, format string, args ...any) {
e.ReportErrorAtID(0, l, format, args...)
}
// ReportErrorAtID records an error at a source location and expression id.
func (e *Errors) ReportErrorAtID(id int64, l Location, format string, args ...any) {
e.numErrors++
if e.numErrors > e.maxErrorsToReport {
return
}
err := Error{
err := &Error{
ExprID: id,
Location: l,
Message: fmt.Sprintf(format, args...),
}
@@ -51,12 +57,12 @@ func (e *Errors) ReportError(l Location, format string, args ...interface{}) {
}
// GetErrors returns the list of observed errors.
func (e *Errors) GetErrors() []Error {
func (e *Errors) GetErrors() []*Error {
return e.errors[:]
}
// Append creates a new Errors object with the current and input errors.
func (e *Errors) Append(errs []Error) *Errors {
func (e *Errors) Append(errs []*Error) *Errors {
return &Errors{
errors: append(e.errors, errs...),
source: e.source,

View File

@@ -0,0 +1,17 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
go_library(
name = "go_default_library",
srcs = [
"functions.go",
],
importpath = "github.com/google/cel-go/common/functions",
deps = [
"//common/types/ref:go_default_library",
],
)

View File

@@ -0,0 +1,61 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package functions defines the standard builtin functions supported by the interpreter
package functions
import "github.com/google/cel-go/common/types/ref"
// Overload defines a named overload of a function, indicating an operand trait
// which must be present on the first argument to the overload as well as one
// of either a unary, binary, or function implementation.
//
// The majority of operators within the expression language are unary or binary
// and the specializations simplify the call contract for implementers of
// types with operator overloads. Any added complexity is assumed to be handled
// by the generic FunctionOp.
type Overload struct {
// Operator name as written in an expression or defined within
// operators.go.
Operator string
// Operand trait used to dispatch the call. The zero-value indicates a
// global function overload or that one of the Unary / Binary / Function
// definitions should be used to execute the call.
OperandTrait int
// Unary defines the overload with a UnaryOp implementation. May be nil.
Unary UnaryOp
// Binary defines the overload with a BinaryOp implementation. May be nil.
Binary BinaryOp
// Function defines the overload with a FunctionOp implementation. May be
// nil.
Function FunctionOp
// NonStrict specifies whether the Overload will tolerate arguments that
// are types.Err or types.Unknown.
NonStrict bool
}
// UnaryOp is a function that takes a single value and produces an output.
type UnaryOp func(value ref.Val) ref.Val
// BinaryOp is a function that takes two values and produces an output.
type BinaryOp func(lhs ref.Val, rhs ref.Val) ref.Val
// FunctionOp is a function with accepts zero or more arguments and produces
// a value or error as a result.
type FunctionOp func(values ...ref.Val) ref.Val

View File

@@ -37,6 +37,8 @@ const (
Modulo = "_%_"
Negate = "-_"
Index = "_[_]"
OptIndex = "_[?_]"
OptSelect = "_?._"
// Macros, must have a valid identifier.
Has = "has"
@@ -99,6 +101,8 @@ var (
LogicalNot: {displayName: "!", precedence: 2, arity: 1},
Negate: {displayName: "-", precedence: 2, arity: 1},
Index: {displayName: "", precedence: 1, arity: 2},
OptIndex: {displayName: "", precedence: 1, arity: 2},
OptSelect: {displayName: "", precedence: 1, arity: 2},
}
)

View File

@@ -148,6 +148,11 @@ const (
StartsWith = "startsWith"
)
// Extension function overloads with complex behaviors that need to be referenced in runtime and static analysis cost computations.
const (
ExtQuoteString = "strings_quote"
)
// String function overload names.
const (
ContainsString = "contains_string"
@@ -156,6 +161,11 @@ const (
StartsWithString = "starts_with_string"
)
// Extension function overloads with complex behaviors that need to be referenced in runtime and static analysis cost computations.
const (
ExtFormatString = "string_format"
)
// Time-based functions.
const (
TimeGetFullYear = "getFullYear"

View File

@@ -64,7 +64,6 @@ type sourceImpl struct {
runes.Buffer
description string
lineOffsets []int32
idOffsets map[int64]int32
}
var _ runes.Buffer = &sourceImpl{}
@@ -92,7 +91,6 @@ func NewStringSource(contents string, description string) Source {
Buffer: runes.NewBuffer(contents),
description: description,
lineOffsets: offsets,
idOffsets: map[int64]int32{},
}
}
@@ -102,7 +100,6 @@ func NewInfoSource(info *exprpb.SourceInfo) Source {
Buffer: runes.NewBuffer(""),
description: info.GetLocation(),
lineOffsets: info.GetLineOffsets(),
idOffsets: info.GetPositions(),
}
}

View File

@@ -0,0 +1,25 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
go_library(
name = "go_default_library",
srcs = [
"standard.go",
],
importpath = "github.com/google/cel-go/common/stdlib",
deps = [
"//checker/decls:go_default_library",
"//common/decls:go_default_library",
"//common/functions:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
],
)

View File

@@ -0,0 +1,661 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package stdlib contains all of the standard library function declarations and definitions for CEL.
package stdlib
import (
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
var (
stdFunctions []*decls.FunctionDecl
stdFnDecls []*exprpb.Decl
stdTypes []*decls.VariableDecl
stdTypeDecls []*exprpb.Decl
)
func init() {
paramA := types.NewTypeParamType("A")
paramB := types.NewTypeParamType("B")
listOfA := types.NewListType(paramA)
mapOfAB := types.NewMapType(paramA, paramB)
stdTypes = []*decls.VariableDecl{
decls.TypeVariable(types.BoolType),
decls.TypeVariable(types.BytesType),
decls.TypeVariable(types.DoubleType),
decls.TypeVariable(types.DurationType),
decls.TypeVariable(types.IntType),
decls.TypeVariable(listOfA),
decls.TypeVariable(mapOfAB),
decls.TypeVariable(types.NullType),
decls.TypeVariable(types.StringType),
decls.TypeVariable(types.TimestampType),
decls.TypeVariable(types.TypeType),
decls.TypeVariable(types.UintType),
}
stdTypeDecls = make([]*exprpb.Decl, 0, len(stdTypes))
for _, stdType := range stdTypes {
typeVar, err := decls.VariableDeclToExprDecl(stdType)
if err != nil {
panic(err)
}
stdTypeDecls = append(stdTypeDecls, typeVar)
}
stdFunctions = []*decls.FunctionDecl{
// Logical operators. Special-cased within the interpreter.
// Note, the singleton binding prevents extensions from overriding the operator behavior.
function(operators.Conditional,
decls.Overload(overloads.Conditional, argTypes(types.BoolType, paramA, paramA), paramA,
decls.OverloadIsNonStrict()),
decls.SingletonFunctionBinding(noFunctionOverrides)),
function(operators.LogicalAnd,
decls.Overload(overloads.LogicalAnd, argTypes(types.BoolType, types.BoolType), types.BoolType,
decls.OverloadIsNonStrict()),
decls.SingletonBinaryBinding(noBinaryOverrides)),
function(operators.LogicalOr,
decls.Overload(overloads.LogicalOr, argTypes(types.BoolType, types.BoolType), types.BoolType,
decls.OverloadIsNonStrict()),
decls.SingletonBinaryBinding(noBinaryOverrides)),
function(operators.LogicalNot,
decls.Overload(overloads.LogicalNot, argTypes(types.BoolType), types.BoolType),
decls.SingletonUnaryBinding(func(val ref.Val) ref.Val {
b, ok := val.(types.Bool)
if !ok {
return types.MaybeNoSuchOverloadErr(val)
}
return b.Negate()
})),
// Comprehension short-circuiting related function
function(operators.NotStrictlyFalse,
decls.Overload(overloads.NotStrictlyFalse, argTypes(types.BoolType), types.BoolType,
decls.OverloadIsNonStrict(),
decls.UnaryBinding(notStrictlyFalse))),
// Deprecated: __not_strictly_false__
function(operators.OldNotStrictlyFalse,
decls.DisableDeclaration(true), // safe deprecation
decls.Overload(operators.OldNotStrictlyFalse, argTypes(types.BoolType), types.BoolType,
decls.OverloadIsNonStrict(),
decls.UnaryBinding(notStrictlyFalse))),
// Equality / inequality. Special-cased in the interpreter
function(operators.Equals,
decls.Overload(overloads.Equals, argTypes(paramA, paramA), types.BoolType),
decls.SingletonBinaryBinding(noBinaryOverrides)),
function(operators.NotEquals,
decls.Overload(overloads.NotEquals, argTypes(paramA, paramA), types.BoolType),
decls.SingletonBinaryBinding(noBinaryOverrides)),
// Mathematical operators
function(operators.Add,
decls.Overload(overloads.AddBytes,
argTypes(types.BytesType, types.BytesType), types.BytesType),
decls.Overload(overloads.AddDouble,
argTypes(types.DoubleType, types.DoubleType), types.DoubleType),
decls.Overload(overloads.AddDurationDuration,
argTypes(types.DurationType, types.DurationType), types.DurationType),
decls.Overload(overloads.AddDurationTimestamp,
argTypes(types.DurationType, types.TimestampType), types.TimestampType),
decls.Overload(overloads.AddTimestampDuration,
argTypes(types.TimestampType, types.DurationType), types.TimestampType),
decls.Overload(overloads.AddInt64,
argTypes(types.IntType, types.IntType), types.IntType),
decls.Overload(overloads.AddList,
argTypes(listOfA, listOfA), listOfA),
decls.Overload(overloads.AddString,
argTypes(types.StringType, types.StringType), types.StringType),
decls.Overload(overloads.AddUint64,
argTypes(types.UintType, types.UintType), types.UintType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
return lhs.(traits.Adder).Add(rhs)
}, traits.AdderType)),
function(operators.Divide,
decls.Overload(overloads.DivideDouble,
argTypes(types.DoubleType, types.DoubleType), types.DoubleType),
decls.Overload(overloads.DivideInt64,
argTypes(types.IntType, types.IntType), types.IntType),
decls.Overload(overloads.DivideUint64,
argTypes(types.UintType, types.UintType), types.UintType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
return lhs.(traits.Divider).Divide(rhs)
}, traits.DividerType)),
function(operators.Modulo,
decls.Overload(overloads.ModuloInt64,
argTypes(types.IntType, types.IntType), types.IntType),
decls.Overload(overloads.ModuloUint64,
argTypes(types.UintType, types.UintType), types.UintType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
return lhs.(traits.Modder).Modulo(rhs)
}, traits.ModderType)),
function(operators.Multiply,
decls.Overload(overloads.MultiplyDouble,
argTypes(types.DoubleType, types.DoubleType), types.DoubleType),
decls.Overload(overloads.MultiplyInt64,
argTypes(types.IntType, types.IntType), types.IntType),
decls.Overload(overloads.MultiplyUint64,
argTypes(types.UintType, types.UintType), types.UintType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
return lhs.(traits.Multiplier).Multiply(rhs)
}, traits.MultiplierType)),
function(operators.Negate,
decls.Overload(overloads.NegateDouble, argTypes(types.DoubleType), types.DoubleType),
decls.Overload(overloads.NegateInt64, argTypes(types.IntType), types.IntType),
decls.SingletonUnaryBinding(func(val ref.Val) ref.Val {
if types.IsBool(val) {
return types.MaybeNoSuchOverloadErr(val)
}
return val.(traits.Negater).Negate()
}, traits.NegatorType)),
function(operators.Subtract,
decls.Overload(overloads.SubtractDouble,
argTypes(types.DoubleType, types.DoubleType), types.DoubleType),
decls.Overload(overloads.SubtractDurationDuration,
argTypes(types.DurationType, types.DurationType), types.DurationType),
decls.Overload(overloads.SubtractInt64,
argTypes(types.IntType, types.IntType), types.IntType),
decls.Overload(overloads.SubtractTimestampDuration,
argTypes(types.TimestampType, types.DurationType), types.TimestampType),
decls.Overload(overloads.SubtractTimestampTimestamp,
argTypes(types.TimestampType, types.TimestampType), types.DurationType),
decls.Overload(overloads.SubtractUint64,
argTypes(types.UintType, types.UintType), types.UintType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
return lhs.(traits.Subtractor).Subtract(rhs)
}, traits.SubtractorType)),
// Relations operators
function(operators.Less,
decls.Overload(overloads.LessBool,
argTypes(types.BoolType, types.BoolType), types.BoolType),
decls.Overload(overloads.LessInt64,
argTypes(types.IntType, types.IntType), types.BoolType),
decls.Overload(overloads.LessInt64Double,
argTypes(types.IntType, types.DoubleType), types.BoolType),
decls.Overload(overloads.LessInt64Uint64,
argTypes(types.IntType, types.UintType), types.BoolType),
decls.Overload(overloads.LessUint64,
argTypes(types.UintType, types.UintType), types.BoolType),
decls.Overload(overloads.LessUint64Double,
argTypes(types.UintType, types.DoubleType), types.BoolType),
decls.Overload(overloads.LessUint64Int64,
argTypes(types.UintType, types.IntType), types.BoolType),
decls.Overload(overloads.LessDouble,
argTypes(types.DoubleType, types.DoubleType), types.BoolType),
decls.Overload(overloads.LessDoubleInt64,
argTypes(types.DoubleType, types.IntType), types.BoolType),
decls.Overload(overloads.LessDoubleUint64,
argTypes(types.DoubleType, types.UintType), types.BoolType),
decls.Overload(overloads.LessString,
argTypes(types.StringType, types.StringType), types.BoolType),
decls.Overload(overloads.LessBytes,
argTypes(types.BytesType, types.BytesType), types.BoolType),
decls.Overload(overloads.LessTimestamp,
argTypes(types.TimestampType, types.TimestampType), types.BoolType),
decls.Overload(overloads.LessDuration,
argTypes(types.DurationType, types.DurationType), types.BoolType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
cmp := lhs.(traits.Comparer).Compare(rhs)
if cmp == types.IntNegOne {
return types.True
}
if cmp == types.IntOne || cmp == types.IntZero {
return types.False
}
return cmp
}, traits.ComparerType)),
function(operators.LessEquals,
decls.Overload(overloads.LessEqualsBool,
argTypes(types.BoolType, types.BoolType), types.BoolType),
decls.Overload(overloads.LessEqualsInt64,
argTypes(types.IntType, types.IntType), types.BoolType),
decls.Overload(overloads.LessEqualsInt64Double,
argTypes(types.IntType, types.DoubleType), types.BoolType),
decls.Overload(overloads.LessEqualsInt64Uint64,
argTypes(types.IntType, types.UintType), types.BoolType),
decls.Overload(overloads.LessEqualsUint64,
argTypes(types.UintType, types.UintType), types.BoolType),
decls.Overload(overloads.LessEqualsUint64Double,
argTypes(types.UintType, types.DoubleType), types.BoolType),
decls.Overload(overloads.LessEqualsUint64Int64,
argTypes(types.UintType, types.IntType), types.BoolType),
decls.Overload(overloads.LessEqualsDouble,
argTypes(types.DoubleType, types.DoubleType), types.BoolType),
decls.Overload(overloads.LessEqualsDoubleInt64,
argTypes(types.DoubleType, types.IntType), types.BoolType),
decls.Overload(overloads.LessEqualsDoubleUint64,
argTypes(types.DoubleType, types.UintType), types.BoolType),
decls.Overload(overloads.LessEqualsString,
argTypes(types.StringType, types.StringType), types.BoolType),
decls.Overload(overloads.LessEqualsBytes,
argTypes(types.BytesType, types.BytesType), types.BoolType),
decls.Overload(overloads.LessEqualsTimestamp,
argTypes(types.TimestampType, types.TimestampType), types.BoolType),
decls.Overload(overloads.LessEqualsDuration,
argTypes(types.DurationType, types.DurationType), types.BoolType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
cmp := lhs.(traits.Comparer).Compare(rhs)
if cmp == types.IntNegOne || cmp == types.IntZero {
return types.True
}
if cmp == types.IntOne {
return types.False
}
return cmp
}, traits.ComparerType)),
function(operators.Greater,
decls.Overload(overloads.GreaterBool,
argTypes(types.BoolType, types.BoolType), types.BoolType),
decls.Overload(overloads.GreaterInt64,
argTypes(types.IntType, types.IntType), types.BoolType),
decls.Overload(overloads.GreaterInt64Double,
argTypes(types.IntType, types.DoubleType), types.BoolType),
decls.Overload(overloads.GreaterInt64Uint64,
argTypes(types.IntType, types.UintType), types.BoolType),
decls.Overload(overloads.GreaterUint64,
argTypes(types.UintType, types.UintType), types.BoolType),
decls.Overload(overloads.GreaterUint64Double,
argTypes(types.UintType, types.DoubleType), types.BoolType),
decls.Overload(overloads.GreaterUint64Int64,
argTypes(types.UintType, types.IntType), types.BoolType),
decls.Overload(overloads.GreaterDouble,
argTypes(types.DoubleType, types.DoubleType), types.BoolType),
decls.Overload(overloads.GreaterDoubleInt64,
argTypes(types.DoubleType, types.IntType), types.BoolType),
decls.Overload(overloads.GreaterDoubleUint64,
argTypes(types.DoubleType, types.UintType), types.BoolType),
decls.Overload(overloads.GreaterString,
argTypes(types.StringType, types.StringType), types.BoolType),
decls.Overload(overloads.GreaterBytes,
argTypes(types.BytesType, types.BytesType), types.BoolType),
decls.Overload(overloads.GreaterTimestamp,
argTypes(types.TimestampType, types.TimestampType), types.BoolType),
decls.Overload(overloads.GreaterDuration,
argTypes(types.DurationType, types.DurationType), types.BoolType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
cmp := lhs.(traits.Comparer).Compare(rhs)
if cmp == types.IntOne {
return types.True
}
if cmp == types.IntNegOne || cmp == types.IntZero {
return types.False
}
return cmp
}, traits.ComparerType)),
function(operators.GreaterEquals,
decls.Overload(overloads.GreaterEqualsBool,
argTypes(types.BoolType, types.BoolType), types.BoolType),
decls.Overload(overloads.GreaterEqualsInt64,
argTypes(types.IntType, types.IntType), types.BoolType),
decls.Overload(overloads.GreaterEqualsInt64Double,
argTypes(types.IntType, types.DoubleType), types.BoolType),
decls.Overload(overloads.GreaterEqualsInt64Uint64,
argTypes(types.IntType, types.UintType), types.BoolType),
decls.Overload(overloads.GreaterEqualsUint64,
argTypes(types.UintType, types.UintType), types.BoolType),
decls.Overload(overloads.GreaterEqualsUint64Double,
argTypes(types.UintType, types.DoubleType), types.BoolType),
decls.Overload(overloads.GreaterEqualsUint64Int64,
argTypes(types.UintType, types.IntType), types.BoolType),
decls.Overload(overloads.GreaterEqualsDouble,
argTypes(types.DoubleType, types.DoubleType), types.BoolType),
decls.Overload(overloads.GreaterEqualsDoubleInt64,
argTypes(types.DoubleType, types.IntType), types.BoolType),
decls.Overload(overloads.GreaterEqualsDoubleUint64,
argTypes(types.DoubleType, types.UintType), types.BoolType),
decls.Overload(overloads.GreaterEqualsString,
argTypes(types.StringType, types.StringType), types.BoolType),
decls.Overload(overloads.GreaterEqualsBytes,
argTypes(types.BytesType, types.BytesType), types.BoolType),
decls.Overload(overloads.GreaterEqualsTimestamp,
argTypes(types.TimestampType, types.TimestampType), types.BoolType),
decls.Overload(overloads.GreaterEqualsDuration,
argTypes(types.DurationType, types.DurationType), types.BoolType),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
cmp := lhs.(traits.Comparer).Compare(rhs)
if cmp == types.IntOne || cmp == types.IntZero {
return types.True
}
if cmp == types.IntNegOne {
return types.False
}
return cmp
}, traits.ComparerType)),
// Indexing
function(operators.Index,
decls.Overload(overloads.IndexList, argTypes(listOfA, types.IntType), paramA),
decls.Overload(overloads.IndexMap, argTypes(mapOfAB, paramA), paramB),
decls.SingletonBinaryBinding(func(lhs, rhs ref.Val) ref.Val {
return lhs.(traits.Indexer).Get(rhs)
}, traits.IndexerType)),
// Collections operators
function(operators.In,
decls.Overload(overloads.InList, argTypes(paramA, listOfA), types.BoolType),
decls.Overload(overloads.InMap, argTypes(paramA, mapOfAB), types.BoolType),
decls.SingletonBinaryBinding(inAggregate)),
function(operators.OldIn,
decls.DisableDeclaration(true), // safe deprecation
decls.Overload(overloads.InList, argTypes(paramA, listOfA), types.BoolType),
decls.Overload(overloads.InMap, argTypes(paramA, mapOfAB), types.BoolType),
decls.SingletonBinaryBinding(inAggregate)),
function(overloads.DeprecatedIn,
decls.DisableDeclaration(true), // safe deprecation
decls.Overload(overloads.InList, argTypes(paramA, listOfA), types.BoolType),
decls.Overload(overloads.InMap, argTypes(paramA, mapOfAB), types.BoolType),
decls.SingletonBinaryBinding(inAggregate)),
function(overloads.Size,
decls.Overload(overloads.SizeBytes, argTypes(types.BytesType), types.IntType),
decls.MemberOverload(overloads.SizeBytesInst, argTypes(types.BytesType), types.IntType),
decls.Overload(overloads.SizeList, argTypes(listOfA), types.IntType),
decls.MemberOverload(overloads.SizeListInst, argTypes(listOfA), types.IntType),
decls.Overload(overloads.SizeMap, argTypes(mapOfAB), types.IntType),
decls.MemberOverload(overloads.SizeMapInst, argTypes(mapOfAB), types.IntType),
decls.Overload(overloads.SizeString, argTypes(types.StringType), types.IntType),
decls.MemberOverload(overloads.SizeStringInst, argTypes(types.StringType), types.IntType),
decls.SingletonUnaryBinding(func(val ref.Val) ref.Val {
return val.(traits.Sizer).Size()
}, traits.SizerType)),
// Type conversions
function(overloads.TypeConvertType,
decls.Overload(overloads.TypeConvertType, argTypes(paramA), types.NewTypeTypeWithParam(paramA)),
decls.SingletonUnaryBinding(convertToType(types.TypeType))),
// Bool conversions
function(overloads.TypeConvertBool,
decls.Overload(overloads.BoolToBool, argTypes(types.BoolType), types.BoolType,
decls.UnaryBinding(identity)),
decls.Overload(overloads.StringToBool, argTypes(types.StringType), types.BoolType,
decls.UnaryBinding(convertToType(types.BoolType)))),
// Bytes conversions
function(overloads.TypeConvertBytes,
decls.Overload(overloads.BytesToBytes, argTypes(types.BytesType), types.BytesType,
decls.UnaryBinding(identity)),
decls.Overload(overloads.StringToBytes, argTypes(types.StringType), types.BytesType,
decls.UnaryBinding(convertToType(types.BytesType)))),
// Double conversions
function(overloads.TypeConvertDouble,
decls.Overload(overloads.DoubleToDouble, argTypes(types.DoubleType), types.DoubleType,
decls.UnaryBinding(identity)),
decls.Overload(overloads.IntToDouble, argTypes(types.IntType), types.DoubleType,
decls.UnaryBinding(convertToType(types.DoubleType))),
decls.Overload(overloads.StringToDouble, argTypes(types.StringType), types.DoubleType,
decls.UnaryBinding(convertToType(types.DoubleType))),
decls.Overload(overloads.UintToDouble, argTypes(types.UintType), types.DoubleType,
decls.UnaryBinding(convertToType(types.DoubleType)))),
// Duration conversions
function(overloads.TypeConvertDuration,
decls.Overload(overloads.DurationToDuration, argTypes(types.DurationType), types.DurationType,
decls.UnaryBinding(identity)),
decls.Overload(overloads.IntToDuration, argTypes(types.IntType), types.DurationType,
decls.UnaryBinding(convertToType(types.DurationType))),
decls.Overload(overloads.StringToDuration, argTypes(types.StringType), types.DurationType,
decls.UnaryBinding(convertToType(types.DurationType)))),
// Dyn conversions
function(overloads.TypeConvertDyn,
decls.Overload(overloads.ToDyn, argTypes(paramA), types.DynType),
decls.SingletonUnaryBinding(identity)),
// Int conversions
function(overloads.TypeConvertInt,
decls.Overload(overloads.IntToInt, argTypes(types.IntType), types.IntType,
decls.UnaryBinding(identity)),
decls.Overload(overloads.DoubleToInt, argTypes(types.DoubleType), types.IntType,
decls.UnaryBinding(convertToType(types.IntType))),
decls.Overload(overloads.DurationToInt, argTypes(types.DurationType), types.IntType,
decls.UnaryBinding(convertToType(types.IntType))),
decls.Overload(overloads.StringToInt, argTypes(types.StringType), types.IntType,
decls.UnaryBinding(convertToType(types.IntType))),
decls.Overload(overloads.TimestampToInt, argTypes(types.TimestampType), types.IntType,
decls.UnaryBinding(convertToType(types.IntType))),
decls.Overload(overloads.UintToInt, argTypes(types.UintType), types.IntType,
decls.UnaryBinding(convertToType(types.IntType))),
),
// String conversions
function(overloads.TypeConvertString,
decls.Overload(overloads.StringToString, argTypes(types.StringType), types.StringType,
decls.UnaryBinding(identity)),
decls.Overload(overloads.BoolToString, argTypes(types.BoolType), types.StringType,
decls.UnaryBinding(convertToType(types.StringType))),
decls.Overload(overloads.BytesToString, argTypes(types.BytesType), types.StringType,
decls.UnaryBinding(convertToType(types.StringType))),
decls.Overload(overloads.DoubleToString, argTypes(types.DoubleType), types.StringType,
decls.UnaryBinding(convertToType(types.StringType))),
decls.Overload(overloads.DurationToString, argTypes(types.DurationType), types.StringType,
decls.UnaryBinding(convertToType(types.StringType))),
decls.Overload(overloads.IntToString, argTypes(types.IntType), types.StringType,
decls.UnaryBinding(convertToType(types.StringType))),
decls.Overload(overloads.TimestampToString, argTypes(types.TimestampType), types.StringType,
decls.UnaryBinding(convertToType(types.StringType))),
decls.Overload(overloads.UintToString, argTypes(types.UintType), types.StringType,
decls.UnaryBinding(convertToType(types.StringType)))),
// Timestamp conversions
function(overloads.TypeConvertTimestamp,
decls.Overload(overloads.TimestampToTimestamp, argTypes(types.TimestampType), types.TimestampType,
decls.UnaryBinding(identity)),
decls.Overload(overloads.IntToTimestamp, argTypes(types.IntType), types.TimestampType,
decls.UnaryBinding(convertToType(types.TimestampType))),
decls.Overload(overloads.StringToTimestamp, argTypes(types.StringType), types.TimestampType,
decls.UnaryBinding(convertToType(types.TimestampType)))),
// Uint conversions
function(overloads.TypeConvertUint,
decls.Overload(overloads.UintToUint, argTypes(types.UintType), types.UintType,
decls.UnaryBinding(identity)),
decls.Overload(overloads.DoubleToUint, argTypes(types.DoubleType), types.UintType,
decls.UnaryBinding(convertToType(types.UintType))),
decls.Overload(overloads.IntToUint, argTypes(types.IntType), types.UintType,
decls.UnaryBinding(convertToType(types.UintType))),
decls.Overload(overloads.StringToUint, argTypes(types.StringType), types.UintType,
decls.UnaryBinding(convertToType(types.UintType)))),
// String functions
function(overloads.Contains,
decls.MemberOverload(overloads.ContainsString,
argTypes(types.StringType, types.StringType), types.BoolType,
decls.BinaryBinding(types.StringContains)),
decls.DisableTypeGuards(true)),
function(overloads.EndsWith,
decls.MemberOverload(overloads.EndsWithString,
argTypes(types.StringType, types.StringType), types.BoolType,
decls.BinaryBinding(types.StringEndsWith)),
decls.DisableTypeGuards(true)),
function(overloads.StartsWith,
decls.MemberOverload(overloads.StartsWithString,
argTypes(types.StringType, types.StringType), types.BoolType,
decls.BinaryBinding(types.StringStartsWith)),
decls.DisableTypeGuards(true)),
function(overloads.Matches,
decls.Overload(overloads.Matches, argTypes(types.StringType, types.StringType), types.BoolType),
decls.MemberOverload(overloads.MatchesString,
argTypes(types.StringType, types.StringType), types.BoolType),
decls.SingletonBinaryBinding(func(str, pat ref.Val) ref.Val {
return str.(traits.Matcher).Match(pat)
}, traits.MatcherType)),
// Timestamp / duration functions
function(overloads.TimeGetFullYear,
decls.MemberOverload(overloads.TimestampToYear,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToYearWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType)),
function(overloads.TimeGetMonth,
decls.MemberOverload(overloads.TimestampToMonth,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToMonthWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType)),
function(overloads.TimeGetDayOfYear,
decls.MemberOverload(overloads.TimestampToDayOfYear,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToDayOfYearWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType)),
function(overloads.TimeGetDayOfMonth,
decls.MemberOverload(overloads.TimestampToDayOfMonthZeroBased,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToDayOfMonthZeroBasedWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType)),
function(overloads.TimeGetDate,
decls.MemberOverload(overloads.TimestampToDayOfMonthOneBased,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToDayOfMonthOneBasedWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType)),
function(overloads.TimeGetDayOfWeek,
decls.MemberOverload(overloads.TimestampToDayOfWeek,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToDayOfWeekWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType)),
function(overloads.TimeGetHours,
decls.MemberOverload(overloads.TimestampToHours,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToHoursWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType),
decls.MemberOverload(overloads.DurationToHours,
argTypes(types.DurationType), types.IntType)),
function(overloads.TimeGetMinutes,
decls.MemberOverload(overloads.TimestampToMinutes,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToMinutesWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType),
decls.MemberOverload(overloads.DurationToMinutes,
argTypes(types.DurationType), types.IntType)),
function(overloads.TimeGetSeconds,
decls.MemberOverload(overloads.TimestampToSeconds,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToSecondsWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType),
decls.MemberOverload(overloads.DurationToSeconds,
argTypes(types.DurationType), types.IntType)),
function(overloads.TimeGetMilliseconds,
decls.MemberOverload(overloads.TimestampToMilliseconds,
argTypes(types.TimestampType), types.IntType),
decls.MemberOverload(overloads.TimestampToMillisecondsWithTz,
argTypes(types.TimestampType, types.StringType), types.IntType),
decls.MemberOverload(overloads.DurationToMilliseconds,
argTypes(types.DurationType), types.IntType)),
}
stdFnDecls = make([]*exprpb.Decl, 0, len(stdFunctions))
for _, fn := range stdFunctions {
if fn.IsDeclarationDisabled() {
continue
}
ed, err := decls.FunctionDeclToExprDecl(fn)
if err != nil {
panic(err)
}
stdFnDecls = append(stdFnDecls, ed)
}
}
// Functions returns the set of standard library function declarations and definitions for CEL.
func Functions() []*decls.FunctionDecl {
return stdFunctions
}
// FunctionExprDecls returns the legacy style protobuf-typed declarations for all functions and overloads
// in the CEL standard environment.
//
// Deprecated: use Functions
func FunctionExprDecls() []*exprpb.Decl {
return stdFnDecls
}
// Types returns the set of standard library types for CEL.
func Types() []*decls.VariableDecl {
return stdTypes
}
// TypeExprDecls returns the legacy style protobuf-typed declarations for all types in the CEL
// standard environment.
//
// Deprecated: use Types
func TypeExprDecls() []*exprpb.Decl {
return stdTypeDecls
}
func notStrictlyFalse(value ref.Val) ref.Val {
if types.IsBool(value) {
return value
}
return types.True
}
func inAggregate(lhs ref.Val, rhs ref.Val) ref.Val {
if rhs.Type().HasTrait(traits.ContainerType) {
return rhs.(traits.Container).Contains(lhs)
}
return types.ValOrErr(rhs, "no such overload")
}
func function(name string, opts ...decls.FunctionOpt) *decls.FunctionDecl {
fn, err := decls.NewFunction(name, opts...)
if err != nil {
panic(err)
}
return fn
}
func argTypes(args ...*types.Type) []*types.Type {
return args
}
func noBinaryOverrides(rhs, lhs ref.Val) ref.Val {
return types.NoSuchOverloadErr()
}
func noFunctionOverrides(args ...ref.Val) ref.Val {
return types.NoSuchOverloadErr()
}
func identity(val ref.Val) ref.Val {
return val
}
func convertToType(t ref.Type) functions.UnaryOp {
return func(val ref.Val) ref.Val {
return val.ConvertToType(t)
}
}

View File

@@ -22,26 +22,25 @@ go_library(
"map.go",
"null.go",
"object.go",
"optional.go",
"overflow.go",
"provider.go",
"string.go",
"timestamp.go",
"type.go",
"types.go",
"uint.go",
"unknown.go",
"util.go",
],
importpath = "github.com/google/cel-go/common/types",
deps = [
"//checker/decls:go_default_library",
"//common/overloads:go_default_library",
"//common/types/pb:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"@com_github_stoewer_go_strcase//:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto//googleapis/rpc/status:go_default_library",
"@org_golang_google_grpc//codes:go_default_library",
"@org_golang_google_grpc//status:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
@@ -68,11 +67,13 @@ go_test(
"map_test.go",
"null_test.go",
"object_test.go",
"optional_test.go",
"provider_test.go",
"string_test.go",
"timestamp_test.go",
"type_test.go",
"types_test.go",
"uint_test.go",
"unknown_test.go",
"util_test.go",
],
embed = [":go_default_library"],
@@ -80,7 +81,7 @@ go_test(
"//common/types/ref:go_default_library",
"//test:go_default_library",
"//test/proto3pb:test_all_types_go_proto",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
"@org_golang_google_protobuf//types/known/anypb:go_default_library",
"@org_golang_google_protobuf//types/known/durationpb:go_default_library",

View File

@@ -20,7 +20,6 @@ import (
"strconv"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
@@ -31,11 +30,6 @@ import (
type Bool bool
var (
// BoolType singleton.
BoolType = NewTypeValue("bool",
traits.ComparerType,
traits.NegatorType)
// boolWrapperType golang reflected type for protobuf bool wrapper type.
boolWrapperType = reflect.TypeOf(&wrapperspb.BoolValue{})
)
@@ -62,7 +56,7 @@ func (b Bool) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements the ref.Val interface method.
func (b Bool) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (b Bool) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Bool:
return reflect.ValueOf(b).Convert(typeDesc).Interface(), nil
@@ -114,6 +108,11 @@ func (b Bool) Equal(other ref.Val) ref.Val {
return Bool(ok && b == otherBool)
}
// IsZeroValue returns true if the boolean value is false.
func (b Bool) IsZeroValue() bool {
return b == False
}
// Negate implements the traits.Negater interface method.
func (b Bool) Negate() ref.Val {
return !b
@@ -125,7 +124,7 @@ func (b Bool) Type() ref.Type {
}
// Value implements the ref.Val interface method.
func (b Bool) Value() interface{} {
func (b Bool) Value() any {
return bool(b)
}

View File

@@ -22,7 +22,6 @@ import (
"unicode/utf8"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
@@ -34,12 +33,6 @@ import (
type Bytes []byte
var (
// BytesType singleton.
BytesType = NewTypeValue("bytes",
traits.AdderType,
traits.ComparerType,
traits.SizerType)
// byteWrapperType golang reflected type for protobuf bytes wrapper type.
byteWrapperType = reflect.TypeOf(&wrapperspb.BytesValue{})
)
@@ -63,7 +56,7 @@ func (b Bytes) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements the ref.Val interface method.
func (b Bytes) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (b Bytes) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Array, reflect.Slice:
return reflect.ValueOf(b).Convert(typeDesc).Interface(), nil
@@ -116,6 +109,11 @@ func (b Bytes) Equal(other ref.Val) ref.Val {
return Bool(ok && bytes.Equal(b, otherBytes))
}
// IsZeroValue returns true if the byte array is empty.
func (b Bytes) IsZeroValue() bool {
return len(b) == 0
}
// Size implements the traits.Sizer interface method.
func (b Bytes) Size() ref.Val {
return Int(len(b))
@@ -127,6 +125,6 @@ func (b Bytes) Type() ref.Type {
}
// Value implements the ref.Val interface method.
func (b Bytes) Value() interface{} {
func (b Bytes) Value() any {
return []byte(b)
}

View File

@@ -20,7 +20,6 @@ import (
"reflect"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
@@ -32,15 +31,6 @@ import (
type Double float64
var (
// DoubleType singleton.
DoubleType = NewTypeValue("double",
traits.AdderType,
traits.ComparerType,
traits.DividerType,
traits.MultiplierType,
traits.NegatorType,
traits.SubtractorType)
// doubleWrapperType reflected type for protobuf double wrapper type.
doubleWrapperType = reflect.TypeOf(&wrapperspb.DoubleValue{})
@@ -78,7 +68,7 @@ func (d Double) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (d Double) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (d Double) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Float32:
v := float32(d)
@@ -134,13 +124,13 @@ func (d Double) ConvertToType(typeVal ref.Type) ref.Val {
case IntType:
i, err := doubleToInt64Checked(float64(d))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(i)
case UintType:
i, err := doubleToUint64Checked(float64(d))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Uint(i)
case DoubleType:
@@ -182,6 +172,11 @@ func (d Double) Equal(other ref.Val) ref.Val {
}
}
// IsZeroValue returns true if double value is 0.0
func (d Double) IsZeroValue() bool {
return float64(d) == 0.0
}
// Multiply implements traits.Multiplier.Multiply.
func (d Double) Multiply(other ref.Val) ref.Val {
otherDouble, ok := other.(Double)
@@ -211,6 +206,6 @@ func (d Double) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (d Double) Value() interface{} {
func (d Double) Value() any {
return float64(d)
}

View File

@@ -22,7 +22,6 @@ import (
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
anypb "google.golang.org/protobuf/types/known/anypb"
dpb "google.golang.org/protobuf/types/known/durationpb"
@@ -41,13 +40,14 @@ func durationOf(d time.Duration) Duration {
}
var (
// DurationType singleton.
DurationType = NewTypeValue("google.protobuf.Duration",
traits.AdderType,
traits.ComparerType,
traits.NegatorType,
traits.ReceiverType,
traits.SubtractorType)
durationValueType = reflect.TypeOf(&dpb.Duration{})
durationZeroArgOverloads = map[string]func(ref.Val) ref.Val{
overloads.TimeGetHours: DurationGetHours,
overloads.TimeGetMinutes: DurationGetMinutes,
overloads.TimeGetSeconds: DurationGetSeconds,
overloads.TimeGetMilliseconds: DurationGetMilliseconds,
}
)
// Add implements traits.Adder.Add.
@@ -57,14 +57,14 @@ func (d Duration) Add(other ref.Val) ref.Val {
dur2 := other.(Duration)
val, err := addDurationChecked(d.Duration, dur2.Duration)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return durationOf(val)
case TimestampType:
ts := other.(Timestamp).Time
val, err := addTimeDurationChecked(ts, d.Duration)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return timestampOf(val)
}
@@ -90,7 +90,7 @@ func (d Duration) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (d Duration) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (d Duration) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the duration is already assignable to the desired type return it.
if reflect.TypeOf(d.Duration).AssignableTo(typeDesc) {
return d.Duration, nil
@@ -138,11 +138,16 @@ func (d Duration) Equal(other ref.Val) ref.Val {
return Bool(ok && d.Duration == otherDur.Duration)
}
// IsZeroValue returns true if the duration value is zero
func (d Duration) IsZeroValue() bool {
return d.Duration == 0
}
// Negate implements traits.Negater.Negate.
func (d Duration) Negate() ref.Val {
val, err := negateDurationChecked(d.Duration)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return durationOf(val)
}
@@ -151,7 +156,7 @@ func (d Duration) Negate() ref.Val {
func (d Duration) Receive(function string, overload string, args []ref.Val) ref.Val {
if len(args) == 0 {
if f, found := durationZeroArgOverloads[function]; found {
return f(d.Duration)
return f(d)
}
}
return NoSuchOverloadErr()
@@ -165,7 +170,7 @@ func (d Duration) Subtract(subtrahend ref.Val) ref.Val {
}
val, err := subtractDurationChecked(d.Duration, subtraDur.Duration)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return durationOf(val)
}
@@ -176,24 +181,42 @@ func (d Duration) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (d Duration) Value() interface{} {
func (d Duration) Value() any {
return d.Duration
}
var (
durationValueType = reflect.TypeOf(&dpb.Duration{})
// DurationGetHours returns the duration in hours.
func DurationGetHours(val ref.Val) ref.Val {
dur, ok := val.(Duration)
if !ok {
return MaybeNoSuchOverloadErr(val)
}
return Int(dur.Hours())
}
durationZeroArgOverloads = map[string]func(time.Duration) ref.Val{
overloads.TimeGetHours: func(dur time.Duration) ref.Val {
return Int(dur.Hours())
},
overloads.TimeGetMinutes: func(dur time.Duration) ref.Val {
return Int(dur.Minutes())
},
overloads.TimeGetSeconds: func(dur time.Duration) ref.Val {
return Int(dur.Seconds())
},
overloads.TimeGetMilliseconds: func(dur time.Duration) ref.Val {
return Int(dur.Milliseconds())
}}
)
// DurationGetMinutes returns duration in minutes.
func DurationGetMinutes(val ref.Val) ref.Val {
dur, ok := val.(Duration)
if !ok {
return MaybeNoSuchOverloadErr(val)
}
return Int(dur.Minutes())
}
// DurationGetSeconds returns duration in seconds.
func DurationGetSeconds(val ref.Val) ref.Val {
dur, ok := val.(Duration)
if !ok {
return MaybeNoSuchOverloadErr(val)
}
return Int(dur.Seconds())
}
// DurationGetMilliseconds returns duration in milliseconds.
func DurationGetMilliseconds(val ref.Val) ref.Val {
dur, ok := val.(Duration)
if !ok {
return MaybeNoSuchOverloadErr(val)
}
return Int(dur.Milliseconds())
}

View File

@@ -22,6 +22,12 @@ import (
"github.com/google/cel-go/common/types/ref"
)
// Error interface which allows types types.Err values to be treated as error values.
type Error interface {
error
ref.Val
}
// Err type which extends the built-in go error and implements ref.Val.
type Err struct {
error
@@ -29,7 +35,7 @@ type Err struct {
var (
// ErrType singleton.
ErrType = NewTypeValue("error")
ErrType = NewOpaqueType("error")
// errDivideByZero is an error indicating a division by zero of an integer value.
errDivideByZero = errors.New("division by zero")
@@ -51,7 +57,7 @@ var (
// NewErr creates a new Err described by the format string and args.
// TODO: Audit the use of this function and standardize the error messages and codes.
func NewErr(format string, args ...interface{}) ref.Val {
func NewErr(format string, args ...any) ref.Val {
return &Err{fmt.Errorf(format, args...)}
}
@@ -62,7 +68,7 @@ func NoSuchOverloadErr() ref.Val {
// UnsupportedRefValConversionErr returns a types.NewErr instance with a no such conversion
// message that indicates that the native value could not be converted to a CEL ref.Val.
func UnsupportedRefValConversionErr(val interface{}) ref.Val {
func UnsupportedRefValConversionErr(val any) ref.Val {
return NewErr("unsupported conversion to ref.Val: (%T)%v", val, val)
}
@@ -74,20 +80,20 @@ func MaybeNoSuchOverloadErr(val ref.Val) ref.Val {
// ValOrErr either returns the existing error or creates a new one.
// TODO: Audit the use of this function and standardize the error messages and codes.
func ValOrErr(val ref.Val, format string, args ...interface{}) ref.Val {
func ValOrErr(val ref.Val, format string, args ...any) ref.Val {
if val == nil || !IsUnknownOrError(val) {
return NewErr(format, args...)
}
return val
}
// wrapErr wraps an existing Go error value into a CEL Err value.
func wrapErr(err error) ref.Val {
// WrapErr wraps an existing Go error value into a CEL Err value.
func WrapErr(err error) ref.Val {
return &Err{error: err}
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (e *Err) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (e *Err) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, e.error
}
@@ -114,7 +120,17 @@ func (e *Err) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (e *Err) Value() interface{} {
func (e *Err) Value() any {
return e.error
}
// Is implements errors.Is.
func (e *Err) Is(target error) bool {
return e.error.Error() == target.Error()
}
// Unwrap implements errors.Unwrap.
func (e *Err) Unwrap() error {
return e.error
}

View File

@@ -22,7 +22,6 @@ import (
"time"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
@@ -41,16 +40,6 @@ const (
)
var (
// IntType singleton.
IntType = NewTypeValue("int",
traits.AdderType,
traits.ComparerType,
traits.DividerType,
traits.ModderType,
traits.MultiplierType,
traits.NegatorType,
traits.SubtractorType)
// int32WrapperType reflected type for protobuf int32 wrapper type.
int32WrapperType = reflect.TypeOf(&wrapperspb.Int32Value{})
@@ -66,7 +55,7 @@ func (i Int) Add(other ref.Val) ref.Val {
}
val, err := addInt64Checked(int64(i), int64(otherInt))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@@ -89,7 +78,7 @@ func (i Int) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (i Int) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (i Int) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Int, reflect.Int32:
// Enums are also mapped as int32 derivations.
@@ -176,7 +165,7 @@ func (i Int) ConvertToType(typeVal ref.Type) ref.Val {
case UintType:
u, err := int64ToUint64Checked(int64(i))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Uint(u)
case DoubleType:
@@ -204,7 +193,7 @@ func (i Int) Divide(other ref.Val) ref.Val {
}
val, err := divideInt64Checked(int64(i), int64(otherInt))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@@ -226,6 +215,11 @@ func (i Int) Equal(other ref.Val) ref.Val {
}
}
// IsZeroValue returns true if integer is equal to 0
func (i Int) IsZeroValue() bool {
return i == IntZero
}
// Modulo implements traits.Modder.Modulo.
func (i Int) Modulo(other ref.Val) ref.Val {
otherInt, ok := other.(Int)
@@ -234,7 +228,7 @@ func (i Int) Modulo(other ref.Val) ref.Val {
}
val, err := moduloInt64Checked(int64(i), int64(otherInt))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@@ -247,7 +241,7 @@ func (i Int) Multiply(other ref.Val) ref.Val {
}
val, err := multiplyInt64Checked(int64(i), int64(otherInt))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@@ -256,7 +250,7 @@ func (i Int) Multiply(other ref.Val) ref.Val {
func (i Int) Negate() ref.Val {
val, err := negateInt64Checked(int64(i))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@@ -269,7 +263,7 @@ func (i Int) Subtract(subtrahend ref.Val) ref.Val {
}
val, err := subtractInt64Checked(int64(i), int64(subtraInt))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(val)
}
@@ -280,7 +274,7 @@ func (i Int) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (i Int) Value() interface{} {
func (i Int) Value() any {
return int64(i)
}

View File

@@ -24,7 +24,7 @@ import (
var (
// IteratorType singleton.
IteratorType = NewTypeValue("iterator", traits.IteratorType)
IteratorType = NewObjectType("iterator", traits.IteratorType)
)
// baseIterator is the basis for list, map, and object iterators.
@@ -34,7 +34,7 @@ var (
// interpreter.
type baseIterator struct{}
func (*baseIterator) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (*baseIterator) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, fmt.Errorf("type conversion on iterators not supported")
}
@@ -50,6 +50,6 @@ func (*baseIterator) Type() ref.Type {
return IteratorType
}
func (*baseIterator) Value() interface{} {
func (*baseIterator) Value() any {
return nil
}

View File

@@ -25,4 +25,5 @@ var (
jsonValueType = reflect.TypeOf(&structpb.Value{})
jsonListValueType = reflect.TypeOf(&structpb.ListValue{})
jsonStructType = reflect.TypeOf(&structpb.Struct{})
jsonNullType = reflect.TypeOf(structpb.NullValue_NULL_VALUE)
)

View File

@@ -17,104 +17,99 @@ package types
import (
"fmt"
"reflect"
"strings"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
)
var (
// ListType singleton.
ListType = NewTypeValue("list",
traits.AdderType,
traits.ContainerType,
traits.IndexerType,
traits.IterableType,
traits.SizerType)
)
// NewDynamicList returns a traits.Lister with heterogenous elements.
// value should be an array of "native" types, i.e. any type that
// NativeToValue() can convert to a ref.Val.
func NewDynamicList(adapter ref.TypeAdapter, value interface{}) traits.Lister {
func NewDynamicList(adapter Adapter, value any) traits.Lister {
refValue := reflect.ValueOf(value)
return &baseList{
TypeAdapter: adapter,
value: value,
size: refValue.Len(),
get: func(i int) interface{} {
Adapter: adapter,
value: value,
size: refValue.Len(),
get: func(i int) any {
return refValue.Index(i).Interface()
},
}
}
// NewStringList returns a traits.Lister containing only strings.
func NewStringList(adapter ref.TypeAdapter, elems []string) traits.Lister {
func NewStringList(adapter Adapter, elems []string) traits.Lister {
return &baseList{
TypeAdapter: adapter,
value: elems,
size: len(elems),
get: func(i int) interface{} { return elems[i] },
Adapter: adapter,
value: elems,
size: len(elems),
get: func(i int) any { return elems[i] },
}
}
// NewRefValList returns a traits.Lister with ref.Val elements.
//
// This type specialization is used with list literals within CEL expressions.
func NewRefValList(adapter ref.TypeAdapter, elems []ref.Val) traits.Lister {
func NewRefValList(adapter Adapter, elems []ref.Val) traits.Lister {
return &baseList{
TypeAdapter: adapter,
value: elems,
size: len(elems),
get: func(i int) interface{} { return elems[i] },
Adapter: adapter,
value: elems,
size: len(elems),
get: func(i int) any { return elems[i] },
}
}
// NewProtoList returns a traits.Lister based on a pb.List instance.
func NewProtoList(adapter ref.TypeAdapter, list protoreflect.List) traits.Lister {
func NewProtoList(adapter Adapter, list protoreflect.List) traits.Lister {
return &baseList{
TypeAdapter: adapter,
value: list,
size: list.Len(),
get: func(i int) interface{} { return list.Get(i).Interface() },
Adapter: adapter,
value: list,
size: list.Len(),
get: func(i int) any { return list.Get(i).Interface() },
}
}
// NewJSONList returns a traits.Lister based on structpb.ListValue instance.
func NewJSONList(adapter ref.TypeAdapter, l *structpb.ListValue) traits.Lister {
func NewJSONList(adapter Adapter, l *structpb.ListValue) traits.Lister {
vals := l.GetValues()
return &baseList{
TypeAdapter: adapter,
value: l,
size: len(vals),
get: func(i int) interface{} { return vals[i] },
Adapter: adapter,
value: l,
size: len(vals),
get: func(i int) any { return vals[i] },
}
}
// NewMutableList creates a new mutable list whose internal state can be modified.
func NewMutableList(adapter ref.TypeAdapter) traits.MutableLister {
func NewMutableList(adapter Adapter) traits.MutableLister {
var mutableValues []ref.Val
return &mutableList{
l := &mutableList{
baseList: &baseList{
TypeAdapter: adapter,
value: mutableValues,
size: 0,
get: func(i int) interface{} { return mutableValues[i] },
Adapter: adapter,
value: mutableValues,
size: 0,
},
mutableValues: mutableValues,
}
l.get = func(i int) any {
return l.mutableValues[i]
}
return l
}
// baseList points to a list containing elements of any type.
// The `value` is an array of native values, and refValue is its reflection object.
// The `ref.TypeAdapter` enables native type to CEL type conversions.
// The `Adapter` enables native type to CEL type conversions.
type baseList struct {
ref.TypeAdapter
value interface{}
Adapter
value any
// size indicates the number of elements within the list.
// Since objects are immutable the size of a list is static.
@@ -122,7 +117,7 @@ type baseList struct {
// get returns a value at the specified integer index.
// The index is guaranteed to be checked against the list index range.
get func(int) interface{}
get func(int) any
}
// Add implements the traits.Adder interface method.
@@ -138,9 +133,9 @@ func (l *baseList) Add(other ref.Val) ref.Val {
return l
}
return &concatList{
TypeAdapter: l.TypeAdapter,
prevList: l,
nextList: otherList}
Adapter: l.Adapter,
prevList: l,
nextList: otherList}
}
// Contains implements the traits.Container interface method.
@@ -157,7 +152,7 @@ func (l *baseList) Contains(elem ref.Val) ref.Val {
}
// ConvertToNative implements the ref.Val interface method.
func (l *baseList) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (l *baseList) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the underlying list value is assignable to the reflected type return it.
if reflect.TypeOf(l.value).AssignableTo(typeDesc) {
return l.value, nil
@@ -240,7 +235,7 @@ func (l *baseList) Equal(other ref.Val) ref.Val {
// Get implements the traits.Indexer interface method.
func (l *baseList) Get(index ref.Val) ref.Val {
ind, err := indexOrError(index)
ind, err := IndexOrError(index)
if err != nil {
return ValOrErr(index, err.Error())
}
@@ -250,6 +245,11 @@ func (l *baseList) Get(index ref.Val) ref.Val {
return l.NativeToValue(l.get(ind))
}
// IsZeroValue returns true if the list is empty.
func (l *baseList) IsZeroValue() bool {
return l.size == 0
}
// Iterator implements the traits.Iterable interface method.
func (l *baseList) Iterator() traits.Iterator {
return newListIterator(l)
@@ -266,10 +266,24 @@ func (l *baseList) Type() ref.Type {
}
// Value implements the ref.Val interface method.
func (l *baseList) Value() interface{} {
func (l *baseList) Value() any {
return l.value
}
// String converts the list to a human readable string form.
func (l *baseList) String() string {
var sb strings.Builder
sb.WriteString("[")
for i := 0; i < l.size; i++ {
sb.WriteString(fmt.Sprintf("%v", l.get(i)))
if i != l.size-1 {
sb.WriteString(", ")
}
}
sb.WriteString("]")
return sb.String()
}
// mutableList aggregates values into its internal storage. For use with internal CEL variables only.
type mutableList struct {
*baseList
@@ -298,14 +312,14 @@ func (l *mutableList) Add(other ref.Val) ref.Val {
func (l *mutableList) ToImmutableList() traits.Lister {
// The reference to internal state is guaranteed to be safe as this call is only performed
// when mutations have been completed.
return NewRefValList(l.TypeAdapter, l.mutableValues)
return NewRefValList(l.Adapter, l.mutableValues)
}
// concatList combines two list implementations together into a view.
// The `ref.TypeAdapter` enables native type to CEL type conversions.
// The `Adapter` enables native type to CEL type conversions.
type concatList struct {
ref.TypeAdapter
value interface{}
Adapter
value any
prevList traits.Lister
nextList traits.Lister
}
@@ -323,9 +337,9 @@ func (l *concatList) Add(other ref.Val) ref.Val {
return l
}
return &concatList{
TypeAdapter: l.TypeAdapter,
prevList: l,
nextList: otherList}
Adapter: l.Adapter,
prevList: l,
nextList: otherList}
}
// Contains implements the traits.Container interface method.
@@ -351,8 +365,8 @@ func (l *concatList) Contains(elem ref.Val) ref.Val {
}
// ConvertToNative implements the ref.Val interface method.
func (l *concatList) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
combined := NewDynamicList(l.TypeAdapter, l.Value().([]interface{}))
func (l *concatList) ConvertToNative(typeDesc reflect.Type) (any, error) {
combined := NewDynamicList(l.Adapter, l.Value().([]any))
return combined.ConvertToNative(typeDesc)
}
@@ -396,7 +410,7 @@ func (l *concatList) Equal(other ref.Val) ref.Val {
// Get implements the traits.Indexer interface method.
func (l *concatList) Get(index ref.Val) ref.Val {
ind, err := indexOrError(index)
ind, err := IndexOrError(index)
if err != nil {
return ValOrErr(index, err.Error())
}
@@ -408,6 +422,11 @@ func (l *concatList) Get(index ref.Val) ref.Val {
return l.nextList.Get(offset)
}
// IsZeroValue returns true if the list is empty.
func (l *concatList) IsZeroValue() bool {
return l.Size().(Int) == 0
}
// Iterator implements the traits.Iterable interface method.
func (l *concatList) Iterator() traits.Iterator {
return newListIterator(l)
@@ -418,15 +437,29 @@ func (l *concatList) Size() ref.Val {
return l.prevList.Size().(Int).Add(l.nextList.Size())
}
// String converts the concatenated list to a human-readable string.
func (l *concatList) String() string {
var sb strings.Builder
sb.WriteString("[")
for i := Int(0); i < l.Size().(Int); i++ {
sb.WriteString(fmt.Sprintf("%v", l.Get(i)))
if i != l.Size().(Int)-1 {
sb.WriteString(", ")
}
}
sb.WriteString("]")
return sb.String()
}
// Type implements the ref.Val interface method.
func (l *concatList) Type() ref.Type {
return ListType
}
// Value implements the ref.Val interface method.
func (l *concatList) Value() interface{} {
func (l *concatList) Value() any {
if l.value == nil {
merged := make([]interface{}, l.Size().(Int))
merged := make([]any, l.Size().(Int))
prevLen := l.prevList.Size().(Int)
for i := Int(0); i < prevLen; i++ {
merged[i] = l.prevList.Get(i).Value()
@@ -469,7 +502,8 @@ func (it *listIterator) Next() ref.Val {
return nil
}
func indexOrError(index ref.Val) (int, error) {
// IndexOrError converts an input index value into either a lossless integer index or an error.
func IndexOrError(index ref.Val) (int, error) {
switch iv := index.(type) {
case Int:
return int(iv), nil

View File

@@ -17,23 +17,25 @@ package types
import (
"fmt"
"reflect"
"strings"
"github.com/stoewer/go-strcase"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/stoewer/go-strcase"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
)
// NewDynamicMap returns a traits.Mapper value with dynamic key, value pairs.
func NewDynamicMap(adapter ref.TypeAdapter, value interface{}) traits.Mapper {
func NewDynamicMap(adapter Adapter, value any) traits.Mapper {
refValue := reflect.ValueOf(value)
return &baseMap{
TypeAdapter: adapter,
Adapter: adapter,
mapAccessor: newReflectMapAccessor(adapter, refValue),
value: value,
size: refValue.Len(),
@@ -44,10 +46,10 @@ func NewDynamicMap(adapter ref.TypeAdapter, value interface{}) traits.Mapper {
// encoded in protocol buffer form.
//
// The `adapter` argument provides type adaptation capabilities from proto to CEL.
func NewJSONStruct(adapter ref.TypeAdapter, value *structpb.Struct) traits.Mapper {
func NewJSONStruct(adapter Adapter, value *structpb.Struct) traits.Mapper {
fields := value.GetFields()
return &baseMap{
TypeAdapter: adapter,
Adapter: adapter,
mapAccessor: newJSONStructAccessor(adapter, fields),
value: value,
size: len(fields),
@@ -55,9 +57,9 @@ func NewJSONStruct(adapter ref.TypeAdapter, value *structpb.Struct) traits.Mappe
}
// NewRefValMap returns a specialized traits.Mapper with CEL valued keys and values.
func NewRefValMap(adapter ref.TypeAdapter, value map[ref.Val]ref.Val) traits.Mapper {
func NewRefValMap(adapter Adapter, value map[ref.Val]ref.Val) traits.Mapper {
return &baseMap{
TypeAdapter: adapter,
Adapter: adapter,
mapAccessor: newRefValMapAccessor(value),
value: value,
size: len(value),
@@ -65,9 +67,9 @@ func NewRefValMap(adapter ref.TypeAdapter, value map[ref.Val]ref.Val) traits.Map
}
// NewStringInterfaceMap returns a specialized traits.Mapper with string keys and interface values.
func NewStringInterfaceMap(adapter ref.TypeAdapter, value map[string]interface{}) traits.Mapper {
func NewStringInterfaceMap(adapter Adapter, value map[string]any) traits.Mapper {
return &baseMap{
TypeAdapter: adapter,
Adapter: adapter,
mapAccessor: newStringIfaceMapAccessor(adapter, value),
value: value,
size: len(value),
@@ -75,9 +77,9 @@ func NewStringInterfaceMap(adapter ref.TypeAdapter, value map[string]interface{}
}
// NewStringStringMap returns a specialized traits.Mapper with string keys and values.
func NewStringStringMap(adapter ref.TypeAdapter, value map[string]string) traits.Mapper {
func NewStringStringMap(adapter Adapter, value map[string]string) traits.Mapper {
return &baseMap{
TypeAdapter: adapter,
Adapter: adapter,
mapAccessor: newStringMapAccessor(value),
value: value,
size: len(value),
@@ -85,22 +87,13 @@ func NewStringStringMap(adapter ref.TypeAdapter, value map[string]string) traits
}
// NewProtoMap returns a specialized traits.Mapper for handling protobuf map values.
func NewProtoMap(adapter ref.TypeAdapter, value *pb.Map) traits.Mapper {
func NewProtoMap(adapter Adapter, value *pb.Map) traits.Mapper {
return &protoMap{
TypeAdapter: adapter,
value: value,
Adapter: adapter,
value: value,
}
}
var (
// MapType singleton.
MapType = NewTypeValue("map",
traits.ContainerType,
traits.IndexerType,
traits.IterableType,
traits.SizerType)
)
// mapAccessor is a private interface for finding values within a map and iterating over the keys.
// This interface implements portions of the API surface area required by the traits.Mapper
// interface.
@@ -119,13 +112,13 @@ type mapAccessor interface {
// Since CEL is side-effect free, the base map represents an immutable object.
type baseMap struct {
// TypeAdapter used to convert keys and values accessed within the map.
ref.TypeAdapter
Adapter
// mapAccessor interface implementation used to find and iterate over map keys.
mapAccessor
// value is the native Go value upon which the map type operators.
value interface{}
value any
// size is the number of entries in the map.
size int
@@ -138,7 +131,7 @@ func (m *baseMap) Contains(index ref.Val) ref.Val {
}
// ConvertToNative implements the ref.Val interface method.
func (m *baseMap) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (m *baseMap) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the map is already assignable to the desired type return it, e.g. interfaces and
// maps with the same key value types.
if reflect.TypeOf(m.value).AssignableTo(typeDesc) {
@@ -275,30 +268,54 @@ func (m *baseMap) Get(key ref.Val) ref.Val {
return v
}
// IsZeroValue returns true if the map is empty.
func (m *baseMap) IsZeroValue() bool {
return m.size == 0
}
// Size implements the traits.Sizer interface method.
func (m *baseMap) Size() ref.Val {
return Int(m.size)
}
// String converts the map into a human-readable string.
func (m *baseMap) String() string {
var sb strings.Builder
sb.WriteString("{")
it := m.Iterator()
i := 0
for it.HasNext() == True {
k := it.Next()
v, _ := m.Find(k)
sb.WriteString(fmt.Sprintf("%v: %v", k, v))
if i != m.size-1 {
sb.WriteString(", ")
}
i++
}
sb.WriteString("}")
return sb.String()
}
// Type implements the ref.Val interface method.
func (m *baseMap) Type() ref.Type {
return MapType
}
// Value implements the ref.Val interface method.
func (m *baseMap) Value() interface{} {
func (m *baseMap) Value() any {
return m.value
}
func newJSONStructAccessor(adapter ref.TypeAdapter, st map[string]*structpb.Value) mapAccessor {
func newJSONStructAccessor(adapter Adapter, st map[string]*structpb.Value) mapAccessor {
return &jsonStructAccessor{
TypeAdapter: adapter,
st: st,
Adapter: adapter,
st: st,
}
}
type jsonStructAccessor struct {
ref.TypeAdapter
Adapter
st map[string]*structpb.Value
}
@@ -333,17 +350,17 @@ func (a *jsonStructAccessor) Iterator() traits.Iterator {
}
}
func newReflectMapAccessor(adapter ref.TypeAdapter, value reflect.Value) mapAccessor {
func newReflectMapAccessor(adapter Adapter, value reflect.Value) mapAccessor {
keyType := value.Type().Key()
return &reflectMapAccessor{
TypeAdapter: adapter,
refValue: value,
keyType: keyType,
Adapter: adapter,
refValue: value,
keyType: keyType,
}
}
type reflectMapAccessor struct {
ref.TypeAdapter
Adapter
refValue reflect.Value
keyType reflect.Type
}
@@ -401,9 +418,9 @@ func (m *reflectMapAccessor) findInternal(key ref.Val) (ref.Val, bool) {
// Iterator creates a Golang reflection based traits.Iterator.
func (m *reflectMapAccessor) Iterator() traits.Iterator {
return &mapIterator{
TypeAdapter: m.TypeAdapter,
mapKeys: m.refValue.MapRange(),
len: m.refValue.Len(),
Adapter: m.Adapter,
mapKeys: m.refValue.MapRange(),
len: m.refValue.Len(),
}
}
@@ -454,9 +471,9 @@ func (a *refValMapAccessor) Find(key ref.Val) (ref.Val, bool) {
// Iterator produces a new traits.Iterator which iterates over the map keys via Golang reflection.
func (a *refValMapAccessor) Iterator() traits.Iterator {
return &mapIterator{
TypeAdapter: DefaultTypeAdapter,
mapKeys: reflect.ValueOf(a.mapVal).MapRange(),
len: len(a.mapVal),
Adapter: DefaultTypeAdapter,
mapKeys: reflect.ValueOf(a.mapVal).MapRange(),
len: len(a.mapVal),
}
}
@@ -498,16 +515,16 @@ func (a *stringMapAccessor) Iterator() traits.Iterator {
}
}
func newStringIfaceMapAccessor(adapter ref.TypeAdapter, mapVal map[string]interface{}) mapAccessor {
func newStringIfaceMapAccessor(adapter Adapter, mapVal map[string]any) mapAccessor {
return &stringIfaceMapAccessor{
TypeAdapter: adapter,
mapVal: mapVal,
Adapter: adapter,
mapVal: mapVal,
}
}
type stringIfaceMapAccessor struct {
ref.TypeAdapter
mapVal map[string]interface{}
Adapter
mapVal map[string]any
}
// Find uses native map accesses to find the key, returning (value, true) if present.
@@ -543,7 +560,7 @@ func (a *stringIfaceMapAccessor) Iterator() traits.Iterator {
// protoMap is a specialized, separate implementation of the traits.Mapper interfaces tailored to
// accessing protoreflect.Map values.
type protoMap struct {
ref.TypeAdapter
Adapter
value *pb.Map
}
@@ -556,7 +573,7 @@ func (m *protoMap) Contains(key ref.Val) ref.Val {
// ConvertToNative implements the ref.Val interface method.
//
// Note, assignment to Golang struct types is not yet supported.
func (m *protoMap) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (m *protoMap) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the map is already assignable to the desired type return it, e.g. interfaces and
// maps with the same key value types.
switch typeDesc {
@@ -601,9 +618,9 @@ func (m *protoMap) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
m.value.Range(func(key protoreflect.MapKey, val protoreflect.Value) bool {
ntvKey := key.Interface()
ntvVal := val.Interface()
switch ntvVal.(type) {
switch pv := ntvVal.(type) {
case protoreflect.Message:
ntvVal = ntvVal.(protoreflect.Message).Interface()
ntvVal = pv.Interface()
}
if keyType == otherKeyType && valType == otherValType {
mapVal.SetMapIndex(reflect.ValueOf(ntvKey), reflect.ValueOf(ntvVal))
@@ -732,6 +749,11 @@ func (m *protoMap) Get(key ref.Val) ref.Val {
return v
}
// IsZeroValue returns true if the map is empty.
func (m *protoMap) IsZeroValue() bool {
return m.value.Len() == 0
}
// Iterator implements the traits.Iterable interface method.
func (m *protoMap) Iterator() traits.Iterator {
// Copy the keys to make their order stable.
@@ -741,9 +763,9 @@ func (m *protoMap) Iterator() traits.Iterator {
return true
})
return &protoMapIterator{
TypeAdapter: m.TypeAdapter,
mapKeys: mapKeys,
len: m.value.Len(),
Adapter: m.Adapter,
mapKeys: mapKeys,
len: m.value.Len(),
}
}
@@ -758,13 +780,13 @@ func (m *protoMap) Type() ref.Type {
}
// Value implements the ref.Val interface method.
func (m *protoMap) Value() interface{} {
func (m *protoMap) Value() any {
return m.value
}
type mapIterator struct {
*baseIterator
ref.TypeAdapter
Adapter
mapKeys *reflect.MapIter
cursor int
len int
@@ -787,7 +809,7 @@ func (it *mapIterator) Next() ref.Val {
type protoMapIterator struct {
*baseIterator
ref.TypeAdapter
Adapter
mapKeys []protoreflect.MapKey
cursor int
len int

View File

@@ -18,9 +18,10 @@ import (
"fmt"
"reflect"
"github.com/google/cel-go/common/types/ref"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/common/types/ref"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
)
@@ -29,19 +30,23 @@ import (
type Null structpb.NullValue
var (
// NullType singleton.
NullType = NewTypeValue("null_type")
// NullValue singleton.
NullValue = Null(structpb.NullValue_NULL_VALUE)
jsonNullType = reflect.TypeOf(structpb.NullValue_NULL_VALUE)
// golang reflect type for Null values.
nullReflectType = reflect.TypeOf(NullValue)
)
// ConvertToNative implements ref.Val.ConvertToNative.
func (n Null) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (n Null) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Int32:
return reflect.ValueOf(n).Convert(typeDesc).Interface(), nil
switch typeDesc {
case jsonNullType:
return structpb.NullValue_NULL_VALUE, nil
case nullReflectType:
return n, nil
}
case reflect.Ptr:
switch typeDesc {
case anyValueType:
@@ -54,6 +59,10 @@ func (n Null) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
return anypb.New(pb.(proto.Message))
case jsonValueType:
return structpb.NewNullValue(), nil
case boolWrapperType, byteWrapperType, doubleWrapperType, floatWrapperType,
int32WrapperType, int64WrapperType, stringWrapperType, uint32WrapperType,
uint64WrapperType:
return nil, nil
}
case reflect.Interface:
nv := n.Value()
@@ -86,12 +95,17 @@ func (n Null) Equal(other ref.Val) ref.Val {
return Bool(NullType == other.Type())
}
// IsZeroValue returns true as null always represents an absent value.
func (n Null) IsZeroValue() bool {
return true
}
// Type implements ref.Val.Type.
func (n Null) Type() ref.Type {
return NullType
}
// Value implements ref.Val.Value.
func (n Null) Value() interface{} {
func (n Null) Value() any {
return structpb.NullValue_NULL_VALUE
}

View File

@@ -18,20 +18,21 @@ import (
"fmt"
"reflect"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
)
type protoObj struct {
ref.TypeAdapter
Adapter
value proto.Message
typeDesc *pb.TypeDescription
typeValue *TypeValue
typeValue ref.Val
}
// NewObject returns an object based on a proto.Message value which handles
@@ -41,18 +42,18 @@ type protoObj struct {
// Note: the type value is pulled from the list of registered types within the
// type provider. If the proto type is not registered within the type provider,
// then this will result in an error within the type adapter / provider.
func NewObject(adapter ref.TypeAdapter,
func NewObject(adapter Adapter,
typeDesc *pb.TypeDescription,
typeValue *TypeValue,
typeValue ref.Val,
value proto.Message) ref.Val {
return &protoObj{
TypeAdapter: adapter,
value: value,
typeDesc: typeDesc,
typeValue: typeValue}
Adapter: adapter,
value: value,
typeDesc: typeDesc,
typeValue: typeValue}
}
func (o *protoObj) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (o *protoObj) ConvertToNative(typeDesc reflect.Type) (any, error) {
srcPB := o.value
if reflect.TypeOf(srcPB).AssignableTo(typeDesc) {
return srcPB, nil
@@ -133,6 +134,11 @@ func (o *protoObj) IsSet(field ref.Val) ref.Val {
return False
}
// IsZeroValue returns true if the protobuf object is empty.
func (o *protoObj) IsZeroValue() bool {
return proto.Equal(o.value, o.typeDesc.Zero())
}
func (o *protoObj) Get(index ref.Val) ref.Val {
protoFieldName, ok := index.(String)
if !ok {
@@ -151,9 +157,9 @@ func (o *protoObj) Get(index ref.Val) ref.Val {
}
func (o *protoObj) Type() ref.Type {
return o.typeValue
return o.typeValue.(ref.Type)
}
func (o *protoObj) Value() interface{} {
func (o *protoObj) Value() any {
return o.value
}

View File

@@ -0,0 +1,108 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"errors"
"fmt"
"reflect"
"github.com/google/cel-go/common/types/ref"
)
var (
// OptionalType indicates the runtime type of an optional value.
OptionalType = NewOpaqueType("optional")
// OptionalNone is a sentinel value which is used to indicate an empty optional value.
OptionalNone = &Optional{}
)
// OptionalOf returns an optional value which wraps a concrete CEL value.
func OptionalOf(value ref.Val) *Optional {
return &Optional{value: value}
}
// Optional value which points to a value if non-empty.
type Optional struct {
value ref.Val
}
// HasValue returns true if the optional has a value.
func (o *Optional) HasValue() bool {
return o.value != nil
}
// GetValue returns the wrapped value contained in the optional.
func (o *Optional) GetValue() ref.Val {
if !o.HasValue() {
return NewErr("optional.none() dereference")
}
return o.value
}
// ConvertToNative implements the ref.Val interface method.
func (o *Optional) ConvertToNative(typeDesc reflect.Type) (any, error) {
if !o.HasValue() {
return nil, errors.New("optional.none() dereference")
}
return o.value.ConvertToNative(typeDesc)
}
// ConvertToType implements the ref.Val interface method.
func (o *Optional) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case OptionalType:
return o
case TypeType:
return OptionalType
}
return NewErr("type conversion error from '%s' to '%s'", OptionalType, typeVal)
}
// Equal determines whether the values contained by two optional values are equal.
func (o *Optional) Equal(other ref.Val) ref.Val {
otherOpt, isOpt := other.(*Optional)
if !isOpt {
return False
}
if !o.HasValue() {
return Bool(!otherOpt.HasValue())
}
if !otherOpt.HasValue() {
return False
}
return o.value.Equal(otherOpt.value)
}
func (o *Optional) String() string {
if o.HasValue() {
return fmt.Sprintf("optional(%v)", o.GetValue())
}
return "optional.none()"
}
// Type implements the ref.Val interface method.
func (o *Optional) Type() ref.Type {
return OptionalType
}
// Value returns the underlying 'Value()' of the wrapped value, if present.
func (o *Optional) Value() any {
if o.value == nil {
return nil
}
return o.value.Value()
}

View File

@@ -17,7 +17,7 @@ go_library(
],
importpath = "github.com/google/cel-go/common/types/pb",
deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//encoding/protowire:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",

View File

@@ -18,9 +18,9 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"
)
// NewEnumValueDescription produces an enum value description with the fully qualified enum value
// newEnumValueDescription produces an enum value description with the fully qualified enum value
// name and the enum value descriptor.
func NewEnumValueDescription(name string, desc protoreflect.EnumValueDescriptor) *EnumValueDescription {
func newEnumValueDescription(name string, desc protoreflect.EnumValueDescriptor) *EnumValueDescription {
return &EnumValueDescription{
enumValueName: name,
desc: desc,

View File

@@ -18,32 +18,66 @@ import (
"fmt"
"google.golang.org/protobuf/reflect/protoreflect"
dynamicpb "google.golang.org/protobuf/types/dynamicpb"
)
// NewFileDescription returns a FileDescription instance with a complete listing of all the message
// types and enum values declared within any scope in the file.
func NewFileDescription(fileDesc protoreflect.FileDescriptor, pbdb *Db) *FileDescription {
// newFileDescription returns a FileDescription instance with a complete listing of all the message
// types and enum values, as well as a map of extensions declared within any scope in the file.
func newFileDescription(fileDesc protoreflect.FileDescriptor, pbdb *Db) (*FileDescription, extensionMap) {
metadata := collectFileMetadata(fileDesc)
enums := make(map[string]*EnumValueDescription)
for name, enumVal := range metadata.enumValues {
enums[name] = NewEnumValueDescription(name, enumVal)
enums[name] = newEnumValueDescription(name, enumVal)
}
types := make(map[string]*TypeDescription)
for name, msgType := range metadata.msgTypes {
types[name] = NewTypeDescription(name, msgType)
types[name] = newTypeDescription(name, msgType, pbdb.extensions)
}
fileExtMap := make(extensionMap)
for typeName, extensions := range metadata.msgExtensionMap {
messageExtMap, found := fileExtMap[typeName]
if !found {
messageExtMap = make(map[string]*FieldDescription)
}
for _, ext := range extensions {
extDesc := dynamicpb.NewExtensionType(ext).TypeDescriptor()
messageExtMap[string(ext.FullName())] = newFieldDescription(extDesc)
}
fileExtMap[typeName] = messageExtMap
}
return &FileDescription{
name: fileDesc.Path(),
types: types,
enums: enums,
}
}, fileExtMap
}
// FileDescription holds a map of all types and enum values declared within a proto file.
type FileDescription struct {
name string
types map[string]*TypeDescription
enums map[string]*EnumValueDescription
}
// Copy creates a copy of the FileDescription with updated Db references within its types.
func (fd *FileDescription) Copy(pbdb *Db) *FileDescription {
typesCopy := make(map[string]*TypeDescription, len(fd.types))
for k, v := range fd.types {
typesCopy[k] = v.Copy(pbdb)
}
return &FileDescription{
name: fd.name,
types: typesCopy,
enums: fd.enums,
}
}
// GetName returns the fully qualified file path for the file.
func (fd *FileDescription) GetName() string {
return fd.name
}
// GetEnumDescription returns an EnumDescription for a qualified enum value
// name declared within the .proto file.
func (fd *FileDescription) GetEnumDescription(enumName string) (*EnumValueDescription, bool) {
@@ -94,6 +128,10 @@ type fileMetadata struct {
msgTypes map[string]protoreflect.MessageDescriptor
// enumValues maps from fully-qualified enum value to enum value descriptor.
enumValues map[string]protoreflect.EnumValueDescriptor
// msgExtensionMap maps from the protobuf message name being extended to a set of extensions
// for the type.
msgExtensionMap map[string][]protoreflect.ExtensionDescriptor
// TODO: support enum type definitions for use in future type-check enhancements.
}
@@ -102,28 +140,38 @@ type fileMetadata struct {
func collectFileMetadata(fileDesc protoreflect.FileDescriptor) *fileMetadata {
msgTypes := make(map[string]protoreflect.MessageDescriptor)
enumValues := make(map[string]protoreflect.EnumValueDescriptor)
collectMsgTypes(fileDesc.Messages(), msgTypes, enumValues)
msgExtensionMap := make(map[string][]protoreflect.ExtensionDescriptor)
collectMsgTypes(fileDesc.Messages(), msgTypes, enumValues, msgExtensionMap)
collectEnumValues(fileDesc.Enums(), enumValues)
collectExtensions(fileDesc.Extensions(), msgExtensionMap)
return &fileMetadata{
msgTypes: msgTypes,
enumValues: enumValues,
msgTypes: msgTypes,
enumValues: enumValues,
msgExtensionMap: msgExtensionMap,
}
}
// collectMsgTypes recursively collects messages, nested messages, and nested enums into a map of
// fully qualified protobuf names to descriptors.
func collectMsgTypes(msgTypes protoreflect.MessageDescriptors, msgTypeMap map[string]protoreflect.MessageDescriptor, enumValueMap map[string]protoreflect.EnumValueDescriptor) {
func collectMsgTypes(msgTypes protoreflect.MessageDescriptors,
msgTypeMap map[string]protoreflect.MessageDescriptor,
enumValueMap map[string]protoreflect.EnumValueDescriptor,
msgExtensionMap map[string][]protoreflect.ExtensionDescriptor) {
for i := 0; i < msgTypes.Len(); i++ {
msgType := msgTypes.Get(i)
msgTypeMap[string(msgType.FullName())] = msgType
nestedMsgTypes := msgType.Messages()
if nestedMsgTypes.Len() != 0 {
collectMsgTypes(nestedMsgTypes, msgTypeMap, enumValueMap)
collectMsgTypes(nestedMsgTypes, msgTypeMap, enumValueMap, msgExtensionMap)
}
nestedEnumTypes := msgType.Enums()
if nestedEnumTypes.Len() != 0 {
collectEnumValues(nestedEnumTypes, enumValueMap)
}
nestedExtensions := msgType.Extensions()
if nestedExtensions.Len() != 0 {
collectExtensions(nestedExtensions, msgExtensionMap)
}
}
}
@@ -139,3 +187,16 @@ func collectEnumValues(enumTypes protoreflect.EnumDescriptors, enumValueMap map[
}
}
}
func collectExtensions(extensions protoreflect.ExtensionDescriptors, msgExtensionMap map[string][]protoreflect.ExtensionDescriptor) {
for i := 0; i < extensions.Len(); i++ {
ext := extensions.Get(i)
extendsMsg := string(ext.ContainingMessage().FullName())
msgExts, found := msgExtensionMap[extendsMsg]
if !found {
msgExts = []protoreflect.ExtensionDescriptor{}
}
msgExts = append(msgExts, ext)
msgExtensionMap[extendsMsg] = msgExts
}
}

View File

@@ -40,13 +40,19 @@ type Db struct {
revFileDescriptorMap map[string]*FileDescription
// files contains the deduped set of FileDescriptions whose types are contained in the pb.Db.
files []*FileDescription
// extensions contains the mapping between a given type name, extension name and its FieldDescription
extensions map[string]map[string]*FieldDescription
}
// extensionsMap is a type alias to a map[typeName]map[extensionName]*FieldDescription
type extensionMap = map[string]map[string]*FieldDescription
var (
// DefaultDb used at evaluation time or unless overridden at check time.
DefaultDb = &Db{
revFileDescriptorMap: make(map[string]*FileDescription),
files: []*FileDescription{},
extensions: make(extensionMap),
}
)
@@ -80,6 +86,7 @@ func NewDb() *Db {
pbdb := &Db{
revFileDescriptorMap: make(map[string]*FileDescription),
files: []*FileDescription{},
extensions: make(extensionMap),
}
// The FileDescription objects in the default db contain lazily initialized TypeDescription
// values which may point to the state contained in the DefaultDb irrespective of this shallow
@@ -96,19 +103,34 @@ func NewDb() *Db {
// Copy creates a copy of the current database with its own internal descriptor mapping.
func (pbdb *Db) Copy() *Db {
copy := NewDb()
for k, v := range pbdb.revFileDescriptorMap {
copy.revFileDescriptorMap[k] = v
}
for _, f := range pbdb.files {
for _, fd := range pbdb.files {
hasFile := false
for _, f2 := range copy.files {
if f2 == f {
for _, fd2 := range copy.files {
if fd2 == fd {
hasFile = true
}
}
if !hasFile {
copy.files = append(copy.files, f)
fd = fd.Copy(copy)
copy.files = append(copy.files, fd)
}
for _, enumValName := range fd.GetEnumNames() {
copy.revFileDescriptorMap[enumValName] = fd
}
for _, msgTypeName := range fd.GetTypeNames() {
copy.revFileDescriptorMap[msgTypeName] = fd
}
copy.revFileDescriptorMap[fd.GetName()] = fd
}
for typeName, extFieldMap := range pbdb.extensions {
copyExtFieldMap, found := copy.extensions[typeName]
if !found {
copyExtFieldMap = make(map[string]*FieldDescription, len(extFieldMap))
}
for extFieldName, fd := range extFieldMap {
copyExtFieldMap[extFieldName] = fd
}
copy.extensions[typeName] = copyExtFieldMap
}
return copy
}
@@ -137,17 +159,30 @@ func (pbdb *Db) RegisterDescriptor(fileDesc protoreflect.FileDescriptor) (*FileD
if err == nil {
fileDesc = globalFD
}
fd = NewFileDescription(fileDesc, pbdb)
var fileExtMap extensionMap
fd, fileExtMap = newFileDescription(fileDesc, pbdb)
for _, enumValName := range fd.GetEnumNames() {
pbdb.revFileDescriptorMap[enumValName] = fd
}
for _, msgTypeName := range fd.GetTypeNames() {
pbdb.revFileDescriptorMap[msgTypeName] = fd
}
pbdb.revFileDescriptorMap[fileDesc.Path()] = fd
pbdb.revFileDescriptorMap[fd.GetName()] = fd
// Return the specific file descriptor registered.
pbdb.files = append(pbdb.files, fd)
// Index the protobuf message extensions from the file into the pbdb
for typeName, extMap := range fileExtMap {
typeExtMap, found := pbdb.extensions[typeName]
if !found {
pbdb.extensions[typeName] = extMap
continue
}
for extName, field := range extMap {
typeExtMap[extName] = field
}
}
return fd, nil
}

View File

@@ -38,22 +38,23 @@ type description interface {
Zero() proto.Message
}
// NewTypeDescription produces a TypeDescription value for the fully-qualified proto type name
// newTypeDescription produces a TypeDescription value for the fully-qualified proto type name
// with a given descriptor.
func NewTypeDescription(typeName string, desc protoreflect.MessageDescriptor) *TypeDescription {
func newTypeDescription(typeName string, desc protoreflect.MessageDescriptor, extensions extensionMap) *TypeDescription {
msgType := dynamicpb.NewMessageType(desc)
msgZero := dynamicpb.NewMessage(desc)
fieldMap := map[string]*FieldDescription{}
fields := desc.Fields()
for i := 0; i < fields.Len(); i++ {
f := fields.Get(i)
fieldMap[string(f.Name())] = NewFieldDescription(f)
fieldMap[string(f.Name())] = newFieldDescription(f)
}
return &TypeDescription{
typeName: typeName,
desc: desc,
msgType: msgType,
fieldMap: fieldMap,
extensions: extensions,
reflectType: reflectTypeOf(msgZero),
zeroMsg: zeroValueOf(msgZero),
}
@@ -66,10 +67,24 @@ type TypeDescription struct {
desc protoreflect.MessageDescriptor
msgType protoreflect.MessageType
fieldMap map[string]*FieldDescription
extensions extensionMap
reflectType reflect.Type
zeroMsg proto.Message
}
// Copy copies the type description with updated references to the Db.
func (td *TypeDescription) Copy(pbdb *Db) *TypeDescription {
return &TypeDescription{
typeName: td.typeName,
desc: td.desc,
msgType: td.msgType,
fieldMap: td.fieldMap,
extensions: pbdb.extensions,
reflectType: td.reflectType,
zeroMsg: td.zeroMsg,
}
}
// FieldMap returns a string field name to FieldDescription map.
func (td *TypeDescription) FieldMap() map[string]*FieldDescription {
return td.fieldMap
@@ -78,16 +93,21 @@ func (td *TypeDescription) FieldMap() map[string]*FieldDescription {
// FieldByName returns (FieldDescription, true) if the field name is declared within the type.
func (td *TypeDescription) FieldByName(name string) (*FieldDescription, bool) {
fd, found := td.fieldMap[name]
if found {
return fd, true
}
extFieldMap, found := td.extensions[td.typeName]
if !found {
return nil, false
}
return fd, true
fd, found = extFieldMap[name]
return fd, found
}
// MaybeUnwrap accepts a proto message as input and unwraps it to a primitive CEL type if possible.
//
// This method returns the unwrapped value and 'true', else the original value and 'false'.
func (td *TypeDescription) MaybeUnwrap(msg proto.Message) (interface{}, bool, error) {
func (td *TypeDescription) MaybeUnwrap(msg proto.Message) (any, bool, error) {
return unwrap(td, msg)
}
@@ -111,8 +131,8 @@ func (td *TypeDescription) Zero() proto.Message {
return td.zeroMsg
}
// NewFieldDescription creates a new field description from a protoreflect.FieldDescriptor.
func NewFieldDescription(fieldDesc protoreflect.FieldDescriptor) *FieldDescription {
// newFieldDescription creates a new field description from a protoreflect.FieldDescriptor.
func newFieldDescription(fieldDesc protoreflect.FieldDescriptor) *FieldDescription {
var reflectType reflect.Type
var zeroMsg proto.Message
switch fieldDesc.Kind() {
@@ -124,9 +144,17 @@ func NewFieldDescription(fieldDesc protoreflect.FieldDescriptor) *FieldDescripti
default:
reflectType = reflectTypeOf(fieldDesc.Default().Interface())
if fieldDesc.IsList() {
parentMsg := dynamicpb.NewMessage(fieldDesc.ContainingMessage())
listField := parentMsg.NewField(fieldDesc).List()
elem := listField.NewElement().Interface()
var elemValue protoreflect.Value
if fieldDesc.IsExtension() {
et := dynamicpb.NewExtensionType(fieldDesc)
elemValue = et.New().List().NewElement()
} else {
parentMsgType := fieldDesc.ContainingMessage()
parentMsg := dynamicpb.NewMessage(parentMsgType)
listField := parentMsg.NewField(fieldDesc).List()
elemValue = listField.NewElement()
}
elem := elemValue.Interface()
switch elemType := elem.(type) {
case protoreflect.Message:
elem = elemType.Interface()
@@ -140,8 +168,8 @@ func NewFieldDescription(fieldDesc protoreflect.FieldDescriptor) *FieldDescripti
}
var keyType, valType *FieldDescription
if fieldDesc.IsMap() {
keyType = NewFieldDescription(fieldDesc.MapKey())
valType = NewFieldDescription(fieldDesc.MapValue())
keyType = newFieldDescription(fieldDesc.MapKey())
valType = newFieldDescription(fieldDesc.MapValue())
}
return &FieldDescription{
desc: fieldDesc,
@@ -195,7 +223,7 @@ func (fd *FieldDescription) Descriptor() protoreflect.FieldDescriptor {
//
// This function implements the FieldType.IsSet function contract which can be used to operate on
// more than just protobuf field accesses; however, the target here must be a protobuf.Message.
func (fd *FieldDescription) IsSet(target interface{}) bool {
func (fd *FieldDescription) IsSet(target any) bool {
switch v := target.(type) {
case proto.Message:
pbRef := v.ProtoReflect()
@@ -219,14 +247,14 @@ func (fd *FieldDescription) IsSet(target interface{}) bool {
//
// This function implements the FieldType.GetFrom function contract which can be used to operate
// on more than just protobuf field accesses; however, the target here must be a protobuf.Message.
func (fd *FieldDescription) GetFrom(target interface{}) (interface{}, error) {
func (fd *FieldDescription) GetFrom(target any) (any, error) {
v, ok := target.(proto.Message)
if !ok {
return nil, fmt.Errorf("unsupported field selection target: (%T)%v", target, target)
}
pbRef := v.ProtoReflect()
pbDesc := pbRef.Descriptor()
var fieldVal interface{}
var fieldVal any
if pbDesc == fd.desc.ContainingMessage() {
// When the target protobuf shares the same message descriptor instance as the field
// descriptor, use the cached field descriptor value.
@@ -257,7 +285,7 @@ func (fd *FieldDescription) GetFrom(target interface{}) (interface{}, error) {
// IsEnum returns true if the field type refers to an enum value.
func (fd *FieldDescription) IsEnum() bool {
return fd.desc.Kind() == protoreflect.EnumKind
return fd.ProtoKind() == protoreflect.EnumKind
}
// IsMap returns true if the field is of map type.
@@ -267,7 +295,7 @@ func (fd *FieldDescription) IsMap() bool {
// IsMessage returns true if the field is of message type.
func (fd *FieldDescription) IsMessage() bool {
kind := fd.desc.Kind()
kind := fd.ProtoKind()
return kind == protoreflect.MessageKind || kind == protoreflect.GroupKind
}
@@ -289,7 +317,7 @@ func (fd *FieldDescription) IsList() bool {
//
// This function returns the unwrapped value and 'true' on success, or the original value
// and 'false' otherwise.
func (fd *FieldDescription) MaybeUnwrapDynamic(msg protoreflect.Message) (interface{}, bool, error) {
func (fd *FieldDescription) MaybeUnwrapDynamic(msg protoreflect.Message) (any, bool, error) {
return unwrapDynamic(fd, msg)
}
@@ -298,6 +326,11 @@ func (fd *FieldDescription) Name() string {
return string(fd.desc.Name())
}
// ProtoKind returns the protobuf reflected kind of the field.
func (fd *FieldDescription) ProtoKind() protoreflect.Kind {
return fd.desc.Kind()
}
// ReflectType returns the Golang reflect.Type for this field.
func (fd *FieldDescription) ReflectType() reflect.Type {
return fd.reflectType
@@ -317,17 +350,17 @@ func (fd *FieldDescription) Zero() proto.Message {
}
func (fd *FieldDescription) typeDefToType() *exprpb.Type {
if fd.desc.Kind() == protoreflect.MessageKind || fd.desc.Kind() == protoreflect.GroupKind {
if fd.IsMessage() {
msgType := string(fd.desc.Message().FullName())
if wk, found := CheckedWellKnowns[msgType]; found {
return wk
}
return checkedMessageType(msgType)
}
if fd.desc.Kind() == protoreflect.EnumKind {
if fd.IsEnum() {
return checkedInt
}
return CheckedPrimitives[fd.desc.Kind()]
return CheckedPrimitives[fd.ProtoKind()]
}
// Map wraps the protoreflect.Map object with a key and value FieldDescription for use in
@@ -362,7 +395,7 @@ func checkedWrap(t *exprpb.Type) *exprpb.Type {
// input message is a *dynamicpb.Message which obscures the typing information from Go.
//
// Returns the unwrapped value and 'true' if unwrapped, otherwise the input value and 'false'.
func unwrap(desc description, msg proto.Message) (interface{}, bool, error) {
func unwrap(desc description, msg proto.Message) (any, bool, error) {
switch v := msg.(type) {
case *anypb.Any:
dynMsg, err := v.UnmarshalNew()
@@ -418,7 +451,7 @@ func unwrap(desc description, msg proto.Message) (interface{}, bool, error) {
// unwrapDynamic unwraps a reflected protobuf Message value.
//
// Returns the unwrapped value and 'true' if unwrapped, otherwise the input value and 'false'.
func unwrapDynamic(desc description, refMsg protoreflect.Message) (interface{}, bool, error) {
func unwrapDynamic(desc description, refMsg protoreflect.Message) (any, bool, error) {
msg := refMsg.Interface()
if !refMsg.IsValid() {
msg = desc.Zero()
@@ -435,13 +468,13 @@ func unwrapDynamic(desc description, refMsg protoreflect.Message) (interface{},
unwrappedAny := &anypb.Any{}
err := Merge(unwrappedAny, msg)
if err != nil {
return nil, false, err
return nil, false, fmt.Errorf("unwrap dynamic field failed: %v", err)
}
dynMsg, err := unwrappedAny.UnmarshalNew()
if err != nil {
// Allow the error to move further up the stack as it should result in an type
// conversion error if the caller does not recover it somehow.
return nil, false, err
return nil, false, fmt.Errorf("unmarshal dynamic any failed: %v", err)
}
// Attempt to unwrap the dynamic type, otherwise return the dynamic message.
unwrapped, nested, err := unwrapDynamic(desc, dynMsg.ProtoReflect())
@@ -508,7 +541,7 @@ func unwrapDynamic(desc description, refMsg protoreflect.Message) (interface{},
// reflectTypeOf intercepts the reflect.Type call to ensure that dynamicpb.Message types preserve
// well-known protobuf reflected types expected by the CEL type system.
func reflectTypeOf(val interface{}) reflect.Type {
func reflectTypeOf(val any) reflect.Type {
switch v := val.(type) {
case proto.Message:
return reflect.TypeOf(zeroValueOf(v))
@@ -532,8 +565,10 @@ func zeroValueOf(msg proto.Message) proto.Message {
}
var (
jsonValueTypeURL = "types.googleapis.com/google.protobuf.Value"
zeroValueMap = map[string]proto.Message{
"google.protobuf.Any": &anypb.Any{},
"google.protobuf.Any": &anypb.Any{TypeUrl: jsonValueTypeURL},
"google.protobuf.Duration": &dpb.Duration{},
"google.protobuf.ListValue": &structpb.ListValue{},
"google.protobuf.Struct": &structpb.Struct{},

View File

@@ -19,11 +19,12 @@ import (
"reflect"
"time"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
anypb "google.golang.org/protobuf/types/known/anypb"
@@ -32,17 +33,64 @@ import (
tpb "google.golang.org/protobuf/types/known/timestamppb"
)
type protoTypeRegistry struct {
revTypeMap map[string]ref.Type
// Adapter converts native Go values of varying type and complexity to equivalent CEL values.
type Adapter = ref.TypeAdapter
// Provider specifies functions for creating new object instances and for resolving
// enum values by name.
type Provider interface {
// EnumValue returns the numeric value of the given enum value name.
EnumValue(enumName string) ref.Val
// FindIdent takes a qualified identifier name and returns a ref.Val if one exists.
FindIdent(identName string) (ref.Val, bool)
// FindStructType returns the Type give a qualified type name.
//
// For historical reasons, only struct types are expected to be returned through this
// method, and the type values are expected to be wrapped in a TypeType instance using
// TypeTypeWithParam(<structType>).
//
// Returns false if not found.
FindStructType(structType string) (*Type, bool)
// FieldStructFieldType returns the field type for a checked type value. Returns
// false if the field could not be found.
FindStructFieldType(structType, fieldName string) (*FieldType, bool)
// NewValue creates a new type value from a qualified name and map of field
// name to value.
//
// Note, for each value, the Val.ConvertToNative function will be invoked
// to convert the Val to the field's native type. If an error occurs during
// conversion, the NewValue will be a types.Err.
NewValue(structType string, fields map[string]ref.Val) ref.Val
}
// FieldType represents a field's type value and whether that field supports presence detection.
type FieldType struct {
// Type of the field as a CEL native type value.
Type *Type
// IsSet indicates whether the field is set on an input object.
IsSet ref.FieldTester
// GetFrom retrieves the field value on the input object, if set.
GetFrom ref.FieldGetter
}
// Registry provides type information for a set of registered types.
type Registry struct {
revTypeMap map[string]*Type
pbdb *pb.Db
}
// NewRegistry accepts a list of proto message instances and returns a type
// provider which can create new instances of the provided message or any
// message that proto depends upon in its FileDescriptor.
func NewRegistry(types ...proto.Message) (ref.TypeRegistry, error) {
p := &protoTypeRegistry{
revTypeMap: make(map[string]ref.Type),
func NewRegistry(types ...proto.Message) (*Registry, error) {
p := &Registry{
revTypeMap: make(map[string]*Type),
pbdb: pb.NewDb(),
}
err := p.RegisterType(
@@ -78,18 +126,17 @@ func NewRegistry(types ...proto.Message) (ref.TypeRegistry, error) {
}
// NewEmptyRegistry returns a registry which is completely unconfigured.
func NewEmptyRegistry() ref.TypeRegistry {
return &protoTypeRegistry{
revTypeMap: make(map[string]ref.Type),
func NewEmptyRegistry() *Registry {
return &Registry{
revTypeMap: make(map[string]*Type),
pbdb: pb.NewDb(),
}
}
// Copy implements the ref.TypeRegistry interface method which copies the current state of the
// registry into its own memory space.
func (p *protoTypeRegistry) Copy() ref.TypeRegistry {
copy := &protoTypeRegistry{
revTypeMap: make(map[string]ref.Type),
// Copy copies the current state of the registry into its own memory space.
func (p *Registry) Copy() *Registry {
copy := &Registry{
revTypeMap: make(map[string]*Type),
pbdb: p.pbdb.Copy(),
}
for k, v := range p.revTypeMap {
@@ -98,7 +145,8 @@ func (p *protoTypeRegistry) Copy() ref.TypeRegistry {
return copy
}
func (p *protoTypeRegistry) EnumValue(enumName string) ref.Val {
// EnumValue returns the numeric value of the given enum value name.
func (p *Registry) EnumValue(enumName string) ref.Val {
enumVal, found := p.pbdb.DescribeEnum(enumName)
if !found {
return NewErr("unknown enum name '%s'", enumName)
@@ -106,9 +154,12 @@ func (p *protoTypeRegistry) EnumValue(enumName string) ref.Val {
return Int(enumVal.Value())
}
func (p *protoTypeRegistry) FindFieldType(messageType string,
fieldName string) (*ref.FieldType, bool) {
msgType, found := p.pbdb.DescribeType(messageType)
// FieldFieldType returns the field type for a checked type value. Returns false if
// the field could not be found.
//
// Deprecated: use FindStructFieldType
func (p *Registry) FindFieldType(structType, fieldName string) (*ref.FieldType, bool) {
msgType, found := p.pbdb.DescribeType(structType)
if !found {
return nil, false
}
@@ -117,15 +168,32 @@ func (p *protoTypeRegistry) FindFieldType(messageType string,
return nil, false
}
return &ref.FieldType{
Type: field.CheckedType(),
IsSet: field.IsSet,
GetFrom: field.GetFrom},
true
Type: field.CheckedType(),
IsSet: field.IsSet,
GetFrom: field.GetFrom}, true
}
func (p *protoTypeRegistry) FindIdent(identName string) (ref.Val, bool) {
// FieldStructFieldType returns the field type for a checked type value. Returns
// false if the field could not be found.
func (p *Registry) FindStructFieldType(structType, fieldName string) (*FieldType, bool) {
msgType, found := p.pbdb.DescribeType(structType)
if !found {
return nil, false
}
field, found := msgType.FieldByName(fieldName)
if !found {
return nil, false
}
return &FieldType{
Type: fieldDescToCELType(field),
IsSet: field.IsSet,
GetFrom: field.GetFrom}, true
}
// FindIdent takes a qualified identifier name and returns a ref.Val if one exists.
func (p *Registry) FindIdent(identName string) (ref.Val, bool) {
if t, found := p.revTypeMap[identName]; found {
return t.(ref.Val), true
return t, true
}
if enumVal, found := p.pbdb.DescribeEnum(identName); found {
return Int(enumVal.Value()), true
@@ -133,24 +201,50 @@ func (p *protoTypeRegistry) FindIdent(identName string) (ref.Val, bool) {
return nil, false
}
func (p *protoTypeRegistry) FindType(typeName string) (*exprpb.Type, bool) {
if _, found := p.pbdb.DescribeType(typeName); !found {
// FindType looks up the Type given a qualified typeName. Returns false if not found.
//
// Deprecated: use FindStructType
func (p *Registry) FindType(structType string) (*exprpb.Type, bool) {
if _, found := p.pbdb.DescribeType(structType); !found {
return nil, false
}
if typeName != "" && typeName[0] == '.' {
typeName = typeName[1:]
if structType != "" && structType[0] == '.' {
structType = structType[1:]
}
return &exprpb.Type{
TypeKind: &exprpb.Type_Type{
Type: &exprpb.Type{
TypeKind: &exprpb.Type_MessageType{
MessageType: typeName}}}}, true
MessageType: structType}}}}, true
}
func (p *protoTypeRegistry) NewValue(typeName string, fields map[string]ref.Val) ref.Val {
td, found := p.pbdb.DescribeType(typeName)
// FindStructType returns the Type give a qualified type name.
//
// For historical reasons, only struct types are expected to be returned through this
// method, and the type values are expected to be wrapped in a TypeType instance using
// TypeTypeWithParam(<structType>).
//
// Returns false if not found.
func (p *Registry) FindStructType(structType string) (*Type, bool) {
if _, found := p.pbdb.DescribeType(structType); !found {
return nil, false
}
if structType != "" && structType[0] == '.' {
structType = structType[1:]
}
return NewTypeTypeWithParam(NewObjectType(structType)), true
}
// NewValue creates a new type value from a qualified name and map of field
// name to value.
//
// Note, for each value, the Val.ConvertToNative function will be invoked
// to convert the Val to the field's native type. If an error occurs during
// conversion, the NewValue will be a types.Err.
func (p *Registry) NewValue(structType string, fields map[string]ref.Val) ref.Val {
td, found := p.pbdb.DescribeType(structType)
if !found {
return NewErr("unknown type '%s'", typeName)
return NewErr("unknown type '%s'", structType)
}
msg := td.New()
fieldMap := td.FieldMap()
@@ -167,7 +261,8 @@ func (p *protoTypeRegistry) NewValue(typeName string, fields map[string]ref.Val)
return p.NativeToValue(msg.Interface())
}
func (p *protoTypeRegistry) RegisterDescriptor(fileDesc protoreflect.FileDescriptor) error {
// RegisterDescriptor registers the contents of a protocol buffer `FileDescriptor`.
func (p *Registry) RegisterDescriptor(fileDesc protoreflect.FileDescriptor) error {
fd, err := p.pbdb.RegisterDescriptor(fileDesc)
if err != nil {
return err
@@ -175,7 +270,8 @@ func (p *protoTypeRegistry) RegisterDescriptor(fileDesc protoreflect.FileDescrip
return p.registerAllTypes(fd)
}
func (p *protoTypeRegistry) RegisterMessage(message proto.Message) error {
// RegisterMessage registers a protocol buffer message and its dependencies.
func (p *Registry) RegisterMessage(message proto.Message) error {
fd, err := p.pbdb.RegisterMessage(message)
if err != nil {
return err
@@ -183,11 +279,32 @@ func (p *protoTypeRegistry) RegisterMessage(message proto.Message) error {
return p.registerAllTypes(fd)
}
func (p *protoTypeRegistry) RegisterType(types ...ref.Type) error {
// RegisterType registers a type value with the provider which ensures the provider is aware of how to
// map the type to an identifier.
//
// If the `ref.Type` value is a `*types.Type` it will be registered directly by its runtime type name.
// If the `ref.Type` value is not a `*types.Type` instance, a `*types.Type` instance which reflects the
// traits present on the input and the runtime type name. By default this foreign type will be treated
// as a types.StructKind. To avoid potential issues where the `ref.Type` values does not match the
// generated `*types.Type` instance, consider always using the `*types.Type` to represent type extensions
// to CEL, even when they're not based on protobuf types.
func (p *Registry) RegisterType(types ...ref.Type) error {
for _, t := range types {
p.revTypeMap[t.TypeName()] = t
celType := maybeForeignType(t)
existing, found := p.revTypeMap[t.TypeName()]
if !found {
p.revTypeMap[t.TypeName()] = celType
continue
}
if !existing.IsEquivalentType(celType) {
return fmt.Errorf("type registration conflict. found: %v, input: %v", existing, celType)
}
if existing.traitMask != celType.traitMask {
return fmt.Errorf(
"type registered with conflicting traits: %v with traits %v, input: %v",
existing.TypeName(), existing.traitMask, celType.traitMask)
}
}
// TODO: generate an error when the type name is registered more than once.
return nil
}
@@ -195,7 +312,7 @@ func (p *protoTypeRegistry) RegisterType(types ...ref.Type) error {
// providing support for custom proto-based types.
//
// This method should be the inverse of ref.Val.ConvertToNative.
func (p *protoTypeRegistry) NativeToValue(value interface{}) ref.Val {
func (p *Registry) NativeToValue(value any) ref.Val {
if val, found := nativeToValue(p, value); found {
return val
}
@@ -217,7 +334,7 @@ func (p *protoTypeRegistry) NativeToValue(value interface{}) ref.Val {
if !found {
return NewErr("unknown type: '%s'", typeName)
}
return NewObject(p, td, typeVal.(*TypeValue), v)
return NewObject(p, td, typeVal, v)
case *pb.Map:
return NewProtoMap(p, v)
case protoreflect.List:
@@ -230,8 +347,13 @@ func (p *protoTypeRegistry) NativeToValue(value interface{}) ref.Val {
return UnsupportedRefValConversionErr(value)
}
func (p *protoTypeRegistry) registerAllTypes(fd *pb.FileDescription) error {
func (p *Registry) registerAllTypes(fd *pb.FileDescription) error {
for _, typeName := range fd.GetTypeNames() {
// skip well-known type names since they're automatically sanitized
// during NewObjectType() calls.
if _, found := checkedWellKnowns[typeName]; found {
continue
}
err := p.RegisterType(NewObjectTypeValue(typeName))
if err != nil {
return err
@@ -240,6 +362,28 @@ func (p *protoTypeRegistry) registerAllTypes(fd *pb.FileDescription) error {
return nil
}
func fieldDescToCELType(field *pb.FieldDescription) *Type {
if field.IsMap() {
return NewMapType(
singularFieldDescToCELType(field.KeyType),
singularFieldDescToCELType(field.ValueType))
}
if field.IsList() {
return NewListType(singularFieldDescToCELType(field))
}
return singularFieldDescToCELType(field)
}
func singularFieldDescToCELType(field *pb.FieldDescription) *Type {
if field.IsMessage() {
return NewObjectType(string(field.Descriptor().Message().FullName()))
}
if field.IsEnum() {
return IntType
}
return ProtoCELPrimitives[field.ProtoKind()]
}
// defaultTypeAdapter converts go native types to CEL values.
type defaultTypeAdapter struct{}
@@ -249,7 +393,7 @@ var (
)
// NativeToValue implements the ref.TypeAdapter interface.
func (a *defaultTypeAdapter) NativeToValue(value interface{}) ref.Val {
func (a *defaultTypeAdapter) NativeToValue(value any) ref.Val {
if val, found := nativeToValue(a, value); found {
return val
}
@@ -258,7 +402,7 @@ func (a *defaultTypeAdapter) NativeToValue(value interface{}) ref.Val {
// nativeToValue returns the converted (ref.Val, true) of a conversion is found,
// otherwise (nil, false)
func nativeToValue(a ref.TypeAdapter, value interface{}) (ref.Val, bool) {
func nativeToValue(a Adapter, value any) (ref.Val, bool) {
switch v := value.(type) {
case nil:
return NullValue, true
@@ -364,7 +508,7 @@ func nativeToValue(a ref.TypeAdapter, value interface{}) (ref.Val, bool) {
// specializations for common map types.
case map[string]string:
return NewStringStringMap(a, v), true
case map[string]interface{}:
case map[string]any:
return NewStringInterfaceMap(a, v), true
case map[ref.Val]ref.Val:
return NewRefValMap(a, v), true
@@ -479,9 +623,12 @@ func msgSetField(target protoreflect.Message, field *pb.FieldDescription, val re
if err != nil {
return fieldTypeConversionError(field, err)
}
switch v.(type) {
if v == nil {
return nil
}
switch pv := v.(type) {
case proto.Message:
v = v.(proto.Message).ProtoReflect()
v = pv.ProtoReflect()
}
target.Set(field.Descriptor(), protoreflect.ValueOf(v))
return nil
@@ -495,6 +642,9 @@ func msgSetListField(target protoreflect.List, listField *pb.FieldDescription, l
if err != nil {
return fieldTypeConversionError(listField, err)
}
if elemVal == nil {
continue
}
switch ev := elemVal.(type) {
case proto.Message:
elemVal = ev.ProtoReflect()
@@ -519,9 +669,12 @@ func msgSetMapField(target protoreflect.Map, mapField *pb.FieldDescription, mapV
if err != nil {
return fieldTypeConversionError(mapField, err)
}
switch v.(type) {
if v == nil {
continue
}
switch pv := v.(type) {
case proto.Message:
v = v.(proto.Message).ProtoReflect()
v = pv.ProtoReflect()
}
target.Set(protoreflect.ValueOf(k).MapKey(), protoreflect.ValueOf(v))
}
@@ -537,3 +690,24 @@ func fieldTypeConversionError(field *pb.FieldDescription, err error) error {
msgName := field.Descriptor().ContainingMessage().FullName()
return fmt.Errorf("field type conversion error for %v.%v value type: %v", msgName, field.Name(), err)
}
var (
// ProtoCELPrimitives provides a map from the protoreflect Kind to the equivalent CEL type.
ProtoCELPrimitives = map[protoreflect.Kind]*Type{
protoreflect.BoolKind: BoolType,
protoreflect.BytesKind: BytesType,
protoreflect.DoubleKind: DoubleType,
protoreflect.FloatKind: DoubleType,
protoreflect.Int32Kind: IntType,
protoreflect.Int64Kind: IntType,
protoreflect.Sint32Kind: IntType,
protoreflect.Sint64Kind: IntType,
protoreflect.Uint32Kind: UintType,
protoreflect.Uint64Kind: UintType,
protoreflect.Fixed32Kind: UintType,
protoreflect.Fixed64Kind: UintType,
protoreflect.Sfixed32Kind: IntType,
protoreflect.Sfixed64Kind: IntType,
protoreflect.StringKind: StringType,
}
)

View File

@@ -13,7 +13,7 @@ go_library(
],
importpath = "github.com/google/cel-go/common/types/ref",
deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
],

View File

@@ -23,45 +23,45 @@ import (
// TypeProvider specifies functions for creating new object instances and for
// resolving enum values by name.
//
// Deprecated: use types.Provider
type TypeProvider interface {
// EnumValue returns the numeric value of the given enum value name.
EnumValue(enumName string) Val
// FindIdent takes a qualified identifier name and returns a Value if one
// exists.
// FindIdent takes a qualified identifier name and returns a Value if one exists.
FindIdent(identName string) (Val, bool)
// FindType looks up the Type given a qualified typeName. Returns false
// if not found.
//
// Used during type-checking only.
// FindType looks up the Type given a qualified typeName. Returns false if not found.
FindType(typeName string) (*exprpb.Type, bool)
// FieldFieldType returns the field type for a checked type value. Returns
// false if the field could not be found.
//
// Used during type-checking only.
FindFieldType(messageType string, fieldName string) (*FieldType, bool)
// FieldFieldType returns the field type for a checked type value. Returns false if
// the field could not be found.
FindFieldType(messageType, fieldName string) (*FieldType, bool)
// NewValue creates a new type value from a qualified name and map of field
// name to value.
// NewValue creates a new type value from a qualified name and map of field name
// to value.
//
// Note, for each value, the Val.ConvertToNative function will be invoked
// to convert the Val to the field's native type. If an error occurs during
// conversion, the NewValue will be a types.Err.
// Note, for each value, the Val.ConvertToNative function will be invoked to convert
// the Val to the field's native type. If an error occurs during conversion, the
// NewValue will be a types.Err.
NewValue(typeName string, fields map[string]Val) Val
}
// TypeAdapter converts native Go values of varying type and complexity to equivalent CEL values.
//
// Deprecated: use types.Adapter
type TypeAdapter interface {
// NativeToValue converts the input `value` to a CEL `ref.Val`.
NativeToValue(value interface{}) Val
NativeToValue(value any) Val
}
// TypeRegistry allows third-parties to add custom types to CEL. Not all `TypeProvider`
// implementations support type-customization, so these features are optional. However, a
// `TypeRegistry` should be a `TypeProvider` and a `TypeAdapter` to ensure that types
// which are registered can be converted to CEL representations.
//
// Deprecated: use types.Registry
type TypeRegistry interface {
TypeAdapter
TypeProvider
@@ -78,15 +78,14 @@ type TypeRegistry interface {
// If a type is provided more than once with an alternative definition, the
// call will result in an error.
RegisterType(types ...Type) error
// Copy the TypeRegistry and return a new registry whose mutable state is isolated.
Copy() TypeRegistry
}
// FieldType represents a field's type value and whether that field supports
// presence detection.
//
// Deprecated: use types.FieldType
type FieldType struct {
// Type of the field.
// Type of the field as a protobuf type value.
Type *exprpb.Type
// IsSet indicates whether the field is set on an input object.
@@ -97,7 +96,7 @@ type FieldType struct {
}
// FieldTester is used to test field presence on an input object.
type FieldTester func(target interface{}) bool
type FieldTester func(target any) bool
// FieldGetter is used to get the field value from an input object, if set.
type FieldGetter func(target interface{}) (interface{}, error)
type FieldGetter func(target any) (any, error)

View File

@@ -37,9 +37,18 @@ type Type interface {
type Val interface {
// ConvertToNative converts the Value to a native Go struct according to the
// reflected type description, or error if the conversion is not feasible.
ConvertToNative(typeDesc reflect.Type) (interface{}, error)
//
// The ConvertToNative method is intended to be used to support conversion between CEL types
// and native types during object creation expressions or by clients who need to adapt the,
// returned CEL value into an equivalent Go value instance.
//
// When implementing or using ConvertToNative, the following guidelines apply:
// - Use ConvertToNative when marshalling CEL evaluation results to native types.
// - Do not use ConvertToNative within CEL extension functions.
// - Document whether your implementation supports non-CEL field types, such as Go or Protobuf.
ConvertToNative(typeDesc reflect.Type) (any, error)
// ConvertToType supports type conversions between value types supported by the expression language.
// ConvertToType supports type conversions between CEL value types supported by the expression language.
ConvertToType(typeValue Type) Val
// Equal returns true if the `other` value has the same type and content as the implementing struct.
@@ -50,5 +59,5 @@ type Val interface {
// Value returns the raw value of the instance which may not be directly compatible with the expression
// language types.
Value() interface{}
Value() any
}

View File

@@ -24,7 +24,6 @@ import (
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
@@ -36,18 +35,10 @@ import (
type String string
var (
// StringType singleton.
StringType = NewTypeValue("string",
traits.AdderType,
traits.ComparerType,
traits.MatcherType,
traits.ReceiverType,
traits.SizerType)
stringOneArgOverloads = map[string]func(String, ref.Val) ref.Val{
overloads.Contains: stringContains,
overloads.EndsWith: stringEndsWith,
overloads.StartsWith: stringStartsWith,
stringOneArgOverloads = map[string]func(ref.Val, ref.Val) ref.Val{
overloads.Contains: StringContains,
overloads.EndsWith: StringEndsWith,
overloads.StartsWith: StringStartsWith,
}
stringWrapperType = reflect.TypeOf(&wrapperspb.StringValue{})
@@ -72,7 +63,7 @@ func (s String) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (s String) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (s String) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.String:
if reflect.TypeOf(s).AssignableTo(typeDesc) {
@@ -154,6 +145,11 @@ func (s String) Equal(other ref.Val) ref.Val {
return Bool(ok && s == otherString)
}
// IsZeroValue returns true if the string is empty.
func (s String) IsZeroValue() bool {
return len(s) == 0
}
// Match implements traits.Matcher.Match.
func (s String) Match(pattern ref.Val) ref.Val {
pat, ok := pattern.(String)
@@ -189,30 +185,45 @@ func (s String) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (s String) Value() interface{} {
func (s String) Value() any {
return string(s)
}
func stringContains(s String, sub ref.Val) ref.Val {
// StringContains returns whether the string contains a substring.
func StringContains(s, sub ref.Val) ref.Val {
str, ok := s.(String)
if !ok {
return MaybeNoSuchOverloadErr(s)
}
subStr, ok := sub.(String)
if !ok {
return MaybeNoSuchOverloadErr(sub)
}
return Bool(strings.Contains(string(s), string(subStr)))
return Bool(strings.Contains(string(str), string(subStr)))
}
func stringEndsWith(s String, suf ref.Val) ref.Val {
// StringEndsWith returns whether the target string contains the input suffix.
func StringEndsWith(s, suf ref.Val) ref.Val {
str, ok := s.(String)
if !ok {
return MaybeNoSuchOverloadErr(s)
}
sufStr, ok := suf.(String)
if !ok {
return MaybeNoSuchOverloadErr(suf)
}
return Bool(strings.HasSuffix(string(s), string(sufStr)))
return Bool(strings.HasSuffix(string(str), string(sufStr)))
}
func stringStartsWith(s String, pre ref.Val) ref.Val {
// StringStartsWith returns whether the target string contains the input prefix.
func StringStartsWith(s, pre ref.Val) ref.Val {
str, ok := s.(String)
if !ok {
return MaybeNoSuchOverloadErr(s)
}
preStr, ok := pre.(String)
if !ok {
return MaybeNoSuchOverloadErr(pre)
}
return Bool(strings.HasPrefix(string(s), string(preStr)))
return Bool(strings.HasPrefix(string(str), string(preStr)))
}

View File

@@ -23,7 +23,6 @@ import (
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
@@ -53,15 +52,6 @@ const (
maxUnixTime int64 = 253402300799
)
var (
// TimestampType singleton.
TimestampType = NewTypeValue("google.protobuf.Timestamp",
traits.AdderType,
traits.ComparerType,
traits.ReceiverType,
traits.SubtractorType)
)
// Add implements traits.Adder.Add.
func (t Timestamp) Add(other ref.Val) ref.Val {
switch other.Type() {
@@ -89,7 +79,7 @@ func (t Timestamp) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (t Timestamp) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (t Timestamp) ConvertToNative(typeDesc reflect.Type) (any, error) {
// If the timestamp is already assignable to the desired type return it.
if reflect.TypeOf(t.Time).AssignableTo(typeDesc) {
return t.Time, nil
@@ -138,6 +128,11 @@ func (t Timestamp) Equal(other ref.Val) ref.Val {
return Bool(ok && t.Time.Equal(otherTime.Time))
}
// IsZeroValue returns true if the timestamp is epoch 0.
func (t Timestamp) IsZeroValue() bool {
return t.IsZero()
}
// Receive implements traits.Receiver.Receive.
func (t Timestamp) Receive(function string, overload string, args []ref.Val) ref.Val {
switch len(args) {
@@ -160,14 +155,14 @@ func (t Timestamp) Subtract(subtrahend ref.Val) ref.Val {
dur := subtrahend.(Duration)
val, err := subtractTimeDurationChecked(t.Time, dur.Duration)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return timestampOf(val)
case TimestampType:
t2 := subtrahend.(Timestamp).Time
val, err := subtractTimeChecked(t.Time, t2)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return durationOf(val)
}
@@ -180,7 +175,7 @@ func (t Timestamp) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (t Timestamp) Value() interface{} {
func (t Timestamp) Value() any {
return t.Time
}
@@ -288,7 +283,7 @@ func timeZone(tz ref.Val, visitor timestampVisitor) timestampVisitor {
if ind == -1 {
loc, err := time.LoadLocation(val)
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return visitor(t.In(loc))
}
@@ -297,11 +292,11 @@ func timeZone(tz ref.Val, visitor timestampVisitor) timestampVisitor {
// in the format ^(+|-)(0[0-9]|1[0-4]):[0-5][0-9]$. The numerical input is parsed in terms of hours and minutes.
hr, err := strconv.Atoi(string(val[0:ind]))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
min, err := strconv.Atoi(string(val[ind+1:]))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
var offset int
if string(val[0]) == "-" {

View File

@@ -20,6 +20,7 @@ go_library(
"receiver.go",
"sizer.go",
"traits.go",
"zeroer.go",
],
importpath = "github.com/google/cel-go/common/types/traits",
deps = [

View File

@@ -1,4 +1,4 @@
// Copyright 2020 Google LLC
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,24 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package interpreter
package traits
import "math"
// TODO: remove Coster.
// Coster calculates the heuristic cost incurred during evaluation.
// Deprecated: Please migrate cel.EstimateCost, it supports length estimates for input data and cost estimates for
// extension functions.
type Coster interface {
Cost() (min, max int64)
}
// estimateCost returns the heuristic cost interval for the program.
func estimateCost(i interface{}) (min, max int64) {
c, ok := i.(Coster)
if !ok {
return 0, math.MaxInt64
}
return c.Cost()
// Zeroer interface for testing whether a CEL value is a zero value for its type.
type Zeroer interface {
// IsZeroValue indicates whether the object is the zero value for the type.
IsZeroValue() bool
}

View File

@@ -1,102 +0,0 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"fmt"
"reflect"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
var (
// TypeType is the type of a TypeValue.
TypeType = NewTypeValue("type")
)
// TypeValue is an instance of a Value that describes a value's type.
type TypeValue struct {
name string
traitMask int
}
// NewTypeValue returns *TypeValue which is both a ref.Type and ref.Val.
func NewTypeValue(name string, traits ...int) *TypeValue {
traitMask := 0
for _, trait := range traits {
traitMask |= trait
}
return &TypeValue{
name: name,
traitMask: traitMask}
}
// NewObjectTypeValue returns a *TypeValue based on the input name, which is
// annotated with the traits relevant to all objects.
func NewObjectTypeValue(name string) *TypeValue {
return NewTypeValue(name,
traits.FieldTesterType,
traits.IndexerType)
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (t *TypeValue) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
// TODO: replace the internal type representation with a proto-value.
return nil, fmt.Errorf("type conversion not supported for 'type'")
}
// ConvertToType implements ref.Val.ConvertToType.
func (t *TypeValue) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case TypeType:
return TypeType
case StringType:
return String(t.TypeName())
}
return NewErr("type conversion error from '%s' to '%s'", TypeType, typeVal)
}
// Equal implements ref.Val.Equal.
func (t *TypeValue) Equal(other ref.Val) ref.Val {
otherType, ok := other.(ref.Type)
return Bool(ok && t.TypeName() == otherType.TypeName())
}
// HasTrait indicates whether the type supports the given trait.
// Trait codes are defined in the traits package, e.g. see traits.AdderType.
func (t *TypeValue) HasTrait(trait int) bool {
return trait&t.traitMask == trait
}
// String implements fmt.Stringer.
func (t *TypeValue) String() string {
return t.name
}
// Type implements ref.Val.Type.
func (t *TypeValue) Type() ref.Type {
return TypeType
}
// TypeName gives the type's name as a string.
func (t *TypeValue) TypeName() string {
return t.name
}
// Value implements ref.Val.Value.
func (t *TypeValue) Value() interface{} {
return t.name
}

806
vendor/github.com/google/cel-go/common/types/types.go generated vendored Normal file
View File

@@ -0,0 +1,806 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"fmt"
"reflect"
"strings"
chkdecls "github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Kind indicates a CEL type's kind which is used to differentiate quickly between simple
// and complex types.
type Kind uint
const (
// UnspecifiedKind is returned when the type is nil or its kind is not specified.
UnspecifiedKind Kind = iota
// DynKind represents a dynamic type. This kind only exists at type-check time.
DynKind
// AnyKind represents a google.protobuf.Any type. This kind only exists at type-check time.
// Prefer DynKind to AnyKind as AnyKind has a specific meaning which is based on protobuf
// well-known types.
AnyKind
// BoolKind represents a boolean type.
BoolKind
// BytesKind represents a bytes type.
BytesKind
// DoubleKind represents a double type.
DoubleKind
// DurationKind represents a CEL duration type.
DurationKind
// ErrorKind represents a CEL error type.
ErrorKind
// IntKind represents an integer type.
IntKind
// ListKind represents a list type.
ListKind
// MapKind represents a map type.
MapKind
// NullTypeKind represents a null type.
NullTypeKind
// OpaqueKind represents an abstract type which has no accessible fields.
OpaqueKind
// StringKind represents a string type.
StringKind
// StructKind represents a structured object with typed fields.
StructKind
// TimestampKind represents a a CEL time type.
TimestampKind
// TypeKind represents the CEL type.
TypeKind
// TypeParamKind represents a parameterized type whose type name will be resolved at type-check time, if possible.
TypeParamKind
// UintKind represents a uint type.
UintKind
// UnknownKind represents an unknown value type.
UnknownKind
)
var (
// AnyType represents the google.protobuf.Any type.
AnyType = &Type{
kind: AnyKind,
runtimeTypeName: "google.protobuf.Any",
traitMask: traits.FieldTesterType |
traits.IndexerType,
}
// BoolType represents the bool type.
BoolType = &Type{
kind: BoolKind,
runtimeTypeName: "bool",
traitMask: traits.ComparerType |
traits.NegatorType,
}
// BytesType represents the bytes type.
BytesType = &Type{
kind: BytesKind,
runtimeTypeName: "bytes",
traitMask: traits.AdderType |
traits.ComparerType |
traits.SizerType,
}
// DoubleType represents the double type.
DoubleType = &Type{
kind: DoubleKind,
runtimeTypeName: "double",
traitMask: traits.AdderType |
traits.ComparerType |
traits.DividerType |
traits.MultiplierType |
traits.NegatorType |
traits.SubtractorType,
}
// DurationType represents the CEL duration type.
DurationType = &Type{
kind: DurationKind,
runtimeTypeName: "google.protobuf.Duration",
traitMask: traits.AdderType |
traits.ComparerType |
traits.NegatorType |
traits.ReceiverType |
traits.SubtractorType,
}
// DynType represents a dynamic CEL type whose type will be determined at runtime from context.
DynType = &Type{
kind: DynKind,
runtimeTypeName: "dyn",
}
// ErrorType represents a CEL error value.
ErrorType = &Type{
kind: ErrorKind,
runtimeTypeName: "error",
}
// IntType represents the int type.
IntType = &Type{
kind: IntKind,
runtimeTypeName: "int",
traitMask: traits.AdderType |
traits.ComparerType |
traits.DividerType |
traits.ModderType |
traits.MultiplierType |
traits.NegatorType |
traits.SubtractorType,
}
// ListType represents the runtime list type.
ListType = NewListType(nil)
// MapType represents the runtime map type.
MapType = NewMapType(nil, nil)
// NullType represents the type of a null value.
NullType = &Type{
kind: NullTypeKind,
runtimeTypeName: "null_type",
}
// StringType represents the string type.
StringType = &Type{
kind: StringKind,
runtimeTypeName: "string",
traitMask: traits.AdderType |
traits.ComparerType |
traits.MatcherType |
traits.ReceiverType |
traits.SizerType,
}
// TimestampType represents the time type.
TimestampType = &Type{
kind: TimestampKind,
runtimeTypeName: "google.protobuf.Timestamp",
traitMask: traits.AdderType |
traits.ComparerType |
traits.ReceiverType |
traits.SubtractorType,
}
// TypeType represents a CEL type
TypeType = &Type{
kind: TypeKind,
runtimeTypeName: "type",
}
// UintType represents a uint type.
UintType = &Type{
kind: UintKind,
runtimeTypeName: "uint",
traitMask: traits.AdderType |
traits.ComparerType |
traits.DividerType |
traits.ModderType |
traits.MultiplierType |
traits.SubtractorType,
}
// UnknownType represents an unknown value type.
UnknownType = &Type{
kind: UnknownKind,
runtimeTypeName: "unknown",
}
)
var _ ref.Type = &Type{}
var _ ref.Val = &Type{}
// Type holds a reference to a runtime type with an optional type-checked set of type parameters.
type Type struct {
// kind indicates general category of the type.
kind Kind
// parameters holds the optional type-checked set of type Parameters that are used during static analysis.
parameters []*Type
// runtimeTypeName indicates the runtime type name of the type.
runtimeTypeName string
// isAssignableType function determines whether one type is assignable to this type.
// A nil value for the isAssignableType function falls back to equality of kind, runtimeType, and parameters.
isAssignableType func(other *Type) bool
// isAssignableRuntimeType function determines whether the runtime type (with erasure) is assignable to this type.
// A nil value for the isAssignableRuntimeType function falls back to the equality of the type or type name.
isAssignableRuntimeType func(other ref.Val) bool
// traitMask is a mask of flags which indicate the capabilities of the type.
traitMask int
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (t *Type) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, fmt.Errorf("type conversion not supported for 'type'")
}
// ConvertToType implements ref.Val.ConvertToType.
func (t *Type) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case TypeType:
return TypeType
case StringType:
return String(t.TypeName())
}
return NewErr("type conversion error from '%s' to '%s'", TypeType, typeVal)
}
// Equal indicates whether two types have the same runtime type name.
//
// The name Equal is a bit of a misnomer, but for historical reasons, this is the
// runtime behavior. For a more accurate definition see IsType().
func (t *Type) Equal(other ref.Val) ref.Val {
otherType, ok := other.(ref.Type)
return Bool(ok && t.TypeName() == otherType.TypeName())
}
// HasTrait implements the ref.Type interface method.
func (t *Type) HasTrait(trait int) bool {
return trait&t.traitMask == trait
}
// IsExactType indicates whether the two types are exactly the same. This check also verifies type parameter type names.
func (t *Type) IsExactType(other *Type) bool {
return t.isTypeInternal(other, true)
}
// IsEquivalentType indicates whether two types are equivalent. This check ignores type parameter type names.
func (t *Type) IsEquivalentType(other *Type) bool {
return t.isTypeInternal(other, false)
}
// Kind indicates general category of the type.
func (t *Type) Kind() Kind {
if t == nil {
return UnspecifiedKind
}
return t.kind
}
// isTypeInternal checks whether the two types are equivalent or exactly the same based on the checkTypeParamName flag.
func (t *Type) isTypeInternal(other *Type, checkTypeParamName bool) bool {
if t == nil {
return false
}
if t == other {
return true
}
if t.Kind() != other.Kind() || len(t.Parameters()) != len(other.Parameters()) {
return false
}
if (checkTypeParamName || t.Kind() != TypeParamKind) && t.TypeName() != other.TypeName() {
return false
}
for i, p := range t.Parameters() {
if !p.isTypeInternal(other.Parameters()[i], checkTypeParamName) {
return false
}
}
return true
}
// IsAssignableType determines whether the current type is type-check assignable from the input fromType.
func (t *Type) IsAssignableType(fromType *Type) bool {
if t == nil {
return false
}
if t.isAssignableType != nil {
return t.isAssignableType(fromType)
}
return t.defaultIsAssignableType(fromType)
}
// IsAssignableRuntimeType determines whether the current type is runtime assignable from the input runtimeType.
//
// At runtime, parameterized types are erased and so a function which type-checks to support a map(string, string)
// will have a runtime assignable type of a map.
func (t *Type) IsAssignableRuntimeType(val ref.Val) bool {
if t == nil {
return false
}
if t.isAssignableRuntimeType != nil {
return t.isAssignableRuntimeType(val)
}
return t.defaultIsAssignableRuntimeType(val)
}
// Parameters returns the list of type parameters if set.
//
// For ListKind, Parameters()[0] represents the list element type
// For MapKind, Parameters()[0] represents the map key type, and Parameters()[1] represents the map
// value type.
func (t *Type) Parameters() []*Type {
if t == nil {
return emptyParams
}
return t.parameters
}
// DeclaredTypeName indicates the fully qualified and parameterized type-check type name.
func (t *Type) DeclaredTypeName() string {
// if the type itself is neither null, nor dyn, but is assignable to null, then it's a wrapper type.
if t.Kind() != NullTypeKind && !t.isDyn() && t.IsAssignableType(NullType) {
return fmt.Sprintf("wrapper(%s)", t.TypeName())
}
return t.TypeName()
}
// Type implements the ref.Val interface method.
func (t *Type) Type() ref.Type {
return TypeType
}
// Value implements the ref.Val interface method.
func (t *Type) Value() any {
return t.TypeName()
}
// TypeName returns the type-erased fully qualified runtime type name.
//
// TypeName implements the ref.Type interface method.
func (t *Type) TypeName() string {
if t == nil {
return ""
}
return t.runtimeTypeName
}
// String returns a human-readable definition of the type name.
func (t *Type) String() string {
if len(t.Parameters()) == 0 {
return t.DeclaredTypeName()
}
params := make([]string, len(t.Parameters()))
for i, p := range t.Parameters() {
params[i] = p.String()
}
return fmt.Sprintf("%s(%s)", t.DeclaredTypeName(), strings.Join(params, ", "))
}
// isDyn indicates whether the type is dynamic in any way.
func (t *Type) isDyn() bool {
k := t.Kind()
return k == DynKind || k == AnyKind || k == TypeParamKind
}
// defaultIsAssignableType provides the standard definition of what it means for one type to be assignable to another
// where any of the following may return a true result:
// - The from types are the same instance
// - The target type is dynamic
// - The fromType has the same kind and type name as the target type, and all parameters of the target type
//
// are IsAssignableType() from the parameters of the fromType.
func (t *Type) defaultIsAssignableType(fromType *Type) bool {
if t == fromType || t.isDyn() {
return true
}
if t.Kind() != fromType.Kind() ||
t.TypeName() != fromType.TypeName() ||
len(t.Parameters()) != len(fromType.Parameters()) {
return false
}
for i, tp := range t.Parameters() {
fp := fromType.Parameters()[i]
if !tp.IsAssignableType(fp) {
return false
}
}
return true
}
// defaultIsAssignableRuntimeType inspects the type and in the case of list and map elements, the key and element types
// to determine whether a ref.Val is assignable to the declared type for a function signature.
func (t *Type) defaultIsAssignableRuntimeType(val ref.Val) bool {
valType := val.Type()
// If the current type and value type don't agree, then return
if !(t.isDyn() || t.TypeName() == valType.TypeName()) {
return false
}
switch t.Kind() {
case ListKind:
elemType := t.Parameters()[0]
l := val.(traits.Lister)
if l.Size() == IntZero {
return true
}
it := l.Iterator()
elemVal := it.Next()
return elemType.IsAssignableRuntimeType(elemVal)
case MapKind:
keyType := t.Parameters()[0]
elemType := t.Parameters()[1]
m := val.(traits.Mapper)
if m.Size() == IntZero {
return true
}
it := m.Iterator()
keyVal := it.Next()
elemVal := m.Get(keyVal)
return keyType.IsAssignableRuntimeType(keyVal) && elemType.IsAssignableRuntimeType(elemVal)
}
return true
}
// NewListType creates an instances of a list type value with the provided element type.
func NewListType(elemType *Type) *Type {
return &Type{
kind: ListKind,
parameters: []*Type{elemType},
runtimeTypeName: "list",
traitMask: traits.AdderType |
traits.ContainerType |
traits.IndexerType |
traits.IterableType |
traits.SizerType,
}
}
// NewMapType creates an instance of a map type value with the provided key and value types.
func NewMapType(keyType, valueType *Type) *Type {
return &Type{
kind: MapKind,
parameters: []*Type{keyType, valueType},
runtimeTypeName: "map",
traitMask: traits.ContainerType |
traits.IndexerType |
traits.IterableType |
traits.SizerType,
}
}
// NewNullableType creates an instance of a nullable type with the provided wrapped type.
//
// Note: only primitive types are supported as wrapped types.
func NewNullableType(wrapped *Type) *Type {
return &Type{
kind: wrapped.Kind(),
parameters: wrapped.Parameters(),
runtimeTypeName: wrapped.TypeName(),
traitMask: wrapped.traitMask,
isAssignableType: func(other *Type) bool {
return NullType.IsAssignableType(other) || wrapped.IsAssignableType(other)
},
isAssignableRuntimeType: func(other ref.Val) bool {
return NullType.IsAssignableRuntimeType(other) || wrapped.IsAssignableRuntimeType(other)
},
}
}
// NewOptionalType creates an abstract parameterized type instance corresponding to CEL's notion of optional.
func NewOptionalType(param *Type) *Type {
return NewOpaqueType("optional", param)
}
// NewOpaqueType creates an abstract parameterized type with a given name.
func NewOpaqueType(name string, params ...*Type) *Type {
return &Type{
kind: OpaqueKind,
parameters: params,
runtimeTypeName: name,
}
}
// NewObjectType creates a type reference to an externally defined type, e.g. a protobuf message type.
//
// An object type is assumed to support field presence testing and field indexing. Additionally, the
// type may also indicate additional traits through the use of the optional traits vararg argument.
func NewObjectType(typeName string, traits ...int) *Type {
// Function sanitizes object types on the fly
if wkt, found := checkedWellKnowns[typeName]; found {
return wkt
}
traitMask := 0
for _, trait := range traits {
traitMask |= trait
}
return &Type{
kind: StructKind,
parameters: emptyParams,
runtimeTypeName: typeName,
traitMask: structTypeTraitMask | traitMask,
}
}
// NewObjectTypeValue creates a type reference to an externally defined type.
//
// Deprecated: use cel.ObjectType(typeName)
func NewObjectTypeValue(typeName string) *Type {
return NewObjectType(typeName)
}
// NewTypeValue creates an opaque type which has a set of optional type traits as defined in
// the common/types/traits package.
//
// Deprecated: use cel.ObjectType(typeName, traits)
func NewTypeValue(typeName string, traits ...int) *Type {
traitMask := 0
for _, trait := range traits {
traitMask |= trait
}
return &Type{
kind: StructKind,
parameters: emptyParams,
runtimeTypeName: typeName,
traitMask: traitMask,
}
}
// NewTypeParamType creates a parameterized type instance.
func NewTypeParamType(paramName string) *Type {
return &Type{
kind: TypeParamKind,
runtimeTypeName: paramName,
}
}
// NewTypeTypeWithParam creates a type with a type parameter.
// Used for type-checking purposes, but equivalent to TypeType otherwise.
func NewTypeTypeWithParam(param *Type) *Type {
return &Type{
kind: TypeKind,
runtimeTypeName: "type",
parameters: []*Type{param},
}
}
// TypeToExprType converts a CEL-native type representation to a protobuf CEL Type representation.
func TypeToExprType(t *Type) (*exprpb.Type, error) {
switch t.Kind() {
case AnyKind:
return chkdecls.Any, nil
case BoolKind:
return maybeWrapper(t, chkdecls.Bool), nil
case BytesKind:
return maybeWrapper(t, chkdecls.Bytes), nil
case DoubleKind:
return maybeWrapper(t, chkdecls.Double), nil
case DurationKind:
return chkdecls.Duration, nil
case DynKind:
return chkdecls.Dyn, nil
case ErrorKind:
return chkdecls.Error, nil
case IntKind:
return maybeWrapper(t, chkdecls.Int), nil
case ListKind:
if len(t.Parameters()) != 1 {
return nil, fmt.Errorf("invalid list, got %d parameters, wanted one", len(t.Parameters()))
}
et, err := TypeToExprType(t.Parameters()[0])
if err != nil {
return nil, err
}
return chkdecls.NewListType(et), nil
case MapKind:
if len(t.Parameters()) != 2 {
return nil, fmt.Errorf("invalid map, got %d parameters, wanted two", len(t.Parameters()))
}
kt, err := TypeToExprType(t.Parameters()[0])
if err != nil {
return nil, err
}
vt, err := TypeToExprType(t.Parameters()[1])
if err != nil {
return nil, err
}
return chkdecls.NewMapType(kt, vt), nil
case NullTypeKind:
return chkdecls.Null, nil
case OpaqueKind:
params := make([]*exprpb.Type, len(t.Parameters()))
for i, p := range t.Parameters() {
pt, err := TypeToExprType(p)
if err != nil {
return nil, err
}
params[i] = pt
}
return chkdecls.NewAbstractType(t.TypeName(), params...), nil
case StringKind:
return maybeWrapper(t, chkdecls.String), nil
case StructKind:
return chkdecls.NewObjectType(t.TypeName()), nil
case TimestampKind:
return chkdecls.Timestamp, nil
case TypeParamKind:
return chkdecls.NewTypeParamType(t.TypeName()), nil
case TypeKind:
if len(t.Parameters()) == 1 {
p, err := TypeToExprType(t.Parameters()[0])
if err != nil {
return nil, err
}
return chkdecls.NewTypeType(p), nil
}
return chkdecls.NewTypeType(nil), nil
case UintKind:
return maybeWrapper(t, chkdecls.Uint), nil
}
return nil, fmt.Errorf("missing type conversion to proto: %v", t)
}
// ExprTypeToType converts a protobuf CEL type representation to a CEL-native type representation.
func ExprTypeToType(t *exprpb.Type) (*Type, error) {
switch t.GetTypeKind().(type) {
case *exprpb.Type_Dyn:
return DynType, nil
case *exprpb.Type_AbstractType_:
paramTypes := make([]*Type, len(t.GetAbstractType().GetParameterTypes()))
for i, p := range t.GetAbstractType().GetParameterTypes() {
pt, err := ExprTypeToType(p)
if err != nil {
return nil, err
}
paramTypes[i] = pt
}
return NewOpaqueType(t.GetAbstractType().GetName(), paramTypes...), nil
case *exprpb.Type_ListType_:
et, err := ExprTypeToType(t.GetListType().GetElemType())
if err != nil {
return nil, err
}
return NewListType(et), nil
case *exprpb.Type_MapType_:
kt, err := ExprTypeToType(t.GetMapType().GetKeyType())
if err != nil {
return nil, err
}
vt, err := ExprTypeToType(t.GetMapType().GetValueType())
if err != nil {
return nil, err
}
return NewMapType(kt, vt), nil
case *exprpb.Type_MessageType:
return NewObjectType(t.GetMessageType()), nil
case *exprpb.Type_Null:
return NullType, nil
case *exprpb.Type_Primitive:
switch t.GetPrimitive() {
case exprpb.Type_BOOL:
return BoolType, nil
case exprpb.Type_BYTES:
return BytesType, nil
case exprpb.Type_DOUBLE:
return DoubleType, nil
case exprpb.Type_INT64:
return IntType, nil
case exprpb.Type_STRING:
return StringType, nil
case exprpb.Type_UINT64:
return UintType, nil
default:
return nil, fmt.Errorf("unsupported primitive type: %v", t)
}
case *exprpb.Type_TypeParam:
return NewTypeParamType(t.GetTypeParam()), nil
case *exprpb.Type_Type:
if t.GetType().GetTypeKind() != nil {
p, err := ExprTypeToType(t.GetType())
if err != nil {
return nil, err
}
return NewTypeTypeWithParam(p), nil
}
return TypeType, nil
case *exprpb.Type_WellKnown:
switch t.GetWellKnown() {
case exprpb.Type_ANY:
return AnyType, nil
case exprpb.Type_DURATION:
return DurationType, nil
case exprpb.Type_TIMESTAMP:
return TimestampType, nil
default:
return nil, fmt.Errorf("unsupported well-known type: %v", t)
}
case *exprpb.Type_Wrapper:
t, err := ExprTypeToType(&exprpb.Type{TypeKind: &exprpb.Type_Primitive{Primitive: t.GetWrapper()}})
if err != nil {
return nil, err
}
return NewNullableType(t), nil
case *exprpb.Type_Error:
return ErrorType, nil
default:
return nil, fmt.Errorf("unsupported type: %v", t)
}
}
func maybeWrapper(t *Type, pbType *exprpb.Type) *exprpb.Type {
if t.IsAssignableType(NullType) {
return chkdecls.NewWrapperType(pbType)
}
return pbType
}
func maybeForeignType(t ref.Type) *Type {
if celType, ok := t.(*Type); ok {
return celType
}
// Inspect the incoming type to determine its traits. The assumption will be that the incoming
// type does not have any field values; however, if the trait mask indicates that field testing
// and indexing are supported, the foreign type is marked as a struct.
traitMask := 0
for _, trait := range allTraits {
if t.HasTrait(trait) {
traitMask |= trait
}
}
// Treat the value like a struct. If it has no fields, this is harmless to denote the type
// as such since it basically becomes an opaque type by convention.
return NewObjectType(t.TypeName(), traitMask)
}
var (
checkedWellKnowns = map[string]*Type{
// Wrapper types.
"google.protobuf.BoolValue": NewNullableType(BoolType),
"google.protobuf.BytesValue": NewNullableType(BytesType),
"google.protobuf.DoubleValue": NewNullableType(DoubleType),
"google.protobuf.FloatValue": NewNullableType(DoubleType),
"google.protobuf.Int64Value": NewNullableType(IntType),
"google.protobuf.Int32Value": NewNullableType(IntType),
"google.protobuf.UInt64Value": NewNullableType(UintType),
"google.protobuf.UInt32Value": NewNullableType(UintType),
"google.protobuf.StringValue": NewNullableType(StringType),
// Well-known types.
"google.protobuf.Any": AnyType,
"google.protobuf.Duration": DurationType,
"google.protobuf.Timestamp": TimestampType,
// Json types.
"google.protobuf.ListValue": NewListType(DynType),
"google.protobuf.NullValue": NullType,
"google.protobuf.Struct": NewMapType(StringType, DynType),
"google.protobuf.Value": DynType,
}
emptyParams = []*Type{}
allTraits = []int{
traits.AdderType,
traits.ComparerType,
traits.ContainerType,
traits.DividerType,
traits.FieldTesterType,
traits.IndexerType,
traits.IterableType,
traits.IteratorType,
traits.MatcherType,
traits.ModderType,
traits.MultiplierType,
traits.NegatorType,
traits.ReceiverType,
traits.SizerType,
traits.SubtractorType,
}
structTypeTraitMask = traits.FieldTesterType | traits.IndexerType
)

View File

@@ -21,7 +21,6 @@ import (
"strconv"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
anypb "google.golang.org/protobuf/types/known/anypb"
structpb "google.golang.org/protobuf/types/known/structpb"
@@ -32,15 +31,6 @@ import (
type Uint uint64
var (
// UintType singleton.
UintType = NewTypeValue("uint",
traits.AdderType,
traits.ComparerType,
traits.DividerType,
traits.ModderType,
traits.MultiplierType,
traits.SubtractorType)
uint32WrapperType = reflect.TypeOf(&wrapperspb.UInt32Value{})
uint64WrapperType = reflect.TypeOf(&wrapperspb.UInt64Value{})
@@ -59,7 +49,7 @@ func (i Uint) Add(other ref.Val) ref.Val {
}
val, err := addUint64Checked(uint64(i), uint64(otherUint))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Uint(val)
}
@@ -82,7 +72,7 @@ func (i Uint) Compare(other ref.Val) ref.Val {
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (i Uint) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (i Uint) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.Uint, reflect.Uint32:
v, err := uint64ToUint32Checked(uint64(i))
@@ -149,7 +139,7 @@ func (i Uint) ConvertToType(typeVal ref.Type) ref.Val {
case IntType:
v, err := uint64ToInt64Checked(uint64(i))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Int(v)
case UintType:
@@ -172,7 +162,7 @@ func (i Uint) Divide(other ref.Val) ref.Val {
}
div, err := divideUint64Checked(uint64(i), uint64(otherUint))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Uint(div)
}
@@ -194,6 +184,11 @@ func (i Uint) Equal(other ref.Val) ref.Val {
}
}
// IsZeroValue returns true if the uint is zero.
func (i Uint) IsZeroValue() bool {
return i == 0
}
// Modulo implements traits.Modder.Modulo.
func (i Uint) Modulo(other ref.Val) ref.Val {
otherUint, ok := other.(Uint)
@@ -202,7 +197,7 @@ func (i Uint) Modulo(other ref.Val) ref.Val {
}
mod, err := moduloUint64Checked(uint64(i), uint64(otherUint))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Uint(mod)
}
@@ -215,7 +210,7 @@ func (i Uint) Multiply(other ref.Val) ref.Val {
}
val, err := multiplyUint64Checked(uint64(i), uint64(otherUint))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Uint(val)
}
@@ -228,7 +223,7 @@ func (i Uint) Subtract(subtrahend ref.Val) ref.Val {
}
val, err := subtractUint64Checked(uint64(i), uint64(subtraUint))
if err != nil {
return wrapErr(err)
return WrapErr(err)
}
return Uint(val)
}
@@ -239,7 +234,7 @@ func (i Uint) Type() ref.Type {
}
// Value implements ref.Val.Value.
func (i Uint) Value() interface{} {
func (i Uint) Value() any {
return uint64(i)
}

View File

@@ -15,52 +15,312 @@
package types
import (
"fmt"
"math"
"reflect"
"sort"
"strings"
"unicode"
"github.com/google/cel-go/common/types/ref"
)
// Unknown type implementation which collects expression ids which caused the
// current value to become unknown.
type Unknown []int64
var (
// UnknownType singleton.
UnknownType = NewTypeValue("unknown")
unspecifiedAttribute = &AttributeTrail{qualifierPath: []any{}}
)
// NewAttributeTrail creates a new simple attribute from a variable name.
func NewAttributeTrail(variable string) *AttributeTrail {
if variable == "" {
return unspecifiedAttribute
}
return &AttributeTrail{variable: variable}
}
// AttributeTrail specifies a variable with an optional qualifier path. An attribute value is expected to
// correspond to an AbsoluteAttribute, meaning a field selection which starts with a top-level variable.
//
// The qualifer path elements adhere to the AttributeQualifier type constraint.
type AttributeTrail struct {
variable string
qualifierPath []any
}
// Equal returns whether two attribute values have the same variable name and qualifier paths.
func (a *AttributeTrail) Equal(other *AttributeTrail) bool {
if a.Variable() != other.Variable() || len(a.QualifierPath()) != len(other.QualifierPath()) {
return false
}
for i, q := range a.QualifierPath() {
qual := other.QualifierPath()[i]
if !qualifiersEqual(q, qual) {
return false
}
}
return true
}
func qualifiersEqual(a, b any) bool {
if a == b {
return true
}
switch numA := a.(type) {
case int64:
numB, ok := b.(uint64)
if !ok {
return false
}
return intUintEqual(numA, numB)
case uint64:
numB, ok := b.(int64)
if !ok {
return false
}
return intUintEqual(numB, numA)
default:
return false
}
}
func intUintEqual(i int64, u uint64) bool {
if i < 0 || u > math.MaxInt64 {
return false
}
return i == int64(u)
}
// Variable returns the variable name associated with the attribute.
func (a *AttributeTrail) Variable() string {
return a.variable
}
// QualifierPath returns the optional set of qualifying fields or indices applied to the variable.
func (a *AttributeTrail) QualifierPath() []any {
return a.qualifierPath
}
// String returns the string representation of the Attribute.
func (a *AttributeTrail) String() string {
if a.variable == "" {
return "<unspecified>"
}
var str strings.Builder
str.WriteString(a.variable)
for _, q := range a.qualifierPath {
switch q := q.(type) {
case bool, int64:
str.WriteString(fmt.Sprintf("[%v]", q))
case uint64:
str.WriteString(fmt.Sprintf("[%vu]", q))
case string:
if isIdentifierCharacter(q) {
str.WriteString(fmt.Sprintf(".%v", q))
} else {
str.WriteString(fmt.Sprintf("[%q]", q))
}
}
}
return str.String()
}
func isIdentifierCharacter(str string) bool {
for _, c := range str {
if unicode.IsLetter(c) || unicode.IsDigit(c) || string(c) == "_" {
continue
}
return false
}
return true
}
// AttributeQualifier constrains the possible types which may be used to qualify an attribute.
type AttributeQualifier interface {
bool | int64 | uint64 | string
}
// QualifyAttribute qualifies an attribute using a valid AttributeQualifier type.
func QualifyAttribute[T AttributeQualifier](attr *AttributeTrail, qualifier T) *AttributeTrail {
attr.qualifierPath = append(attr.qualifierPath, qualifier)
return attr
}
// Unknown type which collects expression ids which caused the current value to become unknown.
type Unknown struct {
attributeTrails map[int64][]*AttributeTrail
}
// NewUnknown creates a new unknown at a given expression id for an attribute.
//
// If the attribute is nil, the attribute value will be the `unspecifiedAttribute`.
func NewUnknown(id int64, attr *AttributeTrail) *Unknown {
if attr == nil {
attr = unspecifiedAttribute
}
return &Unknown{
attributeTrails: map[int64][]*AttributeTrail{id: {attr}},
}
}
// IDs returns the set of unknown expression ids contained by this value.
//
// Numeric identifiers are guaranteed to be in sorted order.
func (u *Unknown) IDs() []int64 {
ids := make(int64Slice, len(u.attributeTrails))
i := 0
for id := range u.attributeTrails {
ids[i] = id
i++
}
ids.Sort()
return ids
}
// GetAttributeTrails returns the attribute trails, if present, missing for a given expression id.
func (u *Unknown) GetAttributeTrails(id int64) ([]*AttributeTrail, bool) {
trails, found := u.attributeTrails[id]
return trails, found
}
// Contains returns true if the input unknown is a subset of the current unknown.
func (u *Unknown) Contains(other *Unknown) bool {
for id, otherTrails := range other.attributeTrails {
trails, found := u.attributeTrails[id]
if !found || len(otherTrails) != len(trails) {
return false
}
for _, ot := range otherTrails {
found := false
for _, t := range trails {
if t.Equal(ot) {
found = true
break
}
}
if !found {
return false
}
}
}
return true
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (u Unknown) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
func (u *Unknown) ConvertToNative(typeDesc reflect.Type) (any, error) {
return u.Value(), nil
}
// ConvertToType is an identity function since unknown values cannot be modified.
func (u Unknown) ConvertToType(typeVal ref.Type) ref.Val {
func (u *Unknown) ConvertToType(typeVal ref.Type) ref.Val {
return u
}
// Equal is an identity function since unknown values cannot be modified.
func (u Unknown) Equal(other ref.Val) ref.Val {
func (u *Unknown) Equal(other ref.Val) ref.Val {
return u
}
// String implements the Stringer interface
func (u *Unknown) String() string {
var str strings.Builder
for id, attrs := range u.attributeTrails {
if str.Len() != 0 {
str.WriteString(", ")
}
if len(attrs) == 1 {
str.WriteString(fmt.Sprintf("%v (%d)", attrs[0], id))
} else {
str.WriteString(fmt.Sprintf("%v (%d)", attrs, id))
}
}
return str.String()
}
// Type implements ref.Val.Type.
func (u Unknown) Type() ref.Type {
func (u *Unknown) Type() ref.Type {
return UnknownType
}
// Value implements ref.Val.Value.
func (u Unknown) Value() interface{} {
return []int64(u)
func (u *Unknown) Value() any {
return u
}
// IsUnknown returns whether the element ref.Type or ref.Val is equal to the
// UnknownType singleton.
// IsUnknown returns whether the element ref.Val is in instance of *types.Unknown
func IsUnknown(val ref.Val) bool {
switch val.(type) {
case Unknown:
case *Unknown:
return true
default:
return false
}
}
// MaybeMergeUnknowns determines whether an input value and another, possibly nil, unknown will produce
// an unknown result.
//
// If the input `val` is another Unknown, then the result will be the merge of the `val` and the input
// `unk`. If the `val` is not unknown, then the result will depend on whether the input `unk` is nil.
// If both values are non-nil and unknown, then the return value will be a merge of both unknowns.
func MaybeMergeUnknowns(val ref.Val, unk *Unknown) (*Unknown, bool) {
src, isUnk := val.(*Unknown)
if !isUnk {
if unk != nil {
return unk, true
}
return unk, false
}
return MergeUnknowns(src, unk), true
}
// MergeUnknowns combines two unknown values into a new unknown value.
func MergeUnknowns(unk1, unk2 *Unknown) *Unknown {
if unk1 == nil {
return unk2
}
if unk2 == nil {
return unk1
}
out := &Unknown{
attributeTrails: make(map[int64][]*AttributeTrail, len(unk1.attributeTrails)+len(unk2.attributeTrails)),
}
for id, ats := range unk1.attributeTrails {
out.attributeTrails[id] = ats
}
for id, ats := range unk2.attributeTrails {
existing, found := out.attributeTrails[id]
if !found {
out.attributeTrails[id] = ats
continue
}
for _, at := range ats {
found := false
for _, et := range existing {
if at.Equal(et) {
found = true
break
}
}
if !found {
existing = append(existing, at)
}
}
out.attributeTrails[id] = existing
}
return out
}
// int64Slice is an implementation of the sort.Interface
type int64Slice []int64
// Len returns the number of elements in the slice.
func (x int64Slice) Len() int { return len(x) }
// Less indicates whether the value at index i is less than the value at index j.
func (x int64Slice) Less(i, j int) bool { return x[i] < x[j] }
// Swap swaps the values at indices i and j in place.
func (x int64Slice) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
// Sort is a convenience method: x.Sort() calls Sort(x).
func (x int64Slice) Sort() { sort.Sort(x) }

View File

@@ -21,7 +21,7 @@ import (
// IsUnknownOrError returns whether the input element ref.Val is an ErrType or UnknownType.
func IsUnknownOrError(val ref.Val) bool {
switch val.(type) {
case Unknown, *Err:
case *Unknown, *Err:
return true
}
return false

View File

@@ -9,14 +9,31 @@ go_library(
srcs = [
"encoders.go",
"guards.go",
"lists.go",
"math.go",
"native.go",
"protos.go",
"sets.go",
"strings.go",
],
importpath = "github.com/google/cel-go/ext",
visibility = ["//visibility:public"],
deps = [
"//cel:go_default_library",
"//checker:go_default_library",
"//checker/decls:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//interpreter:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
"@org_golang_google_protobuf//types/known/structpb",
"@org_golang_x_text//language:go_default_library",
"@org_golang_x_text//message:go_default_library",
],
)
@@ -25,6 +42,11 @@ go_test(
size = "small",
srcs = [
"encoders_test.go",
"lists_test.go",
"math_test.go",
"native_test.go",
"protos_test.go",
"sets_test.go",
"strings_test.go",
],
embed = [
@@ -32,5 +54,16 @@ go_test(
],
deps = [
"//cel:go_default_library",
"//checker:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/wrapperspb:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
],
)

View File

@@ -3,6 +3,30 @@
CEL extensions are a related set of constants, functions, macros, or other
features which may not be covered by the core CEL spec.
## Bindings
Returns a cel.EnvOption to configure support for local variable bindings
in expressions.
# Cel.Bind
Binds a simple identifier to an initialization expression which may be used
in a subsequenct result expression. Bindings may also be nested within each
other.
cel.bind(<varName>, <initExpr>, <resultExpr>)
Examples:
cel.bind(a, 'hello',
cel.bind(b, 'world', a + b + b + a)) // "helloworldworldhello"
// Avoid a list allocation within the exists comprehension.
cel.bind(valid_values, [a, b, c],
[d, e, f].exists(elem, elem in valid_values))
Local bindings are not guaranteed to be evaluated before use.
## Encoders
Encoding utilies for marshalling data into standardized representations.
@@ -31,6 +55,173 @@ Example:
base64.encode(b'hello') // return 'aGVsbG8='
## Math
Math helper macros and functions.
Note, all macros use the 'math' namespace; however, at the time of macro
expansion the namespace looks just like any other identifier. If you are
currently using a variable named 'math', the macro will likely work just as
intended; however, there is some chance for collision.
### Math.Greatest
Returns the greatest valued number present in the arguments to the macro.
Greatest is a variable argument count macro which must take at least one
argument. Simple numeric and list literals are supported as valid argument
types; however, other literals will be flagged as errors during macro
expansion. If the argument expression does not resolve to a numeric or
list(numeric) type during type-checking, or during runtime then an error
will be produced. If a list argument is empty, this too will produce an
error.
math.greatest(<arg>, ...) -> <double|int|uint>
Examples:
math.greatest(1) // 1
math.greatest(1u, 2u) // 2u
math.greatest(-42.0, -21.5, -100.0) // -21.5
math.greatest([-42.0, -21.5, -100.0]) // -21.5
math.greatest(numbers) // numbers must be list(numeric)
math.greatest() // parse error
math.greatest('string') // parse error
math.greatest(a, b) // check-time error if a or b is non-numeric
math.greatest(dyn('string')) // runtime error
### Math.Least
Returns the least valued number present in the arguments to the macro.
Least is a variable argument count macro which must take at least one
argument. Simple numeric and list literals are supported as valid argument
types; however, other literals will be flagged as errors during macro
expansion. If the argument expression does not resolve to a numeric or
list(numeric) type during type-checking, or during runtime then an error
will be produced. If a list argument is empty, this too will produce an error.
math.least(<arg>, ...) -> <double|int|uint>
Examples:
math.least(1) // 1
math.least(1u, 2u) // 1u
math.least(-42.0, -21.5, -100.0) // -100.0
math.least([-42.0, -21.5, -100.0]) // -100.0
math.least(numbers) // numbers must be list(numeric)
math.least() // parse error
math.least('string') // parse error
math.least(a, b) // check-time error if a or b is non-numeric
math.least(dyn('string')) // runtime error
## Protos
Protos configure extended macros and functions for proto manipulation.
Note, all macros use the 'proto' namespace; however, at the time of macro
expansion the namespace looks just like any other identifier. If you are
currently using a variable named 'proto', the macro will likely work just as
you intend; however, there is some chance for collision.
### Protos.GetExt
Macro which generates a select expression that retrieves an extension field
from the input proto2 syntax message. If the field is not set, the default
value forthe extension field is returned according to safe-traversal semantics.
proto.getExt(<msg>, <fully.qualified.extension.name>) -> <field-type>
Example:
proto.getExt(msg, google.expr.proto2.test.int32_ext) // returns int value
### Protos.HasExt
Macro which generates a test-only select expression that determines whether
an extension field is set on a proto2 syntax message.
proto.hasExt(<msg>, <fully.qualified.extension.name>) -> <bool>
Example:
proto.hasExt(msg, google.expr.proto2.test.int32_ext) // returns true || false
## Lists
Extended functions for list manipulation. As a general note, all indices are
zero-based.
### Slice
Returns a new sub-list using the indexes provided.
<list>.slice(<int>, <int>) -> <list>
Examples:
[1,2,3,4].slice(1, 3) // return [2, 3]
[1,2,3,4].slice(2, 4) // return [3 ,4]
## Sets
Sets provides set relationship tests.
There is no set type within CEL, and while one may be introduced in the
future, there are cases where a `list` type is known to behave like a set.
For such cases, this library provides some basic functionality for
determining set containment, equivalence, and intersection.
### Sets.Contains
Returns whether the first list argument contains all elements in the second
list argument. The list may contain elements of any type and standard CEL
equality is used to determine whether a value exists in both lists. If the
second list is empty, the result will always return true.
sets.contains(list(T), list(T)) -> bool
Examples:
sets.contains([], []) // true
sets.contains([], [1]) // false
sets.contains([1, 2, 3, 4], [2, 3]) // true
sets.contains([1, 2.0, 3u], [1.0, 2u, 3]) // true
### Sets.Equivalent
Returns whether the first and second list are set equivalent. Lists are set
equivalent if for every item in the first list, there is an element in the
second which is equal. The lists may not be of the same size as they do not
guarantee the elements within them are unique, so size does not factor into
the computation.
sets.equivalent(list(T), list(T)) -> bool
Examples:
sets.equivalent([], []) // true
sets.equivalent([1], [1, 1]) // true
sets.equivalent([1], [1u, 1.0]) // true
sets.equivalent([1, 2, 3], [3u, 2.0, 1]) // true
### Sets.Intersects
Returns whether the first list has at least one element whose value is equal
to an element in the second list. If either list is empty, the result will
be false.
sets.intersects(list(T), list(T)) -> bool
Examples:
sets.intersects([1], []) // false
sets.intersects([1], [1, 2]) // true
sets.intersects([[1], [2, 3]], [[1, 2], [2, 3.0]]) // true
## Strings
Extended functions for string manipulation. As a general note, all indices are
@@ -70,6 +261,23 @@ Examples:
'hello mellow'.indexOf('ello', 2) // returns 7
'hello mellow'.indexOf('ello', 20) // error
### Join
Returns a new string where the elements of string list are concatenated.
The function also accepts an optional separator which is placed between
elements in the resulting string.
<list<string>>.join() -> <string>
<list<string>>.join(<string>) -> <string>
Examples:
['hello', 'mellow'].join() // returns 'hellomellow'
['hello', 'mellow'].join(' ') // returns 'hello mellow'
[].join() // returns ''
[].join('/') // returns ''
### LastIndexOf
Returns the integer index of the last occurrence of the search string. If the
@@ -105,6 +313,20 @@ Examples:
'TacoCat'.lowerAscii() // returns 'tacocat'
'TacoCÆt Xii'.lowerAscii() // returns 'tacocÆt xii'
### Quote
**Introduced in version 1**
Takes the given string and makes it safe to print (without any formatting due to escape sequences).
If any invalid UTF-8 characters are encountered, they are replaced with \uFFFD.
strings.quote(<string>)
Examples:
strings.quote('single-quote with "double quote"') // returns '"single-quote with \"double quote\""'
strings.quote("two escape sequences \a\n") // returns '"two escape sequences \\a\\n"'
### Replace
Returns a new string based on the target, which replaces the occurrences of a

96
vendor/github.com/google/cel-go/ext/bindings.go generated vendored Normal file
View File

@@ -0,0 +1,96 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"github.com/google/cel-go/cel"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Bindings returns a cel.EnvOption to configure support for local variable
// bindings in expressions.
//
// # Cel.Bind
//
// Binds a simple identifier to an initialization expression which may be used
// in a subsequenct result expression. Bindings may also be nested within each
// other.
//
// cel.bind(<varName>, <initExpr>, <resultExpr>)
//
// Examples:
//
// cel.bind(a, 'hello',
// cel.bind(b, 'world', a + b + b + a)) // "helloworldworldhello"
//
// // Avoid a list allocation within the exists comprehension.
// cel.bind(valid_values, [a, b, c],
// [d, e, f].exists(elem, elem in valid_values))
//
// Local bindings are not guaranteed to be evaluated before use.
func Bindings() cel.EnvOption {
return cel.Lib(celBindings{})
}
const (
celNamespace = "cel"
bindMacro = "bind"
unusedIterVar = "#unused"
)
type celBindings struct{}
func (celBindings) LibraryName() string {
return "cel.lib.ext.cel.bindings"
}
func (celBindings) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Macros(
// cel.bind(var, <init>, <expr>)
cel.NewReceiverMacro(bindMacro, 3, celBind),
),
}
}
func (celBindings) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func celBind(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(celNamespace, target) {
return nil, nil
}
varIdent := args[0]
varName := ""
switch varIdent.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
varName = varIdent.GetIdentExpr().GetName()
default:
return nil, meh.NewError(varIdent.GetId(), "cel.bind() variable names must be simple identifiers")
}
varInit := args[1]
resultExpr := args[2]
return meh.Fold(
unusedIterVar,
meh.NewList(),
varName,
varInit,
meh.LiteralBool(false),
meh.Ident(varName),
resultExpr,
), nil
}

View File

@@ -16,7 +16,6 @@ package ext
import (
"encoding/base64"
"reflect"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
@@ -26,34 +25,38 @@ import (
// Encoders returns a cel.EnvOption to configure extended functions for string, byte, and object
// encodings.
//
// Base64.Decode
// # Base64.Decode
//
// Decodes base64-encoded string to bytes.
//
// This function will return an error if the string input is not base64-encoded.
//
// base64.decode(<string>) -> <bytes>
// base64.decode(<string>) -> <bytes>
//
// Examples:
//
// base64.decode('aGVsbG8=') // return b'hello'
// base64.decode('aGVsbG8') // error
// base64.decode('aGVsbG8=') // return b'hello'
// base64.decode('aGVsbG8') // error
//
// Base64.Encode
// # Base64.Encode
//
// Encodes bytes to a base64-encoded string.
//
// base64.encode(<bytes>) -> <string>
// base64.encode(<bytes>) -> <string>
//
// Examples:
//
// base64.encode(b'hello') // return b'aGVsbG8='
// base64.encode(b'hello') // return b'aGVsbG8='
func Encoders() cel.EnvOption {
return cel.Lib(encoderLib{})
}
type encoderLib struct{}
func (encoderLib) LibraryName() string {
return "cel.lib.ext.encoders"
}
func (encoderLib) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Function("base64.decode",
@@ -82,7 +85,3 @@ func base64DecodeString(str string) ([]byte, error) {
func base64EncodeBytes(bytes []byte) (string, error) {
return base64.StdEncoding.EncodeToString(bytes), nil
}
var (
bytesListType = reflect.TypeOf([]byte{})
)

View File

@@ -17,6 +17,8 @@ package ext
import (
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// function invocation guards for common call signatures within extension functions.
@@ -48,3 +50,15 @@ func listStringOrError(strs []string, err error) ref.Val {
}
return types.DefaultTypeAdapter.NativeToValue(strs)
}
func macroTargetMatchesNamespace(ns string, target *exprpb.Expr) bool {
switch target.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
if target.GetIdentExpr().GetName() != ns {
return false
}
return true
default:
return false
}
}

94
vendor/github.com/google/cel-go/ext/lists.go generated vendored Normal file
View File

@@ -0,0 +1,94 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"fmt"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
// Lists returns a cel.EnvOption to configure extended functions for list manipulation.
// As a general note, all indices are zero-based.
// # Slice
//
// Returns a new sub-list using the indexes provided.
//
// <list>.slice(<int>, <int>) -> <list>
//
// Examples:
//
// [1,2,3,4].slice(1, 3) // return [2, 3]
// [1,2,3,4].slice(2, 4) // return [3 ,4]
func Lists() cel.EnvOption {
return cel.Lib(listsLib{})
}
type listsLib struct{}
// LibraryName implements the SingletonLibrary interface method.
func (listsLib) LibraryName() string {
return "cel.lib.ext.lists"
}
// CompileOptions implements the Library interface method.
func (listsLib) CompileOptions() []cel.EnvOption {
listType := cel.ListType(cel.TypeParamType("T"))
return []cel.EnvOption{
cel.Function("slice",
cel.MemberOverload("list_slice",
[]*cel.Type{listType, cel.IntType, cel.IntType}, listType,
cel.FunctionBinding(func(args ...ref.Val) ref.Val {
list := args[0].(traits.Lister)
start := args[1].(types.Int)
end := args[2].(types.Int)
result, err := slice(list, start, end)
if err != nil {
return types.WrapErr(err)
}
return result
}),
),
),
}
}
// ProgramOptions implements the Library interface method.
func (listsLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func slice(list traits.Lister, start, end types.Int) (ref.Val, error) {
listLength := list.Size().(types.Int)
if start < 0 || end < 0 {
return nil, fmt.Errorf("cannot slice(%d, %d), negative indexes not supported", start, end)
}
if start > end {
return nil, fmt.Errorf("cannot slice(%d, %d), start index must be less than or equal to end index", start, end)
}
if listLength < end {
return nil, fmt.Errorf("cannot slice(%d, %d), list is length %d", start, end, listLength)
}
var newList []ref.Val
for i := types.Int(start); i < end; i++ {
val := list.Get(i)
newList = append(newList, val)
}
return types.DefaultTypeAdapter.NativeToValue(newList), nil
}

373
vendor/github.com/google/cel-go/ext/math.go generated vendored Normal file
View File

@@ -0,0 +1,373 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"fmt"
"strings"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Math returns a cel.EnvOption to configure namespaced math helper macros and
// functions.
//
// Note, all macros use the 'math' namespace; however, at the time of macro
// expansion the namespace looks just like any other identifier. If you are
// currently using a variable named 'math', the macro will likely work just as
// intended; however, there is some chance for collision.
//
// # Math.Greatest
//
// Returns the greatest valued number present in the arguments to the macro.
//
// Greatest is a variable argument count macro which must take at least one
// argument. Simple numeric and list literals are supported as valid argument
// types; however, other literals will be flagged as errors during macro
// expansion. If the argument expression does not resolve to a numeric or
// list(numeric) type during type-checking, or during runtime then an error
// will be produced. If a list argument is empty, this too will produce an
// error.
//
// math.greatest(<arg>, ...) -> <double|int|uint>
//
// Examples:
//
// math.greatest(1) // 1
// math.greatest(1u, 2u) // 2u
// math.greatest(-42.0, -21.5, -100.0) // -21.5
// math.greatest([-42.0, -21.5, -100.0]) // -21.5
// math.greatest(numbers) // numbers must be list(numeric)
//
// math.greatest() // parse error
// math.greatest('string') // parse error
// math.greatest(a, b) // check-time error if a or b is non-numeric
// math.greatest(dyn('string')) // runtime error
//
// # Math.Least
//
// Returns the least valued number present in the arguments to the macro.
//
// Least is a variable argument count macro which must take at least one
// argument. Simple numeric and list literals are supported as valid argument
// types; however, other literals will be flagged as errors during macro
// expansion. If the argument expression does not resolve to a numeric or
// list(numeric) type during type-checking, or during runtime then an error
// will be produced. If a list argument is empty, this too will produce an
// error.
//
// math.least(<arg>, ...) -> <double|int|uint>
//
// Examples:
//
// math.least(1) // 1
// math.least(1u, 2u) // 1u
// math.least(-42.0, -21.5, -100.0) // -100.0
// math.least([-42.0, -21.5, -100.0]) // -100.0
// math.least(numbers) // numbers must be list(numeric)
//
// math.least() // parse error
// math.least('string') // parse error
// math.least(a, b) // check-time error if a or b is non-numeric
// math.least(dyn('string')) // runtime error
func Math() cel.EnvOption {
return cel.Lib(mathLib{})
}
const (
mathNamespace = "math"
leastMacro = "least"
greatestMacro = "greatest"
minFunc = "math.@min"
maxFunc = "math.@max"
)
type mathLib struct{}
// LibraryName implements the SingletonLibrary interface method.
func (mathLib) LibraryName() string {
return "cel.lib.ext.math"
}
// CompileOptions implements the Library interface method.
func (mathLib) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Macros(
// math.least(num, ...)
cel.NewReceiverVarArgMacro(leastMacro, mathLeast),
// math.greatest(num, ...)
cel.NewReceiverVarArgMacro(greatestMacro, mathGreatest),
),
cel.Function(minFunc,
cel.Overload("math_@min_double", []*cel.Type{cel.DoubleType}, cel.DoubleType,
cel.UnaryBinding(identity)),
cel.Overload("math_@min_int", []*cel.Type{cel.IntType}, cel.IntType,
cel.UnaryBinding(identity)),
cel.Overload("math_@min_uint", []*cel.Type{cel.UintType}, cel.UintType,
cel.UnaryBinding(identity)),
cel.Overload("math_@min_double_double", []*cel.Type{cel.DoubleType, cel.DoubleType}, cel.DoubleType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_int_int", []*cel.Type{cel.IntType, cel.IntType}, cel.IntType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_uint_uint", []*cel.Type{cel.UintType, cel.UintType}, cel.UintType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_int_uint", []*cel.Type{cel.IntType, cel.UintType}, cel.DynType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_int_double", []*cel.Type{cel.IntType, cel.DoubleType}, cel.DynType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_double_int", []*cel.Type{cel.DoubleType, cel.IntType}, cel.DynType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_double_uint", []*cel.Type{cel.DoubleType, cel.UintType}, cel.DynType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_uint_int", []*cel.Type{cel.UintType, cel.IntType}, cel.DynType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_uint_double", []*cel.Type{cel.UintType, cel.DoubleType}, cel.DynType,
cel.BinaryBinding(minPair)),
cel.Overload("math_@min_list_double", []*cel.Type{cel.ListType(cel.DoubleType)}, cel.DoubleType,
cel.UnaryBinding(minList)),
cel.Overload("math_@min_list_int", []*cel.Type{cel.ListType(cel.IntType)}, cel.IntType,
cel.UnaryBinding(minList)),
cel.Overload("math_@min_list_uint", []*cel.Type{cel.ListType(cel.UintType)}, cel.UintType,
cel.UnaryBinding(minList)),
),
cel.Function(maxFunc,
cel.Overload("math_@max_double", []*cel.Type{cel.DoubleType}, cel.DoubleType,
cel.UnaryBinding(identity)),
cel.Overload("math_@max_int", []*cel.Type{cel.IntType}, cel.IntType,
cel.UnaryBinding(identity)),
cel.Overload("math_@max_uint", []*cel.Type{cel.UintType}, cel.UintType,
cel.UnaryBinding(identity)),
cel.Overload("math_@max_double_double", []*cel.Type{cel.DoubleType, cel.DoubleType}, cel.DoubleType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_int_int", []*cel.Type{cel.IntType, cel.IntType}, cel.IntType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_uint_uint", []*cel.Type{cel.UintType, cel.UintType}, cel.UintType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_int_uint", []*cel.Type{cel.IntType, cel.UintType}, cel.DynType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_int_double", []*cel.Type{cel.IntType, cel.DoubleType}, cel.DynType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_double_int", []*cel.Type{cel.DoubleType, cel.IntType}, cel.DynType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_double_uint", []*cel.Type{cel.DoubleType, cel.UintType}, cel.DynType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_uint_int", []*cel.Type{cel.UintType, cel.IntType}, cel.DynType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_uint_double", []*cel.Type{cel.UintType, cel.DoubleType}, cel.DynType,
cel.BinaryBinding(maxPair)),
cel.Overload("math_@max_list_double", []*cel.Type{cel.ListType(cel.DoubleType)}, cel.DoubleType,
cel.UnaryBinding(maxList)),
cel.Overload("math_@max_list_int", []*cel.Type{cel.ListType(cel.IntType)}, cel.IntType,
cel.UnaryBinding(maxList)),
cel.Overload("math_@max_list_uint", []*cel.Type{cel.ListType(cel.UintType)}, cel.UintType,
cel.UnaryBinding(maxList)),
),
}
}
// ProgramOptions implements the Library interface method.
func (mathLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func mathLeast(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(mathNamespace, target) {
return nil, nil
}
switch len(args) {
case 0:
return nil, meh.NewError(target.GetId(), "math.least() requires at least one argument")
case 1:
if isListLiteralWithValidArgs(args[0]) || isValidArgType(args[0]) {
return meh.GlobalCall(minFunc, args[0]), nil
}
return nil, meh.NewError(args[0].GetId(), "math.least() invalid single argument value")
case 2:
err := checkInvalidArgs(meh, "math.least()", args)
if err != nil {
return nil, err
}
return meh.GlobalCall(minFunc, args...), nil
default:
err := checkInvalidArgs(meh, "math.least()", args)
if err != nil {
return nil, err
}
return meh.GlobalCall(minFunc, meh.NewList(args...)), nil
}
}
func mathGreatest(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(mathNamespace, target) {
return nil, nil
}
switch len(args) {
case 0:
return nil, meh.NewError(target.GetId(), "math.greatest() requires at least one argument")
case 1:
if isListLiteralWithValidArgs(args[0]) || isValidArgType(args[0]) {
return meh.GlobalCall(maxFunc, args[0]), nil
}
return nil, meh.NewError(args[0].GetId(), "math.greatest() invalid single argument value")
case 2:
err := checkInvalidArgs(meh, "math.greatest()", args)
if err != nil {
return nil, err
}
return meh.GlobalCall(maxFunc, args...), nil
default:
err := checkInvalidArgs(meh, "math.greatest()", args)
if err != nil {
return nil, err
}
return meh.GlobalCall(maxFunc, meh.NewList(args...)), nil
}
}
func identity(val ref.Val) ref.Val {
return val
}
func minPair(first, second ref.Val) ref.Val {
cmp, ok := first.(traits.Comparer)
if !ok {
return types.MaybeNoSuchOverloadErr(first)
}
out := cmp.Compare(second)
if types.IsUnknownOrError(out) {
return maybeSuffixError(out, "math.@min")
}
if out == types.IntOne {
return second
}
return first
}
func minList(numList ref.Val) ref.Val {
l := numList.(traits.Lister)
size := l.Size().(types.Int)
if size == types.IntZero {
return types.NewErr("math.@min(list) argument must not be empty")
}
min := l.Get(types.IntZero)
for i := types.IntOne; i < size; i++ {
min = minPair(min, l.Get(i))
}
switch min.Type() {
case types.IntType, types.DoubleType, types.UintType, types.UnknownType:
return min
default:
return types.NewErr("no such overload: math.@min")
}
}
func maxPair(first, second ref.Val) ref.Val {
cmp, ok := first.(traits.Comparer)
if !ok {
return types.MaybeNoSuchOverloadErr(first)
}
out := cmp.Compare(second)
if types.IsUnknownOrError(out) {
return maybeSuffixError(out, "math.@max")
}
if out == types.IntNegOne {
return second
}
return first
}
func maxList(numList ref.Val) ref.Val {
l := numList.(traits.Lister)
size := l.Size().(types.Int)
if size == types.IntZero {
return types.NewErr("math.@max(list) argument must not be empty")
}
max := l.Get(types.IntZero)
for i := types.IntOne; i < size; i++ {
max = maxPair(max, l.Get(i))
}
switch max.Type() {
case types.IntType, types.DoubleType, types.UintType, types.UnknownType:
return max
default:
return types.NewErr("no such overload: math.@max")
}
}
func checkInvalidArgs(meh cel.MacroExprHelper, funcName string, args []*exprpb.Expr) *cel.Error {
for _, arg := range args {
err := checkInvalidArgLiteral(funcName, arg)
if err != nil {
return meh.NewError(arg.GetId(), err.Error())
}
}
return nil
}
func checkInvalidArgLiteral(funcName string, arg *exprpb.Expr) error {
if !isValidArgType(arg) {
return fmt.Errorf("%s simple literal arguments must be numeric", funcName)
}
return nil
}
func isValidArgType(arg *exprpb.Expr) bool {
switch arg.GetExprKind().(type) {
case *exprpb.Expr_ConstExpr:
c := arg.GetConstExpr()
switch c.GetConstantKind().(type) {
case *exprpb.Constant_DoubleValue, *exprpb.Constant_Int64Value, *exprpb.Constant_Uint64Value:
return true
default:
return false
}
case *exprpb.Expr_ListExpr, *exprpb.Expr_StructExpr:
return false
default:
return true
}
}
func isListLiteralWithValidArgs(arg *exprpb.Expr) bool {
switch arg.GetExprKind().(type) {
case *exprpb.Expr_ListExpr:
list := arg.GetListExpr()
if len(list.GetElements()) == 0 {
return false
}
for _, e := range list.GetElements() {
if !isValidArgType(e) {
return false
}
}
return true
}
return false
}
func maybeSuffixError(val ref.Val, suffix string) ref.Val {
if types.IsError(val) {
msg := val.(*types.Err).String()
if !strings.Contains(msg, suffix) {
return types.NewErr("%s: %s", msg, suffix)
}
}
return val
}

574
vendor/github.com/google/cel-go/ext/native.go generated vendored Normal file
View File

@@ -0,0 +1,574 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"fmt"
"reflect"
"strings"
"time"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
structpb "google.golang.org/protobuf/types/known/structpb"
)
var (
nativeObjTraitMask = traits.FieldTesterType | traits.IndexerType
jsonValueType = reflect.TypeOf(&structpb.Value{})
jsonStructType = reflect.TypeOf(&structpb.Struct{})
)
// NativeTypes creates a type provider which uses reflect.Type and reflect.Value instances
// to produce type definitions that can be used within CEL.
//
// All struct types in Go are exposed to CEL via their simple package name and struct type name:
//
// ```go
// package identity
//
// type Account struct {
// ID int
// }
//
// ```
//
// The type `identity.Account` would be exported to CEL using the same qualified name, e.g.
// `identity.Account{ID: 1234}` would create a new `Account` instance with the `ID` field
// populated.
//
// Only exported fields are exposed via NativeTypes, and the type-mapping between Go and CEL
// is as follows:
//
// | Go type | CEL type |
// |-------------------------------------|-----------|
// | bool | bool |
// | []byte | bytes |
// | float32, float64 | double |
// | int, int8, int16, int32, int64 | int |
// | string | string |
// | uint, uint8, uint16, uint32, uint64 | uint |
// | time.Duration | duration |
// | time.Time | timestamp |
// | array, slice | list |
// | map | map |
//
// Please note, if you intend to configure support for proto messages in addition to native
// types, you will need to provide the protobuf types before the golang native types. The
// same advice holds if you are using custom type adapters and type providers. The native type
// provider composes over whichever type adapter and provider is configured in the cel.Env at
// the time that it is invoked.
func NativeTypes(refTypes ...any) cel.EnvOption {
return func(env *cel.Env) (*cel.Env, error) {
tp, err := newNativeTypeProvider(env.CELTypeAdapter(), env.CELTypeProvider(), refTypes...)
if err != nil {
return nil, err
}
env, err = cel.CustomTypeAdapter(tp)(env)
if err != nil {
return nil, err
}
return cel.CustomTypeProvider(tp)(env)
}
}
func newNativeTypeProvider(adapter types.Adapter, provider types.Provider, refTypes ...any) (*nativeTypeProvider, error) {
nativeTypes := make(map[string]*nativeType, len(refTypes))
for _, refType := range refTypes {
switch rt := refType.(type) {
case reflect.Type:
t, err := newNativeType(rt)
if err != nil {
return nil, err
}
nativeTypes[t.TypeName()] = t
case reflect.Value:
t, err := newNativeType(rt.Type())
if err != nil {
return nil, err
}
nativeTypes[t.TypeName()] = t
default:
return nil, fmt.Errorf("unsupported native type: %v (%T) must be reflect.Type or reflect.Value", rt, rt)
}
}
return &nativeTypeProvider{
nativeTypes: nativeTypes,
baseAdapter: adapter,
baseProvider: provider,
}, nil
}
type nativeTypeProvider struct {
nativeTypes map[string]*nativeType
baseAdapter types.Adapter
baseProvider types.Provider
}
// EnumValue proxies to the types.Provider configured at the times the NativeTypes
// option was configured.
func (tp *nativeTypeProvider) EnumValue(enumName string) ref.Val {
return tp.baseProvider.EnumValue(enumName)
}
// FindIdent looks up natives type instances by qualified identifier, and if not found
// proxies to the composed types.Provider.
func (tp *nativeTypeProvider) FindIdent(typeName string) (ref.Val, bool) {
if t, found := tp.nativeTypes[typeName]; found {
return t, true
}
return tp.baseProvider.FindIdent(typeName)
}
// FindStructType looks up the CEL type definition by qualified identifier, and if not found
// proxies to the composed types.Provider.
func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool) {
if _, found := tp.nativeTypes[typeName]; found {
return types.NewTypeTypeWithParam(types.NewObjectType(typeName)), true
}
if celType, found := tp.baseProvider.FindStructType(typeName); found {
return celType, true
}
return tp.baseProvider.FindStructType(typeName)
}
// FindStructFieldType looks up a native type's field definition, and if the type name is not a native
// type then proxies to the composed types.Provider
func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) {
t, found := tp.nativeTypes[typeName]
if !found {
return tp.baseProvider.FindStructFieldType(typeName, fieldName)
}
refField, isDefined := t.hasField(fieldName)
if !found || !isDefined {
return nil, false
}
celType, ok := convertToCelType(refField.Type)
if !ok {
return nil, false
}
return &types.FieldType{
Type: celType,
IsSet: func(obj any) bool {
refVal := reflect.Indirect(reflect.ValueOf(obj))
refField := refVal.FieldByName(fieldName)
return !refField.IsZero()
},
GetFrom: func(obj any) (any, error) {
refVal := reflect.Indirect(reflect.ValueOf(obj))
refField := refVal.FieldByName(fieldName)
return getFieldValue(tp, refField), nil
},
}, true
}
// NewValue implements the ref.TypeProvider interface method.
func (tp *nativeTypeProvider) NewValue(typeName string, fields map[string]ref.Val) ref.Val {
t, found := tp.nativeTypes[typeName]
if !found {
return tp.baseProvider.NewValue(typeName, fields)
}
refPtr := reflect.New(t.refType)
refVal := refPtr.Elem()
for fieldName, val := range fields {
refFieldDef, isDefined := t.hasField(fieldName)
if !isDefined {
return types.NewErr("no such field: %s", fieldName)
}
fieldVal, err := val.ConvertToNative(refFieldDef.Type)
if err != nil {
return types.NewErr(err.Error())
}
refField := refVal.FieldByIndex(refFieldDef.Index)
refFieldVal := reflect.ValueOf(fieldVal)
refField.Set(refFieldVal)
}
return tp.NativeToValue(refPtr.Interface())
}
// NewValue adapts native values to CEL values and will proxy to the composed type adapter
// for non-native types.
func (tp *nativeTypeProvider) NativeToValue(val any) ref.Val {
if val == nil {
return types.NullValue
}
if v, ok := val.(ref.Val); ok {
return v
}
rawVal := reflect.ValueOf(val)
refVal := rawVal
if refVal.Kind() == reflect.Ptr {
refVal = reflect.Indirect(refVal)
}
// This isn't quite right if you're also supporting proto,
// but maybe an acceptable limitation.
switch refVal.Kind() {
case reflect.Array, reflect.Slice:
switch val := val.(type) {
case []byte:
return tp.baseAdapter.NativeToValue(val)
default:
return types.NewDynamicList(tp, val)
}
case reflect.Map:
return types.NewDynamicMap(tp, val)
case reflect.Struct:
switch val := val.(type) {
case proto.Message, *pb.Map, protoreflect.List, protoreflect.Message, protoreflect.Value,
time.Time:
return tp.baseAdapter.NativeToValue(val)
default:
return newNativeObject(tp, val, rawVal)
}
default:
return tp.baseAdapter.NativeToValue(val)
}
}
func convertToCelType(refType reflect.Type) (*cel.Type, bool) {
switch refType.Kind() {
case reflect.Bool:
return cel.BoolType, true
case reflect.Float32, reflect.Float64:
return cel.DoubleType, true
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if refType == durationType {
return cel.DurationType, true
}
return cel.IntType, true
case reflect.String:
return cel.StringType, true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return cel.UintType, true
case reflect.Array, reflect.Slice:
refElem := refType.Elem()
if refElem == reflect.TypeOf(byte(0)) {
return cel.BytesType, true
}
elemType, ok := convertToCelType(refElem)
if !ok {
return nil, false
}
return cel.ListType(elemType), true
case reflect.Map:
keyType, ok := convertToCelType(refType.Key())
if !ok {
return nil, false
}
// Ensure the key type is a int, bool, uint, string
elemType, ok := convertToCelType(refType.Elem())
if !ok {
return nil, false
}
return cel.MapType(keyType, elemType), true
case reflect.Struct:
if refType == timestampType {
return cel.TimestampType, true
}
return cel.ObjectType(
fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
), true
case reflect.Pointer:
if refType.Implements(pbMsgInterfaceType) {
pbMsg := reflect.New(refType.Elem()).Interface().(protoreflect.ProtoMessage)
return cel.ObjectType(string(pbMsg.ProtoReflect().Descriptor().FullName())), true
}
return convertToCelType(refType.Elem())
}
return nil, false
}
func newNativeObject(adapter types.Adapter, val any, refValue reflect.Value) ref.Val {
valType, err := newNativeType(refValue.Type())
if err != nil {
return types.NewErr(err.Error())
}
return &nativeObj{
Adapter: adapter,
val: val,
valType: valType,
refValue: refValue,
}
}
type nativeObj struct {
types.Adapter
val any
valType *nativeType
refValue reflect.Value
}
// ConvertToNative implements the ref.Val interface method.
//
// CEL does not have a notion of pointers, so whether a field is a pointer or value
// is handled as part of this conversion step.
func (o *nativeObj) ConvertToNative(typeDesc reflect.Type) (any, error) {
if o.refValue.Type() == typeDesc {
return o.val, nil
}
if o.refValue.Kind() == reflect.Pointer && o.refValue.Type().Elem() == typeDesc {
return o.refValue.Elem().Interface(), nil
}
if typeDesc.Kind() == reflect.Pointer && o.refValue.Type() == typeDesc.Elem() {
ptr := reflect.New(typeDesc.Elem())
ptr.Elem().Set(o.refValue)
return ptr.Interface(), nil
}
switch typeDesc {
case jsonValueType:
jsonStruct, err := o.ConvertToNative(jsonStructType)
if err != nil {
return nil, err
}
return structpb.NewStructValue(jsonStruct.(*structpb.Struct)), nil
case jsonStructType:
refVal := reflect.Indirect(o.refValue)
refType := refVal.Type()
fields := make(map[string]*structpb.Value, refVal.NumField())
for i := 0; i < refVal.NumField(); i++ {
fieldType := refType.Field(i)
fieldValue := refVal.Field(i)
if !fieldValue.IsValid() || fieldValue.IsZero() {
continue
}
fieldCELVal := o.NativeToValue(fieldValue.Interface())
fieldJSONVal, err := fieldCELVal.ConvertToNative(jsonValueType)
if err != nil {
return nil, err
}
fields[fieldType.Name] = fieldJSONVal.(*structpb.Value)
}
return &structpb.Struct{Fields: fields}, nil
}
return nil, fmt.Errorf("type conversion error from '%v' to '%v'", o.Type(), typeDesc)
}
// ConvertToType implements the ref.Val interface method.
func (o *nativeObj) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case types.TypeType:
return o.valType
default:
if typeVal.TypeName() == o.valType.typeName {
return o
}
}
return types.NewErr("type conversion error from '%s' to '%s'", o.Type(), typeVal)
}
// Equal implements the ref.Val interface method.
//
// Note, that in Golang a pointer to a value is not equal to the value it contains.
// In CEL pointers and values to which they point are equal.
func (o *nativeObj) Equal(other ref.Val) ref.Val {
otherNtv, ok := other.(*nativeObj)
if !ok {
return types.False
}
val := o.val
otherVal := otherNtv.val
refVal := o.refValue
otherRefVal := otherNtv.refValue
if refVal.Kind() != otherRefVal.Kind() {
if refVal.Kind() == reflect.Pointer {
val = refVal.Elem().Interface()
} else if otherRefVal.Kind() == reflect.Pointer {
otherVal = otherRefVal.Elem().Interface()
}
}
return types.Bool(reflect.DeepEqual(val, otherVal))
}
// IsZeroValue indicates whether the contained Golang value is a zero value.
//
// Golang largely follows proto3 semantics for zero values.
func (o *nativeObj) IsZeroValue() bool {
return reflect.Indirect(o.refValue).IsZero()
}
// IsSet tests whether a field which is defined is set to a non-default value.
func (o *nativeObj) IsSet(field ref.Val) ref.Val {
refField, refErr := o.getReflectedField(field)
if refErr != nil {
return refErr
}
return types.Bool(!refField.IsZero())
}
// Get returns the value fo a field name.
func (o *nativeObj) Get(field ref.Val) ref.Val {
refField, refErr := o.getReflectedField(field)
if refErr != nil {
return refErr
}
return adaptFieldValue(o, refField)
}
func (o *nativeObj) getReflectedField(field ref.Val) (reflect.Value, ref.Val) {
fieldName, ok := field.(types.String)
if !ok {
return reflect.Value{}, types.MaybeNoSuchOverloadErr(field)
}
fieldNameStr := string(fieldName)
refField, isDefined := o.valType.hasField(fieldNameStr)
if !isDefined {
return reflect.Value{}, types.NewErr("no such field: %s", fieldName)
}
refVal := reflect.Indirect(o.refValue)
return refVal.FieldByIndex(refField.Index), nil
}
// Type implements the ref.Val interface method.
func (o *nativeObj) Type() ref.Type {
return o.valType
}
// Value implements the ref.Val interface method.
func (o *nativeObj) Value() any {
return o.val
}
func newNativeType(rawType reflect.Type) (*nativeType, error) {
refType := rawType
if refType.Kind() == reflect.Pointer {
refType = refType.Elem()
}
if !isValidObjectType(refType) {
return nil, fmt.Errorf("unsupported reflect.Type %v, must be reflect.Struct", rawType)
}
return &nativeType{
typeName: fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
refType: refType,
}, nil
}
type nativeType struct {
typeName string
refType reflect.Type
}
// ConvertToNative implements ref.Val.ConvertToNative.
func (t *nativeType) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, fmt.Errorf("type conversion error for type to '%v'", typeDesc)
}
// ConvertToType implements ref.Val.ConvertToType.
func (t *nativeType) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case types.TypeType:
return types.TypeType
}
return types.NewErr("type conversion error from '%s' to '%s'", types.TypeType, typeVal)
}
// Equal returns true of both type names are equal to each other.
func (t *nativeType) Equal(other ref.Val) ref.Val {
otherType, ok := other.(ref.Type)
return types.Bool(ok && t.TypeName() == otherType.TypeName())
}
// HasTrait implements the ref.Type interface method.
func (t *nativeType) HasTrait(trait int) bool {
return nativeObjTraitMask&trait == trait
}
// String implements the strings.Stringer interface method.
func (t *nativeType) String() string {
return t.typeName
}
// Type implements the ref.Val interface method.
func (t *nativeType) Type() ref.Type {
return types.TypeType
}
// TypeName implements the ref.Type interface method.
func (t *nativeType) TypeName() string {
return t.typeName
}
// Value implements the ref.Val interface method.
func (t *nativeType) Value() any {
return t.typeName
}
// hasField returns whether a field name has a corresponding Golang reflect.StructField
func (t *nativeType) hasField(fieldName string) (reflect.StructField, bool) {
f, found := t.refType.FieldByName(fieldName)
if !found || !f.IsExported() || !isSupportedType(f.Type) {
return reflect.StructField{}, false
}
return f, true
}
func adaptFieldValue(adapter types.Adapter, refField reflect.Value) ref.Val {
return adapter.NativeToValue(getFieldValue(adapter, refField))
}
func getFieldValue(adapter types.Adapter, refField reflect.Value) any {
if refField.IsZero() {
switch refField.Kind() {
case reflect.Array, reflect.Slice:
return types.NewDynamicList(adapter, []ref.Val{})
case reflect.Map:
return types.NewDynamicMap(adapter, map[ref.Val]ref.Val{})
case reflect.Struct:
if refField.Type() == timestampType {
return types.Timestamp{Time: time.Unix(0, 0)}
}
return reflect.New(refField.Type()).Elem().Interface()
case reflect.Pointer:
return reflect.New(refField.Type().Elem()).Interface()
}
}
return refField.Interface()
}
func simplePkgAlias(pkgPath string) string {
paths := strings.Split(pkgPath, "/")
if len(paths) == 0 {
return ""
}
return paths[len(paths)-1]
}
func isValidObjectType(refType reflect.Type) bool {
return refType.Kind() == reflect.Struct
}
func isSupportedType(refType reflect.Type) bool {
switch refType.Kind() {
case reflect.Chan, reflect.Complex64, reflect.Complex128, reflect.Func, reflect.UnsafePointer, reflect.Uintptr:
return false
case reflect.Array, reflect.Slice:
return isSupportedType(refType.Elem())
case reflect.Map:
return isSupportedType(refType.Key()) && isSupportedType(refType.Elem())
}
return true
}
var (
pbMsgInterfaceType = reflect.TypeOf((*protoreflect.ProtoMessage)(nil)).Elem()
timestampType = reflect.TypeOf(time.Now())
durationType = reflect.TypeOf(time.Nanosecond)
)

141
vendor/github.com/google/cel-go/ext/protos.go generated vendored Normal file
View File

@@ -0,0 +1,141 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"github.com/google/cel-go/cel"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Protos returns a cel.EnvOption to configure extended macros and functions for
// proto manipulation.
//
// Note, all macros use the 'proto' namespace; however, at the time of macro
// expansion the namespace looks just like any other identifier. If you are
// currently using a variable named 'proto', the macro will likely work just as
// intended; however, there is some chance for collision.
//
// # Protos.GetExt
//
// Macro which generates a select expression that retrieves an extension field
// from the input proto2 syntax message. If the field is not set, the default
// value forthe extension field is returned according to safe-traversal semantics.
//
// proto.getExt(<msg>, <fully.qualified.extension.name>) -> <field-type>
//
// Examples:
//
// proto.getExt(msg, google.expr.proto2.test.int32_ext) // returns int value
//
// # Protos.HasExt
//
// Macro which generates a test-only select expression that determines whether
// an extension field is set on a proto2 syntax message.
//
// proto.hasExt(<msg>, <fully.qualified.extension.name>) -> <bool>
//
// Examples:
//
// proto.hasExt(msg, google.expr.proto2.test.int32_ext) // returns true || false
func Protos() cel.EnvOption {
return cel.Lib(protoLib{})
}
var (
protoNamespace = "proto"
hasExtension = "hasExt"
getExtension = "getExt"
)
type protoLib struct{}
// LibraryName implements the SingletonLibrary interface method.
func (protoLib) LibraryName() string {
return "cel.lib.ext.protos"
}
// CompileOptions implements the Library interface method.
func (protoLib) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Macros(
// proto.getExt(msg, select_expression)
cel.NewReceiverMacro(getExtension, 2, getProtoExt),
// proto.hasExt(msg, select_expression)
cel.NewReceiverMacro(hasExtension, 2, hasProtoExt),
),
}
}
// ProgramOptions implements the Library interface method.
func (protoLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
// hasProtoExt generates a test-only select expression for a fully-qualified extension name on a protobuf message.
func hasProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(protoNamespace, target) {
return nil, nil
}
extensionField, err := getExtFieldName(meh, args[1])
if err != nil {
return nil, err
}
return meh.PresenceTest(args[0], extensionField), nil
}
// getProtoExt generates a select expression for a fully-qualified extension name on a protobuf message.
func getProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) {
if !macroTargetMatchesNamespace(protoNamespace, target) {
return nil, nil
}
extFieldName, err := getExtFieldName(meh, args[1])
if err != nil {
return nil, err
}
return meh.Select(args[0], extFieldName), nil
}
func getExtFieldName(meh cel.MacroExprHelper, expr *exprpb.Expr) (string, *cel.Error) {
isValid := false
extensionField := ""
switch expr.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
extensionField, isValid = validateIdentifier(expr)
}
if !isValid {
return "", meh.NewError(expr.GetId(), "invalid extension field")
}
return extensionField, nil
}
func validateIdentifier(expr *exprpb.Expr) (string, bool) {
switch expr.GetExprKind().(type) {
case *exprpb.Expr_IdentExpr:
return expr.GetIdentExpr().GetName(), true
case *exprpb.Expr_SelectExpr:
sel := expr.GetSelectExpr()
if sel.GetTestOnly() {
return "", false
}
opStr, isIdent := validateIdentifier(sel.GetOperand())
if !isIdent {
return "", false
}
return opStr + "." + sel.GetField(), true
default:
return "", false
}
}

197
vendor/github.com/google/cel-go/ext/sets.go generated vendored Normal file
View File

@@ -0,0 +1,197 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ext
import (
"math"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
)
// Sets returns a cel.EnvOption to configure namespaced set relationship
// functions.
//
// There is no set type within CEL, and while one may be introduced in the
// future, there are cases where a `list` type is known to behave like a set.
// For such cases, this library provides some basic functionality for
// determining set containment, equivalence, and intersection.
//
// # Sets.Contains
//
// Returns whether the first list argument contains all elements in the second
// list argument. The list may contain elements of any type and standard CEL
// equality is used to determine whether a value exists in both lists. If the
// second list is empty, the result will always return true.
//
// sets.contains(list(T), list(T)) -> bool
//
// Examples:
//
// sets.contains([], []) // true
// sets.contains([], [1]) // false
// sets.contains([1, 2, 3, 4], [2, 3]) // true
// sets.contains([1, 2.0, 3u], [1.0, 2u, 3]) // true
//
// # Sets.Equivalent
//
// Returns whether the first and second list are set equivalent. Lists are set
// equivalent if for every item in the first list, there is an element in the
// second which is equal. The lists may not be of the same size as they do not
// guarantee the elements within them are unique, so size does not factor into
// the computation.
//
// Examples:
//
// sets.equivalent([], []) // true
// sets.equivalent([1], [1, 1]) // true
// sets.equivalent([1], [1u, 1.0]) // true
// sets.equivalent([1, 2, 3], [3u, 2.0, 1]) // true
//
// # Sets.Intersects
//
// Returns whether the first list has at least one element whose value is equal
// to an element in the second list. If either list is empty, the result will
// be false.
//
// Examples:
//
// sets.intersects([1], []) // false
// sets.intersects([1], [1, 2]) // true
// sets.intersects([[1], [2, 3]], [[1, 2], [2, 3.0]]) // true
func Sets() cel.EnvOption {
return cel.Lib(setsLib{})
}
type setsLib struct{}
// LibraryName implements the SingletonLibrary interface method.
func (setsLib) LibraryName() string {
return "cel.lib.ext.sets"
}
// CompileOptions implements the Library interface method.
func (setsLib) CompileOptions() []cel.EnvOption {
listType := cel.ListType(cel.TypeParamType("T"))
return []cel.EnvOption{
cel.Function("sets.contains",
cel.Overload("list_sets_contains_list", []*cel.Type{listType, listType}, cel.BoolType,
cel.BinaryBinding(setsContains))),
cel.Function("sets.equivalent",
cel.Overload("list_sets_equivalent_list", []*cel.Type{listType, listType}, cel.BoolType,
cel.BinaryBinding(setsEquivalent))),
cel.Function("sets.intersects",
cel.Overload("list_sets_intersects_list", []*cel.Type{listType, listType}, cel.BoolType,
cel.BinaryBinding(setsIntersects))),
cel.CostEstimatorOptions(
checker.OverloadCostEstimate("list_sets_contains_list", estimateSetsCost(1)),
checker.OverloadCostEstimate("list_sets_intersects_list", estimateSetsCost(1)),
// equivalence requires potentially two m*n comparisons to ensure each list is contained by the other
checker.OverloadCostEstimate("list_sets_equivalent_list", estimateSetsCost(2)),
),
}
}
// ProgramOptions implements the Library interface method.
func (setsLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{
cel.CostTrackerOptions(
interpreter.OverloadCostTracker("list_sets_contains_list", trackSetsCost(1)),
interpreter.OverloadCostTracker("list_sets_intersects_list", trackSetsCost(1)),
interpreter.OverloadCostTracker("list_sets_equivalent_list", trackSetsCost(2)),
),
}
}
func setsIntersects(listA, listB ref.Val) ref.Val {
lA := listA.(traits.Lister)
lB := listB.(traits.Lister)
it := lA.Iterator()
for it.HasNext() == types.True {
exists := lB.Contains(it.Next())
if exists == types.True {
return types.True
}
}
return types.False
}
func setsContains(list, sublist ref.Val) ref.Val {
l := list.(traits.Lister)
sub := sublist.(traits.Lister)
it := sub.Iterator()
for it.HasNext() == types.True {
exists := l.Contains(it.Next())
if exists != types.True {
return exists
}
}
return types.True
}
func setsEquivalent(listA, listB ref.Val) ref.Val {
aContainsB := setsContains(listA, listB)
if aContainsB != types.True {
return aContainsB
}
return setsContains(listB, listA)
}
func estimateSetsCost(costFactor float64) checker.FunctionEstimator {
return func(estimator checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if len(args) == 2 {
arg0Size := estimateSize(estimator, args[0])
arg1Size := estimateSize(estimator, args[1])
costEstimate := arg0Size.Multiply(arg1Size).MultiplyByCostFactor(costFactor).Add(callCostEstimate)
return &checker.CallEstimate{CostEstimate: costEstimate}
}
return nil
}
}
func estimateSize(estimator checker.CostEstimator, node checker.AstNode) checker.SizeEstimate {
if l := node.ComputedSize(); l != nil {
return *l
}
if l := estimator.EstimateSize(node); l != nil {
return *l
}
return checker.SizeEstimate{Min: 0, Max: math.MaxUint64}
}
func trackSetsCost(costFactor float64) interpreter.FunctionTracker {
return func(args []ref.Val, _ ref.Val) *uint64 {
lhsSize := actualSize(args[0])
rhsSize := actualSize(args[1])
cost := callCost + uint64(float64(lhsSize*rhsSize)*costFactor)
return &cost
}
}
func actualSize(value ref.Val) uint64 {
if sz, ok := value.(traits.Sizer); ok {
return uint64(sz.Size().(types.Int))
}
return 1
}
var (
callCostEstimate = checker.CostEstimate{Min: 1, Max: 1}
callCost = uint64(1)
)

View File

@@ -19,32 +19,92 @@ package ext
import (
"fmt"
"math"
"reflect"
"sort"
"strings"
"unicode"
"unicode/utf8"
"golang.org/x/text/language"
"golang.org/x/text/message"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter"
)
const (
defaultLocale = "en-US"
defaultPrecision = 6
)
// Strings returns a cel.EnvOption to configure extended functions for string manipulation.
// As a general note, all indices are zero-based.
//
// CharAt
// # CharAt
//
// Returns the character at the given position. If the position is negative, or greater than
// the length of the string, the function will produce an error:
//
// <string>.charAt(<int>) -> <string>
// <string>.charAt(<int>) -> <string>
//
// Examples:
//
// 'hello'.charAt(4) // return 'o'
// 'hello'.charAt(5) // return ''
// 'hello'.charAt(-1) // error
// 'hello'.charAt(4) // return 'o'
// 'hello'.charAt(5) // return ''
// 'hello'.charAt(-1) // error
//
// IndexOf
// # Format
//
// Introduced at version: 1
//
// Returns a new string with substitutions being performed, printf-style.
// The valid formatting clauses are:
//
// `%s` - substitutes a string. This can also be used on bools, lists, maps, bytes,
// Duration and Timestamp, in addition to all numerical types (int, uint, and double).
// Note that the dot/period decimal separator will always be used when printing a list
// or map that contains a double, and that null can be passed (which results in the
// string "null") in addition to types.
// `%d` - substitutes an integer.
// `%f` - substitutes a double with fixed-point precision. The default precision is 6, but
// this can be adjusted. The strings `Infinity`, `-Infinity`, and `NaN` are also valid input
// for this clause.
// `%e` - substitutes a double in scientific notation. The default precision is 6, but this
// can be adjusted.
// `%b` - substitutes an integer with its equivalent binary string. Can also be used on bools.
// `%x` - substitutes an integer with its equivalent in hexadecimal, or if given a string or
// bytes, will output each character's equivalent in hexadecimal.
// `%X` - same as above, but with A-F capitalized.
// `%o` - substitutes an integer with its equivalent in octal.
//
// <string>.format(<list>) -> <string>
//
// Examples:
//
// "this is a string: %s\nand an integer: %d".format(["str", 42]) // returns "this is a string: str\nand an integer: 42"
// "a double substituted with %%s: %s".format([64.2]) // returns "a double substituted with %s: 64.2"
// "string type: %s".format([type(string)]) // returns "string type: string"
// "timestamp: %s".format([timestamp("2023-02-03T23:31:20+00:00")]) // returns "timestamp: 2023-02-03T23:31:20Z"
// "duration: %s".format([duration("1h45m47s")]) // returns "duration: 6347s"
// "%f".format([3.14]) // returns "3.140000"
// "scientific notation: %e".format([2.71828]) // returns "scientific notation: 2.718280\u202f\u00d7\u202f10\u2070\u2070"
// "5 in binary: %b".format([5]), // returns "5 in binary; 101"
// "26 in hex: %x".format([26]), // returns "26 in hex: 1a"
// "26 in hex (uppercase): %X".format([26]) // returns "26 in hex (uppercase): 1A"
// "30 in octal: %o".format([30]) // returns "30 in octal: 36"
// "a map inside a list: %s".format([[1, 2, 3, {"a": "x", "b": "y", "c": "z"}]]) // returns "a map inside a list: [1, 2, 3, {"a":"x", "b":"y", "c":"d"}]"
// "true bool: %s - false bool: %s\nbinary bool: %b".format([true, false, true]) // returns "true bool: true - false bool: false\nbinary bool: 1"
//
// Passing an incorrect type (an integer to `%s`) is considered an error, as well as attempting
// to use more formatting clauses than there are arguments (`%d %d %d` while passing two ints, for instance).
// If compile-time checking is enabled, and the formatting string is a constant, and the argument list is a literal,
// then letting any arguments go unused/unformatted is also considered an error.
//
// # IndexOf
//
// Returns the integer index of the first occurrence of the search string. If the search string is
// not found the function returns -1.
@@ -52,19 +112,19 @@ import (
// The function also accepts an optional position from which to begin the substring search. If the
// substring is the empty string, the index where the search starts is returned (zero or custom).
//
// <string>.indexOf(<string>) -> <int>
// <string>.indexOf(<string>, <int>) -> <int>
// <string>.indexOf(<string>) -> <int>
// <string>.indexOf(<string>, <int>) -> <int>
//
// Examples:
//
// 'hello mellow'.indexOf('') // returns 0
// 'hello mellow'.indexOf('ello') // returns 1
// 'hello mellow'.indexOf('jello') // returns -1
// 'hello mellow'.indexOf('', 2) // returns 2
// 'hello mellow'.indexOf('ello', 2) // returns 7
// 'hello mellow'.indexOf('ello', 20) // error
// 'hello mellow'.indexOf('') // returns 0
// 'hello mellow'.indexOf('ello') // returns 1
// 'hello mellow'.indexOf('jello') // returns -1
// 'hello mellow'.indexOf('', 2) // returns 2
// 'hello mellow'.indexOf('ello', 2) // returns 7
// 'hello mellow'.indexOf('ello', 20) // error
//
// Join
// # Join
//
// Returns a new string where the elements of string list are concatenated.
//
@@ -75,12 +135,12 @@ import (
//
// Examples:
//
// ['hello', 'mellow'].join() // returns 'hellomellow'
// ['hello', 'mellow'].join(' ') // returns 'hello mellow'
// [].join() // returns ''
// [].join('/') // returns ''
// ['hello', 'mellow'].join() // returns 'hellomellow'
// ['hello', 'mellow'].join(' ') // returns 'hello mellow'
// [].join() // returns ''
// [].join('/') // returns ''
//
// LastIndexOf
// # LastIndexOf
//
// Returns the integer index at the start of the last occurrence of the search string. If the
// search string is not found the function returns -1.
@@ -89,31 +149,45 @@ import (
// considered as the beginning of the substring match. If the substring is the empty string,
// the index where the search starts is returned (string length or custom).
//
// <string>.lastIndexOf(<string>) -> <int>
// <string>.lastIndexOf(<string>, <int>) -> <int>
// <string>.lastIndexOf(<string>) -> <int>
// <string>.lastIndexOf(<string>, <int>) -> <int>
//
// Examples:
//
// 'hello mellow'.lastIndexOf('') // returns 12
// 'hello mellow'.lastIndexOf('ello') // returns 7
// 'hello mellow'.lastIndexOf('jello') // returns -1
// 'hello mellow'.lastIndexOf('ello', 6) // returns 1
// 'hello mellow'.lastIndexOf('ello', -1) // error
// 'hello mellow'.lastIndexOf('') // returns 12
// 'hello mellow'.lastIndexOf('ello') // returns 7
// 'hello mellow'.lastIndexOf('jello') // returns -1
// 'hello mellow'.lastIndexOf('ello', 6) // returns 1
// 'hello mellow'.lastIndexOf('ello', -1) // error
//
// LowerAscii
// # LowerAscii
//
// Returns a new string where all ASCII characters are lower-cased.
//
// This function does not perform Unicode case-mapping for characters outside the ASCII range.
//
// <string>.lowerAscii() -> <string>
// <string>.lowerAscii() -> <string>
//
// Examples:
//
// 'TacoCat'.lowerAscii() // returns 'tacocat'
// 'TacoCÆt Xii'.lowerAscii() // returns 'tacocÆt xii'
// 'TacoCat'.lowerAscii() // returns 'tacocat'
// 'TacoCÆt Xii'.lowerAscii() // returns 'tacocÆt xii'
//
// Replace
// # Strings.Quote
//
// Introduced in version: 1
//
// Takes the given string and makes it safe to print (without any formatting due to escape sequences).
// If any invalid UTF-8 characters are encountered, they are replaced with \uFFFD.
//
// strings.quote(<string>)
//
// Examples:
//
// strings.quote('single-quote with "double quote"') // returns '"single-quote with \"double quote\""'
// strings.quote("two escape sequences \a\n") // returns '"two escape sequences \\a\\n"'
//
// # Replace
//
// Returns a new string based on the target, which replaces the occurrences of a search string
// with a replacement string if present. The function accepts an optional limit on the number of
@@ -122,17 +196,17 @@ import (
// When the replacement limit is 0, the result is the original string. When the limit is a negative
// number, the function behaves the same as replace all.
//
// <string>.replace(<string>, <string>) -> <string>
// <string>.replace(<string>, <string>, <int>) -> <string>
// <string>.replace(<string>, <string>) -> <string>
// <string>.replace(<string>, <string>, <int>) -> <string>
//
// Examples:
//
// 'hello hello'.replace('he', 'we') // returns 'wello wello'
// 'hello hello'.replace('he', 'we', -1) // returns 'wello wello'
// 'hello hello'.replace('he', 'we', 1) // returns 'wello hello'
// 'hello hello'.replace('he', 'we', 0) // returns 'hello hello'
// 'hello hello'.replace('he', 'we') // returns 'wello wello'
// 'hello hello'.replace('he', 'we', -1) // returns 'wello wello'
// 'hello hello'.replace('he', 'we', 1) // returns 'wello hello'
// 'hello hello'.replace('he', 'we', 0) // returns 'hello hello'
//
// Split
// # Split
//
// Returns a list of strings split from the input by the given separator. The function accepts
// an optional argument specifying a limit on the number of substrings produced by the split.
@@ -141,18 +215,18 @@ import (
// target string to split. When the limit is a negative number, the function behaves the same as
// split all.
//
// <string>.split(<string>) -> <list<string>>
// <string>.split(<string>, <int>) -> <list<string>>
// <string>.split(<string>) -> <list<string>>
// <string>.split(<string>, <int>) -> <list<string>>
//
// Examples:
//
// 'hello hello hello'.split(' ') // returns ['hello', 'hello', 'hello']
// 'hello hello hello'.split(' ', 0) // returns []
// 'hello hello hello'.split(' ', 1) // returns ['hello hello hello']
// 'hello hello hello'.split(' ', 2) // returns ['hello', 'hello hello']
// 'hello hello hello'.split(' ', -1) // returns ['hello', 'hello', 'hello']
// 'hello hello hello'.split(' ') // returns ['hello', 'hello', 'hello']
// 'hello hello hello'.split(' ', 0) // returns []
// 'hello hello hello'.split(' ', 1) // returns ['hello hello hello']
// 'hello hello hello'.split(' ', 2) // returns ['hello', 'hello hello']
// 'hello hello hello'.split(' ', -1) // returns ['hello', 'hello', 'hello']
//
// Substring
// # Substring
//
// Returns the substring given a numeric range corresponding to character positions. Optionally
// may omit the trailing range for a substring from a given character position until the end of
@@ -162,48 +236,104 @@ import (
// error to specify an end range that is lower than the start range, or for either the start or end
// index to be negative or exceed the string length.
//
// <string>.substring(<int>) -> <string>
// <string>.substring(<int>, <int>) -> <string>
// <string>.substring(<int>) -> <string>
// <string>.substring(<int>, <int>) -> <string>
//
// Examples:
//
// 'tacocat'.substring(4) // returns 'cat'
// 'tacocat'.substring(0, 4) // returns 'taco'
// 'tacocat'.substring(-1) // error
// 'tacocat'.substring(2, 1) // error
// 'tacocat'.substring(4) // returns 'cat'
// 'tacocat'.substring(0, 4) // returns 'taco'
// 'tacocat'.substring(-1) // error
// 'tacocat'.substring(2, 1) // error
//
// Trim
// # Trim
//
// Returns a new string which removes the leading and trailing whitespace in the target string.
// The trim function uses the Unicode definition of whitespace which does not include the
// zero-width spaces. See: https://en.wikipedia.org/wiki/Whitespace_character#Unicode
//
// <string>.trim() -> <string>
// <string>.trim() -> <string>
//
// Examples:
//
// ' \ttrim\n '.trim() // returns 'trim'
// ' \ttrim\n '.trim() // returns 'trim'
//
// UpperAscii
// # UpperAscii
//
// Returns a new string where all ASCII characters are upper-cased.
//
// This function does not perform Unicode case-mapping for characters outside the ASCII range.
//
// <string>.upperAscii() -> <string>
// <string>.upperAscii() -> <string>
//
// Examples:
//
// 'TacoCat'.upperAscii() // returns 'TACOCAT'
// 'TacoCÆt Xii'.upperAscii() // returns 'TACOCÆT XII'
func Strings() cel.EnvOption {
return cel.Lib(stringLib{})
// 'TacoCat'.upperAscii() // returns 'TACOCAT'
// 'TacoCÆt Xii'.upperAscii() // returns 'TACOCÆT XII'
func Strings(options ...StringsOption) cel.EnvOption {
s := &stringLib{version: math.MaxUint32}
for _, o := range options {
s = o(s)
}
return cel.Lib(s)
}
type stringLib struct{}
type stringLib struct {
locale string
version uint32
}
func (stringLib) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
// LibraryName implements the SingletonLibrary interface method.
func (*stringLib) LibraryName() string {
return "cel.lib.ext.strings"
}
// StringsOption is a functional interface for configuring the strings library.
type StringsOption func(*stringLib) *stringLib
// StringsLocale configures the library with the given locale. The locale tag will
// be checked for validity at the time that EnvOptions are configured. If this option
// is not passed, string.format will behave as if en_US was passed as the locale.
func StringsLocale(locale string) StringsOption {
return func(sl *stringLib) *stringLib {
sl.locale = locale
return sl
}
}
// StringsVersion configures the version of the string library.
//
// The version limits which functions are available. Only functions introduced
// below or equal to the given version included in the library. If this option
// is not set, all functions are available.
//
// See the library documentation to determine which version a function was introduced.
// If the documentation does not state which version a function was introduced, it can
// be assumed to be introduced at version 0, when the library was first created.
func StringsVersion(version uint32) StringsOption {
return func(lib *stringLib) *stringLib {
lib.version = version
return lib
}
}
// CompileOptions implements the Library interface method.
func (lib *stringLib) CompileOptions() []cel.EnvOption {
formatLocale := "en_US"
if lib.locale != "" {
// ensure locale is properly-formed if set
_, err := language.Parse(lib.locale)
if err != nil {
return []cel.EnvOption{
func(e *cel.Env) (*cel.Env, error) {
return nil, fmt.Errorf("failed to parse locale: %w", err)
},
}
}
formatLocale = lib.locale
}
opts := []cel.EnvOption{
cel.Function("charAt",
cel.MemberOverload("string_char_at_int", []*cel.Type{cel.StringType, cel.IntType}, cel.StringType,
cel.BinaryBinding(func(str, ind ref.Val) ref.Val {
@@ -303,28 +433,64 @@ func (stringLib) CompileOptions() []cel.EnvOption {
s := str.(types.String)
return stringOrError(upperASCII(string(s)))
}))),
cel.Function("join",
cel.MemberOverload("list_join", []*cel.Type{cel.ListType(cel.StringType)}, cel.StringType,
cel.UnaryBinding(func(list ref.Val) ref.Val {
l, err := list.ConvertToNative(stringListType)
if err != nil {
return types.NewErr(err.Error())
}
return stringOrError(join(l.([]string)))
})),
cel.MemberOverload("list_join_string", []*cel.Type{cel.ListType(cel.StringType), cel.StringType}, cel.StringType,
cel.BinaryBinding(func(list, delim ref.Val) ref.Val {
l, err := list.ConvertToNative(stringListType)
if err != nil {
return types.NewErr(err.Error())
}
d := delim.(types.String)
return stringOrError(joinSeparator(l.([]string), string(d)))
}))),
}
if lib.version >= 1 {
opts = append(opts, cel.Function("format",
cel.MemberOverload("string_format", []*cel.Type{cel.StringType, cel.ListType(cel.DynType)}, cel.StringType,
cel.FunctionBinding(func(args ...ref.Val) ref.Val {
s := string(args[0].(types.String))
formatArgs := args[1].(traits.Lister)
return stringOrError(interpreter.ParseFormatString(s, &stringFormatter{}, &stringArgList{formatArgs}, formatLocale))
}))),
cel.Function("strings.quote", cel.Overload("strings_quote", []*cel.Type{cel.StringType}, cel.StringType,
cel.UnaryBinding(func(str ref.Val) ref.Val {
s := str.(types.String)
return stringOrError(quote(string(s)))
}))))
}
if lib.version >= 2 {
opts = append(opts,
cel.Function("join",
cel.MemberOverload("list_join", []*cel.Type{cel.ListType(cel.StringType)}, cel.StringType,
cel.UnaryBinding(func(list ref.Val) ref.Val {
l := list.(traits.Lister)
return stringOrError(joinValSeparator(l, ""))
})),
cel.MemberOverload("list_join_string", []*cel.Type{cel.ListType(cel.StringType), cel.StringType}, cel.StringType,
cel.BinaryBinding(func(list, delim ref.Val) ref.Val {
l := list.(traits.Lister)
d := delim.(types.String)
return stringOrError(joinValSeparator(l, string(d)))
}))),
)
} else {
opts = append(opts,
cel.Function("join",
cel.MemberOverload("list_join", []*cel.Type{cel.ListType(cel.StringType)}, cel.StringType,
cel.UnaryBinding(func(list ref.Val) ref.Val {
l, err := list.ConvertToNative(stringListType)
if err != nil {
return types.NewErr(err.Error())
}
return stringOrError(join(l.([]string)))
})),
cel.MemberOverload("list_join_string", []*cel.Type{cel.ListType(cel.StringType), cel.StringType}, cel.StringType,
cel.BinaryBinding(func(list, delim ref.Val) ref.Val {
l, err := list.ConvertToNative(stringListType)
if err != nil {
return types.NewErr(err.Error())
}
d := delim.(types.String)
return stringOrError(joinSeparator(l.([]string), string(d)))
}))),
)
}
return opts
}
func (stringLib) ProgramOptions() []cel.ProgramOption {
// ProgramOptions implements the Library interface method.
func (*stringLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
@@ -478,6 +644,452 @@ func join(strs []string) (string, error) {
return strings.Join(strs, ""), nil
}
func joinValSeparator(strs traits.Lister, separator string) (string, error) {
sz := strs.Size().(types.Int)
var sb strings.Builder
for i := types.Int(0); i < sz; i++ {
if i != 0 {
sb.WriteString(separator)
}
elem := strs.Get(i)
str, ok := elem.(types.String)
if !ok {
return "", fmt.Errorf("join: invalid input: %v", elem)
}
sb.WriteString(string(str))
}
return sb.String(), nil
}
type clauseImpl func(ref.Val, string) (string, error)
func clauseForType(argType ref.Type) (clauseImpl, error) {
switch argType {
case types.IntType, types.UintType:
return formatDecimal, nil
case types.StringType, types.BytesType, types.BoolType, types.NullType, types.TypeType:
return FormatString, nil
case types.TimestampType, types.DurationType:
// special case to ensure timestamps/durations get printed as CEL literals
return func(arg ref.Val, locale string) (string, error) {
argStrVal := arg.ConvertToType(types.StringType)
argStr := argStrVal.Value().(string)
if arg.Type() == types.TimestampType {
return fmt.Sprintf("timestamp(%q)", argStr), nil
}
if arg.Type() == types.DurationType {
return fmt.Sprintf("duration(%q)", argStr), nil
}
return "", fmt.Errorf("cannot convert argument of type %s to timestamp/duration", arg.Type().TypeName())
}, nil
case types.ListType:
return formatList, nil
case types.MapType:
return formatMap, nil
case types.DoubleType:
// avoid formatFixed so we can output a period as the decimal separator in order
// to always be a valid CEL literal
return func(arg ref.Val, locale string) (string, error) {
argDouble, ok := arg.Value().(float64)
if !ok {
return "", fmt.Errorf("couldn't convert %s to float64", arg.Type().TypeName())
}
fmtStr := fmt.Sprintf("%%.%df", defaultPrecision)
return fmt.Sprintf(fmtStr, argDouble), nil
}, nil
case types.TypeType:
return func(arg ref.Val, locale string) (string, error) {
return fmt.Sprintf("type(%s)", arg.Value().(string)), nil
}, nil
default:
return nil, fmt.Errorf("no formatting function for %s", argType.TypeName())
}
}
func formatList(arg ref.Val, locale string) (string, error) {
argList := arg.(traits.Lister)
argIterator := argList.Iterator()
var listStrBuilder strings.Builder
_, err := listStrBuilder.WriteRune('[')
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
for argIterator.HasNext() == types.True {
member := argIterator.Next()
memberFormat, err := clauseForType(member.Type())
if err != nil {
return "", err
}
unquotedStr, err := memberFormat(member, locale)
if err != nil {
return "", err
}
str := quoteForCEL(member, unquotedStr)
_, err = listStrBuilder.WriteString(str)
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
if argIterator.HasNext() == types.True {
_, err = listStrBuilder.WriteString(", ")
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
}
}
_, err = listStrBuilder.WriteRune(']')
if err != nil {
return "", fmt.Errorf("error writing to list string: %w", err)
}
return listStrBuilder.String(), nil
}
func formatMap(arg ref.Val, locale string) (string, error) {
argMap := arg.(traits.Mapper)
argIterator := argMap.Iterator()
type mapPair struct {
key string
value string
}
argPairs := make([]mapPair, argMap.Size().Value().(int64))
i := 0
for argIterator.HasNext() == types.True {
key := argIterator.Next()
var keyFormat clauseImpl
switch key.Type() {
case types.StringType, types.BoolType:
keyFormat = FormatString
case types.IntType, types.UintType:
keyFormat = formatDecimal
default:
return "", fmt.Errorf("no formatting function for map key of type %s", key.Type().TypeName())
}
unquotedKeyStr, err := keyFormat(key, locale)
if err != nil {
return "", err
}
keyStr := quoteForCEL(key, unquotedKeyStr)
value, found := argMap.Find(key)
if !found {
return "", fmt.Errorf("could not find key: %q", key)
}
valueFormat, err := clauseForType(value.Type())
if err != nil {
return "", err
}
unquotedValueStr, err := valueFormat(value, locale)
if err != nil {
return "", err
}
valueStr := quoteForCEL(value, unquotedValueStr)
argPairs[i] = mapPair{keyStr, valueStr}
i++
}
sort.SliceStable(argPairs, func(x, y int) bool {
return argPairs[x].key < argPairs[y].key
})
var mapStrBuilder strings.Builder
_, err := mapStrBuilder.WriteRune('{')
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
for i, entry := range argPairs {
_, err = mapStrBuilder.WriteString(fmt.Sprintf("%s:%s", entry.key, entry.value))
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
if i < len(argPairs)-1 {
_, err = mapStrBuilder.WriteString(", ")
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
}
}
_, err = mapStrBuilder.WriteRune('}')
if err != nil {
return "", fmt.Errorf("error writing to map string: %w", err)
}
return mapStrBuilder.String(), nil
}
// quoteForCEL takes a formatted, unquoted value and quotes it in a manner
// suitable for embedding directly in CEL.
func quoteForCEL(refVal ref.Val, unquotedValue string) string {
switch refVal.Type() {
case types.StringType:
return fmt.Sprintf("%q", unquotedValue)
case types.BytesType:
return fmt.Sprintf("b%q", unquotedValue)
case types.DoubleType:
// special case to handle infinity/NaN
num := refVal.Value().(float64)
if math.IsInf(num, 1) || math.IsInf(num, -1) || math.IsNaN(num) {
return fmt.Sprintf("%q", unquotedValue)
}
return unquotedValue
default:
return unquotedValue
}
}
// FormatString returns the string representation of a CEL value.
// It is used to implement the %s specifier in the (string).format() extension
// function.
func FormatString(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.ListType:
return formatList(arg, locale)
case types.MapType:
return formatMap(arg, locale)
case types.IntType, types.UintType, types.DoubleType,
types.BoolType, types.StringType, types.TimestampType, types.BytesType, types.DurationType, types.TypeType:
argStrVal := arg.ConvertToType(types.StringType)
argStr, ok := argStrVal.Value().(string)
if !ok {
return "", fmt.Errorf("could not convert argument %q to string", argStrVal)
}
return argStr, nil
case types.NullType:
return "null", nil
default:
return "", fmt.Errorf("string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given %s", arg.Type().TypeName())
}
}
func formatDecimal(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt, ok := arg.ConvertToType(types.IntType).Value().(int64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value())
}
return fmt.Sprintf("%d", argInt), nil
case types.UintType:
argInt, ok := arg.ConvertToType(types.UintType).Value().(uint64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value())
}
return fmt.Sprintf("%d", argInt), nil
default:
return "", fmt.Errorf("decimal clause can only be used on integers, was given %s", arg.Type().TypeName())
}
}
func matchLanguage(locale string) (language.Tag, error) {
matcher, err := makeMatcher(locale)
if err != nil {
return language.Und, err
}
tag, _ := language.MatchStrings(matcher, locale)
return tag, nil
}
func makeMatcher(locale string) (language.Matcher, error) {
tags := make([]language.Tag, 0)
tag, err := language.Parse(locale)
if err != nil {
return nil, err
}
tags = append(tags, tag)
return language.NewMatcher(tags), nil
}
// quote implements a string quoting function. The string will be wrapped in
// double quotes, and all valid CEL escape sequences will be escaped to show up
// literally if printed. If the input contains any invalid UTF-8, the invalid runes
// will be replaced with utf8.RuneError.
func quote(s string) (string, error) {
var quotedStrBuilder strings.Builder
for _, c := range sanitize(s) {
switch c {
case '\a':
quotedStrBuilder.WriteString("\\a")
case '\b':
quotedStrBuilder.WriteString("\\b")
case '\f':
quotedStrBuilder.WriteString("\\f")
case '\n':
quotedStrBuilder.WriteString("\\n")
case '\r':
quotedStrBuilder.WriteString("\\r")
case '\t':
quotedStrBuilder.WriteString("\\t")
case '\v':
quotedStrBuilder.WriteString("\\v")
case '\\':
quotedStrBuilder.WriteString("\\\\")
case '"':
quotedStrBuilder.WriteString("\\\"")
default:
quotedStrBuilder.WriteRune(c)
}
}
escapedStr := quotedStrBuilder.String()
return "\"" + escapedStr + "\"", nil
}
// sanitize replaces all invalid runes in the given string with utf8.RuneError.
func sanitize(s string) string {
var sanitizedStringBuilder strings.Builder
for _, r := range s {
if !utf8.ValidRune(r) {
sanitizedStringBuilder.WriteRune(utf8.RuneError)
} else {
sanitizedStringBuilder.WriteRune(r)
}
}
return sanitizedStringBuilder.String()
}
type stringFormatter struct{}
func (c *stringFormatter) String(arg ref.Val, locale string) (string, error) {
return FormatString(arg, locale)
}
func (c *stringFormatter) Decimal(arg ref.Val, locale string) (string, error) {
return formatDecimal(arg, locale)
}
func (c *stringFormatter) Fixed(precision *int) func(ref.Val, string) (string, error) {
if precision == nil {
precision = new(int)
*precision = defaultPrecision
}
return func(arg ref.Val, locale string) (string, error) {
strException := false
if arg.Type() == types.StringType {
argStr := arg.Value().(string)
if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" {
strException = true
}
}
if arg.Type() != types.DoubleType && !strException {
return "", fmt.Errorf("fixed-point clause can only be used on doubles, was given %s", arg.Type().TypeName())
}
argFloatVal := arg.ConvertToType(types.DoubleType)
argFloat, ok := argFloatVal.Value().(float64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to float64", argFloatVal.Value())
}
fmtStr := fmt.Sprintf("%%.%df", *precision)
matchedLocale, err := matchLanguage(locale)
if err != nil {
return "", fmt.Errorf("error matching locale: %w", err)
}
return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil
}
}
func (c *stringFormatter) Scientific(precision *int) func(ref.Val, string) (string, error) {
if precision == nil {
precision = new(int)
*precision = defaultPrecision
}
return func(arg ref.Val, locale string) (string, error) {
strException := false
if arg.Type() == types.StringType {
argStr := arg.Value().(string)
if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" {
strException = true
}
}
if arg.Type() != types.DoubleType && !strException {
return "", fmt.Errorf("scientific clause can only be used on doubles, was given %s", arg.Type().TypeName())
}
argFloatVal := arg.ConvertToType(types.DoubleType)
argFloat, ok := argFloatVal.Value().(float64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to float64", argFloatVal.Value())
}
matchedLocale, err := matchLanguage(locale)
if err != nil {
return "", fmt.Errorf("error matching locale: %w", err)
}
fmtStr := fmt.Sprintf("%%%de", *precision)
return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil
}
}
func (c *stringFormatter) Binary(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt := arg.Value().(int64)
// locale is intentionally unused as integers formatted as binary
// strings are locale-independent
return fmt.Sprintf("%b", argInt), nil
case types.UintType:
argInt := arg.Value().(uint64)
return fmt.Sprintf("%b", argInt), nil
case types.BoolType:
argBool := arg.Value().(bool)
if argBool {
return "1", nil
}
return "0", nil
default:
return "", fmt.Errorf("only integers and bools can be formatted as binary, was given %s", arg.Type().TypeName())
}
}
func (c *stringFormatter) Hex(useUpper bool) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
fmtStr := "%x"
if useUpper {
fmtStr = "%X"
}
switch arg.Type() {
case types.StringType, types.BytesType:
if arg.Type() == types.BytesType {
return fmt.Sprintf(fmtStr, arg.Value().([]byte)), nil
}
return fmt.Sprintf(fmtStr, arg.Value().(string)), nil
case types.IntType:
argInt, ok := arg.Value().(int64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value())
}
return fmt.Sprintf(fmtStr, argInt), nil
case types.UintType:
argInt, ok := arg.Value().(uint64)
if !ok {
return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value())
}
return fmt.Sprintf(fmtStr, argInt), nil
default:
return "", fmt.Errorf("only integers, byte buffers, and strings can be formatted as hex, was given %s", arg.Type().TypeName())
}
}
}
func (c *stringFormatter) Octal(arg ref.Val, locale string) (string, error) {
switch arg.Type() {
case types.IntType:
argInt := arg.Value().(int64)
return fmt.Sprintf("%o", argInt), nil
case types.UintType:
argInt := arg.Value().(uint64)
return fmt.Sprintf("%o", argInt), nil
default:
return "", fmt.Errorf("octal clause can only be used on integers, was given %s", arg.Type().TypeName())
}
}
type stringArgList struct {
args traits.Lister
}
func (c *stringArgList) Arg(index int64) (ref.Val, error) {
if index >= c.args.Size().Value().(int64) {
return nil, fmt.Errorf("index %d out of range", index)
}
return c.args.Get(types.Int(index)), nil
}
func (c *stringArgList) ArgSize() int64 {
return c.args.Size().Value().(int64)
}
var (
stringListType = reflect.TypeOf([]string{})
)

View File

@@ -11,10 +11,10 @@ go_library(
"activation.go",
"attribute_patterns.go",
"attributes.go",
"coster.go",
"decorators.go",
"dispatcher.go",
"evalstate.go",
"formatting.go",
"interpretable.go",
"interpreter.go",
"optimizations.go",
@@ -25,14 +25,15 @@ go_library(
importpath = "github.com/google/cel-go/interpreter",
deps = [
"//common:go_default_library",
"//common/ast:go_default_library",
"//common/containers:go_default_library",
"//common/functions:go_default_library",
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//interpreter/functions:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/durationpb:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
@@ -49,23 +50,25 @@ go_test(
"attributes_test.go",
"interpreter_test.go",
"prune_test.go",
"runtimecost_test.go",
],
embed = [
":go_default_library",
],
deps = [
"//checker:go_default_library",
"//checker/decls:go_default_library",
"//common/containers:go_default_library",
"//common/debug:go_default_library",
"//common/decls:go_default_library",
"//common/functions:go_default_library",
"//common/operators:go_default_library",
"//common/stdlib:go_default_library",
"//common/types:go_default_library",
"//interpreter/functions:go_default_library",
"//parser:go_default_library",
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/anypb:go_default_library",
],

View File

@@ -28,7 +28,7 @@ import (
type Activation interface {
// ResolveName returns a value from the activation by qualified name, or false if the name
// could not be found.
ResolveName(name string) (interface{}, bool)
ResolveName(name string) (any, bool)
// Parent returns the parent of the current activation, may be nil.
// If non-nil, the parent will be searched during resolve calls.
@@ -43,23 +43,23 @@ func EmptyActivation() Activation {
// emptyActivation is a variable-free activation.
type emptyActivation struct{}
func (emptyActivation) ResolveName(string) (interface{}, bool) { return nil, false }
func (emptyActivation) Parent() Activation { return nil }
func (emptyActivation) ResolveName(string) (any, bool) { return nil, false }
func (emptyActivation) Parent() Activation { return nil }
// NewActivation returns an activation based on a map-based binding where the map keys are
// expected to be qualified names used with ResolveName calls.
//
// The input `bindings` may either be of type `Activation` or `map[string]interface{}`.
// The input `bindings` may either be of type `Activation` or `map[string]any`.
//
// Lazy bindings may be supplied within the map-based input in either of the following forms:
// - func() interface{}
// - func() any
// - func() ref.Val
//
// The output of the lazy binding will overwrite the variable reference in the internal map.
//
// Values which are not represented as ref.Val types on input may be adapted to a ref.Val using
// the ref.TypeAdapter configured in the environment.
func NewActivation(bindings interface{}) (Activation, error) {
// the types.Adapter configured in the environment.
func NewActivation(bindings any) (Activation, error) {
if bindings == nil {
return nil, errors.New("bindings must be non-nil")
}
@@ -67,7 +67,7 @@ func NewActivation(bindings interface{}) (Activation, error) {
if isActivation {
return a, nil
}
m, isMap := bindings.(map[string]interface{})
m, isMap := bindings.(map[string]any)
if !isMap {
return nil, fmt.Errorf(
"activation input must be an activation or map[string]interface: got %T",
@@ -81,7 +81,7 @@ func NewActivation(bindings interface{}) (Activation, error) {
// Named bindings may lazily supply values by providing a function which accepts no arguments and
// produces an interface value.
type mapActivation struct {
bindings map[string]interface{}
bindings map[string]any
}
// Parent implements the Activation interface method.
@@ -90,7 +90,7 @@ func (a *mapActivation) Parent() Activation {
}
// ResolveName implements the Activation interface method.
func (a *mapActivation) ResolveName(name string) (interface{}, bool) {
func (a *mapActivation) ResolveName(name string) (any, bool) {
obj, found := a.bindings[name]
if !found {
return nil, false
@@ -100,7 +100,7 @@ func (a *mapActivation) ResolveName(name string) (interface{}, bool) {
obj = fn()
a.bindings[name] = obj
}
fnRaw, isLazy := obj.(func() interface{})
fnRaw, isLazy := obj.(func() any)
if isLazy {
obj = fnRaw()
a.bindings[name] = obj
@@ -121,7 +121,7 @@ func (a *hierarchicalActivation) Parent() Activation {
}
// ResolveName implements the Activation interface method.
func (a *hierarchicalActivation) ResolveName(name string) (interface{}, bool) {
func (a *hierarchicalActivation) ResolveName(name string) (any, bool) {
if object, found := a.child.ResolveName(name); found {
return object, found
}
@@ -138,8 +138,8 @@ func NewHierarchicalActivation(parent Activation, child Activation) Activation {
// representing field and index operations that should result in a 'types.Unknown' result.
//
// The `bindings` value may be any value type supported by the interpreter.NewActivation call,
// but is typically either an existing Activation or map[string]interface{}.
func NewPartialActivation(bindings interface{},
// but is typically either an existing Activation or map[string]any.
func NewPartialActivation(bindings any,
unknowns ...*AttributePattern) (PartialActivation, error) {
a, err := NewActivation(bindings)
if err != nil {
@@ -184,7 +184,7 @@ func (v *varActivation) Parent() Activation {
}
// ResolveName implements the Activation interface method.
func (v *varActivation) ResolveName(name string) (interface{}, bool) {
func (v *varActivation) ResolveName(name string) (any, bool) {
if name == v.name {
return v.val, true
}
@@ -194,7 +194,7 @@ func (v *varActivation) ResolveName(name string) (interface{}, bool) {
var (
// pool of var activations to reduce allocations during folds.
varActivationPool = &sync.Pool{
New: func() interface{} {
New: func() any {
return &varActivation{}
},
}

View File

@@ -36,9 +36,9 @@ import (
//
// Examples:
//
// 1. ns.myvar["complex-value"]
// 2. ns.myvar["complex-value"][0]
// 3. ns.myvar["complex-value"].*.name
// 1. ns.myvar["complex-value"]
// 2. ns.myvar["complex-value"][0]
// 3. ns.myvar["complex-value"].*.name
//
// The first example is simple: match an attribute where the variable is 'ns.myvar' with a
// field access on 'complex-value'. The second example expands the match to indicate that only
@@ -108,7 +108,7 @@ func (apat *AttributePattern) QualifierPatterns() []*AttributeQualifierPattern {
// AttributeQualifierPattern holds a wildcard or valued qualifier pattern.
type AttributeQualifierPattern struct {
wildcard bool
value interface{}
value any
}
// Matches returns true if the qualifier pattern is a wildcard, or the Qualifier implements the
@@ -134,44 +134,44 @@ func (qpat *AttributeQualifierPattern) Matches(q Qualifier) bool {
type qualifierValueEquator interface {
// QualifierValueEquals returns true if the input value is equal to the value held in the
// Qualifier.
QualifierValueEquals(value interface{}) bool
QualifierValueEquals(value any) bool
}
// QualifierValueEquals implementation for boolean qualifiers.
func (q *boolQualifier) QualifierValueEquals(value interface{}) bool {
func (q *boolQualifier) QualifierValueEquals(value any) bool {
bval, ok := value.(bool)
return ok && q.value == bval
}
// QualifierValueEquals implementation for field qualifiers.
func (q *fieldQualifier) QualifierValueEquals(value interface{}) bool {
func (q *fieldQualifier) QualifierValueEquals(value any) bool {
sval, ok := value.(string)
return ok && q.Name == sval
}
// QualifierValueEquals implementation for string qualifiers.
func (q *stringQualifier) QualifierValueEquals(value interface{}) bool {
func (q *stringQualifier) QualifierValueEquals(value any) bool {
sval, ok := value.(string)
return ok && q.value == sval
}
// QualifierValueEquals implementation for int qualifiers.
func (q *intQualifier) QualifierValueEquals(value interface{}) bool {
func (q *intQualifier) QualifierValueEquals(value any) bool {
return numericValueEquals(value, q.celValue)
}
// QualifierValueEquals implementation for uint qualifiers.
func (q *uintQualifier) QualifierValueEquals(value interface{}) bool {
func (q *uintQualifier) QualifierValueEquals(value any) bool {
return numericValueEquals(value, q.celValue)
}
// QualifierValueEquals implementation for double qualifiers.
func (q *doubleQualifier) QualifierValueEquals(value interface{}) bool {
func (q *doubleQualifier) QualifierValueEquals(value any) bool {
return numericValueEquals(value, q.celValue)
}
// numericValueEquals uses CEL equality to determine whether two number values are
func numericValueEquals(value interface{}, celValue ref.Val) bool {
func numericValueEquals(value any, celValue ref.Val) bool {
val := types.DefaultTypeAdapter.NativeToValue(value)
return celValue.Equal(val) == types.True
}
@@ -179,8 +179,8 @@ func numericValueEquals(value interface{}, celValue ref.Val) bool {
// NewPartialAttributeFactory returns an AttributeFactory implementation capable of performing
// AttributePattern matches with PartialActivation inputs.
func NewPartialAttributeFactory(container *containers.Container,
adapter ref.TypeAdapter,
provider ref.TypeProvider) AttributeFactory {
adapter types.Adapter,
provider types.Provider) AttributeFactory {
fac := NewAttributeFactory(container, adapter, provider)
return &partialAttributeFactory{
AttributeFactory: fac,
@@ -193,8 +193,8 @@ func NewPartialAttributeFactory(container *containers.Container,
type partialAttributeFactory struct {
AttributeFactory
container *containers.Container
adapter ref.TypeAdapter
provider ref.TypeProvider
adapter types.Adapter
provider types.Provider
}
// AbsoluteAttribute implementation of the AttributeFactory interface which wraps the
@@ -243,12 +243,15 @@ func (fac *partialAttributeFactory) matchesUnknownPatterns(
vars PartialActivation,
attrID int64,
variableNames []string,
qualifiers []Qualifier) (types.Unknown, error) {
qualifiers []Qualifier) (*types.Unknown, error) {
patterns := vars.UnknownAttributePatterns()
candidateIndices := map[int]struct{}{}
for _, variable := range variableNames {
for i, pat := range patterns {
if pat.VariableMatches(variable) {
if len(qualifiers) == 0 {
return types.NewUnknown(attrID, types.NewAttributeTrail(variable)), nil
}
candidateIndices[i] = struct{}{}
}
}
@@ -257,10 +260,6 @@ func (fac *partialAttributeFactory) matchesUnknownPatterns(
if len(candidateIndices) == 0 {
return nil, nil
}
// Determine whether to return early if there are no qualifiers.
if len(qualifiers) == 0 {
return types.Unknown{attrID}, nil
}
// Resolve the attribute qualifiers into a static set. This prevents more dynamic
// Attribute resolutions than necessary when there are multiple unknown patterns
// that traverse the same Attribute-based qualifier field.
@@ -272,13 +271,9 @@ func (fac *partialAttributeFactory) matchesUnknownPatterns(
if err != nil {
return nil, err
}
unk, isUnk := val.(types.Unknown)
if isUnk {
return unk, nil
}
// If this resolution behavior ever changes, new implementations of the
// qualifierValueEquator may be required to handle proper resolution.
qual, err = fac.NewQualifier(nil, qual.ID(), val)
qual, err = fac.NewQualifier(nil, qual.ID(), val, attr.IsOptional())
if err != nil {
return nil, err
}
@@ -306,7 +301,28 @@ func (fac *partialAttributeFactory) matchesUnknownPatterns(
}
}
if isUnk {
return types.Unknown{matchExprID}, nil
attr := types.NewAttributeTrail(pat.variable)
for i := 0; i < len(qualPats) && i < len(newQuals); i++ {
if qual, ok := newQuals[i].(ConstantQualifier); ok {
switch v := qual.Value().Value().(type) {
case bool:
types.QualifyAttribute[bool](attr, v)
case float64:
types.QualifyAttribute[int64](attr, int64(v))
case int64:
types.QualifyAttribute[int64](attr, v)
case string:
types.QualifyAttribute[string](attr, v)
case uint64:
types.QualifyAttribute[uint64](attr, v)
default:
types.QualifyAttribute[string](attr, fmt.Sprintf("%v", v))
}
} else {
types.QualifyAttribute[string](attr, "*")
}
}
return types.NewUnknown(matchExprID, attr), nil
}
}
return nil, nil
@@ -338,24 +354,10 @@ func (m *attributeMatcher) AddQualifier(qual Qualifier) (Attribute, error) {
return m, nil
}
// Resolve is an implementation of the Attribute interface method which uses the
// attributeMatcher TryResolve implementation rather than the embedded NamespacedAttribute
// Resolve implementation.
func (m *attributeMatcher) Resolve(vars Activation) (interface{}, error) {
obj, found, err := m.TryResolve(vars)
if err != nil {
return nil, err
}
if !found {
return nil, fmt.Errorf("no such attribute: %v", m.NamespacedAttribute)
}
return obj, nil
}
// TryResolve is an implementation of the NamespacedAttribute interface method which tests
// Resolve is an implementation of the NamespacedAttribute interface method which tests
// for matching unknown attribute patterns and returns types.Unknown if present. Otherwise,
// the standard Resolve logic applies.
func (m *attributeMatcher) TryResolve(vars Activation) (interface{}, bool, error) {
func (m *attributeMatcher) Resolve(vars Activation) (any, error) {
id := m.NamespacedAttribute.ID()
// Bug in how partial activation is resolved, should search parents as well.
partial, isPartial := toPartialActivation(vars)
@@ -366,30 +368,23 @@ func (m *attributeMatcher) TryResolve(vars Activation) (interface{}, bool, error
m.CandidateVariableNames(),
m.qualifiers)
if err != nil {
return nil, true, err
return nil, err
}
if unk != nil {
return unk, true, nil
return unk, nil
}
}
return m.NamespacedAttribute.TryResolve(vars)
return m.NamespacedAttribute.Resolve(vars)
}
// Qualify is an implementation of the Qualifier interface method.
func (m *attributeMatcher) Qualify(vars Activation, obj interface{}) (interface{}, error) {
val, err := m.Resolve(vars)
if err != nil {
return nil, err
}
unk, isUnk := val.(types.Unknown)
if isUnk {
return unk, nil
}
qual, err := m.fac.NewQualifier(nil, m.ID(), val)
if err != nil {
return nil, err
}
return qual.Qualify(vars, obj)
func (m *attributeMatcher) Qualify(vars Activation, obj any) (any, error) {
return attrQualify(m.fac, vars, obj, m)
}
// QualifyIfPresent is an implementation of the Qualifier interface method.
func (m *attributeMatcher) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) {
return attrQualifyIfPresent(m.fac, vars, obj, m, presenceOnly)
}
func toPartialActivation(vars Activation) (PartialActivation, bool) {

File diff suppressed because it is too large Load Diff

View File

@@ -29,7 +29,7 @@ type InterpretableDecorator func(Interpretable) (Interpretable, error)
func decObserveEval(observer EvalObserver) InterpretableDecorator {
return func(i Interpretable) (Interpretable, error) {
switch inst := i.(type) {
case *evalWatch, *evalWatchAttr, *evalWatchConst:
case *evalWatch, *evalWatchAttr, *evalWatchConst, *evalWatchConstructor:
// these instruction are already watching, return straight-away.
return i, nil
case InterpretableAttribute:
@@ -42,6 +42,11 @@ func decObserveEval(observer EvalObserver) InterpretableDecorator {
InterpretableConst: inst,
observer: observer,
}, nil
case InterpretableConstructor:
return &evalWatchConstructor{
constructor: inst,
observer: observer,
}, nil
default:
return &evalWatch{
Interpretable: i,
@@ -70,15 +75,13 @@ func decDisableShortcircuits() InterpretableDecorator {
switch expr := i.(type) {
case *evalOr:
return &evalExhaustiveOr{
id: expr.id,
lhs: expr.lhs,
rhs: expr.rhs,
id: expr.id,
terms: expr.terms,
}, nil
case *evalAnd:
return &evalExhaustiveAnd{
id: expr.id,
lhs: expr.lhs,
rhs: expr.rhs,
id: expr.id,
terms: expr.terms,
}, nil
case *evalFold:
expr.exhaustive = true
@@ -224,8 +227,8 @@ func maybeOptimizeSetMembership(i Interpretable, inlist InterpretableCall) (Inte
valueSet := make(map[ref.Val]ref.Val)
for it.HasNext() == types.True {
elem := it.Next()
if !types.IsPrimitiveType(elem) {
// Note, non-primitive type are not yet supported.
if !types.IsPrimitiveType(elem) || elem.Type() == types.BytesType {
// Note, non-primitive type are not yet supported, and []byte isn't hashable.
return i, nil
}
valueSet[elem] = types.True

View File

@@ -17,7 +17,7 @@ package interpreter
import (
"fmt"
"github.com/google/cel-go/interpreter/functions"
"github.com/google/cel-go/common/functions"
)
// Dispatcher resolves function calls to their appropriate overload.

View File

@@ -66,7 +66,11 @@ func (s *evalState) Value(exprID int64) (ref.Val, bool) {
// SetValue is an implementation of the EvalState interface method.
func (s *evalState) SetValue(exprID int64, val ref.Val) {
s.values[exprID] = val
if val == nil {
delete(s.values, exprID)
} else {
s.values[exprID] = val
}
}
// Reset implements the EvalState interface method.

View File

@@ -0,0 +1,383 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package interpreter
import (
"errors"
"fmt"
"strconv"
"strings"
"unicode"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
type typeVerifier func(int64, ...ref.Type) (bool, error)
// InterpolateFormattedString checks the syntax and cardinality of any string.format calls present in the expression and reports
// any errors at compile time.
func InterpolateFormattedString(verifier typeVerifier) InterpretableDecorator {
return func(inter Interpretable) (Interpretable, error) {
call, ok := inter.(InterpretableCall)
if !ok {
return inter, nil
}
if call.OverloadID() != "string_format" {
return inter, nil
}
args := call.Args()
if len(args) != 2 {
return nil, fmt.Errorf("wrong number of arguments to string.format (expected 2, got %d)", len(args))
}
fmtStrInter, ok := args[0].(InterpretableConst)
if !ok {
return inter, nil
}
var fmtArgsInter InterpretableConstructor
fmtArgsInter, ok = args[1].(InterpretableConstructor)
if !ok {
return inter, nil
}
if fmtArgsInter.Type() != types.ListType {
// don't necessarily return an error since the list may be DynType
return inter, nil
}
formatStr := fmtStrInter.Value().Value().(string)
initVals := fmtArgsInter.InitVals()
formatCheck := &formatCheck{
args: initVals,
verifier: verifier,
}
// use a placeholder locale, since locale doesn't affect syntax
_, err := ParseFormatString(formatStr, formatCheck, formatCheck, "en_US")
if err != nil {
return nil, err
}
seenArgs := formatCheck.argsRequested
if len(initVals) > seenArgs {
return nil, fmt.Errorf("too many arguments supplied to string.format (expected %d, got %d)", seenArgs, len(initVals))
}
return inter, nil
}
}
type formatCheck struct {
args []Interpretable
argsRequested int
curArgIndex int64
enableCheckArgTypes bool
verifier typeVerifier
}
func (c *formatCheck) String(arg ref.Val, locale string) (string, error) {
valid, err := verifyString(c.args[c.curArgIndex], c.verifier)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps")
}
return "", nil
}
func (c *formatCheck) Decimal(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.IntType, types.UintType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("integer clause can only be used on integers")
}
return "", nil
}
func (c *formatCheck) Fixed(precision *int) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
// we allow StringType since "NaN", "Infinity", and "-Infinity" are also valid values
valid, err := c.verifier(id, types.DoubleType, types.StringType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("fixed-point clause can only be used on doubles")
}
return "", nil
}
}
func (c *formatCheck) Scientific(precision *int) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.DoubleType, types.StringType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("scientific clause can only be used on doubles")
}
return "", nil
}
}
func (c *formatCheck) Binary(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.IntType, types.UintType, types.BoolType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("only integers and bools can be formatted as binary")
}
return "", nil
}
func (c *formatCheck) Hex(useUpper bool) func(ref.Val, string) (string, error) {
return func(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.IntType, types.UintType, types.StringType, types.BytesType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("only integers, byte buffers, and strings can be formatted as hex")
}
return "", nil
}
}
func (c *formatCheck) Octal(arg ref.Val, locale string) (string, error) {
id := c.args[c.curArgIndex].ID()
valid, err := c.verifier(id, types.IntType, types.UintType)
if err != nil {
return "", err
}
if !valid {
return "", errors.New("octal clause can only be used on integers")
}
return "", nil
}
func (c *formatCheck) Arg(index int64) (ref.Val, error) {
c.argsRequested++
c.curArgIndex = index
// return a dummy value - this is immediately passed to back to us
// through one of the FormatCallback functions, so anything will do
return types.Int(0), nil
}
func (c *formatCheck) ArgSize() int64 {
return int64(len(c.args))
}
func verifyString(sub Interpretable, verifier typeVerifier) (bool, error) {
subVerified, err := verifier(sub.ID(),
types.ListType, types.MapType, types.IntType, types.UintType, types.DoubleType,
types.BoolType, types.StringType, types.TimestampType, types.BytesType, types.DurationType, types.TypeType, types.NullType)
if err != nil {
return false, err
}
if !subVerified {
return false, nil
}
con, ok := sub.(InterpretableConstructor)
if ok {
members := con.InitVals()
for _, m := range members {
// recursively verify if we're dealing with a list/map
verified, err := verifyString(m, verifier)
if err != nil {
return false, err
}
if !verified {
return false, nil
}
}
}
return true, nil
}
// FormatStringInterpolator is an interface that allows user-defined behavior
// for formatting clause implementations, as well as argument retrieval.
// Each function is expected to support the appropriate types as laid out in
// the string.format documentation, and to return an error if given an inappropriate type.
type FormatStringInterpolator interface {
// String takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a string, or an error if one occurred.
String(ref.Val, string) (string, error)
// Decimal takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a decimal integer, or an error if one occurred.
Decimal(ref.Val, string) (string, error)
// Fixed takes an int pointer representing precision (or nil if none was given) and
// returns a function operating in a similar manner to String and Decimal, taking a
// ref.Val and locale and returning the appropriate string. A closure is returned
// so precision can be set without needing an additional function call/configuration.
Fixed(*int) func(ref.Val, string) (string, error)
// Scientific functions identically to Fixed, except the string returned from the closure
// is expected to be in scientific notation.
Scientific(*int) func(ref.Val, string) (string, error)
// Binary takes a ref.Val and a string representing the current locale identifier
// and returns the Val formatted as a binary integer, or an error if one occurred.
Binary(ref.Val, string) (string, error)
// Hex takes a boolean that, if true, indicates the hex string output by the returned
// closure should use uppercase letters for A-F.
Hex(bool) func(ref.Val, string) (string, error)
// Octal takes a ref.Val and a string representing the current locale identifier and
// returns the Val formatted in octal, or an error if one occurred.
Octal(ref.Val, string) (string, error)
}
// FormatList is an interface that allows user-defined list-like datatypes to be used
// for formatting clause implementations.
type FormatList interface {
// Arg returns the ref.Val at the given index, or an error if one occurred.
Arg(int64) (ref.Val, error)
// ArgSize returns the length of the argument list.
ArgSize() int64
}
type clauseImpl func(ref.Val, string) (string, error)
// ParseFormatString formats a string according to the string.format syntax, taking the clause implementations
// from the provided FormatCallback and the args from the given FormatList.
func ParseFormatString(formatStr string, callback FormatStringInterpolator, list FormatList, locale string) (string, error) {
i := 0
argIndex := 0
var builtStr strings.Builder
for i < len(formatStr) {
if formatStr[i] == '%' {
if i+1 < len(formatStr) && formatStr[i+1] == '%' {
err := builtStr.WriteByte('%')
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i += 2
continue
} else {
argAny, err := list.Arg(int64(argIndex))
if err != nil {
return "", err
}
if i+1 >= len(formatStr) {
return "", errors.New("unexpected end of string")
}
if int64(argIndex) >= list.ArgSize() {
return "", fmt.Errorf("index %d out of range", argIndex)
}
numRead, val, refErr := parseAndFormatClause(formatStr[i:], argAny, callback, list, locale)
if refErr != nil {
return "", refErr
}
_, err = builtStr.WriteString(val)
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i += numRead
argIndex++
}
} else {
err := builtStr.WriteByte(formatStr[i])
if err != nil {
return "", fmt.Errorf("error writing format string: %w", err)
}
i++
}
}
return builtStr.String(), nil
}
// parseAndFormatClause parses the format clause at the start of the given string with val, and returns
// how many characters were consumed and the substituted string form of val, or an error if one occurred.
func parseAndFormatClause(formatStr string, val ref.Val, callback FormatStringInterpolator, list FormatList, locale string) (int, string, error) {
i := 1
read, formatter, err := parseFormattingClause(formatStr[i:], callback)
i += read
if err != nil {
return -1, "", fmt.Errorf("could not parse formatting clause: %s", err)
}
valStr, err := formatter(val, locale)
if err != nil {
return -1, "", fmt.Errorf("error during formatting: %s", err)
}
return i, valStr, nil
}
func parseFormattingClause(formatStr string, callback FormatStringInterpolator) (int, clauseImpl, error) {
i := 0
read, precision, err := parsePrecision(formatStr[i:])
i += read
if err != nil {
return -1, nil, fmt.Errorf("error while parsing precision: %w", err)
}
r := rune(formatStr[i])
i++
switch r {
case 's':
return i, callback.String, nil
case 'd':
return i, callback.Decimal, nil
case 'f':
return i, callback.Fixed(precision), nil
case 'e':
return i, callback.Scientific(precision), nil
case 'b':
return i, callback.Binary, nil
case 'x', 'X':
return i, callback.Hex(unicode.IsUpper(r)), nil
case 'o':
return i, callback.Octal, nil
default:
return -1, nil, fmt.Errorf("unrecognized formatting clause \"%c\"", r)
}
}
func parsePrecision(formatStr string) (int, *int, error) {
i := 0
if formatStr[i] != '.' {
return i, nil, nil
}
i++
var buffer strings.Builder
for {
if i >= len(formatStr) {
return -1, nil, errors.New("could not find end of precision specifier")
}
if !isASCIIDigit(rune(formatStr[i])) {
break
}
buffer.WriteByte(formatStr[i])
i++
}
precision, err := strconv.Atoi(buffer.String())
if err != nil {
return -1, nil, fmt.Errorf("error while converting precision to integer: %w", err)
}
return i, &precision, nil
}
func isASCIIDigit(r rune) bool {
return r <= unicode.MaxASCII && unicode.IsDigit(r)
}

View File

@@ -7,16 +7,11 @@ package(
go_library(
name = "go_default_library",
srcs = [
srcs = [
"functions.go",
"standard.go",
],
importpath = "github.com/google/cel-go/interpreter/functions",
deps = [
"//common/operators:go_default_library",
"//common/overloads:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//common/functions:go_default_library",
],
)

View File

@@ -16,7 +16,7 @@
// interpreter and as declared within the checker#StandardDeclarations.
package functions
import "github.com/google/cel-go/common/types/ref"
import fn "github.com/google/cel-go/common/functions"
// Overload defines a named overload of a function, indicating an operand trait
// which must be present on the first argument to the overload as well as one
@@ -26,37 +26,14 @@ import "github.com/google/cel-go/common/types/ref"
// and the specializations simplify the call contract for implementers of
// types with operator overloads. Any added complexity is assumed to be handled
// by the generic FunctionOp.
type Overload struct {
// Operator name as written in an expression or defined within
// operators.go.
Operator string
// Operand trait used to dispatch the call. The zero-value indicates a
// global function overload or that one of the Unary / Binary / Function
// definitions should be used to execute the call.
OperandTrait int
// Unary defines the overload with a UnaryOp implementation. May be nil.
Unary UnaryOp
// Binary defines the overload with a BinaryOp implementation. May be nil.
Binary BinaryOp
// Function defines the overload with a FunctionOp implementation. May be
// nil.
Function FunctionOp
// NonStrict specifies whether the Overload will tolerate arguments that
// are types.Err or types.Unknown.
NonStrict bool
}
type Overload = fn.Overload
// UnaryOp is a function that takes a single value and produces an output.
type UnaryOp func(value ref.Val) ref.Val
type UnaryOp = fn.UnaryOp
// BinaryOp is a function that takes two values and produces an output.
type BinaryOp func(lhs ref.Val, rhs ref.Val) ref.Val
type BinaryOp = fn.BinaryOp
// FunctionOp is a function with accepts zero or more arguments and produces
// an value (as interface{}) or error as a result.
type FunctionOp func(values ...ref.Val) ref.Val
// a value or error as a result.
type FunctionOp = fn.FunctionOp

View File

@@ -1,270 +0,0 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package functions
import (
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
// StandardOverloads returns the definitions of the built-in overloads.
func StandardOverloads() []*Overload {
return []*Overload{
// Logical not (!a)
{
Operator: operators.LogicalNot,
OperandTrait: traits.NegatorType,
Unary: func(value ref.Val) ref.Val {
if !types.IsBool(value) {
return types.ValOrErr(value, "no such overload")
}
return value.(traits.Negater).Negate()
}},
// Not strictly false: IsBool(a) ? a : true
{
Operator: operators.NotStrictlyFalse,
Unary: notStrictlyFalse},
// Deprecated: not strictly false, may be overridden in the environment.
{
Operator: operators.OldNotStrictlyFalse,
Unary: notStrictlyFalse},
// Less than operator
{Operator: operators.Less,
OperandTrait: traits.ComparerType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
cmp := lhs.(traits.Comparer).Compare(rhs)
if cmp == types.IntNegOne {
return types.True
}
if cmp == types.IntOne || cmp == types.IntZero {
return types.False
}
return cmp
}},
// Less than or equal operator
{Operator: operators.LessEquals,
OperandTrait: traits.ComparerType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
cmp := lhs.(traits.Comparer).Compare(rhs)
if cmp == types.IntNegOne || cmp == types.IntZero {
return types.True
}
if cmp == types.IntOne {
return types.False
}
return cmp
}},
// Greater than operator
{Operator: operators.Greater,
OperandTrait: traits.ComparerType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
cmp := lhs.(traits.Comparer).Compare(rhs)
if cmp == types.IntOne {
return types.True
}
if cmp == types.IntNegOne || cmp == types.IntZero {
return types.False
}
return cmp
}},
// Greater than equal operators
{Operator: operators.GreaterEquals,
OperandTrait: traits.ComparerType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
cmp := lhs.(traits.Comparer).Compare(rhs)
if cmp == types.IntOne || cmp == types.IntZero {
return types.True
}
if cmp == types.IntNegOne {
return types.False
}
return cmp
}},
// Add operator
{Operator: operators.Add,
OperandTrait: traits.AdderType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
return lhs.(traits.Adder).Add(rhs)
}},
// Subtract operators
{Operator: operators.Subtract,
OperandTrait: traits.SubtractorType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
return lhs.(traits.Subtractor).Subtract(rhs)
}},
// Multiply operator
{Operator: operators.Multiply,
OperandTrait: traits.MultiplierType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
return lhs.(traits.Multiplier).Multiply(rhs)
}},
// Divide operator
{Operator: operators.Divide,
OperandTrait: traits.DividerType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
return lhs.(traits.Divider).Divide(rhs)
}},
// Modulo operator
{Operator: operators.Modulo,
OperandTrait: traits.ModderType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
return lhs.(traits.Modder).Modulo(rhs)
}},
// Negate operator
{Operator: operators.Negate,
OperandTrait: traits.NegatorType,
Unary: func(value ref.Val) ref.Val {
if types.IsBool(value) {
return types.ValOrErr(value, "no such overload")
}
return value.(traits.Negater).Negate()
}},
// Index operator
{Operator: operators.Index,
OperandTrait: traits.IndexerType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
return lhs.(traits.Indexer).Get(rhs)
}},
// Size function
{Operator: overloads.Size,
OperandTrait: traits.SizerType,
Unary: func(value ref.Val) ref.Val {
return value.(traits.Sizer).Size()
}},
// In operator
{Operator: operators.In, Binary: inAggregate},
// Deprecated: in operator, may be overridden in the environment.
{Operator: operators.OldIn, Binary: inAggregate},
// Matches function
{Operator: overloads.Matches,
OperandTrait: traits.MatcherType,
Binary: func(lhs ref.Val, rhs ref.Val) ref.Val {
return lhs.(traits.Matcher).Match(rhs)
}},
// Type conversion functions
// TODO: verify type conversion safety of numeric values.
// Int conversions.
{Operator: overloads.TypeConvertInt,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.IntType)
}},
// Uint conversions.
{Operator: overloads.TypeConvertUint,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.UintType)
}},
// Double conversions.
{Operator: overloads.TypeConvertDouble,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.DoubleType)
}},
// Bool conversions.
{Operator: overloads.TypeConvertBool,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.BoolType)
}},
// Bytes conversions.
{Operator: overloads.TypeConvertBytes,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.BytesType)
}},
// String conversions.
{Operator: overloads.TypeConvertString,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.StringType)
}},
// Timestamp conversions.
{Operator: overloads.TypeConvertTimestamp,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.TimestampType)
}},
// Duration conversions.
{Operator: overloads.TypeConvertDuration,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.DurationType)
}},
// Type operations.
{Operator: overloads.TypeConvertType,
Unary: func(value ref.Val) ref.Val {
return value.ConvertToType(types.TypeType)
}},
// Dyn conversion (identity function).
{Operator: overloads.TypeConvertDyn,
Unary: func(value ref.Val) ref.Val {
return value
}},
{Operator: overloads.Iterator,
OperandTrait: traits.IterableType,
Unary: func(value ref.Val) ref.Val {
return value.(traits.Iterable).Iterator()
}},
{Operator: overloads.HasNext,
OperandTrait: traits.IteratorType,
Unary: func(value ref.Val) ref.Val {
return value.(traits.Iterator).HasNext()
}},
{Operator: overloads.Next,
OperandTrait: traits.IteratorType,
Unary: func(value ref.Val) ref.Val {
return value.(traits.Iterator).Next()
}},
}
}
func notStrictlyFalse(value ref.Val) ref.Val {
if types.IsBool(value) {
return value
}
return types.True
}
func inAggregate(lhs ref.Val, rhs ref.Val) ref.Val {
if rhs.Type().HasTrait(traits.ContainerType) {
return rhs.(traits.Container).Contains(lhs)
}
return types.ValOrErr(rhs, "no such overload")
}

File diff suppressed because it is too large Load Diff

View File

@@ -18,9 +18,10 @@
package interpreter
import (
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter/functions"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
@@ -29,19 +30,17 @@ import (
type Interpreter interface {
// NewInterpretable creates an Interpretable from a checked expression and an
// optional list of InterpretableDecorator values.
NewInterpretable(checked *exprpb.CheckedExpr,
decorators ...InterpretableDecorator) (Interpretable, error)
NewInterpretable(checked *ast.CheckedAST, decorators ...InterpretableDecorator) (Interpretable, error)
// NewUncheckedInterpretable returns an Interpretable from a parsed expression
// and an optional list of InterpretableDecorator values.
NewUncheckedInterpretable(expr *exprpb.Expr,
decorators ...InterpretableDecorator) (Interpretable, error)
NewUncheckedInterpretable(expr *exprpb.Expr, decorators ...InterpretableDecorator) (Interpretable, error)
}
// EvalObserver is a functional interface that accepts an expression id and an observed value.
// The id identifies the expression that was evaluated, the programStep is the Interpretable or Qualifier that
// was evaluated and value is the result of the evaluation.
type EvalObserver func(id int64, programStep interface{}, value ref.Val)
type EvalObserver func(id int64, programStep any, value ref.Val)
// Observe constructs a decorator that calls all the provided observers in order after evaluating each Interpretable
// or Qualifier during program evaluation.
@@ -49,7 +48,7 @@ func Observe(observers ...EvalObserver) InterpretableDecorator {
if len(observers) == 1 {
return decObserveEval(observers[0])
}
observeFn := func(id int64, programStep interface{}, val ref.Val) {
observeFn := func(id int64, programStep any, val ref.Val) {
for _, observer := range observers {
observer(id, programStep, val)
}
@@ -96,7 +95,7 @@ func TrackState(state EvalState) InterpretableDecorator {
// This decorator is not thread-safe, and the EvalState must be reset between Eval()
// calls.
func EvalStateObserver(state EvalState) EvalObserver {
return func(id int64, programStep interface{}, val ref.Val) {
return func(id int64, programStep any, val ref.Val) {
state.SetValue(id, val)
}
}
@@ -156,8 +155,8 @@ func CompileRegexConstants(regexOptimizations ...*RegexOptimization) Interpretab
type exprInterpreter struct {
dispatcher Dispatcher
container *containers.Container
provider ref.TypeProvider
adapter ref.TypeAdapter
provider types.Provider
adapter types.Adapter
attrFactory AttributeFactory
}
@@ -165,8 +164,8 @@ type exprInterpreter struct {
// throughout the Eval of all Interpretable instances generated from it.
func NewInterpreter(dispatcher Dispatcher,
container *containers.Container,
provider ref.TypeProvider,
adapter ref.TypeAdapter,
provider types.Provider,
adapter types.Adapter,
attrFactory AttributeFactory) Interpreter {
return &exprInterpreter{
dispatcher: dispatcher,
@@ -176,20 +175,9 @@ func NewInterpreter(dispatcher Dispatcher,
attrFactory: attrFactory}
}
// NewStandardInterpreter builds a Dispatcher and TypeProvider with support for all of the CEL
// builtins defined in the language definition.
func NewStandardInterpreter(container *containers.Container,
provider ref.TypeProvider,
adapter ref.TypeAdapter,
resolver AttributeFactory) Interpreter {
dispatcher := NewDispatcher()
dispatcher.Add(functions.StandardOverloads()...)
return NewInterpreter(dispatcher, container, provider, adapter, resolver)
}
// NewIntepretable implements the Interpreter interface method.
func (i *exprInterpreter) NewInterpretable(
checked *exprpb.CheckedExpr,
checked *ast.CheckedAST,
decorators ...InterpretableDecorator) (Interpretable, error) {
p := newPlanner(
i.dispatcher,
@@ -199,7 +187,7 @@ func (i *exprInterpreter) NewInterpretable(
i.container,
checked,
decorators...)
return p.Plan(checked.GetExpr())
return p.Plan(checked.Expr)
}
// NewUncheckedIntepretable implements the Interpreter interface method.

View File

@@ -18,11 +18,12 @@ import (
"fmt"
"strings"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter/functions"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
@@ -38,11 +39,11 @@ type interpretablePlanner interface {
// functions, types, and namespaced identifiers at plan time rather than at runtime since
// it only needs to be done once and may be semi-expensive to compute.
func newPlanner(disp Dispatcher,
provider ref.TypeProvider,
adapter ref.TypeAdapter,
provider types.Provider,
adapter types.Adapter,
attrFactory AttributeFactory,
cont *containers.Container,
checked *exprpb.CheckedExpr,
checked *ast.CheckedAST,
decorators ...InterpretableDecorator) interpretablePlanner {
return &planner{
disp: disp,
@@ -50,8 +51,8 @@ func newPlanner(disp Dispatcher,
adapter: adapter,
attrFactory: attrFactory,
container: cont,
refMap: checked.GetReferenceMap(),
typeMap: checked.GetTypeMap(),
refMap: checked.ReferenceMap,
typeMap: checked.TypeMap,
decorators: decorators,
}
}
@@ -60,8 +61,8 @@ func newPlanner(disp Dispatcher,
// TypeAdapter, and Container to resolve functions and types at plan time. Namespaces present in
// Select expressions are resolved lazily at evaluation time.
func newUncheckedPlanner(disp Dispatcher,
provider ref.TypeProvider,
adapter ref.TypeAdapter,
provider types.Provider,
adapter types.Adapter,
attrFactory AttributeFactory,
cont *containers.Container,
decorators ...InterpretableDecorator) interpretablePlanner {
@@ -71,8 +72,8 @@ func newUncheckedPlanner(disp Dispatcher,
adapter: adapter,
attrFactory: attrFactory,
container: cont,
refMap: make(map[int64]*exprpb.Reference),
typeMap: make(map[int64]*exprpb.Type),
refMap: make(map[int64]*ast.ReferenceInfo),
typeMap: make(map[int64]*types.Type),
decorators: decorators,
}
}
@@ -80,12 +81,12 @@ func newUncheckedPlanner(disp Dispatcher,
// planner is an implementation of the interpretablePlanner interface.
type planner struct {
disp Dispatcher
provider ref.TypeProvider
adapter ref.TypeAdapter
provider types.Provider
adapter types.Adapter
attrFactory AttributeFactory
container *containers.Container
refMap map[int64]*exprpb.Reference
typeMap map[int64]*exprpb.Type
refMap map[int64]*ast.ReferenceInfo
typeMap map[int64]*types.Type
decorators []InterpretableDecorator
}
@@ -144,22 +145,19 @@ func (p *planner) planIdent(expr *exprpb.Expr) (Interpretable, error) {
}, nil
}
func (p *planner) planCheckedIdent(id int64, identRef *exprpb.Reference) (Interpretable, error) {
func (p *planner) planCheckedIdent(id int64, identRef *ast.ReferenceInfo) (Interpretable, error) {
// Plan a constant reference if this is the case for this simple identifier.
if identRef.GetValue() != nil {
return p.Plan(&exprpb.Expr{Id: id,
ExprKind: &exprpb.Expr_ConstExpr{
ConstExpr: identRef.GetValue(),
}})
if identRef.Value != nil {
return NewConstValue(id, identRef.Value), nil
}
// Check to see whether the type map indicates this is a type name. All types should be
// registered with the provider.
cType := p.typeMap[id]
if cType.GetType() != nil {
cVal, found := p.provider.FindIdent(identRef.GetName())
if cType.Kind() == types.TypeKind {
cVal, found := p.provider.FindIdent(identRef.Name)
if !found {
return nil, fmt.Errorf("reference to undefined type: %s", identRef.GetName())
return nil, fmt.Errorf("reference to undefined type: %s", identRef.Name)
}
return NewConstValue(id, cVal), nil
}
@@ -167,7 +165,7 @@ func (p *planner) planCheckedIdent(id int64, identRef *exprpb.Reference) (Interp
// Otherwise, return the attribute for the resolved identifier name.
return &evalAttr{
adapter: p.adapter,
attr: p.attrFactory.AbsoluteAttribute(id, identRef.GetName()),
attr: p.attrFactory.AbsoluteAttribute(id, identRef.Name),
}, nil
}
@@ -189,16 +187,7 @@ func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) {
if err != nil {
return nil, err
}
// Determine the field type if this is a proto message type.
var fieldType *ref.FieldType
opType := p.typeMap[sel.GetOperand().GetId()]
if opType.GetMessageType() != "" {
ft, found := p.provider.FindFieldType(opType.GetMessageType(), sel.GetField())
if found && ft.IsSet != nil && ft.GetFrom != nil {
fieldType = ft
}
}
// If the Select was marked TestOnly, this is a presence test.
//
@@ -211,37 +200,31 @@ func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) {
// If a string named 'a.b.c' is declared in the environment and referenced within `has(a.b.c)`,
// it is not clear whether has should error or follow the convention defined for structured
// values.
if sel.TestOnly {
// Return the test only eval expression.
return &evalTestOnly{
id: expr.GetId(),
field: types.String(sel.GetField()),
fieldType: fieldType,
op: op,
}, nil
}
// Build a qualifier.
qual, err := p.attrFactory.NewQualifier(
opType, expr.GetId(), sel.GetField())
if err != nil {
return nil, err
}
// Lastly, create a field selection Interpretable.
// Establish the attribute reference.
attr, isAttr := op.(InterpretableAttribute)
if isAttr {
_, err = attr.AddQualifier(qual)
return attr, err
if !isAttr {
attr, err = p.relativeAttr(op.ID(), op, false)
if err != nil {
return nil, err
}
}
relAttr, err := p.relativeAttr(op.ID(), op)
// Build a qualifier for the attribute.
qual, err := p.attrFactory.NewQualifier(opType, expr.GetId(), sel.GetField(), false)
if err != nil {
return nil, err
}
_, err = relAttr.AddQualifier(qual)
if err != nil {
return nil, err
// Modify the attribute to be test-only.
if sel.GetTestOnly() {
attr = &evalTestOnly{
id: expr.GetId(),
InterpretableAttribute: attr,
}
}
return relAttr, nil
// Append the qualifier on the attribute.
_, err = attr.AddQualifier(qual)
return attr, err
}
// planCall creates a callable Interpretable while specializing for common functions and invocation
@@ -286,7 +269,9 @@ func (p *planner) planCall(expr *exprpb.Expr) (Interpretable, error) {
case operators.NotEquals:
return p.planCallNotEqual(expr, args)
case operators.Index:
return p.planCallIndex(expr, args)
return p.planCallIndex(expr, args, false)
case operators.OptSelect, operators.OptIndex:
return p.planCallIndex(expr, args, true)
}
// Otherwise, generate Interpretable calls specialized by argument count.
@@ -423,8 +408,7 @@ func (p *planner) planCallVarArgs(expr *exprpb.Expr,
}
// planCallEqual generates an equals (==) Interpretable.
func (p *planner) planCallEqual(expr *exprpb.Expr,
args []Interpretable) (Interpretable, error) {
func (p *planner) planCallEqual(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
return &evalEq{
id: expr.GetId(),
lhs: args[0],
@@ -433,8 +417,7 @@ func (p *planner) planCallEqual(expr *exprpb.Expr,
}
// planCallNotEqual generates a not equals (!=) Interpretable.
func (p *planner) planCallNotEqual(expr *exprpb.Expr,
args []Interpretable) (Interpretable, error) {
func (p *planner) planCallNotEqual(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
return &evalNe{
id: expr.GetId(),
lhs: args[0],
@@ -443,30 +426,24 @@ func (p *planner) planCallNotEqual(expr *exprpb.Expr,
}
// planCallLogicalAnd generates a logical and (&&) Interpretable.
func (p *planner) planCallLogicalAnd(expr *exprpb.Expr,
args []Interpretable) (Interpretable, error) {
func (p *planner) planCallLogicalAnd(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
return &evalAnd{
id: expr.GetId(),
lhs: args[0],
rhs: args[1],
id: expr.GetId(),
terms: args,
}, nil
}
// planCallLogicalOr generates a logical or (||) Interpretable.
func (p *planner) planCallLogicalOr(expr *exprpb.Expr,
args []Interpretable) (Interpretable, error) {
func (p *planner) planCallLogicalOr(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
return &evalOr{
id: expr.GetId(),
lhs: args[0],
rhs: args[1],
id: expr.GetId(),
terms: args,
}, nil
}
// planCallConditional generates a conditional / ternary (c ? t : f) Interpretable.
func (p *planner) planCallConditional(expr *exprpb.Expr,
args []Interpretable) (Interpretable, error) {
func (p *planner) planCallConditional(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) {
cond := args[0]
t := args[1]
var tAttr Attribute
truthyAttr, isTruthyAttr := t.(InterpretableAttribute)
@@ -493,48 +470,54 @@ func (p *planner) planCallConditional(expr *exprpb.Expr,
// planCallIndex either extends an attribute with the argument to the index operation, or creates
// a relative attribute based on the return of a function call or operation.
func (p *planner) planCallIndex(expr *exprpb.Expr,
args []Interpretable) (Interpretable, error) {
func (p *planner) planCallIndex(expr *exprpb.Expr, args []Interpretable, optional bool) (Interpretable, error) {
op := args[0]
ind := args[1]
opAttr, err := p.relativeAttr(op.ID(), op)
if err != nil {
return nil, err
}
opType := p.typeMap[expr.GetCallExpr().GetTarget().GetId()]
indConst, isIndConst := ind.(InterpretableConst)
if isIndConst {
qual, err := p.attrFactory.NewQualifier(
opType, expr.GetId(), indConst.Value())
opType := p.typeMap[op.ID()]
// Establish the attribute reference.
var err error
attr, isAttr := op.(InterpretableAttribute)
if !isAttr {
attr, err = p.relativeAttr(op.ID(), op, false)
if err != nil {
return nil, err
}
_, err = opAttr.AddQualifier(qual)
return opAttr, err
}
indAttr, isIndAttr := ind.(InterpretableAttribute)
if isIndAttr {
qual, err := p.attrFactory.NewQualifier(
opType, expr.GetId(), indAttr)
if err != nil {
return nil, err
}
_, err = opAttr.AddQualifier(qual)
return opAttr, err
// Construct the qualifier type.
var qual Qualifier
switch ind := ind.(type) {
case InterpretableConst:
qual, err = p.attrFactory.NewQualifier(opType, expr.GetId(), ind.Value(), optional)
case InterpretableAttribute:
qual, err = p.attrFactory.NewQualifier(opType, expr.GetId(), ind, optional)
default:
qual, err = p.relativeAttr(expr.GetId(), ind, optional)
}
indQual, err := p.relativeAttr(expr.GetId(), ind)
if err != nil {
return nil, err
}
_, err = opAttr.AddQualifier(indQual)
return opAttr, err
// Add the qualifier to the attribute
_, err = attr.AddQualifier(qual)
return attr, err
}
// planCreateList generates a list construction Interpretable.
func (p *planner) planCreateList(expr *exprpb.Expr) (Interpretable, error) {
list := expr.GetListExpr()
elems := make([]Interpretable, len(list.GetElements()))
for i, elem := range list.GetElements() {
optionalIndices := list.GetOptionalIndices()
elements := list.GetElements()
optionals := make([]bool, len(elements))
for _, index := range optionalIndices {
if index < 0 || index >= int32(len(elements)) {
return nil, fmt.Errorf("optional index %d out of element bounds [0, %d]", index, len(elements))
}
optionals[index] = true
}
elems := make([]Interpretable, len(elements))
for i, elem := range elements {
elemVal, err := p.Plan(elem)
if err != nil {
return nil, err
@@ -542,9 +525,11 @@ func (p *planner) planCreateList(expr *exprpb.Expr) (Interpretable, error) {
elems[i] = elemVal
}
return &evalList{
id: expr.GetId(),
elems: elems,
adapter: p.adapter,
id: expr.GetId(),
elems: elems,
optionals: optionals,
hasOptionals: len(optionals) != 0,
adapter: p.adapter,
}, nil
}
@@ -555,6 +540,7 @@ func (p *planner) planCreateStruct(expr *exprpb.Expr) (Interpretable, error) {
return p.planCreateObj(expr)
}
entries := str.GetEntries()
optionals := make([]bool, len(entries))
keys := make([]Interpretable, len(entries))
vals := make([]Interpretable, len(entries))
for i, entry := range entries {
@@ -569,23 +555,27 @@ func (p *planner) planCreateStruct(expr *exprpb.Expr) (Interpretable, error) {
return nil, err
}
vals[i] = valVal
optionals[i] = entry.GetOptionalEntry()
}
return &evalMap{
id: expr.GetId(),
keys: keys,
vals: vals,
adapter: p.adapter,
id: expr.GetId(),
keys: keys,
vals: vals,
optionals: optionals,
hasOptionals: len(optionals) != 0,
adapter: p.adapter,
}, nil
}
// planCreateObj generates an object construction Interpretable.
func (p *planner) planCreateObj(expr *exprpb.Expr) (Interpretable, error) {
obj := expr.GetStructExpr()
typeName, defined := p.resolveTypeName(obj.MessageName)
typeName, defined := p.resolveTypeName(obj.GetMessageName())
if !defined {
return nil, fmt.Errorf("unknown type: %s", typeName)
return nil, fmt.Errorf("unknown type: %s", obj.GetMessageName())
}
entries := obj.GetEntries()
optionals := make([]bool, len(entries))
fields := make([]string, len(entries))
vals := make([]Interpretable, len(entries))
for i, entry := range entries {
@@ -595,13 +585,16 @@ func (p *planner) planCreateObj(expr *exprpb.Expr) (Interpretable, error) {
return nil, err
}
vals[i] = val
optionals[i] = entry.GetOptionalEntry()
}
return &evalObj{
id: expr.GetId(),
typeName: typeName,
fields: fields,
vals: vals,
provider: p.provider,
id: expr.GetId(),
typeName: typeName,
fields: fields,
vals: vals,
optionals: optionals,
hasOptionals: len(optionals) != 0,
provider: p.provider,
}, nil
}
@@ -679,7 +672,7 @@ func (p *planner) constValue(c *exprpb.Constant) (ref.Val, error) {
// namespace resolution rules to it in a scan over possible matching types in the TypeProvider.
func (p *planner) resolveTypeName(typeName string) (string, bool) {
for _, qualifiedTypeName := range p.container.ResolveCandidateNames(typeName) {
if _, found := p.provider.FindType(qualifiedTypeName); found {
if _, found := p.provider.FindStructType(qualifiedTypeName); found {
return qualifiedTypeName, true
}
}
@@ -706,8 +699,8 @@ func (p *planner) resolveFunction(expr *exprpb.Expr) (*exprpb.Expr, string, stri
// function name as the fnName value.
oRef, hasOverload := p.refMap[expr.GetId()]
if hasOverload {
if len(oRef.GetOverloadId()) == 1 {
return target, fnName, oRef.GetOverloadId()[0]
if len(oRef.OverloadIDs) == 1 {
return target, fnName, oRef.OverloadIDs[0]
}
// Note, this namespaced function name will not appear as a fully qualified name in ASTs
// built and stored before cel-go v0.5.0; however, this functionality did not work at all
@@ -753,14 +746,18 @@ func (p *planner) resolveFunction(expr *exprpb.Expr) (*exprpb.Expr, string, stri
return target, fnName, ""
}
func (p *planner) relativeAttr(id int64, eval Interpretable) (InterpretableAttribute, error) {
// relativeAttr indicates that the attribute in this case acts as a qualifier and as such needs to
// be observed to ensure that it's evaluation value is properly recorded for state tracking.
func (p *planner) relativeAttr(id int64, eval Interpretable, opt bool) (InterpretableAttribute, error) {
eAttr, ok := eval.(InterpretableAttribute)
if !ok {
eAttr = &evalAttr{
adapter: p.adapter,
attr: p.attrFactory.RelativeAttribute(id, eval),
adapter: p.adapter,
attr: p.attrFactory.RelativeAttribute(id, eval),
optional: opt,
}
}
// This looks like it should either decorate the new evalAttr node, or early return the InterpretableAttribute
decAttr, err := p.decorate(eAttr, nil)
if err != nil {
return nil, err

View File

@@ -16,6 +16,7 @@ package interpreter
import (
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
@@ -26,6 +27,7 @@ import (
type astPruner struct {
expr *exprpb.Expr
macroCalls map[int64]*exprpb.Expr
state EvalState
nextExprID int64
}
@@ -65,13 +67,22 @@ type astPruner struct {
// compiled and constant folded expressions, but is not willing to constant
// fold(and thus cache results of) some external calls, then they can prepare
// the overloads accordingly.
func PruneAst(expr *exprpb.Expr, state EvalState) *exprpb.Expr {
func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalState) *exprpb.ParsedExpr {
pruneState := NewEvalState()
for _, id := range state.IDs() {
v, _ := state.Value(id)
pruneState.SetValue(id, v)
}
pruner := &astPruner{
expr: expr,
state: state,
nextExprID: 1}
newExpr, _ := pruner.prune(expr)
return newExpr
macroCalls: macroCalls,
state: pruneState,
nextExprID: getMaxID(expr)}
newExpr, _ := pruner.maybePrune(expr)
return &exprpb.ParsedExpr{
Expr: newExpr,
SourceInfo: &exprpb.SourceInfo{MacroCalls: pruner.macroCalls},
}
}
func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr {
@@ -84,28 +95,50 @@ func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr {
}
func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, bool) {
switch val.Type() {
case types.BoolType:
switch v := val.(type) {
case types.Bool:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: val.Value().(bool)}}), true
case types.IntType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: bool(v)}}), true
case types.Bytes:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: val.Value().(int64)}}), true
case types.UintType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: []byte(v)}}), true
case types.Double:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: val.Value().(uint64)}}), true
case types.StringType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: float64(v)}}), true
case types.Duration:
p.state.SetValue(id, val)
durationString := string(v.ConvertToType(types.StringType).(types.String))
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: &exprpb.Expr_Call{
Function: overloads.TypeConvertDuration,
Args: []*exprpb.Expr{
p.createLiteral(p.nextID(),
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: durationString}}),
},
},
},
}, true
case types.Int:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: val.Value().(string)}}), true
case types.DoubleType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: int64(v)}}), true
case types.Uint:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: val.Value().(float64)}}), true
case types.BytesType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: uint64(v)}}), true
case types.String:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: val.Value().([]byte)}}), true
case types.NullType:
&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: string(v)}}), true
case types.Null:
p.state.SetValue(id, val)
return p.createLiteral(id,
&exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: val.Value().(structpb.NullValue)}}), true
&exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: v.Value().(structpb.NullValue)}}), true
}
// Attempt to build a list literal.
@@ -123,6 +156,7 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
}
elemExprs[i] = elemExpr
}
p.state.SetValue(id, val)
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_ListExpr{
@@ -162,6 +196,7 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
entries[i] = entry
i++
}
p.state.SetValue(id, val)
return &exprpb.Expr{
Id: id,
ExprKind: &exprpb.Expr_StructExpr{
@@ -177,70 +212,152 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo
return nil, false
}
func (p *astPruner) maybePruneAndOr(node *exprpb.Expr) (*exprpb.Expr, bool) {
if !p.existsWithUnknownValue(node.GetId()) {
func (p *astPruner) maybePruneOptional(elem *exprpb.Expr) (*exprpb.Expr, bool) {
elemVal, found := p.value(elem.GetId())
if found && elemVal.Type() == types.OptionalType {
opt := elemVal.(*types.Optional)
if !opt.HasValue() {
return nil, true
}
if newElem, pruned := p.maybeCreateLiteral(elem.GetId(), opt.GetValue()); pruned {
return newElem, true
}
}
return elem, false
}
func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) {
// elem in list
call := node.GetCallExpr()
val, exists := p.maybeValue(call.GetArgs()[1].GetId())
if !exists {
return nil, false
}
if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero {
return p.maybeCreateLiteral(node.GetId(), types.False)
}
return nil, false
}
func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool) {
call := node.GetCallExpr()
arg := call.GetArgs()[0]
val, exists := p.maybeValue(arg.GetId())
if !exists {
return nil, false
}
if b, ok := val.(types.Bool); ok {
return p.maybeCreateLiteral(node.GetId(), !b)
}
return nil, false
}
func (p *astPruner) maybePruneOr(node *exprpb.Expr) (*exprpb.Expr, bool) {
call := node.GetCallExpr()
// We know result is unknown, so we have at least one unknown arg
// and if one side is a known value, we know we can ignore it.
if p.existsWithKnownValue(call.Args[0].GetId()) {
return call.Args[1], true
if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists {
if v == types.True {
return p.maybeCreateLiteral(node.GetId(), types.True)
}
return call.GetArgs()[1], true
}
if p.existsWithKnownValue(call.Args[1].GetId()) {
return call.Args[0], true
if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists {
if v == types.True {
return p.maybeCreateLiteral(node.GetId(), types.True)
}
return call.GetArgs()[0], true
}
return nil, false
}
func (p *astPruner) maybePruneAnd(node *exprpb.Expr) (*exprpb.Expr, bool) {
call := node.GetCallExpr()
// We know result is unknown, so we have at least one unknown arg
// and if one side is a known value, we know we can ignore it.
if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists {
if v == types.False {
return p.maybeCreateLiteral(node.GetId(), types.False)
}
return call.GetArgs()[1], true
}
if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists {
if v == types.False {
return p.maybeCreateLiteral(node.GetId(), types.False)
}
return call.GetArgs()[0], true
}
return nil, false
}
func (p *astPruner) maybePruneConditional(node *exprpb.Expr) (*exprpb.Expr, bool) {
if !p.existsWithUnknownValue(node.GetId()) {
return nil, false
}
call := node.GetCallExpr()
condVal, condValueExists := p.value(call.Args[0].GetId())
if !condValueExists || types.IsUnknownOrError(condVal) {
cond, exists := p.maybeValue(call.GetArgs()[0].GetId())
if !exists {
return nil, false
}
if condVal.Value().(bool) {
return call.Args[1], true
if cond.Value().(bool) {
return call.GetArgs()[1], true
}
return call.Args[2], true
return call.GetArgs()[2], true
}
func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) {
if _, exists := p.value(node.GetId()); !exists {
return nil, false
}
call := node.GetCallExpr()
if call.Function == operators.LogicalOr || call.Function == operators.LogicalAnd {
return p.maybePruneAndOr(node)
if call.Function == operators.LogicalOr {
return p.maybePruneOr(node)
}
if call.Function == operators.LogicalAnd {
return p.maybePruneAnd(node)
}
if call.Function == operators.Conditional {
return p.maybePruneConditional(node)
}
if call.Function == operators.In {
return p.maybePruneIn(node)
}
if call.Function == operators.LogicalNot {
return p.maybePruneLogicalNot(node)
}
return nil, false
}
func (p *astPruner) maybePrune(node *exprpb.Expr) (*exprpb.Expr, bool) {
return p.prune(node)
}
func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
if node == nil {
return node, false
}
val, valueExists := p.value(node.GetId())
if valueExists && !types.IsUnknownOrError(val) {
val, valueExists := p.maybeValue(node.GetId())
if valueExists {
if newNode, ok := p.maybeCreateLiteral(node.GetId(), val); ok {
delete(p.macroCalls, node.GetId())
return newNode, true
}
}
if macro, found := p.macroCalls[node.GetId()]; found {
// Ensure that intermediate values for the comprehension are cleared during pruning
compre := node.GetComprehensionExpr()
if compre != nil {
visit(macro, clearIterVarVisitor(compre.IterVar, p.state))
}
// prune the expression in terms of the macro call instead of the expanded form.
if newMacro, pruned := p.prune(macro); pruned {
p.macroCalls[node.GetId()] = newMacro
}
}
// We have either an unknown/error value, or something we don't want to
// transform, or expression was not evaluated. If possible, drill down
// more.
switch node.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
if operand, pruned := p.prune(node.GetSelectExpr().GetOperand()); pruned {
if operand, pruned := p.maybePrune(node.GetSelectExpr().GetOperand()); pruned {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_SelectExpr{
@@ -253,10 +370,6 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
}, true
}
case *exprpb.Expr_CallExpr:
if newExpr, pruned := p.maybePruneFunction(node); pruned {
newExpr, _ = p.prune(newExpr)
return newExpr, true
}
var prunedCall bool
call := node.GetCallExpr()
args := call.GetArgs()
@@ -268,40 +381,75 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
}
for i, arg := range args {
newArgs[i] = arg
if newArg, prunedArg := p.prune(arg); prunedArg {
if newArg, prunedArg := p.maybePrune(arg); prunedArg {
prunedCall = true
newArgs[i] = newArg
}
}
if newTarget, prunedTarget := p.prune(call.GetTarget()); prunedTarget {
if newTarget, prunedTarget := p.maybePrune(call.GetTarget()); prunedTarget {
prunedCall = true
newCall.Target = newTarget
}
newNode := &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: newCall,
},
}
if newExpr, pruned := p.maybePruneFunction(newNode); pruned {
newExpr, _ = p.maybePrune(newExpr)
return newExpr, true
}
if prunedCall {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_CallExpr{
CallExpr: newCall,
},
}, true
return newNode, true
}
case *exprpb.Expr_ListExpr:
elems := node.GetListExpr().GetElements()
newElems := make([]*exprpb.Expr, len(elems))
optIndices := node.GetListExpr().GetOptionalIndices()
optIndexMap := map[int32]bool{}
for _, i := range optIndices {
optIndexMap[i] = true
}
newOptIndexMap := make(map[int32]bool, len(optIndexMap))
newElems := make([]*exprpb.Expr, 0, len(elems))
var prunedList bool
prunedIdx := 0
for i, elem := range elems {
newElems[i] = elem
if newElem, prunedElem := p.prune(elem); prunedElem {
newElems[i] = newElem
prunedList = true
_, isOpt := optIndexMap[int32(i)]
if isOpt {
newElem, pruned := p.maybePruneOptional(elem)
if pruned {
prunedList = true
if newElem != nil {
newElems = append(newElems, newElem)
prunedIdx++
}
continue
}
newOptIndexMap[int32(prunedIdx)] = true
}
if newElem, prunedElem := p.maybePrune(elem); prunedElem {
newElems = append(newElems, newElem)
prunedList = true
} else {
newElems = append(newElems, elem)
}
prunedIdx++
}
optIndices = make([]int32, len(newOptIndexMap))
idx := 0
for i := range newOptIndexMap {
optIndices[idx] = i
idx++
}
if prunedList {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{
Elements: newElems,
Elements: newElems,
OptionalIndices: optIndices,
},
},
}, true
@@ -313,8 +461,8 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
newEntries := make([]*exprpb.Expr_CreateStruct_Entry, len(entries))
for i, entry := range entries {
newEntries[i] = entry
newKey, prunedKey := p.prune(entry.GetMapKey())
newValue, prunedValue := p.prune(entry.GetValue())
newKey, prunedKey := p.maybePrune(entry.GetMapKey())
newValue, prunedValue := p.maybePrune(entry.GetValue())
if !prunedKey && !prunedValue {
continue
}
@@ -331,6 +479,7 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
MapKey: newKey,
}
}
newEntry.OptionalEntry = entry.GetOptionalEntry()
newEntries[i] = newEntry
}
if prunedStruct {
@@ -349,7 +498,7 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
// Only the range of the comprehension is pruned since the state tracking only records
// the last iteration of the comprehension and not each step in the evaluation which
// means that the any residuals computed in between might be inaccurate.
if newRange, pruned := p.prune(compre.GetIterRange()); pruned {
if newRange, pruned := p.maybePrune(compre.GetIterRange()); pruned {
return &exprpb.Expr{
Id: node.GetId(),
ExprKind: &exprpb.Expr_ComprehensionExpr{
@@ -374,24 +523,97 @@ func (p *astPruner) value(id int64) (ref.Val, bool) {
return val, (found && val != nil)
}
func (p *astPruner) existsWithUnknownValue(id int64) bool {
val, valueExists := p.value(id)
return valueExists && types.IsUnknown(val)
}
func (p *astPruner) existsWithKnownValue(id int64) bool {
val, valueExists := p.value(id)
return valueExists && !types.IsUnknown(val)
func (p *astPruner) maybeValue(id int64) (ref.Val, bool) {
val, found := p.value(id)
if !found || types.IsUnknownOrError(val) {
return nil, false
}
return val, true
}
func (p *astPruner) nextID() int64 {
for {
_, found := p.state.Value(p.nextExprID)
if !found {
next := p.nextExprID
p.nextExprID++
return next
}
p.nextExprID++
next := p.nextExprID
p.nextExprID++
return next
}
type astVisitor struct {
// visitEntry is called on every expr node, including those within a map/struct entry.
visitExpr func(expr *exprpb.Expr)
// visitEntry is called before entering the key, value of a map/struct entry.
visitEntry func(entry *exprpb.Expr_CreateStruct_Entry)
}
func getMaxID(expr *exprpb.Expr) int64 {
maxID := int64(1)
visit(expr, maxIDVisitor(&maxID))
return maxID
}
func clearIterVarVisitor(varName string, state EvalState) astVisitor {
return astVisitor{
visitExpr: func(e *exprpb.Expr) {
ident := e.GetIdentExpr()
if ident != nil && ident.GetName() == varName {
state.SetValue(e.GetId(), nil)
}
},
}
}
func maxIDVisitor(maxID *int64) astVisitor {
return astVisitor{
visitExpr: func(e *exprpb.Expr) {
if e.GetId() >= *maxID {
*maxID = e.GetId() + 1
}
},
visitEntry: func(e *exprpb.Expr_CreateStruct_Entry) {
if e.GetId() >= *maxID {
*maxID = e.GetId() + 1
}
},
}
}
func visit(expr *exprpb.Expr, visitor astVisitor) {
exprs := []*exprpb.Expr{expr}
for len(exprs) != 0 {
e := exprs[0]
if visitor.visitExpr != nil {
visitor.visitExpr(e)
}
exprs = exprs[1:]
switch e.GetExprKind().(type) {
case *exprpb.Expr_SelectExpr:
exprs = append(exprs, e.GetSelectExpr().GetOperand())
case *exprpb.Expr_CallExpr:
call := e.GetCallExpr()
if call.GetTarget() != nil {
exprs = append(exprs, call.GetTarget())
}
exprs = append(exprs, call.GetArgs()...)
case *exprpb.Expr_ComprehensionExpr:
compre := e.GetComprehensionExpr()
exprs = append(exprs,
compre.GetIterRange(),
compre.GetAccuInit(),
compre.GetLoopCondition(),
compre.GetLoopStep(),
compre.GetResult())
case *exprpb.Expr_ListExpr:
list := e.GetListExpr()
exprs = append(exprs, list.GetElements()...)
case *exprpb.Expr_StructExpr:
for _, entry := range e.GetStructExpr().GetEntries() {
if visitor.visitEntry != nil {
visitor.visitEntry(entry)
}
if entry.GetMapKey() != nil {
exprs = append(exprs, entry.GetMapKey())
}
exprs = append(exprs, entry.GetValue())
}
}
}
}

View File

@@ -36,7 +36,7 @@ type ActualCostEstimator interface {
// CostObserver provides an observer that tracks runtime cost.
func CostObserver(tracker *CostTracker) EvalObserver {
observer := func(id int64, programStep interface{}, val ref.Val) {
observer := func(id int64, programStep any, val ref.Val) {
switch t := programStep.(type) {
case ConstantQualifier:
// TODO: Push identifiers on to the stack before observing constant qualifiers that apply to them
@@ -53,6 +53,11 @@ func CostObserver(tracker *CostTracker) EvalObserver {
tracker.stack.drop(t.Attr().ID())
tracker.cost += common.SelectAndIdentCost
}
if !tracker.presenceTestHasCost {
if _, isTestOnly := programStep.(*evalTestOnly); isTestOnly {
tracker.cost -= common.SelectAndIdentCost
}
}
case *evalExhaustiveConditional:
// Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions.
tracker.stack.drop(t.attr.falsy.ID(), t.attr.truthy.ID(), t.attr.expr.ID())
@@ -60,13 +65,21 @@ func CostObserver(tracker *CostTracker) EvalObserver {
// While the field names are identical, the boolean operation eval structs do not share an interface and so
// must be handled individually.
case *evalOr:
tracker.stack.drop(t.rhs.ID(), t.lhs.ID())
for _, term := range t.terms {
tracker.stack.drop(term.ID())
}
case *evalAnd:
tracker.stack.drop(t.rhs.ID(), t.lhs.ID())
for _, term := range t.terms {
tracker.stack.drop(term.ID())
}
case *evalExhaustiveOr:
tracker.stack.drop(t.rhs.ID(), t.lhs.ID())
for _, term := range t.terms {
tracker.stack.drop(term.ID())
}
case *evalExhaustiveAnd:
tracker.stack.drop(t.rhs.ID(), t.lhs.ID())
for _, term := range t.terms {
tracker.stack.drop(term.ID())
}
case *evalFold:
tracker.stack.drop(t.iterRange.ID())
case Qualifier:
@@ -95,24 +108,86 @@ func CostObserver(tracker *CostTracker) EvalObserver {
return observer
}
// CostTracker represents the information needed for tacking runtime cost
// CostTrackerOption configures the behavior of CostTracker objects.
type CostTrackerOption func(*CostTracker) error
// CostTrackerLimit sets the runtime limit on the evaluation cost during execution and will terminate the expression
// evaluation if the limit is exceeded.
func CostTrackerLimit(limit uint64) CostTrackerOption {
return func(tracker *CostTracker) error {
tracker.Limit = &limit
return nil
}
}
// PresenceTestHasCost determines whether presence testing has a cost of one or zero.
// Defaults to presence test has a cost of one.
func PresenceTestHasCost(hasCost bool) CostTrackerOption {
return func(tracker *CostTracker) error {
tracker.presenceTestHasCost = hasCost
return nil
}
}
// NewCostTracker creates a new CostTracker with a given estimator and a set of functional CostTrackerOption values.
func NewCostTracker(estimator ActualCostEstimator, opts ...CostTrackerOption) (*CostTracker, error) {
tracker := &CostTracker{
Estimator: estimator,
overloadTrackers: map[string]FunctionTracker{},
presenceTestHasCost: true,
}
for _, opt := range opts {
err := opt(tracker)
if err != nil {
return nil, err
}
}
return tracker, nil
}
// OverloadCostTracker binds an overload ID to a runtime FunctionTracker implementation.
//
// OverloadCostTracker instances augment or override ActualCostEstimator decisions, allowing for versioned and/or
// optional cost tracking changes.
func OverloadCostTracker(overloadID string, fnTracker FunctionTracker) CostTrackerOption {
return func(tracker *CostTracker) error {
tracker.overloadTrackers[overloadID] = fnTracker
return nil
}
}
// FunctionTracker computes the actual cost of evaluating the functions with the given arguments and result.
type FunctionTracker func(args []ref.Val, result ref.Val) *uint64
// CostTracker represents the information needed for tracking runtime cost.
type CostTracker struct {
Estimator ActualCostEstimator
Limit *uint64
Estimator ActualCostEstimator
overloadTrackers map[string]FunctionTracker
Limit *uint64
presenceTestHasCost bool
cost uint64
stack refValStack
}
// ActualCost returns the runtime cost
func (c CostTracker) ActualCost() uint64 {
func (c *CostTracker) ActualCost() uint64 {
return c.cost
}
func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val, result ref.Val) uint64 {
func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result ref.Val) uint64 {
var cost uint64
if len(c.overloadTrackers) != 0 {
if tracker, found := c.overloadTrackers[call.OverloadID()]; found {
callCost := tracker(args, result)
if callCost != nil {
cost += *callCost
return cost
}
}
}
if c.Estimator != nil {
callCost := c.Estimator.CallCost(call.Function(), call.OverloadID(), argValues, result)
callCost := c.Estimator.CallCost(call.Function(), call.OverloadID(), args, result)
if callCost != nil {
cost += *callCost
return cost
@@ -122,12 +197,12 @@ func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resul
// if user has their own implementation of ActualCostEstimator, make sure to cover the mapping between overloadId and cost calculation
switch call.OverloadID() {
// O(n) functions
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString:
cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor))
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString, overloads.ExtQuoteString, overloads.ExtFormatString:
cost += uint64(math.Ceil(float64(c.actualSize(args[0])) * common.StringTraversalCostFactor))
case overloads.InList:
// If a list is composed entirely of constant values this is O(1), but we don't account for that here.
// We just assume all list containment checks are O(n).
cost += c.actualSize(argValues[1])
cost += c.actualSize(args[1])
// O(min(m, n)) functions
case overloads.LessString, overloads.GreaterString, overloads.LessEqualsString, overloads.GreaterEqualsString,
overloads.LessBytes, overloads.GreaterBytes, overloads.LessEqualsBytes, overloads.GreaterEqualsBytes,
@@ -135,8 +210,8 @@ func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resul
// When we check the equality of 2 scalar values (e.g. 2 integers, 2 floating-point numbers, 2 booleans etc.),
// the CostTracker.actualSize() function by definition returns 1 for each operand, resulting in an overall cost
// of 1.
lhsSize := c.actualSize(argValues[0])
rhsSize := c.actualSize(argValues[1])
lhsSize := c.actualSize(args[0])
rhsSize := c.actualSize(args[1])
minSize := lhsSize
if rhsSize < minSize {
minSize = rhsSize
@@ -145,23 +220,23 @@ func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resul
// O(m+n) functions
case overloads.AddString, overloads.AddBytes:
// In the worst case scenario, we would need to reallocate a new backing store and copy both operands over.
cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])+c.actualSize(argValues[1])) * common.StringTraversalCostFactor))
cost += uint64(math.Ceil(float64(c.actualSize(args[0])+c.actualSize(args[1])) * common.StringTraversalCostFactor))
// O(nm) functions
case overloads.MatchesString:
// https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL
// Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0
// in case where string is empty but regex is still expensive.
strCost := uint64(math.Ceil((1.0 + float64(c.actualSize(argValues[0]))) * common.StringTraversalCostFactor))
strCost := uint64(math.Ceil((1.0 + float64(c.actualSize(args[0]))) * common.StringTraversalCostFactor))
// We don't know how many expressions are in the regex, just the string length (a huge
// improvement here would be to somehow get a count the number of expressions in the regex or
// how many states are in the regex state machine and use that to measure regex cost).
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
// in length.
regexCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.RegexStringLengthCostFactor))
regexCost := uint64(math.Ceil(float64(c.actualSize(args[1])) * common.RegexStringLengthCostFactor))
cost += strCost * regexCost
case overloads.ContainsString:
strCost := uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor))
substrCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.StringTraversalCostFactor))
strCost := uint64(math.Ceil(float64(c.actualSize(args[0])) * common.StringTraversalCostFactor))
substrCost := uint64(math.Ceil(float64(c.actualSize(args[1])) * common.StringTraversalCostFactor))
cost += strCost * substrCost
default:
@@ -179,7 +254,7 @@ func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val, resul
}
// actualSize returns the size of value
func (c CostTracker) actualSize(value ref.Val) uint64 {
func (c *CostTracker) actualSize(value ref.Val) uint64 {
if sz, ok := value.(traits.Sizer); ok {
return uint64(sz.Size().(types.Int))
}

Some files were not shown because too many files have changed in this diff Show More