1412 lines
38 KiB
Go
1412 lines
38 KiB
Go
// Copyright 2018 The OPA Authors. All rights reserved.
|
|
// Use of this source code is governed by an Apache2
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package topdown
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/open-policy-agent/opa/ast"
|
|
"github.com/open-policy-agent/opa/internal/version"
|
|
"github.com/open-policy-agent/opa/topdown/builtins"
|
|
"github.com/open-policy-agent/opa/topdown/cache"
|
|
"github.com/open-policy-agent/opa/tracing"
|
|
"github.com/open-policy-agent/opa/util"
|
|
)
|
|
|
|
type cachingMode string
|
|
|
|
const (
|
|
defaultHTTPRequestTimeoutEnv = "HTTP_SEND_TIMEOUT"
|
|
defaultCachingMode cachingMode = "serialized"
|
|
cachingModeDeserialized cachingMode = "deserialized"
|
|
)
|
|
|
|
var defaultHTTPRequestTimeout = time.Second * 5
|
|
|
|
var allowedKeyNames = [...]string{
|
|
"method",
|
|
"url",
|
|
"body",
|
|
"enable_redirect",
|
|
"force_json_decode",
|
|
"force_yaml_decode",
|
|
"headers",
|
|
"raw_body",
|
|
"tls_use_system_certs",
|
|
"tls_ca_cert",
|
|
"tls_ca_cert_file",
|
|
"tls_ca_cert_env_variable",
|
|
"tls_client_cert",
|
|
"tls_client_cert_file",
|
|
"tls_client_cert_env_variable",
|
|
"tls_client_key",
|
|
"tls_client_key_file",
|
|
"tls_client_key_env_variable",
|
|
"tls_insecure_skip_verify",
|
|
"tls_server_name",
|
|
"timeout",
|
|
"cache",
|
|
"force_cache",
|
|
"force_cache_duration_seconds",
|
|
"raise_error",
|
|
"caching_mode",
|
|
}
|
|
|
|
// ref: https://www.rfc-editor.org/rfc/rfc7231#section-6.1
|
|
var cacheableHTTPStatusCodes = [...]int{
|
|
http.StatusOK,
|
|
http.StatusNonAuthoritativeInfo,
|
|
http.StatusNoContent,
|
|
http.StatusPartialContent,
|
|
http.StatusMultipleChoices,
|
|
http.StatusMovedPermanently,
|
|
http.StatusNotFound,
|
|
http.StatusMethodNotAllowed,
|
|
http.StatusGone,
|
|
http.StatusRequestURITooLong,
|
|
http.StatusNotImplemented,
|
|
}
|
|
|
|
var (
|
|
allowedKeys = ast.NewSet()
|
|
cacheableCodes = ast.NewSet()
|
|
requiredKeys = ast.NewSet(ast.StringTerm("method"), ast.StringTerm("url"))
|
|
httpSendLatencyMetricKey = "rego_builtin_" + strings.ReplaceAll(ast.HTTPSend.Name, ".", "_")
|
|
httpSendInterQueryCacheHits = httpSendLatencyMetricKey + "_interquery_cache_hits"
|
|
)
|
|
|
|
type httpSendKey string
|
|
|
|
const (
|
|
// httpSendBuiltinCacheKey is the key in the builtin context cache that
|
|
// points to the http.send() specific cache resides at.
|
|
httpSendBuiltinCacheKey httpSendKey = "HTTP_SEND_CACHE_KEY"
|
|
|
|
// HTTPSendInternalErr represents a runtime evaluation error.
|
|
HTTPSendInternalErr string = "eval_http_send_internal_error"
|
|
|
|
// HTTPSendNetworkErr represents a network error.
|
|
HTTPSendNetworkErr string = "eval_http_send_network_error"
|
|
)
|
|
|
|
func builtinHTTPSend(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
|
|
req, err := validateHTTPRequestOperand(operands[0], 1)
|
|
if err != nil {
|
|
return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err)
|
|
}
|
|
|
|
raiseError, err := getRaiseErrorValue(req)
|
|
if err != nil {
|
|
return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err)
|
|
}
|
|
|
|
result, err := getHTTPResponse(bctx, req)
|
|
if err != nil {
|
|
if raiseError {
|
|
return handleHTTPSendErr(bctx, err)
|
|
}
|
|
|
|
obj := ast.NewObject()
|
|
obj.Insert(ast.StringTerm("status_code"), ast.IntNumberTerm(0))
|
|
|
|
errObj := ast.NewObject()
|
|
|
|
switch err.(type) {
|
|
case *url.Error:
|
|
errObj.Insert(ast.StringTerm("code"), ast.StringTerm(HTTPSendNetworkErr))
|
|
default:
|
|
errObj.Insert(ast.StringTerm("code"), ast.StringTerm(HTTPSendInternalErr))
|
|
}
|
|
|
|
errObj.Insert(ast.StringTerm("message"), ast.StringTerm(err.Error()))
|
|
obj.Insert(ast.StringTerm("error"), ast.NewTerm(errObj))
|
|
|
|
result = ast.NewTerm(obj)
|
|
}
|
|
return iter(result)
|
|
}
|
|
|
|
func getHTTPResponse(bctx BuiltinContext, req ast.Object) (*ast.Term, error) {
|
|
|
|
bctx.Metrics.Timer(httpSendLatencyMetricKey).Start()
|
|
|
|
reqExecutor, err := newHTTPRequestExecutor(bctx, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Check if cache already has a response for this query
|
|
resp, err := reqExecutor.CheckCache()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp == nil {
|
|
httpResp, err := reqExecutor.ExecuteHTTPRequest()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer util.Close(httpResp)
|
|
// Add result to intra/inter-query cache.
|
|
resp, err = reqExecutor.InsertIntoCache(httpResp)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
bctx.Metrics.Timer(httpSendLatencyMetricKey).Stop()
|
|
|
|
return ast.NewTerm(resp), nil
|
|
}
|
|
|
|
func init() {
|
|
createAllowedKeys()
|
|
createCacheableHTTPStatusCodes()
|
|
initDefaults()
|
|
RegisterBuiltinFunc(ast.HTTPSend.Name, builtinHTTPSend)
|
|
}
|
|
|
|
func handleHTTPSendErr(bctx BuiltinContext, err error) error {
|
|
// Return HTTP client timeout errors in a generic error message to avoid confusion about what happened.
|
|
// Do not do this if the builtin context was cancelled and is what caused the request to stop.
|
|
if urlErr, ok := err.(*url.Error); ok && urlErr.Timeout() && bctx.Context.Err() == nil {
|
|
err = fmt.Errorf("%s %s: request timed out", urlErr.Op, urlErr.URL)
|
|
}
|
|
if err := bctx.Context.Err(); err != nil {
|
|
return Halt{
|
|
Err: &Error{
|
|
Code: CancelErr,
|
|
Message: fmt.Sprintf("http.send: timed out (%s)", err.Error()),
|
|
},
|
|
}
|
|
}
|
|
return handleBuiltinErr(ast.HTTPSend.Name, bctx.Location, err)
|
|
}
|
|
|
|
func initDefaults() {
|
|
timeoutDuration := os.Getenv(defaultHTTPRequestTimeoutEnv)
|
|
if timeoutDuration != "" {
|
|
var err error
|
|
defaultHTTPRequestTimeout, err = time.ParseDuration(timeoutDuration)
|
|
if err != nil {
|
|
// If it is set to something not valid don't let the process continue in a state
|
|
// that will almost definitely give unexpected results by having it set at 0
|
|
// which means no timeout..
|
|
// This environment variable isn't considered part of the public API.
|
|
// TODO(patrick-east): Remove the environment variable
|
|
panic(fmt.Sprintf("invalid value for HTTP_SEND_TIMEOUT: %s", err))
|
|
}
|
|
}
|
|
}
|
|
|
|
func validateHTTPRequestOperand(term *ast.Term, pos int) (ast.Object, error) {
|
|
|
|
obj, err := builtins.ObjectOperand(term.Value, pos)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
requestKeys := ast.NewSet(obj.Keys()...)
|
|
|
|
invalidKeys := requestKeys.Diff(allowedKeys)
|
|
if invalidKeys.Len() != 0 {
|
|
return nil, builtins.NewOperandErr(pos, "invalid request parameters(s): %v", invalidKeys)
|
|
}
|
|
|
|
missingKeys := requiredKeys.Diff(requestKeys)
|
|
if missingKeys.Len() != 0 {
|
|
return nil, builtins.NewOperandErr(pos, "missing required request parameters(s): %v", missingKeys)
|
|
}
|
|
|
|
return obj, nil
|
|
|
|
}
|
|
|
|
// canonicalizeHeaders returns a copy of the headers where the keys are in
|
|
// canonical HTTP form.
|
|
func canonicalizeHeaders(headers map[string]interface{}) map[string]interface{} {
|
|
canonicalized := map[string]interface{}{}
|
|
|
|
for k, v := range headers {
|
|
canonicalized[http.CanonicalHeaderKey(k)] = v
|
|
}
|
|
|
|
return canonicalized
|
|
}
|
|
|
|
// useSocket examines the url for "unix://" and returns a *http.Transport with
|
|
// a DialContext that opens a socket (specified in the http call).
|
|
// The url is expected to contain socket=/path/to/socket (url encoded)
|
|
// Ex. "unix://localhost/end/point?socket=%2Ftmp%2Fhttp.sock"
|
|
func useSocket(rawURL string, tlsConfig *tls.Config) (bool, string, *http.Transport) {
|
|
u, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
return false, "", nil
|
|
}
|
|
|
|
if u.Scheme != "unix" || u.RawQuery == "" {
|
|
return false, rawURL, nil
|
|
}
|
|
|
|
v, err := url.ParseQuery(u.RawQuery)
|
|
if err != nil {
|
|
return false, rawURL, nil
|
|
}
|
|
|
|
// Rewrite URL targeting the UNIX domain socket.
|
|
u.Scheme = "http"
|
|
|
|
// Extract the path to the socket.
|
|
// Only retrieve the first value. Subsequent values are ignored and removed
|
|
// to prevent HTTP parameter pollution.
|
|
socket := v.Get("socket")
|
|
v.Del("socket")
|
|
u.RawQuery = v.Encode()
|
|
|
|
tr := http.DefaultTransport.(*http.Transport).Clone()
|
|
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
return http.DefaultTransport.(*http.Transport).DialContext(ctx, "unix", socket)
|
|
}
|
|
tr.TLSClientConfig = tlsConfig
|
|
tr.DisableKeepAlives = true
|
|
|
|
return true, u.String(), tr
|
|
}
|
|
|
|
func verifyHost(bctx BuiltinContext, host string) error {
|
|
if bctx.Capabilities == nil || bctx.Capabilities.AllowNet == nil {
|
|
return nil
|
|
}
|
|
|
|
for _, allowed := range bctx.Capabilities.AllowNet {
|
|
if allowed == host {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return fmt.Errorf("unallowed host: %s", host)
|
|
}
|
|
|
|
func verifyURLHost(bctx BuiltinContext, unverifiedURL string) error {
|
|
// Eager return to avoid unnecessary URL parsing
|
|
if bctx.Capabilities == nil || bctx.Capabilities.AllowNet == nil {
|
|
return nil
|
|
}
|
|
|
|
parsedURL, err := url.Parse(unverifiedURL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
host := strings.Split(parsedURL.Host, ":")[0]
|
|
|
|
return verifyHost(bctx, host)
|
|
}
|
|
|
|
func createHTTPRequest(bctx BuiltinContext, obj ast.Object) (*http.Request, *http.Client, error) {
|
|
var url string
|
|
var method string
|
|
|
|
// Additional CA certificates loading options.
|
|
var tlsCaCert []byte
|
|
var tlsCaCertEnvVar string
|
|
var tlsCaCertFile string
|
|
|
|
// Client TLS certificate and key options. Each input source
|
|
// comes in a matched pair.
|
|
var tlsClientCert []byte
|
|
var tlsClientKey []byte
|
|
|
|
var tlsClientCertEnvVar string
|
|
var tlsClientKeyEnvVar string
|
|
|
|
var tlsClientCertFile string
|
|
var tlsClientKeyFile string
|
|
|
|
var tlsServerName string
|
|
var body *bytes.Buffer
|
|
var rawBody *bytes.Buffer
|
|
var enableRedirect bool
|
|
var tlsUseSystemCerts *bool
|
|
var tlsConfig tls.Config
|
|
var customHeaders map[string]interface{}
|
|
var tlsInsecureSkipVerify bool
|
|
var timeout = defaultHTTPRequestTimeout
|
|
|
|
for _, val := range obj.Keys() {
|
|
key, err := ast.JSON(val.Value)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
key = key.(string)
|
|
|
|
var strVal string
|
|
|
|
if s, ok := obj.Get(val).Value.(ast.String); ok {
|
|
strVal = strings.Trim(string(s), "\"")
|
|
} else {
|
|
// Most parameters are strings, so consolidate the type checking.
|
|
switch key {
|
|
case "method",
|
|
"url",
|
|
"raw_body",
|
|
"tls_ca_cert",
|
|
"tls_ca_cert_file",
|
|
"tls_ca_cert_env_variable",
|
|
"tls_client_cert",
|
|
"tls_client_cert_file",
|
|
"tls_client_cert_env_variable",
|
|
"tls_client_key",
|
|
"tls_client_key_file",
|
|
"tls_client_key_env_variable",
|
|
"tls_server_name":
|
|
return nil, nil, fmt.Errorf("%q must be a string", key)
|
|
}
|
|
}
|
|
|
|
switch key {
|
|
case "method":
|
|
method = strings.ToUpper(strVal)
|
|
case "url":
|
|
err := verifyURLHost(bctx, strVal)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
url = strVal
|
|
case "enable_redirect":
|
|
enableRedirect, err = strconv.ParseBool(obj.Get(val).String())
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
case "body":
|
|
bodyVal := obj.Get(val).Value
|
|
bodyValInterface, err := ast.JSON(bodyVal)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
bodyValBytes, err := json.Marshal(bodyValInterface)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
body = bytes.NewBuffer(bodyValBytes)
|
|
case "raw_body":
|
|
rawBody = bytes.NewBuffer([]byte(strVal))
|
|
case "tls_use_system_certs":
|
|
tempTLSUseSystemCerts, err := strconv.ParseBool(obj.Get(val).String())
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
tlsUseSystemCerts = &tempTLSUseSystemCerts
|
|
case "tls_ca_cert":
|
|
tlsCaCert = []byte(strVal)
|
|
case "tls_ca_cert_file":
|
|
tlsCaCertFile = strVal
|
|
case "tls_ca_cert_env_variable":
|
|
tlsCaCertEnvVar = strVal
|
|
case "tls_client_cert":
|
|
tlsClientCert = []byte(strVal)
|
|
case "tls_client_cert_file":
|
|
tlsClientCertFile = strVal
|
|
case "tls_client_cert_env_variable":
|
|
tlsClientCertEnvVar = strVal
|
|
case "tls_client_key":
|
|
tlsClientKey = []byte(strVal)
|
|
case "tls_client_key_file":
|
|
tlsClientKeyFile = strVal
|
|
case "tls_client_key_env_variable":
|
|
tlsClientKeyEnvVar = strVal
|
|
case "tls_server_name":
|
|
tlsServerName = strVal
|
|
case "headers":
|
|
headersVal := obj.Get(val).Value
|
|
headersValInterface, err := ast.JSON(headersVal)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
var ok bool
|
|
customHeaders, ok = headersValInterface.(map[string]interface{})
|
|
if !ok {
|
|
return nil, nil, fmt.Errorf("invalid type for headers key")
|
|
}
|
|
case "tls_insecure_skip_verify":
|
|
tlsInsecureSkipVerify, err = strconv.ParseBool(obj.Get(val).String())
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
case "timeout":
|
|
timeout, err = parseTimeout(obj.Get(val).Value)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
case "cache", "caching_mode",
|
|
"force_cache", "force_cache_duration_seconds",
|
|
"force_json_decode", "force_yaml_decode",
|
|
"raise_error": // no-op
|
|
default:
|
|
return nil, nil, fmt.Errorf("invalid parameter %q", key)
|
|
}
|
|
}
|
|
|
|
isTLS := false
|
|
client := &http.Client{
|
|
Timeout: timeout,
|
|
CheckRedirect: func(*http.Request, []*http.Request) error {
|
|
return http.ErrUseLastResponse
|
|
},
|
|
}
|
|
|
|
if tlsInsecureSkipVerify {
|
|
isTLS = true
|
|
tlsConfig.InsecureSkipVerify = tlsInsecureSkipVerify
|
|
}
|
|
|
|
if len(tlsClientCert) > 0 && len(tlsClientKey) > 0 {
|
|
cert, err := tls.X509KeyPair(tlsClientCert, tlsClientKey)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
isTLS = true
|
|
tlsConfig.Certificates = append(tlsConfig.Certificates, cert)
|
|
}
|
|
|
|
if tlsClientCertFile != "" && tlsClientKeyFile != "" {
|
|
cert, err := tls.LoadX509KeyPair(tlsClientCertFile, tlsClientKeyFile)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
isTLS = true
|
|
tlsConfig.Certificates = append(tlsConfig.Certificates, cert)
|
|
}
|
|
|
|
if tlsClientCertEnvVar != "" && tlsClientKeyEnvVar != "" {
|
|
cert, err := tls.X509KeyPair(
|
|
[]byte(os.Getenv(tlsClientCertEnvVar)),
|
|
[]byte(os.Getenv(tlsClientKeyEnvVar)))
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("cannot extract public/private key pair from envvars %q, %q: %w",
|
|
tlsClientCertEnvVar, tlsClientKeyEnvVar, err)
|
|
}
|
|
|
|
isTLS = true
|
|
tlsConfig.Certificates = append(tlsConfig.Certificates, cert)
|
|
}
|
|
|
|
// Use system certs if no CA cert is provided
|
|
// or system certs flag is not set
|
|
if len(tlsCaCert) == 0 && tlsCaCertFile == "" && tlsCaCertEnvVar == "" && tlsUseSystemCerts == nil {
|
|
trueValue := true
|
|
tlsUseSystemCerts = &trueValue
|
|
}
|
|
|
|
// Check the system certificates config first so that we
|
|
// load additional certificated into the correct pool.
|
|
if tlsUseSystemCerts != nil && *tlsUseSystemCerts && runtime.GOOS != "windows" {
|
|
pool, err := x509.SystemCertPool()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
isTLS = true
|
|
tlsConfig.RootCAs = pool
|
|
}
|
|
|
|
if len(tlsCaCert) != 0 {
|
|
tlsCaCert = bytes.Replace(tlsCaCert, []byte("\\n"), []byte("\n"), -1)
|
|
pool, err := addCACertsFromBytes(tlsConfig.RootCAs, tlsCaCert)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
isTLS = true
|
|
tlsConfig.RootCAs = pool
|
|
}
|
|
|
|
if tlsCaCertFile != "" {
|
|
pool, err := addCACertsFromFile(tlsConfig.RootCAs, tlsCaCertFile)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
isTLS = true
|
|
tlsConfig.RootCAs = pool
|
|
}
|
|
|
|
if tlsCaCertEnvVar != "" {
|
|
pool, err := addCACertsFromEnv(tlsConfig.RootCAs, tlsCaCertEnvVar)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
isTLS = true
|
|
tlsConfig.RootCAs = pool
|
|
}
|
|
|
|
if isTLS {
|
|
if ok, parsedURL, tr := useSocket(url, &tlsConfig); ok {
|
|
client.Transport = tr
|
|
url = parsedURL
|
|
} else {
|
|
tr := http.DefaultTransport.(*http.Transport).Clone()
|
|
tr.TLSClientConfig = &tlsConfig
|
|
tr.DisableKeepAlives = true
|
|
client.Transport = tr
|
|
}
|
|
} else {
|
|
if ok, parsedURL, tr := useSocket(url, nil); ok {
|
|
client.Transport = tr
|
|
url = parsedURL
|
|
}
|
|
}
|
|
|
|
// check if redirects are enabled
|
|
if enableRedirect {
|
|
client.CheckRedirect = func(req *http.Request, _ []*http.Request) error {
|
|
return verifyURLHost(bctx, req.URL.String())
|
|
}
|
|
}
|
|
|
|
if rawBody != nil {
|
|
body = rawBody
|
|
} else if body == nil {
|
|
body = bytes.NewBufferString("")
|
|
}
|
|
|
|
// create the http request, use the builtin context's context to ensure
|
|
// the request is cancelled if evaluation is cancelled.
|
|
req, err := http.NewRequest(method, url, body)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
req = req.WithContext(bctx.Context)
|
|
|
|
// Add custom headers
|
|
if len(customHeaders) != 0 {
|
|
customHeaders = canonicalizeHeaders(customHeaders)
|
|
|
|
for k, v := range customHeaders {
|
|
header, ok := v.(string)
|
|
if !ok {
|
|
return nil, nil, fmt.Errorf("invalid type for headers value %q", v)
|
|
}
|
|
|
|
req.Header.Add(k, header)
|
|
}
|
|
|
|
// Don't overwrite or append to one that was set in the custom headers
|
|
if _, hasUA := customHeaders["User-Agent"]; !hasUA {
|
|
req.Header.Add("User-Agent", version.UserAgent)
|
|
}
|
|
|
|
// If the caller specifies the Host header, use it for the HTTP
|
|
// request host and the TLS server name.
|
|
if host, hasHost := customHeaders["Host"]; hasHost {
|
|
host := host.(string) // We already checked that it's a string.
|
|
req.Host = host
|
|
|
|
// Only default the ServerName if the caller has
|
|
// specified the host. If we don't specify anything,
|
|
// Go will default to the target hostname. This name
|
|
// is not the same as the default that Go populates
|
|
// `req.Host` with, which is why we don't just set
|
|
// this unconditionally.
|
|
tlsConfig.ServerName = host
|
|
}
|
|
}
|
|
|
|
if tlsServerName != "" {
|
|
tlsConfig.ServerName = tlsServerName
|
|
}
|
|
|
|
if len(bctx.DistributedTracingOpts) > 0 {
|
|
client.Transport = tracing.NewTransport(client.Transport, bctx.DistributedTracingOpts)
|
|
}
|
|
|
|
return req, client, nil
|
|
}
|
|
|
|
func executeHTTPRequest(req *http.Request, client *http.Client) (*http.Response, error) {
|
|
return client.Do(req)
|
|
}
|
|
|
|
func isContentType(header http.Header, typ ...string) bool {
|
|
for _, t := range typ {
|
|
if strings.Contains(header.Get("Content-Type"), t) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// In the BuiltinContext cache we only store a single entry that points to
|
|
// our ValueMap which is the "real" http.send() cache.
|
|
func getHTTPSendCache(bctx BuiltinContext) *ast.ValueMap {
|
|
raw, ok := bctx.Cache.Get(httpSendBuiltinCacheKey)
|
|
if !ok {
|
|
// Initialize if it isn't there
|
|
cache := ast.NewValueMap()
|
|
bctx.Cache.Put(httpSendBuiltinCacheKey, cache)
|
|
return cache
|
|
}
|
|
|
|
cache, ok := raw.(*ast.ValueMap)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
return cache
|
|
}
|
|
|
|
// checkHTTPSendCache checks for the given key's value in the cache
|
|
func checkHTTPSendCache(bctx BuiltinContext, key ast.Object) ast.Value {
|
|
requestCache := getHTTPSendCache(bctx)
|
|
if requestCache == nil {
|
|
return nil
|
|
}
|
|
|
|
return requestCache.Get(key)
|
|
}
|
|
|
|
func insertIntoHTTPSendCache(bctx BuiltinContext, key ast.Object, value ast.Value) {
|
|
requestCache := getHTTPSendCache(bctx)
|
|
if requestCache == nil {
|
|
// Should never happen.. if it does just skip caching the value
|
|
return
|
|
}
|
|
requestCache.Put(key, value)
|
|
}
|
|
|
|
// checkHTTPSendInterQueryCache checks for the given key's value in the inter-query cache
|
|
func (c *interQueryCache) checkHTTPSendInterQueryCache() (ast.Value, error) {
|
|
requestCache := c.bctx.InterQueryBuiltinCache
|
|
|
|
value, found := requestCache.Get(c.key)
|
|
if !found {
|
|
return nil, nil
|
|
}
|
|
c.bctx.Metrics.Counter(httpSendInterQueryCacheHits).Incr()
|
|
var cachedRespData *interQueryCacheData
|
|
|
|
switch v := value.(type) {
|
|
case *interQueryCacheValue:
|
|
var err error
|
|
cachedRespData, err = v.copyCacheData()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
case *interQueryCacheData:
|
|
cachedRespData = v
|
|
default:
|
|
return nil, nil
|
|
}
|
|
|
|
if getCurrentTime(c.bctx).Before(cachedRespData.ExpiresAt) {
|
|
return cachedRespData.formatToAST(c.forceJSONDecode, c.forceYAMLDecode)
|
|
}
|
|
|
|
var err error
|
|
c.httpReq, c.httpClient, err = createHTTPRequest(c.bctx, c.key)
|
|
if err != nil {
|
|
return nil, handleHTTPSendErr(c.bctx, err)
|
|
}
|
|
|
|
headers, err := parseResponseHeaders(cachedRespData.Headers)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// check with the server if the stale response is still up-to-date.
|
|
// If server returns a new response (ie. status_code=200), update the cache with the new response
|
|
// If server returns an unmodified response (ie. status_code=304), update the headers for the existing response
|
|
result, modified, err := revalidateCachedResponse(c.httpReq, c.httpClient, headers)
|
|
requestCache.Delete(c.key)
|
|
if err != nil || result == nil {
|
|
return nil, err
|
|
}
|
|
|
|
defer result.Body.Close()
|
|
|
|
if !modified {
|
|
// update the headers in the cached response with their corresponding values from the 304 (Not Modified) response
|
|
for headerName, values := range result.Header {
|
|
cachedRespData.Headers.Del(headerName)
|
|
for _, v := range values {
|
|
cachedRespData.Headers.Add(headerName, v)
|
|
}
|
|
}
|
|
|
|
expiresAt, err := expiryFromHeaders(result.Header)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cachedRespData.ExpiresAt = expiresAt
|
|
|
|
cachingMode, err := getCachingMode(c.key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var pcv cache.InterQueryCacheValue
|
|
|
|
if cachingMode == defaultCachingMode {
|
|
pcv, err = cachedRespData.toCacheValue()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
pcv = cachedRespData
|
|
}
|
|
|
|
c.bctx.InterQueryBuiltinCache.Insert(c.key, pcv)
|
|
|
|
return cachedRespData.formatToAST(c.forceJSONDecode, c.forceYAMLDecode)
|
|
}
|
|
|
|
newValue, respBody, err := formatHTTPResponseToAST(result, c.forceJSONDecode, c.forceYAMLDecode)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := insertIntoHTTPSendInterQueryCache(c.bctx, c.key, result, respBody, c.forceCacheParams); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return newValue, nil
|
|
}
|
|
|
|
// insertIntoHTTPSendInterQueryCache inserts given key and value in the inter-query cache
|
|
func insertIntoHTTPSendInterQueryCache(bctx BuiltinContext, key ast.Value, resp *http.Response, respBody []byte, cacheParams *forceCacheParams) error {
|
|
if resp == nil || (!forceCaching(cacheParams) && !canStore(resp.Header)) || !cacheableCodes.Contains(ast.IntNumberTerm(resp.StatusCode)) {
|
|
return nil
|
|
}
|
|
|
|
requestCache := bctx.InterQueryBuiltinCache
|
|
|
|
obj, ok := key.(ast.Object)
|
|
if !ok {
|
|
return fmt.Errorf("interface conversion error")
|
|
}
|
|
|
|
cachingMode, err := getCachingMode(obj)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var pcv cache.InterQueryCacheValue
|
|
|
|
if cachingMode == defaultCachingMode {
|
|
pcv, err = newInterQueryCacheValue(bctx, resp, respBody, cacheParams)
|
|
} else {
|
|
pcv, err = newInterQueryCacheData(bctx, resp, respBody, cacheParams)
|
|
}
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
requestCache.Insert(key, pcv)
|
|
return nil
|
|
}
|
|
|
|
func createAllowedKeys() {
|
|
for _, element := range allowedKeyNames {
|
|
allowedKeys.Add(ast.StringTerm(element))
|
|
}
|
|
}
|
|
|
|
func createCacheableHTTPStatusCodes() {
|
|
for _, element := range cacheableHTTPStatusCodes {
|
|
cacheableCodes.Add(ast.IntNumberTerm(element))
|
|
}
|
|
}
|
|
|
|
func parseTimeout(timeoutVal ast.Value) (time.Duration, error) {
|
|
var timeout time.Duration
|
|
switch t := timeoutVal.(type) {
|
|
case ast.Number:
|
|
timeoutInt, ok := t.Int64()
|
|
if !ok {
|
|
return timeout, fmt.Errorf("invalid timeout number value %v, must be int64", timeoutVal)
|
|
}
|
|
return time.Duration(timeoutInt), nil
|
|
case ast.String:
|
|
// Support strings without a unit, treat them the same as just a number value (ns)
|
|
var err error
|
|
timeoutInt, err := strconv.ParseInt(string(t), 10, 64)
|
|
if err == nil {
|
|
return time.Duration(timeoutInt), nil
|
|
}
|
|
|
|
// Try parsing it as a duration (requires a supported units suffix)
|
|
timeout, err = time.ParseDuration(string(t))
|
|
if err != nil {
|
|
return timeout, fmt.Errorf("invalid timeout value %v: %s", timeoutVal, err)
|
|
}
|
|
return timeout, nil
|
|
default:
|
|
return timeout, builtins.NewOperandErr(1, "'timeout' must be one of {string, number} but got %s", ast.TypeName(t))
|
|
}
|
|
}
|
|
|
|
func getBoolValFromReqObj(req ast.Object, key *ast.Term) (bool, error) {
|
|
var b ast.Boolean
|
|
var ok bool
|
|
if v := req.Get(key); v != nil {
|
|
if b, ok = v.Value.(ast.Boolean); !ok {
|
|
return false, fmt.Errorf("invalid value for %v field", key.String())
|
|
}
|
|
}
|
|
return bool(b), nil
|
|
}
|
|
|
|
func getCachingMode(req ast.Object) (cachingMode, error) {
|
|
key := ast.StringTerm("caching_mode")
|
|
var s ast.String
|
|
var ok bool
|
|
if v := req.Get(key); v != nil {
|
|
if s, ok = v.Value.(ast.String); !ok {
|
|
return "", fmt.Errorf("invalid value for %v field", key.String())
|
|
}
|
|
|
|
switch cachingMode(s) {
|
|
case defaultCachingMode, cachingModeDeserialized:
|
|
return cachingMode(s), nil
|
|
default:
|
|
return "", fmt.Errorf("invalid value specified for %v field: %v", key.String(), string(s))
|
|
}
|
|
}
|
|
return defaultCachingMode, nil
|
|
}
|
|
|
|
type interQueryCacheValue struct {
|
|
Data []byte
|
|
}
|
|
|
|
func newInterQueryCacheValue(bctx BuiltinContext, resp *http.Response, respBody []byte, cacheParams *forceCacheParams) (*interQueryCacheValue, error) {
|
|
data, err := newInterQueryCacheData(bctx, resp, respBody, cacheParams)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
b, err := json.Marshal(data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &interQueryCacheValue{Data: b}, nil
|
|
}
|
|
|
|
func (cb interQueryCacheValue) SizeInBytes() int64 {
|
|
return int64(len(cb.Data))
|
|
}
|
|
|
|
func (cb *interQueryCacheValue) copyCacheData() (*interQueryCacheData, error) {
|
|
var res interQueryCacheData
|
|
err := util.UnmarshalJSON(cb.Data, &res)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &res, nil
|
|
}
|
|
|
|
type interQueryCacheData struct {
|
|
RespBody []byte
|
|
Status string
|
|
StatusCode int
|
|
Headers http.Header
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
func forceCaching(cacheParams *forceCacheParams) bool {
|
|
return cacheParams != nil && cacheParams.forceCacheDurationSeconds > 0
|
|
}
|
|
|
|
func expiryFromHeaders(headers http.Header) (time.Time, error) {
|
|
var expiresAt time.Time
|
|
maxAge, err := parseMaxAgeCacheDirective(parseCacheControlHeader(headers))
|
|
if err != nil {
|
|
return time.Time{}, err
|
|
}
|
|
if maxAge != -1 {
|
|
createdAt, err := getResponseHeaderDate(headers)
|
|
if err != nil {
|
|
return time.Time{}, err
|
|
}
|
|
expiresAt = createdAt.Add(time.Second * time.Duration(maxAge))
|
|
} else {
|
|
expiresAt = getResponseHeaderExpires(headers)
|
|
}
|
|
return expiresAt, nil
|
|
}
|
|
|
|
func newInterQueryCacheData(bctx BuiltinContext, resp *http.Response, respBody []byte, cacheParams *forceCacheParams) (*interQueryCacheData, error) {
|
|
var expiresAt time.Time
|
|
|
|
if forceCaching(cacheParams) {
|
|
createdAt := getCurrentTime(bctx)
|
|
expiresAt = createdAt.Add(time.Second * time.Duration(cacheParams.forceCacheDurationSeconds))
|
|
} else {
|
|
var err error
|
|
expiresAt, err = expiryFromHeaders(resp.Header)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
cv := interQueryCacheData{
|
|
ExpiresAt: expiresAt,
|
|
RespBody: respBody,
|
|
Status: resp.Status,
|
|
StatusCode: resp.StatusCode,
|
|
Headers: resp.Header}
|
|
|
|
return &cv, nil
|
|
}
|
|
|
|
func (c *interQueryCacheData) formatToAST(forceJSONDecode, forceYAMLDecode bool) (ast.Value, error) {
|
|
return prepareASTResult(c.Headers, forceJSONDecode, forceYAMLDecode, c.RespBody, c.Status, c.StatusCode)
|
|
}
|
|
|
|
func (c *interQueryCacheData) toCacheValue() (*interQueryCacheValue, error) {
|
|
b, err := json.Marshal(c)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &interQueryCacheValue{Data: b}, nil
|
|
}
|
|
|
|
func (c *interQueryCacheData) SizeInBytes() int64 {
|
|
return 0
|
|
}
|
|
|
|
type responseHeaders struct {
|
|
date time.Time // origination date and time of response
|
|
cacheControl map[string]string // response cache-control header
|
|
maxAge deltaSeconds // max-age cache control directive
|
|
expires time.Time // date/time after which the response is considered stale
|
|
etag string // identifier for a specific version of the response
|
|
lastModified string // date and time response was last modified as per origin server
|
|
}
|
|
|
|
// deltaSeconds specifies a non-negative integer, representing
|
|
// time in seconds: http://tools.ietf.org/html/rfc7234#section-1.2.1
|
|
type deltaSeconds int32
|
|
|
|
func parseResponseHeaders(headers http.Header) (*responseHeaders, error) {
|
|
var err error
|
|
result := responseHeaders{}
|
|
|
|
result.date, err = getResponseHeaderDate(headers)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result.cacheControl = parseCacheControlHeader(headers)
|
|
result.maxAge, err = parseMaxAgeCacheDirective(result.cacheControl)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result.expires = getResponseHeaderExpires(headers)
|
|
|
|
result.etag = headers.Get("etag")
|
|
|
|
result.lastModified = headers.Get("last-modified")
|
|
|
|
return &result, nil
|
|
}
|
|
|
|
func revalidateCachedResponse(req *http.Request, client *http.Client, headers *responseHeaders) (*http.Response, bool, error) {
|
|
etag := headers.etag
|
|
lastModified := headers.lastModified
|
|
|
|
if etag == "" && lastModified == "" {
|
|
return nil, false, nil
|
|
}
|
|
|
|
cloneReq := req.Clone(req.Context())
|
|
|
|
if etag != "" {
|
|
cloneReq.Header.Set("if-none-match", etag)
|
|
}
|
|
|
|
if lastModified != "" {
|
|
cloneReq.Header.Set("if-modified-since", lastModified)
|
|
}
|
|
|
|
response, err := client.Do(cloneReq)
|
|
if err != nil {
|
|
return nil, false, err
|
|
}
|
|
|
|
switch response.StatusCode {
|
|
case http.StatusOK:
|
|
return response, true, nil
|
|
|
|
case http.StatusNotModified:
|
|
return response, false, nil
|
|
}
|
|
util.Close(response)
|
|
return nil, false, nil
|
|
}
|
|
|
|
func canStore(headers http.Header) bool {
|
|
ccHeaders := parseCacheControlHeader(headers)
|
|
|
|
// Check "no-store" cache directive
|
|
// The "no-store" response directive indicates that a cache MUST NOT
|
|
// store any part of either the immediate request or response.
|
|
if _, ok := ccHeaders["no-store"]; ok {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func getCurrentTime(bctx BuiltinContext) time.Time {
|
|
var current time.Time
|
|
|
|
value, err := ast.JSON(bctx.Time.Value)
|
|
if err != nil {
|
|
return current
|
|
}
|
|
|
|
valueNum, ok := value.(json.Number)
|
|
if !ok {
|
|
return current
|
|
}
|
|
|
|
valueNumInt, err := valueNum.Int64()
|
|
if err != nil {
|
|
return current
|
|
}
|
|
|
|
current = time.Unix(0, valueNumInt).UTC()
|
|
return current
|
|
}
|
|
|
|
func parseCacheControlHeader(headers http.Header) map[string]string {
|
|
ccDirectives := map[string]string{}
|
|
ccHeader := headers.Get("cache-control")
|
|
|
|
for _, part := range strings.Split(ccHeader, ",") {
|
|
part = strings.Trim(part, " ")
|
|
if part == "" {
|
|
continue
|
|
}
|
|
if strings.ContainsRune(part, '=') {
|
|
items := strings.Split(part, "=")
|
|
if len(items) != 2 {
|
|
continue
|
|
}
|
|
ccDirectives[strings.Trim(items[0], " ")] = strings.Trim(items[1], ",")
|
|
} else {
|
|
ccDirectives[part] = ""
|
|
}
|
|
}
|
|
|
|
return ccDirectives
|
|
}
|
|
|
|
func getResponseHeaderDate(headers http.Header) (date time.Time, err error) {
|
|
dateHeader := headers.Get("date")
|
|
if dateHeader == "" {
|
|
err = fmt.Errorf("no date header")
|
|
return
|
|
}
|
|
return http.ParseTime(dateHeader)
|
|
}
|
|
|
|
func getResponseHeaderExpires(headers http.Header) time.Time {
|
|
expiresHeader := headers.Get("expires")
|
|
if expiresHeader == "" {
|
|
return time.Time{}
|
|
}
|
|
|
|
date, err := http.ParseTime(expiresHeader)
|
|
if err != nil {
|
|
// servers can set `Expires: 0` which is an invalid date to indicate expired content
|
|
return time.Time{}
|
|
}
|
|
|
|
return date
|
|
}
|
|
|
|
// parseMaxAgeCacheDirective parses the max-age directive expressed in delta-seconds as per
|
|
// https://tools.ietf.org/html/rfc7234#section-1.2.1
|
|
func parseMaxAgeCacheDirective(cc map[string]string) (deltaSeconds, error) {
|
|
maxAge, ok := cc["max-age"]
|
|
if !ok {
|
|
return deltaSeconds(-1), nil
|
|
}
|
|
|
|
val, err := strconv.ParseUint(maxAge, 10, 32)
|
|
if err != nil {
|
|
if numError, ok := err.(*strconv.NumError); ok {
|
|
if numError.Err == strconv.ErrRange {
|
|
return deltaSeconds(math.MaxInt32), nil
|
|
}
|
|
}
|
|
return deltaSeconds(-1), err
|
|
}
|
|
|
|
if val > math.MaxInt32 {
|
|
return deltaSeconds(math.MaxInt32), nil
|
|
}
|
|
return deltaSeconds(val), nil
|
|
}
|
|
|
|
func formatHTTPResponseToAST(resp *http.Response, forceJSONDecode, forceYAMLDecode bool) (ast.Value, []byte, error) {
|
|
|
|
resultRawBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
resultObj, err := prepareASTResult(resp.Header, forceJSONDecode, forceYAMLDecode, resultRawBody, resp.Status, resp.StatusCode)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return resultObj, resultRawBody, nil
|
|
}
|
|
|
|
func prepareASTResult(headers http.Header, forceJSONDecode, forceYAMLDecode bool, body []byte, status string, statusCode int) (ast.Value, error) {
|
|
var resultBody interface{}
|
|
|
|
// If the response body cannot be JSON/YAML decoded,
|
|
// an error will not be returned. Instead, the "body" field
|
|
// in the result will be null.
|
|
switch {
|
|
case forceJSONDecode || isContentType(headers, "application/json"):
|
|
_ = util.UnmarshalJSON(body, &resultBody)
|
|
case forceYAMLDecode || isContentType(headers, "application/yaml", "application/x-yaml"):
|
|
_ = util.Unmarshal(body, &resultBody)
|
|
}
|
|
|
|
result := make(map[string]interface{})
|
|
result["status"] = status
|
|
result["status_code"] = statusCode
|
|
result["body"] = resultBody
|
|
result["raw_body"] = string(body)
|
|
result["headers"] = getResponseHeaders(headers)
|
|
|
|
resultObj, err := ast.InterfaceToValue(result)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return resultObj, nil
|
|
}
|
|
|
|
func getResponseHeaders(headers http.Header) map[string]interface{} {
|
|
respHeaders := map[string]interface{}{}
|
|
for headerName, values := range headers {
|
|
var respValues []interface{}
|
|
for _, v := range values {
|
|
respValues = append(respValues, v)
|
|
}
|
|
respHeaders[strings.ToLower(headerName)] = respValues
|
|
}
|
|
return respHeaders
|
|
}
|
|
|
|
// httpRequestExecutor defines an interface for the http send cache
|
|
type httpRequestExecutor interface {
|
|
CheckCache() (ast.Value, error)
|
|
InsertIntoCache(value *http.Response) (ast.Value, error)
|
|
ExecuteHTTPRequest() (*http.Response, error)
|
|
}
|
|
|
|
// newHTTPRequestExecutor returns a new HTTP request executor that wraps either an inter-query or
|
|
// intra-query cache implementation
|
|
func newHTTPRequestExecutor(bctx BuiltinContext, key ast.Object) (httpRequestExecutor, error) {
|
|
useInterQueryCache, forceCacheParams, err := useInterQueryCache(key)
|
|
if err != nil {
|
|
return nil, handleHTTPSendErr(bctx, err)
|
|
}
|
|
|
|
if useInterQueryCache && bctx.InterQueryBuiltinCache != nil {
|
|
return newInterQueryCache(bctx, key, forceCacheParams)
|
|
}
|
|
return newIntraQueryCache(bctx, key)
|
|
}
|
|
|
|
type interQueryCache struct {
|
|
bctx BuiltinContext
|
|
key ast.Object
|
|
httpReq *http.Request
|
|
httpClient *http.Client
|
|
forceJSONDecode bool
|
|
forceYAMLDecode bool
|
|
forceCacheParams *forceCacheParams
|
|
}
|
|
|
|
func newInterQueryCache(bctx BuiltinContext, key ast.Object, forceCacheParams *forceCacheParams) (*interQueryCache, error) {
|
|
return &interQueryCache{bctx: bctx, key: key, forceCacheParams: forceCacheParams}, nil
|
|
}
|
|
|
|
// CheckCache checks the cache for the value of the key set on this object
|
|
func (c *interQueryCache) CheckCache() (ast.Value, error) {
|
|
var err error
|
|
|
|
c.forceJSONDecode, err = getBoolValFromReqObj(c.key, ast.StringTerm("force_json_decode"))
|
|
if err != nil {
|
|
return nil, handleHTTPSendErr(c.bctx, err)
|
|
}
|
|
c.forceYAMLDecode, err = getBoolValFromReqObj(c.key, ast.StringTerm("force_yaml_decode"))
|
|
if err != nil {
|
|
return nil, handleHTTPSendErr(c.bctx, err)
|
|
}
|
|
|
|
resp, err := c.checkHTTPSendInterQueryCache()
|
|
|
|
// fallback to the http send cache if response not found in the inter-query cache or inter-query cache look-up results
|
|
// in an error
|
|
if resp == nil || err != nil {
|
|
return checkHTTPSendCache(c.bctx, c.key), nil
|
|
}
|
|
|
|
return resp, err
|
|
}
|
|
|
|
// InsertIntoCache inserts the key set on this object into the cache with the given value
|
|
func (c *interQueryCache) InsertIntoCache(value *http.Response) (ast.Value, error) {
|
|
result, respBody, err := formatHTTPResponseToAST(value, c.forceJSONDecode, c.forceYAMLDecode)
|
|
if err != nil {
|
|
return nil, handleHTTPSendErr(c.bctx, err)
|
|
}
|
|
|
|
// fallback to the http send cache if error encountered while inserting response in inter-query cache
|
|
err = insertIntoHTTPSendInterQueryCache(c.bctx, c.key, value, respBody, c.forceCacheParams)
|
|
if err != nil {
|
|
insertIntoHTTPSendCache(c.bctx, c.key, result)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// ExecuteHTTPRequest executes a HTTP request
|
|
func (c *interQueryCache) ExecuteHTTPRequest() (*http.Response, error) {
|
|
var err error
|
|
c.httpReq, c.httpClient, err = createHTTPRequest(c.bctx, c.key)
|
|
if err != nil {
|
|
return nil, handleHTTPSendErr(c.bctx, err)
|
|
}
|
|
|
|
return executeHTTPRequest(c.httpReq, c.httpClient)
|
|
}
|
|
|
|
type intraQueryCache struct {
|
|
bctx BuiltinContext
|
|
key ast.Object
|
|
}
|
|
|
|
func newIntraQueryCache(bctx BuiltinContext, key ast.Object) (*intraQueryCache, error) {
|
|
return &intraQueryCache{bctx: bctx, key: key}, nil
|
|
}
|
|
|
|
// CheckCache checks the cache for the value of the key set on this object
|
|
func (c *intraQueryCache) CheckCache() (ast.Value, error) {
|
|
return checkHTTPSendCache(c.bctx, c.key), nil
|
|
}
|
|
|
|
// InsertIntoCache inserts the key set on this object into the cache with the given value
|
|
func (c *intraQueryCache) InsertIntoCache(value *http.Response) (ast.Value, error) {
|
|
forceJSONDecode, err := getBoolValFromReqObj(c.key, ast.StringTerm("force_json_decode"))
|
|
if err != nil {
|
|
return nil, handleHTTPSendErr(c.bctx, err)
|
|
}
|
|
forceYAMLDecode, err := getBoolValFromReqObj(c.key, ast.StringTerm("force_yaml_decode"))
|
|
if err != nil {
|
|
return nil, handleHTTPSendErr(c.bctx, err)
|
|
}
|
|
|
|
result, _, err := formatHTTPResponseToAST(value, forceJSONDecode, forceYAMLDecode)
|
|
if err != nil {
|
|
return nil, handleHTTPSendErr(c.bctx, err)
|
|
}
|
|
|
|
if cacheableCodes.Contains(ast.IntNumberTerm(value.StatusCode)) {
|
|
insertIntoHTTPSendCache(c.bctx, c.key, result)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// ExecuteHTTPRequest executes a HTTP request
|
|
func (c *intraQueryCache) ExecuteHTTPRequest() (*http.Response, error) {
|
|
httpReq, httpClient, err := createHTTPRequest(c.bctx, c.key)
|
|
if err != nil {
|
|
return nil, handleHTTPSendErr(c.bctx, err)
|
|
}
|
|
return executeHTTPRequest(httpReq, httpClient)
|
|
}
|
|
|
|
func useInterQueryCache(req ast.Object) (bool, *forceCacheParams, error) {
|
|
value, err := getBoolValFromReqObj(req, ast.StringTerm("cache"))
|
|
if err != nil {
|
|
return false, nil, err
|
|
}
|
|
|
|
valueForceCache, err := getBoolValFromReqObj(req, ast.StringTerm("force_cache"))
|
|
if err != nil {
|
|
return false, nil, err
|
|
}
|
|
|
|
if valueForceCache {
|
|
forceCacheParams, err := newForceCacheParams(req)
|
|
return true, forceCacheParams, err
|
|
}
|
|
|
|
return value, nil, nil
|
|
}
|
|
|
|
type forceCacheParams struct {
|
|
forceCacheDurationSeconds int32
|
|
}
|
|
|
|
func newForceCacheParams(req ast.Object) (*forceCacheParams, error) {
|
|
term := req.Get(ast.StringTerm("force_cache_duration_seconds"))
|
|
if term == nil {
|
|
return nil, fmt.Errorf("'force_cache' set but 'force_cache_duration_seconds' parameter is missing")
|
|
}
|
|
|
|
forceCacheDurationSeconds := term.String()
|
|
|
|
value, err := strconv.ParseInt(forceCacheDurationSeconds, 10, 32)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &forceCacheParams{forceCacheDurationSeconds: int32(value)}, nil
|
|
}
|
|
|
|
func getRaiseErrorValue(req ast.Object) (bool, error) {
|
|
result := ast.Boolean(true)
|
|
var ok bool
|
|
if v := req.Get(ast.StringTerm("raise_error")); v != nil {
|
|
if result, ok = v.Value.(ast.Boolean); !ok {
|
|
return false, fmt.Errorf("invalid value for raise_error field")
|
|
}
|
|
}
|
|
return bool(result), nil
|
|
}
|