Files
kubesphere/vendor/github.com/open-policy-agent/opa/topdown/providers.go
hongming cfebd96a1f update dependencies (#6267)
Signed-off-by: hongming <coder.scala@gmail.com>
2024-11-06 10:27:06 +08:00

215 lines
5.7 KiB
Go

// Copyright 2022 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"encoding/json"
"net/url"
"time"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/internal/providers/aws"
"github.com/open-policy-agent/opa/topdown/builtins"
)
var awsRequiredConfigKeyNames = ast.NewSet(
ast.StringTerm("aws_service"),
ast.StringTerm("aws_access_key"),
ast.StringTerm("aws_secret_access_key"),
ast.StringTerm("aws_region"),
)
func stringFromTerm(t *ast.Term) string {
if v, ok := t.Value.(ast.String); ok {
return string(v)
}
return ""
}
func getReqBodyBytes(body, rawBody *ast.Term) ([]byte, error) {
var out []byte
switch {
case rawBody != nil:
out = []byte(stringFromTerm(rawBody))
case body != nil:
bodyVal := body.Value
bodyValInterface, err := ast.JSON(bodyVal)
if err != nil {
return nil, err
}
bodyValBytes, err := json.Marshal(bodyValInterface)
if err != nil {
return nil, err
}
out = bodyValBytes
default:
out = []byte("")
}
return out, nil
}
func objectToMap(o ast.Object) map[string][]string {
out := make(map[string][]string, o.Len())
o.Foreach(func(k, v *ast.Term) {
ks := stringFromTerm(k)
vs := stringFromTerm(v)
out[ks] = []string{vs}
})
return out
}
// Note(philipc): This is roughly the same approach used for http.send.
func validateAWSAuthParameters(o ast.Object) error {
awsKeys := ast.NewSet(o.Keys()...)
missingKeys := awsRequiredConfigKeyNames.Diff(awsKeys)
if missingKeys.Len() != 0 {
return builtins.NewOperandErr(2, "missing required AWS config parameters(s): %v", missingKeys)
}
invalidKeys := ast.NewSet()
awsRequiredConfigKeyNames.Foreach(func(t *ast.Term) {
if v := o.Get(t); v != nil {
if _, ok := v.Value.(ast.String); !ok {
invalidKeys.Add(t)
}
}
})
if invalidKeys.Len() != 0 {
return builtins.NewOperandErr(2, "invalid values for required AWS config parameters(s): %v", invalidKeys)
}
return nil
}
func builtinAWSSigV4SignReq(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
// Request object.
reqObj, err := builtins.ObjectOperand(operands[0].Value, 1)
if err != nil {
return err
}
// AWS SigV4 config info object.
awsConfigObj, err := builtins.ObjectOperand(operands[1].Value, 1)
if err != nil {
return err
}
// Make sure our required keys exist!
err = validateAWSAuthParameters(awsConfigObj)
if err != nil {
return err
}
service := stringFromTerm(awsConfigObj.Get(ast.StringTerm("aws_service")))
awsCreds := aws.CredentialsFromObject(awsConfigObj)
// Timestamp for signing.
var signingTimestamp time.Time
timestamp, err := builtins.NumberOperand(operands[2].Value, 1)
if err != nil {
return err
}
ts, ok := timestamp.Int64()
if !ok {
return builtins.NewOperandErr(3, "could not convert time_ns value into a unix timestamp")
}
signingTimestamp = time.Unix(0, ts)
if err != nil {
return err
}
// Make sure our required keys exist!
// This check is stricter than required, but better to break here than downstream.
_, err = validateHTTPRequestOperand(operands[0], 1)
if err != nil {
return err
}
// Prepare required fields from the HTTP request object.
var theURL *url.URL
var method string
reqURL := reqObj.Get(ast.StringTerm("url"))
reqMethod := reqObj.Get(ast.StringTerm("method"))
headers := ast.NewObject()
headersTerm := reqObj.Get(ast.StringTerm("headers"))
if headersTerm != nil {
var ok bool
headers, ok = headersTerm.Value.(ast.Object)
if !ok {
return builtins.NewOperandTypeErr(0, headersTerm.Value, "object")
}
}
// Check types on the request parameters.
invalidParameters := ast.NewSet()
if _, ok := reqURL.Value.(ast.String); !ok {
invalidParameters.Add(ast.StringTerm("url"))
}
if _, ok := reqMethod.Value.(ast.String); !ok {
invalidParameters.Add(ast.StringTerm("method"))
}
if invalidParameters.Len() > 0 {
return builtins.NewOperandErr(1, "invalid values for required request parameters(s): %v", invalidParameters)
}
theURL, err = url.Parse(stringFromTerm(reqURL))
if err != nil {
return err
}
method = stringFromTerm(reqMethod)
bodyTerm := reqObj.Get(ast.StringTerm("body"))
rawBodyTerm := reqObj.Get(ast.StringTerm("raw_body"))
body, err := getReqBodyBytes(bodyTerm, rawBodyTerm)
if err != nil {
return err
}
// Sign the request object's headers, and reconstruct the headers map.
headersMap := objectToMap(headers)
// if payload signing config is set, pass it down to the signing method
disablePayloadSigning := false
t := awsConfigObj.Get(ast.StringTerm("disable_payload_signing"))
if t != nil {
if v, ok := t.Value.(ast.Boolean); ok {
disablePayloadSigning = bool(v)
} else {
return builtins.NewOperandErr(2, "invalid value for 'disable_payload_signing' in AWS config")
}
}
authHeader, awsHeadersMap := aws.SignV4(headersMap, method, theURL, body, service, awsCreds, signingTimestamp, disablePayloadSigning)
signedHeadersObj := ast.NewObject()
// Restore original headers
for k, v := range headersMap {
// objectToMap doesn't support arrays
if len(v) == 1 {
signedHeadersObj.Insert(ast.StringTerm(k), ast.StringTerm(v[0]))
}
}
// Set authorization header
signedHeadersObj.Insert(ast.StringTerm("Authorization"), ast.StringTerm(authHeader))
// set aws signature headers
for k, v := range awsHeadersMap {
signedHeadersObj.Insert(ast.StringTerm(k), ast.StringTerm(v))
}
// Create new request object with updated headers.
out := reqObj.Copy()
out.Insert(ast.StringTerm("headers"), ast.NewTerm(signedHeadersObj))
return iter(ast.NewTerm(out))
}
func init() {
RegisterBuiltinFunc(ast.ProvidersAWSSignReqObj.Name, builtinAWSSigV4SignReq)
}