467 lines
12 KiB
Go
467 lines
12 KiB
Go
// Copyright 2018 The OPA Authors. All rights reserved.
|
|
// Use of this source code is governed by an Apache2
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package topdown
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net/url"
|
|
"strconv"
|
|
|
|
"github.com/open-policy-agent/opa/internal/version"
|
|
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/open-policy-agent/opa/ast"
|
|
"github.com/open-policy-agent/opa/topdown/builtins"
|
|
)
|
|
|
|
const defaultHTTPRequestTimeoutEnv = "HTTP_SEND_TIMEOUT"
|
|
|
|
var defaultHTTPRequestTimeout = time.Second * 5
|
|
|
|
var allowedKeyNames = [...]string{
|
|
"method",
|
|
"url",
|
|
"body",
|
|
"enable_redirect",
|
|
"force_json_decode",
|
|
"headers",
|
|
"raw_body",
|
|
"tls_use_system_certs",
|
|
"tls_ca_cert_file",
|
|
"tls_ca_cert_env_variable",
|
|
"tls_client_cert_env_variable",
|
|
"tls_client_key_env_variable",
|
|
"tls_client_cert_file",
|
|
"tls_client_key_file",
|
|
"tls_insecure_skip_verify",
|
|
"timeout",
|
|
}
|
|
var allowedKeys = ast.NewSet()
|
|
|
|
var requiredKeys = ast.NewSet(ast.StringTerm("method"), ast.StringTerm("url"))
|
|
|
|
type httpSendKey string
|
|
|
|
// httpSendBuiltinCacheKey is the key in the builtin context cache that
|
|
// points to the http.send() specific cache resides at.
|
|
const httpSendBuiltinCacheKey httpSendKey = "HTTP_SEND_CACHE_KEY"
|
|
|
|
func builtinHTTPSend(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
|
|
|
|
req, err := validateHTTPRequestOperand(args[0], 1)
|
|
if err != nil {
|
|
return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err)
|
|
}
|
|
|
|
// check if cache already has a response for this query
|
|
resp := checkHTTPSendCache(bctx, req)
|
|
if resp == nil {
|
|
var err error
|
|
resp, err = executeHTTPRequest(bctx, req)
|
|
if err != nil {
|
|
return handleHTTPSendErr(bctx, err)
|
|
}
|
|
|
|
// add result to cache
|
|
insertIntoHTTPSendCache(bctx, req, resp)
|
|
}
|
|
|
|
return iter(ast.NewTerm(resp))
|
|
}
|
|
|
|
func init() {
|
|
createAllowedKeys()
|
|
initDefaults()
|
|
RegisterBuiltinFunc(ast.HTTPSend.Name, builtinHTTPSend)
|
|
}
|
|
|
|
func handleHTTPSendErr(bctx BuiltinContext, err error) error {
|
|
// Return HTTP client timeout errors in a generic error message to avoid confusion about what happened.
|
|
// Do not do this if the builtin context was cancelled and is what caused the request to stop.
|
|
if urlErr, ok := err.(*url.Error); ok && urlErr.Timeout() && bctx.Context.Err() == nil {
|
|
err = fmt.Errorf("%s %s: request timed out", urlErr.Op, urlErr.URL)
|
|
}
|
|
return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err)
|
|
}
|
|
|
|
func initDefaults() {
|
|
timeoutDuration := os.Getenv(defaultHTTPRequestTimeoutEnv)
|
|
if timeoutDuration != "" {
|
|
var err error
|
|
defaultHTTPRequestTimeout, err = time.ParseDuration(timeoutDuration)
|
|
if err != nil {
|
|
// If it is set to something not valid don't let the process continue in a state
|
|
// that will almost definitely give unexpected results by having it set at 0
|
|
// which means no timeout..
|
|
// This environment variable isn't considered part of the public API.
|
|
// TODO(patrick-east): Remove the environment variable
|
|
panic(fmt.Sprintf("invalid value for HTTP_SEND_TIMEOUT: %s", err))
|
|
}
|
|
}
|
|
}
|
|
|
|
func validateHTTPRequestOperand(term *ast.Term, pos int) (ast.Object, error) {
|
|
|
|
obj, err := builtins.ObjectOperand(term.Value, pos)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
requestKeys := ast.NewSet(obj.Keys()...)
|
|
|
|
invalidKeys := requestKeys.Diff(allowedKeys)
|
|
if invalidKeys.Len() != 0 {
|
|
return nil, builtins.NewOperandErr(pos, "invalid request parameters(s): %v", invalidKeys)
|
|
}
|
|
|
|
missingKeys := requiredKeys.Diff(requestKeys)
|
|
if missingKeys.Len() != 0 {
|
|
return nil, builtins.NewOperandErr(pos, "missing required request parameters(s): %v", missingKeys)
|
|
}
|
|
|
|
return obj, nil
|
|
|
|
}
|
|
|
|
// Adds custom headers to a new HTTP request.
|
|
func addHeaders(req *http.Request, headers map[string]interface{}) (bool, error) {
|
|
for k, v := range headers {
|
|
// Type assertion
|
|
header, ok := v.(string)
|
|
if !ok {
|
|
return false, fmt.Errorf("invalid type for headers value %q", v)
|
|
}
|
|
|
|
// If the Host header is given, bump that up to
|
|
// the request. Otherwise, just collect it in the
|
|
// headers.
|
|
k := http.CanonicalHeaderKey(k)
|
|
switch k {
|
|
case "Host":
|
|
req.Host = header
|
|
default:
|
|
req.Header.Add(k, header)
|
|
}
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func executeHTTPRequest(bctx BuiltinContext, obj ast.Object) (ast.Value, error) {
|
|
var url string
|
|
var method string
|
|
var tlsCaCertEnvVar []byte
|
|
var tlsCaCertFile string
|
|
var tlsClientKeyEnvVar []byte
|
|
var tlsClientCertEnvVar []byte
|
|
var tlsClientCertFile string
|
|
var tlsClientKeyFile string
|
|
var body *bytes.Buffer
|
|
var rawBody *bytes.Buffer
|
|
var enableRedirect bool
|
|
var forceJSONDecode bool
|
|
var tlsUseSystemCerts bool
|
|
var tlsConfig tls.Config
|
|
var clientCerts []tls.Certificate
|
|
var customHeaders map[string]interface{}
|
|
var tlsInsecureSkipVerify bool
|
|
var timeout = defaultHTTPRequestTimeout
|
|
|
|
for _, val := range obj.Keys() {
|
|
key, err := ast.JSON(val.Value)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
key = key.(string)
|
|
|
|
switch key {
|
|
case "method":
|
|
method = obj.Get(val).String()
|
|
method = strings.ToUpper(strings.Trim(method, "\""))
|
|
case "url":
|
|
url = obj.Get(val).String()
|
|
url = strings.Trim(url, "\"")
|
|
case "enable_redirect":
|
|
enableRedirect, err = strconv.ParseBool(obj.Get(val).String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
case "force_json_decode":
|
|
forceJSONDecode, err = strconv.ParseBool(obj.Get(val).String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
case "body":
|
|
bodyVal := obj.Get(val).Value
|
|
bodyValInterface, err := ast.JSON(bodyVal)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
bodyValBytes, err := json.Marshal(bodyValInterface)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
body = bytes.NewBuffer(bodyValBytes)
|
|
case "raw_body":
|
|
s, ok := obj.Get(val).Value.(ast.String)
|
|
if !ok {
|
|
return nil, fmt.Errorf("raw_body must be a string")
|
|
}
|
|
rawBody = bytes.NewBuffer([]byte(s))
|
|
case "tls_use_system_certs":
|
|
tlsUseSystemCerts, err = strconv.ParseBool(obj.Get(val).String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
case "tls_ca_cert_file":
|
|
tlsCaCertFile = obj.Get(val).String()
|
|
tlsCaCertFile = strings.Trim(tlsCaCertFile, "\"")
|
|
case "tls_ca_cert_env_variable":
|
|
caCertEnv := obj.Get(val).String()
|
|
caCertEnv = strings.Trim(caCertEnv, "\"")
|
|
tlsCaCertEnvVar = []byte(os.Getenv(caCertEnv))
|
|
case "tls_client_cert_env_variable":
|
|
clientCertEnv := obj.Get(val).String()
|
|
clientCertEnv = strings.Trim(clientCertEnv, "\"")
|
|
tlsClientCertEnvVar = []byte(os.Getenv(clientCertEnv))
|
|
case "tls_client_key_env_variable":
|
|
clientKeyEnv := obj.Get(val).String()
|
|
clientKeyEnv = strings.Trim(clientKeyEnv, "\"")
|
|
tlsClientKeyEnvVar = []byte(os.Getenv(clientKeyEnv))
|
|
case "tls_client_cert_file":
|
|
tlsClientCertFile = obj.Get(val).String()
|
|
tlsClientCertFile = strings.Trim(tlsClientCertFile, "\"")
|
|
case "tls_client_key_file":
|
|
tlsClientKeyFile = obj.Get(val).String()
|
|
tlsClientKeyFile = strings.Trim(tlsClientKeyFile, "\"")
|
|
case "headers":
|
|
headersVal := obj.Get(val).Value
|
|
headersValInterface, err := ast.JSON(headersVal)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var ok bool
|
|
customHeaders, ok = headersValInterface.(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid type for headers key")
|
|
}
|
|
case "tls_insecure_skip_verify":
|
|
tlsInsecureSkipVerify, err = strconv.ParseBool(obj.Get(val).String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
case "timeout":
|
|
timeout, err = parseTimeout(obj.Get(val).Value)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("invalid parameter %q", key)
|
|
}
|
|
}
|
|
|
|
client := &http.Client{
|
|
Timeout: timeout,
|
|
}
|
|
|
|
if tlsInsecureSkipVerify {
|
|
client.Transport = &http.Transport{
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: tlsInsecureSkipVerify},
|
|
}
|
|
}
|
|
if tlsClientCertFile != "" && tlsClientKeyFile != "" {
|
|
clientCertFromFile, err := tls.LoadX509KeyPair(tlsClientCertFile, tlsClientKeyFile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
clientCerts = append(clientCerts, clientCertFromFile)
|
|
}
|
|
|
|
if len(tlsClientCertEnvVar) > 0 && len(tlsClientKeyEnvVar) > 0 {
|
|
clientCertFromEnv, err := tls.X509KeyPair(tlsClientCertEnvVar, tlsClientKeyEnvVar)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
clientCerts = append(clientCerts, clientCertFromEnv)
|
|
}
|
|
|
|
isTLS := false
|
|
if len(clientCerts) > 0 {
|
|
isTLS = true
|
|
tlsConfig.Certificates = append(tlsConfig.Certificates, clientCerts...)
|
|
}
|
|
|
|
if tlsUseSystemCerts || len(tlsCaCertFile) > 0 || len(tlsCaCertEnvVar) > 0 {
|
|
isTLS = true
|
|
connRootCAs, err := createRootCAs(tlsCaCertFile, tlsCaCertEnvVar, tlsUseSystemCerts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tlsConfig.RootCAs = connRootCAs
|
|
}
|
|
|
|
if isTLS {
|
|
client.Transport = &http.Transport{
|
|
TLSClientConfig: &tlsConfig,
|
|
}
|
|
}
|
|
|
|
// check if redirects are enabled
|
|
if !enableRedirect {
|
|
client.CheckRedirect = func(*http.Request, []*http.Request) error {
|
|
return http.ErrUseLastResponse
|
|
}
|
|
}
|
|
|
|
if rawBody != nil {
|
|
body = rawBody
|
|
} else if body == nil {
|
|
body = bytes.NewBufferString("")
|
|
}
|
|
|
|
// create the http request, use the builtin context's context to ensure
|
|
// the request is cancelled if evaluation is cancelled.
|
|
req, err := http.NewRequest(method, url, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req = req.WithContext(bctx.Context)
|
|
|
|
// Add custom headers
|
|
if len(customHeaders) != 0 {
|
|
if ok, err := addHeaders(req, customHeaders); !ok {
|
|
return nil, err
|
|
}
|
|
// Don't overwrite or append to one that was set in the custom headers
|
|
if _, hasUA := customHeaders["User-Agent"]; !hasUA {
|
|
req.Header.Add("User-Agent", version.UserAgent)
|
|
}
|
|
}
|
|
|
|
// execute the http request
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
// format the http result
|
|
var resultBody interface{}
|
|
var resultRawBody []byte
|
|
|
|
var buf bytes.Buffer
|
|
tee := io.TeeReader(resp.Body, &buf)
|
|
resultRawBody, err = ioutil.ReadAll(tee)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// If the response body cannot be JSON decoded,
|
|
// an error will not be returned. Instead the "body" field
|
|
// in the result will be null.
|
|
if isContentTypeJSON(resp.Header) || forceJSONDecode {
|
|
json.NewDecoder(&buf).Decode(&resultBody)
|
|
}
|
|
|
|
result := make(map[string]interface{})
|
|
result["status"] = resp.Status
|
|
result["status_code"] = resp.StatusCode
|
|
result["body"] = resultBody
|
|
result["raw_body"] = string(resultRawBody)
|
|
|
|
resultObj, err := ast.InterfaceToValue(result)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return resultObj, nil
|
|
}
|
|
|
|
func isContentTypeJSON(header http.Header) bool {
|
|
return strings.Contains(header.Get("Content-Type"), "application/json")
|
|
}
|
|
|
|
// In the BuiltinContext cache we only store a single entry that points to
|
|
// our ValueMap which is the "real" http.send() cache.
|
|
func getHTTPSendCache(bctx BuiltinContext) *ast.ValueMap {
|
|
raw, ok := bctx.Cache.Get(httpSendBuiltinCacheKey)
|
|
if !ok {
|
|
// Initialize if it isn't there
|
|
cache := ast.NewValueMap()
|
|
bctx.Cache.Put(httpSendBuiltinCacheKey, cache)
|
|
return cache
|
|
}
|
|
|
|
cache, ok := raw.(*ast.ValueMap)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
return cache
|
|
}
|
|
|
|
// checkHTTPSendCache checks for the given key's value in the cache
|
|
func checkHTTPSendCache(bctx BuiltinContext, key ast.Object) ast.Value {
|
|
requestCache := getHTTPSendCache(bctx)
|
|
if requestCache == nil {
|
|
return nil
|
|
}
|
|
|
|
return requestCache.Get(key)
|
|
}
|
|
|
|
func insertIntoHTTPSendCache(bctx BuiltinContext, key ast.Object, value ast.Value) {
|
|
requestCache := getHTTPSendCache(bctx)
|
|
if requestCache == nil {
|
|
// Should never happen.. if it does just skip caching the value
|
|
return
|
|
}
|
|
requestCache.Put(key, value)
|
|
}
|
|
|
|
func createAllowedKeys() {
|
|
for _, element := range allowedKeyNames {
|
|
allowedKeys.Add(ast.StringTerm(element))
|
|
}
|
|
}
|
|
|
|
func parseTimeout(timeoutVal ast.Value) (time.Duration, error) {
|
|
var timeout time.Duration
|
|
switch t := timeoutVal.(type) {
|
|
case ast.Number:
|
|
timeoutInt, ok := t.Int64()
|
|
if !ok {
|
|
return timeout, fmt.Errorf("invalid timeout number value %v, must be int64", timeoutVal)
|
|
}
|
|
return time.Duration(timeoutInt), nil
|
|
case ast.String:
|
|
// Support strings without a unit, treat them the same as just a number value (ns)
|
|
var err error
|
|
timeoutInt, err := strconv.ParseInt(string(t), 10, 64)
|
|
if err == nil {
|
|
return time.Duration(timeoutInt), nil
|
|
}
|
|
|
|
// Try parsing it as a duration (requires a supported units suffix)
|
|
timeout, err = time.ParseDuration(string(t))
|
|
if err != nil {
|
|
return timeout, fmt.Errorf("invalid timeout value %v: %s", timeoutVal, err)
|
|
}
|
|
return timeout, nil
|
|
default:
|
|
return timeout, builtins.NewOperandErr(1, "'timeout' must be one of {string, number} but got %s", ast.TypeName(t))
|
|
}
|
|
}
|